diff --git a/backend/Cargo.toml b/backend/Cargo.toml index ed64ac3..de98df7 100644 --- a/backend/Cargo.toml +++ b/backend/Cargo.toml @@ -6,8 +6,12 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -rocket = { version= "0.5.0-rc.1", features = ["json"] } -diesel = { version = "1.4.4", features = ["postgres", "r2d2", "chrono"] } +tokio = { version = "1.15", features = ["full"] } +hyper = "0.14" +axum = { version = "0.4", features = ["json", "headers"] } +diesel = { version = "1.4.4", features = ["postgres", "chrono"] } +bb8 = "0.7" +bb8-diesel = "0.2" dotenv = "0.15.0" rust-argon2 = "0.8" rand = "0.8.4" @@ -18,10 +22,5 @@ serde_json = "1.0" base64 = "0.13.0" zip = "0.5" - -[dependencies.rocket_sync_db_pools] -version = "0.1.0-rc.1" -features = ["diesel_postgres_pool"] - [dev-dependencies] parking_lot = "0.11" \ No newline at end of file diff --git a/backend/src/db/bots.rs b/backend/src/db/bots.rs index d359e28..bc9cb11 100644 --- a/backend/src/db/bots.rs +++ b/backend/src/db/bots.rs @@ -2,7 +2,6 @@ use diesel::prelude::*; use serde::{Deserialize, Serialize}; use crate::schema::{bots, code_bundles}; -use crate::DbConn; use chrono; #[derive(Insertable)] diff --git a/backend/src/db/users.rs b/backend/src/db/users.rs index 0817766..663f173 100644 --- a/backend/src/db/users.rs +++ b/backend/src/db/users.rs @@ -1,4 +1,4 @@ -use crate::{schema::users, DbConn}; +use crate::schema::users; use argon2; use diesel::{prelude::*, PgConnection}; use rand::Rng; diff --git a/backend/src/lib.rs b/backend/src/lib.rs index 8807637..665523f 100644 --- a/backend/src/lib.rs +++ b/backend/src/lib.rs @@ -1,10 +1,5 @@ #![feature(proc_macro_hygiene, decl_macro)] -use rocket::{Build, Rocket}; -use rocket_sync_db_pools::database; - -#[macro_use] -extern crate rocket; #[macro_use] extern crate diesel; @@ -12,27 +7,79 @@ pub mod db; pub mod routes; pub mod schema; -#[database("postgresql_database")] -pub struct DbConn(diesel::PgConnection); +use std::ops::Deref; -#[get("/")] -fn index() -> &'static str { +use axum; +use bb8::PooledConnection; +use bb8_diesel::{self, DieselConnectionManager}; +use diesel::PgConnection; + +use axum::{ + async_trait, + extract::{Extension, FromRequest, RequestParts}, + http::StatusCode, + routing::{get, post}, + AddExtensionLayer, Router, +}; + +async fn index_handler() -> &'static str { "Hello, world!" } -pub fn rocket() -> Rocket { - rocket::build() - .mount( - "/", - routes![ - index, - routes::users::register, - routes::users::login, - routes::users::current_user, - routes::bots::create_bot, - routes::bots::get_bot, - routes::bots::upload_bot_code, - ], - ) - .attach(DbConn::fairing()) +type ConnectionPool = bb8::Pool>; + +pub async fn app() -> Router { + let database_url = "postgresql://planetwars:planetwars@localhost/planetwars"; + let manager = DieselConnectionManager::::new(database_url); + let pool = bb8::Pool::builder().build(manager).await.unwrap(); + + let app = Router::new() + .route("/", get(index_handler)) + .route("/users/register", post(routes::users::register)) + .route("/users/login", post(routes::users::login)) + .route("/users/me", get(routes::users::current_user)) + .route("/bots", post(routes::bots::create_bot)) + .route("/bots/:bot_id", get(routes::bots::get_bot)) + .route("/bots/:bot_id/upload", post(routes::bots::upload_bot_code)) + .layer(AddExtensionLayer::new(pool)); + app +} + +// we can also write a custom extractor that grabs a connection from the pool +// which setup is appropriate depends on your application +pub struct DatabaseConnection(PooledConnection<'static, DieselConnectionManager>); + +impl Deref for DatabaseConnection { + type Target = PooledConnection<'static, DieselConnectionManager>; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +#[async_trait] +impl FromRequest for DatabaseConnection +where + B: Send, +{ + type Rejection = (StatusCode, String); + + async fn from_request(req: &mut RequestParts) -> Result { + let Extension(pool) = Extension::::from_request(req) + .await + .map_err(internal_error)?; + + let conn = pool.get_owned().await.map_err(internal_error)?; + + Ok(Self(conn)) + } +} + +/// Utility function for mapping any error into a `500 Internal Server Error` +/// response. +fn internal_error(err: E) -> (StatusCode, String) +where + E: std::error::Error, +{ + (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()) } diff --git a/backend/src/main.rs b/backend/src/main.rs index 3c0efa8..c75aaf6 100644 --- a/backend/src/main.rs +++ b/backend/src/main.rs @@ -1,8 +1,16 @@ -#[macro_use] -extern crate rocket; -extern crate mozaic4_backend; +use std::net::SocketAddr; -#[launch] -fn launch() -> rocket::Rocket { - mozaic4_backend::rocket() +extern crate mozaic4_backend; +extern crate tokio; + +#[tokio::main] +async fn main() { + let app = mozaic4_backend::app().await; + + let addr = SocketAddr::from(([127, 0, 0, 1], 9000)); + + axum::Server::bind(&addr) + .serve(app.into_make_service()) + .await + .unwrap(); } diff --git a/backend/src/routes/bots.rs b/backend/src/routes/bots.rs index 413c145..da09669 100644 --- a/backend/src/routes/bots.rs +++ b/backend/src/routes/bots.rs @@ -1,16 +1,14 @@ +use axum::extract::{Path, RawBody}; +use axum::http::StatusCode; +use axum::Json; use rand::Rng; -use rocket::data::ToByteUnit; -use rocket::fs::TempFile; -use rocket::Data; -use rocket::{response::status, serde::json::Json}; use serde::{Deserialize, Serialize}; use std::io::Cursor; -use std::path::Path; - -use crate::DbConn; +use std::path; use crate::db::bots::{self, CodeBundle}; use crate::db::users::User; +use crate::DatabaseConnection; use bots::Bot; #[derive(Serialize, Deserialize, Debug)] @@ -18,52 +16,36 @@ pub struct BotParams { name: String, } -// TODO: handle errors -#[post("/bots", data = "")] pub async fn create_bot( - db_conn: DbConn, + conn: DatabaseConnection, user: User, params: Json, -) -> status::Created> { - db_conn - .run(move |conn| { - let bot_params = bots::NewBot { - owner_id: user.id, - name: ¶ms.name, - }; - let bot = bots::create_bot(&bot_params, conn).unwrap(); - let bot_url = uri!(get_bot(bot.id)).to_string(); - status::Created::new(bot_url).body(Json(bot)) - }) - .await +) -> (StatusCode, Json) { + let bot_params = bots::NewBot { + owner_id: user.id, + name: ¶ms.name, + }; + let bot = bots::create_bot(&bot_params, &conn).unwrap(); + (StatusCode::CREATED, Json(bot)) } // TODO: handle errors -#[get("/bots/")] -pub async fn get_bot(db_conn: DbConn, bot_id: i32) -> Json { - db_conn - .run(move |conn| { - let bot = bots::find_bot(bot_id, conn).unwrap(); - Json(bot) - }) - .await +pub async fn get_bot(conn: DatabaseConnection, Path(bot_id): Path) -> Json { + let bot = bots::find_bot(bot_id, &conn).unwrap(); + Json(bot) } // TODO: proper error handling -#[post("/bots//upload", data = "")] pub async fn upload_bot_code( - db_conn: DbConn, + conn: DatabaseConnection, user: User, - bot_id: i32, - data: Data<'_>, -) -> status::Created> { + Path(bot_id): Path, + RawBody(body): RawBody, +) -> (StatusCode, Json) { // TODO: put in config somewhere let data_path = "./data/bots"; - let bot = db_conn - .run(move |conn| bots::find_bot(bot_id, conn)) - .await - .expect("Bot not found"); + let bot = bots::find_bot(bot_id, &conn).expect("Bot not found"); assert_eq!(user.id, bot.owner_id); @@ -71,26 +53,23 @@ pub async fn upload_bot_code( let token: [u8; 16] = rand::thread_rng().gen(); let name = base64::encode(&token); - let path = Path::new(data_path).join(name); - let capped_buf = data.open(10usize.megabytes()).into_bytes().await.unwrap(); - assert!(capped_buf.is_complete()); - let buf = capped_buf.into_inner(); + let path = path::Path::new(data_path).join(name); + // let capped_buf = data.open(10usize.megabytes()).into_bytes().await.unwrap(); + // assert!(capped_buf.is_complete()); + // let buf = capped_buf.into_inner(); + let buf = hyper::body::to_bytes(body).await.unwrap(); zip::ZipArchive::new(Cursor::new(buf)) .unwrap() .extract(&path) .unwrap(); - let code_bundle = db_conn - .run(move |conn| { - let bundle = bots::NewCodeBundle { - bot_id: bot.id, - path: path.to_str().unwrap(), - }; - bots::create_code_bundle(&bundle, conn).expect("Failed to create code bundle") - }) - .await; + let bundle = bots::NewCodeBundle { + bot_id: bot.id, + path: path.to_str().unwrap(), + }; + let code_bundle = + bots::create_code_bundle(&bundle, &conn).expect("Failed to create code bundle"); - // TODO: proper location - status::Created::new("").body(Json(code_bundle)) + (StatusCode::CREATED, Json(code_bundle)) } diff --git a/backend/src/routes/users.rs b/backend/src/routes/users.rs index 45a94b9..fc77d7b 100644 --- a/backend/src/routes/users.rs +++ b/backend/src/routes/users.rs @@ -1,48 +1,32 @@ +use crate::db::users::{Credentials, User}; use crate::db::{sessions, users}; -use crate::{ - db::users::{Credentials, User}, - DbConn, -}; -use rocket::serde::json::Json; +use crate::DatabaseConnection; +use axum::extract::{FromRequest, RequestParts, TypedHeader}; +use axum::headers::authorization::Bearer; +use axum::headers::Authorization; +use axum::http::StatusCode; +use axum::{async_trait, Json}; use serde::{Deserialize, Serialize}; -use rocket::http::Status; -use rocket::request::{FromRequest, Outcome, Request}; -use rocket::response::status; +type AuthorizationHeader = TypedHeader>; -#[derive(Debug)] -pub enum AuthTokenError { - BadCount, - Missing, - Invalid, -} +#[async_trait] +impl FromRequest for User +where + B: Send, +{ + type Rejection = (StatusCode, String); -// TODO: error handling and proper lifetimes -#[rocket::async_trait] -impl<'r> FromRequest<'r> for User { - type Error = AuthTokenError; + async fn from_request(req: &mut RequestParts) -> Result { + let conn = DatabaseConnection::from_request(req).await?; + let TypedHeader(Authorization(bearer)) = AuthorizationHeader::from_request(req) + .await + .map_err(|_| (StatusCode::UNAUTHORIZED, "".to_string()))?; - async fn from_request(request: &'r Request<'_>) -> Outcome { - let keys: Vec<_> = request.headers().get("Authorization").collect(); - let auth_header = match keys.len() { - 0 => return Outcome::Failure((Status::BadRequest, AuthTokenError::Missing)), - 1 => keys[0], - _ => return Outcome::Failure((Status::BadRequest, AuthTokenError::BadCount)), - }; + let (_session, user) = sessions::find_user_by_session(bearer.token(), &conn) + .map_err(|_| (StatusCode::UNAUTHORIZED, "".to_string()))?; - let token = match auth_header.strip_prefix("Bearer ") { - Some(token) => token.to_string(), - None => return Outcome::Failure((Status::BadRequest, AuthTokenError::Invalid)), - }; - - let db = request.guard::().await.unwrap(); - let res = db - .run(move |conn| sessions::find_user_by_session(&token, conn)) - .await; - match res { - Ok((_session, user)) => Outcome::Success(user), - Err(_) => Outcome::Failure((Status::Unauthorized, AuthTokenError::Invalid)), - } + Ok(user) } } @@ -67,18 +51,16 @@ pub struct RegistrationParams { pub password: String, } -#[post("/register", data = "")] -pub async fn register(db_conn: DbConn, params: Json) -> Json { - db_conn - .run(move |conn| { - let credentials = Credentials { - username: ¶ms.username, - password: ¶ms.password, - }; - let user = users::create_user(&credentials, conn).unwrap(); - Json(user.into()) - }) - .await +pub async fn register( + conn: DatabaseConnection, + params: Json, +) -> Json { + let credentials = Credentials { + username: ¶ms.username, + password: ¶ms.password, + }; + let user = users::create_user(&credentials, &conn).unwrap(); + Json(user.into()) } #[derive(Deserialize)] @@ -87,32 +69,26 @@ pub struct LoginParams { pub password: String, } -#[post("/login", data = "")] pub async fn login( - db_conn: DbConn, + conn: DatabaseConnection, params: Json, -) -> Result> { - db_conn - .run(move |conn| { - let credentials = Credentials { - username: ¶ms.username, - password: ¶ms.password, - }; - // TODO: handle failures - let authenticated = users::authenticate_user(&credentials, conn); +) -> Result { + let credentials = Credentials { + username: ¶ms.username, + password: ¶ms.password, + }; + // TODO: handle failures + let authenticated = users::authenticate_user(&credentials, &conn); - match authenticated { - None => Err(status::Forbidden(Some("invalid auth"))), - Some(user) => { - let session = sessions::create_session(&user, conn); - Ok(session.token) - } - } - }) - .await + match authenticated { + None => Err(StatusCode::FORBIDDEN), + Some(user) => { + let session = sessions::create_session(&user, &conn); + Ok(session.token) + } + } } -#[get("/users/me")] pub async fn current_user(user: User) -> Json { Json(user.into()) }