From 6aa72b3c8717f32e62c772aeed327d3cd9a6fa65 Mon Sep 17 00:00:00 2001 From: Ilion Beyst Date: Wed, 15 Dec 2021 22:40:55 +0100 Subject: [PATCH] gracefully handle invalid login credentials --- backend/src/db/users.rs | 32 +++++++++++---------- backend/src/main.rs | 2 +- backend/src/routes/users.rs | 40 +++++++++++++++++++-------- backend/tests/{common.rs => login.rs} | 33 +++++++++++++++++++++- 4 files changed, 79 insertions(+), 28 deletions(-) rename backend/tests/{common.rs => login.rs} (75%) diff --git a/backend/src/db/users.rs b/backend/src/db/users.rs index c06e5b3..29cee88 100644 --- a/backend/src/db/users.rs +++ b/backend/src/db/users.rs @@ -58,24 +58,26 @@ pub fn create_user(credentials: &Credentials, conn: &PgConnection) -> QueryResul } pub fn authenticate_user(credentials: &Credentials, db_conn: &PgConnection) -> Option { - let user = users::table + users::table .filter(users::username.eq(&credentials.username)) .first::(db_conn) - .unwrap(); + .optional() + .unwrap() + .and_then(|user| { + let password_matches = argon2::verify_raw( + credentials.password.as_bytes(), + &user.password_salt, + &user.password_hash, + &argon2_config(), + ) + .unwrap(); - let password_matches = argon2::verify_raw( - credentials.password.as_bytes(), - &user.password_salt, - &user.password_hash, - &argon2_config(), - ) - .unwrap(); - - if password_matches { - return Some(user); - } else { - return None; - } + if password_matches { + return Some(user); + } else { + return None; + } + }) } #[test] diff --git a/backend/src/main.rs b/backend/src/main.rs index 65be48d..3c0efa8 100644 --- a/backend/src/main.rs +++ b/backend/src/main.rs @@ -3,6 +3,6 @@ extern crate rocket; extern crate mozaic4_backend; #[launch] -fn launch() -> _ { +fn launch() -> rocket::Rocket { mozaic4_backend::rocket() } diff --git a/backend/src/routes/users.rs b/backend/src/routes/users.rs index 274b712..72a857f 100644 --- a/backend/src/routes/users.rs +++ b/backend/src/routes/users.rs @@ -7,7 +7,8 @@ use rocket::serde::json::Json; use serde::{Deserialize, Serialize}; use rocket::http::Status; -use rocket::request::{self, FromRequest, Outcome, Request}; +use rocket::request::{FromRequest, Outcome, Request}; +use rocket::response::status; #[derive(Debug)] pub enum AuthTokenError { @@ -23,17 +24,25 @@ impl<'r> FromRequest<'r> for User { async fn from_request(request: &'r Request<'_>) -> Outcome { let keys: Vec<_> = request.headers().get("Authorization").collect(); - let token = match keys.len() { + let auth_header = match keys.len() { 0 => return Outcome::Failure((Status::BadRequest, AuthTokenError::Missing)), - 1 => keys[0].to_string(), + 1 => keys[0], _ => return Outcome::Failure((Status::BadRequest, AuthTokenError::BadCount)), }; + + let token = match auth_header.strip_prefix("Bearer ") { + Some(token) => token.to_string(), + None => return Outcome::Failure((Status::BadRequest, AuthTokenError::Invalid)), + }; + let db = request.guard::().await.unwrap(); - let (_session, user) = db + let res = db .run(move |conn| sessions::find_user_by_session(&token, conn)) - .await - .unwrap(); - Outcome::Success(user) + .await; + match res { + Ok((_session, user)) => Outcome::Success(user), + Err(_) => Outcome::Failure((Status::Unauthorized, AuthTokenError::Invalid)), + } } } @@ -79,7 +88,10 @@ pub struct LoginParams { } #[post("/login", data = "")] -pub async fn login(db_conn: DbConn, params: Json) -> String { +pub async fn login( + db_conn: DbConn, + params: Json, +) -> Result> { db_conn .run(move |conn| { let credentials = Credentials { @@ -87,9 +99,15 @@ pub async fn login(db_conn: DbConn, params: Json) -> String { password: ¶ms.password, }; // TODO: handle failures - let user = users::authenticate_user(&credentials, conn).unwrap(); - let session = sessions::create_session(&user, conn); - return session.token; + let authenticated = users::authenticate_user(&credentials, conn); + + match authenticated { + None => Err(status::Forbidden(Some("invalid auth"))), + Some(user) => { + let session = sessions::create_session(&user, conn); + Ok(session.token) + } + } }) .await } diff --git a/backend/tests/common.rs b/backend/tests/login.rs similarity index 75% rename from backend/tests/common.rs rename to backend/tests/login.rs index 8ab68a1..9c70af2 100644 --- a/backend/tests/common.rs +++ b/backend/tests/login.rs @@ -37,6 +37,22 @@ macro_rules! run_test { }}; } +pub struct BearerAuth { + token: String, +} + +impl BearerAuth { + pub fn new(token: String) -> Self { + Self { token } + } +} + +impl<'a> Into> for BearerAuth { + fn into(self) -> Header<'a> { + Header::new("Authorization", format!("Bearer {}", self.token)) + } +} + #[test] fn test_registration() { run_test!(|client, _conn| { @@ -62,7 +78,7 @@ fn test_registration() { let response = client .get("/users/me") - .header(Header::new("Authorization", token)) + .header(BearerAuth::new(token)) .dispatch() .await; @@ -73,3 +89,18 @@ fn test_registration() { assert_eq!(json["username"], "piepkonijn"); }); } + +#[test] +fn test_reject_invalid_credentials() { + run_test!(|client, _conn| { + let response = client + .post("/login") + .header(ContentType::JSON) + .body(r#"{"username": "piepkonijn", "password": "letmeinplease"}"#) + .dispatch() + .await; + + assert_eq!(response.status(), Status::Forbidden); + // assert_eq!(response.content_type(), Some(ContentType::JSON)); + }); +}