WeightedBackfillPropagation is implemented
This commit is contained in:
parent
37b1f56f74
commit
6cc6e6a7ba
@ -29,7 +29,7 @@ impl<'conf, S: GameState + std::fmt::Debug> MCTS<'conf, S> {
|
|||||||
/// Creates a new instance with the given initial state and configuration
|
/// Creates a new instance with the given initial state and configuration
|
||||||
pub fn new(initial_state: S, config: &'conf MCTSConfig<S>) -> Self {
|
pub fn new(initial_state: S, config: &'conf MCTSConfig<S>) -> Self {
|
||||||
let mut arena: Arena<S> = Arena::new(config.tree_size_allocation);
|
let mut arena: Arena<S> = Arena::new(config.tree_size_allocation);
|
||||||
let root: Node<S> = Node::new(initial_state.clone(), None, None);
|
let root: Node<S> = Node::new(initial_state.clone(), None, None, 0);
|
||||||
let root_id: usize = arena.add_node(root);
|
let root_id: usize = arena.add_node(root);
|
||||||
MCTS {
|
MCTS {
|
||||||
arena,
|
arena,
|
||||||
@ -100,9 +100,10 @@ impl<'conf, S: GameState + std::fmt::Debug> MCTS<'conf, S> {
|
|||||||
let parent: &Node<S> = self.arena.get_node(id);
|
let parent: &Node<S> = self.arena.get_node(id);
|
||||||
let legal_actions: Vec<S::Action> = parent.state.get_legal_actions();
|
let legal_actions: Vec<S::Action> = parent.state.get_legal_actions();
|
||||||
let parent_state: S = parent.state.clone();
|
let parent_state: S = parent.state.clone();
|
||||||
|
let depth: usize = parent.depth + 1;
|
||||||
for action in legal_actions {
|
for action in legal_actions {
|
||||||
let state = parent_state.state_after_action(&action);
|
let state = parent_state.state_after_action(&action);
|
||||||
let new_node = Node::new(state, Some(action), Some(id));
|
let new_node = Node::new(state, Some(action), Some(id), depth);
|
||||||
let new_id = self.arena.add_node(new_node);
|
let new_id = self.arena.add_node(new_node);
|
||||||
self.arena.get_node_mut(id).children.push(new_id);
|
self.arena.get_node_mut(id).children.push(new_id);
|
||||||
}
|
}
|
||||||
|
@ -66,10 +66,6 @@ fn standard_backprop<S: GameState>(
|
|||||||
arena: &mut Arena<S>,
|
arena: &mut Arena<S>,
|
||||||
rewards: &HashMap<S::Player, RewardVal>,
|
rewards: &HashMap<S::Player, RewardVal>,
|
||||||
) {
|
) {
|
||||||
// TODO:
|
|
||||||
// - each node needs the perspective of the different players not just one view
|
|
||||||
// - e.g. reward_sum(player), reward_avg(player), rewards(player)[], visits(player)
|
|
||||||
// - we could make special version for 2-player zero-sum games like below
|
|
||||||
let mut current_id: usize = node_id;
|
let mut current_id: usize = node_id;
|
||||||
loop {
|
loop {
|
||||||
let node = arena.get_node_mut(current_id);
|
let node = arena.get_node_mut(current_id);
|
||||||
@ -90,10 +86,35 @@ fn standard_backprop<S: GameState>(
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn weighted_backprop<S: GameState>(
|
fn weighted_backprop<S: GameState>(
|
||||||
_depth_factor: f64,
|
depth_factor: f64,
|
||||||
_node_id: usize,
|
node_id: usize,
|
||||||
_arena: &mut Arena<S>,
|
arena: &mut Arena<S>,
|
||||||
_rewards: &HashMap<S::Player, RewardVal>,
|
rewards: &HashMap<S::Player, RewardVal>,
|
||||||
) {
|
) {
|
||||||
// TODO
|
let mut current_id: usize = node_id;
|
||||||
|
loop {
|
||||||
|
let node = arena.get_node_mut(current_id);
|
||||||
|
let player = node.state.get_current_player().clone();
|
||||||
|
let weight = weight_for_depth(depth_factor, node.depth);
|
||||||
|
match rewards.get(&player) {
|
||||||
|
Some(reward) => {
|
||||||
|
node.increment_visits();
|
||||||
|
node.record_player_reward(player, (*reward) * weight);
|
||||||
|
if let Some(parent_id) = node.parent {
|
||||||
|
current_id = parent_id;
|
||||||
|
} else {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
None => (),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Calculate the weight based on the current depth
|
||||||
|
///
|
||||||
|
/// The weight can be multiplied agains the reward to provide more or less
|
||||||
|
/// influence based on depth as determined by the depth factor
|
||||||
|
fn weight_for_depth(depth_factor: f64, depth: usize) -> f64 {
|
||||||
|
1.0 / (1.0 + depth_factor * depth as f64)
|
||||||
}
|
}
|
||||||
|
@ -32,18 +32,22 @@ pub struct Node<S: GameState> {
|
|||||||
/// The player's evaluation of the node
|
/// The player's evaluation of the node
|
||||||
pub player_view: HashMap<S::Player, PlayerNodeView>,
|
pub player_view: HashMap<S::Player, PlayerNodeView>,
|
||||||
|
|
||||||
|
/// The depth of the node in the tree, this is 0 for the root node
|
||||||
|
pub depth: usize,
|
||||||
|
|
||||||
/// The identifiers of children nodes, states reachable from this one
|
/// The identifiers of children nodes, states reachable from this one
|
||||||
pub children: Vec<usize>,
|
pub children: Vec<usize>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<S: GameState> Node<S> {
|
impl<S: GameState> Node<S> {
|
||||||
pub fn new(state: S, action: Option<S::Action>, parent: Option<usize>) -> Self {
|
pub fn new(state: S, action: Option<S::Action>, parent: Option<usize>, depth: usize) -> Self {
|
||||||
Node {
|
Node {
|
||||||
state,
|
state,
|
||||||
action,
|
action,
|
||||||
parent,
|
parent,
|
||||||
visits: 0,
|
visits: 0,
|
||||||
player_view: HashMap::with_capacity(2),
|
player_view: HashMap::with_capacity(2),
|
||||||
|
depth: depth,
|
||||||
children: Vec::new(),
|
children: Vec::new(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user