Beispiel #1
0
    def apply_beamforming(
        self,
        data,
        ilens,
        psd_n,
        psd_speech,
        psd_distortion=None,
        rtf_mat=None,
        spk=0,
    ):
        """Beamforming with the provided statistics.

        Args:
            data (torch.complex64/ComplexTensor): (B, F, C, T)
            ilens (torch.Tensor): (B,)
            psd_n (torch.complex64/ComplexTensor):
                Noise covariance matrix for MVDR (B, F, C, C)
                Observation covariance matrix for MPDR/wMPDR (B, F, C, C)
                Stacked observation covariance for WPD (B,F,(btaps+1)*C,(btaps+1)*C)
            psd_speech (torch.complex64/ComplexTensor):
                Speech covariance matrix (B, F, C, C)
            psd_distortion (torch.complex64/ComplexTensor):
                Noise covariance matrix (B, F, C, C)
            rtf_mat (torch.complex64/ComplexTensor):
                RTF matrix (B, F, C, num_spk)
            spk (int): speaker index
        Return:
            enhanced (torch.complex64/ComplexTensor): (B, F, T)
            ws (torch.complex64/ComplexTensor): (B, F) or (B, F, (btaps+1)*C)
        """
        # u: (B, C)
        if self.ref_channel < 0:
            u, _ = self.ref(psd_speech.to(dtype=data.dtype), ilens)
            u = u.double()
        else:
            if self.beamformer_type.endswith("_souden"):
                # (optional) Create onehot vector for fixed reference microphone
                u = torch.zeros(
                    *(data.size()[:-3] + (data.size(-2),)),
                    device=data.device,
                    dtype=torch.double
                )
                u[..., self.ref_channel].fill_(1)
            else:
                # for simplifying computation in RTF-based beamforming
                u = self.ref_channel

        if self.beamformer_type in ("mvdr", "mpdr", "wmpdr"):
            ws = get_mvdr_vector_with_rtf(
                to_double(psd_n),
                to_double(psd_speech),
                to_double(psd_distortion),
                iterations=self.rtf_iterations,
                reference_vector=u,
                normalize_ref_channel=self.ref_channel,
                use_torch_solver=self.use_torch_solver,
                diagonal_loading=self.diagonal_loading,
                diag_eps=self.diag_eps,
            )
            enhanced = apply_beamforming_vector(ws, to_double(data))
        elif self.beamformer_type == "mvdr_tfs":
            assert isinstance(psd_n, (list, tuple))
            ws = [
                get_mvdr_vector_with_rtf(
                    to_double(psd_n_i),
                    to_double(psd_speech),
                    to_double(psd_distortion),
                    iterations=self.rtf_iterations,
                    reference_vector=u,
                    normalize_ref_channel=self.ref_channel,
                    use_torch_solver=self.use_torch_solver,
                    diagonal_loading=self.diagonal_loading,
                    diag_eps=self.diag_eps,
                )
                for psd_n_i in psd_n
            ]
            enhanced = stack([apply_beamforming_vector(w, to_double(data)) for w in ws])
            with torch.no_grad():
                index = enhanced.abs().argmin(dim=0, keepdims=True)
            enhanced = enhanced.gather(0, index).squeeze(0)
            ws = stack(ws, dim=0)
        elif self.beamformer_type in (
            "mpdr_souden",
            "mvdr_souden",
            "wmpdr_souden",
        ):
            ws = get_mvdr_vector(
                to_double(psd_speech),
                to_double(psd_n),
                u,
                use_torch_solver=self.use_torch_solver,
                diagonal_loading=self.diagonal_loading,
                diag_eps=self.diag_eps,
            )
            enhanced = apply_beamforming_vector(ws, to_double(data))
        elif self.beamformer_type == "mvdr_tfs_souden":
            assert isinstance(psd_n, (list, tuple))
            ws = [
                get_mvdr_vector(
                    to_double(psd_speech),
                    to_double(psd_n_i),
                    u,
                    use_torch_solver=self.use_torch_solver,
                    diagonal_loading=self.diagonal_loading,
                    diag_eps=self.diag_eps,
                )
                for psd_n_i in psd_n
            ]
            enhanced = stack([apply_beamforming_vector(w, to_double(data)) for w in ws])
            with torch.no_grad():
                index = enhanced.abs().argmin(dim=0, keepdims=True)
            enhanced = enhanced.gather(0, index).squeeze(0)
            ws = stack(ws, dim=0)
        elif self.beamformer_type == "wpd":
            ws = get_WPD_filter_with_rtf(
                to_double(psd_n),
                to_double(psd_speech),
                to_double(psd_distortion),
                iterations=self.rtf_iterations,
                reference_vector=u,
                normalize_ref_channel=self.ref_channel,
                use_torch_solver=self.use_torch_solver,
                diagonal_loading=self.diagonal_loading,
                diag_eps=self.diag_eps,
            )
            enhanced = perform_WPD_filtering(
                ws, to_double(data), self.bdelay, self.btaps
            )
        elif self.beamformer_type == "wpd_souden":
            ws = get_WPD_filter_v2(
                to_double(psd_speech),
                to_double(psd_n),
                u,
                diagonal_loading=self.diagonal_loading,
                diag_eps=self.diag_eps,
            )
            enhanced = perform_WPD_filtering(
                ws, to_double(data), self.bdelay, self.btaps
            )
        elif self.beamformer_type in ("mwf", "wmwf"):
            ws = get_mwf_vector(
                to_double(psd_speech),
                to_double(psd_n),
                u,
                use_torch_solver=self.use_torch_solver,
                diagonal_loading=self.diagonal_loading,
                diag_eps=self.diag_eps,
            )
            enhanced = apply_beamforming_vector(ws, to_double(data))
        elif self.beamformer_type == "sdw_mwf":
            ws = get_sdw_mwf_vector(
                to_double(psd_speech),
                to_double(psd_n),
                u,
                denoising_weight=self.mwf_mu,
                use_torch_solver=self.use_torch_solver,
                diagonal_loading=self.diagonal_loading,
                diag_eps=self.diag_eps,
            )
            enhanced = apply_beamforming_vector(ws, to_double(data))
        elif self.beamformer_type == "r1mwf":
            ws = get_rank1_mwf_vector(
                to_double(psd_speech),
                to_double(psd_n),
                u,
                denoising_weight=self.mwf_mu,
                use_torch_solver=self.use_torch_solver,
                diagonal_loading=self.diagonal_loading,
                diag_eps=self.diag_eps,
            )
            enhanced = apply_beamforming_vector(ws, to_double(data))
        elif self.beamformer_type in ("lcmp", "wlcmp", "lcmv"):
            ws = get_lcmv_vector_with_rtf(
                to_double(psd_n),
                to_double(rtf_mat),
                reference_vector=spk,
                use_torch_solver=self.use_torch_solver,
                diagonal_loading=self.diagonal_loading,
                diag_eps=self.diag_eps,
            )
            enhanced = apply_beamforming_vector(ws, to_double(data))
        elif self.beamformer_type.startswith("gev"):
            ws = get_gev_vector(
                to_double(psd_n),
                to_double(psd_speech),
                mode="power",
                diagonal_loading=self.diagonal_loading,
                diag_eps=self.diag_eps,
            )
            enhanced = apply_beamforming_vector(ws, to_double(data))
            if self.beamformer_type == "gev_ban":
                gain = blind_analytic_normalization(ws, to_double(psd_n))
                enhanced = enhanced * gain.unsqueeze(-1)
        else:
            raise ValueError(
                "Not supporting beamformer_type={}".format(self.beamformer_type)
            )

        return enhanced.to(dtype=data.dtype), ws.to(dtype=data.dtype)
Beispiel #2
0
    def forward(
        self,
        data: Union[torch.Tensor, ComplexTensor],
        ilens: torch.LongTensor,
        powers: Optional[List[torch.Tensor]] = None,
        oracle_masks: Optional[List[torch.Tensor]] = None,
    ) -> Tuple[Union[torch.Tensor, ComplexTensor], torch.LongTensor, torch.Tensor]:
        """DNN_Beamformer forward function.

        Notation:
            B: Batch
            C: Channel
            T: Time or Sequence length
            F: Freq

        Args:
            data (torch.complex64/ComplexTensor): (B, T, C, F)
            ilens (torch.Tensor): (B,)
            powers (List[torch.Tensor] or None): used for wMPDR or WPD (B, F, T)
            oracle_masks (List[torch.Tensor] or None): oracle masks (B, F, C, T)
                if not None, oracle_masks will be used instead of self.mask
        Returns:
            enhanced (torch.complex64/ComplexTensor): (B, T, F)
            ilens (torch.Tensor): (B,)
            masks (torch.Tensor): (B, T, C, F)
        """
        # data (B, T, C, F) -> (B, F, C, T)
        data = data.permute(0, 3, 2, 1)
        data_d = to_double(data)

        # mask: [(B, F, C, T)]
        if oracle_masks is not None:
            masks = oracle_masks
        else:
            masks, _ = self.mask(data, ilens)
        assert self.nmask == len(masks), len(masks)
        # floor masks to increase numerical stability
        if self.mask_flooring:
            masks = [torch.clamp(m, min=self.flooring_thres) for m in masks]

        if self.num_spk == 1:  # single-speaker case
            if self.use_noise_mask:
                # (mask_speech, mask_noise)
                mask_speech, mask_noise = masks
            else:
                # (mask_speech,)
                mask_speech = masks[0]
                mask_noise = 1 - mask_speech

            if self.beamformer_type in ("lcmv", "lcmp", "wlcmp"):
                raise NotImplementedError("Single source is not supported yet")
            beamformer_stats = prepare_beamformer_stats(
                data_d,
                [mask_speech],
                mask_noise,
                powers=powers,
                beamformer_type=self.beamformer_type,
                bdelay=self.bdelay,
                btaps=self.btaps,
                eps=self.eps,
            )

            if self.beamformer_type in ("mvdr", "mpdr", "wmpdr", "wpd"):
                enhanced, ws = self.apply_beamforming(
                    data,
                    ilens,
                    beamformer_stats["psd_n"],
                    beamformer_stats["psd_speech"],
                    psd_distortion=beamformer_stats["psd_distortion"],
                )
            elif (
                self.beamformer_type.endswith("_souden")
                or self.beamformer_type == "mwf"
                or self.beamformer_type == "wmwf"
                or self.beamformer_type == "sdw_mwf"
                or self.beamformer_type == "r1mwf"
                or self.beamformer_type.startswith("gev")
            ):
                enhanced, ws = self.apply_beamforming(
                    data,
                    ilens,
                    beamformer_stats["psd_n"],
                    beamformer_stats["psd_speech"],
                )
            else:
                raise ValueError(
                    "Not supporting beamformer_type={}".format(self.beamformer_type)
                )

            # (..., F, T) -> (..., T, F)
            enhanced = enhanced.transpose(-1, -2)
        else:  # multi-speaker case
            if self.use_noise_mask:
                # (mask_speech1, ..., mask_noise)
                mask_speech = list(masks[:-1])
                mask_noise = masks[-1]
            else:
                # (mask_speech1, ..., mask_speechX)
                mask_speech = list(masks)
                mask_noise = None

            beamformer_stats = prepare_beamformer_stats(
                data_d,
                mask_speech,
                mask_noise,
                powers=powers,
                beamformer_type=self.beamformer_type,
                bdelay=self.bdelay,
                btaps=self.btaps,
                eps=self.eps,
            )
            if self.beamformer_type in ("lcmv", "lcmp", "wlcmp"):
                rtf_mat = get_rtf_matrix(
                    beamformer_stats["psd_speech"],
                    beamformer_stats["psd_distortion"],
                    diagonal_loading=self.diagonal_loading,
                    ref_channel=self.ref_channel,
                    rtf_iterations=self.rtf_iterations,
                    use_torch_solver=self.use_torch_solver,
                    diag_eps=self.diag_eps,
                )

            enhanced, ws = [], []
            for i in range(self.num_spk):
                # treat all other speakers' psd_speech as noises
                if self.beamformer_type in ("mvdr", "mvdr_tfs", "wmpdr", "wpd"):
                    enh, w = self.apply_beamforming(
                        data,
                        ilens,
                        beamformer_stats["psd_n"][i],
                        beamformer_stats["psd_speech"][i],
                        psd_distortion=beamformer_stats["psd_distortion"][i],
                    )
                elif self.beamformer_type in (
                    "mvdr_souden",
                    "mvdr_tfs_souden",
                    "wmpdr_souden",
                    "wpd_souden",
                    "wmwf",
                    "sdw_mwf",
                    "r1mwf",
                    "gev",
                    "gev_ban",
                ):
                    enh, w = self.apply_beamforming(
                        data,
                        ilens,
                        beamformer_stats["psd_n"][i],
                        beamformer_stats["psd_speech"][i],
                    )
                elif self.beamformer_type == "mpdr":
                    enh, w = self.apply_beamforming(
                        data,
                        ilens,
                        beamformer_stats["psd_n"],
                        beamformer_stats["psd_speech"][i],
                        psd_distortion=beamformer_stats["psd_distortion"][i],
                    )
                elif self.beamformer_type in ("mpdr_souden", "mwf"):
                    enh, w = self.apply_beamforming(
                        data,
                        ilens,
                        beamformer_stats["psd_n"],
                        beamformer_stats["psd_speech"][i],
                    )
                elif self.beamformer_type == "lcmp":
                    enh, w = self.apply_beamforming(
                        data,
                        ilens,
                        beamformer_stats["psd_n"],
                        beamformer_stats["psd_speech"][i],
                        rtf_mat=rtf_mat,
                        spk=i,
                    )
                elif self.beamformer_type in ("lcmv", "wlcmp"):
                    enh, w = self.apply_beamforming(
                        data,
                        ilens,
                        beamformer_stats["psd_n"][i],
                        beamformer_stats["psd_speech"][i],
                        rtf_mat=rtf_mat,
                        spk=i,
                    )
                else:
                    raise ValueError(
                        "Not supporting beamformer_type={}".format(self.beamformer_type)
                    )

                # (..., F, T) -> (..., T, F)
                enh = enh.transpose(-1, -2)
                enhanced.append(enh)
                ws.append(w)

        # (..., F, C, T) -> (..., T, C, F)
        masks = [m.transpose(-1, -3) for m in masks]
        return enhanced, ilens, masks
Beispiel #3
0
    def forward(
        self, data: Union[torch.Tensor, ComplexTensor], ilens: torch.LongTensor
    ) -> Tuple[Union[torch.Tensor, ComplexTensor], torch.LongTensor, Union[
            torch.Tensor, ComplexTensor], ]:
        """DNN_WPE forward function.

        Notation:
            B: Batch
            C: Channel
            T: Time or Sequence length
            F: Freq or Some dimension of the feature vector

        Args:
            data: (B, T, C, F)
            ilens: (B,)
        Returns:
            enhanced (torch.Tensor or List[torch.Tensor]): (B, T, C, F)
            ilens: (B,)
            masks (torch.Tensor or List[torch.Tensor]): (B, T, C, F)
            power (List[torch.Tensor]): (B, F, T)
        """
        # (B, T, C, F) -> (B, F, C, T)
        data = data.permute(0, 3, 2, 1)
        enhanced = [data for i in range(self.nmask)]
        masks = None
        power = None

        for i in range(self.iterations):
            # Calculate power: (..., C, T)
            power = [enh.real**2 + enh.imag**2 for enh in enhanced]
            if i == 0 and self.use_dnn_mask:
                # mask: (B, F, C, T)
                masks, _ = self.mask_est(data, ilens)
                # floor masks to increase numerical stability
                if self.mask_flooring:
                    masks = [m.clamp(min=self.flooring_thres) for m in masks]
                if self.normalization:
                    # Normalize along T
                    masks = [m / m.sum(dim=-1, keepdim=True) for m in masks]
                # (..., C, T) * (..., C, T) -> (..., C, T)
                power = [p * masks[i] for i, p in enumerate(power)]

            # Averaging along the channel axis: (..., C, T) -> (..., T)
            power = [p.mean(dim=-2).clamp(min=self.eps) for p in power]

            # enhanced: (..., C, T) -> (..., C, T)
            # NOTE(kamo): Calculate in double precision
            enhanced = [
                wpe_one_iteration(
                    to_double(data.contiguous()),
                    to_double(p),
                    taps=self.taps,
                    delay=self.delay,
                    inverse_power=self.inverse_power,
                ) for p in power
            ]
            enhanced = [
                enh.to(dtype=data.dtype).masked_fill(
                    make_pad_mask(ilens, enh.real), 0) for enh in enhanced
            ]

        # (B, F, C, T) -> (B, T, C, F)
        enhanced = [enh.permute(0, 3, 2, 1) for enh in enhanced]
        if masks is not None:
            masks = ([m.transpose(-1, -3) for m in masks]
                     if self.nmask > 1 else masks[0].transpose(-1, -3))
        if self.nmask == 1:
            enhanced = enhanced[0]

        return enhanced, ilens, masks, power
Beispiel #4
0
        def apply_beamforming(data, ilens, psd_n, psd_speech, psd_distortion=None):
            """Beamforming with the provided statistics.

            Args:
                data (torch.complex64/ComplexTensor): (B, F, C, T)
                ilens (torch.Tensor): (B,)
                psd_n (torch.complex64/ComplexTensor):
                    Noise covariance matrix for MVDR (B, F, C, C)
                    Observation covariance matrix for MPDR/wMPDR (B, F, C, C)
                    Stacked observation covariance for WPD (B,F,(btaps+1)*C,(btaps+1)*C)
                psd_speech (torch.complex64/ComplexTensor):
                    Speech covariance matrix (B, F, C, C)
                psd_distortion (torch.complex64/ComplexTensor):
                    Noise covariance matrix (B, F, C, C)
            Return:
                enhanced (torch.complex64/ComplexTensor): (B, F, T)
                ws (torch.complex64/ComplexTensor): (B, F) or (B, F, (btaps+1)*C)
            """
            # u: (B, C)
            if self.ref_channel < 0:
                u, _ = self.ref(psd_speech.to(dtype=data.dtype), ilens)
                u = u.double()
            else:
                if self.beamformer_type.endswith("_souden"):
                    # (optional) Create onehot vector for fixed reference microphone
                    u = torch.zeros(
                        *(data.size()[:-3] + (data.size(-2),)),
                        device=data.device,
                        dtype=torch.double
                    )
                    u[..., self.ref_channel].fill_(1)
                else:
                    # for simplifying computation in RTF-based beamforming
                    u = self.ref_channel

            if self.beamformer_type in ("mvdr", "mpdr", "wmpdr"):
                ws = get_mvdr_vector_with_rtf(
                    to_double(psd_n),
                    to_double(psd_speech),
                    to_double(psd_distortion),
                    iterations=self.rtf_iterations,
                    reference_vector=u,
                    normalize_ref_channel=self.ref_channel,
                    use_torch_solver=self.use_torch_solver,
                    diagonal_loading=self.diagonal_loading,
                    diag_eps=self.diag_eps,
                )
                enhanced = apply_beamforming_vector(ws, to_double(data))
            elif self.beamformer_type in ("mpdr_souden", "mvdr_souden", "wmpdr_souden"):
                ws = get_mvdr_vector(
                    to_double(psd_speech),
                    to_double(psd_n),
                    u,
                    use_torch_solver=self.use_torch_solver,
                    diagonal_loading=self.diagonal_loading,
                    diag_eps=self.diag_eps,
                )
                enhanced = apply_beamforming_vector(ws, to_double(data))
            elif self.beamformer_type == "wpd":
                ws = get_WPD_filter_with_rtf(
                    to_double(psd_n),
                    to_double(psd_speech),
                    to_double(psd_distortion),
                    iterations=self.rtf_iterations,
                    reference_vector=u,
                    normalize_ref_channel=self.ref_channel,
                    use_torch_solver=self.use_torch_solver,
                    diagonal_loading=self.diagonal_loading,
                    diag_eps=self.diag_eps,
                )
                enhanced = perform_WPD_filtering(
                    ws, to_double(data), self.bdelay, self.btaps
                )
            elif self.beamformer_type == "wpd_souden":
                ws = get_WPD_filter_v2(
                    to_double(psd_speech),
                    to_double(psd_n),
                    u,
                    diagonal_loading=self.diagonal_loading,
                    diag_eps=self.diag_eps,
                )
                enhanced = perform_WPD_filtering(
                    ws, to_double(data), self.bdelay, self.btaps
                )
            else:
                raise ValueError(
                    "Not supporting beamformer_type={}".format(self.beamformer_type)
                )

            return enhanced.to(dtype=data.dtype), ws.to(dtype=data.dtype)
Beispiel #5
0
    def forward(
        self,
        data: Union[torch.Tensor, ComplexTensor],
        ilens: torch.LongTensor,
        powers: Union[List[torch.Tensor], None] = None,
    ) -> Tuple[Union[torch.Tensor, ComplexTensor], torch.LongTensor, torch.Tensor]:
        """DNN_Beamformer forward function.

        Notation:
            B: Batch
            C: Channel
            T: Time or Sequence length
            F: Freq

        Args:
            data (torch.complex64/ComplexTensor): (B, T, C, F)
            ilens (torch.Tensor): (B,)
            powers (List[torch.Tensor] or None): used for wMPDR or WPD (B, F, T)
        Returns:
            enhanced (torch.complex64/ComplexTensor): (B, T, F)
            ilens (torch.Tensor): (B,)
            masks (torch.Tensor): (B, T, C, F)
        """

        def apply_beamforming(data, ilens, psd_n, psd_speech, psd_distortion=None):
            """Beamforming with the provided statistics.

            Args:
                data (torch.complex64/ComplexTensor): (B, F, C, T)
                ilens (torch.Tensor): (B,)
                psd_n (torch.complex64/ComplexTensor):
                    Noise covariance matrix for MVDR (B, F, C, C)
                    Observation covariance matrix for MPDR/wMPDR (B, F, C, C)
                    Stacked observation covariance for WPD (B,F,(btaps+1)*C,(btaps+1)*C)
                psd_speech (torch.complex64/ComplexTensor):
                    Speech covariance matrix (B, F, C, C)
                psd_distortion (torch.complex64/ComplexTensor):
                    Noise covariance matrix (B, F, C, C)
            Return:
                enhanced (torch.complex64/ComplexTensor): (B, F, T)
                ws (torch.complex64/ComplexTensor): (B, F) or (B, F, (btaps+1)*C)
            """
            # u: (B, C)
            if self.ref_channel < 0:
                u, _ = self.ref(psd_speech.to(dtype=data.dtype), ilens)
                u = u.double()
            else:
                if self.beamformer_type.endswith("_souden"):
                    # (optional) Create onehot vector for fixed reference microphone
                    u = torch.zeros(
                        *(data.size()[:-3] + (data.size(-2),)),
                        device=data.device,
                        dtype=torch.double
                    )
                    u[..., self.ref_channel].fill_(1)
                else:
                    # for simplifying computation in RTF-based beamforming
                    u = self.ref_channel

            if self.beamformer_type in ("mvdr", "mpdr", "wmpdr"):
                ws = get_mvdr_vector_with_rtf(
                    to_double(psd_n),
                    to_double(psd_speech),
                    to_double(psd_distortion),
                    iterations=self.rtf_iterations,
                    reference_vector=u,
                    normalize_ref_channel=self.ref_channel,
                    use_torch_solver=self.use_torch_solver,
                    diagonal_loading=self.diagonal_loading,
                    diag_eps=self.diag_eps,
                )
                enhanced = apply_beamforming_vector(ws, to_double(data))
            elif self.beamformer_type in ("mpdr_souden", "mvdr_souden", "wmpdr_souden"):
                ws = get_mvdr_vector(
                    to_double(psd_speech),
                    to_double(psd_n),
                    u,
                    use_torch_solver=self.use_torch_solver,
                    diagonal_loading=self.diagonal_loading,
                    diag_eps=self.diag_eps,
                )
                enhanced = apply_beamforming_vector(ws, to_double(data))
            elif self.beamformer_type == "wpd":
                ws = get_WPD_filter_with_rtf(
                    to_double(psd_n),
                    to_double(psd_speech),
                    to_double(psd_distortion),
                    iterations=self.rtf_iterations,
                    reference_vector=u,
                    normalize_ref_channel=self.ref_channel,
                    use_torch_solver=self.use_torch_solver,
                    diagonal_loading=self.diagonal_loading,
                    diag_eps=self.diag_eps,
                )
                enhanced = perform_WPD_filtering(
                    ws, to_double(data), self.bdelay, self.btaps
                )
            elif self.beamformer_type == "wpd_souden":
                ws = get_WPD_filter_v2(
                    to_double(psd_speech),
                    to_double(psd_n),
                    u,
                    diagonal_loading=self.diagonal_loading,
                    diag_eps=self.diag_eps,
                )
                enhanced = perform_WPD_filtering(
                    ws, to_double(data), self.bdelay, self.btaps
                )
            else:
                raise ValueError(
                    "Not supporting beamformer_type={}".format(self.beamformer_type)
                )

            return enhanced.to(dtype=data.dtype), ws.to(dtype=data.dtype)

        if isinstance(data, ComplexTensor):
            complex_wrapper = FC
        elif is_torch_1_9_plus and torch.is_complex(data):
            complex_wrapper = torch
        else:
            raise ValueError(
                "Please update your PyTorch version to 1.8+ for compelx support."
            )
        # data (B, T, C, F) -> (B, F, C, T)
        data = data.permute(0, 3, 2, 1)
        data_d = to_double(data)

        # mask: [(B, F, C, T)]
        masks, _ = self.mask(data, ilens)
        assert self.nmask == len(masks), len(masks)
        # floor masks to increase numerical stability
        if self.mask_flooring:
            masks = [torch.clamp(m, min=self.flooring_thres) for m in masks]

        if self.num_spk == 1:  # single-speaker case
            if self.use_noise_mask:
                # (mask_speech, mask_noise)
                mask_speech, mask_noise = masks
            else:
                # (mask_speech,)
                mask_speech = masks[0]
                mask_noise = 1 - mask_speech

            if self.beamformer_type.startswith(
                "wmpdr"
            ) or self.beamformer_type.startswith("wpd"):
                if powers is None:
                    power_input = data_d.real ** 2 + data_d.imag ** 2
                    # Averaging along the channel axis: (..., C, T) -> (..., T)
                    powers = (power_input * mask_speech.double()).mean(dim=-2)
                else:
                    assert len(powers) == 1, len(powers)
                    powers = powers[0]
                inverse_power = 1 / torch.clamp(powers, min=self.eps)

            psd_speech = get_power_spectral_density_matrix(data_d, mask_speech.double())
            if mask_noise is not None and (
                self.beamformer_type == "mvdr_souden"
                or not self.beamformer_type.endswith("_souden")
            ):
                # MVDR or other RTF-based formulas
                psd_noise = get_power_spectral_density_matrix(
                    data_d, mask_noise.double()
                )
            if self.beamformer_type == "mvdr":
                enhanced, ws = apply_beamforming(
                    data, ilens, psd_noise, psd_speech, psd_distortion=psd_noise
                )
            elif self.beamformer_type == "mvdr_souden":
                enhanced, ws = apply_beamforming(data, ilens, psd_noise, psd_speech)
            elif self.beamformer_type == "mpdr":
                psd_observed = complex_wrapper.einsum(
                    "...ct,...et->...ce", [data_d, data_d.conj()]
                )
                enhanced, ws = apply_beamforming(
                    data, ilens, psd_observed, psd_speech, psd_distortion=psd_noise
                )
            elif self.beamformer_type == "mpdr_souden":
                psd_observed = complex_wrapper.einsum(
                    "...ct,...et->...ce", [data_d, data_d.conj()]
                )
                enhanced, ws = apply_beamforming(data, ilens, psd_observed, psd_speech)
            elif self.beamformer_type == "wmpdr":
                psd_observed = complex_wrapper.einsum(
                    "...ct,...et->...ce",
                    [data_d * inverse_power[..., None, :], data_d.conj()],
                )
                enhanced, ws = apply_beamforming(
                    data, ilens, psd_observed, psd_speech, psd_distortion=psd_noise
                )
            elif self.beamformer_type == "wmpdr_souden":
                psd_observed = complex_wrapper.einsum(
                    "...ct,...et->...ce",
                    [data_d * inverse_power[..., None, :], data_d.conj()],
                )
                enhanced, ws = apply_beamforming(data, ilens, psd_observed, psd_speech)
            elif self.beamformer_type == "wpd":
                psd_observed_bar = get_covariances(
                    data_d, inverse_power, self.bdelay, self.btaps, get_vector=False
                )
                enhanced, ws = apply_beamforming(
                    data, ilens, psd_observed_bar, psd_speech, psd_distortion=psd_noise
                )
            elif self.beamformer_type == "wpd_souden":
                psd_observed_bar = get_covariances(
                    data_d, inverse_power, self.bdelay, self.btaps, get_vector=False
                )
                enhanced, ws = apply_beamforming(
                    data, ilens, psd_observed_bar, psd_speech
                )
            else:
                raise ValueError(
                    "Not supporting beamformer_type={}".format(self.beamformer_type)
                )

            # (..., F, T) -> (..., T, F)
            enhanced = enhanced.transpose(-1, -2)
        else:  # multi-speaker case
            if self.use_noise_mask:
                # (mask_speech1, ..., mask_noise)
                mask_speech = list(masks[:-1])
                mask_noise = masks[-1]
            else:
                # (mask_speech1, ..., mask_speechX)
                mask_speech = list(masks)
                mask_noise = None

            if self.beamformer_type.startswith(
                "wmpdr"
            ) or self.beamformer_type.startswith("wpd"):
                if powers is None:
                    power_input = data_d.real ** 2 + data_d.imag ** 2
                    # Averaging along the channel axis: (..., C, T) -> (..., T)
                    powers = [
                        (power_input * m.double()).mean(dim=-2) for m in mask_speech
                    ]
                else:
                    assert len(powers) == self.num_spk, len(powers)
                inverse_power = [1 / torch.clamp(p, min=self.eps) for p in powers]

            psd_speeches = [
                get_power_spectral_density_matrix(data_d, mask.double())
                for mask in mask_speech
            ]
            if mask_noise is not None and (
                self.beamformer_type == "mvdr_souden"
                or not self.beamformer_type.endswith("_souden")
            ):
                # MVDR or other RTF-based formulas
                psd_noise = get_power_spectral_density_matrix(
                    data_d, mask_noise.double()
                )
            if self.beamformer_type in ("mpdr", "mpdr_souden"):
                psd_observed = complex_wrapper.einsum(
                    "...ct,...et->...ce", [data_d, data_d.conj()]
                )
            elif self.beamformer_type in ("wmpdr", "wmpdr_souden"):
                psd_observed = [
                    complex_wrapper.einsum(
                        "...ct,...et->...ce",
                        [data_d * inv_p[..., None, :], data_d.conj()],
                    )
                    for inv_p in inverse_power
                ]
            elif self.beamformer_type in ("wpd", "wpd_souden"):
                psd_observed_bar = [
                    get_covariances(
                        data_d, inv_p, self.bdelay, self.btaps, get_vector=False
                    )
                    for inv_p in inverse_power
                ]

            enhanced, ws = [], []
            for i in range(self.num_spk):
                psd_speech = psd_speeches.pop(i)
                if (
                    self.beamformer_type == "mvdr_souden"
                    or not self.beamformer_type.endswith("_souden")
                ):
                    psd_noise_i = (
                        psd_noise + sum(psd_speeches)
                        if mask_noise is not None
                        else sum(psd_speeches)
                    )
                # treat all other speakers' psd_speech as noises
                if self.beamformer_type == "mvdr":
                    enh, w = apply_beamforming(
                        data, ilens, psd_noise_i, psd_speech, psd_distortion=psd_noise_i
                    )
                elif self.beamformer_type == "mvdr_souden":
                    enh, w = apply_beamforming(data, ilens, psd_noise_i, psd_speech)
                elif self.beamformer_type == "mpdr":
                    enh, w = apply_beamforming(
                        data,
                        ilens,
                        psd_observed,
                        psd_speech,
                        psd_distortion=psd_noise_i,
                    )
                elif self.beamformer_type == "mpdr_souden":
                    enh, w = apply_beamforming(data, ilens, psd_observed, psd_speech)
                elif self.beamformer_type == "wmpdr":
                    enh, w = apply_beamforming(
                        data,
                        ilens,
                        psd_observed[i],
                        psd_speech,
                        psd_distortion=psd_noise_i,
                    )
                elif self.beamformer_type == "wmpdr_souden":
                    enh, w = apply_beamforming(data, ilens, psd_observed[i], psd_speech)
                elif self.beamformer_type == "wpd":
                    enh, w = apply_beamforming(
                        data,
                        ilens,
                        psd_observed_bar[i],
                        psd_speech,
                        psd_distortion=psd_noise_i,
                    )
                elif self.beamformer_type == "wpd_souden":
                    enh, w = apply_beamforming(
                        data, ilens, psd_observed_bar[i], psd_speech
                    )
                else:
                    raise ValueError(
                        "Not supporting beamformer_type={}".format(self.beamformer_type)
                    )
                psd_speeches.insert(i, psd_speech)

                # (..., F, T) -> (..., T, F)
                enh = enh.transpose(-1, -2)
                enhanced.append(enh)
                ws.append(w)

        # (..., F, C, T) -> (..., T, C, F)
        masks = [m.transpose(-1, -3) for m in masks]
        return enhanced, ilens, masks
Beispiel #6
0
def prepare_beamformer_stats(
    signal,
    masks_speech,
    mask_noise,
    powers=None,
    beamformer_type="mvdr",
    bdelay=3,
    btaps=5,
    eps=1e-6,
):
    """Prepare necessary statistics for constructing the specified beamformer.

    Args:
        signal (torch.complex64/ComplexTensor): (..., F, C, T)
        masks_speech (List[torch.Tensor]): (..., F, C, T) masks for all speech sources
        mask_noise (torch.Tensor): (..., F, C, T) noise mask
        powers (List[torch.Tensor]): powers for all speech sources (..., F, T)
                                     used for wMPDR or WPD beamformers
        beamformer_type (str): one of the pre-defined beamformer types
        bdelay (int): delay factor, used for WPD beamformser
        btaps (int): number of filter taps, used for WPD beamformser
        eps (torch.Tensor): tiny constant
    Returns:
        beamformer_stats (dict): a dictionary containing all necessary statistics
            e.g. "psd_n", "psd_speech", "psd_distortion"
            Note:
            * When `masks_speech` is a tensor or a single-element list, all returned
              statistics are tensors;
            * When `masks_speech` is a multi-element list, some returned statistics
              can be a list, e.g., "psd_n" for MVDR, "psd_speech" and "psd_distortion".

    """
    from espnet2.enh.layers.dnn_beamformer import BEAMFORMER_TYPES

    assert beamformer_type in BEAMFORMER_TYPES, "%s is not supported yet"

    if isinstance(masks_speech, (list, tuple)):
        masks_speech = [to_double(m) for m in masks_speech]
    else:
        masks_speech = [to_double(masks_speech)]
    num_spk = len(masks_speech)

    if (beamformer_type.startswith("wmpdr")
            or beamformer_type.startswith("wpd") or beamformer_type == "wlcmp"
            or beamformer_type == "wmwf"):
        if powers is None:
            power_input = signal.real**2 + signal.imag**2
            # Averaging along the channel axis: (..., C, T) -> (..., T)
            powers = [(power_input * m).mean(dim=-2) for m in masks_speech]
        else:
            assert len(powers) == num_spk, (len(powers), num_spk)
        inverse_powers = [1 / torch.clamp(p, min=eps) for p in powers]

    psd_speeches = [
        get_power_spectral_density_matrix(signal, m) for m in masks_speech
    ]
    if (beamformer_type == "mvdr_souden" or beamformer_type == "sdw_mwf"
            or beamformer_type == "r1mwf"
            or beamformer_type.startswith("mvdr_tfs")
            or not beamformer_type.endswith("_souden")):
        # MVDR or other RTF-based formulas
        if mask_noise is not None:
            psd_bg = get_power_spectral_density_matrix(signal,
                                                       to_double(mask_noise))
        if num_spk == 1:
            assert mask_noise is not None
            psd_noise = psd_bg
        else:
            psd_noise = []
            for i in range(num_spk):
                if beamformer_type.startswith("mvdr_tfs"):
                    # NOTE: psd_noise is a list only for this beamformer
                    psd_noise_i = [
                        psd for j, psd in enumerate(psd_speeches) if j != i
                    ]
                else:
                    psd_sum = sum(psd for j, psd in enumerate(psd_speeches)
                                  if j != i)
                    psd_noise_i = (psd_bg + psd_sum
                                   if mask_noise is not None else psd_sum)
                psd_noise.append(psd_noise_i)

    if beamformer_type in (
            "mvdr",
            "mvdr_souden",
            "mvdr_tfs_souden",
            "sdw_mwf",
            "r1mwf",
            "lcmv",
            "gev",
            "gev_ban",
    ):
        psd_n = psd_noise
    elif beamformer_type == "mvdr_tfs":
        psd_n = psd_noise
        psd_noise = [sum(psd_noise_i) for psd_noise_i in psd_noise]
    elif beamformer_type in ("mpdr", "mpdr_souden", "lcmp", "mwf"):
        psd_n = einsum("...ct,...et->...ce", signal, signal.conj())
    elif beamformer_type in ("wmpdr", "wmpdr_souden", "wlcmp", "wmwf"):
        psd_n = [
            einsum(
                "...ct,...et->...ce",
                signal * inv_p[..., None, :],
                signal.conj(),
            ) for inv_p in inverse_powers
        ]
    elif beamformer_type in ("wpd", "wpd_souden"):
        psd_n = [
            get_covariances(signal, inv_p, bdelay, btaps, get_vector=False)
            for inv_p in inverse_powers
        ]

    if num_spk == 1:
        psd_speeches = psd_speeches[0]
        if isinstance(psd_n, (list, tuple)):
            psd_n = psd_n[0]

    if beamformer_type in (
            "mvdr",
            "mpdr",
            "wmpdr",
            "wpd",
            "lcmp",
            "wlcmp",
            "lcmv",
            "mvdr_tfs",
    ):
        return {
            "psd_n": psd_n,
            "psd_speech": psd_speeches,
            "psd_distortion": psd_noise
        }
    elif (beamformer_type.endswith("_souden")
          or beamformer_type.startswith("gev") or beamformer_type == "mwf"
          or beamformer_type == "wmwf" or beamformer_type == "sdw_mwf"
          or beamformer_type == "r1mwf"):
        return {"psd_n": psd_n, "psd_speech": psd_speeches}