MCTS: Random search parallelisation
parent
03471fc24a
commit
8e78ac49c0
|
|
@ -0,0 +1,49 @@
|
|||
package model.ai.mcts;
|
||||
|
||||
import model.board.Board;
|
||||
import model.board.Move;
|
||||
import model.board.utils.PawnColor;
|
||||
import org.jetbrains.annotations.NotNull;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
class MCTSRunJob extends Thread {
|
||||
|
||||
private int score = 0;
|
||||
private Board board;
|
||||
private PawnColor playerTurn;
|
||||
private PawnColor myColor;
|
||||
|
||||
MCTSRunJob(@NotNull Board currentBoard, @NotNull PawnColor currentPlayer, @NotNull PawnColor myColor){
|
||||
super();
|
||||
|
||||
board = new Board(currentBoard);
|
||||
playerTurn = currentPlayer.getOpposite();
|
||||
this.myColor = myColor;
|
||||
}
|
||||
|
||||
//***** Getters/Setters *****//
|
||||
|
||||
public int getScore() {
|
||||
return score;
|
||||
}
|
||||
|
||||
//***** Run method *****//
|
||||
|
||||
@Override
|
||||
public void run(){
|
||||
List<Move> allPossibleMoves;
|
||||
|
||||
while(!board.isFinished()){
|
||||
allPossibleMoves = board.getAllPossibleMoves(playerTurn);
|
||||
Move moveToDo = allPossibleMoves.get((int) Math.floor(Math.random() * allPossibleMoves.size()));
|
||||
|
||||
board.movePawn(moveToDo);
|
||||
playerTurn = playerTurn.getOpposite();
|
||||
}
|
||||
|
||||
if(board.colorHasWon(myColor)) {
|
||||
score = 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -8,15 +8,24 @@ import org.jetbrains.annotations.NotNull;
|
|||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* MCTSTree is the class representing the search tree used in the MCTS algorithm.
|
||||
* It has one root which describe the current state of the game.
|
||||
* Getting the best move to do involve running dozen of thousands random runs of the game from the root.
|
||||
*/
|
||||
class MCTSTree{
|
||||
|
||||
/**
|
||||
* Internal class of MCTSTree which represent a node in the search tree.
|
||||
* All the magic happens here.
|
||||
*/
|
||||
private class Node{
|
||||
private Board boardState;
|
||||
private PawnColor playerTurn;
|
||||
private Move precedentMove = null;
|
||||
|
||||
private Node parent = null;
|
||||
private List<Node> sons = new ArrayList<>();
|
||||
private Node parent = null;
|
||||
|
||||
private int nbSuccess = 0;
|
||||
private int nbTries = 0;
|
||||
|
|
@ -24,31 +33,39 @@ class MCTSTree{
|
|||
|
||||
/**
|
||||
* For root node only.
|
||||
* @param boardState
|
||||
* @param boardState The current board state when the AI needs to get to know which move is the best to do.
|
||||
*/
|
||||
public Node(@NotNull Board boardState){
|
||||
Node(@NotNull Board boardState){
|
||||
this.boardState = boardState;
|
||||
this.playerTurn = myColor;
|
||||
}
|
||||
|
||||
/**
|
||||
* For non-root nodes only.
|
||||
* @param boardState
|
||||
* @param playerTurn
|
||||
* @param precedentMove
|
||||
* @param parent
|
||||
* @param boardState The state of the board when a new node is created.
|
||||
* @param playerTurn The color of the player who is able to move a pawn.
|
||||
* @param precedentMove The previous move done, i.e the link between the previous board state and the current board state.
|
||||
* @param depth The depth of the node in the tree. With the root being 0.
|
||||
*/
|
||||
public Node(@NotNull Board boardState, @NotNull PawnColor playerTurn, @NotNull Move precedentMove, @NotNull Node parent, int depth){
|
||||
private Node(@NotNull Board boardState, @NotNull PawnColor playerTurn, @NotNull Move precedentMove, @NotNull Node parent, int depth){
|
||||
this.boardState = boardState;
|
||||
this.playerTurn = playerTurn;
|
||||
this.precedentMove = precedentMove;
|
||||
|
||||
this.parent = parent;
|
||||
|
||||
this.depth = depth;
|
||||
}
|
||||
|
||||
//***** Getters/Setters *****//
|
||||
|
||||
public void addTries(int tries){
|
||||
nbTries += tries;
|
||||
if(parent != null){
|
||||
parent.addTries(tries);
|
||||
}
|
||||
}
|
||||
|
||||
public List<Node> getSons(){
|
||||
return sons;
|
||||
}
|
||||
|
|
@ -63,27 +80,39 @@ class MCTSTree{
|
|||
|
||||
//***** *****//
|
||||
|
||||
public int playOneTurnRandom(){
|
||||
nbTries++;
|
||||
|
||||
if(boardState.isFinished()){
|
||||
if(boardState.colorHasWon(myColor)){
|
||||
nbSuccess++;
|
||||
return 1;
|
||||
} else {
|
||||
return 0;
|
||||
}
|
||||
} else {
|
||||
Node nextNode = null;
|
||||
int playOneTurn(){
|
||||
int score = 0;
|
||||
int tries = 1;
|
||||
|
||||
if(boardState.isFinished() && boardState.colorHasWon(myColor)){
|
||||
score = 1;
|
||||
addTries(tries);
|
||||
} else if(!boardState.isFinished()) {
|
||||
if(depth > maxDepth) { // Case where we need to play randomly
|
||||
List<Move> allPossibleMoves = boardState.getAllPossibleMoves(playerTurn);
|
||||
Move moveToDo = allPossibleMoves.get((int) Math.floor(Math.random() * allPossibleMoves.size()));
|
||||
|
||||
nextNode = new Node(new Board(boardState, moveToDo), playerTurn.getOpposite(), moveToDo, this, depth + 1);
|
||||
// Launch four jobs.
|
||||
MCTSRunJob[] jobs = new MCTSRunJob[4];
|
||||
for(int i = 0; i < jobs.length; i++){
|
||||
jobs[i] = new MCTSRunJob(boardState, playerTurn, myColor);
|
||||
jobs[i].start();
|
||||
}
|
||||
|
||||
// Wait for all jobs to finish and collect score.
|
||||
for (MCTSRunJob job : jobs) {
|
||||
try {
|
||||
job.join();
|
||||
score += job.getScore();
|
||||
} catch (InterruptedException e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
}
|
||||
|
||||
tries += jobs.length-1;
|
||||
|
||||
addTries(tries);
|
||||
} else { // Case where we are in the beginning of the tree
|
||||
List<Move> allPossibleMoves = boardState.getAllPossibleMoves(playerTurn);
|
||||
Node nextNode;
|
||||
Move moveToDo;
|
||||
|
||||
// Remove the moves already known
|
||||
|
|
@ -115,13 +144,15 @@ class MCTSTree{
|
|||
|
||||
sons.add(nextNode);
|
||||
}
|
||||
|
||||
// Play the next turn.
|
||||
score = nextNode.playOneTurn();
|
||||
}
|
||||
}
|
||||
|
||||
int success = nextNode.playOneTurnRandom();
|
||||
nbSuccess += success;
|
||||
nbSuccess += score;
|
||||
|
||||
return success;
|
||||
}
|
||||
return score;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
@ -145,19 +176,19 @@ class MCTSTree{
|
|||
private PawnColor myColor;
|
||||
private Node root;
|
||||
|
||||
public MCTSTree(@NotNull Board board, @NotNull PawnColor turnColor){
|
||||
MCTSTree(@NotNull Board board, @NotNull PawnColor turnColor){
|
||||
myColor = turnColor;
|
||||
root = new Node(board);
|
||||
}
|
||||
|
||||
public Move getBestMove(long timeBudget){
|
||||
Move getBestMove(long timeBudgetMillis){
|
||||
|
||||
// Explore the MCTS Tree
|
||||
|
||||
long start = System.currentTimeMillis();
|
||||
|
||||
while(System.currentTimeMillis() - start < timeBudget){
|
||||
root.playOneTurnRandom();
|
||||
while(System.currentTimeMillis() - start < timeBudgetMillis){
|
||||
root.playOneTurn();
|
||||
}
|
||||
|
||||
// Get the results
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ public class MainTest {
|
|||
|
||||
System.err.println("-----------------");
|
||||
System.err.println("Move: "+ aiMove);
|
||||
System.err.println("Node Visited: "+ MCTS.nbNodeVisited);
|
||||
System.err.println("Node Visited: ");
|
||||
System.err.println("-----------------");
|
||||
System.err.println(board);
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue