Make minmax slighly less ugly
Can't vouch for correctness yet tho
This commit is contained in:
@@ -27,6 +27,12 @@ impl Color {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Move_ {
|
||||
pub source: Position,
|
||||
pub target: Position,
|
||||
}
|
||||
|
||||
#[derive(PartialEq, Eq, Hash, Clone, Debug)]
|
||||
pub struct Position {
|
||||
pub file: File,
|
||||
@@ -381,4 +387,31 @@ impl Board {
|
||||
Err(())
|
||||
}
|
||||
}
|
||||
pub fn get_legal_moves(
|
||||
&self,
|
||||
position: &Position,
|
||||
) -> Result<impl Iterator<Item = Position> + '_, ()> {
|
||||
Ok(self.find_moves(position)?.chain(self.find_captures(position)?))
|
||||
}
|
||||
pub fn all_moves_for_color(
|
||||
&self,
|
||||
color: Color,
|
||||
) -> impl Iterator<Item = Move_> + '_ {
|
||||
self.iter()
|
||||
.filter_map(move |(source, piece)| {
|
||||
if piece.color == color {
|
||||
if let Ok(targets) = self.get_legal_moves(source) {
|
||||
Some(targets.map(|target| Move_ {
|
||||
source: source.clone(),
|
||||
target,
|
||||
}))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.flatten()
|
||||
}
|
||||
}
|
||||
|
||||
150
rs/src/engine.rs
150
rs/src/engine.rs
@@ -1,5 +1,5 @@
|
||||
use crate::board;
|
||||
use crate::rand::Rng;
|
||||
use crate::board::Move_;
|
||||
|
||||
impl board::PieceType {
|
||||
fn value(&self) -> f32 {
|
||||
@@ -10,10 +10,13 @@ impl board::PieceType {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Move_ {
|
||||
pub source: board::Position,
|
||||
pub target: board::Position,
|
||||
impl board::Color {
|
||||
fn sign(&self) -> f32 {
|
||||
match self {
|
||||
board::Color::Black => -1.0,
|
||||
board::Color::White => 1.0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Engine {
|
||||
@@ -37,121 +40,56 @@ impl Engine {
|
||||
pub fn set_state(&mut self, board: board::Board) {
|
||||
self.board = board;
|
||||
}
|
||||
pub fn get_legal_moves<'a>(
|
||||
board: &'a board::Board,
|
||||
position: &board::Position,
|
||||
) -> Result<impl Iterator<Item = board::Position> + 'a, ()> {
|
||||
Ok(board.find_moves(position)?.chain(board.find_captures(position)?))
|
||||
}
|
||||
pub fn make_move(
|
||||
&mut self,
|
||||
source: &board::Position,
|
||||
target: board::Position,
|
||||
) -> Result<(), ()> {
|
||||
if !Engine::get_legal_moves(&self.board, source)?
|
||||
.any(|pos| pos == target)
|
||||
{
|
||||
if !self.board.get_legal_moves(source)?.any(|pos| pos == target) {
|
||||
Err(())
|
||||
} else {
|
||||
// We checked that there is a piece at source in get_legal_moves
|
||||
self.board.relocate(source, target)
|
||||
}
|
||||
}
|
||||
pub fn evaluate_position(board: &board::Board) -> f32 {
|
||||
board
|
||||
.iter()
|
||||
.map(|(_, piece)| match piece.color {
|
||||
board::Color::White => 1.0,
|
||||
board::Color::Black => -1.0,
|
||||
} * piece.piece_type.value())
|
||||
.sum()
|
||||
}
|
||||
pub fn evaluate_move(
|
||||
move_: &Move_,
|
||||
board: &board::Board,
|
||||
color: &board::Color,
|
||||
) -> Result<f32, ()> {
|
||||
let mut board = board.clone();
|
||||
board.relocate(&move_.source, move_.target.clone())?;
|
||||
Ok(match color {
|
||||
board::Color::White => 1.0,
|
||||
board::Color::Black => -1.0,
|
||||
} * Engine::evaluate_position(&board))
|
||||
}
|
||||
fn minmax(
|
||||
depth: i8,
|
||||
board: &board::Board,
|
||||
color: &board::Color,
|
||||
) -> Option<(f32, Move_)> {
|
||||
eprintln!("Entering minmax for {} at depth {}", color, depth);
|
||||
let all_moves = board
|
||||
.iter()
|
||||
.filter_map(|(source, piece)| {
|
||||
if &piece.color == color {
|
||||
if let Ok(targets) = Engine::get_legal_moves(board, source)
|
||||
{
|
||||
Some(targets.map(|target| Move_ {
|
||||
source: source.clone(),
|
||||
target,
|
||||
}))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.flatten();
|
||||
if depth == 0 {
|
||||
Engine::best_immediate_move(board, color, all_moves)
|
||||
} else if depth > 0 {
|
||||
let best_given_opponent = all_moves.map(|move_| {
|
||||
eprintln!("Finding opponent moves for {:?}", move_);
|
||||
let mut board = board.clone();
|
||||
board.relocate(&move_.source, move_.target.clone()).unwrap();
|
||||
let result = if let Some((opponent_score, _)) =
|
||||
Engine::minmax(depth - 1, &board, &color.other())
|
||||
{
|
||||
(-opponent_score, move_)
|
||||
} else {
|
||||
(0.0, move_)
|
||||
};
|
||||
eprintln!(
|
||||
"Move {:?} got opponent minmax score {} for {:?}",
|
||||
result.1, result.0, color
|
||||
);
|
||||
result
|
||||
});
|
||||
best_given_opponent
|
||||
.max_by(|(s1, _), (s2, _)| s1.partial_cmp(s2).unwrap())
|
||||
} else {
|
||||
panic!();
|
||||
}
|
||||
}
|
||||
fn best_immediate_move(
|
||||
board: &board::Board,
|
||||
color: &board::Color,
|
||||
moves: impl Iterator<Item = Move_>,
|
||||
) -> Option<(f32, Move_)> {
|
||||
let mut rng = rand::thread_rng();
|
||||
pub fn choose_move(&self, color: &board::Color) -> Option<Move_> {
|
||||
Some(
|
||||
moves
|
||||
.map(|move_| {
|
||||
let score = Engine::evaluate_move(&move_, board, color)
|
||||
.unwrap()
|
||||
+ rng.gen::<f32>() * 0.05;
|
||||
eprintln!(
|
||||
"Move {:?} got immediate score {} for {:?}",
|
||||
move_, score, color
|
||||
);
|
||||
(score, move_)
|
||||
self.board
|
||||
.all_moves_for_color(color.clone())
|
||||
.map(|m| {
|
||||
let mut board = self.board.clone();
|
||||
board.relocate(&m.source, m.target.clone()).unwrap();
|
||||
(minmax(1, &board, color), m)
|
||||
})
|
||||
.max_by(|(score1, _), (score2, _)| {
|
||||
score1.partial_cmp(score2).unwrap()
|
||||
})?,
|
||||
.max_by(|(s1, _), (s2, _)| s1.partial_cmp(s2).unwrap())?
|
||||
.1,
|
||||
)
|
||||
}
|
||||
pub fn choose_move(&self, color: &board::Color) -> Option<Move_> {
|
||||
Some(Engine::minmax(1, &self.board, color)?.1)
|
||||
}
|
||||
|
||||
pub fn evaluate_position(board: &board::Board) -> f32 {
|
||||
board
|
||||
.iter()
|
||||
.map(|(_, piece)| piece.color.sign() * piece.piece_type.value())
|
||||
.sum()
|
||||
}
|
||||
|
||||
fn minmax(depth: i8, board: &board::Board, color: &board::Color) -> f32 {
|
||||
if depth == 0 {
|
||||
evaluate_position(board) * color.sign()
|
||||
} else {
|
||||
let best_opponent_move_score = board
|
||||
.all_moves_for_color(color.other())
|
||||
.map(|m| {
|
||||
let mut board = board.clone();
|
||||
board.relocate(&m.source, m.target).unwrap();
|
||||
minmax(depth - 1, &board, &color.other())
|
||||
})
|
||||
.max_by(|s1, s2| s1.partial_cmp(s2).unwrap());
|
||||
if let Some(s) = best_opponent_move_score {
|
||||
-s
|
||||
} else {
|
||||
minmax(0, board, color)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -219,10 +219,7 @@ impl Ui {
|
||||
board::Position::parse(position_str).map_err(|_| {
|
||||
format!("Error parsing position {}", position_str)
|
||||
})?;
|
||||
match engine::Engine::get_legal_moves(
|
||||
&self.engine.board,
|
||||
&position,
|
||||
) {
|
||||
match self.engine.board.get_legal_moves(&position) {
|
||||
Err(_) => {
|
||||
let error = format!("No moves possible from {}", position);
|
||||
Err(error)
|
||||
|
||||
Reference in New Issue
Block a user