From 854772a63dd28c642b74997bc99929833aed384a Mon Sep 17 00:00:00 2001 From: David Kruger Date: Sun, 29 Jun 2025 18:43:52 -0700 Subject: [PATCH] Improve UCB1-Tuned performance We calculate the sample variance of the rewards online storing the value in the node. This greatly reduces the amount of summations that need to be done to calculate the variance during the selection phase. While this burdens other selection algorithms, the cost is not substantial. --- src/policy/selection/ucb1_tuned.rs | 58 +++++++++++++----------------- src/tree/node.rs | 20 ++++++----- 2 files changed, 36 insertions(+), 42 deletions(-) diff --git a/src/policy/selection/ucb1_tuned.rs b/src/policy/selection/ucb1_tuned.rs index 971b601..04d029b 100644 --- a/src/policy/selection/ucb1_tuned.rs +++ b/src/policy/selection/ucb1_tuned.rs @@ -36,15 +36,15 @@ pub fn select_best_child( exploration_constant, parent_visits, node_a.visits, - node_a.rewards(player), node_a.reward_average(player), + node_a.reward_variance(player), ); let ucb_b = ucb1_tuned_value( exploration_constant, parent_visits, node_b.visits, - node_b.rewards(player), node_b.reward_average(player), + node_b.reward_variance(player), ); ucb_a.partial_cmp(&ucb_b).unwrap() }) @@ -57,41 +57,31 @@ pub fn ucb1_tuned_value( exploration_constant: f64, parent_visits: u64, child_visits: u64, - child_rewards: Option<&Vec>, reward_avg: RewardVal, + reward_variance: f64, ) -> RewardVal { - match child_rewards { - None => { - RewardVal::INFINITY // Always explore nodes that have never been visited - } - Some(child_rewards) => { - if child_visits == 0 { - RewardVal::INFINITY // Always explore nodes that have never been visited - } else { - let parent_visits: RewardVal = parent_visits as RewardVal; - let child_visits: RewardVal = child_visits as RewardVal; + if child_visits == 0 { + RewardVal::INFINITY // Always explore nodes that have never been visited + } else { + let parent_visits: RewardVal = parent_visits as RewardVal; + let child_visits: RewardVal = child_visits as RewardVal; - // N: number of visits to the parent node - // n: number of visits to the child node - // x_i: reward of the ith visit to the child node - // X: average reward of the child - // C: exploration constant - // - // UCB1-Tuned = X + C * sqrt(Ln(parent_visits) / child_visits * min(1/4, V_n) - // V(n) = sum(x_i^2)/n - X^2 + sqrt(2*ln(N)/n) - let exploitation = reward_avg; - let mut variance = (child_rewards.iter().map(|&x| x * x).sum::() - / child_visits) - - (reward_avg * reward_avg) - + (2.0 * parent_visits.ln() / child_visits).sqrt(); - if variance > 0.25 { - variance = 0.25; - } - let exploration = - exploration_constant * (parent_visits.ln() / child_visits * variance).sqrt(); - - exploitation + exploration - } + // N: number of visits to the parent node + // n: number of visits to the child node + // x_i: reward of the ith visit to the child node + // X: average reward of the child + // C: exploration constant + // + // UCB1-Tuned = X + C * sqrt(Ln(parent_visits) / child_visits * min(1/4, V_n) + // V(n) = sum(x_i^2)/n - X^2 + sqrt(2*ln(N)/n) + let exploitation = reward_avg; + let mut variance = reward_variance + (2.0 * parent_visits.ln() / child_visits).sqrt(); + if variance > 0.25 { + variance = 0.25; } + let exploration = + exploration_constant * (parent_visits.ln() / child_visits * variance).sqrt(); + + exploitation + exploration } } diff --git a/src/tree/node.rs b/src/tree/node.rs index 5acc5e0..9a48743 100644 --- a/src/tree/node.rs +++ b/src/tree/node.rs @@ -69,10 +69,10 @@ impl Node { } } - pub fn rewards(&self, player: &S::Player) -> Option<&Vec> { + pub fn reward_variance(&self, player: &S::Player) -> f64 { match self.player_view.get(player) { - Some(pv) => Some(&pv.rewards), - None => None, + Some(pv) => pv.weighted_variance / self.visits as f64, + None => 0.0, } } @@ -82,9 +82,10 @@ impl Node { pub fn record_player_reward(&mut self, player: S::Player, reward: RewardVal) { let pv = self.player_view.entry(player).or_default(); - pv.rewards.push(reward); + let prev_reward_average = pv.reward_average; pv.reward_sum += reward; - pv.reward_average = pv.reward_sum / pv.rewards.len() as f64; + pv.reward_average = pv.reward_sum / self.visits as f64; + pv.weighted_variance += (reward - prev_reward_average) * (reward - pv.reward_average); } } @@ -98,8 +99,11 @@ pub struct PlayerRewardView { /// The average reward from simulations through this node, often called the node value pub reward_average: RewardVal, - /// The rewards we have gotten so far for simulations through this node - pub rewards: Vec, + /// The weighted variance from simulations through this node + /// + /// This is used to calculate online sample variance. + /// See Donald E. Knuth. Seminumerical Algorithms, volume 2 of The Art of Computer Programming, chapter 4.2.2, page 232 + pub weighted_variance: f64, } impl Default for PlayerRewardView { @@ -107,7 +111,7 @@ impl Default for PlayerRewardView { PlayerRewardView { reward_sum: 0.0, reward_average: 0.0, - rewards: Vec::new(), + weighted_variance: 0.0, } } }