Network System 0.1.1
High-performance modular networking library for scalable client-server applications
Loading...
Searching...
No Matches
websocket_protocol.cpp
Go to the documentation of this file.
1// BSD 3-Clause License
2// Copyright (c) 2024, 🍀☀🌕🌥 🌊
3// See the LICENSE file in the project root for full license information.
4
6
7#include <algorithm>
8#include <span>
9
11{
12 // ========================================================================
13 // ws_message implementation
14 // ========================================================================
15
16 auto ws_message::as_text() const -> std::string
17 {
18 return std::string(data.begin(), data.end());
19 }
20
21 auto ws_message::as_binary() const -> const std::vector<uint8_t>&
22 {
23 return data;
24 }
25
26 // ========================================================================
27 // websocket_protocol implementation
28 // ========================================================================
29
31 : is_client_(is_client)
32 , fragmented_type_(ws_opcode::continuation)
33 {
34 }
35
36 auto websocket_protocol::process_data(std::span<const uint8_t> data) -> void
37 {
38 // Append incoming data to buffer (single copy from span view)
39 buffer_.insert(buffer_.end(), data.begin(), data.end());
40
41 // Process frames from buffer
42 process_frames();
43 }
44
46 {
47 while (true)
48 {
49 // Try to decode frame header
50 auto header_opt = websocket_frame::decode_header(buffer_);
51 if (!header_opt.has_value())
52 {
53 // Not enough data for header, wait for more
54 return;
55 }
56
57 const auto& header = header_opt.value();
58
59 // Calculate total frame size
60 const size_t header_size =
61 websocket_frame::calculate_header_size(header.payload_len, header.mask);
62 const size_t frame_size = header_size + header.payload_len;
63
64 // Check if we have the complete frame
65 if (buffer_.size() < frame_size)
66 {
67 // Incomplete frame, wait for more data
68 return;
69 }
70
71 // Decode payload
72 auto payload = websocket_frame::decode_payload(header, buffer_);
73
74 // Remove processed frame from buffer
75 buffer_.erase(buffer_.begin(), buffer_.begin() + frame_size);
76
77 // Dispatch frame based on opcode
78 if (header.opcode == ws_opcode::text || header.opcode == ws_opcode::binary ||
79 header.opcode == ws_opcode::continuation)
80 {
81 handle_data_frame(header, payload);
82 }
83 else
84 {
85 handle_control_frame(header, payload);
86 }
87 }
88 }
89
91 const std::vector<uint8_t>& payload)
92 -> void
93 {
94 // Handle fragmentation
95 if (header.opcode == ws_opcode::continuation)
96 {
97 // Continuation frame
98 if (fragmented_message_.empty())
99 {
100 // Protocol error: continuation without initial frame
101 return;
102 }
103
104 // Append to fragmented message
105 fragmented_message_.insert(fragmented_message_.end(), payload.begin(),
106 payload.end());
107
108 if (header.fin)
109 {
110 // Final fragment - complete message
111 ws_message msg;
112 msg.type = (fragmented_type_ == ws_opcode::text) ? ws_message_type::text
114 msg.data = std::move(fragmented_message_);
115
116 // Validate UTF-8 for text messages
117 if (msg.type == ws_message_type::text && !is_valid_utf8(msg.data))
118 {
119 // Invalid UTF-8, protocol error
120 fragmented_message_.clear();
121 return;
122 }
123
124 // Clear fragmentation state
125 fragmented_message_.clear();
126 fragmented_type_ = ws_opcode::continuation;
127
128 // Invoke callback
129 if (message_callback_)
130 {
131 message_callback_(msg);
132 }
133 }
134 }
135 else
136 {
137 // Initial frame (text or binary)
138 if (!header.fin)
139 {
140 // Start of fragmented message
141 fragmented_type_ = header.opcode;
142 fragmented_message_ = payload;
143 }
144 else
145 {
146 // Complete message in single frame
147 ws_message msg;
148 msg.type = (header.opcode == ws_opcode::text) ? ws_message_type::text
150 msg.data = payload;
151
152 // Validate UTF-8 for text messages
153 if (msg.type == ws_message_type::text && !is_valid_utf8(msg.data))
154 {
155 // Invalid UTF-8, protocol error
156 return;
157 }
158
159 // Invoke callback
160 if (message_callback_)
161 {
162 message_callback_(msg);
163 }
164 }
165 }
166 }
167
169 const std::vector<uint8_t>& payload)
170 -> void
171 {
172 // Control frames must not be fragmented
173 if (!header.fin)
174 {
175 // Protocol error
176 return;
177 }
178
179 // Control frames must have payload <= 125 bytes
180 if (header.payload_len > 125)
181 {
182 // Protocol error
183 return;
184 }
185
186 switch (header.opcode)
187 {
188 case ws_opcode::ping:
189 handle_ping(payload);
190 break;
191 case ws_opcode::pong:
192 handle_pong(payload);
193 break;
194 case ws_opcode::close:
195 handle_close(payload);
196 break;
197 default:
198 // Unknown control frame
199 break;
200 }
201 }
202
203 auto websocket_protocol::handle_ping(const std::vector<uint8_t>& payload) -> void
204 {
205 if (ping_callback_)
206 {
207 ping_callback_(payload);
208 }
209 }
210
211 auto websocket_protocol::handle_pong(const std::vector<uint8_t>& payload) -> void
212 {
213 if (pong_callback_)
214 {
215 pong_callback_(payload);
216 }
217 }
218
219 auto websocket_protocol::handle_close(const std::vector<uint8_t>& payload) -> void
220 {
222 std::string reason;
223
224 if (payload.size() >= 2)
225 {
226 // Extract close code (network byte order)
227 code = static_cast<ws_close_code>((payload[0] << 8) | payload[1]);
228
229 // Extract reason (if present)
230 if (payload.size() > 2)
231 {
232 reason = std::string(payload.begin() + 2, payload.end());
233
234 // Validate UTF-8 in reason
235 std::vector<uint8_t> reason_bytes(reason.begin(), reason.end());
236 if (!is_valid_utf8(reason_bytes))
237 {
238 // Invalid UTF-8 in close reason
240 reason.clear();
241 }
242 }
243 }
244
245 if (close_callback_)
246 {
247 close_callback_(code, reason);
248 }
249 }
250
252 -> std::vector<uint8_t>
253 {
254 // Convert string to bytes
255 std::vector<uint8_t> payload(text.begin(), text.end());
256
257 // Validate UTF-8
258 if (!is_valid_utf8(payload))
259 {
260 // Return empty frame on invalid UTF-8
261 return {};
262 }
263
264 // Encode as WebSocket frame
265 return websocket_frame::encode_frame(ws_opcode::text, std::move(payload), true,
266 is_client_);
267 }
268
269 auto websocket_protocol::create_binary_message(std::vector<uint8_t>&& data)
270 -> std::vector<uint8_t>
271 {
272 return websocket_frame::encode_frame(ws_opcode::binary, std::move(data), true,
273 is_client_);
274 }
275
276 auto websocket_protocol::create_ping(std::vector<uint8_t>&& payload)
277 -> std::vector<uint8_t>
278 {
279 // Ping payload must be <= 125 bytes
280 if (payload.size() > 125)
281 {
282 payload.resize(125);
283 }
284
285 return websocket_frame::encode_frame(ws_opcode::ping, std::move(payload), true,
286 is_client_);
287 }
288
289 auto websocket_protocol::create_pong(std::vector<uint8_t>&& payload)
290 -> std::vector<uint8_t>
291 {
292 // Pong payload must be <= 125 bytes
293 if (payload.size() > 125)
294 {
295 payload.resize(125);
296 }
297
298 return websocket_frame::encode_frame(ws_opcode::pong, std::move(payload), true,
299 is_client_);
300 }
301
302 auto websocket_protocol::create_close(ws_close_code code, std::string&& reason)
303 -> std::vector<uint8_t>
304 {
305 std::vector<uint8_t> payload;
306
307 // Encode close code (network byte order)
308 payload.push_back(static_cast<uint8_t>((static_cast<uint16_t>(code) >> 8) & 0xFF));
309 payload.push_back(static_cast<uint8_t>(static_cast<uint16_t>(code) & 0xFF));
310
311 // Append reason (if provided)
312 if (!reason.empty())
313 {
314 // Validate UTF-8 in reason
315 std::vector<uint8_t> reason_bytes(reason.begin(), reason.end());
316 if (is_valid_utf8(reason_bytes))
317 {
318 payload.insert(payload.end(), reason.begin(), reason.end());
319 }
320 }
321
322 // Ensure payload <= 125 bytes
323 if (payload.size() > 125)
324 {
325 payload.resize(125);
326 }
327
328 return websocket_frame::encode_frame(ws_opcode::close, std::move(payload), true,
329 is_client_);
330 }
331
333 std::function<void(const ws_message&)> callback) -> void
334 {
335 message_callback_ = std::move(callback);
336 }
337
339 std::function<void(const std::vector<uint8_t>&)> callback) -> void
340 {
341 ping_callback_ = std::move(callback);
342 }
343
345 std::function<void(const std::vector<uint8_t>&)> callback) -> void
346 {
347 pong_callback_ = std::move(callback);
348 }
349
351 std::function<void(ws_close_code, const std::string&)> callback) -> void
352 {
353 close_callback_ = std::move(callback);
354 }
355
356 auto websocket_protocol::is_valid_utf8(const std::vector<uint8_t>& data) -> bool
357 {
358 size_t i = 0;
359 while (i < data.size())
360 {
361 uint8_t byte = data[i];
362
363 // Determine the number of bytes in this UTF-8 character
364 int bytes_to_follow = 0;
365
366 if ((byte & 0x80) == 0x00)
367 {
368 // 1-byte character (0xxxxxxx)
369 bytes_to_follow = 0;
370 }
371 else if ((byte & 0xE0) == 0xC0)
372 {
373 // 2-byte character (110xxxxx)
374 bytes_to_follow = 1;
375 }
376 else if ((byte & 0xF0) == 0xE0)
377 {
378 // 3-byte character (1110xxxx)
379 bytes_to_follow = 2;
380 }
381 else if ((byte & 0xF8) == 0xF0)
382 {
383 // 4-byte character (11110xxx)
384 bytes_to_follow = 3;
385 }
386 else
387 {
388 // Invalid UTF-8 start byte
389 return false;
390 }
391
392 // Check if we have enough bytes
393 if (i + bytes_to_follow >= data.size())
394 {
395 return false;
396 }
397
398 // Validate continuation bytes
399 for (int j = 1; j <= bytes_to_follow; ++j)
400 {
401 if ((data[i + j] & 0xC0) != 0x80)
402 {
403 // Invalid continuation byte (should be 10xxxxxx)
404 return false;
405 }
406 }
407
408 i += bytes_to_follow + 1;
409 }
410
411 return true;
412 }
413
414} // namespace kcenon::network::internal
static auto encode_frame(ws_opcode opcode, std::vector< uint8_t > &&payload, bool fin=true, bool mask=false) -> std::vector< uint8_t >
Encodes data into a WebSocket frame.
static auto decode_payload(const ws_frame_header &header, const std::vector< uint8_t > &data) -> std::vector< uint8_t >
Decodes the payload from a WebSocket frame.
static auto calculate_header_size(uint64_t payload_len, bool mask) -> size_t
Calculates the size of the frame header.
static auto decode_header(const std::vector< uint8_t > &data) -> std::optional< ws_frame_header >
Decodes a WebSocket frame header from raw data.
auto handle_ping(const std::vector< uint8_t > &payload) -> void
Handles a ping frame.
auto handle_pong(const std::vector< uint8_t > &payload) -> void
Handles a pong frame.
auto set_message_callback(std::function< void(const ws_message &)> callback) -> void
Sets the callback for data messages.
auto create_pong(std::vector< uint8_t > &&payload={}) -> std::vector< uint8_t >
Creates a pong control frame.
auto create_close(ws_close_code code, std::string &&reason="") -> std::vector< uint8_t >
Creates a close control frame.
static auto is_valid_utf8(const std::vector< uint8_t > &data) -> bool
Validates UTF-8 encoding.
auto handle_close(const std::vector< uint8_t > &payload) -> void
Handles a close frame.
auto create_ping(std::vector< uint8_t > &&payload={}) -> std::vector< uint8_t >
Creates a ping control frame.
auto handle_data_frame(const ws_frame_header &header, const std::vector< uint8_t > &payload) -> void
Handles a data frame (text or binary).
websocket_protocol(bool is_client)
Constructs a WebSocket protocol handler.
auto process_frames() -> void
Processes incoming frames from the buffer.
auto set_close_callback(std::function< void(ws_close_code, const std::string &)> callback) -> void
Sets the callback for close frames.
auto set_ping_callback(std::function< void(const std::vector< uint8_t > &)> callback) -> void
Sets the callback for ping frames.
auto create_text_message(std::string &&text) -> std::vector< uint8_t >
Creates a text message frame.
auto create_binary_message(std::vector< uint8_t > &&data) -> std::vector< uint8_t >
Creates a binary message frame.
auto handle_control_frame(const ws_frame_header &header, const std::vector< uint8_t > &payload) -> void
Handles a control frame (ping, pong, close).
auto set_pong_callback(std::function< void(const std::vector< uint8_t > &)> callback) -> void
Sets the callback for pong frames.
auto process_data(std::span< const uint8_t > data) -> void
Processes incoming WebSocket data.
@ text
Text message (UTF-8 encoded)
ws_opcode
WebSocket frame operation codes as defined in RFC 6455.
@ text
Text frame (UTF-8 encoded)
@ close
Connection close frame.
ws_close_code
WebSocket close status codes (RFC 6455 Section 7.4).
Represents a decoded WebSocket frame header.
Represents a complete WebSocket message.
auto as_binary() const -> const std::vector< uint8_t > &
Returns message data as binary (reference).
ws_message_type type
Message type.
std::vector< uint8_t > data
Message payload.
auto as_text() const -> std::string
Converts message data to string (for text messages).