def test_alive_sieve_sieve_self(): batch_size = 12 sieve = alive_sieve.AliveSieve(batch_size, enable_cuda=False) dead_mask = torch.ByteTensor([0, 1, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1]) sieve.mark_dead(dead_mask) assert len(sieve.alive_mask) == 12 assert len(sieve.alive_idxes) == 6 assert (sieve.alive_mask != torch.ByteTensor( [1, 0, 1, 0, 1, 1, 1, 1, 0, 0, 0, 0])).max() == 0 assert (sieve.alive_idxes != torch.LongTensor([0, 2, 4, 5, 6, 7 ])).max() == 0 assert (sieve.out_idxes != torch.LongTensor( [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])).max() == 0 assert sieve.batch_size == 12 sieve.self_sieve_() assert (sieve.out_idxes != torch.LongTensor([0, 2, 4, 5, 6, 7])).max() == 0 assert sieve.batch_size == 6 # sieve again... sieve.mark_dead(torch.ByteTensor([0, 1, 1, 0, 0, 1])) assert (sieve.out_idxes != torch.LongTensor([0, 2, 4, 5, 6, 7])).max() == 0 sieve.self_sieve_() assert (sieve.out_idxes != torch.LongTensor([0, 5, 6])).max() == 0
def test_alive_sieve_init(): batch_size = 12 sieve = alive_sieve.AliveSieve(batch_size, enable_cuda=False) assert len(sieve.alive_mask) == batch_size assert len(sieve.alive_idxes) == batch_size for b in range(batch_size): assert sieve.alive_mask[b] == 1 assert sieve.alive_idxes[b] == b
def test_alive_sieve_mark_dead(): batch_size = 12 sieve = alive_sieve.AliveSieve(batch_size, enable_cuda=False) dead_mask = torch.ByteTensor([0, 1, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1]) sieve.mark_dead(dead_mask) assert len(sieve.alive_mask) == 12 assert len(sieve.alive_idxes) == 6 assert (sieve.alive_mask != torch.ByteTensor( [1, 0, 1, 0, 1, 1, 1, 1, 0, 0, 0, 0])).max() == 0 assert (sieve.alive_idxes != torch.LongTensor([0, 2, 4, 5, 6, 7 ])).max() == 0
def test_playback(): batch_size = 12 sieve = alive_sieve.AliveSieve(batch_size, enable_cuda=False) alive_masks = [] dead_mask = torch.ByteTensor([0, 1, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1]) sieve.mark_dead(dead_mask) alive_masks.append(sieve.alive_mask) sieve.self_sieve_() sieve.mark_dead(torch.ByteTensor([0, 1, 1, 0, 0, 1])) alive_masks.append(sieve.alive_mask) sieve.self_sieve_() sieve.mark_dead(torch.ByteTensor([0, 0, 1])) alive_masks.append(sieve.alive_mask) sieve.self_sieve_() sieve.mark_dead(torch.ByteTensor([1, 1])) alive_masks.append(sieve.alive_mask) # print('alive_masks', alive_masks) sieve = alive_sieve.SievePlayback(alive_masks, enable_cuda=False) ts = [] global_idxes_s = [] batch_sizes = [] for t, global_idxes in sieve: # print('t', t) # print('global_idxes', global_idxes) ts.append(t) global_idxes_s.append(global_idxes) # print('sieve.batch_size', sieve.batch_size) batch_sizes.append(sieve.batch_size) assert len(ts) == 4 assert ts[0] == 0 assert ts[1] == 1 assert ts[2] == 2 assert ts[3] == 3 assert (global_idxes_s[0] - torch.LongTensor( [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])).abs().max() == 0 assert (global_idxes_s[1] - torch.LongTensor([0, 2, 4, 5, 6, 7])).abs().max() == 0 assert (global_idxes_s[2] - torch.LongTensor([0, 5, 6])).abs().max() == 0 assert (global_idxes_s[3] - torch.LongTensor([0, 5])).abs().max() == 0 assert batch_sizes[0] == 12 assert batch_sizes[1] == 6 assert batch_sizes[2] == 3 assert batch_sizes[3] == 2
def test_set_dead_global(): batch_size = 12 sieve = alive_sieve.AliveSieve(batch_size, enable_cuda=False) dead_mask = torch.ByteTensor([0, 1, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1]) # so alive will be 1 1 1 1 1 1 sieve.mark_dead(dead_mask) sieve.self_sieve_() sieve.mark_dead(torch.ByteTensor([0, 1, 1, 0, 0, 1])) target = torch.rand(batch_size, 3) target_orig = target.clone() new_v = torch.rand(3, 3) sieve.set_dead_global(target, new_v) assert target[2][0] == new_v[0][0] assert target[4][0] == new_v[1][0] assert target[7][0] == new_v[2][0]
def run_episode( batch, enable_cuda, enable_comms, enable_proposal, prosocial, agent_models, # batch_size, testing, render=False): """ turning testing on means, we disable stochasticity: always pick the argmax """ type_constr = torch.cuda if enable_cuda else torch batch_size = batch['N'].size()[0] s = State(**batch) if enable_cuda: s.cuda() sieve = alive_sieve.AliveSieve(batch_size=batch_size, enable_cuda=enable_cuda) actions_by_timestep = [] alive_masks = [] # next two tensofrs wont be sieved, they will stay same size throughout # entire batch, we will update them using sieve.out_idxes[...] rewards = type_constr.FloatTensor(batch_size, SEQ_LEN).fill_(0) num_steps = type_constr.LongTensor(batch_size).fill_(10) term_matches_argmax_count = 0 utt_matches_argmax_count = 0 utt_stochastic_draws = 0 num_policy_runs = 0 prop_matches_argmax_count = 0 prop_stochastic_draws = 0 entropy_loss_by_agent = [ Variable(type_constr.FloatTensor(1).fill_(0)), Variable(type_constr.FloatTensor(1).fill_(0)) ] if render: print(' ') for t in range(10): agent = t % 2 agent_model = agent_models[agent] if enable_comms: _prev_message = s.m_prev else: # we dont strictly need to blank them, since they'll be all zeros anyway, # but defense in depth and all that :) _prev_message = type_constr.LongTensor(sieve.batch_size, 6).fill_(0) if enable_proposal: _prev_proposal = s.last_proposal else: # we do need to blank this one though :) _prev_proposal = type_constr.LongTensor(sieve.batch_size, SEQ_LEN).fill_(0) nodes, term_a, s.m_prev, this_proposal, _entropy_loss, \ _term_matches_argmax_count, _utt_matches_argmax_count, _utt_stochastic_draws, \ _prop_matches_argmax_count, _prop_stochastic_draws = agent_model( pool=Variable(s.pool), utility=Variable(s.utilities[:, agent]), m_prev=Variable(s.m_prev), prev_proposal=Variable(_prev_proposal), testing=testing ) entropy_loss_by_agent[agent] += _entropy_loss actions_by_timestep.append(nodes) term_matches_argmax_count += _term_matches_argmax_count num_policy_runs += sieve.batch_size utt_matches_argmax_count += _utt_matches_argmax_count utt_stochastic_draws += _utt_stochastic_draws prop_matches_argmax_count += _prop_matches_argmax_count prop_stochastic_draws += _prop_stochastic_draws if render and sieve.out_idxes[0] == 0: render_action(t=t, s=s, term=term_a, prop=this_proposal) new_rewards = rewards_lib.calc_rewards(t=t, s=s, term=term_a) rewards[sieve.out_idxes] = new_rewards s.last_proposal = this_proposal sieve.mark_dead(term_a) sieve.mark_dead(t + 1 >= s.N) alive_masks.append(sieve.alive_mask.clone()) sieve.set_dead_global(num_steps, t + 1) if sieve.all_dead(): break s.sieve_(sieve.alive_idxes) sieve.self_sieve_() if render: print(' r: %.2f' % rewards[0].mean()) print(' ') return actions_by_timestep, rewards, num_steps, alive_masks, entropy_loss_by_agent, \ term_matches_argmax_count, num_policy_runs, utt_matches_argmax_count, utt_stochastic_draws, \ prop_matches_argmax_count, prop_stochastic_draws