コード例 #1
0
def perform_WPD_filtering(filter_matrix: ComplexTensor, Y: ComplexTensor,
                          bdelay: int, btaps: int) -> ComplexTensor:
    """Perform WPD filtering.

    Args:
        filter_matrix: Filter matrix (B, F, (btaps + 1) * C)
        Y : Complex STFT signal with shape (B, F, C, T)

    Returns:
        enhanced (ComplexTensor): (B, F, T)
    """
    # (B, F, C, T) --> (B, F, C, T, btaps + 1)
    Ytilde = signal_framing(Y,
                            btaps + 1,
                            1,
                            bdelay,
                            do_padding=True,
                            pad_value=0)
    Ytilde = FC.reverse(Ytilde, dim=-1)

    Bs, Fdim, C, T = Y.shape
    # --> (B, F, T, btaps + 1, C) --> (B, F, T, (btaps + 1) * C)
    Ytilde = Ytilde.permute(0, 1, 3, 4, 2).contiguous().view(Bs, Fdim, T, -1)
    # (B, F, T, 1)
    enhanced = FC.einsum("...tc,...c->...t", [Ytilde, filter_matrix.conj()])
    return enhanced
コード例 #2
0
ファイル: beamformer.py プロジェクト: zqs01/espnet
def get_power_spectral_density_matrix(xs: ComplexTensor,
                                      mask: torch.Tensor,
                                      normalization=True,
                                      eps: float = 1e-15) -> ComplexTensor:
    """Return cross-channel power spectral density (PSD) matrix

    Args:
        xs (ComplexTensor): (..., F, C, T)
        mask (torch.Tensor): (..., F, C, T)
        normalization (bool):
        eps (float):
    Returns
        psd (ComplexTensor): (..., F, C, C)

    """
    # outer product: (..., C_1, T) x (..., C_2, T) -> (..., T, C, C_2)
    psd_Y = FC.einsum('...ct,...et->...tce', [xs, xs.conj()])

    # Averaging mask along C: (..., C, T) -> (..., T)
    mask = mask.mean(dim=-2)

    # Normalized mask along T: (..., T)
    if normalization:
        # If assuming the tensor is padded with zero, the summation along
        # the time axis is same regardless of the padding length.
        mask = mask / (mask.sum(dim=-1, keepdim=True) + eps)

    # psd: (..., T, C, C)
    psd = psd_Y * mask[..., None, None]
    # (..., T, C, C) -> (..., C, C)
    psd = psd.sum(dim=-3)

    return psd
コード例 #3
0
def test_inv(ch):
    torch.manual_seed(100)
    X = ComplexTensor(torch.rand(2, 3, ch, ch), torch.rand(2, 3, ch, ch))
    X = X + X.conj().transpose(-1, -2)
    assert FC.allclose(ComplexTensor(np.linalg.inv(X.numpy())),
                       inv(X),
                       atol=1e-4)
コード例 #4
0
ファイル: test_enh_layers.py プロジェクト: jumon/espnet-1
def test_get_rtf(ch):
    stft = Stft(
        n_fft=8,
        win_length=None,
        hop_length=2,
        center=True,
        window="hann",
        normalized=False,
        onesided=True,
    )
    torch.random.manual_seed(0)
    x = random_speech[..., :ch]
    n = torch.rand(2, 16, ch, dtype=torch.double)
    ilens = torch.LongTensor([16, 12])
    # (B, T, C, F) -> (B, F, C, T)
    X = ComplexTensor(*torch.unbind(stft(x, ilens)[0], dim=-1)).transpose(-1, -3)
    N = ComplexTensor(*torch.unbind(stft(n, ilens)[0], dim=-1)).transpose(-1, -3)
    # (B, F, C, C)
    Phi_X = FC.einsum("...ct,...et->...ce", [X, X.conj()])
    Phi_N = FC.einsum("...ct,...et->...ce", [N, N.conj()])
    # (B, F, C, 1)
    rtf = get_rtf(Phi_X, Phi_N, reference_vector=0, iterations=20)
    if is_torch_1_1_plus:
        rtf = rtf / (rtf.abs().max(dim=-2, keepdim=True).values + 1e-15)
    else:
        rtf = rtf / (rtf.abs().max(dim=-2, keepdim=True)[0] + 1e-15)
    # rtf \approx Phi_N MaxEigVec(Phi_N^-1 @ Phi_X)
    if is_torch_1_1_plus:
        # torch.solve is required, which is only available after pytorch 1.1.0+
        mat = FC.solve(Phi_X, Phi_N)[0]
        max_eigenvec = FC.solve(rtf, Phi_N)[0]
    else:
        mat = FC.matmul(Phi_N.inverse2(), Phi_X)
        max_eigenvec = FC.matmul(Phi_N.inverse2(), rtf)
    factor = FC.matmul(mat, max_eigenvec)
    assert FC.allclose(
        FC.matmul(max_eigenvec, factor.transpose(-1, -2)),
        FC.matmul(factor, max_eigenvec.transpose(-1, -2)),
    )
コード例 #5
0
ファイル: dnn_beamformer.py プロジェクト: yistLin/espnet
    def forward(
        self, data: ComplexTensor, ilens: torch.LongTensor
    ) -> Tuple[ComplexTensor, torch.LongTensor, torch.Tensor]:
        """The forward function

        Notation:
            B: Batch
            C: Channel
            T: Time or Sequence length
            F: Freq

        Args:
            data (ComplexTensor): (B, T, C, F), double precision
            ilens (torch.Tensor): (B,)
        Returns:
            enhanced (ComplexTensor): (B, T, F), double precision
            ilens (torch.Tensor): (B,)
            masks (torch.Tensor): (B, T, C, F)
        """
        def apply_beamforming(data, ilens, psd_speech, psd_n, beamformer_type):
            # u: (B, C)
            if self.ref_channel < 0:
                u, _ = self.ref(psd_speech.float(), ilens)
            else:
                # (optional) Create onehot vector for fixed reference microphone
                u = torch.zeros(*(data.size()[:-3] + (data.size(-2), )),
                                device=data.device)
                u[..., self.ref_channel].fill_(1)

            if beamformer_type in ("mpdr", "mvdr"):
                ws = get_mvdr_vector(psd_speech, psd_n, u.double())
                enhanced = apply_beamforming_vector(ws, data)
            elif beamformer_type == "wpd":
                ws = get_WPD_filter_v2(psd_speech, psd_n, u.double())
                enhanced = perform_WPD_filtering(ws, data, self.bdelay,
                                                 self.btaps)
            else:
                raise ValueError("Not supporting beamformer_type={}".format(
                    beamformer_type))

            return enhanced, ws

        # data (B, T, C, F) -> (B, F, C, T)
        data = data.permute(0, 3, 2, 1)

        # mask: [(B, F, C, T)]
        masks, _ = self.mask(data.float(), ilens)
        assert self.nmask == len(masks)
        # floor masks with self.eps to increase numerical stability
        masks = [torch.clamp(m, min=self.eps) for m in masks]

        if self.num_spk == 1:  # single-speaker case
            if self.use_noise_mask:
                # (mask_speech, mask_noise)
                mask_speech, mask_noise = masks
            else:
                # (mask_speech,)
                mask_speech = masks[0]
                mask_noise = 1 - mask_speech

            psd_speech = get_power_spectral_density_matrix(
                data, mask_speech.double())
            if self.beamformer_type == "mvdr":
                # psd of noise
                psd_n = get_power_spectral_density_matrix(
                    data, mask_noise.double())
            elif self.beamformer_type == "mpdr":
                # psd of observed signal
                psd_n = FC.einsum("...ct,...et->...ce", [data, data.conj()])
            elif self.beamformer_type == "wpd":
                # Calculate power: (..., C, T)
                power_speech = (data.real**2 +
                                data.imag**2) * mask_speech.double()
                # Averaging along the channel axis: (B, F, C, T) -> (B, F, T)
                power_speech = power_speech.mean(dim=-2)
                inverse_power = 1 / torch.clamp(power_speech, min=self.eps)
                # covariance of expanded observed speech
                psd_n = get_covariances(data,
                                        inverse_power,
                                        self.bdelay,
                                        self.btaps,
                                        get_vector=False)
            else:
                raise ValueError("Not supporting beamformer_type={}".format(
                    self.beamformer_type))

            enhanced, ws = apply_beamforming(data, ilens, psd_speech, psd_n,
                                             self.beamformer_type)

            # (..., F, T) -> (..., T, F)
            enhanced = enhanced.transpose(-1, -2)
        else:  # multi-speaker case
            if self.use_noise_mask:
                # (mask_speech1, ..., mask_noise)
                mask_speech = list(masks[:-1])
                mask_noise = masks[-1]
            else:
                # (mask_speech1, ..., mask_speechX)
                mask_speech = list(masks)
                mask_noise = None

            psd_speeches = [
                get_power_spectral_density_matrix(data, mask)
                for mask in mask_speech
            ]
            if self.beamformer_type == "mvdr":
                # psd of noise
                if mask_noise is not None:
                    psd_n = get_power_spectral_density_matrix(data, mask_noise)
            elif self.beamformer_type == "mpdr":
                # psd of observed speech
                psd_n = FC.einsum("...ct,...et->...ce", [data, data.conj()])
            elif self.beamformer_type == "wpd":
                # Calculate power: (..., C, T)
                power = data.real**2 + data.imag**2
                power_speeches = [power * mask for mask in mask_speech]
                # Averaging along the channel axis: (B, F, C, T) -> (B, F, T)
                power_speeches = [ps.mean(dim=-2) for ps in power_speeches]
                inverse_poweres = [
                    1 / torch.clamp(ps, min=self.eps) for ps in power_speeches
                ]
                # covariance of expanded observed speech
                psd_n = [
                    get_covariances(data,
                                    inv_ps,
                                    self.bdelay,
                                    self.btaps,
                                    get_vector=False)
                    for inv_ps in inverse_poweres
                ]
            else:
                raise ValueError("Not supporting beamformer_type={}".format(
                    self.beamformer_type))

            enhanced = []
            for i in range(self.num_spk):
                psd_speech = psd_speeches.pop(i)
                # treat all other speakers' psd_speech as noises
                if self.beamformer_type == "mvdr":
                    psd_noise = sum(psd_speeches)
                    if mask_noise is not None:
                        psd_noise = psd_noise + psd_n

                    enh, w = apply_beamforming(data, ilens, psd_speech,
                                               psd_noise, self.beamformer_type)
                elif self.beamformer_type == "mpdr":
                    enh, w = apply_beamforming(data, ilens, psd_speech, psd_n,
                                               self.beamformer_type)
                elif self.beamformer_type == "wpd":
                    enh, w = apply_beamforming(data, ilens, psd_speech,
                                               psd_n[i], self.beamformer_type)
                else:
                    raise ValueError(
                        "Not supporting beamformer_type={}".format(
                            self.beamformer_type))
                psd_speeches.insert(i, psd_speech)

                # (..., F, T) -> (..., T, F)
                enh = enh.transpose(-1, -2)
                enhanced.append(enh)

        # (..., F, C, T) -> (..., T, C, F)
        masks = [m.transpose(-1, -3) for m in masks]
        return enhanced, ilens, masks
コード例 #6
0
ファイル: beamformer.py プロジェクト: zqs01/espnet
def apply_beamforming_vector(beamform_vector: ComplexTensor,
                             mix: ComplexTensor) -> ComplexTensor:
    # (..., C) x (..., C, T) -> (..., T)
    es = FC.einsum('...c,...ct->...t', [beamform_vector.conj(), mix])
    return es
コード例 #7
0
def online_wpe_step(input_buffer: ComplexTensor,
                    power: torch.Tensor,
                    inv_cov: ComplexTensor = None,
                    filter_taps: ComplexTensor = None,
                    alpha: float = 0.99,
                    taps: int = 10,
                    delay: int = 3):
    """One step of online dereverberation.

    Args:
        input_buffer: (F, C, taps + delay + 1)
        power: Estimate for the current PSD (F, T)
        inv_cov: Current estimate of R^-1
        filter_taps: Current estimate of filter taps (F, taps * C, taps)
        alpha (float): Smoothing factor
        taps (int): Number of filter taps
        delay (int): Delay in frames

    Returns:
        Dereverberated frame of shape (F, D)
        Updated estimate of R^-1
        Updated estimate of the filter taps


    >>> frame_length = 512
    >>> frame_shift = 128
    >>> taps = 6
    >>> delay = 3
    >>> alpha = 0.999
    >>> frequency_bins = frame_length // 2 + 1
    >>> Q = None
    >>> G = None
    >>> unreverbed, Q, G = online_wpe_step(stft, get_power_online(stft), Q, G,
    ...                                    alpha=alpha, taps=taps, delay=delay)

    """
    assert input_buffer.size(-1) == taps + delay + 1, input_buffer.size()
    C = input_buffer.size(-2)

    if inv_cov is None:
        inv_cov = ComplexTensor(
            torch.eye(C * taps, dtype=input_buffer.dtype).expand(
                *input_buffer.size()[:-2], C * taps, C * taps))
    if filter_taps is None:
        filter_taps = ComplexTensor(
            torch.zeros(*input_buffer.size()[:-2],
                        C * taps,
                        C,
                        dtype=input_buffer.dtype))

    window = FC.reverse(input_buffer[..., :-delay - 1], dim=-1)
    # (..., C, T) -> (..., C * T)
    window = window.view(*input_buffer.size()[:-2], -1)
    pred = input_buffer[..., -1] - FC.einsum('...id,...i->...d',
                                             (filter_taps.conj(), window))

    nominator = FC.einsum('...ij,...j->...i', (inv_cov, window))
    denominator = \
        FC.einsum('...i,...i->...', (window.conj(), nominator)) + alpha * power
    kalman_gain = nominator / denominator[..., None]

    inv_cov_k = inv_cov - FC.einsum('...j,...jm,...i->...im',
                                    (window.conj(), inv_cov, kalman_gain))
    inv_cov_k /= alpha

    filter_taps_k = \
        filter_taps + FC.einsum('...i,...m->...im', (kalman_gain, pred.conj()))
    return pred, inv_cov_k, filter_taps_k
コード例 #8
0
def wpe_step_v3(Y,
                inverse_power,
                taps=10,
                delay=3,
                statistics_mode='full',
                solver='torch_complex.solve'):
    """

    Tested with 1.7.0.dev20200807

    Properties (Compared to lower versions):
      - faster
      - less memory for backward
      - (less peak memory)? Looks so. Difficult to profile.


    Args:
        Y: (..., channel, frames)
        inverse_power:
        taps:
        delay:
        statistics_mode:
        solver:

    Returns:

    """
    if statistics_mode == 'full':
        s = Ellipsis
    elif statistics_mode == 'valid':
        raise NotImplementedError(statistics_mode)
        s = (Ellipsis, slice(delay + taps - 1, None))
    else:
        raise ValueError(statistics_mode)

    if isinstance(Y, np.ndarray):
        Y = ComplexTensor(Y)
        Y = Y.to(inverse_power.device)

    Y_tilde = build_y_tilde(Y, taps, delay)

    # Torch does not keep the non contignous property for tensors with for
    # negation (i.e. ComplexTensor.conj changes the sign of imag).
    Y_conj = Y.conj()
    Y_tilde_conj = build_y_tilde(Y_conj, taps, delay)

    # Y_tilde_conj = Y_tilde.conj()

    # This code is faster, but with backward graph the memory consumption is to
    # high. (Pytorch is at the moment not intelligent enough)
    # Y_tilde_inverse_power = Y_tilde * inverse_power[..., None, :]
    # R = Y_tilde_inverse_power[s] @ transpose(Y_tilde_conj[s])
    # P = Y_tilde_inverse_power[s] @ transpose(Y_conj[s])

    def get_correlation(m, Y1, Y2):
        real = torch.einsum('...t,...dt,...et->...de', m,
                            Y1.real, Y2.real) - torch.einsum(
                                '...t,...dt,...et->...de', m, Y1.imag, Y2.imag)

        imag = torch.einsum('...t,...dt,...et->...de', m,
                            Y1.real, Y2.imag) + torch.einsum(
                                '...t,...dt,...et->...de', m, Y1.imag, Y2.real)
        return ComplexTensor(real, imag)

    # R_conj = torch_complex.functional.einsum(
    #     '...t,...dt,...et->...de', inverse_power, Y_tilde_conj, Y_tilde)
    R_conj = get_correlation(inverse_power, Y_tilde_conj, Y_tilde)

    # # print('wpe rss before P', ByteSize(process.memory_info().rss))
    # P_conj = torch_complex.functional.einsum(
    #     '...t,...dt,...et->...de',
    #     inverse_power, Y_tilde_conj, Y
    # )
    P_conj = get_correlation(inverse_power, Y_tilde_conj, Y)

    G_conj = _solve(R=R_conj, P=P_conj, solver=solver)

    # Matmul converts the non contignous Y_tilde to contignous, hence use einsum
    # Einsum does not work on the gpu with non contignous, hence use torch.utils.checkpoint.checkpoint
    # X = Y - torch_complex.functional.einsum('...ij,...ik->...jk', G_conj, Y_tilde)
    X = ComplexTensor(
        Y.real -
        torch.einsum('...ij,...ik->...jk', G_conj.real, Y_tilde.real) +
        torch.einsum('...ij,...ik->...jk', G_conj.imag, Y_tilde.imag),
        Y.imag -
        torch.einsum('...ij,...ik->...jk', G_conj.real, Y_tilde.imag) -
        torch.einsum('...ij,...ik->...jk', G_conj.imag, Y_tilde.real),
    )

    return X
コード例 #9
0
def wpe_step_v2(Y,
                inverse_power,
                taps=10,
                delay=3,
                statistics_mode='full',
                solver='torch_complex.solve'):
    """

    Args:
        Y: (..., channel, frames)
        inverse_power:
        taps:
        delay:
        statistics_mode:
        solver:

    Returns:

    """
    if statistics_mode == 'full':
        s = Ellipsis
    elif statistics_mode == 'valid':
        raise NotImplementedError(statistics_mode)
        s = (Ellipsis, slice(delay + taps - 1, None))
    else:
        raise ValueError(statistics_mode)

    if isinstance(Y, np.ndarray):
        Y = ComplexTensor(Y)
        Y = Y.to(inverse_power.device)

    Y_tilde = build_y_tilde(Y, taps, delay)

    # Torch does not keep the non contignous property for tensors with for
    # negation (i.e. ComplexTensor.conj changes the sign of imag).
    Y_conj = Y.conj()
    Y_tilde_conj = build_y_tilde(Y_conj, taps, delay)
    # Y_tilde_conj = Y_tilde.conj()

    # This code is faster, but with backward graph the memory consumption is to
    # high. (Pytorch is at the moment not intelligent enough)
    # Y_tilde_inverse_power = Y_tilde * inverse_power[..., None, :]
    # R = Y_tilde_inverse_power[s] @ hermite(Y_tilde[s])
    # P = Y_tilde_inverse_power[s] @ hermite(Y[s])

    import torch.utils.checkpoint

    # remove when https://github.com/pytorch/pytorch/issues/42418
    # has a solution.
    # This may be very expencive, because the calculation of R dominates the
    # execution time of WPE
    def get_R(inverse_power, Y_tilde_real, Y_tilde_imag):
        Y_tilde_real = Y_tilde_real.contiguous()
        Y_tilde_imag = Y_tilde_imag.contiguous()
        Y_tilde = ComplexTensor(Y_tilde_real, Y_tilde_imag)
        Y_tilde_conj = ComplexTensor(Y_tilde_real, -Y_tilde_imag)
        R = torch_complex.functional.einsum('...t,...dt,...et->...de',
                                            inverse_power, Y_tilde,
                                            Y_tilde_conj)
        return R.real, R.imag

    R = ComplexTensor(*torch.utils.checkpoint.checkpoint(
        get_R, inverse_power, Y_tilde.real, Y_tilde.imag))

    # print('wpe rss before P', ByteSize(process.memory_info().rss))
    P = torch_complex.functional.einsum('...t,...dt,...et->...de',
                                        inverse_power, Y_tilde, Y_conj)
    G = _solve(R=R, P=P, solver=solver)

    # remove when https://github.com/pytorch/pytorch/issues/42418
    # has a solution.
    def contiguous_einsum(equation, *operands):
        def foo(*operands):
            assert len(operands) % 2 == 0, len(operands)
            operands = [
                ComplexTensor(real.contiguous(), imag.contiguous())
                for real, imag in zip(operands[::2], operands[1::2])
            ]
            ret = torch_complex.functional.einsum(equation, operands)
            return ret.real, ret.imag

        operands = [part for o in operands for part in [o.real, o.imag]]

        real, imag = torch.utils.checkpoint.checkpoint(foo, *operands)
        return ComplexTensor(real, imag)

    # Matmul cannot handle the non contignous Y_tilde, hence use einsum
    # Einsum does not work on the gpu with non contignous, hence use torch.utils.checkpoint.checkpoint
    X = Y - contiguous_einsum('...ij,...ik->...jk', G.conj(), Y_tilde)

    return X