11#include <openssl/evp.h>
12#include <openssl/sha.h>
19 constexpr std::string_view WEBSOCKET_GUID =
"258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
22 constexpr std::string_view BASE64_CHARS =
23 "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
28 auto to_lower(std::string str) -> std::string
30 std::transform(str.begin(), str.end(), str.begin(),
31 [](
unsigned char c) { return std::tolower(c); });
38 auto trim(std::string_view str) -> std::string
40 const auto start = str.find_first_not_of(
" \t\r\n");
41 if (start == std::string_view::npos)
46 const auto end = str.find_last_not_of(
" \t\r\n");
47 return std::string(str.substr(start, end - start + 1));
56 result.reserve(((data.size() + 2) / 3) * 4);
61 for (uint8_t c : data)
67 result.push_back(BASE64_CHARS[(val >> valb) & 0x3F]);
74 result.push_back(BASE64_CHARS[((val << 8) >> (valb + 8)) & 0x3F]);
77 while (result.size() % 4)
79 result.push_back(
'=');
86 -> std::vector<uint8_t>
88 std::vector<uint8_t> hash(SHA_DIGEST_LENGTH);
89 SHA1(
reinterpret_cast<const unsigned char*
>(data.c_str()), data.size(),
97 static thread_local std::random_device rd;
98 static thread_local std::mt19937 gen(rd());
99 static thread_local std::uniform_int_distribution<uint32_t> dis(0, 255);
101 std::vector<uint8_t> random_bytes(16);
102 for (
auto&
byte : random_bytes)
104 byte =
static_cast<uint8_t
>(dis(gen));
107 return base64_encode(random_bytes);
114 std::string combined = client_key;
115 combined.append(WEBSOCKET_GUID);
118 auto hash = sha1_hash(combined);
121 return base64_encode(hash);
125 -> std::map<std::string, std::string>
127 std::map<std::string, std::string> headers;
130 const auto header_end = http_message.find(
"\r\n\r\n");
131 if (header_end == std::string::npos)
136 std::istringstream
stream(http_message.substr(0, header_end));
140 std::getline(
stream, line);
143 while (std::getline(
stream, line))
146 if (!line.empty() && line.back() ==
'\r')
151 const auto colon_pos = line.find(
':');
152 if (colon_pos != std::string::npos)
154 auto name = to_lower(trim(line.substr(0, colon_pos)));
155 auto value = trim(line.substr(colon_pos + 1));
156 headers[name] = value;
167 const auto first_line_end = response.find(
"\r\n");
168 if (first_line_end == std::string::npos)
173 const auto status_line = response.substr(0, first_line_end);
174 std::istringstream
stream(status_line);
184 std::string_view host, std::string_view path, uint16_t port,
185 const std::map<std::string, std::string>& extra_headers) -> std::string
187 std::ostringstream request;
190 request <<
"GET " << path <<
" HTTP/1.1\r\n";
193 request <<
"Host: " << host;
194 if (port != 80 && port != 443)
196 request <<
":" << port;
201 request <<
"Upgrade: websocket\r\n";
202 request <<
"Connection: Upgrade\r\n";
203 request <<
"Sec-WebSocket-Key: " << generate_websocket_key() <<
"\r\n";
204 request <<
"Sec-WebSocket-Version: 13\r\n";
207 for (
const auto& [name, value] : extra_headers)
209 request << name <<
": " << value <<
"\r\n";
215 return request.str();
219 const std::string& response,
const std::string& expected_key)
226 const int status_code = extract_status_code(response);
227 if (status_code != 101)
230 "Invalid status code: " + std::to_string(status_code);
235 result.
headers = parse_headers(response);
238 auto it = result.
headers.find(
"upgrade");
239 if (it == result.
headers.end() || to_lower(it->second) !=
"websocket")
246 it = result.
headers.find(
"connection");
247 if (it == result.
headers.end() || to_lower(it->second) !=
"upgrade")
249 result.
error_message =
"Missing or invalid Connection header";
254 it = result.
headers.find(
"sec-websocket-accept");
255 if (it == result.
headers.end())
257 result.
error_message =
"Missing Sec-WebSocket-Accept header";
261 const std::string expected_accept = calculate_accept_key(expected_key);
262 if (it->second != expected_accept)
279 result.
headers = parse_headers(request);
282 if (request.substr(0, 3) !=
"GET")
289 auto it = result.
headers.find(
"upgrade");
290 if (it == result.
headers.end() || to_lower(it->second) !=
"websocket")
297 it = result.
headers.find(
"connection");
298 if (it == result.
headers.end() || to_lower(it->second) !=
"upgrade")
300 result.
error_message =
"Missing or invalid Connection header";
305 it = result.
headers.find(
"sec-websocket-key");
306 if (it == result.
headers.end() || it->second.empty())
308 result.
error_message =
"Missing or empty Sec-WebSocket-Key header";
313 it = result.
headers.find(
"sec-websocket-version");
314 if (it == result.
headers.end() || it->second !=
"13")
316 result.
error_message =
"Missing or invalid Sec-WebSocket-Version header";
327 std::ostringstream response;
330 response <<
"HTTP/1.1 101 Switching Protocols\r\n";
333 response <<
"Upgrade: websocket\r\n";
334 response <<
"Connection: Upgrade\r\n";
335 response <<
"Sec-WebSocket-Accept: " << calculate_accept_key(client_key)
341 return response.str();
static auto generate_websocket_key() -> std::string
Generates a random Sec-WebSocket-Key (client-side).
static auto base64_encode(const std::vector< uint8_t > &data) -> std::string
Encodes binary data to Base64.
static auto extract_status_code(const std::string &response) -> int
Extracts the status code from an HTTP response.
static auto create_server_response(const std::string &client_key) -> std::string
Creates a WebSocket handshake response (server-side).
static auto create_client_handshake(std::string_view host, std::string_view path, uint16_t port, const std::map< std::string, std::string > &extra_headers={}) -> std::string
Creates a WebSocket handshake request (client-side).
static auto validate_server_response(const std::string &response, const std::string &expected_key) -> ws_handshake_result
Validates a WebSocket handshake response (client-side).
static auto parse_headers(const std::string &http_message) -> std::map< std::string, std::string >
Parses HTTP headers from a request or response.
static auto sha1_hash(const std::string &data) -> std::vector< uint8_t >
Computes SHA-1 hash of input data.
static auto parse_client_request(const std::string &request) -> ws_handshake_result
Parses a WebSocket handshake request (server-side).
static auto calculate_accept_key(const std::string &client_key) -> std::string
Calculates Sec-WebSocket-Accept from client key.
QUIC stream implementation (RFC 9000 Sections 2-4)
http_version
HTTP protocol version.
Result of a WebSocket handshake operation.
bool success
Whether the handshake was successful.
std::map< std::string, std::string > headers
Parsed HTTP headers.
std::string error_message
Error message if success is false.