def _source_mask(self, ilens):
        """Make masks for self-attention.

        Args:
            ilens (LongTensor or List): Batch of lengths (B,).

        Returns:
            Tensor: Mask tensor for self-attention.
                    dtype=torch.uint8 in PyTorch 1.2-
                    dtype=torch.bool in PyTorch 1.2+ (including 1.2)

        Examples:
            >>> ilens = [5, 3]
            >>> self._source_mask(ilens)
            tensor([[[1, 1, 1, 1, 1],
                     [1, 1, 1, 0, 0]]], dtype=torch.uint8)

        """
        x_masks = make_non_pad_mask(ilens).to(next(self.parameters()).device)
        return x_masks.unsqueeze(-2)
    def _source_mask(self, ilens):
        """Make masks for self-attention.

        Examples:
            >>> ilens = [5, 3]
            >>> self._source_mask(ilens)
            tensor([[[1, 1, 1, 1, 1],
                     [1, 1, 1, 1, 1],
                     [1, 1, 1, 1, 1],
                     [1, 1, 1, 1, 1],
                     [1, 1, 1, 1, 1]],
                    [[1, 1, 1, 0, 0],
                     [1, 1, 1, 0, 0],
                     [1, 1, 1, 0, 0],
                     [0, 0, 0, 0, 0],
                     [0, 0, 0, 0, 0]]], dtype=torch.uint8)

        """
        x_masks = make_non_pad_mask(ilens).to(next(self.parameters()).device)
        return x_masks.unsqueeze(-2) & x_masks.unsqueeze(-1)
Exemple #3
0
    def forward(self, xs_pad, ilens, ys_pad):
        """E2E forward.

        :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim)
        :param torch.Tensor ilens: batch of lengths of source sequences (B)
        :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax)
        :return: ctc loss value
        :rtype: torch.Tensor
        :return: attention loss value
        :rtype: torch.Tensor
        :return: accuracy in attention decoder
        :rtype: float
        """
        # 1. forward encoder
        xs_pad = xs_pad[:, : max(ilens)]  # for data parallel
        src_mask = make_non_pad_mask(ilens.tolist()).to(xs_pad.device).unsqueeze(-2)
        hs_pad, hs_mask = self.encoder(xs_pad, src_mask)
        self.hs_pad = hs_pad

        hs_mask = hs_mask.transpose(1,2)
        hs_mask = hs_mask.repeat(1,1,256).type(torch.FloatTensor).to(hs_pad.device)
        hs_pad_masked = hs_pad * hs_mask
        logging.warning("hs_pad_masked.size()==>" + str(hs_pad_masked.size()))
        att_vec = self.att(hs_pad_masked)

        # att_vec = self.att(hs_pad)

        pred_pad = self.output(att_vec).unsqueeze(1)
        logging.warning("att_vec.size()==>" + str(att_vec.size()))
        logging.warning("pred_pad.size()==>" + str(pred_pad.size()))
        # compute loss
        self.loss = self.criterion(pred_pad, ys_pad)
        self.acc = th_accuracy(pred_pad.view(-1, self.odim), ys_pad,
                               ignore_label=self.ignore_id)

        loss_data = float(self.loss)
        if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data):
             self.reporter.report(self.acc, loss_data)
        else:
            logging.warning("loss (=%f) is not correct", loss_data)
        return self.loss
    def _target_mask(self, olens):
        """Make masks for masked self-attention.

        Examples:
            >>> olens = [5, 3]
            >>> self._target_mask(olens)
            tensor([[[1, 0, 0, 0, 0],
                     [1, 1, 0, 0, 0],
                     [1, 1, 1, 0, 0],
                     [1, 1, 1, 1, 0],
                     [1, 1, 1, 1, 1]],
                    [[1, 0, 0, 0, 0],
                     [1, 1, 0, 0, 0],
                     [1, 1, 1, 0, 0],
                     [0, 0, 0, 0, 0],
                     [0, 0, 0, 0, 0]]], dtype=torch.uint8)

        """
        y_masks = make_non_pad_mask(olens).to(next(self.parameters()).device)
        s_masks = subsequent_mask(y_masks.size(-1), device=y_masks.device).unsqueeze(0)
        return y_masks.unsqueeze(-2) & s_masks & y_masks.unsqueeze(-1)
Exemple #5
0
    def forward(
        self,
        x: torch.Tensor,
        x_lengths: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        """Calculate forward propagation.

        Args:
            x (Tensor): Input index tensor (B, T_text).
            x_lengths (Tensor): Length tensor (B,).

        Returns:
            Tensor: Encoded hidden representation (B, attention_dim, T_text).
            Tensor: Projected mean tensor (B, attention_dim, T_text).
            Tensor: Projected scale tensor (B, attention_dim, T_text).
            Tensor: Mask tensor for input tensor (B, 1, T_text).

        """
        x = self.emb(x) * math.sqrt(self.attention_dim)
        x_mask = (
            make_non_pad_mask(x_lengths)
            .to(
                device=x.device,
                dtype=x.dtype,
            )
            .unsqueeze(1)
        )
        # encoder assume the channel last (B, T_text, attention_dim)
        # but mask shape shoud be (B, 1, T_text)
        x, _ = self.encoder(x, x_mask)

        # convert the channel first (B, attention_dim, T_text)
        x = x.transpose(1, 2)
        stats = self.proj(x) * x_mask
        m, logs = stats.split(stats.size(1) // 2, dim=1)

        return x, m, logs, x_mask
Exemple #6
0
    def forward(self, after_outs, before_outs, logits, ys, labels, olens):
        """Calculate forward propagation.

        Args:
            after_outs (Tensor): Batch of outputs after postnets (B, Lmax, odim).
            before_outs (Tensor): Batch of outputs before postnets (B, Lmax, odim).
            logits (Tensor): Batch of stop logits (B, Lmax).
            ys (Tensor): Batch of padded target features (B, Lmax, odim).
            labels (LongTensor): Batch of the sequences of stop token labels (B, Lmax).
            olens (LongTensor): Batch of the lengths of each target (B,).

        Returns:
            Tensor: L1 loss value.
            Tensor: Mean square error loss value.
            Tensor: Binary cross entropy loss value.

        """
        # perform masking for padded values
        if self.use_masking:
            mask = make_non_pad_mask(olens).unsqueeze(-1).to(ys.device)
            ys = ys.masked_select(mask)
            after_outs = after_outs.masked_select(mask)
            before_outs = before_outs.masked_select(mask)
            labels = labels.masked_select(mask[:, :, 0])
            logits = logits.masked_select(mask[:, :, 0])

        # calculate loss
        l1_loss = F.l1_loss(after_outs, ys) + F.l1_loss(before_outs, ys)
        mse_loss = F.mse_loss(after_outs, ys) + F.mse_loss(before_outs, ys)
        bce_loss = F.binary_cross_entropy_with_logits(logits,
                                                      labels,
                                                      pos_weight=torch.tensor(
                                                          self.bce_pos_weight,
                                                          device=ys.device))

        return l1_loss, mse_loss, bce_loss
Exemple #7
0
    def forward(self, student, teacher, ilens, typ=None):
        """Calculate forward propagation.
        Args:
            student (list of Tensor): outputs from student
            teacher (list of Tensor): outputs from teacher
            ilens (list): input sequence lengths
            typ (str): type of loss - L1 or L2
        Returns:
            Tensor: loss value.
        """
        # apply mask to remove padded part
        loss = 0.0
        for s, t in zip(student, teacher):
            masks = make_non_pad_mask(ilens).unsqueeze(-1).to(s.device)
            s = s.masked_select(masks)
            t = t.masked_select(masks)
            # calculate loss
            if typ is not None:
                tloss = torch.nn.L1Loss(reduction='mean')(s, t)
            else:
                tloss = self.mse_criterion(s, t)
            loss = loss + tloss

        return loss
    def forward(self, cbhg_outs, spcs, olens):
        """Calculate forward propagation.

        Args:
            cbhg_outs (Tensor): Batch of CBHG outputs (B, Lmax, spc_dim).
            spcs (Tensor): Batch of groundtruth of spectrogram (B, Lmax, spc_dim).
            olens (LongTensor): Batch of the lengths of each sequence (B,).

        Returns:
            Tensor: L1 loss value
            Tensor: Mean square error loss value.

        """
        # perform masking for padded values
        if self.use_masking:
            mask = make_non_pad_mask(olens).unsqueeze(-1).to(spcs.device)
            spcs = spcs.masked_select(mask)
            cbhg_outs = cbhg_outs.masked_select(mask)

        # calculate loss
        cbhg_l1_loss = F.l1_loss(cbhg_outs, spcs)
        cbhg_mse_loss = F.mse_loss(cbhg_outs, spcs)

        return cbhg_l1_loss, cbhg_mse_loss
Exemple #9
0
    def __call__(self, batch, device=torch.device("cpu")):
        """Convert a given batch.

        Args:
            batch (list): List of ndarrays.
            device (torch.device): The device to be send.

        Returns:
            dict: Dict of converted tensors.

        """
        # batch should be located in list
        assert len(batch) == 1
        xs, ys, spembs, extras, f0, energy = batch[0]
        
        
        # get list of lengths (must be tensor for DataParallel)
        ilens = torch.from_numpy(np.array([x.shape[0] for x in xs])).long().to(device)
        olens = torch.from_numpy(np.array([y.shape[0] for y in ys])).long().to(device)
        
        # reorganize ys 
        # print(ilens, ilens.shape)
        if extras is not None:
            new_ys = []
            non_zero_lens_mask = []
            ds_nonzeros = []
            if self.append_position:
                position = []
            for ib in range(ilens.shape[0]):
                # reorganize ys: divide ys with different phn/char, remove the phn/char with zero length
                ys_ib = ys[ib]
                ds_ib = extras[ib] # durations for 
                new_ys_ib = []
                non_zero_lens_mask_ib = []
                for it in range(ilens[ib]):
                    start = int(sum(ds_ib[:it]))*self.reduction_factor
                    end = int(sum(ds_ib[:it+1]))*self.reduction_factor
                    if start != end:
                        ys_split = torch.from_numpy(ys_ib[start:end]).float()
                        new_ys_ib.append(ys_split) # l x odim
                        non_zero_lens_mask_ib.append(1) # if length > 0, then mask=1
                        ds_nonzeros.append(int(ds_ib[it]*self.reduction_factor))
                        if self.append_position:
                            position.append(torch.FloatTensor(list(range(end-start)))/(end-start))
                    else:
                        non_zero_lens_mask_ib.append(0) # if length = 0, then mask=0
                
                new_ys.extend(new_ys_ib)
                non_zero_lens_mask.append(torch.tensor(non_zero_lens_mask_ib))


            new_ys = pad_list(new_ys,0).to(device) # #-of-phn x Lmax x odim
            non_zero_lens_mask = pad_list(non_zero_lens_mask, 0)

            
        xs = pad_list([torch.from_numpy(x).long() for x in xs], 0).to(device)
        ys = pad_list([torch.from_numpy(y).float() for y in ys], 0).to(device)
        if self.use_fe_condition:
            new_f0 = pad_list([torch.from_numpy(f00).float() for f00 in f0], 0) # B x Imax x 1
            new_en = pad_list([torch.from_numpy(enn).float() for enn in energy], 0) # B x Imax x 1

        # prepare dict
        new_batch = {
            "xs": xs,
            "ilens": ilens,
            "ys": ys,
            "olens": olens,
        }

        # load speaker embedding
        if spembs is not None:
            spembs = torch.from_numpy(np.array(spembs)).float()
            new_batch["spembs"] = spembs.to(device)

        # load second target
        if extras is not None:
            extras = pad_list([torch.from_numpy(extra).float() for extra in extras], 0)
            new_batch["extras"] = extras.to(device)
            new_batch["new_ys"] = new_ys
            new_batch["non_zero_lens_mask"] = non_zero_lens_mask
            new_batch["ds_nonzeros"] = torch.tensor(ds_nonzeros).to(device) 
            new_batch["output_masks"] = make_non_pad_mask(new_batch["ds_nonzeros"]).to(device) # #-of-phn x new_Lmax
            assert new_batch["new_ys"].shape[1] == new_batch["output_masks"].shape[1]
            if self.append_position:
                position = pad_list(position, 0)
                new_batch['position'] = position
                assert position.shape[0]==new_ys.shape[0]
            if self.use_fe_condition:
                new_batch['f0'] = new_f0
                new_batch['energy'] = new_en
                
        return new_batch
Exemple #10
0
    def forward(self,
                xs,
                ilens,
                ys,
                olens,
                spembs=None,
                extras=None,
                new_ys=None,
                non_zero_lens_mask=None,
                ds_nonzeros=None,
                output_masks=None,
                position=None,
                f0=None,
                energy=None,
                *args,
                **kwargs):
        """Calculate forward propagation.
        Args:
            xs (Tensor): Batch of padded character ids (B, Tmax).
            ilens (LongTensor): Batch of lengths of each input batch (B,).
            ys (Tensor): Batch of padded target features (B, Lmax, odim).
            olens (LongTensor): Batch of the lengths of each target (B,).
            spembs (Tensor, optional):
                Batch of speaker embedding vectors (B, spk_embed_dim).
            extras (Tensor, optional):
                Batch of groundtruth spectrograms (B, Lmax, spc_dim).
            new_ys (Tensor): reorganized mel-spectrograms
            non_zero_lens_masks (Tensor)
            ds_nonzeros (Tensor)
            output_masks (Tensor)
            position (Tenor): position values for each phoneme
            f0 (Tensor): pitch
            energy (Tensor)
        Returns:
            Tensor: Loss value.
        """
        # remove unnecessary padded part (for multi-gpus)
        max_in = max(ilens)
        max_out = max(olens)
        if max_in != xs.shape[1]:
            xs = xs[:, :max_in]
        if max_out != ys.shape[1]:
            ys = ys[:, :max_out]

        # calculate FCL-taco2-enc outputs
        hs, hlens = self.enc(xs, ilens)

        if self.spk_embed_dim is not None:
            spembs = F.normalize(spembs).unsqueeze(1).expand(
                -1, hs.size(1), -1)
            hs = torch.cat([hs, spembs], dim=-1)

        # duration predictor loss cal
        ds = extras.squeeze(-1)
        d_masks = make_pad_mask(ilens).to(xs.device)
        d_outs = self.duration_predictor(hs, d_masks)  # (B, Tmax)
        duration_masks = make_non_pad_mask(ilens).to(ys.device)
        d_outs = d_outs.masked_select(duration_masks)
        duration_loss = self.duration_criterion(
            d_outs, ds.masked_select(duration_masks))

        if self.use_fe_condition:
            expand_hs = hs
            fe_masks = d_masks
            if self.stop_gradient_from_pitch_predictor:
                p_outs = self.pitch_predictor(expand_hs.detach(),
                                              fe_masks.unsqueeze(-1))
            else:
                p_outs = self.pitch_predictor(
                    expand_hs, fe_masks.unsqueeze(-1))  # B x Tmax x 1

            if self.stop_gradient_from_energy_predictor:
                e_outs = self.energy_predictor(expand_hs.detach(),
                                               fe_masks.unsqueeze(-1))
            else:
                e_outs = self.energy_predictor(
                    expand_hs, fe_masks.unsqueeze(-1))  # B x Tmax x 1

            pitch_loss = self.prosody_criterion(p_outs, f0, ilens)
            energy_loss = self.prosody_criterion(e_outs, energy, ilens)
            p_embs = self.pitch_embed(f0.transpose(1, 2)).transpose(1, 2)
            e_embs = self.energy_embed(energy.transpose(1, 2)).transpose(1, 2)
        else:
            p_embs = None
            e_embs = None

        ylens = olens
        after_outs, before_outs = self.dec(hs, hlens, ds, ys, ylens, new_ys,
                                           non_zero_lens_mask, ds_nonzeros,
                                           output_masks, position, f0, energy,
                                           p_embs, e_embs)

        # modifiy mod part of groundtruth
        if self.reduction_factor > 1:
            olens = olens.new(
                [olen - olen % self.reduction_factor for olen in olens])
            max_out = max(olens)
            ys = ys[:, :max_out]

        # caluculate taco2 loss
        l1_loss, mse_loss = self.taco2_loss(after_outs, before_outs, ys, olens)
        loss = l1_loss + mse_loss + duration_loss
        report_keys = [
            {
                "l1_loss": l1_loss.item()
            },
            {
                "mse_loss": mse_loss.item()
            },
            {
                "dur_loss": duration_loss.item()
            },
        ]

        if self.use_fe_condition:
            prosody_weight = 1.0
            loss = loss + prosody_weight * (pitch_loss + energy_loss)
            report_keys += [
                {
                    'pitch_loss': pitch_loss.item()
                },
                {
                    'energy_loss': energy_loss.item()
                },
            ]

        report_keys += [{"loss": loss.item()}]
        self.reporter.report(report_keys)

        return loss
Exemple #11
0
    def forward(self, xs_pad, ilens, ys_pad):
        """E2E forward.

        Args:
            xs_pad (torch.Tensor): batch of padded source sequences (B, Tmax, idim)
            ilens (torch.Tensor): batch of lengths of input sequences (B)
            ys_pad (torch.Tensor): batch of padded target sequences (B, Lmax)

        Returns:
            loss (torch.Tensor): transducer loss value

        """
        # 1. encoder
        xs_pad = xs_pad[:, :max(ilens)]

        if "transformer" in self.etype:

            src_mask = make_non_pad_mask(ilens.tolist()).to(
                xs_pad.device).unsqueeze(-2)

            batchsize = xs_pad.size(0)
            inputs = xs_pad.unsqueeze(1)
            logging.info("inputs:{}".format(inputs.shape))
            logging.info("src_mask:{}".format(src_mask.shape))

            inputs_length = []
            if src_mask is not None:
                for mask in src_mask.tolist():
                    inputs_length.append(mask[0].count(True))

                for i in range(batchsize):
                    inputs_s = inputs[i].unsqueeze(0)[:, :,
                                                      0:inputs_length[i], :]
                    core_out = self.conv(inputs_s)
                    inputs_length[i] = core_out.size(2)

                inputs_length = torch.as_tensor(inputs_length)

            else:
                core_out = self.conv(inputs)
                inputs_length = core_out.size(2)
                inputs_length = torch.as_tensor(inputs_length)
                logging.info("inputs_length:{}".format(inputs_length))

            # block 1
            # the inputs shape of Conv2d is 4-dim of (bsz * c * l * w)
            # the inputs shape of Conv1d is 3-dim of (bsz * c * l)
            # the inputs shape of transformer is 3-dim of (l * bsz * c)
            # conv output format: (bsz * c * t * d)
            inputs = self.conv(inputs)

            # we can get a batch of 16 channels feature maps in all time steps
            # merge 16 channels of one timestep to create one self-attention input (batch, 16, dim)
            inputs = inputs.permute(2, 0, 1, 3)
            logging.info("inputs:{}".format(inputs.shape))
            merge = torch.zeros(inputs.size(0), batchsize, 512)

            for t in range(inputs.size(0)):  # max_length
                merge[t] = self.clayers(inputs[t],
                                        None)[0].reshape(batchsize, 512)

            xs = merge.permute(1, 0, 2)

            if inputs_length.dim() == 0:
                masks = make_non_pad_mask([inputs_length]).unsqueeze(-2)
            else:
                masks = make_non_pad_mask(inputs_length.tolist()).unsqueeze(-2)

            hs_pad, hs_mask = self.encoder(xs, masks)
        else:
            hs_pad, hs_mask, _ = self.enc(xs_pad, ilens)
        self.hs_pad = hs_pad

        # 1.5. transducer preparation related
        ys_in_pad, target, pred_len, target_len = prepare_loss_inputs(
            ys_pad, hs_mask)

        # 2. decoder
        if "transformer" in self.dtype:
            ys_mask = target_mask(ys_in_pad, self.blank_id)
            pred_pad, _ = self.decoder(ys_in_pad, ys_mask, hs_pad)
        else:
            if self.rnnt_mode == "rnnt":
                pred_pad = self.dec(hs_pad, ys_in_pad)
            else:
                pred_pad = self.dec(hs_pad, ys_in_pad, pred_len)
        self.pred_pad = pred_pad

        # 3. loss computation
        loss = self.criterion(pred_pad, target, pred_len, target_len)

        self.loss = loss
        loss_data = float(self.loss)

        # 4. compute cer/wer
        if self.training or self.error_calculator is None:
            cer, wer = None, None
        else:
            cer, wer = self.error_calculator(hs_pad, ys_pad)

        if not math.isnan(loss_data):
            self.reporter.report(loss_data, cer, wer)
        else:
            logging.warning("loss (=%f) is not correct", loss_data)

        return self.loss
Exemple #12
0
    def forward(
        self,
        after_outs: torch.Tensor,
        before_outs: torch.Tensor,
        d_outs: torch.Tensor,
        p_outs: torch.Tensor,
        e_outs: torch.Tensor,
        ys: torch.Tensor,
        ds: torch.Tensor,
        ps: torch.Tensor,
        es: torch.Tensor,
        ilens: torch.Tensor,
        olens: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        """Calculate forward propagation.

        Args:
            after_outs (Tensor): Batch of outputs after postnets (B, T_feats, odim).
            before_outs (Tensor): Batch of outputs before postnets (B, T_feats, odim).
            d_outs (LongTensor): Batch of outputs of duration predictor (B, T_text).
            p_outs (Tensor): Batch of outputs of pitch predictor (B, T_text, 1).
            e_outs (Tensor): Batch of outputs of energy predictor (B, T_text, 1).
            ys (Tensor): Batch of target features (B, T_feats, odim).
            ds (LongTensor): Batch of durations (B, T_text).
            ps (Tensor): Batch of target token-averaged pitch (B, T_text, 1).
            es (Tensor): Batch of target token-averaged energy (B, T_text, 1).
            ilens (LongTensor): Batch of the lengths of each input (B,).
            olens (LongTensor): Batch of the lengths of each target (B,).

        Returns:
            Tensor: L1 loss value.
            Tensor: Duration predictor loss value.
            Tensor: Pitch predictor loss value.
            Tensor: Energy predictor loss value.

        """
        # apply mask to remove padded part
        if self.use_masking:
            out_masks = make_non_pad_mask(olens).unsqueeze(-1).to(ys.device)
            before_outs = before_outs.masked_select(out_masks)
            if after_outs is not None:
                after_outs = after_outs.masked_select(out_masks)
            ys = ys.masked_select(out_masks)
            duration_masks = make_non_pad_mask(ilens).to(ys.device)
            d_outs = d_outs.masked_select(duration_masks)
            ds = ds.masked_select(duration_masks)
            pitch_masks = make_non_pad_mask(ilens).unsqueeze(-1).to(ys.device)
            p_outs = p_outs.masked_select(pitch_masks)
            e_outs = e_outs.masked_select(pitch_masks)
            ps = ps.masked_select(pitch_masks)
            es = es.masked_select(pitch_masks)

        # calculate loss
        l1_loss = self.l1_criterion(before_outs, ys)
        if after_outs is not None:
            l1_loss += self.l1_criterion(after_outs, ys)
        duration_loss = self.duration_criterion(d_outs, ds)
        pitch_loss = self.mse_criterion(p_outs, ps)
        energy_loss = self.mse_criterion(e_outs, es)

        # make weighted mask and apply it
        if self.use_weighted_masking:
            out_masks = make_non_pad_mask(olens).unsqueeze(-1).to(ys.device)
            out_weights = out_masks.float() / out_masks.sum(
                dim=1, keepdim=True).float()
            out_weights /= ys.size(0) * ys.size(2)
            duration_masks = make_non_pad_mask(ilens).to(ys.device)
            duration_weights = (
                duration_masks.float() /
                duration_masks.sum(dim=1, keepdim=True).float())
            duration_weights /= ds.size(0)

            # apply weight
            l1_loss = l1_loss.mul(out_weights).masked_select(out_masks).sum()
            duration_loss = (duration_loss.mul(duration_weights).masked_select(
                duration_masks).sum())
            pitch_masks = duration_masks.unsqueeze(-1)
            pitch_weights = duration_weights.unsqueeze(-1)
            pitch_loss = pitch_loss.mul(pitch_weights).masked_select(
                pitch_masks).sum()
            energy_loss = (energy_loss.mul(pitch_weights).masked_select(
                pitch_masks).sum())

        return l1_loss, duration_loss, pitch_loss, energy_loss
Exemple #13
0
    def inference(
        self,
        text: torch.Tensor,
        text_lengths: torch.Tensor,
        feats: Optional[torch.Tensor] = None,
        feats_lengths: Optional[torch.Tensor] = None,
        sids: Optional[torch.Tensor] = None,
        spembs: Optional[torch.Tensor] = None,
        lids: Optional[torch.Tensor] = None,
        dur: Optional[torch.Tensor] = None,
        noise_scale: float = 0.667,
        noise_scale_dur: float = 0.8,
        alpha: float = 1.0,
        max_len: Optional[int] = None,
        use_teacher_forcing: bool = False,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Run inference.

        Args:
            text (Tensor): Input text index tensor (B, T_text,).
            text_lengths (Tensor): Text length tensor (B,).
            feats (Tensor): Feature tensor (B, aux_channels, T_feats,).
            feats_lengths (Tensor): Feature length tensor (B,).
            sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1).
            spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim).
            lids (Optional[Tensor]): Language index tensor (B,) or (B, 1).
            dur (Optional[Tensor]): Ground-truth duration (B, T_text,). If provided,
                skip the prediction of durations (i.e., teacher forcing).
            noise_scale (float): Noise scale parameter for flow.
            noise_scale_dur (float): Noise scale parameter for duration predictor.
            alpha (float): Alpha parameter to control the speed of generated speech.
            max_len (Optional[int]): Maximum length of acoustic feature sequence.
            use_teacher_forcing (bool): Whether to use teacher forcing.

        Returns:
            Tensor: Generated waveform tensor (B, T_wav).
            Tensor: Monotonic attention weight tensor (B, T_feats, T_text).
            Tensor: Duration tensor (B, T_text).

        """
        # encoder
        x, m_p, logs_p, x_mask = self.text_encoder(text, text_lengths)
        g = None
        if self.spks is not None:
            # (B, global_channels, 1)
            g = self.global_emb(sids.view(-1)).unsqueeze(-1)
        if self.spk_embed_dim is not None:
            # (B, global_channels, 1)
            g_ = self.spemb_proj(F.normalize(
                spembs.unsqueeze(0))).unsqueeze(-1)
            if g is None:
                g = g_
            else:
                g = g + g_
        if self.langs is not None:
            # (B, global_channels, 1)
            g_ = self.lang_emb(lids.view(-1)).unsqueeze(-1)
            if g is None:
                g = g_
            else:
                g = g + g_

        if use_teacher_forcing:
            # forward posterior encoder
            z, m_q, logs_q, y_mask = self.posterior_encoder(feats,
                                                            feats_lengths,
                                                            g=g)

            # forward flow
            z_p = self.flow(z, y_mask, g=g)  # (B, H, T_feats)

            # monotonic alignment search
            s_p_sq_r = torch.exp(-2 * logs_p)  # (B, H, T_text)
            # (B, 1, T_text)
            neg_x_ent_1 = torch.sum(
                -0.5 * math.log(2 * math.pi) - logs_p,
                [1],
                keepdim=True,
            )
            # (B, T_feats, H) x (B, H, T_text) = (B, T_feats, T_text)
            neg_x_ent_2 = torch.matmul(
                -0.5 * (z_p**2).transpose(1, 2),
                s_p_sq_r,
            )
            # (B, T_feats, H) x (B, H, T_text) = (B, T_feats, T_text)
            neg_x_ent_3 = torch.matmul(
                z_p.transpose(1, 2),
                (m_p * s_p_sq_r),
            )
            # (B, 1, T_text)
            neg_x_ent_4 = torch.sum(
                -0.5 * (m_p**2) * s_p_sq_r,
                [1],
                keepdim=True,
            )
            # (B, T_feats, T_text)
            neg_x_ent = neg_x_ent_1 + neg_x_ent_2 + neg_x_ent_3 + neg_x_ent_4
            # (B, 1, T_feats, T_text)
            attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(
                y_mask, -1)
            # monotonic attention weight: (B, 1, T_feats, T_text)
            attn = self.maximum_path(
                neg_x_ent,
                attn_mask.squeeze(1),
            ).unsqueeze(1)
            dur = attn.sum(2)  # (B, 1, T_text)

            # forward decoder with random segments
            wav = self.decoder(z * y_mask, g=g)
        else:
            # duration
            if dur is None:
                logw = self.duration_predictor(
                    x,
                    x_mask,
                    g=g,
                    inverse=True,
                    noise_scale=noise_scale_dur,
                )
                w = torch.exp(logw) * x_mask * alpha
                dur = torch.ceil(w)
            y_lengths = torch.clamp_min(torch.sum(dur, [1, 2]), 1).long()
            y_mask = make_non_pad_mask(y_lengths).unsqueeze(1).to(text.device)
            attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(
                y_mask, -1)
            attn = self._generate_path(dur, attn_mask)

            # expand the length to match with the feature sequence
            # (B, T_feats, T_text) x (B, T_text, H) -> (B, H, T_feats)
            m_p = torch.matmul(
                attn.squeeze(1),
                m_p.transpose(1, 2),
            ).transpose(1, 2)
            # (B, T_feats, T_text) x (B, T_text, H) -> (B, H, T_feats)
            logs_p = torch.matmul(
                attn.squeeze(1),
                logs_p.transpose(1, 2),
            ).transpose(1, 2)

            # decoder
            z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
            z = self.flow(z_p, y_mask, g=g, inverse=True)
            wav = self.decoder((z * y_mask)[:, :, :max_len], g=g)

        return wav.squeeze(1), attn.squeeze(1), dur.squeeze(1)
    def forward(self, xs_pad, ilens, ys_pad):
        """E2E forward.

        :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim)
        :param torch.Tensor ilens: batch of lengths of source sequences (B)
        :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax)
        :return: ctc loass value
        :rtype: torch.Tensor
        :return: attention loss value
        :rtype: torch.Tensor
        :return: accuracy in attention decoder
        :rtype: float
        """
        # 1. forward encoder
        xs_pad = xs_pad[:, :max(ilens)]  # for data parallel
        src_mask = make_non_pad_mask(ilens.tolist()).to(
            xs_pad.device).unsqueeze(-2)
        hs_pad, hs_mask = self.encoder(xs_pad, src_mask)

        # CTC forward
        ys = [y[y != self.ignore_id] for y in ys_pad]
        y_len = max([len(y) for y in ys])
        ys_pad = ys_pad[:, :y_len]
        self.hs_pad = hs_pad
        cer_ctc = None
        batch_size = xs_pad.size(0)
        if self.mtlalpha == 0.0:
            loss_ctc = None
        else:
            batch_size = xs_pad.size(0)
            hs_len = hs_mask.view(batch_size, -1).sum(1)
            loss_ctc = self.ctc(hs_pad.view(batch_size, -1, self.adim), hs_len,
                                ys_pad)

        # trigger mask
        start_time = time.time()
        # 2. forward decoder
        ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos,
                                            self.ignore_id)
        ys_mask = target_mask(ys_in_pad, self.ignore_id)
        pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask)
        self.pred_pad = pred_pad

        # 3. compute attention loss
        loss_att = self.criterion(pred_pad, ys_out_pad)
        self.acc = th_accuracy(pred_pad.view(-1, self.odim),
                               ys_out_pad,
                               ignore_label=self.ignore_id)

        # copyied from e2e_asr
        alpha = self.mtlalpha
        if alpha == 0:
            self.loss = loss_att
            loss_att_data = float(loss_att)
            loss_ctc_data = None
        elif alpha == 1:
            self.loss = loss_ctc
            loss_att_data = None
            loss_ctc_data = float(loss_ctc)
        else:
            self.loss = alpha * loss_ctc + (1 - alpha) * loss_att
            loss_att_data = float(loss_att)

        return self.loss, loss_ctc_data, loss_att_data, self.acc
Exemple #15
0
    def inference(
        self,
        text: torch.Tensor,
        text_lengths: torch.Tensor,
        feats: Optional[torch.Tensor] = None,
        feats_lengths: Optional[torch.Tensor] = None,
        pitch: Optional[torch.Tensor] = None,
        energy: Optional[torch.Tensor] = None,
        sids: Optional[torch.Tensor] = None,
        spembs: Optional[torch.Tensor] = None,
        lids: Optional[torch.Tensor] = None,
        use_teacher_forcing: bool = False,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Run inference.

        Args:
            text (Tensor): Input text index tensor (B, T_text,).
            text_lengths (Tensor): Text length tensor (B,).
            feats (Tensor): Feature tensor (B, T_feats, aux_channels).
            feats_lengths (Tensor): Feature length tensor (B,).
            pitch (Tensor): Pitch tensor (B, T_feats, 1)
            energy (Tensor): Energy tensor (B, T_feats, 1)
            sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1).
            spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim).
            lids (Optional[Tensor]): Language index tensor (B,) or (B, 1).
            use_teacher_forcing (bool): Whether to use teacher forcing.

        Returns:
            Tensor: Generated waveform tensor (B, T_wav).
            Tensor: Duration tensor (B, T_text).

        """
        # forward encoder
        x_masks = self._source_mask(text_lengths)
        hs, _ = self.encoder(text, x_masks)  # (B, T_text, adim)

        # integrate with GST
        if self.use_gst:
            style_embs = self.gst(feats)
            hs = hs + style_embs.unsqueeze(1)

        # integrate with SID and LID embeddings
        if self.spks is not None:
            sid_embs = self.sid_emb(sids.view(-1))
            hs = hs + sid_embs.unsqueeze(1)
        if self.langs is not None:
            lid_embs = self.lid_emb(lids.view(-1))
            hs = hs + lid_embs.unsqueeze(1)

        # integrate speaker embedding
        if self.spk_embed_dim is not None:
            hs = self._integrate_with_spk_embed(hs, spembs)

        h_masks = make_pad_mask(text_lengths).to(hs.device)
        if use_teacher_forcing:
            # forward alignment module and obtain duration, averaged pitch, energy
            log_p_attn = self.alignment_module(hs, feats, h_masks)
            d_outs, _ = viterbi_decode(log_p_attn, text_lengths, feats_lengths)
            p_outs = average_by_duration(d_outs, pitch.squeeze(-1),
                                         text_lengths,
                                         feats_lengths).unsqueeze(-1)
            e_outs = average_by_duration(d_outs, energy.squeeze(-1),
                                         text_lengths,
                                         feats_lengths).unsqueeze(-1)
        else:
            # forward duration predictor and variance predictors
            p_outs = self.pitch_predictor(hs, h_masks.unsqueeze(-1))
            e_outs = self.energy_predictor(hs, h_masks.unsqueeze(-1))
            d_outs = self.duration_predictor.inference(hs, h_masks)

        p_embs = self.pitch_embed(p_outs.transpose(1, 2)).transpose(1, 2)
        e_embs = self.energy_embed(e_outs.transpose(1, 2)).transpose(1, 2)
        hs = hs + e_embs + p_embs

        # upsampling
        if feats_lengths is not None:
            h_masks = make_non_pad_mask(feats_lengths).to(hs.device)
        else:
            h_masks = None
        d_masks = make_non_pad_mask(text_lengths).to(d_outs.device)
        hs = self.length_regulator(hs, d_outs, h_masks,
                                   d_masks)  # (B, T_feats, adim)

        # forward decoder
        if feats_lengths is not None:
            h_masks = self._source_mask(feats_lengths)
        else:
            h_masks = None
        zs, _ = self.decoder(hs, h_masks)  # (B, T_feats, adim)

        # forward generator
        wav = self.generator(zs.transpose(1, 2))

        return wav.squeeze(1), d_outs
Exemple #16
0
    def forward(
        self,
        d_outs: torch.Tensor,
        ds: torch.Tensor,
        p_outs: torch.Tensor,
        ps: torch.Tensor,
        e_outs: torch.Tensor,
        es: torch.Tensor,
        ilens: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        """Calculate forward propagation.

        Args:
            d_outs (LongTensor): Batch of outputs of duration predictor (B, T_text).
            ds (LongTensor): Batch of durations (B, T_text).
            p_outs (Tensor): Batch of outputs of pitch predictor (B, T_text, 1).
            ps (Tensor): Batch of target token-averaged pitch (B, T_text, 1).
            e_outs (Tensor): Batch of outputs of energy predictor (B, T_text, 1).
            es (Tensor): Batch of target token-averaged energy (B, T_text, 1).
            ilens (LongTensor): Batch of the lengths of each input (B,).

        Returns:
            Tensor: Duration predictor loss value.
            Tensor: Pitch predictor loss value.
            Tensor: Energy predictor loss value.

        """
        # apply mask to remove padded part
        if self.use_masking:
            duration_masks = make_non_pad_mask(ilens).to(ds.device)
            d_outs = d_outs.masked_select(duration_masks)
            ds = ds.masked_select(duration_masks)
            pitch_masks = make_non_pad_mask(ilens).unsqueeze(-1).to(ds.device)
            p_outs = p_outs.masked_select(pitch_masks)
            e_outs = e_outs.masked_select(pitch_masks)
            ps = ps.masked_select(pitch_masks)
            es = es.masked_select(pitch_masks)

        # calculate loss
        duration_loss = self.duration_criterion(d_outs, ds)
        pitch_loss = self.mse_criterion(p_outs, ps)
        energy_loss = self.mse_criterion(e_outs, es)

        # make weighted mask and apply it
        if self.use_weighted_masking:
            duration_masks = make_non_pad_mask(ilens).to(ds.device)
            duration_weights = (
                duration_masks.float() / duration_masks.sum(dim=1, keepdim=True).float()
            )
            duration_weights /= ds.size(0)

            # apply weight
            duration_loss = (
                duration_loss.mul(duration_weights).masked_select(duration_masks).sum()
            )
            pitch_masks = duration_masks.unsqueeze(-1)
            pitch_weights = duration_weights.unsqueeze(-1)
            pitch_loss = pitch_loss.mul(pitch_weights).masked_select(pitch_masks).sum()
            energy_loss = (
                energy_loss.mul(pitch_weights).masked_select(pitch_masks).sum()
            )

        return duration_loss, pitch_loss, energy_loss
Exemple #17
0
 def _source_mask(self, ilens: torch.Tensor) -> torch.Tensor:
     x_masks = make_non_pad_mask(ilens).to(next(self.parameters()).device)
     return x_masks.unsqueeze(-2)
Exemple #18
0
    def forward(self, xs_pad, ilens, ys_pad):
        """E2E forward.

        :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim)
        :param torch.Tensor ilens: batch of lengths of source sequences (B)
        :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax)
        :return: ctc loass value
        :rtype: torch.Tensor
        :return: attention loss value
        :rtype: torch.Tensor
        :return: accuracy in attention decoder
        :rtype: float
        '''
        """
        if self.attention_enc_type in [
                'self_attn_dynamic_span', 'self_attn_adaptive_span',
                'self_attn_adaptive_span2', 'self_attn_fixed_span2',
                'self_attn_dynamic_span2'
        ]:
            for layer in self.encoder.encoders:
                layer.self_attn.clamp_param()
        if self.attention_dec_type in [
                'self_attn_dynamic_span', 'self_attn_adaptive_span',
                'self_attn_adaptive_span2', 'self_attn_fixed_span2',
                'self_attn_dynamic_span2'
        ]:
            for layer in self.decoder.decoders:
                layer.self_attn.clamp_param()

        # 1. forward encoder
        xs_pad = xs_pad[:, :max(ilens)]  # for data parallel
        src_mask = make_non_pad_mask(ilens.tolist()).to(
            xs_pad.device).unsqueeze(-2)
        hs_pad, hs_mask = self.encoder(xs_pad, src_mask)
        self.hs_pad = hs_pad

        # 2. forward decoder
        ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos,
                                            self.ignore_id)
        ys_mask = target_mask(ys_in_pad, self.ignore_id)
        pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask)
        self.pred_pad = pred_pad

        # 3. compute attention loss
        loss_att = self.criterion(pred_pad, ys_out_pad)
        self.acc = th_accuracy(pred_pad.view(-1, self.odim),
                               ys_out_pad,
                               ignore_label=self.ignore_id)

        # TODO(karita) show predicted text
        # TODO(karita) calculate these stats
        cer_ctc = None
        if self.mtlalpha == 0.0:
            loss_ctc = None
        else:
            batch_size = xs_pad.size(0)
            hs_len = hs_mask.view(batch_size, -1).sum(1)
            loss_ctc = self.ctc(hs_pad.view(batch_size, -1, self.adim), hs_len,
                                ys_pad)
            if self.error_calculator is not None:
                ys_hat = self.ctc.argmax(hs_pad.view(batch_size, -1,
                                                     self.adim)).data
                cer_ctc = self.error_calculator(ys_hat.cpu(),
                                                ys_pad.cpu(),
                                                is_ctc=True)

        # 5. compute cer/wer
        if self.training or self.error_calculator is None:
            cer, wer = None, None
        else:
            ys_hat = pred_pad.argmax(dim=-1)
            cer, wer = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())

        # copyied from e2e_asr
        alpha = self.mtlalpha
        if alpha == 0:
            self.loss = loss_att
            loss_att_data = float(loss_att)
            loss_ctc_data = None
        elif alpha == 1:
            self.loss = loss_ctc
            loss_att_data = None
            loss_ctc_data = float(loss_ctc)
        else:
            self.loss = alpha * loss_ctc + (1 - alpha) * loss_att
            loss_att_data = float(loss_att)
            loss_ctc_data = float(loss_ctc)
        # xkc09 Span attention loss computation
        # xkc09 Span attention size loss computation
        loss_span = 0
        if self.attention_enc_type in [
                'self_attn_dynamic_span', 'self_attn_adaptive_span',
                'self_attn_adaptive_span2', 'self_attn_dynamic_span2'
        ]:
            loss_span += sum([
                layer.self_attn.get_mean_span()
                for layer in self.encoder.encoders
            ])
        if self.attention_dec_type in [
                'self_attn_dynamic_span', 'self_attn_adaptive_span',
                'self_attn_adaptive_span2', 'self_attn_dynamic_span2'
        ]:
            loss_span += sum([
                layer.self_attn.get_mean_span()
                for layer in self.decoder.decoders
            ])
        # xkc09 Span attention ratio loss computation
        loss_ratio = 0
        if self.ratio_adaptive:
            # target_ratio = 0.5
            if self.attention_enc_type in [
                    'self_attn_adaptive_span2', 'self_attn_fixed_span2',
                    'self_attn_dynamic_span2'
            ]:
                loss_ratio += sum([
                    1 - layer.self_attn.get_mean_ratio()
                    for layer in self.encoder.encoders
                ])
            if self.attention_dec_type in [
                    'self_attn_adaptive_span2', 'self_attn_fixed_span2',
                    'self_attn_dynamic_span2'
            ]:
                loss_ratio += sum([
                    1 - layer.self_attn.get_mean_ratio()
                    for layer in self.decoder.decoders
                ])
        if (self.attention_enc_type in [
                'self_attn_dynamic_span', 'self_attn_adaptive_span',
                'self_attn_adaptive_span2', 'self_attn_fixed_span2',
                'self_attn_dynamic_span2'
        ] or self.attention_dec_type in [
                'self_attn_dynamic_span', 'self_attn_adaptive_span',
                'self_attn_adaptive_span2', 'self_attn_fixed_span2',
                'self_attn_dynamic_span2'
        ]):
            if getattr(self, 'span_loss_coef', None):
                self.loss += (loss_span + loss_ratio) * self.span_loss_coef

        loss_data = float(self.loss)
        if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data):
            self.reporter.report(loss_ctc_data, loss_att_data, self.acc,
                                 cer_ctc, cer, wer, loss_data)
        else:
            logging.warning("loss (=%f) is not correct", loss_data)
        return self.loss
Exemple #19
0
    def forward(
        self,
        text: torch.Tensor,
        text_lengths: torch.Tensor,
        feats: torch.Tensor,
        feats_lengths: torch.Tensor,
        pitch: torch.Tensor,
        pitch_lengths: torch.Tensor,
        energy: torch.Tensor,
        energy_lengths: torch.Tensor,
        sids: Optional[torch.Tensor] = None,
        spembs: Optional[torch.Tensor] = None,
        lids: Optional[torch.Tensor] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
               torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
               torch.Tensor, torch.Tensor, ]:
        """Calculate forward propagation.

        Args:
            text (Tensor): Text index tensor (B, T_text).
            text_lengths (Tensor): Text length tensor (B,).
            feats (Tensor): Feature tensor (B, T_feats, aux_channels).
            feats_lengths (Tensor): Feature length tensor (B,).
            pitch (Tensor): Batch of padded token-averaged pitch (B, T_text, 1).
            pitch_lengths (LongTensor): Batch of pitch lengths (B, T_text).
            energy (Tensor): Batch of padded token-averaged energy (B, T_text, 1).
            energy_lengths (LongTensor): Batch of energy lengths (B, T_text).
            sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1).
            spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim).
            lids (Optional[Tensor]): Language index tensor (B,) or (B, 1).

        Returns:
            Tensor: Waveform tensor (B, 1, segment_size * upsample_factor).
            Tensor: Binarization loss ().
            Tensor: Log probability attention matrix (B, T_feats, T_text).
            Tensor: Segments start index tensor (B,).
            Tensor: predicted duration (B, T_text).
            Tensor: ground-truth duration obtained from an alignment module (B, T_text).
            Tensor: predicted pitch (B, T_text,1).
            Tensor: ground-truth averaged pitch (B, T_text, 1).
            Tensor: predicted energy (B, T_text, 1).
            Tensor: ground-truth averaged energy (B, T_text, 1).

        """
        text = text[:, :text_lengths.max()]  # for data-parallel
        feats = feats[:, :feats_lengths.max()]  # for data-parallel
        pitch = pitch[:, :pitch_lengths.max()]  # for data-parallel
        energy = energy[:, :energy_lengths.max()]  # for data-parallel

        # forward encoder
        x_masks = self._source_mask(text_lengths)
        hs, _ = self.encoder(text, x_masks)  # (B, T_text, adim)

        # integrate with GST
        if self.use_gst:
            style_embs = self.gst(feats)
            hs = hs + style_embs.unsqueeze(1)

        # integrate with SID and LID embeddings
        if self.spks is not None:
            sid_embs = self.sid_emb(sids.view(-1))
            hs = hs + sid_embs.unsqueeze(1)
        if self.langs is not None:
            lid_embs = self.lid_emb(lids.view(-1))
            hs = hs + lid_embs.unsqueeze(1)

        # integrate speaker embedding
        if self.spk_embed_dim is not None:
            hs = self._integrate_with_spk_embed(hs, spembs)

        # forward alignment module and obtain duration, averaged pitch, energy
        h_masks = make_pad_mask(text_lengths).to(hs.device)
        log_p_attn = self.alignment_module(hs, feats, h_masks)
        ds, bin_loss = viterbi_decode(log_p_attn, text_lengths, feats_lengths)
        ps = average_by_duration(ds, pitch.squeeze(-1), text_lengths,
                                 feats_lengths).unsqueeze(-1)
        es = average_by_duration(ds, energy.squeeze(-1), text_lengths,
                                 feats_lengths).unsqueeze(-1)

        # forward duration predictor and variance predictors
        if self.stop_gradient_from_pitch_predictor:
            p_outs = self.pitch_predictor(hs.detach(), h_masks.unsqueeze(-1))
        else:
            p_outs = self.pitch_predictor(hs, h_masks.unsqueeze(-1))
        if self.stop_gradient_from_energy_predictor:
            e_outs = self.energy_predictor(hs.detach(), h_masks.unsqueeze(-1))
        else:
            e_outs = self.energy_predictor(hs, h_masks.unsqueeze(-1))
        d_outs = self.duration_predictor(hs, h_masks)

        # use groundtruth in training
        p_embs = self.pitch_embed(ps.transpose(1, 2)).transpose(1, 2)
        e_embs = self.energy_embed(es.transpose(1, 2)).transpose(1, 2)
        hs = hs + e_embs + p_embs

        # upsampling
        h_masks = make_non_pad_mask(feats_lengths).to(hs.device)
        d_masks = make_non_pad_mask(text_lengths).to(ds.device)
        hs = self.length_regulator(hs, ds, h_masks,
                                   d_masks)  # (B, T_feats, adim)

        # forward decoder
        h_masks = self._source_mask(feats_lengths)
        zs, _ = self.decoder(hs, h_masks)  # (B, T_feats, adim)

        # get random segments
        z_segments, z_start_idxs = get_random_segments(
            zs.transpose(1, 2),
            feats_lengths,
            self.segment_size,
        )
        # forward generator
        wav = self.generator(z_segments)

        return (
            wav,
            bin_loss,
            log_p_attn,
            z_start_idxs,
            d_outs,
            ds,
            p_outs,
            ps,
            e_outs,
            es,
        )
Exemple #20
0
    def forward(self, xs_pad, ilens, ys_pad):
        """E2E forward.

        Args:
            xs_pad (torch.Tensor): batch of padded source sequences (B, Tmax, idim)
            ilens (torch.Tensor): batch of lengths of input sequences (B)
            ys_pad (torch.Tensor): batch of padded target sequences (B, Lmax)

        Returns:
            loss (torch.Tensor): transducer loss value

        """
        # 1. encoder
        xs_pad = xs_pad[:, :max(ilens)]

        if "custom" in self.etype:
            src_mask = make_non_pad_mask(ilens.tolist()).to(
                xs_pad.device).unsqueeze(-2)

            _hs_pad, hs_mask = self.encoder(xs_pad, src_mask)
        else:
            _hs_pad, hs_mask, _ = self.enc(xs_pad, ilens)

        if self.use_aux_task:
            hs_pad, aux_hs_pad = _hs_pad[0], _hs_pad[1]
        else:
            hs_pad, aux_hs_pad = _hs_pad, None

        # 1.5. transducer preparation related
        ys_in_pad, ys_out_pad, target, pred_len, target_len = prepare_loss_inputs(
            ys_pad, hs_mask)

        # 2. decoder
        if "custom" in self.dtype:
            ys_mask = target_mask(ys_in_pad, self.blank_id)
            pred_pad, _ = self.decoder(ys_in_pad, ys_mask, hs_pad)
        else:
            pred_pad = self.dec(hs_pad, ys_in_pad)

        z = self.joint_network(hs_pad.unsqueeze(2), pred_pad.unsqueeze(1))

        # 3. loss computation
        loss_trans = self.criterion(z, target, pred_len, target_len)

        if self.use_aux_task and aux_hs_pad is not None:
            loss_trans += self.auxiliary_task(aux_hs_pad, pred_pad, z, target,
                                              pred_len, target_len)

        if self.use_aux_ctc:
            if "custom" in self.etype:
                hs_mask = torch.IntTensor([h.size(1) for h in hs_mask], ).to(
                    hs_mask.device)

            loss_ctc = self.aux_ctc(hs_pad, hs_mask, ys_pad)
        else:
            loss_ctc = 0

        if self.use_aux_cross_entropy:
            loss_ce = self.aux_cross_entropy(self.aux_decoder_output(pred_pad),
                                             ys_out_pad)
        else:
            loss_ce = 0

        loss = (self.transducer_weight * loss_trans +
                self.aux_ctc_weight * loss_ctc +
                self.aux_cross_entropy_weight * loss_ce)

        self.loss = loss
        loss_data = float(loss)

        # 4. compute cer/wer
        if self.training or self.error_calculator is None:
            cer, wer = None, None
        else:
            cer, wer = self.error_calculator(hs_pad, ys_pad)

        if not math.isnan(loss_data):
            self.reporter.report(loss_data, cer, wer)
        else:
            logging.warning("loss (=%f) is not correct", loss_data)

        return self.loss
Exemple #21
0
    def forward(self, xs_pad, ilens, ys_pad):
        """E2E forward.

        :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim)
        :param torch.Tensor ilens: batch of lengths of source sequences (B)
        :param torch.Tensor ys_pad: batch of padded target sequences
                                    (B, num_spkrs, Lmax)
        :return: ctc loass value
        :rtype: torch.Tensor
        :return: attention loss value
        :rtype: torch.Tensor
        :return: accuracy in attention decoder
        :rtype: float
        """
        # 1. forward encoder
        xs_pad = xs_pad[:, :max(ilens)]  # for data parallel
        src_mask = make_non_pad_mask(ilens.tolist()).to(
            xs_pad.device).unsqueeze(-2)
        hs_pad, hs_mask = self.encoder(xs_pad,
                                       src_mask)  # list: speaker differentiate
        self.hs_pad = hs_pad

        # 2. ctc
        # TODO(karita) show predicted text
        # TODO(karita) calculate these stats
        cer_ctc = None
        assert self.mtlalpha > 0.0
        batch_size = xs_pad.size(0)
        ys_pad = ys_pad.transpose(0, 1)  # (num_spkrs, B, Lmax)
        hs_len = [
            hs_mask[i].view(batch_size, -1).sum(1)
            for i in range(self.num_spkrs)
        ]
        loss_ctc_perm = torch.stack(
            [
                self.ctc(
                    hs_pad[i // self.num_spkrs].view(batch_size, -1,
                                                     self.adim),
                    hs_len[i // self.num_spkrs],
                    ys_pad[i % self.num_spkrs],
                ) for i in range(self.num_spkrs**2)
            ],
            dim=1,
        )  # (B, num_spkrs^2)
        loss_ctc, min_perm = self.pit.pit_process(loss_ctc_perm)
        logging.info("ctc loss:" + str(float(loss_ctc)))

        # Permute the labels according to loss
        for b in range(batch_size):  # B
            ys_pad[:, b] = ys_pad[min_perm[b], b]  # (num_spkrs, B, Lmax)
        ys_out_len = [
            float(torch.sum(ys_pad[i] != self.ignore_id))
            for i in range(self.num_spkrs)
        ]

        # TODO(karita) show predicted text
        # TODO(karita) calculate these stats
        if self.error_calculator is not None:
            cer_ctc = []
            for i in range(self.num_spkrs):
                ys_hat = self.ctc.argmax(hs_pad[i].view(
                    batch_size, -1, self.adim)).data
                cer_ctc.append(
                    self.error_calculator(ys_hat.cpu(),
                                          ys_pad[i].cpu(),
                                          is_ctc=True))
            cer_ctc = sum(map(lambda x: x[0] * x[1], zip(
                cer_ctc, ys_out_len))) / sum(ys_out_len)
        else:
            cer_ctc = None

        # 3. forward decoder
        if self.mtlalpha == 1.0:
            loss_att, self.acc, cer, wer = None, None, None, None
        else:
            pred_pad, pred_mask = [None] * self.num_spkrs, [None
                                                            ] * self.num_spkrs
            loss_att, acc = [None] * self.num_spkrs, [None] * self.num_spkrs
            for i in range(self.num_spkrs):
                (
                    pred_pad[i],
                    pred_mask[i],
                    loss_att[i],
                    acc[i],
                ) = self.decoder_and_attention(hs_pad[i], hs_mask[i],
                                               ys_pad[i], batch_size)

            # 4. compute attention loss
            # The following is just an approximation
            loss_att = sum(
                map(lambda x: x[0] * x[1], zip(loss_att,
                                               ys_out_len))) / sum(ys_out_len)
            self.acc = sum(map(lambda x: x[0] * x[1], zip(
                acc, ys_out_len))) / sum(ys_out_len)

            # 5. compute cer/wer
            if self.training or self.error_calculator is None:
                cer, wer = None, None
            else:
                ys_hat = pred_pad.argmax(dim=-1)
                cer, wer = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())

        # copyied from e2e_asr
        alpha = self.mtlalpha
        if alpha == 0:
            self.loss = loss_att
            loss_att_data = float(loss_att)
            loss_ctc_data = None
        elif alpha == 1:
            self.loss = loss_ctc
            loss_att_data = None
            loss_ctc_data = float(loss_ctc)
        else:
            self.loss = alpha * loss_ctc + (1 - alpha) * loss_att
            loss_att_data = float(loss_att)
            loss_ctc_data = float(loss_ctc)

        loss_data = float(self.loss)
        if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data):
            self.reporter.report(loss_ctc_data, loss_att_data, self.acc,
                                 cer_ctc, cer, wer, loss_data)
        else:
            logging.warning("loss (=%f) is not correct", loss_data)
        return self.loss
    def forward(self, xs_pad, ilens, ys_pad):
        """E2E forward.

        Args:
            xs_pad (torch.Tensor): batch of padded source sequences (B, Tmax, idim)
            ilens (torch.Tensor): batch of lengths of input sequences (B)
            ys_pad (torch.Tensor): batch of padded target sequences (B, Lmax)

        Returns:
            loss (torch.Tensor): transducer loss value

        """

        # 1. encoder
        # if etpye is transformer, deal the padding
        # xs_pad:[8, 393, 83]
        # ilens:[393 * 8]
        # src_mask:[8, 1, 393]
        # hs_mask:[8, 1, 65]
        if self.etype == "transformer":
            xs_pad = xs_pad[:, :max(ilens)]
            src_mask = make_non_pad_mask(ilens.tolist()).to(
                xs_pad.device).unsqueeze(-2)

            hs_pad, hs_mask = self.encoder(xs_pad, src_mask)

        else:
            logging.info("enc!!!")
            hs_pad, hs_mask, _ = self.encoder(xs_pad, ilens)
        self.hs_pad = hs_pad

        # 1.5. transducer preparation related
        # ys_in_pad: sos,1,2,...,0 [8, 14]
        # target: 1,2,... [8, 13]
        # pred_len: [8]
        # target_len: [8]
        # ys_out_pad:1,2,...,eos,-1

        ys = [y[y != self.ignore_id] for y in ys_pad]

        eos = ys[0].new([self.eos])
        sos = ys[0].new([self.sos])

        ys_in = [torch.cat([sos, y], dim=0) for y in ys]
        ys_out = [torch.cat([y, eos], dim=0) for y in ys]

        ys_out_pad = pad_list(ys_out, self.ignore_id)

        ys_in_pad, target, pred_len, target_len = prepare_loss_inputs(
            ys_pad, hs_mask)
        # 2. decoder
        # ys_mask:[8, 16, 16]

        if self.dtype == "transformer":
            ys_mask = target_mask(ys_in_pad, self.blank_id)

            pred_pad, pred_att, _ = self.decoder(ys_in_pad, ys_mask, hs_pad,
                                                 hs_mask)
        else:
            if self.rnnt_mode == "rnnt":
                pred_pad = self.dec(hs_pad, ys_in_pad)
            else:
                pred_pad = self.dec(hs_pad, ys_in_pad, pred_len)
        self.pred_pad = pred_pad

        # 3. loss computation
        loss_att = F.cross_entropy(
            pred_att,
            ys_out_pad.view(-1),  # batch x olength
            ignore_index=self.ignore_id,
        )
        # compute perplexity
        # ppl = math.exp(loss_att.item())
        # -1: eos, which is removed in the loss computation
        loss_att *= np.mean([len(x) for x in ys_in]) - 1

        loss_rnnt = self.criterion(pred_pad, target, pred_len, target_len)

        # loss_ctc = self.ctc(hs_pad, pred_len, ys_pad)

        alpha = self.mtlalpha
        beta = self.mtlbeta
        gamma = self.mtlgamma

        self.loss_rnnt = loss_rnnt
        self.loss_att = loss_att
        # self.loss_ctc = loss_ctc

        # self.loss = alpha * self.loss_ctc + beta * self.loss_rnnt + gamma * self.loss_att
        self.loss = beta * self.loss_rnnt + gamma * self.loss_att
        # self.loss = alpha * self.loss_ctc
        loss_data = float(self.loss)
        # loss_ctc_data = float(self.loss_ctc)
        loss_att_data = float(self.loss_att)
        loss_rnnt_data = float(self.loss_rnnt)

        # loss_att_data = None
        # loss_rnnt_data = None
        # 4. compute cer/wer

        if self.training or self.error_calculator is None:
            logging.info("ALL none!!!!!")
            cer, wer = None, None
        else:
            cer, wer = self.error_calculator(hs_pad, ys_pad)

        # with open('/home/oshindo/espnet/egs/aishell/asr1/exp/train_sp_pytorch_e2e_asr_transducer/blstmp_ctc.txt', "a+") as fid:
        #     fid.write("loss:" + str(loss_ctc_data) + '\n')

        if not math.isnan(loss_data):
            self.reporter.report(loss_data, loss_rnnt_data, loss_att_data, cer,
                                 wer)
        else:
            logging.warning("loss (=%f) is not correct", loss_data)

        return self.loss
Exemple #23
0
    def forward(self, xs, ilens, ys, olens, spembs=None, *args, **kwargs):
        """Calculate forward propagation.

        Args:
            xs (Tensor): Batch of padded character ids (B, Tmax).
            ilens (LongTensor): Batch of lengths of each input batch (B,).
            ys (Tensor): Batch of padded target features (B, Lmax, odim).
            olens (LongTensor): Batch of the lengths of each target (B,).
            spembs (Tensor, optional): Batch of speaker embedding vectors (B, spk_embed_dim).

        Returns:
            Tensor: Loss value.

        """
        # remove unnecessary padded part (for multi-gpus)
        xs = xs[:, :max(ilens)]
        ys = ys[:, :max(olens)]

        # forward propagation
        outs, ds, d_outs = self._forward(xs,
                                         ilens,
                                         ys,
                                         olens,
                                         spembs=spembs,
                                         is_inference=False)

        # apply mask to remove padded part
        if self.use_masking:
            in_masks = make_non_pad_mask(ilens).to(xs.device)
            d_outs = d_outs.masked_select(in_masks)
            ds = ds.masked_select(in_masks)
            out_masks = make_non_pad_mask(olens).unsqueeze(-1).to(ys.device)
            outs = outs.masked_select(out_masks)
            ys = ys.masked_select(out_masks)

        # calculate loss
        l1_loss = self.criterion(outs, ys)
        duration_loss = self.duration_criterion(d_outs, ds)
        loss = l1_loss + duration_loss
        report_keys = [
            {
                "l1_loss": l1_loss.item()
            },
            {
                "duration_loss": duration_loss.item()
            },
            {
                "loss": loss.item()
            },
        ]

        # report extra information
        if self.use_scaled_pos_enc:
            report_keys += [
                {
                    "encoder_alpha": self.encoder.embed[-1].alpha.data.item()
                },
                {
                    "decoder_alpha": self.decoder.embed[-1].alpha.data.item()
                },
            ]
        self.reporter.report(report_keys)

        return loss
Exemple #24
0
    def forward(self, xs_pad, ilens, ys_pad, enc_mask=None, dec_mask=None):
        """E2E forward.

        :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim)
        :param torch.Tensor ilens: batch of lengths of source sequences (B)
        :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax)
        :return: ctc loass value
        :rtype: torch.Tensor
        :return: attention loss value
        :rtype: torch.Tensor
        :return: accuracy in attention decoder
        :rtype: float
        """
        # 1. forward encoder
        xs_pad = xs_pad[:, :max(ilens)]  # for data parallel
        batch_size = xs_pad.shape[0]
        src_mask = make_non_pad_mask(ilens.tolist()).to(
            xs_pad.device).unsqueeze(-2)
        if isinstance(self.encoder.embed, EncoderConv2d):
            xs, hs_mask = self.encoder.embed(xs_pad,
                                             torch.sum(src_mask, 2).squeeze())
            hs_mask = hs_mask.unsqueeze(1)
        else:
            xs, hs_mask = self.encoder.embed(xs_pad, src_mask)

        if enc_mask is not None:
            enc_mask = enc_mask[:, :hs_mask.shape[2], :hs_mask.shape[2]]
        enc_mask = enc_mask & hs_mask if enc_mask is not None else hs_mask
        hs_pad, _ = self.encoder.encoders(xs, enc_mask)
        if self.encoder.normalize_before:
            hs_pad = self.encoder.after_norm(hs_pad)

        # CTC forward
        ys = [y[y != self.ignore_id] for y in ys_pad]
        y_len = max([len(y) for y in ys])
        ys_pad = ys_pad[:, :y_len]
        if dec_mask is not None:
            dec_mask = dec_mask[:, :y_len + 1, :hs_pad.shape[1]]
        self.hs_pad = hs_pad
        batch_size = xs_pad.size(0)
        if self.mtlalpha == 0.0:
            loss_ctc = None
        else:
            batch_size = xs_pad.size(0)
            hs_len = hs_mask.view(batch_size, -1).sum(1)
            loss_ctc = self.ctc(hs_pad.view(batch_size, -1, self.adim), hs_len,
                                ys_pad)

        # trigger mask
        hs_mask = hs_mask & dec_mask if dec_mask is not None else hs_mask
        # 2. forward decoder
        ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos,
                                            self.ignore_id)
        ys_mask = target_mask(ys_in_pad, self.ignore_id)
        pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask)
        self.pred_pad = pred_pad

        # 3. compute attention loss
        loss_att = self.criterion(pred_pad, ys_out_pad)
        self.acc = th_accuracy(pred_pad.view(-1, self.odim),
                               ys_out_pad,
                               ignore_label=self.ignore_id)

        # copyied from e2e_asr
        alpha = self.mtlalpha
        if alpha == 0:
            self.loss = loss_att
            loss_att_data = float(loss_att)
            loss_ctc_data = None
        elif alpha == 1:
            self.loss = loss_ctc
            loss_att_data = None
            loss_ctc_data = float(loss_ctc)
        else:
            self.loss = alpha * loss_ctc + (1 - alpha) * loss_att
            loss_att_data = float(loss_att)
            loss_ctc_data = float(loss_ctc)

        return self.loss, loss_ctc_data, loss_att_data, self.acc
    def forward(self, feats: torch.Tensor, feats_len: torch.Tensor,
                labels: torch.Tensor) -> torch.Tensor:
        """E2E forward.

        Args:
            feats: Feature sequences. (B, F, D_feats)
            feats_len: Feature sequences lengths. (B,)
            labels: Label ID sequences. (B, L)

        Returns:
            loss: Transducer loss value

        """
        # 1. encoder
        feats = feats[:, :max(feats_len)]

        if self.etype == "custom":
            feats_mask = (make_non_pad_mask(feats_len.tolist()).to(
                feats.device).unsqueeze(-2))

            _enc_out, _enc_out_len = self.encoder(feats, feats_mask)
        else:
            _enc_out, _enc_out_len, _ = self.enc(feats, feats_len)

        if self.use_auxiliary_enc_outputs:
            enc_out, aux_enc_out = _enc_out[0], _enc_out[1]
            enc_out_len, aux_enc_out_len = _enc_out_len[0], _enc_out_len[1]
        else:
            enc_out, aux_enc_out = _enc_out, None
            enc_out_len, aux_enc_out_len = _enc_out_len, None

        # 2. decoder
        dec_in = get_decoder_input(labels, self.blank_id, self.ignore_id)

        if self.dtype == "custom":
            self.decoder.set_device(enc_out.device)

            dec_in_mask = target_mask(dec_in, self.blank_id)
            dec_out, _ = self.decoder(dec_in, dec_in_mask)
        else:
            self.dec.set_device(enc_out.device)

            dec_out = self.dec(dec_in)

        # 3. transducer tasks computation
        losses = self.transducer_tasks(
            enc_out,
            aux_enc_out,
            dec_out,
            labels,
            enc_out_len,
            aux_enc_out_len,
        )

        if self.training or self.error_calculator is None:
            cer, wer = None, None
        else:
            cer, wer = self.error_calculator(
                enc_out, self.transducer_tasks.get_target())

        self.loss = sum(losses)
        loss_data = float(self.loss)

        if not math.isnan(loss_data):
            self.reporter.report(
                loss_data,
                *[float(loss) for loss in losses],
                cer,
                wer,
            )
        else:
            logging.warning("loss (=%f) is not correct", loss_data)

        return self.loss
Exemple #26
0
    def forward(self, xs_pad, ilens, ys_pad):
        """E2E forward.

        :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim)
        :param torch.Tensor ilens: batch of lengths of source sequences (B)
        :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax)
        :return: ctc loass value
        :rtype: torch.Tensor
        :return: attention loss value
        :rtype: torch.Tensor
        :return: accuracy in attention decoder
        :rtype: float
        """
        # 1. forward Transformer encoder
        xs_pad = xs_pad[:, : max(ilens)]  # for data parallel
        if xs_pad.size(1) != ys_pad.size(1):
            if xs_pad.size(1) < ys_pad.size(1):
                ys_pad = ys_pad[:, :xs_pad.size(1)].contiguous()
            else:
                raise ValueError("target size {} is smaller than input size {}".format(ys_pad.size(1), xs_pad.size(1)))
        src_mask = make_non_pad_mask(ilens.tolist()).to(xs_pad.device).unsqueeze(-2)
        hs_pad, hs_mask = self.encoder(xs_pad, src_mask)

        # 2. post-processing layer for target dimension
        if self.outer:
            post_pad = self.poster(hs_pad)
            post_pad = post_pad.view(post_pad.size(0), -1, self.odim)
            if post_pad.size(1) != xs_pad.size(1):
                if post_pad.size(1) < xs_pad.size(1):
                    xs_pad = xs_pad[:, :post_pad.size(1)].contiguous()
                else:
                    raise ValueError("target size {} and pred size {} is mismatch".format(xs_pad.size(1), post_pad.size(1)))
            if self.residual:
                post_pad = post_pad + self.matcher_res(xs_pad)
            else:
                post_pad = torch.cat([post_pad, xs_pad], dim=-1)
            pred_pad = self.matcher(post_pad)
        else:
            pred_pad = self.poster(hs_pad)
            pred_pad = pred_pad.view(pred_pad.size(0), -1, self.odim)
        self.pred_pad = pred_pad
        if pred_pad.size(1) != ys_pad.size(1):
            if pred_pad.size(1) < ys_pad.size(1):
                ys_pad = ys_pad[:, :pred_pad.size(1)].contiguous()
            else:
                raise ValueError("target size {} and pred size {} is mismatch".format(ys_pad.size(1), pred_pad.size(1)))

        # 3. compute attention loss
        loss_att = self.criterion(pred_pad, ys_pad)
        self.acc = th_accuracy(
            pred_pad.view(-1, self.odim), ys_pad, ignore_label=self.ignore_id
        )

        # TODO(karita) show predicted text
        # TODO(karita) calculate these stats
        cer_ctc = None
        if self.mtlalpha == 0.0:
            loss_ctc = None
        else:
            batch_size = xs_pad.size(0)
            hs_len = hs_mask.view(batch_size, -1).sum(1)
            loss_ctc = self.ctc(pred_pad.view(batch_size, -1, self.adim), hs_len, ys_pad)
            if self.error_calculator is not None:
                ys_hat = self.ctc.argmax(pred_pad.view(batch_size, -1, self.adim)).data
                cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)

        # 3. compute cer/wer
        if self.training or self.error_calculator is None:
            cer, wer = None, None
        else:
            ys_hat = pred_pad.argmax(dim=-1)
            cer, wer = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())

        # copyied from e2e_asr
        alpha = self.mtlalpha
        if alpha == 0:
            self.loss = loss_att
            loss_att_data = float(loss_att)
            loss_ctc_data = None
        elif alpha == 1:
            self.loss = loss_ctc
            loss_att_data = None
            loss_ctc_data = float(loss_ctc)
        else:
            self.loss = alpha * loss_ctc + (1 - alpha) * loss_att
            loss_att_data = float(loss_att)
            loss_ctc_data = float(loss_ctc)

        loss_data = float(self.loss)
        if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data):
            self.reporter.report(
                loss_ctc_data, loss_att_data, self.acc, cer_ctc, cer, wer, loss_data
            )
        else:
            pass
            # logging.warning("loss (=%f) is not correct", loss_data)
        return self.loss
Exemple #27
0
    def forward(self, xs_pad, ilens, ys_pad, ys_pad_src):
        """E2E forward.

        :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim)
        :param torch.Tensor ilens: batch of lengths of source sequences (B)
        :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax)
        :param torch.Tensor ys_pad_src: batch of padded target sequences (B, Lmax)
        :return: ctc loss value
        :rtype: torch.Tensor
        :return: attention loss value
        :rtype: torch.Tensor
        :return: accuracy in attention decoder
        :rtype: float
        """
        # 0. Extract target language ID
        tgt_lang_ids = None
        if self.multilingual:
            tgt_lang_ids = ys_pad[:, 0:1]
            ys_pad = ys_pad[:, 1:]  # remove target language ID in the beggining

        # 1. forward encoder
        xs_pad = xs_pad[:, : max(ilens)]  # for data parallel
        src_mask = make_non_pad_mask(ilens.tolist()).to(xs_pad.device).unsqueeze(-2)
        hs_pad, hs_mask = self.encoder(xs_pad, src_mask)

        # 2. forward decoder
        ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
        # replace <sos> with target language ID
        if self.replace_sos:
            ys_in_pad = torch.cat([tgt_lang_ids, ys_in_pad[:, 1:]], dim=1)
        ys_mask = target_mask(ys_in_pad, self.ignore_id)
        pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask)

        # 3. compute ST loss
        loss_att = self.criterion(pred_pad, ys_out_pad)

        self.acc = th_accuracy(
            pred_pad.view(-1, self.odim), ys_out_pad, ignore_label=self.ignore_id
        )

        # 4. compute corpus-level bleu in a mini-batch
        if self.training:
            self.bleu = None
        else:
            ys_hat = pred_pad.argmax(dim=-1)
            self.bleu = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())

        # 5. compute auxiliary ASR loss
        loss_asr_att, acc_asr, loss_asr_ctc, cer_ctc, cer, wer = self.forward_asr(
            hs_pad, hs_mask, ys_pad_src
        )

        # 6. compute auxiliary MT loss
        loss_mt, acc_mt = 0.0, None
        if self.mt_weight > 0:
            loss_mt, acc_mt = self.forward_mt(
                ys_pad_src, ys_in_pad, ys_out_pad, ys_mask
            )

        asr_ctc_weight = self.mtlalpha
        self.loss = (
            (1 - self.asr_weight - self.mt_weight) * loss_att
            + self.asr_weight
            * (asr_ctc_weight * loss_asr_ctc + (1 - asr_ctc_weight) * loss_asr_att)
            + self.mt_weight * loss_mt
        )
        loss_asr_data = float(
            asr_ctc_weight * loss_asr_ctc + (1 - asr_ctc_weight) * loss_asr_att
        )
        loss_mt_data = None if self.mt_weight == 0 else float(loss_mt)
        loss_st_data = float(loss_att)

        loss_data = float(self.loss)
        if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data):
            self.reporter.report(
                loss_asr_data,
                loss_mt_data,
                loss_st_data,
                acc_asr,
                acc_mt,
                self.acc,
                cer_ctc,
                cer,
                wer,
                self.bleu,
                loss_data,
            )
        else:
            logging.warning("loss (=%f) is not correct", loss_data)
        return self.loss
Exemple #28
0
    def forward(self, xs_pad, ilens, ys_pad):
        """E2E forward.

        :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim)
        :param torch.Tensor ilens: batch of lengths of source sequences (B)
        :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax)
        :return: ctc loss value
        :rtype: torch.Tensor
        :return: attention loss value
        :rtype: torch.Tensor
        :return: accuracy in attention decoder
        :rtype: float
        """
        # 1. forward encoder
        xs_pad = xs_pad[:, :max(ilens)]  # for data parallel
        src_mask = make_non_pad_mask(ilens.tolist()).to(
            xs_pad.device).unsqueeze(-2)
        hs_pad, hs_mask, hs_intermediates = self.encoder(xs_pad, src_mask)
        self.hs_pad = hs_pad

        # 2. forward decoder
        ys_in_pad, ys_out_pad = mask_uniform(ys_pad, self.mask_token, self.eos,
                                             self.ignore_id)
        ys_mask = square_mask(ys_in_pad, self.eos)
        pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask)
        self.pred_pad = pred_pad

        # 3. compute attention loss
        loss_att = self.criterion(pred_pad, ys_out_pad)
        self.acc = th_accuracy(pred_pad.view(-1, self.odim),
                               ys_out_pad,
                               ignore_label=self.ignore_id)

        # 4. compute ctc loss
        loss_ctc, cer_ctc = None, None
        loss_intermediate_ctc = 0.0
        if self.mtlalpha > 0:
            batch_size = xs_pad.size(0)
            hs_len = hs_mask.view(batch_size, -1).sum(1)
            loss_ctc = self.ctc(hs_pad.view(batch_size, -1, self.adim), hs_len,
                                ys_pad)
            if self.error_calculator is not None:
                ys_hat = self.ctc.argmax(hs_pad.view(batch_size, -1,
                                                     self.adim)).data
                cer_ctc = self.error_calculator(ys_hat.cpu(),
                                                ys_pad.cpu(),
                                                is_ctc=True)
            # for visualization
            if not self.training:
                self.ctc.softmax(hs_pad)
            if self.intermediate_ctc_weight > 0 and self.intermediate_ctc_layers:
                for hs_intermediate in hs_intermediates:
                    # assuming hs_intermediates and hs_pad has same length / padding
                    loss_inter = self.ctc(
                        hs_intermediate.view(batch_size, -1, self.adim),
                        hs_len, ys_pad)
                    loss_intermediate_ctc += loss_inter

                loss_intermediate_ctc /= len(self.intermediate_ctc_layers)

        # 5. compute cer/wer
        if self.training or self.error_calculator is None or self.decoder is None:
            cer, wer = None, None
        else:
            ys_hat = pred_pad.argmax(dim=-1)
            cer, wer = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())

        alpha = self.mtlalpha
        if alpha == 0:
            self.loss = loss_att
            loss_att_data = float(loss_att)
            loss_ctc_data = None
        else:
            self.loss = (alpha * loss_ctc +
                         self.intermediate_ctc_weight * loss_intermediate_ctc +
                         (1 - alpha - self.intermediate_ctc_weight) * loss_att)
            loss_att_data = float(loss_att)
            loss_ctc_data = float(loss_ctc)

        loss_data = float(self.loss)
        if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data):
            self.reporter.report(loss_ctc_data, loss_att_data, self.acc,
                                 cer_ctc, cer, wer, loss_data)
        else:
            logging.warning("loss (=%f) is not correct", loss_data)
        return self.loss
    def forward(self, xs_pad, ilens, ys_pad):
        """E2E forward.

        :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim)
        :param torch.Tensor ilens: batch of lengths of source sequences (B)
        :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax)
        :return: ctc loss value
        :rtype: torch.Tensor
        :return: attention loss value
        :rtype: torch.Tensor
        :return: accuracy in attention decoder
        :rtype: float
        """
        # 1. forward encoder
        xs_pad = xs_pad[:, : max(ilens)]  # for data parallel
        src_mask = make_non_pad_mask(ilens.tolist()).to(xs_pad.device).unsqueeze(-2)
        hs_pad, hs_mask = self.encoder(xs_pad, src_mask)
        self.hs_pad = hs_pad

        # 2. forward decoder
        if self.decoder is not None:
            if self.decoder_mode == "maskctc":
                ys_in_pad, ys_out_pad = mask_uniform(
                    ys_pad, self.mask_token, self.eos, self.ignore_id
                )
                ys_mask = (ys_in_pad != self.ignore_id).unsqueeze(-2)
            else:
                ys_in_pad, ys_out_pad = add_sos_eos(
                    ys_pad, self.sos, self.eos, self.ignore_id
                )
                ys_mask = target_mask(ys_in_pad, self.ignore_id)
            pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask)
            self.pred_pad = pred_pad

            # 3. compute attention loss
            loss_att = self.criterion(pred_pad, ys_out_pad)
            self.acc = th_accuracy(
                pred_pad.view(-1, self.odim), ys_out_pad, ignore_label=self.ignore_id
            )
        else:
            loss_att = None
            self.acc = None

        # TODO(karita) show predicted text
        # TODO(karita) calculate these stats
        cer_ctc = None
        if self.mtlalpha == 0.0:
            loss_ctc = None
        else:
            batch_size = xs_pad.size(0)
            hs_len = hs_mask.view(batch_size, -1).sum(1)
            loss_ctc = self.ctc(hs_pad.view(batch_size, -1, self.adim), hs_len, ys_pad)
            if not self.training and self.error_calculator is not None:
                ys_hat = self.ctc.argmax(hs_pad.view(batch_size, -1, self.adim)).data
                cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
            # for visualization
            if not self.training:
                self.ctc.softmax(hs_pad)

        # 5. compute cer/wer
        if self.training or self.error_calculator is None or self.decoder is None:
            cer, wer = None, None
        else:
            ys_hat = pred_pad.argmax(dim=-1)
            cer, wer = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())

        # copied from e2e_asr
        alpha = self.mtlalpha
        if alpha == 0:
            self.loss = loss_att
            loss_att_data = float(loss_att)
            loss_ctc_data = None
        elif alpha == 1:
            self.loss = loss_ctc
            loss_att_data = None
            loss_ctc_data = float(loss_ctc)
        else:
            self.loss = alpha * loss_ctc + (1 - alpha) * loss_att
            loss_att_data = float(loss_att)
            loss_ctc_data = float(loss_ctc)

        loss_data = float(self.loss)
        if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data):
            self.reporter.report(
                loss_ctc_data, loss_att_data, self.acc, cer_ctc, cer, wer, loss_data
            )
        else:
            logging.warning("loss (=%f) is not correct", loss_data)
        return self.loss
    def _target_mask(self, olens):

        y_masks = make_non_pad_mask(olens).to(next(self.parameters()).device)
        s_masks = subsequent_mask(y_masks.size(-1),
                                  device=y_masks.device).unsqueeze(0)
        return y_masks.unsqueeze(-2) & s_masks