Example #1
0
    def compute_loss(
        self,
        feat_out,
        feat_out_post,
        eos_out,
        feat_tgt,
        eos_tgt,
        tgt_lens,
        reduction="mean",
    ):
        mask = lengths_to_mask(tgt_lens)
        _eos_out = eos_out[mask].squeeze()
        _eos_tgt = eos_tgt[mask]
        _feat_tgt = feat_tgt[mask]
        _feat_out = feat_out[mask]
        _feat_out_post = feat_out_post[mask]

        l1_loss = F.l1_loss(_feat_out, _feat_tgt,
                            reduction=reduction) + F.l1_loss(
                                _feat_out_post, _feat_tgt, reduction=reduction)
        mse_loss = F.mse_loss(
            _feat_out, _feat_tgt, reduction=reduction) + F.mse_loss(
                _feat_out_post, _feat_tgt, reduction=reduction)
        eos_loss = F.binary_cross_entropy_with_logits(
            _eos_out,
            _eos_tgt,
            pos_weight=torch.tensor(self.bce_pos_weight),
            reduction=reduction,
        )
        return l1_loss, mse_loss, eos_loss
    def forward(self, model, sample, reduce=True):
        net_output = model(**sample["net_input"])
        loss, nll_loss = self.compute_loss(model,
                                           net_output,
                                           sample,
                                           reduce=reduce)

        ctc_loss = torch.tensor(0.0).type_as(loss)
        if self.ctc_weight > 0.0:
            ctc_lprobs, ctc_lens = model.get_ctc_output(net_output, sample)
            ctc_tgt, ctc_tgt_lens = model.get_ctc_target(sample)
            ctc_tgt_mask = lengths_to_mask(ctc_tgt_lens)
            ctc_tgt_flat = ctc_tgt.masked_select(ctc_tgt_mask)
            reduction = "sum" if reduce else "none"
            ctc_loss = (F.ctc_loss(
                ctc_lprobs,
                ctc_tgt_flat,
                ctc_lens,
                ctc_tgt_lens,
                reduction=reduction,
                zero_infinity=True,
            ) * self.ctc_weight)
        loss += ctc_loss

        sample_size = (sample["target"].size(0)
                       if self.sentence_avg else sample["ntokens"])
        logging_output = {
            "loss": utils.item(loss.data),
            "nll_loss": utils.item(nll_loss.data),
            "ctc_loss": utils.item(ctc_loss.data),
            "ntokens": sample["ntokens"],
            "nsentences": sample["target"].size(0),
            "sample_size": sample_size,
        }
        if self.report_accuracy:
            n_correct, total = self.compute_accuracy(model, net_output, sample)
            logging_output["n_correct"] = utils.item(n_correct.data)
            logging_output["total"] = utils.item(total.data)
        return loss, sample_size, logging_output
Example #3
0
    def forward(self, model: FairseqEncoderModel, sample, reduction="mean"):
        src_tokens = sample["net_input"]["src_tokens"]
        src_lens = sample["net_input"]["src_lengths"]
        tgt_lens = sample["target_lengths"]
        _feat_out, _feat_out_post, _, log_dur_out, pitch_out, energy_out = model(
            src_tokens=src_tokens,
            src_lengths=src_lens,
            prev_output_tokens=sample["net_input"]["prev_output_tokens"],
            incremental_state=None,
            target_lengths=tgt_lens,
            speaker=sample["speaker"],
            durations=sample["durations"],
            pitches=sample["pitches"],
            energies=sample["energies"],
        )

        src_mask = lengths_to_mask(sample["net_input"]["src_lengths"])
        tgt_mask = lengths_to_mask(sample["target_lengths"])

        pitches, energies = sample["pitches"], sample["energies"]
        pitch_out, pitches = pitch_out[src_mask], pitches[src_mask]
        energy_out, energies = energy_out[src_mask], energies[src_mask]

        feat_out, feat = _feat_out[tgt_mask], sample["target"][tgt_mask]
        l1_loss = F.l1_loss(feat_out, feat, reduction=reduction)
        if _feat_out_post is not None:
            l1_loss += F.l1_loss(_feat_out_post[tgt_mask],
                                 feat,
                                 reduction=reduction)

        pitch_loss = F.mse_loss(pitch_out, pitches, reduction=reduction)
        energy_loss = F.mse_loss(energy_out, energies, reduction=reduction)

        log_dur_out = log_dur_out[src_mask]
        dur = sample["durations"].float()
        dur = dur.half() if log_dur_out.type().endswith(".HalfTensor") else dur
        log_dur = torch.log(dur + 1)[src_mask]
        dur_loss = F.mse_loss(log_dur_out, log_dur, reduction=reduction)

        ctc_loss = torch.tensor(0.0).type_as(l1_loss)
        if self.ctc_weight > 0.0:
            lprobs = model.get_normalized_probs((_feat_out, ), log_probs=True)
            lprobs = lprobs.transpose(0, 1)  # T x B x C
            src_mask = lengths_to_mask(src_lens)
            src_tokens_flat = src_tokens.masked_select(src_mask)
            ctc_loss = (F.ctc_loss(
                lprobs,
                src_tokens_flat,
                tgt_lens,
                src_lens,
                reduction=reduction,
                zero_infinity=True,
            ) * self.ctc_weight)

        loss = l1_loss + dur_loss + pitch_loss + energy_loss + ctc_loss

        sample_size = sample["nsentences"]
        logging_output = {
            "loss": utils.item(loss.data),
            "ntokens": sample["ntokens"],
            "nsentences": sample["nsentences"],
            "sample_size": sample_size,
            "l1_loss": utils.item(l1_loss.data),
            "dur_loss": utils.item(dur_loss.data),
            "pitch_loss": utils.item(pitch_loss.data),
            "energy_loss": utils.item(energy_loss.data),
            "ctc_loss": utils.item(ctc_loss.data),
        }
        return loss, sample_size, logging_output
Example #4
0
 def _get_masks(src_lens, tgt_lens):
     in_masks = lengths_to_mask(src_lens)
     out_masks = lengths_to_mask(tgt_lens)
     return out_masks.unsqueeze(2) & in_masks.unsqueeze(1)
Example #5
0
    def forward(self, model, sample, reduction="mean"):
        bsz, max_len, _ = sample["target"].size()
        feat_tgt = sample["target"]
        feat_len = sample["target_lengths"].view(bsz, 1).expand(-1, max_len)
        eos_tgt = torch.arange(max_len).to(sample["target"].device)
        eos_tgt = eos_tgt.view(1, max_len).expand(bsz, -1)
        eos_tgt = (eos_tgt == (feat_len - 1)).float()
        src_tokens = sample["net_input"]["src_tokens"]
        src_lens = sample["net_input"]["src_lengths"]
        tgt_lens = sample["target_lengths"]

        feat_out, eos_out, extra = model(
            src_tokens=src_tokens,
            src_lengths=src_lens,
            prev_output_tokens=sample["net_input"]["prev_output_tokens"],
            incremental_state=None,
            target_lengths=tgt_lens,
            speaker=sample["speaker"],
        )

        l1_loss, mse_loss, eos_loss = self.compute_loss(
            extra["feature_out"],
            feat_out,
            eos_out,
            feat_tgt,
            eos_tgt,
            tgt_lens,
            reduction,
        )
        attn_loss = torch.tensor(0.0).type_as(l1_loss)
        if self.guided_attn is not None:
            attn_loss = self.guided_attn(extra["attn"], src_lens, tgt_lens,
                                         reduction)
        ctc_loss = torch.tensor(0.0).type_as(l1_loss)
        if self.ctc_weight > 0.0:
            net_output = (feat_out, eos_out, extra)
            lprobs = model.get_normalized_probs(net_output, log_probs=True)
            lprobs = lprobs.transpose(0, 1)  # T x B x C
            src_mask = lengths_to_mask(src_lens)
            src_tokens_flat = src_tokens.masked_select(src_mask)
            ctc_loss = (F.ctc_loss(
                lprobs,
                src_tokens_flat,
                tgt_lens,
                src_lens,
                reduction=reduction,
                zero_infinity=True,
            ) * self.ctc_weight)
        loss = l1_loss + mse_loss + eos_loss + attn_loss + ctc_loss

        sample_size = sample["nsentences"] if self.sentence_avg else sample[
            "ntokens"]
        logging_output = {
            "loss": utils.item(loss.data),
            "ntokens": sample["ntokens"],
            "nsentences": sample["nsentences"],
            "sample_size": sample_size,
            "l1_loss": utils.item(l1_loss.data),
            "mse_loss": utils.item(mse_loss.data),
            "eos_loss": utils.item(eos_loss.data),
            "attn_loss": utils.item(attn_loss.data),
            "ctc_loss": utils.item(ctc_loss.data),
        }
        return loss, sample_size, logging_output