Working MCTS implementation
This is a basic working implementation of the MCTS algorithm. Though currently the algorithm is slow compared with other implemenations, and makes sub-optimal choices when playing tic-tac-toe. Therefore some modifications are needed
This commit is contained in:
parent
197a46996a
commit
17884f4b90
154
Cargo.lock
generated
Normal file
154
Cargo.lock
generated
Normal file
@ -0,0 +1,154 @@
|
||||
# This file is automatically @generated by Cargo.
|
||||
# It is not intended for manual editing.
|
||||
version = 4
|
||||
|
||||
[[package]]
|
||||
name = "cfg-if"
|
||||
version = "1.0.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9555578bc9e57714c812a1f84e4fc5b4d21fcb063490c624de019f7464c91268"
|
||||
|
||||
[[package]]
|
||||
name = "getrandom"
|
||||
version = "0.2.16"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "335ff9f135e4384c8150d6f27c6daed433577f86b4750418338c01a1a2528592"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"libc",
|
||||
"wasi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "libc"
|
||||
version = "0.2.174"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1171693293099992e19cddea4e8b849964e9846f4acee11b3948bcc337be8776"
|
||||
|
||||
[[package]]
|
||||
name = "ppv-lite86"
|
||||
version = "0.2.21"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "85eae3c4ed2f50dcfe72643da4befc30deadb458a9b590d720cde2f2b1e97da9"
|
||||
dependencies = [
|
||||
"zerocopy",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "proc-macro2"
|
||||
version = "1.0.95"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "02b3e5e68a3a1a02aad3ec490a98007cbc13c37cbe84a3cd7b8e406d76e7f778"
|
||||
dependencies = [
|
||||
"unicode-ident",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "quote"
|
||||
version = "1.0.40"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1885c039570dc00dcb4ff087a89e185fd56bae234ddc7f056a945bf36467248d"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rand"
|
||||
version = "0.8.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404"
|
||||
dependencies = [
|
||||
"libc",
|
||||
"rand_chacha",
|
||||
"rand_core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rand_chacha"
|
||||
version = "0.3.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88"
|
||||
dependencies = [
|
||||
"ppv-lite86",
|
||||
"rand_core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rand_core"
|
||||
version = "0.6.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c"
|
||||
dependencies = [
|
||||
"getrandom",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustic_mcts"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"rand",
|
||||
"thiserror",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "syn"
|
||||
version = "2.0.103"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e4307e30089d6fd6aff212f2da3a1f9e32f3223b1f010fb09b7c95f90f3ca1e8"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"unicode-ident",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "thiserror"
|
||||
version = "2.0.12"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "567b8a2dae586314f7be2a752ec7474332959c6460e02bde30d702a66d488708"
|
||||
dependencies = [
|
||||
"thiserror-impl",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "thiserror-impl"
|
||||
version = "2.0.12"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7f7cf42b4507d8ea322120659672cf1b9dbb93f8f2d4ecfd6e51350ff5b17a1d"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "unicode-ident"
|
||||
version = "1.0.18"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5a5f39404a5da50712a4c1eecf25e90dd62b613502b7e925fd4e4d19b5c96512"
|
||||
|
||||
[[package]]
|
||||
name = "wasi"
|
||||
version = "0.11.1+wasi-snapshot-preview1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ccf3ec651a847eb01de73ccad15eb7d99f80485de043efb2f370cd654f4ea44b"
|
||||
|
||||
[[package]]
|
||||
name = "zerocopy"
|
||||
version = "0.8.26"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1039dd0d3c310cf05de012d8a39ff557cb0d23087fd44cad61df08fc31907a2f"
|
||||
dependencies = [
|
||||
"zerocopy-derive",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "zerocopy-derive"
|
||||
version = "0.8.26"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9ecf5b4cc5364572d7f4c329661bcc82724222973f2cab6f050a4e5c22f75181"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
]
|
16
Cargo.toml
Normal file
16
Cargo.toml
Normal file
@ -0,0 +1,16 @@
|
||||
[package]
|
||||
name = "rustic_mcts"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
authors = ["David Kruger <david@krugerlabs.us>"]
|
||||
description = "An extensible implementation of Monte Carlo Tree Search (MCTS) using an arena allocator."
|
||||
license = "MIT"
|
||||
repository = "https://gitlabs.krugerlabs.us/krugd/rustic_mcts"
|
||||
readme = "README.md"
|
||||
keywords = ["mcts", "rust", "monte_carlo", "tree", "ai", "ml"]
|
||||
categories = ["algorithms", "data-structures"]
|
||||
|
||||
|
||||
[dependencies]
|
||||
rand = "~0.8"
|
||||
thiserror = "~2.0"
|
298
examples/tic_tac_toe.rs
Normal file
298
examples/tic_tac_toe.rs
Normal file
@ -0,0 +1,298 @@
|
||||
use std::collections::HashMap;
|
||||
use std::fmt;
|
||||
use std::io::{self, Write};
|
||||
|
||||
use rustic_mcts::policy::backprop::BackpropagationPolicy;
|
||||
use rustic_mcts::policy::decision::DecisionPolicy;
|
||||
use rustic_mcts::policy::selection::SelectionPolicy;
|
||||
use rustic_mcts::policy::simulation::SimulationPolicy;
|
||||
use rustic_mcts::{Action, GameState, MCTSConfig, RewardVal, MCTS};
|
||||
|
||||
fn main() {
|
||||
println!("MCTS Tic-Tac-Toe Example");
|
||||
println!("========================");
|
||||
println!();
|
||||
|
||||
// Set up a new game
|
||||
let mut game = TicTacToe::new();
|
||||
|
||||
// Create MCTS configuration
|
||||
let config = MCTSConfig {
|
||||
max_iterations: 10_000,
|
||||
max_time: None,
|
||||
tree_size_allocation: 10_000,
|
||||
selection_policy: SelectionPolicy::UCB1Tuned(1.414),
|
||||
simulation_policy: SimulationPolicy::Random,
|
||||
backprop_policy: BackpropagationPolicy::Standard,
|
||||
decision_policy: DecisionPolicy::MostVisits,
|
||||
};
|
||||
|
||||
// Main game loop
|
||||
while !game.is_terminal() {
|
||||
// Display the board
|
||||
println!("{}", game);
|
||||
|
||||
if game.current_player == Player::X {
|
||||
// Human player (X)
|
||||
println!("Your move (enter row column, e.g. '1 2'): ");
|
||||
io::stdout().flush().unwrap();
|
||||
|
||||
let mut input = String::new();
|
||||
io::stdin().read_line(&mut input).unwrap();
|
||||
|
||||
let coords: Vec<usize> = input
|
||||
.trim()
|
||||
.split_whitespace()
|
||||
.filter_map(|s| s.parse::<usize>().ok())
|
||||
.collect();
|
||||
|
||||
if coords.len() != 2 || coords[0] > 2 || coords[1] > 2 {
|
||||
println!("Invalid move! Enter row and column (0-2).");
|
||||
continue;
|
||||
}
|
||||
|
||||
let row = coords[0];
|
||||
let col = coords[1];
|
||||
|
||||
let move_index = row * 3 + col;
|
||||
let action = Move { index: move_index };
|
||||
|
||||
if !game.is_legal_move(&action) {
|
||||
println!("Illegal move! Try again.");
|
||||
continue;
|
||||
}
|
||||
|
||||
// Apply the human's move
|
||||
game = game.state_after_action(&action);
|
||||
} else {
|
||||
// AI player (O)
|
||||
println!("AI is thinking...");
|
||||
|
||||
// Create a new MCTS search
|
||||
let mut mcts = MCTS::new(game.clone(), &config);
|
||||
|
||||
// Find the best move
|
||||
match mcts.search() {
|
||||
Ok(action) => {
|
||||
println!(
|
||||
"AI chooses: {} (row {}, col {})",
|
||||
action.index,
|
||||
action.index / 3,
|
||||
action.index % 3
|
||||
);
|
||||
|
||||
// Apply the AI's move
|
||||
game = game.state_after_action(&action);
|
||||
}
|
||||
Err(e) => {
|
||||
println!("Error: {:?}", e);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Display final state
|
||||
println!("{}", game);
|
||||
|
||||
// Report the result
|
||||
if let Some(winner) = game.get_winner() {
|
||||
println!("Player {:?} wins!", winner);
|
||||
} else {
|
||||
println!("The game is a draw!");
|
||||
}
|
||||
}
|
||||
|
||||
/// Players in Tic-Tac-Toe
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
enum Player {
|
||||
X,
|
||||
O,
|
||||
}
|
||||
|
||||
impl rustic_mcts::Player for Player {}
|
||||
|
||||
/// Tic-Tac-Toe move
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
struct Move {
|
||||
/// Board position index (0-8)
|
||||
index: usize,
|
||||
}
|
||||
|
||||
impl Action for Move {
|
||||
fn id(&self) -> usize {
|
||||
self.index
|
||||
}
|
||||
}
|
||||
|
||||
/// Tic-Tac-Toe game state
|
||||
#[derive(Clone)]
|
||||
struct TicTacToe {
|
||||
/// Board representation (None = empty, Some(Player) = occupied)
|
||||
board: [Option<Player>; 9],
|
||||
|
||||
/// Current player's turn
|
||||
current_player: Player,
|
||||
|
||||
/// Number of moves played so far
|
||||
moves_played: usize,
|
||||
}
|
||||
|
||||
impl TicTacToe {
|
||||
/// Creates a new empty Tic-Tac-Toe board
|
||||
fn new() -> Self {
|
||||
TicTacToe {
|
||||
board: [None; 9],
|
||||
current_player: Player::X,
|
||||
moves_played: 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Checks if a move is legal
|
||||
fn is_legal_move(&self, action: &Move) -> bool {
|
||||
if action.index >= 9 {
|
||||
return false;
|
||||
}
|
||||
self.board[action.index].is_none()
|
||||
}
|
||||
|
||||
/// Returns the winner of the game, if any
|
||||
fn get_winner(&self) -> Option<Player> {
|
||||
// Check rows
|
||||
for row in 0..3 {
|
||||
let i = row * 3;
|
||||
if self.board[i].is_some()
|
||||
&& self.board[i] == self.board[i + 1]
|
||||
&& self.board[i] == self.board[i + 2]
|
||||
{
|
||||
return self.board[i];
|
||||
}
|
||||
}
|
||||
|
||||
// Check columns
|
||||
for col in 0..3 {
|
||||
if self.board[col].is_some()
|
||||
&& self.board[col] == self.board[col + 3]
|
||||
&& self.board[col] == self.board[col + 6]
|
||||
{
|
||||
return self.board[col];
|
||||
}
|
||||
}
|
||||
|
||||
// Check diagonals
|
||||
if self.board[0].is_some()
|
||||
&& self.board[0] == self.board[4]
|
||||
&& self.board[0] == self.board[8]
|
||||
{
|
||||
return self.board[0];
|
||||
}
|
||||
if self.board[2].is_some()
|
||||
&& self.board[2] == self.board[4]
|
||||
&& self.board[2] == self.board[6]
|
||||
{
|
||||
return self.board[2];
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
impl GameState for TicTacToe {
|
||||
type Action = Move;
|
||||
type Player = Player;
|
||||
|
||||
fn get_legal_actions(&self) -> Vec<Self::Action> {
|
||||
let mut actions = Vec::new();
|
||||
for i in 0..9 {
|
||||
if self.board[i].is_none() {
|
||||
actions.push(Move { index: i });
|
||||
}
|
||||
}
|
||||
actions
|
||||
}
|
||||
|
||||
fn state_after_action(&self, action: &Self::Action) -> Self {
|
||||
let mut new_state = self.clone();
|
||||
|
||||
// Make the move
|
||||
new_state.board[action.index] = Some(self.current_player);
|
||||
new_state.moves_played = self.moves_played + 1;
|
||||
|
||||
// Switch player
|
||||
new_state.current_player = match self.current_player {
|
||||
Player::X => Player::O,
|
||||
Player::O => Player::X,
|
||||
};
|
||||
|
||||
new_state
|
||||
}
|
||||
|
||||
fn is_terminal(&self) -> bool {
|
||||
self.get_winner().is_some() || self.moves_played == 9
|
||||
}
|
||||
|
||||
fn reward_for_player(&self, player: &Self::Player) -> RewardVal {
|
||||
if let Some(winner) = self.get_winner() {
|
||||
if winner == *player {
|
||||
return 1.0; // Win
|
||||
} else {
|
||||
return 0.0; // Loss
|
||||
}
|
||||
}
|
||||
|
||||
// Draw
|
||||
0.5
|
||||
}
|
||||
|
||||
fn rewards_for_players(&self) -> HashMap<Self::Player, RewardVal> {
|
||||
HashMap::from_iter(vec![
|
||||
(Player::X, self.reward_for_player(&Player::X)),
|
||||
(Player::O, self.reward_for_player(&Player::O)),
|
||||
])
|
||||
}
|
||||
|
||||
fn get_current_player(&self) -> &Self::Player {
|
||||
&self.current_player
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for TicTacToe {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
writeln!(f, " 0 1 2")?;
|
||||
for row in 0..3 {
|
||||
write!(f, "{} ", row)?;
|
||||
for col in 0..3 {
|
||||
let index = row * 3 + col;
|
||||
let symbol = match self.board[index] {
|
||||
Some(Player::X) => "X",
|
||||
Some(Player::O) => "O",
|
||||
None => ".",
|
||||
};
|
||||
write!(f, "{} ", symbol)?;
|
||||
}
|
||||
writeln!(f)?;
|
||||
}
|
||||
|
||||
writeln!(f, "\nPlayer {:?}'s turn", self.current_player)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Debug for TicTacToe {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "\n")?;
|
||||
for row in 0..3 {
|
||||
for col in 0..3 {
|
||||
let index = row * 3 + col;
|
||||
let symbol = match self.board[index] {
|
||||
Some(Player::X) => "X",
|
||||
Some(Player::O) => "O",
|
||||
None => ".",
|
||||
};
|
||||
write!(f, "{} ", symbol)?;
|
||||
}
|
||||
writeln!(f)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
67
src/config.rs
Normal file
67
src/config.rs
Normal file
@ -0,0 +1,67 @@
|
||||
use crate::policy::backprop::BackpropagationPolicy;
|
||||
use crate::policy::decision::DecisionPolicy;
|
||||
use crate::policy::selection::SelectionPolicy;
|
||||
use crate::policy::simulation::SimulationPolicy;
|
||||
use crate::state::GameState;
|
||||
use std::time::Duration;
|
||||
|
||||
/// Configuration for the MCTS algorithm
|
||||
#[derive(Debug)]
|
||||
pub struct MCTSConfig<S: GameState> {
|
||||
/// The maximum number of iterations to run when searching
|
||||
///
|
||||
/// The search will stop after the given number of iterations, even if there
|
||||
/// is search time has not exceeded `max_time`.
|
||||
pub max_iterations: usize,
|
||||
|
||||
/// The maximum time to run the search
|
||||
///
|
||||
/// If set, the search will stop after this duration even if the maximum
|
||||
/// iterations hasn't been reached.
|
||||
pub max_time: Option<Duration>,
|
||||
|
||||
/// The size to initially allocate for the search tree
|
||||
///
|
||||
/// This pre-allocates memory for the search tree which ensures contiguous
|
||||
/// memory and improves performance by preventing the resizing of tree
|
||||
/// as we explore.
|
||||
pub tree_size_allocation: usize,
|
||||
|
||||
/// The selection policy
|
||||
///
|
||||
/// This dictates the path through which the game tree is searched. As such
|
||||
/// the policy has a large impact on the overall aglorthm exeuction
|
||||
pub selection_policy: SelectionPolicy<S>,
|
||||
|
||||
/// The simulation policy
|
||||
///
|
||||
/// This dictates the game siluation when expanding and evaluating the
|
||||
/// search tree. Random is generally a good default.
|
||||
pub simulation_policy: SimulationPolicy<S>,
|
||||
|
||||
/// The backpropagation policy
|
||||
///
|
||||
/// This dictates how the results of the simulation playouts are propagated
|
||||
/// back up the tree.
|
||||
pub backprop_policy: BackpropagationPolicy<S>,
|
||||
|
||||
/// The decision policy
|
||||
///
|
||||
/// This dictates how the MCTS algorithm determines its final decision
|
||||
/// after iterating through the search tree
|
||||
pub decision_policy: DecisionPolicy,
|
||||
}
|
||||
|
||||
impl<S: GameState> Default for MCTSConfig<S> {
|
||||
fn default() -> Self {
|
||||
MCTSConfig {
|
||||
max_iterations: 10_000,
|
||||
max_time: None,
|
||||
tree_size_allocation: 10_000,
|
||||
selection_policy: SelectionPolicy::UCB1Tuned(1.414),
|
||||
simulation_policy: SimulationPolicy::Random,
|
||||
backprop_policy: BackpropagationPolicy::Standard,
|
||||
decision_policy: DecisionPolicy::MostVisits,
|
||||
}
|
||||
}
|
||||
}
|
17
src/lib.rs
Normal file
17
src/lib.rs
Normal file
@ -0,0 +1,17 @@
|
||||
//! # rustic_mcts
|
||||
//!
|
||||
//! An extensible implementation of Monte Carlo Tree Search (MCTS) using arena allocation and
|
||||
//! configurable policies.
|
||||
|
||||
pub mod config;
|
||||
pub mod mcts;
|
||||
pub mod policy;
|
||||
pub mod state;
|
||||
pub mod tree;
|
||||
|
||||
pub use config::MCTSConfig;
|
||||
pub use mcts::MCTS;
|
||||
pub use state::Action;
|
||||
pub use state::GameState;
|
||||
pub use state::Player;
|
||||
pub use tree::node::RewardVal;
|
147
src/mcts.rs
Normal file
147
src/mcts.rs
Normal file
@ -0,0 +1,147 @@
|
||||
use crate::config::MCTSConfig;
|
||||
use crate::policy::backprop::backpropagate_rewards;
|
||||
use crate::policy::decision::decide_on_action;
|
||||
use crate::policy::selection::select_best_child;
|
||||
use crate::policy::simulation::simulate_reward;
|
||||
use crate::state::GameState;
|
||||
use crate::tree::arena::Arena;
|
||||
use crate::tree::node::{Node, RewardVal};
|
||||
use rand::prelude::SliceRandom;
|
||||
use std::collections::HashMap;
|
||||
use std::time::Instant;
|
||||
|
||||
/// Monte Carlo Tree Search implementation
|
||||
///
|
||||
/// This provides the interface for performing optimal searches on a tree using
|
||||
/// the MCTS algorithm.
|
||||
pub struct MCTS<'conf, S: GameState> {
|
||||
/// The arena used for the tree
|
||||
arena: Arena<S>,
|
||||
|
||||
/// The identifier of the root node of the search tree
|
||||
root_id: usize,
|
||||
|
||||
/// The configuration used for the search
|
||||
config: &'conf MCTSConfig<S>,
|
||||
}
|
||||
|
||||
impl<'conf, S: GameState + std::fmt::Debug> MCTS<'conf, S> {
|
||||
/// Creates a new instance with the given initial state and configuration
|
||||
pub fn new(initial_state: S, config: &'conf MCTSConfig<S>) -> Self {
|
||||
let mut arena: Arena<S> = Arena::new(config.tree_size_allocation);
|
||||
let root: Node<S> = Node::new(initial_state.clone(), None, None);
|
||||
let root_id: usize = arena.add_node(root);
|
||||
MCTS {
|
||||
arena,
|
||||
root_id,
|
||||
config,
|
||||
}
|
||||
}
|
||||
|
||||
/// Runs the MCTS algorithm, returning the "best" action
|
||||
///
|
||||
/// The search will stop once `max_iterations` or `max_time` from
|
||||
/// the assigned configration is reached.
|
||||
pub fn search(&mut self) -> Result<S::Action> {
|
||||
self.search_for_iterations(self.config.max_iterations)
|
||||
}
|
||||
|
||||
/// Runs the MCTS algorithm, returning the "best" action after the given iterations
|
||||
///
|
||||
/// This ignores the `max_iterations` provided in the config, however will
|
||||
/// return if `max_time` is specific and reached before the iterations are complete.
|
||||
pub fn search_for_iterations(&mut self, iterations: usize) -> Result<S::Action> {
|
||||
let start_time = Instant::now();
|
||||
for _ in 0..iterations {
|
||||
match self.config.max_time {
|
||||
Some(max_time) => {
|
||||
if start_time.elapsed() >= max_time {
|
||||
break; // ending early due to time
|
||||
}
|
||||
}
|
||||
None => {}
|
||||
}
|
||||
self.execute_iteration()?;
|
||||
}
|
||||
|
||||
self.best_action()
|
||||
}
|
||||
|
||||
/// Runs the MCTS algorithm for a single iteration
|
||||
fn execute_iteration(&mut self) -> Result<()> {
|
||||
let mut selected_id: usize = self.select();
|
||||
let selected_node: &Node<S> = self.arena.get_node(selected_id);
|
||||
if !selected_node.state.is_terminal() {
|
||||
self.expand(selected_id);
|
||||
let children: &Vec<usize> = &self.arena.get_node(selected_id).children;
|
||||
let random_child: usize = *children.choose(&mut rand::thread_rng()).unwrap();
|
||||
selected_id = random_child;
|
||||
}
|
||||
let rewards = self.simulate(selected_id);
|
||||
self.backprop(selected_id, &rewards);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// MCTS Phase 1: Selection - Find the "best" node to expand
|
||||
fn select(&mut self) -> usize {
|
||||
let mut current_id: usize = self.root_id;
|
||||
loop {
|
||||
let node = &self.arena.get_node(current_id);
|
||||
if node.is_leaf() || node.state.is_terminal() {
|
||||
return current_id;
|
||||
}
|
||||
current_id = select_best_child(&self.config.selection_policy, &node, &self.arena);
|
||||
}
|
||||
}
|
||||
|
||||
/// MCTS Phase 2: Expansion - Expand the selected node on the tree
|
||||
fn expand(&mut self, id: usize) {
|
||||
let parent: &Node<S> = self.arena.get_node_mut(id);
|
||||
let legal_actions: Vec<S::Action> = parent.state.get_legal_actions();
|
||||
let parent_state: S = parent.state.clone();
|
||||
for action in legal_actions {
|
||||
let state = parent_state.state_after_action(&action);
|
||||
let new_node = Node::new(state, Some(action), Some(id));
|
||||
let new_id = self.arena.add_node(new_node);
|
||||
self.arena.get_node_mut(id).children.push(new_id);
|
||||
}
|
||||
}
|
||||
|
||||
fn simulate(&self, id: usize) -> HashMap<S::Player, RewardVal> {
|
||||
let node = &self.arena.get_node(id);
|
||||
simulate_reward(&self.config.simulation_policy, &node, &self.arena)
|
||||
}
|
||||
|
||||
fn backprop(&mut self, selected_id: usize, rewards: &HashMap<S::Player, RewardVal>) {
|
||||
backpropagate_rewards(
|
||||
&self.config.backprop_policy,
|
||||
selected_id,
|
||||
&mut self.arena,
|
||||
&rewards,
|
||||
)
|
||||
}
|
||||
|
||||
fn best_action(&self) -> Result<S::Action> {
|
||||
let root_node: &Node<S> = self.arena.get_node(self.root_id);
|
||||
match decide_on_action(&self.config.decision_policy, &root_node, &self.arena) {
|
||||
Some(action) => Ok(action),
|
||||
None => Err(MCTSError::NoBestAction),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Errors returned by the MCTS algorithm
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum MCTSError {
|
||||
/// The best action doesn't exist
|
||||
#[error("Unable to determine a best action for the game")]
|
||||
NoBestAction,
|
||||
|
||||
/// The search tree was exhausted without finding a terminal node
|
||||
#[error("Search tree exhausted without finding terminal node")]
|
||||
NonTerminalGame,
|
||||
}
|
||||
|
||||
/// Result returned by the MCTS algorithm
|
||||
pub type Result<T> = std::result::Result<T, MCTSError>;
|
99
src/policy/backprop/mod.rs
Normal file
99
src/policy/backprop/mod.rs
Normal file
@ -0,0 +1,99 @@
|
||||
use crate::state::GameState;
|
||||
use crate::tree::arena::Arena;
|
||||
use crate::tree::node::RewardVal;
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// The back propagation policy dictating the propagation of playout results
|
||||
///
|
||||
/// This policy drives how the backpropagation phase of the MCTS algorithm is
|
||||
/// executed, allowing for some minor customization.
|
||||
///
|
||||
/// Typically the Standard policy, used by most implementaions of MCTS, is
|
||||
/// sufficient
|
||||
#[derive(Debug)]
|
||||
pub enum BackpropagationPolicy<S: GameState> {
|
||||
/// Standard back propagation
|
||||
///
|
||||
/// This increments the visitation count and adds the simulated rewards
|
||||
/// results to the aggregate values.
|
||||
///
|
||||
/// This is the standard policy used in most MCTS implementations.
|
||||
Standard,
|
||||
|
||||
/// Weighted back propagation
|
||||
///
|
||||
/// This weights the value of the simulated rewards based on the depth,
|
||||
/// allowing us to put more-or-less influence on deeper branches
|
||||
/// - Positive weight factor makes deeper nodes less influential
|
||||
/// - Negative weight factor makes deeper nodes more influential
|
||||
Weighted(f64),
|
||||
|
||||
/// Custom backpropagation policy
|
||||
Custom(Box<dyn CustomBackpropagationPolicy<S>>),
|
||||
}
|
||||
|
||||
/// Trait for an object implementing the backpropagation logic whene exploring the MCTS
|
||||
/// search tree.
|
||||
pub trait CustomBackpropagationPolicy<S: GameState>: std::fmt::Debug {
|
||||
/// Backpropagate the given rewards values from the node up the tree
|
||||
fn backprop(
|
||||
&self,
|
||||
node_id: usize,
|
||||
arena: &mut Arena<S>,
|
||||
rewards: &HashMap<S::Player, RewardVal>,
|
||||
);
|
||||
}
|
||||
|
||||
pub fn backpropagate_rewards<S: GameState>(
|
||||
policy: &BackpropagationPolicy<S>,
|
||||
node_id: usize,
|
||||
arena: &mut Arena<S>,
|
||||
rewards: &HashMap<S::Player, RewardVal>,
|
||||
) {
|
||||
match policy {
|
||||
BackpropagationPolicy::Standard => standard_backprop(node_id, arena, rewards),
|
||||
BackpropagationPolicy::Weighted(depth_factor) => {
|
||||
weighted_backprop(*depth_factor, node_id, arena, rewards)
|
||||
}
|
||||
BackpropagationPolicy::Custom(custom_policy) => {
|
||||
custom_policy.backprop(node_id, arena, rewards)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn standard_backprop<S: GameState>(
|
||||
node_id: usize,
|
||||
arena: &mut Arena<S>,
|
||||
rewards: &HashMap<S::Player, RewardVal>,
|
||||
) {
|
||||
// TODO:
|
||||
// - each node needs the perspective of the different players not just one view
|
||||
// - e.g. reward_sum(player), reward_avg(player), rewards(player)[], visits(player)
|
||||
// - we could make special version for 2-player zero-sum games like below
|
||||
let mut current_id: usize = node_id;
|
||||
loop {
|
||||
let node = arena.get_node_mut(current_id);
|
||||
let player = node.state.get_current_player().clone();
|
||||
match rewards.get(&player) {
|
||||
Some(reward) => {
|
||||
node.increment_visits();
|
||||
node.record_player_reward(player, *reward);
|
||||
if let Some(parent_id) = node.parent {
|
||||
current_id = parent_id;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
None => (),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn weighted_backprop<S: GameState>(
|
||||
_depth_factor: f64,
|
||||
_node_id: usize,
|
||||
_arena: &mut Arena<S>,
|
||||
_rewards: &HashMap<S::Player, RewardVal>,
|
||||
) {
|
||||
// TODO
|
||||
}
|
65
src/policy/decision/mod.rs
Normal file
65
src/policy/decision/mod.rs
Normal file
@ -0,0 +1,65 @@
|
||||
use crate::state::GameState;
|
||||
use crate::tree::arena::Arena;
|
||||
use crate::tree::node::Node;
|
||||
|
||||
/// The decision policy when determining the action in final MCTS phase
|
||||
///
|
||||
/// This policy drives how the MCTS algorithm chooses which action is the
|
||||
/// "best" from the exploration.
|
||||
#[derive(Debug)]
|
||||
pub enum DecisionPolicy {
|
||||
/// Decide on the action with the most visits
|
||||
///
|
||||
/// This option relies on the statistical confidence drive by the MCTS
|
||||
/// algorithm instead of the potentially more noisy value estimates.
|
||||
///
|
||||
/// This is the standard policy used in most MCTS implementations, and
|
||||
/// is a good selection when not hyper-maximizing for potential gain
|
||||
MostVisits,
|
||||
|
||||
/// Decide on the action with the highest average value
|
||||
///
|
||||
/// This is non-standard, but is more aggressive in attempting to gain
|
||||
/// the highest value in a decision.
|
||||
HighestValue,
|
||||
}
|
||||
|
||||
pub fn decide_on_action<S: GameState>(
|
||||
policy: &DecisionPolicy,
|
||||
root_node: &Node<S>,
|
||||
arena: &Arena<S>,
|
||||
) -> Option<S::Action> {
|
||||
match policy {
|
||||
DecisionPolicy::MostVisits => most_visits(root_node, arena),
|
||||
DecisionPolicy::HighestValue => highest_value(root_node, arena),
|
||||
}
|
||||
}
|
||||
|
||||
fn most_visits<S: GameState>(root_node: &Node<S>, arena: &Arena<S>) -> Option<S::Action> {
|
||||
let best_child_id: &usize = root_node
|
||||
.children
|
||||
.iter()
|
||||
.max_by(|&a, &b| {
|
||||
let node_a_visits = arena.get_node(*a).visits;
|
||||
let node_b_visits = arena.get_node(*b).visits;
|
||||
node_a_visits.partial_cmp(&node_b_visits).unwrap()
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
arena.get_node(*best_child_id).action.clone()
|
||||
}
|
||||
|
||||
fn highest_value<S: GameState>(root_node: &Node<S>, arena: &Arena<S>) -> Option<S::Action> {
|
||||
let player = root_node.state.get_current_player();
|
||||
let best_child_id: &usize = root_node
|
||||
.children
|
||||
.iter()
|
||||
.max_by(|&a, &b| {
|
||||
let node_a_score = arena.get_node(*a).reward_average(player);
|
||||
let node_b_score = arena.get_node(*b).reward_average(player);
|
||||
node_a_score.partial_cmp(&node_b_score).unwrap()
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
arena.get_node(*best_child_id).action.clone()
|
||||
}
|
4
src/policy/mod.rs
Normal file
4
src/policy/mod.rs
Normal file
@ -0,0 +1,4 @@
|
||||
pub mod backprop;
|
||||
pub mod decision;
|
||||
pub mod selection;
|
||||
pub mod simulation;
|
59
src/policy/selection/mod.rs
Normal file
59
src/policy/selection/mod.rs
Normal file
@ -0,0 +1,59 @@
|
||||
mod ucb1;
|
||||
mod ucb1_tuned;
|
||||
|
||||
use crate::state::GameState;
|
||||
use crate::tree::arena::Arena;
|
||||
use crate::tree::node::Node;
|
||||
|
||||
/// The selection policy used in the MCTS selection phase
|
||||
///
|
||||
/// This drives the selection of the nodes in the search tree, determining
|
||||
/// which paths are explored and evaluated.
|
||||
///
|
||||
/// In general UCB1-Tuned or UCB1 should be effective, however if necessariy
|
||||
/// a custom selection policy can be provided.
|
||||
#[derive(Debug)]
|
||||
pub enum SelectionPolicy<S: GameState> {
|
||||
/// Upper Confidence Bound 1 (UCB1) with the given exploration constant
|
||||
///
|
||||
/// The exploration constant controls the balance between exploration and
|
||||
/// exploitation. The higher the value, the mroe likely the search will
|
||||
/// explore less-visited nodes. A standard value is √2 ≈ 1.414.
|
||||
UCB1(f64),
|
||||
|
||||
/// Upper Confidence Bound 1 Tuned (UCB1-Tuned)
|
||||
///
|
||||
/// A tuned version of UCB1 instead using the empirical
|
||||
/// standard deviation of the rewards to drive exploration.
|
||||
///
|
||||
/// Auer, P., Cesa-Bianchi, N. & Fischer, P. Finite-time Analysis of the Multiarmed Bandit Problem. Machine Learning 47, 235–256 (2002). https://doi.org/10.1023/A:1013689704352
|
||||
UCB1Tuned(f64),
|
||||
|
||||
/// Custom selection policy
|
||||
Custom(Box<dyn CustomSelectionPolicy<S>>),
|
||||
}
|
||||
|
||||
/// Trait for an object implementing the selection logic whene exploring the MCTS
|
||||
/// search tree.
|
||||
///
|
||||
/// The policy should select the child of the given node which is "best" for the current player
|
||||
pub trait CustomSelectionPolicy<S: GameState>: std::fmt::Debug {
|
||||
/// Selects a child based on the policy, returning the node ID
|
||||
fn select_child(&self, node: &Node<S>, arena: &Arena<S>) -> usize;
|
||||
}
|
||||
|
||||
pub fn select_best_child<S: GameState>(
|
||||
policy: &SelectionPolicy<S>,
|
||||
node: &Node<S>,
|
||||
arena: &Arena<S>,
|
||||
) -> usize {
|
||||
match policy {
|
||||
SelectionPolicy::UCB1(exploration_constant) => {
|
||||
ucb1::select_best_child(*exploration_constant, node, arena)
|
||||
}
|
||||
SelectionPolicy::UCB1Tuned(exploration_constant) => {
|
||||
ucb1_tuned::select_best_child(*exploration_constant, node, arena)
|
||||
}
|
||||
SelectionPolicy::Custom(custom_policy) => custom_policy.select_child(node, arena),
|
||||
}
|
||||
}
|
79
src/policy/selection/ucb1.rs
Normal file
79
src/policy/selection/ucb1.rs
Normal file
@ -0,0 +1,79 @@
|
||||
//! Upper Confidence Bound 1 (UCB1) selection policy
|
||||
//!
|
||||
//! This is the classic selection policy for MCTS, which balances
|
||||
//! exploration and exploitation using the UCB1 formula:
|
||||
//!
|
||||
//! ```text
|
||||
//! UCB1 = average_reward + exploration_constant * sqrt(ln(parent_visits) / child_visits)
|
||||
//! ```
|
||||
//!
|
||||
//! Where:
|
||||
//! - `average_reward` is the average reward from simulations through this node
|
||||
//! - `exploration_constant` controls the balance between exploration and exploitation
|
||||
//! - `parent_visits` is the number of visits to the parent node
|
||||
//! - `child_visits` is the number of visits to the child node
|
||||
//!
|
||||
//! Higher exploration constants favor exploration (trying less-visited nodes),
|
||||
//! while lower values favor exploitation (choosing nodes with higher values).
|
||||
//!
|
||||
//! The commonly used value for the exploration constant is sqrt(2) ≈ 1.414,
|
||||
//! which is the default in this implementation.
|
||||
|
||||
use crate::state::GameState;
|
||||
use crate::tree::arena::Arena;
|
||||
use crate::tree::node::{Node, RewardVal};
|
||||
|
||||
/// Selects the index of the "best" child using the UCB1 selection policy
|
||||
pub fn select_best_child<S: GameState>(
|
||||
exploration_constant: f64,
|
||||
node: &Node<S>,
|
||||
arena: &Arena<S>,
|
||||
) -> usize {
|
||||
if node.is_leaf() {
|
||||
panic!("select_best_child called on leaf node");
|
||||
}
|
||||
|
||||
let player = node.state.get_current_player();
|
||||
let parent_visits = node.visits;
|
||||
let best_child = node
|
||||
.children
|
||||
.iter()
|
||||
.max_by(|&a, &b| {
|
||||
let node_a = arena.get_node(*a);
|
||||
let node_b = arena.get_node(*b);
|
||||
let ucb_a = ucb1_value(
|
||||
exploration_constant,
|
||||
node_a.reward_average(player),
|
||||
node_a.visits,
|
||||
parent_visits,
|
||||
);
|
||||
let ucb_b = ucb1_value(
|
||||
exploration_constant,
|
||||
node_b.reward_average(player),
|
||||
node_b.visits,
|
||||
parent_visits,
|
||||
);
|
||||
ucb_a.partial_cmp(&ucb_b).unwrap()
|
||||
})
|
||||
.unwrap();
|
||||
*best_child
|
||||
}
|
||||
|
||||
/// Calculates the UCB1 value for a node
|
||||
pub fn ucb1_value(
|
||||
exploration_constant: f64,
|
||||
child_value: RewardVal,
|
||||
child_visits: u64,
|
||||
parent_visits: u64,
|
||||
) -> RewardVal {
|
||||
if child_visits == 0 {
|
||||
return f64::INFINITY; // Always explore nodes that have never been visited
|
||||
}
|
||||
|
||||
// UCB1 formula: value + C * sqrt(ln(parent_visits) / child_visits)
|
||||
let exploitation = child_value;
|
||||
let exploration =
|
||||
exploration_constant * ((parent_visits as f64).ln() / child_visits as f64).sqrt();
|
||||
|
||||
exploitation + exploration
|
||||
}
|
97
src/policy/selection/ucb1_tuned.rs
Normal file
97
src/policy/selection/ucb1_tuned.rs
Normal file
@ -0,0 +1,97 @@
|
||||
//! Upper Confidence Bound 1 Tuned (UCB1-Tuned) selection policy
|
||||
//!
|
||||
//! This is a fine-tuned version of UCB which takes into account the
|
||||
//! empircally measured variance of the rewards to drive the exploration.
|
||||
//!
|
||||
//! This has been found to perform substantially better than UCB1 in most
|
||||
//! situations.
|
||||
//!
|
||||
//! Auer, P., Cesa-Bianchi, N. & Fischer, P.
|
||||
//! Finite-time Analysis of the Multiarmed Bandit Problem.
|
||||
//! Machine Learning 47, 235–256 (2002). https://doi.org/10.1023/A:1013689704352
|
||||
|
||||
use crate::state::GameState;
|
||||
use crate::tree::arena::Arena;
|
||||
use crate::tree::node::{Node, RewardVal};
|
||||
|
||||
/// Selects the index of the "best" child using the UCB1-Tuned selection policy
|
||||
pub fn select_best_child<S: GameState>(
|
||||
exploration_constant: f64,
|
||||
node: &Node<S>,
|
||||
arena: &Arena<S>,
|
||||
) -> usize {
|
||||
if node.is_leaf() {
|
||||
panic!("select_best_child called on leaf node");
|
||||
}
|
||||
|
||||
let player = node.state.get_current_player();
|
||||
let parent_visits = node.visits;
|
||||
let best_child = node
|
||||
.children
|
||||
.iter()
|
||||
.max_by(|&a, &b| {
|
||||
let node_a = arena.get_node(*a);
|
||||
let node_b = arena.get_node(*b);
|
||||
let ucb_a = ucb1_tuned_value(
|
||||
exploration_constant,
|
||||
parent_visits,
|
||||
node_a.visits,
|
||||
node_a.rewards(player),
|
||||
node_a.reward_average(player),
|
||||
);
|
||||
let ucb_b = ucb1_tuned_value(
|
||||
exploration_constant,
|
||||
parent_visits,
|
||||
node_b.visits,
|
||||
node_b.rewards(player),
|
||||
node_b.reward_average(player),
|
||||
);
|
||||
ucb_a.partial_cmp(&ucb_b).unwrap()
|
||||
})
|
||||
.unwrap();
|
||||
*best_child
|
||||
}
|
||||
|
||||
/// Calculates the UCB1-Tuned value for a node
|
||||
pub fn ucb1_tuned_value(
|
||||
exploration_constant: f64,
|
||||
parent_visits: u64,
|
||||
child_visits: u64,
|
||||
child_rewards: Option<&Vec<RewardVal>>,
|
||||
reward_avg: RewardVal,
|
||||
) -> RewardVal {
|
||||
match child_rewards {
|
||||
None => {
|
||||
RewardVal::INFINITY // Always explore nodes that have never been visited
|
||||
}
|
||||
Some(child_rewards) => {
|
||||
if child_visits == 0 {
|
||||
RewardVal::INFINITY // Always explore nodes that have never been visited
|
||||
} else {
|
||||
let parent_visits: RewardVal = parent_visits as RewardVal;
|
||||
let child_visits: RewardVal = child_visits as RewardVal;
|
||||
|
||||
// N: number of visits to the parent node
|
||||
// n: number of visits to the child node
|
||||
// x_i: reward of the ith visit to the child node
|
||||
// X: average reward of the child
|
||||
// C: exploration constant
|
||||
//
|
||||
// UCB1-Tuned = X + C * sqrt(Ln(parent_visits) / child_visits * min(1/4, V_n)
|
||||
// V(n) = sum(x_i^2)/n - X^2 + sqrt(2*ln(N)/n)
|
||||
let exploitation = reward_avg;
|
||||
let mut variance = (child_rewards.iter().map(|&x| x * x).sum::<RewardVal>()
|
||||
/ child_visits)
|
||||
- (reward_avg * reward_avg)
|
||||
+ (2.0 * parent_visits.ln() / child_visits).sqrt();
|
||||
if variance > 0.25 {
|
||||
variance = 0.25;
|
||||
}
|
||||
let exploration =
|
||||
exploration_constant * (parent_visits.ln() / child_visits * variance).sqrt();
|
||||
|
||||
exploitation + exploration
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
44
src/policy/simulation/mod.rs
Normal file
44
src/policy/simulation/mod.rs
Normal file
@ -0,0 +1,44 @@
|
||||
mod random;
|
||||
|
||||
use crate::state::GameState;
|
||||
use crate::tree::arena::Arena;
|
||||
use crate::tree::node::{Node, RewardVal};
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// The simulation policy used in the MCTS simulation phase
|
||||
///
|
||||
/// This policy drives the game simulations while evaluating the tree. While
|
||||
/// a random policy works well, a game-specific policy can be provided either
|
||||
/// as a custom policy.
|
||||
#[derive(Debug)]
|
||||
pub enum SimulationPolicy<S: GameState> {
|
||||
/// Random simulation policy
|
||||
///
|
||||
/// The sequential actions are selected randomly from the available actions
|
||||
/// at each state until a terminal state is found.
|
||||
Random,
|
||||
|
||||
/// Custom simulation policy
|
||||
Custom(Box<dyn CustomSimulationPolicy<S>>),
|
||||
}
|
||||
|
||||
/// Trait for an object implementing the simulation logic whene exploring the MCTS
|
||||
/// search tree.
|
||||
pub trait CustomSimulationPolicy<S: GameState>: std::fmt::Debug {
|
||||
/// Simulates the gameplay from the current node onward, returning the rewards
|
||||
///
|
||||
/// This should simulate the game until a terminal node is reached, returning
|
||||
/// the final reward for each player at the terminal node
|
||||
fn simulate(&self, node: &Node<S>, arena: &Arena<S>) -> HashMap<S::Player, RewardVal>;
|
||||
}
|
||||
|
||||
pub fn simulate_reward<S: GameState>(
|
||||
policy: &SimulationPolicy<S>,
|
||||
node: &Node<S>,
|
||||
arena: &Arena<S>,
|
||||
) -> HashMap<S::Player, RewardVal> {
|
||||
match policy {
|
||||
SimulationPolicy::Random => random::simulate(node),
|
||||
SimulationPolicy::Custom(custom_policy) => custom_policy.simulate(node, arena),
|
||||
}
|
||||
}
|
18
src/policy/simulation/random.rs
Normal file
18
src/policy/simulation/random.rs
Normal file
@ -0,0 +1,18 @@
|
||||
//! Random play simulation policy
|
||||
//!
|
||||
//! Actions are chosen at random
|
||||
|
||||
use crate::state::GameState;
|
||||
use crate::tree::node::{Node, RewardVal};
|
||||
use rand::prelude::SliceRandom;
|
||||
use std::collections::HashMap;
|
||||
|
||||
pub fn simulate<S: GameState>(node: &Node<S>) -> HashMap<S::Player, RewardVal> {
|
||||
let mut state: S = node.state.clone();
|
||||
while !state.is_terminal() {
|
||||
let legal_actions = state.get_legal_actions();
|
||||
let action = legal_actions.choose(&mut rand::thread_rng()).unwrap();
|
||||
state = state.state_after_action(&action);
|
||||
}
|
||||
state.rewards_for_players()
|
||||
}
|
90
src/state.rs
Normal file
90
src/state.rs
Normal file
@ -0,0 +1,90 @@
|
||||
use crate::tree::node::RewardVal;
|
||||
use std::collections::HashMap;
|
||||
use std::fmt::Debug;
|
||||
use std::hash::Hash;
|
||||
|
||||
/// Trait for the game state used in MCTS
|
||||
///
|
||||
/// When leveraging MCTS for your game, you must implement this trait to provide
|
||||
/// the specifics for your game.
|
||||
pub trait GameState: Clone {
|
||||
/// The type of actions that can be taken in the game
|
||||
type Action: Action;
|
||||
|
||||
/// The type of players in the game
|
||||
type Player: Player;
|
||||
|
||||
/// Returns if the game state is terminal, i.e. the game is over
|
||||
///
|
||||
/// A game state is terminal when no other actions are possible. This can be
|
||||
/// the result of a player winning, a draw, or because some other conditions
|
||||
/// have been met leading to a game with no further possible states.
|
||||
///
|
||||
/// The default implementation returns True if `get_legal_actions()` returns
|
||||
/// an empty list. It is recommended to override this for a more efficient
|
||||
/// implementation if possible.
|
||||
fn is_terminal(&self) -> bool {
|
||||
let actions = self.get_legal_actions();
|
||||
actions.len() == 0
|
||||
}
|
||||
|
||||
/// Returns the list of legal actions for the game state
|
||||
///
|
||||
/// This method must return all possible actions that can be made from the
|
||||
/// current game state.
|
||||
fn get_legal_actions(&self) -> Vec<Self::Action>;
|
||||
|
||||
/// Returns the game state resulting from applying the action to the state
|
||||
///
|
||||
/// This function should not modify the current state directly, and
|
||||
/// instead should modify a copy of the state and return that.
|
||||
fn state_after_action(&self, action: &Self::Action) -> Self;
|
||||
|
||||
/// Returns the reward from the perspective of the given player for the game state
|
||||
///
|
||||
/// This evaluates the current state from the perspective of the given player, and
|
||||
/// returns the reward indicating how good of a result the given state is for the
|
||||
/// player.
|
||||
///
|
||||
/// This is used in the MCTS backpropagation and simulation phases to evaluate
|
||||
/// the value of a given node in the search tree.
|
||||
///
|
||||
/// A general rule of thumb for values are:
|
||||
/// - 1.0 => a win for the player
|
||||
/// - 0.5 => a draw
|
||||
/// - 0.0 => a loss for the player
|
||||
///
|
||||
/// Other values can be used for relative wins or losses
|
||||
fn reward_for_player(&self, player: &Self::Player) -> RewardVal;
|
||||
|
||||
/// Returns the rewards for all players at the current state
|
||||
fn rewards_for_players(&self) -> HashMap<Self::Player, RewardVal>;
|
||||
|
||||
/// Returns the player whose turn it is for the game state
|
||||
///
|
||||
/// This is used for evaluating the state, so for simultaneous games
|
||||
/// consider the "current player" as the one from whose perspective we are
|
||||
/// evaluating the game state from
|
||||
fn get_current_player(&self) -> &Self::Player;
|
||||
}
|
||||
|
||||
/// Trait used for actions that can be taken in a game
|
||||
///
|
||||
/// An action is dependent upon the specific game being defined, and includes
|
||||
/// things like moves, attacks, and other decisions.
|
||||
pub trait Action: Clone + Debug {
|
||||
/// Returns a uniqie identifier for this action
|
||||
fn id(&self) -> usize;
|
||||
}
|
||||
|
||||
/// Trait used for players participating in a game
|
||||
pub trait Player: Clone + Debug + PartialEq + Eq + Hash {}
|
||||
|
||||
/// Convenience implemnentation of a Player for usize
|
||||
impl Player for usize {}
|
||||
|
||||
/// Convenience implemnentation of a Player for char
|
||||
impl Player for char {}
|
||||
|
||||
/// Convenience implemnentation of a Player for String
|
||||
impl Player for String {}
|
46
src/tree/arena.rs
Normal file
46
src/tree/arena.rs
Normal file
@ -0,0 +1,46 @@
|
||||
use crate::state::GameState;
|
||||
use crate::tree::node::Node;
|
||||
|
||||
/// An arena for Node allocation
|
||||
///
|
||||
/// We use an arena for node allocation to improve performance of our search.
|
||||
/// The memory is contiguous which allows for faster movement through the tree,
|
||||
/// as well as more efficient destruction as our MCTS search will destroy the
|
||||
/// entire tree at once.
|
||||
pub struct Arena<S: GameState> {
|
||||
pub nodes: Vec<Node<S>>,
|
||||
}
|
||||
|
||||
impl<S: GameState> Arena<S> {
|
||||
/// Create a new Arena with the given initial capacity
|
||||
///
|
||||
/// The arena creates a contiguous block. By reserving an initial capacity
|
||||
/// that is sufficient to encapsulate a full search tree we can reduce the
|
||||
/// number of reallocs that are required. This number is highly game
|
||||
/// dependent.
|
||||
pub fn new(initial_capacity: usize) -> Self {
|
||||
Arena {
|
||||
nodes: Vec::with_capacity(initial_capacity),
|
||||
}
|
||||
}
|
||||
|
||||
/// Adds a node to the Arena, returning its identifier
|
||||
///
|
||||
/// This appends the node to the allocated Arena, and returns the nodes
|
||||
/// index in the arena which is used as an identifier for later retrieval.
|
||||
pub fn add_node(&mut self, node: Node<S>) -> usize {
|
||||
let id = self.nodes.len();
|
||||
self.nodes.push(node);
|
||||
id
|
||||
}
|
||||
|
||||
/// Retrieves a mutable reference to a Node in the Arena
|
||||
pub fn get_node_mut(&mut self, id: usize) -> &mut Node<S> {
|
||||
&mut self.nodes[id]
|
||||
}
|
||||
|
||||
/// Retrieves a reference to a Node in the Arena
|
||||
pub fn get_node(&self, id: usize) -> &Node<S> {
|
||||
&self.nodes[id]
|
||||
}
|
||||
}
|
2
src/tree/mod.rs
Normal file
2
src/tree/mod.rs
Normal file
@ -0,0 +1,2 @@
|
||||
pub mod arena;
|
||||
pub mod node;
|
114
src/tree/node.rs
Normal file
114
src/tree/node.rs
Normal file
@ -0,0 +1,114 @@
|
||||
use std::collections::HashMap;
|
||||
use std::fmt::Debug;
|
||||
|
||||
use crate::state::GameState;
|
||||
|
||||
/// The type used for reward values
|
||||
pub type RewardVal = f64;
|
||||
|
||||
/// A node in the MCTS tree
|
||||
///
|
||||
/// A node represents a given game state and, using the path from the root node,
|
||||
/// the actions that led to the given state. A node has a number of children
|
||||
/// nodes representing the game states reachable from the given state, after
|
||||
/// a given action. This creates the tree that MCTS iterates through.
|
||||
///
|
||||
/// This class is not thread safe, as the library does not provide for parallel
|
||||
/// search.
|
||||
#[derive(Debug)]
|
||||
pub struct Node<S: GameState> {
|
||||
/// The game state at the given node, after `action`
|
||||
pub state: S,
|
||||
|
||||
/// The action that led to this state from its parent
|
||||
pub action: Option<S::Action>,
|
||||
|
||||
/// The identifier of the parent Node
|
||||
pub parent: Option<usize>,
|
||||
|
||||
/// The number of times this node has been visited
|
||||
pub visits: u64,
|
||||
|
||||
/// The player's evaluation of the node
|
||||
pub player_view: HashMap<S::Player, PlayerNodeView>,
|
||||
|
||||
/// The identifiers of children nodes, states reachable from this one
|
||||
pub children: Vec<usize>,
|
||||
}
|
||||
|
||||
impl<S: GameState> Node<S> {
|
||||
pub fn new(state: S, action: Option<S::Action>, parent: Option<usize>) -> Self {
|
||||
Node {
|
||||
state,
|
||||
action,
|
||||
parent,
|
||||
visits: 0,
|
||||
player_view: HashMap::with_capacity(2),
|
||||
children: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_leaf(&self) -> bool {
|
||||
self.children.is_empty()
|
||||
}
|
||||
|
||||
pub fn reward_sum(&self, player: &S::Player) -> RewardVal {
|
||||
match self.player_view.get(player) {
|
||||
Some(pv) => pv.reward_sum,
|
||||
None => 0.0,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn reward_average(&self, player: &S::Player) -> RewardVal {
|
||||
match self.player_view.get(player) {
|
||||
Some(pv) => pv.reward_average,
|
||||
None => 0.0,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn rewards(&self, player: &S::Player) -> Option<&Vec<RewardVal>> {
|
||||
match self.player_view.get(player) {
|
||||
Some(pv) => Some(&pv.rewards),
|
||||
None => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn increment_visits(&mut self) {
|
||||
self.visits += 1
|
||||
}
|
||||
|
||||
pub fn record_player_reward(&mut self, player: S::Player, reward: RewardVal) {
|
||||
let pv = self
|
||||
.player_view
|
||||
.entry(player)
|
||||
.or_insert(PlayerNodeView::default());
|
||||
pv.rewards.push(reward);
|
||||
pv.reward_sum += reward;
|
||||
pv.reward_average = pv.reward_sum / pv.rewards.len() as f64;
|
||||
}
|
||||
}
|
||||
|
||||
/// A player's specific perspective of a node's value
|
||||
///
|
||||
/// Each player has their own idea of the value of a node.
|
||||
#[derive(Debug)]
|
||||
pub struct PlayerNodeView {
|
||||
/// The total reward from simulations through this node
|
||||
pub reward_sum: RewardVal,
|
||||
|
||||
/// The average reward from simulations through this node, often called the node value
|
||||
pub reward_average: RewardVal,
|
||||
|
||||
/// The rewards we have gotten so far for simulations through this node
|
||||
pub rewards: Vec<RewardVal>,
|
||||
}
|
||||
|
||||
impl Default for PlayerNodeView {
|
||||
fn default() -> Self {
|
||||
PlayerNodeView {
|
||||
reward_sum: 0.0,
|
||||
reward_average: 0.0,
|
||||
rewards: Vec::new(),
|
||||
}
|
||||
}
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user