Adding a basic print_tree function to visualize the MCTS search tree

This probably is not a good thing to run on a very large tree.
This commit is contained in:
David Kruger 2025-06-27 14:56:51 -07:00
parent 6cc6e6a7ba
commit b80f039b93
5 changed files with 45 additions and 9 deletions

View File

@ -40,6 +40,7 @@ fn main() {
// Find the best move // Find the best move
match mcts.search() { match mcts.search() {
Ok(action) => { Ok(action) => {
mcts.print_tree();
println!( println!(
"AI chooses: {} (row {}, col {})", "AI chooses: {} (row {}, col {})",
action.index, action.index,

View File

@ -6,6 +6,7 @@ use crate::policy::simulation::simulate_reward;
use crate::state::GameState; use crate::state::GameState;
use crate::tree::arena::Arena; use crate::tree::arena::Arena;
use crate::tree::node::{Node, RewardVal}; use crate::tree::node::{Node, RewardVal};
use crate::tree::print_tree;
use rand::prelude::SliceRandom; use rand::prelude::SliceRandom;
use std::collections::HashMap; use std::collections::HashMap;
use std::time::Instant; use std::time::Instant;
@ -130,6 +131,10 @@ impl<'conf, S: GameState + std::fmt::Debug> MCTS<'conf, S> {
None => Err(MCTSError::NoBestAction), None => Err(MCTSError::NoBestAction),
} }
} }
pub fn print_tree(&self) {
print_tree(self.root_id, &self.arena)
}
} }
/// Errors returned by the MCTS algorithm /// Errors returned by the MCTS algorithm

View File

@ -7,7 +7,7 @@ use std::hash::Hash;
/// ///
/// When leveraging MCTS for your game, you must implement this trait to provide /// When leveraging MCTS for your game, you must implement this trait to provide
/// the specifics for your game. /// the specifics for your game.
pub trait GameState: Clone { pub trait GameState: Clone + Debug {
/// The type of actions that can be taken in the game /// The type of actions that can be taken in the game
type Action: Action; type Action: Action;

View File

@ -1,2 +1,27 @@
pub mod arena; pub mod arena;
pub mod node; pub mod node;
use crate::state::GameState;
use crate::tree::arena::Arena;
use crate::tree::node::Node;
pub fn print_tree<S: GameState>(node_id: usize, arena: &Arena<S>) {
let mut to_print: Vec<usize> = Vec::new();
to_print.push(node_id);
while let Some(node_id) = to_print.pop() {
let node: &Node<S> = arena.get_node(node_id);
if node.depth > 0 {
for _ in 0..node.depth - 1 {
print!("| ")
}
print!("|- ");
}
println!(
"{:?} a:{:?} v:{} {:?}",
node_id, node.action, node.visits, node.player_view
);
for child_id in node.children.iter().rev() {
to_print.push(*child_id);
}
}
}

View File

@ -1,5 +1,4 @@
use std::collections::HashMap; use std::collections::HashMap;
use std::fmt::Debug;
use crate::state::GameState; use crate::state::GameState;
@ -15,7 +14,7 @@ pub type RewardVal = f64;
/// ///
/// This class is not thread safe, as the library does not provide for parallel /// This class is not thread safe, as the library does not provide for parallel
/// search. /// search.
#[derive(Debug)] #[derive(std::fmt::Debug)]
pub struct Node<S: GameState> { pub struct Node<S: GameState> {
/// The game state at the given node, after `action` /// The game state at the given node, after `action`
pub state: S, pub state: S,
@ -30,7 +29,7 @@ pub struct Node<S: GameState> {
pub visits: u64, pub visits: u64,
/// The player's evaluation of the node /// The player's evaluation of the node
pub player_view: HashMap<S::Player, PlayerNodeView>, pub player_view: HashMap<S::Player, PlayerRewardView>,
/// The depth of the node in the tree, this is 0 for the root node /// The depth of the node in the tree, this is 0 for the root node
pub depth: usize, pub depth: usize,
@ -85,7 +84,7 @@ impl<S: GameState> Node<S> {
let pv = self let pv = self
.player_view .player_view
.entry(player) .entry(player)
.or_insert(PlayerNodeView::default()); .or_insert(PlayerRewardView::default());
pv.rewards.push(reward); pv.rewards.push(reward);
pv.reward_sum += reward; pv.reward_sum += reward;
pv.reward_average = pv.reward_sum / pv.rewards.len() as f64; pv.reward_average = pv.reward_sum / pv.rewards.len() as f64;
@ -95,8 +94,7 @@ impl<S: GameState> Node<S> {
/// A player's specific perspective of a node's value /// A player's specific perspective of a node's value
/// ///
/// Each player has their own idea of the value of a node. /// Each player has their own idea of the value of a node.
#[derive(Debug)] pub struct PlayerRewardView {
pub struct PlayerNodeView {
/// The total reward from simulations through this node /// The total reward from simulations through this node
pub reward_sum: RewardVal, pub reward_sum: RewardVal,
@ -107,12 +105,19 @@ pub struct PlayerNodeView {
pub rewards: Vec<RewardVal>, pub rewards: Vec<RewardVal>,
} }
impl Default for PlayerNodeView { impl Default for PlayerRewardView {
fn default() -> Self { fn default() -> Self {
PlayerNodeView { PlayerRewardView {
reward_sum: 0.0, reward_sum: 0.0,
reward_average: 0.0, reward_average: 0.0,
rewards: Vec::new(), rewards: Vec::new(),
} }
} }
} }
impl std::fmt::Debug for PlayerRewardView {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{{sum={} avg={}}}", self.reward_sum, self.reward_average)?;
Ok(())
}
}