Ejemplo n.º 1
0
def test_complex_complex_mm():
    """
    Complex mtx x complex mtx matrix multiply
    :return:
    """
    c = ComplexTensor(torch.zeros(4, 3)) + 1
    cc = c.mm(c.t())
    cc = cc.view(-1).data.numpy()

    # do the same in numpy
    np_c = np.ones((2, 3)).astype(np.complex64)
    np_cc = np.matmul(np_c, np_c.T)

    # compare
    np_cc = np_cc.flatten()
    np_cc = list(np_cc.real) + list(np_cc.imag)

    assert np.array_equal(np_cc, cc)
Ejemplo n.º 2
0
def test_complex_real_mm():
    """
    Complex mtx x real mtx matrix multiply
    :return:
    """
    c = ComplexTensor(torch.zeros(4, 3)) + 1
    r = torch.ones(2, 3) * 2 + 3
    cr = c.mm(r.t())
    cr = cr.view(-1).data.numpy()

    # do the same in numpy
    np_c = np.ones((2, 3)).astype(np.complex64)
    np_r = np.ones((2, 3)) * 2 + 3
    np_cr = np.matmul(np_c, np_r.T)

    # compare
    np_cr = np_cr.flatten()
    np_cr = list(np_cr.real) + list(np_cr.imag)

    assert np.array_equal(np_cr, cr)