Adding really stupid benchmark
This commit is contained in:
parent
6a33818238
commit
44ef9ebdd8
255
Cargo.lock
generated
255
Cargo.lock
generated
@ -2,6 +2,12 @@
|
||||
# It is not intended for manual editing.
|
||||
version = 4
|
||||
|
||||
[[package]]
|
||||
name = "anstyle"
|
||||
version = "1.0.11"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "862ed96ca487e809f1c8e5a8447f6ee2cf102f846893800b20cebdf541fc6bbd"
|
||||
|
||||
[[package]]
|
||||
name = "bitflags"
|
||||
version = "2.9.1"
|
||||
@ -14,6 +20,73 @@ version = "1.0.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9555578bc9e57714c812a1f84e4fc5b4d21fcb063490c624de019f7464c91268"
|
||||
|
||||
[[package]]
|
||||
name = "clap"
|
||||
version = "4.5.40"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "40b6887a1d8685cebccf115538db5c0efe625ccac9696ad45c409d96566e910f"
|
||||
dependencies = [
|
||||
"clap_builder",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "clap_builder"
|
||||
version = "4.5.40"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e0c66c08ce9f0c698cbce5c0279d0bb6ac936d8674174fe48f736533b964f59e"
|
||||
dependencies = [
|
||||
"anstyle",
|
||||
"clap_lex",
|
||||
"terminal_size",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "clap_lex"
|
||||
version = "0.7.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b94f61472cee1439c0b966b47e3aca9ae07e45d070759512cd390ea2bebc6675"
|
||||
|
||||
[[package]]
|
||||
name = "condtype"
|
||||
version = "1.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "baf0a07a401f374238ab8e2f11a104d2851bf9ce711ec69804834de8af45c7af"
|
||||
|
||||
[[package]]
|
||||
name = "divan"
|
||||
version = "0.1.21"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a405457ec78b8fe08b0e32b4a3570ab5dff6dd16eb9e76a5ee0a9d9cbd898933"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"clap",
|
||||
"condtype",
|
||||
"divan-macros",
|
||||
"libc",
|
||||
"regex-lite",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "divan-macros"
|
||||
version = "0.1.21"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9556bc800956545d6420a640173e5ba7dfa82f38d3ea5a167eb555bc69ac3323"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "errno"
|
||||
version = "0.3.13"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "778e2ac28f6c47af28e4907f13ffd1e1ddbd400980a9abd7c8df189bf578a5ad"
|
||||
dependencies = [
|
||||
"libc",
|
||||
"windows-sys 0.60.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "getrandom"
|
||||
version = "0.3.3"
|
||||
@ -32,6 +105,12 @@ version = "0.2.174"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1171693293099992e19cddea4e8b849964e9846f4acee11b3948bcc337be8776"
|
||||
|
||||
[[package]]
|
||||
name = "linux-raw-sys"
|
||||
version = "0.9.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cd945864f07fe9f5371a27ad7b52a172b4b499999f1d97574c9fa68373937e12"
|
||||
|
||||
[[package]]
|
||||
name = "ppv-lite86"
|
||||
version = "0.2.21"
|
||||
@ -94,14 +173,34 @@ dependencies = [
|
||||
"getrandom",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "regex-lite"
|
||||
version = "0.1.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "53a49587ad06b26609c52e423de037e7f57f20d53535d66e08c695f347df952a"
|
||||
|
||||
[[package]]
|
||||
name = "rustic_mcts"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"divan",
|
||||
"rand",
|
||||
"thiserror",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustix"
|
||||
version = "1.0.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c71e83d6afe7ff64890ec6b71d6a69bb8a610ab78ce364b3352876bb4c801266"
|
||||
dependencies = [
|
||||
"bitflags",
|
||||
"errno",
|
||||
"libc",
|
||||
"linux-raw-sys",
|
||||
"windows-sys 0.59.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "syn"
|
||||
version = "2.0.103"
|
||||
@ -113,6 +212,16 @@ dependencies = [
|
||||
"unicode-ident",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "terminal_size"
|
||||
version = "0.4.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "45c6481c4829e4cc63825e62c49186a34538b7b2750b73b266581ffb612fb5ed"
|
||||
dependencies = [
|
||||
"rustix",
|
||||
"windows-sys 0.59.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "thiserror"
|
||||
version = "2.0.12"
|
||||
@ -148,6 +257,152 @@ dependencies = [
|
||||
"wit-bindgen-rt",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows-sys"
|
||||
version = "0.59.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1e38bc4d79ed67fd075bcc251a1c39b32a1776bbe92e5bef1f0bf1f8c531853b"
|
||||
dependencies = [
|
||||
"windows-targets 0.52.6",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows-sys"
|
||||
version = "0.60.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f2f500e4d28234f72040990ec9d39e3a6b950f9f22d3dba18416c35882612bcb"
|
||||
dependencies = [
|
||||
"windows-targets 0.53.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows-targets"
|
||||
version = "0.52.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973"
|
||||
dependencies = [
|
||||
"windows_aarch64_gnullvm 0.52.6",
|
||||
"windows_aarch64_msvc 0.52.6",
|
||||
"windows_i686_gnu 0.52.6",
|
||||
"windows_i686_gnullvm 0.52.6",
|
||||
"windows_i686_msvc 0.52.6",
|
||||
"windows_x86_64_gnu 0.52.6",
|
||||
"windows_x86_64_gnullvm 0.52.6",
|
||||
"windows_x86_64_msvc 0.52.6",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows-targets"
|
||||
version = "0.53.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c66f69fcc9ce11da9966ddb31a40968cad001c5bedeb5c2b82ede4253ab48aef"
|
||||
dependencies = [
|
||||
"windows_aarch64_gnullvm 0.53.0",
|
||||
"windows_aarch64_msvc 0.53.0",
|
||||
"windows_i686_gnu 0.53.0",
|
||||
"windows_i686_gnullvm 0.53.0",
|
||||
"windows_i686_msvc 0.53.0",
|
||||
"windows_x86_64_gnu 0.53.0",
|
||||
"windows_x86_64_gnullvm 0.53.0",
|
||||
"windows_x86_64_msvc 0.53.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows_aarch64_gnullvm"
|
||||
version = "0.52.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3"
|
||||
|
||||
[[package]]
|
||||
name = "windows_aarch64_gnullvm"
|
||||
version = "0.53.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "86b8d5f90ddd19cb4a147a5fa63ca848db3df085e25fee3cc10b39b6eebae764"
|
||||
|
||||
[[package]]
|
||||
name = "windows_aarch64_msvc"
|
||||
version = "0.52.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469"
|
||||
|
||||
[[package]]
|
||||
name = "windows_aarch64_msvc"
|
||||
version = "0.53.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c7651a1f62a11b8cbd5e0d42526e55f2c99886c77e007179efff86c2b137e66c"
|
||||
|
||||
[[package]]
|
||||
name = "windows_i686_gnu"
|
||||
version = "0.52.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b"
|
||||
|
||||
[[package]]
|
||||
name = "windows_i686_gnu"
|
||||
version = "0.53.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c1dc67659d35f387f5f6c479dc4e28f1d4bb90ddd1a5d3da2e5d97b42d6272c3"
|
||||
|
||||
[[package]]
|
||||
name = "windows_i686_gnullvm"
|
||||
version = "0.52.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66"
|
||||
|
||||
[[package]]
|
||||
name = "windows_i686_gnullvm"
|
||||
version = "0.53.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9ce6ccbdedbf6d6354471319e781c0dfef054c81fbc7cf83f338a4296c0cae11"
|
||||
|
||||
[[package]]
|
||||
name = "windows_i686_msvc"
|
||||
version = "0.52.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66"
|
||||
|
||||
[[package]]
|
||||
name = "windows_i686_msvc"
|
||||
version = "0.53.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "581fee95406bb13382d2f65cd4a908ca7b1e4c2f1917f143ba16efe98a589b5d"
|
||||
|
||||
[[package]]
|
||||
name = "windows_x86_64_gnu"
|
||||
version = "0.52.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78"
|
||||
|
||||
[[package]]
|
||||
name = "windows_x86_64_gnu"
|
||||
version = "0.53.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2e55b5ac9ea33f2fc1716d1742db15574fd6fc8dadc51caab1c16a3d3b4190ba"
|
||||
|
||||
[[package]]
|
||||
name = "windows_x86_64_gnullvm"
|
||||
version = "0.52.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d"
|
||||
|
||||
[[package]]
|
||||
name = "windows_x86_64_gnullvm"
|
||||
version = "0.53.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0a6e035dd0599267ce1ee132e51c27dd29437f63325753051e71dd9e42406c57"
|
||||
|
||||
[[package]]
|
||||
name = "windows_x86_64_msvc"
|
||||
version = "0.52.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec"
|
||||
|
||||
[[package]]
|
||||
name = "windows_x86_64_msvc"
|
||||
version = "0.53.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "271414315aff87387382ec3d271b52d7ae78726f5d44ac98b4f4030c91880486"
|
||||
|
||||
[[package]]
|
||||
name = "wit-bindgen-rt"
|
||||
version = "0.39.0"
|
||||
|
@ -14,3 +14,10 @@ categories = ["algorithms", "data-structures"]
|
||||
[dependencies]
|
||||
rand = "~0.9"
|
||||
thiserror = "~2.0"
|
||||
|
||||
[dev-dependencies]
|
||||
divan = "0.1.21"
|
||||
|
||||
[[bench]]
|
||||
name = "example"
|
||||
harness = false
|
||||
|
192
benches/example.rs
Normal file
192
benches/example.rs
Normal file
@ -0,0 +1,192 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use rustic_mcts::policy::backprop::BackpropagationPolicy;
|
||||
use rustic_mcts::policy::decision::DecisionPolicy;
|
||||
use rustic_mcts::policy::selection::SelectionPolicy;
|
||||
use rustic_mcts::policy::simulation::SimulationPolicy;
|
||||
use rustic_mcts::{Action, GameState, MCTSConfig, RewardVal, MCTS};
|
||||
|
||||
fn main() {
|
||||
// Run registered benchmarks.
|
||||
divan::main();
|
||||
}
|
||||
|
||||
#[divan::bench]
|
||||
fn tic_tac_toe() {
|
||||
// Set up a new game
|
||||
let mut game = TicTacToe::new();
|
||||
|
||||
// Create MCTS configuration
|
||||
let config = MCTSConfig {
|
||||
max_iterations: 10_000,
|
||||
max_time: None,
|
||||
tree_size_allocation: 10_000,
|
||||
selection_policy: SelectionPolicy::UCB1Tuned(1.414),
|
||||
simulation_policy: SimulationPolicy::Random,
|
||||
backprop_policy: BackpropagationPolicy::Standard,
|
||||
decision_policy: DecisionPolicy::MostVisits,
|
||||
};
|
||||
|
||||
// Main game loop
|
||||
while !game.is_terminal() {
|
||||
// Create a new MCTS search
|
||||
let mut mcts = MCTS::new(game.clone(), &config);
|
||||
|
||||
// Find the best move
|
||||
match mcts.search() {
|
||||
Ok(action) => {
|
||||
// Apply the AI's move
|
||||
game = game.state_after_action(&action);
|
||||
}
|
||||
Err(_) => {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Players in Tic-Tac-Toe
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
enum Player {
|
||||
X,
|
||||
O,
|
||||
}
|
||||
|
||||
impl rustic_mcts::Player for Player {}
|
||||
|
||||
/// Tic-Tac-Toe move
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
struct Move {
|
||||
/// Board position index (0-8)
|
||||
index: usize,
|
||||
}
|
||||
|
||||
impl Action for Move {
|
||||
fn id(&self) -> usize {
|
||||
self.index
|
||||
}
|
||||
}
|
||||
|
||||
/// Tic-Tac-Toe game state
|
||||
#[derive(Debug, Clone)]
|
||||
struct TicTacToe {
|
||||
/// Board representation (None = empty, Some(Player) = occupied)
|
||||
board: [Option<Player>; 9],
|
||||
|
||||
/// Current player's turn
|
||||
current_player: Player,
|
||||
|
||||
/// Number of moves played so far
|
||||
moves_played: usize,
|
||||
}
|
||||
|
||||
impl TicTacToe {
|
||||
/// Creates a new empty Tic-Tac-Toe board
|
||||
fn new() -> Self {
|
||||
TicTacToe {
|
||||
board: [None; 9],
|
||||
current_player: Player::X,
|
||||
moves_played: 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the winner of the game, if any
|
||||
fn get_winner(&self) -> Option<Player> {
|
||||
// Check rows
|
||||
for row in 0..3 {
|
||||
let i = row * 3;
|
||||
if self.board[i].is_some()
|
||||
&& self.board[i] == self.board[i + 1]
|
||||
&& self.board[i] == self.board[i + 2]
|
||||
{
|
||||
return self.board[i];
|
||||
}
|
||||
}
|
||||
|
||||
// Check columns
|
||||
for col in 0..3 {
|
||||
if self.board[col].is_some()
|
||||
&& self.board[col] == self.board[col + 3]
|
||||
&& self.board[col] == self.board[col + 6]
|
||||
{
|
||||
return self.board[col];
|
||||
}
|
||||
}
|
||||
|
||||
// Check diagonals
|
||||
if self.board[0].is_some()
|
||||
&& self.board[0] == self.board[4]
|
||||
&& self.board[0] == self.board[8]
|
||||
{
|
||||
return self.board[0];
|
||||
}
|
||||
if self.board[2].is_some()
|
||||
&& self.board[2] == self.board[4]
|
||||
&& self.board[2] == self.board[6]
|
||||
{
|
||||
return self.board[2];
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
impl GameState for TicTacToe {
|
||||
type Action = Move;
|
||||
type Player = Player;
|
||||
|
||||
fn get_legal_actions(&self) -> Vec<Self::Action> {
|
||||
let mut actions = Vec::new();
|
||||
for i in 0..9 {
|
||||
if self.board[i].is_none() {
|
||||
actions.push(Move { index: i });
|
||||
}
|
||||
}
|
||||
actions
|
||||
}
|
||||
|
||||
fn state_after_action(&self, action: &Self::Action) -> Self {
|
||||
let mut new_state = self.clone();
|
||||
|
||||
// Make the move
|
||||
new_state.board[action.index] = Some(self.current_player);
|
||||
new_state.moves_played = self.moves_played + 1;
|
||||
|
||||
// Switch player
|
||||
new_state.current_player = match self.current_player {
|
||||
Player::X => Player::O,
|
||||
Player::O => Player::X,
|
||||
};
|
||||
|
||||
new_state
|
||||
}
|
||||
|
||||
fn is_terminal(&self) -> bool {
|
||||
self.get_winner().is_some() || self.moves_played == 9
|
||||
}
|
||||
|
||||
fn reward_for_player(&self, player: &Self::Player) -> RewardVal {
|
||||
if let Some(winner) = self.get_winner() {
|
||||
if winner == *player {
|
||||
return 1.0; // Win
|
||||
} else {
|
||||
return 0.0; // Loss
|
||||
}
|
||||
}
|
||||
|
||||
// Draw
|
||||
0.5
|
||||
}
|
||||
|
||||
fn rewards_for_players(&self) -> HashMap<Self::Player, RewardVal> {
|
||||
HashMap::from_iter(vec![
|
||||
(Player::X, self.reward_for_player(&Player::X)),
|
||||
(Player::O, self.reward_for_player(&Player::O)),
|
||||
])
|
||||
}
|
||||
|
||||
fn get_current_player(&self) -> &Self::Player {
|
||||
&self.current_player
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user