def test_marginal(equation): inputs, output = equation.split("->") inputs = inputs.split(",") operands = [ torch.randn(torch.Size((2, ) * len(input_))) for input_ in inputs ] for input_, x in zip(inputs, operands): x._pyro_dims = input_ # check forward pass for x in operands: require_backward(x) actual = contract(equation, *operands, backend="pyro.ops.einsum.torch_marginal") expected = contract(equation, *operands, backend="pyro.ops.einsum.torch_log") assert_equal(expected, actual) # check backward pass actual._pyro_backward() for input_, operand in zip(inputs, operands): marginal_equation = ",".join(inputs) + "->" + input_ expected = contract(marginal_equation, *operands, backend="pyro.ops.einsum.torch_log") actual = operand._pyro_backward_result assert_equal(expected, actual)
def test_adjoint_marginal(equation, plates): inputs, output = equation.split('->') inputs = inputs.split(',') operands = [torch.randn(torch.Size((2,) * len(input_))) for input_ in inputs] for input_, x in zip(inputs, operands): x._pyro_dims = input_ # check forward pass for x in operands: require_backward(x) actual, = ubersum(equation, *operands, plates=plates, modulo_total=True, backend='pyro.ops.einsum.torch_marginal') expected, = ubersum(equation, *operands, plates=plates, modulo_total=True, backend='pyro.ops.einsum.torch_log') assert_equal(expected, actual) # check backward pass actual._pyro_backward() for input_, operand in zip(inputs, operands): marginal_equation = ','.join(inputs) + '->' + input_ expected, = ubersum(marginal_equation, *operands, plates=plates, modulo_total=True, backend='pyro.ops.einsum.torch_log') actual = operand._pyro_backward_result assert_equal(expected, actual)
def test_adjoint_shape(backend, equation, plates): backend = "pyro.ops.einsum.torch_{}".format(backend) inputs, output = equation.split("->") inputs = inputs.split(",") operands = [ torch.randn(torch.Size((2, ) * len(input_))) for input_ in inputs ] for input_, x in zip(inputs, operands): x._pyro_dims = input_ # run forward-backward algorithm for x in operands: require_backward(x) (result, ) = ubersum(equation, *operands, plates=plates, modulo_total=True, backend=backend) result._pyro_backward() for input_, x in zip(inputs, operands): backward_result = x._pyro_backward_result contract_dims = set(input_) - set(output) - set(plates) if contract_dims: assert backward_result is not None else: assert backward_result is None
def compute_expectation(self, costs): """ Returns a differentiable expected cost, summing over costs at given ordinals. :param dict costs: A dict mapping ordinals to lists of cost tensors :returns: a scalar expected cost :rtype: torch.Tensor or float """ # Share computation across all cost terms. with shared_intermediates() as cache: ring = MarginalRing(cache=cache) expected_cost = 0. for ordinal, cost_terms in costs.items(): log_factors = self._get_log_factors(ordinal) scale = math.exp(sum(x for x in log_factors if not isinstance(x, torch.Tensor))) log_factors = [x for x in log_factors if isinstance(x, torch.Tensor)] # Collect log_prob terms to query for marginal probability. queries = {frozenset(cost._pyro_dims): None for cost in cost_terms} for log_factor in log_factors: key = frozenset(log_factor._pyro_dims) if queries.get(key, False) is None: queries[key] = log_factor # Ensure a query exists for each cost term. for cost in cost_terms: key = frozenset(cost._pyro_dims) if queries[key] is None: query = cost.new_zeros(cost.shape) query._pyro_dims = cost._pyro_dims log_factors.append(query) queries[key] = query # Perform sum-product contraction. Note that plates never need to be # product-contracted due to our plate-based dependency ordering. sum_dims = set().union(*(x._pyro_dims for x in log_factors)) - ordinal for query in queries.values(): require_backward(query) root = ring.sumproduct(log_factors, sum_dims) root._pyro_backward() probs = {key: query._pyro_backward_result.exp() for key, query in queries.items()} # Aggregate prob * cost terms. for cost in cost_terms: key = frozenset(cost._pyro_dims) prob = probs[key] prob._pyro_dims = queries[key]._pyro_dims mask = prob > 0 if torch._C._get_tracing_state() or not mask.all(): mask._pyro_dims = prob._pyro_dims cost, prob, mask = packed.broadcast_all(cost, prob, mask) prob = prob[mask] cost = cost[mask] else: cost, prob = packed.broadcast_all(cost, prob) expected_cost = expected_cost + scale * torch.tensordot(prob, cost, prob.dim()) LAST_CACHE_SIZE[0] = count_cached_ops(cache) return expected_cost
def _get_trace(self, model, guide, args, kwargs): model_trace, guide_trace = super()._get_trace(model, guide, args, kwargs) # Mark all sample sites with require_backward to gather enumerated # sites and adjust cond_indep_stack of all sample sites. for node in model_trace.nodes.values(): if node["type"] == "sample" and not node["is_observed"]: log_prob = node["packed"]["unscaled_log_prob"] require_backward(log_prob) self._saved_state = model, model_trace, guide_trace, args, kwargs return model_trace, guide_trace
def test_shape(backend, equation): backend = "pyro.ops.einsum.torch_{}".format(backend) inputs, output = equation.split("->") inputs = inputs.split(",") symbols = sorted(set(equation) - set(",->")) sizes = dict(zip(symbols, itertools.count(2))) input_shapes = [torch.Size(sizes[dim] for dim in dims) for dims in inputs] operands = [torch.randn(shape) for shape in input_shapes] for input_, x in zip(inputs, operands): x._pyro_dims = input_ # check forward pass for x in operands: require_backward(x) expected = contract(equation, *operands, backend="pyro.ops.einsum.torch_log") actual = contract(equation, *operands, backend=backend) if backend.endswith("map"): assert actual.dtype == expected.dtype assert actual.shape == expected.shape else: assert_equal(actual, expected) # check backward pass actual._pyro_backward() for input_, x in zip(inputs, operands): backward_result = x._pyro_backward_result if backend.endswith("marginal"): assert backward_result.shape == x.shape else: contract_dims = set(input_) - set(output) if contract_dims: assert backward_result.size(0) == len(contract_dims) assert set(backward_result._pyro_dims[1:]) == set(output) for sample, dim in zip(backward_result, backward_result._pyro_sample_dims): assert sample.min() >= 0 assert sample.max() < sizes[dim] else: assert backward_result is None
def _forward_backward(*operands): # First we request backward results on each input operand. # This is the pyro.ops.adjoint equivalent of torch's .requires_grad_(). for operand in operands: require_backward(operand) # Next we run the forward pass. results = einsum(equation, *operands, backend=backend, **kwargs) # The we run a backward pass. for result in results: result._pyro_backward() # Finally we retrieve results from the ._pyro_backward_result attribute # that has been set on each input operand. If you only want results on a # subset of operands, you can call require_backward() on only those. results = [] for x in operands: results.append(x._pyro_backward_result) x._pyro_backward_result = None return tuple(results)
def _sample_posterior_from_trace(model, enum_trace, temperature, *args, **kwargs): plate_to_symbol = enum_trace.plate_to_symbol # Collect a set of query sample sites to which the backward algorithm will propagate. sum_dims = set() queries = [] dim_to_size = {} cost_terms = OrderedDict() enum_terms = OrderedDict() for node in enum_trace.nodes.values(): if node["type"] == "sample": ordinal = frozenset(plate_to_symbol[f.name] for f in node["cond_indep_stack"] if f.vectorized and f.size > 1) # For sites that depend on an enumerated variable, we need to apply # the mask but not the scale when sampling. if "masked_log_prob" not in node["packed"]: node["packed"]["masked_log_prob"] = packed.scale_and_mask( node["packed"]["unscaled_log_prob"], mask=node["packed"]["mask"]) log_prob = node["packed"]["masked_log_prob"] sum_dims.update(frozenset(log_prob._pyro_dims) - ordinal) if sum_dims.isdisjoint(log_prob._pyro_dims): continue dim_to_size.update(zip(log_prob._pyro_dims, log_prob.shape)) if node["infer"].get("_enumerate_dim") is None: cost_terms.setdefault(ordinal, []).append(log_prob) else: enum_terms.setdefault(ordinal, []).append(log_prob) # Note we mark all sample sites with require_backward to gather # enumerated sites and adjust cond_indep_stack of all sample sites. if not node["is_observed"]: queries.append(log_prob) require_backward(log_prob) # We take special care to match the term ordering in # pyro.infer.traceenum_elbo._compute_model_factors() to allow # contract_tensor_tree() to use shared_intermediates() inside # TraceEnumSample_ELBO. The special ordering is: first all cost terms in # order of model_trace, then all enum_terms in order of model trace. log_probs = cost_terms for ordinal, terms in enum_terms.items(): log_probs.setdefault(ordinal, []).extend(terms) # Run forward-backward algorithm, collecting the ordinal of each connected component. cache = getattr(enum_trace, "_sharing_cache", {}) ring = _make_ring(temperature, cache, dim_to_size) with shared_intermediates(cache): log_probs = contract_tensor_tree(log_probs, sum_dims, ring=ring) # run forward algorithm query_to_ordinal = {} pending = object() # a constant value for pending queries for query in queries: query._pyro_backward_result = pending for ordinal, terms in log_probs.items(): for term in terms: if hasattr(term, "_pyro_backward"): term._pyro_backward() # run backward algorithm # Note: this is quadratic in number of ordinals for query in queries: if query not in query_to_ordinal and query._pyro_backward_result is not pending: query_to_ordinal[query] = ordinal # Construct a collapsed trace by gathering and adjusting cond_indep_stack. collapsed_trace = poutine.Trace() for node in enum_trace.nodes.values(): if node["type"] == "sample" and not node["is_observed"]: # TODO move this into a Leaf implementation somehow new_node = { "type": "sample", "name": node["name"], "is_observed": False, "infer": node["infer"].copy(), "cond_indep_stack": node["cond_indep_stack"], "value": node["value"], } log_prob = node["packed"]["masked_log_prob"] if hasattr(log_prob, "_pyro_backward_result"): # Adjust the cond_indep_stack. ordinal = query_to_ordinal[log_prob] new_node["cond_indep_stack"] = tuple( f for f in node["cond_indep_stack"] if not (f.vectorized and f.size > 1) or plate_to_symbol[f.name] in ordinal) # Gather if node depended on an enumerated value. sample = log_prob._pyro_backward_result if sample is not None: new_value = packed.pack(node["value"], node["infer"]["_dim_to_symbol"]) for index, dim in zip(jit_iter(sample), sample._pyro_sample_dims): if dim in new_value._pyro_dims: index._pyro_dims = sample._pyro_dims[1:] new_value = packed.gather(new_value, index, dim) new_node["value"] = packed.unpack(new_value, enum_trace.symbol_to_dim) collapsed_trace.add_node(node["name"], **new_node) # Replay the model against the collapsed trace. with SamplePosteriorMessenger(trace=collapsed_trace): return model(*args, **kwargs)
def _sample_posterior(model, first_available_dim, temperature, *args, **kwargs): # For internal use by infer_discrete. # Create an enumerated trace. with poutine.block(), EnumerateMessenger(first_available_dim): enum_trace = poutine.trace(model).get_trace(*args, **kwargs) enum_trace = prune_subsample_sites(enum_trace) enum_trace.compute_log_prob() enum_trace.pack_tensors() plate_to_symbol = enum_trace.plate_to_symbol # Collect a set of query sample sites to which the backward algorithm will propagate. log_probs = OrderedDict() sum_dims = set() queries = [] for node in enum_trace.nodes.values(): if node["type"] == "sample": ordinal = frozenset(plate_to_symbol[f.name] for f in node["cond_indep_stack"] if f.vectorized) log_prob = node["packed"]["log_prob"] log_probs.setdefault(ordinal, []).append(log_prob) sum_dims.update(log_prob._pyro_dims) for frame in node["cond_indep_stack"]: if frame.vectorized: sum_dims.remove(plate_to_symbol[frame.name]) # Note we mark all sample sites with require_backward to gather # enumerated sites and adjust cond_indep_stack of all sample sites. if not node["is_observed"]: queries.append(log_prob) require_backward(log_prob) # Run forward-backward algorithm, collecting the ordinal of each connected component. ring = _make_ring(temperature) log_probs = contract_tensor_tree(log_probs, sum_dims, ring=ring) # run forward algorithm query_to_ordinal = {} pending = object() # a constant value for pending queries for query in queries: query._pyro_backward_result = pending for ordinal, terms in log_probs.items(): for term in terms: if hasattr(term, "_pyro_backward"): term._pyro_backward() # run backward algorithm # Note: this is quadratic in number of ordinals for query in queries: if query not in query_to_ordinal and query._pyro_backward_result is not pending: query_to_ordinal[query] = ordinal # Construct a collapsed trace by gathering and adjusting cond_indep_stack. collapsed_trace = poutine.Trace() for node in enum_trace.nodes.values(): if node["type"] == "sample" and not node["is_observed"]: # TODO move this into a Leaf implementation somehow new_node = { "type": "sample", "name": node["name"], "is_observed": False, "infer": node["infer"].copy(), "cond_indep_stack": node["cond_indep_stack"], "value": node["value"], } log_prob = node["packed"]["log_prob"] if hasattr(log_prob, "_pyro_backward_result"): # Adjust the cond_indep_stack. ordinal = query_to_ordinal[log_prob] new_node["cond_indep_stack"] = tuple( f for f in node["cond_indep_stack"] if not f.vectorized or plate_to_symbol[f.name] in ordinal) # Gather if node depended on an enumerated value. sample = log_prob._pyro_backward_result if sample is not None: new_value = packed.pack(node["value"], node["infer"]["_dim_to_symbol"]) for index, dim in zip(jit_iter(sample), sample._pyro_sample_dims): if dim in new_value._pyro_dims: index._pyro_dims = sample._pyro_dims[1:] new_value = packed.gather(new_value, index, dim) new_node["value"] = packed.unpack(new_value, enum_trace.symbol_to_dim) collapsed_trace.add_node(node["name"], **new_node) # Replay the model against the collapsed trace. with SamplePosteriorMessenger(trace=collapsed_trace): return model(*args, **kwargs)