diff --git a/src/policy/selection/ucb1_tuned.rs b/src/policy/selection/ucb1_tuned.rs index 971b601..04d029b 100644 --- a/src/policy/selection/ucb1_tuned.rs +++ b/src/policy/selection/ucb1_tuned.rs @@ -36,15 +36,15 @@ pub fn select_best_child( exploration_constant, parent_visits, node_a.visits, - node_a.rewards(player), node_a.reward_average(player), + node_a.reward_variance(player), ); let ucb_b = ucb1_tuned_value( exploration_constant, parent_visits, node_b.visits, - node_b.rewards(player), node_b.reward_average(player), + node_b.reward_variance(player), ); ucb_a.partial_cmp(&ucb_b).unwrap() }) @@ -57,41 +57,31 @@ pub fn ucb1_tuned_value( exploration_constant: f64, parent_visits: u64, child_visits: u64, - child_rewards: Option<&Vec>, reward_avg: RewardVal, + reward_variance: f64, ) -> RewardVal { - match child_rewards { - None => { - RewardVal::INFINITY // Always explore nodes that have never been visited - } - Some(child_rewards) => { - if child_visits == 0 { - RewardVal::INFINITY // Always explore nodes that have never been visited - } else { - let parent_visits: RewardVal = parent_visits as RewardVal; - let child_visits: RewardVal = child_visits as RewardVal; + if child_visits == 0 { + RewardVal::INFINITY // Always explore nodes that have never been visited + } else { + let parent_visits: RewardVal = parent_visits as RewardVal; + let child_visits: RewardVal = child_visits as RewardVal; - // N: number of visits to the parent node - // n: number of visits to the child node - // x_i: reward of the ith visit to the child node - // X: average reward of the child - // C: exploration constant - // - // UCB1-Tuned = X + C * sqrt(Ln(parent_visits) / child_visits * min(1/4, V_n) - // V(n) = sum(x_i^2)/n - X^2 + sqrt(2*ln(N)/n) - let exploitation = reward_avg; - let mut variance = (child_rewards.iter().map(|&x| x * x).sum::() - / child_visits) - - (reward_avg * reward_avg) - + (2.0 * parent_visits.ln() / child_visits).sqrt(); - if variance > 0.25 { - variance = 0.25; - } - let exploration = - exploration_constant * (parent_visits.ln() / child_visits * variance).sqrt(); - - exploitation + exploration - } + // N: number of visits to the parent node + // n: number of visits to the child node + // x_i: reward of the ith visit to the child node + // X: average reward of the child + // C: exploration constant + // + // UCB1-Tuned = X + C * sqrt(Ln(parent_visits) / child_visits * min(1/4, V_n) + // V(n) = sum(x_i^2)/n - X^2 + sqrt(2*ln(N)/n) + let exploitation = reward_avg; + let mut variance = reward_variance + (2.0 * parent_visits.ln() / child_visits).sqrt(); + if variance > 0.25 { + variance = 0.25; } + let exploration = + exploration_constant * (parent_visits.ln() / child_visits * variance).sqrt(); + + exploitation + exploration } } diff --git a/src/tree/node.rs b/src/tree/node.rs index 5acc5e0..9a48743 100644 --- a/src/tree/node.rs +++ b/src/tree/node.rs @@ -69,10 +69,10 @@ impl Node { } } - pub fn rewards(&self, player: &S::Player) -> Option<&Vec> { + pub fn reward_variance(&self, player: &S::Player) -> f64 { match self.player_view.get(player) { - Some(pv) => Some(&pv.rewards), - None => None, + Some(pv) => pv.weighted_variance / self.visits as f64, + None => 0.0, } } @@ -82,9 +82,10 @@ impl Node { pub fn record_player_reward(&mut self, player: S::Player, reward: RewardVal) { let pv = self.player_view.entry(player).or_default(); - pv.rewards.push(reward); + let prev_reward_average = pv.reward_average; pv.reward_sum += reward; - pv.reward_average = pv.reward_sum / pv.rewards.len() as f64; + pv.reward_average = pv.reward_sum / self.visits as f64; + pv.weighted_variance += (reward - prev_reward_average) * (reward - pv.reward_average); } } @@ -98,8 +99,11 @@ pub struct PlayerRewardView { /// The average reward from simulations through this node, often called the node value pub reward_average: RewardVal, - /// The rewards we have gotten so far for simulations through this node - pub rewards: Vec, + /// The weighted variance from simulations through this node + /// + /// This is used to calculate online sample variance. + /// See Donald E. Knuth. Seminumerical Algorithms, volume 2 of The Art of Computer Programming, chapter 4.2.2, page 232 + pub weighted_variance: f64, } impl Default for PlayerRewardView { @@ -107,7 +111,7 @@ impl Default for PlayerRewardView { PlayerRewardView { reward_sum: 0.0, reward_average: 0.0, - rewards: Vec::new(), + weighted_variance: 0.0, } } }