def _compute_downstream_costs(model_trace, guide_trace, # non_reparam_nodes): # recursively compute downstream cost nodes for all sample sites in model and guide # (even though ultimately just need for non-reparameterizable sample sites) # 1. downstream costs used for rao-blackwellization # 2. model observe sites (as well as terms that arise from the model and guide having different # dependency structures) are taken care of via 'children_in_model' below topo_sort_guide_nodes = guide_trace.topological_sort(reverse=True) topo_sort_guide_nodes = [x for x in topo_sort_guide_nodes if guide_trace.nodes[x]["type"] == "sample"] ordered_guide_nodes_dict = {n: i for i, n in enumerate(topo_sort_guide_nodes)} downstream_guide_cost_nodes = {} downstream_costs = {} stacks = get_plate_stacks(model_trace) for node in topo_sort_guide_nodes: downstream_costs[node] = MultiFrameTensor((stacks[node], model_trace.nodes[node]['log_prob'] - guide_trace.nodes[node]['log_prob'])) nodes_included_in_sum = set([node]) downstream_guide_cost_nodes[node] = set([node]) # make more efficient by ordering children appropriately (higher children first) children = [(k, -ordered_guide_nodes_dict[k]) for k in guide_trace.successors(node)] sorted_children = sorted(children, key=itemgetter(1)) for child, _ in sorted_children: child_cost_nodes = downstream_guide_cost_nodes[child] downstream_guide_cost_nodes[node].update(child_cost_nodes) if nodes_included_in_sum.isdisjoint(child_cost_nodes): # avoid duplicates downstream_costs[node].add(*downstream_costs[child].items()) # XXX nodes_included_in_sum logic could be more fine-grained, possibly leading # to speed-ups in case there are many duplicates nodes_included_in_sum.update(child_cost_nodes) missing_downstream_costs = downstream_guide_cost_nodes[node] - nodes_included_in_sum # include terms we missed because we had to avoid duplicates for missing_node in missing_downstream_costs: downstream_costs[node].add((stacks[missing_node], model_trace.nodes[missing_node]['log_prob'] - guide_trace.nodes[missing_node]['log_prob'])) # finish assembling complete downstream costs # (the above computation may be missing terms from model) for site in non_reparam_nodes: children_in_model = set() for node in downstream_guide_cost_nodes[site]: children_in_model.update(model_trace.successors(node)) # remove terms accounted for above children_in_model.difference_update(downstream_guide_cost_nodes[site]) for child in children_in_model: assert (model_trace.nodes[child]["type"] == "sample") downstream_costs[site].add((stacks[child], model_trace.nodes[child]['log_prob'])) downstream_guide_cost_nodes[site].update([child]) for k in non_reparam_nodes: downstream_costs[k] = downstream_costs[k].sum_to(guide_trace.nodes[k]["cond_indep_stack"]) return downstream_costs, downstream_guide_cost_nodes
def _brute_force_compute_downstream_costs( model_trace, guide_trace, # non_reparam_nodes): guide_nodes = [ x for x in guide_trace.nodes if guide_trace.nodes[x]["type"] == "sample" ] downstream_costs, downstream_guide_cost_nodes = {}, {} stacks = get_plate_stacks(model_trace) for node in guide_nodes: downstream_costs[node] = MultiFrameTensor( (stacks[node], model_trace.nodes[node]['log_prob'] - guide_trace.nodes[node]['log_prob'])) downstream_guide_cost_nodes[node] = set([node]) descendants = guide_trace.successors(node) for desc in descendants: desc_mft = MultiFrameTensor( (stacks[desc], model_trace.nodes[desc]['log_prob'] - guide_trace.nodes[desc]['log_prob'])) downstream_costs[node].add(*desc_mft.items()) downstream_guide_cost_nodes[node].update([desc]) for site in non_reparam_nodes: children_in_model = set() for node in downstream_guide_cost_nodes[site]: children_in_model.update(model_trace.successors(node)) children_in_model.difference_update(downstream_guide_cost_nodes[site]) for child in children_in_model: assert (model_trace.nodes[child]["type"] == "sample") child_mft = MultiFrameTensor( (stacks[child], model_trace.nodes[child]['log_prob'])) downstream_costs[site].add(*child_mft.items()) downstream_guide_cost_nodes[site].update([child]) for k in non_reparam_nodes: downstream_costs[k] = downstream_costs[k].sum_to( guide_trace.nodes[k]["cond_indep_stack"]) return downstream_costs, downstream_guide_cost_nodes
def _compute_log_r(model_trace, guide_trace): log_r = MultiFrameTensor() stacks = get_iarange_stacks(model_trace) for name, model_site in model_trace.nodes.items(): if model_site["type"] == "sample": log_r_term = model_site["log_prob"] if not model_site["is_observed"]: log_r_term = log_r_term - guide_trace.nodes[name]["log_prob"] log_r.add((stacks[name], log_r_term.detach())) return log_r
def test_multi_frame_tensor(): stacks = {} actual = MultiFrameTensor() tr = poutine.trace(xy_model).get_trace() for name, site in tr.nodes.items(): if site["type"] == "sample": log_prob = site["fn"].log_prob(site["value"]) stacks[name] = site["cond_indep_stack"] actual.add((site["cond_indep_stack"], log_prob)) assert len(actual) == 4 logp = math.log(0.5) expected = { 'b': torch.ones(torch.Size()) * logp * (1 + 2 + 3 + 6), 'bx': torch.ones(torch.Size((2, ))) * logp * (1 + 1 + 3 + 3), 'by': torch.ones(torch.Size((3, 1))) * logp * (1 + 2 + 1 + 2), 'bxy': torch.ones(torch.Size((3, 2))) * logp * (1 + 1 + 1 + 1), } for name, expected_sum in expected.items(): actual_sum = actual.sum_to(stacks[name]) assert_equal(actual_sum, expected_sum, msg=name)