예제 #1
0
def test_axes_ops():
    """TODO."""
    # Subtraction
    def test_sub(axes1, axes2, target):
        """
        TODO.

        Arguments:
          axes1: TODO
          axes2: TODO
          target: TODO

        Returns:

        """
        assert ng.make_axes(axes1) - ng.make_axes(axes2) == ng.make_axes(target)

    test_sub([ax.A, ax.B], [ax.A], [ax.B])
    test_sub([ax.A, ax.B], [ax.B], [ax.A])

    # Combined axes length
    assert FlattenedAxis([ax.A, ax.B]).length \
        == ax.A.length * ax.B.length
    assert ng.make_axes([ax.A, (ax.B, ax.C)]).lengths \
        == (ax.A.length, ax.B.length * ax.C.length)
    assert FlattenedAxis([ax.A, (ax.B, ax.C)]).length \
        == ax.A.length * ax.B.length * ax.C.length
예제 #2
0
파일: serde.py 프로젝트: ami-GS/ngraph
def pb_to_axis(msg):

    if msg.uuid.uuid in GLOBAL_AXIS_REGISTRY:  # Already deserialized
        return GLOBAL_AXIS_REGISTRY[msg.uuid.uuid]
    elif msg.HasField('flattened_axes'):  # FlattenedAxis
        axes = protobuf_to_axes(msg.flattened_axes)
        axis = FlattenedAxis(axes)
    else:
        axis = Axis(name=msg.name, length=msg.length)

    axis.uuid = uuid.UUID(bytes=msg.uuid.uuid)
    GLOBAL_AXIS_REGISTRY[axis.uuid.bytes] = axis
    return axis
예제 #3
0
def test_simple_tensors():
    """
    tons of tests relating to reaxeing tensors.

    variables names have a postfix integer which represents the dimensionality
    of the value.  Views have x_y postfix which means they are y dimensional
    views of x dimensional buffers.

    I started refactoring into smaller pieces as seen in tests above, but
    stopped ...
    """
    # A simple vector
    td1 = TensorDescription(axes=[ax.A])
    e1 = random(td1)

    td2 = TensorDescription(axes=[ax.A, ax.B])
    e2 = random(td2)

    # Reaxes
    e1_1 = tensorview(td1.broadcast([ax.A, ax.B]), e1)
    e1_2 = tensorview(td1.broadcast([ax.B, ax.A]), e1)
    e1_3 = tensorview(td1.broadcast([(ax.B, ax.C), ax.A]), e1)

    e2_1 = tensorview(td2.broadcast([ax.B, ax.A]), e2)
    e2_2 = tensorview(td2.broadcast([ax.A, ax.B]), e2)
    e2_3 = tensorview(td2.flatten((
        FlattenedAxis((ax.A, ax.B)),
    )), e2_2)

    assert e1_1.shape == (ax.A.length, ax.B.length)
    assert e1_2.shape == (ax.B.length, ax.A.length)

    for i in range(ax.A.length):
        e1_1[i] = i

    for i in range(ax.A.length):
        assert e1[i] == i
        for j in range(ax.B.length):
            assert e1_1[i, j] == i
            assert e1_2[j, i] == i
        for j in range(ax.B.length * ax.C.length):
            assert e1_3[j, i] == i

    def val2(i, j):
        return (i + 1) * (j + 2)

    for i in range(ax.A.length):
        for j in range(ax.B.length):
            e2[i, j] = val2(i, j)

    for i in range(ax.A.length):
        for j in range(ax.B.length):
            assert e2_1[j, i] == val2(i, j)
            assert e2_2[i, j] == val2(i, j)
            assert e2_3[i * ax.B.length + j] == val2(i, j)