Beispiel #1
0
    def forward(self, input: ComplexTensor, ilens: torch.Tensor):
        """Forward.

        Args:
            input (ComplexTensor): spectrum [Batch, T, (C,) F]
            ilens (torch.Tensor): input lengths [Batch]
        """
        if not isinstance(input, ComplexTensor) and (
                is_torch_1_9_plus and not torch.is_complex(input)):
            raise TypeError("Only support complex tensors for stft decoder")

        bs = input.size(0)
        if input.dim() == 4:
            multi_channel = True
            # input: (Batch, T, C, F) -> (Batch * C, T, F)
            input = input.transpose(1, 2).reshape(-1, input.size(1),
                                                  input.size(3))
        else:
            multi_channel = False

        wav, wav_lens = self.stft.inverse(input, ilens)

        if multi_channel:
            # wav: (Batch * C, Nsamples) -> (Batch, Nsamples, C)
            wav = wav.reshape(bs, -1, wav.size(1)).transpose(1, 2)

        return wav, wav_lens
Beispiel #2
0
    def forward(self, xs: ComplexTensor, input_lengths: torch.LongTensor) \
            -> torch.Tensor:
        assert xs.size(0) == input_lengths.size(0), (xs.size(0),
                                                     input_lengths.size(0))

        # xs: (B, C, T, D)
        C = xs.size(1)
        if self.feat_type == 'amplitude':
            # xs: (B, C, T, F) -> (B, C, T, F)
            xs = (xs.real ** 2 + xs.imag ** 2) ** 0.5
        elif self.feat_type == 'power':
            # xs: (B, C, T, F) -> (B, C, T, F)
            xs = xs.real ** 2 + xs.imag ** 2
        elif self.feat_type == 'log_power':
            # xs: (B, C, T, F) -> (B, C, T, F)
            xs = torch.log(xs.real ** 2 + xs.imag ** 2)
        elif self.feat_type == 'concat':
            # xs: (B, C, T, F) -> (B, C, T, 2 * F)
            xs = torch.cat([xs.real, xs.imag], -1)
        else:
            raise NotImplementedError(f'Not implemented: {self.feat_type}')

        if self.model_type in ('blstm', 'lstm'):
            # xs: (B, C, T, F) -> xs: (B, C, T, D)
            xs = self.net(xs, input_lengths)

        elif self.model_type == 'cnn':
            if self.channel_independent:
                # xs: (B, C, T, F) -> xs: (B * C, F, T)
                xs = xs.view(-1, *xs.size()[2:]).transpose(1, 2)
                # xs: (B * C, F, T) -> xs: (B * C, D, T)
                xs = self.net(xs)
                # xs: (B * C, D, T) -> (B, C, T, D)
                xs = xs.transpose(1, 2).contiguous().view(
                    -1, C, xs.size(2), xs.size(1))
            else:
                # xs: (B, C, T, F) -> xs: (B, C, T, F)
                xs = self.net(xs)
        else:
            raise NotImplementedError(f'Not implemented: {self.model_type}')

        # xs: (B, C, T, D) -> out:(B, C, T, F)
        out = self.linear(xs)
        # Zero padding
        out = torch.sigmoid(out)
        out.masked_fill(make_pad_mask(input_lengths, out, length_dim=2), 0)

        return out