def test_einsum(equation): sizes = dict(a=2, b=3, c=4) inputs, outputs = equation.split('->') inputs = inputs.split(',') tensors = [randn(tuple(sizes[d] for d in dims)) for dims in inputs] funsors = [Tensor(x) for x in tensors] expected = Tensor(ops.einsum(equation, *tensors)) actual = Einsum(equation, tuple(funsors)) assert_close(actual, expected, atol=1e-5, rtol=None)
def test_mvn_affine_einsum(): c = Tensor(torch.randn(3, 2, 2)) x = Variable('x', reals(2, 2)) y = Variable('y', reals()) data = dict(x=Tensor(torch.randn(2, 2)), y=Tensor(torch.randn(()))) with interpretation(lazy): d = dist_to_funsor(random_mvn((), 3)) d = d(value=Einsum("abc,bc->a", c, x) + y) _check_mvn_affine(d, data)
def test_batched_einsum(equation, batch1, batch2): inputs, output = equation.split('->') inputs = inputs.split(',') sizes = dict(a=2, b=3, c=4, i=5, j=6) batch1 = OrderedDict([(k, bint(sizes[k])) for k in batch1]) batch2 = OrderedDict([(k, bint(sizes[k])) for k in batch2]) funsors = [ random_tensor(batch, reals(*(sizes[d] for d in dims))) for batch, dims in zip([batch1, batch2], inputs) ] actual = Einsum(equation, tuple(funsors)) _equation = ','.join('...' + i for i in inputs) + '->...' + output inputs, tensors = align_tensors(*funsors) batch = tuple(v.size for v in inputs.values()) tensors = [ ops.expand(x, batch + f.shape) for (x, f) in zip(tensors, funsors) ] expected = Tensor(ops.einsum(_equation, *tensors), inputs) assert_close(actual, expected, atol=1e-5, rtol=None)
def test_extract_affine(expr): x = eval(expr) assert is_affine(x) assert isinstance(x, (Unary, Contraction, Einsum)) real_inputs = OrderedDict((k, d) for k, d in x.inputs.items() if d.dtype == 'real') const, coeffs = extract_affine(x) assert isinstance(const, Tensor) assert const.shape == x.shape assert list(coeffs) == list(real_inputs) for name, (coeff, eqn) in coeffs.items(): assert isinstance(name, str) assert isinstance(coeff, Tensor) assert isinstance(eqn, str) subs = {k: random_tensor(OrderedDict(), d) for k, d in real_inputs.items()} expected = x(**subs) assert isinstance(expected, Tensor) actual = const + sum(Einsum(eqn, (coeff, subs[k])) for k, (coeff, eqn) in coeffs.items()) assert isinstance(actual, Tensor) assert_close(actual, expected)