From 6cc6e6a7ba5589b60656c2e1aec3f9f544e9abd6 Mon Sep 17 00:00:00 2001 From: David Kruger Date: Fri, 27 Jun 2025 14:06:19 -0700 Subject: [PATCH] WeightedBackfillPropagation is implemented --- src/mcts.rs | 5 +++-- src/policy/backprop/mod.rs | 39 +++++++++++++++++++++++++++++--------- src/tree/node.rs | 6 +++++- 3 files changed, 38 insertions(+), 12 deletions(-) diff --git a/src/mcts.rs b/src/mcts.rs index bd50edf..fb25f18 100644 --- a/src/mcts.rs +++ b/src/mcts.rs @@ -29,7 +29,7 @@ impl<'conf, S: GameState + std::fmt::Debug> MCTS<'conf, S> { /// Creates a new instance with the given initial state and configuration pub fn new(initial_state: S, config: &'conf MCTSConfig) -> Self { let mut arena: Arena = Arena::new(config.tree_size_allocation); - let root: Node = Node::new(initial_state.clone(), None, None); + let root: Node = Node::new(initial_state.clone(), None, None, 0); let root_id: usize = arena.add_node(root); MCTS { arena, @@ -100,9 +100,10 @@ impl<'conf, S: GameState + std::fmt::Debug> MCTS<'conf, S> { let parent: &Node = self.arena.get_node(id); let legal_actions: Vec = parent.state.get_legal_actions(); let parent_state: S = parent.state.clone(); + let depth: usize = parent.depth + 1; for action in legal_actions { let state = parent_state.state_after_action(&action); - let new_node = Node::new(state, Some(action), Some(id)); + let new_node = Node::new(state, Some(action), Some(id), depth); let new_id = self.arena.add_node(new_node); self.arena.get_node_mut(id).children.push(new_id); } diff --git a/src/policy/backprop/mod.rs b/src/policy/backprop/mod.rs index 4b07037..1c98a14 100644 --- a/src/policy/backprop/mod.rs +++ b/src/policy/backprop/mod.rs @@ -66,10 +66,6 @@ fn standard_backprop( arena: &mut Arena, rewards: &HashMap, ) { - // TODO: - // - each node needs the perspective of the different players not just one view - // - e.g. reward_sum(player), reward_avg(player), rewards(player)[], visits(player) - // - we could make special version for 2-player zero-sum games like below let mut current_id: usize = node_id; loop { let node = arena.get_node_mut(current_id); @@ -90,10 +86,35 @@ fn standard_backprop( } fn weighted_backprop( - _depth_factor: f64, - _node_id: usize, - _arena: &mut Arena, - _rewards: &HashMap, + depth_factor: f64, + node_id: usize, + arena: &mut Arena, + rewards: &HashMap, ) { - // TODO + let mut current_id: usize = node_id; + loop { + let node = arena.get_node_mut(current_id); + let player = node.state.get_current_player().clone(); + let weight = weight_for_depth(depth_factor, node.depth); + match rewards.get(&player) { + Some(reward) => { + node.increment_visits(); + node.record_player_reward(player, (*reward) * weight); + if let Some(parent_id) = node.parent { + current_id = parent_id; + } else { + break; + } + } + None => (), + } + } +} + +/// Calculate the weight based on the current depth +/// +/// The weight can be multiplied agains the reward to provide more or less +/// influence based on depth as determined by the depth factor +fn weight_for_depth(depth_factor: f64, depth: usize) -> f64 { + 1.0 / (1.0 + depth_factor * depth as f64) } diff --git a/src/tree/node.rs b/src/tree/node.rs index 36522fd..fb6b644 100644 --- a/src/tree/node.rs +++ b/src/tree/node.rs @@ -32,18 +32,22 @@ pub struct Node { /// The player's evaluation of the node pub player_view: HashMap, + /// The depth of the node in the tree, this is 0 for the root node + pub depth: usize, + /// The identifiers of children nodes, states reachable from this one pub children: Vec, } impl Node { - pub fn new(state: S, action: Option, parent: Option) -> Self { + pub fn new(state: S, action: Option, parent: Option, depth: usize) -> Self { Node { state, action, parent, visits: 0, player_view: HashMap::with_capacity(2), + depth: depth, children: Vec::new(), } }