예제 #1
0
    def differentiable_loss(self, model, guide, *args, **kwargs):
        with plate(size=self.num_particles) if self.num_particles > 1 else contextlib.ExitStack(), \
                enum(first_available_dim=(-self.max_plate_nesting-1) if self.max_plate_nesting else None):
            guide_tr = trace(guide).get_trace(*args, **kwargs)
            model_tr = trace(replay(model, trace=guide_tr)).get_trace(
                *args, **kwargs)

        model_terms = terms_from_trace(model_tr)
        guide_terms = terms_from_trace(guide_tr)

        log_measures = guide_terms["log_measures"] + model_terms["log_measures"]
        log_factors = model_terms["log_factors"] + [
            -f for f in guide_terms["log_factors"]
        ]
        plate_vars = model_terms["plate_vars"] | guide_terms["plate_vars"]
        measure_vars = model_terms["measure_vars"] | guide_terms["measure_vars"]

        with funsor.interpreter.interpretation(funsor.terms.lazy):
            elbo = funsor.sum_product.sum_product(funsor.ops.logaddexp,
                                                  funsor.ops.add,
                                                  log_measures + log_factors,
                                                  eliminate=measure_vars
                                                  | plate_vars,
                                                  plates=plate_vars)

        return -to_data(funsor.optimizer.apply_optimizer(elbo))
예제 #2
0
파일: trace_elbo.py 프로젝트: pyro-ppl/pyro
    def differentiable_loss(self, model, guide, *args, **kwargs):
        with enum(), plate(
                size=self.num_particles
        ) if self.num_particles > 1 else contextlib.ExitStack():
            guide_tr = trace(config_enumerate(
                default="flat")(guide)).get_trace(*args, **kwargs)
            model_tr = trace(replay(model, trace=guide_tr)).get_trace(
                *args, **kwargs)

        model_terms = terms_from_trace(model_tr)
        guide_terms = terms_from_trace(guide_tr)

        log_measures = guide_terms["log_measures"] + model_terms["log_measures"]
        log_factors = model_terms["log_factors"] + [
            -f for f in guide_terms["log_factors"]
        ]
        plate_vars = model_terms["plate_vars"] | guide_terms["plate_vars"]
        measure_vars = model_terms["measure_vars"] | guide_terms["measure_vars"]

        elbo = funsor.Integrate(
            sum(log_measures, to_funsor(0.0)),
            sum(log_factors, to_funsor(0.0)),
            measure_vars,
        )
        elbo = elbo.reduce(funsor.ops.add, plate_vars)

        return -to_data(elbo)
예제 #3
0
    def differentiable_loss(self, model, guide, *args, **kwargs):

        # get batched, enumerated, to_funsor-ed traces from the guide and model
        with plate(size=self.num_particles) if self.num_particles > 1 else contextlib.ExitStack(), \
                enum(first_available_dim=(-self.max_plate_nesting-1) if self.max_plate_nesting else None):
            guide_tr = trace(guide).get_trace(*args, **kwargs)
            model_tr = trace(replay(model, trace=guide_tr)).get_trace(
                *args, **kwargs)

        # extract from traces all metadata that we will need to compute the elbo
        guide_terms = terms_from_trace(guide_tr)
        model_terms = terms_from_trace(model_tr)

        # build up a lazy expression for the elbo
        with funsor.interpreter.interpretation(funsor.terms.lazy):
            # identify and contract out auxiliary variables in the model with partial_sum_product
            contracted_factors, uncontracted_factors = [], []
            for f in model_terms["log_factors"]:
                if model_terms["measure_vars"].intersection(f.inputs):
                    contracted_factors.append(f)
                else:
                    uncontracted_factors.append(f)
            # incorporate the effects of subsampling and handlers.scale through a common scale factor
            contracted_costs = [
                model_terms["scale"] * f
                for f in funsor.sum_product.partial_sum_product(
                    funsor.ops.logaddexp,
                    funsor.ops.add,
                    model_terms["log_measures"] + contracted_factors,
                    plates=model_terms["plate_vars"],
                    eliminate=model_terms["measure_vars"])
            ]

            costs = contracted_costs + uncontracted_factors  # model costs: logp
            costs += [-f for f in guide_terms["log_factors"]
                      ]  # guide costs: -logq

            # finally, integrate out guide variables in the elbo and all plates
            plate_vars = guide_terms["plate_vars"] | model_terms["plate_vars"]
            elbo = to_funsor(0, output=funsor.Real)
            for cost in costs:
                # compute the marginal logq in the guide corresponding to this cost term
                log_prob = funsor.sum_product.sum_product(
                    funsor.ops.logaddexp,
                    funsor.ops.add,
                    guide_terms["log_measures"],
                    plates=plate_vars,
                    eliminate=(plate_vars | guide_terms["measure_vars"]) -
                    frozenset(cost.inputs))
                # compute the expected cost term E_q[logp] or E_q[-logq] using the marginal logq for q
                elbo_term = funsor.Integrate(
                    log_prob, cost,
                    guide_terms["measure_vars"] & frozenset(cost.inputs))
                elbo += elbo_term.reduce(funsor.ops.add,
                                         plate_vars & frozenset(cost.inputs))

        # evaluate the elbo, using memoize to share tensor computation where possible
        with funsor.memoize.memoize():
            return -to_data(funsor.optimizer.apply_optimizer(elbo))
예제 #4
0
def _sample_posterior(model, first_available_dim, temperature, *args,
                      **kwargs):

    if temperature == 0:
        sum_op, prod_op = funsor.ops.max, funsor.ops.add
        approx = funsor.approximations.argmax_approximate
    elif temperature == 1:
        sum_op, prod_op = funsor.ops.logaddexp, funsor.ops.add
        approx = funsor.montecarlo.MonteCarlo()
    else:
        raise ValueError("temperature must be 0 (map) or 1 (sample) for now")

    with block(), enum(first_available_dim=first_available_dim):
        # XXX replay against an empty Trace to ensure densities are not double-counted
        model_tr = trace(replay(model,
                                trace=Trace())).get_trace(*args, **kwargs)

    terms = terms_from_trace(model_tr)
    # terms["log_factors"] = [log p(x) for each observed or latent sample site x]
    # terms["log_measures"] = [log p(z) or other Dice factor
    #                          for each latent sample site z]

    with funsor.interpretations.lazy:
        log_prob = funsor.sum_product.sum_product(
            sum_op,
            prod_op,
            terms["log_factors"] + terms["log_measures"],
            eliminate=terms["measure_vars"] | terms["plate_vars"],
            plates=terms["plate_vars"],
        )
        log_prob = funsor.optimizer.apply_optimizer(log_prob)

    with approx:
        approx_factors = funsor.adjoint.adjoint(sum_op, prod_op, log_prob)

    # construct a result trace to replay against the model
    sample_tr = model_tr.copy()
    sample_subs = {}
    for name, node in sample_tr.nodes.items():
        if node["type"] != "sample" or site_is_subsample(node):
            continue
        if node["is_observed"]:
            # "observed" values may be collapsed samples that depend on enumerated
            # values, so we have to slice them down
            # TODO this should really be handled entirely under the hood by adjoint
            node["funsor"] = {"value": node["funsor"]["value"](**sample_subs)}
        else:
            node["funsor"]["log_measure"] = approx_factors[node["funsor"]
                                                           ["log_measure"]]
            node["funsor"]["value"] = _get_support_value(
                node["funsor"]["log_measure"], name)
            sample_subs[name] = node["funsor"]["value"]

    with replay(trace=sample_tr):
        return model(*args, **kwargs)
예제 #5
0
    def differentiable_loss(self, model, guide, *args, **kwargs):

        # get batched, enumerated, to_funsor-ed traces from the guide and model
        with plate(
                size=self.num_particles
        ) if self.num_particles > 1 else contextlib.ExitStack(), enum(
                first_available_dim=(-self.max_plate_nesting -
                                     1) if self.max_plate_nesting else None):
            guide_tr = trace(guide).get_trace(*args, **kwargs)
            model_tr = trace(replay(model, trace=guide_tr)).get_trace(
                *args, **kwargs)

        # extract from traces all metadata that we will need to compute the elbo
        guide_terms = terms_from_trace(guide_tr)
        model_terms = terms_from_trace(model_tr)

        # build up a lazy expression for the elbo
        with funsor.terms.lazy:
            # identify and contract out auxiliary variables in the model with partial_sum_product
            contracted_factors, uncontracted_factors = [], []
            for f in model_terms["log_factors"]:
                if model_terms["measure_vars"].intersection(f.inputs):
                    contracted_factors.append(f)
                else:
                    uncontracted_factors.append(f)
            # incorporate the effects of subsampling and handlers.scale through a common scale factor
            contracted_costs = [
                model_terms["scale"] * f
                for f in funsor.sum_product.partial_sum_product(
                    funsor.ops.logaddexp,
                    funsor.ops.add,
                    model_terms["log_measures"] + contracted_factors,
                    plates=model_terms["plate_vars"],
                    eliminate=model_terms["measure_vars"],
                )
            ]

            # accumulate costs from model (logp) and guide (-logq)
            costs = contracted_costs + uncontracted_factors  # model costs: logp
            costs += [-f for f in guide_terms["log_factors"]
                      ]  # guide costs: -logq

            # compute expected cost
            # Cf. pyro.infer.util.Dice.compute_expectation()
            # https://github.com/pyro-ppl/pyro/blob/0.3.0/pyro/infer/util.py#L212
            # TODO Replace this with funsor.Expectation
            plate_vars = guide_terms["plate_vars"] | model_terms["plate_vars"]
            # compute the marginal logq in the guide corresponding to each cost term
            targets = dict()
            for cost in costs:
                input_vars = frozenset(cost.inputs)
                if input_vars not in targets:
                    targets[input_vars] = funsor.Tensor(
                        funsor.ops.new_zeros(
                            funsor.tensor.get_default_prototype(),
                            tuple(v.size for v in cost.inputs.values()),
                        ),
                        cost.inputs,
                        cost.dtype,
                    )
            with AdjointTape() as tape:
                logzq = funsor.sum_product.sum_product(
                    funsor.ops.logaddexp,
                    funsor.ops.add,
                    guide_terms["log_measures"] + list(targets.values()),
                    plates=plate_vars,
                    eliminate=(plate_vars | guide_terms["measure_vars"]),
                )
            marginals = tape.adjoint(funsor.ops.logaddexp, funsor.ops.add,
                                     logzq, tuple(targets.values()))
            # finally, integrate out guide variables in the elbo and all plates
            elbo = to_funsor(0, output=funsor.Real)
            for cost in costs:
                target = targets[frozenset(cost.inputs)]
                logzq_local = marginals[target].reduce(
                    funsor.ops.logaddexp,
                    frozenset(cost.inputs) - plate_vars)
                log_prob = marginals[target] - logzq_local
                elbo_term = funsor.Integrate(
                    log_prob,
                    cost,
                    guide_terms["measure_vars"] & frozenset(log_prob.inputs),
                )
                elbo += elbo_term.reduce(funsor.ops.add,
                                         plate_vars & frozenset(cost.inputs))

        # evaluate the elbo, using memoize to share tensor computation where possible
        with funsor.interpretations.memoize():
            return -to_data(apply_optimizer(elbo))