From 6a338182389a8847615d71dd8d51275df144a64e Mon Sep 17 00:00:00 2001 From: David Kruger Date: Fri, 27 Jun 2025 16:07:33 -0700 Subject: [PATCH] We propagate the reward for both sides The AI now properly choses the optimal path for the active player --- src/policy/backprop/mod.rs | 39 +++++++++++++++----------------------- 1 file changed, 15 insertions(+), 24 deletions(-) diff --git a/src/policy/backprop/mod.rs b/src/policy/backprop/mod.rs index 1c98a14..7980461 100644 --- a/src/policy/backprop/mod.rs +++ b/src/policy/backprop/mod.rs @@ -69,18 +69,14 @@ fn standard_backprop( let mut current_id: usize = node_id; loop { let node = arena.get_node_mut(current_id); - let player = node.state.get_current_player().clone(); - match rewards.get(&player) { - Some(reward) => { - node.increment_visits(); - node.record_player_reward(player, *reward); - if let Some(parent_id) = node.parent { - current_id = parent_id; - } else { - break; - } - } - None => (), + node.increment_visits(); + for (player, reward) in rewards.iter() { + node.record_player_reward(player.clone(), *reward); + } + if let Some(parent_id) = node.parent { + current_id = parent_id; + } else { + break; } } } @@ -94,19 +90,14 @@ fn weighted_backprop( let mut current_id: usize = node_id; loop { let node = arena.get_node_mut(current_id); - let player = node.state.get_current_player().clone(); let weight = weight_for_depth(depth_factor, node.depth); - match rewards.get(&player) { - Some(reward) => { - node.increment_visits(); - node.record_player_reward(player, (*reward) * weight); - if let Some(parent_id) = node.parent { - current_id = parent_id; - } else { - break; - } - } - None => (), + for (player, reward) in rewards.iter() { + node.record_player_reward(player.clone(), (*reward) * weight); + } + if let Some(parent_id) = node.parent { + current_id = parent_id; + } else { + break; } } }