diff --git a/Cargo.lock b/Cargo.lock index 8a80c72..098e30e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,6 +2,12 @@ # It is not intended for manual editing. version = 4 +[[package]] +name = "anstyle" +version = "1.0.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "862ed96ca487e809f1c8e5a8447f6ee2cf102f846893800b20cebdf541fc6bbd" + [[package]] name = "bitflags" version = "2.9.1" @@ -14,6 +20,73 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9555578bc9e57714c812a1f84e4fc5b4d21fcb063490c624de019f7464c91268" +[[package]] +name = "clap" +version = "4.5.40" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "40b6887a1d8685cebccf115538db5c0efe625ccac9696ad45c409d96566e910f" +dependencies = [ + "clap_builder", +] + +[[package]] +name = "clap_builder" +version = "4.5.40" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e0c66c08ce9f0c698cbce5c0279d0bb6ac936d8674174fe48f736533b964f59e" +dependencies = [ + "anstyle", + "clap_lex", + "terminal_size", +] + +[[package]] +name = "clap_lex" +version = "0.7.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b94f61472cee1439c0b966b47e3aca9ae07e45d070759512cd390ea2bebc6675" + +[[package]] +name = "condtype" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "baf0a07a401f374238ab8e2f11a104d2851bf9ce711ec69804834de8af45c7af" + +[[package]] +name = "divan" +version = "0.1.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a405457ec78b8fe08b0e32b4a3570ab5dff6dd16eb9e76a5ee0a9d9cbd898933" +dependencies = [ + "cfg-if", + "clap", + "condtype", + "divan-macros", + "libc", + "regex-lite", +] + +[[package]] +name = "divan-macros" +version = "0.1.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9556bc800956545d6420a640173e5ba7dfa82f38d3ea5a167eb555bc69ac3323" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "errno" +version = "0.3.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "778e2ac28f6c47af28e4907f13ffd1e1ddbd400980a9abd7c8df189bf578a5ad" +dependencies = [ + "libc", + "windows-sys 0.60.2", +] + [[package]] name = "getrandom" version = "0.3.3" @@ -32,6 +105,12 @@ version = "0.2.174" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1171693293099992e19cddea4e8b849964e9846f4acee11b3948bcc337be8776" +[[package]] +name = "linux-raw-sys" +version = "0.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cd945864f07fe9f5371a27ad7b52a172b4b499999f1d97574c9fa68373937e12" + [[package]] name = "ppv-lite86" version = "0.2.21" @@ -94,14 +173,34 @@ dependencies = [ "getrandom", ] +[[package]] +name = "regex-lite" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "53a49587ad06b26609c52e423de037e7f57f20d53535d66e08c695f347df952a" + [[package]] name = "rustic_mcts" version = "0.1.0" dependencies = [ + "divan", "rand", "thiserror", ] +[[package]] +name = "rustix" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c71e83d6afe7ff64890ec6b71d6a69bb8a610ab78ce364b3352876bb4c801266" +dependencies = [ + "bitflags", + "errno", + "libc", + "linux-raw-sys", + "windows-sys 0.59.0", +] + [[package]] name = "syn" version = "2.0.103" @@ -113,6 +212,16 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "terminal_size" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "45c6481c4829e4cc63825e62c49186a34538b7b2750b73b266581ffb612fb5ed" +dependencies = [ + "rustix", + "windows-sys 0.59.0", +] + [[package]] name = "thiserror" version = "2.0.12" @@ -148,6 +257,152 @@ dependencies = [ "wit-bindgen-rt", ] +[[package]] +name = "windows-sys" +version = "0.59.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e38bc4d79ed67fd075bcc251a1c39b32a1776bbe92e5bef1f0bf1f8c531853b" +dependencies = [ + "windows-targets 0.52.6", +] + +[[package]] +name = "windows-sys" +version = "0.60.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f2f500e4d28234f72040990ec9d39e3a6b950f9f22d3dba18416c35882612bcb" +dependencies = [ + "windows-targets 0.53.2", +] + +[[package]] +name = "windows-targets" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973" +dependencies = [ + "windows_aarch64_gnullvm 0.52.6", + "windows_aarch64_msvc 0.52.6", + "windows_i686_gnu 0.52.6", + "windows_i686_gnullvm 0.52.6", + "windows_i686_msvc 0.52.6", + "windows_x86_64_gnu 0.52.6", + "windows_x86_64_gnullvm 0.52.6", + "windows_x86_64_msvc 0.52.6", +] + +[[package]] +name = "windows-targets" +version = "0.53.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c66f69fcc9ce11da9966ddb31a40968cad001c5bedeb5c2b82ede4253ab48aef" +dependencies = [ + "windows_aarch64_gnullvm 0.53.0", + "windows_aarch64_msvc 0.53.0", + "windows_i686_gnu 0.53.0", + "windows_i686_gnullvm 0.53.0", + "windows_i686_msvc 0.53.0", + "windows_x86_64_gnu 0.53.0", + "windows_x86_64_gnullvm 0.53.0", + "windows_x86_64_msvc 0.53.0", +] + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "86b8d5f90ddd19cb4a147a5fa63ca848db3df085e25fee3cc10b39b6eebae764" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c7651a1f62a11b8cbd5e0d42526e55f2c99886c77e007179efff86c2b137e66c" + +[[package]] +name = "windows_i686_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" + +[[package]] +name = "windows_i686_gnu" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c1dc67659d35f387f5f6c479dc4e28f1d4bb90ddd1a5d3da2e5d97b42d6272c3" + +[[package]] +name = "windows_i686_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" + +[[package]] +name = "windows_i686_gnullvm" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ce6ccbdedbf6d6354471319e781c0dfef054c81fbc7cf83f338a4296c0cae11" + +[[package]] +name = "windows_i686_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" + +[[package]] +name = "windows_i686_msvc" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "581fee95406bb13382d2f65cd4a908ca7b1e4c2f1917f143ba16efe98a589b5d" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2e55b5ac9ea33f2fc1716d1742db15574fd6fc8dadc51caab1c16a3d3b4190ba" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0a6e035dd0599267ce1ee132e51c27dd29437f63325753051e71dd9e42406c57" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "271414315aff87387382ec3d271b52d7ae78726f5d44ac98b4f4030c91880486" + [[package]] name = "wit-bindgen-rt" version = "0.39.0" diff --git a/Cargo.toml b/Cargo.toml index f7bc4df..0b20f3a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,3 +14,10 @@ categories = ["algorithms", "data-structures"] [dependencies] rand = "~0.9" thiserror = "~2.0" + +[dev-dependencies] +divan = "0.1.21" + +[[bench]] +name = "example" +harness = false diff --git a/benches/example.rs b/benches/example.rs new file mode 100644 index 0000000..af951c5 --- /dev/null +++ b/benches/example.rs @@ -0,0 +1,192 @@ +use std::collections::HashMap; + +use rustic_mcts::policy::backprop::BackpropagationPolicy; +use rustic_mcts::policy::decision::DecisionPolicy; +use rustic_mcts::policy::selection::SelectionPolicy; +use rustic_mcts::policy::simulation::SimulationPolicy; +use rustic_mcts::{Action, GameState, MCTSConfig, RewardVal, MCTS}; + +fn main() { + // Run registered benchmarks. + divan::main(); +} + +#[divan::bench] +fn tic_tac_toe() { + // Set up a new game + let mut game = TicTacToe::new(); + + // Create MCTS configuration + let config = MCTSConfig { + max_iterations: 10_000, + max_time: None, + tree_size_allocation: 10_000, + selection_policy: SelectionPolicy::UCB1Tuned(1.414), + simulation_policy: SimulationPolicy::Random, + backprop_policy: BackpropagationPolicy::Standard, + decision_policy: DecisionPolicy::MostVisits, + }; + + // Main game loop + while !game.is_terminal() { + // Create a new MCTS search + let mut mcts = MCTS::new(game.clone(), &config); + + // Find the best move + match mcts.search() { + Ok(action) => { + // Apply the AI's move + game = game.state_after_action(&action); + } + Err(_) => { + break; + } + } + } +} + +/// Players in Tic-Tac-Toe +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +enum Player { + X, + O, +} + +impl rustic_mcts::Player for Player {} + +/// Tic-Tac-Toe move +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +struct Move { + /// Board position index (0-8) + index: usize, +} + +impl Action for Move { + fn id(&self) -> usize { + self.index + } +} + +/// Tic-Tac-Toe game state +#[derive(Debug, Clone)] +struct TicTacToe { + /// Board representation (None = empty, Some(Player) = occupied) + board: [Option; 9], + + /// Current player's turn + current_player: Player, + + /// Number of moves played so far + moves_played: usize, +} + +impl TicTacToe { + /// Creates a new empty Tic-Tac-Toe board + fn new() -> Self { + TicTacToe { + board: [None; 9], + current_player: Player::X, + moves_played: 0, + } + } + + /// Returns the winner of the game, if any + fn get_winner(&self) -> Option { + // Check rows + for row in 0..3 { + let i = row * 3; + if self.board[i].is_some() + && self.board[i] == self.board[i + 1] + && self.board[i] == self.board[i + 2] + { + return self.board[i]; + } + } + + // Check columns + for col in 0..3 { + if self.board[col].is_some() + && self.board[col] == self.board[col + 3] + && self.board[col] == self.board[col + 6] + { + return self.board[col]; + } + } + + // Check diagonals + if self.board[0].is_some() + && self.board[0] == self.board[4] + && self.board[0] == self.board[8] + { + return self.board[0]; + } + if self.board[2].is_some() + && self.board[2] == self.board[4] + && self.board[2] == self.board[6] + { + return self.board[2]; + } + + None + } +} + +impl GameState for TicTacToe { + type Action = Move; + type Player = Player; + + fn get_legal_actions(&self) -> Vec { + let mut actions = Vec::new(); + for i in 0..9 { + if self.board[i].is_none() { + actions.push(Move { index: i }); + } + } + actions + } + + fn state_after_action(&self, action: &Self::Action) -> Self { + let mut new_state = self.clone(); + + // Make the move + new_state.board[action.index] = Some(self.current_player); + new_state.moves_played = self.moves_played + 1; + + // Switch player + new_state.current_player = match self.current_player { + Player::X => Player::O, + Player::O => Player::X, + }; + + new_state + } + + fn is_terminal(&self) -> bool { + self.get_winner().is_some() || self.moves_played == 9 + } + + fn reward_for_player(&self, player: &Self::Player) -> RewardVal { + if let Some(winner) = self.get_winner() { + if winner == *player { + return 1.0; // Win + } else { + return 0.0; // Loss + } + } + + // Draw + 0.5 + } + + fn rewards_for_players(&self) -> HashMap { + HashMap::from_iter(vec![ + (Player::X, self.reward_for_player(&Player::X)), + (Player::O, self.reward_for_player(&Player::O)), + ]) + } + + fn get_current_player(&self) -> &Self::Player { + &self.current_player + } +} +