Exemplo n.º 1
0
def test_transpose_method():
    dat = np.arange(24).reshape(2, 3, 4)

    for axes in permutations(range(3)):
        # passing tuple of integers
        x = Tensor(dat)
        f = x.transpose(axes)
        f.backward(dat.transpose(axes))

        assert_allclose(f.data, dat.transpose(axes))
        assert_allclose(x.grad, dat)

        # passing integers directly
        x = Tensor(dat)
        f = x.transpose(*axes)
        f.backward(dat.transpose(axes))

        assert_allclose(f.data, dat.transpose(axes), err_msg="{}".format(axes))
        assert_allclose(x.grad, dat, err_msg="{}".format(axes))

    # passing integers directly
    x = Tensor(dat)
    f = x.transpose()
    f.backward(dat.transpose())

    assert_allclose(f.data, dat.transpose())
    assert_allclose(x.grad, dat)

    # check that constant=True works
    x = Tensor(dat)
    f = x.transpose(constant=True)
    assert f.constant and not x.constant

    f = x.transpose(1, 0, 2, constant=True)
    assert f.constant and not x.constant
Exemplo n.º 2
0
def test_transpose():
    dat = np.arange(24).reshape(2, 3, 4)
    x = Tensor(dat)
    f = x.transpose(axes=(2, 1, 0))
    f.backward(dat.transpose((2, 1, 0)))

    assert np.allclose(f.data, dat.transpose((2, 1, 0)))
    assert np.allclose(x.grad, dat)