upgrade to diesel 2.0

This commit is contained in:
Ilion Beyst 2022-10-12 22:52:15 +02:00
parent ed016773b1
commit ae57359353
24 changed files with 357 additions and 220 deletions

View file

@ -43,7 +43,7 @@ jobs:
- name: Setup tests - name: Setup tests
run: | run: |
docker pull python:3.10-slim-buster docker pull python:3.10-slim-buster
cargo install diesel_cli --version ^1.4 || true cargo install diesel_cli --version ^2.0 || true
cd planetwars-server cd planetwars-server
diesel migration run --locked-schema diesel migration run --locked-schema
env: env:

149
Cargo.lock generated
View file

@ -29,6 +29,15 @@ dependencies = [
"alloc-no-stdlib", "alloc-no-stdlib",
] ]
[[package]]
name = "android_system_properties"
version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "819e7219dbd41043ac279b19830f2efc897156490d7fd6ea916720117ee66311"
dependencies = [
"libc",
]
[[package]] [[package]]
name = "anyhow" name = "anyhow"
version = "1.0.58" version = "1.0.58"
@ -171,22 +180,21 @@ checksum = "904dfeac50f3cdaba28fc6f57fdcddb75f49ed61346676a78c4ffe55877802fd"
[[package]] [[package]]
name = "bb8" name = "bb8"
version = "0.7.1" version = "0.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2e9f4fa9768efd269499d8fba693260cfc670891cf6de3adc935588447a77cc8" checksum = "1627eccf3aa91405435ba240be23513eeca466b5dc33866422672264de061582"
dependencies = [ dependencies = [
"async-trait", "async-trait",
"futures-channel", "futures-channel",
"futures-util", "futures-util",
"parking_lot 0.11.2", "parking_lot 0.12.1",
"tokio", "tokio",
] ]
[[package]] [[package]]
name = "bb8-diesel" name = "bb8-diesel"
version = "0.2.1" version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "git+https://github.com/overdrivenpotato/bb8-diesel.git#89b76207bbca35082687c765074f402200fcc51f"
checksum = "79c87e12b0086ff7850d98a19d2a70f5fd901b463412d499514d8e2e16ad0826"
dependencies = [ dependencies = [
"async-trait", "async-trait",
"bb8", "bb8",
@ -372,15 +380,17 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
[[package]] [[package]]
name = "chrono" name = "chrono"
version = "0.4.19" version = "0.4.22"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "670ad68c9088c2a963aaa298cb369688cf3f9465ce5e2d4ca10e6e0098a1ce73" checksum = "bfd4d1b31faaa3a89d7934dbded3111da0d2ef28e3ebccdb4f0179f5929d1ef1"
dependencies = [ dependencies = [
"libc", "iana-time-zone",
"js-sys",
"num-integer", "num-integer",
"num-traits", "num-traits",
"serde", "serde",
"time", "time",
"wasm-bindgen",
"winapi", "winapi",
] ]
@ -441,6 +451,16 @@ dependencies = [
"cc", "cc",
] ]
[[package]]
name = "codespan-reporting"
version = "0.11.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3538270d33cc669650c4b093848450d380def10c331d38c768e34cac80576e6e"
dependencies = [
"termcolor",
"unicode-width",
]
[[package]] [[package]]
name = "config" name = "config"
version = "0.12.0" version = "0.12.0"
@ -520,6 +540,50 @@ dependencies = [
"typenum", "typenum",
] ]
[[package]]
name = "cxx"
version = "1.0.78"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "19f39818dcfc97d45b03953c1292efc4e80954e1583c4aa770bac1383e2310a4"
dependencies = [
"cc",
"cxxbridge-flags",
"cxxbridge-macro",
"link-cplusplus",
]
[[package]]
name = "cxx-build"
version = "1.0.78"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3e580d70777c116df50c390d1211993f62d40302881e54d4b79727acb83d0199"
dependencies = [
"cc",
"codespan-reporting",
"once_cell",
"proc-macro2",
"quote",
"scratch",
"syn",
]
[[package]]
name = "cxxbridge-flags"
version = "1.0.78"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "56a46460b88d1cec95112c8c363f0e2c39afdb237f60583b0b36343bf627ea9c"
[[package]]
name = "cxxbridge-macro"
version = "1.0.78"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "747b608fecf06b0d72d440f27acc99288207324b793be2c17991839f3d4995ea"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]] [[package]]
name = "darling" name = "darling"
version = "0.13.1" version = "0.13.1"
@ -557,23 +621,24 @@ dependencies = [
[[package]] [[package]]
name = "diesel" name = "diesel"
version = "1.4.8" version = "2.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b28135ecf6b7d446b43e27e225622a038cc4e2930a1022f51cdb97ada19b8e4d" checksum = "68c186a7418a2aac330bb76cde82f16c36b03a66fb91db32d20214311f9f6545"
dependencies = [ dependencies = [
"bitflags", "bitflags",
"byteorder", "byteorder",
"chrono", "chrono",
"diesel_derives", "diesel_derives",
"itoa 1.0.2",
"pq-sys", "pq-sys",
"r2d2", "r2d2",
] ]
[[package]] [[package]]
name = "diesel-derive-enum" name = "diesel-derive-enum"
version = "1.1.2" version = "2.0.0-rc.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6c8910921b014e2af16298f006de12aa08af894b71f0f49a486ab6d74b17bbed" checksum = "5f28fc9f5bf184ebc58ad9105dede024981e2303fe878a0fe16557f3a979064a"
dependencies = [ dependencies = [
"heck", "heck",
"proc-macro2", "proc-macro2",
@ -583,10 +648,11 @@ dependencies = [
[[package]] [[package]]
name = "diesel_derives" name = "diesel_derives"
version = "1.4.1" version = "2.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "45f5098f628d02a7a0f68ddba586fb61e80edec3bdc1be3b921f4ceec60858d3" checksum = "143b758c91dbc3fe1fdcb0dba5bd13276c6a66422f2ef5795b58488248a310aa"
dependencies = [ dependencies = [
"proc-macro-error",
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn", "syn",
@ -1014,6 +1080,30 @@ dependencies = [
"tokio", "tokio",
] ]
[[package]]
name = "iana-time-zone"
version = "0.1.51"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f5a6ef98976b22b3b7f2f3a806f858cb862044cfa66805aa3ad84cb3d3b785ed"
dependencies = [
"android_system_properties",
"core-foundation-sys",
"iana-time-zone-haiku",
"js-sys",
"wasm-bindgen",
"winapi",
]
[[package]]
name = "iana-time-zone-haiku"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fde6edd6cef363e9359ed3c98ba64590ba9eecba2293eb5a723ab32aee8926aa"
dependencies = [
"cxx",
"cxx-build",
]
[[package]] [[package]]
name = "ident_case" name = "ident_case"
version = "1.0.1" version = "1.0.1"
@ -1112,6 +1202,15 @@ version = "0.2.126"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "349d5a591cd28b49e1d1037471617a32ddcda5731b99419008085f72d5a53836" checksum = "349d5a591cd28b49e1d1037471617a32ddcda5731b99419008085f72d5a53836"
[[package]]
name = "link-cplusplus"
version = "1.0.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9272ab7b96c9046fbc5bc56c06c117cb639fe2d509df0c421cad82d2915cf369"
dependencies = [
"cc",
]
[[package]] [[package]]
name = "linked-hash-map" name = "linked-hash-map"
version = "0.5.4" version = "0.5.4"
@ -1195,9 +1294,9 @@ dependencies = [
[[package]] [[package]]
name = "mio" name = "mio"
version = "0.8.3" version = "0.8.4"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "713d550d9b44d89174e066b7a6217ae06234c10cb47819a88290d2b353c31799" checksum = "57ee1c23c7c63b0c9250c339ffdc69255f110b298b901b9f6c82547b7b87caaf"
dependencies = [ dependencies = [
"libc", "libc",
"log", "log",
@ -1999,6 +2098,12 @@ version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd" checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd"
[[package]]
name = "scratch"
version = "1.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9c8132065adcfd6e02db789d9285a0deb2f3fcb04002865ab67d5fb103533898"
[[package]] [[package]]
name = "sct" name = "sct"
version = "0.7.0" version = "0.7.0"
@ -2292,16 +2397,16 @@ checksum = "cda74da7e1a664f795bb1f8a87ec406fb89a02522cf6e50620d016add6dbbf5c"
[[package]] [[package]]
name = "tokio" name = "tokio"
version = "1.19.2" version = "1.21.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c51a52ed6686dd62c320f9b89299e9dfb46f730c7a48e635c19f21d116cb1439" checksum = "a9e03c497dc955702ba729190dc4aac6f2a0ce97f913e5b1b5912fc5039d9099"
dependencies = [ dependencies = [
"autocfg 1.1.0",
"bytes", "bytes",
"libc", "libc",
"memchr", "memchr",
"mio", "mio",
"num_cpus", "num_cpus",
"once_cell",
"parking_lot 0.12.1", "parking_lot 0.12.1",
"pin-project-lite", "pin-project-lite",
"signal-hook-registry", "signal-hook-registry",
@ -2592,6 +2697,12 @@ dependencies = [
"tinyvec", "tinyvec",
] ]
[[package]]
name = "unicode-width"
version = "0.1.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c0edd1e5b14653f783770bce4a4dabb4a5108a5370a5f5d8cfe8710c361f6c8b"
[[package]] [[package]]
name = "untrusted" name = "untrusted"
version = "0.7.1" version = "0.7.1"

View file

@ -15,15 +15,15 @@ path = "src/cli.rs"
[dependencies] [dependencies]
futures = "0.3" futures = "0.3"
tokio = { version = "1.15", features = ["full"] } tokio = { version = "1.21", features = ["full"] }
tokio-stream = "0.1.9" tokio-stream = "0.1.9"
hyper = "0.14" hyper = "0.14"
tower-http = { version = "0.3.4", features = ["full"] } tower-http = { version = "0.3.4", features = ["full"] }
axum = { version = "0.5", features = ["json", "headers", "multipart"] } axum = { version = "0.5", features = ["json", "headers", "multipart"] }
diesel = { version = "1.4.4", features = ["postgres", "chrono"] } diesel = { version = "2.0", features = ["postgres", "chrono"] }
diesel-derive-enum = { version = "1.1", features = ["postgres"] } diesel-derive-enum = { version = "2.0.0-rc.0", features = ["postgres"] }
bb8 = "0.7" bb8 = "0.8"
bb8-diesel = "0.2" bb8-diesel = { git = "https://github.com/overdrivenpotato/bb8-diesel.git" }
dotenv = "0.15.0" dotenv = "0.15.0"
rust-argon2 = "0.8" rust-argon2 = "0.8"
rand = "0.8.4" rand = "0.8.4"

View file

@ -38,12 +38,12 @@ impl SetPassword {
let global_config = get_config().unwrap(); let global_config = get_config().unwrap();
let pool = create_db_pool(&global_config).await; let pool = create_db_pool(&global_config).await;
let conn = pool.get().await.expect("could not get database connection"); let mut conn = pool.get().await.expect("could not get database connection");
let credentials = db::users::Credentials { let credentials = db::users::Credentials {
username: &self.username, username: &self.username,
password: &self.new_password, password: &self.new_password,
}; };
db::users::set_user_password(credentials, &conn).expect("could not set password"); db::users::set_user_password(credentials, &mut conn).expect("could not set password");
} }
} }

View file

@ -5,7 +5,7 @@ use crate::schema::{bot_versions, bots};
use chrono; use chrono;
#[derive(Insertable)] #[derive(Insertable)]
#[table_name = "bots"] #[diesel(table_name = bots)]
pub struct NewBot<'a> { pub struct NewBot<'a> {
pub owner_id: Option<i32>, pub owner_id: Option<i32>,
pub name: &'a str, pub name: &'a str,
@ -19,29 +19,29 @@ pub struct Bot {
pub active_version: Option<i32>, pub active_version: Option<i32>,
} }
pub fn create_bot(new_bot: &NewBot, conn: &PgConnection) -> QueryResult<Bot> { pub fn create_bot(new_bot: &NewBot, conn: &mut PgConnection) -> QueryResult<Bot> {
diesel::insert_into(bots::table) diesel::insert_into(bots::table)
.values(new_bot) .values(new_bot)
.get_result(conn) .get_result(conn)
} }
pub fn find_bot(id: i32, conn: &PgConnection) -> QueryResult<Bot> { pub fn find_bot(id: i32, conn: &mut PgConnection) -> QueryResult<Bot> {
bots::table.find(id).first(conn) bots::table.find(id).first(conn)
} }
pub fn find_bots_by_owner(owner_id: i32, conn: &PgConnection) -> QueryResult<Vec<Bot>> { pub fn find_bots_by_owner(owner_id: i32, conn: &mut PgConnection) -> QueryResult<Vec<Bot>> {
bots::table bots::table
.filter(bots::owner_id.eq(owner_id)) .filter(bots::owner_id.eq(owner_id))
.get_results(conn) .get_results(conn)
} }
pub fn find_bot_by_name(name: &str, conn: &PgConnection) -> QueryResult<Bot> { pub fn find_bot_by_name(name: &str, conn: &mut PgConnection) -> QueryResult<Bot> {
bots::table.filter(bots::name.eq(name)).first(conn) bots::table.filter(bots::name.eq(name)).first(conn)
} }
pub fn find_bot_with_version_by_name( pub fn find_bot_with_version_by_name(
bot_name: &str, bot_name: &str,
conn: &PgConnection, conn: &mut PgConnection,
) -> QueryResult<(Bot, BotVersion)> { ) -> QueryResult<(Bot, BotVersion)> {
bots::table bots::table
.inner_join(bot_versions::table.on(bots::active_version.eq(bot_versions::id.nullable()))) .inner_join(bot_versions::table.on(bots::active_version.eq(bot_versions::id.nullable())))
@ -49,26 +49,28 @@ pub fn find_bot_with_version_by_name(
.first(conn) .first(conn)
} }
pub fn all_active_bots_with_version(conn: &PgConnection) -> QueryResult<Vec<(Bot, BotVersion)>> { pub fn all_active_bots_with_version(
conn: &mut PgConnection,
) -> QueryResult<Vec<(Bot, BotVersion)>> {
bots::table bots::table
.inner_join(bot_versions::table.on(bots::active_version.eq(bot_versions::id.nullable()))) .inner_join(bot_versions::table.on(bots::active_version.eq(bot_versions::id.nullable())))
.get_results(conn) .get_results(conn)
} }
pub fn find_all_bots(conn: &PgConnection) -> QueryResult<Vec<Bot>> { pub fn find_all_bots(conn: &mut PgConnection) -> QueryResult<Vec<Bot>> {
bots::table.get_results(conn) bots::table.get_results(conn)
} }
/// Find all bots that have an associated active version. /// Find all bots that have an associated active version.
/// These are the bots that can be run. /// These are the bots that can be run.
pub fn find_active_bots(conn: &PgConnection) -> QueryResult<Vec<Bot>> { pub fn find_active_bots(conn: &mut PgConnection) -> QueryResult<Vec<Bot>> {
bots::table bots::table
.filter(bots::active_version.is_not_null()) .filter(bots::active_version.is_not_null())
.get_results(conn) .get_results(conn)
} }
#[derive(Insertable)] #[derive(Insertable)]
#[table_name = "bot_versions"] #[diesel(table_name = bot_versions)]
pub struct NewBotVersion<'a> { pub struct NewBotVersion<'a> {
pub bot_id: Option<i32>, pub bot_id: Option<i32>,
pub code_bundle_path: Option<&'a str>, pub code_bundle_path: Option<&'a str>,
@ -86,7 +88,7 @@ pub struct BotVersion {
pub fn create_bot_version( pub fn create_bot_version(
new_bot_version: &NewBotVersion, new_bot_version: &NewBotVersion,
conn: &PgConnection, conn: &mut PgConnection,
) -> QueryResult<BotVersion> { ) -> QueryResult<BotVersion> {
diesel::insert_into(bot_versions::table) diesel::insert_into(bot_versions::table)
.values(new_bot_version) .values(new_bot_version)
@ -96,7 +98,7 @@ pub fn create_bot_version(
pub fn set_active_version( pub fn set_active_version(
bot_id: i32, bot_id: i32,
version_id: Option<i32>, version_id: Option<i32>,
conn: &PgConnection, conn: &mut PgConnection,
) -> QueryResult<()> { ) -> QueryResult<()> {
diesel::update(bots::table.filter(bots::id.eq(bot_id))) diesel::update(bots::table.filter(bots::id.eq(bot_id)))
.set(bots::active_version.eq(version_id)) .set(bots::active_version.eq(version_id))
@ -104,13 +106,13 @@ pub fn set_active_version(
Ok(()) Ok(())
} }
pub fn find_bot_version(version_id: i32, conn: &PgConnection) -> QueryResult<BotVersion> { pub fn find_bot_version(version_id: i32, conn: &mut PgConnection) -> QueryResult<BotVersion> {
bot_versions::table bot_versions::table
.filter(bot_versions::id.eq(version_id)) .filter(bot_versions::id.eq(version_id))
.first(conn) .first(conn)
} }
pub fn find_bot_versions(bot_id: i32, conn: &PgConnection) -> QueryResult<Vec<BotVersion>> { pub fn find_bot_versions(bot_id: i32, conn: &mut PgConnection) -> QueryResult<Vec<BotVersion>> {
bot_versions::table bot_versions::table
.filter(bot_versions::bot_id.eq(bot_id)) .filter(bot_versions::bot_id.eq(bot_id))
.get_results(conn) .get_results(conn)

View file

@ -3,7 +3,7 @@ use diesel::prelude::*;
use crate::schema::maps; use crate::schema::maps;
#[derive(Insertable)] #[derive(Insertable)]
#[table_name = "maps"] #[diesel(table_name = maps)]
pub struct NewMap<'a> { pub struct NewMap<'a> {
pub name: &'a str, pub name: &'a str,
pub file_path: &'a str, pub file_path: &'a str,
@ -16,20 +16,20 @@ pub struct Map {
pub file_path: String, pub file_path: String,
} }
pub fn create_map(new_map: NewMap, conn: &PgConnection) -> QueryResult<Map> { pub fn create_map(new_map: NewMap, conn: &mut PgConnection) -> QueryResult<Map> {
diesel::insert_into(maps::table) diesel::insert_into(maps::table)
.values(new_map) .values(new_map)
.get_result(conn) .get_result(conn)
} }
pub fn find_map(id: i32, conn: &PgConnection) -> QueryResult<Map> { pub fn find_map(id: i32, conn: &mut PgConnection) -> QueryResult<Map> {
maps::table.find(id).get_result(conn) maps::table.find(id).get_result(conn)
} }
pub fn find_map_by_name(name: &str, conn: &PgConnection) -> QueryResult<Map> { pub fn find_map_by_name(name: &str, conn: &mut PgConnection) -> QueryResult<Map> {
maps::table.filter(maps::name.eq(name)).first(conn) maps::table.filter(maps::name.eq(name)).first(conn)
} }
pub fn list_maps(conn: &PgConnection) -> QueryResult<Vec<Map>> { pub fn list_maps(conn: &mut PgConnection) -> QueryResult<Vec<Map>> {
maps::table.get_results(conn) maps::table.get_results(conn)
} }

View file

@ -1,9 +1,6 @@
pub use crate::db_types::MatchState; pub use crate::db_types::MatchState;
use chrono::NaiveDateTime; use chrono::NaiveDateTime;
use diesel::associations::BelongsTo; use diesel::associations::BelongsTo;
use diesel::pg::Pg;
use diesel::query_builder::BoxedSelectStatement;
use diesel::query_source::{AppearsInFromClause, Once};
use diesel::sql_types::*; use diesel::sql_types::*;
use diesel::{ use diesel::{
BelongingToDsl, ExpressionMethods, JoinOnDsl, NullableExpressionMethods, QueryDsl, RunQueryDsl, BelongingToDsl, ExpressionMethods, JoinOnDsl, NullableExpressionMethods, QueryDsl, RunQueryDsl,
@ -18,7 +15,7 @@ use super::bots::{Bot, BotVersion};
use super::maps::Map; use super::maps::Map;
#[derive(Insertable)] #[derive(Insertable)]
#[table_name = "matches"] #[diesel(table_name = matches)]
pub struct NewMatch<'a> { pub struct NewMatch<'a> {
pub state: MatchState, pub state: MatchState,
pub log_path: &'a str, pub log_path: &'a str,
@ -27,7 +24,7 @@ pub struct NewMatch<'a> {
} }
#[derive(Insertable)] #[derive(Insertable)]
#[table_name = "match_players"] #[diesel(table_name = match_players)]
pub struct NewMatchPlayer { pub struct NewMatchPlayer {
/// id of the match this player is in /// id of the match this player is in
pub match_id: i32, pub match_id: i32,
@ -38,7 +35,7 @@ pub struct NewMatchPlayer {
} }
#[derive(Queryable, Identifiable)] #[derive(Queryable, Identifiable)]
#[table_name = "matches"] #[diesel(table_name = matches)]
pub struct MatchBase { pub struct MatchBase {
pub id: i32, pub id: i32,
pub state: MatchState, pub state: MatchState,
@ -50,8 +47,8 @@ pub struct MatchBase {
} }
#[derive(Queryable, Identifiable, Associations, Clone)] #[derive(Queryable, Identifiable, Associations, Clone)]
#[primary_key(match_id, player_id)] #[diesel(primary_key(match_id, player_id))]
#[belongs_to(MatchBase, foreign_key = "match_id")] #[diesel(belongs_to(MatchBase, foreign_key = match_id))]
pub struct MatchPlayer { pub struct MatchPlayer {
pub match_id: i32, pub match_id: i32,
pub player_id: i32, pub player_id: i32,
@ -65,9 +62,9 @@ pub struct MatchPlayerData {
pub fn create_match( pub fn create_match(
new_match_base: &NewMatch, new_match_base: &NewMatch,
new_match_players: &[MatchPlayerData], new_match_players: &[MatchPlayerData],
conn: &PgConnection, conn: &mut PgConnection,
) -> QueryResult<MatchData> { ) -> QueryResult<MatchData> {
conn.transaction(|| { conn.transaction(|conn| {
let match_base = diesel::insert_into(matches::table) let match_base = diesel::insert_into(matches::table)
.values(new_match_base) .values(new_match_base)
.get_result::<MatchBase>(conn)?; .get_result::<MatchBase>(conn)?;
@ -101,7 +98,7 @@ pub struct MatchData {
/// Add player information to MatchBase instances /// Add player information to MatchBase instances
fn fetch_full_match_data( fn fetch_full_match_data(
matches: Vec<MatchBase>, matches: Vec<MatchBase>,
conn: &PgConnection, conn: &mut PgConnection,
) -> QueryResult<Vec<FullMatchData>> { ) -> QueryResult<Vec<FullMatchData>> {
let map_ids: HashSet<i32> = matches.iter().filter_map(|m| m.map_id).collect(); let map_ids: HashSet<i32> = matches.iter().filter_map(|m| m.map_id).collect();
@ -140,8 +137,8 @@ fn fetch_full_match_data(
} }
// TODO: this method should disappear // TODO: this method should disappear
pub fn list_matches(amount: i64, conn: &PgConnection) -> QueryResult<Vec<FullMatchData>> { pub fn list_matches(amount: i64, conn: &mut PgConnection) -> QueryResult<Vec<FullMatchData>> {
conn.transaction(|| { conn.transaction(|conn| {
let matches = matches::table let matches = matches::table
.filter(matches::state.eq(MatchState::Finished)) .filter(matches::state.eq(MatchState::Finished))
.order_by(matches::created_at.desc()) .order_by(matches::created_at.desc())
@ -164,17 +161,32 @@ pub fn list_public_matches(
amount: i64, amount: i64,
before: Option<NaiveDateTime>, before: Option<NaiveDateTime>,
after: Option<NaiveDateTime>, after: Option<NaiveDateTime>,
conn: &PgConnection, conn: &mut PgConnection,
) -> QueryResult<Vec<FullMatchData>> { ) -> QueryResult<Vec<FullMatchData>> {
conn.transaction(|| { conn.transaction(|conn| {
// TODO: how can this common logic be abstracted? // TODO: how can this common logic be abstracted?
let query = matches::table let mut query = matches::table
.filter(matches::state.eq(MatchState::Finished)) .filter(matches::state.eq(MatchState::Finished))
.filter(matches::is_public.eq(true)) .filter(matches::is_public.eq(true))
.into_boxed(); .into_boxed();
let matches = // TODO: how to remove this duplication?
select_matches_page(query, amount, before, after).get_results::<MatchBase>(conn)?; query = match (before, after) {
(None, None) => query.order_by(matches::created_at.desc()),
(Some(before), None) => query
.filter(matches::created_at.lt(before))
.order_by(matches::created_at.desc()),
(None, Some(after)) => query
.filter(matches::created_at.gt(after))
.order_by(matches::created_at.asc()),
(Some(before), Some(after)) => query
.filter(matches::created_at.lt(before))
.filter(matches::created_at.gt(after))
.order_by(matches::created_at.desc()),
};
query = query.limit(amount);
let matches = query.get_results::<MatchBase>(conn)?;
fetch_full_match_data(matches, conn) fetch_full_match_data(matches, conn)
}) })
} }
@ -185,7 +197,7 @@ pub fn list_bot_matches(
amount: i64, amount: i64,
before: Option<NaiveDateTime>, before: Option<NaiveDateTime>,
after: Option<NaiveDateTime>, after: Option<NaiveDateTime>,
conn: &PgConnection, conn: &mut PgConnection,
) -> QueryResult<Vec<FullMatchData>> { ) -> QueryResult<Vec<FullMatchData>> {
let mut query = matches::table let mut query = matches::table
.filter(matches::state.eq(MatchState::Finished)) .filter(matches::state.eq(MatchState::Finished))
@ -211,22 +223,8 @@ pub fn list_bot_matches(
}; };
} }
let matches = // TODO: how to remove this duplication?
select_matches_page(query, amount, before, after).get_results::<MatchBase>(conn)?; query = match (before, after) {
fetch_full_match_data(matches, conn)
}
fn select_matches_page<QS>(
query: BoxedSelectStatement<'static, matches::SqlType, QS, Pg>,
amount: i64,
before: Option<NaiveDateTime>,
after: Option<NaiveDateTime>,
) -> BoxedSelectStatement<'static, matches::SqlType, QS, Pg>
where
QS: AppearsInFromClause<matches::table, Count = Once>,
{
// TODO: this is not nice. Replace this with proper cursor logic.
match (before, after) {
(None, None) => query.order_by(matches::created_at.desc()), (None, None) => query.order_by(matches::created_at.desc()),
(Some(before), None) => query (Some(before), None) => query
.filter(matches::created_at.lt(before)) .filter(matches::created_at.lt(before))
@ -238,8 +236,11 @@ where
.filter(matches::created_at.lt(before)) .filter(matches::created_at.lt(before))
.filter(matches::created_at.gt(after)) .filter(matches::created_at.gt(after))
.order_by(matches::created_at.desc()), .order_by(matches::created_at.desc()),
} };
.limit(amount) query = query.limit(amount);
let matches = query.get_results::<MatchBase>(conn)?;
fetch_full_match_data(matches, conn)
} }
// TODO: maybe unify this with matchdata? // TODO: maybe unify this with matchdata?
@ -270,8 +271,8 @@ impl BelongsTo<MatchBase> for FullMatchPlayerData {
} }
} }
pub fn find_match(id: i32, conn: &PgConnection) -> QueryResult<FullMatchData> { pub fn find_match(id: i32, conn: &mut PgConnection) -> QueryResult<FullMatchData> {
conn.transaction(|| { conn.transaction(|conn| {
let match_base = matches::table.find(id).get_result::<MatchBase>(conn)?; let match_base = matches::table.find(id).get_result::<MatchBase>(conn)?;
let map = match match_base.map_id { let map = match match_base.map_id {
@ -298,7 +299,7 @@ pub fn find_match(id: i32, conn: &PgConnection) -> QueryResult<FullMatchData> {
}) })
} }
pub fn find_match_base(id: i32, conn: &PgConnection) -> QueryResult<MatchBase> { pub fn find_match_base(id: i32, conn: &mut PgConnection) -> QueryResult<MatchBase> {
matches::table.find(id).get_result::<MatchBase>(conn) matches::table.find(id).get_result::<MatchBase>(conn)
} }
@ -306,7 +307,7 @@ pub enum MatchResult {
Finished { winner: Option<i32> }, Finished { winner: Option<i32> },
} }
pub fn save_match_result(id: i32, result: MatchResult, conn: &PgConnection) -> QueryResult<()> { pub fn save_match_result(id: i32, result: MatchResult, conn: &mut PgConnection) -> QueryResult<()> {
let MatchResult::Finished { winner } = result; let MatchResult::Finished { winner } = result;
diesel::update(matches::table.find(id)) diesel::update(matches::table.find(id))
@ -320,17 +321,20 @@ pub fn save_match_result(id: i32, result: MatchResult, conn: &PgConnection) -> Q
#[derive(QueryableByName)] #[derive(QueryableByName)]
pub struct BotStatsRecord { pub struct BotStatsRecord {
#[sql_type = "Text"] #[diesel(sql_type = Text)]
pub opponent: String, pub opponent: String,
#[sql_type = "Text"] #[diesel(sql_type = Text)]
pub map: String, pub map: String,
#[sql_type = "Nullable<Bool>"] #[diesel(sql_type = Nullable<Bool>)]
pub win: Option<bool>, pub win: Option<bool>,
#[sql_type = "Int8"] #[diesel(sql_type = Int8)]
pub count: i64, pub count: i64,
} }
pub fn fetch_bot_stats(bot_name: &str, db_conn: &PgConnection) -> QueryResult<Vec<BotStatsRecord>> { pub fn fetch_bot_stats(
bot_name: &str,
db_conn: &mut PgConnection,
) -> QueryResult<Vec<BotStatsRecord>> {
diesel::sql_query( diesel::sql_query(
" "
SELECT opponent, map, win, COUNT(*) as count SELECT opponent, map, win, COUNT(*) as count

View file

@ -10,7 +10,7 @@ pub struct Rating {
pub rating: f64, pub rating: f64,
} }
pub fn get_rating(bot_id: i32, db_conn: &PgConnection) -> QueryResult<Option<f64>> { pub fn get_rating(bot_id: i32, db_conn: &mut PgConnection) -> QueryResult<Option<f64>> {
ratings::table ratings::table
.filter(ratings::bot_id.eq(bot_id)) .filter(ratings::bot_id.eq(bot_id))
.select(ratings::rating) .select(ratings::rating)
@ -18,7 +18,7 @@ pub fn get_rating(bot_id: i32, db_conn: &PgConnection) -> QueryResult<Option<f64
.optional() .optional()
} }
pub fn set_rating(bot_id: i32, rating: f64, db_conn: &PgConnection) -> QueryResult<usize> { pub fn set_rating(bot_id: i32, rating: f64, db_conn: &mut PgConnection) -> QueryResult<usize> {
diesel::insert_into(ratings::table) diesel::insert_into(ratings::table)
.values(Rating { bot_id, rating }) .values(Rating { bot_id, rating })
.on_conflict(ratings::bot_id) .on_conflict(ratings::bot_id)
@ -40,7 +40,7 @@ pub struct RankedBot {
pub rating: f64, pub rating: f64,
} }
pub fn get_bot_ranking(db_conn: &PgConnection) -> QueryResult<Vec<RankedBot>> { pub fn get_bot_ranking(db_conn: &mut PgConnection) -> QueryResult<Vec<RankedBot>> {
bots::table bots::table
.left_join(users::table) .left_join(users::table)
.inner_join(ratings::table) .inner_join(ratings::table)

View file

@ -6,7 +6,7 @@ use diesel::{insert_into, prelude::*, Insertable, RunQueryDsl};
use rand::{self, Rng}; use rand::{self, Rng};
#[derive(Insertable)] #[derive(Insertable)]
#[table_name = "sessions"] #[diesel(table_name = sessions)]
struct NewSession { struct NewSession {
token: String, token: String,
user_id: i32, user_id: i32,
@ -19,7 +19,7 @@ pub struct Session {
pub token: String, pub token: String,
} }
pub fn create_session(user: &User, conn: &PgConnection) -> Session { pub fn create_session(user: &User, conn: &mut PgConnection) -> Session {
let new_session = NewSession { let new_session = NewSession {
token: gen_session_token(), token: gen_session_token(),
user_id: user.id, user_id: user.id,
@ -31,7 +31,7 @@ pub fn create_session(user: &User, conn: &PgConnection) -> Session {
.unwrap() .unwrap()
} }
pub fn find_user_by_session(token: &str, conn: &PgConnection) -> QueryResult<(Session, User)> { pub fn find_user_by_session(token: &str, conn: &mut PgConnection) -> QueryResult<(Session, User)> {
sessions::table sessions::table
.inner_join(users::table) .inner_join(users::table)
.filter(sessions::token.eq(&token)) .filter(sessions::token.eq(&token))

View file

@ -11,7 +11,7 @@ pub struct Credentials<'a> {
} }
#[derive(Insertable)] #[derive(Insertable)]
#[table_name = "users"] #[diesel(table_name = users)]
pub struct NewUser<'a> { pub struct NewUser<'a> {
pub username: &'a str, pub username: &'a str,
pub password_hash: &'a [u8], pub password_hash: &'a [u8],
@ -50,7 +50,7 @@ pub fn hash_password(password: &str) -> (Vec<u8>, [u8; 32]) {
(hash, salt) (hash, salt)
} }
pub fn create_user(credentials: &Credentials, conn: &PgConnection) -> QueryResult<User> { pub fn create_user(credentials: &Credentials, conn: &mut PgConnection) -> QueryResult<User> {
let (hash, salt) = hash_password(&credentials.password); let (hash, salt) = hash_password(&credentials.password);
let new_user = NewUser { let new_user = NewUser {
@ -63,19 +63,19 @@ pub fn create_user(credentials: &Credentials, conn: &PgConnection) -> QueryResul
.get_result::<User>(conn) .get_result::<User>(conn)
} }
pub fn find_user(user_id: i32, db_conn: &PgConnection) -> QueryResult<User> { pub fn find_user(user_id: i32, db_conn: &mut PgConnection) -> QueryResult<User> {
users::table users::table
.filter(users::id.eq(user_id)) .filter(users::id.eq(user_id))
.first::<User>(db_conn) .first::<User>(db_conn)
} }
pub fn find_user_by_name(username: &str, db_conn: &PgConnection) -> QueryResult<User> { pub fn find_user_by_name(username: &str, db_conn: &mut PgConnection) -> QueryResult<User> {
users::table users::table
.filter(users::username.eq(username)) .filter(users::username.eq(username))
.first::<User>(db_conn) .first::<User>(db_conn)
} }
pub fn set_user_password(credentials: Credentials, db_conn: &PgConnection) -> QueryResult<()> { pub fn set_user_password(credentials: Credentials, db_conn: &mut PgConnection) -> QueryResult<()> {
let (hash, salt) = hash_password(&credentials.password); let (hash, salt) = hash_password(&credentials.password);
let n_changes = diesel::update(users::table.filter(users::username.eq(&credentials.username))) let n_changes = diesel::update(users::table.filter(users::username.eq(&credentials.username)))
@ -91,7 +91,7 @@ pub fn set_user_password(credentials: Credentials, db_conn: &PgConnection) -> Qu
} }
} }
pub fn authenticate_user(credentials: &Credentials, db_conn: &PgConnection) -> Option<User> { pub fn authenticate_user(credentials: &Credentials, db_conn: &mut PgConnection) -> Option<User> {
find_user_by_name(credentials.username, db_conn) find_user_by_name(credentials.username, db_conn)
.optional() .optional()
.unwrap() .unwrap()

View file

@ -2,7 +2,7 @@ use diesel_derive_enum::DbEnum;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
#[derive(DbEnum, Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash)] #[derive(DbEnum, Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash)]
#[DieselType = "Match_state"] #[DieselTypePath = "crate::schema::sql_types::MatchState"]
pub enum MatchState { pub enum MatchState {
Playing, Playing,

View file

@ -8,7 +8,7 @@ pub mod routes;
pub mod schema; pub mod schema;
pub mod util; pub mod util;
use std::ops::Deref; use std::ops::{Deref, DerefMut};
use std::path::PathBuf; use std::path::PathBuf;
use std::sync::Arc; use std::sync::Arc;
use std::{fs, net::SocketAddr}; use std::{fs, net::SocketAddr};
@ -70,9 +70,9 @@ pub struct GlobalConfig {
const SIMPLEBOT_PATH: &str = "../simplebot/simplebot.py"; const SIMPLEBOT_PATH: &str = "../simplebot/simplebot.py";
pub async fn seed_simplebot(config: &GlobalConfig, pool: &ConnectionPool) { pub async fn seed_simplebot(config: &GlobalConfig, pool: &ConnectionPool) {
let conn = pool.get().await.expect("could not get database connection"); let mut conn = pool.get().await.expect("could not get database connection");
// This transaction is expected to fail when simplebot already exists. // This transaction is expected to fail when simplebot already exists.
let _res = conn.transaction::<(), diesel::result::Error, _>(|| { let _res = conn.transaction::<(), diesel::result::Error, _>(|conn| {
use db::bots::NewBot; use db::bots::NewBot;
let new_bot = NewBot { let new_bot = NewBot {
@ -80,12 +80,12 @@ pub async fn seed_simplebot(config: &GlobalConfig, pool: &ConnectionPool) {
owner_id: None, owner_id: None,
}; };
let simplebot = db::bots::create_bot(&new_bot, &conn)?; let simplebot = db::bots::create_bot(&new_bot, conn)?;
let simplebot_code = let simplebot_code =
std::fs::read_to_string(SIMPLEBOT_PATH).expect("could not read simplebot code"); std::fs::read_to_string(SIMPLEBOT_PATH).expect("could not read simplebot code");
modules::bots::save_code_string(&simplebot_code, Some(simplebot.id), &conn, config)?; modules::bots::save_code_string(&simplebot_code, Some(simplebot.id), conn, config)?;
println!("initialized simplebot"); println!("initialized simplebot");
@ -209,6 +209,12 @@ impl Deref for DatabaseConnection {
} }
} }
impl DerefMut for DatabaseConnection {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
#[async_trait] #[async_trait]
impl<B> FromRequest<B> for DatabaseConnection impl<B> FromRequest<B> for DatabaseConnection
where where

View file

@ -9,7 +9,7 @@ use crate::{db, util::gen_alphanumeric, GlobalConfig};
pub fn save_code_string( pub fn save_code_string(
bot_code: &str, bot_code: &str,
bot_id: Option<i32>, bot_id: Option<i32>,
conn: &PgConnection, conn: &mut PgConnection,
config: &GlobalConfig, config: &GlobalConfig,
) -> QueryResult<db::bots::BotVersion> { ) -> QueryResult<db::bots::BotVersion> {
let bundle_name = gen_alphanumeric(16); let bundle_name = gen_alphanumeric(16);

View file

@ -149,19 +149,19 @@ impl pb::client_api_service_server::ClientApiService for ClientApiServer {
req: Request<pb::CreateMatchRequest>, req: Request<pb::CreateMatchRequest>,
) -> Result<Response<pb::CreateMatchResponse>, Status> { ) -> Result<Response<pb::CreateMatchResponse>, Status> {
// TODO: unify with matchrunner module // TODO: unify with matchrunner module
let conn = self.conn_pool.get().await.unwrap(); let mut conn = self.conn_pool.get().await.unwrap();
let match_request = req.get_ref(); let match_request = req.get_ref();
let (opponent_bot, opponent_bot_version) = let (opponent_bot, opponent_bot_version) =
db::bots::find_bot_with_version_by_name(&match_request.opponent_name, &conn) db::bots::find_bot_with_version_by_name(&match_request.opponent_name, &mut conn)
.map_err(|_| Status::not_found("opponent not found"))?; .map_err(|_| Status::not_found("opponent not found"))?;
let map_name = match match_request.map_name.as_str() { let map_name = match match_request.map_name.as_str() {
"" => "hex", "" => "hex",
name => name, name => name,
}; };
let map = db::maps::find_map_by_name(map_name, &conn) let map = db::maps::find_map_by_name(map_name, &mut conn)
.map_err(|_| Status::not_found("map not found"))?; .map_err(|_| Status::not_found("map not found"))?;
let player_key = gen_alphanumeric(32); let player_key = gen_alphanumeric(32);

View file

@ -80,8 +80,8 @@ impl RunMatch {
let match_data = { let match_data = {
// TODO: it would be nice to get an already-open connection here when possible. // TODO: it would be nice to get an already-open connection here when possible.
// Maybe we need an additional abstraction, bundling a connection and connection pool? // Maybe we need an additional abstraction, bundling a connection and connection pool?
let db_conn = conn_pool.get().await.expect("could not get a connection"); let mut db_conn = conn_pool.get().await.expect("could not get a connection");
self.store_in_database(&db_conn)? self.store_in_database(&mut db_conn)?
}; };
let runner_config = self.into_runner_config(); let runner_config = self.into_runner_config();
@ -90,7 +90,7 @@ impl RunMatch {
Ok((match_data, handle)) Ok((match_data, handle))
} }
fn store_in_database(&self, db_conn: &PgConnection) -> QueryResult<MatchData> { fn store_in_database(&self, db_conn: &mut PgConnection) -> QueryResult<MatchData> {
let new_match_data = db::matches::NewMatch { let new_match_data = db::matches::NewMatch {
state: db::matches::MatchState::Playing, state: db::matches::MatchState::Playing,
log_path: &self.log_file_name, log_path: &self.log_file_name,
@ -167,7 +167,7 @@ async fn run_match_task(
let outcome = runner::run_match(match_config).await; let outcome = runner::run_match(match_config).await;
// update match state in database // update match state in database
let conn = connection_pool let mut conn = connection_pool
.get() .get()
.await .await
.expect("could not get database connection"); .expect("could not get database connection");
@ -176,7 +176,8 @@ async fn run_match_task(
winner: outcome.winner.map(|w| (w - 1) as i32), // player numbers in matchrunner start at 1 winner: outcome.winner.map(|w| (w - 1) as i32), // player numbers in matchrunner start at 1
}; };
db::matches::save_match_result(match_id, result, &conn).expect("could not save match result"); db::matches::save_match_result(match_id, result, &mut conn)
.expect("could not save match result");
outcome outcome
} }

View file

@ -20,13 +20,14 @@ pub async fn run_ranker(config: Arc<GlobalConfig>, db_pool: DbPool) {
// TODO: make this configurable // TODO: make this configurable
// play at most one match every n seconds // play at most one match every n seconds
let mut interval = tokio::time::interval(Duration::from_secs(RANKER_INTERVAL)); let mut interval = tokio::time::interval(Duration::from_secs(RANKER_INTERVAL));
let db_conn = db_pool let mut db_conn = db_pool
.get() .get()
.await .await
.expect("could not get database connection"); .expect("could not get database connection");
loop { loop {
interval.tick().await; interval.tick().await;
let bots = db::bots::all_active_bots_with_version(&db_conn).expect("could not load bots"); let bots =
db::bots::all_active_bots_with_version(&mut db_conn).expect("could not load bots");
if bots.len() < 2 { if bots.len() < 2 {
// not enough bots to play a match // not enough bots to play a match
continue; continue;
@ -37,14 +38,14 @@ pub async fn run_ranker(config: Arc<GlobalConfig>, db_pool: DbPool) {
.cloned() .cloned()
.collect(); .collect();
let maps = db::maps::list_maps(&db_conn).expect("could not load map"); let maps = db::maps::list_maps(&mut db_conn).expect("could not load map");
let map = match maps.choose(&mut rand::thread_rng()).cloned() { let map = match maps.choose(&mut rand::thread_rng()).cloned() {
None => continue, // no maps available None => continue, // no maps available
Some(map) => map, Some(map) => map,
}; };
play_ranking_match(config.clone(), map, selected_bots, db_pool.clone()).await; play_ranking_match(config.clone(), map, selected_bots, db_pool.clone()).await;
recalculate_ratings(&db_conn).expect("could not recalculate ratings"); recalculate_ratings(&mut db_conn).expect("could not recalculate ratings");
} }
} }
@ -71,7 +72,7 @@ async fn play_ranking_match(
let _outcome = handle.await; let _outcome = handle.await;
} }
fn recalculate_ratings(db_conn: &PgConnection) -> QueryResult<()> { fn recalculate_ratings(db_conn: &mut PgConnection) -> QueryResult<()> {
let start = Instant::now(); let start = Instant::now();
let match_stats = fetch_match_stats(db_conn)?; let match_stats = fetch_match_stats(db_conn)?;
let ratings = estimate_ratings_from_stats(match_stats); let ratings = estimate_ratings_from_stats(match_stats);
@ -91,7 +92,7 @@ struct MatchStats {
num_matches: usize, num_matches: usize,
} }
fn fetch_match_stats(db_conn: &PgConnection) -> QueryResult<HashMap<(i32, i32), MatchStats>> { fn fetch_match_stats(db_conn: &mut PgConnection) -> QueryResult<HashMap<(i32, i32), MatchStats>> {
let matches = db::matches::list_matches(RANKER_NUM_MATCHES, db_conn)?; let matches = db::matches::list_matches(RANKER_NUM_MATCHES, db_conn)?;
let mut match_stats = HashMap::<(i32, i32), MatchStats>::new(); let mut match_stats = HashMap::<(i32, i32), MatchStats>::new();

View file

@ -112,8 +112,8 @@ where
Err(RegistryAuthError::InvalidCredentials) Err(RegistryAuthError::InvalidCredentials)
} }
} else { } else {
let db_conn = DatabaseConnection::from_request(req).await.unwrap(); let mut db_conn = DatabaseConnection::from_request(req).await.unwrap();
let user = authenticate_user(&credentials, &db_conn) let user = authenticate_user(&credentials, &mut db_conn)
.ok_or(RegistryAuthError::InvalidCredentials)?; .ok_or(RegistryAuthError::InvalidCredentials)?;
Ok(RegistryAuth::User(user)) Ok(RegistryAuth::User(user))
@ -159,12 +159,12 @@ pub struct RegistryError {
} }
async fn check_blob_exists( async fn check_blob_exists(
db_conn: DatabaseConnection, mut db_conn: DatabaseConnection,
auth: RegistryAuth, auth: RegistryAuth,
Path((repository_name, raw_digest)): Path<(String, String)>, Path((repository_name, raw_digest)): Path<(String, String)>,
Extension(config): Extension<Arc<GlobalConfig>>, Extension(config): Extension<Arc<GlobalConfig>>,
) -> Result<impl IntoResponse, StatusCode> { ) -> Result<impl IntoResponse, StatusCode> {
check_access(&repository_name, &auth, &db_conn)?; check_access(&repository_name, &auth, &mut db_conn)?;
let digest = raw_digest.strip_prefix("sha256:").unwrap(); let digest = raw_digest.strip_prefix("sha256:").unwrap();
let blob_path = PathBuf::from(&config.registry_directory) let blob_path = PathBuf::from(&config.registry_directory)
@ -179,12 +179,12 @@ async fn check_blob_exists(
} }
async fn get_blob( async fn get_blob(
db_conn: DatabaseConnection, mut db_conn: DatabaseConnection,
auth: RegistryAuth, auth: RegistryAuth,
Path((repository_name, raw_digest)): Path<(String, String)>, Path((repository_name, raw_digest)): Path<(String, String)>,
Extension(config): Extension<Arc<GlobalConfig>>, Extension(config): Extension<Arc<GlobalConfig>>,
) -> Result<impl IntoResponse, StatusCode> { ) -> Result<impl IntoResponse, StatusCode> {
check_access(&repository_name, &auth, &db_conn)?; check_access(&repository_name, &auth, &mut db_conn)?;
let digest = raw_digest.strip_prefix("sha256:").unwrap(); let digest = raw_digest.strip_prefix("sha256:").unwrap();
let blob_path = PathBuf::from(&config.registry_directory) let blob_path = PathBuf::from(&config.registry_directory)
@ -200,12 +200,12 @@ async fn get_blob(
} }
async fn create_upload( async fn create_upload(
db_conn: DatabaseConnection, mut db_conn: DatabaseConnection,
auth: RegistryAuth, auth: RegistryAuth,
Path(repository_name): Path<String>, Path(repository_name): Path<String>,
Extension(config): Extension<Arc<GlobalConfig>>, Extension(config): Extension<Arc<GlobalConfig>>,
) -> Result<impl IntoResponse, StatusCode> { ) -> Result<impl IntoResponse, StatusCode> {
check_access(&repository_name, &auth, &db_conn)?; check_access(&repository_name, &auth, &mut db_conn)?;
let uuid = gen_alphanumeric(16); let uuid = gen_alphanumeric(16);
tokio::fs::File::create( tokio::fs::File::create(
@ -229,13 +229,13 @@ async fn create_upload(
} }
async fn patch_upload( async fn patch_upload(
db_conn: DatabaseConnection, mut db_conn: DatabaseConnection,
auth: RegistryAuth, auth: RegistryAuth,
Path((repository_name, uuid)): Path<(String, String)>, Path((repository_name, uuid)): Path<(String, String)>,
mut stream: BodyStream, mut stream: BodyStream,
Extension(config): Extension<Arc<GlobalConfig>>, Extension(config): Extension<Arc<GlobalConfig>>,
) -> Result<impl IntoResponse, StatusCode> { ) -> Result<impl IntoResponse, StatusCode> {
check_access(&repository_name, &auth, &db_conn)?; check_access(&repository_name, &auth, &mut db_conn)?;
// TODO: support content range header in request // TODO: support content range header in request
let upload_path = PathBuf::from(&config.registry_directory) let upload_path = PathBuf::from(&config.registry_directory)
@ -275,14 +275,14 @@ struct UploadParams {
} }
async fn put_upload( async fn put_upload(
db_conn: DatabaseConnection, mut db_conn: DatabaseConnection,
auth: RegistryAuth, auth: RegistryAuth,
Path((repository_name, uuid)): Path<(String, String)>, Path((repository_name, uuid)): Path<(String, String)>,
Query(params): Query<UploadParams>, Query(params): Query<UploadParams>,
mut stream: BodyStream, mut stream: BodyStream,
Extension(config): Extension<Arc<GlobalConfig>>, Extension(config): Extension<Arc<GlobalConfig>>,
) -> Result<impl IntoResponse, StatusCode> { ) -> Result<impl IntoResponse, StatusCode> {
check_access(&repository_name, &auth, &db_conn)?; check_access(&repository_name, &auth, &mut db_conn)?;
let upload_path = PathBuf::from(&config.registry_directory) let upload_path = PathBuf::from(&config.registry_directory)
.join("uploads") .join("uploads")
@ -332,12 +332,12 @@ async fn put_upload(
} }
async fn get_manifest( async fn get_manifest(
db_conn: DatabaseConnection, mut db_conn: DatabaseConnection,
auth: RegistryAuth, auth: RegistryAuth,
Path((repository_name, reference)): Path<(String, String)>, Path((repository_name, reference)): Path<(String, String)>,
Extension(config): Extension<Arc<GlobalConfig>>, Extension(config): Extension<Arc<GlobalConfig>>,
) -> Result<impl IntoResponse, StatusCode> { ) -> Result<impl IntoResponse, StatusCode> {
check_access(&repository_name, &auth, &db_conn)?; check_access(&repository_name, &auth, &mut db_conn)?;
let manifest_path = PathBuf::from(&config.registry_directory) let manifest_path = PathBuf::from(&config.registry_directory)
.join("manifests") .join("manifests")
@ -357,13 +357,13 @@ async fn get_manifest(
} }
async fn put_manifest( async fn put_manifest(
db_conn: DatabaseConnection, mut db_conn: DatabaseConnection,
auth: RegistryAuth, auth: RegistryAuth,
Path((repository_name, reference)): Path<(String, String)>, Path((repository_name, reference)): Path<(String, String)>,
mut stream: BodyStream, mut stream: BodyStream,
Extension(config): Extension<Arc<GlobalConfig>>, Extension(config): Extension<Arc<GlobalConfig>>,
) -> Result<impl IntoResponse, StatusCode> { ) -> Result<impl IntoResponse, StatusCode> {
let bot = check_access(&repository_name, &auth, &db_conn)?; let bot = check_access(&repository_name, &auth, &mut db_conn)?;
let repository_dir = PathBuf::from(&config.registry_directory) let repository_dir = PathBuf::from(&config.registry_directory)
.join("manifests") .join("manifests")
@ -399,9 +399,9 @@ async fn put_manifest(
code_bundle_path: None, code_bundle_path: None,
container_digest: Some(&content_digest), container_digest: Some(&content_digest),
}; };
let version = let version = db::bots::create_bot_version(&new_version, &mut db_conn)
db::bots::create_bot_version(&new_version, &db_conn).expect("could not save bot version"); .expect("could not save bot version");
db::bots::set_active_version(bot.id, Some(version.id), &db_conn) db::bots::set_active_version(bot.id, Some(version.id), &mut db_conn)
.expect("could not update bot version"); .expect("could not update bot version");
Ok(Response::builder() Ok(Response::builder()
@ -421,7 +421,7 @@ async fn put_manifest(
fn check_access( fn check_access(
repository_name: &str, repository_name: &str,
auth: &RegistryAuth, auth: &RegistryAuth,
db_conn: &DatabaseConnection, db_conn: &mut DatabaseConnection,
) -> Result<db::bots::Bot, StatusCode> { ) -> Result<db::bots::Bot, StatusCode> {
use diesel::OptionalExtension; use diesel::OptionalExtension;

View file

@ -100,10 +100,10 @@ pub fn validate_bot_name(bot_name: &str) -> Result<(), SaveBotError> {
pub async fn save_bot( pub async fn save_bot(
Json(params): Json<SaveBotParams>, Json(params): Json<SaveBotParams>,
user: User, user: User,
conn: DatabaseConnection, mut conn: DatabaseConnection,
Extension(config): Extension<Arc<GlobalConfig>>, Extension(config): Extension<Arc<GlobalConfig>>,
) -> Result<Json<Bot>, SaveBotError> { ) -> Result<Json<Bot>, SaveBotError> {
let res = bots::find_bot_by_name(&params.bot_name, &conn) let res = bots::find_bot_by_name(&params.bot_name, &mut conn)
.optional() .optional()
.expect("could not run query"); .expect("could not run query");
@ -122,10 +122,10 @@ pub async fn save_bot(
name: &params.bot_name, name: &params.bot_name,
}; };
bots::create_bot(&new_bot, &conn).expect("could not create bot") bots::create_bot(&new_bot, &mut conn).expect("could not create bot")
} }
}; };
let _code_bundle = save_code_string(&params.code, Some(bot.id), &conn, &config) let _code_bundle = save_code_string(&params.code, Some(bot.id), &mut conn, &config)
.expect("failed to save code bundle"); .expect("failed to save code bundle");
Ok(Json(bot)) Ok(Json(bot))
} }
@ -137,12 +137,12 @@ pub struct BotParams {
// TODO: can we unify this with save_bot? // TODO: can we unify this with save_bot?
pub async fn create_bot( pub async fn create_bot(
conn: DatabaseConnection, mut conn: DatabaseConnection,
user: User, user: User,
params: Json<BotParams>, params: Json<BotParams>,
) -> Result<(StatusCode, Json<Bot>), SaveBotError> { ) -> Result<(StatusCode, Json<Bot>), SaveBotError> {
validate_bot_name(&params.name)?; validate_bot_name(&params.name)?;
let existing_bot = bots::find_bot_by_name(&params.name, &conn) let existing_bot = bots::find_bot_by_name(&params.name, &mut conn)
.optional() .optional()
.expect("could not run query"); .expect("could not run query");
if existing_bot.is_some() { if existing_bot.is_some() {
@ -152,26 +152,27 @@ pub async fn create_bot(
owner_id: Some(user.id), owner_id: Some(user.id),
name: &params.name, name: &params.name,
}; };
let bot = bots::create_bot(&bot_params, &conn).unwrap(); let bot = bots::create_bot(&bot_params, &mut conn).unwrap();
Ok((StatusCode::CREATED, Json(bot))) Ok((StatusCode::CREATED, Json(bot)))
} }
// TODO: handle errors // TODO: handle errors
pub async fn get_bot( pub async fn get_bot(
conn: DatabaseConnection, mut conn: DatabaseConnection,
Path(bot_name): Path<String>, Path(bot_name): Path<String>,
) -> Result<Json<JsonValue>, StatusCode> { ) -> Result<Json<JsonValue>, StatusCode> {
let bot = db::bots::find_bot_by_name(&bot_name, &conn).map_err(|_| StatusCode::NOT_FOUND)?; let bot =
db::bots::find_bot_by_name(&bot_name, &mut conn).map_err(|_| StatusCode::NOT_FOUND)?;
let owner: Option<UserData> = match bot.owner_id { let owner: Option<UserData> = match bot.owner_id {
Some(user_id) => { Some(user_id) => {
let user = db::users::find_user(user_id, &conn) let user = db::users::find_user(user_id, &mut conn)
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
Some(user.into()) Some(user.into())
} }
None => None, None => None,
}; };
let versions = let versions = bots::find_bot_versions(bot.id, &mut conn)
bots::find_bot_versions(bot.id, &conn).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
Ok(Json(json!({ Ok(Json(json!({
"bot": bot, "bot": bot,
"owner": owner, "owner": owner,
@ -180,32 +181,32 @@ pub async fn get_bot(
} }
pub async fn get_user_bots( pub async fn get_user_bots(
conn: DatabaseConnection, mut conn: DatabaseConnection,
Path(user_name): Path<String>, Path(user_name): Path<String>,
) -> Result<Json<Vec<Bot>>, StatusCode> { ) -> Result<Json<Vec<Bot>>, StatusCode> {
let user = let user =
db::users::find_user_by_name(&user_name, &conn).map_err(|_| StatusCode::NOT_FOUND)?; db::users::find_user_by_name(&user_name, &mut conn).map_err(|_| StatusCode::NOT_FOUND)?;
db::bots::find_bots_by_owner(user.id, &conn) db::bots::find_bots_by_owner(user.id, &mut conn)
.map(Json) .map(Json)
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR) .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
} }
/// List all active bots /// List all active bots
pub async fn list_bots(conn: DatabaseConnection) -> Result<Json<Vec<Bot>>, StatusCode> { pub async fn list_bots(mut conn: DatabaseConnection) -> Result<Json<Vec<Bot>>, StatusCode> {
bots::find_active_bots(&conn) bots::find_active_bots(&mut conn)
.map(Json) .map(Json)
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR) .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
} }
pub async fn get_ranking(conn: DatabaseConnection) -> Result<Json<Vec<RankedBot>>, StatusCode> { pub async fn get_ranking(mut conn: DatabaseConnection) -> Result<Json<Vec<RankedBot>>, StatusCode> {
ratings::get_bot_ranking(&conn) ratings::get_bot_ranking(&mut conn)
.map(Json) .map(Json)
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR) .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
} }
// TODO: currently this only implements the happy flow // TODO: currently this only implements the happy flow
pub async fn upload_code_multipart( pub async fn upload_code_multipart(
conn: DatabaseConnection, mut conn: DatabaseConnection,
user: User, user: User,
Path(bot_name): Path<String>, Path(bot_name): Path<String>,
mut multipart: Multipart, mut multipart: Multipart,
@ -213,7 +214,7 @@ pub async fn upload_code_multipart(
) -> Result<Json<BotVersion>, StatusCode> { ) -> Result<Json<BotVersion>, StatusCode> {
let bots_dir = PathBuf::from(&config.bots_directory); let bots_dir = PathBuf::from(&config.bots_directory);
let bot = bots::find_bot_by_name(&bot_name, &conn).map_err(|_| StatusCode::NOT_FOUND)?; let bot = bots::find_bot_by_name(&bot_name, &mut conn).map_err(|_| StatusCode::NOT_FOUND)?;
if Some(user.id) != bot.owner_id { if Some(user.id) != bot.owner_id {
return Err(StatusCode::FORBIDDEN); return Err(StatusCode::FORBIDDEN);
@ -246,21 +247,22 @@ pub async fn upload_code_multipart(
container_digest: None, container_digest: None,
}; };
let code_bundle = let code_bundle =
bots::create_bot_version(&bot_version, &conn).expect("Failed to create code bundle"); bots::create_bot_version(&bot_version, &mut conn).expect("Failed to create code bundle");
Ok(Json(code_bundle)) Ok(Json(code_bundle))
} }
pub async fn get_code( pub async fn get_code(
conn: DatabaseConnection, mut conn: DatabaseConnection,
user: User, user: User,
Path(bundle_id): Path<i32>, Path(bundle_id): Path<i32>,
Extension(config): Extension<Arc<GlobalConfig>>, Extension(config): Extension<Arc<GlobalConfig>>,
) -> Result<Vec<u8>, StatusCode> { ) -> Result<Vec<u8>, StatusCode> {
let version = let version =
db::bots::find_bot_version(bundle_id, &conn).map_err(|_| StatusCode::NOT_FOUND)?; db::bots::find_bot_version(bundle_id, &mut conn).map_err(|_| StatusCode::NOT_FOUND)?;
let bot_id = version.bot_id.ok_or(StatusCode::FORBIDDEN)?; let bot_id = version.bot_id.ok_or(StatusCode::FORBIDDEN)?;
let bot = db::bots::find_bot(bot_id, &conn).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; let bot =
db::bots::find_bot(bot_id, &mut conn).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
if bot.owner_id != Some(user.id) { if bot.owner_id != Some(user.id) {
return Err(StatusCode::FORBIDDEN); return Err(StatusCode::FORBIDDEN);
@ -297,10 +299,10 @@ impl MatchupStats {
type BotStats = HashMap<String, HashMap<String, MatchupStats>>; type BotStats = HashMap<String, HashMap<String, MatchupStats>>;
pub async fn get_bot_stats( pub async fn get_bot_stats(
conn: DatabaseConnection, mut conn: DatabaseConnection,
Path(bot_name): Path<String>, Path(bot_name): Path<String>,
) -> Result<Json<BotStats>, StatusCode> { ) -> Result<Json<BotStats>, StatusCode> {
let stats_records = db::matches::fetch_bot_stats(&bot_name, &conn) let stats_records = db::matches::fetch_bot_stats(&bot_name, &mut conn)
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
let mut bot_stats: BotStats = HashMap::new(); let mut bot_stats: BotStats = HashMap::new();
for record in stats_records { for record in stats_records {

View file

@ -35,7 +35,7 @@ pub async fn submit_bot(
Extension(pool): Extension<ConnectionPool>, Extension(pool): Extension<ConnectionPool>,
Extension(config): Extension<Arc<GlobalConfig>>, Extension(config): Extension<Arc<GlobalConfig>>,
) -> Result<Json<SubmitBotResponse>, StatusCode> { ) -> Result<Json<SubmitBotResponse>, StatusCode> {
let conn = pool.get().await.expect("could not get database connection"); let mut conn = pool.get().await.expect("could not get database connection");
let opponent_name = params let opponent_name = params
.opponent_name .opponent_name
@ -46,12 +46,13 @@ pub async fn submit_bot(
.unwrap_or_else(|| DEFAULT_MAP_NAME.to_string()); .unwrap_or_else(|| DEFAULT_MAP_NAME.to_string());
let (opponent_bot, opponent_bot_version) = let (opponent_bot, opponent_bot_version) =
db::bots::find_bot_with_version_by_name(&opponent_name, &conn) db::bots::find_bot_with_version_by_name(&opponent_name, &mut conn)
.map_err(|_| StatusCode::BAD_REQUEST)?; .map_err(|_| StatusCode::BAD_REQUEST)?;
let map = db::maps::find_map_by_name(&map_name, &conn).map_err(|_| StatusCode::BAD_REQUEST)?; let map =
db::maps::find_map_by_name(&map_name, &mut conn).map_err(|_| StatusCode::BAD_REQUEST)?;
let player_bot_version = save_code_string(&params.code, None, &conn, &config) let player_bot_version = save_code_string(&params.code, None, &mut conn, &config)
// TODO: can we recover from this? // TODO: can we recover from this?
.expect("could not save bot code"); .expect("could not save bot code");

View file

@ -8,8 +8,8 @@ pub struct ApiMap {
pub name: String, pub name: String,
} }
pub async fn list_maps(conn: DatabaseConnection) -> Result<Json<Vec<ApiMap>>, StatusCode> { pub async fn list_maps(mut conn: DatabaseConnection) -> Result<Json<Vec<ApiMap>>, StatusCode> {
let maps = db::maps::list_maps(&conn).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; let maps = db::maps::list_maps(&mut conn).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
let api_maps = maps let api_maps = maps
.into_iter() .into_iter()

View file

@ -56,7 +56,7 @@ pub struct ListMatchesResponse {
pub async fn list_recent_matches( pub async fn list_recent_matches(
Query(params): Query<ListRecentMatchesParams>, Query(params): Query<ListRecentMatchesParams>,
conn: DatabaseConnection, mut conn: DatabaseConnection,
) -> Result<Json<ListMatchesResponse>, StatusCode> { ) -> Result<Json<ListMatchesResponse>, StatusCode> {
let requested_count = std::cmp::min( let requested_count = std::cmp::min(
params.count.unwrap_or(DEFAULT_NUM_RETURNED_MATCHES), params.count.unwrap_or(DEFAULT_NUM_RETURNED_MATCHES),
@ -68,7 +68,7 @@ pub async fn list_recent_matches(
let matches_result = match params.bot { let matches_result = match params.bot {
Some(bot_name) => { Some(bot_name) => {
let bot = db::bots::find_bot_by_name(&bot_name, &conn) let bot = db::bots::find_bot_by_name(&bot_name, &mut conn)
.map_err(|_| StatusCode::BAD_REQUEST)?; .map_err(|_| StatusCode::BAD_REQUEST)?;
matches::list_bot_matches( matches::list_bot_matches(
bot.id, bot.id,
@ -76,10 +76,10 @@ pub async fn list_recent_matches(
count, count,
params.before, params.before,
params.after, params.after,
&conn, &mut conn,
) )
} }
None => matches::list_public_matches(count, params.before, params.after, &conn), None => matches::list_public_matches(count, params.before, params.after, &mut conn),
}; };
let mut matches = matches_result.map_err(|_| StatusCode::BAD_REQUEST)?; let mut matches = matches_result.map_err(|_| StatusCode::BAD_REQUEST)?;
@ -119,9 +119,9 @@ pub fn match_data_to_api(data: matches::FullMatchData) -> ApiMatch {
pub async fn get_match_data( pub async fn get_match_data(
Path(match_id): Path<i32>, Path(match_id): Path<i32>,
conn: DatabaseConnection, mut conn: DatabaseConnection,
) -> Result<Json<ApiMatch>, StatusCode> { ) -> Result<Json<ApiMatch>, StatusCode> {
let match_data = matches::find_match(match_id, &conn) let match_data = matches::find_match(match_id, &mut conn)
.map_err(|_| StatusCode::NOT_FOUND) .map_err(|_| StatusCode::NOT_FOUND)
.map(match_data_to_api)?; .map(match_data_to_api)?;
Ok(Json(match_data)) Ok(Json(match_data))
@ -129,11 +129,11 @@ pub async fn get_match_data(
pub async fn get_match_log( pub async fn get_match_log(
Path(match_id): Path<i32>, Path(match_id): Path<i32>,
conn: DatabaseConnection, mut conn: DatabaseConnection,
Extension(config): Extension<Arc<GlobalConfig>>, Extension(config): Extension<Arc<GlobalConfig>>,
) -> Result<Vec<u8>, StatusCode> { ) -> Result<Vec<u8>, StatusCode> {
let match_base = let match_base =
matches::find_match_base(match_id, &conn).map_err(|_| StatusCode::NOT_FOUND)?; matches::find_match_base(match_id, &mut conn).map_err(|_| StatusCode::NOT_FOUND)?;
let log_path = PathBuf::from(&config.match_logs_directory).join(&match_base.log_path); let log_path = PathBuf::from(&config.match_logs_directory).join(&match_base.log_path);
let log_contents = std::fs::read(log_path).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; let log_contents = std::fs::read(log_path).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
Ok(log_contents) Ok(log_contents)

View file

@ -23,13 +23,13 @@ where
type Rejection = (StatusCode, String); type Rejection = (StatusCode, String);
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> { async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
let conn = DatabaseConnection::from_request(req).await?; let mut conn = DatabaseConnection::from_request(req).await?;
let TypedHeader(Authorization(bearer)) = AuthorizationHeader::from_request(req) let TypedHeader(Authorization(bearer)) = AuthorizationHeader::from_request(req)
.await .await
.map_err(|_| (StatusCode::UNAUTHORIZED, "".to_string()))?; .map_err(|_| (StatusCode::UNAUTHORIZED, "".to_string()))?;
let (_session, user) = sessions::find_user_by_session(bearer.token(), &conn) let (_session, user) = sessions::find_user_by_session(bearer.token(), &mut conn)
.map_err(|_| (StatusCode::UNAUTHORIZED, "".to_string()))?; .map_err(|_| (StatusCode::UNAUTHORIZED, "".to_string()))?;
Ok(user) Ok(user)
@ -66,7 +66,7 @@ pub enum RegistrationError {
} }
impl RegistrationParams { impl RegistrationParams {
fn validate(&self, conn: &DatabaseConnection) -> Result<(), RegistrationError> { fn validate(&self, conn: &mut DatabaseConnection) -> Result<(), RegistrationError> {
let mut errors = Vec::new(); let mut errors = Vec::new();
// TODO: do we want to support cased usernames? // TODO: do we want to support cased usernames?
@ -95,7 +95,7 @@ impl RegistrationParams {
errors.push("that username is not allowed".to_string()); errors.push("that username is not allowed".to_string());
} }
if users::find_user_by_name(&self.username, &conn).is_ok() { if users::find_user_by_name(&self.username, conn).is_ok() {
errors.push("username is already taken".to_string()); errors.push("username is already taken".to_string());
} }
@ -137,16 +137,16 @@ impl IntoResponse for RegistrationError {
} }
pub async fn register( pub async fn register(
conn: DatabaseConnection, mut conn: DatabaseConnection,
params: Json<RegistrationParams>, params: Json<RegistrationParams>,
) -> Result<Json<UserData>, RegistrationError> { ) -> Result<Json<UserData>, RegistrationError> {
params.validate(&conn)?; params.validate(&mut conn)?;
let credentials = Credentials { let credentials = Credentials {
username: &params.username, username: &params.username,
password: &params.password, password: &params.password,
}; };
let user = users::create_user(&credentials, &conn)?; let user = users::create_user(&credentials, &mut conn)?;
Ok(Json(user.into())) Ok(Json(user.into()))
} }
@ -156,18 +156,18 @@ pub struct LoginParams {
pub password: String, pub password: String,
} }
pub async fn login(conn: DatabaseConnection, params: Json<LoginParams>) -> Response { pub async fn login(mut conn: DatabaseConnection, params: Json<LoginParams>) -> Response {
let credentials = Credentials { let credentials = Credentials {
username: &params.username, username: &params.username,
password: &params.password, password: &params.password,
}; };
// TODO: handle failures // TODO: handle failures
let authenticated = users::authenticate_user(&credentials, &conn); let authenticated = users::authenticate_user(&credentials, &mut conn);
match authenticated { match authenticated {
None => StatusCode::FORBIDDEN.into_response(), None => StatusCode::FORBIDDEN.into_response(),
Some(user) => { Some(user) => {
let session = sessions::create_session(&user, &conn); let session = sessions::create_session(&user, &mut conn);
let user_data: UserData = user.into(); let user_data: UserData = user.into();
let headers = [("Token", &session.token)]; let headers = [("Token", &session.token)];

View file

@ -1,7 +1,15 @@
// This file is autogenerated by diesel // This file is autogenerated by diesel
#![allow(unused_imports)] #![allow(unused_imports)]
table! { // @generated automatically by Diesel CLI.
pub mod sql_types {
#[derive(diesel::sql_types::SqlType)]
#[diesel(postgres_type(name = "match_state"))]
pub struct MatchState;
}
diesel::table! {
use diesel::sql_types::*; use diesel::sql_types::*;
use crate::db_types::*; use crate::db_types::*;
@ -14,7 +22,7 @@ table! {
} }
} }
table! { diesel::table! {
use diesel::sql_types::*; use diesel::sql_types::*;
use crate::db_types::*; use crate::db_types::*;
@ -26,7 +34,7 @@ table! {
} }
} }
table! { diesel::table! {
use diesel::sql_types::*; use diesel::sql_types::*;
use crate::db_types::*; use crate::db_types::*;
@ -37,7 +45,7 @@ table! {
} }
} }
table! { diesel::table! {
use diesel::sql_types::*; use diesel::sql_types::*;
use crate::db_types::*; use crate::db_types::*;
@ -48,13 +56,14 @@ table! {
} }
} }
table! { diesel::table! {
use diesel::sql_types::*; use diesel::sql_types::*;
use crate::db_types::*; use crate::db_types::*;
use super::sql_types::MatchState;
matches (id) { matches (id) {
id -> Int4, id -> Int4,
state -> Match_state, state -> MatchState,
log_path -> Text, log_path -> Text,
created_at -> Timestamp, created_at -> Timestamp,
winner -> Nullable<Int4>, winner -> Nullable<Int4>,
@ -63,7 +72,7 @@ table! {
} }
} }
table! { diesel::table! {
use diesel::sql_types::*; use diesel::sql_types::*;
use crate::db_types::*; use crate::db_types::*;
@ -73,7 +82,7 @@ table! {
} }
} }
table! { diesel::table! {
use diesel::sql_types::*; use diesel::sql_types::*;
use crate::db_types::*; use crate::db_types::*;
@ -84,7 +93,7 @@ table! {
} }
} }
table! { diesel::table! {
use diesel::sql_types::*; use diesel::sql_types::*;
use crate::db_types::*; use crate::db_types::*;
@ -96,14 +105,14 @@ table! {
} }
} }
joinable!(bots -> users (owner_id)); diesel::joinable!(bots -> users (owner_id));
joinable!(match_players -> bot_versions (bot_version_id)); diesel::joinable!(match_players -> bot_versions (bot_version_id));
joinable!(match_players -> matches (match_id)); diesel::joinable!(match_players -> matches (match_id));
joinable!(matches -> maps (map_id)); diesel::joinable!(matches -> maps (map_id));
joinable!(ratings -> bots (bot_id)); diesel::joinable!(ratings -> bots (bot_id));
joinable!(sessions -> users (user_id)); diesel::joinable!(sessions -> users (user_id));
allow_tables_to_appear_in_same_query!( diesel::allow_tables_to_appear_in_same_query!(
bot_versions, bot_versions,
bots, bots,
maps, maps,

View file

@ -27,7 +27,7 @@ fn create_subdir<P: AsRef<Path>>(base_path: &Path, p: P) -> io::Result<String> {
Ok(dir_path_string) Ok(dir_path_string)
} }
fn clear_database(conn: &PgConnection) { fn clear_database(conn: &mut PgConnection) {
diesel::sql_query( diesel::sql_query(
"TRUNCATE TABLE "TRUNCATE TABLE
bots, bots,
@ -45,20 +45,20 @@ fn clear_database(conn: &PgConnection) {
/// Setup a simple text fixture, having simplebot and the hex map. /// Setup a simple text fixture, having simplebot and the hex map.
/// This is enough to run a simple match. /// This is enough to run a simple match.
fn setup_simple_fixture(db_conn: &PgConnection, config: &GlobalConfig) { fn setup_simple_fixture(db_conn: &mut PgConnection, config: &GlobalConfig) {
let bot = db::bots::create_bot( let bot = db::bots::create_bot(
&db::bots::NewBot { &db::bots::NewBot {
owner_id: None, owner_id: None,
name: "simplebot", name: "simplebot",
}, },
&db_conn, db_conn,
) )
.expect("could not create simplebot"); .expect("could not create simplebot");
let simplebot_code = std::fs::read_to_string("../simplebot/simplebot.py") let simplebot_code = std::fs::read_to_string("../simplebot/simplebot.py")
.expect("could not read simplebot code"); .expect("could not read simplebot code");
let _bot_version = let _bot_version =
modules::bots::save_code_string(&simplebot_code, Some(bot.id), &db_conn, &config) modules::bots::save_code_string(&simplebot_code, Some(bot.id), db_conn, &config)
.expect("could not save bot version"); .expect("could not save bot version");
std::fs::copy( std::fs::copy(
@ -71,7 +71,7 @@ fn setup_simple_fixture(db_conn: &PgConnection, config: &GlobalConfig) {
name: "hex", name: "hex",
file_path: "hex.json", file_path: "hex.json",
}, },
&db_conn, db_conn,
) )
.expect("could not save map"); .expect("could not save map");
} }
@ -119,14 +119,14 @@ impl<'a> TestApp<'a> {
async fn with_db_conn<F, R>(&self, function: F) -> R async fn with_db_conn<F, R>(&self, function: F) -> R
where where
F: FnOnce(&PgConnection) -> R, F: FnOnce(&mut PgConnection) -> R,
{ {
let db_conn = self let mut db_conn = self
.db_pool .db_pool
.get() .get()
.await .await
.expect("could not get db connection"); .expect("could not get db connection");
function(&db_conn) function(&mut db_conn)
} }
} }