From 3317c29480d05684f3e18a68042d7fbf09226645 Mon Sep 17 00:00:00 2001 From: David Kruger Date: Fri, 27 Jun 2025 13:45:49 -0700 Subject: [PATCH] Create a AI vs AI copy of tic-tac-toe --- examples/auto_tic_tac_toe.rs | 255 +++++++++++++++++++++++++++++++++++ 1 file changed, 255 insertions(+) create mode 100644 examples/auto_tic_tac_toe.rs diff --git a/examples/auto_tic_tac_toe.rs b/examples/auto_tic_tac_toe.rs new file mode 100644 index 0000000..b6a5387 --- /dev/null +++ b/examples/auto_tic_tac_toe.rs @@ -0,0 +1,255 @@ +use std::collections::HashMap; +use std::fmt; + +use rustic_mcts::policy::backprop::BackpropagationPolicy; +use rustic_mcts::policy::decision::DecisionPolicy; +use rustic_mcts::policy::selection::SelectionPolicy; +use rustic_mcts::policy::simulation::SimulationPolicy; +use rustic_mcts::{Action, GameState, MCTSConfig, RewardVal, MCTS}; + +fn main() { + println!("MCTS Tic-Tac-Toe Example"); + println!("========================"); + println!(); + + // Set up a new game + let mut game = TicTacToe::new(); + + // Create MCTS configuration + let config = MCTSConfig { + max_iterations: 10_000, + max_time: None, + tree_size_allocation: 10_000, + selection_policy: SelectionPolicy::UCB1Tuned(1.414), + simulation_policy: SimulationPolicy::Random, + backprop_policy: BackpropagationPolicy::Standard, + decision_policy: DecisionPolicy::MostVisits, + }; + + // Main game loop + while !game.is_terminal() { + // Display the board + println!("{}", game); + + // AI player (O) + println!("{:?} is thinking...", game.current_player); + + // Create a new MCTS search + let mut mcts = MCTS::new(game.clone(), &config); + + // Find the best move + match mcts.search() { + Ok(action) => { + println!( + "AI chooses: {} (row {}, col {})", + action.index, + action.index / 3, + action.index % 3 + ); + + // Apply the AI's move + game = game.state_after_action(&action); + } + Err(e) => { + println!("Error: {:?}", e); + break; + } + } + } + + // Display final state + println!("{}", game); + + // Report the result + if let Some(winner) = game.get_winner() { + println!("Player {:?} wins!", winner); + } else { + println!("The game is a draw!"); + } +} + +/// Players in Tic-Tac-Toe +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +enum Player { + X, + O, +} + +impl rustic_mcts::Player for Player {} + +/// Tic-Tac-Toe move +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +struct Move { + /// Board position index (0-8) + index: usize, +} + +impl Action for Move { + fn id(&self) -> usize { + self.index + } +} + +/// Tic-Tac-Toe game state +#[derive(Clone)] +struct TicTacToe { + /// Board representation (None = empty, Some(Player) = occupied) + board: [Option; 9], + + /// Current player's turn + current_player: Player, + + /// Number of moves played so far + moves_played: usize, +} + +impl TicTacToe { + /// Creates a new empty Tic-Tac-Toe board + fn new() -> Self { + TicTacToe { + board: [None; 9], + current_player: Player::X, + moves_played: 0, + } + } + + /// Returns the winner of the game, if any + fn get_winner(&self) -> Option { + // Check rows + for row in 0..3 { + let i = row * 3; + if self.board[i].is_some() + && self.board[i] == self.board[i + 1] + && self.board[i] == self.board[i + 2] + { + return self.board[i]; + } + } + + // Check columns + for col in 0..3 { + if self.board[col].is_some() + && self.board[col] == self.board[col + 3] + && self.board[col] == self.board[col + 6] + { + return self.board[col]; + } + } + + // Check diagonals + if self.board[0].is_some() + && self.board[0] == self.board[4] + && self.board[0] == self.board[8] + { + return self.board[0]; + } + if self.board[2].is_some() + && self.board[2] == self.board[4] + && self.board[2] == self.board[6] + { + return self.board[2]; + } + + None + } +} + +impl GameState for TicTacToe { + type Action = Move; + type Player = Player; + + fn get_legal_actions(&self) -> Vec { + let mut actions = Vec::new(); + for i in 0..9 { + if self.board[i].is_none() { + actions.push(Move { index: i }); + } + } + actions + } + + fn state_after_action(&self, action: &Self::Action) -> Self { + let mut new_state = self.clone(); + + // Make the move + new_state.board[action.index] = Some(self.current_player); + new_state.moves_played = self.moves_played + 1; + + // Switch player + new_state.current_player = match self.current_player { + Player::X => Player::O, + Player::O => Player::X, + }; + + new_state + } + + fn is_terminal(&self) -> bool { + self.get_winner().is_some() || self.moves_played == 9 + } + + fn reward_for_player(&self, player: &Self::Player) -> RewardVal { + if let Some(winner) = self.get_winner() { + if winner == *player { + return 1.0; // Win + } else { + return 0.0; // Loss + } + } + + // Draw + 0.5 + } + + fn rewards_for_players(&self) -> HashMap { + HashMap::from_iter(vec![ + (Player::X, self.reward_for_player(&Player::X)), + (Player::O, self.reward_for_player(&Player::O)), + ]) + } + + fn get_current_player(&self) -> &Self::Player { + &self.current_player + } +} + +impl fmt::Display for TicTacToe { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + writeln!(f, " 0 1 2")?; + for row in 0..3 { + write!(f, "{} ", row)?; + for col in 0..3 { + let index = row * 3 + col; + let symbol = match self.board[index] { + Some(Player::X) => "X", + Some(Player::O) => "O", + None => ".", + }; + write!(f, "{} ", symbol)?; + } + writeln!(f)?; + } + + writeln!(f, "\nPlayer {:?}'s turn", self.current_player)?; + Ok(()) + } +} + +impl fmt::Debug for TicTacToe { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "\n")?; + for row in 0..3 { + for col in 0..3 { + let index = row * 3 + col; + let symbol = match self.board[index] { + Some(Player::X) => "X", + Some(Player::O) => "O", + None => ".", + }; + write!(f, "{} ", symbol)?; + } + writeln!(f)?; + } + Ok(()) + } +}