diff --git a/src/main/java/model/ai/mcts/MCTSRunJob.java b/src/main/java/model/ai/mcts/MCTSRunJob.java new file mode 100644 index 0000000..0eaf092 --- /dev/null +++ b/src/main/java/model/ai/mcts/MCTSRunJob.java @@ -0,0 +1,49 @@ +package model.ai.mcts; + +import model.board.Board; +import model.board.Move; +import model.board.utils.PawnColor; +import org.jetbrains.annotations.NotNull; + +import java.util.List; + +class MCTSRunJob extends Thread { + + private int score = 0; + private Board board; + private PawnColor playerTurn; + private PawnColor myColor; + + MCTSRunJob(@NotNull Board currentBoard, @NotNull PawnColor currentPlayer, @NotNull PawnColor myColor){ + super(); + + board = new Board(currentBoard); + playerTurn = currentPlayer.getOpposite(); + this.myColor = myColor; + } + + //***** Getters/Setters *****// + + public int getScore() { + return score; + } + + //***** Run method *****// + + @Override + public void run(){ + List allPossibleMoves; + + while(!board.isFinished()){ + allPossibleMoves = board.getAllPossibleMoves(playerTurn); + Move moveToDo = allPossibleMoves.get((int) Math.floor(Math.random() * allPossibleMoves.size())); + + board.movePawn(moveToDo); + playerTurn = playerTurn.getOpposite(); + } + + if(board.colorHasWon(myColor)) { + score = 1; + } + } +} diff --git a/src/main/java/model/ai/mcts/MCTSTree.java b/src/main/java/model/ai/mcts/MCTSTree.java index 546599e..623dcac 100644 --- a/src/main/java/model/ai/mcts/MCTSTree.java +++ b/src/main/java/model/ai/mcts/MCTSTree.java @@ -8,15 +8,24 @@ import org.jetbrains.annotations.NotNull; import java.util.ArrayList; import java.util.List; +/** + * MCTSTree is the class representing the search tree used in the MCTS algorithm. + * It has one root which describe the current state of the game. + * Getting the best move to do involve running dozen of thousands random runs of the game from the root. + */ class MCTSTree{ + /** + * Internal class of MCTSTree which represent a node in the search tree. + * All the magic happens here. + */ private class Node{ private Board boardState; private PawnColor playerTurn; private Move precedentMove = null; - private Node parent = null; private List sons = new ArrayList<>(); + private Node parent = null; private int nbSuccess = 0; private int nbTries = 0; @@ -24,31 +33,39 @@ class MCTSTree{ /** * For root node only. - * @param boardState + * @param boardState The current board state when the AI needs to get to know which move is the best to do. */ - public Node(@NotNull Board boardState){ + Node(@NotNull Board boardState){ this.boardState = boardState; this.playerTurn = myColor; } /** * For non-root nodes only. - * @param boardState - * @param playerTurn - * @param precedentMove - * @param parent + * @param boardState The state of the board when a new node is created. + * @param playerTurn The color of the player who is able to move a pawn. + * @param precedentMove The previous move done, i.e the link between the previous board state and the current board state. + * @param depth The depth of the node in the tree. With the root being 0. */ - public Node(@NotNull Board boardState, @NotNull PawnColor playerTurn, @NotNull Move precedentMove, @NotNull Node parent, int depth){ + private Node(@NotNull Board boardState, @NotNull PawnColor playerTurn, @NotNull Move precedentMove, @NotNull Node parent, int depth){ this.boardState = boardState; this.playerTurn = playerTurn; this.precedentMove = precedentMove; this.parent = parent; + this.depth = depth; } //***** Getters/Setters *****// + public void addTries(int tries){ + nbTries += tries; + if(parent != null){ + parent.addTries(tries); + } + } + public List getSons(){ return sons; } @@ -63,27 +80,39 @@ class MCTSTree{ //***** *****// - public int playOneTurnRandom(){ - nbTries++; - - if(boardState.isFinished()){ - if(boardState.colorHasWon(myColor)){ - nbSuccess++; - return 1; - } else { - return 0; - } - } else { - Node nextNode = null; + int playOneTurn(){ + int score = 0; + int tries = 1; + if(boardState.isFinished() && boardState.colorHasWon(myColor)){ + score = 1; + addTries(tries); + } else if(!boardState.isFinished()) { if(depth > maxDepth) { // Case where we need to play randomly - List allPossibleMoves = boardState.getAllPossibleMoves(playerTurn); - Move moveToDo = allPossibleMoves.get((int) Math.floor(Math.random() * allPossibleMoves.size())); - nextNode = new Node(new Board(boardState, moveToDo), playerTurn.getOpposite(), moveToDo, this, depth + 1); + // Launch four jobs. + MCTSRunJob[] jobs = new MCTSRunJob[4]; + for(int i = 0; i < jobs.length; i++){ + jobs[i] = new MCTSRunJob(boardState, playerTurn, myColor); + jobs[i].start(); + } + // Wait for all jobs to finish and collect score. + for (MCTSRunJob job : jobs) { + try { + job.join(); + score += job.getScore(); + } catch (InterruptedException e) { + e.printStackTrace(); + } + } + + tries += jobs.length-1; + + addTries(tries); } else { // Case where we are in the beginning of the tree List allPossibleMoves = boardState.getAllPossibleMoves(playerTurn); + Node nextNode; Move moveToDo; // Remove the moves already known @@ -115,13 +144,15 @@ class MCTSTree{ sons.add(nextNode); } + + // Play the next turn. + score = nextNode.playOneTurn(); } - - int success = nextNode.playOneTurnRandom(); - nbSuccess += success; - - return success; } + + nbSuccess += score; + + return score; } @Override @@ -145,19 +176,19 @@ class MCTSTree{ private PawnColor myColor; private Node root; - public MCTSTree(@NotNull Board board, @NotNull PawnColor turnColor){ + MCTSTree(@NotNull Board board, @NotNull PawnColor turnColor){ myColor = turnColor; root = new Node(board); } - public Move getBestMove(long timeBudget){ + Move getBestMove(long timeBudgetMillis){ // Explore the MCTS Tree long start = System.currentTimeMillis(); - while(System.currentTimeMillis() - start < timeBudget){ - root.playOneTurnRandom(); + while(System.currentTimeMillis() - start < timeBudgetMillis){ + root.playOneTurn(); } // Get the results diff --git a/src/test/java/MainTest.java b/src/test/java/MainTest.java index 9587884..9f5e41b 100644 --- a/src/test/java/MainTest.java +++ b/src/test/java/MainTest.java @@ -19,7 +19,7 @@ public class MainTest { System.err.println("-----------------"); System.err.println("Move: "+ aiMove); - System.err.println("Node Visited: "+ MCTS.nbNodeVisited); + System.err.println("Node Visited: "); System.err.println("-----------------"); System.err.println(board);