def wpe_step_v1(Y,
                inverse_power,
                taps=10,
                delay=3,
                statistics_mode='full',
                solver='torch_complex.solve'):
    """

    Args:
        Y: (..., channel, frames)
        inverse_power:
        taps:
        delay:
        statistics_mode:
        solver:

    Returns:

    """
    if statistics_mode == 'full':
        s = Ellipsis
    elif statistics_mode == 'valid':
        raise NotImplementedError(statistics_mode)
        s = (Ellipsis, slice(delay + taps - 1, None))
    else:
        raise ValueError(statistics_mode)

    if isinstance(Y, np.ndarray):
        Y = ComplexTensor(Y)
        Y = Y.to(inverse_power.device)

    Y_tilde = build_y_tilde(Y, taps, delay)
    Y_tilde = Y_tilde  # .contiguous()

    Y_tilde_inverse_power = Y_tilde * inverse_power[..., None, :]
    R = Y_tilde_inverse_power[s] @ hermite(Y_tilde[s])
    P = Y_tilde_inverse_power[s] @ hermite(Y[s])

    G = _solve(R=R, P=P, solver=solver)

    X = Y - hermite(G) @ Y_tilde

    return X
def wpe_step_v3(Y,
                inverse_power,
                taps=10,
                delay=3,
                statistics_mode='full',
                solver='torch_complex.solve'):
    """

    Tested with 1.7.0.dev20200807

    Properties (Compared to lower versions):
      - faster
      - less memory for backward
      - (less peak memory)? Looks so. Difficult to profile.


    Args:
        Y: (..., channel, frames)
        inverse_power:
        taps:
        delay:
        statistics_mode:
        solver:

    Returns:

    """
    if statistics_mode == 'full':
        s = Ellipsis
    elif statistics_mode == 'valid':
        raise NotImplementedError(statistics_mode)
        s = (Ellipsis, slice(delay + taps - 1, None))
    else:
        raise ValueError(statistics_mode)

    if isinstance(Y, np.ndarray):
        Y = ComplexTensor(Y)
        Y = Y.to(inverse_power.device)

    Y_tilde = build_y_tilde(Y, taps, delay)

    # Torch does not keep the non contignous property for tensors with for
    # negation (i.e. ComplexTensor.conj changes the sign of imag).
    Y_conj = Y.conj()
    Y_tilde_conj = build_y_tilde(Y_conj, taps, delay)

    # Y_tilde_conj = Y_tilde.conj()

    # This code is faster, but with backward graph the memory consumption is to
    # high. (Pytorch is at the moment not intelligent enough)
    # Y_tilde_inverse_power = Y_tilde * inverse_power[..., None, :]
    # R = Y_tilde_inverse_power[s] @ transpose(Y_tilde_conj[s])
    # P = Y_tilde_inverse_power[s] @ transpose(Y_conj[s])

    def get_correlation(m, Y1, Y2):
        real = torch.einsum('...t,...dt,...et->...de', m,
                            Y1.real, Y2.real) - torch.einsum(
                                '...t,...dt,...et->...de', m, Y1.imag, Y2.imag)

        imag = torch.einsum('...t,...dt,...et->...de', m,
                            Y1.real, Y2.imag) + torch.einsum(
                                '...t,...dt,...et->...de', m, Y1.imag, Y2.real)
        return ComplexTensor(real, imag)

    # R_conj = torch_complex.functional.einsum(
    #     '...t,...dt,...et->...de', inverse_power, Y_tilde_conj, Y_tilde)
    R_conj = get_correlation(inverse_power, Y_tilde_conj, Y_tilde)

    # # print('wpe rss before P', ByteSize(process.memory_info().rss))
    # P_conj = torch_complex.functional.einsum(
    #     '...t,...dt,...et->...de',
    #     inverse_power, Y_tilde_conj, Y
    # )
    P_conj = get_correlation(inverse_power, Y_tilde_conj, Y)

    G_conj = _solve(R=R_conj, P=P_conj, solver=solver)

    # Matmul converts the non contignous Y_tilde to contignous, hence use einsum
    # Einsum does not work on the gpu with non contignous, hence use torch.utils.checkpoint.checkpoint
    # X = Y - torch_complex.functional.einsum('...ij,...ik->...jk', G_conj, Y_tilde)
    X = ComplexTensor(
        Y.real -
        torch.einsum('...ij,...ik->...jk', G_conj.real, Y_tilde.real) +
        torch.einsum('...ij,...ik->...jk', G_conj.imag, Y_tilde.imag),
        Y.imag -
        torch.einsum('...ij,...ik->...jk', G_conj.real, Y_tilde.imag) -
        torch.einsum('...ij,...ik->...jk', G_conj.imag, Y_tilde.real),
    )

    return X
def wpe_step_v2(Y,
                inverse_power,
                taps=10,
                delay=3,
                statistics_mode='full',
                solver='torch_complex.solve'):
    """

    Args:
        Y: (..., channel, frames)
        inverse_power:
        taps:
        delay:
        statistics_mode:
        solver:

    Returns:

    """
    if statistics_mode == 'full':
        s = Ellipsis
    elif statistics_mode == 'valid':
        raise NotImplementedError(statistics_mode)
        s = (Ellipsis, slice(delay + taps - 1, None))
    else:
        raise ValueError(statistics_mode)

    if isinstance(Y, np.ndarray):
        Y = ComplexTensor(Y)
        Y = Y.to(inverse_power.device)

    Y_tilde = build_y_tilde(Y, taps, delay)

    # Torch does not keep the non contignous property for tensors with for
    # negation (i.e. ComplexTensor.conj changes the sign of imag).
    Y_conj = Y.conj()
    Y_tilde_conj = build_y_tilde(Y_conj, taps, delay)
    # Y_tilde_conj = Y_tilde.conj()

    # This code is faster, but with backward graph the memory consumption is to
    # high. (Pytorch is at the moment not intelligent enough)
    # Y_tilde_inverse_power = Y_tilde * inverse_power[..., None, :]
    # R = Y_tilde_inverse_power[s] @ hermite(Y_tilde[s])
    # P = Y_tilde_inverse_power[s] @ hermite(Y[s])

    import torch.utils.checkpoint

    # remove when https://github.com/pytorch/pytorch/issues/42418
    # has a solution.
    # This may be very expencive, because the calculation of R dominates the
    # execution time of WPE
    def get_R(inverse_power, Y_tilde_real, Y_tilde_imag):
        Y_tilde_real = Y_tilde_real.contiguous()
        Y_tilde_imag = Y_tilde_imag.contiguous()
        Y_tilde = ComplexTensor(Y_tilde_real, Y_tilde_imag)
        Y_tilde_conj = ComplexTensor(Y_tilde_real, -Y_tilde_imag)
        R = torch_complex.functional.einsum('...t,...dt,...et->...de',
                                            inverse_power, Y_tilde,
                                            Y_tilde_conj)
        return R.real, R.imag

    R = ComplexTensor(*torch.utils.checkpoint.checkpoint(
        get_R, inverse_power, Y_tilde.real, Y_tilde.imag))

    # print('wpe rss before P', ByteSize(process.memory_info().rss))
    P = torch_complex.functional.einsum('...t,...dt,...et->...de',
                                        inverse_power, Y_tilde, Y_conj)
    G = _solve(R=R, P=P, solver=solver)

    # remove when https://github.com/pytorch/pytorch/issues/42418
    # has a solution.
    def contiguous_einsum(equation, *operands):
        def foo(*operands):
            assert len(operands) % 2 == 0, len(operands)
            operands = [
                ComplexTensor(real.contiguous(), imag.contiguous())
                for real, imag in zip(operands[::2], operands[1::2])
            ]
            ret = torch_complex.functional.einsum(equation, operands)
            return ret.real, ret.imag

        operands = [part for o in operands for part in [o.real, o.imag]]

        real, imag = torch.utils.checkpoint.checkpoint(foo, *operands)
        return ComplexTensor(real, imag)

    # Matmul cannot handle the non contignous Y_tilde, hence use einsum
    # Einsum does not work on the gpu with non contignous, hence use torch.utils.checkpoint.checkpoint
    X = Y - contiguous_einsum('...ij,...ik->...jk', G.conj(), Y_tilde)

    return X