Improve UCB1-Tuned performance
We calculate the sample variance of the rewards online storing the value in the node. This greatly reduces the amount of summations that need to be done to calculate the variance during the selection phase. While this burdens other selection algorithms, the cost is not substantial.
This commit is contained in:
parent
44ef9ebdd8
commit
854772a63d
@ -36,15 +36,15 @@ pub fn select_best_child<S: GameState>(
|
|||||||
exploration_constant,
|
exploration_constant,
|
||||||
parent_visits,
|
parent_visits,
|
||||||
node_a.visits,
|
node_a.visits,
|
||||||
node_a.rewards(player),
|
|
||||||
node_a.reward_average(player),
|
node_a.reward_average(player),
|
||||||
|
node_a.reward_variance(player),
|
||||||
);
|
);
|
||||||
let ucb_b = ucb1_tuned_value(
|
let ucb_b = ucb1_tuned_value(
|
||||||
exploration_constant,
|
exploration_constant,
|
||||||
parent_visits,
|
parent_visits,
|
||||||
node_b.visits,
|
node_b.visits,
|
||||||
node_b.rewards(player),
|
|
||||||
node_b.reward_average(player),
|
node_b.reward_average(player),
|
||||||
|
node_b.reward_variance(player),
|
||||||
);
|
);
|
||||||
ucb_a.partial_cmp(&ucb_b).unwrap()
|
ucb_a.partial_cmp(&ucb_b).unwrap()
|
||||||
})
|
})
|
||||||
@ -57,14 +57,9 @@ pub fn ucb1_tuned_value(
|
|||||||
exploration_constant: f64,
|
exploration_constant: f64,
|
||||||
parent_visits: u64,
|
parent_visits: u64,
|
||||||
child_visits: u64,
|
child_visits: u64,
|
||||||
child_rewards: Option<&Vec<RewardVal>>,
|
|
||||||
reward_avg: RewardVal,
|
reward_avg: RewardVal,
|
||||||
|
reward_variance: f64,
|
||||||
) -> RewardVal {
|
) -> RewardVal {
|
||||||
match child_rewards {
|
|
||||||
None => {
|
|
||||||
RewardVal::INFINITY // Always explore nodes that have never been visited
|
|
||||||
}
|
|
||||||
Some(child_rewards) => {
|
|
||||||
if child_visits == 0 {
|
if child_visits == 0 {
|
||||||
RewardVal::INFINITY // Always explore nodes that have never been visited
|
RewardVal::INFINITY // Always explore nodes that have never been visited
|
||||||
} else {
|
} else {
|
||||||
@ -80,10 +75,7 @@ pub fn ucb1_tuned_value(
|
|||||||
// UCB1-Tuned = X + C * sqrt(Ln(parent_visits) / child_visits * min(1/4, V_n)
|
// UCB1-Tuned = X + C * sqrt(Ln(parent_visits) / child_visits * min(1/4, V_n)
|
||||||
// V(n) = sum(x_i^2)/n - X^2 + sqrt(2*ln(N)/n)
|
// V(n) = sum(x_i^2)/n - X^2 + sqrt(2*ln(N)/n)
|
||||||
let exploitation = reward_avg;
|
let exploitation = reward_avg;
|
||||||
let mut variance = (child_rewards.iter().map(|&x| x * x).sum::<RewardVal>()
|
let mut variance = reward_variance + (2.0 * parent_visits.ln() / child_visits).sqrt();
|
||||||
/ child_visits)
|
|
||||||
- (reward_avg * reward_avg)
|
|
||||||
+ (2.0 * parent_visits.ln() / child_visits).sqrt();
|
|
||||||
if variance > 0.25 {
|
if variance > 0.25 {
|
||||||
variance = 0.25;
|
variance = 0.25;
|
||||||
}
|
}
|
||||||
@ -93,5 +85,3 @@ pub fn ucb1_tuned_value(
|
|||||||
exploitation + exploration
|
exploitation + exploration
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}
|
|
||||||
|
@ -69,10 +69,10 @@ impl<S: GameState> Node<S> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn rewards(&self, player: &S::Player) -> Option<&Vec<RewardVal>> {
|
pub fn reward_variance(&self, player: &S::Player) -> f64 {
|
||||||
match self.player_view.get(player) {
|
match self.player_view.get(player) {
|
||||||
Some(pv) => Some(&pv.rewards),
|
Some(pv) => pv.weighted_variance / self.visits as f64,
|
||||||
None => None,
|
None => 0.0,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -82,9 +82,10 @@ impl<S: GameState> Node<S> {
|
|||||||
|
|
||||||
pub fn record_player_reward(&mut self, player: S::Player, reward: RewardVal) {
|
pub fn record_player_reward(&mut self, player: S::Player, reward: RewardVal) {
|
||||||
let pv = self.player_view.entry(player).or_default();
|
let pv = self.player_view.entry(player).or_default();
|
||||||
pv.rewards.push(reward);
|
let prev_reward_average = pv.reward_average;
|
||||||
pv.reward_sum += reward;
|
pv.reward_sum += reward;
|
||||||
pv.reward_average = pv.reward_sum / pv.rewards.len() as f64;
|
pv.reward_average = pv.reward_sum / self.visits as f64;
|
||||||
|
pv.weighted_variance += (reward - prev_reward_average) * (reward - pv.reward_average);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -98,8 +99,11 @@ pub struct PlayerRewardView {
|
|||||||
/// The average reward from simulations through this node, often called the node value
|
/// The average reward from simulations through this node, often called the node value
|
||||||
pub reward_average: RewardVal,
|
pub reward_average: RewardVal,
|
||||||
|
|
||||||
/// The rewards we have gotten so far for simulations through this node
|
/// The weighted variance from simulations through this node
|
||||||
pub rewards: Vec<RewardVal>,
|
///
|
||||||
|
/// This is used to calculate online sample variance.
|
||||||
|
/// See Donald E. Knuth. Seminumerical Algorithms, volume 2 of The Art of Computer Programming, chapter 4.2.2, page 232
|
||||||
|
pub weighted_variance: f64,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for PlayerRewardView {
|
impl Default for PlayerRewardView {
|
||||||
@ -107,7 +111,7 @@ impl Default for PlayerRewardView {
|
|||||||
PlayerRewardView {
|
PlayerRewardView {
|
||||||
reward_sum: 0.0,
|
reward_sum: 0.0,
|
||||||
reward_average: 0.0,
|
reward_average: 0.0,
|
||||||
rewards: Vec::new(),
|
weighted_variance: 0.0,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user