Network System 0.1.1
High-performance modular networking library for scalable client-server applications
Loading...
Searching...
No Matches
rate_limiter.h
Go to the documentation of this file.
1// BSD 3-Clause License
2// Copyright (c) 2021-2025, 🍀☀🌕🌥 🌊
3// See the LICENSE file in the project root for full license information.
4
32#pragma once
33
34#include <chrono>
35#include <mutex>
36#include <shared_mutex>
37#include <unordered_map>
38#include <string>
39#include <string_view>
40#include <atomic>
41#include <cstddef>
42#include <cstdint>
43
44namespace kcenon::network {
45
52
54 size_t burst_size = 20;
55
57 std::chrono::seconds window = std::chrono::seconds(1);
58
60 bool auto_cleanup = true;
61
63 std::chrono::seconds stale_timeout = std::chrono::seconds(300);
64};
65
83public:
89 : config_(std::move(config))
90 , last_cleanup_(std::chrono::steady_clock::now()) {}
91
98 [[nodiscard]] bool allow(std::string_view client_id) {
99 std::unique_lock<std::shared_mutex> lock(mutex_);
100
101 auto now = std::chrono::steady_clock::now();
102
103 // Periodic cleanup
104 if (config_.auto_cleanup) {
105 maybe_cleanup(now);
106 }
107
108 std::string key(client_id);
109 auto& bucket = buckets_[key];
110
111 // Initialize new bucket
112 if (bucket.last_refill.time_since_epoch().count() == 0) {
113 bucket.tokens = static_cast<double>(config_.burst_size);
114 bucket.last_refill = now;
115 }
116
117 // Refill tokens based on elapsed time
118 auto elapsed = std::chrono::duration_cast<std::chrono::milliseconds>(
119 now - bucket.last_refill);
120
121 double tokens_to_add = elapsed.count() *
122 (static_cast<double>(config_.max_requests_per_second) / 1000.0);
123
124 bucket.tokens = std::min(
125 static_cast<double>(config_.burst_size),
126 bucket.tokens + tokens_to_add);
127 bucket.last_refill = now;
128
129 // Try to consume a token
130 if (bucket.tokens >= 1.0) {
131 bucket.tokens -= 1.0;
132 return true;
133 }
134
135 return false;
136 }
137
149 [[nodiscard]] bool allow(std::string_view client_id, std::string_view session_id) {
150 if (session_id.empty()) {
151 return allow(client_id);
152 }
153 std::string composite_key;
154 composite_key.reserve(client_id.size() + 1 + session_id.size());
155 composite_key.append(client_id);
156 composite_key.push_back(':');
157 composite_key.append(session_id);
158 return allow(std::string_view(composite_key));
159 }
160
167 [[nodiscard]] bool would_allow(std::string_view client_id) const {
168 std::shared_lock<std::shared_mutex> lock(mutex_);
169
170 auto it = buckets_.find(std::string(client_id));
171 if (it == buckets_.end()) {
172 return true; // New client would have full bucket
173 }
174
175 auto now = std::chrono::steady_clock::now();
176 auto elapsed = std::chrono::duration_cast<std::chrono::milliseconds>(
177 now - it->second.last_refill);
178
179 double tokens_to_add = elapsed.count() *
180 (static_cast<double>(config_.max_requests_per_second) / 1000.0);
181
182 double available = std::min(
183 static_cast<double>(config_.burst_size),
184 it->second.tokens + tokens_to_add);
185
186 return available >= 1.0;
187 }
188
195 [[nodiscard]] double remaining_tokens(std::string_view client_id) const {
196 std::shared_lock<std::shared_mutex> lock(mutex_);
197
198 auto it = buckets_.find(std::string(client_id));
199 if (it == buckets_.end()) {
200 return static_cast<double>(config_.burst_size);
201 }
202
203 auto now = std::chrono::steady_clock::now();
204 auto elapsed = std::chrono::duration_cast<std::chrono::milliseconds>(
205 now - it->second.last_refill);
206
207 double tokens_to_add = elapsed.count() *
208 (static_cast<double>(config_.max_requests_per_second) / 1000.0);
209
210 return std::min(
211 static_cast<double>(config_.burst_size),
212 it->second.tokens + tokens_to_add);
213 }
214
219 void reset(std::string_view client_id) {
220 std::unique_lock<std::shared_mutex> lock(mutex_);
221 buckets_.erase(std::string(client_id));
222 }
223
227 void reset_all() {
228 std::unique_lock<std::shared_mutex> lock(mutex_);
229 buckets_.clear();
230 }
231
236 [[nodiscard]] size_t client_count() const {
237 std::shared_lock<std::shared_mutex> lock(mutex_);
238 return buckets_.size();
239 }
240
246 std::unique_lock<std::shared_mutex> lock(mutex_);
247 config_ = std::move(config);
248 }
249
254 [[nodiscard]] rate_limiter_config config() const {
255 std::shared_lock<std::shared_mutex> lock(mutex_);
256 return config_;
257 }
258
259private:
260 struct bucket {
261 double tokens = 0.0;
262 std::chrono::steady_clock::time_point last_refill{};
263 };
264
265 void maybe_cleanup(std::chrono::steady_clock::time_point now) {
266 // Cleanup at most once per minute
267 auto since_cleanup = std::chrono::duration_cast<std::chrono::seconds>(
268 now - last_cleanup_);
269
270 if (since_cleanup < std::chrono::seconds(60)) {
271 return;
272 }
273
274 last_cleanup_ = now;
275
276 // Remove stale entries
277 for (auto it = buckets_.begin(); it != buckets_.end();) {
278 auto since_refill = std::chrono::duration_cast<std::chrono::seconds>(
279 now - it->second.last_refill);
280
281 if (since_refill > config_.stale_timeout) {
282 it = buckets_.erase(it);
283 } else {
284 ++it;
285 }
286 }
287 }
288
290 mutable std::shared_mutex mutex_;
291 std::unordered_map<std::string, bucket> buckets_;
292 std::chrono::steady_clock::time_point last_cleanup_;
293};
294
306public:
311 explicit connection_limiter(size_t max_connections = 1000)
312 : max_connections_(max_connections)
314
319 [[nodiscard]] bool can_accept() const noexcept {
320 return current_connections_.load(std::memory_order_acquire) < max_connections_;
321 }
322
327 [[nodiscard]] bool try_accept() noexcept {
328 size_t current = current_connections_.load(std::memory_order_acquire);
329 while (current < max_connections_) {
330 if (current_connections_.compare_exchange_weak(
331 current, current + 1,
332 std::memory_order_acq_rel,
333 std::memory_order_acquire)) {
334 return true;
335 }
336 }
337 return false;
338 }
339
345 void on_connect() noexcept {
346 current_connections_.fetch_add(1, std::memory_order_release);
347 }
348
352 void on_disconnect() noexcept {
353 size_t prev = current_connections_.fetch_sub(1, std::memory_order_release);
354 // Prevent underflow (shouldn't happen in correct usage)
355 if (prev == 0) {
356 current_connections_.fetch_add(1, std::memory_order_release);
357 }
358 }
359
364 [[nodiscard]] size_t current() const noexcept {
365 return current_connections_.load(std::memory_order_acquire);
366 }
367
372 [[nodiscard]] size_t max() const noexcept {
373 return max_connections_;
374 }
375
380 void set_max(size_t max_connections) noexcept {
381 max_connections_ = max_connections;
382 }
383
388 [[nodiscard]] size_t available() const noexcept {
389 size_t current = current_connections_.load(std::memory_order_acquire);
391 }
392
397 [[nodiscard]] bool at_capacity() const noexcept {
398 return current_connections_.load(std::memory_order_acquire) >= max_connections_;
399 }
400
401private:
403 std::atomic<size_t> current_connections_;
404};
405
412public:
418 : limiter_(&limiter)
419 , accepted_(limiter.try_accept()) {}
420
424
427 : limiter_(other.limiter_)
428 , accepted_(other.accepted_) {
429 other.accepted_ = false;
430 }
431
433 if (this != &other) {
434 release();
435 limiter_ = other.limiter_;
436 accepted_ = other.accepted_;
437 other.accepted_ = false;
438 }
439 return *this;
440 }
441
443 release();
444 }
445
450 [[nodiscard]] bool accepted() const noexcept {
451 return accepted_;
452 }
453
457 explicit operator bool() const noexcept {
458 return accepted_;
459 }
460
464 void release() noexcept {
465 if (accepted_ && limiter_) {
467 accepted_ = false;
468 }
469 }
470
471private:
474};
475
482public:
489 size_t max_per_client = 10,
490 size_t max_total = 1000)
491 : max_per_client_(max_per_client)
492 , total_limiter_(max_total) {}
493
499 [[nodiscard]] bool try_accept(std::string_view client_id) {
500 // First check total limit
501 if (!total_limiter_.can_accept()) {
502 return false;
503 }
504
505 std::unique_lock<std::mutex> lock(mutex_);
506 std::string key(client_id);
507
508 auto& count = client_connections_[key];
509 if (count >= max_per_client_) {
510 return false;
511 }
512
513 if (!total_limiter_.try_accept()) {
514 return false;
515 }
516
517 ++count;
518 return true;
519 }
520
525 void release(std::string_view client_id) {
526 std::unique_lock<std::mutex> lock(mutex_);
527 std::string key(client_id);
528
529 auto it = client_connections_.find(key);
530 if (it != client_connections_.end() && it->second > 0) {
531 --it->second;
532 if (it->second == 0) {
533 client_connections_.erase(it);
534 }
536 }
537 }
538
544 [[nodiscard]] size_t client_connections(std::string_view client_id) const {
545 std::unique_lock<std::mutex> lock(mutex_);
546 auto it = client_connections_.find(std::string(client_id));
547 return (it != client_connections_.end()) ? it->second : 0;
548 }
549
554 [[nodiscard]] size_t total_connections() const noexcept {
555 return total_limiter_.current();
556 }
557
558private:
561 mutable std::mutex mutex_;
562 std::unordered_map<std::string, size_t> client_connections_;
563};
564
565} // namespace kcenon::network
RAII guard for connection limiting.
connection_guard & operator=(const connection_guard &)=delete
bool accepted() const noexcept
Check if connection was accepted.
connection_guard & operator=(connection_guard &&other) noexcept
connection_guard(connection_limiter &limiter)
Construct guard and try to accept connection.
connection_guard(connection_guard &&other) noexcept
Movable.
void release() noexcept
Release the connection early.
connection_guard(const connection_guard &)=delete
Non-copyable.
Connection count limiter.
bool try_accept() noexcept
Try to accept a connection.
void on_disconnect() noexcept
Unregister a connection.
connection_limiter(size_t max_connections=1000)
Construct connection limiter.
size_t max() const noexcept
Get maximum connection limit.
void set_max(size_t max_connections) noexcept
Set maximum connection limit.
std::atomic< size_t > current_connections_
bool can_accept() const noexcept
Check if a new connection can be accepted.
size_t available() const noexcept
Get available connection slots.
size_t current() const noexcept
Get current connection count.
void on_connect() noexcept
Register a new connection.
bool at_capacity() const noexcept
Check if at capacity.
void release(std::string_view client_id)
Release a connection from a client.
size_t client_connections(std::string_view client_id) const
Get connection count for a client.
std::unordered_map< std::string, size_t > client_connections_
per_client_connection_limiter(size_t max_per_client=10, size_t max_total=1000)
Construct limiter.
bool try_accept(std::string_view client_id)
Try to accept a connection from a client.
size_t total_connections() const noexcept
Get total connection count.
Token bucket rate limiter.
rate_limiter(rate_limiter_config config={})
Construct rate limiter with configuration.
size_t client_count() const
Get number of tracked clients.
void set_config(rate_limiter_config config)
Update configuration.
bool allow(std::string_view client_id)
Check if request should be allowed.
void reset(std::string_view client_id)
Reset rate limit for a specific client.
rate_limiter_config config() const
Get current configuration.
bool would_allow(std::string_view client_id) const
Check if request would be allowed (without consuming token)
void reset_all()
Reset all rate limits.
rate_limiter_config config_
void maybe_cleanup(std::chrono::steady_clock::time_point now)
bool allow(std::string_view client_id, std::string_view session_id)
Check if request should be allowed (session-aware)
std::unordered_map< std::string, bucket > buckets_
std::chrono::steady_clock::time_point last_cleanup_
double remaining_tokens(std::string_view client_id) const
Get remaining tokens for a client.
Main namespace for all Network System components.
std::chrono::steady_clock::time_point last_refill
Configuration for rate limiter.
size_t burst_size
Maximum burst size (token bucket capacity)
std::chrono::seconds stale_timeout
Stale entry expiration time.
size_t max_requests_per_second
Maximum requests per second.
bool auto_cleanup
Enable automatic cleanup of stale entries.
std::chrono::seconds window
Time window for rate calculation.