def wpe_step_v1(Y, inverse_power, taps=10, delay=3, statistics_mode='full', solver='torch_complex.solve'): """ Args: Y: (..., channel, frames) inverse_power: taps: delay: statistics_mode: solver: Returns: """ if statistics_mode == 'full': s = Ellipsis elif statistics_mode == 'valid': raise NotImplementedError(statistics_mode) s = (Ellipsis, slice(delay + taps - 1, None)) else: raise ValueError(statistics_mode) if isinstance(Y, np.ndarray): Y = ComplexTensor(Y) Y = Y.to(inverse_power.device) Y_tilde = build_y_tilde(Y, taps, delay) Y_tilde = Y_tilde # .contiguous() Y_tilde_inverse_power = Y_tilde * inverse_power[..., None, :] R = Y_tilde_inverse_power[s] @ hermite(Y_tilde[s]) P = Y_tilde_inverse_power[s] @ hermite(Y[s]) G = _solve(R=R, P=P, solver=solver) X = Y - hermite(G) @ Y_tilde return X
def wpe_step_v3(Y, inverse_power, taps=10, delay=3, statistics_mode='full', solver='torch_complex.solve'): """ Tested with 1.7.0.dev20200807 Properties (Compared to lower versions): - faster - less memory for backward - (less peak memory)? Looks so. Difficult to profile. Args: Y: (..., channel, frames) inverse_power: taps: delay: statistics_mode: solver: Returns: """ if statistics_mode == 'full': s = Ellipsis elif statistics_mode == 'valid': raise NotImplementedError(statistics_mode) s = (Ellipsis, slice(delay + taps - 1, None)) else: raise ValueError(statistics_mode) if isinstance(Y, np.ndarray): Y = ComplexTensor(Y) Y = Y.to(inverse_power.device) Y_tilde = build_y_tilde(Y, taps, delay) # Torch does not keep the non contignous property for tensors with for # negation (i.e. ComplexTensor.conj changes the sign of imag). Y_conj = Y.conj() Y_tilde_conj = build_y_tilde(Y_conj, taps, delay) # Y_tilde_conj = Y_tilde.conj() # This code is faster, but with backward graph the memory consumption is to # high. (Pytorch is at the moment not intelligent enough) # Y_tilde_inverse_power = Y_tilde * inverse_power[..., None, :] # R = Y_tilde_inverse_power[s] @ transpose(Y_tilde_conj[s]) # P = Y_tilde_inverse_power[s] @ transpose(Y_conj[s]) def get_correlation(m, Y1, Y2): real = torch.einsum('...t,...dt,...et->...de', m, Y1.real, Y2.real) - torch.einsum( '...t,...dt,...et->...de', m, Y1.imag, Y2.imag) imag = torch.einsum('...t,...dt,...et->...de', m, Y1.real, Y2.imag) + torch.einsum( '...t,...dt,...et->...de', m, Y1.imag, Y2.real) return ComplexTensor(real, imag) # R_conj = torch_complex.functional.einsum( # '...t,...dt,...et->...de', inverse_power, Y_tilde_conj, Y_tilde) R_conj = get_correlation(inverse_power, Y_tilde_conj, Y_tilde) # # print('wpe rss before P', ByteSize(process.memory_info().rss)) # P_conj = torch_complex.functional.einsum( # '...t,...dt,...et->...de', # inverse_power, Y_tilde_conj, Y # ) P_conj = get_correlation(inverse_power, Y_tilde_conj, Y) G_conj = _solve(R=R_conj, P=P_conj, solver=solver) # Matmul converts the non contignous Y_tilde to contignous, hence use einsum # Einsum does not work on the gpu with non contignous, hence use torch.utils.checkpoint.checkpoint # X = Y - torch_complex.functional.einsum('...ij,...ik->...jk', G_conj, Y_tilde) X = ComplexTensor( Y.real - torch.einsum('...ij,...ik->...jk', G_conj.real, Y_tilde.real) + torch.einsum('...ij,...ik->...jk', G_conj.imag, Y_tilde.imag), Y.imag - torch.einsum('...ij,...ik->...jk', G_conj.real, Y_tilde.imag) - torch.einsum('...ij,...ik->...jk', G_conj.imag, Y_tilde.real), ) return X
def wpe_step_v2(Y, inverse_power, taps=10, delay=3, statistics_mode='full', solver='torch_complex.solve'): """ Args: Y: (..., channel, frames) inverse_power: taps: delay: statistics_mode: solver: Returns: """ if statistics_mode == 'full': s = Ellipsis elif statistics_mode == 'valid': raise NotImplementedError(statistics_mode) s = (Ellipsis, slice(delay + taps - 1, None)) else: raise ValueError(statistics_mode) if isinstance(Y, np.ndarray): Y = ComplexTensor(Y) Y = Y.to(inverse_power.device) Y_tilde = build_y_tilde(Y, taps, delay) # Torch does not keep the non contignous property for tensors with for # negation (i.e. ComplexTensor.conj changes the sign of imag). Y_conj = Y.conj() Y_tilde_conj = build_y_tilde(Y_conj, taps, delay) # Y_tilde_conj = Y_tilde.conj() # This code is faster, but with backward graph the memory consumption is to # high. (Pytorch is at the moment not intelligent enough) # Y_tilde_inverse_power = Y_tilde * inverse_power[..., None, :] # R = Y_tilde_inverse_power[s] @ hermite(Y_tilde[s]) # P = Y_tilde_inverse_power[s] @ hermite(Y[s]) import torch.utils.checkpoint # remove when https://github.com/pytorch/pytorch/issues/42418 # has a solution. # This may be very expencive, because the calculation of R dominates the # execution time of WPE def get_R(inverse_power, Y_tilde_real, Y_tilde_imag): Y_tilde_real = Y_tilde_real.contiguous() Y_tilde_imag = Y_tilde_imag.contiguous() Y_tilde = ComplexTensor(Y_tilde_real, Y_tilde_imag) Y_tilde_conj = ComplexTensor(Y_tilde_real, -Y_tilde_imag) R = torch_complex.functional.einsum('...t,...dt,...et->...de', inverse_power, Y_tilde, Y_tilde_conj) return R.real, R.imag R = ComplexTensor(*torch.utils.checkpoint.checkpoint( get_R, inverse_power, Y_tilde.real, Y_tilde.imag)) # print('wpe rss before P', ByteSize(process.memory_info().rss)) P = torch_complex.functional.einsum('...t,...dt,...et->...de', inverse_power, Y_tilde, Y_conj) G = _solve(R=R, P=P, solver=solver) # remove when https://github.com/pytorch/pytorch/issues/42418 # has a solution. def contiguous_einsum(equation, *operands): def foo(*operands): assert len(operands) % 2 == 0, len(operands) operands = [ ComplexTensor(real.contiguous(), imag.contiguous()) for real, imag in zip(operands[::2], operands[1::2]) ] ret = torch_complex.functional.einsum(equation, operands) return ret.real, ret.imag operands = [part for o in operands for part in [o.real, o.imag]] real, imag = torch.utils.checkpoint.checkpoint(foo, *operands) return ComplexTensor(real, imag) # Matmul cannot handle the non contignous Y_tilde, hence use einsum # Einsum does not work on the gpu with non contignous, hence use torch.utils.checkpoint.checkpoint X = Y - contiguous_einsum('...ij,...ik->...jk', G.conj(), Y_tilde) return X