From 17884f4b90d1f3d460a3a997b2c9174b73bacf48 Mon Sep 17 00:00:00 2001 From: David Kruger Date: Mon, 23 Jun 2025 13:46:04 -0700 Subject: [PATCH] Working MCTS implementation This is a basic working implementation of the MCTS algorithm. Though currently the algorithm is slow compared with other implemenations, and makes sub-optimal choices when playing tic-tac-toe. Therefore some modifications are needed --- Cargo.lock | 154 +++++++++++++++ Cargo.toml | 16 ++ examples/tic_tac_toe.rs | 298 +++++++++++++++++++++++++++++ src/config.rs | 67 +++++++ src/lib.rs | 17 ++ src/mcts.rs | 147 ++++++++++++++ src/policy/backprop/mod.rs | 99 ++++++++++ src/policy/decision/mod.rs | 65 +++++++ src/policy/mod.rs | 4 + src/policy/selection/mod.rs | 59 ++++++ src/policy/selection/ucb1.rs | 79 ++++++++ src/policy/selection/ucb1_tuned.rs | 97 ++++++++++ src/policy/simulation/mod.rs | 44 +++++ src/policy/simulation/random.rs | 18 ++ src/state.rs | 90 +++++++++ src/tree/arena.rs | 46 +++++ src/tree/mod.rs | 2 + src/tree/node.rs | 114 +++++++++++ 18 files changed, 1416 insertions(+) create mode 100644 Cargo.lock create mode 100644 Cargo.toml create mode 100644 examples/tic_tac_toe.rs create mode 100644 src/config.rs create mode 100644 src/lib.rs create mode 100644 src/mcts.rs create mode 100644 src/policy/backprop/mod.rs create mode 100644 src/policy/decision/mod.rs create mode 100644 src/policy/mod.rs create mode 100644 src/policy/selection/mod.rs create mode 100644 src/policy/selection/ucb1.rs create mode 100644 src/policy/selection/ucb1_tuned.rs create mode 100644 src/policy/simulation/mod.rs create mode 100644 src/policy/simulation/random.rs create mode 100644 src/state.rs create mode 100644 src/tree/arena.rs create mode 100644 src/tree/mod.rs create mode 100644 src/tree/node.rs diff --git a/Cargo.lock b/Cargo.lock new file mode 100644 index 0000000..739db4d --- /dev/null +++ b/Cargo.lock @@ -0,0 +1,154 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 4 + +[[package]] +name = "cfg-if" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9555578bc9e57714c812a1f84e4fc5b4d21fcb063490c624de019f7464c91268" + +[[package]] +name = "getrandom" +version = "0.2.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "335ff9f135e4384c8150d6f27c6daed433577f86b4750418338c01a1a2528592" +dependencies = [ + "cfg-if", + "libc", + "wasi", +] + +[[package]] +name = "libc" +version = "0.2.174" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1171693293099992e19cddea4e8b849964e9846f4acee11b3948bcc337be8776" + +[[package]] +name = "ppv-lite86" +version = "0.2.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85eae3c4ed2f50dcfe72643da4befc30deadb458a9b590d720cde2f2b1e97da9" +dependencies = [ + "zerocopy", +] + +[[package]] +name = "proc-macro2" +version = "1.0.95" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "02b3e5e68a3a1a02aad3ec490a98007cbc13c37cbe84a3cd7b8e406d76e7f778" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "quote" +version = "1.0.40" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1885c039570dc00dcb4ff087a89e185fd56bae234ddc7f056a945bf36467248d" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "rand" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" +dependencies = [ + "libc", + "rand_chacha", + "rand_core", +] + +[[package]] +name = "rand_chacha" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" +dependencies = [ + "ppv-lite86", + "rand_core", +] + +[[package]] +name = "rand_core" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" +dependencies = [ + "getrandom", +] + +[[package]] +name = "rustic_mcts" +version = "0.1.0" +dependencies = [ + "rand", + "thiserror", +] + +[[package]] +name = "syn" +version = "2.0.103" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e4307e30089d6fd6aff212f2da3a1f9e32f3223b1f010fb09b7c95f90f3ca1e8" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "thiserror" +version = "2.0.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "567b8a2dae586314f7be2a752ec7474332959c6460e02bde30d702a66d488708" +dependencies = [ + "thiserror-impl", +] + +[[package]] +name = "thiserror-impl" +version = "2.0.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f7cf42b4507d8ea322120659672cf1b9dbb93f8f2d4ecfd6e51350ff5b17a1d" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "unicode-ident" +version = "1.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a5f39404a5da50712a4c1eecf25e90dd62b613502b7e925fd4e4d19b5c96512" + +[[package]] +name = "wasi" +version = "0.11.1+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ccf3ec651a847eb01de73ccad15eb7d99f80485de043efb2f370cd654f4ea44b" + +[[package]] +name = "zerocopy" +version = "0.8.26" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1039dd0d3c310cf05de012d8a39ff557cb0d23087fd44cad61df08fc31907a2f" +dependencies = [ + "zerocopy-derive", +] + +[[package]] +name = "zerocopy-derive" +version = "0.8.26" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ecf5b4cc5364572d7f4c329661bcc82724222973f2cab6f050a4e5c22f75181" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..0852c35 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,16 @@ +[package] +name = "rustic_mcts" +version = "0.1.0" +edition = "2021" +authors = ["David Kruger "] +description = "An extensible implementation of Monte Carlo Tree Search (MCTS) using an arena allocator." +license = "MIT" +repository = "https://gitlabs.krugerlabs.us/krugd/rustic_mcts" +readme = "README.md" +keywords = ["mcts", "rust", "monte_carlo", "tree", "ai", "ml"] +categories = ["algorithms", "data-structures"] + + +[dependencies] +rand = "~0.8" +thiserror = "~2.0" diff --git a/examples/tic_tac_toe.rs b/examples/tic_tac_toe.rs new file mode 100644 index 0000000..263b5c7 --- /dev/null +++ b/examples/tic_tac_toe.rs @@ -0,0 +1,298 @@ +use std::collections::HashMap; +use std::fmt; +use std::io::{self, Write}; + +use rustic_mcts::policy::backprop::BackpropagationPolicy; +use rustic_mcts::policy::decision::DecisionPolicy; +use rustic_mcts::policy::selection::SelectionPolicy; +use rustic_mcts::policy::simulation::SimulationPolicy; +use rustic_mcts::{Action, GameState, MCTSConfig, RewardVal, MCTS}; + +fn main() { + println!("MCTS Tic-Tac-Toe Example"); + println!("========================"); + println!(); + + // Set up a new game + let mut game = TicTacToe::new(); + + // Create MCTS configuration + let config = MCTSConfig { + max_iterations: 10_000, + max_time: None, + tree_size_allocation: 10_000, + selection_policy: SelectionPolicy::UCB1Tuned(1.414), + simulation_policy: SimulationPolicy::Random, + backprop_policy: BackpropagationPolicy::Standard, + decision_policy: DecisionPolicy::MostVisits, + }; + + // Main game loop + while !game.is_terminal() { + // Display the board + println!("{}", game); + + if game.current_player == Player::X { + // Human player (X) + println!("Your move (enter row column, e.g. '1 2'): "); + io::stdout().flush().unwrap(); + + let mut input = String::new(); + io::stdin().read_line(&mut input).unwrap(); + + let coords: Vec = input + .trim() + .split_whitespace() + .filter_map(|s| s.parse::().ok()) + .collect(); + + if coords.len() != 2 || coords[0] > 2 || coords[1] > 2 { + println!("Invalid move! Enter row and column (0-2)."); + continue; + } + + let row = coords[0]; + let col = coords[1]; + + let move_index = row * 3 + col; + let action = Move { index: move_index }; + + if !game.is_legal_move(&action) { + println!("Illegal move! Try again."); + continue; + } + + // Apply the human's move + game = game.state_after_action(&action); + } else { + // AI player (O) + println!("AI is thinking..."); + + // Create a new MCTS search + let mut mcts = MCTS::new(game.clone(), &config); + + // Find the best move + match mcts.search() { + Ok(action) => { + println!( + "AI chooses: {} (row {}, col {})", + action.index, + action.index / 3, + action.index % 3 + ); + + // Apply the AI's move + game = game.state_after_action(&action); + } + Err(e) => { + println!("Error: {:?}", e); + break; + } + } + } + } + + // Display final state + println!("{}", game); + + // Report the result + if let Some(winner) = game.get_winner() { + println!("Player {:?} wins!", winner); + } else { + println!("The game is a draw!"); + } +} + +/// Players in Tic-Tac-Toe +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +enum Player { + X, + O, +} + +impl rustic_mcts::Player for Player {} + +/// Tic-Tac-Toe move +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +struct Move { + /// Board position index (0-8) + index: usize, +} + +impl Action for Move { + fn id(&self) -> usize { + self.index + } +} + +/// Tic-Tac-Toe game state +#[derive(Clone)] +struct TicTacToe { + /// Board representation (None = empty, Some(Player) = occupied) + board: [Option; 9], + + /// Current player's turn + current_player: Player, + + /// Number of moves played so far + moves_played: usize, +} + +impl TicTacToe { + /// Creates a new empty Tic-Tac-Toe board + fn new() -> Self { + TicTacToe { + board: [None; 9], + current_player: Player::X, + moves_played: 0, + } + } + + /// Checks if a move is legal + fn is_legal_move(&self, action: &Move) -> bool { + if action.index >= 9 { + return false; + } + self.board[action.index].is_none() + } + + /// Returns the winner of the game, if any + fn get_winner(&self) -> Option { + // Check rows + for row in 0..3 { + let i = row * 3; + if self.board[i].is_some() + && self.board[i] == self.board[i + 1] + && self.board[i] == self.board[i + 2] + { + return self.board[i]; + } + } + + // Check columns + for col in 0..3 { + if self.board[col].is_some() + && self.board[col] == self.board[col + 3] + && self.board[col] == self.board[col + 6] + { + return self.board[col]; + } + } + + // Check diagonals + if self.board[0].is_some() + && self.board[0] == self.board[4] + && self.board[0] == self.board[8] + { + return self.board[0]; + } + if self.board[2].is_some() + && self.board[2] == self.board[4] + && self.board[2] == self.board[6] + { + return self.board[2]; + } + + None + } +} + +impl GameState for TicTacToe { + type Action = Move; + type Player = Player; + + fn get_legal_actions(&self) -> Vec { + let mut actions = Vec::new(); + for i in 0..9 { + if self.board[i].is_none() { + actions.push(Move { index: i }); + } + } + actions + } + + fn state_after_action(&self, action: &Self::Action) -> Self { + let mut new_state = self.clone(); + + // Make the move + new_state.board[action.index] = Some(self.current_player); + new_state.moves_played = self.moves_played + 1; + + // Switch player + new_state.current_player = match self.current_player { + Player::X => Player::O, + Player::O => Player::X, + }; + + new_state + } + + fn is_terminal(&self) -> bool { + self.get_winner().is_some() || self.moves_played == 9 + } + + fn reward_for_player(&self, player: &Self::Player) -> RewardVal { + if let Some(winner) = self.get_winner() { + if winner == *player { + return 1.0; // Win + } else { + return 0.0; // Loss + } + } + + // Draw + 0.5 + } + + fn rewards_for_players(&self) -> HashMap { + HashMap::from_iter(vec![ + (Player::X, self.reward_for_player(&Player::X)), + (Player::O, self.reward_for_player(&Player::O)), + ]) + } + + fn get_current_player(&self) -> &Self::Player { + &self.current_player + } +} + +impl fmt::Display for TicTacToe { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + writeln!(f, " 0 1 2")?; + for row in 0..3 { + write!(f, "{} ", row)?; + for col in 0..3 { + let index = row * 3 + col; + let symbol = match self.board[index] { + Some(Player::X) => "X", + Some(Player::O) => "O", + None => ".", + }; + write!(f, "{} ", symbol)?; + } + writeln!(f)?; + } + + writeln!(f, "\nPlayer {:?}'s turn", self.current_player)?; + Ok(()) + } +} + +impl fmt::Debug for TicTacToe { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "\n")?; + for row in 0..3 { + for col in 0..3 { + let index = row * 3 + col; + let symbol = match self.board[index] { + Some(Player::X) => "X", + Some(Player::O) => "O", + None => ".", + }; + write!(f, "{} ", symbol)?; + } + writeln!(f)?; + } + Ok(()) + } +} diff --git a/src/config.rs b/src/config.rs new file mode 100644 index 0000000..1557810 --- /dev/null +++ b/src/config.rs @@ -0,0 +1,67 @@ +use crate::policy::backprop::BackpropagationPolicy; +use crate::policy::decision::DecisionPolicy; +use crate::policy::selection::SelectionPolicy; +use crate::policy::simulation::SimulationPolicy; +use crate::state::GameState; +use std::time::Duration; + +/// Configuration for the MCTS algorithm +#[derive(Debug)] +pub struct MCTSConfig { + /// The maximum number of iterations to run when searching + /// + /// The search will stop after the given number of iterations, even if there + /// is search time has not exceeded `max_time`. + pub max_iterations: usize, + + /// The maximum time to run the search + /// + /// If set, the search will stop after this duration even if the maximum + /// iterations hasn't been reached. + pub max_time: Option, + + /// The size to initially allocate for the search tree + /// + /// This pre-allocates memory for the search tree which ensures contiguous + /// memory and improves performance by preventing the resizing of tree + /// as we explore. + pub tree_size_allocation: usize, + + /// The selection policy + /// + /// This dictates the path through which the game tree is searched. As such + /// the policy has a large impact on the overall aglorthm exeuction + pub selection_policy: SelectionPolicy, + + /// The simulation policy + /// + /// This dictates the game siluation when expanding and evaluating the + /// search tree. Random is generally a good default. + pub simulation_policy: SimulationPolicy, + + /// The backpropagation policy + /// + /// This dictates how the results of the simulation playouts are propagated + /// back up the tree. + pub backprop_policy: BackpropagationPolicy, + + /// The decision policy + /// + /// This dictates how the MCTS algorithm determines its final decision + /// after iterating through the search tree + pub decision_policy: DecisionPolicy, +} + +impl Default for MCTSConfig { + fn default() -> Self { + MCTSConfig { + max_iterations: 10_000, + max_time: None, + tree_size_allocation: 10_000, + selection_policy: SelectionPolicy::UCB1Tuned(1.414), + simulation_policy: SimulationPolicy::Random, + backprop_policy: BackpropagationPolicy::Standard, + decision_policy: DecisionPolicy::MostVisits, + } + } +} diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..b958c63 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,17 @@ +//! # rustic_mcts +//! +//! An extensible implementation of Monte Carlo Tree Search (MCTS) using arena allocation and +//! configurable policies. + +pub mod config; +pub mod mcts; +pub mod policy; +pub mod state; +pub mod tree; + +pub use config::MCTSConfig; +pub use mcts::MCTS; +pub use state::Action; +pub use state::GameState; +pub use state::Player; +pub use tree::node::RewardVal; diff --git a/src/mcts.rs b/src/mcts.rs new file mode 100644 index 0000000..37e2133 --- /dev/null +++ b/src/mcts.rs @@ -0,0 +1,147 @@ +use crate::config::MCTSConfig; +use crate::policy::backprop::backpropagate_rewards; +use crate::policy::decision::decide_on_action; +use crate::policy::selection::select_best_child; +use crate::policy::simulation::simulate_reward; +use crate::state::GameState; +use crate::tree::arena::Arena; +use crate::tree::node::{Node, RewardVal}; +use rand::prelude::SliceRandom; +use std::collections::HashMap; +use std::time::Instant; + +/// Monte Carlo Tree Search implementation +/// +/// This provides the interface for performing optimal searches on a tree using +/// the MCTS algorithm. +pub struct MCTS<'conf, S: GameState> { + /// The arena used for the tree + arena: Arena, + + /// The identifier of the root node of the search tree + root_id: usize, + + /// The configuration used for the search + config: &'conf MCTSConfig, +} + +impl<'conf, S: GameState + std::fmt::Debug> MCTS<'conf, S> { + /// Creates a new instance with the given initial state and configuration + pub fn new(initial_state: S, config: &'conf MCTSConfig) -> Self { + let mut arena: Arena = Arena::new(config.tree_size_allocation); + let root: Node = Node::new(initial_state.clone(), None, None); + let root_id: usize = arena.add_node(root); + MCTS { + arena, + root_id, + config, + } + } + + /// Runs the MCTS algorithm, returning the "best" action + /// + /// The search will stop once `max_iterations` or `max_time` from + /// the assigned configration is reached. + pub fn search(&mut self) -> Result { + self.search_for_iterations(self.config.max_iterations) + } + + /// Runs the MCTS algorithm, returning the "best" action after the given iterations + /// + /// This ignores the `max_iterations` provided in the config, however will + /// return if `max_time` is specific and reached before the iterations are complete. + pub fn search_for_iterations(&mut self, iterations: usize) -> Result { + let start_time = Instant::now(); + for _ in 0..iterations { + match self.config.max_time { + Some(max_time) => { + if start_time.elapsed() >= max_time { + break; // ending early due to time + } + } + None => {} + } + self.execute_iteration()?; + } + + self.best_action() + } + + /// Runs the MCTS algorithm for a single iteration + fn execute_iteration(&mut self) -> Result<()> { + let mut selected_id: usize = self.select(); + let selected_node: &Node = self.arena.get_node(selected_id); + if !selected_node.state.is_terminal() { + self.expand(selected_id); + let children: &Vec = &self.arena.get_node(selected_id).children; + let random_child: usize = *children.choose(&mut rand::thread_rng()).unwrap(); + selected_id = random_child; + } + let rewards = self.simulate(selected_id); + self.backprop(selected_id, &rewards); + + Ok(()) + } + + /// MCTS Phase 1: Selection - Find the "best" node to expand + fn select(&mut self) -> usize { + let mut current_id: usize = self.root_id; + loop { + let node = &self.arena.get_node(current_id); + if node.is_leaf() || node.state.is_terminal() { + return current_id; + } + current_id = select_best_child(&self.config.selection_policy, &node, &self.arena); + } + } + + /// MCTS Phase 2: Expansion - Expand the selected node on the tree + fn expand(&mut self, id: usize) { + let parent: &Node = self.arena.get_node_mut(id); + let legal_actions: Vec = parent.state.get_legal_actions(); + let parent_state: S = parent.state.clone(); + for action in legal_actions { + let state = parent_state.state_after_action(&action); + let new_node = Node::new(state, Some(action), Some(id)); + let new_id = self.arena.add_node(new_node); + self.arena.get_node_mut(id).children.push(new_id); + } + } + + fn simulate(&self, id: usize) -> HashMap { + let node = &self.arena.get_node(id); + simulate_reward(&self.config.simulation_policy, &node, &self.arena) + } + + fn backprop(&mut self, selected_id: usize, rewards: &HashMap) { + backpropagate_rewards( + &self.config.backprop_policy, + selected_id, + &mut self.arena, + &rewards, + ) + } + + fn best_action(&self) -> Result { + let root_node: &Node = self.arena.get_node(self.root_id); + match decide_on_action(&self.config.decision_policy, &root_node, &self.arena) { + Some(action) => Ok(action), + None => Err(MCTSError::NoBestAction), + } + } +} + +/// Errors returned by the MCTS algorithm +#[derive(Debug, thiserror::Error)] +pub enum MCTSError { + /// The best action doesn't exist + #[error("Unable to determine a best action for the game")] + NoBestAction, + + /// The search tree was exhausted without finding a terminal node + #[error("Search tree exhausted without finding terminal node")] + NonTerminalGame, +} + +/// Result returned by the MCTS algorithm +pub type Result = std::result::Result; diff --git a/src/policy/backprop/mod.rs b/src/policy/backprop/mod.rs new file mode 100644 index 0000000..4b07037 --- /dev/null +++ b/src/policy/backprop/mod.rs @@ -0,0 +1,99 @@ +use crate::state::GameState; +use crate::tree::arena::Arena; +use crate::tree::node::RewardVal; +use std::collections::HashMap; + +/// The back propagation policy dictating the propagation of playout results +/// +/// This policy drives how the backpropagation phase of the MCTS algorithm is +/// executed, allowing for some minor customization. +/// +/// Typically the Standard policy, used by most implementaions of MCTS, is +/// sufficient +#[derive(Debug)] +pub enum BackpropagationPolicy { + /// Standard back propagation + /// + /// This increments the visitation count and adds the simulated rewards + /// results to the aggregate values. + /// + /// This is the standard policy used in most MCTS implementations. + Standard, + + /// Weighted back propagation + /// + /// This weights the value of the simulated rewards based on the depth, + /// allowing us to put more-or-less influence on deeper branches + /// - Positive weight factor makes deeper nodes less influential + /// - Negative weight factor makes deeper nodes more influential + Weighted(f64), + + /// Custom backpropagation policy + Custom(Box>), +} + +/// Trait for an object implementing the backpropagation logic whene exploring the MCTS +/// search tree. +pub trait CustomBackpropagationPolicy: std::fmt::Debug { + /// Backpropagate the given rewards values from the node up the tree + fn backprop( + &self, + node_id: usize, + arena: &mut Arena, + rewards: &HashMap, + ); +} + +pub fn backpropagate_rewards( + policy: &BackpropagationPolicy, + node_id: usize, + arena: &mut Arena, + rewards: &HashMap, +) { + match policy { + BackpropagationPolicy::Standard => standard_backprop(node_id, arena, rewards), + BackpropagationPolicy::Weighted(depth_factor) => { + weighted_backprop(*depth_factor, node_id, arena, rewards) + } + BackpropagationPolicy::Custom(custom_policy) => { + custom_policy.backprop(node_id, arena, rewards) + } + } +} + +fn standard_backprop( + node_id: usize, + arena: &mut Arena, + rewards: &HashMap, +) { + // TODO: + // - each node needs the perspective of the different players not just one view + // - e.g. reward_sum(player), reward_avg(player), rewards(player)[], visits(player) + // - we could make special version for 2-player zero-sum games like below + let mut current_id: usize = node_id; + loop { + let node = arena.get_node_mut(current_id); + let player = node.state.get_current_player().clone(); + match rewards.get(&player) { + Some(reward) => { + node.increment_visits(); + node.record_player_reward(player, *reward); + if let Some(parent_id) = node.parent { + current_id = parent_id; + } else { + break; + } + } + None => (), + } + } +} + +fn weighted_backprop( + _depth_factor: f64, + _node_id: usize, + _arena: &mut Arena, + _rewards: &HashMap, +) { + // TODO +} diff --git a/src/policy/decision/mod.rs b/src/policy/decision/mod.rs new file mode 100644 index 0000000..e9f6800 --- /dev/null +++ b/src/policy/decision/mod.rs @@ -0,0 +1,65 @@ +use crate::state::GameState; +use crate::tree::arena::Arena; +use crate::tree::node::Node; + +/// The decision policy when determining the action in final MCTS phase +/// +/// This policy drives how the MCTS algorithm chooses which action is the +/// "best" from the exploration. +#[derive(Debug)] +pub enum DecisionPolicy { + /// Decide on the action with the most visits + /// + /// This option relies on the statistical confidence drive by the MCTS + /// algorithm instead of the potentially more noisy value estimates. + /// + /// This is the standard policy used in most MCTS implementations, and + /// is a good selection when not hyper-maximizing for potential gain + MostVisits, + + /// Decide on the action with the highest average value + /// + /// This is non-standard, but is more aggressive in attempting to gain + /// the highest value in a decision. + HighestValue, +} + +pub fn decide_on_action( + policy: &DecisionPolicy, + root_node: &Node, + arena: &Arena, +) -> Option { + match policy { + DecisionPolicy::MostVisits => most_visits(root_node, arena), + DecisionPolicy::HighestValue => highest_value(root_node, arena), + } +} + +fn most_visits(root_node: &Node, arena: &Arena) -> Option { + let best_child_id: &usize = root_node + .children + .iter() + .max_by(|&a, &b| { + let node_a_visits = arena.get_node(*a).visits; + let node_b_visits = arena.get_node(*b).visits; + node_a_visits.partial_cmp(&node_b_visits).unwrap() + }) + .unwrap(); + + arena.get_node(*best_child_id).action.clone() +} + +fn highest_value(root_node: &Node, arena: &Arena) -> Option { + let player = root_node.state.get_current_player(); + let best_child_id: &usize = root_node + .children + .iter() + .max_by(|&a, &b| { + let node_a_score = arena.get_node(*a).reward_average(player); + let node_b_score = arena.get_node(*b).reward_average(player); + node_a_score.partial_cmp(&node_b_score).unwrap() + }) + .unwrap(); + + arena.get_node(*best_child_id).action.clone() +} diff --git a/src/policy/mod.rs b/src/policy/mod.rs new file mode 100644 index 0000000..2f9161b --- /dev/null +++ b/src/policy/mod.rs @@ -0,0 +1,4 @@ +pub mod backprop; +pub mod decision; +pub mod selection; +pub mod simulation; diff --git a/src/policy/selection/mod.rs b/src/policy/selection/mod.rs new file mode 100644 index 0000000..cbd3a89 --- /dev/null +++ b/src/policy/selection/mod.rs @@ -0,0 +1,59 @@ +mod ucb1; +mod ucb1_tuned; + +use crate::state::GameState; +use crate::tree::arena::Arena; +use crate::tree::node::Node; + +/// The selection policy used in the MCTS selection phase +/// +/// This drives the selection of the nodes in the search tree, determining +/// which paths are explored and evaluated. +/// +/// In general UCB1-Tuned or UCB1 should be effective, however if necessariy +/// a custom selection policy can be provided. +#[derive(Debug)] +pub enum SelectionPolicy { + /// Upper Confidence Bound 1 (UCB1) with the given exploration constant + /// + /// The exploration constant controls the balance between exploration and + /// exploitation. The higher the value, the mroe likely the search will + /// explore less-visited nodes. A standard value is √2 ≈ 1.414. + UCB1(f64), + + /// Upper Confidence Bound 1 Tuned (UCB1-Tuned) + /// + /// A tuned version of UCB1 instead using the empirical + /// standard deviation of the rewards to drive exploration. + /// + /// Auer, P., Cesa-Bianchi, N. & Fischer, P. Finite-time Analysis of the Multiarmed Bandit Problem. Machine Learning 47, 235–256 (2002). https://doi.org/10.1023/A:1013689704352 + UCB1Tuned(f64), + + /// Custom selection policy + Custom(Box>), +} + +/// Trait for an object implementing the selection logic whene exploring the MCTS +/// search tree. +/// +/// The policy should select the child of the given node which is "best" for the current player +pub trait CustomSelectionPolicy: std::fmt::Debug { + /// Selects a child based on the policy, returning the node ID + fn select_child(&self, node: &Node, arena: &Arena) -> usize; +} + +pub fn select_best_child( + policy: &SelectionPolicy, + node: &Node, + arena: &Arena, +) -> usize { + match policy { + SelectionPolicy::UCB1(exploration_constant) => { + ucb1::select_best_child(*exploration_constant, node, arena) + } + SelectionPolicy::UCB1Tuned(exploration_constant) => { + ucb1_tuned::select_best_child(*exploration_constant, node, arena) + } + SelectionPolicy::Custom(custom_policy) => custom_policy.select_child(node, arena), + } +} diff --git a/src/policy/selection/ucb1.rs b/src/policy/selection/ucb1.rs new file mode 100644 index 0000000..9a5c2f5 --- /dev/null +++ b/src/policy/selection/ucb1.rs @@ -0,0 +1,79 @@ +//! Upper Confidence Bound 1 (UCB1) selection policy +//! +//! This is the classic selection policy for MCTS, which balances +//! exploration and exploitation using the UCB1 formula: +//! +//! ```text +//! UCB1 = average_reward + exploration_constant * sqrt(ln(parent_visits) / child_visits) +//! ``` +//! +//! Where: +//! - `average_reward` is the average reward from simulations through this node +//! - `exploration_constant` controls the balance between exploration and exploitation +//! - `parent_visits` is the number of visits to the parent node +//! - `child_visits` is the number of visits to the child node +//! +//! Higher exploration constants favor exploration (trying less-visited nodes), +//! while lower values favor exploitation (choosing nodes with higher values). +//! +//! The commonly used value for the exploration constant is sqrt(2) ≈ 1.414, +//! which is the default in this implementation. + +use crate::state::GameState; +use crate::tree::arena::Arena; +use crate::tree::node::{Node, RewardVal}; + +/// Selects the index of the "best" child using the UCB1 selection policy +pub fn select_best_child( + exploration_constant: f64, + node: &Node, + arena: &Arena, +) -> usize { + if node.is_leaf() { + panic!("select_best_child called on leaf node"); + } + + let player = node.state.get_current_player(); + let parent_visits = node.visits; + let best_child = node + .children + .iter() + .max_by(|&a, &b| { + let node_a = arena.get_node(*a); + let node_b = arena.get_node(*b); + let ucb_a = ucb1_value( + exploration_constant, + node_a.reward_average(player), + node_a.visits, + parent_visits, + ); + let ucb_b = ucb1_value( + exploration_constant, + node_b.reward_average(player), + node_b.visits, + parent_visits, + ); + ucb_a.partial_cmp(&ucb_b).unwrap() + }) + .unwrap(); + *best_child +} + +/// Calculates the UCB1 value for a node +pub fn ucb1_value( + exploration_constant: f64, + child_value: RewardVal, + child_visits: u64, + parent_visits: u64, +) -> RewardVal { + if child_visits == 0 { + return f64::INFINITY; // Always explore nodes that have never been visited + } + + // UCB1 formula: value + C * sqrt(ln(parent_visits) / child_visits) + let exploitation = child_value; + let exploration = + exploration_constant * ((parent_visits as f64).ln() / child_visits as f64).sqrt(); + + exploitation + exploration +} diff --git a/src/policy/selection/ucb1_tuned.rs b/src/policy/selection/ucb1_tuned.rs new file mode 100644 index 0000000..971b601 --- /dev/null +++ b/src/policy/selection/ucb1_tuned.rs @@ -0,0 +1,97 @@ +//! Upper Confidence Bound 1 Tuned (UCB1-Tuned) selection policy +//! +//! This is a fine-tuned version of UCB which takes into account the +//! empircally measured variance of the rewards to drive the exploration. +//! +//! This has been found to perform substantially better than UCB1 in most +//! situations. +//! +//! Auer, P., Cesa-Bianchi, N. & Fischer, P. +//! Finite-time Analysis of the Multiarmed Bandit Problem. +//! Machine Learning 47, 235–256 (2002). https://doi.org/10.1023/A:1013689704352 + +use crate::state::GameState; +use crate::tree::arena::Arena; +use crate::tree::node::{Node, RewardVal}; + +/// Selects the index of the "best" child using the UCB1-Tuned selection policy +pub fn select_best_child( + exploration_constant: f64, + node: &Node, + arena: &Arena, +) -> usize { + if node.is_leaf() { + panic!("select_best_child called on leaf node"); + } + + let player = node.state.get_current_player(); + let parent_visits = node.visits; + let best_child = node + .children + .iter() + .max_by(|&a, &b| { + let node_a = arena.get_node(*a); + let node_b = arena.get_node(*b); + let ucb_a = ucb1_tuned_value( + exploration_constant, + parent_visits, + node_a.visits, + node_a.rewards(player), + node_a.reward_average(player), + ); + let ucb_b = ucb1_tuned_value( + exploration_constant, + parent_visits, + node_b.visits, + node_b.rewards(player), + node_b.reward_average(player), + ); + ucb_a.partial_cmp(&ucb_b).unwrap() + }) + .unwrap(); + *best_child +} + +/// Calculates the UCB1-Tuned value for a node +pub fn ucb1_tuned_value( + exploration_constant: f64, + parent_visits: u64, + child_visits: u64, + child_rewards: Option<&Vec>, + reward_avg: RewardVal, +) -> RewardVal { + match child_rewards { + None => { + RewardVal::INFINITY // Always explore nodes that have never been visited + } + Some(child_rewards) => { + if child_visits == 0 { + RewardVal::INFINITY // Always explore nodes that have never been visited + } else { + let parent_visits: RewardVal = parent_visits as RewardVal; + let child_visits: RewardVal = child_visits as RewardVal; + + // N: number of visits to the parent node + // n: number of visits to the child node + // x_i: reward of the ith visit to the child node + // X: average reward of the child + // C: exploration constant + // + // UCB1-Tuned = X + C * sqrt(Ln(parent_visits) / child_visits * min(1/4, V_n) + // V(n) = sum(x_i^2)/n - X^2 + sqrt(2*ln(N)/n) + let exploitation = reward_avg; + let mut variance = (child_rewards.iter().map(|&x| x * x).sum::() + / child_visits) + - (reward_avg * reward_avg) + + (2.0 * parent_visits.ln() / child_visits).sqrt(); + if variance > 0.25 { + variance = 0.25; + } + let exploration = + exploration_constant * (parent_visits.ln() / child_visits * variance).sqrt(); + + exploitation + exploration + } + } + } +} diff --git a/src/policy/simulation/mod.rs b/src/policy/simulation/mod.rs new file mode 100644 index 0000000..9f5b97e --- /dev/null +++ b/src/policy/simulation/mod.rs @@ -0,0 +1,44 @@ +mod random; + +use crate::state::GameState; +use crate::tree::arena::Arena; +use crate::tree::node::{Node, RewardVal}; +use std::collections::HashMap; + +/// The simulation policy used in the MCTS simulation phase +/// +/// This policy drives the game simulations while evaluating the tree. While +/// a random policy works well, a game-specific policy can be provided either +/// as a custom policy. +#[derive(Debug)] +pub enum SimulationPolicy { + /// Random simulation policy + /// + /// The sequential actions are selected randomly from the available actions + /// at each state until a terminal state is found. + Random, + + /// Custom simulation policy + Custom(Box>), +} + +/// Trait for an object implementing the simulation logic whene exploring the MCTS +/// search tree. +pub trait CustomSimulationPolicy: std::fmt::Debug { + /// Simulates the gameplay from the current node onward, returning the rewards + /// + /// This should simulate the game until a terminal node is reached, returning + /// the final reward for each player at the terminal node + fn simulate(&self, node: &Node, arena: &Arena) -> HashMap; +} + +pub fn simulate_reward( + policy: &SimulationPolicy, + node: &Node, + arena: &Arena, +) -> HashMap { + match policy { + SimulationPolicy::Random => random::simulate(node), + SimulationPolicy::Custom(custom_policy) => custom_policy.simulate(node, arena), + } +} diff --git a/src/policy/simulation/random.rs b/src/policy/simulation/random.rs new file mode 100644 index 0000000..a47275d --- /dev/null +++ b/src/policy/simulation/random.rs @@ -0,0 +1,18 @@ +//! Random play simulation policy +//! +//! Actions are chosen at random + +use crate::state::GameState; +use crate::tree::node::{Node, RewardVal}; +use rand::prelude::SliceRandom; +use std::collections::HashMap; + +pub fn simulate(node: &Node) -> HashMap { + let mut state: S = node.state.clone(); + while !state.is_terminal() { + let legal_actions = state.get_legal_actions(); + let action = legal_actions.choose(&mut rand::thread_rng()).unwrap(); + state = state.state_after_action(&action); + } + state.rewards_for_players() +} diff --git a/src/state.rs b/src/state.rs new file mode 100644 index 0000000..047ed6e --- /dev/null +++ b/src/state.rs @@ -0,0 +1,90 @@ +use crate::tree::node::RewardVal; +use std::collections::HashMap; +use std::fmt::Debug; +use std::hash::Hash; + +/// Trait for the game state used in MCTS +/// +/// When leveraging MCTS for your game, you must implement this trait to provide +/// the specifics for your game. +pub trait GameState: Clone { + /// The type of actions that can be taken in the game + type Action: Action; + + /// The type of players in the game + type Player: Player; + + /// Returns if the game state is terminal, i.e. the game is over + /// + /// A game state is terminal when no other actions are possible. This can be + /// the result of a player winning, a draw, or because some other conditions + /// have been met leading to a game with no further possible states. + /// + /// The default implementation returns True if `get_legal_actions()` returns + /// an empty list. It is recommended to override this for a more efficient + /// implementation if possible. + fn is_terminal(&self) -> bool { + let actions = self.get_legal_actions(); + actions.len() == 0 + } + + /// Returns the list of legal actions for the game state + /// + /// This method must return all possible actions that can be made from the + /// current game state. + fn get_legal_actions(&self) -> Vec; + + /// Returns the game state resulting from applying the action to the state + /// + /// This function should not modify the current state directly, and + /// instead should modify a copy of the state and return that. + fn state_after_action(&self, action: &Self::Action) -> Self; + + /// Returns the reward from the perspective of the given player for the game state + /// + /// This evaluates the current state from the perspective of the given player, and + /// returns the reward indicating how good of a result the given state is for the + /// player. + /// + /// This is used in the MCTS backpropagation and simulation phases to evaluate + /// the value of a given node in the search tree. + /// + /// A general rule of thumb for values are: + /// - 1.0 => a win for the player + /// - 0.5 => a draw + /// - 0.0 => a loss for the player + /// + /// Other values can be used for relative wins or losses + fn reward_for_player(&self, player: &Self::Player) -> RewardVal; + + /// Returns the rewards for all players at the current state + fn rewards_for_players(&self) -> HashMap; + + /// Returns the player whose turn it is for the game state + /// + /// This is used for evaluating the state, so for simultaneous games + /// consider the "current player" as the one from whose perspective we are + /// evaluating the game state from + fn get_current_player(&self) -> &Self::Player; +} + +/// Trait used for actions that can be taken in a game +/// +/// An action is dependent upon the specific game being defined, and includes +/// things like moves, attacks, and other decisions. +pub trait Action: Clone + Debug { + /// Returns a uniqie identifier for this action + fn id(&self) -> usize; +} + +/// Trait used for players participating in a game +pub trait Player: Clone + Debug + PartialEq + Eq + Hash {} + +/// Convenience implemnentation of a Player for usize +impl Player for usize {} + +/// Convenience implemnentation of a Player for char +impl Player for char {} + +/// Convenience implemnentation of a Player for String +impl Player for String {} diff --git a/src/tree/arena.rs b/src/tree/arena.rs new file mode 100644 index 0000000..6dd1380 --- /dev/null +++ b/src/tree/arena.rs @@ -0,0 +1,46 @@ +use crate::state::GameState; +use crate::tree::node::Node; + +/// An arena for Node allocation +/// +/// We use an arena for node allocation to improve performance of our search. +/// The memory is contiguous which allows for faster movement through the tree, +/// as well as more efficient destruction as our MCTS search will destroy the +/// entire tree at once. +pub struct Arena { + pub nodes: Vec>, +} + +impl Arena { + /// Create a new Arena with the given initial capacity + /// + /// The arena creates a contiguous block. By reserving an initial capacity + /// that is sufficient to encapsulate a full search tree we can reduce the + /// number of reallocs that are required. This number is highly game + /// dependent. + pub fn new(initial_capacity: usize) -> Self { + Arena { + nodes: Vec::with_capacity(initial_capacity), + } + } + + /// Adds a node to the Arena, returning its identifier + /// + /// This appends the node to the allocated Arena, and returns the nodes + /// index in the arena which is used as an identifier for later retrieval. + pub fn add_node(&mut self, node: Node) -> usize { + let id = self.nodes.len(); + self.nodes.push(node); + id + } + + /// Retrieves a mutable reference to a Node in the Arena + pub fn get_node_mut(&mut self, id: usize) -> &mut Node { + &mut self.nodes[id] + } + + /// Retrieves a reference to a Node in the Arena + pub fn get_node(&self, id: usize) -> &Node { + &self.nodes[id] + } +} diff --git a/src/tree/mod.rs b/src/tree/mod.rs new file mode 100644 index 0000000..9c02d29 --- /dev/null +++ b/src/tree/mod.rs @@ -0,0 +1,2 @@ +pub mod arena; +pub mod node; diff --git a/src/tree/node.rs b/src/tree/node.rs new file mode 100644 index 0000000..36522fd --- /dev/null +++ b/src/tree/node.rs @@ -0,0 +1,114 @@ +use std::collections::HashMap; +use std::fmt::Debug; + +use crate::state::GameState; + +/// The type used for reward values +pub type RewardVal = f64; + +/// A node in the MCTS tree +/// +/// A node represents a given game state and, using the path from the root node, +/// the actions that led to the given state. A node has a number of children +/// nodes representing the game states reachable from the given state, after +/// a given action. This creates the tree that MCTS iterates through. +/// +/// This class is not thread safe, as the library does not provide for parallel +/// search. +#[derive(Debug)] +pub struct Node { + /// The game state at the given node, after `action` + pub state: S, + + /// The action that led to this state from its parent + pub action: Option, + + /// The identifier of the parent Node + pub parent: Option, + + /// The number of times this node has been visited + pub visits: u64, + + /// The player's evaluation of the node + pub player_view: HashMap, + + /// The identifiers of children nodes, states reachable from this one + pub children: Vec, +} + +impl Node { + pub fn new(state: S, action: Option, parent: Option) -> Self { + Node { + state, + action, + parent, + visits: 0, + player_view: HashMap::with_capacity(2), + children: Vec::new(), + } + } + + pub fn is_leaf(&self) -> bool { + self.children.is_empty() + } + + pub fn reward_sum(&self, player: &S::Player) -> RewardVal { + match self.player_view.get(player) { + Some(pv) => pv.reward_sum, + None => 0.0, + } + } + + pub fn reward_average(&self, player: &S::Player) -> RewardVal { + match self.player_view.get(player) { + Some(pv) => pv.reward_average, + None => 0.0, + } + } + + pub fn rewards(&self, player: &S::Player) -> Option<&Vec> { + match self.player_view.get(player) { + Some(pv) => Some(&pv.rewards), + None => None, + } + } + + pub fn increment_visits(&mut self) { + self.visits += 1 + } + + pub fn record_player_reward(&mut self, player: S::Player, reward: RewardVal) { + let pv = self + .player_view + .entry(player) + .or_insert(PlayerNodeView::default()); + pv.rewards.push(reward); + pv.reward_sum += reward; + pv.reward_average = pv.reward_sum / pv.rewards.len() as f64; + } +} + +/// A player's specific perspective of a node's value +/// +/// Each player has their own idea of the value of a node. +#[derive(Debug)] +pub struct PlayerNodeView { + /// The total reward from simulations through this node + pub reward_sum: RewardVal, + + /// The average reward from simulations through this node, often called the node value + pub reward_average: RewardVal, + + /// The rewards we have gotten so far for simulations through this node + pub rewards: Vec, +} + +impl Default for PlayerNodeView { + fn default() -> Self { + PlayerNodeView { + reward_sum: 0.0, + reward_average: 0.0, + rewards: Vec::new(), + } + } +}