Working MCTS implementation

This is a basic working implementation of the MCTS algorithm. Though
currently the algorithm is slow compared with other implemenations, and
makes sub-optimal choices when playing tic-tac-toe. Therefore some
modifications are needed
This commit is contained in:
David Kruger 2025-06-23 13:46:04 -07:00
parent 197a46996a
commit 17884f4b90
18 changed files with 1416 additions and 0 deletions

154
Cargo.lock generated Normal file
View File

@ -0,0 +1,154 @@
# This file is automatically @generated by Cargo.
# It is not intended for manual editing.
version = 4
[[package]]
name = "cfg-if"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9555578bc9e57714c812a1f84e4fc5b4d21fcb063490c624de019f7464c91268"
[[package]]
name = "getrandom"
version = "0.2.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "335ff9f135e4384c8150d6f27c6daed433577f86b4750418338c01a1a2528592"
dependencies = [
"cfg-if",
"libc",
"wasi",
]
[[package]]
name = "libc"
version = "0.2.174"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1171693293099992e19cddea4e8b849964e9846f4acee11b3948bcc337be8776"
[[package]]
name = "ppv-lite86"
version = "0.2.21"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "85eae3c4ed2f50dcfe72643da4befc30deadb458a9b590d720cde2f2b1e97da9"
dependencies = [
"zerocopy",
]
[[package]]
name = "proc-macro2"
version = "1.0.95"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "02b3e5e68a3a1a02aad3ec490a98007cbc13c37cbe84a3cd7b8e406d76e7f778"
dependencies = [
"unicode-ident",
]
[[package]]
name = "quote"
version = "1.0.40"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1885c039570dc00dcb4ff087a89e185fd56bae234ddc7f056a945bf36467248d"
dependencies = [
"proc-macro2",
]
[[package]]
name = "rand"
version = "0.8.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404"
dependencies = [
"libc",
"rand_chacha",
"rand_core",
]
[[package]]
name = "rand_chacha"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88"
dependencies = [
"ppv-lite86",
"rand_core",
]
[[package]]
name = "rand_core"
version = "0.6.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c"
dependencies = [
"getrandom",
]
[[package]]
name = "rustic_mcts"
version = "0.1.0"
dependencies = [
"rand",
"thiserror",
]
[[package]]
name = "syn"
version = "2.0.103"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e4307e30089d6fd6aff212f2da3a1f9e32f3223b1f010fb09b7c95f90f3ca1e8"
dependencies = [
"proc-macro2",
"quote",
"unicode-ident",
]
[[package]]
name = "thiserror"
version = "2.0.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "567b8a2dae586314f7be2a752ec7474332959c6460e02bde30d702a66d488708"
dependencies = [
"thiserror-impl",
]
[[package]]
name = "thiserror-impl"
version = "2.0.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7f7cf42b4507d8ea322120659672cf1b9dbb93f8f2d4ecfd6e51350ff5b17a1d"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "unicode-ident"
version = "1.0.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5a5f39404a5da50712a4c1eecf25e90dd62b613502b7e925fd4e4d19b5c96512"
[[package]]
name = "wasi"
version = "0.11.1+wasi-snapshot-preview1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ccf3ec651a847eb01de73ccad15eb7d99f80485de043efb2f370cd654f4ea44b"
[[package]]
name = "zerocopy"
version = "0.8.26"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1039dd0d3c310cf05de012d8a39ff557cb0d23087fd44cad61df08fc31907a2f"
dependencies = [
"zerocopy-derive",
]
[[package]]
name = "zerocopy-derive"
version = "0.8.26"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9ecf5b4cc5364572d7f4c329661bcc82724222973f2cab6f050a4e5c22f75181"
dependencies = [
"proc-macro2",
"quote",
"syn",
]

16
Cargo.toml Normal file
View File

@ -0,0 +1,16 @@
[package]
name = "rustic_mcts"
version = "0.1.0"
edition = "2021"
authors = ["David Kruger <david@krugerlabs.us>"]
description = "An extensible implementation of Monte Carlo Tree Search (MCTS) using an arena allocator."
license = "MIT"
repository = "https://gitlabs.krugerlabs.us/krugd/rustic_mcts"
readme = "README.md"
keywords = ["mcts", "rust", "monte_carlo", "tree", "ai", "ml"]
categories = ["algorithms", "data-structures"]
[dependencies]
rand = "~0.8"
thiserror = "~2.0"

298
examples/tic_tac_toe.rs Normal file
View File

@ -0,0 +1,298 @@
use std::collections::HashMap;
use std::fmt;
use std::io::{self, Write};
use rustic_mcts::policy::backprop::BackpropagationPolicy;
use rustic_mcts::policy::decision::DecisionPolicy;
use rustic_mcts::policy::selection::SelectionPolicy;
use rustic_mcts::policy::simulation::SimulationPolicy;
use rustic_mcts::{Action, GameState, MCTSConfig, RewardVal, MCTS};
fn main() {
println!("MCTS Tic-Tac-Toe Example");
println!("========================");
println!();
// Set up a new game
let mut game = TicTacToe::new();
// Create MCTS configuration
let config = MCTSConfig {
max_iterations: 10_000,
max_time: None,
tree_size_allocation: 10_000,
selection_policy: SelectionPolicy::UCB1Tuned(1.414),
simulation_policy: SimulationPolicy::Random,
backprop_policy: BackpropagationPolicy::Standard,
decision_policy: DecisionPolicy::MostVisits,
};
// Main game loop
while !game.is_terminal() {
// Display the board
println!("{}", game);
if game.current_player == Player::X {
// Human player (X)
println!("Your move (enter row column, e.g. '1 2'): ");
io::stdout().flush().unwrap();
let mut input = String::new();
io::stdin().read_line(&mut input).unwrap();
let coords: Vec<usize> = input
.trim()
.split_whitespace()
.filter_map(|s| s.parse::<usize>().ok())
.collect();
if coords.len() != 2 || coords[0] > 2 || coords[1] > 2 {
println!("Invalid move! Enter row and column (0-2).");
continue;
}
let row = coords[0];
let col = coords[1];
let move_index = row * 3 + col;
let action = Move { index: move_index };
if !game.is_legal_move(&action) {
println!("Illegal move! Try again.");
continue;
}
// Apply the human's move
game = game.state_after_action(&action);
} else {
// AI player (O)
println!("AI is thinking...");
// Create a new MCTS search
let mut mcts = MCTS::new(game.clone(), &config);
// Find the best move
match mcts.search() {
Ok(action) => {
println!(
"AI chooses: {} (row {}, col {})",
action.index,
action.index / 3,
action.index % 3
);
// Apply the AI's move
game = game.state_after_action(&action);
}
Err(e) => {
println!("Error: {:?}", e);
break;
}
}
}
}
// Display final state
println!("{}", game);
// Report the result
if let Some(winner) = game.get_winner() {
println!("Player {:?} wins!", winner);
} else {
println!("The game is a draw!");
}
}
/// Players in Tic-Tac-Toe
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
enum Player {
X,
O,
}
impl rustic_mcts::Player for Player {}
/// Tic-Tac-Toe move
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
struct Move {
/// Board position index (0-8)
index: usize,
}
impl Action for Move {
fn id(&self) -> usize {
self.index
}
}
/// Tic-Tac-Toe game state
#[derive(Clone)]
struct TicTacToe {
/// Board representation (None = empty, Some(Player) = occupied)
board: [Option<Player>; 9],
/// Current player's turn
current_player: Player,
/// Number of moves played so far
moves_played: usize,
}
impl TicTacToe {
/// Creates a new empty Tic-Tac-Toe board
fn new() -> Self {
TicTacToe {
board: [None; 9],
current_player: Player::X,
moves_played: 0,
}
}
/// Checks if a move is legal
fn is_legal_move(&self, action: &Move) -> bool {
if action.index >= 9 {
return false;
}
self.board[action.index].is_none()
}
/// Returns the winner of the game, if any
fn get_winner(&self) -> Option<Player> {
// Check rows
for row in 0..3 {
let i = row * 3;
if self.board[i].is_some()
&& self.board[i] == self.board[i + 1]
&& self.board[i] == self.board[i + 2]
{
return self.board[i];
}
}
// Check columns
for col in 0..3 {
if self.board[col].is_some()
&& self.board[col] == self.board[col + 3]
&& self.board[col] == self.board[col + 6]
{
return self.board[col];
}
}
// Check diagonals
if self.board[0].is_some()
&& self.board[0] == self.board[4]
&& self.board[0] == self.board[8]
{
return self.board[0];
}
if self.board[2].is_some()
&& self.board[2] == self.board[4]
&& self.board[2] == self.board[6]
{
return self.board[2];
}
None
}
}
impl GameState for TicTacToe {
type Action = Move;
type Player = Player;
fn get_legal_actions(&self) -> Vec<Self::Action> {
let mut actions = Vec::new();
for i in 0..9 {
if self.board[i].is_none() {
actions.push(Move { index: i });
}
}
actions
}
fn state_after_action(&self, action: &Self::Action) -> Self {
let mut new_state = self.clone();
// Make the move
new_state.board[action.index] = Some(self.current_player);
new_state.moves_played = self.moves_played + 1;
// Switch player
new_state.current_player = match self.current_player {
Player::X => Player::O,
Player::O => Player::X,
};
new_state
}
fn is_terminal(&self) -> bool {
self.get_winner().is_some() || self.moves_played == 9
}
fn reward_for_player(&self, player: &Self::Player) -> RewardVal {
if let Some(winner) = self.get_winner() {
if winner == *player {
return 1.0; // Win
} else {
return 0.0; // Loss
}
}
// Draw
0.5
}
fn rewards_for_players(&self) -> HashMap<Self::Player, RewardVal> {
HashMap::from_iter(vec![
(Player::X, self.reward_for_player(&Player::X)),
(Player::O, self.reward_for_player(&Player::O)),
])
}
fn get_current_player(&self) -> &Self::Player {
&self.current_player
}
}
impl fmt::Display for TicTacToe {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
writeln!(f, " 0 1 2")?;
for row in 0..3 {
write!(f, "{} ", row)?;
for col in 0..3 {
let index = row * 3 + col;
let symbol = match self.board[index] {
Some(Player::X) => "X",
Some(Player::O) => "O",
None => ".",
};
write!(f, "{} ", symbol)?;
}
writeln!(f)?;
}
writeln!(f, "\nPlayer {:?}'s turn", self.current_player)?;
Ok(())
}
}
impl fmt::Debug for TicTacToe {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "\n")?;
for row in 0..3 {
for col in 0..3 {
let index = row * 3 + col;
let symbol = match self.board[index] {
Some(Player::X) => "X",
Some(Player::O) => "O",
None => ".",
};
write!(f, "{} ", symbol)?;
}
writeln!(f)?;
}
Ok(())
}
}

67
src/config.rs Normal file
View File

@ -0,0 +1,67 @@
use crate::policy::backprop::BackpropagationPolicy;
use crate::policy::decision::DecisionPolicy;
use crate::policy::selection::SelectionPolicy;
use crate::policy::simulation::SimulationPolicy;
use crate::state::GameState;
use std::time::Duration;
/// Configuration for the MCTS algorithm
#[derive(Debug)]
pub struct MCTSConfig<S: GameState> {
/// The maximum number of iterations to run when searching
///
/// The search will stop after the given number of iterations, even if there
/// is search time has not exceeded `max_time`.
pub max_iterations: usize,
/// The maximum time to run the search
///
/// If set, the search will stop after this duration even if the maximum
/// iterations hasn't been reached.
pub max_time: Option<Duration>,
/// The size to initially allocate for the search tree
///
/// This pre-allocates memory for the search tree which ensures contiguous
/// memory and improves performance by preventing the resizing of tree
/// as we explore.
pub tree_size_allocation: usize,
/// The selection policy
///
/// This dictates the path through which the game tree is searched. As such
/// the policy has a large impact on the overall aglorthm exeuction
pub selection_policy: SelectionPolicy<S>,
/// The simulation policy
///
/// This dictates the game siluation when expanding and evaluating the
/// search tree. Random is generally a good default.
pub simulation_policy: SimulationPolicy<S>,
/// The backpropagation policy
///
/// This dictates how the results of the simulation playouts are propagated
/// back up the tree.
pub backprop_policy: BackpropagationPolicy<S>,
/// The decision policy
///
/// This dictates how the MCTS algorithm determines its final decision
/// after iterating through the search tree
pub decision_policy: DecisionPolicy,
}
impl<S: GameState> Default for MCTSConfig<S> {
fn default() -> Self {
MCTSConfig {
max_iterations: 10_000,
max_time: None,
tree_size_allocation: 10_000,
selection_policy: SelectionPolicy::UCB1Tuned(1.414),
simulation_policy: SimulationPolicy::Random,
backprop_policy: BackpropagationPolicy::Standard,
decision_policy: DecisionPolicy::MostVisits,
}
}
}

17
src/lib.rs Normal file
View File

@ -0,0 +1,17 @@
//! # rustic_mcts
//!
//! An extensible implementation of Monte Carlo Tree Search (MCTS) using arena allocation and
//! configurable policies.
pub mod config;
pub mod mcts;
pub mod policy;
pub mod state;
pub mod tree;
pub use config::MCTSConfig;
pub use mcts::MCTS;
pub use state::Action;
pub use state::GameState;
pub use state::Player;
pub use tree::node::RewardVal;

147
src/mcts.rs Normal file
View File

@ -0,0 +1,147 @@
use crate::config::MCTSConfig;
use crate::policy::backprop::backpropagate_rewards;
use crate::policy::decision::decide_on_action;
use crate::policy::selection::select_best_child;
use crate::policy::simulation::simulate_reward;
use crate::state::GameState;
use crate::tree::arena::Arena;
use crate::tree::node::{Node, RewardVal};
use rand::prelude::SliceRandom;
use std::collections::HashMap;
use std::time::Instant;
/// Monte Carlo Tree Search implementation
///
/// This provides the interface for performing optimal searches on a tree using
/// the MCTS algorithm.
pub struct MCTS<'conf, S: GameState> {
/// The arena used for the tree
arena: Arena<S>,
/// The identifier of the root node of the search tree
root_id: usize,
/// The configuration used for the search
config: &'conf MCTSConfig<S>,
}
impl<'conf, S: GameState + std::fmt::Debug> MCTS<'conf, S> {
/// Creates a new instance with the given initial state and configuration
pub fn new(initial_state: S, config: &'conf MCTSConfig<S>) -> Self {
let mut arena: Arena<S> = Arena::new(config.tree_size_allocation);
let root: Node<S> = Node::new(initial_state.clone(), None, None);
let root_id: usize = arena.add_node(root);
MCTS {
arena,
root_id,
config,
}
}
/// Runs the MCTS algorithm, returning the "best" action
///
/// The search will stop once `max_iterations` or `max_time` from
/// the assigned configration is reached.
pub fn search(&mut self) -> Result<S::Action> {
self.search_for_iterations(self.config.max_iterations)
}
/// Runs the MCTS algorithm, returning the "best" action after the given iterations
///
/// This ignores the `max_iterations` provided in the config, however will
/// return if `max_time` is specific and reached before the iterations are complete.
pub fn search_for_iterations(&mut self, iterations: usize) -> Result<S::Action> {
let start_time = Instant::now();
for _ in 0..iterations {
match self.config.max_time {
Some(max_time) => {
if start_time.elapsed() >= max_time {
break; // ending early due to time
}
}
None => {}
}
self.execute_iteration()?;
}
self.best_action()
}
/// Runs the MCTS algorithm for a single iteration
fn execute_iteration(&mut self) -> Result<()> {
let mut selected_id: usize = self.select();
let selected_node: &Node<S> = self.arena.get_node(selected_id);
if !selected_node.state.is_terminal() {
self.expand(selected_id);
let children: &Vec<usize> = &self.arena.get_node(selected_id).children;
let random_child: usize = *children.choose(&mut rand::thread_rng()).unwrap();
selected_id = random_child;
}
let rewards = self.simulate(selected_id);
self.backprop(selected_id, &rewards);
Ok(())
}
/// MCTS Phase 1: Selection - Find the "best" node to expand
fn select(&mut self) -> usize {
let mut current_id: usize = self.root_id;
loop {
let node = &self.arena.get_node(current_id);
if node.is_leaf() || node.state.is_terminal() {
return current_id;
}
current_id = select_best_child(&self.config.selection_policy, &node, &self.arena);
}
}
/// MCTS Phase 2: Expansion - Expand the selected node on the tree
fn expand(&mut self, id: usize) {
let parent: &Node<S> = self.arena.get_node_mut(id);
let legal_actions: Vec<S::Action> = parent.state.get_legal_actions();
let parent_state: S = parent.state.clone();
for action in legal_actions {
let state = parent_state.state_after_action(&action);
let new_node = Node::new(state, Some(action), Some(id));
let new_id = self.arena.add_node(new_node);
self.arena.get_node_mut(id).children.push(new_id);
}
}
fn simulate(&self, id: usize) -> HashMap<S::Player, RewardVal> {
let node = &self.arena.get_node(id);
simulate_reward(&self.config.simulation_policy, &node, &self.arena)
}
fn backprop(&mut self, selected_id: usize, rewards: &HashMap<S::Player, RewardVal>) {
backpropagate_rewards(
&self.config.backprop_policy,
selected_id,
&mut self.arena,
&rewards,
)
}
fn best_action(&self) -> Result<S::Action> {
let root_node: &Node<S> = self.arena.get_node(self.root_id);
match decide_on_action(&self.config.decision_policy, &root_node, &self.arena) {
Some(action) => Ok(action),
None => Err(MCTSError::NoBestAction),
}
}
}
/// Errors returned by the MCTS algorithm
#[derive(Debug, thiserror::Error)]
pub enum MCTSError {
/// The best action doesn't exist
#[error("Unable to determine a best action for the game")]
NoBestAction,
/// The search tree was exhausted without finding a terminal node
#[error("Search tree exhausted without finding terminal node")]
NonTerminalGame,
}
/// Result returned by the MCTS algorithm
pub type Result<T> = std::result::Result<T, MCTSError>;

View File

@ -0,0 +1,99 @@
use crate::state::GameState;
use crate::tree::arena::Arena;
use crate::tree::node::RewardVal;
use std::collections::HashMap;
/// The back propagation policy dictating the propagation of playout results
///
/// This policy drives how the backpropagation phase of the MCTS algorithm is
/// executed, allowing for some minor customization.
///
/// Typically the Standard policy, used by most implementaions of MCTS, is
/// sufficient
#[derive(Debug)]
pub enum BackpropagationPolicy<S: GameState> {
/// Standard back propagation
///
/// This increments the visitation count and adds the simulated rewards
/// results to the aggregate values.
///
/// This is the standard policy used in most MCTS implementations.
Standard,
/// Weighted back propagation
///
/// This weights the value of the simulated rewards based on the depth,
/// allowing us to put more-or-less influence on deeper branches
/// - Positive weight factor makes deeper nodes less influential
/// - Negative weight factor makes deeper nodes more influential
Weighted(f64),
/// Custom backpropagation policy
Custom(Box<dyn CustomBackpropagationPolicy<S>>),
}
/// Trait for an object implementing the backpropagation logic whene exploring the MCTS
/// search tree.
pub trait CustomBackpropagationPolicy<S: GameState>: std::fmt::Debug {
/// Backpropagate the given rewards values from the node up the tree
fn backprop(
&self,
node_id: usize,
arena: &mut Arena<S>,
rewards: &HashMap<S::Player, RewardVal>,
);
}
pub fn backpropagate_rewards<S: GameState>(
policy: &BackpropagationPolicy<S>,
node_id: usize,
arena: &mut Arena<S>,
rewards: &HashMap<S::Player, RewardVal>,
) {
match policy {
BackpropagationPolicy::Standard => standard_backprop(node_id, arena, rewards),
BackpropagationPolicy::Weighted(depth_factor) => {
weighted_backprop(*depth_factor, node_id, arena, rewards)
}
BackpropagationPolicy::Custom(custom_policy) => {
custom_policy.backprop(node_id, arena, rewards)
}
}
}
fn standard_backprop<S: GameState>(
node_id: usize,
arena: &mut Arena<S>,
rewards: &HashMap<S::Player, RewardVal>,
) {
// TODO:
// - each node needs the perspective of the different players not just one view
// - e.g. reward_sum(player), reward_avg(player), rewards(player)[], visits(player)
// - we could make special version for 2-player zero-sum games like below
let mut current_id: usize = node_id;
loop {
let node = arena.get_node_mut(current_id);
let player = node.state.get_current_player().clone();
match rewards.get(&player) {
Some(reward) => {
node.increment_visits();
node.record_player_reward(player, *reward);
if let Some(parent_id) = node.parent {
current_id = parent_id;
} else {
break;
}
}
None => (),
}
}
}
fn weighted_backprop<S: GameState>(
_depth_factor: f64,
_node_id: usize,
_arena: &mut Arena<S>,
_rewards: &HashMap<S::Player, RewardVal>,
) {
// TODO
}

View File

@ -0,0 +1,65 @@
use crate::state::GameState;
use crate::tree::arena::Arena;
use crate::tree::node::Node;
/// The decision policy when determining the action in final MCTS phase
///
/// This policy drives how the MCTS algorithm chooses which action is the
/// "best" from the exploration.
#[derive(Debug)]
pub enum DecisionPolicy {
/// Decide on the action with the most visits
///
/// This option relies on the statistical confidence drive by the MCTS
/// algorithm instead of the potentially more noisy value estimates.
///
/// This is the standard policy used in most MCTS implementations, and
/// is a good selection when not hyper-maximizing for potential gain
MostVisits,
/// Decide on the action with the highest average value
///
/// This is non-standard, but is more aggressive in attempting to gain
/// the highest value in a decision.
HighestValue,
}
pub fn decide_on_action<S: GameState>(
policy: &DecisionPolicy,
root_node: &Node<S>,
arena: &Arena<S>,
) -> Option<S::Action> {
match policy {
DecisionPolicy::MostVisits => most_visits(root_node, arena),
DecisionPolicy::HighestValue => highest_value(root_node, arena),
}
}
fn most_visits<S: GameState>(root_node: &Node<S>, arena: &Arena<S>) -> Option<S::Action> {
let best_child_id: &usize = root_node
.children
.iter()
.max_by(|&a, &b| {
let node_a_visits = arena.get_node(*a).visits;
let node_b_visits = arena.get_node(*b).visits;
node_a_visits.partial_cmp(&node_b_visits).unwrap()
})
.unwrap();
arena.get_node(*best_child_id).action.clone()
}
fn highest_value<S: GameState>(root_node: &Node<S>, arena: &Arena<S>) -> Option<S::Action> {
let player = root_node.state.get_current_player();
let best_child_id: &usize = root_node
.children
.iter()
.max_by(|&a, &b| {
let node_a_score = arena.get_node(*a).reward_average(player);
let node_b_score = arena.get_node(*b).reward_average(player);
node_a_score.partial_cmp(&node_b_score).unwrap()
})
.unwrap();
arena.get_node(*best_child_id).action.clone()
}

4
src/policy/mod.rs Normal file
View File

@ -0,0 +1,4 @@
pub mod backprop;
pub mod decision;
pub mod selection;
pub mod simulation;

View File

@ -0,0 +1,59 @@
mod ucb1;
mod ucb1_tuned;
use crate::state::GameState;
use crate::tree::arena::Arena;
use crate::tree::node::Node;
/// The selection policy used in the MCTS selection phase
///
/// This drives the selection of the nodes in the search tree, determining
/// which paths are explored and evaluated.
///
/// In general UCB1-Tuned or UCB1 should be effective, however if necessariy
/// a custom selection policy can be provided.
#[derive(Debug)]
pub enum SelectionPolicy<S: GameState> {
/// Upper Confidence Bound 1 (UCB1) with the given exploration constant
///
/// The exploration constant controls the balance between exploration and
/// exploitation. The higher the value, the mroe likely the search will
/// explore less-visited nodes. A standard value is √2 ≈ 1.414.
UCB1(f64),
/// Upper Confidence Bound 1 Tuned (UCB1-Tuned)
///
/// A tuned version of UCB1 instead using the empirical
/// standard deviation of the rewards to drive exploration.
///
/// Auer, P., Cesa-Bianchi, N. & Fischer, P. Finite-time Analysis of the Multiarmed Bandit Problem. Machine Learning 47, 235256 (2002). https://doi.org/10.1023/A:1013689704352
UCB1Tuned(f64),
/// Custom selection policy
Custom(Box<dyn CustomSelectionPolicy<S>>),
}
/// Trait for an object implementing the selection logic whene exploring the MCTS
/// search tree.
///
/// The policy should select the child of the given node which is "best" for the current player
pub trait CustomSelectionPolicy<S: GameState>: std::fmt::Debug {
/// Selects a child based on the policy, returning the node ID
fn select_child(&self, node: &Node<S>, arena: &Arena<S>) -> usize;
}
pub fn select_best_child<S: GameState>(
policy: &SelectionPolicy<S>,
node: &Node<S>,
arena: &Arena<S>,
) -> usize {
match policy {
SelectionPolicy::UCB1(exploration_constant) => {
ucb1::select_best_child(*exploration_constant, node, arena)
}
SelectionPolicy::UCB1Tuned(exploration_constant) => {
ucb1_tuned::select_best_child(*exploration_constant, node, arena)
}
SelectionPolicy::Custom(custom_policy) => custom_policy.select_child(node, arena),
}
}

View File

@ -0,0 +1,79 @@
//! Upper Confidence Bound 1 (UCB1) selection policy
//!
//! This is the classic selection policy for MCTS, which balances
//! exploration and exploitation using the UCB1 formula:
//!
//! ```text
//! UCB1 = average_reward + exploration_constant * sqrt(ln(parent_visits) / child_visits)
//! ```
//!
//! Where:
//! - `average_reward` is the average reward from simulations through this node
//! - `exploration_constant` controls the balance between exploration and exploitation
//! - `parent_visits` is the number of visits to the parent node
//! - `child_visits` is the number of visits to the child node
//!
//! Higher exploration constants favor exploration (trying less-visited nodes),
//! while lower values favor exploitation (choosing nodes with higher values).
//!
//! The commonly used value for the exploration constant is sqrt(2) ≈ 1.414,
//! which is the default in this implementation.
use crate::state::GameState;
use crate::tree::arena::Arena;
use crate::tree::node::{Node, RewardVal};
/// Selects the index of the "best" child using the UCB1 selection policy
pub fn select_best_child<S: GameState>(
exploration_constant: f64,
node: &Node<S>,
arena: &Arena<S>,
) -> usize {
if node.is_leaf() {
panic!("select_best_child called on leaf node");
}
let player = node.state.get_current_player();
let parent_visits = node.visits;
let best_child = node
.children
.iter()
.max_by(|&a, &b| {
let node_a = arena.get_node(*a);
let node_b = arena.get_node(*b);
let ucb_a = ucb1_value(
exploration_constant,
node_a.reward_average(player),
node_a.visits,
parent_visits,
);
let ucb_b = ucb1_value(
exploration_constant,
node_b.reward_average(player),
node_b.visits,
parent_visits,
);
ucb_a.partial_cmp(&ucb_b).unwrap()
})
.unwrap();
*best_child
}
/// Calculates the UCB1 value for a node
pub fn ucb1_value(
exploration_constant: f64,
child_value: RewardVal,
child_visits: u64,
parent_visits: u64,
) -> RewardVal {
if child_visits == 0 {
return f64::INFINITY; // Always explore nodes that have never been visited
}
// UCB1 formula: value + C * sqrt(ln(parent_visits) / child_visits)
let exploitation = child_value;
let exploration =
exploration_constant * ((parent_visits as f64).ln() / child_visits as f64).sqrt();
exploitation + exploration
}

View File

@ -0,0 +1,97 @@
//! Upper Confidence Bound 1 Tuned (UCB1-Tuned) selection policy
//!
//! This is a fine-tuned version of UCB which takes into account the
//! empircally measured variance of the rewards to drive the exploration.
//!
//! This has been found to perform substantially better than UCB1 in most
//! situations.
//!
//! Auer, P., Cesa-Bianchi, N. & Fischer, P.
//! Finite-time Analysis of the Multiarmed Bandit Problem.
//! Machine Learning 47, 235256 (2002). https://doi.org/10.1023/A:1013689704352
use crate::state::GameState;
use crate::tree::arena::Arena;
use crate::tree::node::{Node, RewardVal};
/// Selects the index of the "best" child using the UCB1-Tuned selection policy
pub fn select_best_child<S: GameState>(
exploration_constant: f64,
node: &Node<S>,
arena: &Arena<S>,
) -> usize {
if node.is_leaf() {
panic!("select_best_child called on leaf node");
}
let player = node.state.get_current_player();
let parent_visits = node.visits;
let best_child = node
.children
.iter()
.max_by(|&a, &b| {
let node_a = arena.get_node(*a);
let node_b = arena.get_node(*b);
let ucb_a = ucb1_tuned_value(
exploration_constant,
parent_visits,
node_a.visits,
node_a.rewards(player),
node_a.reward_average(player),
);
let ucb_b = ucb1_tuned_value(
exploration_constant,
parent_visits,
node_b.visits,
node_b.rewards(player),
node_b.reward_average(player),
);
ucb_a.partial_cmp(&ucb_b).unwrap()
})
.unwrap();
*best_child
}
/// Calculates the UCB1-Tuned value for a node
pub fn ucb1_tuned_value(
exploration_constant: f64,
parent_visits: u64,
child_visits: u64,
child_rewards: Option<&Vec<RewardVal>>,
reward_avg: RewardVal,
) -> RewardVal {
match child_rewards {
None => {
RewardVal::INFINITY // Always explore nodes that have never been visited
}
Some(child_rewards) => {
if child_visits == 0 {
RewardVal::INFINITY // Always explore nodes that have never been visited
} else {
let parent_visits: RewardVal = parent_visits as RewardVal;
let child_visits: RewardVal = child_visits as RewardVal;
// N: number of visits to the parent node
// n: number of visits to the child node
// x_i: reward of the ith visit to the child node
// X: average reward of the child
// C: exploration constant
//
// UCB1-Tuned = X + C * sqrt(Ln(parent_visits) / child_visits * min(1/4, V_n)
// V(n) = sum(x_i^2)/n - X^2 + sqrt(2*ln(N)/n)
let exploitation = reward_avg;
let mut variance = (child_rewards.iter().map(|&x| x * x).sum::<RewardVal>()
/ child_visits)
- (reward_avg * reward_avg)
+ (2.0 * parent_visits.ln() / child_visits).sqrt();
if variance > 0.25 {
variance = 0.25;
}
let exploration =
exploration_constant * (parent_visits.ln() / child_visits * variance).sqrt();
exploitation + exploration
}
}
}
}

View File

@ -0,0 +1,44 @@
mod random;
use crate::state::GameState;
use crate::tree::arena::Arena;
use crate::tree::node::{Node, RewardVal};
use std::collections::HashMap;
/// The simulation policy used in the MCTS simulation phase
///
/// This policy drives the game simulations while evaluating the tree. While
/// a random policy works well, a game-specific policy can be provided either
/// as a custom policy.
#[derive(Debug)]
pub enum SimulationPolicy<S: GameState> {
/// Random simulation policy
///
/// The sequential actions are selected randomly from the available actions
/// at each state until a terminal state is found.
Random,
/// Custom simulation policy
Custom(Box<dyn CustomSimulationPolicy<S>>),
}
/// Trait for an object implementing the simulation logic whene exploring the MCTS
/// search tree.
pub trait CustomSimulationPolicy<S: GameState>: std::fmt::Debug {
/// Simulates the gameplay from the current node onward, returning the rewards
///
/// This should simulate the game until a terminal node is reached, returning
/// the final reward for each player at the terminal node
fn simulate(&self, node: &Node<S>, arena: &Arena<S>) -> HashMap<S::Player, RewardVal>;
}
pub fn simulate_reward<S: GameState>(
policy: &SimulationPolicy<S>,
node: &Node<S>,
arena: &Arena<S>,
) -> HashMap<S::Player, RewardVal> {
match policy {
SimulationPolicy::Random => random::simulate(node),
SimulationPolicy::Custom(custom_policy) => custom_policy.simulate(node, arena),
}
}

View File

@ -0,0 +1,18 @@
//! Random play simulation policy
//!
//! Actions are chosen at random
use crate::state::GameState;
use crate::tree::node::{Node, RewardVal};
use rand::prelude::SliceRandom;
use std::collections::HashMap;
pub fn simulate<S: GameState>(node: &Node<S>) -> HashMap<S::Player, RewardVal> {
let mut state: S = node.state.clone();
while !state.is_terminal() {
let legal_actions = state.get_legal_actions();
let action = legal_actions.choose(&mut rand::thread_rng()).unwrap();
state = state.state_after_action(&action);
}
state.rewards_for_players()
}

90
src/state.rs Normal file
View File

@ -0,0 +1,90 @@
use crate::tree::node::RewardVal;
use std::collections::HashMap;
use std::fmt::Debug;
use std::hash::Hash;
/// Trait for the game state used in MCTS
///
/// When leveraging MCTS for your game, you must implement this trait to provide
/// the specifics for your game.
pub trait GameState: Clone {
/// The type of actions that can be taken in the game
type Action: Action;
/// The type of players in the game
type Player: Player;
/// Returns if the game state is terminal, i.e. the game is over
///
/// A game state is terminal when no other actions are possible. This can be
/// the result of a player winning, a draw, or because some other conditions
/// have been met leading to a game with no further possible states.
///
/// The default implementation returns True if `get_legal_actions()` returns
/// an empty list. It is recommended to override this for a more efficient
/// implementation if possible.
fn is_terminal(&self) -> bool {
let actions = self.get_legal_actions();
actions.len() == 0
}
/// Returns the list of legal actions for the game state
///
/// This method must return all possible actions that can be made from the
/// current game state.
fn get_legal_actions(&self) -> Vec<Self::Action>;
/// Returns the game state resulting from applying the action to the state
///
/// This function should not modify the current state directly, and
/// instead should modify a copy of the state and return that.
fn state_after_action(&self, action: &Self::Action) -> Self;
/// Returns the reward from the perspective of the given player for the game state
///
/// This evaluates the current state from the perspective of the given player, and
/// returns the reward indicating how good of a result the given state is for the
/// player.
///
/// This is used in the MCTS backpropagation and simulation phases to evaluate
/// the value of a given node in the search tree.
///
/// A general rule of thumb for values are:
/// - 1.0 => a win for the player
/// - 0.5 => a draw
/// - 0.0 => a loss for the player
///
/// Other values can be used for relative wins or losses
fn reward_for_player(&self, player: &Self::Player) -> RewardVal;
/// Returns the rewards for all players at the current state
fn rewards_for_players(&self) -> HashMap<Self::Player, RewardVal>;
/// Returns the player whose turn it is for the game state
///
/// This is used for evaluating the state, so for simultaneous games
/// consider the "current player" as the one from whose perspective we are
/// evaluating the game state from
fn get_current_player(&self) -> &Self::Player;
}
/// Trait used for actions that can be taken in a game
///
/// An action is dependent upon the specific game being defined, and includes
/// things like moves, attacks, and other decisions.
pub trait Action: Clone + Debug {
/// Returns a uniqie identifier for this action
fn id(&self) -> usize;
}
/// Trait used for players participating in a game
pub trait Player: Clone + Debug + PartialEq + Eq + Hash {}
/// Convenience implemnentation of a Player for usize
impl Player for usize {}
/// Convenience implemnentation of a Player for char
impl Player for char {}
/// Convenience implemnentation of a Player for String
impl Player for String {}

46
src/tree/arena.rs Normal file
View File

@ -0,0 +1,46 @@
use crate::state::GameState;
use crate::tree::node::Node;
/// An arena for Node allocation
///
/// We use an arena for node allocation to improve performance of our search.
/// The memory is contiguous which allows for faster movement through the tree,
/// as well as more efficient destruction as our MCTS search will destroy the
/// entire tree at once.
pub struct Arena<S: GameState> {
pub nodes: Vec<Node<S>>,
}
impl<S: GameState> Arena<S> {
/// Create a new Arena with the given initial capacity
///
/// The arena creates a contiguous block. By reserving an initial capacity
/// that is sufficient to encapsulate a full search tree we can reduce the
/// number of reallocs that are required. This number is highly game
/// dependent.
pub fn new(initial_capacity: usize) -> Self {
Arena {
nodes: Vec::with_capacity(initial_capacity),
}
}
/// Adds a node to the Arena, returning its identifier
///
/// This appends the node to the allocated Arena, and returns the nodes
/// index in the arena which is used as an identifier for later retrieval.
pub fn add_node(&mut self, node: Node<S>) -> usize {
let id = self.nodes.len();
self.nodes.push(node);
id
}
/// Retrieves a mutable reference to a Node in the Arena
pub fn get_node_mut(&mut self, id: usize) -> &mut Node<S> {
&mut self.nodes[id]
}
/// Retrieves a reference to a Node in the Arena
pub fn get_node(&self, id: usize) -> &Node<S> {
&self.nodes[id]
}
}

2
src/tree/mod.rs Normal file
View File

@ -0,0 +1,2 @@
pub mod arena;
pub mod node;

114
src/tree/node.rs Normal file
View File

@ -0,0 +1,114 @@
use std::collections::HashMap;
use std::fmt::Debug;
use crate::state::GameState;
/// The type used for reward values
pub type RewardVal = f64;
/// A node in the MCTS tree
///
/// A node represents a given game state and, using the path from the root node,
/// the actions that led to the given state. A node has a number of children
/// nodes representing the game states reachable from the given state, after
/// a given action. This creates the tree that MCTS iterates through.
///
/// This class is not thread safe, as the library does not provide for parallel
/// search.
#[derive(Debug)]
pub struct Node<S: GameState> {
/// The game state at the given node, after `action`
pub state: S,
/// The action that led to this state from its parent
pub action: Option<S::Action>,
/// The identifier of the parent Node
pub parent: Option<usize>,
/// The number of times this node has been visited
pub visits: u64,
/// The player's evaluation of the node
pub player_view: HashMap<S::Player, PlayerNodeView>,
/// The identifiers of children nodes, states reachable from this one
pub children: Vec<usize>,
}
impl<S: GameState> Node<S> {
pub fn new(state: S, action: Option<S::Action>, parent: Option<usize>) -> Self {
Node {
state,
action,
parent,
visits: 0,
player_view: HashMap::with_capacity(2),
children: Vec::new(),
}
}
pub fn is_leaf(&self) -> bool {
self.children.is_empty()
}
pub fn reward_sum(&self, player: &S::Player) -> RewardVal {
match self.player_view.get(player) {
Some(pv) => pv.reward_sum,
None => 0.0,
}
}
pub fn reward_average(&self, player: &S::Player) -> RewardVal {
match self.player_view.get(player) {
Some(pv) => pv.reward_average,
None => 0.0,
}
}
pub fn rewards(&self, player: &S::Player) -> Option<&Vec<RewardVal>> {
match self.player_view.get(player) {
Some(pv) => Some(&pv.rewards),
None => None,
}
}
pub fn increment_visits(&mut self) {
self.visits += 1
}
pub fn record_player_reward(&mut self, player: S::Player, reward: RewardVal) {
let pv = self
.player_view
.entry(player)
.or_insert(PlayerNodeView::default());
pv.rewards.push(reward);
pv.reward_sum += reward;
pv.reward_average = pv.reward_sum / pv.rewards.len() as f64;
}
}
/// A player's specific perspective of a node's value
///
/// Each player has their own idea of the value of a node.
#[derive(Debug)]
pub struct PlayerNodeView {
/// The total reward from simulations through this node
pub reward_sum: RewardVal,
/// The average reward from simulations through this node, often called the node value
pub reward_average: RewardVal,
/// The rewards we have gotten so far for simulations through this node
pub rewards: Vec<RewardVal>,
}
impl Default for PlayerNodeView {
fn default() -> Self {
PlayerNodeView {
reward_sum: 0.0,
reward_average: 0.0,
rewards: Vec::new(),
}
}
}