Esempio n. 1
0
    def test_perplexity(self):
        nll = NLLLoss()
        ppl = Perplexity()
        nll.eval_batch(self.outputs, self.batch)
        ppl.eval_batch(self.outputs, self.batch)

        nll_loss = nll.get_loss()
        ppl_loss = ppl.get_loss()

        self.assertAlmostEqual(ppl_loss, math.exp(nll_loss))
    def test_perplexity(self):
        nll = NLLLoss()
        ppl = Perplexity()
        for output, target in zip(self.outputs, self.targets):
            nll.eval_batch(output, target)
            ppl.eval_batch(output, target)

        nll_loss = nll.get_loss()
        ppl_loss = ppl.get_loss()

        self.assertAlmostEqual(ppl_loss, math.exp(nll_loss))
Esempio n. 3
0
    def test_perplexity(self):
        num_class = 5
        num_batch = 10
        batch_size = 5

        outputs = [F.softmax(Variable(torch.randn(batch_size, num_class)))
                   for _ in range(num_batch)]
        targets = [Variable(torch.LongTensor([random.randint(0, num_class - 1)
                                              for _ in range(batch_size)]))
                   for _ in range(num_batch)]

        nll = NLLLoss()
        ppl = Perplexity()
        for output, target in zip(outputs, targets):
            nll.eval_batch(output, target)
            ppl.eval_batch(output, target)

        nll_loss = nll.get_loss()
        ppl_loss = ppl.get_loss()

        self.assertAlmostEqual(ppl_loss, math.exp(nll_loss))
Esempio n. 4
0
def apply_gradient_attack(data, model, input_vocab, replace_tokens, field_name,
                          opt):
    def convert_to_onehot(inp, vocab_size):
        return torch.zeros(inp.size(0), inp.size(1), vocab_size,
                           device=device).scatter_(2, inp.unsqueeze(2), 1.)

    batch_iterator = torchtext.data.BucketIterator(
        dataset=data,
        batch_size=opt.batch_size,
        sort=True,
        sort_within_batch=True,
        sort_key=lambda x: len(x.src),
        device=device,
        repeat=False)
    batch_generator = batch_iterator.__iter__()

    weight = torch.ones(len(tgt.vocab)).half()
    pad = tgt.vocab.stoi[tgt.pad_token]
    loss = Perplexity(weight, pad)
    if torch.cuda.is_available():
        loss.cuda()
    model.train()

    d = {}

    for batch in tqdm.tqdm(batch_generator, total=len(batch_iterator)):
        indices = getattr(batch, 'index')
        input_variables, input_lengths = getattr(batch, field_name)
        target_variables = getattr(batch, 'tgt')

        # Do random attack if inputs are too long and will OOM under gradient attack
        if max(getattr(batch, field_name)[1]) > 250:
            rand_replacements = get_random_token_replacement(
                input_variables.cpu().numpy(), input_vocab,
                indices.cpu().numpy(), replace_tokens, opt.distinct)

            d.update(rand_replacements)
            continue

        # convert input_variables to one_hot
        input_onehot = Variable(convert_to_onehot(input_variables,
                                                  vocab_size=len(input_vocab)),
                                requires_grad=True).half()

        # Forward propagation
        decoder_outputs, decoder_hidden, other = model(input_onehot,
                                                       input_lengths,
                                                       target_variables,
                                                       already_one_hot=True)

        # print outputs for debugging
        # for i,output_seq_len in enumerate(other['length']):
        #	print(i,output_seq_len)
        #	tgt_id_seq = [other['sequence'][di][i].data[0] for di in range(output_seq_len)]
        #	tgt_seq = [output_vocab.itos[tok] for tok in tgt_id_seq]
        #	print(' '.join([x for x in tgt_seq if x not in ['<sos>','<eos>','<pad>']]), end=', ')
        #	gt = [output_vocab.itos[tok] for tok in target_variables[i]]
        #	print(' '.join([x for x in gt if x not in ['<sos>','<eos>','<pad>']]))

        # Get loss
        loss.reset()
        for step, step_output in enumerate(decoder_outputs):
            batch_size = target_variables.size(0)
            loss.eval_batch(step_output.contiguous().view(batch_size, -1),
                            target_variables[:, step + 1])
        # Backward propagation
        model.zero_grad()
        input_onehot.retain_grad()
        loss.backward(retain_graph=True)
        grads = input_onehot.grad
        del input_onehot

        best_replacements = get_best_token_replacement(
            input_variables.cpu().numpy(),
            grads.cpu().numpy(), input_vocab,
            indices.cpu().numpy(), replace_tokens, opt.distinct)

        d.update(best_replacements)

    return d