💾 Archived View for mozz.us › jetforce › jetforce › app › base.py captured on 2024-02-05 at 10:13:25.

View Raw

More Information

⬅️ Previous capture (2023-03-20)

-=-=-=-=-=-=-

from __future__ import annotations

import dataclasses
import re
import time
import typing
from collections import defaultdict
from urllib.parse import unquote, urlparse

from twisted.internet.defer import Deferred

EnvironDict = typing.Dict[str, typing.Any]
ResponseType = typing.Union[str, bytes, Deferred]
ApplicationResponse = typing.Iterable[ResponseType]
WriteStatusCallable = typing.Callable[[int, str], None]
ApplicationCallable = typing.Callable[
    [EnvironDict, WriteStatusCallable], ApplicationResponse
]


class Status:
    """
    Gemini response status codes.
    """

    INPUT = 10
    SENSITIVE_INPUT = 11

    SUCCESS = 20

    REDIRECT_TEMPORARY = 30
    REDIRECT_PERMANENT = 31

    TEMPORARY_FAILURE = 40
    SERVER_UNAVAILABLE = 41
    CGI_ERROR = 42
    PROXY_ERROR = 43
    SLOW_DOWN = 44

    PERMANENT_FAILURE = 50
    NOT_FOUND = 51
    GONE = 52
    PROXY_REQUEST_REFUSED = 53
    BAD_REQUEST = 59

    CLIENT_CERTIFICATE_REQUIRED = 60
    CERTIFICATE_NOT_AUTHORISED = 61
    CERTIFICATE_NOT_VALID = 62


class Request:
    """
    Object that encapsulates information about a single gemini request.
    """

    environ: EnvironDict
    url: str
    scheme: str
    hostname: str
    port: int | None
    path: str
    params: str
    query: str
    fragment: str

    def __init__(self, environ: EnvironDict):
        self.environ = environ
        self.url = typing.cast(str, environ["GEMINI_URL"])

        url_parts = urlparse(self.url)
        if not url_parts.hostname:
            raise ValueError("Missing hostname component")

        if not url_parts.scheme:
            raise ValueError("Missing scheme component")

        self.scheme = url_parts.scheme

        # gemini://username@host/... is forbidden by the specification
        if self.scheme == "gemini" and url_parts.username:
            raise ValueError("Invalid userinfo component")

        # Convert domain names to punycode for compatibility with URLs that
        # contain encoded IDNs (follows RFC 3490).
        hostname = url_parts.hostname
        hostname = hostname.encode("idna").decode("ascii")

        self.hostname = hostname
        self.port = url_parts.port

        self.path = unquote(url_parts.path)
        self.params = unquote(url_parts.params)
        self.query = unquote(url_parts.query)
        self.fragment = unquote(url_parts.fragment)


@dataclasses.dataclass
class Response:
    """
    Object that encapsulates information about a single gemini response.
    """

    status: int
    meta: str
    body: None | ResponseType | ApplicationResponse = None


RouteHandler = typing.Callable[..., Response]


@dataclasses.dataclass
class RoutePattern:
    """
    A pattern for matching URLs with a single endpoint or route.
    """

    path: str = ".*"
    scheme: str = "gemini"
    hostname: str | None = None

    strict_hostname: bool = True
    strict_port: bool = True
    strict_trailing_slash: bool = False

    def match(self, request: Request) -> re.Match[str] | None:
        """
        Check if the given request URL matches this route pattern.
        """
        if self.hostname is None:
            server_hostname = request.environ["HOSTNAME"]
        else:
            server_hostname = self.hostname
        server_port = request.environ["SERVER_PORT"]

        if self.strict_hostname and request.hostname != server_hostname:
            return None
        if self.strict_port and request.port is not None:
            if request.port != server_port:
                return None
        if self.scheme and self.scheme != request.scheme:
            return None

        if self.strict_trailing_slash:
            request_path = request.path
        else:
            request_path = request.path.rstrip("/")

        return re.fullmatch(self.path, request_path)


class RateLimiter:
    """
    A class that can be used to apply rate-limiting to endpoints.

    Rates are defined as human-readable strings, e.g.

        "5/s (5 requests per-second)
        "10/5m" (10 requests per-5 minutes)
        "100/2h" (100 requests per-2 hours)
        "1000/d" (1k requests per-day)
    """

    RE = re.compile("(?P<number>[0-9]+)/(?P<period>[0-9]+)?(?P<unit>[smhd])")

    number: int
    period: int
    next_timestamp: float
    rate_counter: dict[typing.Any, int]

    def __init__(self, rate: str) -> None:
        match = self.RE.fullmatch(rate)
        if not match:
            raise ValueError(f"Invalid rate format: {rate}")

        rate_data = match.groupdict()

        self.number = int(rate_data["number"])
        self.period = int(rate_data["period"] or 1)
        if rate_data["unit"] == "m":
            self.period *= 60
        elif rate_data["unit"] == "h":
            self.period += 60 * 60
        elif rate_data["unit"] == "d":
            self.period *= 60 * 60 * 24

        self.reset()

    def reset(self) -> None:
        self.next_timestamp = time.time() + self.period
        self.rate_counter = defaultdict(int)

    def get_key(self, request: Request) -> typing.Any:
        """
        Rate limit based on the client's IP-address.
        """
        return request.environ["REMOTE_ADDR"]

    def check(self, request: Request) -> Response | None:
        """
        Check if the given request should be rate limited.

        This method will return a failure response if the request should be
        rate limited.
        """
        time_left = self.next_timestamp - time.time()
        if time_left < 0:
            self.reset()

        key = self.get_key(request)
        if key is not None:
            self.rate_counter[key] += 1
            if self.rate_counter[key] > self.number:
                msg = f"Rate limit exceeded, wait {time_left:.0f} seconds."
                return Response(Status.SLOW_DOWN, msg)

        return None

    def apply(self, wrapped_func: RouteHandler) -> RouteHandler:
        """
        Decorator to apply rate limiting to an individual application route.

        Usage:
            rate_limiter = RateLimiter("10/m")

            @app.route("/endpoint")
            @rate_limiter.apply
            def my_endpoint(request):
                return Response(Status.SUCCESS, "text/gemini", "hello world!")
        """

        def wrapper(request: Request, **kwargs: typing.Any) -> Response:
            response = self.check(request)
            if response:
                return response
            return wrapped_func(request, **kwargs)

        return wrapper


class JetforceApplication:
    """
    Base Jetforce application class with primitive URL routing.

    This is a base class for writing jetforce server applications. It doesn't do
    anything on its own, but it does provide a convenient interface to define
    custom server endpoints using route decorators. If you want to utilize
    jetforce as a library and write your own server in python, this is the class
    that you want to extend. The examples/ directory contains some examples of
    how to accomplish this.
    """

    rate_limiter: RateLimiter | None
    routes: list[tuple[RoutePattern, RouteHandler]]

    request_class: type[Request] = Request

    def __init__(self, rate_limiter: RateLimiter | None = None):
        self.rate_limiter = rate_limiter
        self.routes = []

    def __call__(
        self, environ: EnvironDict, send_status: WriteStatusCallable
    ) -> ApplicationResponse:
        try:
            request = self.request_class(environ)
        except Exception:
            send_status(Status.BAD_REQUEST, "Invalid URL")
            return

        if self.rate_limiter:
            response = self.rate_limiter.check(request)
            if response:
                send_status(response.status, response.meta)
                return

        for route_pattern, callback in self.routes[::-1]:
            match = route_pattern.match(request)
            if match:
                callback_kwargs = match.groupdict()
                break
        else:
            callback = self.default_callback
            callback_kwargs = {}

        response = callback(request, **callback_kwargs)
        send_status(response.status, response.meta)

        if isinstance(response.body, (bytes, str, Deferred)):
            yield response.body
        elif response.body:
            yield from response.body

    def route(
        self,
        path: str = ".*",
        scheme: str = "gemini",
        hostname: str | None = None,
        strict_hostname: bool = True,
        strict_port: bool = True,
        strict_trailing_slash: bool = False,
    ) -> typing.Callable[[RouteHandler], RouteHandler]:
        """
        Decorator for binding a function to a route based on the URL path.

            app = JetforceApplication()

            @app.route('/my-path')
            def my_path(request):
                return Response(Status.SUCCESS, 'text/plain', 'Hello world!')
        """
        route_pattern = RoutePattern(
            path=path,
            scheme=scheme,
            hostname=hostname,
            strict_hostname=strict_hostname,
            strict_port=strict_port,
            strict_trailing_slash=strict_trailing_slash,
        )

        def wrap(func: RouteHandler) -> RouteHandler:
            self.routes.append((route_pattern, func))
            return func

        return wrap

    def default_callback(self, request: Request, **_: typing.Any) -> Response:
        """
        Set the error response based on the URL type.
        """
        return Response(Status.PERMANENT_FAILURE, "Not Found")