From a1d81ac774c0ae52f155cd764fd74fd1ba928a5f Mon Sep 17 00:00:00 2001 From: Ilion Beyst Date: Fri, 23 Sep 2022 21:34:57 +0200 Subject: [PATCH] ensure bots cleanly stop before a match completes --- planetwars-matchrunner/src/bot_runner.rs | 12 ++++++++++-- planetwars-matchrunner/src/docker_runner.rs | 19 +++++++++++++------ planetwars-matchrunner/src/lib.rs | 6 ++---- planetwars-matchrunner/src/match_context.rs | 10 ++++++++++ planetwars-matchrunner/src/pw_match.rs | 5 ++++- .../tests/test_matchrunner.rs | 3 ++- 6 files changed, 41 insertions(+), 14 deletions(-) diff --git a/planetwars-matchrunner/src/bot_runner.rs b/planetwars-matchrunner/src/bot_runner.rs index d40a133..8597e26 100644 --- a/planetwars-matchrunner/src/bot_runner.rs +++ b/planetwars-matchrunner/src/bot_runner.rs @@ -6,14 +6,18 @@ use std::sync::Mutex; use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader, Lines}; use tokio::process; use tokio::sync::mpsc; +use tokio::task::JoinHandle; use tokio::time::timeout; use super::match_context::EventBus; use super::match_context::PlayerHandle; use super::match_context::RequestError; use super::match_context::RequestMessage; +// TODO: this is exactly the same as the docker bot handle. +// should this abstraction be removed? pub struct LocalBotHandle { tx: mpsc::UnboundedSender, + join_handle: JoinHandle<()>, } impl PlayerHandle for LocalBotHandle { @@ -22,6 +26,10 @@ impl PlayerHandle for LocalBotHandle { .send(r) .expect("failed to send message to local bot"); } + + fn into_join_handle(self: Box) -> JoinHandle<()> { + self.join_handle + } } pub fn run_local_bot(player_id: u32, event_bus: Arc>, bot: Bot) -> LocalBotHandle { @@ -33,9 +41,9 @@ pub fn run_local_bot(player_id: u32, event_bus: Arc>, bot: Bot) player_id, bot, }; - tokio::spawn(runner.run()); + let join_handle = tokio::spawn(runner.run()); - LocalBotHandle { tx } + LocalBotHandle { tx, join_handle } } pub struct LocalBotRunner { diff --git a/planetwars-matchrunner/src/docker_runner.rs b/planetwars-matchrunner/src/docker_runner.rs index 6de9bb1..939d734 100644 --- a/planetwars-matchrunner/src/docker_runner.rs +++ b/planetwars-matchrunner/src/docker_runner.rs @@ -9,6 +9,7 @@ use bytes::Bytes; use futures::{Stream, StreamExt}; use tokio::io::{AsyncWrite, AsyncWriteExt}; use tokio::sync::mpsc; +use tokio::task::JoinHandle; use tokio::time::timeout; use crate::match_context::{EventBus, PlayerHandle, RequestError, RequestMessage}; @@ -42,8 +43,7 @@ impl BotSpec for DockerBotSpec { match_logger: MatchLogger, ) -> Box { let process = spawn_docker_process(self).await.unwrap(); - let (handle, runner) = create_docker_bot(process, player_id, event_bus, match_logger); - tokio::spawn(runner.run()); + let handle = run_docker_bot(process, player_id, event_bus, match_logger); return Box::new(handle); } } @@ -155,14 +155,13 @@ impl ContainerProcess { } } -fn create_docker_bot( +fn run_docker_bot( process: ContainerProcess, player_id: u32, event_bus: Arc>, match_logger: MatchLogger, -) -> (DockerBotHandle, DockerBotRunner) { +) -> DockerBotHandle { let (tx, rx) = mpsc::unbounded_channel(); - let bot_handle = DockerBotHandle { tx }; let bot_runner = DockerBotRunner { process, player_id, @@ -170,11 +169,15 @@ fn create_docker_bot( match_logger, rx, }; - (bot_handle, bot_runner) + + let join_handle = tokio::spawn(bot_runner.run()); + + DockerBotHandle { tx, join_handle } } pub struct DockerBotHandle { tx: mpsc::UnboundedSender, + join_handle: JoinHandle<()>, } impl PlayerHandle for DockerBotHandle { @@ -183,6 +186,10 @@ impl PlayerHandle for DockerBotHandle { .send(r) .expect("failed to send message to local bot"); } + + fn into_join_handle(self: Box) -> JoinHandle<()> { + self.join_handle + } } pub struct DockerBotRunner { diff --git a/planetwars-matchrunner/src/lib.rs b/planetwars-matchrunner/src/lib.rs index d26e810..50cff70 100644 --- a/planetwars-matchrunner/src/lib.rs +++ b/planetwars-matchrunner/src/lib.rs @@ -106,11 +106,9 @@ pub async fn run_match(config: MatchConfig) -> MatchOutcome { // ) // .unwrap(); - let mut match_state = pw_match::PwMatch::create(match_ctx, pw_config); - match_state.run().await; + let final_state = pw_match::PwMatch::create(match_ctx, pw_config).run().await; - let final_state = match_state.match_state.state(); - let survivors = final_state.living_players(); + let survivors = final_state.state().living_players(); let winner = if survivors.len() == 1 { Some(survivors[0]) } else { diff --git a/planetwars-matchrunner/src/match_context.rs b/planetwars-matchrunner/src/match_context.rs index 859b11d..bdc87a3 100644 --- a/planetwars-matchrunner/src/match_context.rs +++ b/planetwars-matchrunner/src/match_context.rs @@ -7,6 +7,7 @@ use std::{ collections::HashMap, sync::{Arc, Mutex}, }; +use tokio::task::JoinHandle; use crate::match_log::{MatchLogMessage, MatchLogger}; @@ -71,10 +72,19 @@ impl MatchCtx { pub fn log(&mut self, message: MatchLogMessage) { self.match_logger.send(message).expect("write failed"); } + + pub async fn shutdown(self) { + let join_handles = self + .players + .into_iter() + .map(|(_player_id, player_data)| player_data.handle.into_join_handle()); + futures::future::join_all(join_handles).await; + } } pub trait PlayerHandle: Send { fn send_request(&mut self, r: RequestMessage); + fn into_join_handle(self: Box) -> JoinHandle<()>; } struct PlayerData { diff --git a/planetwars-matchrunner/src/pw_match.rs b/planetwars-matchrunner/src/pw_match.rs index 4af215e..c5650f7 100644 --- a/planetwars-matchrunner/src/pw_match.rs +++ b/planetwars-matchrunner/src/pw_match.rs @@ -39,7 +39,7 @@ impl PwMatch { } } - pub async fn run(&mut self) { + pub async fn run(mut self) -> PlanetWars { // log initial state self.log_game_state(); @@ -53,6 +53,9 @@ impl PwMatch { self.match_state.step(); self.log_game_state(); } + + self.match_ctx.shutdown().await; + self.match_state } async fn prompt_players(&mut self) -> Vec<(usize, RequestResult>)> { diff --git a/planetwars-matchrunner/tests/test_matchrunner.rs b/planetwars-matchrunner/tests/test_matchrunner.rs index 5ac8649..5d2fc83 100644 --- a/planetwars-matchrunner/tests/test_matchrunner.rs +++ b/planetwars-matchrunner/tests/test_matchrunner.rs @@ -70,5 +70,6 @@ async fn docker_runner_timeout() { .request(1, b"sup".to_vec(), Duration::from_millis(1000)) .await; - assert_eq!(resp, Err(RequestError::Timeout)) + assert_eq!(resp, Err(RequestError::Timeout)); + ctx.shutdown().await; }