def _compute_log_r(model_trace, guide_trace): log_r = MultiFrameTensor() stacks = get_iarange_stacks(model_trace) for name, model_site in model_trace.nodes.items(): if model_site["type"] == "sample": log_r_term = model_site["log_prob"] if not model_site["is_observed"]: log_r_term = log_r_term - guide_trace.nodes[name]["log_prob"] log_r.add((stacks[name], log_r_term.detach())) return log_r
def test_multi_frame_tensor(): stacks = {} actual = MultiFrameTensor() tr = poutine.trace(xy_model).get_trace() for name, site in tr.nodes.items(): if site["type"] == "sample": log_prob = site["fn"].log_prob(site["value"]) stacks[name] = site["cond_indep_stack"] actual.add((site["cond_indep_stack"], log_prob)) assert len(actual) == 4 logp = math.log(0.5) expected = { 'b': torch.ones(torch.Size()) * logp * (1 + 2 + 3 + 6), 'bx': torch.ones(torch.Size((2, ))) * logp * (1 + 1 + 3 + 3), 'by': torch.ones(torch.Size((3, 1))) * logp * (1 + 2 + 1 + 2), 'bxy': torch.ones(torch.Size((3, 2))) * logp * (1 + 1 + 1 + 1), } for name, expected_sum in expected.items(): actual_sum = actual.sum_to(stacks[name]) assert_equal(actual_sum, expected_sum, msg=name)
def test_multi_frame_tensor(): stacks = {} actual = MultiFrameTensor() tr = poutine.trace(xy_model).get_trace() for name, site in tr.nodes.items(): if site["type"] == "sample": log_prob = site["fn"].log_prob(site["value"]) stacks[name] = site["cond_indep_stack"] actual.add((site["cond_indep_stack"], log_prob)) assert len(actual) == 4 logp = math.log(0.5) expected = { 'b': torch.ones(torch.Size()) * logp * (1 + 2 + 3 + 6), 'bx': torch.ones(torch.Size((2,))) * logp * (1 + 1 + 3 + 3), 'by': torch.ones(torch.Size((3, 1))) * logp * (1 + 2 + 1 + 2), 'bxy': torch.ones(torch.Size((3, 2))) * logp * (1 + 1 + 1 + 1), } for name, expected_sum in expected.items(): actual_sum = actual.sum_to(stacks[name]) assert_equal(actual_sum, expected_sum, msg=name)