diff --git a/examples/auto_tic_tac_toe.rs b/examples/auto_tic_tac_toe.rs index b6a5387..22c3bff 100644 --- a/examples/auto_tic_tac_toe.rs +++ b/examples/auto_tic_tac_toe.rs @@ -40,6 +40,7 @@ fn main() { // Find the best move match mcts.search() { Ok(action) => { + mcts.print_tree(); println!( "AI chooses: {} (row {}, col {})", action.index, diff --git a/src/mcts.rs b/src/mcts.rs index fb25f18..bce530a 100644 --- a/src/mcts.rs +++ b/src/mcts.rs @@ -6,6 +6,7 @@ use crate::policy::simulation::simulate_reward; use crate::state::GameState; use crate::tree::arena::Arena; use crate::tree::node::{Node, RewardVal}; +use crate::tree::print_tree; use rand::prelude::SliceRandom; use std::collections::HashMap; use std::time::Instant; @@ -130,6 +131,10 @@ impl<'conf, S: GameState + std::fmt::Debug> MCTS<'conf, S> { None => Err(MCTSError::NoBestAction), } } + + pub fn print_tree(&self) { + print_tree(self.root_id, &self.arena) + } } /// Errors returned by the MCTS algorithm diff --git a/src/state.rs b/src/state.rs index 047ed6e..747063c 100644 --- a/src/state.rs +++ b/src/state.rs @@ -7,7 +7,7 @@ use std::hash::Hash; /// /// When leveraging MCTS for your game, you must implement this trait to provide /// the specifics for your game. -pub trait GameState: Clone { +pub trait GameState: Clone + Debug { /// The type of actions that can be taken in the game type Action: Action; diff --git a/src/tree/mod.rs b/src/tree/mod.rs index 9c02d29..79ac9cb 100644 --- a/src/tree/mod.rs +++ b/src/tree/mod.rs @@ -1,2 +1,27 @@ pub mod arena; pub mod node; + +use crate::state::GameState; +use crate::tree::arena::Arena; +use crate::tree::node::Node; + +pub fn print_tree(node_id: usize, arena: &Arena) { + let mut to_print: Vec = Vec::new(); + to_print.push(node_id); + while let Some(node_id) = to_print.pop() { + let node: &Node = arena.get_node(node_id); + if node.depth > 0 { + for _ in 0..node.depth - 1 { + print!("| ") + } + print!("|- "); + } + println!( + "{:?} a:{:?} v:{} {:?}", + node_id, node.action, node.visits, node.player_view + ); + for child_id in node.children.iter().rev() { + to_print.push(*child_id); + } + } +} diff --git a/src/tree/node.rs b/src/tree/node.rs index fb6b644..0c66b95 100644 --- a/src/tree/node.rs +++ b/src/tree/node.rs @@ -1,5 +1,4 @@ use std::collections::HashMap; -use std::fmt::Debug; use crate::state::GameState; @@ -15,7 +14,7 @@ pub type RewardVal = f64; /// /// This class is not thread safe, as the library does not provide for parallel /// search. -#[derive(Debug)] +#[derive(std::fmt::Debug)] pub struct Node { /// The game state at the given node, after `action` pub state: S, @@ -30,7 +29,7 @@ pub struct Node { pub visits: u64, /// The player's evaluation of the node - pub player_view: HashMap, + pub player_view: HashMap, /// The depth of the node in the tree, this is 0 for the root node pub depth: usize, @@ -85,7 +84,7 @@ impl Node { let pv = self .player_view .entry(player) - .or_insert(PlayerNodeView::default()); + .or_insert(PlayerRewardView::default()); pv.rewards.push(reward); pv.reward_sum += reward; pv.reward_average = pv.reward_sum / pv.rewards.len() as f64; @@ -95,8 +94,7 @@ impl Node { /// A player's specific perspective of a node's value /// /// Each player has their own idea of the value of a node. -#[derive(Debug)] -pub struct PlayerNodeView { +pub struct PlayerRewardView { /// The total reward from simulations through this node pub reward_sum: RewardVal, @@ -107,12 +105,19 @@ pub struct PlayerNodeView { pub rewards: Vec, } -impl Default for PlayerNodeView { +impl Default for PlayerRewardView { fn default() -> Self { - PlayerNodeView { + PlayerRewardView { reward_sum: 0.0, reward_average: 0.0, rewards: Vec::new(), } } } + +impl std::fmt::Debug for PlayerRewardView { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{{sum={} avg={}}}", self.reward_sum, self.reward_average)?; + Ok(()) + } +}