256 lines
6.5 KiB
Rust
256 lines
6.5 KiB
Rust
use std::collections::HashMap;
|
|
use std::fmt;
|
|
|
|
use rustic_mcts::policy::backprop::BackpropagationPolicy;
|
|
use rustic_mcts::policy::decision::DecisionPolicy;
|
|
use rustic_mcts::policy::selection::SelectionPolicy;
|
|
use rustic_mcts::policy::simulation::SimulationPolicy;
|
|
use rustic_mcts::{Action, GameState, MCTSConfig, RewardVal, MCTS};
|
|
|
|
fn main() {
|
|
println!("MCTS Tic-Tac-Toe Example");
|
|
println!("========================");
|
|
println!();
|
|
|
|
// Set up a new game
|
|
let mut game = TicTacToe::new();
|
|
|
|
// Create MCTS configuration
|
|
let config = MCTSConfig {
|
|
max_iterations: 10_000,
|
|
max_time: None,
|
|
tree_size_allocation: 10_000,
|
|
selection_policy: SelectionPolicy::UCB1Tuned(1.414),
|
|
simulation_policy: SimulationPolicy::Random,
|
|
backprop_policy: BackpropagationPolicy::Standard,
|
|
decision_policy: DecisionPolicy::MostVisits,
|
|
};
|
|
|
|
// Main game loop
|
|
while !game.is_terminal() {
|
|
// Display the board
|
|
println!("{}", game);
|
|
|
|
// AI player (O)
|
|
println!("{:?} is thinking...", game.current_player);
|
|
|
|
// Create a new MCTS search
|
|
let mut mcts = MCTS::new(game.clone(), &config);
|
|
|
|
// Find the best move
|
|
match mcts.search() {
|
|
Ok(action) => {
|
|
println!(
|
|
"AI chooses: {} (row {}, col {})",
|
|
action.index,
|
|
action.index / 3,
|
|
action.index % 3
|
|
);
|
|
|
|
// Apply the AI's move
|
|
game = game.state_after_action(&action);
|
|
}
|
|
Err(e) => {
|
|
println!("Error: {:?}", e);
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
// Display final state
|
|
println!("{}", game);
|
|
|
|
// Report the result
|
|
if let Some(winner) = game.get_winner() {
|
|
println!("Player {:?} wins!", winner);
|
|
} else {
|
|
println!("The game is a draw!");
|
|
}
|
|
}
|
|
|
|
/// Players in Tic-Tac-Toe
|
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
|
enum Player {
|
|
X,
|
|
O,
|
|
}
|
|
|
|
impl rustic_mcts::Player for Player {}
|
|
|
|
/// Tic-Tac-Toe move
|
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
|
struct Move {
|
|
/// Board position index (0-8)
|
|
index: usize,
|
|
}
|
|
|
|
impl Action for Move {
|
|
fn id(&self) -> usize {
|
|
self.index
|
|
}
|
|
}
|
|
|
|
/// Tic-Tac-Toe game state
|
|
#[derive(Clone)]
|
|
struct TicTacToe {
|
|
/// Board representation (None = empty, Some(Player) = occupied)
|
|
board: [Option<Player>; 9],
|
|
|
|
/// Current player's turn
|
|
current_player: Player,
|
|
|
|
/// Number of moves played so far
|
|
moves_played: usize,
|
|
}
|
|
|
|
impl TicTacToe {
|
|
/// Creates a new empty Tic-Tac-Toe board
|
|
fn new() -> Self {
|
|
TicTacToe {
|
|
board: [None; 9],
|
|
current_player: Player::X,
|
|
moves_played: 0,
|
|
}
|
|
}
|
|
|
|
/// Returns the winner of the game, if any
|
|
fn get_winner(&self) -> Option<Player> {
|
|
// Check rows
|
|
for row in 0..3 {
|
|
let i = row * 3;
|
|
if self.board[i].is_some()
|
|
&& self.board[i] == self.board[i + 1]
|
|
&& self.board[i] == self.board[i + 2]
|
|
{
|
|
return self.board[i];
|
|
}
|
|
}
|
|
|
|
// Check columns
|
|
for col in 0..3 {
|
|
if self.board[col].is_some()
|
|
&& self.board[col] == self.board[col + 3]
|
|
&& self.board[col] == self.board[col + 6]
|
|
{
|
|
return self.board[col];
|
|
}
|
|
}
|
|
|
|
// Check diagonals
|
|
if self.board[0].is_some()
|
|
&& self.board[0] == self.board[4]
|
|
&& self.board[0] == self.board[8]
|
|
{
|
|
return self.board[0];
|
|
}
|
|
if self.board[2].is_some()
|
|
&& self.board[2] == self.board[4]
|
|
&& self.board[2] == self.board[6]
|
|
{
|
|
return self.board[2];
|
|
}
|
|
|
|
None
|
|
}
|
|
}
|
|
|
|
impl GameState for TicTacToe {
|
|
type Action = Move;
|
|
type Player = Player;
|
|
|
|
fn get_legal_actions(&self) -> Vec<Self::Action> {
|
|
let mut actions = Vec::new();
|
|
for i in 0..9 {
|
|
if self.board[i].is_none() {
|
|
actions.push(Move { index: i });
|
|
}
|
|
}
|
|
actions
|
|
}
|
|
|
|
fn state_after_action(&self, action: &Self::Action) -> Self {
|
|
let mut new_state = self.clone();
|
|
|
|
// Make the move
|
|
new_state.board[action.index] = Some(self.current_player);
|
|
new_state.moves_played = self.moves_played + 1;
|
|
|
|
// Switch player
|
|
new_state.current_player = match self.current_player {
|
|
Player::X => Player::O,
|
|
Player::O => Player::X,
|
|
};
|
|
|
|
new_state
|
|
}
|
|
|
|
fn is_terminal(&self) -> bool {
|
|
self.get_winner().is_some() || self.moves_played == 9
|
|
}
|
|
|
|
fn reward_for_player(&self, player: &Self::Player) -> RewardVal {
|
|
if let Some(winner) = self.get_winner() {
|
|
if winner == *player {
|
|
return 1.0; // Win
|
|
} else {
|
|
return 0.0; // Loss
|
|
}
|
|
}
|
|
|
|
// Draw
|
|
0.5
|
|
}
|
|
|
|
fn rewards_for_players(&self) -> HashMap<Self::Player, RewardVal> {
|
|
HashMap::from_iter(vec![
|
|
(Player::X, self.reward_for_player(&Player::X)),
|
|
(Player::O, self.reward_for_player(&Player::O)),
|
|
])
|
|
}
|
|
|
|
fn get_current_player(&self) -> &Self::Player {
|
|
&self.current_player
|
|
}
|
|
}
|
|
|
|
impl fmt::Display for TicTacToe {
|
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
|
writeln!(f, " 0 1 2")?;
|
|
for row in 0..3 {
|
|
write!(f, "{} ", row)?;
|
|
for col in 0..3 {
|
|
let index = row * 3 + col;
|
|
let symbol = match self.board[index] {
|
|
Some(Player::X) => "X",
|
|
Some(Player::O) => "O",
|
|
None => ".",
|
|
};
|
|
write!(f, "{} ", symbol)?;
|
|
}
|
|
writeln!(f)?;
|
|
}
|
|
|
|
writeln!(f, "\nPlayer {:?}'s turn", self.current_player)?;
|
|
Ok(())
|
|
}
|
|
}
|
|
|
|
impl fmt::Debug for TicTacToe {
|
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
|
write!(f, "\n")?;
|
|
for row in 0..3 {
|
|
for col in 0..3 {
|
|
let index = row * 3 + col;
|
|
let symbol = match self.board[index] {
|
|
Some(Player::X) => "X",
|
|
Some(Player::O) => "O",
|
|
None => ".",
|
|
};
|
|
write!(f, "{} ", symbol)?;
|
|
}
|
|
writeln!(f)?;
|
|
}
|
|
Ok(())
|
|
}
|
|
}
|