def test_complex_rmult(): c = ComplexTensor(torch.zeros(4, 3)) + 1 c = (4 + 3j) * c c = c.view(-1).data.numpy() # do the same in numpy sol = np.zeros((2, 3)).astype(np.complex64) + 1 sol = (4 + 3j) * sol sol = sol.flatten() sol = list(sol.real) + list(sol.imag) assert np.array_equal(c, sol)
def test_complex_scalar_sum(): c = ComplexTensor(torch.zeros(4, 3)) c = c + (4 + 3j) c = c.view(-1).data.numpy() # do the same in numpy sol = np.zeros((2, 3)).astype(np.complex64) sol = sol + (4 + 3j) sol = sol.flatten() sol = list(sol.real) + list(sol.imag) assert np.array_equal(c, sol)
def test_real_matrix_sum(): c = ComplexTensor(torch.zeros(4, 3)) r = torch.ones(2, 3) c = c + r c = c.view(-1).data.numpy() # do the same in numpy sol = np.zeros((2, 3)).astype(np.complex64) sol_r = np.ones((2, 3)) sol = sol + sol_r sol = sol.flatten() sol = list(sol.real) + list(sol.imag) assert np.array_equal(c, sol)
def test_complex_complex_ele_mult(): """ Complex mtx x complex mtx elementwise multiply :return: """ c = ComplexTensor(torch.zeros(4, 3)) + 1 c = c * c c = c.view(-1).data.numpy() # do the same in numpy sol = np.zeros((2, 3)).astype(np.complex64) + 1 sol = sol * sol sol = sol.flatten() sol = list(sol.real) + list(sol.imag) assert np.array_equal(c, sol)