Compare commits

..

1 Commits

Author SHA1 Message Date
04a55f0a56 Add a "threads" feature for mutli-threading 2025-06-30 19:49:45 -07:00
6 changed files with 70 additions and 57 deletions

View File

@ -10,13 +10,15 @@ readme = "README.md"
keywords = ["mcts", "rust", "monte_carlo", "tree", "ai", "ml"] keywords = ["mcts", "rust", "monte_carlo", "tree", "ai", "ml"]
categories = ["algorithms", "data-structures"] categories = ["algorithms", "data-structures"]
[features]
threads = []
[dependencies] [dependencies]
rand = "~0.9" rand = "~0.9"
thiserror = "~2.0" thiserror = "~2.0"
[dev-dependencies] [dev-dependencies]
divan = "~0.1" divan = "0.1.21"
[[bench]] [[bench]]
name = "e2e" name = "e2e"

View File

@ -61,7 +61,11 @@ struct Move {
index: usize, index: usize,
} }
impl Action for Move {} impl Action for Move {
fn id(&self) -> usize {
self.index
}
}
/// Tic-Tac-Toe game state /// Tic-Tac-Toe game state
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@ -125,19 +129,6 @@ impl TicTacToe {
None None
} }
fn reward_for_player(&self, player: &Player) -> RewardVal {
if let Some(winner) = self.get_winner() {
if winner == *player {
return 1.0; // Win
} else {
return 0.0; // Loss
}
}
// Draw
0.5
}
} }
impl GameState for TicTacToe { impl GameState for TicTacToe {
@ -174,6 +165,19 @@ impl GameState for TicTacToe {
self.get_winner().is_some() || self.moves_played == 9 self.get_winner().is_some() || self.moves_played == 9
} }
fn reward_for_player(&self, player: &Self::Player) -> RewardVal {
if let Some(winner) = self.get_winner() {
if winner == *player {
return 1.0; // Win
} else {
return 0.0; // Loss
}
}
// Draw
0.5
}
fn rewards_for_players(&self) -> HashMap<Self::Player, RewardVal> { fn rewards_for_players(&self) -> HashMap<Self::Player, RewardVal> {
HashMap::from_iter(vec![ HashMap::from_iter(vec![
(Player::X, self.reward_for_player(&Player::X)), (Player::X, self.reward_for_player(&Player::X)),

View File

@ -84,7 +84,11 @@ struct Move {
index: usize, index: usize,
} }
impl Action for Move {} impl Action for Move {
fn id(&self) -> usize {
self.index
}
}
/// Tic-Tac-Toe game state /// Tic-Tac-Toe game state
#[derive(Clone)] #[derive(Clone)]
@ -148,19 +152,6 @@ impl TicTacToe {
None None
} }
fn reward_for_player(&self, player: &Player) -> RewardVal {
if let Some(winner) = self.get_winner() {
if winner == *player {
return 1.0; // Win
} else {
return 0.0; // Loss
}
}
// Draw
0.5
}
} }
impl GameState for TicTacToe { impl GameState for TicTacToe {
@ -197,6 +188,19 @@ impl GameState for TicTacToe {
self.get_winner().is_some() || self.moves_played == 9 self.get_winner().is_some() || self.moves_played == 9
} }
fn reward_for_player(&self, player: &Self::Player) -> RewardVal {
if let Some(winner) = self.get_winner() {
if winner == *player {
return 1.0; // Win
} else {
return 0.0; // Loss
}
}
// Draw
0.5
}
fn rewards_for_players(&self) -> HashMap<Self::Player, RewardVal> { fn rewards_for_players(&self) -> HashMap<Self::Player, RewardVal> {
HashMap::from_iter(vec![ HashMap::from_iter(vec![
(Player::X, self.reward_for_player(&Player::X)), (Player::X, self.reward_for_player(&Player::X)),

View File

@ -119,7 +119,11 @@ struct Move {
index: usize, index: usize,
} }
impl Action for Move {} impl Action for Move {
fn id(&self) -> usize {
self.index
}
}
/// Tic-Tac-Toe game state /// Tic-Tac-Toe game state
#[derive(Clone)] #[derive(Clone)]
@ -191,19 +195,6 @@ impl TicTacToe {
None None
} }
fn reward_for_player(&self, player: &Player) -> RewardVal {
if let Some(winner) = self.get_winner() {
if winner == *player {
return 1.0; // Win
} else {
return 0.0; // Loss
}
}
// Draw
0.5
}
} }
impl GameState for TicTacToe { impl GameState for TicTacToe {
@ -240,6 +231,19 @@ impl GameState for TicTacToe {
self.get_winner().is_some() || self.moves_played == 9 self.get_winner().is_some() || self.moves_played == 9
} }
fn reward_for_player(&self, player: &Self::Player) -> RewardVal {
if let Some(winner) = self.get_winner() {
if winner == *player {
return 1.0; // Win
} else {
return 0.0; // Loss
}
}
// Draw
0.5
}
fn rewards_for_players(&self) -> HashMap<Self::Player, RewardVal> { fn rewards_for_players(&self) -> HashMap<Self::Player, RewardVal> {
HashMap::from_iter(vec![ HashMap::from_iter(vec![
(Player::X, self.reward_for_player(&Player::X)), (Player::X, self.reward_for_player(&Player::X)),

View File

@ -75,15 +75,8 @@ impl<'conf, S: GameState + std::fmt::Debug> MCTS<'conf, S> {
if !selected_node.state.is_terminal() { if !selected_node.state.is_terminal() {
self.expand(selected_id); self.expand(selected_id);
let children: &Vec<usize> = &self.arena.get_node(selected_id).children; let children: &Vec<usize> = &self.arena.get_node(selected_id).children;
match children.choose(&mut rand::rng()) { let random_child: usize = *children.choose(&mut rand::rng()).unwrap();
Some(&random_child) => { selected_id = random_child;
selected_id = random_child;
}
None => {
// We ran out of nodes
return Err(MCTSError::NonTerminalGame);
}
}
} }
let rewards = self.simulate(selected_id); let rewards = self.simulate(selected_id);
self.backprop(selected_id, &rewards); self.backprop(selected_id, &rewards);

View File

@ -40,11 +40,11 @@ pub trait GameState: Clone + Debug {
/// instead should modify a copy of the state and return that. /// instead should modify a copy of the state and return that.
fn state_after_action(&self, action: &Self::Action) -> Self; fn state_after_action(&self, action: &Self::Action) -> Self;
/// Returns the rewards for all players from their perspective for the game state /// Returns the reward from the perspective of the given player for the game state
/// ///
/// This evaluates the current state from the perspective of each player, and /// This evaluates the current state from the perspective of the given player, and
/// returns a HashMap mapping each player to the result of this evaluation, which /// returns the reward indicating how good of a result the given state is for the
/// we call the reward. /// player.
/// ///
/// This is used in the MCTS backpropagation and simulation phases to evaluate /// This is used in the MCTS backpropagation and simulation phases to evaluate
/// the value of a given node in the search tree. /// the value of a given node in the search tree.
@ -55,6 +55,9 @@ pub trait GameState: Clone + Debug {
/// - 0.0 => a loss for the player /// - 0.0 => a loss for the player
/// ///
/// Other values can be used for relative wins or losses /// Other values can be used for relative wins or losses
fn reward_for_player(&self, player: &Self::Player) -> RewardVal;
/// Returns the rewards for all players at the current state
fn rewards_for_players(&self) -> HashMap<Self::Player, RewardVal>; fn rewards_for_players(&self) -> HashMap<Self::Player, RewardVal>;
/// Returns the player whose turn it is for the game state /// Returns the player whose turn it is for the game state
@ -69,7 +72,10 @@ pub trait GameState: Clone + Debug {
/// ///
/// An action is dependent upon the specific game being defined, and includes /// An action is dependent upon the specific game being defined, and includes
/// things like moves, attacks, and other decisions. /// things like moves, attacks, and other decisions.
pub trait Action: Clone + Debug {} pub trait Action: Clone + Debug {
/// Returns a uniqie identifier for this action
fn id(&self) -> usize;
}
/// Trait used for players participating in a game /// Trait used for players participating in a game
pub trait Player: Clone + Debug + PartialEq + Eq + Hash {} pub trait Player: Clone + Debug + PartialEq + Eq + Hash {}