Ejemplo n.º 1
0
def test_alive_sieve_sieve_self():
    batch_size = 12
    sieve = alive_sieve.AliveSieve(batch_size, enable_cuda=False)

    dead_mask = torch.ByteTensor([0, 1, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1])
    sieve.mark_dead(dead_mask)

    assert len(sieve.alive_mask) == 12
    assert len(sieve.alive_idxes) == 6
    assert (sieve.alive_mask != torch.ByteTensor(
        [1, 0, 1, 0, 1, 1, 1, 1, 0, 0, 0, 0])).max() == 0
    assert (sieve.alive_idxes != torch.LongTensor([0, 2, 4, 5, 6, 7
                                                   ])).max() == 0

    assert (sieve.out_idxes != torch.LongTensor(
        [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])).max() == 0
    assert sieve.batch_size == 12

    sieve.self_sieve_()
    assert (sieve.out_idxes != torch.LongTensor([0, 2, 4, 5, 6, 7])).max() == 0
    assert sieve.batch_size == 6

    # sieve again...
    sieve.mark_dead(torch.ByteTensor([0, 1, 1, 0, 0, 1]))
    assert (sieve.out_idxes != torch.LongTensor([0, 2, 4, 5, 6, 7])).max() == 0
    sieve.self_sieve_()
    assert (sieve.out_idxes != torch.LongTensor([0, 5, 6])).max() == 0
Ejemplo n.º 2
0
def test_alive_sieve_init():
    batch_size = 12
    sieve = alive_sieve.AliveSieve(batch_size, enable_cuda=False)

    assert len(sieve.alive_mask) == batch_size
    assert len(sieve.alive_idxes) == batch_size
    for b in range(batch_size):
        assert sieve.alive_mask[b] == 1
        assert sieve.alive_idxes[b] == b
Ejemplo n.º 3
0
def test_alive_sieve_mark_dead():
    batch_size = 12
    sieve = alive_sieve.AliveSieve(batch_size, enable_cuda=False)

    dead_mask = torch.ByteTensor([0, 1, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1])
    sieve.mark_dead(dead_mask)

    assert len(sieve.alive_mask) == 12
    assert len(sieve.alive_idxes) == 6
    assert (sieve.alive_mask != torch.ByteTensor(
        [1, 0, 1, 0, 1, 1, 1, 1, 0, 0, 0, 0])).max() == 0
    assert (sieve.alive_idxes != torch.LongTensor([0, 2, 4, 5, 6, 7
                                                   ])).max() == 0
Ejemplo n.º 4
0
def test_playback():
    batch_size = 12
    sieve = alive_sieve.AliveSieve(batch_size, enable_cuda=False)
    alive_masks = []

    dead_mask = torch.ByteTensor([0, 1, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1])
    sieve.mark_dead(dead_mask)
    alive_masks.append(sieve.alive_mask)
    sieve.self_sieve_()

    sieve.mark_dead(torch.ByteTensor([0, 1, 1, 0, 0, 1]))
    alive_masks.append(sieve.alive_mask)
    sieve.self_sieve_()

    sieve.mark_dead(torch.ByteTensor([0, 0, 1]))
    alive_masks.append(sieve.alive_mask)

    sieve.self_sieve_()
    sieve.mark_dead(torch.ByteTensor([1, 1]))
    alive_masks.append(sieve.alive_mask)

    # print('alive_masks', alive_masks)

    sieve = alive_sieve.SievePlayback(alive_masks, enable_cuda=False)
    ts = []
    global_idxes_s = []
    batch_sizes = []
    for t, global_idxes in sieve:
        # print('t', t)
        # print('global_idxes', global_idxes)
        ts.append(t)
        global_idxes_s.append(global_idxes)
        # print('sieve.batch_size', sieve.batch_size)
        batch_sizes.append(sieve.batch_size)
    assert len(ts) == 4
    assert ts[0] == 0
    assert ts[1] == 1
    assert ts[2] == 2
    assert ts[3] == 3
    assert (global_idxes_s[0] - torch.LongTensor(
        [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])).abs().max() == 0
    assert (global_idxes_s[1] -
            torch.LongTensor([0, 2, 4, 5, 6, 7])).abs().max() == 0
    assert (global_idxes_s[2] - torch.LongTensor([0, 5, 6])).abs().max() == 0
    assert (global_idxes_s[3] - torch.LongTensor([0, 5])).abs().max() == 0
    assert batch_sizes[0] == 12
    assert batch_sizes[1] == 6
    assert batch_sizes[2] == 3
    assert batch_sizes[3] == 2
Ejemplo n.º 5
0
def test_set_dead_global():
    batch_size = 12
    sieve = alive_sieve.AliveSieve(batch_size, enable_cuda=False)

    dead_mask = torch.ByteTensor([0, 1, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1])
    # so alive will be            1   1    1 1 1 1
    sieve.mark_dead(dead_mask)

    sieve.self_sieve_()

    sieve.mark_dead(torch.ByteTensor([0, 1, 1, 0, 0, 1]))

    target = torch.rand(batch_size, 3)
    target_orig = target.clone()
    new_v = torch.rand(3, 3)
    sieve.set_dead_global(target, new_v)

    assert target[2][0] == new_v[0][0]
    assert target[4][0] == new_v[1][0]
    assert target[7][0] == new_v[2][0]
Ejemplo n.º 6
0
def run_episode(
        batch,
        enable_cuda,
        enable_comms,
        enable_proposal,
        prosocial,
        agent_models,
        # batch_size,
        testing,
        render=False):
    """
    turning testing on means, we disable stochasticity: always pick the argmax
    """

    type_constr = torch.cuda if enable_cuda else torch
    batch_size = batch['N'].size()[0]
    s = State(**batch)
    if enable_cuda:
        s.cuda()

    sieve = alive_sieve.AliveSieve(batch_size=batch_size,
                                   enable_cuda=enable_cuda)
    actions_by_timestep = []
    alive_masks = []

    # next two tensofrs wont be sieved, they will stay same size throughout
    # entire batch, we will update them using sieve.out_idxes[...]
    rewards = type_constr.FloatTensor(batch_size, SEQ_LEN).fill_(0)
    num_steps = type_constr.LongTensor(batch_size).fill_(10)
    term_matches_argmax_count = 0
    utt_matches_argmax_count = 0
    utt_stochastic_draws = 0
    num_policy_runs = 0
    prop_matches_argmax_count = 0
    prop_stochastic_draws = 0

    entropy_loss_by_agent = [
        Variable(type_constr.FloatTensor(1).fill_(0)),
        Variable(type_constr.FloatTensor(1).fill_(0))
    ]
    if render:
        print('  ')
    for t in range(10):
        agent = t % 2

        agent_model = agent_models[agent]
        if enable_comms:
            _prev_message = s.m_prev
        else:
            # we dont strictly need to blank them, since they'll be all zeros anyway,
            # but defense in depth and all that :)
            _prev_message = type_constr.LongTensor(sieve.batch_size,
                                                   6).fill_(0)
        if enable_proposal:
            _prev_proposal = s.last_proposal
        else:
            # we do need to blank this one though :)
            _prev_proposal = type_constr.LongTensor(sieve.batch_size,
                                                    SEQ_LEN).fill_(0)
        nodes, term_a, s.m_prev, this_proposal, _entropy_loss, \
                _term_matches_argmax_count, _utt_matches_argmax_count, _utt_stochastic_draws, \
                _prop_matches_argmax_count, _prop_stochastic_draws = agent_model(
            pool=Variable(s.pool),
            utility=Variable(s.utilities[:, agent]),
            m_prev=Variable(s.m_prev),
            prev_proposal=Variable(_prev_proposal),
            testing=testing
        )
        entropy_loss_by_agent[agent] += _entropy_loss
        actions_by_timestep.append(nodes)
        term_matches_argmax_count += _term_matches_argmax_count
        num_policy_runs += sieve.batch_size
        utt_matches_argmax_count += _utt_matches_argmax_count
        utt_stochastic_draws += _utt_stochastic_draws
        prop_matches_argmax_count += _prop_matches_argmax_count
        prop_stochastic_draws += _prop_stochastic_draws

        if render and sieve.out_idxes[0] == 0:
            render_action(t=t, s=s, term=term_a, prop=this_proposal)

        new_rewards = rewards_lib.calc_rewards(t=t, s=s, term=term_a)
        rewards[sieve.out_idxes] = new_rewards
        s.last_proposal = this_proposal

        sieve.mark_dead(term_a)
        sieve.mark_dead(t + 1 >= s.N)
        alive_masks.append(sieve.alive_mask.clone())
        sieve.set_dead_global(num_steps, t + 1)
        if sieve.all_dead():
            break

        s.sieve_(sieve.alive_idxes)
        sieve.self_sieve_()

    if render:
        print('  r: %.2f' % rewards[0].mean())
        print('  ')

    return actions_by_timestep, rewards, num_steps, alive_masks, entropy_loss_by_agent, \
        term_matches_argmax_count, num_policy_runs, utt_matches_argmax_count, utt_stochastic_draws, \
        prop_matches_argmax_count, prop_stochastic_draws