16 class ssl_error_category :
public std::error_category
19 const char* name() const noexcept
override {
return "ssl"; }
21 std::string
message(
int ev)
const override
24 ERR_error_string_n(
static_cast<unsigned long>(ev), buf,
sizeof(buf));
29 const ssl_error_category& get_ssl_category()
31 static ssl_error_category instance;
35 std::error_code make_ssl_error_code(
int ssl_error)
37 return std::error_code(ssl_error, get_ssl_category());
43 : socket_(std::move(socket))
53 throw std::runtime_error(
"Failed to create SSL object");
57 rbio_ = BIO_new(BIO_s_mem());
58 wbio_ = BIO_new(BIO_s_mem());
64 throw std::runtime_error(
"Failed to create BIO objects");
68 BIO_set_nbio(
rbio_, 1);
69 BIO_set_nbio(
wbio_, 1);
91 std::function<
void(std::error_code)> handler) ->
void
93 handshake_type_ = type;
94 handshake_callback_ = std::move(handler);
95 handshake_in_progress_.store(
true);
98 std::lock_guard<std::mutex> lock(ssl_mutex_);
100 if (type == handshake_type::client)
102 SSL_set_connect_state(ssl_);
106 SSL_set_accept_state(ssl_);
114 continue_handshake();
119 if (!handshake_in_progress_.load())
126 std::lock_guard<std::mutex> lock(ssl_mutex_);
127 result = SSL_do_handshake(ssl_);
133 handshake_in_progress_.store(
false);
134 handshake_complete_.store(
true);
139 std::function<void(std::error_code)> callback;
141 std::lock_guard<std::mutex> lock(callback_mutex_);
142 callback = std::move(handshake_callback_);
147 callback(std::error_code{});
154 std::lock_guard<std::mutex> lock(ssl_mutex_);
155 ssl_error = SSL_get_error(ssl_, result);
161 if (ssl_error == SSL_ERROR_WANT_READ || ssl_error == SSL_ERROR_WANT_WRITE)
168 handshake_in_progress_.store(
false);
170 std::function<void(std::error_code)> callback;
172 std::lock_guard<std::mutex> lock(callback_mutex_);
173 callback = std::move(handshake_callback_);
178 callback(make_ssl_error());
183 std::function<
void(
const std::vector<uint8_t>&,
184 const asio::ip::udp::endpoint&)> callback) ->
void
186 std::lock_guard<std::mutex> lock(callback_mutex_);
187 receive_callback_ = std::move(callback);
191 std::function<
void(std::error_code)> callback) ->
void
193 std::lock_guard<std::mutex> lock(callback_mutex_);
194 error_callback_ = std::move(callback);
199 bool expected =
false;
200 if (is_receiving_.compare_exchange_strong(expected,
true))
208 is_receiving_.store(
false);
213 if (!is_receiving_.load())
218 auto self = shared_from_this();
219 socket_.async_receive_from(
220 asio::buffer(read_buffer_),
222 [
this, self](std::error_code ec, std::size_t length)
224 if (!is_receiving_.load())
231 std::function<void(std::error_code)> callback;
233 std::lock_guard<std::mutex> lock(callback_mutex_);
234 callback = error_callback_;
245 std::vector<uint8_t> data(read_buffer_.begin(),
246 read_buffer_.begin() + length);
247 process_received_data(data, sender_endpoint_);
251 if (is_receiving_.load())
259 const asio::ip::udp::endpoint& sender) ->
void
263 std::lock_guard<std::mutex> lock(ssl_mutex_);
264 int written = BIO_write(rbio_, data.data(),
static_cast<int>(data.size()));
273 if (handshake_in_progress_.load())
275 continue_handshake();
280 if (handshake_complete_.load())
282 std::vector<uint8_t> decrypted;
283 decrypted.resize(65536);
287 std::lock_guard<std::mutex> lock(ssl_mutex_);
288 read_len = SSL_read(ssl_, decrypted.data(),
static_cast<int>(decrypted.size()));
293 decrypted.resize(
static_cast<std::size_t
>(read_len));
295 std::function<void(
const std::vector<uint8_t>&,
const asio::ip::udp::endpoint&)> callback;
297 std::lock_guard<std::mutex> lock(callback_mutex_);
298 callback = receive_callback_;
302 callback(decrypted, sender);
309 std::lock_guard<std::mutex> lock(ssl_mutex_);
310 ssl_error = SSL_get_error(ssl_, read_len);
313 if (ssl_error != SSL_ERROR_WANT_READ && ssl_error != SSL_ERROR_WANT_WRITE)
316 std::function<void(std::error_code)> callback;
318 std::lock_guard<std::mutex> lock(callback_mutex_);
319 callback = error_callback_;
323 callback(make_ssl_error());
333 std::vector<uint8_t> output;
335 std::lock_guard<std::mutex> lock(ssl_mutex_);
336 int pending = BIO_ctrl_pending(wbio_);
342 output.resize(
static_cast<std::size_t
>(pending));
343 int read_len = BIO_read(wbio_, output.data(), pending);
348 output.resize(
static_cast<std::size_t
>(read_len));
352 asio::ip::udp::endpoint target;
354 std::lock_guard<std::mutex> lock(endpoint_mutex_);
355 target = peer_endpoint_;
358 if (target.port() != 0)
360 auto buffer = std::make_shared<std::vector<uint8_t>>(std::move(output));
361 socket_.async_send_to(
362 asio::buffer(*buffer),
364 [buffer](std::error_code , std::size_t )
372 std::vector<uint8_t>&& data,
373 std::function<
void(std::error_code, std::size_t)> handler) ->
void
375 asio::ip::udp::endpoint target;
377 std::lock_guard<std::mutex> lock(endpoint_mutex_);
378 target = peer_endpoint_;
381 async_send_to(std::move(data), target, std::move(handler));
385 std::vector<uint8_t>&& data,
386 const asio::ip::udp::endpoint& endpoint,
387 std::function<
void(std::error_code, std::size_t)> handler) ->
void
389 if (!handshake_complete_.load())
393 handler(std::make_error_code(std::errc::not_connected), 0);
401 std::lock_guard<std::mutex> lock(ssl_mutex_);
402 written = SSL_write(ssl_, data.data(),
static_cast<int>(data.size()));
409 handler(make_ssl_error(), 0);
415 std::vector<uint8_t> encrypted;
417 std::lock_guard<std::mutex> lock(ssl_mutex_);
418 int pending = BIO_ctrl_pending(wbio_);
423 handler(std::make_error_code(std::errc::io_error), 0);
428 encrypted.resize(
static_cast<std::size_t
>(pending));
429 int read_len = BIO_read(wbio_, encrypted.data(), pending);
434 handler(std::make_error_code(std::errc::io_error), 0);
438 encrypted.resize(
static_cast<std::size_t
>(read_len));
442 auto self = shared_from_this();
443 auto buffer = std::make_shared<std::vector<uint8_t>>(std::move(encrypted));
444 auto original_size = data.size();
446 socket_.async_send_to(
447 asio::buffer(*buffer),
449 [handler = std::move(handler), buffer, original_size](
450 std::error_code ec, std::size_t )
455 handler(ec, ec ? 0 : original_size);
462 std::lock_guard<std::mutex> lock(endpoint_mutex_);
463 peer_endpoint_ = endpoint;
468 std::lock_guard<std::mutex> lock(
const_cast<std::mutex&
>(
endpoint_mutex_));
474 unsigned long err = ERR_get_error();
477 return std::make_error_code(std::errc::io_error);
479 return make_ssl_error_code(
static_cast<int>(err));
auto async_send(std::vector< uint8_t > &&data, std::function< void(std::error_code, std::size_t)> handler) -> void
Initiates an asynchronous encrypted send.
auto async_handshake(handshake_type type, std::function< void(std::error_code)> handler) -> void
Performs asynchronous DTLS handshake.
asio::ip::udp::endpoint peer_endpoint_
auto make_ssl_error() const -> std::error_code
Creates an OpenSSL error code from the current error state.
handshake_type
Handshake type enumeration.
auto set_receive_callback(std::function< void(const std::vector< uint8_t > &, const asio::ip::udp::endpoint &)> callback) -> void
Sets a callback to receive decrypted inbound datagrams.
auto set_error_callback(std::function< void(std::error_code)> callback) -> void
Sets a callback to handle socket errors.
auto stop_receive() -> void
Stops the receive loop.
auto set_peer_endpoint(const asio::ip::udp::endpoint &endpoint) -> void
Sets the peer endpoint for connected mode.
auto flush_bio_output() -> void
Flushes pending DTLS output to the network.
auto start_receive() -> void
Begins the continuous asynchronous receive loop.
auto process_received_data(const std::vector< uint8_t > &data, const asio::ip::udp::endpoint &sender) -> void
Processes received encrypted data through DTLS.
std::mutex endpoint_mutex_
dtls_socket(asio::ip::udp::socket socket, SSL_CTX *ssl_ctx)
Constructs a dtls_socket with an existing UDP socket.
auto peer_endpoint() const -> asio::ip::udp::endpoint
Returns the peer endpoint.
auto continue_handshake() -> void
Continues the handshake process.
auto do_receive() -> void
Internal function to handle the receive logic.
auto async_send_to(std::vector< uint8_t > &&data, const asio::ip::udp::endpoint &endpoint, std::function< void(std::error_code, std::size_t)> handler) -> void
Initiates an asynchronous encrypted send to a specific endpoint.
~dtls_socket()
Destructor. Cleans up OpenSSL resources.
struct ssl_ctx_st SSL_CTX
OpenSSL utilities and version definitions.