Example #1
0
    def _step(self, batch: dict) -> Tuple:
        pad_token_id = self.tokenizer.pad_token_id
        src_ids, src_mask = batch["input_ids"], batch["attention_mask"]
        tgt_ids = batch["labels"]
        if isinstance(self.model, T5ForConditionalGeneration):
            decoder_input_ids = self.model._shift_right(tgt_ids)
        else:
            decoder_input_ids = shift_tokens_right(
                tgt_ids, pad_token_id,
                self.model.config.decoder_start_token_id)
        if not self.already_saved_batch:  # This would be slightly better if it only happened on rank zero
            batch["decoder_input_ids"] = decoder_input_ids
            self.save_readable_batch(batch)

        outputs = self(src_ids,
                       attention_mask=src_mask,
                       decoder_input_ids=decoder_input_ids,
                       use_cache=False)
        lm_logits = outputs["logits"]
        if self.hparams.label_smoothing == 0:
            # Same behavior as modeling_bart.py, besides ignoring pad_token_id
            ce_loss_fct = torch.nn.CrossEntropyLoss(ignore_index=pad_token_id)

            assert lm_logits.shape[-1] == self.vocab_size
            loss = ce_loss_fct(lm_logits.view(-1, lm_logits.shape[-1]),
                               tgt_ids.view(-1))
        else:
            lprobs = torch.nn.functional.log_softmax(lm_logits, dim=-1)
            loss, nll_loss = label_smoothed_nll_loss(
                lprobs,
                tgt_ids,
                self.hparams.label_smoothing,
                ignore_index=pad_token_id)
        return (loss, )
    def __call__(self, batch) -> Dict[str, torch.Tensor]:
        if hasattr(self.tokenizer, "prepare_seq2seq_batch"):
            batch = self._encode(batch)
            input_ids, attention_mask, labels = (
                batch["input_ids"],
                batch["attention_mask"],
                batch["labels"],
            )
        else:
            input_ids = torch.stack([x["input_ids"] for x in batch])
            attention_mask = torch.stack([x["attention_mask"] for x in batch])
            labels = torch.stack([x["labels"] for x in batch])

            labels = trim_batch(labels, self.pad_token_id)
            input_ids, attention_mask = trim_batch(
                input_ids, self.pad_token_id, attention_mask=attention_mask)

        if isinstance(self.tokenizer, T5Tokenizer):
            decoder_input_ids = self._shift_right_t5(labels)
        else:
            decoder_input_ids = shift_tokens_right(labels, self.pad_token_id)

        batch = {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "decoder_input_ids": decoder_input_ids,
            "labels": labels,
        }
        return batch
Example #3
0
def preprocess_data_mbart(data):
    input_text, target_text, tokenizer, args = data

    tokenized_example = tokenizer.prepare_seq2seq_batch(
        src_texts=[input_text],
        tgt_texts=[target_text],
        src_lang=args.src_lang,
        tgt_lang=args.tgt_lang,
        max_length=args.max_seq_length,
        padding="max_length",  # pad_to_max_length=True won't work in this case
        return_tensors="pt",
        truncation=True,
    )

    decoder_input_ids = tokenized_example["labels"].clone()
    decoder_input_ids = shift_tokens_right(decoder_input_ids,
                                           tokenizer.pad_token_id)

    labels = tokenized_example["labels"]
    labels[labels == tokenizer.pad_token_id] = -100

    return {
        "input_ids": tokenized_example["input_ids"].squeeze(),
        "attention_mask": tokenized_example["attention_mask"].squeeze(),
        "decoder_input_ids": decoder_input_ids.squeeze(),
        "labels": labels.squeeze(),
    }
 def test_seq2seq_max_target_length(self):
     batch = self.tokenizer.prepare_seq2seq_batch(self.src_text,
                                                  tgt_texts=self.tgt_text,
                                                  max_length=3,
                                                  max_target_length=10)
     batch["decoder_input_ids"] = shift_tokens_right(
         batch.labels, self.tokenizer.pad_token_id)
     self.assertEqual(batch.input_ids.shape[1], 3)
     self.assertEqual(batch.decoder_input_ids.shape[1], 10)
     # max_target_length will default to max_length if not specified
     batch = self.tokenizer.prepare_seq2seq_batch(self.src_text,
                                                  tgt_texts=self.tgt_text,
                                                  max_length=3)
     batch["decoder_input_ids"] = shift_tokens_right(
         batch.labels, self.tokenizer.pad_token_id)
     self.assertEqual(batch.input_ids.shape[1], 3)
     self.assertEqual(batch.decoder_input_ids.shape[1], 3)
 def test_batch_fairseq_parity(self):
     batch: BatchEncoding = self.tokenizer.prepare_seq2seq_batch(
         self.src_text, tgt_texts=self.tgt_text, return_tensors="pt")
     batch["decoder_input_ids"] = shift_tokens_right(
         batch.labels, self.tokenizer.pad_token_id)
     for k in batch:
         batch[k] = batch[k].tolist()
     # batch = {k: v.tolist() for k,v in batch.items()}
     # fairseq batch: https://gist.github.com/sshleifer/cba08bc2109361a74ac3760a7e30e4f4
     # batch.decoder_inputs_ids[0][0] ==
     assert batch.input_ids[1][-2:] == [2, EN_CODE]
     assert batch.decoder_input_ids[1][0] == RO_CODE
     assert batch.decoder_input_ids[1][-1] == 2
     assert batch.labels[1][-2:] == [2, RO_CODE]
Example #6
0
    def test_seq2seq_dataset_truncation(self, tok_name):
        tokenizer = AutoTokenizer.from_pretrained(tok_name)
        tmp_dir = make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir())
        max_len_source = max(len(tokenizer.encode(a)) for a in ARTICLES)
        max_len_target = max(len(tokenizer.encode(a)) for a in SUMMARIES)
        max_src_len = 4
        max_tgt_len = 8
        assert max_len_target > max_src_len  # Will be truncated
        assert max_len_source > max_src_len  # Will be truncated
        src_lang, tgt_lang = "ro_RO", "de_DE"  # ignored for all but mbart, but never causes error.
        train_dataset = Seq2SeqDataset(
            tokenizer,
            data_dir=tmp_dir,
            type_path="train",
            max_source_length=max_src_len,
            max_target_length=max_tgt_len,  # ignored
            src_lang=src_lang,
            tgt_lang=tgt_lang,
        )
        dataloader = DataLoader(train_dataset,
                                batch_size=2,
                                collate_fn=train_dataset.collate_fn)
        for batch in dataloader:
            assert isinstance(batch, dict)
            assert batch["attention_mask"].shape == batch["input_ids"].shape
            # show that articles were trimmed.
            assert batch["input_ids"].shape[1] == max_src_len
            # show that targets are the same len
            assert batch["labels"].shape[1] == max_tgt_len
            if tok_name != MBART_TINY:
                continue
            # check language codes in correct place
            batch["decoder_input_ids"] = shift_tokens_right(
                batch["labels"], tokenizer.pad_token_id)
            assert batch["decoder_input_ids"][
                0, 0].item() == tokenizer.lang_code_to_id[tgt_lang]
            assert batch["decoder_input_ids"][
                0, -1].item() == tokenizer.eos_token_id
            assert batch["input_ids"][0, -2].item() == tokenizer.eos_token_id
            assert batch["input_ids"][
                0, -1].item() == tokenizer.lang_code_to_id[src_lang]

            break  # No need to test every batch
    def test_enro_tokenizer_prepare_seq2seq_batch(self):
        batch = self.tokenizer.prepare_seq2seq_batch(
            self.src_text,
            tgt_texts=self.tgt_text,
            max_length=len(self.expected_src_tokens),
            return_tensors="pt")
        batch["decoder_input_ids"] = shift_tokens_right(
            batch.labels, self.tokenizer.pad_token_id)
        self.assertIsInstance(batch, BatchEncoding)

        self.assertEqual((2, 14), batch.input_ids.shape)
        self.assertEqual((2, 14), batch.attention_mask.shape)
        result = batch.input_ids.tolist()[0]
        self.assertListEqual(self.expected_src_tokens, result)
        self.assertEqual(2, batch.decoder_input_ids[0, -1])  # EOS
        # Test that special tokens are reset
        self.assertEqual(self.tokenizer.prefix_tokens, [])
        self.assertEqual(self.tokenizer.suffix_tokens,
                         [self.tokenizer.eos_token_id, EN_CODE])
    def test_forward(self):
        src, tgt = ["I am a small frog"], ["Ich bin ein kleiner Frosch."]
        expected_ids = [38, 121, 14, 697, 38848, 0]

        model_inputs: dict = self.tokenizer.prepare_seq2seq_batch(src, tgt_texts=tgt, return_tensors="pt").to(
            torch_device
        )

        self.assertListEqual(expected_ids, model_inputs.input_ids[0].tolist())

        desired_keys = {
            "input_ids",
            "attention_mask",
            "labels",
        }
        self.assertSetEqual(desired_keys, set(model_inputs.keys()))
        model_inputs["decoder_input_ids"] = shift_tokens_right(model_inputs.labels, self.tokenizer.pad_token_id)
        model_inputs["return_dict"] = True
        model_inputs["use_cache"] = False
        with torch.no_grad():
            outputs = self.model(**model_inputs)
        max_indices = outputs.logits.argmax(-1)
        self.tokenizer.batch_decode(max_indices)
Example #9
0
    def _step(self, batch: dict) -> tuple:
        """Compute the loss for a batch"""
        pad_token_id = self.tokenizer.pad_token_id
        input_ids, src_mask, labels = batch["input_ids"], batch[
            "attention_mask"], batch["labels"]
        if isinstance(self.model, T5ForConditionalGeneration):
            decoder_input_ids = self.model._shift_right(labels)
        else:
            decoder_input_ids = shift_tokens_right(labels, pad_token_id)

        # noinspection PyCallingNonCallable
        student_outputs = self(
            input_ids,
            attention_mask=src_mask,
            decoder_input_ids=decoder_input_ids,
            output_hidden_states=self.do_calc_hidden_loss,
            output_attentions=False,
            use_cache=False,
        )
        lm_logits = student_outputs["logits"]

        # Same cross entropy vs. label smoothing logic as finetune.py
        assert lm_logits.shape[-1] == self.model.config.vocab_size
        if self.hparams.label_smoothing == 0:
            # Same behavior as modeling_bart.py, besides ignoring pad_token_id
            loss_fct = nn.CrossEntropyLoss(ignore_index=pad_token_id)
            student_lm_loss = loss_fct(lm_logits.view(-1, lm_logits.shape[-1]),
                                       labels.view(-1))
        else:
            lprobs = nn.functional.log_softmax(lm_logits, dim=-1)
            student_lm_loss, _ = label_smoothed_nll_loss(
                lprobs,
                labels,
                self.hparams.label_smoothing,
                ignore_index=pad_token_id)

        def zero_tensor():
            return torch.tensor(0.0).type_as(student_lm_loss)

        teacher_enc_outputs = student_outputs[
            "encoder_last_hidden_state"]  # use this unless self.different_base_models
        hid_loss_enc, hid_loss_dec = zero_tensor(), zero_tensor()
        if self.different_encoder:  # compute encoder hidden state loss
            all_teacher_encoder_outputs = self.teacher.get_encoder()(
                input_ids,
                attention_mask=src_mask,
                output_hidden_states=self.do_calc_hidden_loss,
            )
            if self.different_base_models:
                teacher_enc_outputs = all_teacher_encoder_outputs[
                    "last_hidden_state"]
            elif self.do_calc_hidden_loss:
                hid_loss_enc = self.calc_hidden_loss(
                    src_mask,
                    student_outputs["encoder_hidden_states"],
                    all_teacher_encoder_outputs["hidden_states"],
                    self.e_matches,
                    normalize_hidden=self.hparams.normalize_hidden,
                )

        teacher_outputs = self.teacher(
            input_ids,
            attention_mask=src_mask,
            encoder_outputs=(teacher_enc_outputs, ),
            decoder_input_ids=decoder_input_ids,
            output_hidden_states=self.do_calc_hidden_loss,
            use_cache=
            False,  # since we are not passing labels, never let this default to True
        )
        dec_mask = decoder_input_ids.ne(pad_token_id)
        loss_ce = self.calc_ce_loss(dec_mask, lm_logits,
                                    teacher_outputs["logits"])
        if self.do_calc_hidden_loss:  # Intermediate supervision of decoder hidden states
            hid_loss_dec = self.calc_hidden_loss(
                dec_mask,
                student_outputs["decoder_hidden_states"],
                teacher_outputs["decoder_hidden_states"],
                self.d_matches,
                normalize_hidden=self.hparams.normalize_hidden,
            )

        blended_loss = (self.alpha_ce * loss_ce +
                        self.alpha_mlm * student_lm_loss +
                        self.hparams.alpha_hid * (hid_loss_enc + hid_loss_dec))
        return blended_loss, loss_ce, student_lm_loss, hid_loss_enc, hid_loss_dec
Example #10
0
        pad_token_id = tokenizer.pad_token_id
        mask_token_id = tokenizer.mask_token_id

        # model = BartModel.from_pretrained('./bart-base', return_dict=True)
        model = BartForConditionalGeneration.from_pretrained(tmp_model_name_path, return_dict=True).to(device)
        model.eval()
        st, ed = 0, 0
        all_loss = []
        while ed < len(ipt):
            st, ed = ed, (ed + batch_size) if (ed + batch_size < len(ipt)) else len(ipt)
            input_ids = tokenizer(ipt[st:ed], return_tensors="pt", padding=True, truncation=True, max_length=1000).input_ids.to(device)
            with torch.no_grad():
                src_ids = input_ids
                tgt_ids = tokenizer(opt[st:ed], return_tensors="pt", padding=True, truncation=True, max_length=1000).input_ids.to(device)
                # tgt_ids = torch.cat([torch.zeros([batch_size, 1], dtype=tgt_ids.dtype), tgt_ids], 1)
                decoder_input_ids = shift_tokens_right(tgt_ids, pad_token_id)
                # print(src_ids, tokenizer.decode(src_ids[0], skip_special_tokens=False))
                # print(tgt_ids, tokenizer.decode(tgt_ids[0], skip_special_tokens=False))
                outputs = model(src_ids, decoder_input_ids=decoder_input_ids, use_cache=False)
                lm_logits = outputs["logits"]
                # print(src_ids.size(), lm_logits.size(), decoder_input_ids.size())


                tmp_batch_size = lm_logits.size()[0]
                pad_pos = torch.eq(tgt_ids, pad_token_id).to(torch.float)
                sen_pos = torch.eq(tgt_ids, mask_token_id).to(torch.float)
                dis_pos = torch.cat([torch.zeros([tmp_batch_size, 1]).to(sen_pos.device), sen_pos[:, :-1]], 1)
                loss_mask = 1 - (pad_pos + sen_pos + dis_pos)
                # Same behavior as modeling_bart.py, besides ignoring pad_token_id
                ce_loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
Example #11
0
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        vis_inputs=None,
        vis_attention_mask=None,
        decoder_input_ids=None,
        decoder_attention_mask=None,
        encoder_outputs=None,
        past_key_values=None,
        inputs_embeds=None,
        decoder_inputs_embeds=None,
        labels=None,
        use_cache=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        reduce_loss=False,
        **kwargs,
    ):
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if labels is not None:
            if decoder_input_ids is None:
                decoder_input_ids = shift_tokens_right(
                    labels, self.config.pad_token_id,
                    self.config.decoder_start_token_id)

        outputs = self.model(
            input_ids,
            attention_mask=attention_mask,
            vis_inputs=vis_inputs,
            vis_attention_mask=vis_attention_mask,
            decoder_input_ids=decoder_input_ids,
            encoder_outputs=encoder_outputs,
            decoder_attention_mask=decoder_attention_mask,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            decoder_inputs_embeds=decoder_inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias

        masked_lm_loss = None
        if labels is not None:
            # loss_fct = CrossEntropyLoss()
            if reduce_loss:
                loss_fct = CrossEntropyLoss(ignore_index=-100)
            else:
                loss_fct = CrossEntropyLoss(ignore_index=-100,
                                            reduction='none')
            masked_lm_loss = loss_fct(
                lm_logits.view(-1, self.config.vocab_size), labels.view(-1))

        if not return_dict:
            output = (lm_logits, ) + outputs[1:]
            return ((masked_lm_loss, ) +
                    output) if masked_lm_loss is not None else output

        return Seq2SeqLMOutput(
            loss=masked_lm_loss,
            logits=lm_logits,
            past_key_values=outputs.past_key_values,
            decoder_hidden_states=outputs.decoder_hidden_states,
            decoder_attentions=outputs.decoder_attentions,
            cross_attentions=outputs.cross_attentions,
            encoder_last_hidden_state=outputs.encoder_last_hidden_state,
            encoder_hidden_states=outputs.encoder_hidden_states,
            encoder_attentions=outputs.encoder_attentions,
        )
Example #12
0
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        vis_inputs=None,
        vis_attention_mask=None,
        decoder_input_ids=None,
        decoder_attention_mask=None,
        encoder_outputs=None,
        past_key_values=None,
        inputs_embeds=None,
        decoder_inputs_embeds=None,
        use_cache=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        **kwargs,
    ):

        # different to other models, Bart automatically creates decoder_input_ids from
        # input_ids if no decoder_input_ids are provided
        if decoder_input_ids is None and decoder_inputs_embeds is None:
            decoder_input_ids = shift_tokens_right(
                input_ids, self.config.pad_token_id,
                self.config.decoder_start_token_id)

        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (output_hidden_states
                                if output_hidden_states is not None else
                                self.config.output_hidden_states)
        use_cache = use_cache if use_cache is not None else self.config.use_cache
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if encoder_outputs is None:
            encoder_outputs = self.encoder(
                input_ids=input_ids,
                attention_mask=attention_mask,
                vis_inputs=vis_inputs,
                vis_attention_mask=vis_attention_mask,
                inputs_embeds=inputs_embeds,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
            )

        # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=False
        elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
            encoder_outputs = BaseModelOutput(
                last_hidden_state=encoder_outputs[0],
                hidden_states=encoder_outputs[1]
                if len(encoder_outputs) > 1 else None,
                attentions=encoder_outputs[2]
                if len(encoder_outputs) > 2 else None,
            )

        if attention_mask is None:
            attention_mask = input_ids.ne(self.config.pad_token_id).to(
                dtype=torch.float, device=input_ids.device)
        if vis_attention_mask is None:
            B, L = attention_mask.size()
            V_L = encoder_outputs[0].size(1) - L
            vis_attention_mask = attention_mask.new_ones(B, V_L)
        encoder_attention_mask = torch.cat(
            [attention_mask, vis_attention_mask], dim=1)

        # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
        decoder_outputs = self.decoder(
            input_ids=decoder_input_ids,
            attention_mask=decoder_attention_mask,
            encoder_hidden_states=encoder_outputs[0],
            # encoder_attention_mask=attention_mask,
            encoder_attention_mask=encoder_attention_mask,
            past_key_values=past_key_values,
            inputs_embeds=decoder_inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        if not return_dict:
            return decoder_outputs + encoder_outputs

        return Seq2SeqModelOutput(
            last_hidden_state=decoder_outputs.last_hidden_state,
            past_key_values=decoder_outputs.past_key_values,
            decoder_hidden_states=decoder_outputs.hidden_states,
            decoder_attentions=decoder_outputs.attentions,
            cross_attentions=decoder_outputs.cross_attentions,
            encoder_last_hidden_state=encoder_outputs.last_hidden_state,
            encoder_hidden_states=encoder_outputs.hidden_states,
            encoder_attentions=encoder_outputs.attentions,
        )
Example #13
0
    def _valid_step(self, batch: dict) -> dict:
        pad_token_id = self.tokenizer.pad_token_id
        src_ids, src_mask = batch["input_ids"], batch["attention_mask"]

        tgt_ids = batch["labels"]
        if isinstance(self.model, T5ForConditionalGeneration):
            decoder_input_ids = self.model._shift_right(tgt_ids)
        else:
            decoder_input_ids = shift_tokens_right(tgt_ids, pad_token_id)

        if not self.already_saved_batch:  # This would be slightly better if it only happened on rank zero
            batch["decoder_input_ids"] = decoder_input_ids
            self.save_readable_batch(batch)

        outputs = self(src_ids,
                       attention_mask=src_mask,
                       decoder_input_ids=decoder_input_ids,
                       use_cache=False)
        lm_logits = outputs["logits"]

        if self.hparams.label_smoothing == 0:
            # Same behavior as modeling_bart.py, besides ignoring pad_token_id
            ce_loss_fct = torch.nn.CrossEntropyLoss(ignore_index=pad_token_id)
            assert lm_logits.shape[-1] == self.vocab_size

            loss = ce_loss_fct(lm_logits.view(-1, lm_logits.shape[-1]),
                               tgt_ids.view(-1))
            p_tensor = torch.all(torch.max(lm_logits, 2)[1] == tgt_ids, 1)
            acc = p_tensor.sum() / sum(p_tensor.shape)
            lm_loss, ti_loss = loss + 0., loss + 0.

            # ce_loss_fct = torch.nn.CrossEntropyLoss(ignore_index=pad_token_id, reduction="none")
            # assert lm_logits.shape[-1] == self.vocab_size
            # batch_loss = ce_loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), tgt_ids.view(-1))

            # lm_mask = (batch["loss_labels"].eq(0).to(torch.float)[:, None] * torch.ones_like(tgt_ids) * (1 - tgt_ids.eq(pad_token_id).to(torch.float))).view(-1)
            # ti_mask = (batch["loss_labels"].eq(1).to(torch.float)[:, None] * torch.ones_like(tgt_ids) * (1 - tgt_ids.eq(pad_token_id).to(torch.float))).view(-1)

            # lm_loss = torch.sum(batch_loss * lm_mask) / (torch.sum(lm_mask) + 1e-20)
            # ti_loss = torch.sum(batch_loss * ti_mask) / (torch.sum(ti_mask) + 1e-20)
            # loss = torch.sum(batch_loss * (lm_mask + ti_mask)) / (torch.sum(lm_mask + ti_mask) + 1e-20)
        else:
            lprobs = torch.nn.functional.log_softmax(lm_logits, dim=-1)
            loss, nll_loss = label_smoothed_nll_loss(
                lprobs,
                tgt_ids,
                self.hparams.label_smoothing,
                ignore_index=pad_token_id)
            lm_loss, ti_loss = loss + 0., ti_loss + 0.

        # print(src_ids)
        # print(batch.keys())
        # print(batch["ids"].cpu().numpy())
        # print(batch["loss_labels"].cpu().numpy())
        # for i in range(3):
        #     print(self.tokenizer.convert_ids_to_tokens(batch["input_ids"].cpu().numpy()[i]))
        #     print(self.tokenizer.convert_ids_to_tokens(decoder_input_ids.cpu().numpy()[i]))
        #     print(self.tokenizer.convert_ids_to_tokens(batch["labels"].cpu().numpy()[i]))
        #     # print(self.tokenizer.unk_token, self.tokenizer.pad_token, self.tokenizer.eos_token, self.tokenizer.bos_token, self.tokenizer.cls_token, self.tokenizer.mask_token)
        #     print("="*10)
        # print("="*30)
        # print(loss, lm_loss, ti_loss)
        # exit()

        return {'loss': loss, 'acc': acc}
Example #14
0
    def forward(
            self,
            input_ids=None,
            input_intensity_labels=None,
            attention_mask=None,
            decoder_input_ids=None,
            decoder_attention_mask=None,
            head_mask=None,
            decoder_head_mask=None,
            encoder_outputs=None,
            past_key_values=None,
            inputs_embeds=None,
            decoder_inputs_embeds=None,
            use_cache=None,
            output_attentions=None,
            output_hidden_states=None,
            return_dict=None,
    ):
        # decoder_input设为了一个序列长度为1的全零向量,所以这一步应该也不会执行
        if decoder_input_ids is None and decoder_inputs_embeds is None:
            decoder_input_ids = shift_tokens_right(
                input_ids, self.config.pad_token_id, self.config.decoder_start_token_id
            )

        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        use_cache = use_cache if use_cache is not None else self.config.use_cache
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
        elif input_ids is not None:
            input_shape = input_ids.size()
            input_ids = input_ids.view(-1, input_shape[-1])
            # input_ids = input_ids
        elif inputs_embeds is not None:
            inputs_embeds = self.resnet(inputs_embeds)
            # inputs_embeds = self.batchnorm(inputs_embeds)
            input_shape = inputs_embeds.size()[:-1]
        else:
            raise ValueError("You have to specify either input_ids or inputs_embeds")
        if input_intensity_labels is not None:
            inputs_intensity_embeds = self.shared(input_intensity_labels)
            # inputs_intensity_embeds = self.batchnorm(inputs_intensity_embeds)
            inputs_embeds = torch.cat((inputs_embeds, inputs_intensity_embeds), dim=-1)
            inputs_embeds = self.linear2(inputs_embeds)
        attention_mask = torch.ones(input_shape[0], input_shape[1]).to(self.args.device)

        if encoder_outputs is None:
            encoder_outputs = self.encoder(
                input_ids=input_ids,
                attention_mask=attention_mask,
                head_mask=head_mask,
                inputs_embeds=inputs_embeds,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
            )


        # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
        # config中将return dict设置为了false,所以这一步应该不会执行
        elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
            encoder_outputs = BaseModelOutput(
                last_hidden_state=encoder_outputs[0],
                hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
                attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
            )

        # decoder
        # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
        decoder_outputs = ()
        for i in range(4):
            if decoder_input_ids is not None:
                dec_size = decoder_input_ids.size()[:2]
                decoder_attention_mask = torch.ones(dec_size[0], dec_size[1]).to(self.args.device)
            elif decoder_inputs_embeds is not None:
                dec_size = decoder_inputs_embeds.size()[:2]
                decoder_attention_mask = torch.ones(dec_size[0], dec_size[1]).to(self.args.device)
            else:
                decoder_attention_mask = None

            # make masks if user doesn't supply
            # if not use_cache:
            #     if decoder_input_ids is not None:
            #         decoder_input_ids, decoder_padding_mask, causal_mask = _prepare_bart_decoder_inputs(
            #             self.config,
            #             input_ids,
            #             decoder_input_ids=decoder_input_ids,
            #             decoder_padding_mask=decoder_attention_mask,
            #             causal_mask_dtype=self.shared.weight.dtype,
            #         )
            #     else:
            #         decoder_input_embeds, decoder_padding_mask, causal_mask = _prepare_bart_decoder_inputs(
            #             self.config,
            #             input_ids,
            #             decoder_input_ids=decoder_inputs_embeds,
            #             decoder_padding_mask=decoder_attention_mask,
            #             causal_mask_dtype=self.shared.weight.dtype,
            #         )
            # else:
            #     decoder_padding_mask, causal_mask = None, None

            assert decoder_input_ids is not None or decoder_inputs_embeds is not None

            # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
            decoder_outputs = self.decoder(
                input_ids=decoder_input_ids,
                attention_mask=decoder_attention_mask,
                encoder_hidden_states=encoder_outputs[0],
                encoder_attention_mask=attention_mask,
                head_mask=decoder_head_mask,
                encoder_head_mask=head_mask,
                past_key_values=past_key_values,
                inputs_embeds=decoder_inputs_embeds,
                use_cache=use_cache,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
            )
            decoder_inputs_embeds = torch.cat((decoder_inputs_embeds, decoder_outputs[0][:, -1:, :]), dim=1)
            decoder_input_ids = None

        # return
        if not return_dict:
            out = decoder_outputs + encoder_outputs
            out = (self.linear(out[0]),) + out
        else:
            out = Seq2SeqModelOutput(
                last_hidden_state=decoder_outputs.last_hidden_state,
                past_key_values=decoder_outputs.past_key_values,
                decoder_hidden_states=decoder_outputs.hidden_states,
                decoder_attentions=decoder_outputs.attentions,
                cross_attentions=decoder_outputs.cross_attentions,
                encoder_last_hidden_state=encoder_outputs.last_hidden_state,
                encoder_hidden_states=encoder_outputs.hidden_states,
                encoder_attentions=encoder_outputs.attentions,
            )
        return out