14using namespace protocols::quic;
21 : udp_socket_(std::move(socket))
23 , retransmit_timer_(udp_socket_.get_executor())
24 , idle_timer_(udp_socket_.get_executor())
45 : udp_socket_(std::move(other.udp_socket_))
46 , remote_endpoint_(std::move(other.remote_endpoint_))
47 , recv_buffer_(std::move(other.recv_buffer_))
49 , state_(other.state_.load())
50 , crypto_(std::move(other.crypto_))
51 , local_conn_id_(std::move(other.local_conn_id_))
52 , remote_conn_id_(std::move(other.remote_conn_id_))
53 , next_packet_number_(other.next_packet_number_)
54 , largest_received_pn_(other.largest_received_pn_)
55 , next_stream_id_(other.next_stream_id_)
56 , pending_crypto_data_(std::move(other.pending_crypto_data_))
57 , pending_stream_data_(std::move(other.pending_stream_data_))
58 , stream_data_cb_(std::move(other.stream_data_cb_))
59 , connected_cb_(std::move(other.connected_cb_))
60 , error_cb_(std::move(other.error_cb_))
61 , close_cb_(std::move(other.close_cb_))
62 , is_receiving_(other.is_receiving_.load())
63 , handshake_complete_(other.handshake_complete_.load())
64 , retransmit_timer_(std::move(other.retransmit_timer_))
65 , idle_timer_(std::move(other.idle_timer_))
75 udp_socket_ = std::move(other.udp_socket_);
76 remote_endpoint_ = std::move(other.remote_endpoint_);
77 recv_buffer_ = std::move(other.recv_buffer_);
79 state_.store(other.state_.load());
80 crypto_ = std::move(other.crypto_);
81 local_conn_id_ = std::move(other.local_conn_id_);
82 remote_conn_id_ = std::move(other.remote_conn_id_);
83 next_packet_number_ = other.next_packet_number_;
84 largest_received_pn_ = other.largest_received_pn_;
85 next_stream_id_ = other.next_stream_id_;
86 pending_crypto_data_ = std::move(other.pending_crypto_data_);
87 pending_stream_data_ = std::move(other.pending_stream_data_);
89 std::lock_guard<std::mutex> lock(callback_mutex_);
90 stream_data_cb_ = std::move(other.stream_data_cb_);
91 connected_cb_ = std::move(other.connected_cb_);
92 error_cb_ = std::move(other.error_cb_);
93 close_cb_ = std::move(other.close_cb_);
95 is_receiving_.store(other.is_receiving_.load());
96 handshake_complete_.store(other.handshake_complete_.load());
97 retransmit_timer_ = std::move(other.retransmit_timer_);
98 idle_timer_ = std::move(other.idle_timer_);
109 std::lock_guard<std::mutex> lock(callback_mutex_);
110 stream_data_cb_ = std::move(cb);
115 std::lock_guard<std::mutex> lock(callback_mutex_);
116 connected_cb_ = std::move(cb);
121 std::lock_guard<std::mutex> lock(callback_mutex_);
122 error_cb_ = std::move(cb);
127 std::lock_guard<std::mutex> lock(callback_mutex_);
128 close_cb_ = std::move(cb);
142 "connect() can only be called on client sockets",
150 "Connection already in progress or established",
154 remote_endpoint_ = endpoint;
157 auto init_result = crypto_.init_client(
158 server_name.empty() ? endpoint.address().to_string() : server_name);
159 if (init_result.is_err())
163 "Failed to initialize TLS client",
165 init_result.error().message);
170 remote_conn_id_ = generate_connection_id();
172 auto derive_result = crypto_.derive_initial_secrets(remote_conn_id_);
173 if (derive_result.is_err())
177 "Failed to derive initial secrets",
179 derive_result.error().message);
185 auto handshake_result = crypto_.start_handshake();
186 if (handshake_result.is_err())
191 "Failed to start TLS handshake",
193 handshake_result.error().message);
197 if (!handshake_result.value().empty())
199 queue_crypto_data(std::move(handshake_result.value()));
208 send_pending_packets();
220 "accept() can only be called on server sockets",
228 "Connection already in progress",
233 auto init_result = crypto_.init_server(cert_file, key_file);
234 if (init_result.is_err())
238 "Failed to initialize TLS server",
240 init_result.error().message);
253 auto current_state = state_.load();
270 std::vector<frame> frames;
271 frames.push_back(close_frame);
273 auto level = handshake_complete_.load()
274 ? encryption_level::application
275 : encryption_level::initial;
277 auto send_result = send_packet(level, std::move(frames));
278 if (send_result.is_err())
286 idle_timer_.expires_after(std::chrono::milliseconds(300));
287 idle_timer_.async_wait(
288 [self = shared_from_this()](
const std::error_code& ec)
293 self->stop_receive();
306 is_receiving_.store(
true);
312 is_receiving_.store(
false);
316 std::vector<uint8_t>&& data,
323 "Connection not established",
328 std::lock_guard<std::mutex> lock(state_mutex_);
329 pending_stream_data_[stream_id].push_back({std::move(data), fin});
332 send_pending_packets();
347 "Connection not established",
351 std::lock_guard<std::mutex> lock(state_mutex_);
358 uint64_t type_bits = 0;
368 uint64_t stream_id = (next_stream_id_ << 2) | type_bits;
372 pending_stream_data_[stream_id] = {};
374 return ok(std::move(stream_id));
379 std::lock_guard<std::mutex> lock(state_mutex_);
381 auto it = pending_stream_data_.find(stream_id);
382 if (it == pending_stream_data_.end())
391 it->second.push_back({{},
true});
393 send_pending_packets();
443 if (!is_receiving_.load())
448 auto self = shared_from_this();
449 udp_socket_.async_receive_from(
450 asio::buffer(recv_buffer_),
452 [
this, self](std::error_code ec, std::size_t bytes_transferred)
454 if (!is_receiving_.load())
461 if (ec != asio::error::operation_aborted)
463 std::lock_guard<std::mutex> lock(callback_mutex_);
472 if (bytes_transferred > 0)
474 handle_packet(std::span(recv_buffer_.data(), bytes_transferred));
478 if (is_receiving_.load())
489 if (header_result.is_err())
495 auto& [header, header_length] = header_result.value();
498 auto level = determine_encryption_level(header);
504 if (std::holds_alternative<long_header>(header))
506 const auto& lh = std::get<long_header>(header);
507 if (lh.type() == packet_type::initial)
510 remote_conn_id_ = lh.src_conn_id;
513 auto derive_result = crypto_.derive_initial_secrets(lh.dest_conn_id);
514 if (derive_result.is_err())
525 auto keys_result = crypto_.get_read_keys(level);
526 if (keys_result.is_err())
534 size_t pn_offset = header_length;
535 if (std::holds_alternative<long_header>(header))
542 size_t sample_offset = pn_offset + 4;
548 std::span<const uint8_t> sample(data.data() + sample_offset,
hp_sample_size);
551 std::vector<uint8_t> packet_copy(data.begin(), data.end());
555 std::span(packet_copy.data(), header_length + 4),
559 if (unprotect_header_result.is_err())
564 auto& [first_byte, pn_length] = unprotect_header_result.value();
567 uint64_t truncated_pn = 0;
568 for (
size_t i = 0; i < pn_length; ++i)
570 truncated_pn = (truncated_pn << 8) | packet_copy[pn_offset + i];
574 auto level_idx =
static_cast<size_t>(level);
576 truncated_pn, pn_length, largest_received_pn_[level_idx]);
579 if (full_pn > largest_received_pn_[level_idx])
581 largest_received_pn_[level_idx] = full_pn;
585 size_t payload_offset = pn_offset + pn_length;
588 std::span(packet_copy),
592 if (unprotect_result.is_err())
597 auto& [unprotected_header, payload] = unprotect_result.value();
601 if (frames_result.is_err())
607 for (
const auto& f : frames_result.value())
613 send_pending_packets();
618 std::visit([
this](
auto&& arg) {
619 using T = std::decay_t<
decltype(arg)>;
621 if constexpr (std::is_same_v<T, crypto_frame>)
623 process_crypto_frame(arg);
625 else if constexpr (std::is_same_v<T, stream_frame>)
627 process_stream_frame(arg);
629 else if constexpr (std::is_same_v<T, ack_frame>)
631 process_ack_frame(arg);
633 else if constexpr (std::is_same_v<T, connection_close_frame>)
635 process_connection_close_frame(arg);
637 else if constexpr (std::is_same_v<T, handshake_done_frame>)
639 process_handshake_done_frame();
641 else if constexpr (std::is_same_v<T, ping_frame>)
645 else if constexpr (std::is_same_v<T, padding_frame>)
655 auto level = crypto_.current_level();
658 auto response_result = crypto_.process_crypto_data(level, f.data);
659 if (response_result.is_err())
665 if (!response_result.value().empty())
667 queue_crypto_data(std::move(response_result.value()));
671 if (crypto_.is_handshake_complete() && !handshake_complete_.load())
673 handshake_complete_.store(
true);
681 std::lock_guard<std::mutex> lock(callback_mutex_);
690 std::vector<frame> frames;
693 (void)send_packet(encryption_level::application, std::move(frames));
697 std::lock_guard<std::mutex> lock(callback_mutex_);
708 std::lock_guard<std::mutex> lock(callback_mutex_);
711 stream_data_cb_(f.stream_id, f.data, f.fin);
726 std::lock_guard<std::mutex> lock(callback_mutex_);
729 close_cb_(f.error_code, f.reason_phrase);
733 idle_timer_.expires_after(std::chrono::milliseconds(300));
734 idle_timer_.async_wait(
735 [self = shared_from_this()](
const std::error_code& ec)
740 self->stop_receive();
750 handshake_complete_.store(
true);
753 std::lock_guard<std::mutex> lock(callback_mutex_);
763 auto current_state = state_.load();
773 std::vector<frame> frames;
777 std::lock_guard<std::mutex> lock(state_mutex_);
778 auto level_idx =
static_cast<size_t>(level);
779 while (!pending_crypto_data_[level_idx].empty())
781 auto& data = pending_crypto_data_[level_idx].front();
784 cf.
data = std::move(data);
785 frames.push_back(std::move(cf));
786 pending_crypto_data_[level_idx].pop_front();
793 std::lock_guard<std::mutex> lock(state_mutex_);
794 for (
auto& [stream_id, queue] : pending_stream_data_)
796 while (!queue.empty())
798 auto& [data, fin] = queue.front();
802 sf.
data = std::move(data);
804 frames.push_back(std::move(sf));
812 (void)send_packet(level, std::move(frames));
820 auto keys_result = crypto_.get_write_keys(level);
821 if (keys_result.is_err())
825 "Write keys not available",
830 std::vector<uint8_t> payload;
831 for (
const auto& f : frames)
834 payload.insert(payload.end(), frame_bytes.begin(), frame_bytes.end());
838 auto level_idx =
static_cast<size_t>(level);
839 uint64_t pn = next_packet_number_[level_idx]++;
842 std::vector<uint8_t> header;
843 if (level == encryption_level::initial)
846 remote_conn_id_, local_conn_id_, {}, pn);
848 else if (level == encryption_level::handshake)
851 remote_conn_id_, local_conn_id_, pn);
856 remote_conn_id_, pn, crypto_.key_phase());
861 keys_result.value(), header, payload, pn);
863 if (protect_result.is_err())
867 "Failed to protect packet",
869 protect_result.error().message);
873 auto& protected_packet = protect_result.value();
875 auto self = shared_from_this();
876 auto buffer = std::make_shared<std::vector<uint8_t>>(std::move(protected_packet));
878 udp_socket_.async_send_to(
879 asio::buffer(*buffer),
881 [self, buffer](std::error_code ec, std::size_t )
883 if (ec && ec != asio::error::operation_aborted)
885 std::lock_guard<std::mutex> lock(self->callback_mutex_);
898 auto level = crypto_.current_level();
899 auto level_idx =
static_cast<size_t>(level);
901 std::lock_guard<std::mutex> lock(state_mutex_);
902 pending_crypto_data_[level_idx].push_back(std::move(data));
908 if (std::holds_alternative<long_header>(header))
910 const auto& lh = std::get<long_header>(header);
913 case packet_type::initial:
914 return encryption_level::initial;
915 case packet_type::zero_rtt:
916 return encryption_level::zero_rtt;
917 case packet_type::handshake:
918 return encryption_level::handshake;
920 return encryption_level::initial;
926 return encryption_level::application;
932 std::random_device rd;
933 std::mt19937 gen(rd());
934 std::uniform_int_distribution<unsigned int> dis(0, 255);
936 std::array<uint8_t, 8> id_bytes;
937 for (
auto&
byte : id_bytes)
939 byte =
static_cast<uint8_t
>(dis(gen));
949 send_pending_packets();
954 state_.store(new_state);
A QUIC socket that wraps UDP and integrates QUIC packet protection.
auto local_connection_id() const -> const protocols::quic::connection_id &
Get the local connection ID.
auto determine_encryption_level(const protocols::quic::packet_header &header) const noexcept -> protocols::quic::encryption_level
Determine encryption level from packet header.
std::function< void( uint64_t stream_id, std::span< const uint8_t > data, bool fin)> stream_data_callback
Callback for receiving stream data.
auto on_retransmit_timeout() -> void
Retransmission timeout handler.
auto remote_connection_id() const -> const protocols::quic::connection_id &
Get the remote connection ID.
std::function< void(uint64_t error_code, const std::string &reason)> close_callback
Callback when connection is closed.
std::atomic< quic_connection_state > state_
Connection state.
auto process_frame(const protocols::quic::frame &f) -> void
Process a parsed frame.
auto connect(const asio::ip::udp::endpoint &endpoint, const std::string &server_name="") -> VoidResult
Connect to a remote server (client only)
auto process_connection_close_frame(const protocols::quic::connection_close_frame &f) -> void
Process CONNECTION_CLOSE frame.
auto transition_state(quic_connection_state new_state) -> void
Transition to a new connection state.
auto state() const noexcept -> quic_connection_state
Get the current connection state.
protocols::quic::connection_id remote_conn_id_
Remote connection ID.
auto send_stream_data(uint64_t stream_id, std::vector< uint8_t > &&data, bool fin=false) -> VoidResult
Send data on a stream.
auto process_handshake_done_frame() -> void
Process HANDSHAKE_DONE frame.
auto is_handshake_complete() const noexcept -> bool
Check if the TLS handshake is complete.
auto create_stream(bool unidirectional=false) -> Result< uint64_t >
Create a new stream.
auto queue_crypto_data(std::vector< uint8_t > &&data) -> void
Queue crypto data for sending.
std::atomic< bool > handshake_complete_
Is handshake complete.
auto process_ack_frame(const protocols::quic::ack_frame &f) -> void
Process ACK frame.
auto send_pending_packets() -> void
Send pending outgoing packets.
auto process_stream_frame(const protocols::quic::stream_frame &f) -> void
Process STREAM frame data.
auto close(uint64_t error_code=0, const std::string &reason="") -> VoidResult
Close the connection gracefully.
auto send_packet(protocols::quic::encryption_level level, std::vector< protocols::quic::frame > &&frames) -> VoidResult
Build and send a packet with frames.
auto role() const noexcept -> quic_role
Get the role (client or server)
quic_socket(asio::ip::udp::socket socket, quic_role role)
Constructs a QUIC socket.
auto remote_endpoint() const -> asio::ip::udp::endpoint
Get the remote endpoint.
auto handle_packet(std::span< const uint8_t > data) -> void
Handle received packet data.
auto do_receive() -> void
Internal receive loop implementation.
auto is_connected() const noexcept -> bool
Check if the connection is established.
std::function< void(std::error_code)> error_callback
Callback for error handling.
quic_role role_
Socket role (client/server)
uint64_t next_stream_id_
Next stream ID to allocate.
auto generate_connection_id() -> protocols::quic::connection_id
Generate a new connection ID.
~quic_socket()
Destructor.
std::function< void()> connected_callback
Callback when connection is established.
auto close_stream(uint64_t stream_id) -> VoidResult
Close a stream.
protocols::quic::connection_id local_conn_id_
Local connection ID.
auto stop_receive() -> void
Stop the receive loop.
asio::ip::udp::endpoint remote_endpoint_
Remote endpoint.
auto set_stream_data_callback(stream_data_callback cb) -> void
Set callback for stream data reception.
asio::steady_timer retransmit_timer_
Retransmission timer.
auto start_receive() -> void
Start the receive loop.
asio::steady_timer idle_timer_
Idle timeout timer.
auto accept(const std::string &cert_file, const std::string &key_file) -> VoidResult
Accept an incoming connection (server only)
auto set_close_callback(close_callback cb) -> void
Set callback for connection close.
quic_socket & operator=(const quic_socket &)=delete
auto set_connected_callback(connected_callback cb) -> void
Set callback for connection establishment.
auto process_crypto_frame(const protocols::quic::crypto_frame &f) -> void
Process CRYPTO frame data.
auto set_error_callback(error_callback cb) -> void
Set callback for errors.
QUIC Connection ID (RFC 9000 Section 5.1)
static auto build(const frame &f) -> std::vector< uint8_t >
Build any frame from variant.
static auto parse_all(std::span< const uint8_t > data) -> Result< std::vector< frame > >
Parse all frames from buffer.
static auto build_handshake(const connection_id &dest_cid, const connection_id &src_cid, uint64_t packet_number, uint32_t version=quic_version::version_1) -> std::vector< uint8_t >
Build a Handshake packet header.
static auto build_short(const connection_id &dest_cid, uint64_t packet_number, bool key_phase=false, bool spin_bit=false) -> std::vector< uint8_t >
Build a Short Header (1-RTT) packet.
static auto build_initial(const connection_id &dest_cid, const connection_id &src_cid, const std::vector< uint8_t > &token, uint64_t packet_number, uint32_t version=quic_version::version_1) -> std::vector< uint8_t >
Build an Initial packet header.
static auto decode(uint64_t truncated_pn, size_t pn_length, uint64_t largest_pn) noexcept -> uint64_t
Decode a packet number from received data.
static auto parse_header(std::span< const uint8_t > data) -> Result< std::pair< packet_header, size_t > >
Parse a packet header (without header protection removal)
static auto unprotect_header(const quic_keys &keys, std::span< uint8_t > header, size_t pn_offset, std::span< const uint8_t > sample) -> Result< std::pair< uint8_t, size_t > >
Remove header protection.
static auto unprotect(const quic_keys &keys, std::span< const uint8_t > packet, size_t header_length, uint64_t packet_number) -> Result< std::pair< std::vector< uint8_t >, std::vector< uint8_t > > >
Unprotect (decrypt) a QUIC packet.
static auto protect(const quic_keys &keys, std::span< const uint8_t > header, std::span< const uint8_t > payload, uint64_t packet_number) -> Result< std::vector< uint8_t > >
Protect (encrypt) a QUIC packet.
constexpr int invalid_argument
constexpr int not_initialized
constexpr int internal_error
constexpr int connection_failed
quic_role
Role of the QUIC endpoint (client or server)
quic_connection_state
QUIC connection state machine states.
@ connected
Connection established.
@ closed
Connection closed.
@ closing
Closing connection.
@ draining
Draining period before close.
@ handshake
Handshake in progress.
@ handshake_start
Initiating handshake.
@ error
Black hole detected, reset to base.
constexpr size_t hp_sample_size
Header protection sample size.
encryption_level
QUIC encryption levels (RFC 9001 Section 4)
std::variant< padding_frame, ping_frame, ack_frame, reset_stream_frame, stop_sending_frame, crypto_frame, new_token_frame, stream_frame, max_data_frame, max_stream_data_frame, max_streams_frame, data_blocked_frame, stream_data_blocked_frame, streams_blocked_frame, new_connection_id_frame, retire_connection_id_frame, path_challenge_frame, path_response_frame, connection_close_frame, handshake_done_frame > frame
Variant type holding any QUIC frame.
std::variant< long_header, short_header > packet_header
Variant type for packet headers.
VoidResult error_void(int code, const std::string &message, const std::string &source="network_system", const std::string &details="")
ACK frame (RFC 9000 Section 19.3)
CONNECTION_CLOSE frame (RFC 9000 Section 19.19)
uint64_t error_code
Error code indicating reason.
std::string reason_phrase
Human-readable reason.
bool is_application_error
True if application-level error.
CRYPTO frame (RFC 9000 Section 19.6)
uint64_t offset
Byte offset in crypto stream.
std::vector< uint8_t > data
Cryptographic handshake data.
HANDSHAKE_DONE frame (RFC 9000 Section 19.20)
STREAM frame (RFC 9000 Section 19.8)
std::vector< uint8_t > data
Stream data.
uint64_t stream_id
Stream identifier.
uint64_t offset
Byte offset in stream (0 if not present)
bool fin
True if this is the final data.