예제 #1
0
def trace(a: ComplexTensor) -> ComplexTensor:
    E = torch.eye(a.real.size(-1), dtype=torch.uint8).expand(*a.size())
    return a[E].view(*a.size()[:-1]).sum(-1)
예제 #2
0
    def _compute_loss(
        self,
        speech_mix,
        speech_lengths,
        speech_ref,
        dereverb_speech_ref=None,
        noise_ref=None,
        cal_loss=True,
    ):
        """Compute loss according to self.loss_type.

        Args:
            speech_mix: (Batch, samples) or (Batch, samples, channels)
            speech_lengths: (Batch,), default None for chunk interator,
                            because the chunk-iterator does not have the
                            speech_lengths returned. see in
                            espnet2/iterators/chunk_iter_factory.py
            speech_ref: (Batch, num_speaker, samples)
                        or (Batch, num_speaker, samples, channels)
            dereverb_speech_ref: (Batch, N, samples)
                        or (Batch, num_speaker, samples, channels)
            noise_ref: (Batch, num_noise_type, samples)
                        or (Batch, num_speaker, samples, channels)
            cal_loss: whether to calculate enh loss, defualt is True

        Returns:
            loss: (torch.Tensor) speech enhancement loss
            speech_pre: (List[torch.Tensor] or List[ComplexTensor])
                        enhanced speech or spectrum(s)
            others: (OrderedDict) estimated masks or None
            output_lengths: (Batch,)
            perm: () best permutation
        """
        feature_mix, flens = self.encoder(speech_mix, speech_lengths)
        feature_pre, flens, others = self.separator(feature_mix, flens)

        if self.loss_type not in ["si_snr", "ci_sdr"]:
            spectrum_mix = feature_mix
            spectrum_pre = feature_pre
            # predict separated speech and masks
            if self.stft_consistency:
                # pseudo STFT -> time-domain -> STFT (compute loss)
                tmp_t_domain = [
                    self.decoder(sp, speech_lengths)[0] for sp in spectrum_pre
                ]
                spectrum_pre = [
                    self.encoder(sp, speech_lengths)[0] for sp in tmp_t_domain
                ]
                pass

            if spectrum_pre is not None and not isinstance(
                spectrum_pre[0], ComplexTensor
            ):
                spectrum_pre = [
                    ComplexTensor(*torch.unbind(sp, dim=-1)) for sp in spectrum_pre
                ]

            if not cal_loss:
                loss, perm = None, None
                return loss, spectrum_pre, others, flens, perm

            # prepare reference speech and reference spectrum
            speech_ref = torch.unbind(speech_ref, dim=1)
            # List[ComplexTensor(Batch, T, F)] or List[ComplexTensor(Batch, T, C, F)]
            spectrum_ref = [self.encoder(sr, speech_lengths)[0] for sr in speech_ref]

            # compute TF masking loss
            if self.loss_type == "magnitude":
                # compute loss on magnitude spectrum
                assert spectrum_pre is not None
                magnitude_pre = [abs(ps + 1e-15) for ps in spectrum_pre]
                if spectrum_ref[0].dim() > magnitude_pre[0].dim():
                    # only select one channel as the reference
                    magnitude_ref = [
                        abs(sr[..., self.ref_channel, :]) for sr in spectrum_ref
                    ]
                else:
                    magnitude_ref = [abs(sr) for sr in spectrum_ref]

                tf_loss, perm = self._permutation_loss(
                    magnitude_ref, magnitude_pre, self.tf_mse_loss
                )
            elif self.loss_type.startswith("spectrum"):
                # compute loss on complex spectrum
                if self.loss_type == "spectrum":
                    loss_func = self.tf_mse_loss
                elif self.loss_type == "spectrum_log":
                    loss_func = self.tf_log_mse_loss
                else:
                    raise ValueError("Unsupported loss type: %s" % self.loss_type)

                assert spectrum_pre is not None
                if spectrum_ref[0].dim() > spectrum_pre[0].dim():
                    # only select one channel as the reference
                    spectrum_ref = [sr[..., self.ref_channel, :] for sr in spectrum_ref]

                tf_loss, perm = self._permutation_loss(
                    spectrum_ref, spectrum_pre, loss_func
                )
            elif self.loss_type.startswith("mask"):
                if self.loss_type == "mask_mse":
                    loss_func = self.tf_mse_loss
                else:
                    raise ValueError("Unsupported loss type: %s" % self.loss_type)

                assert others is not None
                mask_pre_ = [
                    others["mask_spk{}".format(spk + 1)] for spk in range(self.num_spk)
                ]

                # prepare ideal masks
                mask_ref = self._create_mask_label(
                    spectrum_mix, spectrum_ref, mask_type=self.mask_type
                )

                # compute TF masking loss
                tf_loss, perm = self._permutation_loss(mask_ref, mask_pre_, loss_func)

                if "mask_dereverb1" in others:
                    if dereverb_speech_ref is None:
                        raise ValueError(
                            "No dereverberated reference for training!\n"
                            'Please specify "--use_dereverb_ref true" in run.sh'
                        )

                    mask_wpe_pre = [
                        others["mask_dereverb{}".format(spk + 1)]
                        for spk in range(self.num_spk)
                        if "mask_dereverb{}".format(spk + 1) in others
                    ]
                    assert len(mask_wpe_pre) == dereverb_speech_ref.size(1), (
                        len(mask_wpe_pre),
                        dereverb_speech_ref.size(1),
                    )
                    dereverb_speech_ref = torch.unbind(dereverb_speech_ref, dim=1)
                    dereverb_spectrum_ref = [
                        self.encoder(dr, speech_lengths)[0]
                        for dr in dereverb_speech_ref
                    ]
                    dereverb_mask_ref = self._create_mask_label(
                        spectrum_mix, dereverb_spectrum_ref, mask_type=self.mask_type
                    )

                    tf_dereverb_loss, perm_d = self._permutation_loss(
                        dereverb_mask_ref, mask_wpe_pre, loss_func
                    )
                    tf_loss = tf_loss + tf_dereverb_loss

                if "mask_noise1" in others:
                    if noise_ref is None:
                        raise ValueError(
                            "No noise reference for training!\n"
                            'Please specify "--use_noise_ref true" in run.sh'
                        )

                    noise_ref = torch.unbind(noise_ref, dim=1)
                    noise_spectrum_ref = [
                        self.encoder(nr, speech_lengths)[0] for nr in noise_ref
                    ]
                    noise_mask_ref = self._create_mask_label(
                        spectrum_mix, noise_spectrum_ref, mask_type=self.mask_type
                    )

                    mask_noise_pre = [
                        others["mask_noise{}".format(n + 1)]
                        for n in range(self.num_noise_type)
                    ]
                    tf_noise_loss, perm_n = self._permutation_loss(
                        noise_mask_ref, mask_noise_pre, loss_func
                    )
                    tf_loss = tf_loss + tf_noise_loss
            else:
                raise ValueError("Unsupported loss type: %s" % self.loss_type)

            loss = tf_loss
            return loss, spectrum_pre, others, flens, perm

        else:
            speech_pre = [self.decoder(ps, speech_lengths)[0] for ps in feature_pre]
            if not cal_loss:
                loss, perm = None, None
                return loss, speech_pre, None, speech_lengths, perm

            # speech_pre: list[(batch, sample)]
            assert speech_pre[0].dim() == 2, speech_pre[0].dim()

            if speech_ref.dim() == 4:
                # For si_snr loss of multi-channel input,
                # only select one channel as the reference
                speech_ref = speech_ref[..., self.ref_channel]
            speech_ref = torch.unbind(speech_ref, dim=1)

            if self.loss_type == "si_snr":
                # compute si-snr loss
                loss, perm = self._permutation_loss(
                    speech_ref, speech_pre, self.si_snr_loss_zeromean
                )
            elif self.loss_type == "ci_sdr":
                # compute ci-snr loss
                loss, perm = self._permutation_loss(
                    speech_ref, speech_pre, self.ci_sdr_loss
                )
            else:
                raise ValueError("Unsupported loss type: %s" % self.loss_type)

            return loss, speech_pre, None, speech_lengths, perm
예제 #3
0
    def forward(self, data: ComplexTensor, ilens: torch.LongTensor) \
            -> Tuple[ComplexTensor, torch.LongTensor, ComplexTensor]:
        """The forward function

        Notation:
            B: Batch
            C: Channel
            T: Time or Sequence length
            F: Freq

        Args:
            data (ComplexTensor): (B, T, C, F)
            ilens (torch.Tensor): (B,)
        Returns:
            enhanced (ComplexTensor): (B, T, F)
            ilens (torch.Tensor): (B,)

        """
        def apply_beamforming(data, ilens, psd_speech, psd_noise):
            # u: (B, C)
            if self.ref_channel < 0:
                u, _ = self.ref(psd_speech, ilens)
            else:
                # (optional) Create onehot vector for fixed reference microphone
                u = torch.zeros(*(data.size()[:-3] + (data.size(-2), )),
                                device=data.device)
                u[..., self.ref_channel].fill_(1)

            ws = get_mvdr_vector(psd_speech, psd_noise, u)
            enhanced = apply_beamforming_vector(ws, data)

            return enhanced, ws

        # data (B, T, C, F) -> (B, F, C, T)
        data = data.permute(0, 3, 2, 1)

        # mask: (B, F, C, T)
        masks, _ = self.mask(data, ilens)
        assert self.nmask == len(masks)

        if self.nmask == 2:  # (mask_speech, mask_noise)
            mask_speech, mask_noise = masks

            psd_speech = get_power_spectral_density_matrix(data, mask_speech)
            psd_noise = get_power_spectral_density_matrix(data, mask_noise)

            enhanced, ws = apply_beamforming(data, ilens, psd_speech,
                                             psd_noise)

            # (..., F, T) -> (..., T, F)
            enhanced = enhanced.transpose(-1, -2)
            mask_speech = mask_speech.transpose(-1, -3)
        else:  # multi-speaker case: (mask_speech1, ..., mask_noise)
            mask_speech = list(masks[:-1])
            mask_noise = masks[-1]

            psd_speeches = [
                get_power_spectral_density_matrix(data, mask)
                for mask in mask_speech
            ]
            psd_noise = get_power_spectral_density_matrix(data, mask_noise)

            enhanced = []
            ws = []
            for i in range(self.nmask - 1):
                psd_speech = psd_speeches.pop(i)
                # treat all other speakers' psd_speech as noises
                enh, w = apply_beamforming(data, ilens, psd_speech,
                                           sum(psd_speeches) + psd_noise)
                psd_speeches.insert(i, psd_speech)

                # (..., F, T) -> (..., T, F)
                enh = enh.transpose(-1, -2)
                mask_speech[i] = mask_speech[i].transpose(-1, -3)

                enhanced.append(enh)
                ws.append(w)

        return enhanced, ilens, mask_speech
                High_noise, Low_noise = Decomposition(y_.squeeze(0), 0.10)

                # dncnn_data = High_noise[:,0].unsqueeze(0)
                decompose_data = torch.cat([y_, High_noise, Low_noise], dim=1)
                # x = torch.cat([High_origin.unsqueeze(1), Low_origin.unsqueeze(1)], dim=1)
                model = model.cpu()
                decom_output = decompose_model(decompose_data.float()).squeeze(
                    0)  # inference

                dncnn_output = model(High_noise[:, 0].unsqueeze(1)).squeeze(0)
                # dncnn_high, dncnn_low = Decomposition(dncnn_output, 0.10)

                # x_ = output[0].cpu().detach().numpy().astype(np.float32)

                output = ComplexTensor(dncnn_output[1] + decom_output[3],
                                       decom_output[2] +
                                       decom_output[4]).abs()
                # output = ComplexTensor(dncnn_high[0,0] + decom_output[3], decom_output[2] + decom_output[4]).abs()
                x_ = output.cpu().detach().numpy().astype(np.float32)

                # x_ = ComplexTensor(output[:, 0] + output[:, 2], Low_noise[:, 0] + Low_noise[:, 1]).abs().squeeze(0)
                # x_ = ComplexTensor(output[:, 0] + output[:, 2], output[:, 1] + output[:, 3]).abs().squeeze(0)
                # x_ = torch.add(output[:,0], output[:,1]).squeeze(0)

                # x_ = x_.cpu().detach().numpy().astype(np.float32)

                # x_ = x_.view(y.shape[0], y.shape[1])
                # x_ = x_.cpu()
                # x_ = x_.detach().numpy().astype(np.float32)
                elapsed_time = time.time() - start_time
예제 #5
0
def test_gev_phase_correction():
    mat = ComplexTensor(torch.rand(2, 3, 4), torch.rand(2, 3, 4))
    mat_th = torch.complex(mat.real, mat.imag)
    norm = gev_phase_correction(mat)
    norm_th = gev_phase_correction(mat_th)
    assert np.allclose(norm.numpy(), norm_th.numpy())
예제 #6
0
                # lowfreq_input = torch.cat([y_, High_noise[:,0].unsqueeze(1), High_noise[:,1].unsqueeze(1), Low_noise[:,0].unsqueeze(1), Low_noise[:,1].unsqueeze(1)], dim=1)
                lowfreq_input = torch.cat([
                    High_noise[:, 1].unsqueeze(1),
                    Low_noise[:, 0].unsqueeze(1), Low_noise[:, 1].unsqueeze(1)
                ],
                                          dim=1)

                output_dncnn = model_dncnn(dncnn_input.float()).squeeze(
                    0)  # inference

                lowfreq_output = model_lowfreq(lowfreq_input)

                # x_ = output[0].cpu().detach().numpy().astype(np.float32)
                # output = ComplexTensor(High_origin[0,0].cuda() + output[3], output[2] + output[4]).abs()
                output = ComplexTensor(
                    output_dncnn[0] + lowfreq_output[0, 1],
                    lowfreq_output[0, 0] + lowfreq_output[0, 2]).abs()
                # output = ComplexTensor(output_dncnn[0] + Low_origin[0, 0],
                #                        High_origin[0, 1] + Low_origin[0, 1]).abs()

                x_ = output.cpu().detach().numpy().astype(np.float32)

                plt.figure()
                plt.imshow(x, cmap='jet')
                plt.show()
                plt.close()

                plt.figure()
                plt.imshow(output.detach().numpy(), cmap='jet')
                plt.show()
                plt.close()
예제 #7
0
def main():
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--in-scp', type=str, required=True)
    parser.add_argument('--clean-scp',
                        type=str,
                        help='Decode using oracle clean power')
    parser.add_argument('--out-dir', type=str, required=True)
    parser.add_argument('--model-state', type=str)
    parser.add_argument('--model-config', type=str)
    parser.add_argument('--stft-file', type=str, default='./stft.json')
    parser.add_argument('--ngpu', type=int, default=1)
    parser.add_argument('--ref-channels', type=str2int_tuple, default=None)
    parser.add_argument('--online', type=strtobool, default=False)
    parser.add_argument('--taps', type=int, default=5)
    parser.add_argument('--delay', type=int, default=3)

    args = parser.parse_args()
    devcice = 'cuda' if args.ngpu > 1 else 'cpu'

    if args.model_config is not None:
        with open(args.model_config) as f:
            model_config = json.load(f)

        model_config.update(use_dnn=True)
        _ = model_config.pop('width')
        norm_scale = model_config.pop('norm_scale')
        model = DNN_WPE(**model_config)
        if args.model_state is not None:
            model.load_state_dict(torch.load(args.model_state))
    else:
        model = None
        norm_scale = False

    reader = SoundScpReader(args.in_scp)
    writer = SoundScpWriter(args.out_dir, 'wav')

    if args.clean_scp is not None:
        clean_reader = SoundScpReader(args.clean_scp)
    else:
        clean_reader = None
    stft_func = Stft(args.stft_file)

    for key in tqdm(reader):
        # inp: (T, C)
        rate, inp = reader[key]
        if inp.ndim == 1:
            inp = inp[:, None]
        if args.ref_channels is not None:
            inp = inp[:, args.ref_channels]

        # Scaling int to [-1, 1]
        inp = inp.astype(np.float32) / (np.iinfo(inp.dtype).max - 1)
        if norm_scale:
            scale = np.abs(inp).mean()
            inp /= scale
        else:
            scale = 1.

        # inp: (T, C) -> inp_stft: (C, F, T)
        inp_stft = stft_func(inp.T)

        if clean_reader is not None:
            _, clean = clean_reader[key]
            if clean.ndim == 1:
                clean = clean[:, None]
            # clean: (T, C) -> clean_stft: (C, F, T)
            clean_stft = stft_func(clean.T)
            power = (clean_stft.real**clean_stft.imag**2).mean(0)

        elif model is not None:
            # To torch(C, F, T) -> (1, C, T, F)
            inp_stft_th = ComplexTensor(inp_stft.transpose(
                0, 2, 1)[None]).to(devcice)
            with torch.no_grad():
                _, power = model(inp_stft_th, return_wpe=False)
            # power: (1, C, T, F) -> (F, C, T)
            power = power[0].permute(2, 0, 1)

            # To numpy: (F, C, T) -> (F, T)
            power = power.cpu().numpy().mean(1)
        else:
            power = None

        # enh_stft: (F, C, T)
        if not args.online:
            enh_stft = wpe(
                inp_stft.transpose(1, 0, 2),
                power=power,
                taps=args.taps,
                delay=args.delay,
                iterations=1 if model is not None else 3,
            )
        else:
            enh_stft = online_wpe(inp_stft.transpose(1, 0, 2),
                                  power=power,
                                  taps=args.taps,
                                  delay=args.delay)
        # enh_stft: (F, C, T) -> (C, F, T)
        enh_stft = enh_stft.transpose(1, 0, 2)
        enh_stft = enh_stft[0]

        # enh_stft: (C, F, T) -> enh: (T, C)
        enh = stft_func.istft(enh_stft).T
        # Truncate
        enh = enh[:inp.shape[0]]

        if norm_scale:
            enh *= scale
        # Rescaling  [-1, 1] to int16
        enh = (enh * (np.iinfo(np.int16).max - 1)).astype(np.int16)

        writer[key] = (rate, enh)
예제 #8
0
    def forward(self,
                data: ComplexTensor, ilens: torch.LongTensor=None,
                return_wpe: bool=True) -> Tuple[Optional[ComplexTensor],
                                                torch.Tensor]:
        if ilens is None:
            ilens = torch.full((data.size(0),), data.size(2),
                               dtype=torch.long, device=data.device)

        r = -self.rcontext if self.rcontext != 0 else None
        enhanced = data[:, :, self.lcontext:r, :]

        if self.lcontext != 0 or self.rcontext != 0:
            assert all(ilens[0] == i for i in ilens)

            # Create context window (a.k.a Splicing)
            if self.model_type in ('blstm', 'lstm'):
                width = data.size(2) - self.lcontext - self.rcontext
                # data: (B, C, l + w + r, F)
                indices = [i + j for i in range(width)
                           for j in range(1 + self.lcontext + self.rcontext)]
                _y = data[:, :, indices]
                # data: (B, C, l, (1 + w + r), F)
                data = _y.view(
                    data.size(0), data.size(1),
                    width, (1 + self.lcontext + self.rcontext) * data.size(3))
                ilens = torch.full((data.size(0),), width,
                                   dtype=torch.long, device=data.device)
                del _y

        for i in range(self.iterations):
            power = enhanced.real ** 2 + enhanced.imag ** 2
            # Calculate power: (B, C, T, Context, F)
            if i == 0 and self.use_dnn:
                # mask: (B, C, T, F)
                mask = self.estimator(data, ilens)
                if mask.size(2) != power.size(2):
                    assert mask.size(2) == (power.size(2) + self.rcontext + self.lcontext)
                    r = -self.rcontext if self.rcontext != 0 else None
                    mask = mask[:, :, self.lcontext:r, :]

                if self.normalization:
                    # Normalize along T
                    mask = mask / mask.sum(dim=-2)[..., None]
                if self.out_type == 'mask':
                    power = power * mask
                else:
                    power = mask

                    if self.out_type == 'amplitude':
                        power = power ** 2
                    elif self.out_type == 'log_power':
                        power = power.exp()
                    elif self.out_type == 'power':
                        pass
                    else:
                        raise NotImplementedError(self.out_type)

            if not return_wpe:
                return None, power

            # power: (B, C, T, F) -> _power: (B, F, T)
            _power = power.mean(dim=1).transpose(-1, -2).contiguous()

            # data: (B, C, T, F) -> _data: (B, F, C, T)
            _data = data.permute(0, 3, 1, 2).contiguous()
            # _enhanced: (B, F, C, T)
            _enhanced_real = []
            _enhanced_imag = []
            for d, p, l in zip(_data, _power, ilens):
                # e: (F, C, T) -> (T, C, F)
                e = wpe_one_iteration(
                    d[..., :l], p[..., :l],
                    taps=self.taps, delay=self.delay,
                    inverse_power=self.inverse_power).transpose(0, 2)
                _enhanced_real.append(e.real)
                _enhanced_imag.append(e.imag)
            # _enhanced: B x (T, C, F) -> (B, T, C, F) -> (B, F, C, T)
            _enhanced_real = pad_sequence(_enhanced_real,
                                          batch_first=True).transpose(1, 3)
            _enhanced_imag = pad_sequence(_enhanced_imag,
                                          batch_first=True).transpose(1, 3)
            _enhanced = ComplexTensor(_enhanced_real, _enhanced_imag)

            # enhanced: (B, F, C, T) -> (B, C, T, F)
            enhanced = _enhanced.permute(0, 2, 3, 1)

        # enhanced: (B, C, T, F), power: (B, C, T, F)
        return enhanced, power
예제 #9
0
                # y_ = torch.cat([y_,High_noise, Low_noise], dim=1)
                # x = torch.cat([High_origin.unsqueeze(1), Low_origin.unsqueeze(1)], dim=1)

                dncnn_input = High_noise[:,0].unsqueeze(1).cuda()
                lowfreq_input = torch.cat([High_noise[:,1].unsqueeze(1), Low_noise[:,0].unsqueeze(1), Low_noise[:,1].unsqueeze(1)], dim=1)

                dncnn_input = dncnn_input.cuda()
                lowfreq_input = lowfreq_input.cuda()

                output_dncnn = model_dncnn(dncnn_input.cuda().float()).squeeze(0) # inference
                lowfreq_output = model_lowfreq(lowfreq_input).unsqueeze(0)


                # x_ = output[0].cpu().detach().numpy().astype(np.float32)
                # output = ComplexTensor(High_origin[0,0].cuda() + output[3], output[2] + output[4]).abs()
                output = ComplexTensor(output_dncnn[0] + lowfreq_output[1], lowfreq_output[0] + lowfreq_output[2]).abs()

                x_ = output.cpu().detach().numpy().astype(np.float32)

                # x_ = ComplexTensor(output[:, 0] + output[:, 2], Low_noise[:, 0] + Low_noise[:, 1]).abs().squeeze(0)
                # x_ = ComplexTensor(output[:, 0] + output[:, 2], output[:, 1] + output[:, 3]).abs().squeeze(0)
                # x_ = torch.add(output[:,0], output[:,1]).squeeze(0)

                # x_ = x_.cpu().detach().numpy().astype(np.float32)

                # x_ = x_.view(y.shape[0], y.shape[1])
                # x_ = x_.cpu()
                # x_ = x_.detach().numpy().astype(np.float32)
                torch.cuda.synchronize()
                elapsed_time = time.time() - start_time
예제 #10
0
def get_WPD_filter_with_rtf(
    psd_observed_bar: ComplexTensor,
    psd_speech: ComplexTensor,
    psd_noise: ComplexTensor,
    iterations: int = 3,
    reference_vector: Union[int, torch.Tensor, None] = None,
    normalize_ref_channel: Optional[int] = None,
    use_torch_solver: bool = True,
    diagonal_loading: bool = True,
    diag_eps: float = 1e-7,
    eps: float = 1e-15,
) -> ComplexTensor:
    """Return the WPD vector calculated with RTF.

        WPD is the Weighted Power minimization Distortionless response
        convolutional beamformer. As follows:

        h = (Rf^-1 @ vbar) / (vbar^H @ R^-1 @ vbar)

    Reference:
        T. Nakatani and K. Kinoshita, "A Unified Convolutional Beamformer
        for Simultaneous Denoising and Dereverberation," in IEEE Signal
        Processing Letters, vol. 26, no. 6, pp. 903-907, June 2019, doi:
        10.1109/LSP.2019.2911179.
        https://ieeexplore.ieee.org/document/8691481

    Args:
        psd_observed_bar (ComplexTensor): stacked observation covariance matrix
        psd_speech (ComplexTensor): speech covariance matrix (..., F, C, C)
        psd_noise (ComplexTensor): noise covariance matrix (..., F, C, C)
        iterations (int): number of iterations in power method
        reference_vector (torch.Tensor or int): (..., C) or scalar
        normalize_ref_channel (int): reference channel for normalizing the RTF
        use_torch_solver (bool): Whether to use `solve` instead of `inverse`
        diagonal_loading (bool): Whether to add a tiny term to the diagonal of psd_n
        diag_eps (float):
        eps (float):
    Returns:
        beamform_vector (ComplexTensor)r: (..., F, C)
    """
    C = psd_noise.size(-1)
    if diagonal_loading:
        psd_noise = tik_reg(psd_noise, reg=diag_eps, eps=eps)

    # (B, F, C, 1)
    rtf = get_rtf(
        psd_speech,
        psd_noise,
        reference_vector,
        iterations=iterations,
        use_torch_solver=use_torch_solver,
    )

    # (B, F, (K+1)*C, 1)
    rtf = FC.pad(rtf, (0, 0, 0, psd_observed_bar.shape[-1] - C), "constant", 0)
    # numerator: (..., C_1, C_2) x (..., C_2, 1) -> (..., C_1)
    if use_torch_solver:
        numerator = FC.solve(rtf, psd_observed_bar)[0].squeeze(-1)
    else:
        numerator = FC.matmul(psd_observed_bar.inverse2(), rtf).squeeze(-1)
    denominator = FC.einsum("...d,...d->...",
                            [rtf.squeeze(-1).conj(), numerator])
    if normalize_ref_channel is not None:
        scale = rtf.squeeze(-1)[..., normalize_ref_channel, None].conj()
        beamforming_vector = numerator * scale / (
            denominator.real.unsqueeze(-1) + eps)
    else:
        beamforming_vector = numerator / (denominator.real.unsqueeze(-1) + eps)
    return beamforming_vector
예제 #11
0
    def forward(
        self, data: ComplexTensor, ilens: torch.LongTensor
    ) -> Tuple[ComplexTensor, torch.LongTensor, ComplexTensor]:
        """DNN_WPE forward function.

        Notation:
            B: Batch
            C: Channel
            T: Time or Sequence length
            F: Freq or Some dimension of the feature vector

        Args:
            data: (B, T, C, F)
            ilens: (B,)
        Returns:
            enhanced (torch.Tensor or List[torch.Tensor]): (B, T, C, F)
            ilens: (B,)
            masks (torch.Tensor or List[torch.Tensor]): (B, T, C, F)
            power (List[torch.Tensor]): (B, F, T)
        """
        # (B, T, C, F) -> (B, F, C, T)
        data = data.permute(0, 3, 2, 1)
        enhanced = [data for i in range(self.nmask)]
        masks = None
        power = None

        for i in range(self.iterations):
            # Calculate power: (..., C, T)
            power = [enh.real**2 + enh.imag**2 for enh in enhanced]
            if i == 0 and self.use_dnn_mask:
                # mask: (B, F, C, T)
                masks, _ = self.mask_est(data, ilens)
                # floor masks to increase numerical stability
                if self.mask_flooring:
                    masks = [m.clamp(min=self.flooring_thres) for m in masks]
                if self.normalization:
                    # Normalize along T
                    masks = [m / m.sum(dim=-1, keepdim=True) for m in masks]
                # (..., C, T) * (..., C, T) -> (..., C, T)
                power = [p * masks[i] for i, p in enumerate(power)]

            # Averaging along the channel axis: (..., C, T) -> (..., T)
            power = [p.mean(dim=-2).clamp(min=self.eps) for p in power]

            # enhanced: (..., C, T) -> (..., C, T)
            # NOTE(kamo): Calculate in double precision
            enhanced = [
                wpe_one_iteration(
                    data.contiguous().double(),
                    p.double(),
                    taps=self.taps,
                    delay=self.delay,
                    inverse_power=self.inverse_power,
                ) for p in power
            ]
            enhanced = [
                enh.to(dtype=data.dtype).masked_fill(
                    make_pad_mask(ilens, enh.real), 0) for enh in enhanced
            ]

        # (B, F, C, T) -> (B, T, C, F)
        enhanced = [enh.permute(0, 3, 2, 1) for enh in enhanced]
        if masks is not None:
            masks = ([m.transpose(-1, -3) for m in masks]
                     if self.nmask > 1 else masks[0].transpose(-1, -3))
        if self.nmask == 1:
            enhanced = enhanced[0]

        return enhanced, ilens, masks, power
예제 #12
0
def get_mvdr_vector_with_rtf(
    psd_n: ComplexTensor,
    psd_speech: ComplexTensor,
    psd_noise: ComplexTensor,
    iterations: int = 3,
    reference_vector: Union[int, torch.Tensor, None] = None,
    normalize_ref_channel: Optional[int] = None,
    use_torch_solver: bool = True,
    diagonal_loading: bool = True,
    diag_eps: float = 1e-7,
    eps: float = 1e-8,
) -> ComplexTensor:
    """Return the MVDR (Minimum Variance Distortionless Response) vector
        calculated with RTF:

        h = (Npsd^-1 @ rtf) / (rtf^H @ Npsd^-1 @ rtf)

    Reference:
        On optimal frequency-domain multichannel linear filtering
        for noise reduction; M. Souden et al., 2010;
        https://ieeexplore.ieee.org/document/5089420

    Args:
        psd_n (ComplexTensor): observation/noise covariance matrix (..., F, C, C)
        psd_speech (ComplexTensor): speech covariance matrix (..., F, C, C)
        psd_noise (ComplexTensor): noise covariance matrix (..., F, C, C)
        iterations (int): number of iterations in power method
        reference_vector (torch.Tensor or int): (..., C) or scalar
        normalize_ref_channel (int): reference channel for normalizing the RTF
        use_torch_solver (bool): Whether to use `solve` instead of `inverse`
        diagonal_loading (bool): Whether to add a tiny term to the diagonal of psd_n
        diag_eps (float):
        eps (float):
    Returns:
        beamform_vector (ComplexTensor): (..., F, C)
    """  # noqa: H405, D205, D400
    if diagonal_loading:
        psd_noise = tik_reg(psd_noise, reg=diag_eps, eps=eps)

    # (B, F, C, 1)
    rtf = get_rtf(
        psd_speech,
        psd_noise,
        reference_vector,
        iterations=iterations,
        use_torch_solver=use_torch_solver,
    )

    # numerator: (..., C_1, C_2) x (..., C_2, 1) -> (..., C_1)
    if use_torch_solver:
        numerator = FC.solve(rtf, psd_n)[0].squeeze(-1)
    else:
        numerator = FC.matmul(psd_n.inverse2(), rtf).squeeze(-1)
    denominator = FC.einsum("...d,...d->...",
                            [rtf.squeeze(-1).conj(), numerator])
    if normalize_ref_channel is not None:
        scale = rtf.squeeze(-1)[..., normalize_ref_channel, None].conj()
        beamforming_vector = numerator * scale / (
            denominator.real.unsqueeze(-1) + eps)
    else:
        beamforming_vector = numerator / (denominator.real.unsqueeze(-1) + eps)
    return beamforming_vector
예제 #13
0
    def forward(
        self, input: ComplexTensor, ilens: torch.Tensor
    ) -> Tuple[List[ComplexTensor], torch.Tensor, OrderedDict]:
        """Forward.

        Args:
            input (ComplexTensor): mixed speech [Batch, Frames, Channel, Freq]
            ilens (torch.Tensor): input lengths [Batch]

        Returns:
            enhanced speech (single-channel): List[ComplexTensor]
            output lengths
            other predcited data: OrderedDict[
                'dereverb1': ComplexTensor(Batch, Frames, Channel, Freq),
                'mask_dereverb1': torch.Tensor(Batch, Frames, Channel, Freq),
                'mask_noise1': torch.Tensor(Batch, Frames, Channel, Freq),
                'mask_spk1': torch.Tensor(Batch, Frames, Channel, Freq),
                'mask_spk2': torch.Tensor(Batch, Frames, Channel, Freq),
                ...
                'mask_spkn': torch.Tensor(Batch, Frames, Channel, Freq),
            ]
        """
        # Shape of input spectrum must be (B, T, F) or (B, T, C, F)
        assert input.dim() in (3, 4), input.dim()
        enhanced = input
        others = OrderedDict()

        if (self.training and self.loss_type is not None
                and self.loss_type.startswith("mask")):
            # Only estimating masks during training for saving memory
            if self.use_wpe:
                if input.dim() == 3:
                    mask_w, ilens = self.wpe.predict_mask(
                        input.unsqueeze(-2), ilens)
                    mask_w = mask_w.squeeze(-2)
                elif input.dim() == 4:
                    mask_w, ilens = self.wpe.predict_mask(input, ilens)

                if mask_w is not None:
                    if isinstance(enhanced, list):
                        # single-source WPE
                        for spk in range(self.num_spk):
                            others["mask_dereverb{}".format(spk +
                                                            1)] = mask_w[spk]
                    else:
                        # multi-source WPE
                        others["mask_dereverb1"] = mask_w

            if self.use_beamformer and input.dim() == 4:
                others_b, ilens = self.beamformer.predict_mask(input, ilens)
                for spk in range(self.num_spk):
                    others["mask_spk{}".format(spk + 1)] = others_b[spk]
                if len(others_b) > self.num_spk:
                    others["mask_noise1"] = others_b[self.num_spk]

            return None, ilens, others

        else:
            powers = None
            # Performing both mask estimation and enhancement
            if input.dim() == 3:
                # single-channel input (B, T, F)
                if self.use_wpe:
                    enhanced, ilens, mask_w, powers = self.wpe(
                        input.unsqueeze(-2), ilens)
                    if isinstance(enhanced, list):
                        # single-source WPE
                        enhanced = [enh.squeeze(-2) for enh in enhanced]
                        if mask_w is not None:
                            for spk in range(self.num_spk):
                                key = "dereverb{}".format(spk + 1)
                                others[key] = enhanced[spk]
                                others["mask_" + key] = mask_w[spk].squeeze(-2)
                    else:
                        # multi-source WPE
                        enhanced = enhanced.squeeze(-2)
                        if mask_w is not None:
                            others["dereverb1"] = enhanced
                            others["mask_dereverb1"] = mask_w.squeeze(-2)
            else:
                # multi-channel input (B, T, C, F)
                # 1. WPE
                if self.use_wpe:
                    enhanced, ilens, mask_w, powers = self.wpe(input, ilens)
                    if mask_w is not None:
                        if isinstance(enhanced, list):
                            # single-source WPE
                            for spk in range(self.num_spk):
                                key = "dereverb{}".format(spk + 1)
                                others[key] = enhanced[spk]
                                others["mask_" + key] = mask_w[spk]
                        else:
                            # multi-source WPE
                            others["dereverb1"] = enhanced
                            others["mask_dereverb1"] = mask_w.squeeze(-2)

                # 2. Beamformer
                if self.use_beamformer:
                    if (not self.beamformer.beamformer_type.startswith("wmpdr")
                            or not self.beamformer.beamformer_type.startswith(
                                "wpd") or not self.shared_power
                            or (self.wpe.nmask == 1 and self.num_spk > 1)):
                        powers = None

                    # enhanced: (B, T, C, F) -> (B, T, F)
                    if isinstance(enhanced, list):
                        # outputs of single-source WPE
                        raise NotImplementedError(
                            "Single-source WPE is not supported with beamformer "
                            "in multi-speaker cases.")
                    else:
                        # output of multi-source WPE
                        enhanced, ilens, others_b = self.beamformer(
                            enhanced, ilens, powers=powers)
                    for spk in range(self.num_spk):
                        others["mask_spk{}".format(spk + 1)] = others_b[spk]
                    if len(others_b) > self.num_spk:
                        others["mask_noise1"] = others_b[self.num_spk]

        if not isinstance(enhanced, list):
            enhanced = [enhanced]

        return enhanced, ilens, others
                batch_x, batch_y = batch_yx[:, 0].cuda(), batch_yx[:, 1].cuda()
                temp = batch_x
                High_noise, Low_noise = Decomposition(batch_y, 0.125)
                High_origin, Low_origin = Decomposition(batch_x, 0.125)

            # batch_y = batch_y.unsqueeze(1)
            # batch_x = batch_x.unsqueeze(1)
            #
            batch_y = torch.cat([High_noise, Low_noise], dim=1)
            batch_x = torch.cat([High_origin, Low_origin], dim=1)

            output = model(batch_y.cuda())

            loss_hf = criterion(output.cpu(), batch_x)

            output_hf = ComplexTensor(output[:, 0], output[:, 1]).abs()
            output_lf = ComplexTensor(output[:, 2], output[:, 3]).abs()

            final_output = ComplexTensor(output[:, 0] + output[:, 2],
                                         output[:, 1] + output[:, 3]).abs()

            High_noise = ComplexTensor(batch_y[:, 0], batch_y[:, 1]).abs()
            High_origin = ComplexTensor(batch_x[:, 0], batch_x[:, 1]).abs()

            Low_noise = ComplexTensor(batch_y[:, 2], batch_y[:, 3]).abs()
            Low_origin = ComplexTensor(batch_x[:, 2], batch_x[:, 3]).abs()

            fig = plt.figure()
            gs = GridSpec(nrows=2, ncols=4)

            highfreq1 = fig.add_subplot(gs[0, 0])
예제 #15
0
파일: beamformer.py 프로젝트: zqs01/espnet
def apply_beamforming_vector(beamform_vector: ComplexTensor,
                             mix: ComplexTensor) -> ComplexTensor:
    # (..., C) x (..., C, T) -> (..., T)
    es = FC.einsum('...c,...ct->...t', [beamform_vector.conj(), mix])
    return es
                torch.cuda.synchronize()
                start_time = time.time()
                #  0.14
                High_noise, Low_noise = Decomposition(y_.squeeze(0), 0.125)
                High_noise = High_noise.cuda()
                Low_noise = Low_noise.cuda()
                y_ = y_.cuda()

                y_ = torch.cat([y_, High_noise, Low_noise], dim=1)
                # x = torch.cat([High_origin.unsqueeze(1), Low_origin.unsqueeze(1)], dim=1)

                output = decompose_model(y_.cuda().float()).squeeze(
                    0)  # inference
                # x_ = output[0].cpu().detach().numpy().astype(np.float32)
                output = ComplexTensor(output[1] + output[3],
                                       output[2] + output[4]).abs()
                x_ = output.cpu().detach().numpy().astype(np.float32)

                # x_ = ComplexTensor(output[:, 0] + output[:, 2], Low_noise[:, 0] + Low_noise[:, 1]).abs().squeeze(0)
                # x_ = ComplexTensor(output[:, 0] + output[:, 2], output[:, 1] + output[:, 3]).abs().squeeze(0)
                # x_ = torch.add(output[:,0], output[:,1]).squeeze(0)

                # x_ = x_.cpu().detach().numpy().astype(np.float32)

                # x_ = x_.view(y.shape[0], y.shape[1])
                # x_ = x_.cpu()
                # x_ = x_.detach().numpy().astype(np.float32)
                torch.cuda.synchronize()
                elapsed_time = time.time() - start_time

                psnr_x_ = compare_psnr(x, x_)
예제 #17
0
def test_conformer_separator_forward_backward_complex(
    input_dim,
    num_spk,
    adim,
    aheads,
    layers,
    linear_units,
    positionwise_layer_type,
    positionwise_conv_kernel_size,
    normalize_before,
    concat_after,
    dropout_rate,
    input_layer,
    positional_dropout_rate,
    attention_dropout_rate,
    nonlinear,
    conformer_pos_enc_layer_type,
    conformer_self_attn_layer_type,
    conformer_activation_type,
    use_macaron_style_in_conformer,
    use_cnn_in_conformer,
    conformer_enc_kernel_size,
    padding_idx,
):
    model = ConformerSeparator(
        input_dim=input_dim,
        num_spk=num_spk,
        adim=adim,
        aheads=aheads,
        layers=layers,
        linear_units=linear_units,
        dropout_rate=dropout_rate,
        positional_dropout_rate=positional_dropout_rate,
        attention_dropout_rate=attention_dropout_rate,
        input_layer=input_layer,
        normalize_before=normalize_before,
        concat_after=concat_after,
        positionwise_layer_type=positionwise_layer_type,
        positionwise_conv_kernel_size=positionwise_conv_kernel_size,
        use_macaron_style_in_conformer=use_macaron_style_in_conformer,
        nonlinear=nonlinear,
        conformer_pos_enc_layer_type=conformer_pos_enc_layer_type,
        conformer_self_attn_layer_type=conformer_self_attn_layer_type,
        conformer_activation_type=conformer_activation_type,
        use_cnn_in_conformer=use_cnn_in_conformer,
        conformer_enc_kernel_size=conformer_enc_kernel_size,
        padding_idx=padding_idx,
    )
    model.train()

    real = torch.rand(2, 10, input_dim)
    imag = torch.rand(2, 10, input_dim)
    x = ComplexTensor(real, imag)
    x_lens = torch.tensor([10, 8], dtype=torch.long)

    masked, flens, others = model(x, ilens=x_lens)

    assert isinstance(masked[0], ComplexTensor)
    assert len(masked) == num_spk

    masked[0].abs().mean().backward()
예제 #18
0
    # --> (B, F, T, btaps + 1, C) --> (B, F, T, (btaps + 1) * C)
    Ytilde = Ytilde.permute(0, 1, 3, 4, 2).contiguous().view(Bs, Fdim, T, -1)
    # (B, F, T, 1)
    enhanced = FC.einsum("...tc,...c->...t", [Ytilde, filter_matrix.conj()])
    return enhanced


if __name__ == "__main__":
    ############################################
    #                  Example                 #
    ############################################
    eps = 1e-10
    btaps = 5
    bdelay = 3
    # pretend to be some STFT: (B, F, C, T)
    Z = ComplexTensor(torch.rand(4, 256, 2, 518), torch.rand(4, 256, 2, 518))

    # Calculate power: (B, F, C, T)
    power = Z.real ** 2 + Z.imag ** 2
    # pretend to be some mask
    mask_speech = torch.ones_like(Z.real)
    # (..., C, T) * (..., C, T) -> (..., C, T)
    power = power * mask_speech
    # Averaging along the channel axis: (B, F, C, T) -> (B, F, T)
    power = power.mean(dim=-2)
    # (B, F, T) --> (B * F, T)
    power = power.view(-1, power.shape[-1])
    inverse_power = 1 / torch.clamp(power, min=eps)

    B, Fdim, C, T = Z.shape
예제 #19
0
def test_trace():
    t = ComplexTensor(_get_complex_array(10, 10))
    x = numpy.trace(t.numpy())
    y = F.trace(t).numpy()
    numpy.testing.assert_allclose(x, y)
예제 #20
0
def get_WPD_filter_v2(
    Phi: ComplexTensor,
    Rf: ComplexTensor,
    reference_vector: torch.Tensor,
    eps: float = 1e-15,
) -> ComplexTensor:
    """Return the WPD vector with filter v2.

        WPD is the Weighted Power minimization Distortionless response
        convolutional beamformer. As follows:

        h = (Rf^-1 @ Phi_{xx}) @ u / tr[(Rf^-1) @ Phi_{xx}]

       This implementaion is more efficient than `get_WPD_filter` as
        it skips unnecessary computation with zeros.

    Reference:
        T. Nakatani and K. Kinoshita, "A Unified Convolutional Beamformer
        for Simultaneous Denoising and Dereverberation," in IEEE Signal
        Processing Letters, vol. 26, no. 6, pp. 903-907, June 2019, doi:
        10.1109/LSP.2019.2911179.
        https://ieeexplore.ieee.org/document/8691481

    Args:
        Phi (ComplexTensor): (B, F, C, C)
            is speech PSD.
        Rf (ComplexTensor): (B, F, (btaps+1) * C, (btaps+1) * C)
            is the power normalized spatio-temporal covariance matrix.
        reference_vector (torch.Tensor): (B, C)
            is the reference_vector.
        eps (float):

    Returns:
        filter_matrix (ComplexTensor): (B, F, (btaps+1) * C)
    """
    C = reference_vector.shape[-1]
    try:
        inv_Rf = inv(Rf)
    except Exception:
        try:
            reg_coeff_tensor = (
                ComplexTensor(torch.rand_like(Rf.real), torch.rand_like(Rf.real)) * 1e-4
            )
            Rf = Rf / 10e4
            Phi = Phi / 10e4
            Rf += reg_coeff_tensor
            inv_Rf = inv(Rf)
        except Exception:
            reg_coeff_tensor = (
                ComplexTensor(torch.rand_like(Rf.real), torch.rand_like(Rf.real)) * 1e-1
            )
            Rf = Rf / 10e10
            Phi = Phi / 10e10
            Rf += reg_coeff_tensor
            inv_Rf = inv(Rf)
    # (B, F, (btaps+1) * C, (btaps+1) * C) --> (B, F, (btaps+1) * C, C)
    inv_Rf_pruned = inv_Rf[..., :C]
    # numerator: (..., C_1, C_2) x (..., C_2, C_3) -> (..., C_1, C_3)
    numerator = FC.einsum("...ec,...cd->...ed", [inv_Rf_pruned, Phi])
    # ws: (..., (btaps+1) * C, C) / (...,) -> (..., (btaps+1) * C, C)
    ws = numerator / (FC.trace(numerator[..., :C, :])[..., None, None] + eps)
    # h: (..., F, C_1, C_2) x (..., C_2) -> (..., F, C_1)
    beamform_vector = FC.einsum("...fec,...c->...fe", [ws, reference_vector])
    # (B, F, (btaps+1) * C)
    return beamform_vector
예제 #21
0
    def forward_enh(
        self,
        speech_mix: torch.Tensor,
        speech_mix_lengths: torch.Tensor = None,
        resort_pre: bool = True,
        **kwargs,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
        """Frontend + Encoder + Decoder + Calc loss

        Args:
            speech_mix: (Batch, samples) or (Batch, samples, channels)
            speech_ref: (Batch, num_speaker, samples)
                        or (Batch, num_speaker, samples, channels)
            speech_mix_lengths: (Batch,), default None for chunk interator,
                            because the chunk-iterator does not have the
                            speech_lengths returned. see in
                            espnet2/iterators/chunk_iter_factory.py
        """
        # clean speech signal of each speaker
        speech_ref = [
            kwargs["speech_ref{}".format(spk + 1)]
            for spk in range(self.num_spk)
        ]
        # (Batch, num_speaker, samples) or (Batch, num_speaker, samples, channels)
        speech_ref = torch.stack(speech_ref, dim=1)

        if "noise_ref1" in kwargs:
            # noise signal (optional, required when using
            # frontend models with beamformering)
            noise_ref = [
                kwargs["noise_ref{}".format(n + 1)]
                for n in range(self.num_noise_type)
            ]
            # (Batch, num_noise_type, samples) or
            # (Batch, num_noise_type, samples, channels)
            noise_ref = torch.stack(noise_ref, dim=1)
        else:
            noise_ref = None

        # dereverberated noisy signal
        # (optional, only used for frontend models with WPE)
        dereverb_speech_ref = kwargs.get("dereverb_ref", None)

        batch_size = speech_mix.shape[0]
        speech_lengths = (speech_mix_lengths if speech_mix_lengths is not None
                          else torch.ones(batch_size).int() *
                          speech_mix.shape[1])
        assert speech_lengths.dim() == 1, speech_lengths.shape
        # Check that batch_size is unified
        assert speech_mix.shape[0] == speech_ref.shape[
            0] == speech_lengths.shape[0], (
                speech_mix.shape,
                speech_ref.shape,
                speech_lengths.shape,
            )

        # for data-parallel
        speech_ref = speech_ref[:, :, :speech_lengths.max()]
        speech_mix = speech_mix[:, :speech_lengths.max()]

        if self.loss_type != "si_snr":
            # prepare reference speech and reference spectrum
            speech_ref = torch.unbind(speech_ref, dim=1)
            spectrum_ref = [self.enh_model.stft(sr)[0] for sr in speech_ref]

            # List[ComplexTensor(Batch, T, F)] or List[ComplexTensor(Batch, T, C, F)]
            spectrum_ref = [
                ComplexTensor(sr[..., 0], sr[..., 1]) for sr in spectrum_ref
            ]
            spectrum_mix = self.enh_model.stft(speech_mix)[0]
            spectrum_mix = ComplexTensor(spectrum_mix[..., 0],
                                         spectrum_mix[..., 1])

            # predict separated speech and masks
            spectrum_pre, tf_length, mask_pre = self.enh_model(
                speech_mix, speech_lengths)

            # TODO(Chenda), Shall we add options for computing loss on
            #  the masked spectrum?
            # compute TF masking loss
            if self.loss_type == "magnitude":
                # compute loss on magnitude spectrum
                magnitude_pre = [abs(ps) for ps in spectrum_pre]
                magnitude_ref = [abs(sr) for sr in spectrum_ref]
                tf_loss, perm = self._permutation_loss(magnitude_ref,
                                                       magnitude_pre,
                                                       self.tf_mse_loss)
            elif self.loss_type == "spectrum":
                # compute loss on complex spectrum
                tf_loss, perm = self._permutation_loss(spectrum_ref,
                                                       spectrum_pre,
                                                       self.tf_mse_loss)
            elif self.loss_type.startswith("mask"):
                if self.loss_type == "mask_mse":
                    loss_func = self.tf_mse_loss
                else:
                    raise ValueError("Unsupported loss type: %s" %
                                     self.loss_type)

                assert mask_pre is not None
                mask_pre_ = [
                    mask_pre["spk{}".format(spk + 1)]
                    for spk in range(self.num_spk)
                ]

                # prepare ideal masks
                mask_ref = self._create_mask_label(spectrum_mix,
                                                   spectrum_ref,
                                                   mask_type=self.mask_type)

                # compute TF masking loss
                tf_loss, perm = self._permutation_loss(mask_ref, mask_pre_,
                                                       loss_func)

                if "dereverb" in mask_pre:
                    if dereverb_speech_ref is None:
                        raise ValueError(
                            "No dereverberated reference for training!\n"
                            'Please specify "--use_dereverb_ref true" in run.sh'
                        )

                    dereverb_spectrum_ref = self.enh_model.stft(
                        dereverb_speech_ref)[0]
                    dereverb_spectrum_ref = ComplexTensor(
                        dereverb_spectrum_ref[..., 0],
                        dereverb_spectrum_ref[..., 1])
                    # ComplexTensor(B, T, F) or ComplexTensor(B, T, C, F)
                    dereverb_mask_ref = self._create_mask_label(
                        spectrum_mix, [dereverb_spectrum_ref],
                        mask_type=self.mask_type)[0]

                    tf_loss = (tf_loss + loss_func(
                        dereverb_mask_ref, mask_pre["dereverb"]).mean())

                if "noise1" in mask_pre:
                    if noise_ref is None:
                        raise ValueError(
                            "No noise reference for training!\n"
                            'Please specify "--use_noise_ref true" in run.sh')

                    noise_ref = torch.unbind(noise_ref, dim=1)
                    noise_spectrum_ref = [
                        self.enh_model.stft(nr)[0] for nr in noise_ref
                    ]
                    noise_spectrum_ref = [
                        ComplexTensor(nr[..., 0], nr[..., 1])
                        for nr in noise_spectrum_ref
                    ]
                    noise_mask_ref = self._create_mask_label(
                        spectrum_mix,
                        noise_spectrum_ref,
                        mask_type=self.mask_type)

                    mask_noise_pre = [
                        mask_pre["noise{}".format(n + 1)]
                        for n in range(self.num_noise_type)
                    ]
                    tf_noise_loss, perm_n = self._permutation_loss(
                        noise_mask_ref, mask_noise_pre, loss_func)
                    tf_loss = tf_loss + tf_noise_loss
            else:
                raise ValueError("Unsupported loss type: %s" % self.loss_type)

            if spectrum_pre is None and self.loss_type == "mask":
                # Need the wav prediction in training
                # TODO(Jing): should coordinate with the enh/nets/***, this is ugly now.
                self.enh_model.training = False
                speech_pre, *__ = self.enh_model.forward_rawwav(
                    speech_mix, speech_lengths)
                self.enh_model.training = self.training
            else:
                speech_pre, *__ = self.enh_model.forward_rawwav(
                    speech_mix, speech_lengths)

            loss = tf_loss

        else:
            if speech_ref.dim() == 4:
                # For si_snr loss of multi-channel input,
                # only select one channel as the reference
                speech_ref = speech_ref[..., self.ref_channel]

            speech_pre, speech_lengths, *__ = self.enh_model.forward_rawwav(
                speech_mix, speech_lengths)
            # speech_pre: list[(batch, sample)]
            assert speech_pre[0].dim() == 2, speech_pre[0].dim()
            speech_ref = torch.unbind(speech_ref, dim=1)

            # compute si-snr loss
            si_snr_loss, perm = self._permutation_loss(
                speech_ref, speech_pre, self.si_snr_loss_zeromean)
            loss = si_snr_loss

        if resort_pre:
            # speech_pre : list[(bs,T)] of spk
            # perm : list[(num_spk)] of batch
            speech_pre_list = []
            for batch_idx, p in enumerate(perm):
                batch_list = []
                for spk_idx in p:
                    batch_list.append(speech_pre[spk_idx][batch_idx])  # spk,T
                speech_pre_list.append(torch.stack(batch_list, dim=0))

            speech_pre = torch.stack(speech_pre_list, dim=0)  # bs,num_spk,T
        else:
            speech_pre = torch.stack(speech_pre, dim=1)  # bs,num_spk,T

        return loss, perm, speech_pre
예제 #22
0
def signal_framing(
    signal: Union[torch.Tensor, ComplexTensor],
    frame_length: int,
    frame_step: int,
    bdelay: int,
    do_padding: bool = False,
    pad_value: int = 0,
    indices: List = None,
) -> Union[torch.Tensor, ComplexTensor]:
    """Expand `signal` into several frames, with each frame of length `frame_length`.

    Args:
        signal : (..., T)
        frame_length:   length of each segment
        frame_step:     step for selecting frames
        bdelay:         delay for WPD
        do_padding:     whether or not to pad the input signal at the beginning
                          of the time dimension
        pad_value:      value to fill in the padding

    Returns:
        torch.Tensor:
            if do_padding: (..., T, frame_length)
            else:          (..., T - bdelay - frame_length + 2, frame_length)
    """
    if indices is None:
        frame_length2 = frame_length - 1
        # pad to the right at the last dimension of `signal` (time dimension)
        if do_padding:
            # (..., T) --> (..., T + bdelay + frame_length - 2)
            signal = FC.pad(
                signal, (bdelay + frame_length2 - 1, 0), "constant", pad_value
            )

        # indices:
        # [[ 0, 1, ..., frame_length2 - 1,              frame_length2 - 1 + bdelay ],
        #  [ 1, 2, ..., frame_length2,                  frame_length2 + bdelay     ],
        #  [ 2, 3, ..., frame_length2 + 1,              frame_length2 + 1 + bdelay ],
        #  ...
        #  [ T-bdelay-frame_length2, ..., T-1-bdelay,   T-1 ]
        indices = [
            [*range(i, i + frame_length2), i + frame_length2 + bdelay - 1]
            for i in range(0, signal.shape[-1] - frame_length2 - bdelay + 1, frame_step)
        ]

    if isinstance(signal, ComplexTensor):
        real = signal_framing(
            signal.real,
            frame_length,
            frame_step,
            bdelay,
            do_padding,
            pad_value,
            indices,
        )
        imag = signal_framing(
            signal.imag,
            frame_length,
            frame_step,
            bdelay,
            do_padding,
            pad_value,
            indices,
        )
        return ComplexTensor(real, imag)
    else:
        # (..., T - bdelay - frame_length + 2, frame_length)
        signal = signal[..., indices]
        # signal[..., :-1] = -signal[..., :-1]
        return signal
예제 #23
0
    def forward(self, input: torch.Tensor, ilens: torch.Tensor):
        """Forward.

        Args:
            input (torch.Tensor): mixed speech [Batch, Nsample, Channel]
            ilens (torch.Tensor): input lengths [Batch]

        Returns:
            enhanced speech  (single-channel):
                torch.Tensor or List[torch.Tensor]
            output lengths
            predcited masks: OrderedDict[
                'dereverb': torch.Tensor(Batch, Frames, Channel, Freq),
                'spk1': torch.Tensor(Batch, Frames, Channel, Freq),
                'spk2': torch.Tensor(Batch, Frames, Channel, Freq),
                ...
                'spkn': torch.Tensor(Batch, Frames, Channel, Freq),
                'noise1': torch.Tensor(Batch, Frames, Channel, Freq),
            ]
        """
        # wave -> stft -> magnitude specturm
        input_spectrum, flens = self.stft(input, ilens)
        # (Batch, Frames, Freq) or (Batch, Frames, Channels, Freq)
        input_spectrum = ComplexTensor(input_spectrum[..., 0], input_spectrum[..., 1])
        if self.normalize_input:
            input_spectrum = input_spectrum / abs(input_spectrum).max()

        enhanced = input_spectrum
        masks = OrderedDict()

        if input_spectrum.dim() == 3:
            # single-channel input
            if self.use_wpe:
                # (B, T, F)
                enhanced, flens, mask_w = self.wpe(input_spectrum.unsqueeze(-2), flens)
                enhanced = enhanced.squeeze(-2)
                if mask_w is not None:
                    masks["dereverb"] = mask_w.squeeze(-2)

        elif input_spectrum.dim() == 4:
            # multi-channel input
            # 1. WPE
            if self.use_wpe:
                # (B, T, C, F)
                enhanced, flens, mask_w = self.wpe(input_spectrum, flens)
                if mask_w is not None:
                    masks["dereverb"] = mask_w

            # 2. Beamformer
            if self.use_beamformer:
                # enhanced: (B, T, C, F) -> (B, T, F)
                enhanced, flens, masks_b = self.beamformer(enhanced, flens)
                for spk in range(self.num_spk):
                    masks["spk{}".format(spk + 1)] = masks_b[spk]
                if len(masks_b) > self.num_spk:
                    masks["noise1"] = masks_b[self.num_spk]

        else:
            raise ValueError(
                "Invalid spectrum dimension: {}".format(input_spectrum.shape)
            )

        # Convert ComplexTensor to torch.Tensor
        # (B, T, F) -> (B, T, F, 2)
        if isinstance(enhanced, list):
            # multi-speaker output
            enhanced = [torch.stack([enh.real, enh.imag], dim=-1) for enh in enhanced]
        else:
            # single-speaker output
            enhanced = torch.stack([enhanced.real, enhanced.imag], dim=-1).float()
        return enhanced, flens, masks
예제 #24
0
def online_wpe_step(input_buffer: ComplexTensor,
                    power: torch.Tensor,
                    inv_cov: ComplexTensor = None,
                    filter_taps: ComplexTensor = None,
                    alpha: float = 0.99,
                    taps: int = 10,
                    delay: int = 3):
    """One step of online dereverberation.

    Args:
        input_buffer: (F, C, taps + delay + 1)
        power: Estimate for the current PSD (F, T)
        inv_cov: Current estimate of R^-1
        filter_taps: Current estimate of filter taps (F, taps * C, taps)
        alpha (float): Smoothing factor
        taps (int): Number of filter taps
        delay (int): Delay in frames

    Returns:
        Dereverberated frame of shape (F, D)
        Updated estimate of R^-1
        Updated estimate of the filter taps


    >>> frame_length = 512
    >>> frame_shift = 128
    >>> taps = 6
    >>> delay = 3
    >>> alpha = 0.999
    >>> frequency_bins = frame_length // 2 + 1
    >>> Q = None
    >>> G = None
    >>> unreverbed, Q, G = online_wpe_step(stft, get_power_online(stft), Q, G,
    ...                                    alpha=alpha, taps=taps, delay=delay)

    """
    assert input_buffer.size(-1) == taps + delay + 1, input_buffer.size()
    C = input_buffer.size(-2)

    if inv_cov is None:
        inv_cov = ComplexTensor(
            torch.eye(C * taps, dtype=input_buffer.dtype).expand(
                *input_buffer.size()[:-2], C * taps, C * taps))
    if filter_taps is None:
        filter_taps = ComplexTensor(
            torch.zeros(*input_buffer.size()[:-2],
                        C * taps,
                        C,
                        dtype=input_buffer.dtype))

    window = FC.reverse(input_buffer[..., :-delay - 1], dim=-1)
    # (..., C, T) -> (..., C * T)
    window = window.view(*input_buffer.size()[:-2], -1)
    pred = input_buffer[..., -1] - FC.einsum('...id,...i->...d',
                                             (filter_taps.conj(), window))

    nominator = FC.einsum('...ij,...j->...i', (inv_cov, window))
    denominator = \
        FC.einsum('...i,...i->...', (window.conj(), nominator)) + alpha * power
    kalman_gain = nominator / denominator[..., None]

    inv_cov_k = inv_cov - FC.einsum('...j,...jm,...i->...im',
                                    (window.conj(), inv_cov, kalman_gain))
    inv_cov_k /= alpha

    filter_taps_k = \
        filter_taps + FC.einsum('...i,...m->...im', (kalman_gain, pred.conj()))
    return pred, inv_cov_k, filter_taps_k
예제 #25
0
    def forward(
        self,
        speech_mix: torch.Tensor,
        speech_mix_lengths: torch.Tensor = None,
        **kwargs,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
        """Frontend + Encoder + Decoder + Calc loss

        Args:
            speech_mix: (Batch, samples) or (Batch, samples, channels)
            speech_ref: (Batch, num_speaker, samples)
                        or (Batch, num_speaker, samples, channels)
            speech_mix_lengths: (Batch,), default None for chunk interator,
                            because the chunk-iterator does not have the
                            speech_lengths returned. see in
                            espnet2/iterators/chunk_iter_factory.py
        """
        # clean speech signal of each speaker
        speech_ref = [
            kwargs["speech_ref{}".format(spk + 1)]
            for spk in range(self.num_spk)
        ]
        # (Batch, num_speaker, samples) or (Batch, num_speaker, samples, channels)
        speech_ref = torch.stack(speech_ref, dim=1)

        if "noise_ref1" in kwargs:
            # noise signal (optional, required when using
            # frontend models with beamformering)
            noise_ref = [
                kwargs["noise_ref{}".format(n + 1)]
                for n in range(self.num_noise_type)
            ]
            # (Batch, num_noise_type, samples) or
            # (Batch, num_noise_type, samples, channels)
            noise_ref = torch.stack(noise_ref, dim=1)
        else:
            noise_ref = None

        # dereverberated noisy signal
        # (optional, only used for frontend models with WPE)
        dereverb_speech_ref = kwargs.get("dereverb_ref", None)

        batch_size = speech_mix.shape[0]
        speech_lengths = (speech_mix_lengths if speech_mix_lengths is not None
                          else torch.ones(batch_size).int() *
                          speech_mix.shape[1])
        assert speech_lengths.dim() == 1, speech_lengths.shape
        # Check that batch_size is unified
        assert speech_mix.shape[0] == speech_ref.shape[
            0] == speech_lengths.shape[0], (
                speech_mix.shape,
                speech_ref.shape,
                speech_lengths.shape,
            )
        batch_size = speech_mix.shape[0]

        # for data-parallel
        speech_ref = speech_ref[:, :, :speech_lengths.max()]
        speech_mix = speech_mix[:, :speech_lengths.max()]

        if self.loss_type != "si_snr":
            # prepare reference speech and reference spectrum
            speech_ref = torch.unbind(speech_ref, dim=1)
            spectrum_ref = [self.enh_model.stft(sr)[0] for sr in speech_ref]

            # List[ComplexTensor(Batch, T, F)] or List[ComplexTensor(Batch, T, C, F)]
            spectrum_ref = [
                ComplexTensor(sr[..., 0], sr[..., 1]) for sr in spectrum_ref
            ]
            spectrum_mix = self.enh_model.stft(speech_mix)[0]
            spectrum_mix = ComplexTensor(spectrum_mix[..., 0],
                                         spectrum_mix[..., 1])

            # predict separated speech and masks
            spectrum_pre, tf_length, mask_pre = self.enh_model(
                speech_mix, speech_lengths)

            # compute TF masking loss
            if self.loss_type == "magnitude":
                # compute loss on magnitude spectrum
                magnitude_pre = [abs(ps) for ps in spectrum_pre]
                magnitude_ref = [abs(sr) for sr in spectrum_ref]
                tf_loss, perm = self._permutation_loss(magnitude_ref,
                                                       magnitude_pre,
                                                       self.tf_mse_loss)
            elif self.loss_type == "spectrum":
                # compute loss on complex spectrum
                tf_loss, perm = self._permutation_loss(spectrum_ref,
                                                       spectrum_pre,
                                                       self.tf_mse_loss)
            elif self.loss_type.startswith("mask"):
                if self.loss_type == "mask_mse":
                    loss_func = self.tf_mse_loss
                else:
                    raise ValueError("Unsupported loss type: %s" %
                                     self.loss_type)

                assert mask_pre is not None
                mask_pre_ = [
                    mask_pre["spk{}".format(spk + 1)]
                    for spk in range(self.num_spk)
                ]

                # prepare ideal masks
                mask_ref = self._create_mask_label(spectrum_mix,
                                                   spectrum_ref,
                                                   mask_type=self.mask_type)

                # compute TF masking loss
                tf_loss, perm = self._permutation_loss(mask_ref, mask_pre_,
                                                       loss_func)

                if "dereverb" in mask_pre:
                    if dereverb_speech_ref is None:
                        raise ValueError(
                            "No dereverberated reference for training!\n"
                            'Please specify "--use_dereverb_ref true" in run.sh'
                        )

                    dereverb_spectrum_ref = self.enh_model.stft(
                        dereverb_speech_ref)[0]
                    dereverb_spectrum_ref = ComplexTensor(
                        dereverb_spectrum_ref[..., 0],
                        dereverb_spectrum_ref[..., 1])
                    # ComplexTensor(B, T, F) or ComplexTensor(B, T, C, F)
                    dereverb_mask_ref = self._create_mask_label(
                        spectrum_mix, [dereverb_spectrum_ref],
                        mask_type=self.mask_type)[0]

                    tf_loss = (tf_loss + loss_func(
                        dereverb_mask_ref, mask_pre["dereverb"]).mean())

                if "noise1" in mask_pre:
                    if noise_ref is None:
                        raise ValueError(
                            "No noise reference for training!\n"
                            'Please specify "--use_noise_ref true" in run.sh')

                    noise_ref = torch.unbind(noise_ref, dim=1)
                    noise_spectrum_ref = [
                        self.enh_model.stft(nr)[0] for nr in noise_ref
                    ]
                    noise_spectrum_ref = [
                        ComplexTensor(nr[..., 0], nr[..., 1])
                        for nr in noise_spectrum_ref
                    ]
                    noise_mask_ref = self._create_mask_label(
                        spectrum_mix,
                        noise_spectrum_ref,
                        mask_type=self.mask_type)

                    mask_noise_pre = [
                        mask_pre["noise{}".format(n + 1)]
                        for n in range(self.num_noise_type)
                    ]
                    tf_noise_loss, perm_n = self._permutation_loss(
                        noise_mask_ref, mask_noise_pre, loss_func)
                    tf_loss = tf_loss + tf_noise_loss
            else:
                raise ValueError("Unsupported loss type: %s" % self.loss_type)

            if self.training:
                si_snr = None
            else:
                speech_pre = [
                    self.enh_model.stft.inverse(ps, speech_lengths)[0]
                    for ps in spectrum_pre
                ]
                if speech_ref[0].dim() == 3:
                    # For si_snr loss, only select one channel as the reference
                    speech_ref = [
                        sr[..., self.ref_channel] for sr in speech_ref
                    ]
                # compute si-snr loss
                si_snr_loss, perm = self._permutation_loss(speech_ref,
                                                           speech_pre,
                                                           self.si_snr_loss,
                                                           perm=perm)
                si_snr = -si_snr_loss.detach()

            loss = tf_loss

            stats = dict(
                si_snr=si_snr,
                loss=loss.detach(),
            )
        else:
            if speech_ref.dim() == 4:
                # For si_snr loss of multi-channel input,
                # only select one channel as the reference
                speech_ref = speech_ref[..., self.ref_channel]

            speech_pre, speech_lengths, *__ = self.enh_model.forward_rawwav(
                speech_mix, speech_lengths)
            # speech_pre: list[(batch, sample)]
            assert speech_pre[0].dim() == 2, speech_pre[0].dim()
            speech_ref = torch.unbind(speech_ref, dim=1)

            # compute si-snr loss
            si_snr_loss, perm = self._permutation_loss(
                speech_ref, speech_pre, self.si_snr_loss_zeromean)
            si_snr = -si_snr_loss
            loss = si_snr_loss
            stats = dict(si_snr=si_snr.detach(), loss=loss.detach())

        # force_gatherable: to-device and to-tensor if scalar for DataParallel
        loss, stats, weight = force_gatherable((loss, stats, batch_size),
                                               loss.device)
        return loss, stats, weight
                High_origin, Low_origin = Decomposition(batch_x, 0.80)
                High_noise, Low_noise = Decomposition(batch_y, 0.80)

            # batch_y = batch_y.unsqueeze(1)
            # batch_x = batch_x.unsqueeze(1)
            #
            # batch_y = torch.cat([High_noise.unsqueeze(1), Low_noise.unsqueeze(1)], dim=1)
            # batch_x = torch.cat([High_origin.unsqueeze(1), Low_origin.unsqueeze(1)], dim=1)

            output_hf = model_hf(High_noise.cuda())
            output_lf = model_lf(Low_noise.cuda())

            loss_hf = criterion(output_hf.cpu(), High_origin.unsqueeze(1))
            loss_lf = criterion(output_lf.cpu(), Low_origin.unsqueeze(1))

            output = ComplexTensor(output_hf[:, 0] + output_lf[:, 0],
                                   output_hf[:, 1] + output_lf[:, 1]).abs()

            output_hf = ComplexTensor(output_hf[:, 0], output_hf[:, 1]).abs()
            output_lf = ComplexTensor(output_lf[:, 0], output_lf[:, 1]).abs()

            High_noise = ComplexTensor(High_noise[:, 0], High_noise[:,
                                                                    1]).abs()
            High_origin = ComplexTensor(High_origin[:, 0],
                                        High_origin[:, 1]).abs()

            Low_noise = ComplexTensor(Low_noise[:, 0], Low_noise[:, 1]).abs()
            Low_origin = ComplexTensor(Low_origin[:, 0], Low_origin[:,
                                                                    1]).abs()

            fig = plt.figure()
            gs = GridSpec(nrows=2, ncols=4)
예제 #27
0
    def forward(
        self, data: ComplexTensor, ilens: torch.LongTensor
    ) -> Tuple[ComplexTensor, torch.LongTensor, torch.Tensor]:
        """The forward function

        Notation:
            B: Batch
            C: Channel
            T: Time or Sequence length
            F: Freq

        Args:
            data (ComplexTensor): (B, T, C, F), double precision
            ilens (torch.Tensor): (B,)
        Returns:
            enhanced (ComplexTensor): (B, T, F), double precision
            ilens (torch.Tensor): (B,)
            masks (torch.Tensor): (B, T, C, F)
        """
        def apply_beamforming(data, ilens, psd_speech, psd_n, beamformer_type):
            # u: (B, C)
            if self.ref_channel < 0:
                u, _ = self.ref(psd_speech.float(), ilens)
            else:
                # (optional) Create onehot vector for fixed reference microphone
                u = torch.zeros(*(data.size()[:-3] + (data.size(-2), )),
                                device=data.device)
                u[..., self.ref_channel].fill_(1)

            if beamformer_type in ("mpdr", "mvdr"):
                ws = get_mvdr_vector(psd_speech, psd_n, u.double())
                enhanced = apply_beamforming_vector(ws, data)
            elif beamformer_type == "wpd":
                ws = get_WPD_filter_v2(psd_speech, psd_n, u.double())
                enhanced = perform_WPD_filtering(ws, data, self.bdelay,
                                                 self.btaps)
            else:
                raise ValueError("Not supporting beamformer_type={}".format(
                    beamformer_type))

            return enhanced, ws

        # data (B, T, C, F) -> (B, F, C, T)
        data = data.permute(0, 3, 2, 1)

        # mask: [(B, F, C, T)]
        masks, _ = self.mask(data.float(), ilens)
        assert self.nmask == len(masks)
        # floor masks with self.eps to increase numerical stability
        masks = [torch.clamp(m, min=self.eps) for m in masks]

        if self.num_spk == 1:  # single-speaker case
            if self.use_noise_mask:
                # (mask_speech, mask_noise)
                mask_speech, mask_noise = masks
            else:
                # (mask_speech,)
                mask_speech = masks[0]
                mask_noise = 1 - mask_speech

            psd_speech = get_power_spectral_density_matrix(
                data, mask_speech.double())
            if self.beamformer_type == "mvdr":
                # psd of noise
                psd_n = get_power_spectral_density_matrix(
                    data, mask_noise.double())
            elif self.beamformer_type == "mpdr":
                # psd of observed signal
                psd_n = FC.einsum("...ct,...et->...ce", [data, data.conj()])
            elif self.beamformer_type == "wpd":
                # Calculate power: (..., C, T)
                power_speech = (data.real**2 +
                                data.imag**2) * mask_speech.double()
                # Averaging along the channel axis: (B, F, C, T) -> (B, F, T)
                power_speech = power_speech.mean(dim=-2)
                inverse_power = 1 / torch.clamp(power_speech, min=self.eps)
                # covariance of expanded observed speech
                psd_n = get_covariances(data,
                                        inverse_power,
                                        self.bdelay,
                                        self.btaps,
                                        get_vector=False)
            else:
                raise ValueError("Not supporting beamformer_type={}".format(
                    self.beamformer_type))

            enhanced, ws = apply_beamforming(data, ilens, psd_speech, psd_n,
                                             self.beamformer_type)

            # (..., F, T) -> (..., T, F)
            enhanced = enhanced.transpose(-1, -2)
        else:  # multi-speaker case
            if self.use_noise_mask:
                # (mask_speech1, ..., mask_noise)
                mask_speech = list(masks[:-1])
                mask_noise = masks[-1]
            else:
                # (mask_speech1, ..., mask_speechX)
                mask_speech = list(masks)
                mask_noise = None

            psd_speeches = [
                get_power_spectral_density_matrix(data, mask)
                for mask in mask_speech
            ]
            if self.beamformer_type == "mvdr":
                # psd of noise
                if mask_noise is not None:
                    psd_n = get_power_spectral_density_matrix(data, mask_noise)
            elif self.beamformer_type == "mpdr":
                # psd of observed speech
                psd_n = FC.einsum("...ct,...et->...ce", [data, data.conj()])
            elif self.beamformer_type == "wpd":
                # Calculate power: (..., C, T)
                power = data.real**2 + data.imag**2
                power_speeches = [power * mask for mask in mask_speech]
                # Averaging along the channel axis: (B, F, C, T) -> (B, F, T)
                power_speeches = [ps.mean(dim=-2) for ps in power_speeches]
                inverse_poweres = [
                    1 / torch.clamp(ps, min=self.eps) for ps in power_speeches
                ]
                # covariance of expanded observed speech
                psd_n = [
                    get_covariances(data,
                                    inv_ps,
                                    self.bdelay,
                                    self.btaps,
                                    get_vector=False)
                    for inv_ps in inverse_poweres
                ]
            else:
                raise ValueError("Not supporting beamformer_type={}".format(
                    self.beamformer_type))

            enhanced = []
            for i in range(self.num_spk):
                psd_speech = psd_speeches.pop(i)
                # treat all other speakers' psd_speech as noises
                if self.beamformer_type == "mvdr":
                    psd_noise = sum(psd_speeches)
                    if mask_noise is not None:
                        psd_noise = psd_noise + psd_n

                    enh, w = apply_beamforming(data, ilens, psd_speech,
                                               psd_noise, self.beamformer_type)
                elif self.beamformer_type == "mpdr":
                    enh, w = apply_beamforming(data, ilens, psd_speech, psd_n,
                                               self.beamformer_type)
                elif self.beamformer_type == "wpd":
                    enh, w = apply_beamforming(data, ilens, psd_speech,
                                               psd_n[i], self.beamformer_type)
                else:
                    raise ValueError(
                        "Not supporting beamformer_type={}".format(
                            self.beamformer_type))
                psd_speeches.insert(i, psd_speech)

                # (..., F, T) -> (..., T, F)
                enh = enh.transpose(-1, -2)
                enhanced.append(enh)

        # (..., F, C, T) -> (..., T, C, F)
        masks = [m.transpose(-1, -3) for m in masks]
        return enhanced, ilens, masks
예제 #28
0
def composition(high, low, end=0.2, start=0.0):
    output = ComplexTensor(high[:,1]+low[:, 1], low[:,0]+ low[:, 2]).abs()


    return output
예제 #29
0
def to_torch_tensor(x):
    """Change to torch.Tensor or ComplexTensor from numpy.ndarray.

    Args:
        x: Inputs. It should be one of numpy.ndarray, Tensor, ComplexTensor, and dict.

    Returns:
        Tensor or ComplexTensor: Type converted inputs.

    Examples:
        >>> xs = np.ones(3, dtype=np.float32)
        >>> xs = to_torch_tensor(xs)
        tensor([1., 1., 1.])
        >>> xs = torch.ones(3, 4, 5)
        >>> assert to_torch_tensor(xs) is xs
        >>> xs = {'real': xs, 'imag': xs}
        >>> to_torch_tensor(xs)
        ComplexTensor(
        Real:
        tensor([1., 1., 1.])
        Imag;
        tensor([1., 1., 1.])
        )

    """
    # If numpy, change to torch tensor
    if isinstance(x, np.ndarray):
        if x.dtype.kind == "c":
            # Dynamically importing because torch_complex requires python3
            from torch_complex.tensor import ComplexTensor

            return ComplexTensor(x)
        else:
            return torch.from_numpy(x)

    # If {'real': ..., 'imag': ...}, convert to ComplexTensor
    elif isinstance(x, dict):
        # Dynamically importing because torch_complex requires python3
        from torch_complex.tensor import ComplexTensor

        if "real" not in x or "imag" not in x:
            raise ValueError("has 'real' and 'imag' keys: {}".format(list(x)))
        # Relative importing because of using python3 syntax
        return ComplexTensor(x["real"], x["imag"])

    # If torch.Tensor, as it is
    elif isinstance(x, torch.Tensor):
        return x

    else:
        error = ("x must be numpy.ndarray, torch.Tensor or a dict like "
                 "{{'real': torch.Tensor, 'imag': torch.Tensor}}, "
                 "but got {}".format(type(x)))
        try:
            from torch_complex.tensor import ComplexTensor
        except Exception:
            # If PY2
            raise ValueError(error)
        else:
            # If PY3
            if isinstance(x, ComplexTensor):
                return x
            else:
                raise ValueError(error)
예제 #30
0
    def forward(
        self,
        data: ComplexTensor,
        ilens: torch.LongTensor,
        powers: Union[List[torch.Tensor], None] = None,
    ) -> Tuple[ComplexTensor, torch.LongTensor, torch.Tensor]:
        """DNN_Beamformer forward function.

        Notation:
            B: Batch
            C: Channel
            T: Time or Sequence length
            F: Freq

        Args:
            data (ComplexTensor): (B, T, C, F)
            ilens (torch.Tensor): (B,)
            powers (List[torch.Tensor] or None): used for wMPDR or WPD (B, F, T)
        Returns:
            enhanced (ComplexTensor): (B, T, F)
            ilens (torch.Tensor): (B,)
            masks (torch.Tensor): (B, T, C, F)
        """
        def apply_beamforming(data,
                              ilens,
                              psd_n,
                              psd_speech,
                              psd_distortion=None):
            """Beamforming with the provided statistics.

            Args:
                data (ComplexTensor): (B, F, C, T)
                ilens (torch.Tensor): (B,)
                psd_n (ComplexTensor):
                    Noise covariance matrix for MVDR (B, F, C, C)
                    Observation covariance matrix for MPDR/wMPDR (B, F, C, C)
                    Stacked observation covariance for WPD (B,F,(btaps+1)*C,(btaps+1)*C)
                psd_speech (ComplexTensor): Speech covariance matrix (B, F, C, C)
                psd_distortion (ComplexTensor): Noise covariance matrix (B, F, C, C)
            Return:
                enhanced (ComplexTensor): (B, F, T)
                ws (ComplexTensor): (B, F) or (B, F, (btaps+1)*C)
            """
            # u: (B, C)
            if self.ref_channel < 0:
                u, _ = self.ref(psd_speech.to(dtype=data.dtype), ilens)
                u = u.double()
            else:
                if self.beamformer_type.endswith("_souden"):
                    # (optional) Create onehot vector for fixed reference microphone
                    u = torch.zeros(*(data.size()[:-3] + (data.size(-2), )),
                                    device=data.device,
                                    dtype=torch.double)
                    u[..., self.ref_channel].fill_(1)
                else:
                    # for simplifying computation in RTF-based beamforming
                    u = self.ref_channel

            if self.beamformer_type in ("mvdr", "mpdr", "wmpdr"):
                ws = get_mvdr_vector_with_rtf(
                    psd_n.double(),
                    psd_speech.double(),
                    psd_distortion.double(),
                    iterations=self.rtf_iterations,
                    reference_vector=u,
                    normalize_ref_channel=self.ref_channel,
                    use_torch_solver=self.use_torch_solver,
                    diagonal_loading=self.diagonal_loading,
                    diag_eps=self.diag_eps,
                )
                enhanced = apply_beamforming_vector(ws, data.double())
            elif self.beamformer_type in ("mpdr_souden", "mvdr_souden",
                                          "wmpdr_souden"):
                ws = get_mvdr_vector(
                    psd_speech.double(),
                    psd_n.double(),
                    u,
                    use_torch_solver=self.use_torch_solver,
                    diagonal_loading=self.diagonal_loading,
                    diag_eps=self.diag_eps,
                )
                enhanced = apply_beamforming_vector(ws, data.double())
            elif self.beamformer_type == "wpd":
                ws = get_WPD_filter_with_rtf(
                    psd_n.double(),
                    psd_speech.double(),
                    psd_distortion.double(),
                    iterations=self.rtf_iterations,
                    reference_vector=u,
                    normalize_ref_channel=self.ref_channel,
                    use_torch_solver=self.use_torch_solver,
                    diagonal_loading=self.diagonal_loading,
                    diag_eps=self.diag_eps,
                )
                enhanced = perform_WPD_filtering(ws, data.double(),
                                                 self.bdelay, self.btaps)
            elif self.beamformer_type == "wpd_souden":
                ws = get_WPD_filter_v2(
                    psd_speech.double(),
                    psd_n.double(),
                    u,
                    diagonal_loading=self.diagonal_loading,
                    diag_eps=self.diag_eps,
                )
                enhanced = perform_WPD_filtering(ws, data.double(),
                                                 self.bdelay, self.btaps)
            else:
                raise ValueError("Not supporting beamformer_type={}".format(
                    self.beamformer_type))

            return enhanced.to(dtype=data.dtype), ws.to(dtype=data.dtype)

        # data (B, T, C, F) -> (B, F, C, T)
        data = data.permute(0, 3, 2, 1)
        data_d = data.double()

        # mask: [(B, F, C, T)]
        masks, _ = self.mask(data, ilens)
        assert self.nmask == len(masks), len(masks)
        # floor masks to increase numerical stability
        if self.mask_flooring:
            masks = [torch.clamp(m, min=self.flooring_thres) for m in masks]

        if self.num_spk == 1:  # single-speaker case
            if self.use_noise_mask:
                # (mask_speech, mask_noise)
                mask_speech, mask_noise = masks
            else:
                # (mask_speech,)
                mask_speech = masks[0]
                mask_noise = 1 - mask_speech

            if self.beamformer_type.startswith(
                    "wmpdr") or self.beamformer_type.startswith("wpd"):
                if powers is None:
                    power_input = data_d.real**2 + data_d.imag**2
                    # Averaging along the channel axis: (..., C, T) -> (..., T)
                    powers = (power_input * mask_speech.double()).mean(dim=-2)
                else:
                    assert len(powers) == 1, len(powers)
                    powers = powers[0]
                inverse_power = 1 / torch.clamp(powers, min=self.eps)

            psd_speech = get_power_spectral_density_matrix(
                data_d, mask_speech.double())
            if mask_noise is not None and (
                    self.beamformer_type == "mvdr_souden"
                    or not self.beamformer_type.endswith("_souden")):
                # MVDR or other RTF-based formulas
                psd_noise = get_power_spectral_density_matrix(
                    data_d, mask_noise.double())
            if self.beamformer_type == "mvdr":
                enhanced, ws = apply_beamforming(data,
                                                 ilens,
                                                 psd_noise,
                                                 psd_speech,
                                                 psd_distortion=psd_noise)
            elif self.beamformer_type == "mvdr_souden":
                enhanced, ws = apply_beamforming(data, ilens, psd_noise,
                                                 psd_speech)
            elif self.beamformer_type == "mpdr":
                psd_observed = FC.einsum("...ct,...et->...ce",
                                         [data_d, data_d.conj()])
                enhanced, ws = apply_beamforming(data,
                                                 ilens,
                                                 psd_observed,
                                                 psd_speech,
                                                 psd_distortion=psd_noise)
            elif self.beamformer_type == "mpdr_souden":
                psd_observed = FC.einsum("...ct,...et->...ce",
                                         [data_d, data_d.conj()])
                enhanced, ws = apply_beamforming(data, ilens, psd_observed,
                                                 psd_speech)
            elif self.beamformer_type == "wmpdr":
                psd_observed = FC.einsum(
                    "...ct,...et->...ce",
                    [data_d * inverse_power[..., None, :],
                     data_d.conj()],
                )
                enhanced, ws = apply_beamforming(data,
                                                 ilens,
                                                 psd_observed,
                                                 psd_speech,
                                                 psd_distortion=psd_noise)
            elif self.beamformer_type == "wmpdr_souden":
                psd_observed = FC.einsum(
                    "...ct,...et->...ce",
                    [data_d * inverse_power[..., None, :],
                     data_d.conj()],
                )
                enhanced, ws = apply_beamforming(data, ilens, psd_observed,
                                                 psd_speech)
            elif self.beamformer_type == "wpd":
                psd_observed_bar = get_covariances(data_d,
                                                   inverse_power,
                                                   self.bdelay,
                                                   self.btaps,
                                                   get_vector=False)
                enhanced, ws = apply_beamforming(data,
                                                 ilens,
                                                 psd_observed_bar,
                                                 psd_speech,
                                                 psd_distortion=psd_noise)
            elif self.beamformer_type == "wpd_souden":
                psd_observed_bar = get_covariances(data_d,
                                                   inverse_power,
                                                   self.bdelay,
                                                   self.btaps,
                                                   get_vector=False)
                enhanced, ws = apply_beamforming(data, ilens, psd_observed_bar,
                                                 psd_speech)
            else:
                raise ValueError("Not supporting beamformer_type={}".format(
                    self.beamformer_type))

            # (..., F, T) -> (..., T, F)
            enhanced = enhanced.transpose(-1, -2)
        else:  # multi-speaker case
            if self.use_noise_mask:
                # (mask_speech1, ..., mask_noise)
                mask_speech = list(masks[:-1])
                mask_noise = masks[-1]
            else:
                # (mask_speech1, ..., mask_speechX)
                mask_speech = list(masks)
                mask_noise = None

            if self.beamformer_type.startswith(
                    "wmpdr") or self.beamformer_type.startswith("wpd"):
                if powers is None:
                    power_input = data_d.real**2 + data_d.imag**2
                    # Averaging along the channel axis: (..., C, T) -> (..., T)
                    powers = [(power_input * m.double()).mean(dim=-2)
                              for m in mask_speech]
                else:
                    assert len(powers) == self.num_spk, len(powers)
                inverse_power = [
                    1 / torch.clamp(p, min=self.eps) for p in powers
                ]

            psd_speeches = [
                get_power_spectral_density_matrix(data_d, mask.double())
                for mask in mask_speech
            ]
            if mask_noise is not None and (
                    self.beamformer_type == "mvdr_souden"
                    or not self.beamformer_type.endswith("_souden")):
                # MVDR or other RTF-based formulas
                psd_noise = get_power_spectral_density_matrix(
                    data_d, mask_noise.double())
            if self.beamformer_type in ("mpdr", "mpdr_souden"):
                psd_observed = FC.einsum("...ct,...et->...ce",
                                         [data_d, data_d.conj()])
            elif self.beamformer_type in ("wmpdr", "wmpdr_souden"):
                psd_observed = [
                    FC.einsum(
                        "...ct,...et->...ce",
                        [data_d * inv_p[..., None, :],
                         data_d.conj()],
                    ) for inv_p in inverse_power
                ]
            elif self.beamformer_type in ("wpd", "wpd_souden"):
                psd_observed_bar = [
                    get_covariances(data_d,
                                    inv_p,
                                    self.bdelay,
                                    self.btaps,
                                    get_vector=False)
                    for inv_p in inverse_power
                ]

            enhanced, ws = [], []
            for i in range(self.num_spk):
                psd_speech = psd_speeches.pop(i)
                if (self.beamformer_type == "mvdr_souden"
                        or not self.beamformer_type.endswith("_souden")):
                    psd_noise_i = (psd_noise + sum(psd_speeches) if mask_noise
                                   is not None else sum(psd_speeches))
                # treat all other speakers' psd_speech as noises
                if self.beamformer_type == "mvdr":
                    enh, w = apply_beamforming(data,
                                               ilens,
                                               psd_noise_i,
                                               psd_speech,
                                               psd_distortion=psd_noise_i)
                elif self.beamformer_type == "mvdr_souden":
                    enh, w = apply_beamforming(data, ilens, psd_noise_i,
                                               psd_speech)
                elif self.beamformer_type == "mpdr":
                    enh, w = apply_beamforming(
                        data,
                        ilens,
                        psd_observed,
                        psd_speech,
                        psd_distortion=psd_noise_i,
                    )
                elif self.beamformer_type == "mpdr_souden":
                    enh, w = apply_beamforming(data, ilens, psd_observed,
                                               psd_speech)
                elif self.beamformer_type == "wmpdr":
                    enh, w = apply_beamforming(
                        data,
                        ilens,
                        psd_observed[i],
                        psd_speech,
                        psd_distortion=psd_noise_i,
                    )
                elif self.beamformer_type == "wmpdr_souden":
                    enh, w = apply_beamforming(data, ilens, psd_observed[i],
                                               psd_speech)
                elif self.beamformer_type == "wpd":
                    enh, w = apply_beamforming(
                        data,
                        ilens,
                        psd_observed_bar[i],
                        psd_speech,
                        psd_distortion=psd_noise_i,
                    )
                elif self.beamformer_type == "wpd_souden":
                    enh, w = apply_beamforming(data, ilens,
                                               psd_observed_bar[i], psd_speech)
                else:
                    raise ValueError(
                        "Not supporting beamformer_type={}".format(
                            self.beamformer_type))
                psd_speeches.insert(i, psd_speech)

                # (..., F, T) -> (..., T, F)
                enh = enh.transpose(-1, -2)
                enhanced.append(enh)
                ws.append(w)

        # (..., F, C, T) -> (..., T, C, F)
        masks = [m.transpose(-1, -3) for m in masks]
        return enhanced, ilens, masks