Esempio n. 1
0
    def __init__(
        self,
        num_spk: int = 1,
        normalize_input: bool = False,
        mask_type: str = "IPM^2",
        # STFT options
        n_fft: int = 512,
        win_length: int = None,
        hop_length: int = 128,
        center: bool = True,
        window: Optional[str] = "hann",
        pad_mode: str = "reflect",
        normalized: bool = False,
        onesided: bool = True,
        # Dereverberation options
        use_wpe: bool = False,
        wnet_type: str = "blstmp",
        wlayers: int = 3,
        wunits: int = 300,
        wprojs: int = 320,
        wdropout_rate: float = 0.0,
        taps: int = 5,
        delay: int = 3,
        use_dnn_mask_for_wpe: bool = True,
        # Beamformer options
        use_beamformer: bool = True,
        bnet_type: str = "blstmp",
        blayers: int = 3,
        bunits: int = 300,
        bprojs: int = 320,
        badim: int = 320,
        ref_channel: int = -1,
        use_noise_mask: bool = True,
        beamformer_type="mvdr",
        bdropout_rate=0.0,
    ):
        super(BeamformerNet, self).__init__()

        self.mask_type = mask_type

        self.num_spk = num_spk
        self.num_bin = n_fft // 2 + 1

        self.stft = Stft(
            n_fft=n_fft,
            win_length=win_length,
            hop_length=hop_length,
            center=center,
            window=window,
            pad_mode=pad_mode,
            normalized=normalized,
            onesided=onesided,
        )

        self.normalize_input = normalize_input
        self.use_beamformer = use_beamformer
        self.use_wpe = use_wpe

        if self.use_wpe:
            if use_dnn_mask_for_wpe:
                # Use DNN for power estimation
                iterations = 1
            else:
                # Performing as conventional WPE, without DNN Estimator
                iterations = 2

            self.wpe = DNN_WPE(
                wtype=wnet_type,
                widim=self.num_bin,
                wunits=wunits,
                wprojs=wprojs,
                wlayers=wlayers,
                taps=taps,
                delay=delay,
                dropout_rate=wdropout_rate,
                iterations=iterations,
                use_dnn_mask=use_dnn_mask_for_wpe,
            )
        else:
            self.wpe = None

        self.ref_channel = ref_channel
        if self.use_beamformer:
            self.beamformer = DNN_Beamformer(
                btype=bnet_type,
                bidim=self.num_bin,
                bunits=bunits,
                bprojs=bprojs,
                blayers=blayers,
                num_spk=num_spk,
                use_noise_mask=use_noise_mask,
                dropout_rate=bdropout_rate,
                badim=badim,
                ref_channel=ref_channel,
                beamformer_type=beamformer_type,
                btaps=taps,
                bdelay=delay,
            )
        else:
            self.beamformer = None
Esempio n. 2
0
class NeuralBeamformer(AbsSeparator):
    def __init__(
        self,
        input_dim: int,
        num_spk: int = 1,
        loss_type: str = "mask_mse",
        # Dereverberation options
        use_wpe: bool = False,
        wnet_type: str = "blstmp",
        wlayers: int = 3,
        wunits: int = 300,
        wprojs: int = 320,
        wdropout_rate: float = 0.0,
        taps: int = 5,
        delay: int = 3,
        use_dnn_mask_for_wpe: bool = True,
        wnonlinear: str = "crelu",
        multi_source_wpe: bool = True,
        wnormalization: bool = False,
        # Beamformer options
        use_beamformer: bool = True,
        bnet_type: str = "blstmp",
        blayers: int = 3,
        bunits: int = 300,
        bprojs: int = 320,
        badim: int = 320,
        ref_channel: int = -1,
        use_noise_mask: bool = True,
        bnonlinear: str = "sigmoid",
        beamformer_type: str = "mvdr_souden",
        rtf_iterations: int = 2,
        bdropout_rate: float = 0.0,
        shared_power: bool = True,
        # For numerical stability
        diagonal_loading: bool = True,
        diag_eps_wpe: float = 1e-7,
        diag_eps_bf: float = 1e-7,
        mask_flooring: bool = False,
        flooring_thres_wpe: float = 1e-6,
        flooring_thres_bf: float = 1e-6,
        use_torch_solver: bool = True,
    ):
        super().__init__()

        self._num_spk = num_spk
        self.loss_type = loss_type
        if loss_type not in ("mask_mse", "spectrum", "spectrum_log", "magnitude"):
            raise ValueError("Unsupported loss type: %s" % loss_type)

        self.use_beamformer = use_beamformer
        self.use_wpe = use_wpe

        if self.use_wpe:
            if use_dnn_mask_for_wpe:
                # Use DNN for power estimation
                iterations = 1
            else:
                # Performing as conventional WPE, without DNN Estimator
                iterations = 2

            self.wpe = DNN_WPE(
                wtype=wnet_type,
                widim=input_dim,
                wlayers=wlayers,
                wunits=wunits,
                wprojs=wprojs,
                dropout_rate=wdropout_rate,
                taps=taps,
                delay=delay,
                use_dnn_mask=use_dnn_mask_for_wpe,
                nmask=1 if multi_source_wpe else num_spk,
                nonlinear=wnonlinear,
                iterations=iterations,
                normalization=wnormalization,
                diagonal_loading=diagonal_loading,
                diag_eps=diag_eps_wpe,
                mask_flooring=mask_flooring,
                flooring_thres=flooring_thres_wpe,
                use_torch_solver=use_torch_solver,
            )
        else:
            self.wpe = None

        self.ref_channel = ref_channel
        if self.use_beamformer:
            self.beamformer = DNN_Beamformer(
                bidim=input_dim,
                btype=bnet_type,
                blayers=blayers,
                bunits=bunits,
                bprojs=bprojs,
                num_spk=num_spk,
                use_noise_mask=use_noise_mask,
                nonlinear=bnonlinear,
                dropout_rate=bdropout_rate,
                badim=badim,
                ref_channel=ref_channel,
                beamformer_type=beamformer_type,
                rtf_iterations=rtf_iterations,
                btaps=taps,
                bdelay=delay,
                diagonal_loading=diagonal_loading,
                diag_eps=diag_eps_bf,
                mask_flooring=mask_flooring,
                flooring_thres=flooring_thres_bf,
                use_torch_solver=use_torch_solver,
            )
        else:
            self.beamformer = None

        # share speech powers between WPE and beamforming (wMPDR/WPD)
        self.shared_power = shared_power and use_wpe

    def forward(
        self, input: Union[torch.Tensor, ComplexTensor], ilens: torch.Tensor
    ) -> Tuple[List[Union[torch.Tensor, ComplexTensor]], torch.Tensor, OrderedDict]:
        """Forward.

        Args:
            input (torch.complex64/ComplexTensor):
                mixed speech [Batch, Frames, Channel, Freq]
            ilens (torch.Tensor): input lengths [Batch]

        Returns:
            enhanced speech (single-channel): List[torch.complex64/ComplexTensor]
            output lengths
            other predcited data: OrderedDict[
                'dereverb1': ComplexTensor(Batch, Frames, Channel, Freq),
                'mask_dereverb1': torch.Tensor(Batch, Frames, Channel, Freq),
                'mask_noise1': torch.Tensor(Batch, Frames, Channel, Freq),
                'mask_spk1': torch.Tensor(Batch, Frames, Channel, Freq),
                'mask_spk2': torch.Tensor(Batch, Frames, Channel, Freq),
                ...
                'mask_spkn': torch.Tensor(Batch, Frames, Channel, Freq),
            ]
        """
        # Shape of input spectrum must be (B, T, F) or (B, T, C, F)
        assert input.dim() in (3, 4), input.dim()
        enhanced = input
        others = OrderedDict()

        if (
            self.training
            and self.loss_type is not None
            and self.loss_type.startswith("mask")
        ):
            # Only estimating masks during training for saving memory
            if self.use_wpe:
                if input.dim() == 3:
                    mask_w, ilens = self.wpe.predict_mask(input.unsqueeze(-2), ilens)
                    mask_w = mask_w.squeeze(-2)
                elif input.dim() == 4:
                    mask_w, ilens = self.wpe.predict_mask(input, ilens)

                if mask_w is not None:
                    if isinstance(enhanced, list):
                        # single-source WPE
                        for spk in range(self.num_spk):
                            others["mask_dereverb{}".format(spk + 1)] = mask_w[spk]
                    else:
                        # multi-source WPE
                        others["mask_dereverb1"] = mask_w

            if self.use_beamformer and input.dim() == 4:
                others_b, ilens = self.beamformer.predict_mask(input, ilens)
                for spk in range(self.num_spk):
                    others["mask_spk{}".format(spk + 1)] = others_b[spk]
                if len(others_b) > self.num_spk:
                    others["mask_noise1"] = others_b[self.num_spk]

            return None, ilens, others

        else:
            powers = None
            # Performing both mask estimation and enhancement
            if input.dim() == 3:
                # single-channel input (B, T, F)
                if self.use_wpe:
                    enhanced, ilens, mask_w, powers = self.wpe(
                        input.unsqueeze(-2), ilens
                    )
                    if isinstance(enhanced, list):
                        # single-source WPE
                        enhanced = [enh.squeeze(-2) for enh in enhanced]
                        if mask_w is not None:
                            for spk in range(self.num_spk):
                                key = "dereverb{}".format(spk + 1)
                                others[key] = enhanced[spk]
                                others["mask_" + key] = mask_w[spk].squeeze(-2)
                    else:
                        # multi-source WPE
                        enhanced = enhanced.squeeze(-2)
                        if mask_w is not None:
                            others["dereverb1"] = enhanced
                            others["mask_dereverb1"] = mask_w.squeeze(-2)
            else:
                # multi-channel input (B, T, C, F)
                # 1. WPE
                if self.use_wpe:
                    enhanced, ilens, mask_w, powers = self.wpe(input, ilens)
                    if mask_w is not None:
                        if isinstance(enhanced, list):
                            # single-source WPE
                            for spk in range(self.num_spk):
                                key = "dereverb{}".format(spk + 1)
                                others[key] = enhanced[spk]
                                others["mask_" + key] = mask_w[spk]
                        else:
                            # multi-source WPE
                            others["dereverb1"] = enhanced
                            others["mask_dereverb1"] = mask_w.squeeze(-2)

                # 2. Beamformer
                if self.use_beamformer:
                    if (
                        not self.beamformer.beamformer_type.startswith("wmpdr")
                        or not self.beamformer.beamformer_type.startswith("wpd")
                        or not self.shared_power
                        or (self.wpe.nmask == 1 and self.num_spk > 1)
                    ):
                        powers = None

                    # enhanced: (B, T, C, F) -> (B, T, F)
                    if isinstance(enhanced, list):
                        # outputs of single-source WPE
                        raise NotImplementedError(
                            "Single-source WPE is not supported with beamformer "
                            "in multi-speaker cases."
                        )
                    else:
                        # output of multi-source WPE
                        enhanced, ilens, others_b = self.beamformer(
                            enhanced, ilens, powers=powers
                        )
                    for spk in range(self.num_spk):
                        others["mask_spk{}".format(spk + 1)] = others_b[spk]
                    if len(others_b) > self.num_spk:
                        others["mask_noise1"] = others_b[self.num_spk]

        if not isinstance(enhanced, list):
            enhanced = [enhanced]

        return enhanced, ilens, others

    @property
    def num_spk(self):
        return self._num_spk
Esempio n. 3
0
    def __init__(
        self,
        input_dim: int,
        num_spk: int = 1,
        loss_type: str = "mask_mse",
        # Dereverberation options
        use_wpe: bool = False,
        wnet_type: str = "blstmp",
        wlayers: int = 3,
        wunits: int = 300,
        wprojs: int = 320,
        wdropout_rate: float = 0.0,
        taps: int = 5,
        delay: int = 3,
        use_dnn_mask_for_wpe: bool = True,
        wnonlinear: str = "crelu",
        multi_source_wpe: bool = True,
        wnormalization: bool = False,
        # Beamformer options
        use_beamformer: bool = True,
        bnet_type: str = "blstmp",
        blayers: int = 3,
        bunits: int = 300,
        bprojs: int = 320,
        badim: int = 320,
        ref_channel: int = -1,
        use_noise_mask: bool = True,
        bnonlinear: str = "sigmoid",
        beamformer_type: str = "mvdr_souden",
        rtf_iterations: int = 2,
        bdropout_rate: float = 0.0,
        shared_power: bool = True,
        # For numerical stability
        diagonal_loading: bool = True,
        diag_eps_wpe: float = 1e-7,
        diag_eps_bf: float = 1e-7,
        mask_flooring: bool = False,
        flooring_thres_wpe: float = 1e-6,
        flooring_thres_bf: float = 1e-6,
        use_torch_solver: bool = True,
    ):
        super().__init__()

        self._num_spk = num_spk
        self.loss_type = loss_type
        if loss_type not in ("mask_mse", "spectrum", "spectrum_log", "magnitude"):
            raise ValueError("Unsupported loss type: %s" % loss_type)

        self.use_beamformer = use_beamformer
        self.use_wpe = use_wpe

        if self.use_wpe:
            if use_dnn_mask_for_wpe:
                # Use DNN for power estimation
                iterations = 1
            else:
                # Performing as conventional WPE, without DNN Estimator
                iterations = 2

            self.wpe = DNN_WPE(
                wtype=wnet_type,
                widim=input_dim,
                wlayers=wlayers,
                wunits=wunits,
                wprojs=wprojs,
                dropout_rate=wdropout_rate,
                taps=taps,
                delay=delay,
                use_dnn_mask=use_dnn_mask_for_wpe,
                nmask=1 if multi_source_wpe else num_spk,
                nonlinear=wnonlinear,
                iterations=iterations,
                normalization=wnormalization,
                diagonal_loading=diagonal_loading,
                diag_eps=diag_eps_wpe,
                mask_flooring=mask_flooring,
                flooring_thres=flooring_thres_wpe,
                use_torch_solver=use_torch_solver,
            )
        else:
            self.wpe = None

        self.ref_channel = ref_channel
        if self.use_beamformer:
            self.beamformer = DNN_Beamformer(
                bidim=input_dim,
                btype=bnet_type,
                blayers=blayers,
                bunits=bunits,
                bprojs=bprojs,
                num_spk=num_spk,
                use_noise_mask=use_noise_mask,
                nonlinear=bnonlinear,
                dropout_rate=bdropout_rate,
                badim=badim,
                ref_channel=ref_channel,
                beamformer_type=beamformer_type,
                rtf_iterations=rtf_iterations,
                btaps=taps,
                bdelay=delay,
                diagonal_loading=diagonal_loading,
                diag_eps=diag_eps_bf,
                mask_flooring=mask_flooring,
                flooring_thres=flooring_thres_bf,
                use_torch_solver=use_torch_solver,
            )
        else:
            self.beamformer = None

        # share speech powers between WPE and beamforming (wMPDR/WPD)
        self.shared_power = shared_power and use_wpe
Esempio n. 4
0
class BeamformerNet(AbsEnhancement):
    """TF Masking based beamformer"""
    def __init__(
        self,
        num_spk: int = 1,
        normalize_input: bool = False,
        mask_type: str = "IPM^2",
        loss_type: str = "mask_mse",
        # STFT options
        n_fft: int = 512,
        win_length: int = None,
        hop_length: int = 128,
        center: bool = True,
        window: Optional[str] = "hann",
        normalized: bool = False,
        onesided: bool = True,
        # Dereverberation options
        use_wpe: bool = False,
        wnet_type: str = "blstmp",
        wlayers: int = 3,
        wunits: int = 300,
        wprojs: int = 320,
        wdropout_rate: float = 0.0,
        taps: int = 5,
        delay: int = 3,
        use_dnn_mask_for_wpe: bool = True,
        wnonlinear: str = "crelu",
        # Beamformer options
        use_beamformer: bool = True,
        bnet_type: str = "blstmp",
        blayers: int = 3,
        bunits: int = 300,
        bprojs: int = 320,
        badim: int = 320,
        ref_channel: int = -1,
        use_noise_mask: bool = True,
        bnonlinear: str = "sigmoid",
        beamformer_type="mvdr",
        bdropout_rate=0.0,
    ):
        super(BeamformerNet, self).__init__()

        self.mask_type = mask_type
        self.loss_type = loss_type
        if loss_type not in ("mask_mse", "spectrum"):
            raise ValueError("Unsupported loss type: %s" % loss_type)

        self.num_spk = num_spk
        self.num_bin = n_fft // 2 + 1

        self.stft = Stft(
            n_fft=n_fft,
            win_length=win_length,
            hop_length=hop_length,
            center=center,
            window=window,
            normalized=normalized,
            onesided=onesided,
        )

        self.normalize_input = normalize_input
        self.use_beamformer = use_beamformer
        self.use_wpe = use_wpe

        if self.use_wpe:
            if use_dnn_mask_for_wpe:
                # Use DNN for power estimation
                iterations = 1
            else:
                # Performing as conventional WPE, without DNN Estimator
                iterations = 2

            self.wpe = DNN_WPE(
                wtype=wnet_type,
                widim=self.num_bin,
                wunits=wunits,
                wprojs=wprojs,
                wlayers=wlayers,
                taps=taps,
                delay=delay,
                dropout_rate=wdropout_rate,
                iterations=iterations,
                use_dnn_mask=use_dnn_mask_for_wpe,
                nonlinear=wnonlinear,
            )
        else:
            self.wpe = None

        self.ref_channel = ref_channel
        if self.use_beamformer:
            self.beamformer = DNN_Beamformer(
                btype=bnet_type,
                bidim=self.num_bin,
                bunits=bunits,
                bprojs=bprojs,
                blayers=blayers,
                num_spk=num_spk,
                use_noise_mask=use_noise_mask,
                nonlinear=bnonlinear,
                dropout_rate=bdropout_rate,
                badim=badim,
                ref_channel=ref_channel,
                beamformer_type=beamformer_type,
                btaps=taps,
                bdelay=delay,
            )
        else:
            self.beamformer = None

    def forward(self, input: torch.Tensor, ilens: torch.Tensor):
        """Forward.

        Args:
            input (torch.Tensor): mixed speech [Batch, Nsample, Channel]
            ilens (torch.Tensor): input lengths [Batch]

        Returns:
            enhanced speech  (single-channel):
                torch.Tensor or List[torch.Tensor]
            output lengths
            predcited masks: OrderedDict[
                'dereverb': torch.Tensor(Batch, Frames, Channel, Freq),
                'spk1': torch.Tensor(Batch, Frames, Channel, Freq),
                'spk2': torch.Tensor(Batch, Frames, Channel, Freq),
                ...
                'spkn': torch.Tensor(Batch, Frames, Channel, Freq),
                'noise1': torch.Tensor(Batch, Frames, Channel, Freq),
            ]
        """
        # wave -> stft -> magnitude specturm
        input_spectrum, flens = self.stft(input, ilens)
        # (Batch, Frames, Freq) or (Batch, Frames, Channels, Freq)
        input_spectrum = ComplexTensor(input_spectrum[..., 0],
                                       input_spectrum[..., 1])
        if self.normalize_input:
            input_spectrum = input_spectrum / abs(input_spectrum).max()

        # Shape of input spectrum must be (B, T, F) or (B, T, C, F)
        assert input_spectrum.dim() in (3, 4), input_spectrum.dim()
        enhanced = input_spectrum
        masks = OrderedDict()

        if self.training and self.loss_type.startswith("mask"):
            # Only estimating masks for training
            if self.use_wpe:
                if input_spectrum.dim() == 3:
                    mask_w, flens = self.wpe.predict_mask(
                        input_spectrum.unsqueeze(-2), flens)
                    mask_w = mask_w.squeeze(-2)
                elif input_spectrum.dim() == 4:
                    if self.use_beamformer:
                        enhanced, flens, mask_w = self.wpe(
                            input_spectrum, flens)
                    else:
                        mask_w, flens = self.wpe.predict_mask(
                            input_spectrum, flens)

                if mask_w is not None:
                    masks["dereverb"] = mask_w

            if self.use_beamformer and input_spectrum.dim() == 4:
                masks_b, flens = self.beamformer.predict_mask(enhanced, flens)
                for spk in range(self.num_spk):
                    masks["spk{}".format(spk + 1)] = masks_b[spk]
                if len(masks_b) > self.num_spk:
                    masks["noise1"] = masks_b[self.num_spk]

            return None, flens, masks

        else:
            # Performing both mask estimation and enhancement
            if input_spectrum.dim() == 3:
                # single-channel input (B, T, F)
                if self.use_wpe:
                    enhanced, flens, mask_w = self.wpe(
                        input_spectrum.unsqueeze(-2), flens)
                    enhanced = enhanced.squeeze(-2)
                    if mask_w is not None:
                        masks["dereverb"] = mask_w.squeeze(-2)
            else:
                # multi-channel input (B, T, C, F)
                # 1. WPE
                if self.use_wpe:
                    enhanced, flens, mask_w = self.wpe(input_spectrum, flens)
                    if mask_w is not None:
                        masks["dereverb"] = mask_w

                # 2. Beamformer
                if self.use_beamformer:
                    # enhanced: (B, T, C, F) -> (B, T, F)
                    enhanced, flens, masks_b = self.beamformer(enhanced, flens)
                    for spk in range(self.num_spk):
                        masks["spk{}".format(spk + 1)] = masks_b[spk]
                    if len(masks_b) > self.num_spk:
                        masks["noise1"] = masks_b[self.num_spk]

        # Convert ComplexTensor to torch.Tensor
        # (B, T, F) -> (B, T, F, 2)
        if isinstance(enhanced, list):
            # multi-speaker output
            enhanced = [
                torch.stack([enh.real, enh.imag], dim=-1) for enh in enhanced
            ]
        else:
            # single-speaker output
            enhanced = torch.stack([enhanced.real, enhanced.imag],
                                   dim=-1).float()
        return enhanced, flens, masks

    def forward_rawwav(self, input: torch.Tensor, ilens: torch.Tensor):
        """Output with wavformes.

        Args:
            input (torch.Tensor): mixed speech [Batch, Nsample, Channel]
            ilens (torch.Tensor): input lengths [Batch]

        Returns:
            predcited speech wavs (single-channel):
                torch.Tensor(Batch, Nsamples), or List[torch.Tensor(Batch, Nsamples)]
            output lengths
            predcited masks: OrderedDict[
                'dereverb': torch.Tensor(Batch, Frames, Channel, Freq),
                'spk1': torch.Tensor(Batch, Frames, Channel, Freq),
                'spk2': torch.Tensor(Batch, Frames, Channel, Freq),
                ...
                'spkn': torch.Tensor(Batch, Frames, Channel, Freq),
                'noise1': torch.Tensor(Batch, Frames, Channel, Freq),
            ]
        """

        # predict spectrum for each speaker
        predicted_spectrums, flens, masks = self.forward(input, ilens)

        if predicted_spectrums is None:
            predicted_wavs = None
        elif isinstance(predicted_spectrums, list):
            # multi-speaker input
            predicted_wavs = [
                self.stft.inverse(ps, ilens)[0] for ps in predicted_spectrums
            ]
        else:
            # single-speaker input
            predicted_wavs = self.stft.inverse(predicted_spectrums, ilens)[0]

        return predicted_wavs, ilens, masks