def _source_mask(self, ilens): """Make masks for self-attention. Args: ilens (LongTensor or List): Batch of lengths (B,). Returns: Tensor: Mask tensor for self-attention. dtype=torch.uint8 in PyTorch 1.2- dtype=torch.bool in PyTorch 1.2+ (including 1.2) Examples: >>> ilens = [5, 3] >>> self._source_mask(ilens) tensor([[[1, 1, 1, 1, 1], [1, 1, 1, 0, 0]]], dtype=torch.uint8) """ x_masks = make_non_pad_mask(ilens).to(next(self.parameters()).device) return x_masks.unsqueeze(-2)
def _source_mask(self, ilens): """Make masks for self-attention. Examples: >>> ilens = [5, 3] >>> self._source_mask(ilens) tensor([[[1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1]], [[1, 1, 1, 0, 0], [1, 1, 1, 0, 0], [1, 1, 1, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]], dtype=torch.uint8) """ x_masks = make_non_pad_mask(ilens).to(next(self.parameters()).device) return x_masks.unsqueeze(-2) & x_masks.unsqueeze(-1)
def forward(self, xs_pad, ilens, ys_pad): """E2E forward. :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim) :param torch.Tensor ilens: batch of lengths of source sequences (B) :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax) :return: ctc loss value :rtype: torch.Tensor :return: attention loss value :rtype: torch.Tensor :return: accuracy in attention decoder :rtype: float """ # 1. forward encoder xs_pad = xs_pad[:, : max(ilens)] # for data parallel src_mask = make_non_pad_mask(ilens.tolist()).to(xs_pad.device).unsqueeze(-2) hs_pad, hs_mask = self.encoder(xs_pad, src_mask) self.hs_pad = hs_pad hs_mask = hs_mask.transpose(1,2) hs_mask = hs_mask.repeat(1,1,256).type(torch.FloatTensor).to(hs_pad.device) hs_pad_masked = hs_pad * hs_mask logging.warning("hs_pad_masked.size()==>" + str(hs_pad_masked.size())) att_vec = self.att(hs_pad_masked) # att_vec = self.att(hs_pad) pred_pad = self.output(att_vec).unsqueeze(1) logging.warning("att_vec.size()==>" + str(att_vec.size())) logging.warning("pred_pad.size()==>" + str(pred_pad.size())) # compute loss self.loss = self.criterion(pred_pad, ys_pad) self.acc = th_accuracy(pred_pad.view(-1, self.odim), ys_pad, ignore_label=self.ignore_id) loss_data = float(self.loss) if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data): self.reporter.report(self.acc, loss_data) else: logging.warning("loss (=%f) is not correct", loss_data) return self.loss
def _target_mask(self, olens): """Make masks for masked self-attention. Examples: >>> olens = [5, 3] >>> self._target_mask(olens) tensor([[[1, 0, 0, 0, 0], [1, 1, 0, 0, 0], [1, 1, 1, 0, 0], [1, 1, 1, 1, 0], [1, 1, 1, 1, 1]], [[1, 0, 0, 0, 0], [1, 1, 0, 0, 0], [1, 1, 1, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]], dtype=torch.uint8) """ y_masks = make_non_pad_mask(olens).to(next(self.parameters()).device) s_masks = subsequent_mask(y_masks.size(-1), device=y_masks.device).unsqueeze(0) return y_masks.unsqueeze(-2) & s_masks & y_masks.unsqueeze(-1)
def forward( self, x: torch.Tensor, x_lengths: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """Calculate forward propagation. Args: x (Tensor): Input index tensor (B, T_text). x_lengths (Tensor): Length tensor (B,). Returns: Tensor: Encoded hidden representation (B, attention_dim, T_text). Tensor: Projected mean tensor (B, attention_dim, T_text). Tensor: Projected scale tensor (B, attention_dim, T_text). Tensor: Mask tensor for input tensor (B, 1, T_text). """ x = self.emb(x) * math.sqrt(self.attention_dim) x_mask = ( make_non_pad_mask(x_lengths) .to( device=x.device, dtype=x.dtype, ) .unsqueeze(1) ) # encoder assume the channel last (B, T_text, attention_dim) # but mask shape shoud be (B, 1, T_text) x, _ = self.encoder(x, x_mask) # convert the channel first (B, attention_dim, T_text) x = x.transpose(1, 2) stats = self.proj(x) * x_mask m, logs = stats.split(stats.size(1) // 2, dim=1) return x, m, logs, x_mask
def forward(self, after_outs, before_outs, logits, ys, labels, olens): """Calculate forward propagation. Args: after_outs (Tensor): Batch of outputs after postnets (B, Lmax, odim). before_outs (Tensor): Batch of outputs before postnets (B, Lmax, odim). logits (Tensor): Batch of stop logits (B, Lmax). ys (Tensor): Batch of padded target features (B, Lmax, odim). labels (LongTensor): Batch of the sequences of stop token labels (B, Lmax). olens (LongTensor): Batch of the lengths of each target (B,). Returns: Tensor: L1 loss value. Tensor: Mean square error loss value. Tensor: Binary cross entropy loss value. """ # perform masking for padded values if self.use_masking: mask = make_non_pad_mask(olens).unsqueeze(-1).to(ys.device) ys = ys.masked_select(mask) after_outs = after_outs.masked_select(mask) before_outs = before_outs.masked_select(mask) labels = labels.masked_select(mask[:, :, 0]) logits = logits.masked_select(mask[:, :, 0]) # calculate loss l1_loss = F.l1_loss(after_outs, ys) + F.l1_loss(before_outs, ys) mse_loss = F.mse_loss(after_outs, ys) + F.mse_loss(before_outs, ys) bce_loss = F.binary_cross_entropy_with_logits(logits, labels, pos_weight=torch.tensor( self.bce_pos_weight, device=ys.device)) return l1_loss, mse_loss, bce_loss
def forward(self, student, teacher, ilens, typ=None): """Calculate forward propagation. Args: student (list of Tensor): outputs from student teacher (list of Tensor): outputs from teacher ilens (list): input sequence lengths typ (str): type of loss - L1 or L2 Returns: Tensor: loss value. """ # apply mask to remove padded part loss = 0.0 for s, t in zip(student, teacher): masks = make_non_pad_mask(ilens).unsqueeze(-1).to(s.device) s = s.masked_select(masks) t = t.masked_select(masks) # calculate loss if typ is not None: tloss = torch.nn.L1Loss(reduction='mean')(s, t) else: tloss = self.mse_criterion(s, t) loss = loss + tloss return loss
def forward(self, cbhg_outs, spcs, olens): """Calculate forward propagation. Args: cbhg_outs (Tensor): Batch of CBHG outputs (B, Lmax, spc_dim). spcs (Tensor): Batch of groundtruth of spectrogram (B, Lmax, spc_dim). olens (LongTensor): Batch of the lengths of each sequence (B,). Returns: Tensor: L1 loss value Tensor: Mean square error loss value. """ # perform masking for padded values if self.use_masking: mask = make_non_pad_mask(olens).unsqueeze(-1).to(spcs.device) spcs = spcs.masked_select(mask) cbhg_outs = cbhg_outs.masked_select(mask) # calculate loss cbhg_l1_loss = F.l1_loss(cbhg_outs, spcs) cbhg_mse_loss = F.mse_loss(cbhg_outs, spcs) return cbhg_l1_loss, cbhg_mse_loss
def __call__(self, batch, device=torch.device("cpu")): """Convert a given batch. Args: batch (list): List of ndarrays. device (torch.device): The device to be send. Returns: dict: Dict of converted tensors. """ # batch should be located in list assert len(batch) == 1 xs, ys, spembs, extras, f0, energy = batch[0] # get list of lengths (must be tensor for DataParallel) ilens = torch.from_numpy(np.array([x.shape[0] for x in xs])).long().to(device) olens = torch.from_numpy(np.array([y.shape[0] for y in ys])).long().to(device) # reorganize ys # print(ilens, ilens.shape) if extras is not None: new_ys = [] non_zero_lens_mask = [] ds_nonzeros = [] if self.append_position: position = [] for ib in range(ilens.shape[0]): # reorganize ys: divide ys with different phn/char, remove the phn/char with zero length ys_ib = ys[ib] ds_ib = extras[ib] # durations for new_ys_ib = [] non_zero_lens_mask_ib = [] for it in range(ilens[ib]): start = int(sum(ds_ib[:it]))*self.reduction_factor end = int(sum(ds_ib[:it+1]))*self.reduction_factor if start != end: ys_split = torch.from_numpy(ys_ib[start:end]).float() new_ys_ib.append(ys_split) # l x odim non_zero_lens_mask_ib.append(1) # if length > 0, then mask=1 ds_nonzeros.append(int(ds_ib[it]*self.reduction_factor)) if self.append_position: position.append(torch.FloatTensor(list(range(end-start)))/(end-start)) else: non_zero_lens_mask_ib.append(0) # if length = 0, then mask=0 new_ys.extend(new_ys_ib) non_zero_lens_mask.append(torch.tensor(non_zero_lens_mask_ib)) new_ys = pad_list(new_ys,0).to(device) # #-of-phn x Lmax x odim non_zero_lens_mask = pad_list(non_zero_lens_mask, 0) xs = pad_list([torch.from_numpy(x).long() for x in xs], 0).to(device) ys = pad_list([torch.from_numpy(y).float() for y in ys], 0).to(device) if self.use_fe_condition: new_f0 = pad_list([torch.from_numpy(f00).float() for f00 in f0], 0) # B x Imax x 1 new_en = pad_list([torch.from_numpy(enn).float() for enn in energy], 0) # B x Imax x 1 # prepare dict new_batch = { "xs": xs, "ilens": ilens, "ys": ys, "olens": olens, } # load speaker embedding if spembs is not None: spembs = torch.from_numpy(np.array(spembs)).float() new_batch["spembs"] = spembs.to(device) # load second target if extras is not None: extras = pad_list([torch.from_numpy(extra).float() for extra in extras], 0) new_batch["extras"] = extras.to(device) new_batch["new_ys"] = new_ys new_batch["non_zero_lens_mask"] = non_zero_lens_mask new_batch["ds_nonzeros"] = torch.tensor(ds_nonzeros).to(device) new_batch["output_masks"] = make_non_pad_mask(new_batch["ds_nonzeros"]).to(device) # #-of-phn x new_Lmax assert new_batch["new_ys"].shape[1] == new_batch["output_masks"].shape[1] if self.append_position: position = pad_list(position, 0) new_batch['position'] = position assert position.shape[0]==new_ys.shape[0] if self.use_fe_condition: new_batch['f0'] = new_f0 new_batch['energy'] = new_en return new_batch
def forward(self, xs, ilens, ys, olens, spembs=None, extras=None, new_ys=None, non_zero_lens_mask=None, ds_nonzeros=None, output_masks=None, position=None, f0=None, energy=None, *args, **kwargs): """Calculate forward propagation. Args: xs (Tensor): Batch of padded character ids (B, Tmax). ilens (LongTensor): Batch of lengths of each input batch (B,). ys (Tensor): Batch of padded target features (B, Lmax, odim). olens (LongTensor): Batch of the lengths of each target (B,). spembs (Tensor, optional): Batch of speaker embedding vectors (B, spk_embed_dim). extras (Tensor, optional): Batch of groundtruth spectrograms (B, Lmax, spc_dim). new_ys (Tensor): reorganized mel-spectrograms non_zero_lens_masks (Tensor) ds_nonzeros (Tensor) output_masks (Tensor) position (Tenor): position values for each phoneme f0 (Tensor): pitch energy (Tensor) Returns: Tensor: Loss value. """ # remove unnecessary padded part (for multi-gpus) max_in = max(ilens) max_out = max(olens) if max_in != xs.shape[1]: xs = xs[:, :max_in] if max_out != ys.shape[1]: ys = ys[:, :max_out] # calculate FCL-taco2-enc outputs hs, hlens = self.enc(xs, ilens) if self.spk_embed_dim is not None: spembs = F.normalize(spembs).unsqueeze(1).expand( -1, hs.size(1), -1) hs = torch.cat([hs, spembs], dim=-1) # duration predictor loss cal ds = extras.squeeze(-1) d_masks = make_pad_mask(ilens).to(xs.device) d_outs = self.duration_predictor(hs, d_masks) # (B, Tmax) duration_masks = make_non_pad_mask(ilens).to(ys.device) d_outs = d_outs.masked_select(duration_masks) duration_loss = self.duration_criterion( d_outs, ds.masked_select(duration_masks)) if self.use_fe_condition: expand_hs = hs fe_masks = d_masks if self.stop_gradient_from_pitch_predictor: p_outs = self.pitch_predictor(expand_hs.detach(), fe_masks.unsqueeze(-1)) else: p_outs = self.pitch_predictor( expand_hs, fe_masks.unsqueeze(-1)) # B x Tmax x 1 if self.stop_gradient_from_energy_predictor: e_outs = self.energy_predictor(expand_hs.detach(), fe_masks.unsqueeze(-1)) else: e_outs = self.energy_predictor( expand_hs, fe_masks.unsqueeze(-1)) # B x Tmax x 1 pitch_loss = self.prosody_criterion(p_outs, f0, ilens) energy_loss = self.prosody_criterion(e_outs, energy, ilens) p_embs = self.pitch_embed(f0.transpose(1, 2)).transpose(1, 2) e_embs = self.energy_embed(energy.transpose(1, 2)).transpose(1, 2) else: p_embs = None e_embs = None ylens = olens after_outs, before_outs = self.dec(hs, hlens, ds, ys, ylens, new_ys, non_zero_lens_mask, ds_nonzeros, output_masks, position, f0, energy, p_embs, e_embs) # modifiy mod part of groundtruth if self.reduction_factor > 1: olens = olens.new( [olen - olen % self.reduction_factor for olen in olens]) max_out = max(olens) ys = ys[:, :max_out] # caluculate taco2 loss l1_loss, mse_loss = self.taco2_loss(after_outs, before_outs, ys, olens) loss = l1_loss + mse_loss + duration_loss report_keys = [ { "l1_loss": l1_loss.item() }, { "mse_loss": mse_loss.item() }, { "dur_loss": duration_loss.item() }, ] if self.use_fe_condition: prosody_weight = 1.0 loss = loss + prosody_weight * (pitch_loss + energy_loss) report_keys += [ { 'pitch_loss': pitch_loss.item() }, { 'energy_loss': energy_loss.item() }, ] report_keys += [{"loss": loss.item()}] self.reporter.report(report_keys) return loss
def forward(self, xs_pad, ilens, ys_pad): """E2E forward. Args: xs_pad (torch.Tensor): batch of padded source sequences (B, Tmax, idim) ilens (torch.Tensor): batch of lengths of input sequences (B) ys_pad (torch.Tensor): batch of padded target sequences (B, Lmax) Returns: loss (torch.Tensor): transducer loss value """ # 1. encoder xs_pad = xs_pad[:, :max(ilens)] if "transformer" in self.etype: src_mask = make_non_pad_mask(ilens.tolist()).to( xs_pad.device).unsqueeze(-2) batchsize = xs_pad.size(0) inputs = xs_pad.unsqueeze(1) logging.info("inputs:{}".format(inputs.shape)) logging.info("src_mask:{}".format(src_mask.shape)) inputs_length = [] if src_mask is not None: for mask in src_mask.tolist(): inputs_length.append(mask[0].count(True)) for i in range(batchsize): inputs_s = inputs[i].unsqueeze(0)[:, :, 0:inputs_length[i], :] core_out = self.conv(inputs_s) inputs_length[i] = core_out.size(2) inputs_length = torch.as_tensor(inputs_length) else: core_out = self.conv(inputs) inputs_length = core_out.size(2) inputs_length = torch.as_tensor(inputs_length) logging.info("inputs_length:{}".format(inputs_length)) # block 1 # the inputs shape of Conv2d is 4-dim of (bsz * c * l * w) # the inputs shape of Conv1d is 3-dim of (bsz * c * l) # the inputs shape of transformer is 3-dim of (l * bsz * c) # conv output format: (bsz * c * t * d) inputs = self.conv(inputs) # we can get a batch of 16 channels feature maps in all time steps # merge 16 channels of one timestep to create one self-attention input (batch, 16, dim) inputs = inputs.permute(2, 0, 1, 3) logging.info("inputs:{}".format(inputs.shape)) merge = torch.zeros(inputs.size(0), batchsize, 512) for t in range(inputs.size(0)): # max_length merge[t] = self.clayers(inputs[t], None)[0].reshape(batchsize, 512) xs = merge.permute(1, 0, 2) if inputs_length.dim() == 0: masks = make_non_pad_mask([inputs_length]).unsqueeze(-2) else: masks = make_non_pad_mask(inputs_length.tolist()).unsqueeze(-2) hs_pad, hs_mask = self.encoder(xs, masks) else: hs_pad, hs_mask, _ = self.enc(xs_pad, ilens) self.hs_pad = hs_pad # 1.5. transducer preparation related ys_in_pad, target, pred_len, target_len = prepare_loss_inputs( ys_pad, hs_mask) # 2. decoder if "transformer" in self.dtype: ys_mask = target_mask(ys_in_pad, self.blank_id) pred_pad, _ = self.decoder(ys_in_pad, ys_mask, hs_pad) else: if self.rnnt_mode == "rnnt": pred_pad = self.dec(hs_pad, ys_in_pad) else: pred_pad = self.dec(hs_pad, ys_in_pad, pred_len) self.pred_pad = pred_pad # 3. loss computation loss = self.criterion(pred_pad, target, pred_len, target_len) self.loss = loss loss_data = float(self.loss) # 4. compute cer/wer if self.training or self.error_calculator is None: cer, wer = None, None else: cer, wer = self.error_calculator(hs_pad, ys_pad) if not math.isnan(loss_data): self.reporter.report(loss_data, cer, wer) else: logging.warning("loss (=%f) is not correct", loss_data) return self.loss
def forward( self, after_outs: torch.Tensor, before_outs: torch.Tensor, d_outs: torch.Tensor, p_outs: torch.Tensor, e_outs: torch.Tensor, ys: torch.Tensor, ds: torch.Tensor, ps: torch.Tensor, es: torch.Tensor, ilens: torch.Tensor, olens: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """Calculate forward propagation. Args: after_outs (Tensor): Batch of outputs after postnets (B, T_feats, odim). before_outs (Tensor): Batch of outputs before postnets (B, T_feats, odim). d_outs (LongTensor): Batch of outputs of duration predictor (B, T_text). p_outs (Tensor): Batch of outputs of pitch predictor (B, T_text, 1). e_outs (Tensor): Batch of outputs of energy predictor (B, T_text, 1). ys (Tensor): Batch of target features (B, T_feats, odim). ds (LongTensor): Batch of durations (B, T_text). ps (Tensor): Batch of target token-averaged pitch (B, T_text, 1). es (Tensor): Batch of target token-averaged energy (B, T_text, 1). ilens (LongTensor): Batch of the lengths of each input (B,). olens (LongTensor): Batch of the lengths of each target (B,). Returns: Tensor: L1 loss value. Tensor: Duration predictor loss value. Tensor: Pitch predictor loss value. Tensor: Energy predictor loss value. """ # apply mask to remove padded part if self.use_masking: out_masks = make_non_pad_mask(olens).unsqueeze(-1).to(ys.device) before_outs = before_outs.masked_select(out_masks) if after_outs is not None: after_outs = after_outs.masked_select(out_masks) ys = ys.masked_select(out_masks) duration_masks = make_non_pad_mask(ilens).to(ys.device) d_outs = d_outs.masked_select(duration_masks) ds = ds.masked_select(duration_masks) pitch_masks = make_non_pad_mask(ilens).unsqueeze(-1).to(ys.device) p_outs = p_outs.masked_select(pitch_masks) e_outs = e_outs.masked_select(pitch_masks) ps = ps.masked_select(pitch_masks) es = es.masked_select(pitch_masks) # calculate loss l1_loss = self.l1_criterion(before_outs, ys) if after_outs is not None: l1_loss += self.l1_criterion(after_outs, ys) duration_loss = self.duration_criterion(d_outs, ds) pitch_loss = self.mse_criterion(p_outs, ps) energy_loss = self.mse_criterion(e_outs, es) # make weighted mask and apply it if self.use_weighted_masking: out_masks = make_non_pad_mask(olens).unsqueeze(-1).to(ys.device) out_weights = out_masks.float() / out_masks.sum( dim=1, keepdim=True).float() out_weights /= ys.size(0) * ys.size(2) duration_masks = make_non_pad_mask(ilens).to(ys.device) duration_weights = ( duration_masks.float() / duration_masks.sum(dim=1, keepdim=True).float()) duration_weights /= ds.size(0) # apply weight l1_loss = l1_loss.mul(out_weights).masked_select(out_masks).sum() duration_loss = (duration_loss.mul(duration_weights).masked_select( duration_masks).sum()) pitch_masks = duration_masks.unsqueeze(-1) pitch_weights = duration_weights.unsqueeze(-1) pitch_loss = pitch_loss.mul(pitch_weights).masked_select( pitch_masks).sum() energy_loss = (energy_loss.mul(pitch_weights).masked_select( pitch_masks).sum()) return l1_loss, duration_loss, pitch_loss, energy_loss
def inference( self, text: torch.Tensor, text_lengths: torch.Tensor, feats: Optional[torch.Tensor] = None, feats_lengths: Optional[torch.Tensor] = None, sids: Optional[torch.Tensor] = None, spembs: Optional[torch.Tensor] = None, lids: Optional[torch.Tensor] = None, dur: Optional[torch.Tensor] = None, noise_scale: float = 0.667, noise_scale_dur: float = 0.8, alpha: float = 1.0, max_len: Optional[int] = None, use_teacher_forcing: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Run inference. Args: text (Tensor): Input text index tensor (B, T_text,). text_lengths (Tensor): Text length tensor (B,). feats (Tensor): Feature tensor (B, aux_channels, T_feats,). feats_lengths (Tensor): Feature length tensor (B,). sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1). spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim). lids (Optional[Tensor]): Language index tensor (B,) or (B, 1). dur (Optional[Tensor]): Ground-truth duration (B, T_text,). If provided, skip the prediction of durations (i.e., teacher forcing). noise_scale (float): Noise scale parameter for flow. noise_scale_dur (float): Noise scale parameter for duration predictor. alpha (float): Alpha parameter to control the speed of generated speech. max_len (Optional[int]): Maximum length of acoustic feature sequence. use_teacher_forcing (bool): Whether to use teacher forcing. Returns: Tensor: Generated waveform tensor (B, T_wav). Tensor: Monotonic attention weight tensor (B, T_feats, T_text). Tensor: Duration tensor (B, T_text). """ # encoder x, m_p, logs_p, x_mask = self.text_encoder(text, text_lengths) g = None if self.spks is not None: # (B, global_channels, 1) g = self.global_emb(sids.view(-1)).unsqueeze(-1) if self.spk_embed_dim is not None: # (B, global_channels, 1) g_ = self.spemb_proj(F.normalize( spembs.unsqueeze(0))).unsqueeze(-1) if g is None: g = g_ else: g = g + g_ if self.langs is not None: # (B, global_channels, 1) g_ = self.lang_emb(lids.view(-1)).unsqueeze(-1) if g is None: g = g_ else: g = g + g_ if use_teacher_forcing: # forward posterior encoder z, m_q, logs_q, y_mask = self.posterior_encoder(feats, feats_lengths, g=g) # forward flow z_p = self.flow(z, y_mask, g=g) # (B, H, T_feats) # monotonic alignment search s_p_sq_r = torch.exp(-2 * logs_p) # (B, H, T_text) # (B, 1, T_text) neg_x_ent_1 = torch.sum( -0.5 * math.log(2 * math.pi) - logs_p, [1], keepdim=True, ) # (B, T_feats, H) x (B, H, T_text) = (B, T_feats, T_text) neg_x_ent_2 = torch.matmul( -0.5 * (z_p**2).transpose(1, 2), s_p_sq_r, ) # (B, T_feats, H) x (B, H, T_text) = (B, T_feats, T_text) neg_x_ent_3 = torch.matmul( z_p.transpose(1, 2), (m_p * s_p_sq_r), ) # (B, 1, T_text) neg_x_ent_4 = torch.sum( -0.5 * (m_p**2) * s_p_sq_r, [1], keepdim=True, ) # (B, T_feats, T_text) neg_x_ent = neg_x_ent_1 + neg_x_ent_2 + neg_x_ent_3 + neg_x_ent_4 # (B, 1, T_feats, T_text) attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze( y_mask, -1) # monotonic attention weight: (B, 1, T_feats, T_text) attn = self.maximum_path( neg_x_ent, attn_mask.squeeze(1), ).unsqueeze(1) dur = attn.sum(2) # (B, 1, T_text) # forward decoder with random segments wav = self.decoder(z * y_mask, g=g) else: # duration if dur is None: logw = self.duration_predictor( x, x_mask, g=g, inverse=True, noise_scale=noise_scale_dur, ) w = torch.exp(logw) * x_mask * alpha dur = torch.ceil(w) y_lengths = torch.clamp_min(torch.sum(dur, [1, 2]), 1).long() y_mask = make_non_pad_mask(y_lengths).unsqueeze(1).to(text.device) attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze( y_mask, -1) attn = self._generate_path(dur, attn_mask) # expand the length to match with the feature sequence # (B, T_feats, T_text) x (B, T_text, H) -> (B, H, T_feats) m_p = torch.matmul( attn.squeeze(1), m_p.transpose(1, 2), ).transpose(1, 2) # (B, T_feats, T_text) x (B, T_text, H) -> (B, H, T_feats) logs_p = torch.matmul( attn.squeeze(1), logs_p.transpose(1, 2), ).transpose(1, 2) # decoder z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale z = self.flow(z_p, y_mask, g=g, inverse=True) wav = self.decoder((z * y_mask)[:, :, :max_len], g=g) return wav.squeeze(1), attn.squeeze(1), dur.squeeze(1)
def forward(self, xs_pad, ilens, ys_pad): """E2E forward. :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim) :param torch.Tensor ilens: batch of lengths of source sequences (B) :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax) :return: ctc loass value :rtype: torch.Tensor :return: attention loss value :rtype: torch.Tensor :return: accuracy in attention decoder :rtype: float """ # 1. forward encoder xs_pad = xs_pad[:, :max(ilens)] # for data parallel src_mask = make_non_pad_mask(ilens.tolist()).to( xs_pad.device).unsqueeze(-2) hs_pad, hs_mask = self.encoder(xs_pad, src_mask) # CTC forward ys = [y[y != self.ignore_id] for y in ys_pad] y_len = max([len(y) for y in ys]) ys_pad = ys_pad[:, :y_len] self.hs_pad = hs_pad cer_ctc = None batch_size = xs_pad.size(0) if self.mtlalpha == 0.0: loss_ctc = None else: batch_size = xs_pad.size(0) hs_len = hs_mask.view(batch_size, -1).sum(1) loss_ctc = self.ctc(hs_pad.view(batch_size, -1, self.adim), hs_len, ys_pad) # trigger mask start_time = time.time() # 2. forward decoder ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) ys_mask = target_mask(ys_in_pad, self.ignore_id) pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask) self.pred_pad = pred_pad # 3. compute attention loss loss_att = self.criterion(pred_pad, ys_out_pad) self.acc = th_accuracy(pred_pad.view(-1, self.odim), ys_out_pad, ignore_label=self.ignore_id) # copyied from e2e_asr alpha = self.mtlalpha if alpha == 0: self.loss = loss_att loss_att_data = float(loss_att) loss_ctc_data = None elif alpha == 1: self.loss = loss_ctc loss_att_data = None loss_ctc_data = float(loss_ctc) else: self.loss = alpha * loss_ctc + (1 - alpha) * loss_att loss_att_data = float(loss_att) return self.loss, loss_ctc_data, loss_att_data, self.acc
def inference( self, text: torch.Tensor, text_lengths: torch.Tensor, feats: Optional[torch.Tensor] = None, feats_lengths: Optional[torch.Tensor] = None, pitch: Optional[torch.Tensor] = None, energy: Optional[torch.Tensor] = None, sids: Optional[torch.Tensor] = None, spembs: Optional[torch.Tensor] = None, lids: Optional[torch.Tensor] = None, use_teacher_forcing: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Run inference. Args: text (Tensor): Input text index tensor (B, T_text,). text_lengths (Tensor): Text length tensor (B,). feats (Tensor): Feature tensor (B, T_feats, aux_channels). feats_lengths (Tensor): Feature length tensor (B,). pitch (Tensor): Pitch tensor (B, T_feats, 1) energy (Tensor): Energy tensor (B, T_feats, 1) sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1). spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim). lids (Optional[Tensor]): Language index tensor (B,) or (B, 1). use_teacher_forcing (bool): Whether to use teacher forcing. Returns: Tensor: Generated waveform tensor (B, T_wav). Tensor: Duration tensor (B, T_text). """ # forward encoder x_masks = self._source_mask(text_lengths) hs, _ = self.encoder(text, x_masks) # (B, T_text, adim) # integrate with GST if self.use_gst: style_embs = self.gst(feats) hs = hs + style_embs.unsqueeze(1) # integrate with SID and LID embeddings if self.spks is not None: sid_embs = self.sid_emb(sids.view(-1)) hs = hs + sid_embs.unsqueeze(1) if self.langs is not None: lid_embs = self.lid_emb(lids.view(-1)) hs = hs + lid_embs.unsqueeze(1) # integrate speaker embedding if self.spk_embed_dim is not None: hs = self._integrate_with_spk_embed(hs, spembs) h_masks = make_pad_mask(text_lengths).to(hs.device) if use_teacher_forcing: # forward alignment module and obtain duration, averaged pitch, energy log_p_attn = self.alignment_module(hs, feats, h_masks) d_outs, _ = viterbi_decode(log_p_attn, text_lengths, feats_lengths) p_outs = average_by_duration(d_outs, pitch.squeeze(-1), text_lengths, feats_lengths).unsqueeze(-1) e_outs = average_by_duration(d_outs, energy.squeeze(-1), text_lengths, feats_lengths).unsqueeze(-1) else: # forward duration predictor and variance predictors p_outs = self.pitch_predictor(hs, h_masks.unsqueeze(-1)) e_outs = self.energy_predictor(hs, h_masks.unsqueeze(-1)) d_outs = self.duration_predictor.inference(hs, h_masks) p_embs = self.pitch_embed(p_outs.transpose(1, 2)).transpose(1, 2) e_embs = self.energy_embed(e_outs.transpose(1, 2)).transpose(1, 2) hs = hs + e_embs + p_embs # upsampling if feats_lengths is not None: h_masks = make_non_pad_mask(feats_lengths).to(hs.device) else: h_masks = None d_masks = make_non_pad_mask(text_lengths).to(d_outs.device) hs = self.length_regulator(hs, d_outs, h_masks, d_masks) # (B, T_feats, adim) # forward decoder if feats_lengths is not None: h_masks = self._source_mask(feats_lengths) else: h_masks = None zs, _ = self.decoder(hs, h_masks) # (B, T_feats, adim) # forward generator wav = self.generator(zs.transpose(1, 2)) return wav.squeeze(1), d_outs
def forward( self, d_outs: torch.Tensor, ds: torch.Tensor, p_outs: torch.Tensor, ps: torch.Tensor, e_outs: torch.Tensor, es: torch.Tensor, ilens: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """Calculate forward propagation. Args: d_outs (LongTensor): Batch of outputs of duration predictor (B, T_text). ds (LongTensor): Batch of durations (B, T_text). p_outs (Tensor): Batch of outputs of pitch predictor (B, T_text, 1). ps (Tensor): Batch of target token-averaged pitch (B, T_text, 1). e_outs (Tensor): Batch of outputs of energy predictor (B, T_text, 1). es (Tensor): Batch of target token-averaged energy (B, T_text, 1). ilens (LongTensor): Batch of the lengths of each input (B,). Returns: Tensor: Duration predictor loss value. Tensor: Pitch predictor loss value. Tensor: Energy predictor loss value. """ # apply mask to remove padded part if self.use_masking: duration_masks = make_non_pad_mask(ilens).to(ds.device) d_outs = d_outs.masked_select(duration_masks) ds = ds.masked_select(duration_masks) pitch_masks = make_non_pad_mask(ilens).unsqueeze(-1).to(ds.device) p_outs = p_outs.masked_select(pitch_masks) e_outs = e_outs.masked_select(pitch_masks) ps = ps.masked_select(pitch_masks) es = es.masked_select(pitch_masks) # calculate loss duration_loss = self.duration_criterion(d_outs, ds) pitch_loss = self.mse_criterion(p_outs, ps) energy_loss = self.mse_criterion(e_outs, es) # make weighted mask and apply it if self.use_weighted_masking: duration_masks = make_non_pad_mask(ilens).to(ds.device) duration_weights = ( duration_masks.float() / duration_masks.sum(dim=1, keepdim=True).float() ) duration_weights /= ds.size(0) # apply weight duration_loss = ( duration_loss.mul(duration_weights).masked_select(duration_masks).sum() ) pitch_masks = duration_masks.unsqueeze(-1) pitch_weights = duration_weights.unsqueeze(-1) pitch_loss = pitch_loss.mul(pitch_weights).masked_select(pitch_masks).sum() energy_loss = ( energy_loss.mul(pitch_weights).masked_select(pitch_masks).sum() ) return duration_loss, pitch_loss, energy_loss
def _source_mask(self, ilens: torch.Tensor) -> torch.Tensor: x_masks = make_non_pad_mask(ilens).to(next(self.parameters()).device) return x_masks.unsqueeze(-2)
def forward(self, xs_pad, ilens, ys_pad): """E2E forward. :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim) :param torch.Tensor ilens: batch of lengths of source sequences (B) :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax) :return: ctc loass value :rtype: torch.Tensor :return: attention loss value :rtype: torch.Tensor :return: accuracy in attention decoder :rtype: float ''' """ if self.attention_enc_type in [ 'self_attn_dynamic_span', 'self_attn_adaptive_span', 'self_attn_adaptive_span2', 'self_attn_fixed_span2', 'self_attn_dynamic_span2' ]: for layer in self.encoder.encoders: layer.self_attn.clamp_param() if self.attention_dec_type in [ 'self_attn_dynamic_span', 'self_attn_adaptive_span', 'self_attn_adaptive_span2', 'self_attn_fixed_span2', 'self_attn_dynamic_span2' ]: for layer in self.decoder.decoders: layer.self_attn.clamp_param() # 1. forward encoder xs_pad = xs_pad[:, :max(ilens)] # for data parallel src_mask = make_non_pad_mask(ilens.tolist()).to( xs_pad.device).unsqueeze(-2) hs_pad, hs_mask = self.encoder(xs_pad, src_mask) self.hs_pad = hs_pad # 2. forward decoder ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) ys_mask = target_mask(ys_in_pad, self.ignore_id) pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask) self.pred_pad = pred_pad # 3. compute attention loss loss_att = self.criterion(pred_pad, ys_out_pad) self.acc = th_accuracy(pred_pad.view(-1, self.odim), ys_out_pad, ignore_label=self.ignore_id) # TODO(karita) show predicted text # TODO(karita) calculate these stats cer_ctc = None if self.mtlalpha == 0.0: loss_ctc = None else: batch_size = xs_pad.size(0) hs_len = hs_mask.view(batch_size, -1).sum(1) loss_ctc = self.ctc(hs_pad.view(batch_size, -1, self.adim), hs_len, ys_pad) if self.error_calculator is not None: ys_hat = self.ctc.argmax(hs_pad.view(batch_size, -1, self.adim)).data cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True) # 5. compute cer/wer if self.training or self.error_calculator is None: cer, wer = None, None else: ys_hat = pred_pad.argmax(dim=-1) cer, wer = self.error_calculator(ys_hat.cpu(), ys_pad.cpu()) # copyied from e2e_asr alpha = self.mtlalpha if alpha == 0: self.loss = loss_att loss_att_data = float(loss_att) loss_ctc_data = None elif alpha == 1: self.loss = loss_ctc loss_att_data = None loss_ctc_data = float(loss_ctc) else: self.loss = alpha * loss_ctc + (1 - alpha) * loss_att loss_att_data = float(loss_att) loss_ctc_data = float(loss_ctc) # xkc09 Span attention loss computation # xkc09 Span attention size loss computation loss_span = 0 if self.attention_enc_type in [ 'self_attn_dynamic_span', 'self_attn_adaptive_span', 'self_attn_adaptive_span2', 'self_attn_dynamic_span2' ]: loss_span += sum([ layer.self_attn.get_mean_span() for layer in self.encoder.encoders ]) if self.attention_dec_type in [ 'self_attn_dynamic_span', 'self_attn_adaptive_span', 'self_attn_adaptive_span2', 'self_attn_dynamic_span2' ]: loss_span += sum([ layer.self_attn.get_mean_span() for layer in self.decoder.decoders ]) # xkc09 Span attention ratio loss computation loss_ratio = 0 if self.ratio_adaptive: # target_ratio = 0.5 if self.attention_enc_type in [ 'self_attn_adaptive_span2', 'self_attn_fixed_span2', 'self_attn_dynamic_span2' ]: loss_ratio += sum([ 1 - layer.self_attn.get_mean_ratio() for layer in self.encoder.encoders ]) if self.attention_dec_type in [ 'self_attn_adaptive_span2', 'self_attn_fixed_span2', 'self_attn_dynamic_span2' ]: loss_ratio += sum([ 1 - layer.self_attn.get_mean_ratio() for layer in self.decoder.decoders ]) if (self.attention_enc_type in [ 'self_attn_dynamic_span', 'self_attn_adaptive_span', 'self_attn_adaptive_span2', 'self_attn_fixed_span2', 'self_attn_dynamic_span2' ] or self.attention_dec_type in [ 'self_attn_dynamic_span', 'self_attn_adaptive_span', 'self_attn_adaptive_span2', 'self_attn_fixed_span2', 'self_attn_dynamic_span2' ]): if getattr(self, 'span_loss_coef', None): self.loss += (loss_span + loss_ratio) * self.span_loss_coef loss_data = float(self.loss) if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data): self.reporter.report(loss_ctc_data, loss_att_data, self.acc, cer_ctc, cer, wer, loss_data) else: logging.warning("loss (=%f) is not correct", loss_data) return self.loss
def forward( self, text: torch.Tensor, text_lengths: torch.Tensor, feats: torch.Tensor, feats_lengths: torch.Tensor, pitch: torch.Tensor, pitch_lengths: torch.Tensor, energy: torch.Tensor, energy_lengths: torch.Tensor, sids: Optional[torch.Tensor] = None, spembs: Optional[torch.Tensor] = None, lids: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, ]: """Calculate forward propagation. Args: text (Tensor): Text index tensor (B, T_text). text_lengths (Tensor): Text length tensor (B,). feats (Tensor): Feature tensor (B, T_feats, aux_channels). feats_lengths (Tensor): Feature length tensor (B,). pitch (Tensor): Batch of padded token-averaged pitch (B, T_text, 1). pitch_lengths (LongTensor): Batch of pitch lengths (B, T_text). energy (Tensor): Batch of padded token-averaged energy (B, T_text, 1). energy_lengths (LongTensor): Batch of energy lengths (B, T_text). sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1). spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim). lids (Optional[Tensor]): Language index tensor (B,) or (B, 1). Returns: Tensor: Waveform tensor (B, 1, segment_size * upsample_factor). Tensor: Binarization loss (). Tensor: Log probability attention matrix (B, T_feats, T_text). Tensor: Segments start index tensor (B,). Tensor: predicted duration (B, T_text). Tensor: ground-truth duration obtained from an alignment module (B, T_text). Tensor: predicted pitch (B, T_text,1). Tensor: ground-truth averaged pitch (B, T_text, 1). Tensor: predicted energy (B, T_text, 1). Tensor: ground-truth averaged energy (B, T_text, 1). """ text = text[:, :text_lengths.max()] # for data-parallel feats = feats[:, :feats_lengths.max()] # for data-parallel pitch = pitch[:, :pitch_lengths.max()] # for data-parallel energy = energy[:, :energy_lengths.max()] # for data-parallel # forward encoder x_masks = self._source_mask(text_lengths) hs, _ = self.encoder(text, x_masks) # (B, T_text, adim) # integrate with GST if self.use_gst: style_embs = self.gst(feats) hs = hs + style_embs.unsqueeze(1) # integrate with SID and LID embeddings if self.spks is not None: sid_embs = self.sid_emb(sids.view(-1)) hs = hs + sid_embs.unsqueeze(1) if self.langs is not None: lid_embs = self.lid_emb(lids.view(-1)) hs = hs + lid_embs.unsqueeze(1) # integrate speaker embedding if self.spk_embed_dim is not None: hs = self._integrate_with_spk_embed(hs, spembs) # forward alignment module and obtain duration, averaged pitch, energy h_masks = make_pad_mask(text_lengths).to(hs.device) log_p_attn = self.alignment_module(hs, feats, h_masks) ds, bin_loss = viterbi_decode(log_p_attn, text_lengths, feats_lengths) ps = average_by_duration(ds, pitch.squeeze(-1), text_lengths, feats_lengths).unsqueeze(-1) es = average_by_duration(ds, energy.squeeze(-1), text_lengths, feats_lengths).unsqueeze(-1) # forward duration predictor and variance predictors if self.stop_gradient_from_pitch_predictor: p_outs = self.pitch_predictor(hs.detach(), h_masks.unsqueeze(-1)) else: p_outs = self.pitch_predictor(hs, h_masks.unsqueeze(-1)) if self.stop_gradient_from_energy_predictor: e_outs = self.energy_predictor(hs.detach(), h_masks.unsqueeze(-1)) else: e_outs = self.energy_predictor(hs, h_masks.unsqueeze(-1)) d_outs = self.duration_predictor(hs, h_masks) # use groundtruth in training p_embs = self.pitch_embed(ps.transpose(1, 2)).transpose(1, 2) e_embs = self.energy_embed(es.transpose(1, 2)).transpose(1, 2) hs = hs + e_embs + p_embs # upsampling h_masks = make_non_pad_mask(feats_lengths).to(hs.device) d_masks = make_non_pad_mask(text_lengths).to(ds.device) hs = self.length_regulator(hs, ds, h_masks, d_masks) # (B, T_feats, adim) # forward decoder h_masks = self._source_mask(feats_lengths) zs, _ = self.decoder(hs, h_masks) # (B, T_feats, adim) # get random segments z_segments, z_start_idxs = get_random_segments( zs.transpose(1, 2), feats_lengths, self.segment_size, ) # forward generator wav = self.generator(z_segments) return ( wav, bin_loss, log_p_attn, z_start_idxs, d_outs, ds, p_outs, ps, e_outs, es, )
def forward(self, xs_pad, ilens, ys_pad): """E2E forward. Args: xs_pad (torch.Tensor): batch of padded source sequences (B, Tmax, idim) ilens (torch.Tensor): batch of lengths of input sequences (B) ys_pad (torch.Tensor): batch of padded target sequences (B, Lmax) Returns: loss (torch.Tensor): transducer loss value """ # 1. encoder xs_pad = xs_pad[:, :max(ilens)] if "custom" in self.etype: src_mask = make_non_pad_mask(ilens.tolist()).to( xs_pad.device).unsqueeze(-2) _hs_pad, hs_mask = self.encoder(xs_pad, src_mask) else: _hs_pad, hs_mask, _ = self.enc(xs_pad, ilens) if self.use_aux_task: hs_pad, aux_hs_pad = _hs_pad[0], _hs_pad[1] else: hs_pad, aux_hs_pad = _hs_pad, None # 1.5. transducer preparation related ys_in_pad, ys_out_pad, target, pred_len, target_len = prepare_loss_inputs( ys_pad, hs_mask) # 2. decoder if "custom" in self.dtype: ys_mask = target_mask(ys_in_pad, self.blank_id) pred_pad, _ = self.decoder(ys_in_pad, ys_mask, hs_pad) else: pred_pad = self.dec(hs_pad, ys_in_pad) z = self.joint_network(hs_pad.unsqueeze(2), pred_pad.unsqueeze(1)) # 3. loss computation loss_trans = self.criterion(z, target, pred_len, target_len) if self.use_aux_task and aux_hs_pad is not None: loss_trans += self.auxiliary_task(aux_hs_pad, pred_pad, z, target, pred_len, target_len) if self.use_aux_ctc: if "custom" in self.etype: hs_mask = torch.IntTensor([h.size(1) for h in hs_mask], ).to( hs_mask.device) loss_ctc = self.aux_ctc(hs_pad, hs_mask, ys_pad) else: loss_ctc = 0 if self.use_aux_cross_entropy: loss_ce = self.aux_cross_entropy(self.aux_decoder_output(pred_pad), ys_out_pad) else: loss_ce = 0 loss = (self.transducer_weight * loss_trans + self.aux_ctc_weight * loss_ctc + self.aux_cross_entropy_weight * loss_ce) self.loss = loss loss_data = float(loss) # 4. compute cer/wer if self.training or self.error_calculator is None: cer, wer = None, None else: cer, wer = self.error_calculator(hs_pad, ys_pad) if not math.isnan(loss_data): self.reporter.report(loss_data, cer, wer) else: logging.warning("loss (=%f) is not correct", loss_data) return self.loss
def forward(self, xs_pad, ilens, ys_pad): """E2E forward. :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim) :param torch.Tensor ilens: batch of lengths of source sequences (B) :param torch.Tensor ys_pad: batch of padded target sequences (B, num_spkrs, Lmax) :return: ctc loass value :rtype: torch.Tensor :return: attention loss value :rtype: torch.Tensor :return: accuracy in attention decoder :rtype: float """ # 1. forward encoder xs_pad = xs_pad[:, :max(ilens)] # for data parallel src_mask = make_non_pad_mask(ilens.tolist()).to( xs_pad.device).unsqueeze(-2) hs_pad, hs_mask = self.encoder(xs_pad, src_mask) # list: speaker differentiate self.hs_pad = hs_pad # 2. ctc # TODO(karita) show predicted text # TODO(karita) calculate these stats cer_ctc = None assert self.mtlalpha > 0.0 batch_size = xs_pad.size(0) ys_pad = ys_pad.transpose(0, 1) # (num_spkrs, B, Lmax) hs_len = [ hs_mask[i].view(batch_size, -1).sum(1) for i in range(self.num_spkrs) ] loss_ctc_perm = torch.stack( [ self.ctc( hs_pad[i // self.num_spkrs].view(batch_size, -1, self.adim), hs_len[i // self.num_spkrs], ys_pad[i % self.num_spkrs], ) for i in range(self.num_spkrs**2) ], dim=1, ) # (B, num_spkrs^2) loss_ctc, min_perm = self.pit.pit_process(loss_ctc_perm) logging.info("ctc loss:" + str(float(loss_ctc))) # Permute the labels according to loss for b in range(batch_size): # B ys_pad[:, b] = ys_pad[min_perm[b], b] # (num_spkrs, B, Lmax) ys_out_len = [ float(torch.sum(ys_pad[i] != self.ignore_id)) for i in range(self.num_spkrs) ] # TODO(karita) show predicted text # TODO(karita) calculate these stats if self.error_calculator is not None: cer_ctc = [] for i in range(self.num_spkrs): ys_hat = self.ctc.argmax(hs_pad[i].view( batch_size, -1, self.adim)).data cer_ctc.append( self.error_calculator(ys_hat.cpu(), ys_pad[i].cpu(), is_ctc=True)) cer_ctc = sum(map(lambda x: x[0] * x[1], zip( cer_ctc, ys_out_len))) / sum(ys_out_len) else: cer_ctc = None # 3. forward decoder if self.mtlalpha == 1.0: loss_att, self.acc, cer, wer = None, None, None, None else: pred_pad, pred_mask = [None] * self.num_spkrs, [None ] * self.num_spkrs loss_att, acc = [None] * self.num_spkrs, [None] * self.num_spkrs for i in range(self.num_spkrs): ( pred_pad[i], pred_mask[i], loss_att[i], acc[i], ) = self.decoder_and_attention(hs_pad[i], hs_mask[i], ys_pad[i], batch_size) # 4. compute attention loss # The following is just an approximation loss_att = sum( map(lambda x: x[0] * x[1], zip(loss_att, ys_out_len))) / sum(ys_out_len) self.acc = sum(map(lambda x: x[0] * x[1], zip( acc, ys_out_len))) / sum(ys_out_len) # 5. compute cer/wer if self.training or self.error_calculator is None: cer, wer = None, None else: ys_hat = pred_pad.argmax(dim=-1) cer, wer = self.error_calculator(ys_hat.cpu(), ys_pad.cpu()) # copyied from e2e_asr alpha = self.mtlalpha if alpha == 0: self.loss = loss_att loss_att_data = float(loss_att) loss_ctc_data = None elif alpha == 1: self.loss = loss_ctc loss_att_data = None loss_ctc_data = float(loss_ctc) else: self.loss = alpha * loss_ctc + (1 - alpha) * loss_att loss_att_data = float(loss_att) loss_ctc_data = float(loss_ctc) loss_data = float(self.loss) if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data): self.reporter.report(loss_ctc_data, loss_att_data, self.acc, cer_ctc, cer, wer, loss_data) else: logging.warning("loss (=%f) is not correct", loss_data) return self.loss
def forward(self, xs_pad, ilens, ys_pad): """E2E forward. Args: xs_pad (torch.Tensor): batch of padded source sequences (B, Tmax, idim) ilens (torch.Tensor): batch of lengths of input sequences (B) ys_pad (torch.Tensor): batch of padded target sequences (B, Lmax) Returns: loss (torch.Tensor): transducer loss value """ # 1. encoder # if etpye is transformer, deal the padding # xs_pad:[8, 393, 83] # ilens:[393 * 8] # src_mask:[8, 1, 393] # hs_mask:[8, 1, 65] if self.etype == "transformer": xs_pad = xs_pad[:, :max(ilens)] src_mask = make_non_pad_mask(ilens.tolist()).to( xs_pad.device).unsqueeze(-2) hs_pad, hs_mask = self.encoder(xs_pad, src_mask) else: logging.info("enc!!!") hs_pad, hs_mask, _ = self.encoder(xs_pad, ilens) self.hs_pad = hs_pad # 1.5. transducer preparation related # ys_in_pad: sos,1,2,...,0 [8, 14] # target: 1,2,... [8, 13] # pred_len: [8] # target_len: [8] # ys_out_pad:1,2,...,eos,-1 ys = [y[y != self.ignore_id] for y in ys_pad] eos = ys[0].new([self.eos]) sos = ys[0].new([self.sos]) ys_in = [torch.cat([sos, y], dim=0) for y in ys] ys_out = [torch.cat([y, eos], dim=0) for y in ys] ys_out_pad = pad_list(ys_out, self.ignore_id) ys_in_pad, target, pred_len, target_len = prepare_loss_inputs( ys_pad, hs_mask) # 2. decoder # ys_mask:[8, 16, 16] if self.dtype == "transformer": ys_mask = target_mask(ys_in_pad, self.blank_id) pred_pad, pred_att, _ = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask) else: if self.rnnt_mode == "rnnt": pred_pad = self.dec(hs_pad, ys_in_pad) else: pred_pad = self.dec(hs_pad, ys_in_pad, pred_len) self.pred_pad = pred_pad # 3. loss computation loss_att = F.cross_entropy( pred_att, ys_out_pad.view(-1), # batch x olength ignore_index=self.ignore_id, ) # compute perplexity # ppl = math.exp(loss_att.item()) # -1: eos, which is removed in the loss computation loss_att *= np.mean([len(x) for x in ys_in]) - 1 loss_rnnt = self.criterion(pred_pad, target, pred_len, target_len) # loss_ctc = self.ctc(hs_pad, pred_len, ys_pad) alpha = self.mtlalpha beta = self.mtlbeta gamma = self.mtlgamma self.loss_rnnt = loss_rnnt self.loss_att = loss_att # self.loss_ctc = loss_ctc # self.loss = alpha * self.loss_ctc + beta * self.loss_rnnt + gamma * self.loss_att self.loss = beta * self.loss_rnnt + gamma * self.loss_att # self.loss = alpha * self.loss_ctc loss_data = float(self.loss) # loss_ctc_data = float(self.loss_ctc) loss_att_data = float(self.loss_att) loss_rnnt_data = float(self.loss_rnnt) # loss_att_data = None # loss_rnnt_data = None # 4. compute cer/wer if self.training or self.error_calculator is None: logging.info("ALL none!!!!!") cer, wer = None, None else: cer, wer = self.error_calculator(hs_pad, ys_pad) # with open('/home/oshindo/espnet/egs/aishell/asr1/exp/train_sp_pytorch_e2e_asr_transducer/blstmp_ctc.txt', "a+") as fid: # fid.write("loss:" + str(loss_ctc_data) + '\n') if not math.isnan(loss_data): self.reporter.report(loss_data, loss_rnnt_data, loss_att_data, cer, wer) else: logging.warning("loss (=%f) is not correct", loss_data) return self.loss
def forward(self, xs, ilens, ys, olens, spembs=None, *args, **kwargs): """Calculate forward propagation. Args: xs (Tensor): Batch of padded character ids (B, Tmax). ilens (LongTensor): Batch of lengths of each input batch (B,). ys (Tensor): Batch of padded target features (B, Lmax, odim). olens (LongTensor): Batch of the lengths of each target (B,). spembs (Tensor, optional): Batch of speaker embedding vectors (B, spk_embed_dim). Returns: Tensor: Loss value. """ # remove unnecessary padded part (for multi-gpus) xs = xs[:, :max(ilens)] ys = ys[:, :max(olens)] # forward propagation outs, ds, d_outs = self._forward(xs, ilens, ys, olens, spembs=spembs, is_inference=False) # apply mask to remove padded part if self.use_masking: in_masks = make_non_pad_mask(ilens).to(xs.device) d_outs = d_outs.masked_select(in_masks) ds = ds.masked_select(in_masks) out_masks = make_non_pad_mask(olens).unsqueeze(-1).to(ys.device) outs = outs.masked_select(out_masks) ys = ys.masked_select(out_masks) # calculate loss l1_loss = self.criterion(outs, ys) duration_loss = self.duration_criterion(d_outs, ds) loss = l1_loss + duration_loss report_keys = [ { "l1_loss": l1_loss.item() }, { "duration_loss": duration_loss.item() }, { "loss": loss.item() }, ] # report extra information if self.use_scaled_pos_enc: report_keys += [ { "encoder_alpha": self.encoder.embed[-1].alpha.data.item() }, { "decoder_alpha": self.decoder.embed[-1].alpha.data.item() }, ] self.reporter.report(report_keys) return loss
def forward(self, xs_pad, ilens, ys_pad, enc_mask=None, dec_mask=None): """E2E forward. :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim) :param torch.Tensor ilens: batch of lengths of source sequences (B) :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax) :return: ctc loass value :rtype: torch.Tensor :return: attention loss value :rtype: torch.Tensor :return: accuracy in attention decoder :rtype: float """ # 1. forward encoder xs_pad = xs_pad[:, :max(ilens)] # for data parallel batch_size = xs_pad.shape[0] src_mask = make_non_pad_mask(ilens.tolist()).to( xs_pad.device).unsqueeze(-2) if isinstance(self.encoder.embed, EncoderConv2d): xs, hs_mask = self.encoder.embed(xs_pad, torch.sum(src_mask, 2).squeeze()) hs_mask = hs_mask.unsqueeze(1) else: xs, hs_mask = self.encoder.embed(xs_pad, src_mask) if enc_mask is not None: enc_mask = enc_mask[:, :hs_mask.shape[2], :hs_mask.shape[2]] enc_mask = enc_mask & hs_mask if enc_mask is not None else hs_mask hs_pad, _ = self.encoder.encoders(xs, enc_mask) if self.encoder.normalize_before: hs_pad = self.encoder.after_norm(hs_pad) # CTC forward ys = [y[y != self.ignore_id] for y in ys_pad] y_len = max([len(y) for y in ys]) ys_pad = ys_pad[:, :y_len] if dec_mask is not None: dec_mask = dec_mask[:, :y_len + 1, :hs_pad.shape[1]] self.hs_pad = hs_pad batch_size = xs_pad.size(0) if self.mtlalpha == 0.0: loss_ctc = None else: batch_size = xs_pad.size(0) hs_len = hs_mask.view(batch_size, -1).sum(1) loss_ctc = self.ctc(hs_pad.view(batch_size, -1, self.adim), hs_len, ys_pad) # trigger mask hs_mask = hs_mask & dec_mask if dec_mask is not None else hs_mask # 2. forward decoder ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) ys_mask = target_mask(ys_in_pad, self.ignore_id) pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask) self.pred_pad = pred_pad # 3. compute attention loss loss_att = self.criterion(pred_pad, ys_out_pad) self.acc = th_accuracy(pred_pad.view(-1, self.odim), ys_out_pad, ignore_label=self.ignore_id) # copyied from e2e_asr alpha = self.mtlalpha if alpha == 0: self.loss = loss_att loss_att_data = float(loss_att) loss_ctc_data = None elif alpha == 1: self.loss = loss_ctc loss_att_data = None loss_ctc_data = float(loss_ctc) else: self.loss = alpha * loss_ctc + (1 - alpha) * loss_att loss_att_data = float(loss_att) loss_ctc_data = float(loss_ctc) return self.loss, loss_ctc_data, loss_att_data, self.acc
def forward(self, feats: torch.Tensor, feats_len: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: """E2E forward. Args: feats: Feature sequences. (B, F, D_feats) feats_len: Feature sequences lengths. (B,) labels: Label ID sequences. (B, L) Returns: loss: Transducer loss value """ # 1. encoder feats = feats[:, :max(feats_len)] if self.etype == "custom": feats_mask = (make_non_pad_mask(feats_len.tolist()).to( feats.device).unsqueeze(-2)) _enc_out, _enc_out_len = self.encoder(feats, feats_mask) else: _enc_out, _enc_out_len, _ = self.enc(feats, feats_len) if self.use_auxiliary_enc_outputs: enc_out, aux_enc_out = _enc_out[0], _enc_out[1] enc_out_len, aux_enc_out_len = _enc_out_len[0], _enc_out_len[1] else: enc_out, aux_enc_out = _enc_out, None enc_out_len, aux_enc_out_len = _enc_out_len, None # 2. decoder dec_in = get_decoder_input(labels, self.blank_id, self.ignore_id) if self.dtype == "custom": self.decoder.set_device(enc_out.device) dec_in_mask = target_mask(dec_in, self.blank_id) dec_out, _ = self.decoder(dec_in, dec_in_mask) else: self.dec.set_device(enc_out.device) dec_out = self.dec(dec_in) # 3. transducer tasks computation losses = self.transducer_tasks( enc_out, aux_enc_out, dec_out, labels, enc_out_len, aux_enc_out_len, ) if self.training or self.error_calculator is None: cer, wer = None, None else: cer, wer = self.error_calculator( enc_out, self.transducer_tasks.get_target()) self.loss = sum(losses) loss_data = float(self.loss) if not math.isnan(loss_data): self.reporter.report( loss_data, *[float(loss) for loss in losses], cer, wer, ) else: logging.warning("loss (=%f) is not correct", loss_data) return self.loss
def forward(self, xs_pad, ilens, ys_pad): """E2E forward. :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim) :param torch.Tensor ilens: batch of lengths of source sequences (B) :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax) :return: ctc loass value :rtype: torch.Tensor :return: attention loss value :rtype: torch.Tensor :return: accuracy in attention decoder :rtype: float """ # 1. forward Transformer encoder xs_pad = xs_pad[:, : max(ilens)] # for data parallel if xs_pad.size(1) != ys_pad.size(1): if xs_pad.size(1) < ys_pad.size(1): ys_pad = ys_pad[:, :xs_pad.size(1)].contiguous() else: raise ValueError("target size {} is smaller than input size {}".format(ys_pad.size(1), xs_pad.size(1))) src_mask = make_non_pad_mask(ilens.tolist()).to(xs_pad.device).unsqueeze(-2) hs_pad, hs_mask = self.encoder(xs_pad, src_mask) # 2. post-processing layer for target dimension if self.outer: post_pad = self.poster(hs_pad) post_pad = post_pad.view(post_pad.size(0), -1, self.odim) if post_pad.size(1) != xs_pad.size(1): if post_pad.size(1) < xs_pad.size(1): xs_pad = xs_pad[:, :post_pad.size(1)].contiguous() else: raise ValueError("target size {} and pred size {} is mismatch".format(xs_pad.size(1), post_pad.size(1))) if self.residual: post_pad = post_pad + self.matcher_res(xs_pad) else: post_pad = torch.cat([post_pad, xs_pad], dim=-1) pred_pad = self.matcher(post_pad) else: pred_pad = self.poster(hs_pad) pred_pad = pred_pad.view(pred_pad.size(0), -1, self.odim) self.pred_pad = pred_pad if pred_pad.size(1) != ys_pad.size(1): if pred_pad.size(1) < ys_pad.size(1): ys_pad = ys_pad[:, :pred_pad.size(1)].contiguous() else: raise ValueError("target size {} and pred size {} is mismatch".format(ys_pad.size(1), pred_pad.size(1))) # 3. compute attention loss loss_att = self.criterion(pred_pad, ys_pad) self.acc = th_accuracy( pred_pad.view(-1, self.odim), ys_pad, ignore_label=self.ignore_id ) # TODO(karita) show predicted text # TODO(karita) calculate these stats cer_ctc = None if self.mtlalpha == 0.0: loss_ctc = None else: batch_size = xs_pad.size(0) hs_len = hs_mask.view(batch_size, -1).sum(1) loss_ctc = self.ctc(pred_pad.view(batch_size, -1, self.adim), hs_len, ys_pad) if self.error_calculator is not None: ys_hat = self.ctc.argmax(pred_pad.view(batch_size, -1, self.adim)).data cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True) # 3. compute cer/wer if self.training or self.error_calculator is None: cer, wer = None, None else: ys_hat = pred_pad.argmax(dim=-1) cer, wer = self.error_calculator(ys_hat.cpu(), ys_pad.cpu()) # copyied from e2e_asr alpha = self.mtlalpha if alpha == 0: self.loss = loss_att loss_att_data = float(loss_att) loss_ctc_data = None elif alpha == 1: self.loss = loss_ctc loss_att_data = None loss_ctc_data = float(loss_ctc) else: self.loss = alpha * loss_ctc + (1 - alpha) * loss_att loss_att_data = float(loss_att) loss_ctc_data = float(loss_ctc) loss_data = float(self.loss) if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data): self.reporter.report( loss_ctc_data, loss_att_data, self.acc, cer_ctc, cer, wer, loss_data ) else: pass # logging.warning("loss (=%f) is not correct", loss_data) return self.loss
def forward(self, xs_pad, ilens, ys_pad, ys_pad_src): """E2E forward. :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim) :param torch.Tensor ilens: batch of lengths of source sequences (B) :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax) :param torch.Tensor ys_pad_src: batch of padded target sequences (B, Lmax) :return: ctc loss value :rtype: torch.Tensor :return: attention loss value :rtype: torch.Tensor :return: accuracy in attention decoder :rtype: float """ # 0. Extract target language ID tgt_lang_ids = None if self.multilingual: tgt_lang_ids = ys_pad[:, 0:1] ys_pad = ys_pad[:, 1:] # remove target language ID in the beggining # 1. forward encoder xs_pad = xs_pad[:, : max(ilens)] # for data parallel src_mask = make_non_pad_mask(ilens.tolist()).to(xs_pad.device).unsqueeze(-2) hs_pad, hs_mask = self.encoder(xs_pad, src_mask) # 2. forward decoder ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) # replace <sos> with target language ID if self.replace_sos: ys_in_pad = torch.cat([tgt_lang_ids, ys_in_pad[:, 1:]], dim=1) ys_mask = target_mask(ys_in_pad, self.ignore_id) pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask) # 3. compute ST loss loss_att = self.criterion(pred_pad, ys_out_pad) self.acc = th_accuracy( pred_pad.view(-1, self.odim), ys_out_pad, ignore_label=self.ignore_id ) # 4. compute corpus-level bleu in a mini-batch if self.training: self.bleu = None else: ys_hat = pred_pad.argmax(dim=-1) self.bleu = self.error_calculator(ys_hat.cpu(), ys_pad.cpu()) # 5. compute auxiliary ASR loss loss_asr_att, acc_asr, loss_asr_ctc, cer_ctc, cer, wer = self.forward_asr( hs_pad, hs_mask, ys_pad_src ) # 6. compute auxiliary MT loss loss_mt, acc_mt = 0.0, None if self.mt_weight > 0: loss_mt, acc_mt = self.forward_mt( ys_pad_src, ys_in_pad, ys_out_pad, ys_mask ) asr_ctc_weight = self.mtlalpha self.loss = ( (1 - self.asr_weight - self.mt_weight) * loss_att + self.asr_weight * (asr_ctc_weight * loss_asr_ctc + (1 - asr_ctc_weight) * loss_asr_att) + self.mt_weight * loss_mt ) loss_asr_data = float( asr_ctc_weight * loss_asr_ctc + (1 - asr_ctc_weight) * loss_asr_att ) loss_mt_data = None if self.mt_weight == 0 else float(loss_mt) loss_st_data = float(loss_att) loss_data = float(self.loss) if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data): self.reporter.report( loss_asr_data, loss_mt_data, loss_st_data, acc_asr, acc_mt, self.acc, cer_ctc, cer, wer, self.bleu, loss_data, ) else: logging.warning("loss (=%f) is not correct", loss_data) return self.loss
def forward(self, xs_pad, ilens, ys_pad): """E2E forward. :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim) :param torch.Tensor ilens: batch of lengths of source sequences (B) :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax) :return: ctc loss value :rtype: torch.Tensor :return: attention loss value :rtype: torch.Tensor :return: accuracy in attention decoder :rtype: float """ # 1. forward encoder xs_pad = xs_pad[:, :max(ilens)] # for data parallel src_mask = make_non_pad_mask(ilens.tolist()).to( xs_pad.device).unsqueeze(-2) hs_pad, hs_mask, hs_intermediates = self.encoder(xs_pad, src_mask) self.hs_pad = hs_pad # 2. forward decoder ys_in_pad, ys_out_pad = mask_uniform(ys_pad, self.mask_token, self.eos, self.ignore_id) ys_mask = square_mask(ys_in_pad, self.eos) pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask) self.pred_pad = pred_pad # 3. compute attention loss loss_att = self.criterion(pred_pad, ys_out_pad) self.acc = th_accuracy(pred_pad.view(-1, self.odim), ys_out_pad, ignore_label=self.ignore_id) # 4. compute ctc loss loss_ctc, cer_ctc = None, None loss_intermediate_ctc = 0.0 if self.mtlalpha > 0: batch_size = xs_pad.size(0) hs_len = hs_mask.view(batch_size, -1).sum(1) loss_ctc = self.ctc(hs_pad.view(batch_size, -1, self.adim), hs_len, ys_pad) if self.error_calculator is not None: ys_hat = self.ctc.argmax(hs_pad.view(batch_size, -1, self.adim)).data cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True) # for visualization if not self.training: self.ctc.softmax(hs_pad) if self.intermediate_ctc_weight > 0 and self.intermediate_ctc_layers: for hs_intermediate in hs_intermediates: # assuming hs_intermediates and hs_pad has same length / padding loss_inter = self.ctc( hs_intermediate.view(batch_size, -1, self.adim), hs_len, ys_pad) loss_intermediate_ctc += loss_inter loss_intermediate_ctc /= len(self.intermediate_ctc_layers) # 5. compute cer/wer if self.training or self.error_calculator is None or self.decoder is None: cer, wer = None, None else: ys_hat = pred_pad.argmax(dim=-1) cer, wer = self.error_calculator(ys_hat.cpu(), ys_pad.cpu()) alpha = self.mtlalpha if alpha == 0: self.loss = loss_att loss_att_data = float(loss_att) loss_ctc_data = None else: self.loss = (alpha * loss_ctc + self.intermediate_ctc_weight * loss_intermediate_ctc + (1 - alpha - self.intermediate_ctc_weight) * loss_att) loss_att_data = float(loss_att) loss_ctc_data = float(loss_ctc) loss_data = float(self.loss) if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data): self.reporter.report(loss_ctc_data, loss_att_data, self.acc, cer_ctc, cer, wer, loss_data) else: logging.warning("loss (=%f) is not correct", loss_data) return self.loss
def forward(self, xs_pad, ilens, ys_pad): """E2E forward. :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim) :param torch.Tensor ilens: batch of lengths of source sequences (B) :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax) :return: ctc loss value :rtype: torch.Tensor :return: attention loss value :rtype: torch.Tensor :return: accuracy in attention decoder :rtype: float """ # 1. forward encoder xs_pad = xs_pad[:, : max(ilens)] # for data parallel src_mask = make_non_pad_mask(ilens.tolist()).to(xs_pad.device).unsqueeze(-2) hs_pad, hs_mask = self.encoder(xs_pad, src_mask) self.hs_pad = hs_pad # 2. forward decoder if self.decoder is not None: if self.decoder_mode == "maskctc": ys_in_pad, ys_out_pad = mask_uniform( ys_pad, self.mask_token, self.eos, self.ignore_id ) ys_mask = (ys_in_pad != self.ignore_id).unsqueeze(-2) else: ys_in_pad, ys_out_pad = add_sos_eos( ys_pad, self.sos, self.eos, self.ignore_id ) ys_mask = target_mask(ys_in_pad, self.ignore_id) pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask) self.pred_pad = pred_pad # 3. compute attention loss loss_att = self.criterion(pred_pad, ys_out_pad) self.acc = th_accuracy( pred_pad.view(-1, self.odim), ys_out_pad, ignore_label=self.ignore_id ) else: loss_att = None self.acc = None # TODO(karita) show predicted text # TODO(karita) calculate these stats cer_ctc = None if self.mtlalpha == 0.0: loss_ctc = None else: batch_size = xs_pad.size(0) hs_len = hs_mask.view(batch_size, -1).sum(1) loss_ctc = self.ctc(hs_pad.view(batch_size, -1, self.adim), hs_len, ys_pad) if not self.training and self.error_calculator is not None: ys_hat = self.ctc.argmax(hs_pad.view(batch_size, -1, self.adim)).data cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True) # for visualization if not self.training: self.ctc.softmax(hs_pad) # 5. compute cer/wer if self.training or self.error_calculator is None or self.decoder is None: cer, wer = None, None else: ys_hat = pred_pad.argmax(dim=-1) cer, wer = self.error_calculator(ys_hat.cpu(), ys_pad.cpu()) # copied from e2e_asr alpha = self.mtlalpha if alpha == 0: self.loss = loss_att loss_att_data = float(loss_att) loss_ctc_data = None elif alpha == 1: self.loss = loss_ctc loss_att_data = None loss_ctc_data = float(loss_ctc) else: self.loss = alpha * loss_ctc + (1 - alpha) * loss_att loss_att_data = float(loss_att) loss_ctc_data = float(loss_ctc) loss_data = float(self.loss) if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data): self.reporter.report( loss_ctc_data, loss_att_data, self.acc, cer_ctc, cer, wer, loss_data ) else: logging.warning("loss (=%f) is not correct", loss_data) return self.loss
def _target_mask(self, olens): y_masks = make_non_pad_mask(olens).to(next(self.parameters()).device) s_masks = subsequent_mask(y_masks.size(-1), device=y_masks.device).unsqueeze(0) return y_masks.unsqueeze(-2) & s_masks