前言 期中黑白棋比賽主要利用 MCTS 演算法決策要下的步 程式說明 程式主要包含了以下五個 Class Bot:主要提供 getAction method 給 competition callback 互叫 State:MCTS 狀態的 Interface OthelloState:黑白棋的 MCTS State,繼承自 State Node:MCTS Node MonteCarloTreeSearch:MCTS 演算法 Class 各 Class 詳細說明 程式大多使用註解說明 Bot class BOT: def __init__(self, *args, **kargs): # Create MCTS Object for the purpost of search next action self.mcts = MonteCarloTreeSearch(timeLimit=2000) pass # a method called by competition callback def getAction(self, board, color): # search next action action = self.mcts.search(OthelloState(board, color)) print('action = ', action) return action State State is interface class used by MCTS class class State: def getValidActions() -> np.ndarray: """Returns an iterable of all actions which can be taken from this state""" return np.array([]) def takeAction(action) -> Self: """Returns the state which results from taking action action""" pass def isTerminal(): """Returns whether this state is a terminal state""" pass def getReward() -> int: """Returns the reward for this state""" return 0 def getWeight(self, action: tuple[int, int]) -> int: """用來計算 action 的加權""" return 1 OthelloState # inheriate from state class OthelloState(State): def __init__(self, board, color): self.board = board self.color = color self.size = len(board) # 取得合法步 def getValidActions(self): return getValidMoves(self.board, self.color) # 執行 Action 回傳新的 Action def takeAction(self, action): board = self.board.copy() executeMove(board, self.color, action) return OthelloState(board, -self.color) # 如果無合法步或已分出勝負則視為終止 def isTerminal(self): if len(self.getValidActions()) == 0: return True match isEndGame(self.board): case None: return False case _: return True # Reward 值被用於計算 UCT def getReward(self): # 我方的合法部與對手合法部的差 + 與對方的棋子數差異 = Reward diff = len(getValidMoves(self.board, self.color)) - len(getValidMoves(self.board, -self.color)) return np.sum(self.board == self.color) - np.sum(self.board == -self.color) + diff # Reward 的加權值,用於選擇或避開特定位置 def getWeight(self, action: Optional[tuple[int, int]]) -> int: if action is None: return 1 # 有角就下 if action in [ (0, 0), (0, self.size - 1), (self.size - 1, 0), (self.size - 1, self.size - 1), ]: return 100 # 角周圍不下 elif action in [ (0, 1), (1, 0), (1, 1), (0, self.size - 2), (1, self.size - 1), (1, self.size - 2), (self.size - 2, 0), (self.size - 1, 1), (self.size - 2, 1), (self.size - 1, self.size - 2), (self.size - 2, self.size - 1), (self.size - 2, self.size - 2), ]: return -100 # 角周圍的周圍會偏好下 elif action in [ (0, 2), (1, 2), (2, 2), (2, 1), (2, 0), (0, self.size - 3), (1, self.size - 3), (2, self.size - 3), (2, self.size - 2), (2, self.size - 1), (self.size - 3, 0), (self.size - 3, 1), (self.size - 3, 2), (self.size - 2, 2), (self.size - 1, 2), (self.size - 3, self.size - 1), (self.size - 3, self.size - 2), (self.size - 3, self.size - 3), (self.size - 2, self.size - 3), (self.size - 1, self.size - 3), ]: return 50 else: return 1 Node MCTS 節點 class Node: def __init__(self, state: State, action: Optional[tuple[int, int]] = None, parent: Optional[Self] = None): self.state = state self.isTerminal = state.isTerminal() self.isFullyExpanded = self.isTerminal self.parent = parent self.numVisits = 0 self.totalReward = 0 self.children = {} self.weight = state.getWeight(action) MCTS class MonteCarloTreeSearch: def __init__( self, timeLimit=None, # format in milliseconds iterationLimit=None, explorationConstant=2, rolloutPolicy=randomPolicy, ): if timeLimit != None and iterationLimit != None: raise ValueError("Cannot have both a time limit and an iteration limit") if timeLimit != None: if timeLimit <= 0: raise ValueError("The time limit should be positive") self.timeLimit = timeLimit self.limitType = "time" elif iterationLimit != None: if iterationLimit <= 0: raise ValueError("The iteration limit should be positive") self.iterationLimit = iterationLimit self.limitType = "iterations" self.exploreConstant = explorationConstant self.rollout = rolloutPolicy # MCTS Search def search(self, initialState: State): # Create Root from initial state self.root = Node(initialState) # 在限制內不斷執行 MCTS 的四步驟 # - Selection:選擇要進行展開下一層的葉節點。 # - Expansion:展開下一層。 # - Rollout (playout, simulation):估這個節點的 value。 # - Backpropagation:向上更新 # Ref: https://liaowc.github.io/blog/mcts-monte-carlo-tree-search/ match self.limitType: # 限制時間 case "time": timeLimit = time.time() + self.timeLimit / 1000 while time.time() < timeLimit: self.execute() # 限制迭代次數 case "iterations": for _ in range(self.iterationLimit): self.execute() # find action from children (dict[action, node]) of root with best score bestNode = self.getBestChild(self.root) return next( action for action, node in self.root.children.items() if node is bestNode ) def execute(self): newNode = self.select(self.root) reward = self.rollout(newNode.state) self.backpropagate(newNode, reward) # Generated by copilot def select(self, node: Node) -> Node: while not node.isTerminal: # - if node is fully expanded, select best child and continue # - fully expanded represents the parent have no valid action # - terminal represents gameover if node.isFullyExpanded: node = self.getBestChild(node) # 如果不是 FullExpanded 則展開與回傳新節點 else: return self.expand(node) # if node.isTerminal (GameOver) return current node return node def expand(self, node: Node): # 取得所有 Valid Actions actions = node.state.getValidActions() for action in actions: # since ndarray is not hashable, convert it to tuple[int, int] __action = tuple(action) if __action not in node.children: # 展展開新的結點到 Children # Children is hashmap action to node newNode = Node(node.state.takeAction(action), __action, node) node.children[__action] = newNode if len(actions) == len(node.children): node.isFullyExpanded = True return newNode raise Exception("Should never reach here") def backpropagate(self, node: Node, reward: int): while node is not None: node.numVisits += 1 node.totalReward += reward node = node.parent def getBestChild(self, node: Node) -> Node: def UCT(node: Node, child: Node) -> float: return ( child.totalReward / child.numVisits + self.exploreConstant * math.sqrt(math.log(node.numVisits) / child.numVisits) * child.weight # 乘上加權 ) # 計算 Node UCT 值與選擇最佳 Node return max(node.children.values(), key=lambda child: UCT(node, child))