def test_advanced_indexing_tensor(output_shape): # u v # / \ / \ # i j k # \ | / # \ | / # x output = reals(*output_shape) x = random_tensor( OrderedDict([ ('i', bint(2)), ('j', bint(3)), ('k', bint(4)), ]), output) i = random_tensor(OrderedDict([ ('u', bint(5)), ]), bint(2)) j = random_tensor(OrderedDict([ ('v', bint(6)), ('u', bint(5)), ]), bint(3)) k = random_tensor(OrderedDict([ ('v', bint(6)), ]), bint(4)) expected_data = empty((5, 6) + output_shape) for u in range(5): for v in range(6): expected_data[u, v] = x.data[i.data[u], j.data[v, u], k.data[v]] expected = Tensor(expected_data, OrderedDict([ ('u', bint(5)), ('v', bint(6)), ])) assert_equiv(expected, x(i, j, k)) assert_equiv(expected, x(i=i, j=j, k=k)) assert_equiv(expected, x(i=i, j=j)(k=k)) assert_equiv(expected, x(j=j, k=k)(i=i)) assert_equiv(expected, x(k=k, i=i)(j=j)) assert_equiv(expected, x(i=i)(j=j, k=k)) assert_equiv(expected, x(j=j)(k=k, i=i)) assert_equiv(expected, x(k=k)(i=i, j=j)) assert_equiv(expected, x(i=i)(j=j)(k=k)) assert_equiv(expected, x(i=i)(k=k)(j=j)) assert_equiv(expected, x(j=j)(i=i)(k=k)) assert_equiv(expected, x(j=j)(k=k)(i=i)) assert_equiv(expected, x(k=k)(i=i)(j=j)) assert_equiv(expected, x(k=k)(j=j)(i=i))
def test_advanced_indexing_lazy(output_shape): x = Tensor(randn((2, 3, 4) + output_shape), OrderedDict([ ('i', bint(2)), ('j', bint(3)), ('k', bint(4)), ])) u = Variable('u', bint(2)) v = Variable('v', bint(3)) with interpretation(lazy): i = Number(1, 2) - u j = Number(2, 3) - v k = u + v expected_data = empty((2, 3) + output_shape) i_data = x.materialize(i).data j_data = x.materialize(j).data k_data = x.materialize(k).data for u in range(2): for v in range(3): expected_data[u, v] = x.data[i_data[u], j_data[v], k_data[u, v]] expected = Tensor(expected_data, OrderedDict([ ('u', bint(2)), ('v', bint(3)), ])) assert_equiv(expected, x(i, j, k)) assert_equiv(expected, x(i=i, j=j, k=k)) assert_equiv(expected, x(i=i, j=j)(k=k)) assert_equiv(expected, x(j=j, k=k)(i=i)) assert_equiv(expected, x(k=k, i=i)(j=j)) assert_equiv(expected, x(i=i)(j=j, k=k)) assert_equiv(expected, x(j=j)(k=k, i=i)) assert_equiv(expected, x(k=k)(i=i, j=j)) assert_equiv(expected, x(i=i)(j=j)(k=k)) assert_equiv(expected, x(i=i)(k=k)(j=j)) assert_equiv(expected, x(j=j)(i=i)(k=k)) assert_equiv(expected, x(j=j)(k=k)(i=i)) assert_equiv(expected, x(k=k)(i=i)(j=j)) assert_equiv(expected, x(k=k)(j=j)(i=i))