Improve UCB1-Tuned performance
We calculate the sample variance of the rewards online storing the value in the node. This greatly reduces the amount of summations that need to be done to calculate the variance during the selection phase. While this burdens other selection algorithms, the cost is not substantial.
This commit is contained in:
parent
44ef9ebdd8
commit
854772a63d
@ -36,15 +36,15 @@ pub fn select_best_child<S: GameState>(
|
|||||||
exploration_constant,
|
exploration_constant,
|
||||||
parent_visits,
|
parent_visits,
|
||||||
node_a.visits,
|
node_a.visits,
|
||||||
node_a.rewards(player),
|
|
||||||
node_a.reward_average(player),
|
node_a.reward_average(player),
|
||||||
|
node_a.reward_variance(player),
|
||||||
);
|
);
|
||||||
let ucb_b = ucb1_tuned_value(
|
let ucb_b = ucb1_tuned_value(
|
||||||
exploration_constant,
|
exploration_constant,
|
||||||
parent_visits,
|
parent_visits,
|
||||||
node_b.visits,
|
node_b.visits,
|
||||||
node_b.rewards(player),
|
|
||||||
node_b.reward_average(player),
|
node_b.reward_average(player),
|
||||||
|
node_b.reward_variance(player),
|
||||||
);
|
);
|
||||||
ucb_a.partial_cmp(&ucb_b).unwrap()
|
ucb_a.partial_cmp(&ucb_b).unwrap()
|
||||||
})
|
})
|
||||||
@ -57,41 +57,31 @@ pub fn ucb1_tuned_value(
|
|||||||
exploration_constant: f64,
|
exploration_constant: f64,
|
||||||
parent_visits: u64,
|
parent_visits: u64,
|
||||||
child_visits: u64,
|
child_visits: u64,
|
||||||
child_rewards: Option<&Vec<RewardVal>>,
|
|
||||||
reward_avg: RewardVal,
|
reward_avg: RewardVal,
|
||||||
|
reward_variance: f64,
|
||||||
) -> RewardVal {
|
) -> RewardVal {
|
||||||
match child_rewards {
|
if child_visits == 0 {
|
||||||
None => {
|
RewardVal::INFINITY // Always explore nodes that have never been visited
|
||||||
RewardVal::INFINITY // Always explore nodes that have never been visited
|
} else {
|
||||||
}
|
let parent_visits: RewardVal = parent_visits as RewardVal;
|
||||||
Some(child_rewards) => {
|
let child_visits: RewardVal = child_visits as RewardVal;
|
||||||
if child_visits == 0 {
|
|
||||||
RewardVal::INFINITY // Always explore nodes that have never been visited
|
|
||||||
} else {
|
|
||||||
let parent_visits: RewardVal = parent_visits as RewardVal;
|
|
||||||
let child_visits: RewardVal = child_visits as RewardVal;
|
|
||||||
|
|
||||||
// N: number of visits to the parent node
|
// N: number of visits to the parent node
|
||||||
// n: number of visits to the child node
|
// n: number of visits to the child node
|
||||||
// x_i: reward of the ith visit to the child node
|
// x_i: reward of the ith visit to the child node
|
||||||
// X: average reward of the child
|
// X: average reward of the child
|
||||||
// C: exploration constant
|
// C: exploration constant
|
||||||
//
|
//
|
||||||
// UCB1-Tuned = X + C * sqrt(Ln(parent_visits) / child_visits * min(1/4, V_n)
|
// UCB1-Tuned = X + C * sqrt(Ln(parent_visits) / child_visits * min(1/4, V_n)
|
||||||
// V(n) = sum(x_i^2)/n - X^2 + sqrt(2*ln(N)/n)
|
// V(n) = sum(x_i^2)/n - X^2 + sqrt(2*ln(N)/n)
|
||||||
let exploitation = reward_avg;
|
let exploitation = reward_avg;
|
||||||
let mut variance = (child_rewards.iter().map(|&x| x * x).sum::<RewardVal>()
|
let mut variance = reward_variance + (2.0 * parent_visits.ln() / child_visits).sqrt();
|
||||||
/ child_visits)
|
if variance > 0.25 {
|
||||||
- (reward_avg * reward_avg)
|
variance = 0.25;
|
||||||
+ (2.0 * parent_visits.ln() / child_visits).sqrt();
|
|
||||||
if variance > 0.25 {
|
|
||||||
variance = 0.25;
|
|
||||||
}
|
|
||||||
let exploration =
|
|
||||||
exploration_constant * (parent_visits.ln() / child_visits * variance).sqrt();
|
|
||||||
|
|
||||||
exploitation + exploration
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
let exploration =
|
||||||
|
exploration_constant * (parent_visits.ln() / child_visits * variance).sqrt();
|
||||||
|
|
||||||
|
exploitation + exploration
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -69,10 +69,10 @@ impl<S: GameState> Node<S> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn rewards(&self, player: &S::Player) -> Option<&Vec<RewardVal>> {
|
pub fn reward_variance(&self, player: &S::Player) -> f64 {
|
||||||
match self.player_view.get(player) {
|
match self.player_view.get(player) {
|
||||||
Some(pv) => Some(&pv.rewards),
|
Some(pv) => pv.weighted_variance / self.visits as f64,
|
||||||
None => None,
|
None => 0.0,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -82,9 +82,10 @@ impl<S: GameState> Node<S> {
|
|||||||
|
|
||||||
pub fn record_player_reward(&mut self, player: S::Player, reward: RewardVal) {
|
pub fn record_player_reward(&mut self, player: S::Player, reward: RewardVal) {
|
||||||
let pv = self.player_view.entry(player).or_default();
|
let pv = self.player_view.entry(player).or_default();
|
||||||
pv.rewards.push(reward);
|
let prev_reward_average = pv.reward_average;
|
||||||
pv.reward_sum += reward;
|
pv.reward_sum += reward;
|
||||||
pv.reward_average = pv.reward_sum / pv.rewards.len() as f64;
|
pv.reward_average = pv.reward_sum / self.visits as f64;
|
||||||
|
pv.weighted_variance += (reward - prev_reward_average) * (reward - pv.reward_average);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -98,8 +99,11 @@ pub struct PlayerRewardView {
|
|||||||
/// The average reward from simulations through this node, often called the node value
|
/// The average reward from simulations through this node, often called the node value
|
||||||
pub reward_average: RewardVal,
|
pub reward_average: RewardVal,
|
||||||
|
|
||||||
/// The rewards we have gotten so far for simulations through this node
|
/// The weighted variance from simulations through this node
|
||||||
pub rewards: Vec<RewardVal>,
|
///
|
||||||
|
/// This is used to calculate online sample variance.
|
||||||
|
/// See Donald E. Knuth. Seminumerical Algorithms, volume 2 of The Art of Computer Programming, chapter 4.2.2, page 232
|
||||||
|
pub weighted_variance: f64,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for PlayerRewardView {
|
impl Default for PlayerRewardView {
|
||||||
@ -107,7 +111,7 @@ impl Default for PlayerRewardView {
|
|||||||
PlayerRewardView {
|
PlayerRewardView {
|
||||||
reward_sum: 0.0,
|
reward_sum: 0.0,
|
||||||
reward_average: 0.0,
|
reward_average: 0.0,
|
||||||
rewards: Vec::new(),
|
weighted_variance: 0.0,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user