gracefully handle invalid login credentials
This commit is contained in:
parent
13cdbc7ff7
commit
6aa72b3c87
4 changed files with 79 additions and 28 deletions
|
@ -58,24 +58,26 @@ pub fn create_user(credentials: &Credentials, conn: &PgConnection) -> QueryResul
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn authenticate_user(credentials: &Credentials, db_conn: &PgConnection) -> Option<User> {
|
pub fn authenticate_user(credentials: &Credentials, db_conn: &PgConnection) -> Option<User> {
|
||||||
let user = users::table
|
users::table
|
||||||
.filter(users::username.eq(&credentials.username))
|
.filter(users::username.eq(&credentials.username))
|
||||||
.first::<User>(db_conn)
|
.first::<User>(db_conn)
|
||||||
.unwrap();
|
.optional()
|
||||||
|
.unwrap()
|
||||||
|
.and_then(|user| {
|
||||||
|
let password_matches = argon2::verify_raw(
|
||||||
|
credentials.password.as_bytes(),
|
||||||
|
&user.password_salt,
|
||||||
|
&user.password_hash,
|
||||||
|
&argon2_config(),
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
let password_matches = argon2::verify_raw(
|
if password_matches {
|
||||||
credentials.password.as_bytes(),
|
return Some(user);
|
||||||
&user.password_salt,
|
} else {
|
||||||
&user.password_hash,
|
return None;
|
||||||
&argon2_config(),
|
}
|
||||||
)
|
})
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
if password_matches {
|
|
||||||
return Some(user);
|
|
||||||
} else {
|
|
||||||
return None;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|
|
@ -3,6 +3,6 @@ extern crate rocket;
|
||||||
extern crate mozaic4_backend;
|
extern crate mozaic4_backend;
|
||||||
|
|
||||||
#[launch]
|
#[launch]
|
||||||
fn launch() -> _ {
|
fn launch() -> rocket::Rocket<rocket::Build> {
|
||||||
mozaic4_backend::rocket()
|
mozaic4_backend::rocket()
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,7 +7,8 @@ use rocket::serde::json::Json;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use rocket::http::Status;
|
use rocket::http::Status;
|
||||||
use rocket::request::{self, FromRequest, Outcome, Request};
|
use rocket::request::{FromRequest, Outcome, Request};
|
||||||
|
use rocket::response::status;
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub enum AuthTokenError {
|
pub enum AuthTokenError {
|
||||||
|
@ -23,17 +24,25 @@ impl<'r> FromRequest<'r> for User {
|
||||||
|
|
||||||
async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
|
async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
|
||||||
let keys: Vec<_> = request.headers().get("Authorization").collect();
|
let keys: Vec<_> = request.headers().get("Authorization").collect();
|
||||||
let token = match keys.len() {
|
let auth_header = match keys.len() {
|
||||||
0 => return Outcome::Failure((Status::BadRequest, AuthTokenError::Missing)),
|
0 => return Outcome::Failure((Status::BadRequest, AuthTokenError::Missing)),
|
||||||
1 => keys[0].to_string(),
|
1 => keys[0],
|
||||||
_ => return Outcome::Failure((Status::BadRequest, AuthTokenError::BadCount)),
|
_ => return Outcome::Failure((Status::BadRequest, AuthTokenError::BadCount)),
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let token = match auth_header.strip_prefix("Bearer ") {
|
||||||
|
Some(token) => token.to_string(),
|
||||||
|
None => return Outcome::Failure((Status::BadRequest, AuthTokenError::Invalid)),
|
||||||
|
};
|
||||||
|
|
||||||
let db = request.guard::<DbConn>().await.unwrap();
|
let db = request.guard::<DbConn>().await.unwrap();
|
||||||
let (_session, user) = db
|
let res = db
|
||||||
.run(move |conn| sessions::find_user_by_session(&token, conn))
|
.run(move |conn| sessions::find_user_by_session(&token, conn))
|
||||||
.await
|
.await;
|
||||||
.unwrap();
|
match res {
|
||||||
Outcome::Success(user)
|
Ok((_session, user)) => Outcome::Success(user),
|
||||||
|
Err(_) => Outcome::Failure((Status::Unauthorized, AuthTokenError::Invalid)),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -79,7 +88,10 @@ pub struct LoginParams {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[post("/login", data = "<params>")]
|
#[post("/login", data = "<params>")]
|
||||||
pub async fn login(db_conn: DbConn, params: Json<LoginParams>) -> String {
|
pub async fn login(
|
||||||
|
db_conn: DbConn,
|
||||||
|
params: Json<LoginParams>,
|
||||||
|
) -> Result<String, status::Forbidden<&'static str>> {
|
||||||
db_conn
|
db_conn
|
||||||
.run(move |conn| {
|
.run(move |conn| {
|
||||||
let credentials = Credentials {
|
let credentials = Credentials {
|
||||||
|
@ -87,9 +99,15 @@ pub async fn login(db_conn: DbConn, params: Json<LoginParams>) -> String {
|
||||||
password: ¶ms.password,
|
password: ¶ms.password,
|
||||||
};
|
};
|
||||||
// TODO: handle failures
|
// TODO: handle failures
|
||||||
let user = users::authenticate_user(&credentials, conn).unwrap();
|
let authenticated = users::authenticate_user(&credentials, conn);
|
||||||
let session = sessions::create_session(&user, conn);
|
|
||||||
return session.token;
|
match authenticated {
|
||||||
|
None => Err(status::Forbidden(Some("invalid auth"))),
|
||||||
|
Some(user) => {
|
||||||
|
let session = sessions::create_session(&user, conn);
|
||||||
|
Ok(session.token)
|
||||||
|
}
|
||||||
|
}
|
||||||
})
|
})
|
||||||
.await
|
.await
|
||||||
}
|
}
|
||||||
|
|
|
@ -37,6 +37,22 @@ macro_rules! run_test {
|
||||||
}};
|
}};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub struct BearerAuth {
|
||||||
|
token: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl BearerAuth {
|
||||||
|
pub fn new(token: String) -> Self {
|
||||||
|
Self { token }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> Into<Header<'a>> for BearerAuth {
|
||||||
|
fn into(self) -> Header<'a> {
|
||||||
|
Header::new("Authorization", format!("Bearer {}", self.token))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_registration() {
|
fn test_registration() {
|
||||||
run_test!(|client, _conn| {
|
run_test!(|client, _conn| {
|
||||||
|
@ -62,7 +78,7 @@ fn test_registration() {
|
||||||
|
|
||||||
let response = client
|
let response = client
|
||||||
.get("/users/me")
|
.get("/users/me")
|
||||||
.header(Header::new("Authorization", token))
|
.header(BearerAuth::new(token))
|
||||||
.dispatch()
|
.dispatch()
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
|
@ -73,3 +89,18 @@ fn test_registration() {
|
||||||
assert_eq!(json["username"], "piepkonijn");
|
assert_eq!(json["username"], "piepkonijn");
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_reject_invalid_credentials() {
|
||||||
|
run_test!(|client, _conn| {
|
||||||
|
let response = client
|
||||||
|
.post("/login")
|
||||||
|
.header(ContentType::JSON)
|
||||||
|
.body(r#"{"username": "piepkonijn", "password": "letmeinplease"}"#)
|
||||||
|
.dispatch()
|
||||||
|
.await;
|
||||||
|
|
||||||
|
assert_eq!(response.status(), Status::Forbidden);
|
||||||
|
// assert_eq!(response.content_type(), Some(ContentType::JSON));
|
||||||
|
});
|
||||||
|
}
|
Loading…
Reference in a new issue