def test_sequential_sum_product(impl, sum_op, prod_op, batch_inputs, state_domain, num_steps): inputs = OrderedDict(batch_inputs) inputs.update(prev=state_domain, curr=state_domain) if num_steps is None: num_steps = 1 else: inputs["time"] = bint(num_steps) if state_domain.dtype == "real": trans = random_gaussian(inputs) else: trans = random_tensor(inputs) time = Variable("time", bint(num_steps)) actual = impl(sum_op, prod_op, trans, time, {"prev": "curr"}) expected_inputs = batch_inputs.copy() expected_inputs.update(prev=state_domain, curr=state_domain) assert dict(actual.inputs) == expected_inputs # Check against contract. operands = tuple( trans(time=t, prev="t_{}".format(t), curr="t_{}".format(t + 1)) for t in range(num_steps)) reduce_vars = frozenset("t_{}".format(t) for t in range(1, num_steps)) with interpretation(reflect): expected = sum_product(sum_op, prod_op, operands, reduce_vars) expected = apply_optimizer(expected) expected = expected(**{"t_0": "prev", "t_{}".format(num_steps): "curr"}) expected = expected.align(tuple(actual.inputs.keys())) assert_close(actual, expected, rtol=5e-4 * num_steps)
def test_sequential_sum_product_bias_2(num_steps, num_sensors, dim): time = Variable("time", bint(num_steps)) bias = Variable("bias", reals(num_sensors, dim)) bias_dist = random_gaussian( OrderedDict([ ("bias", reals(num_sensors, dim)), ])) trans = random_gaussian( OrderedDict([ ("time", bint(num_steps)), ("x_prev", reals(dim)), ("x_curr", reals(dim)), ])) obs = random_gaussian( OrderedDict([ ("time", bint(num_steps)), ("x_curr", reals(dim)), ("bias", reals(dim)), ])) # Each time step only a single sensor observes x, # and each sensor has a different bias. sensor_id = Tensor(torch.arange(num_steps) % 2, OrderedDict(time=bint(num_steps)), dtype=2) with interpretation(eager_or_die): factor = trans + obs(bias=bias[sensor_id]) + bias_dist assert set(factor.inputs) == {"time", "bias", "x_prev", "x_curr"} result = sequential_sum_product(ops.logaddexp, ops.add, factor, time, {"x_prev": "x_curr"}) assert set(result.inputs) == {"bias", "x_prev", "x_curr"}
def __call__(self): # calls pyro.param so that params are exposed and constraints applied # should not create any new torch.Tensors after __init__ self.initialize_params() N_c = self.config["sizes"]["group"] N_s = self.config["sizes"]["individual"] log_prob = Tensor(torch.tensor(0.), OrderedDict()) plate_g = Tensor(torch.zeros(N_c), OrderedDict([("g", bint(N_c))])) plate_i = Tensor(torch.zeros(N_s), OrderedDict([("i", bint(N_s))])) if self.config["group"]["random"] == "continuous": eps_g_dist = plate_g + dist.Normal(**self.params["eps_g"])( value="eps_g") log_prob += eps_g_dist # individual-level random effects if self.config["individual"]["random"] == "continuous": eps_i_dist = plate_g + plate_i + dist.Normal( **self.params["eps_i"])(value="eps_i") log_prob += eps_i_dist return log_prob
def test_reduce_subset(dims, reduced_vars, op): reduced_vars = frozenset(reduced_vars) sizes = {'a': 3, 'b': 4, 'c': 5} shape = tuple(sizes[d] for d in dims) inputs = OrderedDict((d, bint(sizes[d])) for d in dims) data = rand(shape) + 0.5 dtype = 'real' if op in [ops.and_, ops.or_]: data = astype(data, 'uint8') dtype = 2 x = Tensor(data, inputs, dtype) actual = x.reduce(op, reduced_vars) expected_inputs = OrderedDict( (d, bint(sizes[d])) for d in dims if d not in reduced_vars) reduced_vars &= frozenset(dims) if not reduced_vars: assert actual is x else: if reduced_vars == frozenset(dims): data = REDUCE_OP_TO_NUMERIC[op](data, None) else: for pos in reversed(sorted(map(dims.index, reduced_vars))): data = REDUCE_OP_TO_NUMERIC[op](data, pos) check_funsor(actual, expected_inputs, Domain((), dtype)) assert_close(actual, Tensor(data, expected_inputs, dtype), atol=1e-5, rtol=1e-5)
def test_getitem_tensor(): data = randn((5, 4, 3, 2)) x = Tensor(data) i = Variable('i', bint(5)) j = Variable('j', bint(4)) k = Variable('k', bint(3)) m = Variable('m', bint(2)) y = random_tensor(OrderedDict(), bint(5)) assert_close(x[i](i=y), x[y]) y = random_tensor(OrderedDict(), bint(4)) assert_close(x[:, j](j=y), x[:, y]) y = random_tensor(OrderedDict(), bint(3)) assert_close(x[:, :, k](k=y), x[:, :, y]) y = random_tensor(OrderedDict(), bint(2)) assert_close(x[:, :, :, m](m=y), x[:, :, :, y]) y = random_tensor(OrderedDict([('i', i.output)]), bint(j.dtype)) assert_close(x[i, j](j=y), x[i, y]) y = random_tensor(OrderedDict([('i', i.output), ('j', j.output)]), bint(k.dtype)) assert_close(x[i, j, k](k=y), x[i, j, y])
def test_cons_hash(): assert Variable('x', bint(3)) is Variable('x', bint(3)) assert Variable('x', reals()) is Variable('x', reals()) assert Variable('x', reals()) is not Variable('x', bint(3)) assert Number(0, 3) is Number(0, 3) assert Number(0.) is Number(0.) assert Number(0.) is not Number(0, 3)
def test_lambda_getitem(): data = torch.randn(2) x = Tensor(data) y = Tensor(data, OrderedDict(i=bint(2))) i = Variable('i', bint(2)) assert x[i] is y assert Lambda(i, y) is x
def test_binomial_density(batch_shape, eager): batch_dims = ('i', 'j', 'k')[:len(batch_shape)] inputs = OrderedDict((k, bint(v)) for k, v in zip(batch_dims, batch_shape)) max_count = 10 @funsor.torch.function(reals(), reals(), reals(), reals()) def binomial(total_count, probs, value): return torch.distributions.Binomial(total_count, probs).log_prob(value) check_funsor(binomial, { 'total_count': reals(), 'probs': reals(), 'value': reals() }, reals()) value_data = random_tensor(inputs, bint(max_count)).data.float() total_count_data = value_data + random_tensor( inputs, bint(max_count)).data.float() value = Tensor(value_data, inputs) total_count = Tensor(total_count_data, inputs) probs = Tensor(torch.rand(batch_shape), inputs) expected = binomial(total_count, probs, value) check_funsor(expected, inputs, reals()) m = Variable('value', reals()) actual = dist.Binomial(total_count, probs, value) if eager else \ dist.Binomial(total_count, probs, m)(value=value) check_funsor(actual, inputs, reals()) assert_close(actual, expected)
def test_nested_einsum(eqn1, eqn2, optimize1, optimize2, backend1, backend2, einsum_impl): inputs1, outputs1, sizes1, operands1, _ = make_einsum_example(eqn1, sizes=(3,)) inputs2, outputs2, sizes2, operands2, funsor_operands2 = make_einsum_example(eqn2, sizes=(3,)) # normalize the probs for ground-truth comparison operands1 = [operand.abs() / operand.abs().sum(-1, keepdim=True) for operand in operands1] expected1 = pyro_einsum(eqn1, *operands1, backend=backend1, modulo_total=True)[0] expected2 = pyro_einsum(outputs1[0] + "," + eqn2, *([expected1] + operands2), backend=backend2, modulo_total=True)[0] with interpretation(normalize): funsor_operands1 = [ Categorical(probs=Tensor( operand, inputs=OrderedDict([(d, bint(sizes1[d])) for d in inp[:-1]]) ))(value=Variable(inp[-1], bint(sizes1[inp[-1]]))).exp() for inp, operand in zip(inputs1, operands1) ] output1_naive = einsum_impl(eqn1, *funsor_operands1, backend=backend1) with interpretation(reflect): output1 = apply_optimizer(output1_naive) if optimize1 else output1_naive output2_naive = einsum_impl(outputs1[0] + "," + eqn2, *([output1] + funsor_operands2), backend=backend2) with interpretation(reflect): output2 = apply_optimizer(output2_naive) if optimize2 else output2_naive actual1 = reinterpret(output1) actual2 = reinterpret(output2) assert torch.allclose(expected1, actual1.data) assert torch.allclose(expected2, actual2.data)
def test_reduce_subset(dims, reduced_vars, op): reduced_vars = frozenset(reduced_vars) sizes = {'a': 3, 'b': 4, 'c': 5} shape = tuple(sizes[d] for d in dims) inputs = OrderedDict((d, bint(sizes[d])) for d in dims) data = torch.rand(shape) + 0.5 dtype = 'real' if op in [ops.and_, ops.or_]: data = data.byte() dtype = 2 x = Tensor(data, inputs, dtype) actual = x.reduce(op, reduced_vars) expected_inputs = OrderedDict( (d, bint(sizes[d])) for d in dims if d not in reduced_vars) reduced_vars &= frozenset(dims) if not reduced_vars: assert actual is x else: if reduced_vars == frozenset(dims): if op is ops.logaddexp: # work around missing torch.Tensor.logsumexp() data = data.reshape(-1).logsumexp(0) else: data = REDUCE_OP_TO_TORCH[op](data) else: for pos in reversed(sorted(map(dims.index, reduced_vars))): data = REDUCE_OP_TO_TORCH[op](data, pos) if op in (ops.min, ops.max): data = data[0] check_funsor(actual, expected_inputs, Domain((), dtype)) assert_close(actual, Tensor(data, expected_inputs, dtype), atol=1e-5, rtol=1e-5)
def test_smoke(expr, expected_type): g1 = Gaussian(info_vec=numeric_array([[0.0, 0.1, 0.2], [2.0, 3.0, 4.0]]), precision=numeric_array([[[1.0, 0.1, 0.2], [0.1, 1.0, 0.3], [0.2, 0.3, 1.0]], [[1.0, 0.1, 0.2], [0.1, 1.0, 0.3], [0.2, 0.3, 1.0]]]), inputs=OrderedDict([('i', bint(2)), ('x', reals(3))])) assert isinstance(g1, Gaussian) g2 = Gaussian(info_vec=numeric_array([[0.0, 0.1], [2.0, 3.0]]), precision=numeric_array([[[1.0, 0.2], [0.2, 1.0]], [[1.0, 0.2], [0.2, 1.0]]]), inputs=OrderedDict([('i', bint(2)), ('y', reals(2))])) assert isinstance(g2, Gaussian) shift = Tensor(numeric_array([-1., 1.]), OrderedDict([('i', bint(2))])) assert isinstance(shift, Tensor) i0 = Number(1, 2) assert isinstance(i0, Number) x0 = Tensor(numeric_array([0.5, 0.6, 0.7])) assert isinstance(x0, Tensor) y0 = Tensor(numeric_array([[0.2, 0.3], [0.8, 0.9]]), inputs=OrderedDict([('i', bint(2))])) assert isinstance(y0, Tensor) result = eval(expr) assert isinstance(result, expected_type)
def test_smoke(expr, expected_type): dx = Delta('x', Tensor(torch.randn(2, 3), OrderedDict([('i', bint(2))]))) assert isinstance(dx, Delta) dy = Delta('y', Tensor(torch.randn(3, 4), OrderedDict([('j', bint(3))]))) assert isinstance(dy, Delta) t = Tensor(torch.randn(2, 3), OrderedDict([('i', bint(2)), ('j', bint(3))])) assert isinstance(t, Tensor) g = Gaussian(info_vec=torch.tensor([[0.0, 0.1, 0.2], [2.0, 3.0, 4.0]]), precision=torch.tensor([[[1.0, 0.1, 0.2], [0.1, 1.0, 0.3], [0.2, 0.3, 1.0]], [[1.0, 0.1, 0.2], [0.1, 1.0, 0.3], [0.2, 0.3, 1.0]]]), inputs=OrderedDict([('i', bint(2)), ('x', reals(3))])) assert isinstance(g, Gaussian) i0 = Number(1, 2) assert isinstance(i0, Number) x0 = Tensor(torch.tensor([0.5, 0.6, 0.7])) assert isinstance(x0, Tensor) result = eval(expr) assert isinstance(result, expected_type)
def test_independent(): f = Variable('x_i', reals(4, 5)) + random_tensor(OrderedDict(i=bint(3))) assert f.inputs['x_i'] == reals(4, 5) assert f.inputs['i'] == bint(3) actual = Independent(f, 'x', 'i', 'x_i') assert actual.inputs['x'] == reals(3, 4, 5) assert 'i' not in actual.inputs x = Variable('x', reals(3, 4, 5)) expected = f(x_i=x['i']).reduce(ops.add, 'i') assert actual.inputs == expected.inputs assert actual.output == expected.output data = random_tensor(OrderedDict(), x.output) assert_close(actual(data), expected(data), atol=1e-5, rtol=1e-5) renamed = actual(x='y') assert isinstance(renamed, Independent) assert_close(renamed(y=data), expected(x=data), atol=1e-5, rtol=1e-5) # Ensure it's ok for .reals_var and .diag_var to be the same. renamed = actual(x='x_i') assert isinstance(renamed, Independent) assert_close(renamed(x_i=data), expected(x=data), atol=1e-5, rtol=1e-5)
def test_stack_subs(): x = Variable('x', reals()) y = Variable('y', reals()) z = Variable('z', reals()) j = Variable('j', bint(3)) f = Stack('i', (Number(0), x, y * z)) check_funsor(f, { 'i': bint(3), 'x': reals(), 'y': reals(), 'z': reals() }, reals()) assert f(i=Number(0, 3)) is Number(0) assert f(i=Number(1, 3)) is x assert f(i=Number(2, 3)) is y * z assert f(i=j) is Stack('j', (Number(0), x, y * z)) assert f(i='j') is Stack('j', (Number(0), x, y * z)) assert f.reduce(ops.add, 'i') is Number(0) + x + (y * z) assert f(x=0) is Stack('i', (Number(0), Number(0), y * z)) assert f(y=x) is Stack('i', (Number(0), x, x * z)) assert f(x=0, y=x) is Stack('i', (Number(0), Number(0), x * z)) assert f(x=0, y=x, i=Number(2, 3)) is x * z assert f(x=0, i=j) is Stack('j', (Number(0), Number(0), y * z)) assert f(x=0, i='j') is Stack('j', (Number(0), Number(0), y * z))
def __getitem__(self, other): if type(other) is not tuple: other = to_funsor(other, bint(self.output.shape[0])) return Binary(ops.getitem, self, other) # Handle Ellipsis slicing. if any(part is Ellipsis for part in other): left = [] for part in other: if part is Ellipsis: break left.append(part) right = [] for part in reversed(other): if part is Ellipsis: break right.append(part) right.reverse() missing = len(self.output.shape) - len(left) - len(right) assert missing >= 0 middle = [slice(None)] * missing other = tuple(left + middle + right) # Handle each slice separately. result = self offset = 0 for part in other: if isinstance(part, slice): if part != slice(None): raise NotImplementedError('TODO support nontrivial slicing') offset += 1 else: part = to_funsor(part, bint(result.output.shape[offset])) result = Binary(GetitemOp(offset), result, part) return result
def test_cat_simple(output): x = random_tensor(OrderedDict([ ('i', bint(2)), ]), output) y = random_tensor(OrderedDict([ ('i', bint(3)), ('j', bint(4)), ]), output) z = random_tensor(OrderedDict([ ('i', bint(5)), ('k', bint(6)), ]), output) assert Cat('i', (x, )) is x assert Cat('i', (y, )) is y assert Cat('i', (z, )) is z xy = Cat('i', (x, y)) assert isinstance(xy, Tensor) assert xy.inputs == OrderedDict([ ('i', bint(2 + 3)), ('j', bint(4)), ]) assert xy.output == output xyz = Cat('i', (x, y, z)) assert isinstance(xyz, Tensor) assert xyz.inputs == OrderedDict([ ('i', bint(2 + 3 + 5)), ('j', bint(4)), ('k', bint(6)), ]) assert xy.output == output
def test_subs_lambda(): z = Variable('z', reals()) i = Variable('i', bint(5)) ix = random_tensor(OrderedDict([('i', bint(5))]), reals()) actual = Lambda(i, z)(z=ix) expected = Lambda(i(i='j'), z(z=ix)) check_funsor(actual, expected.inputs, expected.output) assert_close(actual, expected)
def test_getitem_variable(): data = randn((5, 4, 3, 2)) x = Tensor(data) i = Variable('i', bint(5)) j = Variable('j', bint(4)) assert x[i] is Tensor(data, OrderedDict([('i', bint(5))])) assert x[i, j] is Tensor(data, OrderedDict([('i', bint(5)), ('j', bint(4))]))
def mixed_sequential_sum_product(sum_op, prod_op, trans, time, step, num_segments=None): """ For a funsor ``trans`` with dimensions ``time``, ``prev`` and ``curr``, computes a recursion equivalent to:: tail_time = 1 + arange("time", trans.inputs["time"].size - 1) tail = sequential_sum_product(sum_op, prod_op, trans(time=tail_time), time, {"prev": "curr"}) return prod_op(trans(time=0)(curr="drop"), tail(prev="drop")) \ .reduce(sum_op, "drop") by mixing parallel and serial scan algorithms over ``num_segments`` segments. :param ~funsor.ops.AssociativeOp sum_op: A semiring sum operation. :param ~funsor.ops.AssociativeOp prod_op: A semiring product operation. :param ~funsor.terms.Funsor trans: A transition funsor. :param Variable time: The time input dimension. :param dict step: A dict mapping previous variables to current variables. This can contain multiple pairs of prev->curr variable names. :param int num_segments: number of segments for the first stage """ time_var, time, duration = time, time.name, time.output.size num_segments = duration if num_segments is None else num_segments assert num_segments > 0 and duration > 0 # handle unevenly sized segments by chopping off the final segment and calling mixed_sequential_sum_product again if duration % num_segments and duration - duration % num_segments > 0: remainder = trans(**{time: Slice(time, duration - duration % num_segments, duration, 1, duration)}) initial = trans(**{time: Slice(time, 0, duration - duration % num_segments, 1, duration)}) initial_eliminated = mixed_sequential_sum_product( sum_op, prod_op, initial, Variable(time, bint(duration - duration % num_segments)), step, num_segments=num_segments) final = Cat(time, (Stack(time, (initial_eliminated,)), remainder)) final_eliminated = naive_sequential_sum_product( sum_op, prod_op, final, Variable(time, bint(1 + duration % num_segments)), step) return final_eliminated # handle degenerate cases that reduce to a single stage if num_segments == 1: return naive_sequential_sum_product(sum_op, prod_op, trans, time_var, step) if num_segments >= duration: return sequential_sum_product(sum_op, prod_op, trans, time_var, step) # break trans into num_segments segments of equal length segment_length = duration // num_segments segments = [trans(**{time: Slice(time, i * segment_length, (i + 1) * segment_length, 1, duration)}) for i in range(num_segments)] first_stage_result = naive_sequential_sum_product( sum_op, prod_op, Stack(time + "__SEGMENTED", tuple(segments)), Variable(time, bint(segment_length)), step) second_stage_result = sequential_sum_product( sum_op, prod_op, first_stage_result, Variable(time + "__SEGMENTED", bint(num_segments)), step) return second_stage_result
def test_slice_lambda(): z = Variable('z', reals()) i = Variable('i', bint(5)) j = Variable('j', bint(7)) zi = Lambda(i, z) zj = Lambda(j, z) zij = Lambda(j, zi) zj2 = zij[:, i] check_funsor(zj2, zj.inputs, zj.output)
def test_subs_reduce(): x = random_tensor(OrderedDict([('i', bint(3)), ('j', bint(2))]), reals()) ix = random_tensor(OrderedDict([('i', bint(3))]), bint(2)) ix2 = ix(i='i2') with interpretation(reflect): actual = x.reduce(ops.add, frozenset({"i"})) actual = actual(j=ix) expected = x(j=ix2).reduce(ops.add, frozenset({"i"}))(i2='i') assert_close(actual, expected)
def test_sample_subs_smoke(): x = random_tensor(OrderedDict([('i', bint(3)), ('j', bint(2))]), reals()) with interpretation(reflect): z = x(i=1) rng_key = None if get_backend() == "torch" else np.array([0, 1], dtype=np.uint32) actual = z.sample(frozenset({"j"}), OrderedDict({"i": bint(4)}), rng_key=rng_key) check_funsor(actual, {"j": bint(2), "i": bint(4)}, reals())
def test_reduce_logaddexp_gaussian_lazy(): a = random_gaussian(OrderedDict(i=bint(3), a=reals(2))) b = random_tensor(OrderedDict(i=bint(3), b=bint(2))) x = a + b assert isinstance(x, Contraction) assert set(x.inputs) == {'a', 'b', 'i'} y = x.reduce(ops.logaddexp, 'i') # assert isinstance(y, Reduce) assert set(y.inputs) == {'a', 'b'} assert_close(x.reduce(ops.logaddexp), y.reduce(ops.logaddexp))
def test_reduce_logaddexp_deltas_lazy(): a = Delta('a', Tensor(torch.randn(3, 2), OrderedDict(i=bint(3)))) b = Delta('b', Tensor(torch.randn(3), OrderedDict(i=bint(3)))) x = a + b assert isinstance(x, Delta) assert set(x.inputs) == {'a', 'b', 'i'} y = x.reduce(ops.logaddexp, 'i') # assert isinstance(y, Reduce) assert set(y.inputs) == {'a', 'b'} assert_close(x.reduce(ops.logaddexp), y.reduce(ops.logaddexp))
def test_cons_hash(): assert Variable('x', bint(3)) is Variable('x', bint(3)) assert Variable('x', reals()) is Variable('x', reals()) assert Variable('x', reals()) is not Variable('x', bint(3)) assert Number(0, 3) is Number(0, 3) assert Number(0.) is Number(0.) assert Number(0.) is not Number(0, 3) assert Slice('x', 10) is Slice('x', 10) assert Slice('x', 10) is Slice('x', 0, 10) assert Slice('x', 10, 10) is not Slice('x', 0, 10) assert Slice('x', 2, 10, 1) is Slice('x', 2, 10)
def test_syntactic_sugar(): i = Variable("i", bint(3)) log_measure = random_tensor(OrderedDict(i=bint(3))) integrand = random_tensor(OrderedDict(i=bint(3))) expected = (log_measure.exp() * integrand).reduce(ops.add, "i") assert_close(Integrate(log_measure, integrand, "i"), expected) assert_close(Integrate(log_measure, integrand, {"i"}), expected) assert_close(Integrate(log_measure, integrand, frozenset(["i"])), expected) assert_close(Integrate(log_measure, integrand, i), expected) assert_close(Integrate(log_measure, integrand, {i}), expected) assert_close(Integrate(log_measure, integrand, frozenset([i])), expected)
def test_reduce_logaddexp_deltas_discrete_lazy(): a = Delta('a', Tensor(randn(3, 2), OrderedDict(i=bint(3)))) b = Delta('b', Tensor(randn(3), OrderedDict(i=bint(3)))) c = Tensor(randn(3), OrderedDict(i=bint(3))) x = a + b + c assert isinstance(x, Contraction) assert set(x.inputs) == {'a', 'b', 'i'} y = x.reduce(ops.logaddexp, 'i') # assert isinstance(y, Reduce) assert set(y.inputs) == {'a', 'b'} assert_close(x.reduce(ops.logaddexp), y.reduce(ops.logaddexp))
def test_sarkka_bilmes_example_0(duration): trans = random_tensor(OrderedDict({ "time": bint(duration), "a": bint(3), })) expected_inputs = { "a": bint(3), } _check_sarkka_bilmes(trans, expected_inputs, frozenset())
def test_binomial_sample(with_lazy, batch_shape, sample_inputs): batch_dims = ('i', 'j', 'k')[:len(batch_shape)] inputs = OrderedDict((k, bint(v)) for k, v in zip(batch_dims, batch_shape)) max_count = 10 total_count_data = random_tensor(inputs, bint(max_count)).data.float() total_count = Tensor(total_count_data, inputs) probs = Tensor(torch.rand(batch_shape), inputs) with interpretation(lazy if with_lazy else eager): funsor_dist = dist.Binomial(total_count, probs) _check_sample(funsor_dist, sample_inputs, inputs, skip_grad=True)
def test_getitem_number_2_inputs(): data = randn((3, 4, 5, 4, 3, 2)) inputs = OrderedDict([('i', bint(3)), ('j', bint(4))]) x = Tensor(data, inputs) assert_close(x[2], Tensor(data[:, :, 2], inputs)) assert_close(x[:, 1], Tensor(data[:, :, :, 1], inputs)) assert_close(x[2, 1], Tensor(data[:, :, 2, 1], inputs)) assert_close(x[2, :, 1], Tensor(data[:, :, 2, :, 1], inputs)) assert_close(x[3, ...], Tensor(data[:, :, 3, ...], inputs)) assert_close(x[3, 2, ...], Tensor(data[:, :, 3, 2, ...], inputs)) assert_close(x[..., 1], Tensor(data[..., 1], inputs)) assert_close(x[..., 2, 1], Tensor(data[..., 2, 1], inputs)) assert_close(x[3, ..., 1], Tensor(data[:, :, 3, ..., 1], inputs))