def forward(self, xs_pad, ilens):
        """Encodermix forward.

        :param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, D)
        :param torch.Tensor ilens: batch of lengths of input sequences (B)
        :return: list: batch of hidden state sequences [num_spkrs x (B, Tmax, eprojs)]
        :rtype: torch.Tensor
        """
        # mixture encoder
        for module in self.enc_mix:
            xs_pad, ilens, _ = module(xs_pad, ilens)

        # SD and Rec encoder
        xs_pad_sd = [xs_pad for i in range(self.num_spkrs)]
        ilens_sd = [ilens for i in range(self.num_spkrs)]
        for ns in range(self.num_spkrs):
            # Encoder_SD: speaker differentiate encoder
            for module in self.enc_sd[ns]:
                xs_pad_sd[ns], ilens_sd[ns], _ = module(
                    xs_pad_sd[ns], ilens_sd[ns])
            # Encoder_Rec: recognition encoder
            for module in self.enc_rec:
                xs_pad_sd[ns], ilens_sd[ns], _ = module(
                    xs_pad_sd[ns], ilens_sd[ns])

        # make mask to remove bias value in padded part
        mask = to_device(self, make_pad_mask(ilens_sd[0]).unsqueeze(-1))

        return [x.masked_fill(mask, 0.0) for x in xs_pad_sd], ilens_sd, None
Exemple #2
0
    def forward(
        self,
        input: torch.Tensor,
        ilens: torch.Tensor = None
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        """LabelAggregate forward function.

        Args:
            input: (Batch, Nsamples, Label_dim)
            ilens: (Batch)
        Returns:
            output: (Batch, Frames, Label_dim)

        """
        bs = input.size(0)
        max_length = input.size(1)
        label_dim = input.size(2)

        # NOTE(jiatong):
        #   The default behaviour of label aggregation is compatible with
        #   torch.stft about framing and padding.

        # Step1: center padding
        if self.center:
            pad = self.win_length // 2
            max_length = max_length + 2 * pad
            input = torch.nn.functional.pad(input, (0, 0, pad, pad),
                                            "constant", 0)
            input[:, :pad, :] = input[:, pad:(2 * pad), :]
            input[:,
                  (max_length -
                   pad):max_length, :] = input[:,
                                               (max_length -
                                                2 * pad):(max_length - pad), :]
            nframe = (max_length - self.win_length) // self.hop_length + 1

        # Step2: framing
        output = input.as_strided(
            (bs, nframe, self.win_length, label_dim),
            (max_length * label_dim, self.hop_length * label_dim, label_dim,
             1),
        )

        # Step3: aggregate label
        output = torch.gt(output.sum(dim=2, keepdim=False),
                          self.win_length // 2)
        output = output.float()

        # Step4: process lengths
        if ilens is not None:
            if self.center:
                pad = self.win_length // 2
                ilens = ilens + 2 * pad

            olens = (ilens - self.win_length) // self.hop_length + 1
            output.masked_fill_(make_pad_mask(olens, output, 1), 0.0)
        else:
            olens = None

        return output, olens
Exemple #3
0
    def forward(self, feat: torch.Tensor, ilens: torch.LongTensor) \
            -> Tuple[torch.Tensor, torch.LongTensor]:
        # feat: (B, T, D1) x melmat: (D1, D2) -> mel_feat: (B, T, D2)
        mel_feat = torch.matmul(feat, self.melmat)

        logmel_feat = (mel_feat + 1e-20).log()
        # Zero padding
        logmel_feat = logmel_feat.masked_fill(
            make_pad_mask(ilens, logmel_feat, 1), 0.0)

        # We now create the Scattering1D object that will be used to calculate the scattering coefficients.
        # scattering = Scattering1D(J, T, Q)

        # If we are using CUDA, the scattering transform object must be transferred to the GPU by calling its cuda() method. The data is similarly transferred.
        # if use_cuda:
        #     scattering.cuda()
        #     x_all = x_all.cuda()
        #     y_all = y_all.cuda()

        # Compute the scattering transform for all signals in the dataset.
        # Sx_all = scattering.forward(x_all)

        # Since it does not carry useful information, we remove the zeroth-order scattering coefficients, which are always placed in the first channel of the scattering Tensor.
        # Sx_all = Sx_all[:,1:,:]

        # To increase discriminability, we take the logarithm of the scattering coefficients (after adding a small constant to make sure nothing blows up when scattering coefficients are close to zero).
        # Sx_all = torch.log(torch.abs(Sx_all) + log_eps)

        # Finally, we average along the last dimension (time) to get a time-shift invariant representation.
        # Sx_all = torch.mean(Sx_all, dim=-1)

        return logmel_feat, ilens
    def forward(self, xs_pad, ilens, ys_pad):
        """E2E forward.

        :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim)
        :param torch.Tensor ilens: batch of lengths of source sequences (B)
        :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax)
        :return: ctc loass value
        :rtype: torch.Tensor
        :return: attention loss value
        :rtype: torch.Tensor
        :return: accuracy in attention decoder
        :rtype: float
        """
        # forward encoder
        xs_pad = xs_pad[:, :max(ilens)]  # for data parallel
        src_mask = (~make_pad_mask(ilens.tolist())).to(
            xs_pad.device).unsqueeze(-2)
        hs_pad, hs_mask = self.encoder(xs_pad, src_mask)
        # lid output layer
        pred_pad = self.lid_lo(hs_pad)

        # compute lid loss
        self.loss = self.criterion(pred_pad, ys_pad)
        self.acc = th_accuracy(pred_pad.view(-1, self.odim),
                               ys_pad,
                               ignore_label=self.ignore_id)

        loss_data = float(self.loss)
        if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data):
            self.reporter.report(self.acc, loss_data)
        else:
            logging.warning('loss (=%f) is not correct', loss_data)
        return self.loss
    def forward(
            self, input: torch.Tensor,
            input_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """Forward."""
        input = self.linear_in(input)

        args = {"return_dict": True}

        mask = (~make_pad_mask(input_lengths)).to(input.device).float()

        if self.extend_attention_mask:
            args["attention_mask"] = _extend_attention_mask(mask)
        else:
            args["attention_mask"] = mask

        if self.use_inputs_embeds:
            args["inputs_embeds"] = input
        else:
            args["hidden_states"] = input

        if self.transformer.config.model_type == "mpnet":
            args["head_mask"] = [None for _ in self.transformer.layer]

        output = self.transformer(**args).last_hidden_state

        return output, input_lengths
Exemple #6
0
    def forward(
        self,
        feat: torch.Tensor,
        ilens: torch.Tensor = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # feat: (B, T, D1) x melmat: (D1, D2) -> mel_feat: (B, T, D2)
        # print(feat.device, feat.dtype, feat.shape)
        # print("\t",self.melmat.device, self.melmat.dtype, self.melmat.shape)

        mel_feat = torch.matmul(feat, self.melmat.to(feat.device))
        mel_feat = torch.clamp(mel_feat, min=1e-10)

        if self.log_base is None:
            logmel_feat = mel_feat.log()
        elif self.log_base == 2.0:
            logmel_feat = mel_feat.log2()
        elif self.log_base == 10.0:
            logmel_feat = mel_feat.log10()
        else:
            logmel_feat = mel_feat.log() / torch.log(self.log_base)

        # Zero padding
        if ilens is not None:
            logmel_feat = logmel_feat.masked_fill(
                make_pad_mask(ilens, logmel_feat, 1), 0.0
            )
        else:
            ilens = feat.new_full(
                [feat.size(0)], fill_value=feat.size(1), dtype=torch.long
            )
        return logmel_feat, ilens
    def store_penultimate_state(self, xs_pad, ilens, ys_pad, moe_coes,
                                moe_coe_lens):
        moe_coes = moe_coes[:, :max(moe_coe_lens)]  # for data parallel
        xs_pad = xs_pad[:, :max(ilens)]  # for data parallel
        src_mask = (~make_pad_mask(ilens.tolist())).to(
            xs_pad.device).unsqueeze(-2)
        # multi-encoder forward
        cn_hs_pad, hs_mask = self.cn_encoder(xs_pad, src_mask)
        en_hs_pad, hs_mask = self.en_encoder(xs_pad, src_mask)
        moe_coes = moe_coes.unsqueeze(-1)
        hs_pad = cn_hs_pad * moe_coes[:, :, 1] + en_hs_pad * moe_coes[:, :, 0]
        self.hs_pad = hs_pad

        # forward decoder
        ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos,
                                            self.ignore_id)
        ys_mask = target_mask(ys_in_pad, self.ignore_id)
        pred_pad, pred_mask, penultimate_state = self.decoder(
            ys_in_pad,
            ys_mask,
            hs_pad,
            hs_mask,
            moe_coes,
            return_penultimate_state=True)

        # plot penultimate_state, (B,T,att_dim)
        return penultimate_state.squeeze(0).detach().cpu().numpy()
Exemple #8
0
    def nll(self, text: torch.Tensor,
            text_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        batch_size = text.size(0)
        # For data parallel
        text = text[:, :text_lengths.max()]

        # 1. Create a sentence pair like '<sos> w1 w2 w3' and 'w1 w2 w3 <eos>'
        # train: (Batch, Length) -> x, y: (Batch, Length + 1)
        x = F.pad(text, [1, 0], "constant", self.eos)
        t = F.pad(text, [0, 1], "constant", self.ignore_id)
        for i, l in enumerate(text_lengths):
            t[i, l] = self.sos
        x_lengths = text_lengths + 1

        # 2. Forward Language model
        # x: (Batch, Length) -> y: (Batch, Length, NVocab)
        y, _ = self.lm(x, None)

        # 3. Calc negative log likelihood
        # nll: (BxL,)
        nll = F.cross_entropy(y.view(-1, y.shape[-1]),
                              t.view(-1),
                              reduction="none")
        # nll: (BxL,) -> (BxL,)
        nll.masked_fill_(make_pad_mask(x_lengths).to(nll.device).view(-1), 0.0)
        # nll: (BxL,) -> (B, L)
        nll = nll.view(batch_size, -1)
        return nll, x_lengths
    def store_penultimate_state(self, xs_pad, ilens, ys_pad):
        xs_pad = xs_pad[:, :max(ilens)]  # for data parallel
        src_mask = (~make_pad_mask(ilens.tolist())).to(xs_pad.device).unsqueeze(-2)
        # multi-encoder forward
        cn_hs_pad, hs_mask = self.cn_encoder(xs_pad, src_mask)
        en_hs_pad, hs_mask = self.en_encoder(xs_pad, src_mask)
        hs_pad = torch.cat((cn_hs_pad, en_hs_pad), dim=-1)

        coe = self.aggregation_linear(hs_pad) # (B,T,2)
        coe = F.softmax(self.aggre_scaling * coe, dim=-1).unsqueeze(-1) # (B,T,2,1)
        hs_pad = coe[:,:,0] * cn_hs_pad + coe[:,:,1] * en_hs_pad
        
        # soft assign mode
        # penultimate_state = 1 - coe # for mlme usage, we need to inverse this coe
        
        # hard assign mode
        penultimate_state = 1 - coe
        penultimate_state = penultimate_state.squeeze(0).detach().cpu().numpy()
        # (T, 2) ndnarray, choose argmax position
        penultimate_state = (penultimate_state > 0.5).astype(int)

        # forward decoder
        # ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
        # ys_mask = target_mask(ys_in_pad, self.ignore_id)
        # pred_pad, pred_mask, penultimate_state = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask, return_penultimate_state=True)

        # plot penultimate_state, (B,T,att_dim)
        return penultimate_state
Exemple #10
0
def utterance_mvn(
    x: torch.Tensor,
    ilens: torch.LongTensor,
    norm_means: bool = True,
    norm_vars: bool = False,
    eps: float = 1.0e-20,
) -> Tuple[torch.Tensor, torch.LongTensor]:
    """Apply utterance mean and variance normalization

    Args:
        x: (B, T, D), assumed zero padded
        ilens: (B, T, D)
        norm_means:
        norm_vars:
        eps:

    """
    ilens_ = ilens.type_as(x)
    # mean: (B, D)
    mean = x.sum(dim=1) / ilens_[:, None]

    if norm_means:
        x -= mean[:, None, :]
        x_ = x
    else:
        x_ = x - mean[:, None, :]

    # Zero padding
    x_.masked_fill(make_pad_mask(ilens, x_, 1), 0.0)
    if norm_vars:
        var = x_.pow(2).sum(dim=1) / ilens_[:, None]
        var = torch.clamp(var, min=eps)
        x /= var.sqrt()[:, None, :]
        x_ = x
    return x_, ilens
Exemple #11
0
    def forward(
        self,
        xs_pad: torch.Tensor,
        ilens: torch.Tensor,
        ys_pad: torch.Tensor,
        ys_pad_length: torch.Tensor,
        prev_states: torch.Tensor = None,
    ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
        """Forward Hubert Pretrain Encoder.

        Args:
            xs_pad: input tensor (B, L, D)
            ilens: input length (B)
            prev_states: Not to be used now.
        Returns:
            position embedded tensor and mask
        """
        self.cast_mask_emb()
        masks = make_pad_mask(ilens).to(xs_pad.device)
        ys_pad = ys_pad[:, :min(ys_pad_length)]
        enc_outputs = self.encoder(
            xs_pad,
            padding_mask=masks,
            mask=True,
            target_list=[ys_pad],
            features_only=False,
        )
        return enc_outputs
    def forward(
        self,
        xs_pad: torch.Tensor,
        ilens: torch.Tensor,
        prev_states: torch.Tensor = None,
    ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
        """Embed positions in tensor.

        Args:
            xs_pad: input tensor (B, L, D)
            ilens: input length (B)
            prev_states: Not to be used now.
        Returns:
            position embedded tensor and mask
        """
        masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)

        if (
            isinstance(self.embed, Conv2dSubsampling)
            or isinstance(self.embed, Conv2dSubsampling6)
            or isinstance(self.embed, Conv2dSubsampling8)
        ):
            xs_pad, masks = self.embed(xs_pad, masks)
        else:
            xs_pad = self.embed(xs_pad)


        xs_pad, masks = self.encoders(xs_pad, masks)

        if self.normalize_before:
            xs_pad = self.after_norm(xs_pad)

        olens = masks.squeeze(1).sum(1)
        return xs_pad, olens, None
Exemple #13
0
    def forward(
        self, x: torch.Tensor, ilens: torch.Tensor = None
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Forward function

        Args:
            x: (B, L, ...)
            ilens: (B,)
        """
        if ilens is None:
            ilens = x.new_full([x.size(0)], x.size(1))
        norm_means = self.norm_means
        norm_vars = self.norm_vars
        self.mean = self.mean.to(x.device, x.dtype)
        self.std = self.std.to(x.device, x.dtype)
        mask = make_pad_mask(ilens, x, 1)

        # feat: (B, T, D)
        if norm_means:
            if x.requires_grad:
                x = x - self.mean
            else:
                x -= self.mean
        if x.requires_grad:
            x = x.masked_fill(mask, 0.0)
        else:
            x.masked_fill_(mask, 0.0)

        if norm_vars:
            x /= self.std

        return x, ilens
Exemple #14
0
    def forward_mt(self, xs_pad, ys_in_pad, ys_out_pad, ys_mask):
        """Forward pass in the auxiliary MT task.

        :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim)
        :param torch.Tensor ys_in_pad: batch of padded target sequences (B, Lmax)
        :param torch.Tensor ys_out_pad: batch of padded target sequences (B, Lmax)
        :param torch.Tensor ys_mask: batch of input token mask (B, Lmax)
        :return: MT loss value
        :rtype: torch.Tensor
        :return: accuracy in MT decoder
        :rtype: float
        """
        loss, acc = 0.0, None
        if self.mt_weight == 0:
            return loss, acc

        ilens = torch.sum(xs_pad != self.ignore_id, dim=1).cpu().numpy()
        # NOTE: xs_pad is padded with -1
        xs = [x[x != self.ignore_id] for x in xs_pad]  # parse padded xs
        xs_zero_pad = pad_list(xs, self.pad)  # re-pad with zero
        xs_zero_pad = xs_zero_pad[:, :max(ilens)]  # for data parallel
        src_mask = (~make_pad_mask(ilens.tolist())).to(
            xs_zero_pad.device).unsqueeze(-2)
        hs_pad, hs_mask = self.encoder_mt(xs_zero_pad, src_mask)
        pred_pad, _ = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask)
        loss = self.criterion(pred_pad, ys_out_pad)
        acc = th_accuracy(pred_pad.view(-1, self.odim),
                          ys_out_pad,
                          ignore_label=self.ignore_id)
        return loss, acc
    def forward(
        self,
        xs_pad: torch.Tensor,
        ilens: torch.Tensor,
        prev_states: torch.Tensor = None,
    ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
        """Embed positions in tensor.

        Args:
            xs_pad: input tensor (B, L, D)
            ilens: input length (B)
            prev_states: Not to be used now.
        Returns:
            position embedded tensor and mask
        """
        masks = (make_pad_mask(ilens)).to(xs_pad.device)

        self.wav2vec.feature_grad_mult = 0  # make sure conv feature extraction has been freezed
        xs_pad = self.wav2vec.forward(xs_pad,
                                      mask=True,
                                      padding_mask=masks,
                                      features_only=True)['x']
        feats_lens = []
        for lens in ilens:
            feats_lens.append(
                get_output_lens(self.wav2vec.feature_extractor.conv_layers,
                                lens))
        olens = torch.stack(feats_lens)

        # xs_pad = self.projection(xs_pad)

        return xs_pad, olens, None
Exemple #16
0
    def forward(
        self,
        xs_pad: torch.Tensor,
        ilens: torch.Tensor,
        prev_states: torch.Tensor = None,
    ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
        """Calculate forward propagation.

        Args:
            xs_pad (torch.Tensor): Input tensor (#batch, L, input_size).
            ilens (torch.Tensor): Input length (#batch).
            prev_states (torch.Tensor): Not to be used now.

        Returns:
            torch.Tensor: Output tensor (#batch, L, output_size).
            torch.Tensor: Output length (#batch).
            torch.Tensor: Not to be used now.

        """
        masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)

        if (isinstance(self.embed, Conv2dSubsampling)
                or isinstance(self.embed, Conv2dSubsampling6)
                or isinstance(self.embed, Conv2dSubsampling8)):
            xs_pad, masks = self.embed(xs_pad, masks)
        else:
            xs_pad = self.embed(xs_pad)
        xs_pad, masks = self.encoders(xs_pad, masks)
        if isinstance(xs_pad, tuple):
            xs_pad = xs_pad[0]
        if self.normalize_before:
            xs_pad = self.after_norm(xs_pad)

        olens = masks.squeeze(1).sum(1)
        return xs_pad, olens, None
    def forward(self, xs_pad, ilens, ys_pad):
        """E2E forward.

        :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax)
        :param torch.Tensor ilens: batch of lengths of source sequences (B)
        :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax)
        :rtype: torch.Tensor
        :return: attention loss value
        :rtype: torch.Tensor
        :return: accuracy in attention decoder
        :rtype: float
        """
        # 1. forward encoder
        xs_pad = xs_pad[:, :max(ilens)]  # for data parallel
        src_mask = (~make_pad_mask(ilens.tolist())).to(
            xs_pad.device).unsqueeze(-2)
        xs_pad, ys_pad = self.target_forcing(xs_pad, ys_pad)
        hs_pad, hs_mask = self.encoder(xs_pad, src_mask)
        self.hs_pad = hs_pad

        # 2. forward decoder
        ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos,
                                            self.ignore_id)
        ys_mask = target_mask(ys_in_pad, self.ignore_id)
        pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask)
        self.pred_pad = pred_pad

        # 3. compute attention loss
        loss = self.criterion(pred_pad, ys_out_pad)
        self.acc = th_accuracy(pred_pad.view(-1, self.odim),
                               ys_out_pad,
                               ignore_label=self.ignore_id)

        # TODO(karita) show predicted text
        # TODO(karita) calculate these stats

        # 5. compute bleu
        if self.training or self.error_calculator is None:
            bleu = 0.0
        else:
            ys_hat = pred_pad.argmax(dim=-1)
            bleu = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())

        # copyied from e2e_mt
        self.loss = loss

        loss_data = float(self.loss)
        if self.normalize_length:
            self.ppl = np.exp(loss_data)
        else:
            ys_out_pad = ys_out_pad.view(-1)
            ignore = ys_out_pad == self.ignore_id  # (B,)
            total = len(ys_out_pad) - ignore.sum().item()
            self.ppl = np.exp(loss_data * ys_out_pad.size(0) / total)
        if not math.isnan(loss_data):
            self.reporter.report(loss_data, self.acc, self.ppl, bleu)
        else:
            logging.warning('loss (=%f) is not correct', loss_data)
        return self.loss
    def _forward(
        self,
        xs,
        ilens,
        ys=None,
        olens=None,
        spembs=None,
        ds=None,
        is_inference=False,
        alpha=1.0,
    ):
        # forward encoder
        x_masks = self._source_mask(ilens)
        hs, _ = self.encoder(xs, x_masks)  # (B, Tmax, adim)

        # integrate speaker embedding
        if self.spk_embed_dim is not None:
            hs = self._integrate_with_spk_embed(hs, spembs)

        # forward duration predictor and length regulator
        d_masks = make_pad_mask(ilens).to(xs.device)
        if is_inference:
            d_outs = self.duration_predictor.inference(hs,
                                                       d_masks)  # (B, Tmax)
            hs = self.length_regulator(hs, d_outs, ilens,
                                       alpha)  # (B, Lmax, adim)
        else:
            if ds is None:
                with torch.no_grad():
                    ds = self.duration_calculator(xs, ilens, ys, olens,
                                                  spembs)  # (B, Tmax)
            d_outs = self.duration_predictor(hs, d_masks)  # (B, Tmax)
            hs = self.length_regulator(hs, ds, ilens)  # (B, Lmax, adim)

        # forward decoder
        if olens is not None:
            if self.reduction_factor > 1:
                olens_in = olens.new(
                    [olen // self.reduction_factor for olen in olens])
            else:
                olens_in = olens
            h_masks = self._source_mask(olens_in)
        else:
            h_masks = None
        zs, _ = self.decoder(hs, h_masks)  # (B, Lmax, adim)
        before_outs = self.feat_out(zs).view(zs.size(0), -1,
                                             self.odim)  # (B, Lmax, odim)

        # postnet -> (B, Lmax//r * r, odim)
        if self.postnet is None:
            after_outs = before_outs
        else:
            after_outs = before_outs + self.postnet(before_outs.transpose(
                1, 2)).transpose(1, 2)

        if is_inference:
            return before_outs, after_outs, d_outs
        else:
            return before_outs, after_outs, ds, d_outs
Exemple #19
0
 def forward(self, xs, activation=None):
     ilens = [x.shape[0] for x in xs]
     xs_pad = pad_sequence(xs, batch_first=True, padding_value=-1)
     pad_shape = xs.shape
     src_mask = (~make_pad_mask(ilens)).to(xs_pad.device).unsqueeze(-2)
     hs_pad, hs_mask = self.enc.forward(xs_pad, src_mask)
     ys_pad = self.linear(hs_pad)
     return ys_pad
Exemple #20
0
    def forward(
        self,
        hs_pad: torch.Tensor,
        hlens: torch.Tensor,
        ys_in_pad: torch.Tensor,
        ys_in_lens: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Forward decoder.

        Args:
            hs_pad: encoded memory, float32  (batch, maxlen_in, feat)
            hlens: (batch)
            ys_in_pad:
                input token ids, int64 (batch, maxlen_out)
                if input_layer == "embed"
                input tensor (batch, maxlen_out, #mels) in the other cases
            ys_in_lens: (batch)
        Returns:
            (tuple): tuple containing:

            x: decoded token score before softmax (batch, maxlen_out, token)
                if use_output_layer is True,
            olens: (batch, )
        """
        tgt = ys_in_pad
        # tgt_mask: (B, 1, L)
        tgt_mask = (~make_pad_mask(ys_in_lens)[:, None, :]).to(tgt.device)
        # m: (1, L, L)
        m = subsequent_mask(tgt_mask.size(-1),
                            device=tgt_mask.device).unsqueeze(0)
        # tgt_mask: (B, L, L)
        tgt_mask = tgt_mask & m

        memory = hs_pad
        memory_mask = (~make_pad_mask(hlens))[:, None, :].to(memory.device)

        x = self.embed(tgt)
        x, tgt_mask, memory, memory_mask = self.decoders(
            x, tgt_mask, memory, memory_mask)
        if self.normalize_before:
            x = self.after_norm(x)
        if self.output_layer is not None:
            x = self.output_layer(x)

        olens = tgt_mask.sum(1)
        return x, olens
 def store_penultimate_state(self, xs_pad, ilens, ys_pad):
     self.eval()
     xs_pad = xs_pad[:, :max(ilens)]  # for data parallel
     src_mask = (~make_pad_mask(ilens.tolist())).to(
         xs_pad.device).unsqueeze(-2)
     hs_pad, hs_mask = self.encoder(xs_pad, src_mask)
     # plot penultimate_state, (B,T,att_dim)
     return hs_pad.squeeze(0).detach().cpu().numpy()
Exemple #22
0
    def forward(self, xs_pad, ilens, ys_pad):
        """E2E forward.

        Args:
            xs_pad (torch.Tensor): batch of padded source sequences (B, Tmax, idim)
            ilens (torch.Tensor): batch of lengths of input sequences (B)
            ys_pad (torch.Tensor): batch of padded target sequences (B, Lmax)

        Returns:
            loss (torch.Tensor): transducer loss value

        """
        # 1. encoder
        if self.etype == 'transformer':
            xs_pad = xs_pad[:, :max(ilens)]
            src_mask = (~make_pad_mask(ilens.tolist())).to(
                xs_pad.device).unsqueeze(-2)

            hs_pad, hs_mask = self.encoder(xs_pad, src_mask)
        else:
            hs_pad, hlens = xs_pad, ilens
            hs_pad, hlens, _ = self.encoder(hs_pad, hlens)
            hs_mask = hlens
        self.hs_pad = hs_pad

        # 1.5. transducer preparation related
        ys_in_pad, target, pred_len, target_len = prepare_loss_inputs(
            ys_pad, hs_mask)

        # 2. decoder
        if self.dtype == 'transformer':
            ys_mask = target_mask(ys_in_pad, self.blank_id)
            pred_pad, _ = self.decoder(ys_in_pad, ys_mask, hs_pad)
        else:
            if self.rnnt_mode == 'rnnt':
                pred_pad = self.decoder(hs_pad, ys_in_pad)
            else:
                pred_pad = self.decoder(hs_pad, ys_in_pad, pred_len)
        self.pred_pad = pred_pad

        # 3. loss computation
        loss = self.criterion(pred_pad, target, pred_len, target_len)

        self.loss = loss
        loss_data = float(self.loss)

        # 4. compute cer/wer
        if self.training or self.error_calculator is None:
            cer, wer = None, None
        else:
            cer, wer = self.error_calculator(hs_pad, ys_pad)

        if not math.isnan(loss_data):
            self.reporter.report(loss_data, cer, wer)
        else:
            logging.warning('loss (=%f) is not correct', loss_data)

        return self.loss
Exemple #23
0
def utterance_mvn(
    x: torch.Tensor,
    ilens: torch.Tensor = None,
    norm_means: bool = True,
    norm_vars: bool = False,
    eps: float = 1.0e-20,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """Apply utterance mean and variance normalization

    Args:
        x: (B, T, D), assumed zero padded
        ilens: (B,)
        norm_means:
        norm_vars:
        eps:

    """
    if ilens is None:
        ilens = x.new_full([x.size(0)], x.size(1))
    ilens_ = ilens.to(x.device, x.dtype).view(-1,
                                              *[1 for _ in range(x.dim() - 1)])
    # Zero padding
    if x.requires_grad:
        x = x.masked_fill(make_pad_mask(ilens, x, 1), 0.0)
    else:
        x.masked_fill_(make_pad_mask(ilens, x, 1), 0.0)
    # mean: (B, 1, D)
    mean = x.sum(dim=1, keepdim=True) / ilens_

    if norm_means:
        x -= mean

        if norm_vars:
            var = x.pow(2).sum(dim=1, keepdim=True) / ilens_
            std = torch.clamp(var.sqrt(), min=eps)
            x = x / std.sqrt()
        return x, ilens
    else:
        if norm_vars:
            y = x - mean
            y.masked_fill_(make_pad_mask(ilens, y, 1), 0.0)
            var = y.pow(2).sum(dim=1, keepdim=True) / ilens_
            std = torch.clamp(var.sqrt(), min=eps)
            x /= std
        return x, ilens
Exemple #24
0
    def _forward(
        self,
        xs: torch.Tensor,
        ilens: torch.Tensor,
        ys: torch.Tensor = None,
        olens: torch.Tensor = None,
        ds: torch.Tensor = None,
        spembs: torch.Tensor = None,
        is_inference: bool = False,
        alpha: float = 1.0,
    ) -> Sequence[torch.Tensor]:
        # forward encoder
        x_masks = self._source_mask(ilens)
        hs, _ = self.encoder(xs, x_masks)  # (B, Tmax, adim)

        # integrate with GST
        if self.use_gst:
            style_embs = self.gst(ys)
            hs = hs + style_embs.unsqueeze(1)

        # integrate speaker embedding
        if self.spk_embed_dim is not None:
            hs = self._integrate_with_spk_embed(hs, spembs)

        # forward duration predictor and length regulator
        d_masks = make_pad_mask(ilens).to(xs.device)
        if is_inference:
            d_outs = self.duration_predictor.inference(hs,
                                                       d_masks)  # (B, Tmax)
            hs = self.length_regulator(hs, d_outs, ilens,
                                       alpha)  # (B, Lmax, adim)
        else:
            d_outs = self.duration_predictor(hs, d_masks)  # (B, Tmax)
            hs = self.length_regulator(hs, ds, ilens)  # (B, Lmax, adim)

        # forward decoder
        if olens is not None and not is_inference:
            if self.reduction_factor > 1:
                olens_in = olens.new(
                    [olen // self.reduction_factor for olen in olens])
            else:
                olens_in = olens
            h_masks = self._source_mask(olens_in)
        else:
            h_masks = None
        zs, _ = self.decoder(hs, h_masks)  # (B, Lmax, adim)
        before_outs = self.feat_out(zs).view(zs.size(0), -1,
                                             self.odim)  # (B, Lmax, odim)

        # postnet -> (B, Lmax//r * r, odim)
        if self.postnet is None:
            after_outs = before_outs
        else:
            after_outs = before_outs + self.postnet(before_outs.transpose(
                1, 2)).transpose(1, 2)

        return before_outs, after_outs, d_outs
Exemple #25
0
    def forward(
        self, xs: Union[torch.Tensor, ComplexTensor], ilens: torch.LongTensor
    ) -> Tuple[Tuple[torch.Tensor, ...], torch.LongTensor]:
        """Mask estimator forward function.

        Args:
            xs: (B, F, C, T)
            ilens: (B,)
        Returns:
            hs (torch.Tensor): The hidden vector (B, F, C, T)
            masks: A tuple of the masks. (B, F, C, T)
            ilens: (B,)
        """
        assert xs.size(0) == ilens.size(0), (xs.size(0), ilens.size(0))
        _, _, C, input_length = xs.size()
        # (B, F, C, T) -> (B, C, T, F)
        xs = xs.permute(0, 2, 3, 1)

        # Calculate amplitude: (B, C, T, F) -> (B, C, T, F)
        if is_complex(xs):
            xs = (xs.real**2 + xs.imag**2)**0.5
        # xs: (B, C, T, F) -> xs: (B * C, T, F)
        xs = xs.contiguous().view(-1, xs.size(-2), xs.size(-1))
        # ilens: (B,) -> ilens_: (B * C)
        ilens_ = ilens[:, None].expand(-1, C).contiguous().view(-1)

        # xs: (B * C, T, F) -> xs: (B * C, T, D)
        xs, _, _ = self.brnn(xs, ilens_)
        # xs: (B * C, T, D) -> xs: (B, C, T, D)
        xs = xs.view(-1, C, xs.size(-2), xs.size(-1))

        masks = []
        for linear in self.linears:
            # xs: (B, C, T, D) -> mask:(B, C, T, F)
            mask = linear(xs)

            if self.nonlinear == "sigmoid":
                mask = torch.sigmoid(mask)
            elif self.nonlinear == "relu":
                mask = torch.relu(mask)
            elif self.nonlinear == "tanh":
                mask = torch.tanh(mask)
            elif self.nonlinear == "crelu":
                mask = torch.clamp(mask, min=0, max=1)
            # Zero padding
            mask.masked_fill(make_pad_mask(ilens, mask, length_dim=2), 0)

            # (B, C, T, F) -> (B, F, C, T)
            mask = mask.permute(0, 3, 1, 2)

            # Take cares of multi gpu cases: If input_length > max(ilens)
            if mask.size(-1) < input_length:
                mask = F.pad(mask, [0, input_length - mask.size(-1)], value=0)
            masks.append(mask)

        return tuple(masks), ilens
    def forward(self, x: torch.Tensor, ilens: torch.LongTensor) \
            -> Tuple[torch.Tensor, torch.LongTensor]:
        # feat: (B, T, D)
        if self.norm_means:
            x += self.bias.type_as(x)
            x.masked_fill(make_pad_mask(ilens, x, 1), 0.0)

        if self.norm_vars:
            x *= self.scale.type_as(x)
        return x, ilens
Exemple #27
0
    def forward(
        self, feat: torch.Tensor, ilens: torch.LongTensor
    ) -> Tuple[torch.Tensor, torch.LongTensor]:
        # feat: (B, T, D1) x melmat: (D1, D2) -> mel_feat: (B, T, D2)
        mel_feat = torch.matmul(feat, self.melmat)

        logmel_feat = (mel_feat + 1e-20).log()
        # Zero padding
        logmel_feat = logmel_feat.masked_fill(make_pad_mask(ilens, logmel_feat, 1), 0.0)
        return logmel_feat, ilens
Exemple #28
0
    def _forward(
        self,
        xs: torch.Tensor,
        ilens: torch.Tensor,
        ys: torch.Tensor = None,
        olens: torch.Tensor = None,
        ds: torch.Tensor = None,
        ps: torch.Tensor = None,
        es: torch.Tensor = None,
    ):
        x_masks = self._source_mask(ilens)  # (B, 1, Tmax)
        y_masks = self._source_mask(olens)  # (B, 1, Lmax)
        hs, _ = self.encoder(xs, x_masks)  # (B, Tmax, adim)

        d_masks = make_pad_mask(ilens).to(xs.device)
        if self.stop_gradient_from_pitch_predictor:
            p_outs = self.pitch_predictor(hs.detach(), d_masks.unsqueeze(-1))
        else:
            p_outs = self.pitch_predictor(hs, d_masks.unsqueeze(-1))
        if self.stop_gradient_from_energy_predictor:
            e_outs = self.energy_predictor(hs.detach(), d_masks.unsqueeze(-1))
        else:
            e_outs = self.energy_predictor(hs, d_masks.unsqueeze(-1))

        d_outs = self.duration_predictor(hs, d_masks)
        # use groundtruth in training
        p_embs = self.pitch_embed(ps.transpose(1, 2)).transpose(1, 2)
        e_embs = self.energy_embed(es.transpose(1, 2)).transpose(1, 2)

        hs = hs + e_embs + p_embs
        mu = self.length_regulator(hs, ds)  # (B, Lmax, adim)

        mu, _ = self.pre_decoder(mu, y_masks)  # (B, Lmax, adim)

        mu = self.feat_out(mu)  # (B, Lmax, odim)

        mu = mu.transpose(1, 2)  # (B, odim, Lmax)
        if mu.size(2) % 4 != 0:
            mu = torch.cat([
                mu,
                torch.zeros([mu.size(0), self.odim, 4 - mu.size(2) % 4],
                            dtype=mu.dtype,
                            device=mu.device)
            ],
                           dim=2)
            y_masks = torch.cat([
                y_masks,
                torch.zeros([y_masks.size(0), 1, 4 - y_masks.size(2) % 4],
                            dtype=y_masks.dtype,
                            device=y_masks.device)
            ],
                                dim=2)
        noise_estimation, z = self.decoder(ys, y_masks, mu)

        return noise_estimation, z, d_outs, p_outs, e_outs, mu, y_masks
Exemple #29
0
 def expand_to(self, xs, lens):
     """
         xs: (B, D)
         lens: (B,)
     """
     # (B, T, 1)
     mask = to_device(xs, make_pad_mask(lens).unsqueeze(-1))
     # (B, D) -> (B, 1, D) -> (B, T, D)
     xs = xs.unsqueeze(1).expand(-1, mask.size(1),
                                 -1).masked_fill(mask, 0.0)
     return xs
Exemple #30
0
    def forward(
        self,
        input: torch.Tensor,
        ilens: torch.Tensor = None
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        """STFT forward function.

        Args:
            input: (Batch, Nsamples) or (Batch, Nsample, Channels)
            ilens: (Batch)
        Returns:
            output: (Batch, Frames, Freq, 2) or (Batch, Frames, Channels, Freq, 2)

        """
        bs = input.size(0)
        if input.dim() == 3:
            multi_channel = True
            # input: (Batch, Nsample, Channels) -> (Batch * Channels, Nsample)
            input = input.transpose(1, 2).reshape(-1, input.size(1))
        else:
            multi_channel = False

        # output: (Batch, Freq, Frames, 2=real_imag)
        # or (Batch, Channel, Freq, Frames, 2=real_imag)
        output = torch.stft(
            input,
            n_fft=self.n_fft,
            win_length=self.win_length,
            hop_length=self.hop_length,
            center=self.center,
            pad_mode=self.pad_mode,
            normalized=self.normalized,
            onesided=self.onesided,
        )
        # output: (Batch, Freq, Frames, 2=real_imag)
        # -> (Batch, Frames, Freq, 2=real_imag)
        output = output.transpose(1, 2)
        if multi_channel:
            # output: (Batch * Channel, Frames, Freq, 2=real_imag)
            # -> (Batch, Frame, Channel, Freq, 2=real_imag)
            output = output.view(bs, -1, output.size(1), output.size(2),
                                 2).transpose(1, 2)

        if ilens is not None:
            if self.center:
                pad = self.win_length // 2
                ilens = ilens + 2 * pad

            olens = (ilens - self.win_length) // self.hop_length + 1
            output.masked_fill_(make_pad_mask(olens, output, 1), 0.0)
        else:
            olens = None

        return output, olens