前言

期中黑白棋比賽主要利用 MCTS 演算法決策要下的步

程式說明

程式主要包含了以下五個 Class

  • Bot:主要提供 getAction method 給 competition callback 互叫
  • State:MCTS 狀態的 Interface
  • OthelloState:黑白棋的 MCTS State,繼承自 State
  • Node:MCTS Node
  • MonteCarloTreeSearch:MCTS 演算法 Class

各 Class 詳細說明

程式大多使用註解說明

Bot

class BOT:
    def __init__(self, *args, **kargs):
        # Create MCTS Object for the purpost of search next action
        self.mcts = MonteCarloTreeSearch(timeLimit=2000)
        pass
 
    # a method called by competition callback
    def getAction(self, board, color):
        # search next action
        action = self.mcts.search(OthelloState(board, color))
        print('action = ', action)
        return action

State

State is interface class used by MCTS class

class State:
    def getValidActions() -> np.ndarray:
        """Returns an iterable of all actions which can be taken from this state"""
        return np.array([])
 
    def takeAction(action) -> Self:
        """Returns the state which results from taking action action"""
        pass
 
    def isTerminal():
        """Returns whether this state is a terminal state"""
        pass
 
    def getReward() -> int:
        """Returns the reward for this state"""
        return 0
 
    def getWeight(self, action: tuple[int, int]) -> int:
        """用來計算 action 的加權"""
        return 1

OthelloState

# inheriate from state 
class OthelloState(State):
    def __init__(self, board, color):
        self.board = board
        self.color = color
        self.size = len(board)
 
    # 取得合法步
    def getValidActions(self):
        return getValidMoves(self.board, self.color)
 
    # 執行 Action 回傳新的 Action
    def takeAction(self, action):
        board = self.board.copy()
        executeMove(board, self.color, action)
        return OthelloState(board, -self.color)
 
    # 如果無合法步或已分出勝負則視為終止
    def isTerminal(self):
        if len(self.getValidActions()) == 0:
            return True
        match isEndGame(self.board):
            case None: return False
            case _: return True
 
    # Reward 值被用於計算 UCT 
    def getReward(self):
        # 我方的合法部與對手合法部的差 + 與對方的棋子數差異 = Reward
        diff = len(getValidMoves(self.board, self.color)) - len(getValidMoves(self.board, -self.color))
 
        return np.sum(self.board == self.color) - np.sum(self.board == -self.color) + diff
 
    # Reward 的加權值,用於選擇或避開特定位置
    def getWeight(self, action: Optional[tuple[int, int]]) -> int:
        if action is None:
            return 1
            
        # 有角就下
        if action in [
            (0, 0),
            (0, self.size - 1),
            (self.size - 1, 0),
            (self.size - 1, self.size - 1),
        ]:
            return 100
        # 角周圍不下
        elif action in [
            (0, 1),
            (1, 0),
            (1, 1),
            (0, self.size - 2),
            (1, self.size - 1),
            (1, self.size - 2),
            (self.size - 2, 0),
            (self.size - 1, 1),
            (self.size - 2, 1),
            (self.size - 1, self.size - 2),
            (self.size - 2, self.size - 1),
            (self.size - 2, self.size - 2),
        ]:
            return -100
        # 角周圍的周圍會偏好下
        elif action in [
            (0, 2),
            (1, 2),
            (2, 2),
            (2, 1),
            (2, 0),
            (0, self.size - 3),
            (1, self.size - 3),
            (2, self.size - 3),
            (2, self.size - 2),
            (2, self.size - 1),
            (self.size - 3, 0),
            (self.size - 3, 1),
            (self.size - 3, 2),
            (self.size - 2, 2),
            (self.size - 1, 2),
            (self.size - 3, self.size - 1),
            (self.size - 3, self.size - 2),
            (self.size - 3, self.size - 3),
            (self.size - 2, self.size - 3),
            (self.size - 1, self.size - 3),
        ]:
            return 50
        else:
            return 1

Node

MCTS 節點

class Node:
    def __init__(self, state: State, action: Optional[tuple[int, int]] = None, parent: Optional[Self] = None):
        self.state = state
        self.isTerminal = state.isTerminal()
        self.isFullyExpanded = self.isTerminal
        self.parent = parent
        self.numVisits = 0
        self.totalReward = 0
        self.children = {}
        self.weight = state.getWeight(action)

MCTS

class MonteCarloTreeSearch:
    def __init__(
        self,
        timeLimit=None,  # format in milliseconds
        iterationLimit=None,
        explorationConstant=2,
        rolloutPolicy=randomPolicy,
    ):
        if timeLimit != None and iterationLimit != None:
            raise ValueError("Cannot have both a time limit and an iteration limit")
        if timeLimit != None:
            if timeLimit <= 0:
                raise ValueError("The time limit should be positive")
            self.timeLimit = timeLimit
            self.limitType = "time"
        elif iterationLimit != None:
            if iterationLimit <= 0:
                raise ValueError("The iteration limit should be positive")
            self.iterationLimit = iterationLimit
            self.limitType = "iterations"
        self.exploreConstant = explorationConstant
        self.rollout = rolloutPolicy
 
    # MCTS Search
    def search(self, initialState: State):
    # Create Root from initial state
        self.root = Node(initialState)
 
        # 在限制內不斷執行 MCTS 的四步驟
        # - Selection:選擇要進行展開下一層的葉節點。
        # - Expansion:展開下一層。
        # - Rollout (playout, simulation):估這個節點的 value。
        # - Backpropagation:向上更新
        # Ref: https://liaowc.github.io/blog/mcts-monte-carlo-tree-search/
        match self.limitType:
            # 限制時間
            case "time":
                timeLimit = time.time() + self.timeLimit / 1000
                while time.time() < timeLimit:
                    self.execute()
            # 限制迭代次數
            case "iterations":
                for _ in range(self.iterationLimit):
                    self.execute()
 
        # find action from children (dict[action, node]) of root with best score
        bestNode = self.getBestChild(self.root)
        return next(
            action for action, node in self.root.children.items() if node is bestNode
        )
 
    def execute(self):
        newNode = self.select(self.root)
        reward = self.rollout(newNode.state)
        self.backpropagate(newNode, reward)
 
    # Generated by copilot
    def select(self, node: Node) -> Node:
        while not node.isTerminal:
            # - if node is fully expanded, select best child and continue
            # - fully expanded represents the parent have no valid action
            # - terminal represents gameover 
            if node.isFullyExpanded:
                node = self.getBestChild(node)
            # 如果不是 FullExpanded 則展開與回傳新節點
            else:
                return self.expand(node)
        # if node.isTerminal (GameOver) return current node
        return node
 
    def expand(self, node: Node):
        # 取得所有 Valid Actions
        actions = node.state.getValidActions()
        for action in actions:
            # since ndarray is not hashable, convert it to tuple[int, int]
            __action = tuple(action)
            if __action not in node.children:
                # 展展開新的結點到 Children
                # Children is hashmap action to node
                newNode = Node(node.state.takeAction(action), __action, node)
        
                node.children[__action] = newNode
                if len(actions) == len(node.children):
                    node.isFullyExpanded = True
                return newNode
        raise Exception("Should never reach here")
 
    def backpropagate(self, node: Node, reward: int):
        while node is not None:
            node.numVisits += 1
            node.totalReward += reward
            node = node.parent
 
    def getBestChild(self, node: Node) -> Node:
        def UCT(node: Node, child: Node) -> float:
            return (
                child.totalReward / child.numVisits
                + self.exploreConstant
                * math.sqrt(math.log(node.numVisits) / child.numVisits)
                * child.weight # 乘上加權
            )
 
        # 計算 Node UCT 值與選擇最佳 Node
        return max(node.children.values(), key=lambda child: UCT(node, child))