def test_perplexity(self): nll = NLLLoss() ppl = Perplexity() nll.eval_batch(self.outputs, self.batch) ppl.eval_batch(self.outputs, self.batch) nll_loss = nll.get_loss() ppl_loss = ppl.get_loss() self.assertAlmostEqual(ppl_loss, math.exp(nll_loss))
def test_perplexity(self): nll = NLLLoss() ppl = Perplexity() for output, target in zip(self.outputs, self.targets): nll.eval_batch(output, target) ppl.eval_batch(output, target) nll_loss = nll.get_loss() ppl_loss = ppl.get_loss() self.assertAlmostEqual(ppl_loss, math.exp(nll_loss))
def test_perplexity(self): num_class = 5 num_batch = 10 batch_size = 5 outputs = [F.softmax(Variable(torch.randn(batch_size, num_class))) for _ in range(num_batch)] targets = [Variable(torch.LongTensor([random.randint(0, num_class - 1) for _ in range(batch_size)])) for _ in range(num_batch)] nll = NLLLoss() ppl = Perplexity() for output, target in zip(outputs, targets): nll.eval_batch(output, target) ppl.eval_batch(output, target) nll_loss = nll.get_loss() ppl_loss = ppl.get_loss() self.assertAlmostEqual(ppl_loss, math.exp(nll_loss))
def apply_gradient_attack(data, model, input_vocab, replace_tokens, field_name, opt): def convert_to_onehot(inp, vocab_size): return torch.zeros(inp.size(0), inp.size(1), vocab_size, device=device).scatter_(2, inp.unsqueeze(2), 1.) batch_iterator = torchtext.data.BucketIterator( dataset=data, batch_size=opt.batch_size, sort=True, sort_within_batch=True, sort_key=lambda x: len(x.src), device=device, repeat=False) batch_generator = batch_iterator.__iter__() weight = torch.ones(len(tgt.vocab)).half() pad = tgt.vocab.stoi[tgt.pad_token] loss = Perplexity(weight, pad) if torch.cuda.is_available(): loss.cuda() model.train() d = {} for batch in tqdm.tqdm(batch_generator, total=len(batch_iterator)): indices = getattr(batch, 'index') input_variables, input_lengths = getattr(batch, field_name) target_variables = getattr(batch, 'tgt') # Do random attack if inputs are too long and will OOM under gradient attack if max(getattr(batch, field_name)[1]) > 250: rand_replacements = get_random_token_replacement( input_variables.cpu().numpy(), input_vocab, indices.cpu().numpy(), replace_tokens, opt.distinct) d.update(rand_replacements) continue # convert input_variables to one_hot input_onehot = Variable(convert_to_onehot(input_variables, vocab_size=len(input_vocab)), requires_grad=True).half() # Forward propagation decoder_outputs, decoder_hidden, other = model(input_onehot, input_lengths, target_variables, already_one_hot=True) # print outputs for debugging # for i,output_seq_len in enumerate(other['length']): # print(i,output_seq_len) # tgt_id_seq = [other['sequence'][di][i].data[0] for di in range(output_seq_len)] # tgt_seq = [output_vocab.itos[tok] for tok in tgt_id_seq] # print(' '.join([x for x in tgt_seq if x not in ['<sos>','<eos>','<pad>']]), end=', ') # gt = [output_vocab.itos[tok] for tok in target_variables[i]] # print(' '.join([x for x in gt if x not in ['<sos>','<eos>','<pad>']])) # Get loss loss.reset() for step, step_output in enumerate(decoder_outputs): batch_size = target_variables.size(0) loss.eval_batch(step_output.contiguous().view(batch_size, -1), target_variables[:, step + 1]) # Backward propagation model.zero_grad() input_onehot.retain_grad() loss.backward(retain_graph=True) grads = input_onehot.grad del input_onehot best_replacements = get_best_token_replacement( input_variables.cpu().numpy(), grads.cpu().numpy(), input_vocab, indices.cpu().numpy(), replace_tokens, opt.distinct) d.update(best_replacements) return d