implement MLE ranker
This commit is contained in:
parent
ef19e3a9e7
commit
f899fba8ad
1 changed files with 179 additions and 30 deletions
|
@ -2,14 +2,14 @@ use crate::{db::bots::Bot, DbPool};
|
|||
|
||||
use crate::db;
|
||||
use crate::modules::matches::RunMatch;
|
||||
use diesel::{PgConnection, QueryResult};
|
||||
use rand::seq::SliceRandom;
|
||||
use std::time::Duration;
|
||||
use std::collections::HashMap;
|
||||
use std::mem;
|
||||
use std::time::{Duration, Instant};
|
||||
use tokio;
|
||||
|
||||
const RANKER_INTERVAL: u64 = 60;
|
||||
const START_RATING: f64 = 0.0;
|
||||
const SCALE: f64 = 100.0;
|
||||
const MAX_UPDATE: f64 = 0.1;
|
||||
|
||||
pub async fn run_ranker(db_pool: DbPool) {
|
||||
// TODO: make this configurable
|
||||
|
@ -31,6 +31,7 @@ pub async fn run_ranker(db_pool: DbPool) {
|
|||
bots.choose_multiple(&mut rng, 2).cloned().collect()
|
||||
};
|
||||
play_ranking_match(selected_bots, db_pool.clone()).await;
|
||||
recalculate_ratings(&db_conn).expect("could not recalculate ratings");
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -43,45 +44,193 @@ async fn play_ranking_match(selected_bots: Vec<Bot>, db_pool: DbPool) {
|
|||
code_bundles.push(code_bundle);
|
||||
}
|
||||
|
||||
let code_bundle_refs = code_bundles.iter().map(|b| b).collect::<Vec<_>>();
|
||||
let code_bundle_refs = code_bundles.iter().collect::<Vec<_>>();
|
||||
|
||||
let mut run_match = RunMatch::from_players(code_bundle_refs);
|
||||
run_match
|
||||
.store_in_database(&db_conn)
|
||||
.expect("could not store match in db");
|
||||
let outcome = run_match
|
||||
run_match
|
||||
.spawn(db_pool.clone())
|
||||
.await
|
||||
.expect("running match failed");
|
||||
|
||||
let mut ratings = Vec::new();
|
||||
for bot in &selected_bots {
|
||||
let rating = db::ratings::get_rating(bot.id, &db_conn)
|
||||
.expect("could not get bot rating")
|
||||
.unwrap_or(START_RATING);
|
||||
ratings.push(rating);
|
||||
}
|
||||
|
||||
// simple elo rating
|
||||
fn recalculate_ratings(db_conn: &PgConnection) -> QueryResult<()> {
|
||||
let start = Instant::now();
|
||||
let match_stats = fetch_match_stats(db_conn)?;
|
||||
let ratings = estimate_ratings_from_stats(match_stats);
|
||||
|
||||
let scores = match outcome.winner {
|
||||
None => vec![0.5; 2],
|
||||
Some(player_num) => {
|
||||
// TODO: please get rid of this offset
|
||||
let player_ix = player_num - 1;
|
||||
let mut scores = vec![0.0; 2];
|
||||
scores[player_ix] = 1.0;
|
||||
scores
|
||||
for (bot_id, rating) in ratings {
|
||||
db::ratings::set_rating(bot_id, rating, db_conn).expect("could not update bot rating");
|
||||
}
|
||||
let elapsed = Instant::now() - start;
|
||||
println!("computed ratings in {} ms", elapsed.subsec_millis());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
struct MatchStats {
|
||||
total_score: f64,
|
||||
num_matches: usize,
|
||||
}
|
||||
|
||||
fn fetch_match_stats(db_conn: &PgConnection) -> QueryResult<HashMap<(i32, i32), MatchStats>> {
|
||||
let matches = db::matches::list_matches(db_conn)?;
|
||||
|
||||
let mut match_stats = HashMap::<(i32, i32), MatchStats>::new();
|
||||
for m in matches {
|
||||
if m.match_players.len() != 2 {
|
||||
continue;
|
||||
}
|
||||
let (mut a_id, mut b_id) = match (&m.match_players[0].bot, &m.match_players[1].bot) {
|
||||
(Some(ref a), Some(ref b)) => (a.id, b.id),
|
||||
_ => continue,
|
||||
};
|
||||
// score of player a
|
||||
let mut score = match m.base.winner {
|
||||
None => 0.5,
|
||||
Some(0) => 1.0,
|
||||
Some(1) => 0.0,
|
||||
_ => panic!("invalid winner"),
|
||||
};
|
||||
|
||||
for i in 0..2 {
|
||||
let j = 1 - i;
|
||||
// put players in canonical order: smallest id first
|
||||
if b_id < a_id {
|
||||
mem::swap(&mut a_id, &mut b_id);
|
||||
score = 1.0 - score;
|
||||
}
|
||||
|
||||
let scaled_difference = (ratings[j] - ratings[i]) / SCALE;
|
||||
let expected = 1.0 / (1.0 + 10f64.powf(scaled_difference));
|
||||
let new_rating = ratings[i] + MAX_UPDATE * (scores[i] - expected);
|
||||
db::ratings::set_rating(selected_bots[i].id, new_rating, &db_conn)
|
||||
.expect("could not update bot rating");
|
||||
let entry = match_stats.entry((a_id, b_id)).or_default();
|
||||
entry.num_matches += 1;
|
||||
entry.total_score += score;
|
||||
}
|
||||
Ok(match_stats)
|
||||
}
|
||||
|
||||
/// Tokenizes player ids to a set of consecutive numbers
|
||||
struct PlayerTokenizer {
|
||||
id_to_ix: HashMap<i32, usize>,
|
||||
ids: Vec<i32>,
|
||||
}
|
||||
|
||||
impl PlayerTokenizer {
|
||||
fn new() -> Self {
|
||||
PlayerTokenizer {
|
||||
id_to_ix: HashMap::new(),
|
||||
ids: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn tokenize(&mut self, id: i32) -> usize {
|
||||
match self.id_to_ix.get(&id) {
|
||||
Some(&ix) => ix,
|
||||
None => {
|
||||
let ix = self.ids.len();
|
||||
self.ids.push(id);
|
||||
self.id_to_ix.insert(id, ix);
|
||||
ix
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn detokenize(&self, ix: usize) -> i32 {
|
||||
self.ids[ix]
|
||||
}
|
||||
|
||||
fn player_count(&self) -> usize {
|
||||
self.ids.len()
|
||||
}
|
||||
}
|
||||
|
||||
fn sigmoid(logit: f64) -> f64 {
|
||||
1.0 / (1.0 + (-logit).exp())
|
||||
}
|
||||
|
||||
fn estimate_ratings_from_stats(match_stats: HashMap<(i32, i32), MatchStats>) -> Vec<(i32, f64)> {
|
||||
// map player ids to player indexes in the ratings array
|
||||
let mut input_records = Vec::<RatingInputRecord>::with_capacity(match_stats.len());
|
||||
let mut player_tokenizer = PlayerTokenizer::new();
|
||||
|
||||
for ((a_id, b_id), stats) in match_stats.into_iter() {
|
||||
input_records.push(RatingInputRecord {
|
||||
p1_ix: player_tokenizer.tokenize(a_id),
|
||||
p2_ix: player_tokenizer.tokenize(b_id),
|
||||
score: stats.total_score / stats.num_matches as f64,
|
||||
})
|
||||
}
|
||||
|
||||
let mut ratings = vec![0f64; player_tokenizer.player_count()];
|
||||
optimize_ratings(&mut ratings, &input_records);
|
||||
|
||||
ratings
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.map(|(ix, rating)| {
|
||||
(
|
||||
player_tokenizer.detokenize(ix),
|
||||
rating * 100f64 / 10f64.ln(),
|
||||
)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
struct RatingInputRecord {
|
||||
/// index of first player
|
||||
p1_ix: usize,
|
||||
/// index of secord player
|
||||
p2_ix: usize,
|
||||
/// score of player 1 (= 1 - score of player 2)
|
||||
score: f64,
|
||||
}
|
||||
|
||||
fn optimize_ratings(ratings: &mut [f64], input_records: &[RatingInputRecord]) {
|
||||
// TODO: group this in a params struct
|
||||
let tolerance = 10f64.powi(-6);
|
||||
let learning_rate = 0.1;
|
||||
let max_iterations = 10000;
|
||||
|
||||
for _iteration in 0..max_iterations {
|
||||
let mut gradients = vec![0f64; ratings.len()];
|
||||
|
||||
// calculate gradients
|
||||
for record in input_records.iter() {
|
||||
let predicted = sigmoid(ratings[record.p1_ix] - ratings[record.p2_ix]);
|
||||
let gradient = predicted - record.score;
|
||||
gradients[record.p1_ix] += gradient;
|
||||
gradients[record.p2_ix] -= gradient;
|
||||
}
|
||||
|
||||
// apply update step
|
||||
let mut converged = true;
|
||||
for (rating, gradient) in ratings.iter_mut().zip(&gradients) {
|
||||
let update = learning_rate * gradient / input_records.len() as f64;
|
||||
if update > tolerance {
|
||||
converged = false;
|
||||
}
|
||||
*rating -= update;
|
||||
}
|
||||
|
||||
if converged {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_optimize_ratings() {
|
||||
let input_records = vec![RatingInputRecord {
|
||||
p1_ix: 0,
|
||||
p2_ix: 1,
|
||||
score: 0.8,
|
||||
}];
|
||||
|
||||
let mut ratings = vec![0.0; 2];
|
||||
optimize_ratings(&mut ratings, &input_records);
|
||||
assert!(sigmoid(ratings[0] - ratings[1]) - 0.8 < 10f64.powi(-6));
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue