Compare commits
4 Commits
6cc6e6a7ba
...
6a33818238
Author | SHA1 | Date | |
---|---|---|---|
6a33818238 | |||
0f9d4f0c4e | |||
9f893b0005 | |||
b80f039b93 |
48
Cargo.lock
generated
48
Cargo.lock
generated
@ -2,6 +2,12 @@
|
|||||||
# It is not intended for manual editing.
|
# It is not intended for manual editing.
|
||||||
version = 4
|
version = 4
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "bitflags"
|
||||||
|
version = "2.9.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "1b8e56985ec62d17e9c1001dc89c88ecd7dc08e47eba5ec7c29c7b5eeecde967"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "cfg-if"
|
name = "cfg-if"
|
||||||
version = "1.0.1"
|
version = "1.0.1"
|
||||||
@ -10,12 +16,13 @@ checksum = "9555578bc9e57714c812a1f84e4fc5b4d21fcb063490c624de019f7464c91268"
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "getrandom"
|
name = "getrandom"
|
||||||
version = "0.2.16"
|
version = "0.3.3"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "335ff9f135e4384c8150d6f27c6daed433577f86b4750418338c01a1a2528592"
|
checksum = "26145e563e54f2cadc477553f1ec5ee650b00862f0a58bcd12cbdc5f0ea2d2f4"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"cfg-if",
|
"cfg-if",
|
||||||
"libc",
|
"libc",
|
||||||
|
"r-efi",
|
||||||
"wasi",
|
"wasi",
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -53,21 +60,26 @@ dependencies = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "rand"
|
name = "r-efi"
|
||||||
version = "0.8.5"
|
version = "5.3.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404"
|
checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "rand"
|
||||||
|
version = "0.9.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "9fbfd9d094a40bf3ae768db9361049ace4c0e04a4fd6b359518bd7b73a73dd97"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"libc",
|
|
||||||
"rand_chacha",
|
"rand_chacha",
|
||||||
"rand_core",
|
"rand_core",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "rand_chacha"
|
name = "rand_chacha"
|
||||||
version = "0.3.1"
|
version = "0.9.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88"
|
checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"ppv-lite86",
|
"ppv-lite86",
|
||||||
"rand_core",
|
"rand_core",
|
||||||
@ -75,9 +87,9 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "rand_core"
|
name = "rand_core"
|
||||||
version = "0.6.4"
|
version = "0.9.3"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c"
|
checksum = "99d9a13982dcf210057a8a78572b2217b667c3beacbf3a0d8b454f6f82837d38"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"getrandom",
|
"getrandom",
|
||||||
]
|
]
|
||||||
@ -129,9 +141,21 @@ checksum = "5a5f39404a5da50712a4c1eecf25e90dd62b613502b7e925fd4e4d19b5c96512"
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "wasi"
|
name = "wasi"
|
||||||
version = "0.11.1+wasi-snapshot-preview1"
|
version = "0.14.2+wasi-0.2.4"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "ccf3ec651a847eb01de73ccad15eb7d99f80485de043efb2f370cd654f4ea44b"
|
checksum = "9683f9a5a998d873c0d21fcbe3c083009670149a8fab228644b8bd36b2c48cb3"
|
||||||
|
dependencies = [
|
||||||
|
"wit-bindgen-rt",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "wit-bindgen-rt"
|
||||||
|
version = "0.39.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "6f42320e61fe2cfd34354ecb597f86f413484a798ba44a8ca1165c58d42da6c1"
|
||||||
|
dependencies = [
|
||||||
|
"bitflags",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "zerocopy"
|
name = "zerocopy"
|
||||||
|
@ -12,5 +12,5 @@ categories = ["algorithms", "data-structures"]
|
|||||||
|
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
rand = "~0.8"
|
rand = "~0.9"
|
||||||
thiserror = "~2.0"
|
thiserror = "~2.0"
|
||||||
|
@ -40,6 +40,7 @@ fn main() {
|
|||||||
// Find the best move
|
// Find the best move
|
||||||
match mcts.search() {
|
match mcts.search() {
|
||||||
Ok(action) => {
|
Ok(action) => {
|
||||||
|
mcts.print_tree();
|
||||||
println!(
|
println!(
|
||||||
"AI chooses: {} (row {}, col {})",
|
"AI chooses: {} (row {}, col {})",
|
||||||
action.index,
|
action.index,
|
||||||
|
@ -6,7 +6,8 @@ use crate::policy::simulation::simulate_reward;
|
|||||||
use crate::state::GameState;
|
use crate::state::GameState;
|
||||||
use crate::tree::arena::Arena;
|
use crate::tree::arena::Arena;
|
||||||
use crate::tree::node::{Node, RewardVal};
|
use crate::tree::node::{Node, RewardVal};
|
||||||
use rand::prelude::SliceRandom;
|
use crate::tree::print_tree;
|
||||||
|
use rand::prelude::IndexedRandom;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::time::Instant;
|
use std::time::Instant;
|
||||||
|
|
||||||
@ -74,7 +75,7 @@ impl<'conf, S: GameState + std::fmt::Debug> MCTS<'conf, S> {
|
|||||||
if !selected_node.state.is_terminal() {
|
if !selected_node.state.is_terminal() {
|
||||||
self.expand(selected_id);
|
self.expand(selected_id);
|
||||||
let children: &Vec<usize> = &self.arena.get_node(selected_id).children;
|
let children: &Vec<usize> = &self.arena.get_node(selected_id).children;
|
||||||
let random_child: usize = *children.choose(&mut rand::thread_rng()).unwrap();
|
let random_child: usize = *children.choose(&mut rand::rng()).unwrap();
|
||||||
selected_id = random_child;
|
selected_id = random_child;
|
||||||
}
|
}
|
||||||
let rewards = self.simulate(selected_id);
|
let rewards = self.simulate(selected_id);
|
||||||
@ -130,6 +131,10 @@ impl<'conf, S: GameState + std::fmt::Debug> MCTS<'conf, S> {
|
|||||||
None => Err(MCTSError::NoBestAction),
|
None => Err(MCTSError::NoBestAction),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn print_tree(&self) {
|
||||||
|
print_tree(self.root_id, &self.arena)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Errors returned by the MCTS algorithm
|
/// Errors returned by the MCTS algorithm
|
||||||
|
@ -69,18 +69,14 @@ fn standard_backprop<S: GameState>(
|
|||||||
let mut current_id: usize = node_id;
|
let mut current_id: usize = node_id;
|
||||||
loop {
|
loop {
|
||||||
let node = arena.get_node_mut(current_id);
|
let node = arena.get_node_mut(current_id);
|
||||||
let player = node.state.get_current_player().clone();
|
node.increment_visits();
|
||||||
match rewards.get(&player) {
|
for (player, reward) in rewards.iter() {
|
||||||
Some(reward) => {
|
node.record_player_reward(player.clone(), *reward);
|
||||||
node.increment_visits();
|
}
|
||||||
node.record_player_reward(player, *reward);
|
if let Some(parent_id) = node.parent {
|
||||||
if let Some(parent_id) = node.parent {
|
current_id = parent_id;
|
||||||
current_id = parent_id;
|
} else {
|
||||||
} else {
|
break;
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
None => (),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -94,19 +90,14 @@ fn weighted_backprop<S: GameState>(
|
|||||||
let mut current_id: usize = node_id;
|
let mut current_id: usize = node_id;
|
||||||
loop {
|
loop {
|
||||||
let node = arena.get_node_mut(current_id);
|
let node = arena.get_node_mut(current_id);
|
||||||
let player = node.state.get_current_player().clone();
|
|
||||||
let weight = weight_for_depth(depth_factor, node.depth);
|
let weight = weight_for_depth(depth_factor, node.depth);
|
||||||
match rewards.get(&player) {
|
for (player, reward) in rewards.iter() {
|
||||||
Some(reward) => {
|
node.record_player_reward(player.clone(), (*reward) * weight);
|
||||||
node.increment_visits();
|
}
|
||||||
node.record_player_reward(player, (*reward) * weight);
|
if let Some(parent_id) = node.parent {
|
||||||
if let Some(parent_id) = node.parent {
|
current_id = parent_id;
|
||||||
current_id = parent_id;
|
} else {
|
||||||
} else {
|
break;
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
None => (),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -4,14 +4,14 @@
|
|||||||
|
|
||||||
use crate::state::GameState;
|
use crate::state::GameState;
|
||||||
use crate::tree::node::{Node, RewardVal};
|
use crate::tree::node::{Node, RewardVal};
|
||||||
use rand::prelude::SliceRandom;
|
use rand::prelude::IndexedRandom;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
|
||||||
pub fn simulate<S: GameState>(node: &Node<S>) -> HashMap<S::Player, RewardVal> {
|
pub fn simulate<S: GameState>(node: &Node<S>) -> HashMap<S::Player, RewardVal> {
|
||||||
let mut state: S = node.state.clone();
|
let mut state: S = node.state.clone();
|
||||||
while !state.is_terminal() {
|
while !state.is_terminal() {
|
||||||
let legal_actions = state.get_legal_actions();
|
let legal_actions = state.get_legal_actions();
|
||||||
let action = legal_actions.choose(&mut rand::thread_rng()).unwrap();
|
let action = legal_actions.choose(&mut rand::rng()).unwrap();
|
||||||
state = state.state_after_action(&action);
|
state = state.state_after_action(&action);
|
||||||
}
|
}
|
||||||
state.rewards_for_players()
|
state.rewards_for_players()
|
||||||
|
@ -7,7 +7,7 @@ use std::hash::Hash;
|
|||||||
///
|
///
|
||||||
/// When leveraging MCTS for your game, you must implement this trait to provide
|
/// When leveraging MCTS for your game, you must implement this trait to provide
|
||||||
/// the specifics for your game.
|
/// the specifics for your game.
|
||||||
pub trait GameState: Clone {
|
pub trait GameState: Clone + Debug {
|
||||||
/// The type of actions that can be taken in the game
|
/// The type of actions that can be taken in the game
|
||||||
type Action: Action;
|
type Action: Action;
|
||||||
|
|
||||||
|
@ -1,2 +1,27 @@
|
|||||||
pub mod arena;
|
pub mod arena;
|
||||||
pub mod node;
|
pub mod node;
|
||||||
|
|
||||||
|
use crate::state::GameState;
|
||||||
|
use crate::tree::arena::Arena;
|
||||||
|
use crate::tree::node::Node;
|
||||||
|
|
||||||
|
pub fn print_tree<S: GameState>(node_id: usize, arena: &Arena<S>) {
|
||||||
|
let mut to_print: Vec<usize> = Vec::new();
|
||||||
|
to_print.push(node_id);
|
||||||
|
while let Some(node_id) = to_print.pop() {
|
||||||
|
let node: &Node<S> = arena.get_node(node_id);
|
||||||
|
if node.depth > 0 {
|
||||||
|
for _ in 0..node.depth - 1 {
|
||||||
|
print!("| ")
|
||||||
|
}
|
||||||
|
print!("|- ");
|
||||||
|
}
|
||||||
|
println!(
|
||||||
|
"{:?} a:{:?} v:{} {:?}",
|
||||||
|
node_id, node.action, node.visits, node.player_view
|
||||||
|
);
|
||||||
|
for child_id in node.children.iter().rev() {
|
||||||
|
to_print.push(*child_id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::fmt::Debug;
|
|
||||||
|
|
||||||
use crate::state::GameState;
|
use crate::state::GameState;
|
||||||
|
|
||||||
@ -15,7 +14,7 @@ pub type RewardVal = f64;
|
|||||||
///
|
///
|
||||||
/// This class is not thread safe, as the library does not provide for parallel
|
/// This class is not thread safe, as the library does not provide for parallel
|
||||||
/// search.
|
/// search.
|
||||||
#[derive(Debug)]
|
#[derive(std::fmt::Debug)]
|
||||||
pub struct Node<S: GameState> {
|
pub struct Node<S: GameState> {
|
||||||
/// The game state at the given node, after `action`
|
/// The game state at the given node, after `action`
|
||||||
pub state: S,
|
pub state: S,
|
||||||
@ -30,7 +29,7 @@ pub struct Node<S: GameState> {
|
|||||||
pub visits: u64,
|
pub visits: u64,
|
||||||
|
|
||||||
/// The player's evaluation of the node
|
/// The player's evaluation of the node
|
||||||
pub player_view: HashMap<S::Player, PlayerNodeView>,
|
pub player_view: HashMap<S::Player, PlayerRewardView>,
|
||||||
|
|
||||||
/// The depth of the node in the tree, this is 0 for the root node
|
/// The depth of the node in the tree, this is 0 for the root node
|
||||||
pub depth: usize,
|
pub depth: usize,
|
||||||
@ -82,10 +81,7 @@ impl<S: GameState> Node<S> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn record_player_reward(&mut self, player: S::Player, reward: RewardVal) {
|
pub fn record_player_reward(&mut self, player: S::Player, reward: RewardVal) {
|
||||||
let pv = self
|
let pv = self.player_view.entry(player).or_default();
|
||||||
.player_view
|
|
||||||
.entry(player)
|
|
||||||
.or_insert(PlayerNodeView::default());
|
|
||||||
pv.rewards.push(reward);
|
pv.rewards.push(reward);
|
||||||
pv.reward_sum += reward;
|
pv.reward_sum += reward;
|
||||||
pv.reward_average = pv.reward_sum / pv.rewards.len() as f64;
|
pv.reward_average = pv.reward_sum / pv.rewards.len() as f64;
|
||||||
@ -95,8 +91,7 @@ impl<S: GameState> Node<S> {
|
|||||||
/// A player's specific perspective of a node's value
|
/// A player's specific perspective of a node's value
|
||||||
///
|
///
|
||||||
/// Each player has their own idea of the value of a node.
|
/// Each player has their own idea of the value of a node.
|
||||||
#[derive(Debug)]
|
pub struct PlayerRewardView {
|
||||||
pub struct PlayerNodeView {
|
|
||||||
/// The total reward from simulations through this node
|
/// The total reward from simulations through this node
|
||||||
pub reward_sum: RewardVal,
|
pub reward_sum: RewardVal,
|
||||||
|
|
||||||
@ -107,12 +102,19 @@ pub struct PlayerNodeView {
|
|||||||
pub rewards: Vec<RewardVal>,
|
pub rewards: Vec<RewardVal>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for PlayerNodeView {
|
impl Default for PlayerRewardView {
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
PlayerNodeView {
|
PlayerRewardView {
|
||||||
reward_sum: 0.0,
|
reward_sum: 0.0,
|
||||||
reward_average: 0.0,
|
reward_average: 0.0,
|
||||||
rewards: Vec::new(),
|
rewards: Vec::new(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Debug for PlayerRewardView {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
write!(f, "{{sum={} avg={}}}", self.reward_sum, self.reward_average)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user