def compute_loss( self, feat_out, feat_out_post, eos_out, feat_tgt, eos_tgt, tgt_lens, reduction="mean", ): mask = lengths_to_mask(tgt_lens) _eos_out = eos_out[mask].squeeze() _eos_tgt = eos_tgt[mask] _feat_tgt = feat_tgt[mask] _feat_out = feat_out[mask] _feat_out_post = feat_out_post[mask] l1_loss = F.l1_loss(_feat_out, _feat_tgt, reduction=reduction) + F.l1_loss( _feat_out_post, _feat_tgt, reduction=reduction) mse_loss = F.mse_loss( _feat_out, _feat_tgt, reduction=reduction) + F.mse_loss( _feat_out_post, _feat_tgt, reduction=reduction) eos_loss = F.binary_cross_entropy_with_logits( _eos_out, _eos_tgt, pos_weight=torch.tensor(self.bce_pos_weight), reduction=reduction, ) return l1_loss, mse_loss, eos_loss
def forward(self, model, sample, reduce=True): net_output = model(**sample["net_input"]) loss, nll_loss = self.compute_loss(model, net_output, sample, reduce=reduce) ctc_loss = torch.tensor(0.0).type_as(loss) if self.ctc_weight > 0.0: ctc_lprobs, ctc_lens = model.get_ctc_output(net_output, sample) ctc_tgt, ctc_tgt_lens = model.get_ctc_target(sample) ctc_tgt_mask = lengths_to_mask(ctc_tgt_lens) ctc_tgt_flat = ctc_tgt.masked_select(ctc_tgt_mask) reduction = "sum" if reduce else "none" ctc_loss = (F.ctc_loss( ctc_lprobs, ctc_tgt_flat, ctc_lens, ctc_tgt_lens, reduction=reduction, zero_infinity=True, ) * self.ctc_weight) loss += ctc_loss sample_size = (sample["target"].size(0) if self.sentence_avg else sample["ntokens"]) logging_output = { "loss": utils.item(loss.data), "nll_loss": utils.item(nll_loss.data), "ctc_loss": utils.item(ctc_loss.data), "ntokens": sample["ntokens"], "nsentences": sample["target"].size(0), "sample_size": sample_size, } if self.report_accuracy: n_correct, total = self.compute_accuracy(model, net_output, sample) logging_output["n_correct"] = utils.item(n_correct.data) logging_output["total"] = utils.item(total.data) return loss, sample_size, logging_output
def forward(self, model: FairseqEncoderModel, sample, reduction="mean"): src_tokens = sample["net_input"]["src_tokens"] src_lens = sample["net_input"]["src_lengths"] tgt_lens = sample["target_lengths"] _feat_out, _feat_out_post, _, log_dur_out, pitch_out, energy_out = model( src_tokens=src_tokens, src_lengths=src_lens, prev_output_tokens=sample["net_input"]["prev_output_tokens"], incremental_state=None, target_lengths=tgt_lens, speaker=sample["speaker"], durations=sample["durations"], pitches=sample["pitches"], energies=sample["energies"], ) src_mask = lengths_to_mask(sample["net_input"]["src_lengths"]) tgt_mask = lengths_to_mask(sample["target_lengths"]) pitches, energies = sample["pitches"], sample["energies"] pitch_out, pitches = pitch_out[src_mask], pitches[src_mask] energy_out, energies = energy_out[src_mask], energies[src_mask] feat_out, feat = _feat_out[tgt_mask], sample["target"][tgt_mask] l1_loss = F.l1_loss(feat_out, feat, reduction=reduction) if _feat_out_post is not None: l1_loss += F.l1_loss(_feat_out_post[tgt_mask], feat, reduction=reduction) pitch_loss = F.mse_loss(pitch_out, pitches, reduction=reduction) energy_loss = F.mse_loss(energy_out, energies, reduction=reduction) log_dur_out = log_dur_out[src_mask] dur = sample["durations"].float() dur = dur.half() if log_dur_out.type().endswith(".HalfTensor") else dur log_dur = torch.log(dur + 1)[src_mask] dur_loss = F.mse_loss(log_dur_out, log_dur, reduction=reduction) ctc_loss = torch.tensor(0.0).type_as(l1_loss) if self.ctc_weight > 0.0: lprobs = model.get_normalized_probs((_feat_out, ), log_probs=True) lprobs = lprobs.transpose(0, 1) # T x B x C src_mask = lengths_to_mask(src_lens) src_tokens_flat = src_tokens.masked_select(src_mask) ctc_loss = (F.ctc_loss( lprobs, src_tokens_flat, tgt_lens, src_lens, reduction=reduction, zero_infinity=True, ) * self.ctc_weight) loss = l1_loss + dur_loss + pitch_loss + energy_loss + ctc_loss sample_size = sample["nsentences"] logging_output = { "loss": utils.item(loss.data), "ntokens": sample["ntokens"], "nsentences": sample["nsentences"], "sample_size": sample_size, "l1_loss": utils.item(l1_loss.data), "dur_loss": utils.item(dur_loss.data), "pitch_loss": utils.item(pitch_loss.data), "energy_loss": utils.item(energy_loss.data), "ctc_loss": utils.item(ctc_loss.data), } return loss, sample_size, logging_output
def _get_masks(src_lens, tgt_lens): in_masks = lengths_to_mask(src_lens) out_masks = lengths_to_mask(tgt_lens) return out_masks.unsqueeze(2) & in_masks.unsqueeze(1)
def forward(self, model, sample, reduction="mean"): bsz, max_len, _ = sample["target"].size() feat_tgt = sample["target"] feat_len = sample["target_lengths"].view(bsz, 1).expand(-1, max_len) eos_tgt = torch.arange(max_len).to(sample["target"].device) eos_tgt = eos_tgt.view(1, max_len).expand(bsz, -1) eos_tgt = (eos_tgt == (feat_len - 1)).float() src_tokens = sample["net_input"]["src_tokens"] src_lens = sample["net_input"]["src_lengths"] tgt_lens = sample["target_lengths"] feat_out, eos_out, extra = model( src_tokens=src_tokens, src_lengths=src_lens, prev_output_tokens=sample["net_input"]["prev_output_tokens"], incremental_state=None, target_lengths=tgt_lens, speaker=sample["speaker"], ) l1_loss, mse_loss, eos_loss = self.compute_loss( extra["feature_out"], feat_out, eos_out, feat_tgt, eos_tgt, tgt_lens, reduction, ) attn_loss = torch.tensor(0.0).type_as(l1_loss) if self.guided_attn is not None: attn_loss = self.guided_attn(extra["attn"], src_lens, tgt_lens, reduction) ctc_loss = torch.tensor(0.0).type_as(l1_loss) if self.ctc_weight > 0.0: net_output = (feat_out, eos_out, extra) lprobs = model.get_normalized_probs(net_output, log_probs=True) lprobs = lprobs.transpose(0, 1) # T x B x C src_mask = lengths_to_mask(src_lens) src_tokens_flat = src_tokens.masked_select(src_mask) ctc_loss = (F.ctc_loss( lprobs, src_tokens_flat, tgt_lens, src_lens, reduction=reduction, zero_infinity=True, ) * self.ctc_weight) loss = l1_loss + mse_loss + eos_loss + attn_loss + ctc_loss sample_size = sample["nsentences"] if self.sentence_avg else sample[ "ntokens"] logging_output = { "loss": utils.item(loss.data), "ntokens": sample["ntokens"], "nsentences": sample["nsentences"], "sample_size": sample_size, "l1_loss": utils.item(l1_loss.data), "mse_loss": utils.item(mse_loss.data), "eos_loss": utils.item(eos_loss.data), "attn_loss": utils.item(attn_loss.data), "ctc_loss": utils.item(ctc_loss.data), } return loss, sample_size, logging_output