Create a AI vs AI copy of tic-tac-toe
This commit is contained in:
parent
17884f4b90
commit
3317c29480
255
examples/auto_tic_tac_toe.rs
Normal file
255
examples/auto_tic_tac_toe.rs
Normal file
@ -0,0 +1,255 @@
|
||||
use std::collections::HashMap;
|
||||
use std::fmt;
|
||||
|
||||
use rustic_mcts::policy::backprop::BackpropagationPolicy;
|
||||
use rustic_mcts::policy::decision::DecisionPolicy;
|
||||
use rustic_mcts::policy::selection::SelectionPolicy;
|
||||
use rustic_mcts::policy::simulation::SimulationPolicy;
|
||||
use rustic_mcts::{Action, GameState, MCTSConfig, RewardVal, MCTS};
|
||||
|
||||
fn main() {
|
||||
println!("MCTS Tic-Tac-Toe Example");
|
||||
println!("========================");
|
||||
println!();
|
||||
|
||||
// Set up a new game
|
||||
let mut game = TicTacToe::new();
|
||||
|
||||
// Create MCTS configuration
|
||||
let config = MCTSConfig {
|
||||
max_iterations: 10_000,
|
||||
max_time: None,
|
||||
tree_size_allocation: 10_000,
|
||||
selection_policy: SelectionPolicy::UCB1Tuned(1.414),
|
||||
simulation_policy: SimulationPolicy::Random,
|
||||
backprop_policy: BackpropagationPolicy::Standard,
|
||||
decision_policy: DecisionPolicy::MostVisits,
|
||||
};
|
||||
|
||||
// Main game loop
|
||||
while !game.is_terminal() {
|
||||
// Display the board
|
||||
println!("{}", game);
|
||||
|
||||
// AI player (O)
|
||||
println!("{:?} is thinking...", game.current_player);
|
||||
|
||||
// Create a new MCTS search
|
||||
let mut mcts = MCTS::new(game.clone(), &config);
|
||||
|
||||
// Find the best move
|
||||
match mcts.search() {
|
||||
Ok(action) => {
|
||||
println!(
|
||||
"AI chooses: {} (row {}, col {})",
|
||||
action.index,
|
||||
action.index / 3,
|
||||
action.index % 3
|
||||
);
|
||||
|
||||
// Apply the AI's move
|
||||
game = game.state_after_action(&action);
|
||||
}
|
||||
Err(e) => {
|
||||
println!("Error: {:?}", e);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Display final state
|
||||
println!("{}", game);
|
||||
|
||||
// Report the result
|
||||
if let Some(winner) = game.get_winner() {
|
||||
println!("Player {:?} wins!", winner);
|
||||
} else {
|
||||
println!("The game is a draw!");
|
||||
}
|
||||
}
|
||||
|
||||
/// Players in Tic-Tac-Toe
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
enum Player {
|
||||
X,
|
||||
O,
|
||||
}
|
||||
|
||||
impl rustic_mcts::Player for Player {}
|
||||
|
||||
/// Tic-Tac-Toe move
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
struct Move {
|
||||
/// Board position index (0-8)
|
||||
index: usize,
|
||||
}
|
||||
|
||||
impl Action for Move {
|
||||
fn id(&self) -> usize {
|
||||
self.index
|
||||
}
|
||||
}
|
||||
|
||||
/// Tic-Tac-Toe game state
|
||||
#[derive(Clone)]
|
||||
struct TicTacToe {
|
||||
/// Board representation (None = empty, Some(Player) = occupied)
|
||||
board: [Option<Player>; 9],
|
||||
|
||||
/// Current player's turn
|
||||
current_player: Player,
|
||||
|
||||
/// Number of moves played so far
|
||||
moves_played: usize,
|
||||
}
|
||||
|
||||
impl TicTacToe {
|
||||
/// Creates a new empty Tic-Tac-Toe board
|
||||
fn new() -> Self {
|
||||
TicTacToe {
|
||||
board: [None; 9],
|
||||
current_player: Player::X,
|
||||
moves_played: 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the winner of the game, if any
|
||||
fn get_winner(&self) -> Option<Player> {
|
||||
// Check rows
|
||||
for row in 0..3 {
|
||||
let i = row * 3;
|
||||
if self.board[i].is_some()
|
||||
&& self.board[i] == self.board[i + 1]
|
||||
&& self.board[i] == self.board[i + 2]
|
||||
{
|
||||
return self.board[i];
|
||||
}
|
||||
}
|
||||
|
||||
// Check columns
|
||||
for col in 0..3 {
|
||||
if self.board[col].is_some()
|
||||
&& self.board[col] == self.board[col + 3]
|
||||
&& self.board[col] == self.board[col + 6]
|
||||
{
|
||||
return self.board[col];
|
||||
}
|
||||
}
|
||||
|
||||
// Check diagonals
|
||||
if self.board[0].is_some()
|
||||
&& self.board[0] == self.board[4]
|
||||
&& self.board[0] == self.board[8]
|
||||
{
|
||||
return self.board[0];
|
||||
}
|
||||
if self.board[2].is_some()
|
||||
&& self.board[2] == self.board[4]
|
||||
&& self.board[2] == self.board[6]
|
||||
{
|
||||
return self.board[2];
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
impl GameState for TicTacToe {
|
||||
type Action = Move;
|
||||
type Player = Player;
|
||||
|
||||
fn get_legal_actions(&self) -> Vec<Self::Action> {
|
||||
let mut actions = Vec::new();
|
||||
for i in 0..9 {
|
||||
if self.board[i].is_none() {
|
||||
actions.push(Move { index: i });
|
||||
}
|
||||
}
|
||||
actions
|
||||
}
|
||||
|
||||
fn state_after_action(&self, action: &Self::Action) -> Self {
|
||||
let mut new_state = self.clone();
|
||||
|
||||
// Make the move
|
||||
new_state.board[action.index] = Some(self.current_player);
|
||||
new_state.moves_played = self.moves_played + 1;
|
||||
|
||||
// Switch player
|
||||
new_state.current_player = match self.current_player {
|
||||
Player::X => Player::O,
|
||||
Player::O => Player::X,
|
||||
};
|
||||
|
||||
new_state
|
||||
}
|
||||
|
||||
fn is_terminal(&self) -> bool {
|
||||
self.get_winner().is_some() || self.moves_played == 9
|
||||
}
|
||||
|
||||
fn reward_for_player(&self, player: &Self::Player) -> RewardVal {
|
||||
if let Some(winner) = self.get_winner() {
|
||||
if winner == *player {
|
||||
return 1.0; // Win
|
||||
} else {
|
||||
return 0.0; // Loss
|
||||
}
|
||||
}
|
||||
|
||||
// Draw
|
||||
0.5
|
||||
}
|
||||
|
||||
fn rewards_for_players(&self) -> HashMap<Self::Player, RewardVal> {
|
||||
HashMap::from_iter(vec![
|
||||
(Player::X, self.reward_for_player(&Player::X)),
|
||||
(Player::O, self.reward_for_player(&Player::O)),
|
||||
])
|
||||
}
|
||||
|
||||
fn get_current_player(&self) -> &Self::Player {
|
||||
&self.current_player
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for TicTacToe {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
writeln!(f, " 0 1 2")?;
|
||||
for row in 0..3 {
|
||||
write!(f, "{} ", row)?;
|
||||
for col in 0..3 {
|
||||
let index = row * 3 + col;
|
||||
let symbol = match self.board[index] {
|
||||
Some(Player::X) => "X",
|
||||
Some(Player::O) => "O",
|
||||
None => ".",
|
||||
};
|
||||
write!(f, "{} ", symbol)?;
|
||||
}
|
||||
writeln!(f)?;
|
||||
}
|
||||
|
||||
writeln!(f, "\nPlayer {:?}'s turn", self.current_player)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Debug for TicTacToe {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "\n")?;
|
||||
for row in 0..3 {
|
||||
for col in 0..3 {
|
||||
let index = row * 3 + col;
|
||||
let symbol = match self.board[index] {
|
||||
Some(Player::X) => "X",
|
||||
Some(Player::O) => "O",
|
||||
None => ".",
|
||||
};
|
||||
write!(f, "{} ", symbol)?;
|
||||
}
|
||||
writeln!(f)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user