Improve UCB1-Tuned performance

We calculate the sample variance of the rewards online storing the value
in the node. This greatly reduces the amount of summations that need to
be done to calculate the variance during the selection phase.

While this burdens other selection algorithms, the cost is not
substantial.
This commit is contained in:
David Kruger 2025-06-29 18:43:52 -07:00
parent 44ef9ebdd8
commit 854772a63d
2 changed files with 36 additions and 42 deletions

View File

@ -36,15 +36,15 @@ pub fn select_best_child<S: GameState>(
exploration_constant, exploration_constant,
parent_visits, parent_visits,
node_a.visits, node_a.visits,
node_a.rewards(player),
node_a.reward_average(player), node_a.reward_average(player),
node_a.reward_variance(player),
); );
let ucb_b = ucb1_tuned_value( let ucb_b = ucb1_tuned_value(
exploration_constant, exploration_constant,
parent_visits, parent_visits,
node_b.visits, node_b.visits,
node_b.rewards(player),
node_b.reward_average(player), node_b.reward_average(player),
node_b.reward_variance(player),
); );
ucb_a.partial_cmp(&ucb_b).unwrap() ucb_a.partial_cmp(&ucb_b).unwrap()
}) })
@ -57,41 +57,31 @@ pub fn ucb1_tuned_value(
exploration_constant: f64, exploration_constant: f64,
parent_visits: u64, parent_visits: u64,
child_visits: u64, child_visits: u64,
child_rewards: Option<&Vec<RewardVal>>,
reward_avg: RewardVal, reward_avg: RewardVal,
reward_variance: f64,
) -> RewardVal { ) -> RewardVal {
match child_rewards { if child_visits == 0 {
None => { RewardVal::INFINITY // Always explore nodes that have never been visited
RewardVal::INFINITY // Always explore nodes that have never been visited } else {
} let parent_visits: RewardVal = parent_visits as RewardVal;
Some(child_rewards) => { let child_visits: RewardVal = child_visits as RewardVal;
if child_visits == 0 {
RewardVal::INFINITY // Always explore nodes that have never been visited
} else {
let parent_visits: RewardVal = parent_visits as RewardVal;
let child_visits: RewardVal = child_visits as RewardVal;
// N: number of visits to the parent node // N: number of visits to the parent node
// n: number of visits to the child node // n: number of visits to the child node
// x_i: reward of the ith visit to the child node // x_i: reward of the ith visit to the child node
// X: average reward of the child // X: average reward of the child
// C: exploration constant // C: exploration constant
// //
// UCB1-Tuned = X + C * sqrt(Ln(parent_visits) / child_visits * min(1/4, V_n) // UCB1-Tuned = X + C * sqrt(Ln(parent_visits) / child_visits * min(1/4, V_n)
// V(n) = sum(x_i^2)/n - X^2 + sqrt(2*ln(N)/n) // V(n) = sum(x_i^2)/n - X^2 + sqrt(2*ln(N)/n)
let exploitation = reward_avg; let exploitation = reward_avg;
let mut variance = (child_rewards.iter().map(|&x| x * x).sum::<RewardVal>() let mut variance = reward_variance + (2.0 * parent_visits.ln() / child_visits).sqrt();
/ child_visits) if variance > 0.25 {
- (reward_avg * reward_avg) variance = 0.25;
+ (2.0 * parent_visits.ln() / child_visits).sqrt();
if variance > 0.25 {
variance = 0.25;
}
let exploration =
exploration_constant * (parent_visits.ln() / child_visits * variance).sqrt();
exploitation + exploration
}
} }
let exploration =
exploration_constant * (parent_visits.ln() / child_visits * variance).sqrt();
exploitation + exploration
} }
} }

View File

@ -69,10 +69,10 @@ impl<S: GameState> Node<S> {
} }
} }
pub fn rewards(&self, player: &S::Player) -> Option<&Vec<RewardVal>> { pub fn reward_variance(&self, player: &S::Player) -> f64 {
match self.player_view.get(player) { match self.player_view.get(player) {
Some(pv) => Some(&pv.rewards), Some(pv) => pv.weighted_variance / self.visits as f64,
None => None, None => 0.0,
} }
} }
@ -82,9 +82,10 @@ impl<S: GameState> Node<S> {
pub fn record_player_reward(&mut self, player: S::Player, reward: RewardVal) { pub fn record_player_reward(&mut self, player: S::Player, reward: RewardVal) {
let pv = self.player_view.entry(player).or_default(); let pv = self.player_view.entry(player).or_default();
pv.rewards.push(reward); let prev_reward_average = pv.reward_average;
pv.reward_sum += reward; pv.reward_sum += reward;
pv.reward_average = pv.reward_sum / pv.rewards.len() as f64; pv.reward_average = pv.reward_sum / self.visits as f64;
pv.weighted_variance += (reward - prev_reward_average) * (reward - pv.reward_average);
} }
} }
@ -98,8 +99,11 @@ pub struct PlayerRewardView {
/// The average reward from simulations through this node, often called the node value /// The average reward from simulations through this node, often called the node value
pub reward_average: RewardVal, pub reward_average: RewardVal,
/// The rewards we have gotten so far for simulations through this node /// The weighted variance from simulations through this node
pub rewards: Vec<RewardVal>, ///
/// This is used to calculate online sample variance.
/// See Donald E. Knuth. Seminumerical Algorithms, volume 2 of The Art of Computer Programming, chapter 4.2.2, page 232
pub weighted_variance: f64,
} }
impl Default for PlayerRewardView { impl Default for PlayerRewardView {
@ -107,7 +111,7 @@ impl Default for PlayerRewardView {
PlayerRewardView { PlayerRewardView {
reward_sum: 0.0, reward_sum: 0.0,
reward_average: 0.0, reward_average: 0.0,
rewards: Vec::new(), weighted_variance: 0.0,
} }
} }
} }