Network System 0.1.1
High-performance modular networking library for scalable client-server applications
Loading...
Searching...
No Matches
websocket_frame.cpp
Go to the documentation of this file.
1// BSD 3-Clause License
2// Copyright (c) 2024, 🍀☀🌕🌥 🌊
3// See the LICENSE file in the project root for full license information.
4
6
7#include <algorithm>
8#include <random>
9
11{
12 namespace
13 {
14 // Constants from RFC 6455
15 constexpr uint8_t FIN_BIT = 0x80;
16 constexpr uint8_t RSV1_BIT = 0x40;
17 constexpr uint8_t RSV2_BIT = 0x20;
18 constexpr uint8_t RSV3_BIT = 0x10;
19 constexpr uint8_t OPCODE_MASK = 0x0F;
20 constexpr uint8_t MASK_BIT = 0x80;
21 constexpr uint8_t PAYLOAD_LEN_MASK = 0x7F;
22
23 constexpr uint8_t PAYLOAD_LEN_16BIT = 126;
24 constexpr uint8_t PAYLOAD_LEN_64BIT = 127;
25
26 // Minimum header sizes
27 constexpr size_t MIN_HEADER_SIZE = 2;
28 constexpr size_t MASKING_KEY_SIZE = 4;
29
33 auto to_network_u16(uint16_t value) -> std::array<uint8_t, 2>
34 {
35 return {static_cast<uint8_t>((value >> 8) & 0xFF),
36 static_cast<uint8_t>(value & 0xFF)};
37 }
38
42 auto to_network_u64(uint64_t value) -> std::array<uint8_t, 8>
43 {
44 return {static_cast<uint8_t>((value >> 56) & 0xFF),
45 static_cast<uint8_t>((value >> 48) & 0xFF),
46 static_cast<uint8_t>((value >> 40) & 0xFF),
47 static_cast<uint8_t>((value >> 32) & 0xFF),
48 static_cast<uint8_t>((value >> 24) & 0xFF),
49 static_cast<uint8_t>((value >> 16) & 0xFF),
50 static_cast<uint8_t>((value >> 8) & 0xFF),
51 static_cast<uint8_t>(value & 0xFF)};
52 }
53
57 auto from_network_u16(const uint8_t* data) -> uint16_t
58 {
59 return static_cast<uint16_t>((data[0] << 8) | data[1]);
60 }
61
65 auto from_network_u64(const uint8_t* data) -> uint64_t
66 {
67 return (static_cast<uint64_t>(data[0]) << 56) |
68 (static_cast<uint64_t>(data[1]) << 48) |
69 (static_cast<uint64_t>(data[2]) << 40) |
70 (static_cast<uint64_t>(data[3]) << 32) |
71 (static_cast<uint64_t>(data[4]) << 24) |
72 (static_cast<uint64_t>(data[5]) << 16) |
73 (static_cast<uint64_t>(data[6]) << 8) | (static_cast<uint64_t>(data[7]));
74 }
75
76 } // anonymous namespace
77
78 auto websocket_frame::calculate_header_size(uint64_t payload_len, bool mask) -> size_t
79 {
80 size_t header_size = MIN_HEADER_SIZE;
81
82 if (payload_len >= 65536)
83 {
84 header_size += 8; // 64-bit length field
85 }
86 else if (payload_len >= 126)
87 {
88 header_size += 2; // 16-bit length field
89 }
90
91 if (mask)
92 {
93 header_size += MASKING_KEY_SIZE;
94 }
95
96 return header_size;
97 }
98
99 auto websocket_frame::encode_frame(ws_opcode opcode, std::vector<uint8_t>&& payload,
100 bool fin, bool mask) -> std::vector<uint8_t>
101 {
102 const uint64_t payload_len = payload.size();
103 const size_t header_size = calculate_header_size(payload_len, mask);
104
105 std::vector<uint8_t> frame;
106 frame.reserve(header_size + payload_len);
107
108 // Byte 0: FIN, RSV1-3, Opcode
109 uint8_t byte0 = static_cast<uint8_t>(opcode);
110 if (fin)
111 {
112 byte0 |= FIN_BIT;
113 }
114 frame.push_back(byte0);
115
116 // Byte 1: Mask, Payload length
117 uint8_t byte1 = 0;
118 if (mask)
119 {
120 byte1 |= MASK_BIT;
121 }
122
123 if (payload_len < 126)
124 {
125 byte1 |= static_cast<uint8_t>(payload_len);
126 frame.push_back(byte1);
127 }
128 else if (payload_len < 65536)
129 {
130 byte1 |= PAYLOAD_LEN_16BIT;
131 frame.push_back(byte1);
132
133 // Extended payload length (16-bit)
134 auto len_bytes = to_network_u16(static_cast<uint16_t>(payload_len));
135 frame.insert(frame.end(), len_bytes.begin(), len_bytes.end());
136 }
137 else
138 {
139 byte1 |= PAYLOAD_LEN_64BIT;
140 frame.push_back(byte1);
141
142 // Extended payload length (64-bit)
143 auto len_bytes = to_network_u64(payload_len);
144 frame.insert(frame.end(), len_bytes.begin(), len_bytes.end());
145 }
146
147 // Masking key (if mask is enabled)
148 std::array<uint8_t, 4> masking_key{};
149 if (mask)
150 {
151 masking_key = generate_mask();
152 frame.insert(frame.end(), masking_key.begin(), masking_key.end());
153
154 // Apply masking to payload
155 apply_mask(payload, masking_key);
156 }
157
158 // Append payload
159 frame.insert(frame.end(), payload.begin(), payload.end());
160
161 return frame;
162 }
163
164 auto websocket_frame::decode_header(const std::vector<uint8_t>& data)
165 -> std::optional<ws_frame_header>
166 {
167 if (data.size() < MIN_HEADER_SIZE)
168 {
169 return std::nullopt;
170 }
171
172 ws_frame_header header{};
173
174 // Parse byte 0: FIN, RSV1-3, Opcode
175 const uint8_t byte0 = data[0];
176 header.fin = (byte0 & FIN_BIT) != 0;
177 header.rsv1 = (byte0 & RSV1_BIT) != 0;
178 header.rsv2 = (byte0 & RSV2_BIT) != 0;
179 header.rsv3 = (byte0 & RSV3_BIT) != 0;
180 header.opcode = static_cast<ws_opcode>(byte0 & OPCODE_MASK);
181
182 // Parse byte 1: Mask, Payload length
183 const uint8_t byte1 = data[1];
184 header.mask = (byte1 & MASK_BIT) != 0;
185 uint8_t payload_len = byte1 & PAYLOAD_LEN_MASK;
186
187 size_t offset = MIN_HEADER_SIZE;
188
189 // Extended payload length
190 if (payload_len == PAYLOAD_LEN_16BIT)
191 {
192 if (data.size() < offset + 2)
193 {
194 return std::nullopt;
195 }
196 header.payload_len = from_network_u16(&data[offset]);
197 offset += 2;
198 }
199 else if (payload_len == PAYLOAD_LEN_64BIT)
200 {
201 if (data.size() < offset + 8)
202 {
203 return std::nullopt;
204 }
205 header.payload_len = from_network_u64(&data[offset]);
206 offset += 8;
207 }
208 else
209 {
210 header.payload_len = payload_len;
211 }
212
213 // Masking key
214 if (header.mask)
215 {
216 if (data.size() < offset + MASKING_KEY_SIZE)
217 {
218 return std::nullopt;
219 }
220 std::copy_n(data.begin() + offset, MASKING_KEY_SIZE, header.masking_key.begin());
221 }
222
223 return header;
224 }
225
227 const std::vector<uint8_t>& data)
228 -> std::vector<uint8_t>
229 {
230 const size_t header_size = calculate_header_size(header.payload_len, header.mask);
231
232 if (data.size() < header_size + header.payload_len)
233 {
234 return {};
235 }
236
237 std::vector<uint8_t> payload(data.begin() + header_size,
238 data.begin() + header_size + header.payload_len);
239
240 // Unmask if necessary
241 if (header.mask)
242 {
243 apply_mask(payload, header.masking_key);
244 }
245
246 return payload;
247 }
248
249 auto websocket_frame::apply_mask(std::vector<uint8_t>& data,
250 const std::array<uint8_t, 4>& mask) -> void
251 {
252 for (size_t i = 0; i < data.size(); ++i)
253 {
254 data[i] ^= mask[i % 4];
255 }
256 }
257
258 auto websocket_frame::generate_mask() -> std::array<uint8_t, 4>
259 {
260 static thread_local std::random_device rd;
261 static thread_local std::mt19937 gen(rd());
262 static thread_local std::uniform_int_distribution<uint32_t> dis;
263
264 const uint32_t random_value = dis(gen);
265 return {static_cast<uint8_t>((random_value >> 24) & 0xFF),
266 static_cast<uint8_t>((random_value >> 16) & 0xFF),
267 static_cast<uint8_t>((random_value >> 8) & 0xFF),
268 static_cast<uint8_t>(random_value & 0xFF)};
269 }
270
271} // namespace kcenon::network::internal
static auto generate_mask() -> std::array< uint8_t, 4 >
Generates a random 4-byte masking key.
static auto encode_frame(ws_opcode opcode, std::vector< uint8_t > &&payload, bool fin=true, bool mask=false) -> std::vector< uint8_t >
Encodes data into a WebSocket frame.
static auto decode_payload(const ws_frame_header &header, const std::vector< uint8_t > &data) -> std::vector< uint8_t >
Decodes the payload from a WebSocket frame.
static auto calculate_header_size(uint64_t payload_len, bool mask) -> size_t
Calculates the size of the frame header.
static auto apply_mask(std::vector< uint8_t > &data, const std::array< uint8_t, 4 > &mask) -> void
Applies or removes XOR masking on data.
static auto decode_header(const std::vector< uint8_t > &data) -> std::optional< ws_frame_header >
Decodes a WebSocket frame header from raw data.
ws_opcode
WebSocket frame operation codes as defined in RFC 6455.
std::variant< padding_frame, ping_frame, ack_frame, reset_stream_frame, stop_sending_frame, crypto_frame, new_token_frame, stream_frame, max_data_frame, max_stream_data_frame, max_streams_frame, data_blocked_frame, stream_data_blocked_frame, streams_blocked_frame, new_connection_id_frame, retire_connection_id_frame, path_challenge_frame, path_response_frame, connection_close_frame, handshake_done_frame > frame
Variant type holding any QUIC frame.
Represents a decoded WebSocket frame header.