Exemplo n.º 1
0
def test_marginal(equation):
    inputs, output = equation.split("->")
    inputs = inputs.split(",")
    operands = [
        torch.randn(torch.Size((2, ) * len(input_))) for input_ in inputs
    ]
    for input_, x in zip(inputs, operands):
        x._pyro_dims = input_

    # check forward pass
    for x in operands:
        require_backward(x)
    actual = contract(equation,
                      *operands,
                      backend="pyro.ops.einsum.torch_marginal")
    expected = contract(equation,
                        *operands,
                        backend="pyro.ops.einsum.torch_log")
    assert_equal(expected, actual)

    # check backward pass
    actual._pyro_backward()
    for input_, operand in zip(inputs, operands):
        marginal_equation = ",".join(inputs) + "->" + input_
        expected = contract(marginal_equation,
                            *operands,
                            backend="pyro.ops.einsum.torch_log")
        actual = operand._pyro_backward_result
        assert_equal(expected, actual)
Exemplo n.º 2
0
def test_einsum(equation, min_size):
    inputs, output = equation.split('->')
    inputs = inputs.split(',')
    symbols = sorted(set(equation) - set(',->'))
    sizes = dict(zip(symbols, itertools.count(min_size)))
    shapes = [torch.Size(tuple(sizes[dim] for dim in dims)) for dims in inputs]
    operands = [torch.randn(shape) for shape in shapes]

    expected = contract(equation,
                        *(torch_exp(x) for x in operands),
                        backend='torch').log()
    actual = contract(equation, *operands, backend='pyro.ops.einsum.torch_log')
    assert_equal(actual, expected)
Exemplo n.º 3
0
 def sumproduct(self, terms, dims):
     inputs = [term._pyro_dims for term in terms]
     output = "".join(sorted(set("".join(inputs)) - set(dims)))
     equation = ",".join(inputs) + "->" + output
     term = contract(equation, *terms, backend=self._backend)
     term._pyro_dims = output
     return term
Exemplo n.º 4
0
def test_einsum(equation, min_size, infinite):
    inputs, output = equation.split("->")
    inputs = inputs.split(",")
    symbols = sorted(set(equation) - set(",->"))
    sizes = dict(zip(symbols, itertools.count(min_size)))
    shapes = [torch.Size(tuple(sizes[dim] for dim in dims)) for dims in inputs]
    operands = [
        torch.full(shape, -float("inf")) if infinite else torch.randn(shape)
        for shape in shapes
    ]

    expected = contract(equation,
                        *(torch_exp(x) for x in operands),
                        backend="torch").log()
    actual = contract(equation, *operands, backend="pyro.ops.einsum.torch_log")
    assert_equal(actual, expected)
Exemplo n.º 5
0
def test_shape(backend, equation):
    backend = "pyro.ops.einsum.torch_{}".format(backend)
    inputs, output = equation.split("->")
    inputs = inputs.split(",")
    symbols = sorted(set(equation) - set(",->"))
    sizes = dict(zip(symbols, itertools.count(2)))
    input_shapes = [torch.Size(sizes[dim] for dim in dims) for dims in inputs]
    operands = [torch.randn(shape) for shape in input_shapes]
    for input_, x in zip(inputs, operands):
        x._pyro_dims = input_

    # check forward pass
    for x in operands:
        require_backward(x)
    expected = contract(equation,
                        *operands,
                        backend="pyro.ops.einsum.torch_log")
    actual = contract(equation, *operands, backend=backend)
    if backend.endswith("map"):
        assert actual.dtype == expected.dtype
        assert actual.shape == expected.shape
    else:
        assert_equal(actual, expected)

    # check backward pass
    actual._pyro_backward()
    for input_, x in zip(inputs, operands):
        backward_result = x._pyro_backward_result
        if backend.endswith("marginal"):
            assert backward_result.shape == x.shape
        else:
            contract_dims = set(input_) - set(output)
            if contract_dims:
                assert backward_result.size(0) == len(contract_dims)
                assert set(backward_result._pyro_dims[1:]) == set(output)
                for sample, dim in zip(backward_result,
                                       backward_result._pyro_sample_dims):
                    assert sample.min() >= 0
                    assert sample.max() < sizes[dim]
            else:
                assert backward_result is None