Rust REST API Rate Limiting with Sliding Window Algorithm per IP
Rust REST API Rate Limiting with Sliding Window Algorithm per IP
Rate limiting is crucial for protecting REST APIs from abuse, ensuring fair usage, and maintaining service availability. The sliding window algorithm is a popular choice for rate limiting due to its accuracy and flexibility. This article will guide you through implementing a rate limiter using the sliding window algorithm in Rust, specifically limiting requests per IP address.
Understanding the Sliding Window Algorithm
The sliding window algorithm tracks requests within a fixed-size time window. Unlike simpler algorithms like token bucket or leaky bucket, the sliding window provides a more accurate rate limiting by considering the timestamps of individual requests. Here's how it works:
- Timestamped Requests: Each incoming request's timestamp is recorded.
- Fixed-Size Window: A time window of a predefined duration (e.g., 1 minute) is maintained.
- Request Count: The algorithm counts the number of requests within the current window.
- Sliding the Window: As time progresses, the window slides forward, automatically excluding older requests.
- Rate Limit Enforcement: If the request count exceeds the defined limit within the window, the request is rejected.
This approach effectively distributes the rate limit over time, preventing sudden bursts of traffic from exhausting the limit.
Implementation in Rust
Let's create a basic implementation in Rust. We'll use std::collections::HashMap to store request timestamps for each IP address and std::time::Instant to measure time.
use std::collections::HashMap;
use std::time::{Duration, Instant};
use std::net::IpAddr;
#[derive(Debug)]
struct RateLimiter {
limit: u32,
window: Duration,
requests: HashMap<IpAddr, Vec<Instant>>,
}
impl RateLimiter {
fn new(limit: u32, window: Duration) -> Self {
RateLimiter {
limit,
window,
requests: HashMap::new(),
}
}
fn is_allowed(&mut self, ip: IpAddr) -> bool {
let now = Instant::now();
// Clean up old requests
self.cleanup(ip, now);
// Get the request history for this IP
let request_history = self.requests.entry(ip).or_insert(Vec::new());
// Check if the limit is exceeded
if request_history.len() >= self.limit as usize {
return false;
}
// Add the new request
request_history.push(now);
true
}
fn cleanup(&mut self, ip: IpAddr, now: Instant) {
let window_start = now - self.window;
if let Some(request_history) = self.requests.get_mut(&ip) {
request_history.retain(|&request_time| request_time >= window_start);
}
}
}
#[tokio::main]
async fn main() {
use std::net::Ipv4Addr;
let mut rate_limiter = RateLimiter::new(5, Duration::from_secs(60)); // 5 requests per 60 seconds
let ip1 = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
let ip2 = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100));
for _ in 0..7 {
if rate_limiter.is_allowed(ip1) {
println!("Request from {} allowed", ip1);
} else {
println!("Request from {} blocked", ip1);
}
tokio::time::sleep(Duration::from_secs(5)).await;
}
for _ in 0..3 {
if rate_limiter.is_allowed(ip2) {
println!("Request from {} allowed", ip2);
} else {
println!("Request from {} blocked", ip2);
}
tokio::time::sleep(Duration::from_secs(5)).await;
}
}
Explanation:
RateLimiterstruct:limit: The maximum number of requests allowed within the window.window: The duration of the sliding window.requests: AHashMapwhere the key is the IP address (IpAddr) and the value is a vector ofInstantrepresenting the timestamps of requests from that IP.
new()function: Creates a newRateLimiterinstance with the specified limit and window duration.is_allowed()function:- Gets the current time using
Instant::now(). - Calls
cleanup()to remove outdated requests from the request history. - Retrieves the request history for the given IP address using
self.requests.entry(ip).or_insert(Vec::new()). This either retrieves the existing vector or creates a new one if it doesn't exist. - Checks if the number of requests in the history exceeds the limit. If it does, the function returns
false(request blocked). - If the request is allowed, the current timestamp is added to the request history, and the function returns
true.
- Gets the current time using
cleanup()function:- Calculates the start of the current window (
window_start). - Retrieves the request history for the given IP address.
- Uses
retain()to keep only the timestamps that fall within the current window (i.e., timestamps greater than or equal towindow_start). This efficiently removes outdated requests.
- Calculates the start of the current window (
Considerations for Production
This is a basic in-memory implementation. For production use, you'll need to address several key considerations:
- Concurrency: The
HashMapis not thread-safe. You'll need to use aMutexorRwLockto protect it from concurrent access in a multi-threaded environment like a web server. - Persistence: In-memory data is lost when the application restarts. You'll need to persist the request history to a database (e.g., Redis, PostgreSQL) to maintain rate limiting across restarts and for scalability.
- IP Address Extraction: You'll need to extract the IP address from the incoming request. The method for doing this depends on the web framework you're using (e.g., Actix-web, Rocket).
- Error Handling: Add proper error handling to deal with potential issues like database connection errors or invalid IP addresses.
- Configuration: Externalize the rate limit and window duration as configuration parameters to allow for easy adjustments without code changes.
Production-Ready Example (Illustrative)
This example shows how to integrate the rate limiter with Actix-web, using Redis for persistence and tokio::sync::Mutex for concurrency.
// Requires adding dependencies to Cargo.toml:
// actix-web = "4"
// redis = "0.21"
// tokio = { version = "1", features = ["full"] }
// serde = { version = "1", features = ["derive"] }
// serde_json = "1"
use actix_web::{web, App, HttpResponse, HttpServer, Responder, middleware::Logger};
use redis::{Client, Commands};
use std::net::IpAddr;
use std::str::FromStr;
use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
use tokio::sync::Mutex;
use serde::{Serialize, Deserialize};
use serde_json;
#[derive(Clone)]
struct AppState {
redis_client: Client,
rate_limiter: web::Data<Mutex<RateLimiter>>,
}
#[derive(Debug)]
struct RateLimiter {
limit: u32,
window: Duration,
}
#[derive(Serialize, Deserialize, Debug)]
struct RequestLog {
timestamp: u64, // Unix timestamp in seconds
}
impl RateLimiter {
fn new(limit: u32, window: Duration) -> Self {
RateLimiter {
limit,
window,
}
}
async fn is_allowed(&self, state: &AppState, ip: IpAddr) -> Result<bool, redis::RedisError> {
let mut conn = state.redis_client.get_connection()?;
let key = format!("rate_limit:{}", ip);
let now = SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs();
// Fetch existing logs
let logs_str: Option<String> = conn.get(&key)?;
let mut logs: Vec<RequestLog> = logs_str
.map(|s| serde_json::from_str(&s).unwrap_or_else(|_| Vec::new()))
.unwrap_or_else(|| Vec::new());
// Clean up old logs
let window_start = now - self.window.as_secs();
logs.retain(|log| log.timestamp >= window_start);
// Check if limit is exceeded
if logs.len() >= self.limit as usize {
return Ok(false);
}
// Add new log
logs.push(RequestLog { timestamp: now });
// Persist logs back to Redis
let logs_str = serde_json::to_string(&logs).unwrap();
conn.set(&key, logs_str)?; // Store the logs as a JSON string
// Set expiry for the key
conn.expire(&key, self.window.as_secs() as usize)?; // Expire the key after the window duration
Ok(true)
}
}
async fn index(req: actix_web::HttpRequest, data: web::Data<AppState>) -> impl Responder {
let ip_string = req.connection_info().realip_remote_addr().unwrap_or("0.0.0.0:0");
let ip = IpAddr::from_str(ip_string.split(':').next().unwrap_or("0.0.0.0")).unwrap();
let mut rate_limiter = data.rate_limiter.lock().await;
match rate_limiter.is_allowed(&data, ip).await {
Ok(true) => HttpResponse::Ok().body("Request allowed"),
Ok(false) => HttpResponse::TooManyRequests().body("Rate limit exceeded"),
Err(e) => {
eprintln!("Redis error: {:?}", e);
HttpResponse::InternalServerError().body("Internal Server Error")
}
}
}
#[actix_web::main]
async fn main() -> std::io::Result<()> {
env_logger::init_from_env(env_logger::Env::default().default_filter_or("info"));
let redis_client = redis::Client::open("redis://127.0.0.1/").expect("Failed to connect to Redis");
let rate_limiter = web::Data::new(Mutex::new(RateLimiter::new(5, Duration::from_secs(60)))); // 5 requests per minute
let app_state = AppState {
redis_client: redis_client.clone(),
rate_limiter: rate_limiter.clone(),
};
HttpServer::new(move || {
App::new()
.app_data(web::Data::new(app_state.clone()))
.wrap(Logger::default())
.route("/", web::get().to(index))
})
.bind("127.0.0.1:8080")?
.run()
.await
}
Key improvements in the production-ready example:
- Redis Persistence: The request history is stored in Redis, allowing the rate limiter to persist across application restarts.
- Concurrency Control: A
tokio::sync::Mutexis used to protect theRateLimiterfrom concurrent access by multiple threads. - Actix-web Integration: The example integrates with the Actix-web framework to handle HTTP requests.
- IP Address Extraction: The code extracts the IP address from the
HttpRequestobject. - Error Handling: Basic error handling is included for Redis operations.
- JSON Serialization: Uses
serdeandserde_jsonto serialize and deserialize the request logs for storage in Redis. - Unix Timestamps: Stores timestamps as Unix timestamps (seconds since the epoch) for easier storage and retrieval in Redis.
- Key Expiry: Sets an expiry on the Redis key to automatically remove rate limit data after the window duration.
Further Enhancements
- Customizable Limits: Allow different rate limits for different API endpoints or user roles.
- Dynamic Configuration: Implement a mechanism to update rate limits dynamically without restarting the application.
- Metrics and Monitoring: Collect metrics on rate limiting activity (e.g., number of requests blocked) and monitor the performance of the rate limiter.
- Distributed Rate Limiting: For highly distributed systems, consider using a distributed rate limiting service like Redis Cluster or a dedicated rate limiting solution.
By following this guide, you can implement a robust and effective rate limiter for your Rust REST API using the sliding window algorithm. Remember to adapt the code to your specific requirements and environment.