コード例 #1
0
def test_complex_rmult():
    c = ComplexTensor(torch.zeros(4, 3)) + 1
    c = (4 + 3j) * c
    c = c.view(-1).data.numpy()

    # do the same in numpy
    sol = np.zeros((2, 3)).astype(np.complex64) + 1
    sol = (4 + 3j) * sol
    sol = sol.flatten()
    sol = list(sol.real) + list(sol.imag)

    assert np.array_equal(c, sol)
コード例 #2
0
def test_complex_scalar_sum():
    c = ComplexTensor(torch.zeros(4, 3))
    c = c + (4 + 3j)
    c = c.view(-1).data.numpy()

    # do the same in numpy
    sol = np.zeros((2, 3)).astype(np.complex64)
    sol = sol + (4 + 3j)
    sol = sol.flatten()
    sol = list(sol.real) + list(sol.imag)

    assert np.array_equal(c, sol)
コード例 #3
0
def test_real_matrix_sum():

    c = ComplexTensor(torch.zeros(4, 3))
    r = torch.ones(2, 3)
    c = c + r
    c = c.view(-1).data.numpy()

    # do the same in numpy
    sol = np.zeros((2, 3)).astype(np.complex64)
    sol_r = np.ones((2, 3))
    sol = sol + sol_r
    sol = sol.flatten()
    sol = list(sol.real) + list(sol.imag)

    assert np.array_equal(c, sol)
コード例 #4
0
def test_complex_complex_ele_mult():
    """
    Complex mtx x complex mtx elementwise multiply
    :return:
    """
    c = ComplexTensor(torch.zeros(4, 3)) + 1
    c = c * c
    c = c.view(-1).data.numpy()

    # do the same in numpy
    sol = np.zeros((2, 3)).astype(np.complex64) + 1
    sol = sol * sol
    sol = sol.flatten()
    sol = list(sol.real) + list(sol.imag)

    assert np.array_equal(c, sol)