diff --git a/planetwars-server/src/modules/client_api.rs b/planetwars-server/src/modules/client_api.rs index 3402964..d58e1bc 100644 --- a/planetwars-server/src/modules/client_api.rs +++ b/planetwars-server/src/modules/client_api.rs @@ -33,10 +33,26 @@ pub struct ClientApiServer { router: PlayerRouter, } +type ClientMessages = Streaming; +type ServerMessages = mpsc::UnboundedReceiver>; + +enum PlayerConnectionState { + Reserved, + ClientConnected { + tx: oneshot::Sender, + client_messages: ClientMessages, + }, + ServerConnected { + tx: oneshot::Sender, + server_messages: ServerMessages, + }, + // In connected state, the connection is removed from the PlayerRouter +} + /// Routes players to their handler #[derive(Clone)] struct PlayerRouter { - routing_table: Arc>>, + routing_table: Arc>>, } impl PlayerRouter { @@ -55,12 +71,12 @@ impl Default for PlayerRouter { // TODO: implement a way to expire entries impl PlayerRouter { - fn put(&self, player_key: String, entry: SyncThingData) { + fn put(&self, player_key: String, entry: PlayerConnectionState) { let mut routing_table = self.routing_table.lock().unwrap(); routing_table.insert(player_key, entry); } - fn take(&self, player_key: &str) -> Option { + fn take(&self, player_key: &str) -> Option { // TODO: this design does not allow for reconnects. Is this desired? let mut routing_table = self.routing_table.lock().unwrap(); routing_table.remove(player_key) @@ -81,21 +97,62 @@ impl pb::client_api_service_server::ClientApiService for ClientApiServer { .get("player_key") .ok_or_else(|| Status::unauthenticated("no player_key provided"))?; - let player_key_str = player_key + let player_key_string = player_key .to_str() - .map_err(|_| Status::invalid_argument("unreadable string"))?; + .map_err(|_| Status::invalid_argument("unreadable string"))? + .to_string(); - let sync_data = self - .router - .take(player_key_str) - .ok_or_else(|| Status::not_found("player_key not found"))?; + let client_messages = req.into_inner(); - let stream = req.into_inner(); + enum ConnState { + Connected { + server_messages: ServerMessages, + }, + Awaiting { + rx: oneshot::Receiver, + }, + } - sync_data.tx.send(stream).unwrap(); - Ok(Response::new(UnboundedReceiverStream::new( - sync_data.server_messages, - ))) + let conn_state = { + // during this block, a lack is held on the routing table + + let mut routing_table = self.router.routing_table.lock().unwrap(); + let connection_state = routing_table + .remove(&player_key_string) + .ok_or_else(|| Status::not_found("player_key not found"))?; + match connection_state { + PlayerConnectionState::Reserved => { + let (tx, rx) = oneshot::channel(); + + routing_table.insert( + player_key_string, + PlayerConnectionState::ClientConnected { + tx, + client_messages, + }, + ); + + ConnState::Awaiting { rx } + } + PlayerConnectionState::ServerConnected { + tx, + server_messages, + } => { + tx.send(client_messages).unwrap(); + ConnState::Connected { server_messages } + } + PlayerConnectionState::ClientConnected { .. } => panic!("player already connected"), + } + }; + + let server_messages = match conn_state { + ConnState::Connected { server_messages } => server_messages, + ConnState::Awaiting { rx } => rx + .await + .map_err(|_| Status::internal("failed to connect player to game"))?, + }; + + Ok(Response::new(UnboundedReceiverStream::new(server_messages))) } async fn create_match( @@ -119,6 +176,9 @@ impl pb::client_api_service_server::ClientApiService for ClientApiServer { .map_err(|_| Status::not_found("map not found"))?; let player_key = gen_alphanumeric(32); + // ensure that the player key is registered in the router when we send a response + self.router + .put(player_key.clone(), PlayerConnectionState::Reserved); let remote_bot_spec = Box::new(RemoteBotSpec { player_key: player_key.clone(), @@ -155,12 +215,6 @@ impl pb::client_api_service_server::ClientApiService for ClientApiServer { } } -// TODO: please rename me -struct SyncThingData { - tx: oneshot::Sender>, - server_messages: mpsc::UnboundedReceiver>, -} - struct RemoteBotSpec { player_key: String, router: PlayerRouter, @@ -174,32 +228,71 @@ impl runner::BotSpec for RemoteBotSpec { event_bus: Arc>, _match_logger: MatchLogger, ) -> Box { - let (tx, rx) = oneshot::channel(); let (server_msg_snd, server_msg_recv) = mpsc::unbounded_channel(); - self.router.put( - self.player_key.clone(), - SyncThingData { - tx, - server_messages: server_msg_recv, - }, - ); - let fut = tokio::time::timeout(Duration::from_secs(10), rx); - match fut.await { - Ok(Ok(client_messages)) => { - // let client_messages = rx.await.unwrap(); - tokio::spawn(handle_bot_messages( - player_id, - event_bus.clone(), + enum ConnState { + Connected { + client_messages: ClientMessages, + }, + Awaiting { + rx: oneshot::Receiver, + }, + } + + let conn_state = { + // during this block, we hold a lock on the routing table. + + let mut routing_table = self.router.routing_table.lock().unwrap(); + let connection_state = routing_table + .remove(&self.player_key) + .expect("player key not found in routing table"); + + match connection_state { + PlayerConnectionState::Reserved => { + let (tx, rx) = oneshot::channel(); + routing_table.insert( + self.player_key.clone(), + PlayerConnectionState::ServerConnected { + tx, + server_messages: server_msg_recv, + }, + ); + ConnState::Awaiting { rx } + } + PlayerConnectionState::ClientConnected { + tx, client_messages, - )); - } - _ => { - // ensure router cleanup - self.router.take(&self.player_key); + } => { + tx.send(server_msg_recv).unwrap(); + ConnState::Connected { client_messages } + } + PlayerConnectionState::ServerConnected { .. } => panic!("server already connected"), } }; + let maybe_client_messages = match conn_state { + ConnState::Connected { client_messages } => Some(client_messages), + ConnState::Awaiting { rx } => { + let fut = tokio::time::timeout(Duration::from_secs(10), rx); + match fut.await { + Ok(Ok(client_messages)) => Some(client_messages), + _ => { + // ensure router cleanup + self.router.take(&self.player_key); + None + } + } + } + }; + + if let Some(client_messages) = maybe_client_messages { + tokio::spawn(handle_bot_messages( + player_id, + event_bus.clone(), + client_messages, + )); + } + // If the player did not connect, the receiving half of `sender` // will be dropped here, resulting in a time-out for every turn. // This is fine for now, but @@ -217,7 +310,7 @@ async fn handle_bot_messages( event_bus: Arc>, mut messages: Streaming, ) { - // TODO: can this be writte nmore nicely? + // TODO: can this be written more nicely? while let Some(message) = messages.message().await.unwrap() { match message.client_message { Some(pb::PlayerApiClientMessageType::Action(resp)) => {