def test_marginal(equation): inputs, output = equation.split("->") inputs = inputs.split(",") operands = [ torch.randn(torch.Size((2, ) * len(input_))) for input_ in inputs ] for input_, x in zip(inputs, operands): x._pyro_dims = input_ # check forward pass for x in operands: require_backward(x) actual = contract(equation, *operands, backend="pyro.ops.einsum.torch_marginal") expected = contract(equation, *operands, backend="pyro.ops.einsum.torch_log") assert_equal(expected, actual) # check backward pass actual._pyro_backward() for input_, operand in zip(inputs, operands): marginal_equation = ",".join(inputs) + "->" + input_ expected = contract(marginal_equation, *operands, backend="pyro.ops.einsum.torch_log") actual = operand._pyro_backward_result assert_equal(expected, actual)
def test_einsum(equation, min_size): inputs, output = equation.split('->') inputs = inputs.split(',') symbols = sorted(set(equation) - set(',->')) sizes = dict(zip(symbols, itertools.count(min_size))) shapes = [torch.Size(tuple(sizes[dim] for dim in dims)) for dims in inputs] operands = [torch.randn(shape) for shape in shapes] expected = contract(equation, *(torch_exp(x) for x in operands), backend='torch').log() actual = contract(equation, *operands, backend='pyro.ops.einsum.torch_log') assert_equal(actual, expected)
def sumproduct(self, terms, dims): inputs = [term._pyro_dims for term in terms] output = "".join(sorted(set("".join(inputs)) - set(dims))) equation = ",".join(inputs) + "->" + output term = contract(equation, *terms, backend=self._backend) term._pyro_dims = output return term
def test_einsum(equation, min_size, infinite): inputs, output = equation.split("->") inputs = inputs.split(",") symbols = sorted(set(equation) - set(",->")) sizes = dict(zip(symbols, itertools.count(min_size))) shapes = [torch.Size(tuple(sizes[dim] for dim in dims)) for dims in inputs] operands = [ torch.full(shape, -float("inf")) if infinite else torch.randn(shape) for shape in shapes ] expected = contract(equation, *(torch_exp(x) for x in operands), backend="torch").log() actual = contract(equation, *operands, backend="pyro.ops.einsum.torch_log") assert_equal(actual, expected)
def test_shape(backend, equation): backend = "pyro.ops.einsum.torch_{}".format(backend) inputs, output = equation.split("->") inputs = inputs.split(",") symbols = sorted(set(equation) - set(",->")) sizes = dict(zip(symbols, itertools.count(2))) input_shapes = [torch.Size(sizes[dim] for dim in dims) for dims in inputs] operands = [torch.randn(shape) for shape in input_shapes] for input_, x in zip(inputs, operands): x._pyro_dims = input_ # check forward pass for x in operands: require_backward(x) expected = contract(equation, *operands, backend="pyro.ops.einsum.torch_log") actual = contract(equation, *operands, backend=backend) if backend.endswith("map"): assert actual.dtype == expected.dtype assert actual.shape == expected.shape else: assert_equal(actual, expected) # check backward pass actual._pyro_backward() for input_, x in zip(inputs, operands): backward_result = x._pyro_backward_result if backend.endswith("marginal"): assert backward_result.shape == x.shape else: contract_dims = set(input_) - set(output) if contract_dims: assert backward_result.size(0) == len(contract_dims) assert set(backward_result._pyro_dims[1:]) == set(output) for sample, dim in zip(backward_result, backward_result._pyro_sample_dims): assert sample.min() >= 0 assert sample.max() < sizes[dim] else: assert backward_result is None