class VMCTSPlayer(object):
    def __init__(self, cfg: OthelloConfig):
        self._name = "Vanilla MCTS"
        self._cfg = cfg
        self._game = Othello(self._cfg)
        self._node = VMCTSNode(self._cfg, self._game)

    def name(self) -> str:
        return self._name

    def game(self) -> Othello:
        return self._game

    def play(self, action: int):
        if action not in self._game.legal_actions():
            raise ValueError(str(action) + " is invalid move")
        child_node = self._node.child(action)
        if child_node is None:
            child_game = self._game.clone()
            child_game.apply_action(action)
            child_node = VMCTSNode(self._cfg, child_game)
        self._node = child_node
        self._game = self._node.game()

    def choose_action(self) -> int:
        if self._game.is_terminal():
            return -1
        for sim in range(self._cfg.num_simulations_vmcts):
            vmcts(self._node, self._cfg)
        action = self._node.select_optimal_action()
        return action
示例#2
0
 def __init__(self, cfg: OthelloConfig, network: Network,
              device: torch.device):
     self._name = "AlphaZero"
     self._cfg = cfg
     self._network = network
     self._network.eval()
     self._device = device
     self._game = Othello(self._cfg)
     self._node, *_ = Node.get_new_node(self._cfg, self._game,
                                        self._network, self._device)
示例#3
0
 def get_new_node(
         cfg: OthelloConfig, game: Othello, network: Network,
         device: torch.device) -> Tuple[Node, np.ndarray, np.ndarray]:
     state_tensor = image_to_tensor(game.make_input_image(), device)
     with torch.no_grad():
         p, v = network.inference(state_tensor)
         actions_mask = torch.as_tensor(game.legal_actions_mask(),
                                        dtype=torch.float32).to(device)
         p = filter_legal_action_probs(p.unsqueeze(0),
                                       actions_mask.unsqueeze(0))
         p.squeeze(0)
     p, v = p.cpu().numpy(), v.cpu().numpy()
     new_node = Node(cfg, game, p)
     return new_node, p, v
 def __init__(self, name: str, message_queue: Queue, log_queue: Queue,
              shared_state_dicts: Dict[str, Union[Dict[str, torch.Tensor],
                                                  OrderedDict[str,
                                                              torch.Tensor],
                                                  int]],
              replay_buffer: ReplayBuffer, device_name: str,
              cfg: OthelloConfig):
     super().__init__(name=name)
     self._message_queue = message_queue
     self._log_queue = log_queue
     self._shared_state_dicts = shared_state_dicts
     self._replay_buffer = replay_buffer
     self._cfg = cfg
     self._device = torch.device(device_name)
     self._network = Network()
     self._game = Othello(self._cfg)
     self._interrupted = False
示例#5
0
def generate_training_data(
    cfg: OthelloConfig, game: Othello, target_policies: np.ndarray,
    final_returns: np.ndarray
) -> List[Tuple[np.ndarray, np.ndarray, float, np.ndarray]]:
    assert len(target_policies) == len(game)
    dq = deque(maxlen=cfg.total_input_channels // 2)
    training_data = []  # list of (input_image, pi, z, action_mask)
    for _ in range(cfg.total_input_channels // 2):
        dq.appendleft(np.zeros((2, 8, 8), dtype=np.bool))
    for i in range(len(game)):
        img = game.history_state(i)
        player = game.history_player(i)
        action_mask = game.history_actions_mask(i)
        dq.appendleft(img)
        x = np.zeros((cfg.total_input_channels, 8, 8), dtype=np.bool)
        for ch, img in enumerate(dq):
            x[ch] += img[0]
            x[(cfg.total_input_channels // 2) + ch] += img[1]
        x[-1] += bool(player)
        training_data.append(
            (x, target_policies[i], float(final_returns[player]), action_mask))
    return training_data
示例#6
0
class HumanPlayer(object):
    def __init__(self, cfg: OthelloConfig):
        self._name = "Human"
        self._cfg = cfg
        self._game = Othello(self._cfg)

    def name(self) -> str:
        return self._name

    def game(self) -> Othello:
        return self._game

    def play(self, action: int):
        if action not in self._game.legal_actions():
            raise ValueError(str(action) + " is invalid move")
        self._game.apply_action(action)

    def choose_action(self) -> int:
        if self._game.is_terminal():
            return -1
        action = int(input("Choose an action: "))
        return action
示例#7
0
class AZPlayer(object):
    def __init__(self, cfg: OthelloConfig, network: Network,
                 device: torch.device):
        self._name = "AlphaZero"
        self._cfg = cfg
        self._network = network
        self._network.eval()
        self._device = device
        self._game = Othello(self._cfg)
        self._node, *_ = Node.get_new_node(self._cfg, self._game,
                                           self._network, self._device)

    def name(self) -> str:
        return self._name

    def game(self) -> Othello:
        return self._game

    def play(self, action: int):
        if action not in self._game.legal_actions():
            raise ValueError(str(action) + " is invalid move")
        child_node = self._node.child(action)
        if child_node is None:
            child_game = self._game.clone()
            child_game.apply_action(action)
            child_node, *_ = Node.get_new_node(self._cfg, child_game,
                                               self._network, self._device)
        self._node = child_node
        self._game = self._node.game()

    def choose_action(self) -> int:
        if self._game.is_terminal():
            return -1
        for sim in range(self._cfg.num_simulations_eval_player):
            mcts(self._node, self._cfg, self._network, self._device)
        action = self._node.select_optimal_action()
        return action
 def __init__(self, cfg: OthelloConfig):
     self._name = "Vanilla MCTS"
     self._cfg = cfg
     self._game = Othello(self._cfg)
     self._node = VMCTSNode(self._cfg, self._game)
示例#9
0
 def __init__(self, cfg: OthelloConfig):
     self._name = "Human"
     self._cfg = cfg
     self._game = Othello(self._cfg)
示例#10
0
class SelfPlayWorker(Process):
    def __init__(self, name: str, message_queue: Queue, log_queue: Queue,
                 shared_state_dicts: Dict[str, Union[Dict[str, torch.Tensor],
                                                     OrderedDict[str,
                                                                 torch.Tensor],
                                                     int]],
                 replay_buffer: ReplayBuffer, device_name: str,
                 cfg: OthelloConfig):
        super().__init__(name=name)
        self._message_queue = message_queue
        self._log_queue = log_queue
        self._shared_state_dicts = shared_state_dicts
        self._replay_buffer = replay_buffer
        self._cfg = cfg
        self._device = torch.device(device_name)
        self._network = Network()
        self._game = Othello(self._cfg)
        self._interrupted = False

    def run(self):
        print(super().name, "started.")
        self._network.to(self._device).eval()
        while True:
            self._load_latest_network()
            t1 = time.time()
            self._game.reset()
            target_policies = []
            node, *_ = Node.get_new_node(self._cfg, self._game, self._network,
                                         self._device)
            while not self._game.is_terminal():
                self._check_message_queue()
                if self._interrupted:
                    break
                for _ in range(self._cfg.num_simulations):
                    mcts(node, self._cfg, self._network, self._device)
                target_policy = node.get_policy()
                action = node.select_optimal_action()
                target_policies.append(target_policy)
                child = node.child(action)
                self._game = child.game()
                node = child
            if self._interrupted:
                break
            final_returns = np.array(self._game.returns()).astype(np.float32)
            target_policies = np.array(target_policies).astype(np.float32)
            training_data = generate_training_data(self._cfg, self._game,
                                                   target_policies,
                                                   final_returns)
            self._replay_buffer.save_training_data(training_data)
            t2 = time.time()
            if self._cfg.debug:
                print(super().name, "completed one simulation in", t2 - t1,
                      "seconds.")
        print(super().name, "terminated.")

    def _check_message_queue(self):
        if not self._message_queue.empty():
            msg = self._message_queue.get()
            if msg == self._cfg.message_interrupt:
                self._interrupted = True

    # noinspection DuplicatedCode
    def _load_latest_network(self):
        while True:
            try:
                state_dict = self._shared_state_dicts["network"]
                for k, v in state_dict.items():
                    state_dict[k] = v.to(self._device)
                self._network.load_state_dict(state_dict)
                self._network.eval()
                return
            except KeyError:
                pass
            self._check_message_queue()
            if self._interrupted:
                return
            time.sleep(1.0)