def test_rnn_generator(): """Tests the RNNGenerator.""" # pylint: disable=too-many-locals,unused-variable import common batch_size = 4 seqlen = 4 vocab_size = 32 tok_emb_dim = 8 rnn_dim = 12 num_layers = 1 debug = True gen = RNNGenerator(**locals()) toks = Variable(torch.LongTensor(batch_size, 1).fill_(1)) gen_probs, gen_state = gen(toks) gen_toks = torch.multinomial(gen_probs.exp(), 1).detach() gen_probs, gen_state = gen(toks=gen_toks, prev_state=gen_state) # test basic rollout init_toks = Variable(torch.LongTensor(batch_size, seqlen).fill_(1)) ro_seqs, ro_log_probs = gen.rollout(init_toks, seqlen, 0) assert len(ro_seqs) == seqlen and len(ro_log_probs) == seqlen assert torch.np.allclose( torch.stack(ro_log_probs).data.exp().sum(-1).numpy(), 1) # test reproducability init_rand_state = torch.get_rng_state() with common.rand_state(torch, 42) as rand_state: ro1, _ = gen.rollout(init_toks, 8) with common.rand_state(torch, rand_state): ro2, _ = gen.rollout(init_toks, 8) assert all((t1.data == t2.data).all() for t1, t2 in zip(ro1, ro2)) assert (torch.get_rng_state() == init_rand_state).all() # test continuation rand_toks = Variable(torch.LongTensor(batch_size, 2).random_(vocab_size)) ro_seqs, _, (ro_hid, ro_rng) = gen.rollout(rand_toks, 2, return_first_state=True) with common.rand_state(torch, ro_rng): next_ro, _ = gen.rollout((ro_seqs[0], ro_hid), 1) assert (ro_seqs[1].data == next_ro[0].data).all() # test double-backward sum(gen_probs).sum().backward(create_graph=True) sum(p.grad.norm() for p in gen.parameters(dx2=True)).backward()
def __init__(self, generator, label, seqlen, num_samples, gen_init_toks, seed, eos_idx=None, **unused_kwargs): super(GenDataset, self).__init__() self.label = label th = torch.cuda if gen_init_toks.is_cuda else torch with common.rand_state(th, seed): init_toks = gen_init_toks.data.cpu() batch_size = gen_init_toks.size(0) num_batches = (num_samples + batch_size - 1) // batch_size samples = [] for _ in range(num_batches): gen_seqs, _ = generator.rollout(gen_init_toks, seqlen) samps = torch.cat([gen_init_toks] + gen_seqs, -1).data if eos_idx: self.mask_gen_seqs_(samps, eos_idx) samples.append(samps.cpu()) self.samples = torch.cat(samples)
def _compute_eval_metric(self, num_samples=256): test_nll = 0 num_test_batches = max(num_samples // len(self.init_toks), 1) with common.rand_state(torch.cuda, -1): for _ in range(num_test_batches): gen_seqs, _ = self.g.rollout(self.init_toks, self.opts.seqlen) test_nll += self.compute_oracle_nll(gen_seqs) test_nll /= num_test_batches return test_nll
def _create_oracle(self): """Returns a randomly initialized generator.""" with common.rand_state(torch, self.opts.seed): opt_vars = vars(self.opts) opt_vars.pop('rnn_dim') oracle = model.generator.create( gen_type=self.opts.oracle_type, rnn_dim=self.opts.oracle_dim, **opt_vars) for param in oracle.parameters(): nn.init.normal(param, std=1) return oracle
def _compute_d_test_acc(self, num_samples=256): num_test_batches = max(num_samples // len(self.init_toks), 1) test_loader = self._create_dataloader(self.test_dataset) acc_real = 0 for i, (batch_toks, _) in enumerate(test_loader): if i == num_test_batches: break toks = Variable(batch_toks[:, 1:].cuda()) # no init toks acc_real += self.compute_acc(self.d(toks), LABEL_REAL) acc_real /= num_test_batches acc_gen = 0 with common.rand_state(torch.cuda, -1): for _ in range(num_test_batches): gen_seqs, _ = self.g.rollout(self.init_toks, self.opts.seqlen) init_gen_seqs = torch.cat([self.init_toks] + gen_seqs, -1) dataset.GenDataset.mask_gen_seqs_(init_gen_seqs.data, self.opts.eos_idx) acc_gen += self.compute_acc(self.d(init_gen_seqs[:, 1:]), LABEL_GEN) acc_gen /= num_test_batches return acc_real, acc_gen