Beispiel #1
0
def log_strategy(writer: SummaryWriter, policy: NnPolicyWrapper,
                 global_step: int):
    infoset = KuhnInfoset(0, ())

    for card in range(3):
        infoset.card = card

        infoset.bet_sequence = ()
        aggressive_action_prob = policy.aggressive_action_prob(infoset)
        node_name = "strategy/%s/p0_open" % card_to_str(card)
        writer.add_scalar(node_name,
                          aggressive_action_prob,
                          global_step=global_step)

        infoset.bet_sequence = (0, )
        aggressive_action_prob = policy.aggressive_action_prob(infoset)
        node_name = "strategy/%s/p0_check/p1" % card_to_str(card)
        writer.add_scalar(node_name,
                          aggressive_action_prob,
                          global_step=global_step)

        infoset.bet_sequence = (0, 1)
        aggressive_action_prob = policy.aggressive_action_prob(infoset)
        node_name = "strategy/%s/p0_check/p1_bet/p0" % card_to_str(card)
        writer.add_scalar(node_name,
                          aggressive_action_prob,
                          global_step=global_step)

        infoset.bet_sequence = (1, )
        aggressive_action_prob = policy.aggressive_action_prob(infoset)
        node_name = "strategy/%s/p0_bet/p1" % card_to_str(card)
        writer.add_scalar(node_name,
                          aggressive_action_prob,
                          global_step=global_step)
Beispiel #2
0
def log_qvals(writer: SummaryWriter, policy: QPolicy, global_step: int):
    infoset = KuhnInfoset(0, ())

    for card in range(3):
        infoset.card = card

        infoset.bet_sequence = ()
        state = torch.from_numpy(
            infoset_to_state(infoset)).float().unsqueeze(0).to(device)
        q_vals = policy.qnetwork_local.forward(state).cpu().numpy()[0]
        node_name = "q_vals/%s/p0_open" % card_to_str(card)
        writer.add_scalar(node_name,
                          q_vals[1] - q_vals[0],
                          global_step=global_step)

        infoset.bet_sequence = (0, )
        state = torch.from_numpy(
            infoset_to_state(infoset)).float().unsqueeze(0).to(device)
        q_vals = policy.qnetwork_local.forward(state).cpu().numpy()[0]
        node_name = "q_vals/%s/p0_check/p1" % card_to_str(card)
        writer.add_scalar(node_name,
                          q_vals[1] - q_vals[0],
                          global_step=global_step)

        infoset.bet_sequence = (0, 1)
        state = torch.from_numpy(
            infoset_to_state(infoset)).float().unsqueeze(0).to(device)
        q_vals = policy.qnetwork_local.forward(state).cpu().numpy()[0]
        node_name = "q_vals/%s/p0_check/p1_bet/p0" % card_to_str(card)
        writer.add_scalar(node_name,
                          q_vals[1] - q_vals[0],
                          global_step=global_step)

        infoset.bet_sequence = (1, )
        state = torch.from_numpy(
            infoset_to_state(infoset)).float().unsqueeze(0).to(device)
        q_vals = policy.qnetwork_local.forward(state).cpu().numpy()[0]
        node_name = "q_vals/%s/p0_bet/p1" % card_to_str(card)
        writer.add_scalar(node_name,
                          q_vals[1] - q_vals[0],
                          global_step=global_step)