Network System 0.1.1
High-performance modular networking library for scalable client-server applications
Loading...
Searching...
No Matches
websocket_handshake.cpp
Go to the documentation of this file.
1// BSD 3-Clause License
2// Copyright (c) 2024, 🍀☀🌕🌥 🌊
3// See the LICENSE file in the project root for full license information.
4
6
7#include <algorithm>
8#include <random>
9#include <sstream>
10
11#include <openssl/evp.h>
12#include <openssl/sha.h>
13
15{
16 namespace
17 {
18 // WebSocket GUID constant from RFC 6455
19 constexpr std::string_view WEBSOCKET_GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
20
21 // Base64 encoding table
22 constexpr std::string_view BASE64_CHARS =
23 "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
24
28 auto to_lower(std::string str) -> std::string
29 {
30 std::transform(str.begin(), str.end(), str.begin(),
31 [](unsigned char c) { return std::tolower(c); });
32 return str;
33 }
34
38 auto trim(std::string_view str) -> std::string
39 {
40 const auto start = str.find_first_not_of(" \t\r\n");
41 if (start == std::string_view::npos)
42 {
43 return {};
44 }
45
46 const auto end = str.find_last_not_of(" \t\r\n");
47 return std::string(str.substr(start, end - start + 1));
48 }
49
50 } // anonymous namespace
51
52 auto websocket_handshake::base64_encode(const std::vector<uint8_t>& data)
53 -> std::string
54 {
55 std::string result;
56 result.reserve(((data.size() + 2) / 3) * 4);
57
58 int val = 0;
59 int valb = -6;
60
61 for (uint8_t c : data)
62 {
63 val = (val << 8) + c;
64 valb += 8;
65 while (valb >= 0)
66 {
67 result.push_back(BASE64_CHARS[(val >> valb) & 0x3F]);
68 valb -= 6;
69 }
70 }
71
72 if (valb > -6)
73 {
74 result.push_back(BASE64_CHARS[((val << 8) >> (valb + 8)) & 0x3F]);
75 }
76
77 while (result.size() % 4)
78 {
79 result.push_back('=');
80 }
81
82 return result;
83 }
84
85 auto websocket_handshake::sha1_hash(const std::string& data)
86 -> std::vector<uint8_t>
87 {
88 std::vector<uint8_t> hash(SHA_DIGEST_LENGTH);
89 SHA1(reinterpret_cast<const unsigned char*>(data.c_str()), data.size(),
90 hash.data());
91 return hash;
92 }
93
95 {
96 // Generate 16 random bytes
97 static thread_local std::random_device rd;
98 static thread_local std::mt19937 gen(rd());
99 static thread_local std::uniform_int_distribution<uint32_t> dis(0, 255);
100
101 std::vector<uint8_t> random_bytes(16);
102 for (auto& byte : random_bytes)
103 {
104 byte = static_cast<uint8_t>(dis(gen));
105 }
106
107 return base64_encode(random_bytes);
108 }
109
110 auto websocket_handshake::calculate_accept_key(const std::string& client_key)
111 -> std::string
112 {
113 // Concatenate client key with GUID
114 std::string combined = client_key;
115 combined.append(WEBSOCKET_GUID);
116
117 // Hash with SHA-1
118 auto hash = sha1_hash(combined);
119
120 // Encode as Base64
121 return base64_encode(hash);
122 }
123
124 auto websocket_handshake::parse_headers(const std::string& http_message)
125 -> std::map<std::string, std::string>
126 {
127 std::map<std::string, std::string> headers;
128
129 // Find the end of headers (empty line)
130 const auto header_end = http_message.find("\r\n\r\n");
131 if (header_end == std::string::npos)
132 {
133 return headers;
134 }
135
136 std::istringstream stream(http_message.substr(0, header_end));
137 std::string line;
138
139 // Skip the first line (request/status line)
140 std::getline(stream, line);
141
142 // Parse headers
143 while (std::getline(stream, line))
144 {
145 // Remove trailing \r if present
146 if (!line.empty() && line.back() == '\r')
147 {
148 line.pop_back();
149 }
150
151 const auto colon_pos = line.find(':');
152 if (colon_pos != std::string::npos)
153 {
154 auto name = to_lower(trim(line.substr(0, colon_pos)));
155 auto value = trim(line.substr(colon_pos + 1));
156 headers[name] = value;
157 }
158 }
159
160 return headers;
161 }
162
163 auto websocket_handshake::extract_status_code(const std::string& response)
164 -> int
165 {
166 // HTTP response format: "HTTP/1.1 101 Switching Protocols\r\n"
167 const auto first_line_end = response.find("\r\n");
168 if (first_line_end == std::string::npos)
169 {
170 return 0;
171 }
172
173 const auto status_line = response.substr(0, first_line_end);
174 std::istringstream stream(status_line);
175
176 std::string http_version;
177 int status_code = 0;
178
179 stream >> http_version >> status_code;
180 return status_code;
181 }
182
184 std::string_view host, std::string_view path, uint16_t port,
185 const std::map<std::string, std::string>& extra_headers) -> std::string
186 {
187 std::ostringstream request;
188
189 // Request line
190 request << "GET " << path << " HTTP/1.1\r\n";
191
192 // Host header (include port if non-standard)
193 request << "Host: " << host;
194 if (port != 80 && port != 443)
195 {
196 request << ":" << port;
197 }
198 request << "\r\n";
199
200 // Required WebSocket headers
201 request << "Upgrade: websocket\r\n";
202 request << "Connection: Upgrade\r\n";
203 request << "Sec-WebSocket-Key: " << generate_websocket_key() << "\r\n";
204 request << "Sec-WebSocket-Version: 13\r\n";
205
206 // Extra headers
207 for (const auto& [name, value] : extra_headers)
208 {
209 request << name << ": " << value << "\r\n";
210 }
211
212 // Empty line to end headers
213 request << "\r\n";
214
215 return request.str();
216 }
217
219 const std::string& response, const std::string& expected_key)
221 {
222 ws_handshake_result result;
223 result.success = false;
224
225 // Check status code
226 const int status_code = extract_status_code(response);
227 if (status_code != 101)
228 {
229 result.error_message =
230 "Invalid status code: " + std::to_string(status_code);
231 return result;
232 }
233
234 // Parse headers
235 result.headers = parse_headers(response);
236
237 // Validate Upgrade header
238 auto it = result.headers.find("upgrade");
239 if (it == result.headers.end() || to_lower(it->second) != "websocket")
240 {
241 result.error_message = "Missing or invalid Upgrade header";
242 return result;
243 }
244
245 // Validate Connection header
246 it = result.headers.find("connection");
247 if (it == result.headers.end() || to_lower(it->second) != "upgrade")
248 {
249 result.error_message = "Missing or invalid Connection header";
250 return result;
251 }
252
253 // Validate Sec-WebSocket-Accept header
254 it = result.headers.find("sec-websocket-accept");
255 if (it == result.headers.end())
256 {
257 result.error_message = "Missing Sec-WebSocket-Accept header";
258 return result;
259 }
260
261 const std::string expected_accept = calculate_accept_key(expected_key);
262 if (it->second != expected_accept)
263 {
264 result.error_message = "Invalid Sec-WebSocket-Accept value";
265 return result;
266 }
267
268 result.success = true;
269 return result;
270 }
271
272 auto websocket_handshake::parse_client_request(const std::string& request)
274 {
275 ws_handshake_result result;
276 result.success = false;
277
278 // Parse headers
279 result.headers = parse_headers(request);
280
281 // Validate request line (should start with "GET")
282 if (request.substr(0, 3) != "GET")
283 {
284 result.error_message = "Invalid HTTP method (expected GET)";
285 return result;
286 }
287
288 // Validate Upgrade header
289 auto it = result.headers.find("upgrade");
290 if (it == result.headers.end() || to_lower(it->second) != "websocket")
291 {
292 result.error_message = "Missing or invalid Upgrade header";
293 return result;
294 }
295
296 // Validate Connection header
297 it = result.headers.find("connection");
298 if (it == result.headers.end() || to_lower(it->second) != "upgrade")
299 {
300 result.error_message = "Missing or invalid Connection header";
301 return result;
302 }
303
304 // Validate Sec-WebSocket-Key header
305 it = result.headers.find("sec-websocket-key");
306 if (it == result.headers.end() || it->second.empty())
307 {
308 result.error_message = "Missing or empty Sec-WebSocket-Key header";
309 return result;
310 }
311
312 // Validate Sec-WebSocket-Version header
313 it = result.headers.find("sec-websocket-version");
314 if (it == result.headers.end() || it->second != "13")
315 {
316 result.error_message = "Missing or invalid Sec-WebSocket-Version header";
317 return result;
318 }
319
320 result.success = true;
321 return result;
322 }
323
324 auto websocket_handshake::create_server_response(const std::string& client_key)
325 -> std::string
326 {
327 std::ostringstream response;
328
329 // Status line
330 response << "HTTP/1.1 101 Switching Protocols\r\n";
331
332 // Required headers
333 response << "Upgrade: websocket\r\n";
334 response << "Connection: Upgrade\r\n";
335 response << "Sec-WebSocket-Accept: " << calculate_accept_key(client_key)
336 << "\r\n";
337
338 // Empty line to end headers
339 response << "\r\n";
340
341 return response.str();
342 }
343
344} // namespace kcenon::network::internal
static auto generate_websocket_key() -> std::string
Generates a random Sec-WebSocket-Key (client-side).
static auto base64_encode(const std::vector< uint8_t > &data) -> std::string
Encodes binary data to Base64.
static auto extract_status_code(const std::string &response) -> int
Extracts the status code from an HTTP response.
static auto create_server_response(const std::string &client_key) -> std::string
Creates a WebSocket handshake response (server-side).
static auto create_client_handshake(std::string_view host, std::string_view path, uint16_t port, const std::map< std::string, std::string > &extra_headers={}) -> std::string
Creates a WebSocket handshake request (client-side).
static auto validate_server_response(const std::string &response, const std::string &expected_key) -> ws_handshake_result
Validates a WebSocket handshake response (client-side).
static auto parse_headers(const std::string &http_message) -> std::map< std::string, std::string >
Parses HTTP headers from a request or response.
static auto sha1_hash(const std::string &data) -> std::vector< uint8_t >
Computes SHA-1 hash of input data.
static auto parse_client_request(const std::string &request) -> ws_handshake_result
Parses a WebSocket handshake request (server-side).
static auto calculate_accept_key(const std::string &client_key) -> std::string
Calculates Sec-WebSocket-Accept from client key.
QUIC stream implementation (RFC 9000 Sections 2-4)
Definition stream.h:147
http_version
HTTP protocol version.
Definition http_types.h:41
Result of a WebSocket handshake operation.
bool success
Whether the handshake was successful.
std::map< std::string, std::string > headers
Parsed HTTP headers.
std::string error_message
Error message if success is false.