WeightedBackfillPropagation is implemented

This commit is contained in:
David Kruger 2025-06-27 14:06:19 -07:00
parent 37b1f56f74
commit 6cc6e6a7ba
3 changed files with 38 additions and 12 deletions

View File

@ -29,7 +29,7 @@ impl<'conf, S: GameState + std::fmt::Debug> MCTS<'conf, S> {
/// Creates a new instance with the given initial state and configuration
pub fn new(initial_state: S, config: &'conf MCTSConfig<S>) -> Self {
let mut arena: Arena<S> = Arena::new(config.tree_size_allocation);
let root: Node<S> = Node::new(initial_state.clone(), None, None);
let root: Node<S> = Node::new(initial_state.clone(), None, None, 0);
let root_id: usize = arena.add_node(root);
MCTS {
arena,
@ -100,9 +100,10 @@ impl<'conf, S: GameState + std::fmt::Debug> MCTS<'conf, S> {
let parent: &Node<S> = self.arena.get_node(id);
let legal_actions: Vec<S::Action> = parent.state.get_legal_actions();
let parent_state: S = parent.state.clone();
let depth: usize = parent.depth + 1;
for action in legal_actions {
let state = parent_state.state_after_action(&action);
let new_node = Node::new(state, Some(action), Some(id));
let new_node = Node::new(state, Some(action), Some(id), depth);
let new_id = self.arena.add_node(new_node);
self.arena.get_node_mut(id).children.push(new_id);
}

View File

@ -66,10 +66,6 @@ fn standard_backprop<S: GameState>(
arena: &mut Arena<S>,
rewards: &HashMap<S::Player, RewardVal>,
) {
// TODO:
// - each node needs the perspective of the different players not just one view
// - e.g. reward_sum(player), reward_avg(player), rewards(player)[], visits(player)
// - we could make special version for 2-player zero-sum games like below
let mut current_id: usize = node_id;
loop {
let node = arena.get_node_mut(current_id);
@ -90,10 +86,35 @@ fn standard_backprop<S: GameState>(
}
fn weighted_backprop<S: GameState>(
_depth_factor: f64,
_node_id: usize,
_arena: &mut Arena<S>,
_rewards: &HashMap<S::Player, RewardVal>,
depth_factor: f64,
node_id: usize,
arena: &mut Arena<S>,
rewards: &HashMap<S::Player, RewardVal>,
) {
// TODO
let mut current_id: usize = node_id;
loop {
let node = arena.get_node_mut(current_id);
let player = node.state.get_current_player().clone();
let weight = weight_for_depth(depth_factor, node.depth);
match rewards.get(&player) {
Some(reward) => {
node.increment_visits();
node.record_player_reward(player, (*reward) * weight);
if let Some(parent_id) = node.parent {
current_id = parent_id;
} else {
break;
}
}
None => (),
}
}
}
/// Calculate the weight based on the current depth
///
/// The weight can be multiplied agains the reward to provide more or less
/// influence based on depth as determined by the depth factor
fn weight_for_depth(depth_factor: f64, depth: usize) -> f64 {
1.0 / (1.0 + depth_factor * depth as f64)
}

View File

@ -32,18 +32,22 @@ pub struct Node<S: GameState> {
/// The player's evaluation of the node
pub player_view: HashMap<S::Player, PlayerNodeView>,
/// The depth of the node in the tree, this is 0 for the root node
pub depth: usize,
/// The identifiers of children nodes, states reachable from this one
pub children: Vec<usize>,
}
impl<S: GameState> Node<S> {
pub fn new(state: S, action: Option<S::Action>, parent: Option<usize>) -> Self {
pub fn new(state: S, action: Option<S::Action>, parent: Option<usize>, depth: usize) -> Self {
Node {
state,
action,
parent,
visits: 0,
player_view: HashMap::with_capacity(2),
depth: depth,
children: Vec::new(),
}
}