MCTS: Random search parallelisation

master
Gregory Martin 2018-03-15 10:00:00 +01:00
parent 03471fc24a
commit 8e78ac49c0
No known key found for this signature in database
GPG Key ID: 8791DD65FA92D9F0
3 changed files with 113 additions and 33 deletions

View File

@ -0,0 +1,49 @@
package model.ai.mcts;
import model.board.Board;
import model.board.Move;
import model.board.utils.PawnColor;
import org.jetbrains.annotations.NotNull;
import java.util.List;
class MCTSRunJob extends Thread {
private int score = 0;
private Board board;
private PawnColor playerTurn;
private PawnColor myColor;
MCTSRunJob(@NotNull Board currentBoard, @NotNull PawnColor currentPlayer, @NotNull PawnColor myColor){
super();
board = new Board(currentBoard);
playerTurn = currentPlayer.getOpposite();
this.myColor = myColor;
}
//***** Getters/Setters *****//
public int getScore() {
return score;
}
//***** Run method *****//
@Override
public void run(){
List<Move> allPossibleMoves;
while(!board.isFinished()){
allPossibleMoves = board.getAllPossibleMoves(playerTurn);
Move moveToDo = allPossibleMoves.get((int) Math.floor(Math.random() * allPossibleMoves.size()));
board.movePawn(moveToDo);
playerTurn = playerTurn.getOpposite();
}
if(board.colorHasWon(myColor)) {
score = 1;
}
}
}

View File

@ -8,15 +8,24 @@ import org.jetbrains.annotations.NotNull;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
/**
* MCTSTree is the class representing the search tree used in the MCTS algorithm.
* It has one root which describe the current state of the game.
* Getting the best move to do involve running dozen of thousands random runs of the game from the root.
*/
class MCTSTree{ class MCTSTree{
/**
* Internal class of MCTSTree which represent a node in the search tree.
* All the magic happens here.
*/
private class Node{ private class Node{
private Board boardState; private Board boardState;
private PawnColor playerTurn; private PawnColor playerTurn;
private Move precedentMove = null; private Move precedentMove = null;
private Node parent = null;
private List<Node> sons = new ArrayList<>(); private List<Node> sons = new ArrayList<>();
private Node parent = null;
private int nbSuccess = 0; private int nbSuccess = 0;
private int nbTries = 0; private int nbTries = 0;
@ -24,31 +33,39 @@ class MCTSTree{
/** /**
* For root node only. * For root node only.
* @param boardState * @param boardState The current board state when the AI needs to get to know which move is the best to do.
*/ */
public Node(@NotNull Board boardState){ Node(@NotNull Board boardState){
this.boardState = boardState; this.boardState = boardState;
this.playerTurn = myColor; this.playerTurn = myColor;
} }
/** /**
* For non-root nodes only. * For non-root nodes only.
* @param boardState * @param boardState The state of the board when a new node is created.
* @param playerTurn * @param playerTurn The color of the player who is able to move a pawn.
* @param precedentMove * @param precedentMove The previous move done, i.e the link between the previous board state and the current board state.
* @param parent * @param depth The depth of the node in the tree. With the root being 0.
*/ */
public Node(@NotNull Board boardState, @NotNull PawnColor playerTurn, @NotNull Move precedentMove, @NotNull Node parent, int depth){ private Node(@NotNull Board boardState, @NotNull PawnColor playerTurn, @NotNull Move precedentMove, @NotNull Node parent, int depth){
this.boardState = boardState; this.boardState = boardState;
this.playerTurn = playerTurn; this.playerTurn = playerTurn;
this.precedentMove = precedentMove; this.precedentMove = precedentMove;
this.parent = parent; this.parent = parent;
this.depth = depth; this.depth = depth;
} }
//***** Getters/Setters *****// //***** Getters/Setters *****//
public void addTries(int tries){
nbTries += tries;
if(parent != null){
parent.addTries(tries);
}
}
public List<Node> getSons(){ public List<Node> getSons(){
return sons; return sons;
} }
@ -63,27 +80,39 @@ class MCTSTree{
//***** *****// //***** *****//
public int playOneTurnRandom(){ int playOneTurn(){
nbTries++; int score = 0;
int tries = 1;
if(boardState.isFinished()){
if(boardState.colorHasWon(myColor)){
nbSuccess++;
return 1;
} else {
return 0;
}
} else {
Node nextNode = null;
if(boardState.isFinished() && boardState.colorHasWon(myColor)){
score = 1;
addTries(tries);
} else if(!boardState.isFinished()) {
if(depth > maxDepth) { // Case where we need to play randomly if(depth > maxDepth) { // Case where we need to play randomly
List<Move> allPossibleMoves = boardState.getAllPossibleMoves(playerTurn);
Move moveToDo = allPossibleMoves.get((int) Math.floor(Math.random() * allPossibleMoves.size()));
nextNode = new Node(new Board(boardState, moveToDo), playerTurn.getOpposite(), moveToDo, this, depth + 1); // Launch four jobs.
MCTSRunJob[] jobs = new MCTSRunJob[4];
for(int i = 0; i < jobs.length; i++){
jobs[i] = new MCTSRunJob(boardState, playerTurn, myColor);
jobs[i].start();
}
// Wait for all jobs to finish and collect score.
for (MCTSRunJob job : jobs) {
try {
job.join();
score += job.getScore();
} catch (InterruptedException e) {
e.printStackTrace();
}
}
tries += jobs.length-1;
addTries(tries);
} else { // Case where we are in the beginning of the tree } else { // Case where we are in the beginning of the tree
List<Move> allPossibleMoves = boardState.getAllPossibleMoves(playerTurn); List<Move> allPossibleMoves = boardState.getAllPossibleMoves(playerTurn);
Node nextNode;
Move moveToDo; Move moveToDo;
// Remove the moves already known // Remove the moves already known
@ -115,13 +144,15 @@ class MCTSTree{
sons.add(nextNode); sons.add(nextNode);
} }
// Play the next turn.
score = nextNode.playOneTurn();
}
} }
int success = nextNode.playOneTurnRandom(); nbSuccess += score;
nbSuccess += success;
return success; return score;
}
} }
@Override @Override
@ -145,19 +176,19 @@ class MCTSTree{
private PawnColor myColor; private PawnColor myColor;
private Node root; private Node root;
public MCTSTree(@NotNull Board board, @NotNull PawnColor turnColor){ MCTSTree(@NotNull Board board, @NotNull PawnColor turnColor){
myColor = turnColor; myColor = turnColor;
root = new Node(board); root = new Node(board);
} }
public Move getBestMove(long timeBudget){ Move getBestMove(long timeBudgetMillis){
// Explore the MCTS Tree // Explore the MCTS Tree
long start = System.currentTimeMillis(); long start = System.currentTimeMillis();
while(System.currentTimeMillis() - start < timeBudget){ while(System.currentTimeMillis() - start < timeBudgetMillis){
root.playOneTurnRandom(); root.playOneTurn();
} }
// Get the results // Get the results

View File

@ -19,7 +19,7 @@ public class MainTest {
System.err.println("-----------------"); System.err.println("-----------------");
System.err.println("Move: "+ aiMove); System.err.println("Move: "+ aiMove);
System.err.println("Node Visited: "+ MCTS.nbNodeVisited); System.err.println("Node Visited: ");
System.err.println("-----------------"); System.err.println("-----------------");
System.err.println(board); System.err.println(board);