def test_optimized_plated_einsum(equation, plates, backend): inputs, outputs, sizes, operands, funsor_operands = make_einsum_example( equation) expected = pyro_einsum.einsum(equation, *operands, plates=plates, backend=backend)[0] actual = einsum(equation, *funsor_operands, plates=plates, backend=backend) if len(equation) < 10: actual_naive = naive_plated_einsum(equation, *funsor_operands, plates=plates, backend=backend) assert_close(actual, actual_naive) assert isinstance(actual, funsor.Tensor) and len(outputs) == 1 if len(outputs[0]) > 0: actual = actual.align(tuple(outputs[0])) assert expected.shape == actual.data.shape assert torch.allclose(expected, actual.data) for output in outputs: for i, output_dim in enumerate(output): assert output_dim in actual.inputs assert actual.inputs[output_dim].dtype == sizes[output_dim]
def test_einsum_linear(equation, plates): inputs, outputs, log_operands, sizes = make_example(equation) operands = [x.exp() for x in log_operands] try: log_expected = ubersum(equation, *log_operands, plates=plates, modulo_total=True) expected = [x.exp() for x in log_expected] except NotImplementedError: pytest.skip() # einsum() is in linear space whereas ubersum() is in log space. actual = einsum(equation, *operands, plates=plates, modulo_total=True) assert isinstance(actual, tuple) assert len(actual) == len(outputs) for output, expected_part, actual_part in zip(outputs, expected, actual): assert_equal( expected_part.log(), actual_part.log(), msg="For output '{}':\nExpected:\n{}\nActual:\n{}".format( output, expected_part.detach().cpu(), actual_part.detach().cpu()), )
def _compute_tmc_estimate(model_trace, guide_trace): """ Use :func:`~pyro.ops.contract.einsum` to compute the Tensor Monte Carlo estimate of the marginal likelihood given parallel-sampled traces. """ # factors log_factors = _compute_tmc_factors(model_trace, guide_trace) log_factors += _compute_dice_factors(model_trace, guide_trace) if not log_factors: return 0. # loss eqn = ",".join([f._pyro_dims for f in log_factors]) + "->" plates = "".join(frozenset().union(list(model_trace.plate_to_symbol.values()), list(guide_trace.plate_to_symbol.values()))) tmc, = einsum(eqn, *log_factors, plates=plates, backend="pyro.ops.einsum.torch_log", modulo_total=False) return tmc
def _forward_backward(*operands): # First we request backward results on each input operand. # This is the pyro.ops.adjoint equivalent of torch's .requires_grad_(). for operand in operands: require_backward(operand) # Next we run the forward pass. results = einsum(equation, *operands, backend=backend, **kwargs) # The we run a backward pass. for result in results: result._pyro_backward() # Finally we retrieve results from the ._pyro_backward_result attribute # that has been set on each input operand. If you only want results on a # subset of operands, you can call require_backward() on only those. results = [] for x in operands: results.append(x._pyro_backward_result) x._pyro_backward_result = None return tuple(results)
def _einsum(*operands): return einsum(equation, *operands, backend='pyro.ops.einsum.torch_log', **kwargs)
def _einsum(*operands): return einsum(equation, *operands, **kwargs)