💾 Archived View for alchemi.dev › en › projects › kochab › files › src › user_management › routes.rs captured on 2022-07-16 at 16:43:54.

View Raw

More Information

-=-=-=-=-=-=-

use anyhow::Result;
use serde::{Serialize, de::DeserializeOwned};

#[cfg(feature = "dashmap")]
use dashmap::DashMap;
#[cfg(not(feature = "dashmap"))]
use std::collections::HashMap;
#[cfg(not(feature = "dashmap"))]
use std::sync::RwLock;

use std::future::Future;

use crate::{Document, Request, Response};
use crate::types::document::HeadingLevel;
use crate::user_management::{
    User,
    RegisteredUser,
    UserManagerError,
    user::NotSignedInUser,
};

/// Import this trait to use [`add_um_routes()`](Self::add_um_routes())
pub trait UserManagementRoutes: private::Sealed {
    /// Add pre-configured routes to the serve to handle authentication
    ///
    /// Specifically, the following routes are added:
    /// * `/account`, the main settings & login page
    /// * `/account/askcert`, a page which always prompts for a certificate
    /// * `/account/register`, for users to register a new account
    /// * `/account/login`, for users to link their certificate to an existing account
    /// * `/account/password`, to change the user's password
    ///
    /// If this method is used, no more routes should be added under `/account`.  If you
    /// would like to direct a user to login from your application, you should send them
    /// to `/account`, which will start the login/registration flow.
    ///
    /// The `redir` argument allows you to specify the point that users will be directed
    /// to return to once their account has been created.
    fn add_um_routes<UserData: Serialize + DeserializeOwned + Default + 'static>(self) -> Self;

    /// Add a special route that requires users to be logged in
    ///
    /// In addition to the normal [`Request`], your handler will recieve a copy of the
    /// [`RegisteredUser`] for the current user.  If a user tries to connect to the page
    /// without logging in, they will be prompted to register or link an account.
    ///
    /// To use this method, ensure that [`add_um_routes()`](Self::add_um_routes()) has
    /// also been called.
    fn add_authenticated_route<UserData, Handler, F>(
        self,
        path: &'static str,
        handler: Handler,
    ) -> Self
    where
        UserData: Serialize + DeserializeOwned + 'static + Send + Sync,
        Handler: Clone + Send + Sync + 'static + Fn(Request, RegisteredUser<UserData>) -> F,
        F: Send + Sync + 'static + Future<Output = Result<Response>>;

    /// Add a special route that requires users to be logged in AND takes input
    ///
    /// Like with [`add_authenticated_route()`](Self::add_authenticated_route()), this
    /// prompts the user to log in if they haven't already, but additionally prompts the
    /// user for input before running the handler with both the user object and the input
    /// they provided.
    ///
    /// To a user, this might look something like this:
    /// * Click a link to `/your/route`
    /// * See a screen asking you to sign in or create an account
    /// * Create a new account, and return to the app.
    /// * Now, clicking the link shows the prompt provided.
    /// * After entering some value, the user receives the response from the handler.
    ///
    /// For a user whose already logged in, this will just look like a normal input route,
    /// where they enter some query and see a page.  This method just takes the burden of
    /// having to check if the user sent a query string and respond with an INPUT response
    /// if not.
    ///
    /// To use this method, ensure that [`add_um_routes()`](Self::add_um_routes()) has
    /// also been called.
    fn add_authenticated_input_route<UserData, Handler, F>(
        self,
        path: &'static str,
        prompt: &'static str,
        handler: Handler,
    ) -> Self
    where
        UserData: Serialize + DeserializeOwned + 'static + Send + Sync,
        Handler: Clone + Send + Sync + 'static + Fn(Request, RegisteredUser<UserData>, String) -> F,
        F: Send + Sync + 'static + Future<Output = Result<Response>>;
}

impl UserManagementRoutes for crate::Server {
    /// Add pre-configured routes to the serve to handle authentication
    ///
    /// See [`UserManagementRoutes::add_um_routes()`]
    fn add_um_routes<UserData: Serialize + DeserializeOwned + Default + 'static>(self) -> Self {
        let clients_page = Response::success_gemini(include_str!("pages/clients.gmi"));

        #[allow(unused_mut)]
        let mut modified_self = self.add_route("/account", handle_base::<UserData>)
            .add_route("/account/askcert", handle_ask_cert::<UserData>)
            .add_route("/account/register", handle_register::<UserData>)
            .add_route("/account/clients", clients_page);

        #[cfg(feature = "user_management_advanced")] {
            modified_self = modified_self
                .add_route("/account/login", handle_login::<UserData>)
                .add_route("/account/password", handle_password::<UserData>);
        }

        modified_self
    }

    /// Add a special route that requires users to be logged in
    ///
    /// See [`UserManagementRoutes::add_authenticated_route()`]
    fn add_authenticated_route<UserData, Handler, F>(
        self,
        path: &'static str,
        handler: Handler,
    ) -> Self
    where
        UserData: Serialize + DeserializeOwned + 'static + Send + Sync,
        Handler: Clone + Send + Sync + 'static + Fn(Request, RegisteredUser<UserData>) -> F,
        F: Send + Sync + 'static + Future<Output = Result<Response>>
    {
        self.add_route(path, move|request: Request| {
            let handler = handler.clone();
            async move {
                let segments = request.path_segments();
                let segments = segments.iter().map(String::as_ref).collect::<Vec<&str>>();
                Ok(match request.user::<UserData>()? {
                    User::Unauthenticated => {
                        render_unauth_page(segments)
                    },
                    User::NotSignedIn(user) => {
                        save_redirect(&user, segments);
                        Response::success_gemini(NSI)
                    },
                    User::SignedIn(user) => {
                        (handler)(request, user).await?
                    },
                })
            }
        })
    }

    /// Add a special route that requires users to be logged in AND takes input
    ///
    /// See [`UserManagementRoutes::add_authenticated_input_route()`]
    fn add_authenticated_input_route<UserData, Handler, F>(
        self,
        path: &'static str,
        prompt: &'static str,
        handler: Handler,
    ) -> Self
    where
        UserData: Serialize + DeserializeOwned + 'static + Send + Sync,
        Handler: Clone + Send + Sync + 'static + Fn(Request, RegisteredUser<UserData>, String) -> F,
        F: Send + Sync + 'static + Future<Output = Result<Response>>
    {
        self.add_authenticated_route(path, move|request, user| {
            let handler = handler.clone();
            async move {
                if let Some(input) = request.input().map(str::to_owned) {
                    (handler.clone())(request, user, input).await
                } else {
                    Ok(Response::input(prompt))
                }
            }
        })
    }
}

#[cfg(feature = "user_management_advanced")]
const NSI: &str = include_str!("pages/nsi.gmi");
#[cfg(not(feature = "user_management_advanced"))]
const NSI: &str = include_str!("pages/nopass/nsi.gmi");

// TODO periodically clean these
#[cfg(feature = "dashmap")]
lazy_static::lazy_static! {
    static ref PENDING_REDIRECTS: DashMap<[u8; 32], String> = Default::default();
}

#[cfg(not(feature = "dashmap"))]
lazy_static::lazy_static! {
    static ref PENDING_REDIRECTS: RwLock<HashMap<[u8; 32], String>> = Default::default();
}

async fn handle_base<UserData: Serialize + DeserializeOwned>(request: Request) -> Result<Response> {
    let segments = request.trailing_segments().iter().map(String::as_str).collect::<Vec<&str>>();
    Ok(match request.user::<UserData>()? {
        User::Unauthenticated => {
            render_unauth_page(segments)
        },
        User::NotSignedIn(usr) => {
            save_redirect(&usr, segments);
            Response::success_gemini(NSI)
        },
        User::SignedIn(user) => {
            render_settings_menu(user)
        },
    })
}

async fn handle_ask_cert<UserData: Serialize + DeserializeOwned>(request: Request) -> Result<Response> {
    Ok(match request.user::<UserData>()? {
        User::Unauthenticated => {
            Response::client_certificate_required("Please select a client certificate to proceed.")
        },
        User::NotSignedIn(nsi) => {
            let segments = request.trailing_segments().iter().map(String::as_str).collect::<Vec<&str>>();
            save_redirect(&nsi, segments);
            #[cfg(feature = "user_management_advanced")] {
                Response::success_gemini(include_str!("pages/askcert/success.gmi"))
            }
            #[cfg(not(feature = "user_management_advanced"))] {
                Response::success_gemini(include_str!("pages/nopass/askcert/success.gmi"))
            }
        },
        User::SignedIn(user) => {
            Response::success_gemini(format!(
                include_str!("pages/askcert/exists.gmi"),
                username = user.username(),
                redirect = get_redirect(&user),
            ))
        },
    })
}

async fn handle_register<UserData: Serialize + DeserializeOwned + Default>(request: Request) -> Result<Response> {
    Ok(match request.user::<UserData>()? {
        User::Unauthenticated => {
            render_unauth_page(&[""])
        },
        User::NotSignedIn(nsi) => {
            if let Some(username) = request.input() {
                match nsi.register::<UserData>(username.to_owned()) {
                    Err(UserManagerError::UsernameNotUnique) => {
                        #[cfg(feature = "user_management_advanced")] {
                            Response::success_gemini(format!(
                                include_str!("pages/register/exists.gmi"),
                                username = username,
                            ))
                        }
                        #[cfg(not(feature = "user_management_advanced"))] {
                            Response::success_gemini(format!(
                                include_str!("pages/register/exists.gmi"),
                                username = username,
                            ))
                        }
                    },
                    Ok(user) => {
                        #[cfg(feature = "user_management_advanced")] {
                            Response::success_gemini(format!(
                                include_str!("pages/register/success.gmi"),
                                username = username,
                                redirect = get_redirect(&user),
                            ))
                        }
                        #[cfg(not(feature = "user_management_advanced"))] {
                            Response::success_gemini(format!(
                                include_str!("pages/nopass/register/success.gmi"),
                                username = username,
                                redirect = get_redirect(&user),
                            ))
                        }
                    },
                    Err(e) => return Err(e.into())
                }
            } else {
                Response::input("Please pick a username")
            }
        },
        User::SignedIn(user) => {
            render_settings_menu(user)
        },
    })
}

#[cfg(feature = "user_management_advanced")]
async fn handle_login<UserData: Serialize + DeserializeOwned + Default>(request: Request) -> Result<Response> {
    Ok(match request.user::<UserData>()? {
        User::Unauthenticated => {
            render_unauth_page(&[""])
        },
        User::NotSignedIn(nsi) => {
            if let Some(username) = request.trailing_segments().get(0) {
                if let Some(password) = request.input() {
                    match nsi.attach::<UserData>(username, Some(password.as_bytes())) {
                        Err(UserManagerError::PasswordNotSet) | Ok(None) => {
                            Response::success_gemini(format!(
                                include_str!("pages/login/wrong.gmi"),
                                username = username,
                            ))
                        },
                        Ok(Some(user)) => {
                            Response::success_gemini(format!(
                                include_str!("pages/login/success.gmi"),
                                username = username,
                                redirect = get_redirect(&user),
                            ))
                        },
                        Err(e) => return Err(e.into()),
                    }
                } else {
                    Response::sensitive_input("Please enter your password")
                }
            } else if let Some(username) = request.input() {
                Response::redirect_temporary(
                    format!("/account/login/{}", username).as_str()
                )
            } else {
                Response::input("Please enter your username")
            }
        },
        User::SignedIn(user) => {
            render_settings_menu(user)
        },
    })
}

#[cfg(feature = "user_management_advanced")]
async fn handle_password<UserData: Serialize + DeserializeOwned + Default>(request: Request) -> Result<Response> {
    Ok(match request.user::<UserData>()? {
        User::Unauthenticated => {
            render_unauth_page(&[""])
        },
        User::NotSignedIn(nsi) => {
            save_redirect(&nsi, &[""]);
            Response::success_gemini(NSI)
        },
        User::SignedIn(mut user) => {
            if let Some(password) = request.input() {
                user.set_password(password)?;
                Response::success_gemini(include_str!("pages/password/success.gmi"))
            } else {
                Response::sensitive_input(
                    format!("Please enter a {}password",
                        if user.has_password() {
                            "new "
                        } else {
                            ""
                        }
                    )
                )
            }
        },
    })
}

fn render_settings_menu<UserData: Serialize + DeserializeOwned>(
    user: RegisteredUser<UserData>
) -> Response {
    let mut document = Document::new();
    document
        .add_heading(HeadingLevel::H1, "User Settings")
        .add_blank_line()
        .add_text(&format!("Welcome {}!", user.username()))
        .add_blank_line()
        .add_link(get_redirect(&user).as_str(), "Back to the app")
        .add_blank_line();

    #[cfg(feature = "user_management_advanced")]
    document
        .add_text(
            if user.has_password() {
                concat!(
                "You currently have a password set.  This can be used to link any new",
                " certificates or clients to your account.  If you don't remember your",
                " password, or would like to change it, you may do so here.",
                )
            } else {
                concat!(
                "You don't currently have a password set!  Without a password, you cannot",
                " link any new certificates to your account, and if you lose your current",
                " client or certificate, you won't be able to recover your account.",
                )
            }
        )
        .add_blank_line()
        .add_link("/account/password", if user.has_password() { "Change password" } else { "Set password" });

    document.into()
}

fn render_unauth_page<'a>(
    redirect: impl AsRef<[&'a str]>,
) -> Response {
    Response::success_gemini(format!(
        include_str!("pages/unauth.gmi"),
        redirect = redirect.as_ref().join("/"),
    ))
}

fn save_redirect<'a>(
    user: &NotSignedInUser,
    redirect: impl AsRef<[&'a str]>,
) {
    let mut redirect = redirect.as_ref().join("/");
    redirect.insert(0, '/');
    if redirect.len() > 1 {
        #[cfg(feature = "dashmap")]
        let ref_to_map = &*PENDING_REDIRECTS;
        #[cfg(not(feature = "dashmap"))]
        let mut ref_to_map = PENDING_REDIRECTS.write().unwrap();

        debug!("Added \"{}\" as redirect for cert {:x?}", redirect, &user.certificate);
        ref_to_map.insert(user.certificate, redirect);
    }
}

fn get_redirect<T: Serialize + DeserializeOwned>(user: &RegisteredUser<T>) -> String {
        let cert = user.active_certificate().unwrap();

        #[cfg(feature = "dashmap")]
        let maybe_redir = PENDING_REDIRECTS.get(cert).map(|r| r.clone());
        #[cfg(not(feature = "dashmap"))]
        let ref_to_map = PENDING_REDIRECTS.read().unwrap();
        #[cfg(not(feature = "dashmap"))]
        let maybe_redir = ref_to_map.get(cert).cloned();

        let redirect = maybe_redir.unwrap_or_else(||"/".to_string());
        debug!("Accessed redirect to \"{}\" for cert {:x?}", redirect, cert);
        redirect
}

mod private {
    pub trait Sealed {}
    impl Sealed for crate::Server {}
}