def test_seq2seq_max_target_length(self):
     batch = self.tokenizer.prepare_seq2seq_batch(
         self.src_text, tgt_texts=self.tgt_text, max_length=3, max_target_length=10
     )
     batch["decoder_input_ids"] = shift_tokens_right(batch.labels, self.tokenizer.pad_token_id)
     self.assertEqual(batch.input_ids.shape[1], 3)
     self.assertEqual(batch.decoder_input_ids.shape[1], 10)
     # max_target_length will default to max_length if not specified
     batch = self.tokenizer.prepare_seq2seq_batch(self.src_text, tgt_texts=self.tgt_text, max_length=3)
     batch["decoder_input_ids"] = shift_tokens_right(batch.labels, self.tokenizer.pad_token_id)
     self.assertEqual(batch.input_ids.shape[1], 3)
     self.assertEqual(batch.decoder_input_ids.shape[1], 3)
Beispiel #2
0
    def forward(self, input_ids, attention_mask=None, encoder_outputs=None,
            decoder_input_ids=None, decoder_attention_mask=None, decoder_cached_states=None,
            use_cache=False, is_training=False):

        if is_training:
            _decoder_input_ids = shift_tokens_right(decoder_input_ids, self.config.pad_token_id)
        else:
            _decoder_input_ids = decoder_input_ids

        outputs = self.model(
            input_ids,
            attention_mask=attention_mask,
            encoder_outputs=encoder_outputs,
            decoder_input_ids=_decoder_input_ids,
            decoder_attention_mask=decoder_attention_mask,
            decoder_cached_states=decoder_cached_states,
            use_cache=use_cache,
        )
        lm_logits = F.linear(outputs[0], self.model.shared.weight, bias=self.final_logits_bias)
        if is_training:
            # loss_fct = nn.CrossEntropyLoss(reduction="mean", ignore_index=self.config.pad_token_id)
            # loss = loss_fct(lm_logits.view(-1, self.config.vocab_size),
            #                   decoder_input_ids.view(-1))
            lprobs = F.log_softmax(lm_logits, dim=-1)
            loss, _ = label_smoothed_nll_loss(lprobs, decoder_input_ids, epsilon=0.1, ignore_index=self.config.pad_token_id)
            return loss
        return (lm_logits, ) + outputs[1:]
Beispiel #3
0
    def training_step(self, batch, batch_idx):
        src, tgt = batch[0], batch[1]
        src_input = self.tokenizer.encode_batch(src, max_length=self.max_len)
        tgt_input = self.tokenizer.encode_batch(tgt, max_length=self.max_len)

        input_ids = src_input["input_ids"].to(self.device)
        attention_mask = src_input["attention_mask"].to(self.device)
        labels = tgt_input["input_ids"].to(self.device)
        decoder_input_ids = shift_tokens_right(
            labels, self.tokenizer.token2idx["[PAD]"])

        outputs = self.model(
            input_ids,
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
        )

        lm_logits = outputs[0]
        loss_fn = torch.nn.CrossEntropyLoss(
            ignore_index=self.tokenizer.token2idx["[PAD]"])

        lm_loss = loss_fn(lm_logits.view(-1, lm_logits.shape[-1]),
                          labels.view(-1))
        self.save_model()
        return {"loss": lm_loss}
Beispiel #4
0
def convert_to_features(example_batch: Dict[str, Any]) -> Dict[str, Any]:
    input_encodings = tokenizer.batch_encode_plus(
        example_batch["utterance"],
        max_length=32,
        truncation=True,
        padding="longest",
        return_tensors="pt",
    )
    target_encodings = tokenizer.batch_encode_plus(
        example_batch["code"],
        max_length=32,
        truncation=True,
        padding="longest",
        return_tensors="pt",
    )

    labels = target_encodings["input_ids"]
    decoder_input_ids = shift_tokens_right(labels, model.config.pad_token_id)
    labels[labels[:, :] == model.config.pad_token_id] = -100

    encodings = {
        "input_ids": input_encodings["input_ids"].numpy().copy(),
        "attention_mask": input_encodings["attention_mask"].numpy().copy(),
        "decoder_input_ids": decoder_input_ids.numpy().copy(),
        "labels": labels.numpy().copy(),
    }

    return encodings
Beispiel #5
0
    def _step(self, batch):
        if batch['task'][0] == 'response':
            pad_token_id = self.tokenizer.pad_token_id
            target_ids = batch['target_ids']

            decoder_input_ids = shift_tokens_right(target_ids, pad_token_id)

            outputs = self(input_ids=batch['source_ids'],
                           attention_mask=batch['source_mask'],
                           decoder_input_ids=decoder_input_ids,
                           use_cache=False,
                           task=batch['task'][0])

            lprobs = torch.nn.functional.log_softmax(outputs[0], dim=-1)
            loss, nll_loss = label_smoothed_nll_loss(
                lprobs,
                target_ids,
                self.hparams.label_smoothing,
                ignore_index=pad_token_id)

        elif batch['task'][0] in ['cfemotion', 'emotion', 'sentiment']:
            outputs = self(input_ids=batch['source_ids'],
                           attention_mask=batch['source_mask'],
                           lm_labels=batch['label'],
                           task=batch['task'][0])
            loss = outputs[0]

        else:
            raise ValueError('The dataset contains an invalid task.')

        if self.hparams.task_weights:
            loss = (self.task_weights[self.tasks.index(batch['task'][0])] *
                    loss)

        return loss
Beispiel #6
0
    def _step(self, batch: dict) -> Tuple:
        pad_token_id = self.tokenizer.pad_token_id
        src_ids, src_mask = batch["input_ids"], batch["attention_mask"]
        tgt_ids = batch["labels"]
        if isinstance(self.model, T5ForConditionalGeneration):
            decoder_input_ids = self.model._shift_right(tgt_ids)
        else:
            decoder_input_ids = shift_tokens_right(tgt_ids, pad_token_id)

        outputs = self(src_ids,
                       attention_mask=src_mask,
                       decoder_input_ids=decoder_input_ids,
                       use_cache=False)
        lm_logits = outputs[0]
        if self.hparams.label_smoothing == 0:
            # Same behavior as modeling_bart.py, besides ignoring pad_token_id
            ce_loss_fct = torch.nn.CrossEntropyLoss(ignore_index=pad_token_id)

            assert lm_logits.shape[-1] == self.vocab_size
            loss = ce_loss_fct(lm_logits.view(-1, lm_logits.shape[-1]),
                               tgt_ids.view(-1))
        else:
            lprobs = torch.nn.functional.log_softmax(lm_logits, dim=-1)
            loss, nll_loss = label_smoothed_nll_loss(
                lprobs,
                tgt_ids,
                self.hparams.label_smoothing,
                ignore_index=pad_token_id)
        return (loss, )
Beispiel #7
0
def preprocess_data_mbart(data):
    input_text, target_text, tokenizer, args = data

    tokenized_example = tokenizer.prepare_seq2seq_batch(
        src_texts=[input_text],
        tgt_texts=[target_text],
        src_lang=args.src_lang,
        tgt_lang=args.tgt_lang,
        max_length=args.max_seq_length,
        padding="max_length",  # pad_to_max_length=True won't work in this case
        return_tensors="pt",
        truncation=True,
    )

    decoder_input_ids = tokenized_example["labels"].clone()
    decoder_input_ids = shift_tokens_right(decoder_input_ids, tokenizer.pad_token_id)

    labels = tokenized_example["labels"]
    labels[labels == tokenizer.pad_token_id] = -100

    return {
        "input_ids": tokenized_example["input_ids"].squeeze(),
        "attention_mask": tokenized_example["attention_mask"].squeeze(),
        "decoder_input_ids": decoder_input_ids.squeeze(),
        "labels": labels.squeeze(),
    }
Beispiel #8
0
    def __call__(self, batch) -> Dict[str, torch.Tensor]:
        if hasattr(self.tokenizer, "prepare_seq2seq_batch"):
            batch = self._encode(batch)
            input_ids, attention_mask, labels = (
                batch["input_ids"],
                batch["attention_mask"],
                batch["labels"],
            )
        else:
            input_ids = torch.stack([x["input_ids"] for x in batch])
            attention_mask = torch.stack([x["attention_mask"] for x in batch])
            labels = torch.stack([x["labels"] for x in batch])

            labels = trim_batch(labels, self.pad_token_id)
            input_ids, attention_mask = trim_batch(
                input_ids, self.pad_token_id, attention_mask=attention_mask)

        if isinstance(self.tokenizer, T5Tokenizer):
            decoder_input_ids = self._shift_right_t5(labels)
        else:
            decoder_input_ids = shift_tokens_right(labels, self.pad_token_id)

        batch = {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "decoder_input_ids": decoder_input_ids,
            "labels": labels,
        }
        return batch
Beispiel #9
0
 def test_shift_tokens_right(self):
     input_ids = torch.Tensor([[71, 82, 18, 33, 2, 1, 1], [68, 34, 26, 58, 30, 82, 2]]).long()
     shifted = shift_tokens_right(input_ids, 1)
     n_pad_before = input_ids.eq(1).float().sum()
     n_pad_after = shifted.eq(1).float().sum()
     self.assertEqual(shifted.shape, input_ids.shape)
     self.assertEqual(n_pad_after, n_pad_before - 1)
     self.assertTrue(torch.eq(shifted[:, 0], 2).all())
Beispiel #10
0
 def test_batch_fairseq_parity(self):
     batch: BatchEncoding = self.tokenizer.prepare_seq2seq_batch(
         self.src_text, tgt_texts=self.tgt_text, return_tensors="pt")
     batch["decoder_input_ids"] = shift_tokens_right(
         batch.labels, self.tokenizer.pad_token_id)
     for k in batch:
         batch[k] = batch[k].tolist()
     # batch = {k: v.tolist() for k,v in batch.items()}
     # fairseq batch: https://gist.github.com/sshleifer/cba08bc2109361a74ac3760a7e30e4f4
     # batch.decoder_inputs_ids[0][0] ==
     assert batch.input_ids[1][-2:] == [2, EN_CODE]
     assert batch.decoder_input_ids[1][0] == RO_CODE
     assert batch.decoder_input_ids[1][-1] == 2
     assert batch.labels[1][-2:] == [2, RO_CODE]
Beispiel #11
0
    def _step(self, batch):
        labels = torch.full(
            (len(batch['task']), 1),
            int(batch['task'][0] != 'response'),  # fake
            dtype=torch.long).cuda()

        if batch['task'][0] == 'response':
            pad_token_id = self.tokenizer.pad_token_id
            target_ids = batch['target_ids']

            decoder_input_ids = shift_tokens_right(target_ids, pad_token_id)

            outputs = self(input_ids=batch['source_ids'],
                           attention_mask=batch['source_mask'],
                           decoder_input_ids=decoder_input_ids,
                           use_cache=False,
                           task=batch['task'][0])

            lprobs = torch.nn.functional.log_softmax(outputs[0], dim=-1)
            loss, nll_loss = label_smoothed_nll_loss(
                lprobs,
                target_ids,
                self.hparams.label_smoothing,
                ignore_index=pad_token_id)

            x = outputs[1]  # last hidden state

        elif batch['task'][0] in ['cfemotion', 'emotion', 'sentiment']:
            outputs = self(input_ids=batch['source_ids'],
                           attention_mask=batch['source_mask'],
                           lm_labels=batch['label'],
                           task=batch['task'][0])
            loss = outputs[0]

            x = outputs[2]  # last hidden state

        else:
            raise ValueError('The dataset contains an invalid task.')

        eos_mask = batch['source_ids'].eq(self.model.config.eos_token_id)
        if len(torch.unique(eos_mask.sum(1))) > 1:
            raise ValueError(
                'All examples must have the same number of <eos> tokens.')
        sentence_representation = x[eos_mask, :].view(x.size(0), -1,
                                                      x.size(-1))[:, -1, :]
        logits = self.discriminator(sentence_representation)
        adv_loss = F.cross_entropy(logits.view(-1, 2), labels.view(-1))

        return loss + adv_loss
    def test_enro_tokenizer_prepare_seq2seq_batch(self):
        batch = self.tokenizer.prepare_seq2seq_batch(
            self.src_text, tgt_texts=self.tgt_text, max_length=len(self.expected_src_tokens),
        )
        batch["decoder_input_ids"] = shift_tokens_right(batch.labels, self.tokenizer.pad_token_id)
        self.assertIsInstance(batch, BatchEncoding)

        self.assertEqual((2, 14), batch.input_ids.shape)
        self.assertEqual((2, 14), batch.attention_mask.shape)
        result = batch.input_ids.tolist()[0]
        self.assertListEqual(self.expected_src_tokens, result)
        self.assertEqual(2, batch.decoder_input_ids[0, -1])  # EOS
        # Test that special tokens are reset
        self.assertEqual(self.tokenizer.prefix_tokens, [])
        self.assertEqual(self.tokenizer.suffix_tokens, [self.tokenizer.eos_token_id, EN_CODE])
def test_seq2seq_dataset_truncation(tok_name):
    tokenizer = AutoTokenizer.from_pretrained(tok_name)
    tmp_dir = make_test_data_dir()
    max_len_source = max(len(tokenizer.encode(a)) for a in ARTICLES)
    max_len_target = max(len(tokenizer.encode(a)) for a in SUMMARIES)
    max_src_len = 4
    max_tgt_len = 8
    assert max_len_target > max_src_len  # Will be truncated
    assert max_len_source > max_src_len  # Will be truncated
    src_lang, tgt_lang = "ro_RO", "de_DE"  # ignored for all but mbart, but never causes error.
    train_dataset = Seq2SeqDataset(
        tokenizer,
        data_dir=tmp_dir,
        type_path="train",
        max_source_length=max_src_len,
        max_target_length=max_tgt_len,  # ignored
        src_lang=src_lang,
        tgt_lang=tgt_lang,
    )
    dataloader = DataLoader(train_dataset,
                            batch_size=2,
                            collate_fn=train_dataset.collate_fn)
    for batch in dataloader:
        assert isinstance(batch, dict)
        assert batch["attention_mask"].shape == batch["input_ids"].shape
        # show that articles were trimmed.
        assert batch["input_ids"].shape[1] == max_src_len
        # show that targets are the same len
        assert batch["labels"].shape[1] == max_tgt_len
        if tok_name != MBART_TINY:
            continue
        # check language codes in correct place
        batch["decoder_input_ids"] = shift_tokens_right(
            batch["labels"], tokenizer.pad_token_id)
        assert batch["decoder_input_ids"][
            0, 0].item() == tokenizer.lang_code_to_id[tgt_lang]
        assert batch["decoder_input_ids"][0,
                                          -1].item() == tokenizer.eos_token_id
        assert batch["input_ids"][0, -2].item() == tokenizer.eos_token_id
        assert batch["input_ids"][
            0, -1].item() == tokenizer.lang_code_to_id[src_lang]

        break  # No need to test every batch
    def test_forward(self):
        src, tgt = ["I am a small frog"], ["Ich bin ein kleiner Frosch."]
        expected_ids = [38, 121, 14, 697, 38848, 0]

        model_inputs: dict = self.tokenizer.prepare_seq2seq_batch(src, tgt_texts=tgt).to(torch_device)

        self.assertListEqual(expected_ids, model_inputs.input_ids[0].tolist())

        desired_keys = {
            "input_ids",
            "attention_mask",
            "labels",
        }
        self.assertSetEqual(desired_keys, set(model_inputs.keys()))
        model_inputs["decoder_input_ids"] = shift_tokens_right(model_inputs.labels, self.tokenizer.pad_token_id)
        model_inputs["return_dict"] = True
        model_inputs["use_cache"] = False
        with torch.no_grad():
            outputs = self.model(**model_inputs)
        max_indices = outputs.logits.argmax(-1)
        self.tokenizer.batch_decode(max_indices)
Beispiel #15
0
    def _step_D(self, batch):
        labels = torch.full(
            (len(batch['task']), 1),
            int(batch['task'][0] == 'response'),  # real
            dtype=torch.long).cuda()

        if batch['task'][0] == 'response':
            pad_token_id = self.tokenizer.pad_token_id
            target_ids = batch['target_ids']

            decoder_input_ids = shift_tokens_right(target_ids, pad_token_id)

            outputs = self(input_ids=batch['source_ids'],
                           attention_mask=batch['source_mask'],
                           decoder_input_ids=decoder_input_ids,
                           use_cache=False,
                           task=batch['task'][0])

        elif batch['task'][0] in ['cfemotion', 'emotion', 'sentiment']:
            outputs = self(
                input_ids=batch['source_ids'],
                attention_mask=batch['source_mask'],
                # lm_labels=batch['label'],
                task=batch['task'][0])

        else:
            raise ValueError('The dataset contains an invalid task.')

        x = outputs[1]  # last hidden state
        eos_mask = batch['source_ids'].eq(self.model.config.eos_token_id)
        if len(torch.unique(eos_mask.sum(1))) > 1:
            raise ValueError(
                'All examples must have the same number of <eos> tokens.')
        sentence_representation = x[eos_mask, :].view(x.size(0), -1,
                                                      x.size(-1))[:, -1, :]
        logits = self.discriminator(sentence_representation)
        loss = F.cross_entropy(logits.view(-1, 2), labels.view(-1))

        return loss
Beispiel #16
0
def main():
    parser = argparse.ArgumentParser(
        description=
        "Convert BART to LongBART. Replaces BART encoder's SelfAttnetion with LongformerSelfAttention"
    )
    parser.add_argument(
        '--base_model',
        type=str,
        default='facebook/bart-large',
        help='The name or path of the base model you want to convert')
    parser.add_argument('--tokenizer_name_or_path',
                        type=str,
                        default='facebook/bart-large',
                        help='The name or path of the tokenizer')
    parser.add_argument('--save_model_to',
                        type=str,
                        required=True,
                        help='The path to save the converted model')
    parser.add_argument(
        '--attention_window',
        type=int,
        default=512,
        help='attention window size for longformer self attention (one sided)')
    parser.add_argument('--max_pos',
                        type=int,
                        default=4096 * 4,
                        help='maximum encoder positions')

    args = parser.parse_args()

    if not os.path.exists(args.save_model_to):
        os.mkdir(args.save_model_to)

    create_long_model(save_model_to=args.save_model_to,
                      base_model=args.base_model,
                      tokenizer_name_or_path=args.tokenizer_name_or_path,
                      attention_window=args.attention_window,
                      max_pos=args.max_pos)

    tokenizer = BartTokenizer.from_pretrained(args.save_model_to)
    TXT = "My friends are <mask> but they eat too many carbs."
    model = LongformerEncoderDecoderForConditionalGeneration.from_pretrained(
        args.save_model_to)
    model.model.encoder.config.gradient_checkpointing = True
    model.model.decoder.config.gradient_checkpointing = True
    data = tokenizer([TXT],
                     return_tensors='pt',
                     padding='max_length',
                     max_length=2048)
    input_ids = data['input_ids']
    attention_mask = data['attention_mask']
    decoder_input_ids = shift_tokens_right(input_ids[:, :5],
                                           tokenizer.pad_token_id)
    logits = model(input_ids,
                   attention_mask=attention_mask,
                   decoder_input_ids=decoder_input_ids,
                   use_cache=False)[0]
    masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item()
    probs = logits[0, masked_index].softmax(dim=0)
    values, predictions = probs.topk(5)
    print(tokenizer.convert_ids_to_tokens(predictions))
Beispiel #17
0
    def _step(self, batch):
        # assert is_frozen(self.teacher)
        pad_token_id = self.tokenizer.pad_token_id
        input_ids, src_mask, tgt_ids = batch["input_ids"], batch[
            "attention_mask"], batch["labels"]
        decoder_input_ids = shift_tokens_right(tgt_ids, pad_token_id)
        # noinspection PyCallingNonCallable
        lm_logits, dec_hidden, enc_outputs, enc_hidden_state = self(
            input_ids,
            attention_mask=src_mask,
            decoder_input_ids=decoder_input_ids,
            output_hidden_states=True,
            output_attentions=False,
            use_cache=False,
        )  # TODO(@sshleifer): return_dict=True cleanup

        # Same cross entropy vs. label smoothing logic as finetune.py
        assert lm_logits.shape[-1] == self.model.config.vocab_size
        if self.hparams.label_smoothing == 0:
            # Same behavior as modeling_bart.py, besides ignoring pad_token_id
            loss_fct = torch.nn.CrossEntropyLoss(ignore_index=pad_token_id)
            student_lm_loss = loss_fct(lm_logits.view(-1, lm_logits.shape[-1]),
                                       tgt_ids.view(-1))
        else:
            lprobs = torch.nn.functional.log_softmax(lm_logits, dim=-1)
            student_lm_loss, _ = label_smoothed_nll_loss(
                lprobs,
                tgt_ids,
                self.hparams.label_smoothing,
                ignore_index=pad_token_id)

        def zero_tensor():
            return torch.tensor(0.0).type_as(student_lm_loss)

        loss_encoder, hid_loss_enc, hid_loss_dec = zero_tensor(), zero_tensor(
        ), zero_tensor()
        if self.different_encoder:
            with torch.no_grad():
                teacher_enc_outputs, teacher_enc_hid, _ = self.teacher.model.encoder(
                    input_ids,
                    attention_mask=src_mask,
                    output_hidden_states=True)
            if self.hparams.alpha_encoder_loss > 0:
                loss_encoder = self.calc_mse_loss(enc_outputs,
                                                  teacher_enc_outputs,
                                                  src_mask)

            hid_loss_enc = self.calc_hidden_loss(src_mask, enc_hidden_state,
                                                 teacher_enc_hid,
                                                 self.hparams.e_layer_to_copy)

        teacher_enc_outputs = (enc_outputs, )
        assert isinstance(teacher_enc_outputs,
                          tuple), type(teacher_enc_outputs)

        with torch.no_grad():
            tloss, tlogits, tdec_hidden, _ = self.teacher(
                input_ids,
                attention_mask=src_mask,
                encoder_outputs=teacher_enc_outputs,
                decoder_input_ids=decoder_input_ids,
                lm_labels=tgt_ids,
                output_hidden_states=True,
            )
        dec_mask = decoder_input_ids.ne(pad_token_id)
        loss_ce, s_logits_slct, t_logits_slct = self.calc_ce_loss(
            dec_mask, lm_logits, tlogits)
        if self.alpha_hid > 0:
            hid_loss_dec = self.calc_hidden_loss(dec_mask, dec_hidden,
                                                 tdec_hidden,
                                                 self.hparams.d_matches)

        blended_loss = (self.alpha_ce * loss_ce +
                        self.alpha_mlm * student_lm_loss +
                        self.hparams.alpha_encoder_loss * loss_encoder +
                        self.hparams.alpha_hid * (hid_loss_enc + hid_loss_dec))
        return blended_loss, loss_ce, student_lm_loss, loss_encoder, hid_loss_enc, hid_loss_dec
Beispiel #18
0
    def _step(self, batch: dict) -> Tuple:
        pad_token_id = self.tokenizer.pad_token_id
        src_ids, src_mask = batch["input_ids"], batch["attention_mask"]
        tgt_ids = batch["labels"]
        if isinstance(self.model, T5ForConditionalGeneration):
            decoder_input_ids = self.model._shift_right(tgt_ids)
        else:
            decoder_input_ids = shift_tokens_right(tgt_ids, pad_token_id)
        if not self.already_saved_batch:  # This would be slightly better if it only happened on rank zero
            batch["decoder_input_ids"] = decoder_input_ids
            self.save_readable_batch(batch)

        outputs = self(src_ids,
                       attention_mask=src_mask,
                       decoder_input_ids=decoder_input_ids,
                       use_cache=False)
        lm_logits = outputs[0]

        if self.hparams.label_smoothing == 0:
            # Same behavior as modeling_bart.py, besides ignoring pad_token_id
            ce_loss_fct = torch.nn.CrossEntropyLoss(ignore_index=pad_token_id)

            assert lm_logits.shape[-1] == self.vocab_size
            loss = ce_loss_fct(lm_logits.view(-1, lm_logits.shape[-1]),
                               tgt_ids.view(-1))
        else:
            lprobs = torch.nn.functional.log_softmax(lm_logits, dim=-1)
            loss, nll_loss = label_smoothed_nll_loss(
                lprobs,
                tgt_ids,
                self.hparams.label_smoothing,
                ignore_index=pad_token_id)

        batch_size = src_ids.shape[0]
        loss_log = {'ce': loss.item()}

        if self.unlikelihood_training_tokens:
            ul_token_loss = self.unlikelihood_loss_token(
                decoder_input_ids, tgt_ids, lm_logits)
            ul_token_loss_weighted = ul_token_loss * self.unlikelihood_tokens_alpha

            print("*******************")
            print(f"loss: {loss}")
            print(f"UL token loss: {ul_token_loss_weighted}")
            print("*******************")

            loss_log['ul_token'] = ul_token_loss_weighted.item() / batch_size
            loss += ul_token_loss_weighted

        if self.unlikelihood_training_copy:
            ul_copy_loss = self.unlikelihood_loss_copying(
                src_ids, decoder_input_ids, lm_logits,
                self.unlikelihood_copy_n)
            ul_copy_loss_weighted = self.unlikelihood_copy_alpha * ul_copy_loss

            print("*******************")
            print(f"loss: {loss}")
            print(f"UL copy loss: {ul_copy_loss_weighted}")
            print("*******************")

            loss_log['ul_copy'] = ul_copy_loss_weighted.item() / batch_size
            loss += ul_copy_loss_weighted

        if self.unlikelihood_training:
            ul_loss = self.unlikelihood_loss(
                decoder_input_ids, lm_logits, self.weight_vector,
                self.unlikelihood_selective_penalty)
            ul_loss_weighted = ul_loss * self.unlikelihood_alpha

            print("*******************")
            print(f"loss: {loss}")
            print(f"UL loss: {ul_loss_weighted}")
            print("*******************")

            loss_log['ul_logr'] = ul_loss_weighted.item() / batch_size
            loss += ul_loss_weighted

        self.losses.append(loss_log)
        return (loss, )
Beispiel #19
0
def generate_decode_copy(model, args, input_ids):
    assert input_ids.shape[
        0] == 1, "Ngram repetition penalty only supported for batch size of 1 (for now)"

    alpha = args.decode_state_change_copy_alpha
    n = args.decode_state_change_copy_n

    temp = model
    model = SummarizationModule(args)
    model.model = temp.to('cuda')
    model.model.train()
    input_ids = input_ids.to('cuda')

    # initialize optimizer
    #optimizer = SGD(model.model.parameters(), lr=1.0)
    #optimizer.zero_grad()

    # first generate output normally
    tgt_ids, logs = model.model.generate(
        input_ids,
        do_sample=False,
        num_beams=1,
        max_length=1024,
        early_stopping=False,
        num_return_sequences=1,
        decoder_start_token_id=model.model.config.pad_token_id)

    # treat generated text as decoder target
    pad_token_id = model.model.config.pad_token_id
    decoder_input_ids = shift_tokens_right(tgt_ids, pad_token_id)
    outputs = model(input_ids,
                    decoder_input_ids=decoder_input_ids,
                    use_cache=False)
    lm_logits = outputs[0]

    input_ids = input_ids.tolist()
    decoder_input_ids = decoder_input_ids.tolist()
    tgt_ids = tgt_ids.tolist()

    open('input_ids.txt', 'w').write(json.dumps(input_ids))
    open('decoder_input_ids.txt', 'w').write(json.dumps(decoder_input_ids))
    open('tgt_ids.txt', 'w').write(json.dumps(tgt_ids))
    return None

    # apply repetition penalty
    loss = model.unlikelihood_loss_copying(input_ids, decoder_input_ids,
                                           lm_logits, n) * alpha
    #loss.backward()
    #optimizer.step()

    # now re-generate output
    if args.decode_method == 'greedy':
        gen_ids, logs = model.model.generate(
            input_ids,
            do_sample=False,
            ngram_copy_penalty=args.decode_ngram_copy_penalty,
            ngram_copy_weight=args.decode_ngram_copy_weight,
            max_length=args.max_target_length,
            early_stopping=False,
            num_return_sequences=1,
            decoder_start_token_id=model.model.config.pad_token_id)
    elif args.decode_method == 'beam':
        gen_ids, logs = model.model.generate(
            input_ids,
            ngram_copy_penalty=args.decode_ngram_copy_penalty,
            ngram_copy_weight=args.decode_ngram_copy_weight,
            num_beams=args.decode_num_beams,
            max_length=args.max_target_length,
            early_stopping=False,
            num_return_sequences=1,
            decoder_start_token_id=model.model.config.pad_token_id)
    else:
        gen_ids, logs = model.model.generate(
            input_ids,
            do_sample=True,
            ngram_copy_penalty=args.decode_ngram_copy_penalty,
            ngram_copy_weight=args.decode_ngram_copy_weight,
            top_p=args.decode_p,
            max_length=args.max_target_length,
            early_stopping=False,
            num_return_sequences=1,
            decoder_start_token_id=model.model.config.pad_token_id)

    return gen_ids, logs
Beispiel #20
0
    def _step(self, batch):
        # assert is_frozen(self.teacher) copied_decoder_layers
        pad_token_id = self.tokenizer.pad_token_id
        input_ids, src_mask, labels = batch["input_ids"], batch[
            "attention_mask"], batch["labels"]
        if isinstance(self.model, T5ForConditionalGeneration):
            decoder_input_ids = self.model._shift_right(labels)
        else:
            decoder_input_ids = shift_tokens_right(labels, pad_token_id)

        # noinspection PyCallingNonCallable
        lm_logits, dec_hidden, enc_outputs, enc_hidden_state = self(
            input_ids,
            attention_mask=src_mask,
            decoder_input_ids=decoder_input_ids,
            output_hidden_states=True,
            output_attentions=False,
            use_cache=False,
        )

        # Same cross entropy vs. label smoothing logic as finetune.py
        assert lm_logits.shape[-1] == self.model.config.vocab_size
        if self.hparams.label_smoothing == 0:
            # Same behavior as modeling_bart.py, besides ignoring pad_token_id
            loss_fct = torch.nn.CrossEntropyLoss(ignore_index=pad_token_id)
            student_lm_loss = loss_fct(lm_logits.view(-1, lm_logits.shape[-1]),
                                       labels.view(-1))
        else:
            lprobs = torch.nn.functional.log_softmax(lm_logits, dim=-1)
            student_lm_loss, _ = label_smoothed_nll_loss(
                lprobs,
                labels,
                self.hparams.label_smoothing,
                ignore_index=pad_token_id)

        def zero_tensor():
            return torch.tensor(0.0).type_as(student_lm_loss)

        hid_loss_enc, hid_loss_dec = zero_tensor(), zero_tensor()
        if self.different_encoder:  # compute encoder hidden state loss
            with torch.no_grad():
                teacher_enc_hid = self.teacher.get_encoder()(
                    input_ids,
                    attention_mask=src_mask,
                    output_hidden_states=True,
                    return_dict=True).hidden_states

            hid_loss_enc = self.calc_hidden_loss(
                src_mask,
                enc_hidden_state,
                teacher_enc_hid,
                self.e_matches,
                normalize_hidden=self.hparams.normalize_hidden,
            )

        with torch.no_grad():
            outputs = self.teacher(
                input_ids,
                attention_mask=src_mask,
                encoder_outputs=(enc_outputs, ),
                decoder_input_ids=decoder_input_ids,
                lm_labels=labels,
                output_hidden_states=True,
                return_dict=True,
            )
            tlogits, tdec_hidden = outputs.logits, outputs.decoder_hidden_states
        dec_mask = decoder_input_ids.ne(pad_token_id)
        loss_ce = self.calc_ce_loss(dec_mask, lm_logits, tlogits)
        if self.alpha_hid > 0:  # Intermediate supervision of decoder hidden states
            hid_loss_dec = self.calc_hidden_loss(
                dec_mask,
                dec_hidden,
                tdec_hidden,
                self.d_matches,
                normalize_hidden=self.hparams.normalize_hidden)

        blended_loss = (self.alpha_ce * loss_ce +
                        self.alpha_mlm * student_lm_loss +
                        self.hparams.alpha_hid * (hid_loss_enc + hid_loss_dec))
        return blended_loss, loss_ce, student_lm_loss, hid_loss_enc, hid_loss_dec
Beispiel #21
0
def convert_examples_to_features_question_generation(
    examples,
    tokenizer,
    max_length=512,
    max_length_label=32,
    bart=False,
):
    """
    This function converts a list of examples into features that can be used
    as inputs for the question generation model.
    INPUTS:
    - examples: list of object <class: InputExample>, examples to convert.
    - tokenizer: torch tokenizer object, tokenize the examples.
    - max_length: int, size of the maximum input sentence (list of tokens).
    - max_length_label: int, size of the maximum label sentence.
    - bart: boolean, saying whether the model is BART or not.
    OUTPUTS:
    - features: list of object <class: InputFeatures>, list of features for model.
    """
    processor = DataProcessor()
    features = []
    pad_token = tokenizer.pad_token_id
    for (ex_index, example) in enumerate(examples):
        ######## ENCODING INPUT ########
        # This will encode both answer and context with a separator.
        inputs = tokenizer.encode_plus(example.answer,
                                       example.context,
                                       add_special_tokens=True,
                                       max_length=max_length,
                                       truncation='only_second')
        input_ids = inputs["input_ids"]

        token_type_ids = [0] * (len(tokenizer.encode(example.answer)) + 1)
        token_type_ids += [1] * (len(input_ids) - len(token_type_ids))

        # The mask has 1 for real tokens and 0 for padding tokens. Only real
        # tokens are attended to.
        attention_mask = [1] * len(input_ids)

        ######## PADDING ########
        padding_length = max_length - len(input_ids)
        input_ids = input_ids + [pad_token] * padding_length
        attention_mask = attention_mask + [0] * padding_length
        token_type_ids = token_type_ids + [0] * padding_length

        assert len(
            input_ids
        ) == max_length, "Error with input ids length {} vs {}".format(
            len(input_ids), max_length)
        assert len(attention_mask
                   ) == max_length, "Error with mask length {} vs {}".format(
                       len(attention_mask), max_length)
        assert len(token_type_ids
                   ) == max_length, "Error with token length {} vs {}".format(
                       len(token_type_ids), max_length)

        ######## ENCODING LABEL ########
        if example.question is not None:
            if bart:
                label = tokenizer.encode_plus(
                    example.question,
                    max_length=max_length_label,
                    truncation=True,
                )
                label_ids = label["input_ids"]
                padding_length = max_length_label - len(label_ids)
                label_ids = label_ids + [-100] * padding_length
                decoder_input_ids = shift_tokens_right(
                    torch.tensor(label_ids).unsqueeze(0),
                    -100).squeeze(0).tolist()
                decoder_input_ids = [
                    x if x != -100 else pad_token for x in decoder_input_ids
                ]
            else:
                label_ids = tokenizer.encode(
                    example.question,
                    add_special_tokens=True,
                    max_length=max_length_label,
                    truncation=True,
                )

                decoder_input_ids = label_ids
                padding_length = max_length_label - len(label_ids)
                label_ids = label_ids + [-100] * padding_length
                decoder_input_ids = decoder_input_ids + [pad_token
                                                         ] * padding_length

            decoder_attention_mask = [1] * max_length_label

            assert len(
                label_ids
            ) == max_length_label, "Error with input length {} vs {}".format(
                len(input_ids), max_length)
            assert len(
                decoder_input_ids
            ) == max_length_label, "Error with input length {} vs {}".format(
                len(decoder_input_ids), max_length_label)
            assert len(
                decoder_attention_mask
            ) == max_length_label, "Error with input length {} vs {}".format(
                len(decoder_attention_mask), max_length_label)

            features.append(
                InputFeatures(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    token_type_ids=token_type_ids,
                    label=label_ids,
                    decoder_input_ids=decoder_input_ids,
                    decoder_attention_mask=decoder_attention_mask,
                ))
        else:
            features.append(
                InputFeatures(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    token_type_ids=token_type_ids,
                    label=None,
                    decoder_input_ids=None,
                    decoder_attention_mask=None,
                ))
    return features
Beispiel #22
0
    def _step(self, batch: dict) -> tuple:
        """Compute the loss for a batch"""
        pad_token_id = self.tokenizer.pad_token_id
        input_ids, src_mask, labels = batch["input_ids"], batch[
            "attention_mask"], batch["labels"]
        if isinstance(self.model, T5ForConditionalGeneration):
            decoder_input_ids = self.model._shift_right(labels)
        else:
            decoder_input_ids = shift_tokens_right(labels, pad_token_id)

        # noinspection PyCallingNonCallable
        student_outputs = self(
            input_ids,
            attention_mask=src_mask,
            decoder_input_ids=decoder_input_ids,
            output_hidden_states=self.do_calc_hidden_loss,
            output_attentions=False,
            use_cache=False,
        )
        lm_logits = student_outputs.logits

        # Same cross entropy vs. label smoothing logic as finetune.py
        assert lm_logits.shape[-1] == self.model.config.vocab_size
        if self.hparams.label_smoothing == 0:
            # Same behavior as modeling_bart.py, besides ignoring pad_token_id
            loss_fct = torch.nn.CrossEntropyLoss(ignore_index=pad_token_id)
            student_lm_loss = loss_fct(lm_logits.view(-1, lm_logits.shape[-1]),
                                       labels.view(-1))
        else:
            lprobs = F.log_softmax(lm_logits, dim=-1)
            student_lm_loss, _ = label_smoothed_nll_loss(
                lprobs,
                labels,
                self.hparams.label_smoothing,
                ignore_index=pad_token_id)

        def zero_tensor():
            return torch.tensor(0.0).type_as(student_lm_loss)

        teacher_enc_outputs = student_outputs.encoder_last_hidden_state  # use this unless self.different_base_models
        hid_loss_enc, hid_loss_dec = zero_tensor(), zero_tensor()
        if self.different_encoder:  # compute encoder hidden state loss
            all_teacher_encoder_outputs = self.teacher.get_encoder()(
                input_ids,
                attention_mask=src_mask,
                output_hidden_states=self.do_calc_hidden_loss,
            )
            if self.different_base_models:
                teacher_enc_outputs = all_teacher_encoder_outputs.last_hidden_state
            elif self.do_calc_hidden_loss:
                hid_loss_enc = self.calc_hidden_loss(
                    src_mask,
                    student_outputs.encoder_hidden_states,
                    all_teacher_encoder_outputs.hidden_states,
                    self.e_matches,
                    normalize_hidden=self.hparams.normalize_hidden,
                )

        teacher_outputs = self.teacher(
            input_ids,
            attention_mask=src_mask,
            encoder_outputs=(teacher_enc_outputs, ),
            decoder_input_ids=decoder_input_ids,
            output_hidden_states=self.do_calc_hidden_loss,
            use_cache=
            False,  # since we are not passing labels, never let this default to True
        )
        dec_mask = decoder_input_ids.ne(pad_token_id)
        loss_ce = self.calc_ce_loss(dec_mask, lm_logits,
                                    teacher_outputs.logits)
        if self.do_calc_hidden_loss:  # Intermediate supervision of decoder hidden states
            hid_loss_dec = self.calc_hidden_loss(
                dec_mask,
                student_outputs.decoder_hidden_states,
                teacher_outputs.decoder_hidden_states,
                self.d_matches,
                normalize_hidden=self.hparams.normalize_hidden,
            )

        blended_loss = (self.alpha_ce * loss_ce +
                        self.alpha_mlm * student_lm_loss +
                        self.hparams.alpha_hid * (hid_loss_enc + hid_loss_dec))
        return blended_loss, loss_ce, student_lm_loss, hid_loss_enc, hid_loss_dec
Beispiel #23
0
    def train_epoch(self, batch_size, label_smooth_epsilon, weight, text_type):
        assert 'train' in self._dataset

        random.shuffle(self._dataset['train'])
        print_train_loss_ssl = 0.0
        print_train_loss = 0.0
        print_train_loss_net = 0.0
        num = 0
        for i in trange(0,
                        len(self._dataset['train']),
                        batch_size,
                        desc='BART Training'):
            self._model.split_to_gpus(n_gpus=min(2, torch.cuda.device_count()))
            self._model.train()
            self._ssl_model.train()

            batch = self._dataset['train'][i:i + batch_size]

            self._optimizer.zero_grad()

            for j in range(0, len(batch), LIL_BATCH_SIZE):
                lil_batch = batch[j:j + LIL_BATCH_SIZE]

                src_lengths = torch.tensor(
                    [len(t.src_tokens) for t in lil_batch])
                src_tokens = collate_tokens(
                    [t.src_tokens for t in lil_batch],
                    pad_idx=self._model.dictionary.pad())
                tgt_tokens = collate_tokens(
                    [t.tgt_tokens for t in lil_batch],
                    pad_idx=self._model.dictionary.pad())

                loss_net = self._get_label_smoothed_nll_loss(
                    src_lengths=src_lengths,
                    src_tokens=src_tokens,
                    tgt_tokens=tgt_tokens,
                    epsilon=label_smooth_epsilon)

                # SSL training
                if text_type == 0:
                    texts = [t.texts for t in lil_batch]
                    ssl_batch = tokenizer.batch_encode_plus(
                        texts,
                        padding=True,
                        max_length=SRC_MAX_LEN,
                        truncation=True,
                        return_tensors='pt')

                    ssl_input_ids, ssl_label_ids = mask_tokens(
                        ssl_batch.input_ids, tokenizer, MLM_PROBABILITY)
                    ssl_decoder_input_ids = shift_tokens_right(
                        ssl_batch.input_ids,
                        self._ssl_model.config.pad_token_id)
                elif text_type == 1:
                    ssl_input_ids, ssl_label_ids = mask_tokens(
                        src_tokens, tokenizer, MLM_PROBABILITY)
                    ssl_decoder_input_ids = shift_tokens_right(
                        src_tokens, self._ssl_model.config.pad_token_id)
                else:
                    ssl_input_ids, ssl_label_ids = mask_tokens(
                        tgt_tokens, tokenizer, MLM_PROBABILITY)
                    ssl_decoder_input_ids = shift_tokens_right(
                        tgt_tokens, self._ssl_model.config.pad_token_id)

                ssl_input_ids, ssl_label_ids, ssl_decoder_input_ids = ssl_input_ids.cuda(
                ), ssl_label_ids.cuda(), ssl_decoder_input_ids.cuda()

                ssl_outputs = self._ssl_model(
                    input_ids=ssl_input_ids,
                    decoder_input_ids=ssl_decoder_input_ids,
                    labels=ssl_label_ids,
                )

                loss_ssl = ssl_outputs[0]
                print_train_loss_ssl += loss_ssl
                print_train_loss_net += loss_net
                num += 1

                # Total training loss
                loss = (loss_net * NET_WEIGHT +
                        loss_ssl * weight) * len(lil_batch) / batch_size
                print_train_loss += loss * batch_size

                if torch.isnan(loss):
                    print('warning: nan loss')
                    print(f'tgt_text: {lil_batch[0].tgt_text}')
                else:
                    loss.backward()

            self._optimizer.step()
            self._lr_scheduler.step()

            self._global_step += 1
            if self._global_step % self._eval_steps == 0:
                self.gen_log()

        print("net training loss:", print_train_loss_net.item() / num)
        print("ssl training loss:", print_train_loss_ssl.item() / num)
        print("total training loss:", print_train_loss.item() / num)