Пример #1
0
        def apply_beamforming(data, ilens, psd_speech, psd_noise):
            # u: (B, C)
            if self.ref_channel < 0:
                u, _ = self.ref(psd_speech, ilens)
            else:
                # (optional) Create onehot vector for fixed reference microphone
                u = torch.zeros(*(data.size()[:-3] + (data.size(-2), )),
                                device=data.device)
                u[..., self.ref_channel].fill_(1)

            ws = get_mvdr_vector(psd_speech, psd_noise, u)
            enhanced = apply_beamforming_vector(ws, data)

            return enhanced, ws
Пример #2
0
    def forward(self, data: ComplexTensor, ilens: torch.LongTensor) \
            -> Tuple[ComplexTensor, torch.LongTensor, ComplexTensor]:
        """The forward function

        Notation:
            B: Batch
            C: Channel
            T: Time or Sequence length
            F: Freq

        Args:
            data (ComplexTensor): (B, T, C, F)
            ilens (torch.Tensor): (B,)
        Returns:
            enhanced (ComplexTensor): (B, T, F)
            ilens (torch.Tensor): (B,)

        """
        # data (B, T, C, F) -> (B, F, C, T)
        data = data.permute(0, 3, 2, 1)

        # mask: (B, F, C, T)
        (mask_speech, mask_noise), _ = self.mask(data, ilens)

        psd_speech = get_power_spectral_density_matrix(data, mask_speech)
        psd_noise = get_power_spectral_density_matrix(data, mask_noise)

        # u: (B, C)
        if self.ref_channel < 0:
            u, _ = self.ref(psd_speech, ilens)
        else:
            # (optional) Create onehot vector for fixed reference microphone
            u = torch.zeros(*(data.size()[:-3] + (data.size(-2),)),
                            device=data.device)
            u[..., self.ref_channel].fill_(1)

        ws = get_mvdr_vector(psd_speech, psd_noise, u)
        enhanced = apply_beamforming_vector(ws, data)

        # (..., F, T) -> (..., T, F)
        enhanced = enhanced.transpose(-1, -2)
        mask_speech = mask_speech.transpose(-1, -3)

        return enhanced, ilens, mask_speech
Пример #3
0
        def apply_beamforming(data, ilens, psd_speech, psd_n, beamformer_type):
            # u: (B, C)
            if self.ref_channel < 0:
                u, _ = self.ref(psd_speech.float(), ilens)
            else:
                # (optional) Create onehot vector for fixed reference microphone
                u = torch.zeros(*(data.size()[:-3] + (data.size(-2), )),
                                device=data.device)
                u[..., self.ref_channel].fill_(1)

            if beamformer_type in ("mpdr", "mvdr"):
                ws = get_mvdr_vector(psd_speech, psd_n, u.double())
                enhanced = apply_beamforming_vector(ws, data)
            elif beamformer_type == "wpd":
                ws = get_WPD_filter_v2(psd_speech, psd_n, u.double())
                enhanced = perform_WPD_filtering(ws, data, self.bdelay,
                                                 self.btaps)
            else:
                raise ValueError("Not supporting beamformer_type={}".format(
                    beamformer_type))

            return enhanced, ws