feat: add ai castling (Fixes #4)
All checks were successful
Linux arm64 / Build (push) Successful in 41s

This commit is contained in:
2025-05-06 17:15:55 +02:00
parent 58f02f681c
commit 810a0f2159
12 changed files with 252 additions and 150 deletions

View File

@@ -2,20 +2,15 @@ package chess.ai;
import java.util.List; import java.util.List;
import chess.controller.Command; import chess.ai.actions.AIAction;
import chess.ai.actions.AIActions;
import chess.controller.CommandExecutor; import chess.controller.CommandExecutor;
import chess.controller.Command.CommandResult;
import chess.controller.commands.GetPieceAtCommand;
import chess.controller.commands.GetPlayerMovesCommand;
import chess.controller.commands.GetAllowedCastlingsCommand;
import chess.controller.commands.GetAllowedCastlingsCommand.CastlingResult;
import chess.controller.event.GameAdapter; import chess.controller.event.GameAdapter;
import chess.model.Color; import chess.model.Color;
import chess.model.Coordinate; import chess.model.Coordinate;
import chess.model.Move;
import chess.model.Piece; import chess.model.Piece;
public abstract class AI extends GameAdapter{ public abstract class AI extends GameAdapter {
protected final CommandExecutor commandExecutor; protected final CommandExecutor commandExecutor;
protected final Color color; protected final Color color;
@@ -26,7 +21,6 @@ public abstract class AI extends GameAdapter{
} }
protected abstract void play(); protected abstract void play();
protected abstract void promote(Coordinate pawnCoords);
@Override @Override
public void onPlayerTurn(Color color, boolean undone) { public void onPlayerTurn(Color color, boolean undone) {
@@ -36,44 +30,12 @@ public abstract class AI extends GameAdapter{
play(); play();
} }
@Override protected List<AIAction> getAllowedActions() {
public void onPromotePawn(Coordinate pieceCoords) { return AIActions.getAllowedActions(this.commandExecutor);
Piece pawn = pieceAt(pieceCoords);
if (pawn.getColor() != this.color)
return;
promote(pieceCoords);
} }
protected Piece pieceAt(Coordinate coordinate) { protected Piece pieceAt(Coordinate coordinate) {
GetPieceAtCommand command = new GetPieceAtCommand(coordinate); return AIActions.pieceAt(coordinate, this.commandExecutor);
sendCommand(command);
return command.getPiece();
}
protected List<Move> getAllowedMoves() {
return getAllowedMoves(this.commandExecutor);
}
protected List<Move> getAllowedMoves(CommandExecutor commandExecutor) {
GetPlayerMovesCommand cmd = new GetPlayerMovesCommand();
sendCommand(cmd, commandExecutor);
return cmd.getMoves();
}
protected CastlingResult getAllowedCastlings() {
GetAllowedCastlingsCommand cmd2 = new GetAllowedCastlingsCommand();
sendCommand(cmd2);
return cmd2.getCastlingResult();
}
protected CommandResult sendCommand(Command command) {
return sendCommand(command, this.commandExecutor);
}
protected CommandResult sendCommand(Command command, CommandExecutor commandExecutor) {
CommandResult result = commandExecutor.executeCommand(command);
assert result != CommandResult.NotAllowed : "Command not allowed!";
return result;
} }
} }

View File

@@ -3,15 +3,9 @@ package chess.ai;
import java.util.List; import java.util.List;
import java.util.Random; import java.util.Random;
import chess.ai.actions.AIAction;
import chess.controller.CommandExecutor; import chess.controller.CommandExecutor;
import chess.controller.commands.CastlingCommand;
import chess.controller.commands.GetAllowedCastlingsCommand.CastlingResult;
import chess.controller.commands.MoveCommand;
import chess.controller.commands.PromoteCommand;
import chess.controller.commands.PromoteCommand.PromoteType;
import chess.model.Color; import chess.model.Color;
import chess.model.Coordinate;
import chess.model.Move;
public class DumbAI extends AI { public class DumbAI extends AI {
@@ -23,39 +17,11 @@ public class DumbAI extends AI {
@Override @Override
protected void play() { protected void play() {
CastlingResult castlings = getAllowedCastlings(); List<AIAction> actions = getAllowedActions();
List<Move> moves = getAllowedMoves();
switch (castlings) { int randomAction = this.random.nextInt(actions.size());
case Both: {
int randomMove = this.random.nextInt(moves.size() + 2);
if (randomMove < moves.size() - 2)
break;
sendCommand(new CastlingCommand(randomMove == moves.size()));
return;
}
case Small: actions.get(randomAction).applyAction();
case Big: {
int randomMove = this.random.nextInt(moves.size() + 1);
if (randomMove != moves.size())
break;
sendCommand(new CastlingCommand(castlings == CastlingResult.Big));
return;
}
default:
break;
}
int randomMove = this.random.nextInt(moves.size());
sendCommand(new MoveCommand(moves.get(randomMove)));
}
@Override
protected void promote(Coordinate pawnCoords) {
int promote = this.random.nextInt(PromoteType.values().length);
sendCommand(new PromoteCommand(PromoteType.values()[promote]));
} }
} }

View File

@@ -4,12 +4,10 @@ import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.Random; import java.util.Random;
import chess.ai.actions.AIAction;
import chess.ai.actions.AIActionMove;
import chess.controller.CommandExecutor; import chess.controller.CommandExecutor;
import chess.controller.commands.MoveCommand;
import chess.controller.commands.PromoteCommand;
import chess.controller.commands.PromoteCommand.PromoteType;
import chess.model.Color; import chess.model.Color;
import chess.model.Coordinate;
import chess.model.Move; import chess.model.Move;
import chess.model.Piece; import chess.model.Piece;
@@ -26,21 +24,23 @@ public class HungryAI extends AI {
private int getMoveCost(Move move) { private int getMoveCost(Move move) {
Piece piece = pieceAt(move.getDeadPieceCoords()); Piece piece = pieceAt(move.getDeadPieceCoords());
return - (int) pieceCost.getCost(piece); return -(int) pieceCost.getCost(piece);
} }
private List<Move> getBestMoves() { private List<AIAction> getBestMoves() {
List<Move> moves = getAllowedMoves(); List<AIAction> actions = getAllowedActions();
List<Move> bestMoves = new ArrayList<>(); List<AIAction> bestMoves = new ArrayList<>();
int bestCost = 0; int bestCost = 0;
for (Move move : moves) { for (AIAction action : actions) {
int moveCost = getMoveCost(move); if (action instanceof AIActionMove move) {
if (moveCost == bestCost) { int moveCost = getMoveCost(move.getMove());
bestMoves.add(move); if (moveCost == bestCost) {
} else if (moveCost > bestCost) { bestMoves.add(move);
bestMoves.clear(); } else if (moveCost > bestCost) {
bestMoves.add(move); bestMoves.clear();
bestCost = moveCost; bestMoves.add(move);
bestCost = moveCost;
}
} }
} }
return bestMoves; return bestMoves;
@@ -48,14 +48,7 @@ public class HungryAI extends AI {
@Override @Override
protected void play() { protected void play() {
List<Move> bestMoves = getBestMoves(); List<AIAction> bestMoves = getBestMoves();
int randomMove = this.random.nextInt(bestMoves.size()); bestMoves.get(this.random.nextInt(bestMoves.size())).applyAction();
this.commandExecutor.executeCommand(new MoveCommand(bestMoves.get(randomMove)));
} }
@Override
protected void promote(Coordinate pawnCoords) {
sendCommand(new PromoteCommand(PromoteType.Queen));
}
} }

View File

@@ -0,0 +1,33 @@
package chess.ai.actions;
import chess.controller.Command;
import chess.controller.CommandExecutor;
import chess.controller.Command.CommandResult;
import chess.controller.commands.UndoCommand;
public abstract class AIAction {
private final CommandExecutor commandExecutor;
public AIAction(CommandExecutor commandExecutor) {
this.commandExecutor = commandExecutor;
}
protected CommandResult sendCommand(Command cmd, CommandExecutor commandExecutor) {
return commandExecutor.executeCommand(cmd);
}
public void undoAction(CommandExecutor commandExecutor) {
sendCommand(new UndoCommand(), commandExecutor);
}
public void undoAction() {
undoAction(this.commandExecutor);
}
public void applyAction() {
applyAction(this.commandExecutor);
}
public abstract void applyAction(CommandExecutor commandExecutor);
}

View File

@@ -0,0 +1,20 @@
package chess.ai.actions;
import chess.controller.CommandExecutor;
import chess.controller.commands.CastlingCommand;
public class AIActionCastling extends AIAction{
private final boolean bigCastling;
public AIActionCastling(CommandExecutor commandExecutor, boolean bigCastling) {
super(commandExecutor);
this.bigCastling = bigCastling;
}
@Override
public void applyAction(CommandExecutor commandExecutor) {
sendCommand(new CastlingCommand(this.bigCastling), commandExecutor);
}
}

View File

@@ -0,0 +1,25 @@
package chess.ai.actions;
import chess.controller.CommandExecutor;
import chess.controller.commands.MoveCommand;
import chess.model.Move;
public class AIActionMove extends AIAction{
private final Move move;
public AIActionMove(CommandExecutor commandExecutor, Move move) {
super(commandExecutor);
this.move = move;
}
public Move getMove() {
return move;
}
@Override
public void applyAction(CommandExecutor commandExecutor) {
sendCommand(new MoveCommand(move), commandExecutor);
}
}

View File

@@ -0,0 +1,26 @@
package chess.ai.actions;
import chess.controller.CommandExecutor;
import chess.controller.commands.MoveCommand;
import chess.controller.commands.PromoteCommand;
import chess.controller.commands.PromoteCommand.PromoteType;
import chess.model.Move;
public class AIActionMoveAndPromote extends AIAction{
private final Move move;
private final PromoteType promoteType;
public AIActionMoveAndPromote(CommandExecutor commandExecutor, Move move, PromoteType promoteType) {
super(commandExecutor);
this.move = move;
this.promoteType = promoteType;
}
@Override
public void applyAction(CommandExecutor commandExecutor) {
sendCommand(new MoveCommand(move), commandExecutor);
sendCommand(new PromoteCommand(promoteType), commandExecutor);
}
}

View File

@@ -0,0 +1,87 @@
package chess.ai.actions;
import java.util.ArrayList;
import java.util.List;
import chess.controller.Command;
import chess.controller.Command.CommandResult;
import chess.controller.CommandExecutor;
import chess.controller.commands.GetAllowedCastlingsCommand;
import chess.controller.commands.GetAllowedCastlingsCommand.CastlingResult;
import chess.controller.commands.PromoteCommand.PromoteType;
import chess.controller.commands.GetPieceAtCommand;
import chess.controller.commands.GetPlayerMovesCommand;
import chess.model.Color;
import chess.model.Coordinate;
import chess.model.Move;
import chess.model.Piece;
import chess.model.pieces.Pawn;
public class AIActions {
public static List<AIAction> getAllowedActions(CommandExecutor commandExecutor) {
List<Move> moves = getAllowedMoves(commandExecutor);
CastlingResult castlingResult = getAllowedCastlings(commandExecutor);
List<AIAction> actions = new ArrayList<>(moves.size() + 10);
for (Move move : moves) {
Piece movingPiece = pieceAt(move.getStart(), commandExecutor);
if (movingPiece instanceof Pawn) {
int enemyLineY = movingPiece.getColor() == Color.White ? 0 : 7;
if (move.getFinish().getY() == enemyLineY) {
PromoteType[] promotes = PromoteType.values();
for (PromoteType promote : promotes) {
actions.add(new AIActionMoveAndPromote(commandExecutor, move, promote));
}
continue;
}
}
actions.add(new AIActionMove(commandExecutor, move));
}
switch (castlingResult) {
case Both:
actions.add(new AIActionCastling(commandExecutor, true));
actions.add(new AIActionCastling(commandExecutor, false));
break;
case Small:
actions.add(new AIActionCastling(commandExecutor, false));
break;
case Big:
actions.add(new AIActionCastling(commandExecutor, true));
break;
case None:
break;
}
return actions;
}
private static CastlingResult getAllowedCastlings(CommandExecutor commandExecutor) {
GetAllowedCastlingsCommand cmd2 = new GetAllowedCastlingsCommand();
sendCommand(cmd2, commandExecutor);
return cmd2.getCastlingResult();
}
private static List<Move> getAllowedMoves(CommandExecutor commandExecutor) {
GetPlayerMovesCommand cmd = new GetPlayerMovesCommand();
sendCommand(cmd, commandExecutor);
return cmd.getMoves();
}
private static CommandResult sendCommand(Command command, CommandExecutor commandExecutor) {
CommandResult result = commandExecutor.executeCommand(command);
assert result != CommandResult.NotAllowed : "Command not allowed!";
return result;
}
public static Piece pieceAt(Coordinate coordinate, CommandExecutor commandExecutor) {
GetPieceAtCommand command = new GetPieceAtCommand(coordinate);
commandExecutor.executeCommand(command);
return command.getPiece();
}
}

View File

@@ -8,13 +8,9 @@ import java.util.concurrent.Executors;
import java.util.concurrent.Future; import java.util.concurrent.Future;
import chess.ai.AI; import chess.ai.AI;
import chess.ai.actions.AIAction;
import chess.controller.CommandExecutor; import chess.controller.CommandExecutor;
import chess.controller.commands.MoveCommand;
import chess.controller.commands.PromoteCommand;
import chess.controller.commands.PromoteCommand.PromoteType;
import chess.model.Color; import chess.model.Color;
import chess.model.Coordinate;
import chess.model.Move;
import common.Signal1; import common.Signal1;
public class AlphaBetaAI extends AI { public class AlphaBetaAI extends AI {
@@ -38,23 +34,23 @@ public class AlphaBetaAI extends AI {
new AlphaBetaThreadCreator(commandExecutor, color, threadCount)); new AlphaBetaThreadCreator(commandExecutor, color, threadCount));
} }
private Move getBestMove() { private AIAction getBestMove() {
List<Move> moves = getAllowedMoves(); List<AIAction> actions = getAllowedActions();
List<Future<Float>> moveEvaluations = new ArrayList<>(50); List<Future<Float>> moveEvaluations = new ArrayList<>(actions.size());
float bestMoveValue = MIN_FLOAT; float bestMoveValue = MIN_FLOAT;
Move bestMove = null; AIAction bestMove = null;
this.onStartEval.emit(moves.size()); this.onStartEval.emit(actions.size());
for (Move move : moves) { for (AIAction action : actions) {
moveEvaluations.add(this.threadPool.submit(() -> { moveEvaluations.add(this.threadPool.submit(() -> {
return AlphaBetaThreadCreator.getMoveValue(move, this.searchDepth); return AlphaBetaThreadCreator.getMoveValue(action, this.searchDepth);
})); }));
} }
for (int i = 0; i < moves.size(); i++) { for (int i = 0; i < actions.size(); i++) {
this.onProgress.emit((float) i / (float) moves.size()); this.onProgress.emit((float) i / (float) actions.size());
Move move = moves.get(i); AIAction action = actions.get(i);
float value = MIN_FLOAT; float value = MIN_FLOAT;
try { try {
@@ -64,7 +60,7 @@ public class AlphaBetaAI extends AI {
} }
if (value > bestMoveValue) { if (value > bestMoveValue) {
bestMoveValue = value; bestMoveValue = value;
bestMove = move; bestMove = action;
} }
} }
@@ -80,13 +76,8 @@ public class AlphaBetaAI extends AI {
@Override @Override
protected void play() { protected void play() {
Move move = getBestMove(); AIAction move = getBestMove();
sendCommand(new MoveCommand(move)); move.applyAction();
}
@Override
protected void promote(Coordinate pawnCoords) {
sendCommand(new PromoteCommand(PromoteType.Queen));
} }
} }

View File

@@ -8,10 +8,10 @@ import java.util.Map.Entry;
import chess.ai.PieceCost; import chess.ai.PieceCost;
import chess.ai.PiecePosCost; import chess.ai.PiecePosCost;
import chess.ai.actions.AIAction;
import chess.model.ChessBoard; import chess.model.ChessBoard;
import chess.model.Color; import chess.model.Color;
import chess.model.Coordinate; import chess.model.Coordinate;
import chess.model.Move;
import chess.model.Piece; import chess.model.Piece;
public class AlphaBetaThread extends Thread { public class AlphaBetaThread extends Thread {
@@ -52,27 +52,27 @@ public class AlphaBetaThread extends Thread {
return result; return result;
} }
public float getMoveValue(Move move, int searchDepth) { public float getMoveValue(AIAction move, int searchDepth) {
this.simulation.tryMove(move); move.applyAction(this.simulation.getCommandExecutor());
float value = -negaMax(searchDepth - 1, MIN_FLOAT, MAX_FLOAT); float value = -negaMax(searchDepth - 1, MIN_FLOAT, MAX_FLOAT);
this.simulation.undoMove(); move.undoAction(this.simulation.getCommandExecutor());
return value; return value;
} }
private float negaMax(int depth, float alpha, float beta) { private float negaMax(int depth, float alpha, float beta) {
float value = MIN_FLOAT; float value = MIN_FLOAT;
List<Move> moves = this.simulation.getAllowedMoves(); List<AIAction> moves = this.simulation.getAllowedActions();
if (moves.isEmpty()) if (moves.isEmpty())
return -getEndGameEvaluation(); return -getEndGameEvaluation();
List<Entry<Move, Float>> movesCost = new ArrayList<>(moves.size()); List<Entry<AIAction, Float>> movesCost = new ArrayList<>(moves.size());
for (Move move : moves) { for (AIAction move : moves) {
this.simulation.tryMove(move); move.applyAction();
movesCost.add(Map.entry(move, -getBoardEvaluation())); movesCost.add(Map.entry(move, -getBoardEvaluation()));
this.simulation.undoMove(); move.undoAction();
} }
Collections.sort(movesCost, (first, second) -> { Collections.sort(movesCost, (first, second) -> {
@@ -83,10 +83,10 @@ public class AlphaBetaThread extends Thread {
return -movesCost.getFirst().getValue(); return -movesCost.getFirst().getValue();
for (var moveEntry : movesCost) { for (var moveEntry : movesCost) {
Move move = moveEntry.getKey(); AIAction move = moveEntry.getKey();
this.simulation.tryMove(move); move.applyAction();
value = Float.max(value, -negaMax(depth - 1, -beta, -alpha)); value = Float.max(value, -negaMax(depth - 1, -beta, -alpha));
this.simulation.undoMove(); move.undoAction();
alpha = Float.max(alpha, value); alpha = Float.max(alpha, value);
if (alpha >= beta) if (alpha >= beta)
return value; return value;

View File

@@ -2,9 +2,9 @@ package chess.ai.minimax;
import java.util.concurrent.ThreadFactory; import java.util.concurrent.ThreadFactory;
import chess.ai.actions.AIAction;
import chess.controller.CommandExecutor; import chess.controller.CommandExecutor;
import chess.model.Color; import chess.model.Color;
import chess.model.Move;
public class AlphaBetaThreadCreator implements ThreadFactory{ public class AlphaBetaThreadCreator implements ThreadFactory{
@@ -21,7 +21,7 @@ public class AlphaBetaThreadCreator implements ThreadFactory{
} }
} }
public static float getMoveValue(Move move, int searchDepth) { public static float getMoveValue(AIAction move, int searchDepth) {
AlphaBetaThread t = (AlphaBetaThread) Thread.currentThread(); AlphaBetaThread t = (AlphaBetaThread) Thread.currentThread();
return t.getMoveValue(move, searchDepth); return t.getMoveValue(move, searchDepth);
} }

View File

@@ -2,16 +2,17 @@ package chess.ai.minimax;
import java.util.List; import java.util.List;
import chess.ai.actions.AIAction;
import chess.ai.actions.AIActions;
import chess.controller.Command; import chess.controller.Command;
import chess.controller.Command.CommandResult; import chess.controller.Command.CommandResult;
import chess.controller.CommandExecutor; import chess.controller.CommandExecutor;
import chess.controller.commands.CastlingCommand; import chess.controller.commands.CastlingCommand;
import chess.controller.commands.GetPlayerMovesCommand;
import chess.controller.commands.MoveCommand; import chess.controller.commands.MoveCommand;
import chess.controller.commands.NewGameCommand; import chess.controller.commands.NewGameCommand;
import chess.controller.commands.PromoteCommand; import chess.controller.commands.PromoteCommand;
import chess.controller.commands.UndoCommand;
import chess.controller.commands.PromoteCommand.PromoteType; import chess.controller.commands.PromoteCommand.PromoteType;
import chess.controller.commands.UndoCommand;
import chess.controller.event.EmptyGameDispatcher; import chess.controller.event.EmptyGameDispatcher;
import chess.controller.event.GameAdapter; import chess.controller.event.GameAdapter;
import chess.model.ChessBoard; import chess.model.ChessBoard;
@@ -90,10 +91,8 @@ public class GameSimulation extends GameAdapter {
return this.gameSimulation.getPlayerTurn(); return this.gameSimulation.getPlayerTurn();
} }
public List<Move> getAllowedMoves() { public List<AIAction> getAllowedActions() {
GetPlayerMovesCommand cmd = new GetPlayerMovesCommand(); return AIActions.getAllowedActions(this.simulation);
sendCommand(cmd);
return cmd.getMoves();
} }
public void close() { public void close() {