From ae57359353cf31ff374a8932999742920878bf00 Mon Sep 17 00:00:00 2001 From: Ilion Beyst Date: Wed, 12 Oct 2022 22:52:15 +0200 Subject: [PATCH] upgrade to diesel 2.0 --- .github/workflows/rust-tests.yml | 2 +- Cargo.lock | 149 +++++++++++++++++--- planetwars-server/Cargo.toml | 10 +- planetwars-server/src/cli.rs | 4 +- planetwars-server/src/db/bots.rs | 30 ++-- planetwars-server/src/db/maps.rs | 10 +- planetwars-server/src/db/matches.rs | 96 +++++++------ planetwars-server/src/db/ratings.rs | 6 +- planetwars-server/src/db/sessions.rs | 6 +- planetwars-server/src/db/users.rs | 12 +- planetwars-server/src/db_types.rs | 2 +- planetwars-server/src/lib.rs | 16 ++- planetwars-server/src/modules/bots.rs | 2 +- planetwars-server/src/modules/client_api.rs | 6 +- planetwars-server/src/modules/matches.rs | 11 +- planetwars-server/src/modules/ranking.rs | 13 +- planetwars-server/src/modules/registry.rs | 40 +++--- planetwars-server/src/routes/bots.rs | 56 ++++---- planetwars-server/src/routes/demo.rs | 9 +- planetwars-server/src/routes/maps.rs | 4 +- planetwars-server/src/routes/matches.rs | 16 +-- planetwars-server/src/routes/users.rs | 20 +-- planetwars-server/src/schema.rs | 41 +++--- planetwars-server/tests/integration.rs | 16 +-- 24 files changed, 357 insertions(+), 220 deletions(-) diff --git a/.github/workflows/rust-tests.yml b/.github/workflows/rust-tests.yml index b410c01..6be3ce3 100644 --- a/.github/workflows/rust-tests.yml +++ b/.github/workflows/rust-tests.yml @@ -43,7 +43,7 @@ jobs: - name: Setup tests run: | docker pull python:3.10-slim-buster - cargo install diesel_cli --version ^1.4 || true + cargo install diesel_cli --version ^2.0 || true cd planetwars-server diesel migration run --locked-schema env: diff --git a/Cargo.lock b/Cargo.lock index 323f9d2..947c300 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -29,6 +29,15 @@ dependencies = [ "alloc-no-stdlib", ] +[[package]] +name = "android_system_properties" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "819e7219dbd41043ac279b19830f2efc897156490d7fd6ea916720117ee66311" +dependencies = [ + "libc", +] + [[package]] name = "anyhow" version = "1.0.58" @@ -171,22 +180,21 @@ checksum = "904dfeac50f3cdaba28fc6f57fdcddb75f49ed61346676a78c4ffe55877802fd" [[package]] name = "bb8" -version = "0.7.1" +version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2e9f4fa9768efd269499d8fba693260cfc670891cf6de3adc935588447a77cc8" +checksum = "1627eccf3aa91405435ba240be23513eeca466b5dc33866422672264de061582" dependencies = [ "async-trait", "futures-channel", "futures-util", - "parking_lot 0.11.2", + "parking_lot 0.12.1", "tokio", ] [[package]] name = "bb8-diesel" version = "0.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "79c87e12b0086ff7850d98a19d2a70f5fd901b463412d499514d8e2e16ad0826" +source = "git+https://github.com/overdrivenpotato/bb8-diesel.git#89b76207bbca35082687c765074f402200fcc51f" dependencies = [ "async-trait", "bb8", @@ -372,15 +380,17 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" [[package]] name = "chrono" -version = "0.4.19" +version = "0.4.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "670ad68c9088c2a963aaa298cb369688cf3f9465ce5e2d4ca10e6e0098a1ce73" +checksum = "bfd4d1b31faaa3a89d7934dbded3111da0d2ef28e3ebccdb4f0179f5929d1ef1" dependencies = [ - "libc", + "iana-time-zone", + "js-sys", "num-integer", "num-traits", "serde", "time", + "wasm-bindgen", "winapi", ] @@ -441,6 +451,16 @@ dependencies = [ "cc", ] +[[package]] +name = "codespan-reporting" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3538270d33cc669650c4b093848450d380def10c331d38c768e34cac80576e6e" +dependencies = [ + "termcolor", + "unicode-width", +] + [[package]] name = "config" version = "0.12.0" @@ -520,6 +540,50 @@ dependencies = [ "typenum", ] +[[package]] +name = "cxx" +version = "1.0.78" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "19f39818dcfc97d45b03953c1292efc4e80954e1583c4aa770bac1383e2310a4" +dependencies = [ + "cc", + "cxxbridge-flags", + "cxxbridge-macro", + "link-cplusplus", +] + +[[package]] +name = "cxx-build" +version = "1.0.78" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3e580d70777c116df50c390d1211993f62d40302881e54d4b79727acb83d0199" +dependencies = [ + "cc", + "codespan-reporting", + "once_cell", + "proc-macro2", + "quote", + "scratch", + "syn", +] + +[[package]] +name = "cxxbridge-flags" +version = "1.0.78" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "56a46460b88d1cec95112c8c363f0e2c39afdb237f60583b0b36343bf627ea9c" + +[[package]] +name = "cxxbridge-macro" +version = "1.0.78" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "747b608fecf06b0d72d440f27acc99288207324b793be2c17991839f3d4995ea" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "darling" version = "0.13.1" @@ -557,23 +621,24 @@ dependencies = [ [[package]] name = "diesel" -version = "1.4.8" +version = "2.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b28135ecf6b7d446b43e27e225622a038cc4e2930a1022f51cdb97ada19b8e4d" +checksum = "68c186a7418a2aac330bb76cde82f16c36b03a66fb91db32d20214311f9f6545" dependencies = [ "bitflags", "byteorder", "chrono", "diesel_derives", + "itoa 1.0.2", "pq-sys", "r2d2", ] [[package]] name = "diesel-derive-enum" -version = "1.1.2" +version = "2.0.0-rc.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6c8910921b014e2af16298f006de12aa08af894b71f0f49a486ab6d74b17bbed" +checksum = "5f28fc9f5bf184ebc58ad9105dede024981e2303fe878a0fe16557f3a979064a" dependencies = [ "heck", "proc-macro2", @@ -583,10 +648,11 @@ dependencies = [ [[package]] name = "diesel_derives" -version = "1.4.1" +version = "2.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "45f5098f628d02a7a0f68ddba586fb61e80edec3bdc1be3b921f4ceec60858d3" +checksum = "143b758c91dbc3fe1fdcb0dba5bd13276c6a66422f2ef5795b58488248a310aa" dependencies = [ + "proc-macro-error", "proc-macro2", "quote", "syn", @@ -1014,6 +1080,30 @@ dependencies = [ "tokio", ] +[[package]] +name = "iana-time-zone" +version = "0.1.51" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f5a6ef98976b22b3b7f2f3a806f858cb862044cfa66805aa3ad84cb3d3b785ed" +dependencies = [ + "android_system_properties", + "core-foundation-sys", + "iana-time-zone-haiku", + "js-sys", + "wasm-bindgen", + "winapi", +] + +[[package]] +name = "iana-time-zone-haiku" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fde6edd6cef363e9359ed3c98ba64590ba9eecba2293eb5a723ab32aee8926aa" +dependencies = [ + "cxx", + "cxx-build", +] + [[package]] name = "ident_case" version = "1.0.1" @@ -1112,6 +1202,15 @@ version = "0.2.126" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "349d5a591cd28b49e1d1037471617a32ddcda5731b99419008085f72d5a53836" +[[package]] +name = "link-cplusplus" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9272ab7b96c9046fbc5bc56c06c117cb639fe2d509df0c421cad82d2915cf369" +dependencies = [ + "cc", +] + [[package]] name = "linked-hash-map" version = "0.5.4" @@ -1195,9 +1294,9 @@ dependencies = [ [[package]] name = "mio" -version = "0.8.3" +version = "0.8.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "713d550d9b44d89174e066b7a6217ae06234c10cb47819a88290d2b353c31799" +checksum = "57ee1c23c7c63b0c9250c339ffdc69255f110b298b901b9f6c82547b7b87caaf" dependencies = [ "libc", "log", @@ -1999,6 +2098,12 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd" +[[package]] +name = "scratch" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c8132065adcfd6e02db789d9285a0deb2f3fcb04002865ab67d5fb103533898" + [[package]] name = "sct" version = "0.7.0" @@ -2292,16 +2397,16 @@ checksum = "cda74da7e1a664f795bb1f8a87ec406fb89a02522cf6e50620d016add6dbbf5c" [[package]] name = "tokio" -version = "1.19.2" +version = "1.21.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c51a52ed6686dd62c320f9b89299e9dfb46f730c7a48e635c19f21d116cb1439" +checksum = "a9e03c497dc955702ba729190dc4aac6f2a0ce97f913e5b1b5912fc5039d9099" dependencies = [ + "autocfg 1.1.0", "bytes", "libc", "memchr", "mio", "num_cpus", - "once_cell", "parking_lot 0.12.1", "pin-project-lite", "signal-hook-registry", @@ -2592,6 +2697,12 @@ dependencies = [ "tinyvec", ] +[[package]] +name = "unicode-width" +version = "0.1.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c0edd1e5b14653f783770bce4a4dabb4a5108a5370a5f5d8cfe8710c361f6c8b" + [[package]] name = "untrusted" version = "0.7.1" diff --git a/planetwars-server/Cargo.toml b/planetwars-server/Cargo.toml index f5641d1..183bb90 100644 --- a/planetwars-server/Cargo.toml +++ b/planetwars-server/Cargo.toml @@ -15,15 +15,15 @@ path = "src/cli.rs" [dependencies] futures = "0.3" -tokio = { version = "1.15", features = ["full"] } +tokio = { version = "1.21", features = ["full"] } tokio-stream = "0.1.9" hyper = "0.14" tower-http = { version = "0.3.4", features = ["full"] } axum = { version = "0.5", features = ["json", "headers", "multipart"] } -diesel = { version = "1.4.4", features = ["postgres", "chrono"] } -diesel-derive-enum = { version = "1.1", features = ["postgres"] } -bb8 = "0.7" -bb8-diesel = "0.2" +diesel = { version = "2.0", features = ["postgres", "chrono"] } +diesel-derive-enum = { version = "2.0.0-rc.0", features = ["postgres"] } +bb8 = "0.8" +bb8-diesel = { git = "https://github.com/overdrivenpotato/bb8-diesel.git" } dotenv = "0.15.0" rust-argon2 = "0.8" rand = "0.8.4" diff --git a/planetwars-server/src/cli.rs b/planetwars-server/src/cli.rs index f33506e..e1eeac3 100644 --- a/planetwars-server/src/cli.rs +++ b/planetwars-server/src/cli.rs @@ -38,12 +38,12 @@ impl SetPassword { let global_config = get_config().unwrap(); let pool = create_db_pool(&global_config).await; - let conn = pool.get().await.expect("could not get database connection"); + let mut conn = pool.get().await.expect("could not get database connection"); let credentials = db::users::Credentials { username: &self.username, password: &self.new_password, }; - db::users::set_user_password(credentials, &conn).expect("could not set password"); + db::users::set_user_password(credentials, &mut conn).expect("could not set password"); } } diff --git a/planetwars-server/src/db/bots.rs b/planetwars-server/src/db/bots.rs index a0a31b0..cf8bbb5 100644 --- a/planetwars-server/src/db/bots.rs +++ b/planetwars-server/src/db/bots.rs @@ -5,7 +5,7 @@ use crate::schema::{bot_versions, bots}; use chrono; #[derive(Insertable)] -#[table_name = "bots"] +#[diesel(table_name = bots)] pub struct NewBot<'a> { pub owner_id: Option, pub name: &'a str, @@ -19,29 +19,29 @@ pub struct Bot { pub active_version: Option, } -pub fn create_bot(new_bot: &NewBot, conn: &PgConnection) -> QueryResult { +pub fn create_bot(new_bot: &NewBot, conn: &mut PgConnection) -> QueryResult { diesel::insert_into(bots::table) .values(new_bot) .get_result(conn) } -pub fn find_bot(id: i32, conn: &PgConnection) -> QueryResult { +pub fn find_bot(id: i32, conn: &mut PgConnection) -> QueryResult { bots::table.find(id).first(conn) } -pub fn find_bots_by_owner(owner_id: i32, conn: &PgConnection) -> QueryResult> { +pub fn find_bots_by_owner(owner_id: i32, conn: &mut PgConnection) -> QueryResult> { bots::table .filter(bots::owner_id.eq(owner_id)) .get_results(conn) } -pub fn find_bot_by_name(name: &str, conn: &PgConnection) -> QueryResult { +pub fn find_bot_by_name(name: &str, conn: &mut PgConnection) -> QueryResult { bots::table.filter(bots::name.eq(name)).first(conn) } pub fn find_bot_with_version_by_name( bot_name: &str, - conn: &PgConnection, + conn: &mut PgConnection, ) -> QueryResult<(Bot, BotVersion)> { bots::table .inner_join(bot_versions::table.on(bots::active_version.eq(bot_versions::id.nullable()))) @@ -49,26 +49,28 @@ pub fn find_bot_with_version_by_name( .first(conn) } -pub fn all_active_bots_with_version(conn: &PgConnection) -> QueryResult> { +pub fn all_active_bots_with_version( + conn: &mut PgConnection, +) -> QueryResult> { bots::table .inner_join(bot_versions::table.on(bots::active_version.eq(bot_versions::id.nullable()))) .get_results(conn) } -pub fn find_all_bots(conn: &PgConnection) -> QueryResult> { +pub fn find_all_bots(conn: &mut PgConnection) -> QueryResult> { bots::table.get_results(conn) } /// Find all bots that have an associated active version. /// These are the bots that can be run. -pub fn find_active_bots(conn: &PgConnection) -> QueryResult> { +pub fn find_active_bots(conn: &mut PgConnection) -> QueryResult> { bots::table .filter(bots::active_version.is_not_null()) .get_results(conn) } #[derive(Insertable)] -#[table_name = "bot_versions"] +#[diesel(table_name = bot_versions)] pub struct NewBotVersion<'a> { pub bot_id: Option, pub code_bundle_path: Option<&'a str>, @@ -86,7 +88,7 @@ pub struct BotVersion { pub fn create_bot_version( new_bot_version: &NewBotVersion, - conn: &PgConnection, + conn: &mut PgConnection, ) -> QueryResult { diesel::insert_into(bot_versions::table) .values(new_bot_version) @@ -96,7 +98,7 @@ pub fn create_bot_version( pub fn set_active_version( bot_id: i32, version_id: Option, - conn: &PgConnection, + conn: &mut PgConnection, ) -> QueryResult<()> { diesel::update(bots::table.filter(bots::id.eq(bot_id))) .set(bots::active_version.eq(version_id)) @@ -104,13 +106,13 @@ pub fn set_active_version( Ok(()) } -pub fn find_bot_version(version_id: i32, conn: &PgConnection) -> QueryResult { +pub fn find_bot_version(version_id: i32, conn: &mut PgConnection) -> QueryResult { bot_versions::table .filter(bot_versions::id.eq(version_id)) .first(conn) } -pub fn find_bot_versions(bot_id: i32, conn: &PgConnection) -> QueryResult> { +pub fn find_bot_versions(bot_id: i32, conn: &mut PgConnection) -> QueryResult> { bot_versions::table .filter(bot_versions::bot_id.eq(bot_id)) .get_results(conn) diff --git a/planetwars-server/src/db/maps.rs b/planetwars-server/src/db/maps.rs index dffe4fd..8972461 100644 --- a/planetwars-server/src/db/maps.rs +++ b/planetwars-server/src/db/maps.rs @@ -3,7 +3,7 @@ use diesel::prelude::*; use crate::schema::maps; #[derive(Insertable)] -#[table_name = "maps"] +#[diesel(table_name = maps)] pub struct NewMap<'a> { pub name: &'a str, pub file_path: &'a str, @@ -16,20 +16,20 @@ pub struct Map { pub file_path: String, } -pub fn create_map(new_map: NewMap, conn: &PgConnection) -> QueryResult { +pub fn create_map(new_map: NewMap, conn: &mut PgConnection) -> QueryResult { diesel::insert_into(maps::table) .values(new_map) .get_result(conn) } -pub fn find_map(id: i32, conn: &PgConnection) -> QueryResult { +pub fn find_map(id: i32, conn: &mut PgConnection) -> QueryResult { maps::table.find(id).get_result(conn) } -pub fn find_map_by_name(name: &str, conn: &PgConnection) -> QueryResult { +pub fn find_map_by_name(name: &str, conn: &mut PgConnection) -> QueryResult { maps::table.filter(maps::name.eq(name)).first(conn) } -pub fn list_maps(conn: &PgConnection) -> QueryResult> { +pub fn list_maps(conn: &mut PgConnection) -> QueryResult> { maps::table.get_results(conn) } diff --git a/planetwars-server/src/db/matches.rs b/planetwars-server/src/db/matches.rs index 1dded43..bfec892 100644 --- a/planetwars-server/src/db/matches.rs +++ b/planetwars-server/src/db/matches.rs @@ -1,9 +1,6 @@ pub use crate::db_types::MatchState; use chrono::NaiveDateTime; use diesel::associations::BelongsTo; -use diesel::pg::Pg; -use diesel::query_builder::BoxedSelectStatement; -use diesel::query_source::{AppearsInFromClause, Once}; use diesel::sql_types::*; use diesel::{ BelongingToDsl, ExpressionMethods, JoinOnDsl, NullableExpressionMethods, QueryDsl, RunQueryDsl, @@ -18,7 +15,7 @@ use super::bots::{Bot, BotVersion}; use super::maps::Map; #[derive(Insertable)] -#[table_name = "matches"] +#[diesel(table_name = matches)] pub struct NewMatch<'a> { pub state: MatchState, pub log_path: &'a str, @@ -27,7 +24,7 @@ pub struct NewMatch<'a> { } #[derive(Insertable)] -#[table_name = "match_players"] +#[diesel(table_name = match_players)] pub struct NewMatchPlayer { /// id of the match this player is in pub match_id: i32, @@ -38,7 +35,7 @@ pub struct NewMatchPlayer { } #[derive(Queryable, Identifiable)] -#[table_name = "matches"] +#[diesel(table_name = matches)] pub struct MatchBase { pub id: i32, pub state: MatchState, @@ -50,8 +47,8 @@ pub struct MatchBase { } #[derive(Queryable, Identifiable, Associations, Clone)] -#[primary_key(match_id, player_id)] -#[belongs_to(MatchBase, foreign_key = "match_id")] +#[diesel(primary_key(match_id, player_id))] +#[diesel(belongs_to(MatchBase, foreign_key = match_id))] pub struct MatchPlayer { pub match_id: i32, pub player_id: i32, @@ -65,9 +62,9 @@ pub struct MatchPlayerData { pub fn create_match( new_match_base: &NewMatch, new_match_players: &[MatchPlayerData], - conn: &PgConnection, + conn: &mut PgConnection, ) -> QueryResult { - conn.transaction(|| { + conn.transaction(|conn| { let match_base = diesel::insert_into(matches::table) .values(new_match_base) .get_result::(conn)?; @@ -101,7 +98,7 @@ pub struct MatchData { /// Add player information to MatchBase instances fn fetch_full_match_data( matches: Vec, - conn: &PgConnection, + conn: &mut PgConnection, ) -> QueryResult> { let map_ids: HashSet = matches.iter().filter_map(|m| m.map_id).collect(); @@ -140,8 +137,8 @@ fn fetch_full_match_data( } // TODO: this method should disappear -pub fn list_matches(amount: i64, conn: &PgConnection) -> QueryResult> { - conn.transaction(|| { +pub fn list_matches(amount: i64, conn: &mut PgConnection) -> QueryResult> { + conn.transaction(|conn| { let matches = matches::table .filter(matches::state.eq(MatchState::Finished)) .order_by(matches::created_at.desc()) @@ -164,17 +161,32 @@ pub fn list_public_matches( amount: i64, before: Option, after: Option, - conn: &PgConnection, + conn: &mut PgConnection, ) -> QueryResult> { - conn.transaction(|| { + conn.transaction(|conn| { // TODO: how can this common logic be abstracted? - let query = matches::table + let mut query = matches::table .filter(matches::state.eq(MatchState::Finished)) .filter(matches::is_public.eq(true)) .into_boxed(); - let matches = - select_matches_page(query, amount, before, after).get_results::(conn)?; + // TODO: how to remove this duplication? + query = match (before, after) { + (None, None) => query.order_by(matches::created_at.desc()), + (Some(before), None) => query + .filter(matches::created_at.lt(before)) + .order_by(matches::created_at.desc()), + (None, Some(after)) => query + .filter(matches::created_at.gt(after)) + .order_by(matches::created_at.asc()), + (Some(before), Some(after)) => query + .filter(matches::created_at.lt(before)) + .filter(matches::created_at.gt(after)) + .order_by(matches::created_at.desc()), + }; + query = query.limit(amount); + + let matches = query.get_results::(conn)?; fetch_full_match_data(matches, conn) }) } @@ -185,7 +197,7 @@ pub fn list_bot_matches( amount: i64, before: Option, after: Option, - conn: &PgConnection, + conn: &mut PgConnection, ) -> QueryResult> { let mut query = matches::table .filter(matches::state.eq(MatchState::Finished)) @@ -211,22 +223,8 @@ pub fn list_bot_matches( }; } - let matches = - select_matches_page(query, amount, before, after).get_results::(conn)?; - fetch_full_match_data(matches, conn) -} - -fn select_matches_page( - query: BoxedSelectStatement<'static, matches::SqlType, QS, Pg>, - amount: i64, - before: Option, - after: Option, -) -> BoxedSelectStatement<'static, matches::SqlType, QS, Pg> -where - QS: AppearsInFromClause, -{ - // TODO: this is not nice. Replace this with proper cursor logic. - match (before, after) { + // TODO: how to remove this duplication? + query = match (before, after) { (None, None) => query.order_by(matches::created_at.desc()), (Some(before), None) => query .filter(matches::created_at.lt(before)) @@ -238,8 +236,11 @@ where .filter(matches::created_at.lt(before)) .filter(matches::created_at.gt(after)) .order_by(matches::created_at.desc()), - } - .limit(amount) + }; + query = query.limit(amount); + + let matches = query.get_results::(conn)?; + fetch_full_match_data(matches, conn) } // TODO: maybe unify this with matchdata? @@ -270,8 +271,8 @@ impl BelongsTo for FullMatchPlayerData { } } -pub fn find_match(id: i32, conn: &PgConnection) -> QueryResult { - conn.transaction(|| { +pub fn find_match(id: i32, conn: &mut PgConnection) -> QueryResult { + conn.transaction(|conn| { let match_base = matches::table.find(id).get_result::(conn)?; let map = match match_base.map_id { @@ -298,7 +299,7 @@ pub fn find_match(id: i32, conn: &PgConnection) -> QueryResult { }) } -pub fn find_match_base(id: i32, conn: &PgConnection) -> QueryResult { +pub fn find_match_base(id: i32, conn: &mut PgConnection) -> QueryResult { matches::table.find(id).get_result::(conn) } @@ -306,7 +307,7 @@ pub enum MatchResult { Finished { winner: Option }, } -pub fn save_match_result(id: i32, result: MatchResult, conn: &PgConnection) -> QueryResult<()> { +pub fn save_match_result(id: i32, result: MatchResult, conn: &mut PgConnection) -> QueryResult<()> { let MatchResult::Finished { winner } = result; diesel::update(matches::table.find(id)) @@ -320,17 +321,20 @@ pub fn save_match_result(id: i32, result: MatchResult, conn: &PgConnection) -> Q #[derive(QueryableByName)] pub struct BotStatsRecord { - #[sql_type = "Text"] + #[diesel(sql_type = Text)] pub opponent: String, - #[sql_type = "Text"] + #[diesel(sql_type = Text)] pub map: String, - #[sql_type = "Nullable"] + #[diesel(sql_type = Nullable)] pub win: Option, - #[sql_type = "Int8"] + #[diesel(sql_type = Int8)] pub count: i64, } -pub fn fetch_bot_stats(bot_name: &str, db_conn: &PgConnection) -> QueryResult> { +pub fn fetch_bot_stats( + bot_name: &str, + db_conn: &mut PgConnection, +) -> QueryResult> { diesel::sql_query( " SELECT opponent, map, win, COUNT(*) as count diff --git a/planetwars-server/src/db/ratings.rs b/planetwars-server/src/db/ratings.rs index 8262fed..0a510d4 100644 --- a/planetwars-server/src/db/ratings.rs +++ b/planetwars-server/src/db/ratings.rs @@ -10,7 +10,7 @@ pub struct Rating { pub rating: f64, } -pub fn get_rating(bot_id: i32, db_conn: &PgConnection) -> QueryResult> { +pub fn get_rating(bot_id: i32, db_conn: &mut PgConnection) -> QueryResult> { ratings::table .filter(ratings::bot_id.eq(bot_id)) .select(ratings::rating) @@ -18,7 +18,7 @@ pub fn get_rating(bot_id: i32, db_conn: &PgConnection) -> QueryResult QueryResult { +pub fn set_rating(bot_id: i32, rating: f64, db_conn: &mut PgConnection) -> QueryResult { diesel::insert_into(ratings::table) .values(Rating { bot_id, rating }) .on_conflict(ratings::bot_id) @@ -40,7 +40,7 @@ pub struct RankedBot { pub rating: f64, } -pub fn get_bot_ranking(db_conn: &PgConnection) -> QueryResult> { +pub fn get_bot_ranking(db_conn: &mut PgConnection) -> QueryResult> { bots::table .left_join(users::table) .inner_join(ratings::table) diff --git a/planetwars-server/src/db/sessions.rs b/planetwars-server/src/db/sessions.rs index a91d954..f8108cc 100644 --- a/planetwars-server/src/db/sessions.rs +++ b/planetwars-server/src/db/sessions.rs @@ -6,7 +6,7 @@ use diesel::{insert_into, prelude::*, Insertable, RunQueryDsl}; use rand::{self, Rng}; #[derive(Insertable)] -#[table_name = "sessions"] +#[diesel(table_name = sessions)] struct NewSession { token: String, user_id: i32, @@ -19,7 +19,7 @@ pub struct Session { pub token: String, } -pub fn create_session(user: &User, conn: &PgConnection) -> Session { +pub fn create_session(user: &User, conn: &mut PgConnection) -> Session { let new_session = NewSession { token: gen_session_token(), user_id: user.id, @@ -31,7 +31,7 @@ pub fn create_session(user: &User, conn: &PgConnection) -> Session { .unwrap() } -pub fn find_user_by_session(token: &str, conn: &PgConnection) -> QueryResult<(Session, User)> { +pub fn find_user_by_session(token: &str, conn: &mut PgConnection) -> QueryResult<(Session, User)> { sessions::table .inner_join(users::table) .filter(sessions::token.eq(&token)) diff --git a/planetwars-server/src/db/users.rs b/planetwars-server/src/db/users.rs index 9676dae..60cc20a 100644 --- a/planetwars-server/src/db/users.rs +++ b/planetwars-server/src/db/users.rs @@ -11,7 +11,7 @@ pub struct Credentials<'a> { } #[derive(Insertable)] -#[table_name = "users"] +#[diesel(table_name = users)] pub struct NewUser<'a> { pub username: &'a str, pub password_hash: &'a [u8], @@ -50,7 +50,7 @@ pub fn hash_password(password: &str) -> (Vec, [u8; 32]) { (hash, salt) } -pub fn create_user(credentials: &Credentials, conn: &PgConnection) -> QueryResult { +pub fn create_user(credentials: &Credentials, conn: &mut PgConnection) -> QueryResult { let (hash, salt) = hash_password(&credentials.password); let new_user = NewUser { @@ -63,19 +63,19 @@ pub fn create_user(credentials: &Credentials, conn: &PgConnection) -> QueryResul .get_result::(conn) } -pub fn find_user(user_id: i32, db_conn: &PgConnection) -> QueryResult { +pub fn find_user(user_id: i32, db_conn: &mut PgConnection) -> QueryResult { users::table .filter(users::id.eq(user_id)) .first::(db_conn) } -pub fn find_user_by_name(username: &str, db_conn: &PgConnection) -> QueryResult { +pub fn find_user_by_name(username: &str, db_conn: &mut PgConnection) -> QueryResult { users::table .filter(users::username.eq(username)) .first::(db_conn) } -pub fn set_user_password(credentials: Credentials, db_conn: &PgConnection) -> QueryResult<()> { +pub fn set_user_password(credentials: Credentials, db_conn: &mut PgConnection) -> QueryResult<()> { let (hash, salt) = hash_password(&credentials.password); let n_changes = diesel::update(users::table.filter(users::username.eq(&credentials.username))) @@ -91,7 +91,7 @@ pub fn set_user_password(credentials: Credentials, db_conn: &PgConnection) -> Qu } } -pub fn authenticate_user(credentials: &Credentials, db_conn: &PgConnection) -> Option { +pub fn authenticate_user(credentials: &Credentials, db_conn: &mut PgConnection) -> Option { find_user_by_name(credentials.username, db_conn) .optional() .unwrap() diff --git a/planetwars-server/src/db_types.rs b/planetwars-server/src/db_types.rs index 1b99e49..29b1e9b 100644 --- a/planetwars-server/src/db_types.rs +++ b/planetwars-server/src/db_types.rs @@ -2,7 +2,7 @@ use diesel_derive_enum::DbEnum; use serde::{Deserialize, Serialize}; #[derive(DbEnum, Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash)] -#[DieselType = "Match_state"] +#[DieselTypePath = "crate::schema::sql_types::MatchState"] pub enum MatchState { Playing, diff --git a/planetwars-server/src/lib.rs b/planetwars-server/src/lib.rs index 1696f1a..316458c 100644 --- a/planetwars-server/src/lib.rs +++ b/planetwars-server/src/lib.rs @@ -8,7 +8,7 @@ pub mod routes; pub mod schema; pub mod util; -use std::ops::Deref; +use std::ops::{Deref, DerefMut}; use std::path::PathBuf; use std::sync::Arc; use std::{fs, net::SocketAddr}; @@ -70,9 +70,9 @@ pub struct GlobalConfig { const SIMPLEBOT_PATH: &str = "../simplebot/simplebot.py"; pub async fn seed_simplebot(config: &GlobalConfig, pool: &ConnectionPool) { - let conn = pool.get().await.expect("could not get database connection"); + let mut conn = pool.get().await.expect("could not get database connection"); // This transaction is expected to fail when simplebot already exists. - let _res = conn.transaction::<(), diesel::result::Error, _>(|| { + let _res = conn.transaction::<(), diesel::result::Error, _>(|conn| { use db::bots::NewBot; let new_bot = NewBot { @@ -80,12 +80,12 @@ pub async fn seed_simplebot(config: &GlobalConfig, pool: &ConnectionPool) { owner_id: None, }; - let simplebot = db::bots::create_bot(&new_bot, &conn)?; + let simplebot = db::bots::create_bot(&new_bot, conn)?; let simplebot_code = std::fs::read_to_string(SIMPLEBOT_PATH).expect("could not read simplebot code"); - modules::bots::save_code_string(&simplebot_code, Some(simplebot.id), &conn, config)?; + modules::bots::save_code_string(&simplebot_code, Some(simplebot.id), conn, config)?; println!("initialized simplebot"); @@ -209,6 +209,12 @@ impl Deref for DatabaseConnection { } } +impl DerefMut for DatabaseConnection { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + #[async_trait] impl FromRequest for DatabaseConnection where diff --git a/planetwars-server/src/modules/bots.rs b/planetwars-server/src/modules/bots.rs index 6a2883c..6893581 100644 --- a/planetwars-server/src/modules/bots.rs +++ b/planetwars-server/src/modules/bots.rs @@ -9,7 +9,7 @@ use crate::{db, util::gen_alphanumeric, GlobalConfig}; pub fn save_code_string( bot_code: &str, bot_id: Option, - conn: &PgConnection, + conn: &mut PgConnection, config: &GlobalConfig, ) -> QueryResult { let bundle_name = gen_alphanumeric(16); diff --git a/planetwars-server/src/modules/client_api.rs b/planetwars-server/src/modules/client_api.rs index 6e5d05a..9c0bbe7 100644 --- a/planetwars-server/src/modules/client_api.rs +++ b/planetwars-server/src/modules/client_api.rs @@ -149,19 +149,19 @@ impl pb::client_api_service_server::ClientApiService for ClientApiServer { req: Request, ) -> Result, Status> { // TODO: unify with matchrunner module - let conn = self.conn_pool.get().await.unwrap(); + let mut conn = self.conn_pool.get().await.unwrap(); let match_request = req.get_ref(); let (opponent_bot, opponent_bot_version) = - db::bots::find_bot_with_version_by_name(&match_request.opponent_name, &conn) + db::bots::find_bot_with_version_by_name(&match_request.opponent_name, &mut conn) .map_err(|_| Status::not_found("opponent not found"))?; let map_name = match match_request.map_name.as_str() { "" => "hex", name => name, }; - let map = db::maps::find_map_by_name(map_name, &conn) + let map = db::maps::find_map_by_name(map_name, &mut conn) .map_err(|_| Status::not_found("map not found"))?; let player_key = gen_alphanumeric(32); diff --git a/planetwars-server/src/modules/matches.rs b/planetwars-server/src/modules/matches.rs index ecc7976..71e8a98 100644 --- a/planetwars-server/src/modules/matches.rs +++ b/planetwars-server/src/modules/matches.rs @@ -80,8 +80,8 @@ impl RunMatch { let match_data = { // TODO: it would be nice to get an already-open connection here when possible. // Maybe we need an additional abstraction, bundling a connection and connection pool? - let db_conn = conn_pool.get().await.expect("could not get a connection"); - self.store_in_database(&db_conn)? + let mut db_conn = conn_pool.get().await.expect("could not get a connection"); + self.store_in_database(&mut db_conn)? }; let runner_config = self.into_runner_config(); @@ -90,7 +90,7 @@ impl RunMatch { Ok((match_data, handle)) } - fn store_in_database(&self, db_conn: &PgConnection) -> QueryResult { + fn store_in_database(&self, db_conn: &mut PgConnection) -> QueryResult { let new_match_data = db::matches::NewMatch { state: db::matches::MatchState::Playing, log_path: &self.log_file_name, @@ -167,7 +167,7 @@ async fn run_match_task( let outcome = runner::run_match(match_config).await; // update match state in database - let conn = connection_pool + let mut conn = connection_pool .get() .await .expect("could not get database connection"); @@ -176,7 +176,8 @@ async fn run_match_task( winner: outcome.winner.map(|w| (w - 1) as i32), // player numbers in matchrunner start at 1 }; - db::matches::save_match_result(match_id, result, &conn).expect("could not save match result"); + db::matches::save_match_result(match_id, result, &mut conn) + .expect("could not save match result"); outcome } diff --git a/planetwars-server/src/modules/ranking.rs b/planetwars-server/src/modules/ranking.rs index 90c4a56..92f0f8a 100644 --- a/planetwars-server/src/modules/ranking.rs +++ b/planetwars-server/src/modules/ranking.rs @@ -20,13 +20,14 @@ pub async fn run_ranker(config: Arc, db_pool: DbPool) { // TODO: make this configurable // play at most one match every n seconds let mut interval = tokio::time::interval(Duration::from_secs(RANKER_INTERVAL)); - let db_conn = db_pool + let mut db_conn = db_pool .get() .await .expect("could not get database connection"); loop { interval.tick().await; - let bots = db::bots::all_active_bots_with_version(&db_conn).expect("could not load bots"); + let bots = + db::bots::all_active_bots_with_version(&mut db_conn).expect("could not load bots"); if bots.len() < 2 { // not enough bots to play a match continue; @@ -37,14 +38,14 @@ pub async fn run_ranker(config: Arc, db_pool: DbPool) { .cloned() .collect(); - let maps = db::maps::list_maps(&db_conn).expect("could not load map"); + let maps = db::maps::list_maps(&mut db_conn).expect("could not load map"); let map = match maps.choose(&mut rand::thread_rng()).cloned() { None => continue, // no maps available Some(map) => map, }; play_ranking_match(config.clone(), map, selected_bots, db_pool.clone()).await; - recalculate_ratings(&db_conn).expect("could not recalculate ratings"); + recalculate_ratings(&mut db_conn).expect("could not recalculate ratings"); } } @@ -71,7 +72,7 @@ async fn play_ranking_match( let _outcome = handle.await; } -fn recalculate_ratings(db_conn: &PgConnection) -> QueryResult<()> { +fn recalculate_ratings(db_conn: &mut PgConnection) -> QueryResult<()> { let start = Instant::now(); let match_stats = fetch_match_stats(db_conn)?; let ratings = estimate_ratings_from_stats(match_stats); @@ -91,7 +92,7 @@ struct MatchStats { num_matches: usize, } -fn fetch_match_stats(db_conn: &PgConnection) -> QueryResult> { +fn fetch_match_stats(db_conn: &mut PgConnection) -> QueryResult> { let matches = db::matches::list_matches(RANKER_NUM_MATCHES, db_conn)?; let mut match_stats = HashMap::<(i32, i32), MatchStats>::new(); diff --git a/planetwars-server/src/modules/registry.rs b/planetwars-server/src/modules/registry.rs index 4a79d59..5e1e05b 100644 --- a/planetwars-server/src/modules/registry.rs +++ b/planetwars-server/src/modules/registry.rs @@ -112,8 +112,8 @@ where Err(RegistryAuthError::InvalidCredentials) } } else { - let db_conn = DatabaseConnection::from_request(req).await.unwrap(); - let user = authenticate_user(&credentials, &db_conn) + let mut db_conn = DatabaseConnection::from_request(req).await.unwrap(); + let user = authenticate_user(&credentials, &mut db_conn) .ok_or(RegistryAuthError::InvalidCredentials)?; Ok(RegistryAuth::User(user)) @@ -159,12 +159,12 @@ pub struct RegistryError { } async fn check_blob_exists( - db_conn: DatabaseConnection, + mut db_conn: DatabaseConnection, auth: RegistryAuth, Path((repository_name, raw_digest)): Path<(String, String)>, Extension(config): Extension>, ) -> Result { - check_access(&repository_name, &auth, &db_conn)?; + check_access(&repository_name, &auth, &mut db_conn)?; let digest = raw_digest.strip_prefix("sha256:").unwrap(); let blob_path = PathBuf::from(&config.registry_directory) @@ -179,12 +179,12 @@ async fn check_blob_exists( } async fn get_blob( - db_conn: DatabaseConnection, + mut db_conn: DatabaseConnection, auth: RegistryAuth, Path((repository_name, raw_digest)): Path<(String, String)>, Extension(config): Extension>, ) -> Result { - check_access(&repository_name, &auth, &db_conn)?; + check_access(&repository_name, &auth, &mut db_conn)?; let digest = raw_digest.strip_prefix("sha256:").unwrap(); let blob_path = PathBuf::from(&config.registry_directory) @@ -200,12 +200,12 @@ async fn get_blob( } async fn create_upload( - db_conn: DatabaseConnection, + mut db_conn: DatabaseConnection, auth: RegistryAuth, Path(repository_name): Path, Extension(config): Extension>, ) -> Result { - check_access(&repository_name, &auth, &db_conn)?; + check_access(&repository_name, &auth, &mut db_conn)?; let uuid = gen_alphanumeric(16); tokio::fs::File::create( @@ -229,13 +229,13 @@ async fn create_upload( } async fn patch_upload( - db_conn: DatabaseConnection, + mut db_conn: DatabaseConnection, auth: RegistryAuth, Path((repository_name, uuid)): Path<(String, String)>, mut stream: BodyStream, Extension(config): Extension>, ) -> Result { - check_access(&repository_name, &auth, &db_conn)?; + check_access(&repository_name, &auth, &mut db_conn)?; // TODO: support content range header in request let upload_path = PathBuf::from(&config.registry_directory) @@ -275,14 +275,14 @@ struct UploadParams { } async fn put_upload( - db_conn: DatabaseConnection, + mut db_conn: DatabaseConnection, auth: RegistryAuth, Path((repository_name, uuid)): Path<(String, String)>, Query(params): Query, mut stream: BodyStream, Extension(config): Extension>, ) -> Result { - check_access(&repository_name, &auth, &db_conn)?; + check_access(&repository_name, &auth, &mut db_conn)?; let upload_path = PathBuf::from(&config.registry_directory) .join("uploads") @@ -332,12 +332,12 @@ async fn put_upload( } async fn get_manifest( - db_conn: DatabaseConnection, + mut db_conn: DatabaseConnection, auth: RegistryAuth, Path((repository_name, reference)): Path<(String, String)>, Extension(config): Extension>, ) -> Result { - check_access(&repository_name, &auth, &db_conn)?; + check_access(&repository_name, &auth, &mut db_conn)?; let manifest_path = PathBuf::from(&config.registry_directory) .join("manifests") @@ -357,13 +357,13 @@ async fn get_manifest( } async fn put_manifest( - db_conn: DatabaseConnection, + mut db_conn: DatabaseConnection, auth: RegistryAuth, Path((repository_name, reference)): Path<(String, String)>, mut stream: BodyStream, Extension(config): Extension>, ) -> Result { - let bot = check_access(&repository_name, &auth, &db_conn)?; + let bot = check_access(&repository_name, &auth, &mut db_conn)?; let repository_dir = PathBuf::from(&config.registry_directory) .join("manifests") @@ -399,9 +399,9 @@ async fn put_manifest( code_bundle_path: None, container_digest: Some(&content_digest), }; - let version = - db::bots::create_bot_version(&new_version, &db_conn).expect("could not save bot version"); - db::bots::set_active_version(bot.id, Some(version.id), &db_conn) + let version = db::bots::create_bot_version(&new_version, &mut db_conn) + .expect("could not save bot version"); + db::bots::set_active_version(bot.id, Some(version.id), &mut db_conn) .expect("could not update bot version"); Ok(Response::builder() @@ -421,7 +421,7 @@ async fn put_manifest( fn check_access( repository_name: &str, auth: &RegistryAuth, - db_conn: &DatabaseConnection, + db_conn: &mut DatabaseConnection, ) -> Result { use diesel::OptionalExtension; diff --git a/planetwars-server/src/routes/bots.rs b/planetwars-server/src/routes/bots.rs index f8087fd..f0ff9bf 100644 --- a/planetwars-server/src/routes/bots.rs +++ b/planetwars-server/src/routes/bots.rs @@ -100,10 +100,10 @@ pub fn validate_bot_name(bot_name: &str) -> Result<(), SaveBotError> { pub async fn save_bot( Json(params): Json, user: User, - conn: DatabaseConnection, + mut conn: DatabaseConnection, Extension(config): Extension>, ) -> Result, SaveBotError> { - let res = bots::find_bot_by_name(¶ms.bot_name, &conn) + let res = bots::find_bot_by_name(¶ms.bot_name, &mut conn) .optional() .expect("could not run query"); @@ -122,10 +122,10 @@ pub async fn save_bot( name: ¶ms.bot_name, }; - bots::create_bot(&new_bot, &conn).expect("could not create bot") + bots::create_bot(&new_bot, &mut conn).expect("could not create bot") } }; - let _code_bundle = save_code_string(¶ms.code, Some(bot.id), &conn, &config) + let _code_bundle = save_code_string(¶ms.code, Some(bot.id), &mut conn, &config) .expect("failed to save code bundle"); Ok(Json(bot)) } @@ -137,12 +137,12 @@ pub struct BotParams { // TODO: can we unify this with save_bot? pub async fn create_bot( - conn: DatabaseConnection, + mut conn: DatabaseConnection, user: User, params: Json, ) -> Result<(StatusCode, Json), SaveBotError> { validate_bot_name(¶ms.name)?; - let existing_bot = bots::find_bot_by_name(¶ms.name, &conn) + let existing_bot = bots::find_bot_by_name(¶ms.name, &mut conn) .optional() .expect("could not run query"); if existing_bot.is_some() { @@ -152,26 +152,27 @@ pub async fn create_bot( owner_id: Some(user.id), name: ¶ms.name, }; - let bot = bots::create_bot(&bot_params, &conn).unwrap(); + let bot = bots::create_bot(&bot_params, &mut conn).unwrap(); Ok((StatusCode::CREATED, Json(bot))) } // TODO: handle errors pub async fn get_bot( - conn: DatabaseConnection, + mut conn: DatabaseConnection, Path(bot_name): Path, ) -> Result, StatusCode> { - let bot = db::bots::find_bot_by_name(&bot_name, &conn).map_err(|_| StatusCode::NOT_FOUND)?; + let bot = + db::bots::find_bot_by_name(&bot_name, &mut conn).map_err(|_| StatusCode::NOT_FOUND)?; let owner: Option = match bot.owner_id { Some(user_id) => { - let user = db::users::find_user(user_id, &conn) + let user = db::users::find_user(user_id, &mut conn) .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; Some(user.into()) } None => None, }; - let versions = - bots::find_bot_versions(bot.id, &conn).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; + let versions = bots::find_bot_versions(bot.id, &mut conn) + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; Ok(Json(json!({ "bot": bot, "owner": owner, @@ -180,32 +181,32 @@ pub async fn get_bot( } pub async fn get_user_bots( - conn: DatabaseConnection, + mut conn: DatabaseConnection, Path(user_name): Path, ) -> Result>, StatusCode> { let user = - db::users::find_user_by_name(&user_name, &conn).map_err(|_| StatusCode::NOT_FOUND)?; - db::bots::find_bots_by_owner(user.id, &conn) + db::users::find_user_by_name(&user_name, &mut conn).map_err(|_| StatusCode::NOT_FOUND)?; + db::bots::find_bots_by_owner(user.id, &mut conn) .map(Json) .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR) } /// List all active bots -pub async fn list_bots(conn: DatabaseConnection) -> Result>, StatusCode> { - bots::find_active_bots(&conn) +pub async fn list_bots(mut conn: DatabaseConnection) -> Result>, StatusCode> { + bots::find_active_bots(&mut conn) .map(Json) .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR) } -pub async fn get_ranking(conn: DatabaseConnection) -> Result>, StatusCode> { - ratings::get_bot_ranking(&conn) +pub async fn get_ranking(mut conn: DatabaseConnection) -> Result>, StatusCode> { + ratings::get_bot_ranking(&mut conn) .map(Json) .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR) } // TODO: currently this only implements the happy flow pub async fn upload_code_multipart( - conn: DatabaseConnection, + mut conn: DatabaseConnection, user: User, Path(bot_name): Path, mut multipart: Multipart, @@ -213,7 +214,7 @@ pub async fn upload_code_multipart( ) -> Result, StatusCode> { let bots_dir = PathBuf::from(&config.bots_directory); - let bot = bots::find_bot_by_name(&bot_name, &conn).map_err(|_| StatusCode::NOT_FOUND)?; + let bot = bots::find_bot_by_name(&bot_name, &mut conn).map_err(|_| StatusCode::NOT_FOUND)?; if Some(user.id) != bot.owner_id { return Err(StatusCode::FORBIDDEN); @@ -246,21 +247,22 @@ pub async fn upload_code_multipart( container_digest: None, }; let code_bundle = - bots::create_bot_version(&bot_version, &conn).expect("Failed to create code bundle"); + bots::create_bot_version(&bot_version, &mut conn).expect("Failed to create code bundle"); Ok(Json(code_bundle)) } pub async fn get_code( - conn: DatabaseConnection, + mut conn: DatabaseConnection, user: User, Path(bundle_id): Path, Extension(config): Extension>, ) -> Result, StatusCode> { let version = - db::bots::find_bot_version(bundle_id, &conn).map_err(|_| StatusCode::NOT_FOUND)?; + db::bots::find_bot_version(bundle_id, &mut conn).map_err(|_| StatusCode::NOT_FOUND)?; let bot_id = version.bot_id.ok_or(StatusCode::FORBIDDEN)?; - let bot = db::bots::find_bot(bot_id, &conn).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; + let bot = + db::bots::find_bot(bot_id, &mut conn).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; if bot.owner_id != Some(user.id) { return Err(StatusCode::FORBIDDEN); @@ -297,10 +299,10 @@ impl MatchupStats { type BotStats = HashMap>; pub async fn get_bot_stats( - conn: DatabaseConnection, + mut conn: DatabaseConnection, Path(bot_name): Path, ) -> Result, StatusCode> { - let stats_records = db::matches::fetch_bot_stats(&bot_name, &conn) + let stats_records = db::matches::fetch_bot_stats(&bot_name, &mut conn) .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; let mut bot_stats: BotStats = HashMap::new(); for record in stats_records { diff --git a/planetwars-server/src/routes/demo.rs b/planetwars-server/src/routes/demo.rs index 1ec8825..cd490ef 100644 --- a/planetwars-server/src/routes/demo.rs +++ b/planetwars-server/src/routes/demo.rs @@ -35,7 +35,7 @@ pub async fn submit_bot( Extension(pool): Extension, Extension(config): Extension>, ) -> Result, StatusCode> { - let conn = pool.get().await.expect("could not get database connection"); + let mut conn = pool.get().await.expect("could not get database connection"); let opponent_name = params .opponent_name @@ -46,12 +46,13 @@ pub async fn submit_bot( .unwrap_or_else(|| DEFAULT_MAP_NAME.to_string()); let (opponent_bot, opponent_bot_version) = - db::bots::find_bot_with_version_by_name(&opponent_name, &conn) + db::bots::find_bot_with_version_by_name(&opponent_name, &mut conn) .map_err(|_| StatusCode::BAD_REQUEST)?; - let map = db::maps::find_map_by_name(&map_name, &conn).map_err(|_| StatusCode::BAD_REQUEST)?; + let map = + db::maps::find_map_by_name(&map_name, &mut conn).map_err(|_| StatusCode::BAD_REQUEST)?; - let player_bot_version = save_code_string(¶ms.code, None, &conn, &config) + let player_bot_version = save_code_string(¶ms.code, None, &mut conn, &config) // TODO: can we recover from this? .expect("could not save bot code"); diff --git a/planetwars-server/src/routes/maps.rs b/planetwars-server/src/routes/maps.rs index 689b11e..188089f 100644 --- a/planetwars-server/src/routes/maps.rs +++ b/planetwars-server/src/routes/maps.rs @@ -8,8 +8,8 @@ pub struct ApiMap { pub name: String, } -pub async fn list_maps(conn: DatabaseConnection) -> Result>, StatusCode> { - let maps = db::maps::list_maps(&conn).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; +pub async fn list_maps(mut conn: DatabaseConnection) -> Result>, StatusCode> { + let maps = db::maps::list_maps(&mut conn).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; let api_maps = maps .into_iter() diff --git a/planetwars-server/src/routes/matches.rs b/planetwars-server/src/routes/matches.rs index 1d7403c..3ad10cf 100644 --- a/planetwars-server/src/routes/matches.rs +++ b/planetwars-server/src/routes/matches.rs @@ -56,7 +56,7 @@ pub struct ListMatchesResponse { pub async fn list_recent_matches( Query(params): Query, - conn: DatabaseConnection, + mut conn: DatabaseConnection, ) -> Result, StatusCode> { let requested_count = std::cmp::min( params.count.unwrap_or(DEFAULT_NUM_RETURNED_MATCHES), @@ -68,7 +68,7 @@ pub async fn list_recent_matches( let matches_result = match params.bot { Some(bot_name) => { - let bot = db::bots::find_bot_by_name(&bot_name, &conn) + let bot = db::bots::find_bot_by_name(&bot_name, &mut conn) .map_err(|_| StatusCode::BAD_REQUEST)?; matches::list_bot_matches( bot.id, @@ -76,10 +76,10 @@ pub async fn list_recent_matches( count, params.before, params.after, - &conn, + &mut conn, ) } - None => matches::list_public_matches(count, params.before, params.after, &conn), + None => matches::list_public_matches(count, params.before, params.after, &mut conn), }; let mut matches = matches_result.map_err(|_| StatusCode::BAD_REQUEST)?; @@ -119,9 +119,9 @@ pub fn match_data_to_api(data: matches::FullMatchData) -> ApiMatch { pub async fn get_match_data( Path(match_id): Path, - conn: DatabaseConnection, + mut conn: DatabaseConnection, ) -> Result, StatusCode> { - let match_data = matches::find_match(match_id, &conn) + let match_data = matches::find_match(match_id, &mut conn) .map_err(|_| StatusCode::NOT_FOUND) .map(match_data_to_api)?; Ok(Json(match_data)) @@ -129,11 +129,11 @@ pub async fn get_match_data( pub async fn get_match_log( Path(match_id): Path, - conn: DatabaseConnection, + mut conn: DatabaseConnection, Extension(config): Extension>, ) -> Result, StatusCode> { let match_base = - matches::find_match_base(match_id, &conn).map_err(|_| StatusCode::NOT_FOUND)?; + matches::find_match_base(match_id, &mut conn).map_err(|_| StatusCode::NOT_FOUND)?; let log_path = PathBuf::from(&config.match_logs_directory).join(&match_base.log_path); let log_contents = std::fs::read(log_path).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; Ok(log_contents) diff --git a/planetwars-server/src/routes/users.rs b/planetwars-server/src/routes/users.rs index 264e5b9..d072d0a 100644 --- a/planetwars-server/src/routes/users.rs +++ b/planetwars-server/src/routes/users.rs @@ -23,13 +23,13 @@ where type Rejection = (StatusCode, String); async fn from_request(req: &mut RequestParts) -> Result { - let conn = DatabaseConnection::from_request(req).await?; + let mut conn = DatabaseConnection::from_request(req).await?; let TypedHeader(Authorization(bearer)) = AuthorizationHeader::from_request(req) .await .map_err(|_| (StatusCode::UNAUTHORIZED, "".to_string()))?; - let (_session, user) = sessions::find_user_by_session(bearer.token(), &conn) + let (_session, user) = sessions::find_user_by_session(bearer.token(), &mut conn) .map_err(|_| (StatusCode::UNAUTHORIZED, "".to_string()))?; Ok(user) @@ -66,7 +66,7 @@ pub enum RegistrationError { } impl RegistrationParams { - fn validate(&self, conn: &DatabaseConnection) -> Result<(), RegistrationError> { + fn validate(&self, conn: &mut DatabaseConnection) -> Result<(), RegistrationError> { let mut errors = Vec::new(); // TODO: do we want to support cased usernames? @@ -95,7 +95,7 @@ impl RegistrationParams { errors.push("that username is not allowed".to_string()); } - if users::find_user_by_name(&self.username, &conn).is_ok() { + if users::find_user_by_name(&self.username, conn).is_ok() { errors.push("username is already taken".to_string()); } @@ -137,16 +137,16 @@ impl IntoResponse for RegistrationError { } pub async fn register( - conn: DatabaseConnection, + mut conn: DatabaseConnection, params: Json, ) -> Result, RegistrationError> { - params.validate(&conn)?; + params.validate(&mut conn)?; let credentials = Credentials { username: ¶ms.username, password: ¶ms.password, }; - let user = users::create_user(&credentials, &conn)?; + let user = users::create_user(&credentials, &mut conn)?; Ok(Json(user.into())) } @@ -156,18 +156,18 @@ pub struct LoginParams { pub password: String, } -pub async fn login(conn: DatabaseConnection, params: Json) -> Response { +pub async fn login(mut conn: DatabaseConnection, params: Json) -> Response { let credentials = Credentials { username: ¶ms.username, password: ¶ms.password, }; // TODO: handle failures - let authenticated = users::authenticate_user(&credentials, &conn); + let authenticated = users::authenticate_user(&credentials, &mut conn); match authenticated { None => StatusCode::FORBIDDEN.into_response(), Some(user) => { - let session = sessions::create_session(&user, &conn); + let session = sessions::create_session(&user, &mut conn); let user_data: UserData = user.into(); let headers = [("Token", &session.token)]; diff --git a/planetwars-server/src/schema.rs b/planetwars-server/src/schema.rs index adc6555..27ebebe 100644 --- a/planetwars-server/src/schema.rs +++ b/planetwars-server/src/schema.rs @@ -1,7 +1,15 @@ // This file is autogenerated by diesel #![allow(unused_imports)] -table! { +// @generated automatically by Diesel CLI. + +pub mod sql_types { + #[derive(diesel::sql_types::SqlType)] + #[diesel(postgres_type(name = "match_state"))] + pub struct MatchState; +} + +diesel::table! { use diesel::sql_types::*; use crate::db_types::*; @@ -14,7 +22,7 @@ table! { } } -table! { +diesel::table! { use diesel::sql_types::*; use crate::db_types::*; @@ -26,7 +34,7 @@ table! { } } -table! { +diesel::table! { use diesel::sql_types::*; use crate::db_types::*; @@ -37,7 +45,7 @@ table! { } } -table! { +diesel::table! { use diesel::sql_types::*; use crate::db_types::*; @@ -48,13 +56,14 @@ table! { } } -table! { +diesel::table! { use diesel::sql_types::*; use crate::db_types::*; + use super::sql_types::MatchState; matches (id) { id -> Int4, - state -> Match_state, + state -> MatchState, log_path -> Text, created_at -> Timestamp, winner -> Nullable, @@ -63,7 +72,7 @@ table! { } } -table! { +diesel::table! { use diesel::sql_types::*; use crate::db_types::*; @@ -73,7 +82,7 @@ table! { } } -table! { +diesel::table! { use diesel::sql_types::*; use crate::db_types::*; @@ -84,7 +93,7 @@ table! { } } -table! { +diesel::table! { use diesel::sql_types::*; use crate::db_types::*; @@ -96,14 +105,14 @@ table! { } } -joinable!(bots -> users (owner_id)); -joinable!(match_players -> bot_versions (bot_version_id)); -joinable!(match_players -> matches (match_id)); -joinable!(matches -> maps (map_id)); -joinable!(ratings -> bots (bot_id)); -joinable!(sessions -> users (user_id)); +diesel::joinable!(bots -> users (owner_id)); +diesel::joinable!(match_players -> bot_versions (bot_version_id)); +diesel::joinable!(match_players -> matches (match_id)); +diesel::joinable!(matches -> maps (map_id)); +diesel::joinable!(ratings -> bots (bot_id)); +diesel::joinable!(sessions -> users (user_id)); -allow_tables_to_appear_in_same_query!( +diesel::allow_tables_to_appear_in_same_query!( bot_versions, bots, maps, diff --git a/planetwars-server/tests/integration.rs b/planetwars-server/tests/integration.rs index ad63e0e..83de912 100644 --- a/planetwars-server/tests/integration.rs +++ b/planetwars-server/tests/integration.rs @@ -27,7 +27,7 @@ fn create_subdir>(base_path: &Path, p: P) -> io::Result { Ok(dir_path_string) } -fn clear_database(conn: &PgConnection) { +fn clear_database(conn: &mut PgConnection) { diesel::sql_query( "TRUNCATE TABLE bots, @@ -45,20 +45,20 @@ fn clear_database(conn: &PgConnection) { /// Setup a simple text fixture, having simplebot and the hex map. /// This is enough to run a simple match. -fn setup_simple_fixture(db_conn: &PgConnection, config: &GlobalConfig) { +fn setup_simple_fixture(db_conn: &mut PgConnection, config: &GlobalConfig) { let bot = db::bots::create_bot( &db::bots::NewBot { owner_id: None, name: "simplebot", }, - &db_conn, + db_conn, ) .expect("could not create simplebot"); let simplebot_code = std::fs::read_to_string("../simplebot/simplebot.py") .expect("could not read simplebot code"); let _bot_version = - modules::bots::save_code_string(&simplebot_code, Some(bot.id), &db_conn, &config) + modules::bots::save_code_string(&simplebot_code, Some(bot.id), db_conn, &config) .expect("could not save bot version"); std::fs::copy( @@ -71,7 +71,7 @@ fn setup_simple_fixture(db_conn: &PgConnection, config: &GlobalConfig) { name: "hex", file_path: "hex.json", }, - &db_conn, + db_conn, ) .expect("could not save map"); } @@ -119,14 +119,14 @@ impl<'a> TestApp<'a> { async fn with_db_conn(&self, function: F) -> R where - F: FnOnce(&PgConnection) -> R, + F: FnOnce(&mut PgConnection) -> R, { - let db_conn = self + let mut db_conn = self .db_pool .get() .await .expect("could not get db connection"); - function(&db_conn) + function(&mut db_conn) } }