WeightedBackfillPropagation is implemented
This commit is contained in:
parent
37b1f56f74
commit
6cc6e6a7ba
@ -29,7 +29,7 @@ impl<'conf, S: GameState + std::fmt::Debug> MCTS<'conf, S> {
|
||||
/// Creates a new instance with the given initial state and configuration
|
||||
pub fn new(initial_state: S, config: &'conf MCTSConfig<S>) -> Self {
|
||||
let mut arena: Arena<S> = Arena::new(config.tree_size_allocation);
|
||||
let root: Node<S> = Node::new(initial_state.clone(), None, None);
|
||||
let root: Node<S> = Node::new(initial_state.clone(), None, None, 0);
|
||||
let root_id: usize = arena.add_node(root);
|
||||
MCTS {
|
||||
arena,
|
||||
@ -100,9 +100,10 @@ impl<'conf, S: GameState + std::fmt::Debug> MCTS<'conf, S> {
|
||||
let parent: &Node<S> = self.arena.get_node(id);
|
||||
let legal_actions: Vec<S::Action> = parent.state.get_legal_actions();
|
||||
let parent_state: S = parent.state.clone();
|
||||
let depth: usize = parent.depth + 1;
|
||||
for action in legal_actions {
|
||||
let state = parent_state.state_after_action(&action);
|
||||
let new_node = Node::new(state, Some(action), Some(id));
|
||||
let new_node = Node::new(state, Some(action), Some(id), depth);
|
||||
let new_id = self.arena.add_node(new_node);
|
||||
self.arena.get_node_mut(id).children.push(new_id);
|
||||
}
|
||||
|
@ -66,10 +66,6 @@ fn standard_backprop<S: GameState>(
|
||||
arena: &mut Arena<S>,
|
||||
rewards: &HashMap<S::Player, RewardVal>,
|
||||
) {
|
||||
// TODO:
|
||||
// - each node needs the perspective of the different players not just one view
|
||||
// - e.g. reward_sum(player), reward_avg(player), rewards(player)[], visits(player)
|
||||
// - we could make special version for 2-player zero-sum games like below
|
||||
let mut current_id: usize = node_id;
|
||||
loop {
|
||||
let node = arena.get_node_mut(current_id);
|
||||
@ -90,10 +86,35 @@ fn standard_backprop<S: GameState>(
|
||||
}
|
||||
|
||||
fn weighted_backprop<S: GameState>(
|
||||
_depth_factor: f64,
|
||||
_node_id: usize,
|
||||
_arena: &mut Arena<S>,
|
||||
_rewards: &HashMap<S::Player, RewardVal>,
|
||||
depth_factor: f64,
|
||||
node_id: usize,
|
||||
arena: &mut Arena<S>,
|
||||
rewards: &HashMap<S::Player, RewardVal>,
|
||||
) {
|
||||
// TODO
|
||||
let mut current_id: usize = node_id;
|
||||
loop {
|
||||
let node = arena.get_node_mut(current_id);
|
||||
let player = node.state.get_current_player().clone();
|
||||
let weight = weight_for_depth(depth_factor, node.depth);
|
||||
match rewards.get(&player) {
|
||||
Some(reward) => {
|
||||
node.increment_visits();
|
||||
node.record_player_reward(player, (*reward) * weight);
|
||||
if let Some(parent_id) = node.parent {
|
||||
current_id = parent_id;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
None => (),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Calculate the weight based on the current depth
|
||||
///
|
||||
/// The weight can be multiplied agains the reward to provide more or less
|
||||
/// influence based on depth as determined by the depth factor
|
||||
fn weight_for_depth(depth_factor: f64, depth: usize) -> f64 {
|
||||
1.0 / (1.0 + depth_factor * depth as f64)
|
||||
}
|
||||
|
@ -32,18 +32,22 @@ pub struct Node<S: GameState> {
|
||||
/// The player's evaluation of the node
|
||||
pub player_view: HashMap<S::Player, PlayerNodeView>,
|
||||
|
||||
/// The depth of the node in the tree, this is 0 for the root node
|
||||
pub depth: usize,
|
||||
|
||||
/// The identifiers of children nodes, states reachable from this one
|
||||
pub children: Vec<usize>,
|
||||
}
|
||||
|
||||
impl<S: GameState> Node<S> {
|
||||
pub fn new(state: S, action: Option<S::Action>, parent: Option<usize>) -> Self {
|
||||
pub fn new(state: S, action: Option<S::Action>, parent: Option<usize>, depth: usize) -> Self {
|
||||
Node {
|
||||
state,
|
||||
action,
|
||||
parent,
|
||||
visits: 0,
|
||||
player_view: HashMap::with_capacity(2),
|
||||
depth: depth,
|
||||
children: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user