def compute_expectation(self, costs): """ Returns a differentiable expected cost, summing over costs at given ordinals. :param dict costs: A dict mapping ordinals to lists of cost tensors :returns: a scalar expected cost :rtype: torch.Tensor or float """ # Share computation across all cost terms. with shared_intermediates() as cache: ring = MarginalRing(cache=cache) expected_cost = 0. for ordinal, cost_terms in costs.items(): log_factors = self._get_log_factors(ordinal) scale = math.exp(sum(x for x in log_factors if not isinstance(x, torch.Tensor))) log_factors = [x for x in log_factors if isinstance(x, torch.Tensor)] # Collect log_prob terms to query for marginal probability. queries = {frozenset(cost._pyro_dims): None for cost in cost_terms} for log_factor in log_factors: key = frozenset(log_factor._pyro_dims) if queries.get(key, False) is None: queries[key] = log_factor # Ensure a query exists for each cost term. for cost in cost_terms: key = frozenset(cost._pyro_dims) if queries[key] is None: query = cost.new_zeros(cost.shape) query._pyro_dims = cost._pyro_dims log_factors.append(query) queries[key] = query # Perform sum-product contraction. Note that plates never need to be # product-contracted due to our plate-based dependency ordering. sum_dims = set().union(*(x._pyro_dims for x in log_factors)) - ordinal for query in queries.values(): require_backward(query) root = ring.sumproduct(log_factors, sum_dims) root._pyro_backward() probs = {key: query._pyro_backward_result.exp() for key, query in queries.items()} # Aggregate prob * cost terms. for cost in cost_terms: key = frozenset(cost._pyro_dims) prob = probs[key] prob._pyro_dims = queries[key]._pyro_dims mask = prob > 0 if torch._C._get_tracing_state() or not mask.all(): mask._pyro_dims = prob._pyro_dims cost, prob, mask = packed.broadcast_all(cost, prob, mask) prob = prob[mask] cost = cost[mask] else: cost, prob = packed.broadcast_all(cost, prob) expected_cost = expected_cost + scale * torch.tensordot(prob, cost, prob.dim()) LAST_CACHE_SIZE[0] = count_cached_ops(cache) return expected_cost
def einsum(equation, *operands): """ Forward-max-sum backward-argmax implementation of einsum. This assumes all operands have a ``._pyro_dims`` attribute set. """ equation = packed.rename_equation(equation, *operands) inputs, output = equation.split('->') any_requires_backward = any(hasattr(x, '_pyro_backward') for x in operands) contract_dims = ''.join( sorted(set().union(*(x._pyro_dims for x in operands)) - set(output))) dims = output + contract_dims result = reduce(operator.add, packed.broadcast_all(*operands, dims=dims)) argmax = None # work around lack of pytorch support for zero-sized tensors if contract_dims: output_shape = result.shape[:len(output)] contract_shape = result.shape[len(output):] result, argmax = result.reshape(output_shape + (-1, )).max(-1) if any_requires_backward: argmax = unflatten(argmax, output, contract_dims, contract_shape) elif result is operands[0]: result = result[...] # create a new object result._pyro_dims = output assert result.dim() == len(result._pyro_dims) if any_requires_backward: result._pyro_backward = _EinsumBackward(operands, argmax) return result
def compute_site_dice_factor(site): log_denom = 0 log_prob = site["packed"][ "score_parts"].score_function # not scaled by subsampling dims = getattr(log_prob, "_pyro_dims", "") if site["infer"].get("enumerate"): num_samples = site["infer"].get("num_samples") if num_samples is not None: # site was multiply sampled if not is_identically_zero(log_prob): log_prob = log_prob - log_prob.detach() log_prob = log_prob - math.log(num_samples) if not isinstance(log_prob, torch.Tensor): log_prob = torch.tensor(float(log_prob), device=site["value"].device) log_prob._pyro_dims = dims # I don't know why the following broadcast is needed, but it makes tests pass: log_prob, _ = packed.broadcast_all(log_prob, site["packed"]["log_prob"]) elif site["infer"]["enumerate"] == "sequential": log_denom = math.log(site["infer"].get("_enum_total", num_samples)) else: # site was monte carlo sampled if not is_identically_zero(log_prob): log_prob = log_prob - log_prob.detach() log_prob._pyro_dims = dims return log_prob, log_denom
def __init__(self, guide_trace, ordering): log_denom = defaultdict(float) # avoids double-counting when sequentially enumerating log_probs = defaultdict(list) # accounts for upstream probabilties for name, site in guide_trace.nodes.items(): if site["type"] != "sample": continue log_prob = site["packed"]["score_parts"].score_function # not scaled by subsampling dims = getattr(log_prob, "_pyro_dims", "") ordinal = ordering[name] if site["infer"].get("enumerate"): num_samples = site["infer"].get("num_samples") if num_samples is not None: # site was multiply sampled if not is_identically_zero(log_prob): log_prob = log_prob - log_prob.detach() log_prob = log_prob - math.log(num_samples) if not isinstance(log_prob, torch.Tensor): log_prob = site["value"].new_tensor(log_prob) log_prob._pyro_dims = dims # I don't know why the following broadcast is needed, but it makes tests pass: log_prob, _ = packed.broadcast_all(log_prob, site["packed"]["log_prob"]) elif site["infer"]["enumerate"] == "sequential": log_denom[ordinal] += math.log(site["infer"]["_enum_total"]) else: # site was monte carlo sampled if is_identically_zero(log_prob): continue log_prob = log_prob - log_prob.detach() log_prob._pyro_dims = dims log_probs[ordinal].append(log_prob) self.log_denom = log_denom self.log_probs = log_probs
def process(self, message): output = self.output operands = list(self.operands) contract_dims = ''.join(sorted(set().union(*(x._pyro_dims for x in operands)) - set(output))) batch_dims = output # Slice down operands before combining terms. sample2 = message if sample2 is not None: for dim, index in zip(sample2._pyro_sample_dims, jit_iter(sample2)): batch_dims = batch_dims.replace(dim, '') for i, x in enumerate(operands): if dim in x._pyro_dims: index._pyro_dims = sample2._pyro_dims[1:] x = packed.gather(x, index, dim) operands[i] = x # Combine terms. dims = batch_dims + contract_dims logits = reduce(operator.add, packed.broadcast_all(*operands, dims=dims)) # Sample. sample1 = None # work around lack of pytorch support for zero-sized tensors if contract_dims: output_shape = logits.shape[:len(batch_dims)] contract_shape = logits.shape[len(batch_dims):] flat_logits = logits.reshape(output_shape + (-1,)) flat_sample = dist.Categorical(logits=flat_logits).sample() sample1 = unflatten(flat_sample, batch_dims, contract_dims, contract_shape) # Cut down samples to pass on to subsequent steps. return einsum_backward_sample(self.operands, sample1, sample2)
def test_broadcast_all(shapes): inputs, dim_to_symbol, symbol_to_dim = make_inputs(shapes) packed_inputs = [packed.pack(x, dim_to_symbol) for x in inputs] packed_outputs = packed.broadcast_all(*packed_inputs) actual = tuple(packed.unpack(x, symbol_to_dim) for x in packed_outputs) expected = broadcast_all(*inputs) if inputs else [] assert len(actual) == len(expected) for a, e in zip(actual, expected): assert_equal(a, e)
def einsum_backward_sample(operands, sample1, sample2): """ Cuts down samples to pass on to subsequent steps. This is used in various ``_EinsumBackward.__call__()`` methods. This assumes all operands have a ``._pyro_dims`` attribute set. """ # Combine upstream sample with sample at this site. if sample1 is None: sample = sample2 elif sample2 is None: sample = sample1 else: # Slice sample1 down based on choices in sample2. assert set(sample1._pyro_sample_dims).isdisjoint( sample2._pyro_sample_dims) sample_dims = sample1._pyro_sample_dims + sample2._pyro_sample_dims for dim, index in zip(sample2._pyro_sample_dims, jit_iter(sample2)): if dim in sample1._pyro_dims: index._pyro_dims = sample2._pyro_dims[1:] sample1 = packed.gather(sample1, index, dim) # Concatenate the two samples. parts = packed.broadcast_all(sample1, sample2) sample = torch.cat(parts) sample._pyro_dims = parts[0]._pyro_dims sample._pyro_sample_dims = sample_dims assert sample.dim() == len(sample._pyro_dims) if not torch._C._get_tracing_state(): assert sample.size(0) == len(sample._pyro_sample_dims) # Select sample dimensions to pass on to downstream sites. for x in operands: if not hasattr(x, '_pyro_backward'): continue if sample is None: yield x._pyro_backward, None continue x_sample_dims = set(x._pyro_dims) & set(sample._pyro_sample_dims) if not x_sample_dims: yield x._pyro_backward, None continue if x_sample_dims == set(sample._pyro_sample_dims): yield x._pyro_backward, sample continue x_sample_dims = ''.join(sorted(x_sample_dims)) x_sample = sample[[ sample._pyro_sample_dims.index(dim) for dim in x_sample_dims ]] x_sample._pyro_dims = sample._pyro_dims x_sample._pyro_sample_dims = x_sample_dims assert x_sample.dim() == len(x_sample._pyro_dims) if not torch._C._get_tracing_state(): assert x_sample.size(0) == len(x_sample._pyro_sample_dims) yield x._pyro_backward, x_sample