MCTS: Random search parallelisation
parent
03471fc24a
commit
8e78ac49c0
|
|
@ -0,0 +1,49 @@
|
||||||
|
package model.ai.mcts;
|
||||||
|
|
||||||
|
import model.board.Board;
|
||||||
|
import model.board.Move;
|
||||||
|
import model.board.utils.PawnColor;
|
||||||
|
import org.jetbrains.annotations.NotNull;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
class MCTSRunJob extends Thread {
|
||||||
|
|
||||||
|
private int score = 0;
|
||||||
|
private Board board;
|
||||||
|
private PawnColor playerTurn;
|
||||||
|
private PawnColor myColor;
|
||||||
|
|
||||||
|
MCTSRunJob(@NotNull Board currentBoard, @NotNull PawnColor currentPlayer, @NotNull PawnColor myColor){
|
||||||
|
super();
|
||||||
|
|
||||||
|
board = new Board(currentBoard);
|
||||||
|
playerTurn = currentPlayer.getOpposite();
|
||||||
|
this.myColor = myColor;
|
||||||
|
}
|
||||||
|
|
||||||
|
//***** Getters/Setters *****//
|
||||||
|
|
||||||
|
public int getScore() {
|
||||||
|
return score;
|
||||||
|
}
|
||||||
|
|
||||||
|
//***** Run method *****//
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void run(){
|
||||||
|
List<Move> allPossibleMoves;
|
||||||
|
|
||||||
|
while(!board.isFinished()){
|
||||||
|
allPossibleMoves = board.getAllPossibleMoves(playerTurn);
|
||||||
|
Move moveToDo = allPossibleMoves.get((int) Math.floor(Math.random() * allPossibleMoves.size()));
|
||||||
|
|
||||||
|
board.movePawn(moveToDo);
|
||||||
|
playerTurn = playerTurn.getOpposite();
|
||||||
|
}
|
||||||
|
|
||||||
|
if(board.colorHasWon(myColor)) {
|
||||||
|
score = 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -8,15 +8,24 @@ import org.jetbrains.annotations.NotNull;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* MCTSTree is the class representing the search tree used in the MCTS algorithm.
|
||||||
|
* It has one root which describe the current state of the game.
|
||||||
|
* Getting the best move to do involve running dozen of thousands random runs of the game from the root.
|
||||||
|
*/
|
||||||
class MCTSTree{
|
class MCTSTree{
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Internal class of MCTSTree which represent a node in the search tree.
|
||||||
|
* All the magic happens here.
|
||||||
|
*/
|
||||||
private class Node{
|
private class Node{
|
||||||
private Board boardState;
|
private Board boardState;
|
||||||
private PawnColor playerTurn;
|
private PawnColor playerTurn;
|
||||||
private Move precedentMove = null;
|
private Move precedentMove = null;
|
||||||
|
|
||||||
private Node parent = null;
|
|
||||||
private List<Node> sons = new ArrayList<>();
|
private List<Node> sons = new ArrayList<>();
|
||||||
|
private Node parent = null;
|
||||||
|
|
||||||
private int nbSuccess = 0;
|
private int nbSuccess = 0;
|
||||||
private int nbTries = 0;
|
private int nbTries = 0;
|
||||||
|
|
@ -24,31 +33,39 @@ class MCTSTree{
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* For root node only.
|
* For root node only.
|
||||||
* @param boardState
|
* @param boardState The current board state when the AI needs to get to know which move is the best to do.
|
||||||
*/
|
*/
|
||||||
public Node(@NotNull Board boardState){
|
Node(@NotNull Board boardState){
|
||||||
this.boardState = boardState;
|
this.boardState = boardState;
|
||||||
this.playerTurn = myColor;
|
this.playerTurn = myColor;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* For non-root nodes only.
|
* For non-root nodes only.
|
||||||
* @param boardState
|
* @param boardState The state of the board when a new node is created.
|
||||||
* @param playerTurn
|
* @param playerTurn The color of the player who is able to move a pawn.
|
||||||
* @param precedentMove
|
* @param precedentMove The previous move done, i.e the link between the previous board state and the current board state.
|
||||||
* @param parent
|
* @param depth The depth of the node in the tree. With the root being 0.
|
||||||
*/
|
*/
|
||||||
public Node(@NotNull Board boardState, @NotNull PawnColor playerTurn, @NotNull Move precedentMove, @NotNull Node parent, int depth){
|
private Node(@NotNull Board boardState, @NotNull PawnColor playerTurn, @NotNull Move precedentMove, @NotNull Node parent, int depth){
|
||||||
this.boardState = boardState;
|
this.boardState = boardState;
|
||||||
this.playerTurn = playerTurn;
|
this.playerTurn = playerTurn;
|
||||||
this.precedentMove = precedentMove;
|
this.precedentMove = precedentMove;
|
||||||
|
|
||||||
this.parent = parent;
|
this.parent = parent;
|
||||||
|
|
||||||
this.depth = depth;
|
this.depth = depth;
|
||||||
}
|
}
|
||||||
|
|
||||||
//***** Getters/Setters *****//
|
//***** Getters/Setters *****//
|
||||||
|
|
||||||
|
public void addTries(int tries){
|
||||||
|
nbTries += tries;
|
||||||
|
if(parent != null){
|
||||||
|
parent.addTries(tries);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
public List<Node> getSons(){
|
public List<Node> getSons(){
|
||||||
return sons;
|
return sons;
|
||||||
}
|
}
|
||||||
|
|
@ -63,27 +80,39 @@ class MCTSTree{
|
||||||
|
|
||||||
//***** *****//
|
//***** *****//
|
||||||
|
|
||||||
public int playOneTurnRandom(){
|
int playOneTurn(){
|
||||||
nbTries++;
|
int score = 0;
|
||||||
|
int tries = 1;
|
||||||
if(boardState.isFinished()){
|
|
||||||
if(boardState.colorHasWon(myColor)){
|
|
||||||
nbSuccess++;
|
|
||||||
return 1;
|
|
||||||
} else {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
Node nextNode = null;
|
|
||||||
|
|
||||||
|
if(boardState.isFinished() && boardState.colorHasWon(myColor)){
|
||||||
|
score = 1;
|
||||||
|
addTries(tries);
|
||||||
|
} else if(!boardState.isFinished()) {
|
||||||
if(depth > maxDepth) { // Case where we need to play randomly
|
if(depth > maxDepth) { // Case where we need to play randomly
|
||||||
List<Move> allPossibleMoves = boardState.getAllPossibleMoves(playerTurn);
|
|
||||||
Move moveToDo = allPossibleMoves.get((int) Math.floor(Math.random() * allPossibleMoves.size()));
|
|
||||||
|
|
||||||
nextNode = new Node(new Board(boardState, moveToDo), playerTurn.getOpposite(), moveToDo, this, depth + 1);
|
// Launch four jobs.
|
||||||
|
MCTSRunJob[] jobs = new MCTSRunJob[4];
|
||||||
|
for(int i = 0; i < jobs.length; i++){
|
||||||
|
jobs[i] = new MCTSRunJob(boardState, playerTurn, myColor);
|
||||||
|
jobs[i].start();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for all jobs to finish and collect score.
|
||||||
|
for (MCTSRunJob job : jobs) {
|
||||||
|
try {
|
||||||
|
job.join();
|
||||||
|
score += job.getScore();
|
||||||
|
} catch (InterruptedException e) {
|
||||||
|
e.printStackTrace();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
tries += jobs.length-1;
|
||||||
|
|
||||||
|
addTries(tries);
|
||||||
} else { // Case where we are in the beginning of the tree
|
} else { // Case where we are in the beginning of the tree
|
||||||
List<Move> allPossibleMoves = boardState.getAllPossibleMoves(playerTurn);
|
List<Move> allPossibleMoves = boardState.getAllPossibleMoves(playerTurn);
|
||||||
|
Node nextNode;
|
||||||
Move moveToDo;
|
Move moveToDo;
|
||||||
|
|
||||||
// Remove the moves already known
|
// Remove the moves already known
|
||||||
|
|
@ -115,13 +144,15 @@ class MCTSTree{
|
||||||
|
|
||||||
sons.add(nextNode);
|
sons.add(nextNode);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Play the next turn.
|
||||||
|
score = nextNode.playOneTurn();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
int success = nextNode.playOneTurnRandom();
|
nbSuccess += score;
|
||||||
nbSuccess += success;
|
|
||||||
|
|
||||||
return success;
|
return score;
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
@ -145,19 +176,19 @@ class MCTSTree{
|
||||||
private PawnColor myColor;
|
private PawnColor myColor;
|
||||||
private Node root;
|
private Node root;
|
||||||
|
|
||||||
public MCTSTree(@NotNull Board board, @NotNull PawnColor turnColor){
|
MCTSTree(@NotNull Board board, @NotNull PawnColor turnColor){
|
||||||
myColor = turnColor;
|
myColor = turnColor;
|
||||||
root = new Node(board);
|
root = new Node(board);
|
||||||
}
|
}
|
||||||
|
|
||||||
public Move getBestMove(long timeBudget){
|
Move getBestMove(long timeBudgetMillis){
|
||||||
|
|
||||||
// Explore the MCTS Tree
|
// Explore the MCTS Tree
|
||||||
|
|
||||||
long start = System.currentTimeMillis();
|
long start = System.currentTimeMillis();
|
||||||
|
|
||||||
while(System.currentTimeMillis() - start < timeBudget){
|
while(System.currentTimeMillis() - start < timeBudgetMillis){
|
||||||
root.playOneTurnRandom();
|
root.playOneTurn();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get the results
|
// Get the results
|
||||||
|
|
|
||||||
|
|
@ -19,7 +19,7 @@ public class MainTest {
|
||||||
|
|
||||||
System.err.println("-----------------");
|
System.err.println("-----------------");
|
||||||
System.err.println("Move: "+ aiMove);
|
System.err.println("Move: "+ aiMove);
|
||||||
System.err.println("Node Visited: "+ MCTS.nbNodeVisited);
|
System.err.println("Node Visited: ");
|
||||||
System.err.println("-----------------");
|
System.err.println("-----------------");
|
||||||
System.err.println(board);
|
System.err.println(board);
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue