def test_complex_complex_mm(): """ Complex mtx x complex mtx matrix multiply :return: """ c = ComplexTensor(torch.zeros(4, 3)) + 1 cc = c.mm(c.t()) cc = cc.view(-1).data.numpy() # do the same in numpy np_c = np.ones((2, 3)).astype(np.complex64) np_cc = np.matmul(np_c, np_c.T) # compare np_cc = np_cc.flatten() np_cc = list(np_cc.real) + list(np_cc.imag) assert np.array_equal(np_cc, cc)
def test_complex_real_mm(): """ Complex mtx x real mtx matrix multiply :return: """ c = ComplexTensor(torch.zeros(4, 3)) + 1 r = torch.ones(2, 3) * 2 + 3 cr = c.mm(r.t()) cr = cr.view(-1).data.numpy() # do the same in numpy np_c = np.ones((2, 3)).astype(np.complex64) np_r = np.ones((2, 3)) * 2 + 3 np_cr = np.matmul(np_c, np_r.T) # compare np_cr = np_cr.flatten() np_cr = list(np_cr.real) + list(np_cr.imag) assert np.array_equal(np_cr, cr)