コード例 #1
0
ファイル: test_axes.py プロジェクト: tensor-tang/ngraph
def test_simple_tensors():
    """
    tons of tests relating to reaxeing tensors.

    variables names have a postfix integer which represents the dimensionality
    of the value.  Views have x_y postfix which means they are y dimensional
    views of x dimensional buffers.

    I started refactoring into smaller pieces as seen in tests above, but
    stopped ...
    """
    # A simple vector
    td1 = TensorDescription(axes=[ax.A])
    e1 = random(td1)

    td2 = TensorDescription(axes=[ax.A, ax.B])
    e2 = random(td2)

    # Reaxes
    e1_1 = tensorview(td1.broadcast([ax.A, ax.B]), e1)
    e1_2 = tensorview(td1.broadcast([ax.B, ax.A]), e1)
    e1_3 = tensorview(td1.broadcast([(ax.B, ax.C), ax.A]), e1)

    e2_1 = tensorview(td2.broadcast([ax.B, ax.A]), e2)
    e2_2 = tensorview(td2.broadcast([ax.A, ax.B]), e2)
    e2_3 = tensorview(td2.flatten((
        FlattenedAxis((ax.A, ax.B)),
    )), e2_2)

    assert e1_1.shape == (ax.A.length, ax.B.length)
    assert e1_2.shape == (ax.B.length, ax.A.length)

    for i in range(ax.A.length):
        e1_1[i] = i

    for i in range(ax.A.length):
        assert e1[i] == i
        for j in range(ax.B.length):
            assert e1_1[i, j] == i
            assert e1_2[j, i] == i
        for j in range(ax.B.length * ax.C.length):
            assert e1_3[j, i] == i

    def val2(i, j):
        return (i + 1) * (j + 2)

    for i in range(ax.A.length):
        for j in range(ax.B.length):
            e2[i, j] = val2(i, j)

    for i in range(ax.A.length):
        for j in range(ax.B.length):
            assert e2_1[j, i] == val2(i, j)
            assert e2_2[i, j] == val2(i, j)
            assert e2_3[i * ax.B.length + j] == val2(i, j)
コード例 #2
0
def test_reaxe_0d_to_2d():
    td = TensorDescription(axes=())
    x = random(td)

    x_view = tensorview(td.broadcast([ax_A, ax_B]), x)

    # set x
    x[()] = 3

    assert x_view.shape == (ax_A.length, ax_B.length)
    assert np.all(x_view == 3)
コード例 #3
0
def test_reaxe_0d_to_1d():
    td = TensorDescription(())
    x = random(td)

    # create view of x
    x_view = tensorview(td.broadcast([ax_A]), x)

    # set x
    x[()] = 3

    # setting e also sets x_view
    assert x_view.shape == (ax_A.length, )
    assert np.all(x_view == 3)