class VMCTSPlayer(object): def __init__(self, cfg: OthelloConfig): self._name = "Vanilla MCTS" self._cfg = cfg self._game = Othello(self._cfg) self._node = VMCTSNode(self._cfg, self._game) def name(self) -> str: return self._name def game(self) -> Othello: return self._game def play(self, action: int): if action not in self._game.legal_actions(): raise ValueError(str(action) + " is invalid move") child_node = self._node.child(action) if child_node is None: child_game = self._game.clone() child_game.apply_action(action) child_node = VMCTSNode(self._cfg, child_game) self._node = child_node self._game = self._node.game() def choose_action(self) -> int: if self._game.is_terminal(): return -1 for sim in range(self._cfg.num_simulations_vmcts): vmcts(self._node, self._cfg) action = self._node.select_optimal_action() return action
def __init__(self, cfg: OthelloConfig, network: Network, device: torch.device): self._name = "AlphaZero" self._cfg = cfg self._network = network self._network.eval() self._device = device self._game = Othello(self._cfg) self._node, *_ = Node.get_new_node(self._cfg, self._game, self._network, self._device)
def get_new_node( cfg: OthelloConfig, game: Othello, network: Network, device: torch.device) -> Tuple[Node, np.ndarray, np.ndarray]: state_tensor = image_to_tensor(game.make_input_image(), device) with torch.no_grad(): p, v = network.inference(state_tensor) actions_mask = torch.as_tensor(game.legal_actions_mask(), dtype=torch.float32).to(device) p = filter_legal_action_probs(p.unsqueeze(0), actions_mask.unsqueeze(0)) p.squeeze(0) p, v = p.cpu().numpy(), v.cpu().numpy() new_node = Node(cfg, game, p) return new_node, p, v
def __init__(self, name: str, message_queue: Queue, log_queue: Queue, shared_state_dicts: Dict[str, Union[Dict[str, torch.Tensor], OrderedDict[str, torch.Tensor], int]], replay_buffer: ReplayBuffer, device_name: str, cfg: OthelloConfig): super().__init__(name=name) self._message_queue = message_queue self._log_queue = log_queue self._shared_state_dicts = shared_state_dicts self._replay_buffer = replay_buffer self._cfg = cfg self._device = torch.device(device_name) self._network = Network() self._game = Othello(self._cfg) self._interrupted = False
def generate_training_data( cfg: OthelloConfig, game: Othello, target_policies: np.ndarray, final_returns: np.ndarray ) -> List[Tuple[np.ndarray, np.ndarray, float, np.ndarray]]: assert len(target_policies) == len(game) dq = deque(maxlen=cfg.total_input_channels // 2) training_data = [] # list of (input_image, pi, z, action_mask) for _ in range(cfg.total_input_channels // 2): dq.appendleft(np.zeros((2, 8, 8), dtype=np.bool)) for i in range(len(game)): img = game.history_state(i) player = game.history_player(i) action_mask = game.history_actions_mask(i) dq.appendleft(img) x = np.zeros((cfg.total_input_channels, 8, 8), dtype=np.bool) for ch, img in enumerate(dq): x[ch] += img[0] x[(cfg.total_input_channels // 2) + ch] += img[1] x[-1] += bool(player) training_data.append( (x, target_policies[i], float(final_returns[player]), action_mask)) return training_data
class HumanPlayer(object): def __init__(self, cfg: OthelloConfig): self._name = "Human" self._cfg = cfg self._game = Othello(self._cfg) def name(self) -> str: return self._name def game(self) -> Othello: return self._game def play(self, action: int): if action not in self._game.legal_actions(): raise ValueError(str(action) + " is invalid move") self._game.apply_action(action) def choose_action(self) -> int: if self._game.is_terminal(): return -1 action = int(input("Choose an action: ")) return action
class AZPlayer(object): def __init__(self, cfg: OthelloConfig, network: Network, device: torch.device): self._name = "AlphaZero" self._cfg = cfg self._network = network self._network.eval() self._device = device self._game = Othello(self._cfg) self._node, *_ = Node.get_new_node(self._cfg, self._game, self._network, self._device) def name(self) -> str: return self._name def game(self) -> Othello: return self._game def play(self, action: int): if action not in self._game.legal_actions(): raise ValueError(str(action) + " is invalid move") child_node = self._node.child(action) if child_node is None: child_game = self._game.clone() child_game.apply_action(action) child_node, *_ = Node.get_new_node(self._cfg, child_game, self._network, self._device) self._node = child_node self._game = self._node.game() def choose_action(self) -> int: if self._game.is_terminal(): return -1 for sim in range(self._cfg.num_simulations_eval_player): mcts(self._node, self._cfg, self._network, self._device) action = self._node.select_optimal_action() return action
def __init__(self, cfg: OthelloConfig): self._name = "Vanilla MCTS" self._cfg = cfg self._game = Othello(self._cfg) self._node = VMCTSNode(self._cfg, self._game)
def __init__(self, cfg: OthelloConfig): self._name = "Human" self._cfg = cfg self._game = Othello(self._cfg)
class SelfPlayWorker(Process): def __init__(self, name: str, message_queue: Queue, log_queue: Queue, shared_state_dicts: Dict[str, Union[Dict[str, torch.Tensor], OrderedDict[str, torch.Tensor], int]], replay_buffer: ReplayBuffer, device_name: str, cfg: OthelloConfig): super().__init__(name=name) self._message_queue = message_queue self._log_queue = log_queue self._shared_state_dicts = shared_state_dicts self._replay_buffer = replay_buffer self._cfg = cfg self._device = torch.device(device_name) self._network = Network() self._game = Othello(self._cfg) self._interrupted = False def run(self): print(super().name, "started.") self._network.to(self._device).eval() while True: self._load_latest_network() t1 = time.time() self._game.reset() target_policies = [] node, *_ = Node.get_new_node(self._cfg, self._game, self._network, self._device) while not self._game.is_terminal(): self._check_message_queue() if self._interrupted: break for _ in range(self._cfg.num_simulations): mcts(node, self._cfg, self._network, self._device) target_policy = node.get_policy() action = node.select_optimal_action() target_policies.append(target_policy) child = node.child(action) self._game = child.game() node = child if self._interrupted: break final_returns = np.array(self._game.returns()).astype(np.float32) target_policies = np.array(target_policies).astype(np.float32) training_data = generate_training_data(self._cfg, self._game, target_policies, final_returns) self._replay_buffer.save_training_data(training_data) t2 = time.time() if self._cfg.debug: print(super().name, "completed one simulation in", t2 - t1, "seconds.") print(super().name, "terminated.") def _check_message_queue(self): if not self._message_queue.empty(): msg = self._message_queue.get() if msg == self._cfg.message_interrupt: self._interrupted = True # noinspection DuplicatedCode def _load_latest_network(self): while True: try: state_dict = self._shared_state_dicts["network"] for k, v in state_dict.items(): state_dict[k] = v.to(self._device) self._network.load_state_dict(state_dict) self._network.eval() return except KeyError: pass self._check_message_queue() if self._interrupted: return time.sleep(1.0)