示例#1
0
def test_complex_impl_consistency():
    if not is_torch_1_9_plus:
        return
    mat_th = torch.complex(torch.from_numpy(mat_np.real),
                           torch.from_numpy(mat_np.imag))
    mat_ct = ComplexTensor(torch.from_numpy(mat_np.real),
                           torch.from_numpy(mat_np.imag))
    bs = mat_th.shape[0]
    rank = mat_th.shape[-1]
    vec_th = torch.complex(torch.rand(bs, rank),
                           torch.rand(bs, rank)).type_as(mat_th)
    vec_ct = ComplexTensor(vec_th.real, vec_th.imag)

    for result_th, result_ct in (
        (abs(mat_th), abs(mat_ct)),
        (inverse(mat_th), inverse(mat_ct)),
        (matmul(mat_th,
                vec_th.unsqueeze(-1)), matmul(mat_ct, vec_ct.unsqueeze(-1))),
        (solve(vec_th.unsqueeze(-1),
               mat_th), solve(vec_ct.unsqueeze(-1), mat_ct)),
        (
            einsum("bec,bc->be", mat_th, vec_th),
            einsum("bec,bc->be", mat_ct, vec_ct),
        ),
    ):
        np.testing.assert_allclose(result_th.numpy(),
                                   result_ct.numpy(),
                                   atol=1e-6)
示例#2
0
    def forward(self, input: torch.Tensor, ilens: torch.Tensor):
        """Forward.

        Args:
            input (torch.Tensor): mixed speech [Batch, Nsample, Channel]
            ilens (torch.Tensor): input lengths [Batch]

        Returns:
            enhanced speech  (single-channel):
                torch.Tensor or List[torch.Tensor]
            output lengths
            predcited masks: OrderedDict[
                'dereverb': torch.Tensor(Batch, Frames, Channel, Freq),
                'spk1': torch.Tensor(Batch, Frames, Channel, Freq),
                'spk2': torch.Tensor(Batch, Frames, Channel, Freq),
                ...
                'spkn': torch.Tensor(Batch, Frames, Channel, Freq),
                'noise1': torch.Tensor(Batch, Frames, Channel, Freq),
            ]
        """
        # wave -> stft -> magnitude specturm
        input_spectrum, flens = self.stft(input, ilens)
        # (Batch, Frames, Freq) or (Batch, Frames, Channels, Freq)
        input_spectrum = ComplexTensor(input_spectrum[..., 0], input_spectrum[..., 1])
        if self.normalize_input:
            input_spectrum = input_spectrum / abs(input_spectrum).max()

        enhanced = input_spectrum
        masks = OrderedDict()

        if input_spectrum.dim() == 3:
            # single-channel input
            if self.use_wpe:
                # (B, T, F)
                enhanced, flens, mask_w = self.wpe(input_spectrum.unsqueeze(-2), flens)
                enhanced = enhanced.squeeze(-2)
                if mask_w is not None:
                    masks["dereverb"] = mask_w.squeeze(-2)

        elif input_spectrum.dim() == 4:
            # multi-channel input
            # 1. WPE
            if self.use_wpe:
                # (B, T, C, F)
                enhanced, flens, mask_w = self.wpe(input_spectrum, flens)
                if mask_w is not None:
                    masks["dereverb"] = mask_w

            # 2. Beamformer
            if self.use_beamformer:
                # enhanced: (B, T, C, F) -> (B, T, F)
                enhanced, flens, masks_b = self.beamformer(enhanced, flens)
                for spk in range(self.num_spk):
                    masks["spk{}".format(spk + 1)] = masks_b[spk]
                if len(masks_b) > self.num_spk:
                    masks["noise1"] = masks_b[self.num_spk]

        else:
            raise ValueError(
                "Invalid spectrum dimension: {}".format(input_spectrum.shape)
            )

        # Convert ComplexTensor to torch.Tensor
        # (B, T, F) -> (B, T, F, 2)
        if isinstance(enhanced, list):
            # multi-speaker output
            enhanced = [torch.stack([enh.real, enh.imag], dim=-1) for enh in enhanced]
        else:
            # single-speaker output
            enhanced = torch.stack([enhanced.real, enhanced.imag], dim=-1).float()
        return enhanced, flens, masks
示例#3
0
    def forward(
        self, input: ComplexTensor, ilens: torch.Tensor
    ) -> Tuple[List[ComplexTensor], torch.Tensor, OrderedDict]:
        """Forward.

        Args:
            input (ComplexTensor): mixed speech [Batch, Frames, Channel, Freq]
            ilens (torch.Tensor): input lengths [Batch]

        Returns:
            enhanced speech (single-channel): List[ComplexTensor]
            output lengths
            other predcited data: OrderedDict[
                'dereverb1': ComplexTensor(Batch, Frames, Channel, Freq),
                'mask_dereverb1': torch.Tensor(Batch, Frames, Channel, Freq),
                'mask_noise1': torch.Tensor(Batch, Frames, Channel, Freq),
                'mask_spk1': torch.Tensor(Batch, Frames, Channel, Freq),
                'mask_spk2': torch.Tensor(Batch, Frames, Channel, Freq),
                ...
                'mask_spkn': torch.Tensor(Batch, Frames, Channel, Freq),
            ]
        """
        # Shape of input spectrum must be (B, T, F) or (B, T, C, F)
        assert input.dim() in (3, 4), input.dim()
        enhanced = input
        others = OrderedDict()

        if (self.training and self.loss_type is not None
                and self.loss_type.startswith("mask")):
            # Only estimating masks during training for saving memory
            if self.use_wpe:
                if input.dim() == 3:
                    mask_w, ilens = self.wpe.predict_mask(
                        input.unsqueeze(-2), ilens)
                    mask_w = mask_w.squeeze(-2)
                elif input.dim() == 4:
                    mask_w, ilens = self.wpe.predict_mask(input, ilens)

                if mask_w is not None:
                    if isinstance(enhanced, list):
                        # single-source WPE
                        for spk in range(self.num_spk):
                            others["mask_dereverb{}".format(spk +
                                                            1)] = mask_w[spk]
                    else:
                        # multi-source WPE
                        others["mask_dereverb1"] = mask_w

            if self.use_beamformer and input.dim() == 4:
                others_b, ilens = self.beamformer.predict_mask(input, ilens)
                for spk in range(self.num_spk):
                    others["mask_spk{}".format(spk + 1)] = others_b[spk]
                if len(others_b) > self.num_spk:
                    others["mask_noise1"] = others_b[self.num_spk]

            return None, ilens, others

        else:
            powers = None
            # Performing both mask estimation and enhancement
            if input.dim() == 3:
                # single-channel input (B, T, F)
                if self.use_wpe:
                    enhanced, ilens, mask_w, powers = self.wpe(
                        input.unsqueeze(-2), ilens)
                    if isinstance(enhanced, list):
                        # single-source WPE
                        enhanced = [enh.squeeze(-2) for enh in enhanced]
                        if mask_w is not None:
                            for spk in range(self.num_spk):
                                key = "dereverb{}".format(spk + 1)
                                others[key] = enhanced[spk]
                                others["mask_" + key] = mask_w[spk].squeeze(-2)
                    else:
                        # multi-source WPE
                        enhanced = enhanced.squeeze(-2)
                        if mask_w is not None:
                            others["dereverb1"] = enhanced
                            others["mask_dereverb1"] = mask_w.squeeze(-2)
            else:
                # multi-channel input (B, T, C, F)
                # 1. WPE
                if self.use_wpe:
                    enhanced, ilens, mask_w, powers = self.wpe(input, ilens)
                    if mask_w is not None:
                        if isinstance(enhanced, list):
                            # single-source WPE
                            for spk in range(self.num_spk):
                                key = "dereverb{}".format(spk + 1)
                                others[key] = enhanced[spk]
                                others["mask_" + key] = mask_w[spk]
                        else:
                            # multi-source WPE
                            others["dereverb1"] = enhanced
                            others["mask_dereverb1"] = mask_w.squeeze(-2)

                # 2. Beamformer
                if self.use_beamformer:
                    if (not self.beamformer.beamformer_type.startswith("wmpdr")
                            or not self.beamformer.beamformer_type.startswith(
                                "wpd") or not self.shared_power
                            or (self.wpe.nmask == 1 and self.num_spk > 1)):
                        powers = None

                    # enhanced: (B, T, C, F) -> (B, T, F)
                    if isinstance(enhanced, list):
                        # outputs of single-source WPE
                        raise NotImplementedError(
                            "Single-source WPE is not supported with beamformer "
                            "in multi-speaker cases.")
                    else:
                        # output of multi-source WPE
                        enhanced, ilens, others_b = self.beamformer(
                            enhanced, ilens, powers=powers)
                    for spk in range(self.num_spk):
                        others["mask_spk{}".format(spk + 1)] = others_b[spk]
                    if len(others_b) > self.num_spk:
                        others["mask_noise1"] = others_b[self.num_spk]

        if not isinstance(enhanced, list):
            enhanced = [enhanced]

        return enhanced, ilens, others
示例#4
0
    def forward(self, input: torch.Tensor, ilens: torch.Tensor):
        """Forward.

        Args:
            input (torch.Tensor): mixed speech [Batch, Nsample, Channel]
            ilens (torch.Tensor): input lengths [Batch]

        Returns:
            enhanced speech  (single-channel):
                torch.Tensor or List[torch.Tensor]
            output lengths
            predcited masks: OrderedDict[
                'dereverb': torch.Tensor(Batch, Frames, Channel, Freq),
                'spk1': torch.Tensor(Batch, Frames, Channel, Freq),
                'spk2': torch.Tensor(Batch, Frames, Channel, Freq),
                ...
                'spkn': torch.Tensor(Batch, Frames, Channel, Freq),
                'noise1': torch.Tensor(Batch, Frames, Channel, Freq),
            ]
        """
        # wave -> stft -> magnitude specturm
        input_spectrum, flens = self.stft(input, ilens)
        # (Batch, Frames, Freq) or (Batch, Frames, Channels, Freq)
        input_spectrum = ComplexTensor(input_spectrum[..., 0],
                                       input_spectrum[..., 1])
        if self.normalize_input:
            input_spectrum = input_spectrum / abs(input_spectrum).max()

        # Shape of input spectrum must be (B, T, F) or (B, T, C, F)
        assert input_spectrum.dim() in (3, 4), input_spectrum.dim()
        enhanced = input_spectrum
        masks = OrderedDict()

        if self.training and self.loss_type.startswith("mask"):
            # Only estimating masks for training
            if self.use_wpe:
                if input_spectrum.dim() == 3:
                    mask_w, flens = self.wpe.predict_mask(
                        input_spectrum.unsqueeze(-2), flens)
                    mask_w = mask_w.squeeze(-2)
                elif input_spectrum.dim() == 4:
                    if self.use_beamformer:
                        enhanced, flens, mask_w = self.wpe(
                            input_spectrum, flens)
                    else:
                        mask_w, flens = self.wpe.predict_mask(
                            input_spectrum, flens)

                if mask_w is not None:
                    masks["dereverb"] = mask_w

            if self.use_beamformer and input_spectrum.dim() == 4:
                masks_b, flens = self.beamformer.predict_mask(enhanced, flens)
                for spk in range(self.num_spk):
                    masks["spk{}".format(spk + 1)] = masks_b[spk]
                if len(masks_b) > self.num_spk:
                    masks["noise1"] = masks_b[self.num_spk]

            return None, flens, masks

        else:
            # Performing both mask estimation and enhancement
            if input_spectrum.dim() == 3:
                # single-channel input (B, T, F)
                if self.use_wpe:
                    enhanced, flens, mask_w = self.wpe(
                        input_spectrum.unsqueeze(-2), flens)
                    enhanced = enhanced.squeeze(-2)
                    if mask_w is not None:
                        masks["dereverb"] = mask_w.squeeze(-2)
            else:
                # multi-channel input (B, T, C, F)
                # 1. WPE
                if self.use_wpe:
                    enhanced, flens, mask_w = self.wpe(input_spectrum, flens)
                    if mask_w is not None:
                        masks["dereverb"] = mask_w

                # 2. Beamformer
                if self.use_beamformer:
                    # enhanced: (B, T, C, F) -> (B, T, F)
                    enhanced, flens, masks_b = self.beamformer(enhanced, flens)
                    for spk in range(self.num_spk):
                        masks["spk{}".format(spk + 1)] = masks_b[spk]
                    if len(masks_b) > self.num_spk:
                        masks["noise1"] = masks_b[self.num_spk]

        # Convert ComplexTensor to torch.Tensor
        # (B, T, F) -> (B, T, F, 2)
        if isinstance(enhanced, list):
            # multi-speaker output
            enhanced = [
                torch.stack([enh.real, enh.imag], dim=-1) for enh in enhanced
            ]
        else:
            # single-speaker output
            enhanced = torch.stack([enhanced.real, enhanced.imag],
                                   dim=-1).float()
        return enhanced, flens, masks