예제 #1
0
    def __call__(self, model, guide, features, trip_counts):
        q = guide(features, trip_counts)
        if self.args.debug:
            print(f"q = {q.quote()}")
        with interpretation(normalize):
            p_prior, p_likelihood = model(features, trip_counts)
            if self.args.debug:
                print(f"p_prior = {p_prior.quote()}")
                print(f"p_likelihood = {p_likelihood.quote()}")

        if self.args.analytic_kl:
            # We can compute the KL part exactly.
            exact_part = funsor.Integrate(q, p_prior - q,
                                          frozenset(["gate_rate_t"]))

            # But we need to Monte Carlo approximate to compute likelihood.
            with interpretation(monte_carlo):
                approx_part = funsor.Integrate(q, p_likelihood,
                                               frozenset(["gate_rate_t"]))

            elbo = exact_part + approx_part
        else:
            with interpretation(normalize):
                p = p_prior + p_likelihood
                pq = p - q
            # Monte Carlo approximate everything.
            with interpretation(monte_carlo):
                elbo = funsor.Integrate(q, pq, frozenset(["gate_rate_t"]))

        loss = -elbo
        assert not loss.inputs, loss.inputs
        assert isinstance(loss, funsor.Tensor), loss.pretty()
        return loss.data
예제 #2
0
파일: trace_elbo.py 프로젝트: pyro-ppl/pyro
    def differentiable_loss(self, model, guide, *args, **kwargs):
        with enum(), plate(
                size=self.num_particles
        ) if self.num_particles > 1 else contextlib.ExitStack():
            guide_tr = trace(config_enumerate(
                default="flat")(guide)).get_trace(*args, **kwargs)
            model_tr = trace(replay(model, trace=guide_tr)).get_trace(
                *args, **kwargs)

        model_terms = terms_from_trace(model_tr)
        guide_terms = terms_from_trace(guide_tr)

        log_measures = guide_terms["log_measures"] + model_terms["log_measures"]
        log_factors = model_terms["log_factors"] + [
            -f for f in guide_terms["log_factors"]
        ]
        plate_vars = model_terms["plate_vars"] | guide_terms["plate_vars"]
        measure_vars = model_terms["measure_vars"] | guide_terms["measure_vars"]

        elbo = funsor.Integrate(
            sum(log_measures, to_funsor(0.0)),
            sum(log_factors, to_funsor(0.0)),
            measure_vars,
        )
        elbo = elbo.reduce(funsor.ops.add, plate_vars)

        return -to_data(elbo)
예제 #3
0
    def differentiable_loss(self, model, guide, *args, **kwargs):

        # get batched, enumerated, to_funsor-ed traces from the guide and model
        with plate(size=self.num_particles) if self.num_particles > 1 else contextlib.ExitStack(), \
                enum(first_available_dim=(-self.max_plate_nesting-1) if self.max_plate_nesting else None):
            guide_tr = trace(guide).get_trace(*args, **kwargs)
            model_tr = trace(replay(model, trace=guide_tr)).get_trace(
                *args, **kwargs)

        # extract from traces all metadata that we will need to compute the elbo
        guide_terms = terms_from_trace(guide_tr)
        model_terms = terms_from_trace(model_tr)

        # build up a lazy expression for the elbo
        with funsor.interpreter.interpretation(funsor.terms.lazy):
            # identify and contract out auxiliary variables in the model with partial_sum_product
            contracted_factors, uncontracted_factors = [], []
            for f in model_terms["log_factors"]:
                if model_terms["measure_vars"].intersection(f.inputs):
                    contracted_factors.append(f)
                else:
                    uncontracted_factors.append(f)
            # incorporate the effects of subsampling and handlers.scale through a common scale factor
            contracted_costs = [
                model_terms["scale"] * f
                for f in funsor.sum_product.partial_sum_product(
                    funsor.ops.logaddexp,
                    funsor.ops.add,
                    model_terms["log_measures"] + contracted_factors,
                    plates=model_terms["plate_vars"],
                    eliminate=model_terms["measure_vars"])
            ]

            costs = contracted_costs + uncontracted_factors  # model costs: logp
            costs += [-f for f in guide_terms["log_factors"]
                      ]  # guide costs: -logq

            # finally, integrate out guide variables in the elbo and all plates
            plate_vars = guide_terms["plate_vars"] | model_terms["plate_vars"]
            elbo = to_funsor(0, output=funsor.Real)
            for cost in costs:
                # compute the marginal logq in the guide corresponding to this cost term
                log_prob = funsor.sum_product.sum_product(
                    funsor.ops.logaddexp,
                    funsor.ops.add,
                    guide_terms["log_measures"],
                    plates=plate_vars,
                    eliminate=(plate_vars | guide_terms["measure_vars"]) -
                    frozenset(cost.inputs))
                # compute the expected cost term E_q[logp] or E_q[-logq] using the marginal logq for q
                elbo_term = funsor.Integrate(
                    log_prob, cost,
                    guide_terms["measure_vars"] & frozenset(cost.inputs))
                elbo += elbo_term.reduce(funsor.ops.add,
                                         plate_vars & frozenset(cost.inputs))

        # evaluate the elbo, using memoize to share tensor computation where possible
        with funsor.memoize.memoize():
            return -to_data(funsor.optimizer.apply_optimizer(elbo))
예제 #4
0
def Expectation(log_probs, costs, sum_vars, prod_vars):
    result = 0
    for cost in costs:
        log_prob = funsor.sum_product.sum_product(
            sum_op=funsor.ops.logaddexp,
            prod_op=funsor.ops.add,
            factors=log_probs,
            plates=prod_vars,
            eliminate=(prod_vars | sum_vars) - frozenset(cost.inputs)
        )
        term = funsor.Integrate(log_prob, cost, sum_vars & frozenset(cost.inputs))
        term = term.reduce(funsor.ops.add, prod_vars & frozenset(cost.inputs))
        result += term
    return result
예제 #5
0
    def loss_function(data, subsample_scale):
        # Lazily sample from the guide.
        loc, scale = encode(data)
        q = funsor.Independent(dist.Normal(loc['i'], scale['i'], value='z_i'),
                               'z', 'i', 'z_i')

        # Evaluate the model likelihood at the lazy value z.
        probs = decode('z')
        p = dist.Bernoulli(probs['x', 'y'], value=data['x', 'y'])
        p = p.reduce(ops.add, {'x', 'y'})

        # Construct an elbo. This is where sampling happens.
        elbo = funsor.Integrate(q, p - q, frozenset(['z']))
        elbo = elbo.reduce(ops.add, 'batch') * subsample_scale
        loss = -elbo
        return loss
예제 #6
0
def test_bart(analytic_kl):
    global call_count
    call_count = 0

    with interpretation(reflect):
        q = Independent(
            Independent(
                Contraction(
                    ops.nullop,
                    ops.add,
                    frozenset(),
                    (
                        Tensor(
                            torch.tensor(
                                [[
                                    -0.6077086925506592, -1.1546266078948975,
                                    -0.7021151781082153, -0.5303535461425781,
                                    -0.6365622282028198, -1.2423288822174072,
                                    -0.9941254258155823, -0.6287292242050171
                                ],
                                 [
                                     -0.6987162828445435, -1.0875964164733887,
                                     -0.7337473630905151, -0.4713417589664459,
                                     -0.6674002408981323, -1.2478348016738892,
                                     -0.8939017057418823, -0.5238542556762695
                                 ]],
                                dtype=torch.float32),  # noqa
                            (
                                (
                                    'time_b4',
                                    bint(2),
                                ),
                                (
                                    '_event_1_b2',
                                    bint(8),
                                ),
                            ),
                            'real'),
                        Gaussian(
                            torch.tensor([
                                [[-0.3536059558391571], [-0.21779225766658783],
                                 [0.2840439975261688], [0.4531521499156952],
                                 [-0.1220812276005745], [-0.05519985035061836],
                                 [0.10932210087776184], [0.6656699776649475]],
                                [[-0.39107921719551086], [
                                    -0.20241987705230713
                                ], [0.2170514464378357], [0.4500560462474823],
                                 [0.27945515513420105], [-0.0490039587020874],
                                 [-0.06399798393249512], [0.846565842628479]]
                            ],
                                         dtype=torch.float32),  # noqa
                            torch.tensor([
                                [[[1.984686255455017]], [[0.6699360013008118]],
                                 [[1.6215802431106567]], [[2.372016668319702]],
                                 [[1.77385413646698]], [[0.526767373085022]],
                                 [[0.8722561597824097]], [[2.1879124641418457]]
                                 ],
                                [[[1.6996612548828125]], [[
                                    0.7535632252693176
                                ]], [[1.4946647882461548]],
                                 [[2.642792224884033]], [[1.7301604747772217]],
                                 [[0.5203893780708313]], [[1.055436372756958]],
                                 [[2.8370864391326904]]]
                            ],
                                         dtype=torch.float32),  # noqa
                            (
                                (
                                    'time_b4',
                                    bint(2),
                                ),
                                (
                                    '_event_1_b2',
                                    bint(8),
                                ),
                                (
                                    'value_b1',
                                    reals(),
                                ),
                            )),
                    )),
                'gate_rate_b3',
                '_event_1_b2',
                'value_b1'),
            'gate_rate_t',
            'time_b4',
            'gate_rate_b3')
        p_prior = Contraction(
            ops.logaddexp,
            ops.add,
            frozenset({'state(time=1)_b11', 'state_b10'}),
            (
                MarkovProduct(
                    ops.logaddexp,
                    ops.add,
                    Contraction(
                        ops.nullop,
                        ops.add,
                        frozenset(),
                        (
                            Tensor(
                                torch.tensor(2.7672932147979736,
                                             dtype=torch.float32), (), 'real'),
                            Gaussian(
                                torch.tensor([-0.0, -0.0, 0.0, 0.0],
                                             dtype=torch.float32),
                                torch.tensor([[
                                    98.01002502441406, 0.0, -99.0000228881836,
                                    -0.0
                                ],
                                              [
                                                  0.0, 98.01002502441406, -0.0,
                                                  -99.0000228881836
                                              ],
                                              [
                                                  -99.0000228881836, -0.0,
                                                  100.0000228881836, 0.0
                                              ],
                                              [
                                                  -0.0, -99.0000228881836, 0.0,
                                                  100.0000228881836
                                              ]],
                                             dtype=torch.float32),  # noqa
                                (
                                    (
                                        'state_b7',
                                        reals(2, ),
                                    ),
                                    (
                                        'state(time=1)_b8',
                                        reals(2, ),
                                    ),
                                )),
                            Subs(
                                AffineNormal(
                                    Tensor(
                                        torch.tensor(
                                            [[
                                                0.03488487750291824,
                                                0.07356668263673782,
                                                0.19946961104869843,
                                                0.5386509299278259,
                                                -0.708323061466217,
                                                0.24411526322364807,
                                                -0.20855577290058136,
                                                -0.2421337217092514
                                            ],
                                             [
                                                 0.41762110590934753,
                                                 0.5272183418273926,
                                                 -0.49835553765296936,
                                                 -0.0363837406039238,
                                                 -0.0005282597267068923,
                                                 0.2704298794269562,
                                                 -0.155222088098526,
                                                 -0.44802337884902954
                                             ]],
                                            dtype=torch.float32),  # noqa
                                        (),
                                        'real'),
                                    Tensor(
                                        torch.tensor(
                                            [[
                                                -0.003566693514585495,
                                                -0.2848514914512634,
                                                0.037103548645973206,
                                                0.12648648023605347,
                                                -0.18501518666744232,
                                                -0.20899859070777893,
                                                0.04121830314397812,
                                                0.0054807960987091064
                                            ],
                                             [
                                                 0.0021788496524095535,
                                                 -0.18700894713401794,
                                                 0.08187370002269745,
                                                 0.13554862141609192,
                                                 -0.10477752983570099,
                                                 -0.20848378539085388,
                                                 -0.01393645629286766,
                                                 0.011670656502246857
                                             ]],
                                            dtype=torch.float32),  # noqa
                                        ((
                                            'time_b9',
                                            bint(2),
                                        ), ),
                                        'real'),
                                    Tensor(
                                        torch.tensor(
                                            [[
                                                0.5974780917167664,
                                                0.864071786403656,
                                                1.0236268043518066,
                                                0.7147538065910339,
                                                0.7423890233039856,
                                                0.9462157487869263,
                                                1.2132389545440674,
                                                1.0596832036972046
                                            ],
                                             [
                                                 0.5787821412086487,
                                                 0.9178534150123596,
                                                 0.9074794054031372,
                                                 0.6600189208984375,
                                                 0.8473222255706787,
                                                 0.8426999449729919,
                                                 1.194266438484192,
                                                 1.0471148490905762
                                             ]],
                                            dtype=torch.float32),  # noqa
                                        ((
                                            'time_b9',
                                            bint(2),
                                        ), ),
                                        'real'),
                                    Variable('state(time=1)_b8', reals(2, )),
                                    Variable('gate_rate_b6', reals(8, ))),
                                ((
                                    'gate_rate_b6',
                                    Binary(
                                        ops.GetitemOp(0),
                                        Variable('gate_rate_t', reals(2, 8)),
                                        Variable('time_b9', bint(2))),
                                ), )),
                        )),
                    Variable('time_b9', bint(2)),
                    frozenset({('state_b7', 'state(time=1)_b8')}),
                    frozenset({('state(time=1)_b8', 'state(time=1)_b11'),
                               ('state_b7', 'state_b10')})),  # noqa
                Subs(
                    dist.MultivariateNormal(
                        Tensor(torch.tensor([0.0, 0.0], dtype=torch.float32),
                               (), 'real'),
                        Tensor(
                            torch.tensor([[10.0, 0.0], [0.0, 10.0]],
                                         dtype=torch.float32),
                            (), 'real'), Variable('value_b5', reals(2, ))), ((
                                'value_b5',
                                Variable('state_b10', reals(2, )),
                            ), )),
            ))
        p_likelihood = Contraction(
            ops.add,
            ops.nullop,
            frozenset({'time_b17', 'destin_b16', 'origin_b15'}),
            (
                Contraction(
                    ops.logaddexp,
                    ops.add,
                    frozenset({'gated_b14'}),
                    (
                        dist.Categorical(
                            Binary(
                                ops.GetitemOp(0),
                                Binary(
                                    ops.GetitemOp(0),
                                    Subs(
                                        Function(
                                            unpack_gate_rate_0, reals(2, 2, 2),
                                            (Variable('gate_rate_b12',
                                                      reals(8, )), )),
                                        ((
                                            'gate_rate_b12',
                                            Binary(
                                                ops.GetitemOp(0),
                                                Variable(
                                                    'gate_rate_t', reals(2,
                                                                         8)),
                                                Variable('time_b17', bint(2))),
                                        ), )), Variable('origin_b15',
                                                        bint(2))),
                                Variable('destin_b16', bint(2))),
                            Variable('gated_b14', bint(2))),
                        Stack(
                            'gated_b14',
                            (
                                dist.Poisson(
                                    Binary(
                                        ops.GetitemOp(0),
                                        Binary(
                                            ops.GetitemOp(0),
                                            Subs(
                                                Function(
                                                    unpack_gate_rate_1,
                                                    reals(2, 2), (Variable(
                                                        'gate_rate_b13',
                                                        reals(8, )), )),
                                                ((
                                                    'gate_rate_b13',
                                                    Binary(
                                                        ops.GetitemOp(0),
                                                        Variable(
                                                            'gate_rate_t',
                                                            reals(2, 8)),
                                                        Variable(
                                                            'time_b17',
                                                            bint(2))),
                                                ), )),
                                            Variable('origin_b15', bint(2))),
                                        Variable('destin_b16', bint(2))),
                                    Tensor(
                                        torch.tensor(
                                            [[[1.0, 1.0], [5.0, 0.0]],
                                             [[0.0, 6.0], [19.0, 3.0]]],
                                            dtype=torch.float32),  # noqa
                                        (
                                            (
                                                'time_b17',
                                                bint(2),
                                            ),
                                            (
                                                'origin_b15',
                                                bint(2),
                                            ),
                                            (
                                                'destin_b16',
                                                bint(2),
                                            ),
                                        ),
                                        'real')),
                                dist.Delta(
                                    Tensor(
                                        torch.tensor(0.0, dtype=torch.float32),
                                        (), 'real'),
                                    Tensor(
                                        torch.tensor(0.0, dtype=torch.float32),
                                        (), 'real'),
                                    Tensor(
                                        torch.tensor(
                                            [[[1.0, 1.0], [5.0, 0.0]],
                                             [[0.0, 6.0], [19.0, 3.0]]],
                                            dtype=torch.float32),  # noqa
                                        (
                                            (
                                                'time_b17',
                                                bint(2),
                                            ),
                                            (
                                                'origin_b15',
                                                bint(2),
                                            ),
                                            (
                                                'destin_b16',
                                                bint(2),
                                            ),
                                        ),
                                        'real')),
                            )),
                    )), ))

    if analytic_kl:
        exact_part = funsor.Integrate(q, p_prior - q, "gate_rate_t")
        with interpretation(monte_carlo):
            approx_part = funsor.Integrate(q, p_likelihood, "gate_rate_t")
        elbo = exact_part + approx_part
    else:
        p = p_prior + p_likelihood
        with interpretation(monte_carlo):
            elbo = Integrate(q, p - q, "gate_rate_t")

    assert isinstance(elbo, Tensor), elbo.pretty()
    assert call_count == 1
예제 #7
0
    def differentiable_loss(self, model, guide, *args, **kwargs):

        # get batched, enumerated, to_funsor-ed traces from the guide and model
        with plate(
                size=self.num_particles
        ) if self.num_particles > 1 else contextlib.ExitStack(), enum(
                first_available_dim=(-self.max_plate_nesting -
                                     1) if self.max_plate_nesting else None):
            guide_tr = trace(guide).get_trace(*args, **kwargs)
            model_tr = trace(replay(model, trace=guide_tr)).get_trace(
                *args, **kwargs)

        # extract from traces all metadata that we will need to compute the elbo
        guide_terms = terms_from_trace(guide_tr)
        model_terms = terms_from_trace(model_tr)

        # build up a lazy expression for the elbo
        with funsor.terms.lazy:
            # identify and contract out auxiliary variables in the model with partial_sum_product
            contracted_factors, uncontracted_factors = [], []
            for f in model_terms["log_factors"]:
                if model_terms["measure_vars"].intersection(f.inputs):
                    contracted_factors.append(f)
                else:
                    uncontracted_factors.append(f)
            # incorporate the effects of subsampling and handlers.scale through a common scale factor
            contracted_costs = [
                model_terms["scale"] * f
                for f in funsor.sum_product.partial_sum_product(
                    funsor.ops.logaddexp,
                    funsor.ops.add,
                    model_terms["log_measures"] + contracted_factors,
                    plates=model_terms["plate_vars"],
                    eliminate=model_terms["measure_vars"],
                )
            ]

            # accumulate costs from model (logp) and guide (-logq)
            costs = contracted_costs + uncontracted_factors  # model costs: logp
            costs += [-f for f in guide_terms["log_factors"]
                      ]  # guide costs: -logq

            # compute expected cost
            # Cf. pyro.infer.util.Dice.compute_expectation()
            # https://github.com/pyro-ppl/pyro/blob/0.3.0/pyro/infer/util.py#L212
            # TODO Replace this with funsor.Expectation
            plate_vars = guide_terms["plate_vars"] | model_terms["plate_vars"]
            # compute the marginal logq in the guide corresponding to each cost term
            targets = dict()
            for cost in costs:
                input_vars = frozenset(cost.inputs)
                if input_vars not in targets:
                    targets[input_vars] = funsor.Tensor(
                        funsor.ops.new_zeros(
                            funsor.tensor.get_default_prototype(),
                            tuple(v.size for v in cost.inputs.values()),
                        ),
                        cost.inputs,
                        cost.dtype,
                    )
            with AdjointTape() as tape:
                logzq = funsor.sum_product.sum_product(
                    funsor.ops.logaddexp,
                    funsor.ops.add,
                    guide_terms["log_measures"] + list(targets.values()),
                    plates=plate_vars,
                    eliminate=(plate_vars | guide_terms["measure_vars"]),
                )
            marginals = tape.adjoint(funsor.ops.logaddexp, funsor.ops.add,
                                     logzq, tuple(targets.values()))
            # finally, integrate out guide variables in the elbo and all plates
            elbo = to_funsor(0, output=funsor.Real)
            for cost in costs:
                target = targets[frozenset(cost.inputs)]
                logzq_local = marginals[target].reduce(
                    funsor.ops.logaddexp,
                    frozenset(cost.inputs) - plate_vars)
                log_prob = marginals[target] - logzq_local
                elbo_term = funsor.Integrate(
                    log_prob,
                    cost,
                    guide_terms["measure_vars"] & frozenset(log_prob.inputs),
                )
                elbo += elbo_term.reduce(funsor.ops.add,
                                         plate_vars & frozenset(cost.inputs))

        # evaluate the elbo, using memoize to share tensor computation where possible
        with funsor.interpretations.memoize():
            return -to_data(apply_optimizer(elbo))