diff --git a/Cargo.toml b/Cargo.toml index 500991d..b130500 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,7 +16,7 @@ rand = "~0.9" thiserror = "~2.0" [dev-dependencies] -divan = "0.1.21" +divan = "~0.1" [[bench]] name = "e2e" diff --git a/src/mcts.rs b/src/mcts.rs index 17ee3cf..fc2b09a 100644 --- a/src/mcts.rs +++ b/src/mcts.rs @@ -75,8 +75,15 @@ impl<'conf, S: GameState + std::fmt::Debug> MCTS<'conf, S> { if !selected_node.state.is_terminal() { self.expand(selected_id); let children: &Vec = &self.arena.get_node(selected_id).children; - let random_child: usize = *children.choose(&mut rand::rng()).unwrap(); - selected_id = random_child; + match children.choose(&mut rand::rng()) { + Some(&random_child) => { + selected_id = random_child; + } + None => { + // We ran out of nodes + return Err(MCTSError::NonTerminalGame); + } + } } let rewards = self.simulate(selected_id); self.backprop(selected_id, &rewards);