Ejemplo n.º 1
0
    def test_loss_fn(self):
        model = AutoModelForSeq2SeqLM.from_pretrained(BART_TINY,
                                                      return_dict=True)
        input_ids, mask = model.dummy_inputs["input_ids"], model.dummy_inputs[
            "attention_mask"]
        target_ids = torch.tensor([[0, 4, 8, 2], [0, 8, 2, 1]],
                                  dtype=torch.long,
                                  device=model.device)
        decoder_input_ids = target_ids[:, :-1].contiguous()  # Why this line?
        lm_labels = target_ids[:, 1:].clone()  # why clone?
        model_computed_loss = model(input_ids,
                                    attention_mask=mask,
                                    decoder_input_ids=decoder_input_ids,
                                    labels=lm_labels,
                                    use_cache=False).loss

        logits = model(input_ids,
                       attention_mask=mask,
                       decoder_input_ids=decoder_input_ids,
                       use_cache=False).logits

        lprobs = torch.nn.functional.log_softmax(logits, dim=-1)
        smoothed_loss, nll_loss = label_smoothed_nll_loss(
            lprobs, lm_labels, 0.1, ignore_index=model.config.pad_token_id)
        with self.assertRaises(AssertionError):
            # TODO: understand why this breaks
            self.assertEqual(nll_loss, model_computed_loss)
Ejemplo n.º 2
0
    def _step(self, batch: dict) -> Tuple:
        pad_token_id = self.tokenizer.pad_token_id
        src_ids, src_mask = batch["input_ids"], batch["attention_mask"]
        if isinstance(self.model, T5ForConditionalGeneration):
            tgt_ids = batch["labels"]
            decoder_input_ids = self.model._shift_right(tgt_ids)
        else:
            #decoder_input_ids = shift_tokens_right(tgt_ids, pad_token_id)
            y = batch["labels"]
            decoder_input_ids = y[:, :-1].contiguous()
            tgt_ids = y[:, 1:].clone()
        if not self.already_saved_batch:  # This would be slightly better if it only happened on rank zero
            batch["decoder_input_ids"] = decoder_input_ids
            self.save_readable_batch(batch)

        outputs = self(src_ids,
                       attention_mask=src_mask,
                       decoder_input_ids=decoder_input_ids,
                       use_cache=False)
        lm_logits = outputs[0]
        if self.hparams.label_smoothing == 0:
            # Same behavior as modeling_bart.py, besides ignoring pad_token_id
            ce_loss_fct = torch.nn.CrossEntropyLoss(ignore_index=pad_token_id)

            assert lm_logits.shape[-1] == self.vocab_size
            loss = ce_loss_fct(lm_logits.view(-1, lm_logits.shape[-1]),
                               tgt_ids.view(-1))
        else:
            lprobs = torch.nn.functional.log_softmax(lm_logits, dim=-1)
            loss, nll_loss = label_smoothed_nll_loss(
                lprobs,
                tgt_ids,
                self.hparams.label_smoothing,
                ignore_index=pad_token_id)
        return (loss, )
Ejemplo n.º 3
0
    def _step(self, batch: dict) -> Tuple:
        pad_token_id = self.tokenizer.pad_token_id
        src_ids, src_mask = batch["input_ids"], batch["attention_mask"]
        tgt_ids = batch["labels"]
        if isinstance(self.model, T5ForConditionalGeneration):
            decoder_input_ids = self.model._shift_right(tgt_ids)
        else:
            decoder_input_ids = shift_tokens_right(tgt_ids, pad_token_id)

        outputs = self(src_ids,
                       attention_mask=src_mask,
                       decoder_input_ids=decoder_input_ids,
                       use_cache=False)
        lm_logits = outputs[0]
        if self.hparams.label_smoothing == 0:
            # Same behavior as modeling_bart.py, besides ignoring pad_token_id
            ce_loss_fct = torch.nn.CrossEntropyLoss(ignore_index=pad_token_id)

            assert lm_logits.shape[-1] == self.vocab_size
            loss = ce_loss_fct(lm_logits.view(-1, lm_logits.shape[-1]),
                               tgt_ids.view(-1))
        else:
            lprobs = torch.nn.functional.log_softmax(lm_logits, dim=-1)
            loss, nll_loss = label_smoothed_nll_loss(
                lprobs,
                tgt_ids,
                self.hparams.label_smoothing,
                ignore_index=pad_token_id)
        return (loss, )
Ejemplo n.º 4
0
    def _step(self, batch: dict) -> Tuple:
        pad_token_id = self.tokenizer.pad_token_id
        source_ids, source_mask, target_ids = batch["input_ids"], batch[
            "attention_mask"], batch["decoder_input_ids"]

        if isinstance(self.model, T5ForConditionalGeneration):
            decoder_input_ids = self.model._shift_right(target_ids)
            lm_labels = target_ids
        else:
            decoder_input_ids = target_ids[:, :-1].contiguous(
            )  # Why this line?
            lm_labels = target_ids[:, 1:].clone()  # why clone?

        outputs = self(source_ids,
                       attention_mask=source_mask,
                       decoder_input_ids=decoder_input_ids,
                       use_cache=False)

        if self.hparams.label_smoothing == 0:
            # Same behavior as modeling_bart.py
            loss_fct = torch.nn.CrossEntropyLoss(ignore_index=pad_token_id)
            lm_logits = outputs[0]
            assert lm_logits.shape[-1] == self.model.config.vocab_size
            loss = loss_fct(lm_logits.view(-1, lm_logits.shape[-1]),
                            lm_labels.view(-1))
        else:
            lprobs = torch.nn.functional.log_softmax(outputs[0], dim=-1)
            loss, nll_loss = label_smoothed_nll_loss(
                lprobs,
                lm_labels,
                self.hparams.label_smoothing,
                ignore_index=pad_token_id)
        return (loss, )
Ejemplo n.º 5
0
    def _step(self, batch: dict) -> Tuple:
        pad_token_id = self.tokenizer.pad_token_id

        # source_ids, source_mask, target_ids = batch["input_ids"], batch["attention_mask"], batch["decoder_input_ids"]
        source_ids, source_mask, target_ids, topic_p = batch[
            "input_ids"], batch["attention_mask"], batch[
                "decoder_input_ids"], batch['topic_p']

        decoder_input_ids = target_ids[:, :-1].contiguous()
        lm_labels = target_ids[:, 1:].clone()

        outputs = self(source_ids,
                       attention_mask=source_mask,
                       decoder_input_ids=decoder_input_ids,
                       topic_p=topic_p,
                       use_cache=False)
        # calculate loss
        if self.hparams.label_smoothing == 0:
            # Same behavior as modeling_bart.py
            loss_fct = torch.nn.CrossEntropyLoss(ignore_index=pad_token_id)
            lm_logits = outputs[0]
            assert lm_logits.shape[-1] == self.model.config.vocab_size
            loss = loss_fct(lm_logits.view(-1, lm_logits.shape[-1]),
                            lm_labels.view(-1))
        else:
            lprobs = torch.nn.functional.log_softmax(outputs[0], dim=-1)
            loss, nll_loss = label_smoothed_nll_loss(
                lprobs,
                lm_labels,
                self.hparams.label_smoothing,
                ignore_index=pad_token_id)
        return (loss, )
Ejemplo n.º 6
0
    def forward(self, input_ids, attention_mask=None, encoder_outputs=None,
            decoder_input_ids=None, decoder_attention_mask=None, decoder_cached_states=None,
            use_cache=False, is_training=False):

        if is_training:
            _decoder_input_ids = shift_tokens_right(decoder_input_ids, self.config.pad_token_id)
        else:
            _decoder_input_ids = decoder_input_ids

        outputs = self.model(
            input_ids,
            attention_mask=attention_mask,
            encoder_outputs=encoder_outputs,
            decoder_input_ids=_decoder_input_ids,
            decoder_attention_mask=decoder_attention_mask,
            decoder_cached_states=decoder_cached_states,
            use_cache=use_cache,
        )
        lm_logits = F.linear(outputs[0], self.model.shared.weight, bias=self.final_logits_bias)
        if is_training:
            # loss_fct = nn.CrossEntropyLoss(reduction="mean", ignore_index=self.config.pad_token_id)
            # loss = loss_fct(lm_logits.view(-1, self.config.vocab_size),
            #                   decoder_input_ids.view(-1))
            lprobs = F.log_softmax(lm_logits, dim=-1)
            loss, _ = label_smoothed_nll_loss(lprobs, decoder_input_ids, epsilon=0.1, ignore_index=self.config.pad_token_id)
            return loss
        return (lm_logits, ) + outputs[1:]
Ejemplo n.º 7
0
    def _compute_loss(self, model, inputs):
        inputs = copy.deepcopy(inputs)
        if self.args.label_smoothing == 0:
            if self.data_args is not None and self.data_args.ignore_pad_token_for_loss:
                # force training to ignore pad token
                labels = inputs.pop("labels")
                logits = model(**inputs, use_cache=False)[0]

                loss_fct = torch.nn.CrossEntropyLoss(
                    ignore_index=self.config.pad_token_id)
                loss = loss_fct(logits.view(-1, logits.shape[-1]),
                                labels.view(-1))
            else:
                # compute usual loss via models
                loss, logits = model(**inputs, use_cache=False)[:2]
        else:
            # compute label smoothed loss
            labels = inputs.pop("labels")
            logits = model(**inputs, use_cache=False)[0]
            lprobs = torch.nn.functional.log_softmax(logits, dim=-1)
            loss, _ = label_smoothed_nll_loss(
                lprobs,
                labels,
                self.args.label_smoothing,
                ignore_index=self.config.pad_token_id)
        return loss, logits
Ejemplo n.º 8
0
    def forward(self, outs, graph_state, target_rel=None, work=False):
        def get_scores(dep, head):
            head = torch.tanh(self.transfer_head(head))
            dep = torch.tanh(self.transfer_dep(dep))

            head = F.dropout(head, p=self.dropout, training=self.training)
            dep = F.dropout(dep, p=self.dropout, training=self.training)

            dep_num, bsz, _ = dep.size()
            head_num = head.size(0)

            bias_dep = dep.new_ones((dep_num, bsz, 1))
            bias_head = head.new_ones((head_num, bsz, 1))

            # seq_len x bsz x dim
            dep = torch.cat([dep, bias_dep], 2)
            head = torch.cat([head, bias_head], 2)

            #bsz x dep_num x vocab_size x dim
            dep = self.proj(dep).view(dep_num, bsz, self.vocabs['rel'].size,
                                      -1).transpose(0, 1).contiguous()
            #bsz x dim x head_num
            head = head.permute(1, 2, 0)

            #bsz x dep_num x vocab_size x head_num
            scores = torch.bmm(
                dep.view(bsz, dep_num * self.vocabs['rel'].size, -1),
                head).view(bsz, dep_num, self.vocabs['rel'].size, head_num)
            return scores

        scores = get_scores(outs, graph_state).permute(1, 0, 3, 2).contiguous()

        dep_num, bsz, _ = outs.size()
        head_num = graph_state.size(0)
        log_probs = F.log_softmax(scores, dim=-1)
        _, rel = torch.max(log_probs, -1)
        if work:
            #dep_num x bsz x head x vocab
            return log_probs

        rel_mask = torch.eq(target_rel,
                            self.vocabs['rel'].token2idx(NIL)) + torch.eq(
                                target_rel, self.vocabs['rel'].token2idx(PAD))
        rel_acc = (torch.eq(rel,
                            target_rel).float().masked_fill_(rel_mask,
                                                             0.)).sum().item()
        rel_tot = rel_mask.numel() - rel_mask.float().sum().item()
        if not self.training:
            print('rel acc %.3f' % (rel_acc / rel_tot))
        rel_loss = label_smoothed_nll_loss(
            log_probs.view(-1, self.vocabs['rel'].size), target_rel.view(-1),
            0.).view(dep_num, bsz, head_num)
        rel_loss = rel_loss.masked_fill_(rel_mask, 0.).sum((0, 2))
        return rel_loss
Ejemplo n.º 9
0
 def _compute_loss(self, logits, labels, ignore_index):
     if self.args.label_smoothing == 0:
         # Same behavior as modeling_bart.py
         loss_fct = torch.nn.CrossEntropyLoss(ignore_index=ignore_index)
         assert logits.shape[-1] == self.model.config.vocab_size
         loss = loss_fct(logits.view(-1, logits.shape[-1]), labels.view(-1))
     else:
         lprobs = torch.nn.functional.log_softmax(logits, dim=-1)
         loss, nll_loss = label_smoothed_nll_loss(lprobs,
                                                  labels,
                                                  self.args.label_smoothing,
                                                  ignore_index=ignore_index)
     return loss
Ejemplo n.º 10
0
    def _training_step(self, model: nn.Module,
                       inputs: Dict[str, Union[torch.Tensor, Any]],
                       optimizer: torch.optim.Optimizer) -> float:
        model.train()
        for k, v in inputs.items():
            if isinstance(v, torch.Tensor):
                inputs[k] = v.to(self.args.device)

        # Our model outputs do not work with DataParallel, so forcing return tuple.
        if isinstance(model, nn.DataParallel):
            inputs["return_tuple"] = True

        if self.label_smoothing == 0:
            outputs = model(**inputs)
            loss = outputs[
                0]  # model outputs are always tuple in transformers (see doc)
        else:
            labels = inputs.pop("labels")
            labels[labels == -100] = model.config.pad_token_id
            outputs = model(**inputs)
            lprobs = torch.nn.functional.log_softmax(outputs[0], dim=-1)
            loss, nll_loss = label_smoothed_nll_loss(
                lprobs,
                labels,
                self.label_smoothing,
                ignore_index=model.config.pad_token_id)

        if self.args.n_gpu > 1:
            loss = loss.mean(
            )  # mean() to average on multi-gpu parallel training
        if self.args.gradient_accumulation_steps > 1:
            loss = loss / self.args.gradient_accumulation_steps

        if self.args.fp16:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()

        return loss.item()
Ejemplo n.º 11
0
    def _step(self, batch: dict) -> Tuple:
        pad_token_id = self.tokenizer.pad_token_id
        src_ids, src_mask = batch["input_ids"], batch["attention_mask"]
        tgt_ids = batch["labels"]
        if isinstance(self.model, T5ForConditionalGeneration):
            decoder_input_ids = self.model._shift_right(tgt_ids)
        else:
            decoder_input_ids = shift_tokens_right(tgt_ids, pad_token_id)
        if not self.already_saved_batch:  # This would be slightly better if it only happened on rank zero
            batch["decoder_input_ids"] = decoder_input_ids
            self.save_readable_batch(batch)

        outputs = self(src_ids, attention_mask=src_mask, decoder_input_ids=decoder_input_ids, use_cache=False)
        lm_logits = outputs[0]

        if self.hparams.label_smoothing == 0:
            # Same behavior as modeling_bart.py, besides ignoring pad_token_id
            ce_loss_fct = torch.nn.CrossEntropyLoss(ignore_index=pad_token_id)

            assert lm_logits.shape[-1] == self.vocab_size
            loss = ce_loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), tgt_ids.view(-1))
        else:
            lprobs = torch.nn.functional.log_softmax(lm_logits, dim=-1)
            loss, nll_loss = label_smoothed_nll_loss(
                lprobs, tgt_ids, self.hparams.label_smoothing, ignore_index=pad_token_id
            )

        batch_size = src_ids.shape[0]
        loss_log = {'ce': loss.item()}
 
        if self.unlikelihood_training:
            ul_loss = self.unlikelihood_loss(decoder_input_ids, lm_logits, self.weight_vector, self.unlikelihood_selective_penalty)
            ul_loss_weighted = ul_loss * self.unlikelihood_alpha
            loss_log['ul_logr'] = ul_loss_weighted.item()/batch_size
            loss += ul_loss_weighted

        self.losses.append(loss_log)
        return (loss,)
Ejemplo n.º 12
0
    def step(self, batch):
        source_ids, source_mask, target_ids = (
            batch["input_ids"],
            batch["attention_mask"],
            batch["decoder_input_ids"],
        )

        decoder_input_ids = target_ids[:, :-1].contiguous()  # Why this line?
        lm_labels = target_ids[:, 1:].clone()  # why clone?

        outputs = self(
            source_ids,
            attention_mask=source_mask,
            decoder_input_ids=decoder_input_ids,
            use_cache=False,
        )
        lprobs = F.log_softmax(outputs[0], dim=-1)
        loss = label_smoothed_nll_loss(
            lprobs,
            lm_labels,
            epsilon=0.1,
            ignore_index=self.tokenizer.pad_token_id)
        return loss[0]
Ejemplo n.º 13
0
    def _step(self, batch):
        # assert is_frozen(self.teacher)
        pad_token_id = self.tokenizer.pad_token_id
        input_ids, src_mask, tgt_ids = batch["input_ids"], batch[
            "attention_mask"], batch["labels"]
        decoder_input_ids = shift_tokens_right(tgt_ids, pad_token_id)
        # noinspection PyCallingNonCallable
        lm_logits, dec_hidden, enc_outputs, enc_hidden_state = self(
            input_ids,
            attention_mask=src_mask,
            decoder_input_ids=decoder_input_ids,
            output_hidden_states=True,
            output_attentions=False,
            use_cache=False,
        )  # TODO(@sshleifer): return_dict=True cleanup

        # Same cross entropy vs. label smoothing logic as finetune.py
        assert lm_logits.shape[-1] == self.model.config.vocab_size
        if self.hparams.label_smoothing == 0:
            # Same behavior as modeling_bart.py, besides ignoring pad_token_id
            loss_fct = torch.nn.CrossEntropyLoss(ignore_index=pad_token_id)
            student_lm_loss = loss_fct(lm_logits.view(-1, lm_logits.shape[-1]),
                                       tgt_ids.view(-1))
        else:
            lprobs = torch.nn.functional.log_softmax(lm_logits, dim=-1)
            student_lm_loss, _ = label_smoothed_nll_loss(
                lprobs,
                tgt_ids,
                self.hparams.label_smoothing,
                ignore_index=pad_token_id)

        def zero_tensor():
            return torch.tensor(0.0).type_as(student_lm_loss)

        loss_encoder, hid_loss_enc, hid_loss_dec = zero_tensor(), zero_tensor(
        ), zero_tensor()
        if self.different_encoder:
            with torch.no_grad():
                teacher_enc_outputs, teacher_enc_hid, _ = self.teacher.model.encoder(
                    input_ids,
                    attention_mask=src_mask,
                    output_hidden_states=True)
            if self.hparams.alpha_encoder_loss > 0:
                loss_encoder = self.calc_mse_loss(enc_outputs,
                                                  teacher_enc_outputs,
                                                  src_mask)

            hid_loss_enc = self.calc_hidden_loss(src_mask, enc_hidden_state,
                                                 teacher_enc_hid,
                                                 self.hparams.e_layer_to_copy)

        teacher_enc_outputs = (enc_outputs, )
        assert isinstance(teacher_enc_outputs,
                          tuple), type(teacher_enc_outputs)

        with torch.no_grad():
            tloss, tlogits, tdec_hidden, _ = self.teacher(
                input_ids,
                attention_mask=src_mask,
                encoder_outputs=teacher_enc_outputs,
                decoder_input_ids=decoder_input_ids,
                lm_labels=tgt_ids,
                output_hidden_states=True,
            )
        dec_mask = decoder_input_ids.ne(pad_token_id)
        loss_ce, s_logits_slct, t_logits_slct = self.calc_ce_loss(
            dec_mask, lm_logits, tlogits)
        if self.alpha_hid > 0:
            hid_loss_dec = self.calc_hidden_loss(dec_mask, dec_hidden,
                                                 tdec_hidden,
                                                 self.hparams.d_matches)

        blended_loss = (self.alpha_ce * loss_ce +
                        self.alpha_mlm * student_lm_loss +
                        self.hparams.alpha_encoder_loss * loss_encoder +
                        self.hparams.alpha_hid * (hid_loss_enc + hid_loss_dec))
        return blended_loss, loss_ce, student_lm_loss, loss_encoder, hid_loss_enc, hid_loss_dec
Ejemplo n.º 14
0
    def _step(self, batch: dict) -> tuple:
        """Compute the loss for a batch"""
        pad_token_id = self.tokenizer.pad_token_id
        input_ids, src_mask, labels = batch["input_ids"], batch[
            "attention_mask"], batch["labels"]
        if isinstance(self.model, T5ForConditionalGeneration):
            decoder_input_ids = self.model._shift_right(labels)
        else:
            decoder_input_ids = shift_tokens_right(labels, pad_token_id)

        # noinspection PyCallingNonCallable
        student_outputs = self(
            input_ids,
            attention_mask=src_mask,
            decoder_input_ids=decoder_input_ids,
            output_hidden_states=self.do_calc_hidden_loss,
            output_attentions=False,
            use_cache=False,
        )
        lm_logits = student_outputs["logits"]

        # Same cross entropy vs. label smoothing logic as finetune.py
        assert lm_logits.shape[-1] == self.model.config.vocab_size
        if self.hparams.label_smoothing == 0:
            # Same behavior as modeling_bart.py, besides ignoring pad_token_id
            loss_fct = nn.CrossEntropyLoss(ignore_index=pad_token_id)
            student_lm_loss = loss_fct(lm_logits.view(-1, lm_logits.shape[-1]),
                                       labels.view(-1))
        else:
            lprobs = nn.functional.log_softmax(lm_logits, dim=-1)
            student_lm_loss, _ = label_smoothed_nll_loss(
                lprobs,
                labels,
                self.hparams.label_smoothing,
                ignore_index=pad_token_id)

        def zero_tensor():
            return torch.tensor(0.0).type_as(student_lm_loss)

        teacher_enc_outputs = student_outputs[
            "encoder_last_hidden_state"]  # use this unless self.different_base_models
        hid_loss_enc, hid_loss_dec = zero_tensor(), zero_tensor()
        if self.different_encoder:  # compute encoder hidden state loss
            all_teacher_encoder_outputs = self.teacher.get_encoder()(
                input_ids,
                attention_mask=src_mask,
                output_hidden_states=self.do_calc_hidden_loss,
            )
            if self.different_base_models:
                teacher_enc_outputs = all_teacher_encoder_outputs[
                    "last_hidden_state"]
            elif self.do_calc_hidden_loss:
                hid_loss_enc = self.calc_hidden_loss(
                    src_mask,
                    student_outputs["encoder_hidden_states"],
                    all_teacher_encoder_outputs["hidden_states"],
                    self.e_matches,
                    normalize_hidden=self.hparams.normalize_hidden,
                )

        teacher_outputs = self.teacher(
            input_ids,
            attention_mask=src_mask,
            encoder_outputs=(teacher_enc_outputs, ),
            decoder_input_ids=decoder_input_ids,
            output_hidden_states=self.do_calc_hidden_loss,
            use_cache=
            False,  # since we are not passing labels, never let this default to True
        )
        dec_mask = decoder_input_ids.ne(pad_token_id)
        loss_ce = self.calc_ce_loss(dec_mask, lm_logits,
                                    teacher_outputs["logits"])
        if self.do_calc_hidden_loss:  # Intermediate supervision of decoder hidden states
            hid_loss_dec = self.calc_hidden_loss(
                dec_mask,
                student_outputs["decoder_hidden_states"],
                teacher_outputs["decoder_hidden_states"],
                self.d_matches,
                normalize_hidden=self.hparams.normalize_hidden,
            )

        blended_loss = (self.alpha_ce * loss_ce +
                        self.alpha_mlm * student_lm_loss +
                        self.hparams.alpha_hid * (hid_loss_enc + hid_loss_dec))
        return blended_loss, loss_ce, student_lm_loss, hid_loss_enc, hid_loss_dec
Ejemplo n.º 15
0
    def _step(self, batch):
        # assert is_frozen(self.teacher) copied_decoder_layers
        pad_token_id = self.tokenizer.pad_token_id
        input_ids, src_mask, labels = batch["input_ids"], batch[
            "attention_mask"], batch["labels"]
        if isinstance(self.model, T5ForConditionalGeneration):
            decoder_input_ids = self.model._shift_right(labels)
        else:
            decoder_input_ids = shift_tokens_right(labels, pad_token_id)

        # noinspection PyCallingNonCallable
        lm_logits, dec_hidden, enc_outputs, enc_hidden_state = self(
            input_ids,
            attention_mask=src_mask,
            decoder_input_ids=decoder_input_ids,
            output_hidden_states=True,
            output_attentions=False,
            use_cache=False,
        )

        # Same cross entropy vs. label smoothing logic as finetune.py
        assert lm_logits.shape[-1] == self.model.config.vocab_size
        if self.hparams.label_smoothing == 0:
            # Same behavior as modeling_bart.py, besides ignoring pad_token_id
            loss_fct = torch.nn.CrossEntropyLoss(ignore_index=pad_token_id)
            student_lm_loss = loss_fct(lm_logits.view(-1, lm_logits.shape[-1]),
                                       labels.view(-1))
        else:
            lprobs = torch.nn.functional.log_softmax(lm_logits, dim=-1)
            student_lm_loss, _ = label_smoothed_nll_loss(
                lprobs,
                labels,
                self.hparams.label_smoothing,
                ignore_index=pad_token_id)

        def zero_tensor():
            return torch.tensor(0.0).type_as(student_lm_loss)

        hid_loss_enc, hid_loss_dec = zero_tensor(), zero_tensor()
        if self.different_encoder:  # compute encoder hidden state loss
            with torch.no_grad():
                teacher_enc_hid = self.teacher.get_encoder()(
                    input_ids,
                    attention_mask=src_mask,
                    output_hidden_states=True,
                    return_dict=True).hidden_states

            hid_loss_enc = self.calc_hidden_loss(
                src_mask,
                enc_hidden_state,
                teacher_enc_hid,
                self.e_matches,
                normalize_hidden=self.hparams.normalize_hidden,
            )

        with torch.no_grad():
            outputs = self.teacher(
                input_ids,
                attention_mask=src_mask,
                encoder_outputs=(enc_outputs, ),
                decoder_input_ids=decoder_input_ids,
                lm_labels=labels,
                output_hidden_states=True,
                return_dict=True,
            )
            tlogits, tdec_hidden = outputs.logits, outputs.decoder_hidden_states
        dec_mask = decoder_input_ids.ne(pad_token_id)
        loss_ce = self.calc_ce_loss(dec_mask, lm_logits, tlogits)
        if self.alpha_hid > 0:  # Intermediate supervision of decoder hidden states
            hid_loss_dec = self.calc_hidden_loss(
                dec_mask,
                dec_hidden,
                tdec_hidden,
                self.d_matches,
                normalize_hidden=self.hparams.normalize_hidden)

        blended_loss = (self.alpha_ce * loss_ce +
                        self.alpha_mlm * student_lm_loss +
                        self.hparams.alpha_hid * (hid_loss_enc + hid_loss_dec))
        return blended_loss, loss_ce, student_lm_loss, hid_loss_enc, hid_loss_dec
Ejemplo n.º 16
0
    def _valid_step(self, batch: dict) -> dict:
        pad_token_id = self.tokenizer.pad_token_id
        src_ids, src_mask = batch["input_ids"], batch["attention_mask"]

        tgt_ids = batch["labels"]
        if isinstance(self.model, T5ForConditionalGeneration):
            decoder_input_ids = self.model._shift_right(tgt_ids)
        else:
            decoder_input_ids = shift_tokens_right(tgt_ids, pad_token_id)

        if not self.already_saved_batch:  # This would be slightly better if it only happened on rank zero
            batch["decoder_input_ids"] = decoder_input_ids
            self.save_readable_batch(batch)

        outputs = self(src_ids,
                       attention_mask=src_mask,
                       decoder_input_ids=decoder_input_ids,
                       use_cache=False)
        lm_logits = outputs["logits"]

        if self.hparams.label_smoothing == 0:
            # Same behavior as modeling_bart.py, besides ignoring pad_token_id
            ce_loss_fct = torch.nn.CrossEntropyLoss(ignore_index=pad_token_id)
            assert lm_logits.shape[-1] == self.vocab_size

            loss = ce_loss_fct(lm_logits.view(-1, lm_logits.shape[-1]),
                               tgt_ids.view(-1))
            p_tensor = torch.all(torch.max(lm_logits, 2)[1] == tgt_ids, 1)
            acc = p_tensor.sum() / sum(p_tensor.shape)
            lm_loss, ti_loss = loss + 0., loss + 0.

            # ce_loss_fct = torch.nn.CrossEntropyLoss(ignore_index=pad_token_id, reduction="none")
            # assert lm_logits.shape[-1] == self.vocab_size
            # batch_loss = ce_loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), tgt_ids.view(-1))

            # lm_mask = (batch["loss_labels"].eq(0).to(torch.float)[:, None] * torch.ones_like(tgt_ids) * (1 - tgt_ids.eq(pad_token_id).to(torch.float))).view(-1)
            # ti_mask = (batch["loss_labels"].eq(1).to(torch.float)[:, None] * torch.ones_like(tgt_ids) * (1 - tgt_ids.eq(pad_token_id).to(torch.float))).view(-1)

            # lm_loss = torch.sum(batch_loss * lm_mask) / (torch.sum(lm_mask) + 1e-20)
            # ti_loss = torch.sum(batch_loss * ti_mask) / (torch.sum(ti_mask) + 1e-20)
            # loss = torch.sum(batch_loss * (lm_mask + ti_mask)) / (torch.sum(lm_mask + ti_mask) + 1e-20)
        else:
            lprobs = torch.nn.functional.log_softmax(lm_logits, dim=-1)
            loss, nll_loss = label_smoothed_nll_loss(
                lprobs,
                tgt_ids,
                self.hparams.label_smoothing,
                ignore_index=pad_token_id)
            lm_loss, ti_loss = loss + 0., ti_loss + 0.

        # print(src_ids)
        # print(batch.keys())
        # print(batch["ids"].cpu().numpy())
        # print(batch["loss_labels"].cpu().numpy())
        # for i in range(3):
        #     print(self.tokenizer.convert_ids_to_tokens(batch["input_ids"].cpu().numpy()[i]))
        #     print(self.tokenizer.convert_ids_to_tokens(decoder_input_ids.cpu().numpy()[i]))
        #     print(self.tokenizer.convert_ids_to_tokens(batch["labels"].cpu().numpy()[i]))
        #     # print(self.tokenizer.unk_token, self.tokenizer.pad_token, self.tokenizer.eos_token, self.tokenizer.bos_token, self.tokenizer.cls_token, self.tokenizer.mask_token)
        #     print("="*10)
        # print("="*30)
        # print(loss, lm_loss, ti_loss)
        # exit()

        return {'loss': loss, 'acc': acc}