mod session; use std::str::FromStr; use session::Session; use actix_csrf::extractor::{CsrfCookie, CsrfToken}; use actix_csrf::Csrf; use actix_session::CookieSession; use actix_web::cookie::SameSite; use actix_web::error::InternalError; use actix_web::http::header::LOCATION; use actix_web::http::{Method, StatusCode}; use actix_web::middleware::Logger; use actix_web::web::{Form, Query}; use actix_web::{get, post, App, HttpRequest, HttpResponse, HttpServer, Responder}; use handlebars::Handlebars; use lettre::EmailAddress; use once_cell::sync::{Lazy, OnceCell}; use rand::prelude::StdRng; use serde::de::Visitor; use serde::{Deserialize, Serialize}; use sodiumoxide::crypto::pwhash::argon2id13::{self, HashedPassword, HASHEDPASSWORDBYTES}; use sqlx::sqlite::SqliteConnectOptions; use sqlx::SqlitePool; use tracing::error; static TEMPLATE_ENGINE: Lazy = Lazy::new(|| { let mut handlebars = Handlebars::new(); handlebars.set_strict_mode(true); handlebars.set_dev_mode(true); handlebars .register_templates_directory(".hbs", "src/templates") .expect("failed to load template directory"); handlebars }); static DB_POOL: OnceCell = OnceCell::new(); #[actix_rt::main] async fn main() -> std::io::Result<()> { tracing_subscriber::fmt::init(); sodiumoxide::init().unwrap(); let db = { let db_options = SqliteConnectOptions::from_str("db.sqlite") .unwrap() .create_if_missing(true); let pool = SqlitePool::connect_with(db_options).await.unwrap(); sqlx::query_file!("db_queries/init.sql") .execute(&pool) .await .unwrap(); pool }; DB_POOL.set(db).unwrap(); HttpServer::new(|| { App::new() .wrap(Logger::default()) .wrap( CookieSession::private(&[0; 32]) .name("session") .path("/") .secure(true) .http_only(true) .same_site(SameSite::Strict), ) .wrap( Csrf::::new() .set_cookie(Method::GET, "/login") .validate_cookie(Method::POST, "/login") .set_cookie(Method::GET, "/register") .validate_cookie(Method::POST, "/register"), ) .service(index) .service(login_ui) .service(login) .service(register_ui) .service(register) .service(account_ui) .service(logout) .service(actix_files::Files::new("/static", "src/static")) }) .bind("127.0.0.1:8080")? .run() .await } #[derive(Deserialize, Serialize)] enum SessionState { Anonymous, } #[get("/")] async fn index() -> impl Responder { match TEMPLATE_ENGINE.render("index", &()) { Ok(resp) => Ok(HttpResponse::Ok().body(resp)), Err(e) => { error!("{}", e); Err(InternalError::new(e, StatusCode::INTERNAL_SERVER_ERROR)) } } } #[derive(Deserialize)] struct LoginQuery { error: Option, } #[get("/login")] async fn login_ui(csrf: CsrfToken, mut query: Query) -> impl Responder { #[derive(Serialize)] struct TemplateArgs { error: Option, csrf: CsrfToken, } match TEMPLATE_ENGINE.render( "login", &TemplateArgs { error: query.error.take(), csrf, }, ) { Ok(resp) => Ok(HttpResponse::Ok().body(resp)), Err(e) => { error!("{}", e); Err(InternalError::new(e, StatusCode::INTERNAL_SERVER_ERROR)) } } } #[derive(Deserialize)] struct Login { csrf: CsrfToken, email: EmailAddress, password: Password, } #[derive(Serialize, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] struct Password(String); impl<'de> Deserialize<'de> for Password { fn deserialize(deserializer: D) -> Result where D: serde::Deserializer<'de>, { use serde::de::{Error, Unexpected}; struct SecretDeserializer; impl<'de> Visitor<'de> for SecretDeserializer { type Value = Password; fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { formatter.write_str("a password between 8 and 64 bytes") } fn visit_str(self, v: &str) -> Result { if v.len() < 8 || v.len() > 64 { println!("password failed"); return Err(Error::invalid_value( Unexpected::Str("password with invalid size"), &"a password between 8 and 64 bytes", )); } Ok(Password(v.to_owned())) } } deserializer.deserialize_string(SecretDeserializer) } } #[post("/login")] async fn login(csrf_cookie: CsrfCookie, form: Form, session: Session) -> impl Responder { if !csrf_cookie.validate(form.csrf.as_ref()) { return HttpResponse::BadRequest().finish(); } let email: &str = form.email.as_ref(); let verified = sqlx::query!("SELECT password FROM users WHERE email = ?", email) .fetch_one(DB_POOL.get().expect("db connection to be set")) .await; if let Ok(record) = verified { let verified = argon2id13::pwhash_verify( &HashedPassword::from_slice(&record.password).unwrap(), form.password.0.as_bytes(), ); if verified { let redirect_to = session.get_redirect_url(); session.init(&form.email); let mut resp = HttpResponse::SeeOther(); if let Some(path) = redirect_to { resp.insert_header((LOCATION, path.to_string())); } else { resp.insert_header((LOCATION, "/account")); } return resp.finish(); } } else { // To guard against timing attacks, we'll construct a fake password to // hash. We won't even check if it's successful, we just need to compute // the hash. Since we don't check the result, this must be in a separate // branch from the success branch, else it's possible to actually // succeed from a bogey hash. let mut data = [0_u8; HASHEDPASSWORDBYTES]; assert!(argon2id13::STRPREFIX.len() < HASHEDPASSWORDBYTES); for (i, c) in argon2id13::STRPREFIX.iter().enumerate() { data[i] = *c; } // Rust shouldn't optimize this out since it ultimately calls out to // a C function, so it shouldn't find out that the function is pure. argon2id13::pwhash_verify( &HashedPassword::from_slice(&data).unwrap(), form.password.0.as_bytes(), ); } HttpResponse::SeeOther() .insert_header((LOCATION, "/login?error=true")) .finish() } #[get("/register")] async fn register_ui(csrf: CsrfToken) -> impl Responder { #[derive(Serialize)] struct TemplateArgs { csrf: CsrfToken, } match TEMPLATE_ENGINE.render("register", &TemplateArgs { csrf }) { Ok(resp) => Ok(HttpResponse::Ok().body(resp)), Err(e) => { error!("{}", e); Err(InternalError::new(e, StatusCode::INTERNAL_SERVER_ERROR)) } } } #[derive(Deserialize)] struct RegistrationInfo { csrf: CsrfToken, email: EmailAddress, password: Password, } #[post("/register")] async fn register( csrf_cookie: CsrfCookie, form: Form, session: Session, ) -> impl Responder { if !csrf_cookie.validate(form.csrf.as_ref()) { return HttpResponse::BadRequest().finish(); } let hashed = { let res = argon2id13::pwhash( form.password.0.as_bytes(), argon2id13::OPSLIMIT_INTERACTIVE, argon2id13::MEMLIMIT_INTERACTIVE, ); if let Ok(res) = res { res } else { return HttpResponse::InternalServerError().finish(); } }; let hashed = hashed.as_ref(); let email: &str = form.email.as_ref(); let insert_res = sqlx::query!( "INSERT INTO users (email, password) VALUES (?, ?)", email, hashed, ) .execute(DB_POOL.get().expect("db connection to be set")) .await; if insert_res.is_ok() { session.init(&form.email); HttpResponse::SeeOther() .insert_header((LOCATION, "/account")) .finish() } else { todo!() } } #[get("/account")] async fn account_ui(req: HttpRequest, session: Session) -> impl Responder { if let Err(error) = session.validate_or_redirect(req.uri()) { return error; } HttpResponse::Ok().body(format!("{:?}", session.email())) } #[get("/logout")] async fn logout(session: Session) -> impl Responder { // It should be ok to logout without a CSRF token; the worst case is that // the user is logged out, which is fail-safe. session.purge(); HttpResponse::SeeOther() .append_header((LOCATION, "/")) .finish() }