upgrade to diesel 2.0

This commit is contained in:
Ilion Beyst 2022-10-12 22:52:15 +02:00
parent ed016773b1
commit ae57359353
24 changed files with 357 additions and 220 deletions

View file

@ -43,7 +43,7 @@ jobs:
- name: Setup tests
run: |
docker pull python:3.10-slim-buster
cargo install diesel_cli --version ^1.4 || true
cargo install diesel_cli --version ^2.0 || true
cd planetwars-server
diesel migration run --locked-schema
env:

149
Cargo.lock generated
View file

@ -29,6 +29,15 @@ dependencies = [
"alloc-no-stdlib",
]
[[package]]
name = "android_system_properties"
version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "819e7219dbd41043ac279b19830f2efc897156490d7fd6ea916720117ee66311"
dependencies = [
"libc",
]
[[package]]
name = "anyhow"
version = "1.0.58"
@ -171,22 +180,21 @@ checksum = "904dfeac50f3cdaba28fc6f57fdcddb75f49ed61346676a78c4ffe55877802fd"
[[package]]
name = "bb8"
version = "0.7.1"
version = "0.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2e9f4fa9768efd269499d8fba693260cfc670891cf6de3adc935588447a77cc8"
checksum = "1627eccf3aa91405435ba240be23513eeca466b5dc33866422672264de061582"
dependencies = [
"async-trait",
"futures-channel",
"futures-util",
"parking_lot 0.11.2",
"parking_lot 0.12.1",
"tokio",
]
[[package]]
name = "bb8-diesel"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "79c87e12b0086ff7850d98a19d2a70f5fd901b463412d499514d8e2e16ad0826"
source = "git+https://github.com/overdrivenpotato/bb8-diesel.git#89b76207bbca35082687c765074f402200fcc51f"
dependencies = [
"async-trait",
"bb8",
@ -372,15 +380,17 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
[[package]]
name = "chrono"
version = "0.4.19"
version = "0.4.22"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "670ad68c9088c2a963aaa298cb369688cf3f9465ce5e2d4ca10e6e0098a1ce73"
checksum = "bfd4d1b31faaa3a89d7934dbded3111da0d2ef28e3ebccdb4f0179f5929d1ef1"
dependencies = [
"libc",
"iana-time-zone",
"js-sys",
"num-integer",
"num-traits",
"serde",
"time",
"wasm-bindgen",
"winapi",
]
@ -441,6 +451,16 @@ dependencies = [
"cc",
]
[[package]]
name = "codespan-reporting"
version = "0.11.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3538270d33cc669650c4b093848450d380def10c331d38c768e34cac80576e6e"
dependencies = [
"termcolor",
"unicode-width",
]
[[package]]
name = "config"
version = "0.12.0"
@ -520,6 +540,50 @@ dependencies = [
"typenum",
]
[[package]]
name = "cxx"
version = "1.0.78"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "19f39818dcfc97d45b03953c1292efc4e80954e1583c4aa770bac1383e2310a4"
dependencies = [
"cc",
"cxxbridge-flags",
"cxxbridge-macro",
"link-cplusplus",
]
[[package]]
name = "cxx-build"
version = "1.0.78"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3e580d70777c116df50c390d1211993f62d40302881e54d4b79727acb83d0199"
dependencies = [
"cc",
"codespan-reporting",
"once_cell",
"proc-macro2",
"quote",
"scratch",
"syn",
]
[[package]]
name = "cxxbridge-flags"
version = "1.0.78"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "56a46460b88d1cec95112c8c363f0e2c39afdb237f60583b0b36343bf627ea9c"
[[package]]
name = "cxxbridge-macro"
version = "1.0.78"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "747b608fecf06b0d72d440f27acc99288207324b793be2c17991839f3d4995ea"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "darling"
version = "0.13.1"
@ -557,23 +621,24 @@ dependencies = [
[[package]]
name = "diesel"
version = "1.4.8"
version = "2.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b28135ecf6b7d446b43e27e225622a038cc4e2930a1022f51cdb97ada19b8e4d"
checksum = "68c186a7418a2aac330bb76cde82f16c36b03a66fb91db32d20214311f9f6545"
dependencies = [
"bitflags",
"byteorder",
"chrono",
"diesel_derives",
"itoa 1.0.2",
"pq-sys",
"r2d2",
]
[[package]]
name = "diesel-derive-enum"
version = "1.1.2"
version = "2.0.0-rc.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6c8910921b014e2af16298f006de12aa08af894b71f0f49a486ab6d74b17bbed"
checksum = "5f28fc9f5bf184ebc58ad9105dede024981e2303fe878a0fe16557f3a979064a"
dependencies = [
"heck",
"proc-macro2",
@ -583,10 +648,11 @@ dependencies = [
[[package]]
name = "diesel_derives"
version = "1.4.1"
version = "2.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "45f5098f628d02a7a0f68ddba586fb61e80edec3bdc1be3b921f4ceec60858d3"
checksum = "143b758c91dbc3fe1fdcb0dba5bd13276c6a66422f2ef5795b58488248a310aa"
dependencies = [
"proc-macro-error",
"proc-macro2",
"quote",
"syn",
@ -1014,6 +1080,30 @@ dependencies = [
"tokio",
]
[[package]]
name = "iana-time-zone"
version = "0.1.51"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f5a6ef98976b22b3b7f2f3a806f858cb862044cfa66805aa3ad84cb3d3b785ed"
dependencies = [
"android_system_properties",
"core-foundation-sys",
"iana-time-zone-haiku",
"js-sys",
"wasm-bindgen",
"winapi",
]
[[package]]
name = "iana-time-zone-haiku"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fde6edd6cef363e9359ed3c98ba64590ba9eecba2293eb5a723ab32aee8926aa"
dependencies = [
"cxx",
"cxx-build",
]
[[package]]
name = "ident_case"
version = "1.0.1"
@ -1112,6 +1202,15 @@ version = "0.2.126"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "349d5a591cd28b49e1d1037471617a32ddcda5731b99419008085f72d5a53836"
[[package]]
name = "link-cplusplus"
version = "1.0.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9272ab7b96c9046fbc5bc56c06c117cb639fe2d509df0c421cad82d2915cf369"
dependencies = [
"cc",
]
[[package]]
name = "linked-hash-map"
version = "0.5.4"
@ -1195,9 +1294,9 @@ dependencies = [
[[package]]
name = "mio"
version = "0.8.3"
version = "0.8.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "713d550d9b44d89174e066b7a6217ae06234c10cb47819a88290d2b353c31799"
checksum = "57ee1c23c7c63b0c9250c339ffdc69255f110b298b901b9f6c82547b7b87caaf"
dependencies = [
"libc",
"log",
@ -1999,6 +2098,12 @@ version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd"
[[package]]
name = "scratch"
version = "1.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9c8132065adcfd6e02db789d9285a0deb2f3fcb04002865ab67d5fb103533898"
[[package]]
name = "sct"
version = "0.7.0"
@ -2292,16 +2397,16 @@ checksum = "cda74da7e1a664f795bb1f8a87ec406fb89a02522cf6e50620d016add6dbbf5c"
[[package]]
name = "tokio"
version = "1.19.2"
version = "1.21.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c51a52ed6686dd62c320f9b89299e9dfb46f730c7a48e635c19f21d116cb1439"
checksum = "a9e03c497dc955702ba729190dc4aac6f2a0ce97f913e5b1b5912fc5039d9099"
dependencies = [
"autocfg 1.1.0",
"bytes",
"libc",
"memchr",
"mio",
"num_cpus",
"once_cell",
"parking_lot 0.12.1",
"pin-project-lite",
"signal-hook-registry",
@ -2592,6 +2697,12 @@ dependencies = [
"tinyvec",
]
[[package]]
name = "unicode-width"
version = "0.1.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c0edd1e5b14653f783770bce4a4dabb4a5108a5370a5f5d8cfe8710c361f6c8b"
[[package]]
name = "untrusted"
version = "0.7.1"

View file

@ -15,15 +15,15 @@ path = "src/cli.rs"
[dependencies]
futures = "0.3"
tokio = { version = "1.15", features = ["full"] }
tokio = { version = "1.21", features = ["full"] }
tokio-stream = "0.1.9"
hyper = "0.14"
tower-http = { version = "0.3.4", features = ["full"] }
axum = { version = "0.5", features = ["json", "headers", "multipart"] }
diesel = { version = "1.4.4", features = ["postgres", "chrono"] }
diesel-derive-enum = { version = "1.1", features = ["postgres"] }
bb8 = "0.7"
bb8-diesel = "0.2"
diesel = { version = "2.0", features = ["postgres", "chrono"] }
diesel-derive-enum = { version = "2.0.0-rc.0", features = ["postgres"] }
bb8 = "0.8"
bb8-diesel = { git = "https://github.com/overdrivenpotato/bb8-diesel.git" }
dotenv = "0.15.0"
rust-argon2 = "0.8"
rand = "0.8.4"

View file

@ -38,12 +38,12 @@ impl SetPassword {
let global_config = get_config().unwrap();
let pool = create_db_pool(&global_config).await;
let conn = pool.get().await.expect("could not get database connection");
let mut conn = pool.get().await.expect("could not get database connection");
let credentials = db::users::Credentials {
username: &self.username,
password: &self.new_password,
};
db::users::set_user_password(credentials, &conn).expect("could not set password");
db::users::set_user_password(credentials, &mut conn).expect("could not set password");
}
}

View file

@ -5,7 +5,7 @@ use crate::schema::{bot_versions, bots};
use chrono;
#[derive(Insertable)]
#[table_name = "bots"]
#[diesel(table_name = bots)]
pub struct NewBot<'a> {
pub owner_id: Option<i32>,
pub name: &'a str,
@ -19,29 +19,29 @@ pub struct Bot {
pub active_version: Option<i32>,
}
pub fn create_bot(new_bot: &NewBot, conn: &PgConnection) -> QueryResult<Bot> {
pub fn create_bot(new_bot: &NewBot, conn: &mut PgConnection) -> QueryResult<Bot> {
diesel::insert_into(bots::table)
.values(new_bot)
.get_result(conn)
}
pub fn find_bot(id: i32, conn: &PgConnection) -> QueryResult<Bot> {
pub fn find_bot(id: i32, conn: &mut PgConnection) -> QueryResult<Bot> {
bots::table.find(id).first(conn)
}
pub fn find_bots_by_owner(owner_id: i32, conn: &PgConnection) -> QueryResult<Vec<Bot>> {
pub fn find_bots_by_owner(owner_id: i32, conn: &mut PgConnection) -> QueryResult<Vec<Bot>> {
bots::table
.filter(bots::owner_id.eq(owner_id))
.get_results(conn)
}
pub fn find_bot_by_name(name: &str, conn: &PgConnection) -> QueryResult<Bot> {
pub fn find_bot_by_name(name: &str, conn: &mut PgConnection) -> QueryResult<Bot> {
bots::table.filter(bots::name.eq(name)).first(conn)
}
pub fn find_bot_with_version_by_name(
bot_name: &str,
conn: &PgConnection,
conn: &mut PgConnection,
) -> QueryResult<(Bot, BotVersion)> {
bots::table
.inner_join(bot_versions::table.on(bots::active_version.eq(bot_versions::id.nullable())))
@ -49,26 +49,28 @@ pub fn find_bot_with_version_by_name(
.first(conn)
}
pub fn all_active_bots_with_version(conn: &PgConnection) -> QueryResult<Vec<(Bot, BotVersion)>> {
pub fn all_active_bots_with_version(
conn: &mut PgConnection,
) -> QueryResult<Vec<(Bot, BotVersion)>> {
bots::table
.inner_join(bot_versions::table.on(bots::active_version.eq(bot_versions::id.nullable())))
.get_results(conn)
}
pub fn find_all_bots(conn: &PgConnection) -> QueryResult<Vec<Bot>> {
pub fn find_all_bots(conn: &mut PgConnection) -> QueryResult<Vec<Bot>> {
bots::table.get_results(conn)
}
/// Find all bots that have an associated active version.
/// These are the bots that can be run.
pub fn find_active_bots(conn: &PgConnection) -> QueryResult<Vec<Bot>> {
pub fn find_active_bots(conn: &mut PgConnection) -> QueryResult<Vec<Bot>> {
bots::table
.filter(bots::active_version.is_not_null())
.get_results(conn)
}
#[derive(Insertable)]
#[table_name = "bot_versions"]
#[diesel(table_name = bot_versions)]
pub struct NewBotVersion<'a> {
pub bot_id: Option<i32>,
pub code_bundle_path: Option<&'a str>,
@ -86,7 +88,7 @@ pub struct BotVersion {
pub fn create_bot_version(
new_bot_version: &NewBotVersion,
conn: &PgConnection,
conn: &mut PgConnection,
) -> QueryResult<BotVersion> {
diesel::insert_into(bot_versions::table)
.values(new_bot_version)
@ -96,7 +98,7 @@ pub fn create_bot_version(
pub fn set_active_version(
bot_id: i32,
version_id: Option<i32>,
conn: &PgConnection,
conn: &mut PgConnection,
) -> QueryResult<()> {
diesel::update(bots::table.filter(bots::id.eq(bot_id)))
.set(bots::active_version.eq(version_id))
@ -104,13 +106,13 @@ pub fn set_active_version(
Ok(())
}
pub fn find_bot_version(version_id: i32, conn: &PgConnection) -> QueryResult<BotVersion> {
pub fn find_bot_version(version_id: i32, conn: &mut PgConnection) -> QueryResult<BotVersion> {
bot_versions::table
.filter(bot_versions::id.eq(version_id))
.first(conn)
}
pub fn find_bot_versions(bot_id: i32, conn: &PgConnection) -> QueryResult<Vec<BotVersion>> {
pub fn find_bot_versions(bot_id: i32, conn: &mut PgConnection) -> QueryResult<Vec<BotVersion>> {
bot_versions::table
.filter(bot_versions::bot_id.eq(bot_id))
.get_results(conn)

View file

@ -3,7 +3,7 @@ use diesel::prelude::*;
use crate::schema::maps;
#[derive(Insertable)]
#[table_name = "maps"]
#[diesel(table_name = maps)]
pub struct NewMap<'a> {
pub name: &'a str,
pub file_path: &'a str,
@ -16,20 +16,20 @@ pub struct Map {
pub file_path: String,
}
pub fn create_map(new_map: NewMap, conn: &PgConnection) -> QueryResult<Map> {
pub fn create_map(new_map: NewMap, conn: &mut PgConnection) -> QueryResult<Map> {
diesel::insert_into(maps::table)
.values(new_map)
.get_result(conn)
}
pub fn find_map(id: i32, conn: &PgConnection) -> QueryResult<Map> {
pub fn find_map(id: i32, conn: &mut PgConnection) -> QueryResult<Map> {
maps::table.find(id).get_result(conn)
}
pub fn find_map_by_name(name: &str, conn: &PgConnection) -> QueryResult<Map> {
pub fn find_map_by_name(name: &str, conn: &mut PgConnection) -> QueryResult<Map> {
maps::table.filter(maps::name.eq(name)).first(conn)
}
pub fn list_maps(conn: &PgConnection) -> QueryResult<Vec<Map>> {
pub fn list_maps(conn: &mut PgConnection) -> QueryResult<Vec<Map>> {
maps::table.get_results(conn)
}

View file

@ -1,9 +1,6 @@
pub use crate::db_types::MatchState;
use chrono::NaiveDateTime;
use diesel::associations::BelongsTo;
use diesel::pg::Pg;
use diesel::query_builder::BoxedSelectStatement;
use diesel::query_source::{AppearsInFromClause, Once};
use diesel::sql_types::*;
use diesel::{
BelongingToDsl, ExpressionMethods, JoinOnDsl, NullableExpressionMethods, QueryDsl, RunQueryDsl,
@ -18,7 +15,7 @@ use super::bots::{Bot, BotVersion};
use super::maps::Map;
#[derive(Insertable)]
#[table_name = "matches"]
#[diesel(table_name = matches)]
pub struct NewMatch<'a> {
pub state: MatchState,
pub log_path: &'a str,
@ -27,7 +24,7 @@ pub struct NewMatch<'a> {
}
#[derive(Insertable)]
#[table_name = "match_players"]
#[diesel(table_name = match_players)]
pub struct NewMatchPlayer {
/// id of the match this player is in
pub match_id: i32,
@ -38,7 +35,7 @@ pub struct NewMatchPlayer {
}
#[derive(Queryable, Identifiable)]
#[table_name = "matches"]
#[diesel(table_name = matches)]
pub struct MatchBase {
pub id: i32,
pub state: MatchState,
@ -50,8 +47,8 @@ pub struct MatchBase {
}
#[derive(Queryable, Identifiable, Associations, Clone)]
#[primary_key(match_id, player_id)]
#[belongs_to(MatchBase, foreign_key = "match_id")]
#[diesel(primary_key(match_id, player_id))]
#[diesel(belongs_to(MatchBase, foreign_key = match_id))]
pub struct MatchPlayer {
pub match_id: i32,
pub player_id: i32,
@ -65,9 +62,9 @@ pub struct MatchPlayerData {
pub fn create_match(
new_match_base: &NewMatch,
new_match_players: &[MatchPlayerData],
conn: &PgConnection,
conn: &mut PgConnection,
) -> QueryResult<MatchData> {
conn.transaction(|| {
conn.transaction(|conn| {
let match_base = diesel::insert_into(matches::table)
.values(new_match_base)
.get_result::<MatchBase>(conn)?;
@ -101,7 +98,7 @@ pub struct MatchData {
/// Add player information to MatchBase instances
fn fetch_full_match_data(
matches: Vec<MatchBase>,
conn: &PgConnection,
conn: &mut PgConnection,
) -> QueryResult<Vec<FullMatchData>> {
let map_ids: HashSet<i32> = matches.iter().filter_map(|m| m.map_id).collect();
@ -140,8 +137,8 @@ fn fetch_full_match_data(
}
// TODO: this method should disappear
pub fn list_matches(amount: i64, conn: &PgConnection) -> QueryResult<Vec<FullMatchData>> {
conn.transaction(|| {
pub fn list_matches(amount: i64, conn: &mut PgConnection) -> QueryResult<Vec<FullMatchData>> {
conn.transaction(|conn| {
let matches = matches::table
.filter(matches::state.eq(MatchState::Finished))
.order_by(matches::created_at.desc())
@ -164,17 +161,32 @@ pub fn list_public_matches(
amount: i64,
before: Option<NaiveDateTime>,
after: Option<NaiveDateTime>,
conn: &PgConnection,
conn: &mut PgConnection,
) -> QueryResult<Vec<FullMatchData>> {
conn.transaction(|| {
conn.transaction(|conn| {
// TODO: how can this common logic be abstracted?
let query = matches::table
let mut query = matches::table
.filter(matches::state.eq(MatchState::Finished))
.filter(matches::is_public.eq(true))
.into_boxed();
let matches =
select_matches_page(query, amount, before, after).get_results::<MatchBase>(conn)?;
// TODO: how to remove this duplication?
query = match (before, after) {
(None, None) => query.order_by(matches::created_at.desc()),
(Some(before), None) => query
.filter(matches::created_at.lt(before))
.order_by(matches::created_at.desc()),
(None, Some(after)) => query
.filter(matches::created_at.gt(after))
.order_by(matches::created_at.asc()),
(Some(before), Some(after)) => query
.filter(matches::created_at.lt(before))
.filter(matches::created_at.gt(after))
.order_by(matches::created_at.desc()),
};
query = query.limit(amount);
let matches = query.get_results::<MatchBase>(conn)?;
fetch_full_match_data(matches, conn)
})
}
@ -185,7 +197,7 @@ pub fn list_bot_matches(
amount: i64,
before: Option<NaiveDateTime>,
after: Option<NaiveDateTime>,
conn: &PgConnection,
conn: &mut PgConnection,
) -> QueryResult<Vec<FullMatchData>> {
let mut query = matches::table
.filter(matches::state.eq(MatchState::Finished))
@ -211,22 +223,8 @@ pub fn list_bot_matches(
};
}
let matches =
select_matches_page(query, amount, before, after).get_results::<MatchBase>(conn)?;
fetch_full_match_data(matches, conn)
}
fn select_matches_page<QS>(
query: BoxedSelectStatement<'static, matches::SqlType, QS, Pg>,
amount: i64,
before: Option<NaiveDateTime>,
after: Option<NaiveDateTime>,
) -> BoxedSelectStatement<'static, matches::SqlType, QS, Pg>
where
QS: AppearsInFromClause<matches::table, Count = Once>,
{
// TODO: this is not nice. Replace this with proper cursor logic.
match (before, after) {
// TODO: how to remove this duplication?
query = match (before, after) {
(None, None) => query.order_by(matches::created_at.desc()),
(Some(before), None) => query
.filter(matches::created_at.lt(before))
@ -238,8 +236,11 @@ where
.filter(matches::created_at.lt(before))
.filter(matches::created_at.gt(after))
.order_by(matches::created_at.desc()),
}
.limit(amount)
};
query = query.limit(amount);
let matches = query.get_results::<MatchBase>(conn)?;
fetch_full_match_data(matches, conn)
}
// TODO: maybe unify this with matchdata?
@ -270,8 +271,8 @@ impl BelongsTo<MatchBase> for FullMatchPlayerData {
}
}
pub fn find_match(id: i32, conn: &PgConnection) -> QueryResult<FullMatchData> {
conn.transaction(|| {
pub fn find_match(id: i32, conn: &mut PgConnection) -> QueryResult<FullMatchData> {
conn.transaction(|conn| {
let match_base = matches::table.find(id).get_result::<MatchBase>(conn)?;
let map = match match_base.map_id {
@ -298,7 +299,7 @@ pub fn find_match(id: i32, conn: &PgConnection) -> QueryResult<FullMatchData> {
})
}
pub fn find_match_base(id: i32, conn: &PgConnection) -> QueryResult<MatchBase> {
pub fn find_match_base(id: i32, conn: &mut PgConnection) -> QueryResult<MatchBase> {
matches::table.find(id).get_result::<MatchBase>(conn)
}
@ -306,7 +307,7 @@ pub enum MatchResult {
Finished { winner: Option<i32> },
}
pub fn save_match_result(id: i32, result: MatchResult, conn: &PgConnection) -> QueryResult<()> {
pub fn save_match_result(id: i32, result: MatchResult, conn: &mut PgConnection) -> QueryResult<()> {
let MatchResult::Finished { winner } = result;
diesel::update(matches::table.find(id))
@ -320,17 +321,20 @@ pub fn save_match_result(id: i32, result: MatchResult, conn: &PgConnection) -> Q
#[derive(QueryableByName)]
pub struct BotStatsRecord {
#[sql_type = "Text"]
#[diesel(sql_type = Text)]
pub opponent: String,
#[sql_type = "Text"]
#[diesel(sql_type = Text)]
pub map: String,
#[sql_type = "Nullable<Bool>"]
#[diesel(sql_type = Nullable<Bool>)]
pub win: Option<bool>,
#[sql_type = "Int8"]
#[diesel(sql_type = Int8)]
pub count: i64,
}
pub fn fetch_bot_stats(bot_name: &str, db_conn: &PgConnection) -> QueryResult<Vec<BotStatsRecord>> {
pub fn fetch_bot_stats(
bot_name: &str,
db_conn: &mut PgConnection,
) -> QueryResult<Vec<BotStatsRecord>> {
diesel::sql_query(
"
SELECT opponent, map, win, COUNT(*) as count

View file

@ -10,7 +10,7 @@ pub struct Rating {
pub rating: f64,
}
pub fn get_rating(bot_id: i32, db_conn: &PgConnection) -> QueryResult<Option<f64>> {
pub fn get_rating(bot_id: i32, db_conn: &mut PgConnection) -> QueryResult<Option<f64>> {
ratings::table
.filter(ratings::bot_id.eq(bot_id))
.select(ratings::rating)
@ -18,7 +18,7 @@ pub fn get_rating(bot_id: i32, db_conn: &PgConnection) -> QueryResult<Option<f64
.optional()
}
pub fn set_rating(bot_id: i32, rating: f64, db_conn: &PgConnection) -> QueryResult<usize> {
pub fn set_rating(bot_id: i32, rating: f64, db_conn: &mut PgConnection) -> QueryResult<usize> {
diesel::insert_into(ratings::table)
.values(Rating { bot_id, rating })
.on_conflict(ratings::bot_id)
@ -40,7 +40,7 @@ pub struct RankedBot {
pub rating: f64,
}
pub fn get_bot_ranking(db_conn: &PgConnection) -> QueryResult<Vec<RankedBot>> {
pub fn get_bot_ranking(db_conn: &mut PgConnection) -> QueryResult<Vec<RankedBot>> {
bots::table
.left_join(users::table)
.inner_join(ratings::table)

View file

@ -6,7 +6,7 @@ use diesel::{insert_into, prelude::*, Insertable, RunQueryDsl};
use rand::{self, Rng};
#[derive(Insertable)]
#[table_name = "sessions"]
#[diesel(table_name = sessions)]
struct NewSession {
token: String,
user_id: i32,
@ -19,7 +19,7 @@ pub struct Session {
pub token: String,
}
pub fn create_session(user: &User, conn: &PgConnection) -> Session {
pub fn create_session(user: &User, conn: &mut PgConnection) -> Session {
let new_session = NewSession {
token: gen_session_token(),
user_id: user.id,
@ -31,7 +31,7 @@ pub fn create_session(user: &User, conn: &PgConnection) -> Session {
.unwrap()
}
pub fn find_user_by_session(token: &str, conn: &PgConnection) -> QueryResult<(Session, User)> {
pub fn find_user_by_session(token: &str, conn: &mut PgConnection) -> QueryResult<(Session, User)> {
sessions::table
.inner_join(users::table)
.filter(sessions::token.eq(&token))

View file

@ -11,7 +11,7 @@ pub struct Credentials<'a> {
}
#[derive(Insertable)]
#[table_name = "users"]
#[diesel(table_name = users)]
pub struct NewUser<'a> {
pub username: &'a str,
pub password_hash: &'a [u8],
@ -50,7 +50,7 @@ pub fn hash_password(password: &str) -> (Vec<u8>, [u8; 32]) {
(hash, salt)
}
pub fn create_user(credentials: &Credentials, conn: &PgConnection) -> QueryResult<User> {
pub fn create_user(credentials: &Credentials, conn: &mut PgConnection) -> QueryResult<User> {
let (hash, salt) = hash_password(&credentials.password);
let new_user = NewUser {
@ -63,19 +63,19 @@ pub fn create_user(credentials: &Credentials, conn: &PgConnection) -> QueryResul
.get_result::<User>(conn)
}
pub fn find_user(user_id: i32, db_conn: &PgConnection) -> QueryResult<User> {
pub fn find_user(user_id: i32, db_conn: &mut PgConnection) -> QueryResult<User> {
users::table
.filter(users::id.eq(user_id))
.first::<User>(db_conn)
}
pub fn find_user_by_name(username: &str, db_conn: &PgConnection) -> QueryResult<User> {
pub fn find_user_by_name(username: &str, db_conn: &mut PgConnection) -> QueryResult<User> {
users::table
.filter(users::username.eq(username))
.first::<User>(db_conn)
}
pub fn set_user_password(credentials: Credentials, db_conn: &PgConnection) -> QueryResult<()> {
pub fn set_user_password(credentials: Credentials, db_conn: &mut PgConnection) -> QueryResult<()> {
let (hash, salt) = hash_password(&credentials.password);
let n_changes = diesel::update(users::table.filter(users::username.eq(&credentials.username)))
@ -91,7 +91,7 @@ pub fn set_user_password(credentials: Credentials, db_conn: &PgConnection) -> Qu
}
}
pub fn authenticate_user(credentials: &Credentials, db_conn: &PgConnection) -> Option<User> {
pub fn authenticate_user(credentials: &Credentials, db_conn: &mut PgConnection) -> Option<User> {
find_user_by_name(credentials.username, db_conn)
.optional()
.unwrap()

View file

@ -2,7 +2,7 @@ use diesel_derive_enum::DbEnum;
use serde::{Deserialize, Serialize};
#[derive(DbEnum, Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash)]
#[DieselType = "Match_state"]
#[DieselTypePath = "crate::schema::sql_types::MatchState"]
pub enum MatchState {
Playing,

View file

@ -8,7 +8,7 @@ pub mod routes;
pub mod schema;
pub mod util;
use std::ops::Deref;
use std::ops::{Deref, DerefMut};
use std::path::PathBuf;
use std::sync::Arc;
use std::{fs, net::SocketAddr};
@ -70,9 +70,9 @@ pub struct GlobalConfig {
const SIMPLEBOT_PATH: &str = "../simplebot/simplebot.py";
pub async fn seed_simplebot(config: &GlobalConfig, pool: &ConnectionPool) {
let conn = pool.get().await.expect("could not get database connection");
let mut conn = pool.get().await.expect("could not get database connection");
// This transaction is expected to fail when simplebot already exists.
let _res = conn.transaction::<(), diesel::result::Error, _>(|| {
let _res = conn.transaction::<(), diesel::result::Error, _>(|conn| {
use db::bots::NewBot;
let new_bot = NewBot {
@ -80,12 +80,12 @@ pub async fn seed_simplebot(config: &GlobalConfig, pool: &ConnectionPool) {
owner_id: None,
};
let simplebot = db::bots::create_bot(&new_bot, &conn)?;
let simplebot = db::bots::create_bot(&new_bot, conn)?;
let simplebot_code =
std::fs::read_to_string(SIMPLEBOT_PATH).expect("could not read simplebot code");
modules::bots::save_code_string(&simplebot_code, Some(simplebot.id), &conn, config)?;
modules::bots::save_code_string(&simplebot_code, Some(simplebot.id), conn, config)?;
println!("initialized simplebot");
@ -209,6 +209,12 @@ impl Deref for DatabaseConnection {
}
}
impl DerefMut for DatabaseConnection {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
#[async_trait]
impl<B> FromRequest<B> for DatabaseConnection
where

View file

@ -9,7 +9,7 @@ use crate::{db, util::gen_alphanumeric, GlobalConfig};
pub fn save_code_string(
bot_code: &str,
bot_id: Option<i32>,
conn: &PgConnection,
conn: &mut PgConnection,
config: &GlobalConfig,
) -> QueryResult<db::bots::BotVersion> {
let bundle_name = gen_alphanumeric(16);

View file

@ -149,19 +149,19 @@ impl pb::client_api_service_server::ClientApiService for ClientApiServer {
req: Request<pb::CreateMatchRequest>,
) -> Result<Response<pb::CreateMatchResponse>, Status> {
// TODO: unify with matchrunner module
let conn = self.conn_pool.get().await.unwrap();
let mut conn = self.conn_pool.get().await.unwrap();
let match_request = req.get_ref();
let (opponent_bot, opponent_bot_version) =
db::bots::find_bot_with_version_by_name(&match_request.opponent_name, &conn)
db::bots::find_bot_with_version_by_name(&match_request.opponent_name, &mut conn)
.map_err(|_| Status::not_found("opponent not found"))?;
let map_name = match match_request.map_name.as_str() {
"" => "hex",
name => name,
};
let map = db::maps::find_map_by_name(map_name, &conn)
let map = db::maps::find_map_by_name(map_name, &mut conn)
.map_err(|_| Status::not_found("map not found"))?;
let player_key = gen_alphanumeric(32);

View file

@ -80,8 +80,8 @@ impl RunMatch {
let match_data = {
// TODO: it would be nice to get an already-open connection here when possible.
// Maybe we need an additional abstraction, bundling a connection and connection pool?
let db_conn = conn_pool.get().await.expect("could not get a connection");
self.store_in_database(&db_conn)?
let mut db_conn = conn_pool.get().await.expect("could not get a connection");
self.store_in_database(&mut db_conn)?
};
let runner_config = self.into_runner_config();
@ -90,7 +90,7 @@ impl RunMatch {
Ok((match_data, handle))
}
fn store_in_database(&self, db_conn: &PgConnection) -> QueryResult<MatchData> {
fn store_in_database(&self, db_conn: &mut PgConnection) -> QueryResult<MatchData> {
let new_match_data = db::matches::NewMatch {
state: db::matches::MatchState::Playing,
log_path: &self.log_file_name,
@ -167,7 +167,7 @@ async fn run_match_task(
let outcome = runner::run_match(match_config).await;
// update match state in database
let conn = connection_pool
let mut conn = connection_pool
.get()
.await
.expect("could not get database connection");
@ -176,7 +176,8 @@ async fn run_match_task(
winner: outcome.winner.map(|w| (w - 1) as i32), // player numbers in matchrunner start at 1
};
db::matches::save_match_result(match_id, result, &conn).expect("could not save match result");
db::matches::save_match_result(match_id, result, &mut conn)
.expect("could not save match result");
outcome
}

View file

@ -20,13 +20,14 @@ pub async fn run_ranker(config: Arc<GlobalConfig>, db_pool: DbPool) {
// TODO: make this configurable
// play at most one match every n seconds
let mut interval = tokio::time::interval(Duration::from_secs(RANKER_INTERVAL));
let db_conn = db_pool
let mut db_conn = db_pool
.get()
.await
.expect("could not get database connection");
loop {
interval.tick().await;
let bots = db::bots::all_active_bots_with_version(&db_conn).expect("could not load bots");
let bots =
db::bots::all_active_bots_with_version(&mut db_conn).expect("could not load bots");
if bots.len() < 2 {
// not enough bots to play a match
continue;
@ -37,14 +38,14 @@ pub async fn run_ranker(config: Arc<GlobalConfig>, db_pool: DbPool) {
.cloned()
.collect();
let maps = db::maps::list_maps(&db_conn).expect("could not load map");
let maps = db::maps::list_maps(&mut db_conn).expect("could not load map");
let map = match maps.choose(&mut rand::thread_rng()).cloned() {
None => continue, // no maps available
Some(map) => map,
};
play_ranking_match(config.clone(), map, selected_bots, db_pool.clone()).await;
recalculate_ratings(&db_conn).expect("could not recalculate ratings");
recalculate_ratings(&mut db_conn).expect("could not recalculate ratings");
}
}
@ -71,7 +72,7 @@ async fn play_ranking_match(
let _outcome = handle.await;
}
fn recalculate_ratings(db_conn: &PgConnection) -> QueryResult<()> {
fn recalculate_ratings(db_conn: &mut PgConnection) -> QueryResult<()> {
let start = Instant::now();
let match_stats = fetch_match_stats(db_conn)?;
let ratings = estimate_ratings_from_stats(match_stats);
@ -91,7 +92,7 @@ struct MatchStats {
num_matches: usize,
}
fn fetch_match_stats(db_conn: &PgConnection) -> QueryResult<HashMap<(i32, i32), MatchStats>> {
fn fetch_match_stats(db_conn: &mut PgConnection) -> QueryResult<HashMap<(i32, i32), MatchStats>> {
let matches = db::matches::list_matches(RANKER_NUM_MATCHES, db_conn)?;
let mut match_stats = HashMap::<(i32, i32), MatchStats>::new();

View file

@ -112,8 +112,8 @@ where
Err(RegistryAuthError::InvalidCredentials)
}
} else {
let db_conn = DatabaseConnection::from_request(req).await.unwrap();
let user = authenticate_user(&credentials, &db_conn)
let mut db_conn = DatabaseConnection::from_request(req).await.unwrap();
let user = authenticate_user(&credentials, &mut db_conn)
.ok_or(RegistryAuthError::InvalidCredentials)?;
Ok(RegistryAuth::User(user))
@ -159,12 +159,12 @@ pub struct RegistryError {
}
async fn check_blob_exists(
db_conn: DatabaseConnection,
mut db_conn: DatabaseConnection,
auth: RegistryAuth,
Path((repository_name, raw_digest)): Path<(String, String)>,
Extension(config): Extension<Arc<GlobalConfig>>,
) -> Result<impl IntoResponse, StatusCode> {
check_access(&repository_name, &auth, &db_conn)?;
check_access(&repository_name, &auth, &mut db_conn)?;
let digest = raw_digest.strip_prefix("sha256:").unwrap();
let blob_path = PathBuf::from(&config.registry_directory)
@ -179,12 +179,12 @@ async fn check_blob_exists(
}
async fn get_blob(
db_conn: DatabaseConnection,
mut db_conn: DatabaseConnection,
auth: RegistryAuth,
Path((repository_name, raw_digest)): Path<(String, String)>,
Extension(config): Extension<Arc<GlobalConfig>>,
) -> Result<impl IntoResponse, StatusCode> {
check_access(&repository_name, &auth, &db_conn)?;
check_access(&repository_name, &auth, &mut db_conn)?;
let digest = raw_digest.strip_prefix("sha256:").unwrap();
let blob_path = PathBuf::from(&config.registry_directory)
@ -200,12 +200,12 @@ async fn get_blob(
}
async fn create_upload(
db_conn: DatabaseConnection,
mut db_conn: DatabaseConnection,
auth: RegistryAuth,
Path(repository_name): Path<String>,
Extension(config): Extension<Arc<GlobalConfig>>,
) -> Result<impl IntoResponse, StatusCode> {
check_access(&repository_name, &auth, &db_conn)?;
check_access(&repository_name, &auth, &mut db_conn)?;
let uuid = gen_alphanumeric(16);
tokio::fs::File::create(
@ -229,13 +229,13 @@ async fn create_upload(
}
async fn patch_upload(
db_conn: DatabaseConnection,
mut db_conn: DatabaseConnection,
auth: RegistryAuth,
Path((repository_name, uuid)): Path<(String, String)>,
mut stream: BodyStream,
Extension(config): Extension<Arc<GlobalConfig>>,
) -> Result<impl IntoResponse, StatusCode> {
check_access(&repository_name, &auth, &db_conn)?;
check_access(&repository_name, &auth, &mut db_conn)?;
// TODO: support content range header in request
let upload_path = PathBuf::from(&config.registry_directory)
@ -275,14 +275,14 @@ struct UploadParams {
}
async fn put_upload(
db_conn: DatabaseConnection,
mut db_conn: DatabaseConnection,
auth: RegistryAuth,
Path((repository_name, uuid)): Path<(String, String)>,
Query(params): Query<UploadParams>,
mut stream: BodyStream,
Extension(config): Extension<Arc<GlobalConfig>>,
) -> Result<impl IntoResponse, StatusCode> {
check_access(&repository_name, &auth, &db_conn)?;
check_access(&repository_name, &auth, &mut db_conn)?;
let upload_path = PathBuf::from(&config.registry_directory)
.join("uploads")
@ -332,12 +332,12 @@ async fn put_upload(
}
async fn get_manifest(
db_conn: DatabaseConnection,
mut db_conn: DatabaseConnection,
auth: RegistryAuth,
Path((repository_name, reference)): Path<(String, String)>,
Extension(config): Extension<Arc<GlobalConfig>>,
) -> Result<impl IntoResponse, StatusCode> {
check_access(&repository_name, &auth, &db_conn)?;
check_access(&repository_name, &auth, &mut db_conn)?;
let manifest_path = PathBuf::from(&config.registry_directory)
.join("manifests")
@ -357,13 +357,13 @@ async fn get_manifest(
}
async fn put_manifest(
db_conn: DatabaseConnection,
mut db_conn: DatabaseConnection,
auth: RegistryAuth,
Path((repository_name, reference)): Path<(String, String)>,
mut stream: BodyStream,
Extension(config): Extension<Arc<GlobalConfig>>,
) -> Result<impl IntoResponse, StatusCode> {
let bot = check_access(&repository_name, &auth, &db_conn)?;
let bot = check_access(&repository_name, &auth, &mut db_conn)?;
let repository_dir = PathBuf::from(&config.registry_directory)
.join("manifests")
@ -399,9 +399,9 @@ async fn put_manifest(
code_bundle_path: None,
container_digest: Some(&content_digest),
};
let version =
db::bots::create_bot_version(&new_version, &db_conn).expect("could not save bot version");
db::bots::set_active_version(bot.id, Some(version.id), &db_conn)
let version = db::bots::create_bot_version(&new_version, &mut db_conn)
.expect("could not save bot version");
db::bots::set_active_version(bot.id, Some(version.id), &mut db_conn)
.expect("could not update bot version");
Ok(Response::builder()
@ -421,7 +421,7 @@ async fn put_manifest(
fn check_access(
repository_name: &str,
auth: &RegistryAuth,
db_conn: &DatabaseConnection,
db_conn: &mut DatabaseConnection,
) -> Result<db::bots::Bot, StatusCode> {
use diesel::OptionalExtension;

View file

@ -100,10 +100,10 @@ pub fn validate_bot_name(bot_name: &str) -> Result<(), SaveBotError> {
pub async fn save_bot(
Json(params): Json<SaveBotParams>,
user: User,
conn: DatabaseConnection,
mut conn: DatabaseConnection,
Extension(config): Extension<Arc<GlobalConfig>>,
) -> Result<Json<Bot>, SaveBotError> {
let res = bots::find_bot_by_name(&params.bot_name, &conn)
let res = bots::find_bot_by_name(&params.bot_name, &mut conn)
.optional()
.expect("could not run query");
@ -122,10 +122,10 @@ pub async fn save_bot(
name: &params.bot_name,
};
bots::create_bot(&new_bot, &conn).expect("could not create bot")
bots::create_bot(&new_bot, &mut conn).expect("could not create bot")
}
};
let _code_bundle = save_code_string(&params.code, Some(bot.id), &conn, &config)
let _code_bundle = save_code_string(&params.code, Some(bot.id), &mut conn, &config)
.expect("failed to save code bundle");
Ok(Json(bot))
}
@ -137,12 +137,12 @@ pub struct BotParams {
// TODO: can we unify this with save_bot?
pub async fn create_bot(
conn: DatabaseConnection,
mut conn: DatabaseConnection,
user: User,
params: Json<BotParams>,
) -> Result<(StatusCode, Json<Bot>), SaveBotError> {
validate_bot_name(&params.name)?;
let existing_bot = bots::find_bot_by_name(&params.name, &conn)
let existing_bot = bots::find_bot_by_name(&params.name, &mut conn)
.optional()
.expect("could not run query");
if existing_bot.is_some() {
@ -152,26 +152,27 @@ pub async fn create_bot(
owner_id: Some(user.id),
name: &params.name,
};
let bot = bots::create_bot(&bot_params, &conn).unwrap();
let bot = bots::create_bot(&bot_params, &mut conn).unwrap();
Ok((StatusCode::CREATED, Json(bot)))
}
// TODO: handle errors
pub async fn get_bot(
conn: DatabaseConnection,
mut conn: DatabaseConnection,
Path(bot_name): Path<String>,
) -> Result<Json<JsonValue>, StatusCode> {
let bot = db::bots::find_bot_by_name(&bot_name, &conn).map_err(|_| StatusCode::NOT_FOUND)?;
let bot =
db::bots::find_bot_by_name(&bot_name, &mut conn).map_err(|_| StatusCode::NOT_FOUND)?;
let owner: Option<UserData> = match bot.owner_id {
Some(user_id) => {
let user = db::users::find_user(user_id, &conn)
let user = db::users::find_user(user_id, &mut conn)
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
Some(user.into())
}
None => None,
};
let versions =
bots::find_bot_versions(bot.id, &conn).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
let versions = bots::find_bot_versions(bot.id, &mut conn)
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
Ok(Json(json!({
"bot": bot,
"owner": owner,
@ -180,32 +181,32 @@ pub async fn get_bot(
}
pub async fn get_user_bots(
conn: DatabaseConnection,
mut conn: DatabaseConnection,
Path(user_name): Path<String>,
) -> Result<Json<Vec<Bot>>, StatusCode> {
let user =
db::users::find_user_by_name(&user_name, &conn).map_err(|_| StatusCode::NOT_FOUND)?;
db::bots::find_bots_by_owner(user.id, &conn)
db::users::find_user_by_name(&user_name, &mut conn).map_err(|_| StatusCode::NOT_FOUND)?;
db::bots::find_bots_by_owner(user.id, &mut conn)
.map(Json)
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
}
/// List all active bots
pub async fn list_bots(conn: DatabaseConnection) -> Result<Json<Vec<Bot>>, StatusCode> {
bots::find_active_bots(&conn)
pub async fn list_bots(mut conn: DatabaseConnection) -> Result<Json<Vec<Bot>>, StatusCode> {
bots::find_active_bots(&mut conn)
.map(Json)
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
}
pub async fn get_ranking(conn: DatabaseConnection) -> Result<Json<Vec<RankedBot>>, StatusCode> {
ratings::get_bot_ranking(&conn)
pub async fn get_ranking(mut conn: DatabaseConnection) -> Result<Json<Vec<RankedBot>>, StatusCode> {
ratings::get_bot_ranking(&mut conn)
.map(Json)
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
}
// TODO: currently this only implements the happy flow
pub async fn upload_code_multipart(
conn: DatabaseConnection,
mut conn: DatabaseConnection,
user: User,
Path(bot_name): Path<String>,
mut multipart: Multipart,
@ -213,7 +214,7 @@ pub async fn upload_code_multipart(
) -> Result<Json<BotVersion>, StatusCode> {
let bots_dir = PathBuf::from(&config.bots_directory);
let bot = bots::find_bot_by_name(&bot_name, &conn).map_err(|_| StatusCode::NOT_FOUND)?;
let bot = bots::find_bot_by_name(&bot_name, &mut conn).map_err(|_| StatusCode::NOT_FOUND)?;
if Some(user.id) != bot.owner_id {
return Err(StatusCode::FORBIDDEN);
@ -246,21 +247,22 @@ pub async fn upload_code_multipart(
container_digest: None,
};
let code_bundle =
bots::create_bot_version(&bot_version, &conn).expect("Failed to create code bundle");
bots::create_bot_version(&bot_version, &mut conn).expect("Failed to create code bundle");
Ok(Json(code_bundle))
}
pub async fn get_code(
conn: DatabaseConnection,
mut conn: DatabaseConnection,
user: User,
Path(bundle_id): Path<i32>,
Extension(config): Extension<Arc<GlobalConfig>>,
) -> Result<Vec<u8>, StatusCode> {
let version =
db::bots::find_bot_version(bundle_id, &conn).map_err(|_| StatusCode::NOT_FOUND)?;
db::bots::find_bot_version(bundle_id, &mut conn).map_err(|_| StatusCode::NOT_FOUND)?;
let bot_id = version.bot_id.ok_or(StatusCode::FORBIDDEN)?;
let bot = db::bots::find_bot(bot_id, &conn).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
let bot =
db::bots::find_bot(bot_id, &mut conn).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
if bot.owner_id != Some(user.id) {
return Err(StatusCode::FORBIDDEN);
@ -297,10 +299,10 @@ impl MatchupStats {
type BotStats = HashMap<String, HashMap<String, MatchupStats>>;
pub async fn get_bot_stats(
conn: DatabaseConnection,
mut conn: DatabaseConnection,
Path(bot_name): Path<String>,
) -> Result<Json<BotStats>, StatusCode> {
let stats_records = db::matches::fetch_bot_stats(&bot_name, &conn)
let stats_records = db::matches::fetch_bot_stats(&bot_name, &mut conn)
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
let mut bot_stats: BotStats = HashMap::new();
for record in stats_records {

View file

@ -35,7 +35,7 @@ pub async fn submit_bot(
Extension(pool): Extension<ConnectionPool>,
Extension(config): Extension<Arc<GlobalConfig>>,
) -> Result<Json<SubmitBotResponse>, StatusCode> {
let conn = pool.get().await.expect("could not get database connection");
let mut conn = pool.get().await.expect("could not get database connection");
let opponent_name = params
.opponent_name
@ -46,12 +46,13 @@ pub async fn submit_bot(
.unwrap_or_else(|| DEFAULT_MAP_NAME.to_string());
let (opponent_bot, opponent_bot_version) =
db::bots::find_bot_with_version_by_name(&opponent_name, &conn)
db::bots::find_bot_with_version_by_name(&opponent_name, &mut conn)
.map_err(|_| StatusCode::BAD_REQUEST)?;
let map = db::maps::find_map_by_name(&map_name, &conn).map_err(|_| StatusCode::BAD_REQUEST)?;
let map =
db::maps::find_map_by_name(&map_name, &mut conn).map_err(|_| StatusCode::BAD_REQUEST)?;
let player_bot_version = save_code_string(&params.code, None, &conn, &config)
let player_bot_version = save_code_string(&params.code, None, &mut conn, &config)
// TODO: can we recover from this?
.expect("could not save bot code");

View file

@ -8,8 +8,8 @@ pub struct ApiMap {
pub name: String,
}
pub async fn list_maps(conn: DatabaseConnection) -> Result<Json<Vec<ApiMap>>, StatusCode> {
let maps = db::maps::list_maps(&conn).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
pub async fn list_maps(mut conn: DatabaseConnection) -> Result<Json<Vec<ApiMap>>, StatusCode> {
let maps = db::maps::list_maps(&mut conn).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
let api_maps = maps
.into_iter()

View file

@ -56,7 +56,7 @@ pub struct ListMatchesResponse {
pub async fn list_recent_matches(
Query(params): Query<ListRecentMatchesParams>,
conn: DatabaseConnection,
mut conn: DatabaseConnection,
) -> Result<Json<ListMatchesResponse>, StatusCode> {
let requested_count = std::cmp::min(
params.count.unwrap_or(DEFAULT_NUM_RETURNED_MATCHES),
@ -68,7 +68,7 @@ pub async fn list_recent_matches(
let matches_result = match params.bot {
Some(bot_name) => {
let bot = db::bots::find_bot_by_name(&bot_name, &conn)
let bot = db::bots::find_bot_by_name(&bot_name, &mut conn)
.map_err(|_| StatusCode::BAD_REQUEST)?;
matches::list_bot_matches(
bot.id,
@ -76,10 +76,10 @@ pub async fn list_recent_matches(
count,
params.before,
params.after,
&conn,
&mut conn,
)
}
None => matches::list_public_matches(count, params.before, params.after, &conn),
None => matches::list_public_matches(count, params.before, params.after, &mut conn),
};
let mut matches = matches_result.map_err(|_| StatusCode::BAD_REQUEST)?;
@ -119,9 +119,9 @@ pub fn match_data_to_api(data: matches::FullMatchData) -> ApiMatch {
pub async fn get_match_data(
Path(match_id): Path<i32>,
conn: DatabaseConnection,
mut conn: DatabaseConnection,
) -> Result<Json<ApiMatch>, StatusCode> {
let match_data = matches::find_match(match_id, &conn)
let match_data = matches::find_match(match_id, &mut conn)
.map_err(|_| StatusCode::NOT_FOUND)
.map(match_data_to_api)?;
Ok(Json(match_data))
@ -129,11 +129,11 @@ pub async fn get_match_data(
pub async fn get_match_log(
Path(match_id): Path<i32>,
conn: DatabaseConnection,
mut conn: DatabaseConnection,
Extension(config): Extension<Arc<GlobalConfig>>,
) -> Result<Vec<u8>, StatusCode> {
let match_base =
matches::find_match_base(match_id, &conn).map_err(|_| StatusCode::NOT_FOUND)?;
matches::find_match_base(match_id, &mut conn).map_err(|_| StatusCode::NOT_FOUND)?;
let log_path = PathBuf::from(&config.match_logs_directory).join(&match_base.log_path);
let log_contents = std::fs::read(log_path).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
Ok(log_contents)

View file

@ -23,13 +23,13 @@ where
type Rejection = (StatusCode, String);
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
let conn = DatabaseConnection::from_request(req).await?;
let mut conn = DatabaseConnection::from_request(req).await?;
let TypedHeader(Authorization(bearer)) = AuthorizationHeader::from_request(req)
.await
.map_err(|_| (StatusCode::UNAUTHORIZED, "".to_string()))?;
let (_session, user) = sessions::find_user_by_session(bearer.token(), &conn)
let (_session, user) = sessions::find_user_by_session(bearer.token(), &mut conn)
.map_err(|_| (StatusCode::UNAUTHORIZED, "".to_string()))?;
Ok(user)
@ -66,7 +66,7 @@ pub enum RegistrationError {
}
impl RegistrationParams {
fn validate(&self, conn: &DatabaseConnection) -> Result<(), RegistrationError> {
fn validate(&self, conn: &mut DatabaseConnection) -> Result<(), RegistrationError> {
let mut errors = Vec::new();
// TODO: do we want to support cased usernames?
@ -95,7 +95,7 @@ impl RegistrationParams {
errors.push("that username is not allowed".to_string());
}
if users::find_user_by_name(&self.username, &conn).is_ok() {
if users::find_user_by_name(&self.username, conn).is_ok() {
errors.push("username is already taken".to_string());
}
@ -137,16 +137,16 @@ impl IntoResponse for RegistrationError {
}
pub async fn register(
conn: DatabaseConnection,
mut conn: DatabaseConnection,
params: Json<RegistrationParams>,
) -> Result<Json<UserData>, RegistrationError> {
params.validate(&conn)?;
params.validate(&mut conn)?;
let credentials = Credentials {
username: &params.username,
password: &params.password,
};
let user = users::create_user(&credentials, &conn)?;
let user = users::create_user(&credentials, &mut conn)?;
Ok(Json(user.into()))
}
@ -156,18 +156,18 @@ pub struct LoginParams {
pub password: String,
}
pub async fn login(conn: DatabaseConnection, params: Json<LoginParams>) -> Response {
pub async fn login(mut conn: DatabaseConnection, params: Json<LoginParams>) -> Response {
let credentials = Credentials {
username: &params.username,
password: &params.password,
};
// TODO: handle failures
let authenticated = users::authenticate_user(&credentials, &conn);
let authenticated = users::authenticate_user(&credentials, &mut conn);
match authenticated {
None => StatusCode::FORBIDDEN.into_response(),
Some(user) => {
let session = sessions::create_session(&user, &conn);
let session = sessions::create_session(&user, &mut conn);
let user_data: UserData = user.into();
let headers = [("Token", &session.token)];

View file

@ -1,7 +1,15 @@
// This file is autogenerated by diesel
#![allow(unused_imports)]
table! {
// @generated automatically by Diesel CLI.
pub mod sql_types {
#[derive(diesel::sql_types::SqlType)]
#[diesel(postgres_type(name = "match_state"))]
pub struct MatchState;
}
diesel::table! {
use diesel::sql_types::*;
use crate::db_types::*;
@ -14,7 +22,7 @@ table! {
}
}
table! {
diesel::table! {
use diesel::sql_types::*;
use crate::db_types::*;
@ -26,7 +34,7 @@ table! {
}
}
table! {
diesel::table! {
use diesel::sql_types::*;
use crate::db_types::*;
@ -37,7 +45,7 @@ table! {
}
}
table! {
diesel::table! {
use diesel::sql_types::*;
use crate::db_types::*;
@ -48,13 +56,14 @@ table! {
}
}
table! {
diesel::table! {
use diesel::sql_types::*;
use crate::db_types::*;
use super::sql_types::MatchState;
matches (id) {
id -> Int4,
state -> Match_state,
state -> MatchState,
log_path -> Text,
created_at -> Timestamp,
winner -> Nullable<Int4>,
@ -63,7 +72,7 @@ table! {
}
}
table! {
diesel::table! {
use diesel::sql_types::*;
use crate::db_types::*;
@ -73,7 +82,7 @@ table! {
}
}
table! {
diesel::table! {
use diesel::sql_types::*;
use crate::db_types::*;
@ -84,7 +93,7 @@ table! {
}
}
table! {
diesel::table! {
use diesel::sql_types::*;
use crate::db_types::*;
@ -96,14 +105,14 @@ table! {
}
}
joinable!(bots -> users (owner_id));
joinable!(match_players -> bot_versions (bot_version_id));
joinable!(match_players -> matches (match_id));
joinable!(matches -> maps (map_id));
joinable!(ratings -> bots (bot_id));
joinable!(sessions -> users (user_id));
diesel::joinable!(bots -> users (owner_id));
diesel::joinable!(match_players -> bot_versions (bot_version_id));
diesel::joinable!(match_players -> matches (match_id));
diesel::joinable!(matches -> maps (map_id));
diesel::joinable!(ratings -> bots (bot_id));
diesel::joinable!(sessions -> users (user_id));
allow_tables_to_appear_in_same_query!(
diesel::allow_tables_to_appear_in_same_query!(
bot_versions,
bots,
maps,

View file

@ -27,7 +27,7 @@ fn create_subdir<P: AsRef<Path>>(base_path: &Path, p: P) -> io::Result<String> {
Ok(dir_path_string)
}
fn clear_database(conn: &PgConnection) {
fn clear_database(conn: &mut PgConnection) {
diesel::sql_query(
"TRUNCATE TABLE
bots,
@ -45,20 +45,20 @@ fn clear_database(conn: &PgConnection) {
/// Setup a simple text fixture, having simplebot and the hex map.
/// This is enough to run a simple match.
fn setup_simple_fixture(db_conn: &PgConnection, config: &GlobalConfig) {
fn setup_simple_fixture(db_conn: &mut PgConnection, config: &GlobalConfig) {
let bot = db::bots::create_bot(
&db::bots::NewBot {
owner_id: None,
name: "simplebot",
},
&db_conn,
db_conn,
)
.expect("could not create simplebot");
let simplebot_code = std::fs::read_to_string("../simplebot/simplebot.py")
.expect("could not read simplebot code");
let _bot_version =
modules::bots::save_code_string(&simplebot_code, Some(bot.id), &db_conn, &config)
modules::bots::save_code_string(&simplebot_code, Some(bot.id), db_conn, &config)
.expect("could not save bot version");
std::fs::copy(
@ -71,7 +71,7 @@ fn setup_simple_fixture(db_conn: &PgConnection, config: &GlobalConfig) {
name: "hex",
file_path: "hex.json",
},
&db_conn,
db_conn,
)
.expect("could not save map");
}
@ -119,14 +119,14 @@ impl<'a> TestApp<'a> {
async fn with_db_conn<F, R>(&self, function: F) -> R
where
F: FnOnce(&PgConnection) -> R,
F: FnOnce(&mut PgConnection) -> R,
{
let db_conn = self
let mut db_conn = self
.db_pool
.get()
.await
.expect("could not get db connection");
function(&db_conn)
function(&mut db_conn)
}
}