Beispiel #1
0
def test_optimized_plated_einsum(equation, plates, backend):
    inputs, outputs, sizes, operands, funsor_operands = make_einsum_example(
        equation)
    expected = pyro_einsum.einsum(equation,
                                  *operands,
                                  plates=plates,
                                  backend=backend)[0]
    actual = einsum(equation, *funsor_operands, plates=plates, backend=backend)

    if len(equation) < 10:
        actual_naive = naive_plated_einsum(equation,
                                           *funsor_operands,
                                           plates=plates,
                                           backend=backend)
        assert_close(actual, actual_naive)

    assert isinstance(actual, funsor.Tensor) and len(outputs) == 1
    if len(outputs[0]) > 0:
        actual = actual.align(tuple(outputs[0]))

    assert expected.shape == actual.data.shape
    assert torch.allclose(expected, actual.data)
    for output in outputs:
        for i, output_dim in enumerate(output):
            assert output_dim in actual.inputs
            assert actual.inputs[output_dim].dtype == sizes[output_dim]
Beispiel #2
0
def test_einsum_linear(equation, plates):
    inputs, outputs, log_operands, sizes = make_example(equation)
    operands = [x.exp() for x in log_operands]

    try:
        log_expected = ubersum(equation,
                               *log_operands,
                               plates=plates,
                               modulo_total=True)
        expected = [x.exp() for x in log_expected]
    except NotImplementedError:
        pytest.skip()

    # einsum() is in linear space whereas ubersum() is in log space.
    actual = einsum(equation, *operands, plates=plates, modulo_total=True)
    assert isinstance(actual, tuple)
    assert len(actual) == len(outputs)
    for output, expected_part, actual_part in zip(outputs, expected, actual):
        assert_equal(
            expected_part.log(),
            actual_part.log(),
            msg="For output '{}':\nExpected:\n{}\nActual:\n{}".format(
                output,
                expected_part.detach().cpu(),
                actual_part.detach().cpu()),
        )
Beispiel #3
0
def _compute_tmc_estimate(model_trace, guide_trace):
    """
    Use :func:`~pyro.ops.contract.einsum` to compute the Tensor Monte Carlo
    estimate of the marginal likelihood given parallel-sampled traces.
    """
    # factors
    log_factors = _compute_tmc_factors(model_trace, guide_trace)
    log_factors += _compute_dice_factors(model_trace, guide_trace)

    if not log_factors:
        return 0.

    # loss
    eqn = ",".join([f._pyro_dims for f in log_factors]) + "->"
    plates = "".join(frozenset().union(list(model_trace.plate_to_symbol.values()),
                                       list(guide_trace.plate_to_symbol.values())))
    tmc, = einsum(eqn, *log_factors, plates=plates,
                  backend="pyro.ops.einsum.torch_log",
                  modulo_total=False)
    return tmc
Beispiel #4
0
        def _forward_backward(*operands):
            # First we request backward results on each input operand.
            # This is the pyro.ops.adjoint equivalent of torch's .requires_grad_().
            for operand in operands:
                require_backward(operand)

            # Next we run the forward pass.
            results = einsum(equation, *operands, backend=backend, **kwargs)

            # The we run a backward pass.
            for result in results:
                result._pyro_backward()

            # Finally we retrieve results from the ._pyro_backward_result attribute
            # that has been set on each input operand. If you only want results on a
            # subset of operands, you can call require_backward() on only those.
            results = []
            for x in operands:
                results.append(x._pyro_backward_result)
                x._pyro_backward_result = None

            return tuple(results)
Beispiel #5
0
 def _einsum(*operands):
     return einsum(equation,
                   *operands,
                   backend='pyro.ops.einsum.torch_log',
                   **kwargs)
Beispiel #6
0
 def _einsum(*operands):
     return einsum(equation, *operands, **kwargs)