def test_axes_ops(): """TODO.""" # Subtraction def test_sub(axes1, axes2, target): """ TODO. Arguments: axes1: TODO axes2: TODO target: TODO Returns: """ assert ng.make_axes(axes1) - ng.make_axes(axes2) == ng.make_axes(target) test_sub([ax.A, ax.B], [ax.A], [ax.B]) test_sub([ax.A, ax.B], [ax.B], [ax.A]) # Combined axes length assert FlattenedAxis([ax.A, ax.B]).length \ == ax.A.length * ax.B.length assert ng.make_axes([ax.A, (ax.B, ax.C)]).lengths \ == (ax.A.length, ax.B.length * ax.C.length) assert FlattenedAxis([ax.A, (ax.B, ax.C)]).length \ == ax.A.length * ax.B.length * ax.C.length
def pb_to_axis(msg): if msg.uuid.uuid in GLOBAL_AXIS_REGISTRY: # Already deserialized return GLOBAL_AXIS_REGISTRY[msg.uuid.uuid] elif msg.HasField('flattened_axes'): # FlattenedAxis axes = protobuf_to_axes(msg.flattened_axes) axis = FlattenedAxis(axes) else: axis = Axis(name=msg.name, length=msg.length) axis.uuid = uuid.UUID(bytes=msg.uuid.uuid) GLOBAL_AXIS_REGISTRY[axis.uuid.bytes] = axis return axis
def test_simple_tensors(): """ tons of tests relating to reaxeing tensors. variables names have a postfix integer which represents the dimensionality of the value. Views have x_y postfix which means they are y dimensional views of x dimensional buffers. I started refactoring into smaller pieces as seen in tests above, but stopped ... """ # A simple vector td1 = TensorDescription(axes=[ax.A]) e1 = random(td1) td2 = TensorDescription(axes=[ax.A, ax.B]) e2 = random(td2) # Reaxes e1_1 = tensorview(td1.broadcast([ax.A, ax.B]), e1) e1_2 = tensorview(td1.broadcast([ax.B, ax.A]), e1) e1_3 = tensorview(td1.broadcast([(ax.B, ax.C), ax.A]), e1) e2_1 = tensorview(td2.broadcast([ax.B, ax.A]), e2) e2_2 = tensorview(td2.broadcast([ax.A, ax.B]), e2) e2_3 = tensorview(td2.flatten(( FlattenedAxis((ax.A, ax.B)), )), e2_2) assert e1_1.shape == (ax.A.length, ax.B.length) assert e1_2.shape == (ax.B.length, ax.A.length) for i in range(ax.A.length): e1_1[i] = i for i in range(ax.A.length): assert e1[i] == i for j in range(ax.B.length): assert e1_1[i, j] == i assert e1_2[j, i] == i for j in range(ax.B.length * ax.C.length): assert e1_3[j, i] == i def val2(i, j): return (i + 1) * (j + 2) for i in range(ax.A.length): for j in range(ax.B.length): e2[i, j] = val2(i, j) for i in range(ax.A.length): for j in range(ax.B.length): assert e2_1[j, i] == val2(i, j) assert e2_2[i, j] == val2(i, j) assert e2_3[i * ax.B.length + j] == val2(i, j)