コード例 #1
0
ファイル: Supervised.py プロジェクト: thomasj02/nfsp-pytorch
    def action_prob(self, infoset: LeducInfoset):
        state = infoset_to_state(infoset)
        state = torch.from_numpy(
            np.array(state)).float().unsqueeze(0).to(device)
        nn_retval = self.network.forward(state).cpu().detach()
        nn_retval = nn.Softmax(dim=1)(nn_retval)

        retval = nn_retval.cpu().detach().numpy()[0]

        return retval
コード例 #2
0
ファイル: Agent.py プロジェクト: thomasj02/nfsp-pytorch
    def action_prob(self, infoset: LeducInfoset):
        state = infoset_to_state(infoset)

        if self.use_q_policy:
            retval = self.leduc_rl_policy.action_prob(infoset)
        else:
            retval = self.leduc_supervised_policy.action_prob(infoset)

        self.last_state = state

        return retval
コード例 #3
0
    def test_notify_reward(self):
        self.sut = LeducPoker.NFSP.Agent.NfspAgent(self.mock_q_policy, self.mock_supervised_trainer, nu=0)
        self.sut.leduc_supervised_policy.action_prob = MagicMock(return_value=[0, 1, 0])

        infoset = LeducInfoset(card=1, bet_sequences=[(PlayerActions.CHECK_CALL,), ()], board_card=None)
        infoset_state = infoset_to_state(infoset)
        self.sut.get_action(infoset)

        self.mock_q_policy.add_sars = MagicMock()

        infoset_next = LeducInfoset(card=1, bet_sequences=[(PlayerActions.CHECK_CALL, PlayerActions.BET_RAISE), ()], board_card=None)
        infoset_next_state = infoset_to_state(infoset_next)

        self.sut.notify_reward(next_infoset=infoset_next, reward=123, is_terminal=True)

        # call_args[0] are the position args
        self.assertEqual(self.mock_q_policy.add_sars.call_args[0], tuple())
        self.assertEqual(self.mock_q_policy.add_sars.call_args[1]["state"].tolist(), infoset_state.tolist())
        self.assertEqual(self.mock_q_policy.add_sars.call_args[1]["action"], PlayerActions.CHECK_CALL)
        self.assertEqual(self.mock_q_policy.add_sars.call_args[1]["reward"], 123)
        self.assertEqual(self.mock_q_policy.add_sars.call_args[1]["next_state"].tolist(), infoset_next_state.tolist())
        self.assertEqual(self.mock_q_policy.add_sars.call_args[1]["is_terminal"], True)
コード例 #4
0
ファイル: Trainer.py プロジェクト: thomasj02/nfsp-pytorch
def make_agent(q_policy_parameters, supervised_trainer_parameters, nu):
    network_units = [64]
    state_size = infoset_to_state(LeducInfoset(card=0, bet_sequences=[(), ()], board_card=None)).shape[0]
    q_network_local = QNetwork(state_size=state_size, action_size=3, hidden_units=network_units).to(device)
    #q_network_target = QNetwork(state_size=state_size, action_size=3, hidden_units=network_units).to(device)
    q_network_target = None

    q_policy = QPolicy(
        nn_local=q_network_local,
        nn_target=q_network_target,
        parameters=q_policy_parameters)

    supervised_network = SupervisedNetwork(state_size=state_size, action_size=3, hidden_units=network_units).to(device)
    supervised_trainer = SupervisedTrainer(
        supervised_trainer_parameters=supervised_trainer_parameters, network=supervised_network)

    return NfspAgent(q_policy=q_policy, supervised_trainer=supervised_trainer, nu=nu)
コード例 #5
0
ファイル: Trainer.py プロジェクト: thomasj02/nfsp-pytorch
def log_qvals(
        writer: SummaryWriter,
        policy: QPolicy,
        infoset: Optional[LeducInfoset],
        global_step: int,
        text_only: bool):
    def recurse(new_action):
        after_action_infoset = copy.deepcopy(infoset)
        after_action_infoset.add_action(new_action)
        log_qvals(writer, policy, after_action_infoset, global_step, text_only)

    if infoset is None:
        for card in range(3):
            infoset = LeducInfoset(card, bet_sequences=[(), ()], board_card=None)
            log_qvals(writer, policy, infoset, global_step, text_only)
    elif infoset.player_to_act == -1:
        for board_card in range(3):
            infoset = LeducInfoset(card=infoset.card, bet_sequences=infoset.bet_sequences, board_card=board_card)
            log_qvals(writer, policy, infoset, global_step, text_only)
    elif infoset.is_terminal:
        return
    else:
        state = infoset_to_state(infoset)
        state = torch.from_numpy(state).float().unsqueeze(0).to(device)
        q_vals = policy.qnetwork_local.forward(state).cpu().numpy()[0]

        node_name = "q_vals/" + str(infoset)
        node_name = node_name.replace(":", "_")

        for action in PlayerActions.ALL_ACTIONS:
            if action == PlayerActions.FOLD and infoset.can_fold:
                if not text_only:
                    writer.add_scalar(node_name+"/f", q_vals[action], global_step=global_step)
                logger.debug("Epoch %s QValue %s %s", e, node_name+"/f", q_vals[action])
                recurse(action)
            elif action == PlayerActions.BET_RAISE and infoset.can_raise:
                if not text_only:
                    writer.add_scalar(node_name+"/r", q_vals[action], global_step=global_step)
                logger.debug("Epoch %s QValue %s %s", e, node_name+"/r", q_vals[action])
                recurse(action)
            elif action == PlayerActions.CHECK_CALL:
                if not text_only:
                    writer.add_scalar(node_name + "/c", q_vals[action], global_step=global_step)
                logger.debug("Epoch %s QValue %s %s", e, node_name+"/c", q_vals[action])
                recurse(action)
コード例 #6
0
ファイル: Agent.py プロジェクト: thomasj02/nfsp-pytorch
    def notify_reward(self, next_infoset: Optional[LeducInfoset],
                      reward: float, is_terminal: bool):
        if self.last_action is None:
            assert reward == 0
            return False

        if next_infoset is None:
            assert is_terminal

        assert self.last_state is not None
        assert self.last_action is not None

        next_state = infoset_to_state(next_infoset)
        self.q_policy.add_sars(state=self.last_state,
                               action=self.last_action,
                               reward=reward,
                               next_state=next_state,
                               is_terminal=is_terminal)
        return True
コード例 #7
0
    def test_action_prob_q(self):
        self.sut = LeducPoker.NFSP.Agent.NfspAgent(self.mock_q_policy, self.mock_supervised_trainer, nu=1.1)

        self.sut.use_q_policy = True
        self.sut.leduc_rl_policy.get_action = MagicMock(return_value=1)
        self.sut.supervised_trainer.add_observation = MagicMock()

        infoset = LeducInfoset(card=1, bet_sequences=[(PlayerActions.CHECK_CALL,), ()], board_card=None)
        infoset_state = infoset_to_state(infoset)

        retval = self.sut.action_prob(infoset)

        self.assertListEqual([0, 1, 0], retval.tolist())
        self.assertEqual(infoset_state.tolist(), self.sut.last_state.tolist())

        self.sut.leduc_rl_policy.get_action.assert_called_with(infoset)

        self.assertEqual(self.sut.supervised_trainer.add_observation.call_args[0][0].tolist(), infoset_state.tolist())
        self.assertEqual(self.sut.supervised_trainer.add_observation.call_args[0][1], 1)
コード例 #8
0
ファイル: Dqn.py プロジェクト: thomasj02/nfsp-pytorch
 def get_action(self, infoset: LeducInfoset) -> PlayerActions:
     state = infoset_to_state(infoset)
     q_policy_action, self.last_action_greedy = self.q_policy.act(
         state, greedy=False)
     return q_policy_action
コード例 #9
0
ファイル: Dqn.py プロジェクト: thomasj02/nfsp-pytorch
 def action_prob(self, infoset: LeducInfoset):
     state = infoset_to_state(infoset)
     return self.q_policy.get_action_probs(state)