Ejemplo n.º 1
0
class BasicPlayer(Player):
    def __init__(self, game, num_playouts: int, c_puct: float=5):
        self.num_playouts = num_playouts
        self._mcts = MCTS(partial(self.policy_value_fn, to_dict=True), c_puct)
        self._self_play = False
        self._c_puct = c_puct
        super(BasicPlayer, self).__init__(game)

    def get_action(self, last_move, return_probs=False, temperature=0.1):
        if last_move:
            self._mcts.update_with_move(last_move)
        action_probs = self._mcts.get_action_probs(self._game, num_playouts=self.num_playouts, temperature=temperature)
        actions, probs = zip(*action_probs.items())
        i = np.random.choice(np.arange(len(actions)), p=np.asarray(probs))
        print(actions, probs, i)
        action = actions[i]
        self._mcts.update_with_move(action)
        if return_probs:
            full_probs = np.zeros(config.BOARD_SIZE*config.BOARD_SIZE)
            rows, cols = zip(*actions)
            full_probs[np.array(rows) * config.BOARD_SIZE + np.array(cols)] = np.array(probs)
            return action, full_probs.ravel()
        else:
            return action

    def policy_value_fn(self, board, to_dict=False):
        value = 0
        moves = available_moves(board)
        p = 1 / len(moves)
        prior_probs = {move: p for move in moves}
        return prior_probs, value

    def reset(self):
        self._mcts = MCTS(partial(self.policy_value_fn, to_dict=True), self._c_puct)
Ejemplo n.º 2
0
class AlphaPlayer(Player):
    def __init__(self, game, network: nn.Module, c_puct: float=5):
        self._network = network
        network.eval()
        self._c_puct = c_puct
        self._mcts = MCTS(partial(self.policy_value_fn, to_dict=True), c_puct)
        self._self_play = False
        super(AlphaPlayer, self).__init__(game)

    def get_action(self, last_move, return_probs=False, temperature=0.1):
        if last_move:
            self._mcts.update_with_move(last_move)
        action_probs = self._mcts.get_action_probs(self._game, num_playouts=config.NUM_PLAYOUTS, temperature=temperature)
        actions, probs = zip(*action_probs.items())
        if self._self_play:
            i = np.random.choice(np.arange(len(actions)), p=np.asarray(probs)*0.75+np.random.dirichlet(0.3*np.ones(len(actions)))*0.25)
        else:
            i = np.random.choice(np.arange(len(actions)), p=np.asarray(probs))
        action = actions[i]
        self._mcts.update_with_move(action)
        if return_probs:
            full_probs = np.zeros(config.BOARD_SIZE*config.BOARD_SIZE)
            rows, cols = zip(*actions)
            full_probs[np.array(rows) * config.BOARD_SIZE + np.array(cols)] = np.array(probs)
            return action, full_probs.ravel()
        else:
            return action

    def set_self_play(self, value: bool):
        self._self_play = value

    def policy_value_fn(self, board, to_dict=False):
        x = board_to_state(board)
        x = th.tensor(x).float().to(self._network.device).unsqueeze(0)
        with th.no_grad():
            prior_probs, value = self._network(x)
        if to_dict:
            prior_probs = prior_probs.cpu().view(*board.shape).numpy()
            moves = available_moves(board)
            rows, cols = zip(*moves)
            prior_probs = dict(zip(moves, prior_probs[np.array(rows), np.array(cols)]))
            value = value[0, 0].item()
        return prior_probs, value

    def reset(self):
        self._mcts = MCTS(partial(self.policy_value_fn, to_dict=True), self._c_puct)