Network System 0.1.1
High-performance modular networking library for scalable client-server applications
Loading...
Searching...
No Matches
dtls_socket.cpp
Go to the documentation of this file.
1// BSD 3-Clause License
2// Copyright (c) 2024, 🍀☀🌕🌥 🌊
3// See the LICENSE file in the project root for full license information.
4
7
8#include <cstring>
9
11{
12
13namespace
14{
15 // Custom error category for SSL errors
16 class ssl_error_category : public std::error_category
17 {
18 public:
19 const char* name() const noexcept override { return "ssl"; }
20
21 std::string message(int ev) const override
22 {
23 char buf[256];
24 ERR_error_string_n(static_cast<unsigned long>(ev), buf, sizeof(buf));
25 return buf;
26 }
27 };
28
29 const ssl_error_category& get_ssl_category()
30 {
31 static ssl_error_category instance;
32 return instance;
33 }
34
35 std::error_code make_ssl_error_code(int ssl_error)
36 {
37 return std::error_code(ssl_error, get_ssl_category());
38 }
39
40} // anonymous namespace
41
42dtls_socket::dtls_socket(asio::ip::udp::socket socket, SSL_CTX* ssl_ctx)
43 : socket_(std::move(socket))
44 , ssl_ctx_(ssl_ctx)
45 , ssl_(nullptr)
46 , rbio_(nullptr)
47 , wbio_(nullptr)
48{
49 // Create SSL object
50 ssl_ = SSL_new(ssl_ctx_);
51 if (!ssl_)
52 {
53 throw std::runtime_error("Failed to create SSL object");
54 }
55
56 // Create memory BIOs for non-blocking I/O
57 rbio_ = BIO_new(BIO_s_mem());
58 wbio_ = BIO_new(BIO_s_mem());
59 if (!rbio_ || !wbio_)
60 {
61 if (rbio_) BIO_free(rbio_);
62 if (wbio_) BIO_free(wbio_);
63 SSL_free(ssl_);
64 throw std::runtime_error("Failed to create BIO objects");
65 }
66
67 // Set BIOs to non-blocking mode
68 BIO_set_nbio(rbio_, 1);
69 BIO_set_nbio(wbio_, 1);
70
71 // Connect BIOs to SSL object (SSL takes ownership)
72 SSL_set_bio(ssl_, rbio_, wbio_);
73
74 // Enable DTLS cookie exchange for servers (DoS protection)
75 // This is optional and can be configured later
76}
77
79{
81
82 if (ssl_)
83 {
84 SSL_shutdown(ssl_);
85 SSL_free(ssl_); // Also frees the BIOs
86 }
87}
88
90 handshake_type type,
91 std::function<void(std::error_code)> handler) -> void
92{
93 handshake_type_ = type;
94 handshake_callback_ = std::move(handler);
95 handshake_in_progress_.store(true);
96
97 {
98 std::lock_guard<std::mutex> lock(ssl_mutex_);
99
100 if (type == handshake_type::client)
101 {
102 SSL_set_connect_state(ssl_);
103 }
104 else
105 {
106 SSL_set_accept_state(ssl_);
107 }
108 }
109
110 // Start receiving for handshake packets
111 start_receive();
112
113 // Initiate handshake
114 continue_handshake();
115}
116
118{
119 if (!handshake_in_progress_.load())
120 {
121 return;
122 }
123
124 int result;
125 {
126 std::lock_guard<std::mutex> lock(ssl_mutex_);
127 result = SSL_do_handshake(ssl_);
128 }
129
130 if (result == 1)
131 {
132 // Handshake complete
133 handshake_in_progress_.store(false);
134 handshake_complete_.store(true);
135
136 // Flush any remaining output
137 flush_bio_output();
138
139 std::function<void(std::error_code)> callback;
140 {
141 std::lock_guard<std::mutex> lock(callback_mutex_);
142 callback = std::move(handshake_callback_);
143 }
144
145 if (callback)
146 {
147 callback(std::error_code{});
148 }
149 return;
150 }
151
152 int ssl_error;
153 {
154 std::lock_guard<std::mutex> lock(ssl_mutex_);
155 ssl_error = SSL_get_error(ssl_, result);
156 }
157
158 // Flush any output generated by the handshake
159 flush_bio_output();
160
161 if (ssl_error == SSL_ERROR_WANT_READ || ssl_error == SSL_ERROR_WANT_WRITE)
162 {
163 // Need more data - continue receiving
164 return;
165 }
166
167 // Handshake failed
168 handshake_in_progress_.store(false);
169
170 std::function<void(std::error_code)> callback;
171 {
172 std::lock_guard<std::mutex> lock(callback_mutex_);
173 callback = std::move(handshake_callback_);
174 }
175
176 if (callback)
177 {
178 callback(make_ssl_error());
179 }
180}
181
183 std::function<void(const std::vector<uint8_t>&,
184 const asio::ip::udp::endpoint&)> callback) -> void
185{
186 std::lock_guard<std::mutex> lock(callback_mutex_);
187 receive_callback_ = std::move(callback);
188}
189
191 std::function<void(std::error_code)> callback) -> void
192{
193 std::lock_guard<std::mutex> lock(callback_mutex_);
194 error_callback_ = std::move(callback);
195}
196
198{
199 bool expected = false;
200 if (is_receiving_.compare_exchange_strong(expected, true))
201 {
202 do_receive();
203 }
204}
205
207{
208 is_receiving_.store(false);
209}
210
212{
213 if (!is_receiving_.load())
214 {
215 return;
216 }
217
218 auto self = shared_from_this();
219 socket_.async_receive_from(
220 asio::buffer(read_buffer_),
221 sender_endpoint_,
222 [this, self](std::error_code ec, std::size_t length)
223 {
224 if (!is_receiving_.load())
225 {
226 return;
227 }
228
229 if (ec)
230 {
231 std::function<void(std::error_code)> callback;
232 {
233 std::lock_guard<std::mutex> lock(callback_mutex_);
234 callback = error_callback_;
235 }
236 if (callback)
237 {
238 callback(ec);
239 }
240 return;
241 }
242
243 if (length > 0)
244 {
245 std::vector<uint8_t> data(read_buffer_.begin(),
246 read_buffer_.begin() + length);
247 process_received_data(data, sender_endpoint_);
248 }
249
250 // Continue receiving
251 if (is_receiving_.load())
252 {
253 do_receive();
254 }
255 });
256}
257
258auto dtls_socket::process_received_data(const std::vector<uint8_t>& data,
259 const asio::ip::udp::endpoint& sender) -> void
260{
261 // Write received data to the read BIO
262 {
263 std::lock_guard<std::mutex> lock(ssl_mutex_);
264 int written = BIO_write(rbio_, data.data(), static_cast<int>(data.size()));
265 if (written <= 0)
266 {
267 // BIO write failed
268 return;
269 }
270 }
271
272 // If handshake is in progress, continue it
273 if (handshake_in_progress_.load())
274 {
275 continue_handshake();
276 return;
277 }
278
279 // Handshake complete, try to read decrypted data
280 if (handshake_complete_.load())
281 {
282 std::vector<uint8_t> decrypted;
283 decrypted.resize(65536);
284
285 int read_len;
286 {
287 std::lock_guard<std::mutex> lock(ssl_mutex_);
288 read_len = SSL_read(ssl_, decrypted.data(), static_cast<int>(decrypted.size()));
289 }
290
291 if (read_len > 0)
292 {
293 decrypted.resize(static_cast<std::size_t>(read_len));
294
295 std::function<void(const std::vector<uint8_t>&, const asio::ip::udp::endpoint&)> callback;
296 {
297 std::lock_guard<std::mutex> lock(callback_mutex_);
298 callback = receive_callback_;
299 }
300 if (callback)
301 {
302 callback(decrypted, sender);
303 }
304 }
305 else
306 {
307 int ssl_error;
308 {
309 std::lock_guard<std::mutex> lock(ssl_mutex_);
310 ssl_error = SSL_get_error(ssl_, read_len);
311 }
312
313 if (ssl_error != SSL_ERROR_WANT_READ && ssl_error != SSL_ERROR_WANT_WRITE)
314 {
315 // Real error
316 std::function<void(std::error_code)> callback;
317 {
318 std::lock_guard<std::mutex> lock(callback_mutex_);
319 callback = error_callback_;
320 }
321 if (callback)
322 {
323 callback(make_ssl_error());
324 }
325 }
326 }
327 }
328}
329
331{
332 // Check if there's data to send in the write BIO
333 std::vector<uint8_t> output;
334 {
335 std::lock_guard<std::mutex> lock(ssl_mutex_);
336 int pending = BIO_ctrl_pending(wbio_);
337 if (pending <= 0)
338 {
339 return;
340 }
341
342 output.resize(static_cast<std::size_t>(pending));
343 int read_len = BIO_read(wbio_, output.data(), pending);
344 if (read_len <= 0)
345 {
346 return;
347 }
348 output.resize(static_cast<std::size_t>(read_len));
349 }
350
351 // Send the encrypted data
352 asio::ip::udp::endpoint target;
353 {
354 std::lock_guard<std::mutex> lock(endpoint_mutex_);
355 target = peer_endpoint_;
356 }
357
358 if (target.port() != 0)
359 {
360 auto buffer = std::make_shared<std::vector<uint8_t>>(std::move(output));
361 socket_.async_send_to(
362 asio::buffer(*buffer),
363 target,
364 [buffer](std::error_code /*ec*/, std::size_t /*bytes*/)
365 {
366 // Fire and forget for handshake messages
367 });
368 }
369}
370
372 std::vector<uint8_t>&& data,
373 std::function<void(std::error_code, std::size_t)> handler) -> void
374{
375 asio::ip::udp::endpoint target;
376 {
377 std::lock_guard<std::mutex> lock(endpoint_mutex_);
378 target = peer_endpoint_;
379 }
380
381 async_send_to(std::move(data), target, std::move(handler));
382}
383
385 std::vector<uint8_t>&& data,
386 const asio::ip::udp::endpoint& endpoint,
387 std::function<void(std::error_code, std::size_t)> handler) -> void
388{
389 if (!handshake_complete_.load())
390 {
391 if (handler)
392 {
393 handler(std::make_error_code(std::errc::not_connected), 0);
394 }
395 return;
396 }
397
398 // Encrypt the data
399 int written;
400 {
401 std::lock_guard<std::mutex> lock(ssl_mutex_);
402 written = SSL_write(ssl_, data.data(), static_cast<int>(data.size()));
403 }
404
405 if (written <= 0)
406 {
407 if (handler)
408 {
409 handler(make_ssl_error(), 0);
410 }
411 return;
412 }
413
414 // Get encrypted data from write BIO
415 std::vector<uint8_t> encrypted;
416 {
417 std::lock_guard<std::mutex> lock(ssl_mutex_);
418 int pending = BIO_ctrl_pending(wbio_);
419 if (pending <= 0)
420 {
421 if (handler)
422 {
423 handler(std::make_error_code(std::errc::io_error), 0);
424 }
425 return;
426 }
427
428 encrypted.resize(static_cast<std::size_t>(pending));
429 int read_len = BIO_read(wbio_, encrypted.data(), pending);
430 if (read_len <= 0)
431 {
432 if (handler)
433 {
434 handler(std::make_error_code(std::errc::io_error), 0);
435 }
436 return;
437 }
438 encrypted.resize(static_cast<std::size_t>(read_len));
439 }
440
441 // Send encrypted data
442 auto self = shared_from_this();
443 auto buffer = std::make_shared<std::vector<uint8_t>>(std::move(encrypted));
444 auto original_size = data.size();
445
446 socket_.async_send_to(
447 asio::buffer(*buffer),
448 endpoint,
449 [handler = std::move(handler), buffer, original_size](
450 std::error_code ec, std::size_t /*bytes_sent*/)
451 {
452 if (handler)
453 {
454 // Report original (plaintext) size to the handler
455 handler(ec, ec ? 0 : original_size);
456 }
457 });
458}
459
460auto dtls_socket::set_peer_endpoint(const asio::ip::udp::endpoint& endpoint) -> void
461{
462 std::lock_guard<std::mutex> lock(endpoint_mutex_);
463 peer_endpoint_ = endpoint;
464}
465
466auto dtls_socket::peer_endpoint() const -> asio::ip::udp::endpoint
467{
468 std::lock_guard<std::mutex> lock(const_cast<std::mutex&>(endpoint_mutex_));
469 return peer_endpoint_;
470}
471
472auto dtls_socket::make_ssl_error() const -> std::error_code
473{
474 unsigned long err = ERR_get_error();
475 if (err == 0)
476 {
477 return std::make_error_code(std::errc::io_error);
478 }
479 return make_ssl_error_code(static_cast<int>(err));
480}
481
482} // namespace kcenon::network::internal
auto async_send(std::vector< uint8_t > &&data, std::function< void(std::error_code, std::size_t)> handler) -> void
Initiates an asynchronous encrypted send.
auto async_handshake(handshake_type type, std::function< void(std::error_code)> handler) -> void
Performs asynchronous DTLS handshake.
asio::ip::udp::endpoint peer_endpoint_
auto make_ssl_error() const -> std::error_code
Creates an OpenSSL error code from the current error state.
handshake_type
Handshake type enumeration.
Definition dtls_socket.h:55
auto set_receive_callback(std::function< void(const std::vector< uint8_t > &, const asio::ip::udp::endpoint &)> callback) -> void
Sets a callback to receive decrypted inbound datagrams.
auto set_error_callback(std::function< void(std::error_code)> callback) -> void
Sets a callback to handle socket errors.
auto stop_receive() -> void
Stops the receive loop.
auto set_peer_endpoint(const asio::ip::udp::endpoint &endpoint) -> void
Sets the peer endpoint for connected mode.
auto flush_bio_output() -> void
Flushes pending DTLS output to the network.
auto start_receive() -> void
Begins the continuous asynchronous receive loop.
auto process_received_data(const std::vector< uint8_t > &data, const asio::ip::udp::endpoint &sender) -> void
Processes received encrypted data through DTLS.
dtls_socket(asio::ip::udp::socket socket, SSL_CTX *ssl_ctx)
Constructs a dtls_socket with an existing UDP socket.
auto peer_endpoint() const -> asio::ip::udp::endpoint
Returns the peer endpoint.
auto continue_handshake() -> void
Continues the handshake process.
auto do_receive() -> void
Internal function to handle the receive logic.
auto async_send_to(std::vector< uint8_t > &&data, const asio::ip::udp::endpoint &endpoint, std::function< void(std::error_code, std::size_t)> handler) -> void
Initiates an asynchronous encrypted send to a specific endpoint.
~dtls_socket()
Destructor. Cleans up OpenSSL resources.
struct ssl_ctx_st SSL_CTX
Definition crypto.h:20
OpenSSL utilities and version definitions.