Compare commits

..

No commits in common. "6a338182389a8847615d71dd8d51275df144a64e" and "6cc6e6a7ba5589b60656c2e1aec3f9f544e9abd6" have entirely different histories.

9 changed files with 52 additions and 100 deletions

46
Cargo.lock generated
View File

@ -2,12 +2,6 @@
# It is not intended for manual editing. # It is not intended for manual editing.
version = 4 version = 4
[[package]]
name = "bitflags"
version = "2.9.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1b8e56985ec62d17e9c1001dc89c88ecd7dc08e47eba5ec7c29c7b5eeecde967"
[[package]] [[package]]
name = "cfg-if" name = "cfg-if"
version = "1.0.1" version = "1.0.1"
@ -16,13 +10,12 @@ checksum = "9555578bc9e57714c812a1f84e4fc5b4d21fcb063490c624de019f7464c91268"
[[package]] [[package]]
name = "getrandom" name = "getrandom"
version = "0.3.3" version = "0.2.16"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "26145e563e54f2cadc477553f1ec5ee650b00862f0a58bcd12cbdc5f0ea2d2f4" checksum = "335ff9f135e4384c8150d6f27c6daed433577f86b4750418338c01a1a2528592"
dependencies = [ dependencies = [
"cfg-if", "cfg-if",
"libc", "libc",
"r-efi",
"wasi", "wasi",
] ]
@ -59,27 +52,22 @@ dependencies = [
"proc-macro2", "proc-macro2",
] ]
[[package]]
name = "r-efi"
version = "5.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f"
[[package]] [[package]]
name = "rand" name = "rand"
version = "0.9.1" version = "0.8.5"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9fbfd9d094a40bf3ae768db9361049ace4c0e04a4fd6b359518bd7b73a73dd97" checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404"
dependencies = [ dependencies = [
"libc",
"rand_chacha", "rand_chacha",
"rand_core", "rand_core",
] ]
[[package]] [[package]]
name = "rand_chacha" name = "rand_chacha"
version = "0.9.0" version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb" checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88"
dependencies = [ dependencies = [
"ppv-lite86", "ppv-lite86",
"rand_core", "rand_core",
@ -87,9 +75,9 @@ dependencies = [
[[package]] [[package]]
name = "rand_core" name = "rand_core"
version = "0.9.3" version = "0.6.4"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "99d9a13982dcf210057a8a78572b2217b667c3beacbf3a0d8b454f6f82837d38" checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c"
dependencies = [ dependencies = [
"getrandom", "getrandom",
] ]
@ -141,21 +129,9 @@ checksum = "5a5f39404a5da50712a4c1eecf25e90dd62b613502b7e925fd4e4d19b5c96512"
[[package]] [[package]]
name = "wasi" name = "wasi"
version = "0.14.2+wasi-0.2.4" version = "0.11.1+wasi-snapshot-preview1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9683f9a5a998d873c0d21fcbe3c083009670149a8fab228644b8bd36b2c48cb3" checksum = "ccf3ec651a847eb01de73ccad15eb7d99f80485de043efb2f370cd654f4ea44b"
dependencies = [
"wit-bindgen-rt",
]
[[package]]
name = "wit-bindgen-rt"
version = "0.39.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6f42320e61fe2cfd34354ecb597f86f413484a798ba44a8ca1165c58d42da6c1"
dependencies = [
"bitflags",
]
[[package]] [[package]]
name = "zerocopy" name = "zerocopy"

View File

@ -12,5 +12,5 @@ categories = ["algorithms", "data-structures"]
[dependencies] [dependencies]
rand = "~0.9" rand = "~0.8"
thiserror = "~2.0" thiserror = "~2.0"

View File

@ -40,7 +40,6 @@ fn main() {
// Find the best move // Find the best move
match mcts.search() { match mcts.search() {
Ok(action) => { Ok(action) => {
mcts.print_tree();
println!( println!(
"AI chooses: {} (row {}, col {})", "AI chooses: {} (row {}, col {})",
action.index, action.index,

View File

@ -6,8 +6,7 @@ use crate::policy::simulation::simulate_reward;
use crate::state::GameState; use crate::state::GameState;
use crate::tree::arena::Arena; use crate::tree::arena::Arena;
use crate::tree::node::{Node, RewardVal}; use crate::tree::node::{Node, RewardVal};
use crate::tree::print_tree; use rand::prelude::SliceRandom;
use rand::prelude::IndexedRandom;
use std::collections::HashMap; use std::collections::HashMap;
use std::time::Instant; use std::time::Instant;
@ -75,7 +74,7 @@ impl<'conf, S: GameState + std::fmt::Debug> MCTS<'conf, S> {
if !selected_node.state.is_terminal() { if !selected_node.state.is_terminal() {
self.expand(selected_id); self.expand(selected_id);
let children: &Vec<usize> = &self.arena.get_node(selected_id).children; let children: &Vec<usize> = &self.arena.get_node(selected_id).children;
let random_child: usize = *children.choose(&mut rand::rng()).unwrap(); let random_child: usize = *children.choose(&mut rand::thread_rng()).unwrap();
selected_id = random_child; selected_id = random_child;
} }
let rewards = self.simulate(selected_id); let rewards = self.simulate(selected_id);
@ -131,10 +130,6 @@ impl<'conf, S: GameState + std::fmt::Debug> MCTS<'conf, S> {
None => Err(MCTSError::NoBestAction), None => Err(MCTSError::NoBestAction),
} }
} }
pub fn print_tree(&self) {
print_tree(self.root_id, &self.arena)
}
} }
/// Errors returned by the MCTS algorithm /// Errors returned by the MCTS algorithm

View File

@ -69,14 +69,18 @@ fn standard_backprop<S: GameState>(
let mut current_id: usize = node_id; let mut current_id: usize = node_id;
loop { loop {
let node = arena.get_node_mut(current_id); let node = arena.get_node_mut(current_id);
node.increment_visits(); let player = node.state.get_current_player().clone();
for (player, reward) in rewards.iter() { match rewards.get(&player) {
node.record_player_reward(player.clone(), *reward); Some(reward) => {
} node.increment_visits();
if let Some(parent_id) = node.parent { node.record_player_reward(player, *reward);
current_id = parent_id; if let Some(parent_id) = node.parent {
} else { current_id = parent_id;
break; } else {
break;
}
}
None => (),
} }
} }
} }
@ -90,14 +94,19 @@ fn weighted_backprop<S: GameState>(
let mut current_id: usize = node_id; let mut current_id: usize = node_id;
loop { loop {
let node = arena.get_node_mut(current_id); let node = arena.get_node_mut(current_id);
let player = node.state.get_current_player().clone();
let weight = weight_for_depth(depth_factor, node.depth); let weight = weight_for_depth(depth_factor, node.depth);
for (player, reward) in rewards.iter() { match rewards.get(&player) {
node.record_player_reward(player.clone(), (*reward) * weight); Some(reward) => {
} node.increment_visits();
if let Some(parent_id) = node.parent { node.record_player_reward(player, (*reward) * weight);
current_id = parent_id; if let Some(parent_id) = node.parent {
} else { current_id = parent_id;
break; } else {
break;
}
}
None => (),
} }
} }
} }

View File

@ -4,14 +4,14 @@
use crate::state::GameState; use crate::state::GameState;
use crate::tree::node::{Node, RewardVal}; use crate::tree::node::{Node, RewardVal};
use rand::prelude::IndexedRandom; use rand::prelude::SliceRandom;
use std::collections::HashMap; use std::collections::HashMap;
pub fn simulate<S: GameState>(node: &Node<S>) -> HashMap<S::Player, RewardVal> { pub fn simulate<S: GameState>(node: &Node<S>) -> HashMap<S::Player, RewardVal> {
let mut state: S = node.state.clone(); let mut state: S = node.state.clone();
while !state.is_terminal() { while !state.is_terminal() {
let legal_actions = state.get_legal_actions(); let legal_actions = state.get_legal_actions();
let action = legal_actions.choose(&mut rand::rng()).unwrap(); let action = legal_actions.choose(&mut rand::thread_rng()).unwrap();
state = state.state_after_action(&action); state = state.state_after_action(&action);
} }
state.rewards_for_players() state.rewards_for_players()

View File

@ -7,7 +7,7 @@ use std::hash::Hash;
/// ///
/// When leveraging MCTS for your game, you must implement this trait to provide /// When leveraging MCTS for your game, you must implement this trait to provide
/// the specifics for your game. /// the specifics for your game.
pub trait GameState: Clone + Debug { pub trait GameState: Clone {
/// The type of actions that can be taken in the game /// The type of actions that can be taken in the game
type Action: Action; type Action: Action;

View File

@ -1,27 +1,2 @@
pub mod arena; pub mod arena;
pub mod node; pub mod node;
use crate::state::GameState;
use crate::tree::arena::Arena;
use crate::tree::node::Node;
pub fn print_tree<S: GameState>(node_id: usize, arena: &Arena<S>) {
let mut to_print: Vec<usize> = Vec::new();
to_print.push(node_id);
while let Some(node_id) = to_print.pop() {
let node: &Node<S> = arena.get_node(node_id);
if node.depth > 0 {
for _ in 0..node.depth - 1 {
print!("| ")
}
print!("|- ");
}
println!(
"{:?} a:{:?} v:{} {:?}",
node_id, node.action, node.visits, node.player_view
);
for child_id in node.children.iter().rev() {
to_print.push(*child_id);
}
}
}

View File

@ -1,4 +1,5 @@
use std::collections::HashMap; use std::collections::HashMap;
use std::fmt::Debug;
use crate::state::GameState; use crate::state::GameState;
@ -14,7 +15,7 @@ pub type RewardVal = f64;
/// ///
/// This class is not thread safe, as the library does not provide for parallel /// This class is not thread safe, as the library does not provide for parallel
/// search. /// search.
#[derive(std::fmt::Debug)] #[derive(Debug)]
pub struct Node<S: GameState> { pub struct Node<S: GameState> {
/// The game state at the given node, after `action` /// The game state at the given node, after `action`
pub state: S, pub state: S,
@ -29,7 +30,7 @@ pub struct Node<S: GameState> {
pub visits: u64, pub visits: u64,
/// The player's evaluation of the node /// The player's evaluation of the node
pub player_view: HashMap<S::Player, PlayerRewardView>, pub player_view: HashMap<S::Player, PlayerNodeView>,
/// The depth of the node in the tree, this is 0 for the root node /// The depth of the node in the tree, this is 0 for the root node
pub depth: usize, pub depth: usize,
@ -81,7 +82,10 @@ impl<S: GameState> Node<S> {
} }
pub fn record_player_reward(&mut self, player: S::Player, reward: RewardVal) { pub fn record_player_reward(&mut self, player: S::Player, reward: RewardVal) {
let pv = self.player_view.entry(player).or_default(); let pv = self
.player_view
.entry(player)
.or_insert(PlayerNodeView::default());
pv.rewards.push(reward); pv.rewards.push(reward);
pv.reward_sum += reward; pv.reward_sum += reward;
pv.reward_average = pv.reward_sum / pv.rewards.len() as f64; pv.reward_average = pv.reward_sum / pv.rewards.len() as f64;
@ -91,7 +95,8 @@ impl<S: GameState> Node<S> {
/// A player's specific perspective of a node's value /// A player's specific perspective of a node's value
/// ///
/// Each player has their own idea of the value of a node. /// Each player has their own idea of the value of a node.
pub struct PlayerRewardView { #[derive(Debug)]
pub struct PlayerNodeView {
/// The total reward from simulations through this node /// The total reward from simulations through this node
pub reward_sum: RewardVal, pub reward_sum: RewardVal,
@ -102,19 +107,12 @@ pub struct PlayerRewardView {
pub rewards: Vec<RewardVal>, pub rewards: Vec<RewardVal>,
} }
impl Default for PlayerRewardView { impl Default for PlayerNodeView {
fn default() -> Self { fn default() -> Self {
PlayerRewardView { PlayerNodeView {
reward_sum: 0.0, reward_sum: 0.0,
reward_average: 0.0, reward_average: 0.0,
rewards: Vec::new(), rewards: Vec::new(),
} }
} }
} }
impl std::fmt::Debug for PlayerRewardView {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{{sum={} avg={}}}", self.reward_sum, self.reward_average)?;
Ok(())
}
}