def process(self, message): output = self.output operands = list(self.operands) contract_dims = ''.join(sorted(set().union(*(x._pyro_dims for x in operands)) - set(output))) batch_dims = output # Slice down operands before combining terms. sample2 = message if sample2 is not None: for dim, index in zip(sample2._pyro_sample_dims, jit_iter(sample2)): batch_dims = batch_dims.replace(dim, '') for i, x in enumerate(operands): if dim in x._pyro_dims: index._pyro_dims = sample2._pyro_dims[1:] x = packed.gather(x, index, dim) operands[i] = x # Combine terms. dims = batch_dims + contract_dims logits = reduce(operator.add, packed.broadcast_all(*operands, dims=dims)) # Sample. sample1 = None # work around lack of pytorch support for zero-sized tensors if contract_dims: output_shape = logits.shape[:len(batch_dims)] contract_shape = logits.shape[len(batch_dims):] flat_logits = logits.reshape(output_shape + (-1,)) flat_sample = dist.Categorical(logits=flat_logits).sample() sample1 = unflatten(flat_sample, batch_dims, contract_dims, contract_shape) # Cut down samples to pass on to subsequent steps. return einsum_backward_sample(self.operands, sample1, sample2)
def einsum_backward_sample(operands, sample1, sample2): """ Cuts down samples to pass on to subsequent steps. This is used in various ``_EinsumBackward.__call__()`` methods. This assumes all operands have a ``._pyro_dims`` attribute set. """ # Combine upstream sample with sample at this site. if sample1 is None: sample = sample2 elif sample2 is None: sample = sample1 else: # Slice sample1 down based on choices in sample2. assert set(sample1._pyro_sample_dims).isdisjoint( sample2._pyro_sample_dims) sample_dims = sample1._pyro_sample_dims + sample2._pyro_sample_dims for dim, index in zip(sample2._pyro_sample_dims, jit_iter(sample2)): if dim in sample1._pyro_dims: index._pyro_dims = sample2._pyro_dims[1:] sample1 = packed.gather(sample1, index, dim) # Concatenate the two samples. parts = packed.broadcast_all(sample1, sample2) sample = torch.cat(parts) sample._pyro_dims = parts[0]._pyro_dims sample._pyro_sample_dims = sample_dims assert sample.dim() == len(sample._pyro_dims) if not torch._C._get_tracing_state(): assert sample.size(0) == len(sample._pyro_sample_dims) # Select sample dimensions to pass on to downstream sites. for x in operands: if not hasattr(x, '_pyro_backward'): continue if sample is None: yield x._pyro_backward, None continue x_sample_dims = set(x._pyro_dims) & set(sample._pyro_sample_dims) if not x_sample_dims: yield x._pyro_backward, None continue if x_sample_dims == set(sample._pyro_sample_dims): yield x._pyro_backward, sample continue x_sample_dims = ''.join(sorted(x_sample_dims)) x_sample = sample[[ sample._pyro_sample_dims.index(dim) for dim in x_sample_dims ]] x_sample._pyro_dims = sample._pyro_dims x_sample._pyro_sample_dims = x_sample_dims assert x_sample.dim() == len(x_sample._pyro_dims) if not torch._C._get_tracing_state(): assert x_sample.size(0) == len(x_sample._pyro_sample_dims) yield x._pyro_backward, x_sample
def _sample_posterior_from_trace(model, enum_trace, temperature, *args, **kwargs): plate_to_symbol = enum_trace.plate_to_symbol # Collect a set of query sample sites to which the backward algorithm will propagate. sum_dims = set() queries = [] dim_to_size = {} cost_terms = OrderedDict() enum_terms = OrderedDict() for node in enum_trace.nodes.values(): if node["type"] == "sample": ordinal = frozenset(plate_to_symbol[f.name] for f in node["cond_indep_stack"] if f.vectorized and f.size > 1) # For sites that depend on an enumerated variable, we need to apply # the mask but not the scale when sampling. if "masked_log_prob" not in node["packed"]: node["packed"]["masked_log_prob"] = packed.scale_and_mask( node["packed"]["unscaled_log_prob"], mask=node["packed"]["mask"]) log_prob = node["packed"]["masked_log_prob"] sum_dims.update(frozenset(log_prob._pyro_dims) - ordinal) if sum_dims.isdisjoint(log_prob._pyro_dims): continue dim_to_size.update(zip(log_prob._pyro_dims, log_prob.shape)) if node["infer"].get("_enumerate_dim") is None: cost_terms.setdefault(ordinal, []).append(log_prob) else: enum_terms.setdefault(ordinal, []).append(log_prob) # Note we mark all sample sites with require_backward to gather # enumerated sites and adjust cond_indep_stack of all sample sites. if not node["is_observed"]: queries.append(log_prob) require_backward(log_prob) # We take special care to match the term ordering in # pyro.infer.traceenum_elbo._compute_model_factors() to allow # contract_tensor_tree() to use shared_intermediates() inside # TraceEnumSample_ELBO. The special ordering is: first all cost terms in # order of model_trace, then all enum_terms in order of model trace. log_probs = cost_terms for ordinal, terms in enum_terms.items(): log_probs.setdefault(ordinal, []).extend(terms) # Run forward-backward algorithm, collecting the ordinal of each connected component. cache = getattr(enum_trace, "_sharing_cache", {}) ring = _make_ring(temperature, cache, dim_to_size) with shared_intermediates(cache): log_probs = contract_tensor_tree(log_probs, sum_dims, ring=ring) # run forward algorithm query_to_ordinal = {} pending = object() # a constant value for pending queries for query in queries: query._pyro_backward_result = pending for ordinal, terms in log_probs.items(): for term in terms: if hasattr(term, "_pyro_backward"): term._pyro_backward() # run backward algorithm # Note: this is quadratic in number of ordinals for query in queries: if query not in query_to_ordinal and query._pyro_backward_result is not pending: query_to_ordinal[query] = ordinal # Construct a collapsed trace by gathering and adjusting cond_indep_stack. collapsed_trace = poutine.Trace() for node in enum_trace.nodes.values(): if node["type"] == "sample" and not node["is_observed"]: # TODO move this into a Leaf implementation somehow new_node = { "type": "sample", "name": node["name"], "is_observed": False, "infer": node["infer"].copy(), "cond_indep_stack": node["cond_indep_stack"], "value": node["value"], } log_prob = node["packed"]["masked_log_prob"] if hasattr(log_prob, "_pyro_backward_result"): # Adjust the cond_indep_stack. ordinal = query_to_ordinal[log_prob] new_node["cond_indep_stack"] = tuple( f for f in node["cond_indep_stack"] if not (f.vectorized and f.size > 1) or plate_to_symbol[f.name] in ordinal) # Gather if node depended on an enumerated value. sample = log_prob._pyro_backward_result if sample is not None: new_value = packed.pack(node["value"], node["infer"]["_dim_to_symbol"]) for index, dim in zip(jit_iter(sample), sample._pyro_sample_dims): if dim in new_value._pyro_dims: index._pyro_dims = sample._pyro_dims[1:] new_value = packed.gather(new_value, index, dim) new_node["value"] = packed.unpack(new_value, enum_trace.symbol_to_dim) collapsed_trace.add_node(node["name"], **new_node) # Replay the model against the collapsed trace. with SamplePosteriorMessenger(trace=collapsed_trace): return model(*args, **kwargs)
def _sample_posterior(model, first_available_dim, temperature, *args, **kwargs): # For internal use by infer_discrete. # Create an enumerated trace. with poutine.block(), EnumerateMessenger(first_available_dim): enum_trace = poutine.trace(model).get_trace(*args, **kwargs) enum_trace = prune_subsample_sites(enum_trace) enum_trace.compute_log_prob() enum_trace.pack_tensors() plate_to_symbol = enum_trace.plate_to_symbol # Collect a set of query sample sites to which the backward algorithm will propagate. log_probs = OrderedDict() sum_dims = set() queries = [] for node in enum_trace.nodes.values(): if node["type"] == "sample": ordinal = frozenset(plate_to_symbol[f.name] for f in node["cond_indep_stack"] if f.vectorized) log_prob = node["packed"]["log_prob"] log_probs.setdefault(ordinal, []).append(log_prob) sum_dims.update(log_prob._pyro_dims) for frame in node["cond_indep_stack"]: if frame.vectorized: sum_dims.remove(plate_to_symbol[frame.name]) # Note we mark all sample sites with require_backward to gather # enumerated sites and adjust cond_indep_stack of all sample sites. if not node["is_observed"]: queries.append(log_prob) require_backward(log_prob) # Run forward-backward algorithm, collecting the ordinal of each connected component. ring = _make_ring(temperature) log_probs = contract_tensor_tree(log_probs, sum_dims, ring=ring) # run forward algorithm query_to_ordinal = {} pending = object() # a constant value for pending queries for query in queries: query._pyro_backward_result = pending for ordinal, terms in log_probs.items(): for term in terms: if hasattr(term, "_pyro_backward"): term._pyro_backward() # run backward algorithm # Note: this is quadratic in number of ordinals for query in queries: if query not in query_to_ordinal and query._pyro_backward_result is not pending: query_to_ordinal[query] = ordinal # Construct a collapsed trace by gathering and adjusting cond_indep_stack. collapsed_trace = poutine.Trace() for node in enum_trace.nodes.values(): if node["type"] == "sample" and not node["is_observed"]: # TODO move this into a Leaf implementation somehow new_node = { "type": "sample", "name": node["name"], "is_observed": False, "infer": node["infer"].copy(), "cond_indep_stack": node["cond_indep_stack"], "value": node["value"], } log_prob = node["packed"]["log_prob"] if hasattr(log_prob, "_pyro_backward_result"): # Adjust the cond_indep_stack. ordinal = query_to_ordinal[log_prob] new_node["cond_indep_stack"] = tuple( f for f in node["cond_indep_stack"] if not f.vectorized or plate_to_symbol[f.name] in ordinal) # Gather if node depended on an enumerated value. sample = log_prob._pyro_backward_result if sample is not None: new_value = packed.pack(node["value"], node["infer"]["_dim_to_symbol"]) for index, dim in zip(jit_iter(sample), sample._pyro_sample_dims): if dim in new_value._pyro_dims: index._pyro_dims = sample._pyro_dims[1:] new_value = packed.gather(new_value, index, dim) new_node["value"] = packed.unpack(new_value, enum_trace.symbol_to_dim) collapsed_trace.add_node(node["name"], **new_node) # Replay the model against the collapsed trace. with SamplePosteriorMessenger(trace=collapsed_trace): return model(*args, **kwargs)