rustic_mcts/examples/auto_tic_tac_toe.rs
2025-06-29 22:51:17 -07:00

256 lines
6.5 KiB
Rust

use std::collections::HashMap;
use std::fmt;
use rustic_mcts::policy::backprop::BackpropagationPolicy;
use rustic_mcts::policy::decision::DecisionPolicy;
use rustic_mcts::policy::selection::SelectionPolicy;
use rustic_mcts::policy::simulation::SimulationPolicy;
use rustic_mcts::{Action, GameState, MCTSConfig, RewardVal, MCTS};
fn main() {
println!("MCTS Tic-Tac-Toe Example");
println!("========================");
println!();
// Set up a new game
let mut game = TicTacToe::new();
// Create MCTS configuration
let config = MCTSConfig {
max_iterations: 10_000,
max_time: None,
tree_size_allocation: 10_000,
selection_policy: SelectionPolicy::UCB1Tuned(1.414),
simulation_policy: SimulationPolicy::Random,
backprop_policy: BackpropagationPolicy::Standard,
decision_policy: DecisionPolicy::MostVisits,
};
// Main game loop
while !game.is_terminal() {
// Display the board
println!("{}", game);
// AI player (O)
println!("{:?} is thinking...", game.current_player);
// Create a new MCTS search
let mut mcts = MCTS::new(game.clone(), &config);
// Find the best move
match mcts.search() {
Ok(action) => {
println!(
"AI chooses: {} (row {}, col {})",
action.index,
action.index / 3,
action.index % 3
);
// Apply the AI's move
game = game.state_after_action(&action);
}
Err(e) => {
println!("Error: {:?}", e);
break;
}
}
}
// Display final state
println!("{}", game);
// Report the result
if let Some(winner) = game.get_winner() {
println!("Player {:?} wins!", winner);
} else {
println!("The game is a draw!");
}
}
/// Players in Tic-Tac-Toe
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
enum Player {
X,
O,
}
impl rustic_mcts::Player for Player {}
/// Tic-Tac-Toe move
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
struct Move {
/// Board position index (0-8)
index: usize,
}
impl Action for Move {
fn id(&self) -> usize {
self.index
}
}
/// Tic-Tac-Toe game state
#[derive(Clone)]
struct TicTacToe {
/// Board representation (None = empty, Some(Player) = occupied)
board: [Option<Player>; 9],
/// Current player's turn
current_player: Player,
/// Number of moves played so far
moves_played: usize,
}
impl TicTacToe {
/// Creates a new empty Tic-Tac-Toe board
fn new() -> Self {
TicTacToe {
board: [None; 9],
current_player: Player::X,
moves_played: 0,
}
}
/// Returns the winner of the game, if any
fn get_winner(&self) -> Option<Player> {
// Check rows
for row in 0..3 {
let i = row * 3;
if self.board[i].is_some()
&& self.board[i] == self.board[i + 1]
&& self.board[i] == self.board[i + 2]
{
return self.board[i];
}
}
// Check columns
for col in 0..3 {
if self.board[col].is_some()
&& self.board[col] == self.board[col + 3]
&& self.board[col] == self.board[col + 6]
{
return self.board[col];
}
}
// Check diagonals
if self.board[0].is_some()
&& self.board[0] == self.board[4]
&& self.board[0] == self.board[8]
{
return self.board[0];
}
if self.board[2].is_some()
&& self.board[2] == self.board[4]
&& self.board[2] == self.board[6]
{
return self.board[2];
}
None
}
}
impl GameState for TicTacToe {
type Action = Move;
type Player = Player;
fn get_legal_actions(&self) -> Vec<Self::Action> {
let mut actions = Vec::new();
for i in 0..9 {
if self.board[i].is_none() {
actions.push(Move { index: i });
}
}
actions
}
fn state_after_action(&self, action: &Self::Action) -> Self {
let mut new_state = self.clone();
// Make the move
new_state.board[action.index] = Some(self.current_player);
new_state.moves_played = self.moves_played + 1;
// Switch player
new_state.current_player = match self.current_player {
Player::X => Player::O,
Player::O => Player::X,
};
new_state
}
fn is_terminal(&self) -> bool {
self.get_winner().is_some() || self.moves_played == 9
}
fn reward_for_player(&self, player: &Self::Player) -> RewardVal {
if let Some(winner) = self.get_winner() {
if winner == *player {
return 1.0; // Win
} else {
return 0.0; // Loss
}
}
// Draw
0.5
}
fn rewards_for_players(&self) -> HashMap<Self::Player, RewardVal> {
HashMap::from_iter(vec![
(Player::X, self.reward_for_player(&Player::X)),
(Player::O, self.reward_for_player(&Player::O)),
])
}
fn get_current_player(&self) -> &Self::Player {
&self.current_player
}
}
impl fmt::Display for TicTacToe {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
writeln!(f, " 0 1 2")?;
for row in 0..3 {
write!(f, "{} ", row)?;
for col in 0..3 {
let index = row * 3 + col;
let symbol = match self.board[index] {
Some(Player::X) => "X",
Some(Player::O) => "O",
None => ".",
};
write!(f, "{} ", symbol)?;
}
writeln!(f)?;
}
writeln!(f, "\nPlayer {:?}'s turn", self.current_player)?;
Ok(())
}
}
impl fmt::Debug for TicTacToe {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "\n")?;
for row in 0..3 {
for col in 0..3 {
let index = row * 3 + col;
let symbol = match self.board[index] {
Some(Player::X) => "X",
Some(Player::O) => "O",
None => ".",
};
write!(f, "{} ", symbol)?;
}
writeln!(f)?;
}
Ok(())
}
}