💾 Archived View for alchemi.dev › en › projects › kochab › files › src › ratelimiting.rs captured on 2022-07-16 at 16:43:05.

View Raw

More Information

-=-=-=-=-=-=-

use dashmap::DashMap;

use std::{fmt::Display, collections::VecDeque, hash::Hash, time::{Duration, Instant}};

/// A simple struct to manage rate limiting.
///
/// Does not require a leaky bucket thread to empty it out, but may occassionally need to
/// trim old keys using [`trim_keys()`].
///
/// [`trim_keys()`]: Self::trim_keys()
pub struct RateLimiter<K: Eq + Hash> {
    log: DashMap<K, VecDeque<Instant>>,
    burst: usize,
    period: Duration,
}

impl<K: Eq + Hash> RateLimiter<K> {
    /// Create a new ratelimiter that allows at most `burst` connections in `period`
    pub fn new(period: Duration, burst: usize) -> Self {
        Self {
            log: DashMap::with_capacity(8),
            period,
            burst,
        }
    }

    /// Check if a key may pass
    ///
    /// If the key has made less than `self.burst` connections in the last `self.period`,
    /// then the key is allowed to connect, which is denoted by an `Ok` result.  This will
    /// register as a new connection from that key.
    ///
    /// If the key is not allowed to connect, than a [`Duration`] denoting the amount of
    /// time until the key is permitted is returned, wrapped in an `Err`
    pub fn check_key(&self, key: K) -> Result<(), Duration> {
        let now = Instant::now();
        let count_after = now - self.period;

        let mut connections = self.log.entry(key)
            .or_insert_with(||VecDeque::with_capacity(self.burst));
        let connections = connections.value_mut();

        // Chcek if space can be made available.  We don't need to trim all expired
        // connections, just the one in question to allow this connection.
        if let Some(earliest_conn) = connections.front() {
            if earliest_conn < &count_after {
                connections.pop_front();
            }
        }

        // Check if the connection should be allowed
        if connections.len() == self.burst {
            Err(connections[0] + self.period - now)
        } else {
            connections.push_back(now);
            Ok(())
        }
    }

    /// Remove any expired keys from the ratelimiter
    ///
    /// This only needs to be called if keys are continuously being added.  If keys are
    /// being reused, or come from a finite set, then you don't need to worry about this.
    ///
    /// If you have many keys coming from a large set, you should infrequently call this
    /// to prevent a memory leak.
    ///
    /// If debug level logging is enabled, this prints an *approximate* number of keys
    /// removed to the log.  For more precise output, use [`trim_keys_verbose()`]
    ///
    /// [`trim_keys_verbose()`]: RateLimiter::trim_keys_verbose()
    pub fn trim_keys(&self) {
        let count_after = Instant::now() - self.period;

        let len: isize = self.log.len() as isize;
        self.log.retain(|_, conns| conns.back().unwrap() > &count_after);
        let removed = len - self.log.len() as isize;
        if removed.is_positive() {
            debug!("Pruned approximately {} expired ratelimit keys", removed);
        }
    }
}

impl<K: Eq + Hash + Display> RateLimiter<K> {

    /// Remove any expired keys from the ratelimiter
    ///
    /// This only needs to be called if keys are continuously being added.  If keys are
    /// being reused, or come from a finite set, then you don't need to worry about this.
    ///
    /// If you have many keys coming from a large set, you should infrequently call this
    /// to prevent a memory leak.
    ///
    /// If debug level logging is on, this prints out any removed keys.
    pub fn trim_keys_verbose(&self) {
        let count_after = Instant::now() - self.period;

        self.log.retain(|ip, conns| {
            let should_keep = conns.back().unwrap() > &count_after;
            if !should_keep {
                debug!("Pruned expired ratelimit key: {}", ip);
            }
            should_keep
        });
    }
}