Exemple #1
0
def test_inverse():
    layer = Stft()
    x = torch.randn(2, 400, requires_grad=True)
    y, _ = layer(x)
    x_lengths = torch.IntTensor([400, 300])
    raw, _ = layer.inverse(y, x_lengths)
    raw, _ = layer.inverse(y)
Exemple #2
0
    def __init__(
        self,
        n_fft: int = 512,
        win_length: int = None,
        hop_length: int = 128,
        rnn_type: str = "blstm",
        layer: int = 3,
        unit: int = 512,
        dropout: float = 0.0,
        num_spk: int = 2,
        nonlinear: str = "sigmoid",
        utt_mvn: bool = False,
        mask_type: str = "IRM",
        loss_type: str = "mask_mse",
    ):
        super(TFMaskingNet, self).__init__()

        self.num_spk = num_spk
        self.num_bin = n_fft // 2 + 1
        self.mask_type = mask_type
        self.loss_type = loss_type
        if loss_type not in ("mask_mse", "magnitude", "spectrum"):
            raise ValueError("Unsupported loss type: %s" % loss_type)

        self.stft = Stft(
            n_fft=n_fft,
            win_length=win_length,
            hop_length=hop_length,
        )

        if utt_mvn:
            self.utt_mvn = UtteranceMVN(norm_means=True, norm_vars=True)

        else:
            self.utt_mvn = None

        self.rnn = RNN(
            idim=self.num_bin,
            elayers=layer,
            cdim=unit,
            hdim=unit,
            dropout=dropout,
            typ=rnn_type,
        )

        self.linear = torch.nn.ModuleList(
            [torch.nn.Linear(unit, self.num_bin) for _ in range(self.num_spk)])

        if nonlinear not in ("sigmoid", "relu", "tanh"):
            raise ValueError("Not supporting nonlinear={}".format(nonlinear))

        self.nonlinear = {
            "sigmoid": torch.nn.Sigmoid(),
            "relu": torch.nn.ReLU(),
            "tanh": torch.nn.Tanh(),
        }[nonlinear]
Exemple #3
0
def test_forward():
    layer = Stft(win_length=4, hop_length=2, n_fft=4)
    x = torch.randn(2, 30)
    y, _ = layer(x)
    assert y.shape == (2, 16, 3, 2)
    y, ylen = layer(x, torch.tensor([30, 15], dtype=torch.long))
    assert (ylen == torch.tensor((16, 8), dtype=torch.long)).all()
Exemple #4
0
    def __init__(
        self,
        window_sz=[512],
        hop_sz=None,
        eps=1e-8,
        time_domain_weight=0.5,
        name=None,
        only_for_test=False,
    ):
        _name = "TD_L1_loss" if name is None else name
        super(MultiResL1SpecLoss, self).__init__(_name,
                                                 only_for_test=only_for_test)

        assert all([x % 2 == 0 for x in window_sz])
        self.window_sz = window_sz

        if hop_sz is None:
            self.hop_sz = [x // 2 for x in window_sz]
        else:
            self.hop_sz = hop_sz

        self.time_domain_weight = time_domain_weight
        self.eps = eps
        self.stft_encoders = torch.nn.ModuleList([])
        for w, h in zip(self.window_sz, self.hop_sz):
            stft_enc = Stft(
                n_fft=w,
                win_length=w,
                hop_length=h,
                window=None,
                center=True,
                normalized=False,
                onesided=True,
            )
            self.stft_encoders.append(stft_enc)
Exemple #5
0
    def __init__(
        self,
        fs: Union[int, str] = 22050,
        n_fft: int = 1024,
        win_length: int = None,
        hop_length: int = 256,
        window: Optional[str] = "hann",
        center: bool = True,
        normalized: bool = False,
        onesided: bool = True,
        use_token_averaged_energy: bool = True,
    ):
        assert check_argument_types()
        super().__init__()
        if isinstance(fs, str):
            fs = humanfriendly.parse_size(fs)

        self.fs = fs
        self.n_fft = n_fft
        self.hop_length = hop_length
        self.win_length = win_length
        self.window = window
        self.use_token_averaged_energy = use_token_averaged_energy

        self.stft = Stft(
            n_fft=n_fft,
            win_length=win_length,
            hop_length=hop_length,
            window=window,
            center=center,
            normalized=normalized,
            onesided=onesided,
        )
Exemple #6
0
def test_gevd(ch):
    stft = Stft(
        n_fft=8,
        win_length=None,
        hop_length=2,
        center=True,
        window="hann",
        normalized=False,
        onesided=True,
    )
    torch.random.manual_seed(0)
    x = random_speech[..., :ch]
    ilens = torch.LongTensor([16, 12])
    # (B, T, C, F) -> (B, F, C, T)
    X = torch.complex(*torch.unbind(stft(x, ilens)[0], dim=-1)).transpose(
        -1, -3)
    # (B, F, C, C)
    Phi_X = torch.einsum("...ct,...et->...ce", [X, X.conj()])

    is_singular = True
    while is_singular:
        N = torch.randn_like(X)
        Phi_N = torch.einsum("...ct,...et->...ce", [N, N.conj()])
        is_singular = not torch.linalg.matrix_rank(Phi_N).eq(ch).all()
    # Phi_N = torch.eye(ch, dtype=Phi_X.dtype).view(1, 1, ch, ch).expand_as(Phi_X)

    # e_val: (B, F, C)
    # e_vec: (B, F, C, C)
    e_val, e_vec = generalized_eigenvalue_decomposition(Phi_X, Phi_N)
    e_val = e_val.to(dtype=e_vec.dtype)
    assert torch.allclose(
        torch.matmul(Phi_X, e_vec),
        torch.matmul(torch.matmul(Phi_N, e_vec), e_val.diag_embed()),
    )
 def __init__(
     self,
     n_fft: int = 1024,
     win_length: int = None,
     hop_length: int = 256,
     window: Optional[str] = "hann",
     center: bool = True,
     pad_mode: str = "reflect",
     normalized: bool = False,
     onesided: bool = True,
 ):
     assert check_argument_types()
     super().__init__()
     self.n_fft = n_fft
     self.hop_length = hop_length
     self.win_length = win_length
     self.window = window
     self.stft = Stft(
         n_fft=n_fft,
         win_length=win_length,
         hop_length=hop_length,
         window=window,
         center=center,
         pad_mode=pad_mode,
         normalized=normalized,
         onesided=onesided,
     )
     self.n_fft = n_fft
Exemple #8
0
    def __init__(
        self,
        fs: Union[int, str] = 16000,
        n_fft: int = 512,
        win_length: int = None,
        hop_length: int = 128,
        window: Optional[str] = "hann",
        center: bool = True,
        normalized: bool = False,
        onesided: bool = True,
        n_mels: int = 80,
        fmin: int = None,
        fmax: int = None,
        htk: bool = False,
        frontend_conf: Optional[dict] = get_default_kwargs(Frontend),
        apply_stft: bool = True,
    ):
        assert check_argument_types()
        super().__init__()
        if isinstance(fs, str):
            fs = humanfriendly.parse_size(fs)

        # Deepcopy (In general, dict shouldn't be used as default arg)
        frontend_conf = copy.deepcopy(frontend_conf)
        self.hop_length = hop_length

        if apply_stft:
            self.stft = Stft(
                n_fft=n_fft,
                win_length=win_length,
                hop_length=hop_length,
                center=center,
                window=window,
                normalized=normalized,
                onesided=onesided,
            )
        else:
            self.stft = None
        self.apply_stft = apply_stft

        if frontend_conf is not None:
            self.frontend = Frontend(idim=n_fft // 2 + 1, **frontend_conf)
        else:
            self.frontend = None

        self.logmel = LogMel(
            fs=fs,
            n_fft=n_fft,
            n_mels=n_mels,
            fmin=fmin,
            fmax=fmax,
            htk=htk,
        )
        self.n_mels = n_mels
        self.frontend_type = "default"
Exemple #9
0
 def __init__(
     self,
     n_fft: int = 512,
     win_length: int = None,
     hop_length: int = 128,
     window="hann",
     center: bool = True,
     normalized: bool = False,
     onesided: bool = True,
 ):
     super().__init__()
     self.stft = Stft(
         n_fft=n_fft,
         win_length=win_length,
         hop_length=hop_length,
         window=window,
         center=center,
         normalized=normalized,
         onesided=onesided,
     )
Exemple #10
0
def test_get_rtf(ch, mode):
    if not is_torch_1_9_plus and mode == "evd":
        # torch 1.9.0+ is required for "evd" mode
        return
    if mode == "evd":
        complex_wrapper = torch.complex
        complex_module = torch
    else:
        complex_wrapper = ComplexTensor
        complex_module = FC
    stft = Stft(
        n_fft=8,
        win_length=None,
        hop_length=2,
        center=True,
        window="hann",
        normalized=False,
        onesided=True,
    )
    torch.random.manual_seed(0)
    x = random_speech[..., :ch]
    ilens = torch.LongTensor([16, 12])
    # (B, T, C, F) -> (B, F, C, T)
    X = complex_wrapper(*torch.unbind(stft(x, ilens)[0], dim=-1)).transpose(
        -1, -3)
    # (B, F, C, C)
    Phi_X = complex_module.einsum("...ct,...et->...ce", [X, X.conj()])

    is_singular = True
    while is_singular:
        N = complex_wrapper(torch.randn_like(X.real), torch.randn_like(X.imag))
        Phi_N = complex_module.einsum("...ct,...et->...ce", [N, N.conj()])
        is_singular = not np.all(np.linalg.matrix_rank(Phi_N.numpy()) == ch)

    # (B, F, C, 1)
    rtf = get_rtf(Phi_X, Phi_N, mode=mode, reference_vector=0, iterations=20)
    if is_torch_1_1_plus:
        rtf = rtf / (rtf.abs().max(dim=-2, keepdim=True).values + 1e-15)
    else:
        rtf = rtf / (rtf.abs().max(dim=-2, keepdim=True)[0] + 1e-15)
    # rtf \approx Phi_N MaxEigVec(Phi_N^-1 @ Phi_X)
    if is_torch_1_1_plus:
        # torch.solve is required, which is only available after pytorch 1.1.0+
        mat = solve(Phi_X, Phi_N)[0]
        max_eigenvec = solve(rtf, Phi_N)[0]
    else:
        mat = complex_module.matmul(Phi_N.inverse2(), Phi_X)
        max_eigenvec = complex_module.matmul(Phi_N.inverse2(), rtf)
    factor = complex_module.matmul(mat, max_eigenvec)
    assert complex_module.allclose(
        complex_module.matmul(max_eigenvec, factor.transpose(-1, -2)),
        complex_module.matmul(factor, max_eigenvec.transpose(-1, -2)),
    )
Exemple #11
0
class STFTDecoder(AbsDecoder):
    """STFT decoder for speech enhancement and separation"""
    def __init__(
        self,
        n_fft: int = 512,
        win_length: int = None,
        hop_length: int = 128,
        window="hann",
        center: bool = True,
        normalized: bool = False,
        onesided: bool = True,
    ):
        super().__init__()
        self.stft = Stft(
            n_fft=n_fft,
            win_length=win_length,
            hop_length=hop_length,
            window=window,
            center=center,
            normalized=normalized,
            onesided=onesided,
        )

    def forward(self, input: ComplexTensor, ilens: torch.Tensor):
        """Forward.

        Args:
            input (ComplexTensor): spectrum [Batch, T, (C,) F]
            ilens (torch.Tensor): input lengths [Batch]
        """
        if not isinstance(input, ComplexTensor) and (
                is_torch_1_9_plus and not torch.is_complex(input)):
            raise TypeError("Only support complex tensors for stft decoder")

        bs = input.size(0)
        if input.dim() == 4:
            multi_channel = True
            # input: (Batch, T, C, F) -> (Batch * C, T, F)
            input = input.transpose(1, 2).reshape(-1, input.size(1),
                                                  input.size(3))
        else:
            multi_channel = False

        wav, wav_lens = self.stft.inverse(input, ilens)

        if multi_channel:
            # wav: (Batch * C, Nsamples) -> (Batch, Nsamples, C)
            wav = wav.reshape(bs, -1, wav.size(1)).transpose(1, 2)

        return wav, wav_lens
Exemple #12
0
    def __init__(
        self,
        fs: Union[int, str] = 16000,
        n_fft: int = 1024,
        win_length: int = None,
        hop_length: int = 256,
        window: Optional[str] = "hann",
        center: bool = True,
        pad_mode: str = "reflect",
        normalized: bool = False,
        onesided: bool = True,
        n_mels: int = 80,
        fmin: Optional[int] = 80,
        fmax: Optional[int] = 7600,
        htk: bool = False,
    ):
        assert check_argument_types()
        super().__init__()
        if isinstance(fs, str):
            fs = humanfriendly.parse_size(fs)

        self.fs = fs
        self.n_mels = n_mels
        self.n_fft = n_fft
        self.hop_length = hop_length
        self.win_length = win_length
        self.window = window
        self.fmin = fmin
        self.fmax = fmax

        self.stft = Stft(
            n_fft=n_fft,
            win_length=win_length,
            hop_length=hop_length,
            window=window,
            center=center,
            pad_mode=pad_mode,
            normalized=normalized,
            onesided=onesided,
        )

        self.logmel = LogMel(
            fs=fs,
            n_fft=n_fft,
            n_mels=n_mels,
            fmin=fmin,
            fmax=fmax,
            htk=htk,
            log_base=10.0,
        )
Exemple #13
0
def test_get_rtf(ch):
    stft = Stft(
        n_fft=8,
        win_length=None,
        hop_length=2,
        center=True,
        window="hann",
        normalized=False,
        onesided=True,
    )
    torch.random.manual_seed(0)
    x = random_speech[..., :ch]
    n = torch.rand(2, 16, ch, dtype=torch.double)
    ilens = torch.LongTensor([16, 12])
    # (B, T, C, F) -> (B, F, C, T)
    X = ComplexTensor(*torch.unbind(stft(x, ilens)[0], dim=-1)).transpose(-1, -3)
    N = ComplexTensor(*torch.unbind(stft(n, ilens)[0], dim=-1)).transpose(-1, -3)
    # (B, F, C, C)
    Phi_X = FC.einsum("...ct,...et->...ce", [X, X.conj()])
    Phi_N = FC.einsum("...ct,...et->...ce", [N, N.conj()])
    # (B, F, C, 1)
    rtf = get_rtf(Phi_X, Phi_N, reference_vector=0, iterations=20)
    if is_torch_1_1_plus:
        rtf = rtf / (rtf.abs().max(dim=-2, keepdim=True).values + 1e-15)
    else:
        rtf = rtf / (rtf.abs().max(dim=-2, keepdim=True)[0] + 1e-15)
    # rtf \approx Phi_N MaxEigVec(Phi_N^-1 @ Phi_X)
    if is_torch_1_1_plus:
        # torch.solve is required, which is only available after pytorch 1.1.0+
        mat = FC.solve(Phi_X, Phi_N)[0]
        max_eigenvec = FC.solve(rtf, Phi_N)[0]
    else:
        mat = FC.matmul(Phi_N.inverse2(), Phi_X)
        max_eigenvec = FC.matmul(Phi_N.inverse2(), rtf)
    factor = FC.matmul(mat, max_eigenvec)
    assert FC.allclose(
        FC.matmul(max_eigenvec, factor.transpose(-1, -2)),
        FC.matmul(factor, max_eigenvec.transpose(-1, -2)),
    )
Exemple #14
0
class STFTDecoder(AbsDecoder):
    """STFT decoder for speech enhancement and separation"""

    def __init__(
        self,
        n_fft: int = 512,
        win_length: int = None,
        hop_length: int = 128,
        window="hann",
        center: bool = True,
        normalized: bool = False,
        onesided: bool = True,
    ):
        super().__init__()
        self.stft = Stft(
            n_fft=n_fft,
            win_length=win_length,
            hop_length=hop_length,
            window=window,
            center=center,
            normalized=normalized,
            onesided=onesided,
        )

    def forward(self, input: ComplexTensor, ilens: torch.Tensor):
        """Forward.

        Args:
            input (ComplexTensor): spectrum [Batch, T, F]
            ilens (torch.Tensor): input lengths [Batch]
        """
        if not isinstance(input, ComplexTensor) and (
            is_torch_1_9_plus and not torch.is_complex(input)
        ):
            raise TypeError("Only support complex tensors for stft decoder")

        wav, wav_lens = self.stft.inverse(input, ilens)

        return wav, wav_lens
Exemple #15
0
    def __init__(
        self,
        n_fft: int = 512,
        win_length: int = None,
        hop_length: int = 128,
        window="hann",
        center: bool = True,
        normalized: bool = False,
        onesided: bool = True,
        use_builtin_complex: bool = True,
    ):
        super().__init__()
        self.stft = Stft(
            n_fft=n_fft,
            win_length=win_length,
            hop_length=hop_length,
            window=window,
            center=center,
            normalized=normalized,
            onesided=onesided,
        )

        self._output_dim = n_fft // 2 + 1 if onesided else n_fft
        self.use_builtin_complex = use_builtin_complex
Exemple #16
0
def test_repr():
    print(Stft())
Exemple #17
0
class TFMaskingTransformer(AbsEnhancement):
    """TF Masking Speech Separation Net."""

    def __init__(
        self,
        n_fft: int = 256, 
        win_length: int = None,
        hop_length: int = 128,
        dnn_type: str = "transformer",
        #layer: int = 3,
        #unit: int = 512,
        dropout: float = 0.0,
        num_spk: int = 2,
        nonlinear: str = "sigmoid",
        utt_mvn: bool = False,
        mask_type: str = "IRM",
        loss_type: str = "mask_mse",
        d_model: int = 256,
        nhead: int = 4,
        linear_units: int = 2048,
        num_layers: int = 6,
        dropout_rate: float = 0.1,
        positional_dropout_rate: float = 0.1,
        attention_dropout_rate: float = 0.0,
        input_layer: Optional[str] = "linear",
        pos_enc_class=PositionalEncoding,
        normalize_before: bool = True,
        concat_after: bool = False,
        positionwise_layer_type: str = "linear",
        positionwise_conv_kernel_size: int = 1,
        padding_idx: int = -1,
    
    ):
        super(TFMaskingTransformer, self).__init__()

        self.num_spk = num_spk
        self.num_bin = n_fft // 2 + 1
        self.mask_type = mask_type
        self.loss_type = loss_type
        if loss_type not in ("mask_mse", "magnitude", "spectrum"):
            raise ValueError("Unsupported loss type: %s" % loss_type)

        self.stft = Stft(n_fft=n_fft, win_length=win_length, hop_length=hop_length,)

        if utt_mvn:
            self.utt_mvn = UtteranceMVN(norm_means=True, norm_vars=True)

        else:
            self.utt_mvn = None

        #self.rnn = RNN(
        #    idim=self.num_bin,
        #    elayers=layer,
        #    cdim=unit,
        #    hdim=unit,
        #    dropout=dropout,
        #    typ=rnn_type,
        #)

        self.encoder = TransformerEncoder(
             input_size=self.num_bin,
             output_size=d_model,
             attention_heads=nhead,
             linear_units=linear_units,
             num_blocks=num_layers,
             positional_dropout_rate=positional_dropout_rate,
             attention_dropout_rate=attention_dropout_rate,
             input_layer=input_layer,
             normalize_before=normalize_before,
             concat_after=concat_after,
             positionwise_layer_type=positionwise_layer_type,
             positionwise_conv_kernel_size=positionwise_conv_kernel_size,
             padding_idx=padding_idx,
        )
        self.linear = torch.nn.ModuleList(
            [torch.nn.Linear(d_model, self.num_bin) for _ in range(self.num_spk)]
        )

        if nonlinear not in ("sigmoid", "relu", "tanh"):
            raise ValueError("Not supporting nonlinear={}".format(nonlinear))

        self.nonlinear = {
            "sigmoid": torch.nn.Sigmoid(),
            "relu": torch.nn.ReLU(),
            "tanh": torch.nn.Tanh(),
        }[nonlinear]

    def forward(self, input: torch.Tensor, ilens: torch.Tensor):
        """Forward.

        Args:
            input (torch.Tensor): mixed speech [Batch, sample]
            ilens (torch.Tensor): input lengths [Batch]

        Returns:
            separated (list[ComplexTensor]): [(B, T, F), ...]
            ilens (torch.Tensor): (B,)
            predcited masks: OrderedDict[
                'spk1': torch.Tensor(Batch, Frames, Channel, Freq),
                'spk2': torch.Tensor(Batch, Frames, Channel, Freq),
                ...
                'spkn': torch.Tensor(Batch, Frames, Channel, Freq),
            ]
        """
        # wave -> stft -> magnitude specturm
        input_spectrum, flens = self.stft(input, ilens)
        input_spectrum = ComplexTensor(input_spectrum[..., 0], input_spectrum[..., 1])
        input_magnitude = abs(input_spectrum)
        input_phase = input_spectrum / (input_magnitude + 10e-12)

        # apply utt mvn
        if self.utt_mvn:
            input_magnitude_mvn, fle = self.utt_mvn(input_magnitude, flens)
        else:
            input_magnitude_mvn = input_magnitude

        # predict masks for each speaker
        #x, flens, _ = self.rnn(input_magnitude_mvn, flens)
        x, olens, _ = self.encoder(input_magnitude_mvn, flens)

        masks = []
        for linear in self.linear:
            y = linear(x)
            y = self.nonlinear(y)
            masks.append(y)

        if self.training and self.loss_type.startswith("mask"):
            predicted_spectrums = None
        else:
            # apply mask
            predict_magnitude = [input_magnitude * m for m in masks]
            predicted_spectrums = [input_phase * pm for pm in predict_magnitude]

        masks = OrderedDict(
            zip(["spk{}".format(i + 1) for i in range(len(masks))], masks)
        )
        return predicted_spectrums, flens, masks
    def forward_rawwav(
        self, input: torch.Tensor, ilens: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Output with waveforms.

        Args:
            input (torch.Tensor): mixed speech [Batch, sample]
            ilens (torch.Tensor): input lengths [Batch]

        Returns:
            predcited speech [Batch, num_speaker, sample]
            output lengths
            predcited masks: OrderedDict[
                'spk1': torch.Tensor(Batch, Frames, Channel, Freq),
                'spk2': torch.Tensor(Batch, Frames, Channel, Freq),
                ...
                'spkn': torch.Tensor(Batch, Frames, Channel, Freq),
            ]
        """

        # predict spectrum for each speaker
        predicted_spectrums, flens, masks = self.forward(input, ilens)

        if predicted_spectrums is None:
            predicted_wavs = None
        elif isinstance(predicted_spectrums, list):
            # multi-speaker input
            predicted_wavs = [
                self.stft.inverse(ps, ilens)[0] for ps in predicted_spectrums
            ]
        else:
            # single-speaker input
            predicted_wavs = self.stft.inverse(predicted_spectrums, ilens)[0]

        return predicted_wavs, ilens, masks
                                              
Exemple #18
0
def test_backward_not_leaf_in():
    layer = Stft()
    x = torch.randn(2, 400, requires_grad=True)
    x = x + 2
    y, _ = layer(x)
    y.sum().backward()
Exemple #19
0
    def __init__(
        self,
        n_fft: int = 256, 
        win_length: int = None,
        hop_length: int = 128,
        dnn_type: str = "transformer",
        #layer: int = 3,
        #unit: int = 512,
        dropout: float = 0.0,
        num_spk: int = 2,
        nonlinear: str = "sigmoid",
        utt_mvn: bool = False,
        mask_type: str = "IRM",
        loss_type: str = "mask_mse",
        d_model: int = 256,
        nhead: int = 4,
        linear_units: int = 2048,
        num_layers: int = 6,
        dropout_rate: float = 0.1,
        positional_dropout_rate: float = 0.1,
        attention_dropout_rate: float = 0.0,
        input_layer: Optional[str] = "linear",
        pos_enc_class=PositionalEncoding,
        normalize_before: bool = True,
        concat_after: bool = False,
        positionwise_layer_type: str = "linear",
        positionwise_conv_kernel_size: int = 1,
        padding_idx: int = -1,
    
    ):
        super(TFMaskingTransformer, self).__init__()

        self.num_spk = num_spk
        self.num_bin = n_fft // 2 + 1
        self.mask_type = mask_type
        self.loss_type = loss_type
        if loss_type not in ("mask_mse", "magnitude", "spectrum"):
            raise ValueError("Unsupported loss type: %s" % loss_type)

        self.stft = Stft(n_fft=n_fft, win_length=win_length, hop_length=hop_length,)

        if utt_mvn:
            self.utt_mvn = UtteranceMVN(norm_means=True, norm_vars=True)

        else:
            self.utt_mvn = None

        #self.rnn = RNN(
        #    idim=self.num_bin,
        #    elayers=layer,
        #    cdim=unit,
        #    hdim=unit,
        #    dropout=dropout,
        #    typ=rnn_type,
        #)

        self.encoder = TransformerEncoder(
             input_size=self.num_bin,
             output_size=d_model,
             attention_heads=nhead,
             linear_units=linear_units,
             num_blocks=num_layers,
             positional_dropout_rate=positional_dropout_rate,
             attention_dropout_rate=attention_dropout_rate,
             input_layer=input_layer,
             normalize_before=normalize_before,
             concat_after=concat_after,
             positionwise_layer_type=positionwise_layer_type,
             positionwise_conv_kernel_size=positionwise_conv_kernel_size,
             padding_idx=padding_idx,
        )
        self.linear = torch.nn.ModuleList(
            [torch.nn.Linear(d_model, self.num_bin) for _ in range(self.num_spk)]
        )

        if nonlinear not in ("sigmoid", "relu", "tanh"):
            raise ValueError("Not supporting nonlinear={}".format(nonlinear))

        self.nonlinear = {
            "sigmoid": torch.nn.Sigmoid(),
            "relu": torch.nn.ReLU(),
            "tanh": torch.nn.Tanh(),
        }[nonlinear]
Exemple #20
0
    def __init__(
        self,
        num_spk: int = 1,
        normalize_input: bool = False,
        mask_type: str = "IPM^2",
        # STFT options
        n_fft: int = 512,
        win_length: int = None,
        hop_length: int = 128,
        center: bool = True,
        window: Optional[str] = "hann",
        pad_mode: str = "reflect",
        normalized: bool = False,
        onesided: bool = True,
        # Dereverberation options
        use_wpe: bool = False,
        wnet_type: str = "blstmp",
        wlayers: int = 3,
        wunits: int = 300,
        wprojs: int = 320,
        wdropout_rate: float = 0.0,
        taps: int = 5,
        delay: int = 3,
        use_dnn_mask_for_wpe: bool = True,
        # Beamformer options
        use_beamformer: bool = True,
        bnet_type: str = "blstmp",
        blayers: int = 3,
        bunits: int = 300,
        bprojs: int = 320,
        badim: int = 320,
        ref_channel: int = -1,
        use_noise_mask: bool = True,
        beamformer_type="mvdr",
        bdropout_rate=0.0,
    ):
        super(BeamformerNet, self).__init__()

        self.mask_type = mask_type

        self.num_spk = num_spk
        self.num_bin = n_fft // 2 + 1

        self.stft = Stft(
            n_fft=n_fft,
            win_length=win_length,
            hop_length=hop_length,
            center=center,
            window=window,
            pad_mode=pad_mode,
            normalized=normalized,
            onesided=onesided,
        )

        self.normalize_input = normalize_input
        self.use_beamformer = use_beamformer
        self.use_wpe = use_wpe

        if self.use_wpe:
            if use_dnn_mask_for_wpe:
                # Use DNN for power estimation
                iterations = 1
            else:
                # Performing as conventional WPE, without DNN Estimator
                iterations = 2

            self.wpe = DNN_WPE(
                wtype=wnet_type,
                widim=self.num_bin,
                wunits=wunits,
                wprojs=wprojs,
                wlayers=wlayers,
                taps=taps,
                delay=delay,
                dropout_rate=wdropout_rate,
                iterations=iterations,
                use_dnn_mask=use_dnn_mask_for_wpe,
            )
        else:
            self.wpe = None

        self.ref_channel = ref_channel
        if self.use_beamformer:
            self.beamformer = DNN_Beamformer(
                btype=bnet_type,
                bidim=self.num_bin,
                bunits=bunits,
                bprojs=bprojs,
                blayers=blayers,
                num_spk=num_spk,
                use_noise_mask=use_noise_mask,
                dropout_rate=bdropout_rate,
                badim=badim,
                ref_channel=ref_channel,
                beamformer_type=beamformer_type,
                btaps=taps,
                bdelay=delay,
            )
        else:
            self.beamformer = None
Exemple #21
0
class BeamformerNet(AbsEnhancement):
    """TF Masking based beamformer

    """

    def __init__(
        self,
        num_spk: int = 1,
        normalize_input: bool = False,
        mask_type: str = "IPM^2",
        # STFT options
        n_fft: int = 512,
        win_length: int = None,
        hop_length: int = 128,
        center: bool = True,
        window: Optional[str] = "hann",
        pad_mode: str = "reflect",
        normalized: bool = False,
        onesided: bool = True,
        # Dereverberation options
        use_wpe: bool = False,
        wnet_type: str = "blstmp",
        wlayers: int = 3,
        wunits: int = 300,
        wprojs: int = 320,
        wdropout_rate: float = 0.0,
        taps: int = 5,
        delay: int = 3,
        use_dnn_mask_for_wpe: bool = True,
        # Beamformer options
        use_beamformer: bool = True,
        bnet_type: str = "blstmp",
        blayers: int = 3,
        bunits: int = 300,
        bprojs: int = 320,
        badim: int = 320,
        ref_channel: int = -1,
        use_noise_mask: bool = True,
        beamformer_type="mvdr",
        bdropout_rate=0.0,
    ):
        super(BeamformerNet, self).__init__()

        self.mask_type = mask_type

        self.num_spk = num_spk
        self.num_bin = n_fft // 2 + 1

        self.stft = Stft(
            n_fft=n_fft,
            win_length=win_length,
            hop_length=hop_length,
            center=center,
            window=window,
            pad_mode=pad_mode,
            normalized=normalized,
            onesided=onesided,
        )

        self.normalize_input = normalize_input
        self.use_beamformer = use_beamformer
        self.use_wpe = use_wpe

        if self.use_wpe:
            if use_dnn_mask_for_wpe:
                # Use DNN for power estimation
                iterations = 1
            else:
                # Performing as conventional WPE, without DNN Estimator
                iterations = 2

            self.wpe = DNN_WPE(
                wtype=wnet_type,
                widim=self.num_bin,
                wunits=wunits,
                wprojs=wprojs,
                wlayers=wlayers,
                taps=taps,
                delay=delay,
                dropout_rate=wdropout_rate,
                iterations=iterations,
                use_dnn_mask=use_dnn_mask_for_wpe,
            )
        else:
            self.wpe = None

        self.ref_channel = ref_channel
        if self.use_beamformer:
            self.beamformer = DNN_Beamformer(
                btype=bnet_type,
                bidim=self.num_bin,
                bunits=bunits,
                bprojs=bprojs,
                blayers=blayers,
                num_spk=num_spk,
                use_noise_mask=use_noise_mask,
                dropout_rate=bdropout_rate,
                badim=badim,
                ref_channel=ref_channel,
                beamformer_type=beamformer_type,
                btaps=taps,
                bdelay=delay,
            )
        else:
            self.beamformer = None

    def forward(self, input: torch.Tensor, ilens: torch.Tensor):
        """Forward.

        Args:
            input (torch.Tensor): mixed speech [Batch, Nsample, Channel]
            ilens (torch.Tensor): input lengths [Batch]

        Returns:
            enhanced speech  (single-channel):
                torch.Tensor or List[torch.Tensor]
            output lengths
            predcited masks: OrderedDict[
                'dereverb': torch.Tensor(Batch, Frames, Channel, Freq),
                'spk1': torch.Tensor(Batch, Frames, Channel, Freq),
                'spk2': torch.Tensor(Batch, Frames, Channel, Freq),
                ...
                'spkn': torch.Tensor(Batch, Frames, Channel, Freq),
                'noise1': torch.Tensor(Batch, Frames, Channel, Freq),
            ]
        """
        # wave -> stft -> magnitude specturm
        input_spectrum, flens = self.stft(input, ilens)
        # (Batch, Frames, Freq) or (Batch, Frames, Channels, Freq)
        input_spectrum = ComplexTensor(input_spectrum[..., 0], input_spectrum[..., 1])
        if self.normalize_input:
            input_spectrum = input_spectrum / abs(input_spectrum).max()

        enhanced = input_spectrum
        masks = OrderedDict()

        if input_spectrum.dim() == 3:
            # single-channel input
            if self.use_wpe:
                # (B, T, F)
                enhanced, flens, mask_w = self.wpe(input_spectrum.unsqueeze(-2), flens)
                enhanced = enhanced.squeeze(-2)
                if mask_w is not None:
                    masks["dereverb"] = mask_w.squeeze(-2)

        elif input_spectrum.dim() == 4:
            # multi-channel input
            # 1. WPE
            if self.use_wpe:
                # (B, T, C, F)
                enhanced, flens, mask_w = self.wpe(input_spectrum, flens)
                if mask_w is not None:
                    masks["dereverb"] = mask_w

            # 2. Beamformer
            if self.use_beamformer:
                # enhanced: (B, T, C, F) -> (B, T, F)
                enhanced, flens, masks_b = self.beamformer(enhanced, flens)
                for spk in range(self.num_spk):
                    masks["spk{}".format(spk + 1)] = masks_b[spk]
                if len(masks_b) > self.num_spk:
                    masks["noise1"] = masks_b[self.num_spk]

        else:
            raise ValueError(
                "Invalid spectrum dimension: {}".format(input_spectrum.shape)
            )

        # Convert ComplexTensor to torch.Tensor
        # (B, T, F) -> (B, T, F, 2)
        if isinstance(enhanced, list):
            # multi-speaker output
            enhanced = [torch.stack([enh.real, enh.imag], dim=-1) for enh in enhanced]
        else:
            # single-speaker output
            enhanced = torch.stack([enhanced.real, enhanced.imag], dim=-1).float()
        return enhanced, flens, masks

    def forward_rawwav(self, input: torch.Tensor, ilens: torch.Tensor):
        """Output with wavformes.

        Args:
            input (torch.Tensor): mixed speech [Batch, Nsample, Channel]
            ilens (torch.Tensor): input lengths [Batch]

        Returns:
            predcited speech wavs (single-channel):
                torch.Tensor(Batch, Nsamples), or List[torch.Tensor(Batch, Nsamples)]
            output lengths
            predcited masks: OrderedDict[
                'dereverb': torch.Tensor(Batch, Frames, Channel, Freq),
                'spk1': torch.Tensor(Batch, Frames, Channel, Freq),
                'spk2': torch.Tensor(Batch, Frames, Channel, Freq),
                ...
                'spkn': torch.Tensor(Batch, Frames, Channel, Freq),
                'noise1': torch.Tensor(Batch, Frames, Channel, Freq),
            ]
        """
        enhanced, flens, masks = self.forward(input, ilens)
        if isinstance(enhanced, list):
            # multi-speaker input
            predicted_wavs = [self.stft.inverse(ps, ilens)[0] for ps in enhanced]
        else:
            # single-speaker input
            predicted_wavs = self.stft.inverse(enhanced, ilens)[0]

        return predicted_wavs, ilens, masks
Exemple #22
0
class TFMaskingNet(AbsEnhancement):
    """TF Masking Speech Separation Net."""
    def __init__(
        self,
        n_fft: int = 512,
        win_length: int = None,
        hop_length: int = 128,
        rnn_type: str = "blstm",
        layer: int = 3,
        unit: int = 512,
        dropout: float = 0.0,
        num_spk: int = 2,
        nonlinear: str = "sigmoid",
        utt_mvn: bool = False,
        mask_type: str = "IRM",
        loss_type: str = "mask_mse",
    ):
        super(TFMaskingNet, self).__init__()

        self.num_spk = num_spk
        self.num_bin = n_fft // 2 + 1
        self.mask_type = mask_type
        self.loss_type = loss_type
        if loss_type not in ("mask_mse", "magnitude", "spectrum"):
            raise ValueError("Unsupported loss type: %s" % loss_type)

        self.stft = Stft(
            n_fft=n_fft,
            win_length=win_length,
            hop_length=hop_length,
        )

        if utt_mvn:
            self.utt_mvn = UtteranceMVN(norm_means=True, norm_vars=True)

        else:
            self.utt_mvn = None

        self.rnn = RNN(
            idim=self.num_bin,
            elayers=layer,
            cdim=unit,
            hdim=unit,
            dropout=dropout,
            typ=rnn_type,
        )

        self.linear = torch.nn.ModuleList(
            [torch.nn.Linear(unit, self.num_bin) for _ in range(self.num_spk)])

        if nonlinear not in ("sigmoid", "relu", "tanh"):
            raise ValueError("Not supporting nonlinear={}".format(nonlinear))

        self.nonlinear = {
            "sigmoid": torch.nn.Sigmoid(),
            "relu": torch.nn.ReLU(),
            "tanh": torch.nn.Tanh(),
        }[nonlinear]

    def forward(self, input: torch.Tensor, ilens: torch.Tensor):
        """Forward.

        Args:
            input (torch.Tensor): mixed speech [Batch, sample]
            ilens (torch.Tensor): input lengths [Batch]

        Returns:
            separated (list[ComplexTensor]): [(B, T, F), ...]
            ilens (torch.Tensor): (B,)
            predcited masks: OrderedDict[
                'spk1': torch.Tensor(Batch, Frames, Channel, Freq),
                'spk2': torch.Tensor(Batch, Frames, Channel, Freq),
                ...
                'spkn': torch.Tensor(Batch, Frames, Channel, Freq),
            ]
        """

        # wave -> stft -> magnitude specturm
        input_spectrum, flens = self.stft(input, ilens)
        input_spectrum = ComplexTensor(input_spectrum[..., 0],
                                       input_spectrum[..., 1])
        input_magnitude = abs(input_spectrum)
        input_phase = input_spectrum / (input_magnitude + 10e-12)

        # apply utt mvn
        if self.utt_mvn:
            input_magnitude_mvn, fle = self.utt_mvn(input_magnitude, flens)
        else:
            input_magnitude_mvn = input_magnitude

        # predict masks for each speaker
        x, flens, _ = self.rnn(input_magnitude_mvn, flens)
        masks = []
        for linear in self.linear:
            y = linear(x)
            y = self.nonlinear(y)
            masks.append(y)

        if self.training and self.loss_type.startswith("mask"):
            predicted_spectrums = None
        else:
            # apply mask
            predict_magnitude = [input_magnitude * m for m in masks]
            predicted_spectrums = [
                input_phase * pm for pm in predict_magnitude
            ]

        masks = OrderedDict(
            zip(["spk{}".format(i + 1) for i in range(len(masks))], masks))
        return predicted_spectrums, flens, masks

    def forward_rawwav(
            self, input: torch.Tensor,
            ilens: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """Output with waveforms.

        Args:
            input (torch.Tensor): mixed speech [Batch, sample]
            ilens (torch.Tensor): input lengths [Batch]

        Returns:
            predcited speech [Batch, num_speaker, sample]
            output lengths
            predcited masks: OrderedDict[
                'spk1': torch.Tensor(Batch, Frames, Channel, Freq),
                'spk2': torch.Tensor(Batch, Frames, Channel, Freq),
                ...
                'spkn': torch.Tensor(Batch, Frames, Channel, Freq),
            ]
        """

        # predict spectrum for each speaker
        predicted_spectrums, flens, masks = self.forward(input, ilens)

        if predicted_spectrums is None:
            predicted_wavs = None
        elif isinstance(predicted_spectrums, list):
            # multi-speaker input
            predicted_wavs = [
                self.stft.inverse(ps, ilens)[0] for ps in predicted_spectrums
            ]
        else:
            # single-speaker input
            predicted_wavs = self.stft.inverse(predicted_spectrums, ilens)[0]

        return predicted_wavs, ilens, masks