Compare commits
3 Commits
multithrea
...
master
Author | SHA1 | Date | |
---|---|---|---|
6aa9002e92 | |||
76051cd76b | |||
a7102a0e44 |
@ -10,15 +10,13 @@ readme = "README.md"
|
|||||||
keywords = ["mcts", "rust", "monte_carlo", "tree", "ai", "ml"]
|
keywords = ["mcts", "rust", "monte_carlo", "tree", "ai", "ml"]
|
||||||
categories = ["algorithms", "data-structures"]
|
categories = ["algorithms", "data-structures"]
|
||||||
|
|
||||||
[features]
|
|
||||||
threads = []
|
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
rand = "~0.9"
|
rand = "~0.9"
|
||||||
thiserror = "~2.0"
|
thiserror = "~2.0"
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
divan = "0.1.21"
|
divan = "~0.1"
|
||||||
|
|
||||||
[[bench]]
|
[[bench]]
|
||||||
name = "e2e"
|
name = "e2e"
|
||||||
|
@ -61,11 +61,7 @@ struct Move {
|
|||||||
index: usize,
|
index: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Action for Move {
|
impl Action for Move {}
|
||||||
fn id(&self) -> usize {
|
|
||||||
self.index
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Tic-Tac-Toe game state
|
/// Tic-Tac-Toe game state
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
@ -129,6 +125,19 @@ impl TicTacToe {
|
|||||||
|
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn reward_for_player(&self, player: &Player) -> RewardVal {
|
||||||
|
if let Some(winner) = self.get_winner() {
|
||||||
|
if winner == *player {
|
||||||
|
return 1.0; // Win
|
||||||
|
} else {
|
||||||
|
return 0.0; // Loss
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Draw
|
||||||
|
0.5
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl GameState for TicTacToe {
|
impl GameState for TicTacToe {
|
||||||
@ -165,19 +174,6 @@ impl GameState for TicTacToe {
|
|||||||
self.get_winner().is_some() || self.moves_played == 9
|
self.get_winner().is_some() || self.moves_played == 9
|
||||||
}
|
}
|
||||||
|
|
||||||
fn reward_for_player(&self, player: &Self::Player) -> RewardVal {
|
|
||||||
if let Some(winner) = self.get_winner() {
|
|
||||||
if winner == *player {
|
|
||||||
return 1.0; // Win
|
|
||||||
} else {
|
|
||||||
return 0.0; // Loss
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Draw
|
|
||||||
0.5
|
|
||||||
}
|
|
||||||
|
|
||||||
fn rewards_for_players(&self) -> HashMap<Self::Player, RewardVal> {
|
fn rewards_for_players(&self) -> HashMap<Self::Player, RewardVal> {
|
||||||
HashMap::from_iter(vec![
|
HashMap::from_iter(vec![
|
||||||
(Player::X, self.reward_for_player(&Player::X)),
|
(Player::X, self.reward_for_player(&Player::X)),
|
||||||
|
@ -84,11 +84,7 @@ struct Move {
|
|||||||
index: usize,
|
index: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Action for Move {
|
impl Action for Move {}
|
||||||
fn id(&self) -> usize {
|
|
||||||
self.index
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Tic-Tac-Toe game state
|
/// Tic-Tac-Toe game state
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
@ -152,6 +148,19 @@ impl TicTacToe {
|
|||||||
|
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn reward_for_player(&self, player: &Player) -> RewardVal {
|
||||||
|
if let Some(winner) = self.get_winner() {
|
||||||
|
if winner == *player {
|
||||||
|
return 1.0; // Win
|
||||||
|
} else {
|
||||||
|
return 0.0; // Loss
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Draw
|
||||||
|
0.5
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl GameState for TicTacToe {
|
impl GameState for TicTacToe {
|
||||||
@ -188,19 +197,6 @@ impl GameState for TicTacToe {
|
|||||||
self.get_winner().is_some() || self.moves_played == 9
|
self.get_winner().is_some() || self.moves_played == 9
|
||||||
}
|
}
|
||||||
|
|
||||||
fn reward_for_player(&self, player: &Self::Player) -> RewardVal {
|
|
||||||
if let Some(winner) = self.get_winner() {
|
|
||||||
if winner == *player {
|
|
||||||
return 1.0; // Win
|
|
||||||
} else {
|
|
||||||
return 0.0; // Loss
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Draw
|
|
||||||
0.5
|
|
||||||
}
|
|
||||||
|
|
||||||
fn rewards_for_players(&self) -> HashMap<Self::Player, RewardVal> {
|
fn rewards_for_players(&self) -> HashMap<Self::Player, RewardVal> {
|
||||||
HashMap::from_iter(vec![
|
HashMap::from_iter(vec![
|
||||||
(Player::X, self.reward_for_player(&Player::X)),
|
(Player::X, self.reward_for_player(&Player::X)),
|
||||||
|
@ -119,11 +119,7 @@ struct Move {
|
|||||||
index: usize,
|
index: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Action for Move {
|
impl Action for Move {}
|
||||||
fn id(&self) -> usize {
|
|
||||||
self.index
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Tic-Tac-Toe game state
|
/// Tic-Tac-Toe game state
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
@ -195,6 +191,19 @@ impl TicTacToe {
|
|||||||
|
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn reward_for_player(&self, player: &Player) -> RewardVal {
|
||||||
|
if let Some(winner) = self.get_winner() {
|
||||||
|
if winner == *player {
|
||||||
|
return 1.0; // Win
|
||||||
|
} else {
|
||||||
|
return 0.0; // Loss
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Draw
|
||||||
|
0.5
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl GameState for TicTacToe {
|
impl GameState for TicTacToe {
|
||||||
@ -231,19 +240,6 @@ impl GameState for TicTacToe {
|
|||||||
self.get_winner().is_some() || self.moves_played == 9
|
self.get_winner().is_some() || self.moves_played == 9
|
||||||
}
|
}
|
||||||
|
|
||||||
fn reward_for_player(&self, player: &Self::Player) -> RewardVal {
|
|
||||||
if let Some(winner) = self.get_winner() {
|
|
||||||
if winner == *player {
|
|
||||||
return 1.0; // Win
|
|
||||||
} else {
|
|
||||||
return 0.0; // Loss
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Draw
|
|
||||||
0.5
|
|
||||||
}
|
|
||||||
|
|
||||||
fn rewards_for_players(&self) -> HashMap<Self::Player, RewardVal> {
|
fn rewards_for_players(&self) -> HashMap<Self::Player, RewardVal> {
|
||||||
HashMap::from_iter(vec![
|
HashMap::from_iter(vec![
|
||||||
(Player::X, self.reward_for_player(&Player::X)),
|
(Player::X, self.reward_for_player(&Player::X)),
|
||||||
|
11
src/mcts.rs
11
src/mcts.rs
@ -75,8 +75,15 @@ impl<'conf, S: GameState + std::fmt::Debug> MCTS<'conf, S> {
|
|||||||
if !selected_node.state.is_terminal() {
|
if !selected_node.state.is_terminal() {
|
||||||
self.expand(selected_id);
|
self.expand(selected_id);
|
||||||
let children: &Vec<usize> = &self.arena.get_node(selected_id).children;
|
let children: &Vec<usize> = &self.arena.get_node(selected_id).children;
|
||||||
let random_child: usize = *children.choose(&mut rand::rng()).unwrap();
|
match children.choose(&mut rand::rng()) {
|
||||||
selected_id = random_child;
|
Some(&random_child) => {
|
||||||
|
selected_id = random_child;
|
||||||
|
}
|
||||||
|
None => {
|
||||||
|
// We ran out of nodes
|
||||||
|
return Err(MCTSError::NonTerminalGame);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
let rewards = self.simulate(selected_id);
|
let rewards = self.simulate(selected_id);
|
||||||
self.backprop(selected_id, &rewards);
|
self.backprop(selected_id, &rewards);
|
||||||
|
16
src/state.rs
16
src/state.rs
@ -40,11 +40,11 @@ pub trait GameState: Clone + Debug {
|
|||||||
/// instead should modify a copy of the state and return that.
|
/// instead should modify a copy of the state and return that.
|
||||||
fn state_after_action(&self, action: &Self::Action) -> Self;
|
fn state_after_action(&self, action: &Self::Action) -> Self;
|
||||||
|
|
||||||
/// Returns the reward from the perspective of the given player for the game state
|
/// Returns the rewards for all players from their perspective for the game state
|
||||||
///
|
///
|
||||||
/// This evaluates the current state from the perspective of the given player, and
|
/// This evaluates the current state from the perspective of each player, and
|
||||||
/// returns the reward indicating how good of a result the given state is for the
|
/// returns a HashMap mapping each player to the result of this evaluation, which
|
||||||
/// player.
|
/// we call the reward.
|
||||||
///
|
///
|
||||||
/// This is used in the MCTS backpropagation and simulation phases to evaluate
|
/// This is used in the MCTS backpropagation and simulation phases to evaluate
|
||||||
/// the value of a given node in the search tree.
|
/// the value of a given node in the search tree.
|
||||||
@ -55,9 +55,6 @@ pub trait GameState: Clone + Debug {
|
|||||||
/// - 0.0 => a loss for the player
|
/// - 0.0 => a loss for the player
|
||||||
///
|
///
|
||||||
/// Other values can be used for relative wins or losses
|
/// Other values can be used for relative wins or losses
|
||||||
fn reward_for_player(&self, player: &Self::Player) -> RewardVal;
|
|
||||||
|
|
||||||
/// Returns the rewards for all players at the current state
|
|
||||||
fn rewards_for_players(&self) -> HashMap<Self::Player, RewardVal>;
|
fn rewards_for_players(&self) -> HashMap<Self::Player, RewardVal>;
|
||||||
|
|
||||||
/// Returns the player whose turn it is for the game state
|
/// Returns the player whose turn it is for the game state
|
||||||
@ -72,10 +69,7 @@ pub trait GameState: Clone + Debug {
|
|||||||
///
|
///
|
||||||
/// An action is dependent upon the specific game being defined, and includes
|
/// An action is dependent upon the specific game being defined, and includes
|
||||||
/// things like moves, attacks, and other decisions.
|
/// things like moves, attacks, and other decisions.
|
||||||
pub trait Action: Clone + Debug {
|
pub trait Action: Clone + Debug {}
|
||||||
/// Returns a uniqie identifier for this action
|
|
||||||
fn id(&self) -> usize;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Trait used for players participating in a game
|
/// Trait used for players participating in a game
|
||||||
pub trait Player: Clone + Debug + PartialEq + Eq + Hash {}
|
pub trait Player: Clone + Debug + PartialEq + Eq + Hash {}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user