def test_compute_downstream_costs_iarange_reuse(dim1, dim2):
    guide_trace = poutine.trace(iarange_reuse_model_guide,
                                graph_type="dense").get_trace(include_obs=False, dim1=dim1, dim2=dim2)
    model_trace = poutine.trace(poutine.replay(iarange_reuse_model_guide, trace=guide_trace),
                                graph_type="dense").get_trace(include_obs=True, dim1=dim1, dim2=dim2)

    guide_trace = prune_subsample_sites(guide_trace)
    model_trace = prune_subsample_sites(model_trace)
    model_trace.compute_log_prob()
    guide_trace.compute_log_prob()

    non_reparam_nodes = set(guide_trace.nonreparam_stochastic_nodes)
    dc, dc_nodes = _compute_downstream_costs(model_trace, guide_trace,
                                             non_reparam_nodes)
    dc_brute, dc_nodes_brute = _brute_force_compute_downstream_costs(model_trace, guide_trace, non_reparam_nodes)
    assert dc_nodes == dc_nodes_brute

    for k in dc:
        assert(guide_trace.nodes[k]['log_prob'].size() == dc[k].size())
        assert_equal(dc[k], dc_brute[k])

    expected_c1 = model_trace.nodes['c1']['log_prob'] - guide_trace.nodes['c1']['log_prob']
    expected_c1 += (model_trace.nodes['b1']['log_prob'] - guide_trace.nodes['b1']['log_prob']).sum()
    expected_c1 += model_trace.nodes['c2']['log_prob'] - guide_trace.nodes['c2']['log_prob']
    expected_c1 += model_trace.nodes['obs']['log_prob']
    assert_equal(expected_c1, dc['c1'])
def test_compute_downstream_costs_plate_reuse(dim1, dim2):
    guide_trace = poutine.trace(plate_reuse_model_guide,
                                graph_type="dense").get_trace(
                                    include_obs=False, dim1=dim1, dim2=dim2)
    model_trace = poutine.trace(poutine.replay(plate_reuse_model_guide,
                                               trace=guide_trace),
                                graph_type="dense").get_trace(include_obs=True,
                                                              dim1=dim1,
                                                              dim2=dim2)

    guide_trace = prune_subsample_sites(guide_trace)
    model_trace = prune_subsample_sites(model_trace)
    model_trace.compute_log_prob()
    guide_trace.compute_log_prob()

    non_reparam_nodes = set(guide_trace.nonreparam_stochastic_nodes)
    dc, dc_nodes = _compute_downstream_costs(model_trace, guide_trace,
                                             non_reparam_nodes)
    dc_brute, dc_nodes_brute = _brute_force_compute_downstream_costs(
        model_trace, guide_trace, non_reparam_nodes)
    assert dc_nodes == dc_nodes_brute

    for k in dc:
        assert (guide_trace.nodes[k]['log_prob'].size() == dc[k].size())
        assert_equal(dc[k], dc_brute[k])

    expected_c1 = model_trace.nodes['c1']['log_prob'] - guide_trace.nodes[
        'c1']['log_prob']
    expected_c1 += (model_trace.nodes['b1']['log_prob'] -
                    guide_trace.nodes['b1']['log_prob']).sum()
    expected_c1 += model_trace.nodes['c2']['log_prob'] - guide_trace.nodes[
        'c2']['log_prob']
    expected_c1 += model_trace.nodes['obs']['log_prob']
    assert_equal(expected_c1, dc['c1'])
Esempio n. 3
0
def test_compute_downstream_costs_irange_in_iarange(dim1, dim2):
    guide_trace = poutine.trace(nested_model_guide2,
                                graph_type="dense").get_trace(include_obs=False, dim1=dim1, dim2=dim2)
    model_trace = poutine.trace(poutine.replay(nested_model_guide2, trace=guide_trace),
                                graph_type="dense").get_trace(include_obs=True, dim1=dim1, dim2=dim2)

    guide_trace = prune_subsample_sites(guide_trace)
    model_trace = prune_subsample_sites(model_trace)
    model_trace.compute_log_prob()
    guide_trace.compute_log_prob()

    non_reparam_nodes = set(guide_trace.nonreparam_stochastic_nodes)
    dc, dc_nodes = _compute_downstream_costs(model_trace, guide_trace, non_reparam_nodes)
    dc_brute, dc_nodes_brute = _brute_force_compute_downstream_costs(model_trace, guide_trace, non_reparam_nodes)

    assert dc_nodes == dc_nodes_brute

    for k in dc:
        assert(guide_trace.nodes[k]['log_prob'].size() == dc[k].size())
        assert_equal(dc[k], dc_brute[k])

    expected_b1 = model_trace.nodes['b1']['log_prob'] - guide_trace.nodes['b1']['log_prob']
    expected_b1 += model_trace.nodes['obs1']['log_prob']
    assert_equal(expected_b1, dc['b1'])

    expected_c = model_trace.nodes['c']['log_prob'] - guide_trace.nodes['c']['log_prob']
    for i in range(dim2):
        expected_c += model_trace.nodes['b{}'.format(i)]['log_prob'] - \
            guide_trace.nodes['b{}'.format(i)]['log_prob']
        expected_c += model_trace.nodes['obs{}'.format(i)]['log_prob']
    assert_equal(expected_c, dc['c'])

    expected_a1 = model_trace.nodes['a1']['log_prob'] - guide_trace.nodes['a1']['log_prob']
    expected_a1 += expected_c.sum()
    assert_equal(expected_a1, dc['a1'])
def test_compute_downstream_costs_plate_in_iplate(dim1):
    guide_trace = poutine.trace(
        nested_model_guide, graph_type="dense").get_trace(include_obs=False,
                                                          dim1=dim1)
    model_trace = poutine.trace(poutine.replay(nested_model_guide,
                                               trace=guide_trace),
                                graph_type="dense").get_trace(include_obs=True,
                                                              dim1=dim1)

    guide_trace = prune_subsample_sites(guide_trace)
    model_trace = prune_subsample_sites(model_trace)
    model_trace.compute_log_prob()
    guide_trace.compute_log_prob()

    non_reparam_nodes = set(guide_trace.nonreparam_stochastic_nodes)

    dc, dc_nodes = _compute_downstream_costs(model_trace, guide_trace,
                                             non_reparam_nodes)
    dc_brute, dc_nodes_brute = _brute_force_compute_downstream_costs(
        model_trace, guide_trace, non_reparam_nodes)

    assert dc_nodes == dc_nodes_brute

    expected_c1 = (model_trace.nodes['c1']['log_prob'] -
                   guide_trace.nodes['c1']['log_prob'])
    expected_c1 += model_trace.nodes['obs1']['log_prob']

    expected_b1 = (model_trace.nodes['b1']['log_prob'] -
                   guide_trace.nodes['b1']['log_prob'])
    expected_b1 += (model_trace.nodes['c1']['log_prob'] -
                    guide_trace.nodes['c1']['log_prob']).sum()
    expected_b1 += model_trace.nodes['obs1']['log_prob'].sum()

    expected_c0 = (model_trace.nodes['c0']['log_prob'] -
                   guide_trace.nodes['c0']['log_prob'])
    expected_c0 += model_trace.nodes['obs0']['log_prob']

    expected_b0 = (model_trace.nodes['b0']['log_prob'] -
                   guide_trace.nodes['b0']['log_prob'])
    expected_b0 += (model_trace.nodes['c0']['log_prob'] -
                    guide_trace.nodes['c0']['log_prob']).sum()
    expected_b0 += model_trace.nodes['obs0']['log_prob'].sum()

    assert_equal(expected_c1, dc['c1'], prec=1.0e-6)
    assert_equal(expected_b1, dc['b1'], prec=1.0e-6)
    assert_equal(expected_c0, dc['c0'], prec=1.0e-6)
    assert_equal(expected_b0, dc['b0'], prec=1.0e-6)

    for k in dc:
        assert (guide_trace.nodes[k]['log_prob'].size() == dc[k].size())
        assert_equal(dc[k], dc_brute[k])
def test_compute_downstream_costs_duplicates(dim):
    guide_trace = poutine.trace(diamond_guide,
                                graph_type="dense").get_trace(dim=dim)
    model_trace = poutine.trace(poutine.replay(diamond_model,
                                               trace=guide_trace),
                                graph_type="dense").get_trace(dim=dim)

    guide_trace = prune_subsample_sites(guide_trace)
    model_trace = prune_subsample_sites(model_trace)
    model_trace.compute_log_prob()
    guide_trace.compute_log_prob()

    non_reparam_nodes = set(guide_trace.nonreparam_stochastic_nodes)

    dc, dc_nodes = _compute_downstream_costs(model_trace, guide_trace,
                                             non_reparam_nodes)
    dc_brute, dc_nodes_brute = _brute_force_compute_downstream_costs(
        model_trace, guide_trace, non_reparam_nodes)

    assert dc_nodes == dc_nodes_brute

    expected_a1 = (model_trace.nodes['a1']['log_prob'] -
                   guide_trace.nodes['a1']['log_prob'])
    for d in range(dim):
        expected_a1 += model_trace.nodes['b{}'.format(d)]['log_prob']
        expected_a1 -= guide_trace.nodes['b{}'.format(d)]['log_prob']
    expected_a1 += (model_trace.nodes['c1']['log_prob'] -
                    guide_trace.nodes['c1']['log_prob'])
    expected_a1 += model_trace.nodes['obs']['log_prob']

    expected_b1 = -guide_trace.nodes['b1']['log_prob']
    for d in range(dim):
        expected_b1 += model_trace.nodes['b{}'.format(d)]['log_prob']
    expected_b1 += (model_trace.nodes['c1']['log_prob'] -
                    guide_trace.nodes['c1']['log_prob'])
    expected_b1 += model_trace.nodes['obs']['log_prob']

    expected_c1 = (model_trace.nodes['c1']['log_prob'] -
                   guide_trace.nodes['c1']['log_prob'])
    for d in range(dim):
        expected_c1 += model_trace.nodes['b{}'.format(d)]['log_prob']
    expected_c1 += model_trace.nodes['obs']['log_prob']

    assert_equal(expected_a1, dc['a1'], prec=1.0e-6)
    assert_equal(expected_b1, dc['b1'], prec=1.0e-6)
    assert_equal(expected_c1, dc['c1'], prec=1.0e-6)

    for k in dc:
        assert (guide_trace.nodes[k]['log_prob'].size() == dc[k].size())
        assert_equal(dc[k], dc_brute[k])
def test_compute_downstream_costs_duplicates(dim):
    guide_trace = poutine.trace(diamond_guide,
                                graph_type="dense").get_trace(dim=dim)
    model_trace = poutine.trace(poutine.replay(diamond_model, trace=guide_trace),
                                graph_type="dense").get_trace(dim=dim)

    guide_trace = prune_subsample_sites(guide_trace)
    model_trace = prune_subsample_sites(model_trace)
    model_trace.compute_log_prob()
    guide_trace.compute_log_prob()

    non_reparam_nodes = set(guide_trace.nonreparam_stochastic_nodes)

    dc, dc_nodes = _compute_downstream_costs(model_trace, guide_trace,
                                             non_reparam_nodes)
    dc_brute, dc_nodes_brute = _brute_force_compute_downstream_costs(model_trace, guide_trace,
                                                                     non_reparam_nodes)

    assert dc_nodes == dc_nodes_brute

    expected_a1 = (model_trace.nodes['a1']['log_prob'] - guide_trace.nodes['a1']['log_prob'])
    for d in range(dim):
        expected_a1 += model_trace.nodes['b{}'.format(d)]['log_prob']
        expected_a1 -= guide_trace.nodes['b{}'.format(d)]['log_prob']
    expected_a1 += (model_trace.nodes['c1']['log_prob'] - guide_trace.nodes['c1']['log_prob'])
    expected_a1 += model_trace.nodes['obs']['log_prob']

    expected_b1 = - guide_trace.nodes['b1']['log_prob']
    for d in range(dim):
        expected_b1 += model_trace.nodes['b{}'.format(d)]['log_prob']
    expected_b1 += (model_trace.nodes['c1']['log_prob'] - guide_trace.nodes['c1']['log_prob'])
    expected_b1 += model_trace.nodes['obs']['log_prob']

    expected_c1 = (model_trace.nodes['c1']['log_prob'] - guide_trace.nodes['c1']['log_prob'])
    for d in range(dim):
        expected_c1 += model_trace.nodes['b{}'.format(d)]['log_prob']
    expected_c1 += model_trace.nodes['obs']['log_prob']

    assert_equal(expected_a1, dc['a1'], prec=1.0e-6)
    assert_equal(expected_b1, dc['b1'], prec=1.0e-6)
    assert_equal(expected_c1, dc['c1'], prec=1.0e-6)

    for k in dc:
        assert(guide_trace.nodes[k]['log_prob'].size() == dc[k].size())
        assert_equal(dc[k], dc_brute[k])
def test_compute_downstream_costs_iarange_in_irange(dim1):
    guide_trace = poutine.trace(nested_model_guide,
                                graph_type="dense").get_trace(include_obs=False, dim1=dim1)
    model_trace = poutine.trace(poutine.replay(nested_model_guide, trace=guide_trace),
                                graph_type="dense").get_trace(include_obs=True, dim1=dim1)

    guide_trace = prune_subsample_sites(guide_trace)
    model_trace = prune_subsample_sites(model_trace)
    model_trace.compute_log_prob()
    guide_trace.compute_log_prob()

    non_reparam_nodes = set(guide_trace.nonreparam_stochastic_nodes)

    dc, dc_nodes = _compute_downstream_costs(model_trace, guide_trace,
                                             non_reparam_nodes)
    dc_brute, dc_nodes_brute = _brute_force_compute_downstream_costs(model_trace, guide_trace,
                                                                     non_reparam_nodes)

    assert dc_nodes == dc_nodes_brute

    expected_c1 = (model_trace.nodes['c1']['log_prob'] - guide_trace.nodes['c1']['log_prob'])
    expected_c1 += model_trace.nodes['obs1']['log_prob']

    expected_b1 = (model_trace.nodes['b1']['log_prob'] - guide_trace.nodes['b1']['log_prob'])
    expected_b1 += (model_trace.nodes['c1']['log_prob'] - guide_trace.nodes['c1']['log_prob']).sum()
    expected_b1 += model_trace.nodes['obs1']['log_prob'].sum()

    expected_c0 = (model_trace.nodes['c0']['log_prob'] - guide_trace.nodes['c0']['log_prob'])
    expected_c0 += model_trace.nodes['obs0']['log_prob']

    expected_b0 = (model_trace.nodes['b0']['log_prob'] - guide_trace.nodes['b0']['log_prob'])
    expected_b0 += (model_trace.nodes['c0']['log_prob'] - guide_trace.nodes['c0']['log_prob']).sum()
    expected_b0 += model_trace.nodes['obs0']['log_prob'].sum()

    assert_equal(expected_c1, dc['c1'], prec=1.0e-6)
    assert_equal(expected_b1, dc['b1'], prec=1.0e-6)
    assert_equal(expected_c0, dc['c0'], prec=1.0e-6)
    assert_equal(expected_b0, dc['b0'], prec=1.0e-6)

    for k in dc:
        assert(guide_trace.nodes[k]['log_prob'].size() == dc[k].size())
        assert_equal(dc[k], dc_brute[k])
def test_compute_downstream_costs_big_model_guide_pair(include_inner_1,
                                                       include_single,
                                                       flip_c23,
                                                       include_triple,
                                                       include_z1):
    guide_trace = poutine.trace(big_model_guide, graph_type="dense").get_trace(
        include_obs=False,
        include_inner_1=include_inner_1,
        include_single=include_single,
        flip_c23=flip_c23,
        include_triple=include_triple,
        include_z1=include_z1)
    model_trace = poutine.trace(poutine.replay(big_model_guide,
                                               trace=guide_trace),
                                graph_type="dense").get_trace(
                                    include_obs=True,
                                    include_inner_1=include_inner_1,
                                    include_single=include_single,
                                    flip_c23=flip_c23,
                                    include_triple=include_triple,
                                    include_z1=include_z1)

    guide_trace = prune_subsample_sites(guide_trace)
    model_trace = prune_subsample_sites(model_trace)
    model_trace.compute_log_prob()
    guide_trace.compute_log_prob()
    non_reparam_nodes = set(guide_trace.nonreparam_stochastic_nodes)

    dc, dc_nodes = _compute_downstream_costs(model_trace, guide_trace,
                                             non_reparam_nodes)

    dc_brute, dc_nodes_brute = _brute_force_compute_downstream_costs(
        model_trace, guide_trace, non_reparam_nodes)

    assert dc_nodes == dc_nodes_brute

    expected_nodes_full_model = {
        'a1': {'c2', 'a1', 'd1', 'c1', 'obs', 'b1', 'd2', 'c3', 'b0'},
        'd2': {'obs', 'd2'},
        'd1': {'obs', 'd1', 'd2'},
        'c3': {'d2', 'obs', 'd1', 'c3'},
        'b0': {'b0', 'd1', 'c1', 'obs', 'b1', 'd2', 'c3', 'c2'},
        'b1': {'obs', 'b1', 'd1', 'd2', 'c3', 'c1', 'c2'},
        'c1': {'d1', 'c1', 'obs', 'd2', 'c3', 'c2'},
        'c2': {'obs', 'd1', 'c3', 'd2', 'c2'}
    }
    if not include_triple and include_inner_1 and include_single and not flip_c23:
        assert (dc_nodes == expected_nodes_full_model)

    expected_b1 = (model_trace.nodes['b1']['log_prob'] -
                   guide_trace.nodes['b1']['log_prob'])
    expected_b1 += (model_trace.nodes['d2']['log_prob'] -
                    guide_trace.nodes['d2']['log_prob']).sum(0)
    expected_b1 += (model_trace.nodes['d1']['log_prob'] -
                    guide_trace.nodes['d1']['log_prob']).sum(0)
    expected_b1 += model_trace.nodes['obs']['log_prob'].sum(0, keepdim=False)
    if include_inner_1:
        expected_b1 += (model_trace.nodes['c1']['log_prob'] -
                        guide_trace.nodes['c1']['log_prob']).sum(0)
        expected_b1 += (model_trace.nodes['c2']['log_prob'] -
                        guide_trace.nodes['c2']['log_prob']).sum(0)
        expected_b1 += (model_trace.nodes['c3']['log_prob'] -
                        guide_trace.nodes['c3']['log_prob']).sum(0)
    assert_equal(expected_b1, dc['b1'], prec=1.0e-6)

    if include_single:
        expected_b0 = (model_trace.nodes['b0']['log_prob'] -
                       guide_trace.nodes['b0']['log_prob'])
        expected_b0 += (model_trace.nodes['b1']['log_prob'] -
                        guide_trace.nodes['b1']['log_prob']).sum()
        expected_b0 += (model_trace.nodes['d2']['log_prob'] -
                        guide_trace.nodes['d2']['log_prob']).sum()
        expected_b0 += (model_trace.nodes['d1']['log_prob'] -
                        guide_trace.nodes['d1']['log_prob']).sum()
        expected_b0 += model_trace.nodes['obs']['log_prob'].sum()
        if include_inner_1:
            expected_b0 += (model_trace.nodes['c1']['log_prob'] -
                            guide_trace.nodes['c1']['log_prob']).sum()
            expected_b0 += (model_trace.nodes['c2']['log_prob'] -
                            guide_trace.nodes['c2']['log_prob']).sum()
            expected_b0 += (model_trace.nodes['c3']['log_prob'] -
                            guide_trace.nodes['c3']['log_prob']).sum()
        assert_equal(expected_b0, dc['b0'], prec=1.0e-6)
        assert dc['b0'].size() == (5, )

    if include_inner_1:
        expected_c3 = (model_trace.nodes['c3']['log_prob'] -
                       guide_trace.nodes['c3']['log_prob'])
        expected_c3 += (model_trace.nodes['d1']['log_prob'] -
                        guide_trace.nodes['d1']['log_prob']).sum(0)
        expected_c3 += (model_trace.nodes['d2']['log_prob'] -
                        guide_trace.nodes['d2']['log_prob']).sum(0)
        expected_c3 += model_trace.nodes['obs']['log_prob'].sum(0)

        expected_c2 = (model_trace.nodes['c2']['log_prob'] -
                       guide_trace.nodes['c2']['log_prob'])
        expected_c2 += (model_trace.nodes['d1']['log_prob'] -
                        guide_trace.nodes['d1']['log_prob']).sum(0)
        expected_c2 += (model_trace.nodes['d2']['log_prob'] -
                        guide_trace.nodes['d2']['log_prob']).sum(0)
        expected_c2 += model_trace.nodes['obs']['log_prob'].sum(0)

        expected_c1 = (model_trace.nodes['c1']['log_prob'] -
                       guide_trace.nodes['c1']['log_prob'])

        if flip_c23:
            expected_c3 += model_trace.nodes['c2'][
                'log_prob'] - guide_trace.nodes['c2']['log_prob']
            expected_c2 += model_trace.nodes['c3']['log_prob']
        else:
            expected_c2 += model_trace.nodes['c3'][
                'log_prob'] - guide_trace.nodes['c3']['log_prob']
            expected_c2 += model_trace.nodes['c2'][
                'log_prob'] - guide_trace.nodes['c2']['log_prob']
        expected_c1 += expected_c3

        assert_equal(expected_c1, dc['c1'], prec=1.0e-6)
        assert_equal(expected_c2, dc['c2'], prec=1.0e-6)
        assert_equal(expected_c3, dc['c3'], prec=1.0e-6)

    expected_d1 = model_trace.nodes['d1']['log_prob'] - guide_trace.nodes[
        'd1']['log_prob']
    expected_d1 += model_trace.nodes['d2']['log_prob'] - guide_trace.nodes[
        'd2']['log_prob']
    expected_d1 += model_trace.nodes['obs']['log_prob']

    expected_d2 = (model_trace.nodes['d2']['log_prob'] -
                   guide_trace.nodes['d2']['log_prob'])
    expected_d2 += model_trace.nodes['obs']['log_prob']

    if include_triple:
        expected_z0 = dc['a1'] + model_trace.nodes['z0'][
            'log_prob'] - guide_trace.nodes['z0']['log_prob']
        assert_equal(expected_z0, dc['z0'], prec=1.0e-6)
    assert_equal(expected_d2, dc['d2'], prec=1.0e-6)
    assert_equal(expected_d1, dc['d1'], prec=1.0e-6)

    assert dc['b1'].size() == (2, )
    assert dc['d2'].size() == (4, 2)

    for k in dc:
        assert (guide_trace.nodes[k]['log_prob'].size() == dc[k].size())
        assert_equal(dc[k], dc_brute[k])
def test_compute_downstream_costs_big_model_guide_pair(include_inner_1, include_single, flip_c23,
                                                       include_triple, include_z1):
    guide_trace = poutine.trace(big_model_guide,
                                graph_type="dense").get_trace(include_obs=False, include_inner_1=include_inner_1,
                                                              include_single=include_single, flip_c23=flip_c23,
                                                              include_triple=include_triple, include_z1=include_z1)
    model_trace = poutine.trace(poutine.replay(big_model_guide, trace=guide_trace),
                                graph_type="dense").get_trace(include_obs=True, include_inner_1=include_inner_1,
                                                              include_single=include_single, flip_c23=flip_c23,
                                                              include_triple=include_triple, include_z1=include_z1)

    guide_trace = prune_subsample_sites(guide_trace)
    model_trace = prune_subsample_sites(model_trace)
    model_trace.compute_log_prob()
    guide_trace.compute_log_prob()
    non_reparam_nodes = set(guide_trace.nonreparam_stochastic_nodes)

    dc, dc_nodes = _compute_downstream_costs(model_trace, guide_trace,
                                             non_reparam_nodes)

    dc_brute, dc_nodes_brute = _brute_force_compute_downstream_costs(model_trace, guide_trace,
                                                                     non_reparam_nodes)

    assert dc_nodes == dc_nodes_brute

    expected_nodes_full_model = {'a1': {'c2', 'a1', 'd1', 'c1', 'obs', 'b1', 'd2', 'c3', 'b0'}, 'd2': {'obs', 'd2'},
                                 'd1': {'obs', 'd1', 'd2'}, 'c3': {'d2', 'obs', 'd1', 'c3'},
                                 'b0': {'b0', 'd1', 'c1', 'obs', 'b1', 'd2', 'c3', 'c2'},
                                 'b1': {'obs', 'b1', 'd1', 'd2', 'c3', 'c1', 'c2'},
                                 'c1': {'d1', 'c1', 'obs', 'd2', 'c3', 'c2'},
                                 'c2': {'obs', 'd1', 'c3', 'd2', 'c2'}}
    if not include_triple and include_inner_1 and include_single and not flip_c23:
        assert(dc_nodes == expected_nodes_full_model)

    expected_b1 = (model_trace.nodes['b1']['log_prob'] - guide_trace.nodes['b1']['log_prob'])
    expected_b1 += (model_trace.nodes['d2']['log_prob'] - guide_trace.nodes['d2']['log_prob']).sum(0)
    expected_b1 += (model_trace.nodes['d1']['log_prob'] - guide_trace.nodes['d1']['log_prob']).sum(0)
    expected_b1 += model_trace.nodes['obs']['log_prob'].sum(0, keepdim=False)
    if include_inner_1:
        expected_b1 += (model_trace.nodes['c1']['log_prob'] - guide_trace.nodes['c1']['log_prob']).sum(0)
        expected_b1 += (model_trace.nodes['c2']['log_prob'] - guide_trace.nodes['c2']['log_prob']).sum(0)
        expected_b1 += (model_trace.nodes['c3']['log_prob'] - guide_trace.nodes['c3']['log_prob']).sum(0)
    assert_equal(expected_b1, dc['b1'], prec=1.0e-6)

    if include_single:
        expected_b0 = (model_trace.nodes['b0']['log_prob'] - guide_trace.nodes['b0']['log_prob'])
        expected_b0 += (model_trace.nodes['b1']['log_prob'] - guide_trace.nodes['b1']['log_prob']).sum()
        expected_b0 += (model_trace.nodes['d2']['log_prob'] - guide_trace.nodes['d2']['log_prob']).sum()
        expected_b0 += (model_trace.nodes['d1']['log_prob'] - guide_trace.nodes['d1']['log_prob']).sum()
        expected_b0 += model_trace.nodes['obs']['log_prob'].sum()
        if include_inner_1:
            expected_b0 += (model_trace.nodes['c1']['log_prob'] - guide_trace.nodes['c1']['log_prob']).sum()
            expected_b0 += (model_trace.nodes['c2']['log_prob'] - guide_trace.nodes['c2']['log_prob']).sum()
            expected_b0 += (model_trace.nodes['c3']['log_prob'] - guide_trace.nodes['c3']['log_prob']).sum()
        assert_equal(expected_b0, dc['b0'], prec=1.0e-6)
        assert dc['b0'].size() == (5,)

    if include_inner_1:
        expected_c3 = (model_trace.nodes['c3']['log_prob'] - guide_trace.nodes['c3']['log_prob'])
        expected_c3 += (model_trace.nodes['d1']['log_prob'] - guide_trace.nodes['d1']['log_prob']).sum(0)
        expected_c3 += (model_trace.nodes['d2']['log_prob'] - guide_trace.nodes['d2']['log_prob']).sum(0)
        expected_c3 += model_trace.nodes['obs']['log_prob'].sum(0)

        expected_c2 = (model_trace.nodes['c2']['log_prob'] - guide_trace.nodes['c2']['log_prob'])
        expected_c2 += (model_trace.nodes['d1']['log_prob'] - guide_trace.nodes['d1']['log_prob']).sum(0)
        expected_c2 += (model_trace.nodes['d2']['log_prob'] - guide_trace.nodes['d2']['log_prob']).sum(0)
        expected_c2 += model_trace.nodes['obs']['log_prob'].sum(0)

        expected_c1 = (model_trace.nodes['c1']['log_prob'] - guide_trace.nodes['c1']['log_prob'])

        if flip_c23:
            expected_c3 += model_trace.nodes['c2']['log_prob'] - guide_trace.nodes['c2']['log_prob']
            expected_c2 += model_trace.nodes['c3']['log_prob']
        else:
            expected_c2 += model_trace.nodes['c3']['log_prob'] - guide_trace.nodes['c3']['log_prob']
            expected_c2 += model_trace.nodes['c2']['log_prob'] - guide_trace.nodes['c2']['log_prob']
        expected_c1 += expected_c3

        assert_equal(expected_c1, dc['c1'], prec=1.0e-6)
        assert_equal(expected_c2, dc['c2'], prec=1.0e-6)
        assert_equal(expected_c3, dc['c3'], prec=1.0e-6)

    expected_d1 = model_trace.nodes['d1']['log_prob'] - guide_trace.nodes['d1']['log_prob']
    expected_d1 += model_trace.nodes['d2']['log_prob'] - guide_trace.nodes['d2']['log_prob']
    expected_d1 += model_trace.nodes['obs']['log_prob']

    expected_d2 = (model_trace.nodes['d2']['log_prob'] - guide_trace.nodes['d2']['log_prob'])
    expected_d2 += model_trace.nodes['obs']['log_prob']

    if include_triple:
        expected_z0 = dc['a1'] + model_trace.nodes['z0']['log_prob'] - guide_trace.nodes['z0']['log_prob']
        assert_equal(expected_z0, dc['z0'], prec=1.0e-6)
    assert_equal(expected_d2, dc['d2'], prec=1.0e-6)
    assert_equal(expected_d1, dc['d1'], prec=1.0e-6)

    assert dc['b1'].size() == (2,)
    assert dc['d2'].size() == (4, 2)

    for k in dc:
        assert(guide_trace.nodes[k]['log_prob'].size() == dc[k].size())
        assert_equal(dc[k], dc_brute[k])