Adding a basic print_tree function to visualize the MCTS search tree
This probably is not a good thing to run on a very large tree.
This commit is contained in:
parent
6cc6e6a7ba
commit
b80f039b93
@ -40,6 +40,7 @@ fn main() {
|
|||||||
// Find the best move
|
// Find the best move
|
||||||
match mcts.search() {
|
match mcts.search() {
|
||||||
Ok(action) => {
|
Ok(action) => {
|
||||||
|
mcts.print_tree();
|
||||||
println!(
|
println!(
|
||||||
"AI chooses: {} (row {}, col {})",
|
"AI chooses: {} (row {}, col {})",
|
||||||
action.index,
|
action.index,
|
||||||
|
@ -6,6 +6,7 @@ use crate::policy::simulation::simulate_reward;
|
|||||||
use crate::state::GameState;
|
use crate::state::GameState;
|
||||||
use crate::tree::arena::Arena;
|
use crate::tree::arena::Arena;
|
||||||
use crate::tree::node::{Node, RewardVal};
|
use crate::tree::node::{Node, RewardVal};
|
||||||
|
use crate::tree::print_tree;
|
||||||
use rand::prelude::SliceRandom;
|
use rand::prelude::SliceRandom;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::time::Instant;
|
use std::time::Instant;
|
||||||
@ -130,6 +131,10 @@ impl<'conf, S: GameState + std::fmt::Debug> MCTS<'conf, S> {
|
|||||||
None => Err(MCTSError::NoBestAction),
|
None => Err(MCTSError::NoBestAction),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn print_tree(&self) {
|
||||||
|
print_tree(self.root_id, &self.arena)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Errors returned by the MCTS algorithm
|
/// Errors returned by the MCTS algorithm
|
||||||
|
@ -7,7 +7,7 @@ use std::hash::Hash;
|
|||||||
///
|
///
|
||||||
/// When leveraging MCTS for your game, you must implement this trait to provide
|
/// When leveraging MCTS for your game, you must implement this trait to provide
|
||||||
/// the specifics for your game.
|
/// the specifics for your game.
|
||||||
pub trait GameState: Clone {
|
pub trait GameState: Clone + Debug {
|
||||||
/// The type of actions that can be taken in the game
|
/// The type of actions that can be taken in the game
|
||||||
type Action: Action;
|
type Action: Action;
|
||||||
|
|
||||||
|
@ -1,2 +1,27 @@
|
|||||||
pub mod arena;
|
pub mod arena;
|
||||||
pub mod node;
|
pub mod node;
|
||||||
|
|
||||||
|
use crate::state::GameState;
|
||||||
|
use crate::tree::arena::Arena;
|
||||||
|
use crate::tree::node::Node;
|
||||||
|
|
||||||
|
pub fn print_tree<S: GameState>(node_id: usize, arena: &Arena<S>) {
|
||||||
|
let mut to_print: Vec<usize> = Vec::new();
|
||||||
|
to_print.push(node_id);
|
||||||
|
while let Some(node_id) = to_print.pop() {
|
||||||
|
let node: &Node<S> = arena.get_node(node_id);
|
||||||
|
if node.depth > 0 {
|
||||||
|
for _ in 0..node.depth - 1 {
|
||||||
|
print!("| ")
|
||||||
|
}
|
||||||
|
print!("|- ");
|
||||||
|
}
|
||||||
|
println!(
|
||||||
|
"{:?} a:{:?} v:{} {:?}",
|
||||||
|
node_id, node.action, node.visits, node.player_view
|
||||||
|
);
|
||||||
|
for child_id in node.children.iter().rev() {
|
||||||
|
to_print.push(*child_id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::fmt::Debug;
|
|
||||||
|
|
||||||
use crate::state::GameState;
|
use crate::state::GameState;
|
||||||
|
|
||||||
@ -15,7 +14,7 @@ pub type RewardVal = f64;
|
|||||||
///
|
///
|
||||||
/// This class is not thread safe, as the library does not provide for parallel
|
/// This class is not thread safe, as the library does not provide for parallel
|
||||||
/// search.
|
/// search.
|
||||||
#[derive(Debug)]
|
#[derive(std::fmt::Debug)]
|
||||||
pub struct Node<S: GameState> {
|
pub struct Node<S: GameState> {
|
||||||
/// The game state at the given node, after `action`
|
/// The game state at the given node, after `action`
|
||||||
pub state: S,
|
pub state: S,
|
||||||
@ -30,7 +29,7 @@ pub struct Node<S: GameState> {
|
|||||||
pub visits: u64,
|
pub visits: u64,
|
||||||
|
|
||||||
/// The player's evaluation of the node
|
/// The player's evaluation of the node
|
||||||
pub player_view: HashMap<S::Player, PlayerNodeView>,
|
pub player_view: HashMap<S::Player, PlayerRewardView>,
|
||||||
|
|
||||||
/// The depth of the node in the tree, this is 0 for the root node
|
/// The depth of the node in the tree, this is 0 for the root node
|
||||||
pub depth: usize,
|
pub depth: usize,
|
||||||
@ -85,7 +84,7 @@ impl<S: GameState> Node<S> {
|
|||||||
let pv = self
|
let pv = self
|
||||||
.player_view
|
.player_view
|
||||||
.entry(player)
|
.entry(player)
|
||||||
.or_insert(PlayerNodeView::default());
|
.or_insert(PlayerRewardView::default());
|
||||||
pv.rewards.push(reward);
|
pv.rewards.push(reward);
|
||||||
pv.reward_sum += reward;
|
pv.reward_sum += reward;
|
||||||
pv.reward_average = pv.reward_sum / pv.rewards.len() as f64;
|
pv.reward_average = pv.reward_sum / pv.rewards.len() as f64;
|
||||||
@ -95,8 +94,7 @@ impl<S: GameState> Node<S> {
|
|||||||
/// A player's specific perspective of a node's value
|
/// A player's specific perspective of a node's value
|
||||||
///
|
///
|
||||||
/// Each player has their own idea of the value of a node.
|
/// Each player has their own idea of the value of a node.
|
||||||
#[derive(Debug)]
|
pub struct PlayerRewardView {
|
||||||
pub struct PlayerNodeView {
|
|
||||||
/// The total reward from simulations through this node
|
/// The total reward from simulations through this node
|
||||||
pub reward_sum: RewardVal,
|
pub reward_sum: RewardVal,
|
||||||
|
|
||||||
@ -107,12 +105,19 @@ pub struct PlayerNodeView {
|
|||||||
pub rewards: Vec<RewardVal>,
|
pub rewards: Vec<RewardVal>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for PlayerNodeView {
|
impl Default for PlayerRewardView {
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
PlayerNodeView {
|
PlayerRewardView {
|
||||||
reward_sum: 0.0,
|
reward_sum: 0.0,
|
||||||
reward_average: 0.0,
|
reward_average: 0.0,
|
||||||
rewards: Vec::new(),
|
rewards: Vec::new(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Debug for PlayerRewardView {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
write!(f, "{{sum={} avg={}}}", self.reward_sum, self.reward_average)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user