Beispiel #1
0
    def __call__(self, model, guide, *args, **kwargs):
        """
        Computes the surrogate loss that can be differentiated with autograd
        to produce gradient estimates for the model and guide parameters.
        """
        guide_trace, model_trace = self._get_traces(model, guide, args, kwargs)

        # Extract observations and posterior predictive samples.
        data = OrderedDict()
        samples = OrderedDict()
        for name, site in model_trace.nodes.items():
            if site["type"] == "sample" and site["is_observed"]:
                data[name] = site["infer"]["obs"]
                samples[name] = site["value"]
        assert list(data.keys()) == list(samples.keys())
        if not data:
            raise ValueError("Found no observations")

        # Compute energy distance from mean average error and generalized entropy.
        squared_error = []  # E[ (X - x)^2 ]
        squared_entropy = []  # E[ (X - X')^2 ]
        prototype = next(iter(data.values()))
        pairs = prototype.new_ones(self.num_particles,
                                   self.num_particles).tril(-1).nonzero()
        for name, obs in data.items():
            sample = samples[name]
            scale = model_trace.nodes[name]["scale"]
            mask = model_trace.nodes[name]["mask"]

            # Flatten to subshapes of (num_particles, batch_size, event_size).
            event_dim = model_trace.nodes[name]["fn"].event_dim
            batch_shape = obs.shape[:obs.dim() - event_dim]
            event_shape = obs.shape[obs.dim() - event_dim:]
            if getattr(scale, 'shape', ()):
                scale = scale.expand(batch_shape).reshape(-1)
            if getattr(mask, 'shape', ()):
                mask = mask.expand(batch_shape).reshape(-1)
            obs = obs.reshape(batch_shape.numel(), event_shape.numel())
            sample = sample.reshape(self.num_particles, batch_shape.numel(),
                                    event_shape.numel())

            squared_error.append(_squared_error(sample, obs, scale, mask))
            squared_entropy.append(
                _squared_error(*sample[pairs].unbind(1), scale, mask))

        squared_error = reduce(operator.add, squared_error)
        squared_entropy = reduce(operator.add, squared_entropy)
        error = self._pow(squared_error).mean()  # E[ ||X-x||^beta ]
        entropy = self._pow(squared_entropy).mean()  # E[ ||X-X'||^beta ]
        energy = error - 0.5 * entropy

        # Compute prior.
        log_prior = 0
        if self.prior_scale > 0:
            for site in model_trace.nodes.values():
                if site["type"] == "sample" and not site["is_observed"]:
                    log_prior = log_prior + site["log_prob_sum"]

        # Compute final loss.
        loss = energy - self.prior_scale * log_prior
        warn_if_nan(loss, "loss")
        return loss
Beispiel #2
0
    def loss_and_grads(self, grads, batch, *args, **kwargs):
        """
        :returns: an estimate of the loss (expectation over p(x, y) of
            -log q(x, y) ) - where p is the model and q is the guide
        :rtype: float

        If a batch is provided, the loss is estimated using these traces
        Otherwise, a fresh batch is generated from the model.

        If grads is True, will also call `backward` on loss.

        `args` and `kwargs` are passed to the model and guide.
        """
        if batch is None:
            indices = np.random.choice(len(self.simulations),
                                       size = self.training_batch_size,
                                       replace = False)
            batch = [self.simulations[i] for i in indices]
            batch_size = self.training_batch_size
        else:
            batch_size = len(batch)

        # Collect all cross matched guide traces
        with poutine.trace(param_only=True) as particle_param_capture:
            guide_traces = []

            for i in range(batch_size):
                # model_x: True model against which we contrast the rest
                model_x_trace = batch[i]

                guide_traces.append([])

                for j in range(batch_size):
                    # model_z: Contrasting model parameters
                    model_z_trace = batch[j]
    
                    # Evaluate matched guide
                    guide_trace = self._get_matched_cross_trace(
                        model_x_trace, model_z_trace, *args, **kwargs)  

                    guide_traces[-1].append(guide_trace)

        loss = torch.tensor(0.)
        
        # Calculate losses per site
        for site_name in self.site_names:
            for i in range(batch_size):
                model_x_trace = batch[i]  

                log_prob_priors = []
                for j in range(batch_size):
                    model_z_trace = batch[j]
                    log_prob_prior = (
                        model_x_trace.nodes[site_name]['fn'].log_prob(
                        model_z_trace.nodes[site_name]['value']))
                    log_prob_priors.append(log_prob_prior.unsqueeze(0))
                log_prob_priors = torch.cat(log_prob_priors, 0)

                guide_losses = torch.cat(
                    [self._differentiable_loss_particle(
                        guide_trace, site_name = site_name).unsqueeze(0)
                        for guide_trace in guide_traces[i]], 0)

                f_phis = guide_losses + log_prob_priors
                r = -torch.log_softmax(-f_phis, 0)
                particle_loss = r[i].sum()/batch_size

                loss += particle_loss

        warn_if_nan(loss, "loss")

        if grads:
            guide_params = set(site["value"].unconstrained()
                            for site in particle_param_capture.trace.nodes.values())
            guide_params = list(guide_params)
            torch.autograd.set_detect_anomaly(True)
            guide_grads = torch.autograd.grad(loss, guide_params, allow_unused=True, retain_graph=True)
            for guide_grad, guide_param in zip(guide_grads, guide_params):
                if guide_param.grad is None:
                    guide_param.grad = guide_grad
                else:
                    if guide_grad is not None:
                        guide_param.grad =  guide_param.grad + guide_grad 

        return torch_item(loss)
Beispiel #3
0
def compute_marginals_persistent_bp(exists_logits,
                                    assign_logits,
                                    bp_iters,
                                    bp_momentum=0.5):
    """
    This implements approximate inference of pairwise marginals via
    loopy belief propagation, adapting the approach of [1], [2].

    See :class:`MarginalAssignmentPersistent` for args and problem description.

    [1] Jason L. Williams, Roslyn A. Lau (2014)
        Approximate evaluation of marginal association probabilities with
        belief propagation
        https://arxiv.org/abs/1209.6299
    [2] Ryan Turner, Steven Bottone, Bhargav Avasarala (2014)
        A Complete Variational Tracker
        https://papers.nips.cc/paper/5572-a-complete-variational-tracker.pdf
    """
    # This implements forward-backward message passing among three sets of variables:
    #
    #   a[t,j] ~ Categorical(num_objects + 1), detection -> object assignment
    #   b[t,i] ~ Categorical(num_detections + 1), object -> detection assignment
    #     e[i] ~ Bernonulli, whether each object exists
    #
    # Only assign = a and exists = e are returned.
    assert 0 <= bp_momentum < 1, bp_momentum
    old, new = bp_momentum, 1 - bp_momentum
    num_frames, num_detections, num_objects = assign_logits.shape
    message_b_to_a = assign_logits.new_zeros(num_frames, num_detections,
                                             num_objects)
    message_a_to_b = assign_logits.new_zeros(num_frames, num_detections,
                                             num_objects)
    message_b_to_e = assign_logits.new_zeros(num_frames, num_objects)
    message_e_to_b = assign_logits.new_zeros(num_frames, num_objects)

    for i in range(bp_iters):
        odds_a = (assign_logits + message_b_to_a).exp()
        message_a_to_b = (old * message_a_to_b + new *
                          (assign_logits -
                           (odds_a.sum(2, True) - odds_a).log1p()))
        message_b_to_e = (old * message_b_to_e +
                          new * message_a_to_b.exp().sum(1).log1p())
        message_e_to_b = (
            old * message_e_to_b + new *
            (exists_logits + message_b_to_e.sum(0) - message_b_to_e))
        odds_b = message_a_to_b.exp()
        message_b_to_a = (old * message_b_to_a - new *
                          ((-message_e_to_b).exp().unsqueeze(1) +
                           (1 + odds_b.sum(1, True) - odds_b)).log())

        warn_if_nan(message_a_to_b, 'message_a_to_b iter {}'.format(i))
        warn_if_nan(message_b_to_e, 'message_b_to_e iter {}'.format(i))
        warn_if_nan(message_e_to_b, 'message_e_to_b iter {}'.format(i))
        warn_if_nan(message_b_to_a, 'message_b_to_a iter {}'.format(i))

    # Convert from probs to logits.
    exists = exists_logits + message_b_to_e.sum(0)
    assign = assign_logits + message_b_to_a
    warn_if_nan(exists, 'exists')
    warn_if_nan(assign, 'assign')
    return exists, assign
Beispiel #4
0
    def loss_and_grads(self, model, guide, *args, **kwargs):
        """
        :returns: returns an estimate of the ELBO
        :rtype: float
        Computes the ELBO as well as the surrogate ELBO that is used to form the gradient estimator.
        Performs backward on the latter. Num_particle many samples are used to form the estimators.
        """
        elbo_particles = []
        surrogate_elbo_particles = []
        is_vectorized = self.vectorize_particles and self.num_particles > 1
        tensor_holder = None

        # grab a vectorized trace from the generator
        for model_trace, guide_trace in self._get_traces(
                model, guide, args, kwargs):
            elbo_particle = 0
            surrogate_elbo_particle = 0

            # compute elbo and surrogate elbo
            for name, site in model_trace.nodes.items():
                if site["type"] == "sample":
                    if is_vectorized:
                        log_prob_sum = site["log_prob"].reshape(
                            self.num_particles, -1).sum(-1)
                    else:
                        log_prob_sum = site["log_prob_sum"]
                    elbo_particle = elbo_particle + log_prob_sum.detach()
                    surrogate_elbo_particle = surrogate_elbo_particle + log_prob_sum

            for name, site in guide_trace.nodes.items():
                if site["type"] == "sample":
                    log_prob, score_function_term, entropy_term = site[
                        "score_parts"]
                    if is_vectorized:
                        log_prob_sum = log_prob.reshape(
                            self.num_particles, -1).sum(-1)
                    else:
                        log_prob_sum = site["log_prob_sum"]

                    elbo_particle = elbo_particle - log_prob_sum.detach()

                    if not is_identically_zero(entropy_term):
                        surrogate_elbo_particle = surrogate_elbo_particle - log_prob_sum

                        if not is_identically_zero(score_function_term):
                            # link to the issue: https://github.com/uber/pyro/issues/1222
                            raise NotImplementedError

                    # if not is_identically_zero(score_function_term):
                    #     surrogate_elbo_particle = surrogate_elbo_particle

            if is_identically_zero(elbo_particle):
                if tensor_holder is not None:
                    elbo_particle = tensor_holder.new_zeros(
                        tensor_holder.shape)
                    surrogate_elbo_particle = tensor_holder.new_zeros(
                        tensor_holder.shape)
            else:  # elbo_particle is not None
                if tensor_holder is None:
                    tensor_holder = elbo_particle.new_empty(
                        elbo_particle.shape)
                    # change types of previous `elbo_particle`s
                    for i in range(len(elbo_particles)):
                        elbo_particles[i] = tensor_holder.new_zeros(
                            tensor_holder.shape)
                        surrogate_elbo_particles[i] = tensor_holder.new_zeros(
                            tensor_holder.shape)

            elbo_particles.append(elbo_particle)
            surrogate_elbo_particles.append(surrogate_elbo_particle)

        if tensor_holder is None:
            return 0.

        if is_vectorized:
            elbo_particles = elbo_particles[0]
            surrogate_elbo_particles = surrogate_elbo_particles[0]
        else:
            elbo_particles = torch.stack(elbo_particles)
            surrogate_elbo_particles = torch.stack(surrogate_elbo_particles)

        log_weights = elbo_particles
        log_mean_weight = log_sum_exp(elbo_particles, dim=0) - math.log(
            self.num_particles)
        elbo = log_mean_weight.sum().item()

        # collect parameters to train from model and guide
        trainable_params = any(site["type"] == "param"
                               for trace in (model_trace, guide_trace)
                               for site in trace.nodes.values())

        if trainable_params and getattr(surrogate_elbo_particles,
                                        'requires_grad', False):
            normalized_weights = (log_weights - log_mean_weight).exp()
            surrogate_elbo = (normalized_weights * surrogate_elbo_particles
                              ).sum() / self.num_particles
            surrogate_loss = -surrogate_elbo
            surrogate_loss.backward()

        loss = -elbo
        warn_if_nan(loss, "loss")
        return loss
Beispiel #5
0
    def _loss(self, model, guide, args, kwargs):
        """
        :returns: returns model loss and guide loss
        :rtype: float, float

        Computes the re-weighted wake-sleep estimators for the model (wake-theta) and the
          guide (insomnia * wake-phi + (1 - insomnia) * sleep-phi).
        Performs backward as appropriate on both, over the specified number of particles.
        """

        wake_theta_loss = torch.tensor(100.)
        if self.model_has_params or self.insomnia > 0.:
            # compute quantities for wake theta and wake phi
            log_joints = []
            log_qs = []

            for model_trace, guide_trace in self._get_traces(
                    model, guide, args, kwargs):
                log_joint = 0.
                log_q = 0.

                for _, site in model_trace.nodes.items():
                    if site["type"] == "sample":
                        if self.vectorize_particles:
                            log_p_site = site["log_prob"].reshape(
                                self.num_particles, -1).sum(-1)
                        else:
                            log_p_site = site["log_prob_sum"]
                        log_joint = log_joint + log_p_site

                for _, site in guide_trace.nodes.items():
                    if site["type"] == "sample":
                        if self.vectorize_particles:
                            log_q_site = site["log_prob"].reshape(
                                self.num_particles, -1).sum(-1)
                        else:
                            log_q_site = site["log_prob_sum"]
                        log_q = log_q + log_q_site

                log_joints.append(log_joint)
                log_qs.append(log_q)

            log_joints = log_joints[
                0] if self.vectorize_particles else torch.stack(log_joints)
            log_qs = log_qs[0] if self.vectorize_particles else torch.stack(
                log_qs)
            log_weights = log_joints - log_qs.detach()

            # compute wake theta loss
            log_sum_weight = torch.logsumexp(log_weights, dim=0)
            wake_theta_loss = -(log_sum_weight -
                                math.log(self.num_particles)).sum()
            warn_if_nan(wake_theta_loss, "wake theta loss")

        if self.insomnia > 0:
            # compute wake phi loss
            normalised_weights = (log_weights - log_sum_weight).exp().detach()
            wake_phi_loss = -(normalised_weights * log_qs).sum()
            warn_if_nan(wake_phi_loss, "wake phi loss")

        if self.insomnia < 1:
            # compute sleep phi loss
            _model = pyro.poutine.uncondition(model)
            _guide = guide
            _log_q = 0.

            if self.vectorize_particles:
                if self.max_plate_nesting == float('inf'):
                    self._guess_max_plate_nesting(_model, _guide, args, kwargs)
                _model = self._vectorized_num_sleep_particles(_model)
                _guide = self._vectorized_num_sleep_particles(guide)

            for _ in range(1 if self.vectorize_particles else self.
                           num_sleep_particles):
                _model_trace = poutine.trace(_model).get_trace(*args, **kwargs)
                _model_trace.detach_()
                _guide_trace = self._get_matched_trace(_model_trace, _guide,
                                                       args, kwargs)
                _log_q += _guide_trace.log_prob_sum()

            sleep_phi_loss = -_log_q / self.num_sleep_particles
            warn_if_nan(sleep_phi_loss, "sleep phi loss")

        # compute phi loss
        phi_loss = sleep_phi_loss if self.insomnia == 0 \
            else wake_phi_loss if self.insomnia == 1 \
            else self.insomnia * wake_phi_loss + (1. - self.insomnia) * sleep_phi_loss

        return wake_theta_loss, phi_loss
Beispiel #6
0
    def _differentiable_loss_parts(self, model, guide, *args, **kwargs):
        all_model_samples = defaultdict(list)
        all_guide_samples = defaultdict(list)

        loglikelihood = 0.0
        penalty = 0.0
        for model_trace, guide_trace in self._get_traces(
                model, guide, *args, **kwargs):
            if self.vectorize_particles:
                model_trace_independent = poutine.trace(
                    self._vectorized_num_particles(model)).get_trace(
                        *args, **kwargs)
            else:
                model_trace_independent = poutine.trace(
                    model, graph_type='flat').get_trace(*args, **kwargs)

            loglikelihood_particle = 0.0
            for name, model_site in model_trace.nodes.items():
                if model_site['type'] == 'sample':
                    if name in guide_trace and not model_site['is_observed']:
                        guide_site = guide_trace.nodes[name]
                        independent_model_site = model_trace_independent.nodes[
                            name]
                        if not independent_model_site["fn"].has_rsample:
                            raise ValueError(
                                "Model site {} is not reparameterizable".
                                format(name))
                        if not guide_site["fn"].has_rsample:
                            raise ValueError(
                                "Guide site {} is not reparameterizable".
                                format(name))

                        particle_dim = -self.max_plate_nesting - independent_model_site[
                            "fn"].event_dim

                        model_samples = independent_model_site['value']
                        guide_samples = guide_site['value']

                        if self.vectorize_particles:
                            model_samples = model_samples.transpose(
                                -model_samples.dim(), particle_dim)
                            model_samples = model_samples.view(
                                model_samples.shape[0], -1)

                            guide_samples = guide_samples.transpose(
                                -guide_samples.dim(), particle_dim)
                            guide_samples = guide_samples.view(
                                guide_samples.shape[0], -1)
                        else:
                            model_samples = model_samples.view(1, -1)
                            guide_samples = guide_samples.view(1, -1)

                        all_model_samples[name].append(model_samples)
                        all_guide_samples[name].append(guide_samples)
                    else:
                        loglikelihood_particle = loglikelihood_particle + model_site[
                            'log_prob_sum']

            loglikelihood = loglikelihood_particle / self.num_particles + loglikelihood

        for name in all_model_samples.keys():
            all_model_samples[name] = torch.cat(all_model_samples[name])
            all_guide_samples[name] = torch.cat(all_guide_samples[name])
            divergence = _compute_mmd(all_model_samples[name],
                                      all_guide_samples[name],
                                      kernel=self._kernel[name])
            penalty = self._mmd_scale[name] * divergence + penalty

        warn_if_nan(loglikelihood, "loglikelihood")
        warn_if_nan(penalty, "penalty")
        return loglikelihood, penalty
Beispiel #7
0
    def _quantized_model(self):
        """
        Quantized vectorized model used for parallel-scan enumerated inference.
        This method is called only outside particle_plate.
        """
        C = len(self.compartments)
        T = self.duration
        Q = self.num_quant_bins
        R_shape = getattr(self.population, "shape", ())  # Region shape.

        # Sample global parameters and auxiliary variables.
        params = self.global_model()
        auxiliary, non_compartmental = self._sample_auxiliary()

        # Manually enumerate.
        curr, logp = quantize_enumerate(auxiliary,
                                        min=0,
                                        max=self.population,
                                        num_quant_bins=self.num_quant_bins)
        curr = OrderedDict(zip(self.compartments, curr.unbind(0)))
        logp = OrderedDict(zip(self.compartments, logp.unbind(0)))
        curr.update(non_compartmental)

        # Truncate final value from the right then pad initial value onto the left.
        init = self.initialize(params)
        prev = {}
        for name, value in init.items():
            if name in self.compartments:
                if isinstance(value, torch.Tensor):
                    value = value[
                        ..., None]  # Because curr is enumerated on the right.
                prev[name] = cat2(value,
                                  curr[name][:-1],
                                  dim=-3 if self.is_regional else -2)
            else:  # non-compartmental
                prev[name] = cat2(init[name],
                                  curr[name][:-1],
                                  dim=-curr[name].dim())

        # Reshape to support broadcasting, similar to EnumMessenger.
        def enum_reshape(tensor, position):
            assert tensor.size(-1) == Q
            assert tensor.dim() <= self.max_plate_nesting + 2
            tensor = tensor.permute(tensor.dim() - 1, *range(tensor.dim() - 1))
            shape = [Q] + [1] * (position + self.max_plate_nesting -
                                 (tensor.dim() - 2))
            shape.extend(tensor.shape[1:])
            return tensor.reshape(shape)

        for e, name in enumerate(self.compartments):
            curr[name] = enum_reshape(curr[name], e)
            logp[name] = enum_reshape(logp[name], e)
            prev[name] = enum_reshape(prev[name], e + C)

        # Enable approximate inference by using aux as a non-enumerated proxy
        # for enumerated compartment values.
        for name in self.approximate:
            aux = auxiliary[self.compartments.index(name)]
            curr[name + "_approx"] = aux
            prev[name + "_approx"] = cat2(init[name],
                                          aux[:-1],
                                          dim=-2 if self.is_regional else -1)

        # Record transition factors.
        with poutine.block(), poutine.trace() as tr:
            with self.time_plate:
                t = slice(0, T, 1)  # Used to slice data tensors.
                self._transition_bwd(params, prev, curr, t)
        tr.trace.compute_log_prob()
        for name, site in tr.trace.nodes.items():
            if site["type"] == "sample":
                log_prob = site["log_prob"]
                if log_prob.dim() <= self.max_plate_nesting:  # Not enumerated.
                    pyro.factor("transition_" + name, site["log_prob_sum"])
                    continue
                if self.is_regional and log_prob.shape[-1:] != R_shape:
                    # Poor man's tensor variable elimination.
                    log_prob = log_prob.expand(log_prob.shape[:-1] +
                                               R_shape) / R_shape[0]
                logp[name] = site["log_prob"]

        # Manually perform variable elimination.
        logp = reduce(operator.add, logp.values())
        logp = logp.reshape(Q**C, Q**C, T, -1)  # prev, curr, T, batch
        logp = logp.permute(3, 2, 0, 1).squeeze(0)  # batch, T, prev, curr
        logp = pyro.distributions.hmm._sequential_logmatmulexp(
            logp)  # batch, prev, curr
        logp = logp.reshape(-1, Q**C * Q**C).logsumexp(-1).sum()
        warn_if_nan(logp)
        pyro.factor("transition", logp)

        self._clear_plates()
Beispiel #8
0
    def differentiable_loss(self, model, guide, *args, **kwargs):
        loss, surrogate_loss = self.loss_and_surrogate_loss(model, guide, *args, **kwargs)

        warn_if_nan(loss, "loss")
        return loss + (surrogate_loss - surrogate_loss.detach())