示例#1
0
    def apply_beamforming(
        self,
        data,
        ilens,
        psd_n,
        psd_speech,
        psd_distortion=None,
        rtf_mat=None,
        spk=0,
    ):
        """Beamforming with the provided statistics.

        Args:
            data (torch.complex64/ComplexTensor): (B, F, C, T)
            ilens (torch.Tensor): (B,)
            psd_n (torch.complex64/ComplexTensor):
                Noise covariance matrix for MVDR (B, F, C, C)
                Observation covariance matrix for MPDR/wMPDR (B, F, C, C)
                Stacked observation covariance for WPD (B,F,(btaps+1)*C,(btaps+1)*C)
            psd_speech (torch.complex64/ComplexTensor):
                Speech covariance matrix (B, F, C, C)
            psd_distortion (torch.complex64/ComplexTensor):
                Noise covariance matrix (B, F, C, C)
            rtf_mat (torch.complex64/ComplexTensor):
                RTF matrix (B, F, C, num_spk)
            spk (int): speaker index
        Return:
            enhanced (torch.complex64/ComplexTensor): (B, F, T)
            ws (torch.complex64/ComplexTensor): (B, F) or (B, F, (btaps+1)*C)
        """
        # u: (B, C)
        if self.ref_channel < 0:
            u, _ = self.ref(psd_speech.to(dtype=data.dtype), ilens)
            u = u.double()
        else:
            if self.beamformer_type.endswith("_souden"):
                # (optional) Create onehot vector for fixed reference microphone
                u = torch.zeros(
                    *(data.size()[:-3] + (data.size(-2),)),
                    device=data.device,
                    dtype=torch.double
                )
                u[..., self.ref_channel].fill_(1)
            else:
                # for simplifying computation in RTF-based beamforming
                u = self.ref_channel

        if self.beamformer_type in ("mvdr", "mpdr", "wmpdr"):
            ws = get_mvdr_vector_with_rtf(
                to_double(psd_n),
                to_double(psd_speech),
                to_double(psd_distortion),
                iterations=self.rtf_iterations,
                reference_vector=u,
                normalize_ref_channel=self.ref_channel,
                use_torch_solver=self.use_torch_solver,
                diagonal_loading=self.diagonal_loading,
                diag_eps=self.diag_eps,
            )
            enhanced = apply_beamforming_vector(ws, to_double(data))
        elif self.beamformer_type == "mvdr_tfs":
            assert isinstance(psd_n, (list, tuple))
            ws = [
                get_mvdr_vector_with_rtf(
                    to_double(psd_n_i),
                    to_double(psd_speech),
                    to_double(psd_distortion),
                    iterations=self.rtf_iterations,
                    reference_vector=u,
                    normalize_ref_channel=self.ref_channel,
                    use_torch_solver=self.use_torch_solver,
                    diagonal_loading=self.diagonal_loading,
                    diag_eps=self.diag_eps,
                )
                for psd_n_i in psd_n
            ]
            enhanced = stack([apply_beamforming_vector(w, to_double(data)) for w in ws])
            with torch.no_grad():
                index = enhanced.abs().argmin(dim=0, keepdims=True)
            enhanced = enhanced.gather(0, index).squeeze(0)
            ws = stack(ws, dim=0)
        elif self.beamformer_type in (
            "mpdr_souden",
            "mvdr_souden",
            "wmpdr_souden",
        ):
            ws = get_mvdr_vector(
                to_double(psd_speech),
                to_double(psd_n),
                u,
                use_torch_solver=self.use_torch_solver,
                diagonal_loading=self.diagonal_loading,
                diag_eps=self.diag_eps,
            )
            enhanced = apply_beamforming_vector(ws, to_double(data))
        elif self.beamformer_type == "mvdr_tfs_souden":
            assert isinstance(psd_n, (list, tuple))
            ws = [
                get_mvdr_vector(
                    to_double(psd_speech),
                    to_double(psd_n_i),
                    u,
                    use_torch_solver=self.use_torch_solver,
                    diagonal_loading=self.diagonal_loading,
                    diag_eps=self.diag_eps,
                )
                for psd_n_i in psd_n
            ]
            enhanced = stack([apply_beamforming_vector(w, to_double(data)) for w in ws])
            with torch.no_grad():
                index = enhanced.abs().argmin(dim=0, keepdims=True)
            enhanced = enhanced.gather(0, index).squeeze(0)
            ws = stack(ws, dim=0)
        elif self.beamformer_type == "wpd":
            ws = get_WPD_filter_with_rtf(
                to_double(psd_n),
                to_double(psd_speech),
                to_double(psd_distortion),
                iterations=self.rtf_iterations,
                reference_vector=u,
                normalize_ref_channel=self.ref_channel,
                use_torch_solver=self.use_torch_solver,
                diagonal_loading=self.diagonal_loading,
                diag_eps=self.diag_eps,
            )
            enhanced = perform_WPD_filtering(
                ws, to_double(data), self.bdelay, self.btaps
            )
        elif self.beamformer_type == "wpd_souden":
            ws = get_WPD_filter_v2(
                to_double(psd_speech),
                to_double(psd_n),
                u,
                diagonal_loading=self.diagonal_loading,
                diag_eps=self.diag_eps,
            )
            enhanced = perform_WPD_filtering(
                ws, to_double(data), self.bdelay, self.btaps
            )
        elif self.beamformer_type in ("mwf", "wmwf"):
            ws = get_mwf_vector(
                to_double(psd_speech),
                to_double(psd_n),
                u,
                use_torch_solver=self.use_torch_solver,
                diagonal_loading=self.diagonal_loading,
                diag_eps=self.diag_eps,
            )
            enhanced = apply_beamforming_vector(ws, to_double(data))
        elif self.beamformer_type == "sdw_mwf":
            ws = get_sdw_mwf_vector(
                to_double(psd_speech),
                to_double(psd_n),
                u,
                denoising_weight=self.mwf_mu,
                use_torch_solver=self.use_torch_solver,
                diagonal_loading=self.diagonal_loading,
                diag_eps=self.diag_eps,
            )
            enhanced = apply_beamforming_vector(ws, to_double(data))
        elif self.beamformer_type == "r1mwf":
            ws = get_rank1_mwf_vector(
                to_double(psd_speech),
                to_double(psd_n),
                u,
                denoising_weight=self.mwf_mu,
                use_torch_solver=self.use_torch_solver,
                diagonal_loading=self.diagonal_loading,
                diag_eps=self.diag_eps,
            )
            enhanced = apply_beamforming_vector(ws, to_double(data))
        elif self.beamformer_type in ("lcmp", "wlcmp", "lcmv"):
            ws = get_lcmv_vector_with_rtf(
                to_double(psd_n),
                to_double(rtf_mat),
                reference_vector=spk,
                use_torch_solver=self.use_torch_solver,
                diagonal_loading=self.diagonal_loading,
                diag_eps=self.diag_eps,
            )
            enhanced = apply_beamforming_vector(ws, to_double(data))
        elif self.beamformer_type.startswith("gev"):
            ws = get_gev_vector(
                to_double(psd_n),
                to_double(psd_speech),
                mode="power",
                diagonal_loading=self.diagonal_loading,
                diag_eps=self.diag_eps,
            )
            enhanced = apply_beamforming_vector(ws, to_double(data))
            if self.beamformer_type == "gev_ban":
                gain = blind_analytic_normalization(ws, to_double(psd_n))
                enhanced = enhanced * gain.unsqueeze(-1)
        else:
            raise ValueError(
                "Not supporting beamformer_type={}".format(self.beamformer_type)
            )

        return enhanced.to(dtype=data.dtype), ws.to(dtype=data.dtype)
示例#2
0
        def apply_beamforming(data, ilens, psd_n, psd_speech, psd_distortion=None):
            """Beamforming with the provided statistics.

            Args:
                data (torch.complex64/ComplexTensor): (B, F, C, T)
                ilens (torch.Tensor): (B,)
                psd_n (torch.complex64/ComplexTensor):
                    Noise covariance matrix for MVDR (B, F, C, C)
                    Observation covariance matrix for MPDR/wMPDR (B, F, C, C)
                    Stacked observation covariance for WPD (B,F,(btaps+1)*C,(btaps+1)*C)
                psd_speech (torch.complex64/ComplexTensor):
                    Speech covariance matrix (B, F, C, C)
                psd_distortion (torch.complex64/ComplexTensor):
                    Noise covariance matrix (B, F, C, C)
            Return:
                enhanced (torch.complex64/ComplexTensor): (B, F, T)
                ws (torch.complex64/ComplexTensor): (B, F) or (B, F, (btaps+1)*C)
            """
            # u: (B, C)
            if self.ref_channel < 0:
                u, _ = self.ref(psd_speech.to(dtype=data.dtype), ilens)
                u = u.double()
            else:
                if self.beamformer_type.endswith("_souden"):
                    # (optional) Create onehot vector for fixed reference microphone
                    u = torch.zeros(
                        *(data.size()[:-3] + (data.size(-2),)),
                        device=data.device,
                        dtype=torch.double
                    )
                    u[..., self.ref_channel].fill_(1)
                else:
                    # for simplifying computation in RTF-based beamforming
                    u = self.ref_channel

            if self.beamformer_type in ("mvdr", "mpdr", "wmpdr"):
                ws = get_mvdr_vector_with_rtf(
                    to_double(psd_n),
                    to_double(psd_speech),
                    to_double(psd_distortion),
                    iterations=self.rtf_iterations,
                    reference_vector=u,
                    normalize_ref_channel=self.ref_channel,
                    use_torch_solver=self.use_torch_solver,
                    diagonal_loading=self.diagonal_loading,
                    diag_eps=self.diag_eps,
                )
                enhanced = apply_beamforming_vector(ws, to_double(data))
            elif self.beamformer_type in ("mpdr_souden", "mvdr_souden", "wmpdr_souden"):
                ws = get_mvdr_vector(
                    to_double(psd_speech),
                    to_double(psd_n),
                    u,
                    use_torch_solver=self.use_torch_solver,
                    diagonal_loading=self.diagonal_loading,
                    diag_eps=self.diag_eps,
                )
                enhanced = apply_beamforming_vector(ws, to_double(data))
            elif self.beamformer_type == "wpd":
                ws = get_WPD_filter_with_rtf(
                    to_double(psd_n),
                    to_double(psd_speech),
                    to_double(psd_distortion),
                    iterations=self.rtf_iterations,
                    reference_vector=u,
                    normalize_ref_channel=self.ref_channel,
                    use_torch_solver=self.use_torch_solver,
                    diagonal_loading=self.diagonal_loading,
                    diag_eps=self.diag_eps,
                )
                enhanced = perform_WPD_filtering(
                    ws, to_double(data), self.bdelay, self.btaps
                )
            elif self.beamformer_type == "wpd_souden":
                ws = get_WPD_filter_v2(
                    to_double(psd_speech),
                    to_double(psd_n),
                    u,
                    diagonal_loading=self.diagonal_loading,
                    diag_eps=self.diag_eps,
                )
                enhanced = perform_WPD_filtering(
                    ws, to_double(data), self.bdelay, self.btaps
                )
            else:
                raise ValueError(
                    "Not supporting beamformer_type={}".format(self.beamformer_type)
                )

            return enhanced.to(dtype=data.dtype), ws.to(dtype=data.dtype)