def backpropagate(self, node: Node, value: Value, min_max_stats: MinMaxStats) -> None: while node is not None: min_max_stats.update(node.update_value(value)) value = node.reward + self.effective_discount * value if node.reward is not None else Value( float('nan')) node = node.parent
def backpropagate(search_path: List[Node], value: float, to_play: Player, discount: float, min_max_stats: MinMaxStats): for node in reversed(search_path): node.value_sum += value if node.to_play == to_play else -value node.visit_count += 1 min_max_stats.update(node.value()) value = node.reward + discount * value
def run_mcts(self, obs: np.ndarray, network: Network) -> Node: # ルートノードを展開 root = Node(0) state, policy, value = network.initial_inference(obs) root.expand_node(0, state.squeeze().detach().numpy(), 0, policy.squeeze().detach().numpy()) root.add_exploration_noise(self.dirichlet_alpha, self.exploration_fraction) # if train: min_max_stats = MinMaxStats(None) for _ in range(self.num_simulations): node = root search_path = [node] while node.expanded: # 展開されていない子まで辿る action, node = self._select_child(node, min_max_stats) search_path.append(node) # 子ノードを展開 parent = search_path[-2] next_state, reward, policy, value = network.recurrent_inference( torch.from_numpy(parent.hidden_state).unsqueeze(0), np.array([action]) ) node.expand_node(reward.item(), next_state.squeeze().detach().numpy(), 0, policy.squeeze().detach().numpy()) # 探索結果をルートまで反映 self._backpropagate(search_path, value.item(), 0, min_max_stats) return root
def play_game(self) -> Game: game = Game(self.config.discount) min_max_stats = MinMaxStats(self.config.known_bounds) # Use Exponential Decay to reduce temperature over time temperature = max( self.temperature * (1 - self.config.temperature_decay_factor)** self.network.training_steps(), self.config.temperature_min) self.metrics_temperature(temperature) while not game.terminal() and len( game.history) < self.config.max_moves: # At the root of the search tree we use the representation function to # obtain a hidden state given the current observation. root = Node(0) current_observation = game.get_observation_from_index(-1) network_output = self.network.initial_inference( current_observation) expand_node(root, game.to_play(), game.legal_actions(), network_output) backpropagate([root], network_output.value, game.to_play(), self.config.discount, min_max_stats) add_exploration_noise(self.config, root) # We then run a Monte Carlo Tree Search using only action sequences and the # model learned by the network. run_mcts(self.config, root, game.action_history(), self.network, min_max_stats) action = select_action(root, temperature) game.apply(action) game.store_search_statistics(root) return game
def _ucb_score(self, parent: Node, child: Node, min_max_stats: MinMaxStats) -> float: """ UCBの計算 """ pb_c = np.log((parent.visit_count + self.pb_c_base + 1) / self.pb_c_base) + self.pb_c_init pb_c *= np.sqrt(parent.visit_count) / (child.visit_count + 1) prior_score = pb_c * child.prior value_score = min_max_stats.normalize(child.value) return prior_score + value_score
def ucb_score(config: MuZeroConfig, parent: Node, child: Node, min_max_stats: MinMaxStats) -> float: pb_c = math.log((parent.visit_count + config.pb_c_base + 1) / config.pb_c_base) + config.pb_c_init pb_c *= math.sqrt(parent.visit_count) / (child.visit_count + 1) prior_score = pb_c * child.prior if child.visit_count > 0: value_score = min_max_stats.normalize(child.reward + config.discount * child.value()) else: value_score = 0 return prior_score + value_score
def run_mcts(self, root, num_moves): min_max_stats = MinMaxStats(self.config.known_bounds) for _ in range(self.config.num_simulations): # root.print() action, leaf, cur_moves = self.select_leaf(root, num_moves, min_max_stats) to_play = Player(cur_moves % self.config.game_config.num_players) batch_hidden_state = tf.expand_dims(leaf.parent.hidden_state, axis=0) network_output = self.network.recurrent_inference(batch_hidden_state, [action]).split_batch()[0] self.expand_node(node=leaf, to_play=to_play, actions=self.config.game_config.action_space, network_output=network_output) self.backpropagate(leaf, network_output.value, to_play, min_max_stats)
def make_move(self, game: Game) -> Action: root = Node() min_max_stats = MinMaxStats( known_bounds=self.config.value_config.known_bounds) observation = ObservationBatch( tf.expand_dims(game.history.make_image(), axis=0)) self.expand_node(node=root, actions=game.legal_actions(), network_output=self.network.initial_inference( observation).split_batch()[0]) self.add_exploration_noise(root) self.run_mcts(root, min_max_stats) action_space = self.config.action_space() policy = [ root.children[a].visit_count / root.visit_count if a in root.children else 0 for a in action_space ] game.store_search_statistics(root.value, Policy(tf.constant(policy))) return self.select_action(root, len(game.history))
def _backpropagate(self, search_path: List[Node], value: float, player: int, min_max_stats: MinMaxStats): for node in reversed(search_path): node.value_sum += value if node.player == player else -value node.visit_count += 1 min_max_stats.update(node.value) value = node.reward + self.discount * value
def ucb_score(self, node: Node, min_max_stats: MinMaxStats) -> float: exploitation_score = self.config.mcts_config.default_value if isnan( node.value) else node.reward + self.effective_discount * node.value exploration_score = node.prior * self.config.exploration_function( node.parent.visit_count, node.visit_count) return min_max_stats.normalize(exploitation_score) + exploration_score