Ejemplo n.º 1
0
    def loss_and_grads(self, model, guide, *args, **kwargs):
        """
        :returns: returns an estimate of the ELBO
        :rtype: float

        Estimates the ELBO using ``num_particles`` many samples (particles).
        Performs backward on the ELBO of each particle.
        """
        elbo = 0.0
        for model_trace, guide_trace in self._get_traces(model, guide, *args, **kwargs):
            elbo_particle = _compute_dice_elbo(model_trace, guide_trace)
            if is_identically_zero(elbo_particle):
                continue

            elbo += elbo_particle.item() / self.num_particles

            # collect parameters to train from model and guide
            trainable_params = any(site["type"] == "param"
                                   for trace in (model_trace, guide_trace)
                                   for site in trace.nodes.values())

            if trainable_params and elbo_particle.requires_grad:
                loss_particle = -elbo_particle
                (loss_particle / self.num_particles).backward(retain_graph=True)

        loss = -elbo
        if torch_isnan(loss):
            warnings.warn('Encountered NAN loss')
        return loss
Ejemplo n.º 2
0
def _warn_if_nan(name, value):
    if torch.is_tensor(value):
        value = value.item()
    if torch_isnan(value):
        warnings.warn("Encountered NAN log_prob_sum at site '{}'".format(name))
    if torch_isinf(value) and value > 0:
        warnings.warn("Encountered +inf log_prob_sum at site '{}'".format(name))
Ejemplo n.º 3
0
    def _build_basetree(self, z, r, z_grads, log_slice, direction, energy_current):
        step_size = self.step_size if direction == 1 else -self.step_size
        z_new, r_new, z_grads, potential_energy = single_step_velocity_verlet(
            z, r, self._potential_energy, step_size, z_grads=z_grads)
        energy_new = potential_energy + self._kinetic_energy(r_new)
        sliced_energy = energy_new + log_slice

        # As a part of the slice sampling process (see below), along the trajectory
        #     we eliminate states which p(z, r) < u, or dE > 0.
        # Due to this elimination (and stop doubling conditions),
        #     the size of binary tree might not equal to 2^tree_depth.
        tree_size = 1 if sliced_energy <= 0 else 0
        # Special case: Set diverging to True and accept prob to 0 if the
        # diverging trajectory returns `NaN` energy (e.g. in the case of
        # evaluating log prob of a value simulated using a large step size
        # for a constrained sample site).
        if torch_isnan(energy_new):
            diverging = True
            accept_prob = energy_new.new_tensor(0.0)
        else:
            diverging = (sliced_energy >= self._max_sliced_energy)
            delta_energy = energy_new - energy_current
            accept_prob = (-delta_energy).exp().clamp(max=1)
        return _TreeInfo(z_new, r_new, z_grads, z_new, r_new, z_grads,
                         z_new, tree_size, False, diverging, accept_prob, 1)
Ejemplo n.º 4
0
def test_independent(base_dist, sample_shape, batch_shape,
                     reinterpreted_batch_ndims):
    if batch_shape:
        base_dist = base_dist.expand_by(batch_shape)
    if reinterpreted_batch_ndims > len(base_dist.batch_shape):
        with pytest.raises(ValueError):
            d = dist.Independent(base_dist, reinterpreted_batch_ndims)
    else:
        d = dist.Independent(base_dist, reinterpreted_batch_ndims)
        assert (d.batch_shape == batch_shape[:len(batch_shape) -
                                             reinterpreted_batch_ndims])
        assert (d.event_shape == batch_shape[len(batch_shape) -
                                             reinterpreted_batch_ndims:] +
                base_dist.event_shape)

        assert d.sample().shape == batch_shape + base_dist.event_shape
        assert d.mean.shape == batch_shape + base_dist.event_shape
        assert d.variance.shape == batch_shape + base_dist.event_shape
        x = d.sample(sample_shape)
        assert x.shape == sample_shape + d.batch_shape + d.event_shape

        log_prob = d.log_prob(x)
        assert (log_prob.shape == sample_shape +
                batch_shape[:len(batch_shape) - reinterpreted_batch_ndims])
        assert not torch_isnan(log_prob)
        log_prob_0 = base_dist.log_prob(x)
        assert_equal(log_prob,
                     _sum_rightmost(log_prob_0, reinterpreted_batch_ndims))
Ejemplo n.º 5
0
    def loss_and_grads(self, model, guide, *args, **kwargs):
        """
        :returns: returns an estimate of the ELBO
        :rtype: float

        Estimates the ELBO using ``num_particles`` many samples (particles).
        Performs backward on the ELBO of each particle.
        """
        elbo = 0.0
        for model_trace, guide_trace in self._get_traces(
                model, guide, *args, **kwargs):
            elbo_particle = _compute_dice_elbo(model_trace, guide_trace)
            if is_identically_zero(elbo_particle):
                continue

            elbo += elbo_particle.item() / self.num_particles

            # collect parameters to train from model and guide
            trainable_params = any(site["type"] == "param"
                                   for trace in (model_trace, guide_trace)
                                   for site in trace.nodes.values())

            if trainable_params and elbo_particle.requires_grad:
                loss_particle = -elbo_particle
                (loss_particle /
                 self.num_particles).backward(retain_graph=True)

        loss = -elbo
        if torch_isnan(loss):
            warnings.warn('Encountered NAN loss')
        return loss
Ejemplo n.º 6
0
    def _build_basetree(self, z, r, z_grads, log_slice, direction,
                        energy_current):
        step_size = self.step_size if direction == 1 else -self.step_size
        z_new, r_new, z_grads, potential_energy = single_step_velocity_verlet(
            z, r, self._potential_energy, step_size, z_grads=z_grads)
        energy_new = potential_energy + self._kinetic_energy(r_new)
        sliced_energy = energy_new + log_slice

        # As a part of the slice sampling process (see below), along the trajectory
        #     we eliminate states which p(z, r) < u, or dE > 0.
        # Due to this elimination (and stop doubling conditions),
        #     the size of binary tree might not equal to 2^tree_depth.
        tree_size = 1 if sliced_energy <= 0 else 0
        # Special case: Set diverging to True and accept prob to 0 if the
        # diverging trajectory returns `NaN` energy (e.g. in the case of
        # evaluating log prob of a value simulated using a large step size
        # for a constrained sample site).
        if torch_isnan(energy_new):
            diverging = True
            accept_prob = energy_new.new_tensor(0.0)
        else:
            diverging = (sliced_energy >= self._max_sliced_energy)
            delta_energy = energy_new - energy_current
            accept_prob = (-delta_energy).exp().clamp(max=1)
        return _TreeInfo(z_new, r_new, z_grads, z_new, r_new, z_grads, z_new,
                         tree_size, False, diverging, accept_prob, 1)
Ejemplo n.º 7
0
Archivo: nuts.py Proyecto: zyxue/pyro
    def _build_basetree(self, z, r, z_grads, log_slice, direction,
                        energy_current):
        step_size = self.step_size if direction == 1 else -self.step_size
        z_new, r_new, z_grads, potential_energy = velocity_verlet(
            z,
            r,
            self._potential_energy,
            self.inverse_mass_matrix,
            step_size,
            z_grads=z_grads)
        r_new_flat = torch.cat(
            [r_new[site_name].reshape(-1) for site_name in sorted(r_new)])
        energy_new = potential_energy + self._kinetic_energy(r_new)
        # handle the NaN case
        energy_new = energy_new.new_tensor(
            float("inf")) if torch_isnan(energy_new) else energy_new
        sliced_energy = energy_new + log_slice
        diverging = (sliced_energy > self._max_sliced_energy)
        delta_energy = energy_new - energy_current
        accept_prob = (-delta_energy).exp().clamp(max=1.0)

        if self.use_multinomial_sampling:
            tree_weight = -sliced_energy
        else:
            # As a part of the slice sampling process (see below), along the trajectory
            #   we eliminate states which p(z, r) < u, or dE > 0.
            # Due to this elimination (and stop doubling conditions),
            #   the weight of binary tree might not equal to 2^tree_depth.
            tree_weight = (sliced_energy.new_ones(
                ()) if sliced_energy <= 0 else sliced_energy.new_zeros(()))

        return _TreeInfo(z_new, r_new, z_grads, z_new, r_new, z_grads, z_new,
                         potential_energy, z_grads, r_new_flat, tree_weight,
                         False, diverging, accept_prob, 1)
Ejemplo n.º 8
0
def init_to_mean(
    site=None,
    *,
    fallback: Optional[Callable] = init_to_median,
):
    """
    Initialize to the prior mean; fallback to ``fallback`` (defaults to
    :func:`init_to_median`) if mean is undefined.

    :param callable fallback: Fallback init strategy, for sites not specified
        in ``values``.
    :raises ValueError: If ``fallback=None`` and no value for a site is given
        in ``values``.
    """
    if site is None:
        return functools.partial(init_to_mean, fallback=fallback)

    try:
        # Try .mean() method.
        value = site["fn"].mean.detach()
        if torch_isnan(value):
            raise ValueError
        if hasattr(site["fn"], "_validate_sample"):
            site["fn"]._validate_sample(value)
        value._pyro_custom_init = False
        return value
    except (NotImplementedError, ValueError):
        # This may happen for distributions with infinite variance, e.g. Cauchy.
        pass
    if fallback is not None:
        return fallback(site)
    raise ValueError(
        f"No init strategy specified for site {repr(site['name'])}")
Ejemplo n.º 9
0
def test_masked_mixture_multivariate(sample_shape, batch_shape):
    event_shape = torch.Size((8,))
    component0 = dist.MultivariateNormal(
        torch.zeros(event_shape), torch.eye(event_shape[0])
    )
    component1 = dist.Uniform(
        torch.zeros(event_shape), torch.ones(event_shape)
    ).to_event(1)
    if batch_shape:
        component0 = component0.expand_by(batch_shape)
        component1 = component1.expand_by(batch_shape)
    mask = torch.empty(batch_shape).bernoulli_(0.5).bool()
    d = dist.MaskedMixture(mask, component0, component1)
    assert d.batch_shape == batch_shape
    assert d.event_shape == event_shape

    assert d.sample().shape == batch_shape + event_shape
    assert d.mean.shape == batch_shape + event_shape
    assert d.variance.shape == batch_shape + event_shape
    x = d.sample(sample_shape)
    assert x.shape == sample_shape + batch_shape + event_shape

    log_prob = d.log_prob(x)
    assert log_prob.shape == sample_shape + batch_shape
    assert not torch_isnan(log_prob)
    log_prob_0 = component0.log_prob(x)
    log_prob_1 = component1.log_prob(x)
    mask = mask.expand(sample_shape + batch_shape)
    assert_equal(log_prob[mask], log_prob_1[mask])
    assert_equal(log_prob[~mask], log_prob_0[~mask])
    def loss_and_grads(self, model, guide, *args, **kwargs):
        """
        :returns: returns an estimate of the ELBO
        :rtype: float

        Computes the ELBO as well as the surrogate ELBO that is used to form the gradient estimator.
        Performs backward on the latter. Num_particle many samples are used to form the estimators.
        """
        elbo = 0.0
        # grab a trace from the generator
        for model_trace, guide_trace in self._get_traces(
                model, guide, *args, **kwargs):
            elbo_particle = 0
            surrogate_elbo_particle = 0
            log_r = None

            # compute elbo and surrogate elbo
            for name, site in model_trace.nodes.items():
                if site["type"] == "sample":
                    elbo_particle = elbo_particle + torch_item(
                        site["log_prob_sum"])
                    surrogate_elbo_particle = surrogate_elbo_particle + site[
                        "log_prob_sum"]

            for name, site in guide_trace.nodes.items():
                if site["type"] == "sample":
                    log_prob, score_function_term, entropy_term = site[
                        "score_parts"]

                    elbo_particle = elbo_particle - torch_item(
                        site["log_prob_sum"])

                    if not is_identically_zero(entropy_term):
                        surrogate_elbo_particle = surrogate_elbo_particle - entropy_term.sum(
                        )

                    if not is_identically_zero(score_function_term):
                        if log_r is None:
                            log_r = _compute_log_r(model_trace, guide_trace)
                        site = log_r.sum_to(site["cond_indep_stack"])
                        surrogate_elbo_particle = surrogate_elbo_particle + (
                            site * score_function_term).sum()

            elbo += elbo_particle / self.num_particles

            # collect parameters to train from model and guide
            trainable_params = any(site["type"] == "param"
                                   for trace in (model_trace, guide_trace)
                                   for site in trace.nodes.values())

            if trainable_params and getattr(surrogate_elbo_particle,
                                            'requires_grad', False):
                surrogate_loss_particle = -surrogate_elbo_particle / self.num_particles
                surrogate_loss_particle.backward()

        loss = -elbo
        if torch_isnan(loss):
            warnings.warn('Encountered NAN loss')
        return loss
Ejemplo n.º 11
0
    def fit(
        self,
        x,
        t,
        y,
        num_epochs=100,
        batch_size=100,
        learning_rate=1e-3,
        learning_rate_decay=0.1,
        weight_decay=1e-4,
        log_every=100,
    ):
        """
        Train using :class:`~pyro.infer.svi.SVI` with the
        :class:`TraceCausalEffect_ELBO` loss.

        :param ~torch.Tensor x:
        :param ~torch.Tensor t:
        :param ~torch.Tensor y:
        :param int num_epochs: Number of training epochs. Defaults to 100.
        :param int batch_size: Batch size. Defaults to 100.
        :param float learning_rate: Learning rate. Defaults to 1e-3.
        :param float learning_rate_decay: Learning rate decay over all epochs;
            the per-step decay rate will depend on batch size and number of epochs
            such that the initial learning rate will be ``learning_rate`` and the final
            learning rate will be ``learning_rate * learning_rate_decay``.
            Defaults to 0.1.
        :param float weight_decay: Weight decay. Defaults to 1e-4.
        :param int log_every: Log loss each this-many steps. If zero,
            do not log loss. Defaults to 100.
        :return: list of epoch losses
        """
        assert x.dim() == 2 and x.size(-1) == self.feature_dim
        assert t.shape == x.shape[:1]
        assert y.shape == y.shape[:1]
        self.whiten = PreWhitener(x)

        dataset = TensorDataset(x, t, y)
        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
        logger.info("Training with {} minibatches per epoch".format(
            len(dataloader)))
        num_steps = num_epochs * len(dataloader)
        optim = ClippedAdam({
            "lr": learning_rate,
            "weight_decay": weight_decay,
            "lrd": learning_rate_decay**(1 / num_steps),
        })
        svi = SVI(self.model, self.guide, optim, TraceCausalEffect_ELBO())
        losses = []
        for epoch in range(num_epochs):
            for x, t, y in dataloader:
                x = self.whiten(x)
                loss = svi.step(x, t, y, size=len(dataset)) / len(dataset)
                if log_every and len(losses) % log_every == 0:
                    logger.debug("step {: >5d} loss = {:0.6g}".format(
                        len(losses), loss))
                assert not torch_isnan(loss)
                losses.append(loss)
        return losses
Ejemplo n.º 12
0
    def loss_and_grads(self, model, guide, *args, **kwargs):
        if getattr(self, '_loss_and_surrogate_loss', None) is None:
            # build a closure for loss_and_surrogate_loss
            weakself = weakref.ref(self)

            @pyro.ops.jit.compile(nderivs=1)
            def loss_and_surrogate_loss(*args):
                self = weakself()
                loss = 0.0
                surrogate_loss = 0.0
                for model_trace, guide_trace in self._get_traces(
                        model, guide, *args, **kwargs):
                    elbo_particle = 0
                    surrogate_elbo_particle = 0
                    log_r = None

                    # compute elbo and surrogate elbo
                    for name, site in model_trace.nodes.items():
                        if site["type"] == "sample":
                            elbo_particle = elbo_particle + site["log_prob_sum"]
                            surrogate_elbo_particle = surrogate_elbo_particle + site[
                                "log_prob_sum"]

                    for name, site in guide_trace.nodes.items():
                        if site["type"] == "sample":
                            log_prob, score_function_term, entropy_term = site[
                                "score_parts"]

                            elbo_particle = elbo_particle - site["log_prob_sum"]

                            if not is_identically_zero(entropy_term):
                                surrogate_elbo_particle = surrogate_elbo_particle - entropy_term.sum(
                                )

                            if not is_identically_zero(score_function_term):
                                if log_r is None:
                                    log_r = _compute_log_r(
                                        model_trace, guide_trace)
                                site = log_r.sum_to(site["cond_indep_stack"])
                                surrogate_elbo_particle = surrogate_elbo_particle + (
                                    site * score_function_term).sum()

                    loss = loss - elbo_particle / self.num_particles
                    surrogate_loss = surrogate_loss - surrogate_elbo_particle / self.num_particles

                return loss, surrogate_loss

            self._loss_and_surrogate_loss = loss_and_surrogate_loss

        # invoke _loss_and_surrogate_loss
        loss, surrogate_loss = self._loss_and_surrogate_loss(*args)
        surrogate_loss.backward()  # this line triggers jit compilation
        loss = loss.item()

        if torch_isnan(loss):
            warnings.warn('Encountered NAN loss')
        return loss
Ejemplo n.º 13
0
    def sample(self, params):
        z, potential_energy, z_grads = self._fetch_from_cache()
        # recompute PE when cache is cleared
        if z is None:
            z = params
            z_grads, potential_energy = potential_grad(self.potential_fn, z)
            self._cache(z, potential_energy, z_grads)
        # return early if no sample sites
        elif len(z) == 0:
            self._t += 1
            self._mean_accept_prob = 1.
            if self._t > self._warmup_steps:
                self._accept_cnt += 1
            return params
        r, r_unscaled = self._sample_r(name="r_t={}".format(self._t))
        energy_current = self._kinetic_energy(r_unscaled) + potential_energy

        # Temporarily disable distributions args checking as
        # NaNs are expected during step size adaptation
        with optional(pyro.validation_enabled(False), self._t < self._warmup_steps):
            z_new, r_new, z_grads_new, potential_energy_new = velocity_verlet(
                z, r, self.potential_fn, self.mass_matrix_adapter.kinetic_grad,
                self.step_size, self.num_steps, z_grads=z_grads)
            # apply Metropolis correction.
            r_new_unscaled = self.mass_matrix_adapter.unscale(r_new)
            energy_proposal = self._kinetic_energy(r_new_unscaled) + potential_energy_new
        delta_energy = energy_proposal - energy_current
        # handle the NaN case which may be the case for a diverging trajectory
        # when using a large step size.
        delta_energy = scalar_like(delta_energy, float("inf")) if torch_isnan(delta_energy) else delta_energy
        if delta_energy > self._max_sliced_energy and self._t >= self._warmup_steps:
            self._divergences.append(self._t - self._warmup_steps)

        accept_prob = (-delta_energy).exp().clamp(max=1.)
        rand = pyro.sample("rand_t={}".format(self._t), dist.Uniform(scalar_like(accept_prob, 0.),
                                                                     scalar_like(accept_prob, 1.)))
        accepted = False
        if rand < accept_prob:
            accepted = True
            z = z_new
            z_grads = z_grads_new
            self._cache(z, potential_energy_new, z_grads)

        self._t += 1
        if self._t > self._warmup_steps:
            n = self._t - self._warmup_steps
            if accepted:
                self._accept_cnt += 1
        else:
            n = self._t
            self._adapter.step(self._t, z, accept_prob, z_grads)

        self._mean_accept_prob += (accept_prob.item() - self._mean_accept_prob) / n
        return z.copy()
Ejemplo n.º 14
0
    def loss_and_grads(self, model, guide, *args, **kwargs):
        if getattr(self, '_loss_and_surrogate_loss', None) is None:
            # build a closure for loss_and_surrogate_loss
            weakself = weakref.ref(self)

            @pyro.ops.jit.compile(nderivs=1)
            def loss_and_surrogate_loss(*args):
                self = weakself()
                loss = 0.0
                surrogate_loss = 0.0
                for weight, model_trace, guide_trace in self._get_traces(
                        model, guide, *args, **kwargs):
                    model_trace.compute_log_prob()
                    guide_trace.compute_score_parts()
                    if is_validation_enabled():
                        for site in model_trace.nodes.values():
                            if site["type"] == "sample":
                                check_site_shape(site,
                                                 self.max_iarange_nesting)
                        for site in guide_trace.nodes.values():
                            if site["type"] == "sample":
                                check_site_shape(site,
                                                 self.max_iarange_nesting)

                    # compute elbo for reparameterized nodes
                    non_reparam_nodes = set(
                        guide_trace.nonreparam_stochastic_nodes)
                    elbo, surrogate_elbo = _compute_elbo_reparam(
                        model_trace, guide_trace, non_reparam_nodes)

                    # the following computations are only necessary if we have non-reparameterizable nodes
                    baseline_loss = 0.0
                    if non_reparam_nodes:
                        downstream_costs, _ = _compute_downstream_costs(
                            model_trace, guide_trace, non_reparam_nodes)
                        surrogate_elbo_term, baseline_loss = _compute_elbo_non_reparam(
                            guide_trace, non_reparam_nodes, downstream_costs)
                        surrogate_elbo += surrogate_elbo_term

                    loss = loss - weight * elbo
                    surrogate_loss = surrogate_loss - weight * surrogate_elbo

                return loss, surrogate_loss

            self._loss_and_surrogate_loss = loss_and_surrogate_loss

        loss, surrogate_loss = self._loss_and_surrogate_loss(*args)
        surrogate_loss.backward()  # this line triggers jit compilation
        loss = loss.item()

        if torch_isnan(loss):
            warnings.warn('Encountered NAN loss')
        return loss
Ejemplo n.º 15
0
Archivo: eig.py Proyecto: pyro-ppl/pyro
 def __call__(self, inputs, s, dim=0, keepdim=False):
     """Updates the moving average, and returns :code:`inputs.log()`."""
     self.n += 1
     if torch_isnan(self.ewma) or torch_isinf(self.ewma):
         ewma = inputs
     else:
         ewma = inputs * (1.0 - self.alpha) / (
             1 - self.alpha**self.n) + torch.exp(self.s - s) * self.ewma * (
                 self.alpha - self.alpha**self.n) / (1 - self.alpha**self.n)
     self.ewma = ewma.detach()
     self.s = s.detach()
     return _ewma_log_fn(inputs, ewma)
Ejemplo n.º 16
0
    def _build_basetree(self, z, r, z_grads, log_slice, direction,
                        energy_current):
        step_size = self.step_size if direction == 1 else -self.step_size
        z_new, r_new, z_grads, potential_energy = velocity_verlet(
            z,
            r,
            self.potential_fn,
            self.mass_matrix_adapter.kinetic_grad,
            step_size,
            z_grads=z_grads,
        )
        r_new_unscaled = self.mass_matrix_adapter.unscale(r_new)
        energy_new = potential_energy + self._kinetic_energy(r_new_unscaled)
        # handle the NaN case
        energy_new = (scalar_like(energy_new, float("inf"))
                      if torch_isnan(energy_new) else energy_new)
        sliced_energy = energy_new + log_slice
        diverging = sliced_energy > self._max_sliced_energy
        delta_energy = energy_new - energy_current
        accept_prob = (-delta_energy).exp().clamp(max=1.0)

        if self.use_multinomial_sampling:
            tree_weight = -sliced_energy
        else:
            # As a part of the slice sampling process (see below), along the trajectory
            #   we eliminate states which p(z, r) < u, or dE > 0.
            # Due to this elimination (and stop doubling conditions),
            #   the weight of binary tree might not equal to 2^tree_depth.
            tree_weight = scalar_like(sliced_energy,
                                      1.0 if sliced_energy <= 0 else 0.0)

        r_sum = r_new_unscaled
        return _TreeInfo(
            z_new,
            r_new,
            r_new_unscaled,
            z_grads,
            z_new,
            r_new,
            r_new_unscaled,
            z_grads,
            z_new,
            potential_energy,
            z_grads,
            r_sum,
            tree_weight,
            False,
            diverging,
            accept_prob,
            1,
        )
Ejemplo n.º 17
0
    def sample(self, trace):
        z = {
            name: node["value"].detach()
            for name, node in self._iter_latent_nodes(trace)
        }
        # automatically transform `z` to unconstrained space, if needed.
        for name, transform in self.transforms.items():
            z[name] = transform(z[name])

        r, _ = self._sample_r(name="r_t={}".format(self._t))

        potential_energy, z_grads = self._fetch_from_cache()
        # Temporarily disable distributions args checking as
        # NaNs are expected during step size adaptation
        with optional(pyro.validation_enabled(False),
                      self._t < self._warmup_steps):
            z_new, r_new, z_grads_new, potential_energy_new = velocity_verlet(
                z,
                r,
                self._potential_energy,
                self.inverse_mass_matrix,
                self.step_size,
                self.num_steps,
                z_grads=z_grads)
            # apply Metropolis correction.
            energy_proposal = self._kinetic_energy(
                r_new) + potential_energy_new
            energy_current = self._kinetic_energy(r) + potential_energy if potential_energy is not None \
                else self._energy(z, r)
        delta_energy = energy_proposal - energy_current
        # Set accept prob to 0.0 if delta_energy is `NaN` which may be
        # the case for a diverging trajectory when using a large step size.
        if torch_isnan(delta_energy):
            accept_prob = delta_energy.new_tensor(0.0)
        else:
            accept_prob = (-delta_energy).exp().clamp(max=1.)
        rand = pyro.sample("rand_t={}".format(self._t),
                           dist.Uniform(torch.zeros(1), torch.ones(1)))
        if rand < accept_prob:
            self._accept_cnt += 1
            z = z_new

        if self._t < self._warmup_steps:
            self._adapter.step(self._t, z, accept_prob)

        self._t += 1

        # get trace with the constrained values for `z`.
        for name, transform in self.transforms.items():
            z[name] = transform.inv(z[name])
        return self._get_trace(z)
Ejemplo n.º 18
0
    def loss_and_grads(self, model, guide, *args, **kwargs):
        if getattr(self, '_loss_and_surrogate_loss', None) is None:
            # build a closure for loss_and_surrogate_loss
            weakself = weakref.ref(self)

            @pyro.ops.jit.compile(nderivs=1)
            def loss_and_surrogate_loss(*args):
                self = weakself()
                loss = 0.0
                surrogate_loss = 0.0
                for weight, model_trace, guide_trace in self._get_traces(model, guide, *args, **kwargs):
                    model_trace.compute_log_prob()
                    guide_trace.compute_score_parts()
                    if is_validation_enabled():
                        for site in model_trace.nodes.values():
                            if site["type"] == "sample":
                                check_site_shape(site, self.max_iarange_nesting)
                        for site in guide_trace.nodes.values():
                            if site["type"] == "sample":
                                check_site_shape(site, self.max_iarange_nesting)

                    # compute elbo for reparameterized nodes
                    non_reparam_nodes = set(guide_trace.nonreparam_stochastic_nodes)
                    elbo, surrogate_elbo = _compute_elbo_reparam(model_trace, guide_trace, non_reparam_nodes)

                    # the following computations are only necessary if we have non-reparameterizable nodes
                    baseline_loss = 0.0
                    if non_reparam_nodes:
                        downstream_costs, _ = _compute_downstream_costs(model_trace, guide_trace, non_reparam_nodes)
                        surrogate_elbo_term, baseline_loss = _compute_elbo_non_reparam(guide_trace,
                                                                                       non_reparam_nodes,
                                                                                       downstream_costs)
                        surrogate_elbo += surrogate_elbo_term

                    loss = loss - weight * elbo
                    surrogate_loss = surrogate_loss - weight * surrogate_elbo

                return loss, surrogate_loss

            self._loss_and_surrogate_loss = loss_and_surrogate_loss

        loss, surrogate_loss = self._loss_and_surrogate_loss(*args)
        surrogate_loss.backward()  # this line triggers jit compilation
        loss = loss.item()

        if torch_isnan(loss):
            warnings.warn('Encountered NAN loss')
        return loss
Ejemplo n.º 19
0
def init_to_mean(site):
    """
    Initialize to the prior mean; fallback to median if mean is undefined.
    """
    try:
        # Try .mean() method.
        value = site["fn"].mean.detach()
        if torch_isnan(value):
            raise ValueError
        if hasattr(site["fn"], "_validate_sample"):
            site["fn"]._validate_sample(value)
        return value
    except (NotImplementedError, ValueError):
        # Fall back to a median.
        # This is requred for distributions with infinite variance, e.g. Cauchy.
        return init_to_median(site)
Ejemplo n.º 20
0
    def loss(self, model, guide, *args, **kwargs):
        """
        :returns: returns an estimate of the ELBO
        :rtype: float

        Evaluates the ELBO with an estimator that uses num_particles many samples/particles.
        """
        elbo = 0.0
        for model_trace, guide_trace in self._get_traces(model, guide, *args, **kwargs):
            elbo_particle = torch_item(model_trace.log_prob_sum()) - torch_item(guide_trace.log_prob_sum())
            elbo += elbo_particle / self.num_particles

        loss = -elbo
        if torch_isnan(loss):
            warnings.warn('Encountered NAN loss')
        return loss
Ejemplo n.º 21
0
    def loss(self, model, guide, *args, **kwargs):
        """
        :returns: returns an estimate of the ELBO
        :rtype: float

        Evaluates the ELBO with an estimator that uses num_particles many samples/particles.
        """
        elbo = 0.0
        for weight, model_trace, guide_trace in self._get_traces(model, guide, *args, **kwargs):
            elbo_particle = torch_item(model_trace.log_prob_sum()) - torch_item(guide_trace.log_prob_sum())
            elbo += weight * elbo_particle

        loss = -elbo
        if torch_isnan(loss):
            warnings.warn('Encountered NAN loss')
        return loss
Ejemplo n.º 22
0
    def sample(self, trace):
        z = {
            name: node["value"].detach()
            for name, node in trace.iter_stochastic_nodes()
        }
        # automatically transform `z` to unconstrained space, if needed.
        for name, transform in self.transforms.items():
            z[name] = transform(z[name])
        r = {
            name: pyro.sample("r_{}_t={}".format(name, self._t),
                              self._r_dist[name])
            for name in self._r_dist
        }

        # Temporarily disable distributions args checking as
        # NaNs are expected during step size adaptation
        dist_arg_check = False if self._adapt_phase else pyro.distributions.is_validation_enabled(
        )
        with dist.validation_enabled(dist_arg_check):
            z_new, r_new = velocity_verlet(z, r, self._potential_energy,
                                           self.step_size, self.num_steps)
            # apply Metropolis correction.
            energy_proposal = self._energy(z_new, r_new)
            energy_current = self._energy(z, r)
        delta_energy = energy_proposal - energy_current
        rand = pyro.sample("rand_t={}".format(self._t),
                           dist.Uniform(torch.zeros(1), torch.ones(1)))
        if rand < (-delta_energy).exp():
            self._accept_cnt += 1
            z = z_new

        if self._adapt_phase:
            # Set accept prob to 0.0 if delta_energy is `NaN` which may be
            # the case for a diverging trajectory when using a large step size.
            if torch_isnan(delta_energy):
                accept_prob = delta_energy.new_tensor(0.0)
            else:
                accept_prob = (-delta_energy).exp().clamp(max=1).item()
            self._adapt_step_size(accept_prob)

        self._t += 1
        # get trace with the constrained values for `z`.
        for name, transform in self.transforms.items():
            z[name] = transform.inv(z[name])
        return self._get_trace(z)
Ejemplo n.º 23
0
    def loss(self, model, guide, *args, **kwargs):
        """
        :returns: returns an estimate of the ELBO
        :rtype: float

        Estimates the ELBO using ``num_particles`` many samples (particles).
        """
        elbo = 0.0
        for model_trace, guide_trace in self._get_traces(model, guide, *args, **kwargs):
            elbo_particle = _compute_dice_elbo(model_trace, guide_trace)
            if is_identically_zero(elbo_particle):
                continue

            elbo += elbo_particle.item() / self.num_particles

        loss = -elbo
        if torch_isnan(loss):
            warnings.warn('Encountered NAN loss')
        return loss
Ejemplo n.º 24
0
def init_to_median(site, num_samples=15):
    """
    Initialize to the prior median; fallback to a feasible point if median is
    undefined.
    """
    # The median undefined for multivariate distributions.
    if _is_multivariate(site["fn"]):
        return init_to_feasible(site)
    try:
        # Try to compute empirical median.
        samples = site["fn"].sample(sample_shape=(num_samples,))
        value = samples.median(dim=0)[0]
        if torch_isnan(value):
            raise ValueError
        if hasattr(site["fn"], "_validate_sample"):
            site["fn"]._validate_sample(value)
        return value
    except (RuntimeError, ValueError):
        # Fall back to feasible point.
        return init_to_feasible(site)
Ejemplo n.º 25
0
    def loss(self, model, guide, *args, **kwargs):
        """
        :returns: returns an estimate of the ELBO
        :rtype: float

        Estimates the ELBO using ``num_particles`` many samples (particles).
        """
        elbo = 0.0
        for model_trace, guide_trace in self._get_traces(
                model, guide, *args, **kwargs):
            elbo_particle = _compute_dice_elbo(model_trace, guide_trace)
            if is_identically_zero(elbo_particle):
                continue

            elbo += elbo_particle.item() / self.num_particles

        loss = -elbo
        if torch_isnan(loss):
            warnings.warn('Encountered NAN loss')
        return loss
Ejemplo n.º 26
0
 def initial_trace(self):
     """
     Find a valid trace to initiate the MCMC sampler. This is also used as a
     prototype trace to inter-convert between Pyro's trace object and dict
     object used by the integrator.
     """
     if self._initial_trace:
         return self._initial_trace
     trace = poutine.trace(self.model).get_trace(*self._args,
                                                 **self._kwargs)
     for i in range(self._max_tries_initial_trace):
         trace_log_prob_sum = self._compute_trace_log_prob(trace)
         if not torch_isnan(trace_log_prob_sum) and not torch_isinf(
                 trace_log_prob_sum):
             self._initial_trace = trace
             return trace
         trace = poutine.trace(self.model).get_trace(
             self._args, self._kwargs)
     raise ValueError(
         "Model specification seems incorrect - cannot find a valid trace.")
Ejemplo n.º 27
0
    def _loss_and_grads_particle(self, weight, model_trace, guide_trace):
        # have the trace compute all the individual (batch) log pdf terms
        # and score function terms (if present) so that they are available below
        model_trace.compute_log_prob()
        guide_trace.compute_score_parts()
        if is_validation_enabled():
            for site in model_trace.nodes.values():
                if site["type"] == "sample":
                    check_site_shape(site, self.max_iarange_nesting)
            for site in guide_trace.nodes.values():
                if site["type"] == "sample":
                    check_site_shape(site, self.max_iarange_nesting)

        # compute elbo for reparameterized nodes
        non_reparam_nodes = set(guide_trace.nonreparam_stochastic_nodes)
        elbo, surrogate_elbo = _compute_elbo_reparam(model_trace, guide_trace,
                                                     non_reparam_nodes)

        # the following computations are only necessary if we have non-reparameterizable nodes
        baseline_loss = 0.0
        if non_reparam_nodes:
            downstream_costs, _ = _compute_downstream_costs(
                model_trace, guide_trace, non_reparam_nodes)
            surrogate_elbo_term, baseline_loss = _compute_elbo_non_reparam(
                guide_trace, non_reparam_nodes, downstream_costs)
            surrogate_elbo += surrogate_elbo_term

        # collect parameters to train from model and guide
        trainable_params = any(site["type"] == "param"
                               for trace in (model_trace, guide_trace)
                               for site in trace.nodes.values())

        if trainable_params:
            surrogate_loss = -surrogate_elbo
            torch_backward(weight * (surrogate_loss + baseline_loss))

        loss = -torch_item(elbo)
        if torch_isnan(loss):
            warnings.warn('Encountered NAN loss')
        return weight * loss
Ejemplo n.º 28
0
    def loss_and_grads(self, model, guide, *args, **kwargs):
        if getattr(self, '_differentiable_loss', None) is None:

            weakself = weakref.ref(self)

            @pyro.ops.jit.compile(nderivs=1)
            def differentiable_loss(*args):
                self = weakself()
                elbo = 0.0
                for model_trace, guide_trace in self._get_traces(model, guide, *args, **kwargs):
                    elbo += _compute_dice_elbo(model_trace, guide_trace)
                return elbo * (-1.0 / self.num_particles)

            self._differentiable_loss = differentiable_loss

        differentiable_loss = self._differentiable_loss(*args)
        differentiable_loss.backward()  # this line triggers jit compilation
        loss = differentiable_loss.item()

        if torch_isnan(loss):
            warnings.warn('Encountered NAN loss')
        return loss
Ejemplo n.º 29
0
def init_to_median(
    site=None,
    num_samples=15,
    *,
    fallback: Optional[Callable] = init_to_feasible,
):
    """
    Initialize to the prior median; fallback to ``fallback`` (defaults to
    :func:`init_to_feasible`) if mean is undefined.

    :param callable fallback: Fallback init strategy, for sites not specified
        in ``values``.
    :raises ValueError: If ``fallback=None`` and no value for a site is given
        in ``values``.
    """
    if site is None:
        return functools.partial(init_to_median,
                                 num_samples=num_samples,
                                 fallback=fallback)

    # The median undefined for multivariate distributions.
    if _is_multivariate(site["fn"]):
        return init_to_feasible(site)
    try:
        # Try to compute empirical median.
        samples = site["fn"].sample(sample_shape=(num_samples, ))
        value = samples.median(dim=0)[0]
        if torch_isnan(value):
            raise ValueError
        if hasattr(site["fn"], "_validate_sample"):
            site["fn"]._validate_sample(value)
        value._pyro_custom_init = False
        return value
    except (RuntimeError, ValueError):
        pass
    if fallback is not None:
        return fallback(site)
    raise ValueError(
        f"No init strategy specified for site {repr(site['name'])}")
Ejemplo n.º 30
0
Archivo: hmc.py Proyecto: lewisKit/pyro
    def sample(self, trace):
        z = {name: node["value"].detach() for name, node in trace.iter_stochastic_nodes()}
        # automatically transform `z` to unconstrained space, if needed.
        for name, transform in self.transforms.items():
            z[name] = transform(z[name])
        r = {name: pyro.sample("r_{}_t={}".format(name, self._t), self._r_dist[name])
             for name in self._r_dist}

        # Temporarily disable distributions args checking as
        # NaNs are expected during step size adaptation
        dist_arg_check = False if self._adapt_phase else pyro.distributions.is_validation_enabled()
        with dist.validation_enabled(dist_arg_check):
            z_new, r_new = velocity_verlet(z, r,
                                           self._potential_energy,
                                           self.step_size,
                                           self.num_steps)
            # apply Metropolis correction.
            energy_proposal = self._energy(z_new, r_new)
            energy_current = self._energy(z, r)
        delta_energy = energy_proposal - energy_current
        rand = pyro.sample("rand_t={}".format(self._t), dist.Uniform(torch.zeros(1), torch.ones(1)))
        if rand < (-delta_energy).exp():
            self._accept_cnt += 1
            z = z_new

        if self._adapt_phase:
            # Set accept prob to 0.0 if delta_energy is `NaN` which may be
            # the case for a diverging trajectory when using a large step size.
            if torch_isnan(delta_energy):
                accept_prob = delta_energy.new_tensor(0.0)
            else:
                accept_prob = (-delta_energy).exp().clamp(max=1).item()
            self._adapt_step_size(accept_prob)

        self._t += 1
        # get trace with the constrained values for `z`.
        for name, transform in self.transforms.items():
            z[name] = transform.inv(z[name])
        return self._get_trace(z)
Ejemplo n.º 31
0
    def _loss_and_grads_particle(self, weight, model_trace, guide_trace):
        # have the trace compute all the individual (batch) log pdf terms
        # and score function terms (if present) so that they are available below
        model_trace.compute_log_prob()
        guide_trace.compute_score_parts()
        if is_validation_enabled():
            for site in model_trace.nodes.values():
                if site["type"] == "sample":
                    check_site_shape(site, self.max_iarange_nesting)
            for site in guide_trace.nodes.values():
                if site["type"] == "sample":
                    check_site_shape(site, self.max_iarange_nesting)

        # compute elbo for reparameterized nodes
        non_reparam_nodes = set(guide_trace.nonreparam_stochastic_nodes)
        elbo, surrogate_elbo = _compute_elbo_reparam(model_trace, guide_trace, non_reparam_nodes)

        # the following computations are only necessary if we have non-reparameterizable nodes
        baseline_loss = 0.0
        if non_reparam_nodes:
            downstream_costs, _ = _compute_downstream_costs(model_trace, guide_trace, non_reparam_nodes)
            surrogate_elbo_term, baseline_loss = _compute_elbo_non_reparam(guide_trace,
                                                                           non_reparam_nodes, downstream_costs)
            surrogate_elbo += surrogate_elbo_term

        # collect parameters to train from model and guide
        trainable_params = any(site["type"] == "param"
                               for trace in (model_trace, guide_trace)
                               for site in trace.nodes.values())

        if trainable_params:
            surrogate_loss = -surrogate_elbo
            torch_backward(weight * (surrogate_loss + baseline_loss))

        loss = -torch_item(elbo)
        if torch_isnan(loss):
            warnings.warn('Encountered NAN loss')
        return weight * loss
Ejemplo n.º 32
0
def test_masked_mixture_univariate(component0, component1, sample_shape, batch_shape):
    if batch_shape:
        component0 = component0.expand_by(batch_shape)
        component1 = component1.expand_by(batch_shape)
    mask = torch.empty(batch_shape).bernoulli_(0.5).bool()
    d = dist.MaskedMixture(mask, component0, component1)
    assert d.batch_shape == batch_shape
    assert d.event_shape == ()

    assert d.sample().shape == batch_shape
    assert d.mean.shape == batch_shape
    assert d.variance.shape == batch_shape
    x = d.sample(sample_shape)
    assert x.shape == sample_shape + batch_shape

    log_prob = d.log_prob(x)
    assert log_prob.shape == sample_shape + batch_shape
    assert not torch_isnan(log_prob)
    log_prob_0 = component0.log_prob(x)
    log_prob_1 = component1.log_prob(x)
    mask = mask.expand(sample_shape + batch_shape)
    assert_equal(log_prob[mask], log_prob_1[mask])
    assert_equal(log_prob[~mask], log_prob_0[~mask])
Ejemplo n.º 33
0
    def loss_and_grads(self, model, guide, *args, **kwargs):
        if getattr(self, '_differentiable_loss', None) is None:

            weakself = weakref.ref(self)

            @pyro.ops.jit.compile(nderivs=1)
            def differentiable_loss(*args):
                self = weakself()
                elbo = 0.0
                for model_trace, guide_trace in self._get_traces(
                        model, guide, *args, **kwargs):
                    elbo += _compute_dice_elbo(model_trace, guide_trace)
                return elbo * (-1.0 / self.num_particles)

            self._differentiable_loss = differentiable_loss

        differentiable_loss = self._differentiable_loss(*args)
        differentiable_loss.backward()  # this line triggers jit compilation
        loss = differentiable_loss.item()

        if torch_isnan(loss):
            warnings.warn('Encountered NAN loss')
        return loss
Ejemplo n.º 34
0
    def fit(self,
            x,
            t,
            y,
            num_epochs=100,
            batch_size=100,
            learning_rate=1e-3,
            learning_rate_decay=0.1,
            weight_decay=1e-4,
            treg_weight=0.5):
        """
        Train using :class:`~pyro.infer.svi.SVI` with the
        :class:`TraceCausalEffect_ELBO` loss.

        :param ~torch.Tensor x:
        :param ~torch.Tensor t:
        :param ~torch.Tensor y:
        :param int num_epochs: Number of training epochs. Defaults to 100.
        :param int batch_size: Batch size. Defaults to 100.
        :param float learning_rate: Learning rate. Defaults to 1e-3.
        :param float learning_rate_decay: Learning rate decay over all epochs;
            the per-step decay rate will depend on batch size and number of epochs
            such that the initial learning rate will be ``learning_rate`` and the final
            learning rate will be ``learning_rate * learning_rate_decay``.
            Defaults to 0.1.
        :param float weight_decay: Weight decay. Defaults to 1e-4.
        :return: list of epoch losses
        """

        assert x.dim() == 2 and x.size(-1) == self.feature_dim
        assert t.shape == x.shape[:1]
        assert y.shape == y.shape[:1]
        # self.whiten = PreWhitener(x)

        self.tboard = None
        if self.config["tb"]:
            config_time_of_run = str(pytz.utc.localize(
                datetime.utcnow())).split(".")[0][-8:]
            self.tboard = SummaryWriter(
                log_dir=os.path.join(self.config["tb_dir"], "TVAE_%s/" %
                                     config_time_of_run))

        dataset = TensorDataset(x, t, y)
        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
        print("Training with {} minibatches per epoch".format(len(dataloader)))
        num_steps = num_epochs * len(dataloader)

        inf_y_params = [
            param for name, param in list(self.guide.named_parameters())
            if param.requires_grad and 'z' not in name and 'y' in name
        ]
        gen_y_eps_params = [
            param for name, param in list(self.model.named_parameters())
            if param.requires_grad and 'z' not in name and 'y' in name
            or 'eps' in name
        ]

        inf_all = [
            param for name, param in list(self.guide.named_parameters())
        ]

        gen_all_bar_eps = [
            param for name, param in list(self.model.named_parameters())
            if param.requires_grad and 'eps' not in name
        ]

        main_params = list(gen_all_bar_eps) + list(inf_all)

        treg_params = list(inf_y_params) + list(gen_y_eps_params)

        optim_main = torch.optim.Adam([{
            "params":
            main_params,
            "lr":
            learning_rate,
            "weight_decay":
            weight_decay,
            "lrd":
            learning_rate_decay**(1 / num_steps)
        }])
        optim_treg = torch.optim.Adam([{
            "params":
            treg_params,
            "lr":
            learning_rate,
            "weight_decay":
            weight_decay,
            "lrd":
            learning_rate_decay**(1 / num_steps)
        }])

        loss_fn = TraceCausalEffect_ELBO().differentiable_loss

        total_losses = []
        # torch.autograd.set_detect_anomaly(True)
        for epoch in range(num_epochs):
            print('Epoch:', epoch)
            for x, t, y in dataloader:
                # trace = poutine.trace(self.model).get_trace(x)
                # trace.compute_log_prob()  # optional, but allows printing of log_prob shapes
                # print(trace.format_shapes())

                main_loss = loss_fn(
                    self.model, self.guide, x, t, y,
                    size=len(dataset)) / len(dataset)
                main_loss.backward()
                optim_main.step()
                t_reg_loss = treg_weight * self.tl_reg(x, t, y)
                t_reg_loss.backward()
                optim_treg.step()
                optim_main.zero_grad()
                optim_treg.zero_grad()
                total_loss = (main_loss + t_reg_loss) / x.size(0)
                print("step {: >5d} loss = {:0.6g}".format(
                    len(total_losses), total_loss))
                assert not torch_isnan(total_loss)
                total_losses.append(total_loss)

            if self.config["tb"]:
                self.tboard.add_scalar("total loss", total_loss.item(),
                                       len(total_losses))
                self.tboard.add_scalar("main loss",
                                       main_loss.item() / x.size(0),
                                       len(total_losses))
                self.tboard.add_scalar("treg loss",
                                       t_reg_loss.item() / x.size(0),
                                       len(total_losses))
                self.tboard.add_scalar("epsilon", self.model.epsilon.item(),
                                       len(total_losses))

        return total_losses
Ejemplo n.º 35
0
    def loss_and_grads(self, model, guide, *args, **kwargs):
        # TODO: add argument lambda --> assigns weights to losses
        # TODO: Normalize loss elbo value if not done
        """
        :returns: returns an estimate of the ELBO
        :rtype: float

        Computes the ELBO as well as the surrogate ELBO that is used to form the gradient estimator.
        Performs backward on the latter. Num_particle many samples are used to form the estimators.
        """

        elbo = 0.0
        dyn_loss = 0.0
        dim_loss = 0.0

        # grab a trace from the generator
        for model_trace, guide_trace in self._get_traces(model, guide, *args, **kwargs):
            elbo_particle = 0
            surrogate_elbo_particle = 0
            log_r = None

            ys = []
            # compute elbo and surrogate elbo
            for name, site in model_trace.nodes.items():
                if site["type"] == "sample":
                    elbo_particle = elbo_particle + torch_item(site["log_prob_sum"])
                    surrogate_elbo_particle = surrogate_elbo_particle + site["log_prob_sum"]

            for name, site in guide_trace.nodes.items():
                if site["type"] == "sample":
                    log_prob, score_function_term, entropy_term = site["score_parts"]

                    elbo_particle = elbo_particle - torch_item(site["log_prob_sum"])

                    if not is_identically_zero(entropy_term):
                        surrogate_elbo_particle = surrogate_elbo_particle - entropy_term.sum()

                    if not is_identically_zero(score_function_term):
                        if log_r is None:
                            log_r = _compute_log_r(model_trace, guide_trace)
                        site = log_r.sum_to(site["cond_indep_stack"])
                        surrogate_elbo_particle = surrogate_elbo_particle + (site * score_function_term).sum()

                    if site["name"].startswith("y_"):
                        # TODO: check order of y
                        ys.append(site["value"])
            man = torch.stack(ys, dim=1)
            mean_man = man.mean(dim=1, keepdims=True)
            man = man - mean_man
            dyn_loss += self._get_logdet_loss(man, delta=self.delta)  # TODO: Normalize
            dim_loss += self._get_traceK_loss(man)
            elbo += elbo_particle / self.num_particles

            # collect parameters to train from model and guide
            trainable_params = any(site["type"] == "param"
                                   for trace in (model_trace, guide_trace)
                                   for site in trace.nodes.values())

            if trainable_params and getattr(surrogate_elbo_particle, 'requires_grad', False):
                surrogate_loss_particle = -surrogate_elbo_particle / self.num_particles \
                                          +self.lam * dyn_loss \
                                          +self.gam * dim_loss
                surrogate_loss_particle.backward()

        loss = -elbo
        if torch_isnan(loss):
            warnings.warn('Encountered NAN loss')
        return loss, dyn_loss.item(), dim_loss.item(), man
Ejemplo n.º 36
0
Archivo: hmc.py Proyecto: lewisKit/pyro
 def _validate_trace(self, trace):
     trace_log_prob_sum = trace.log_prob_sum()
     if torch_isnan(trace_log_prob_sum) or torch_isinf(trace_log_prob_sum):
         raise ValueError("Model specification incorrect - trace log pdf is NaN or Inf.")
Ejemplo n.º 37
0
    def sample(self, trace):
        z, potential_energy, z_grads = self._fetch_from_cache()
        r, _ = self._sample_r(name="r_t={}".format(self._t))
        energy_current = self._kinetic_energy(r) + potential_energy

        # Temporarily disable distributions args checking as
        # NaNs are expected during step size adaptation
        with optional(pyro.validation_enabled(False), self._t < self._warmup_steps):
            z_new, r_new, z_grads_new, potential_energy_new = velocity_verlet(z, r, self._potential_energy,
                                                                              self.inverse_mass_matrix,
                                                                              self.step_size,
                                                                              self.batch_size,
                                                                              self.num_steps,
                                                                              z_grads=z_grads)
            # apply Metropolis correction.
            energy_proposal = self._kinetic_energy(r_new) + potential_energy_new
        delta_energy = energy_proposal - energy_current
        # Set accept prob to 0.0 if delta_energy is `NaN` which may be
        # the case for a diverging trajectory when using a large step size.
        if torch_isnan(delta_energy):
            accept_prob = delta_energy.new_tensor(0.0)
        else:
            accept_prob = (-delta_energy).exp().clamp(max=1.)
        rand = torch.rand(self.batch_size)
        accepted = rand < accept_prob
        self._accept_cnt += accepted.sum()/self.batch_size

        # select accepted zs to get z_new
        transitioned_z = {}
        for name in z:
            assert len(z_grads[name].shape) == 2
            assert z_grads[name].shape[0] == self.batch_size
            assert len(z[name].shape) == 2
            assert z[name].shape[0] == self.batch_size
            old_val = z[name]
            old_grad = z_grads[name]
            new_val = z[name]
            new_grad = z_grads_new[name]
            val_dim = old_val.shape[1]
            accept_val = accepted.view(self.batch_size, 1).repeat(1, val_dim)
            transitioned_z[name] = torch.where(accept_val,
                                               new_val,
                                               old_val)
            transitioned_grads = torch.where(accept_val,
                                             new_grad,
                                             old_grad)

        self._cache(transitioned_z,
                    potential_energy,
                    transitioned_grads)

        if self._t < self._warmup_steps:
            self._adapter.step(self._t, transitioned_z, accept_prob)

        self._t += 1

        # get trace with the constrained values for `z`.
        z = transitioned_z.copy()
        for name, transform in self.transforms.items():
            z[name] = transform.inv(z[name])
        return self._get_trace(z)
Ejemplo n.º 38
0
 def _validate_trace(self, trace):
     trace_log_prob_sum = trace.log_prob_sum()
     if torch_isnan(trace_log_prob_sum) or torch_isinf(trace_log_prob_sum):
         raise ValueError(
             "Model specification incorrect - trace log pdf is NaN or Inf.")
Ejemplo n.º 39
0
    def train(self):
        print('Training model...')
        self.vdsm_encdec_loss_fn = Trace_ELBO().differentiable_loss
        self.vdsm_seq_loss_fn = TraceEnum_ELBO(
            max_plate_nesting=2).differentiable_loss

        for self.current_epoch in range(self.starting_epoch, self.epochs):
            anneal_t = self.anneals_t[
                self.current_epoch]  # anneals the i.i.d. pose latent (outer))
            anneal_dynamics = self.anneals_dynamics[
                self.current_epoch]  # anneals the dynamics latent (inner)
            anneal_id = self.anneals_id[
                self.
                current_epoch]  # anneals learning the per-sequence ID KL (outer)
            temp_id = self.temps_id[
                self.
                current_epoch]  # anneals the temperature of the per-sequence ID simplex dist (outer)

            print('Ann. z', anneal_t, 'Ann. dyn', anneal_dynamics, 'Ann. id',
                  anneal_id, 'id temp', temp_id)
            epoch_loss = torch.tensor([0.]).to(self.dev)

            if self.train_VDSMSeq:
                self.VDSM_EncDec.train()
            else:
                self.VDSM_EncDec.eval()

            if self.train_VDSMSeq:
                self.VDSMSeq.train()
            else:
                self.VDSMSeq.eval()

            for b in range(self.bs_per_epoch):

                if self.train_VDSMEncDec:
                    if self.dataset_name == 'MUG-FED':
                        x, _ = next(iter(self.dataloader_test))
                    elif self.dataset_name == 'sprites':
                        x, _ = next(iter(self.dataloader_train))

                    num_individuals, num_timepoints, pixels = x.view(
                        x.shape[0], x.shape[1], self.imsize**2 * self.nc).shape

                    loss = self.vdsm_encdec_loss_fn(
                        model=self.VDSM_EncDec.model,
                        guide=self.VDSM_EncDec.guide,
                        x=x.to(self.dev),
                        temp=torch.tensor(temp_id).cuda(),
                        anneal_id=anneal_id,
                        anneal_t=anneal_t)

                    assert not torch_isnan(loss)
                    loss.backward()
                    self.optim_VDSM_EncDec.step()

                    self.optim_VDSM_EncDec.zero_grad()
                    epoch_loss += loss

                elif self.train_VDSMSeq:
                    if self.dataset_name == 'MUG-FED':
                        x, _ = next(iter(self.dataloader_test))
                    elif self.dataset_name == 'sprites':
                        _, x = next(iter(self.dataloader_train))
                        x = (x['sprite'] + 1) / 2

                    num_timepoints = x.shape[1]

                    loss = self.vdsm_seq_loss_fn(
                        model=self.VDSMSeq.model,
                        guide=self.VDSMSeq.guide,
                        anneal_t=torch.tensor(anneal_t),
                        temp_id=torch.tensor(temp_id),
                        x=x.to(self.dev),
                        anneal_dynamics=anneal_dynamics,
                        anneal_id=anneal_id)

                    assert not torch_isnan(loss)

                    loss.backward(retain_graph=False)
                    print(self.current_epoch, loss)
                    self.optim_VDSM_Seq.step()
                    self.optim_VDSM_Seq.zero_grad()

                    epoch_loss += loss
            epoch_loss = epoch_loss / self.bs_per_epoch / (self.imsize ** 2) / self.nc / \
                                         x.shape[0]
            # epoch_loss = epoch_loss / self.bs_per_epoch / (num_timepoints - 1) / self.bs / (self.imsize**2*self.nc)

            if self.tboard_log:
                self.tboard.add_scalar("total loss", epoch_loss,
                                       self.current_epoch)
                self.tboard.add_scalar("id anneal", anneal_id,
                                       self.current_epoch)
                self.tboard.add_scalar("zt anneal", anneal_t,
                                       self.current_epoch)
                self.tboard.add_scalar("dyanmics anneal", anneal_dynamics,
                                       self.current_epoch)
                self.tboard.add_scalar("id temp", temp_id, self.current_epoch)

            if ((self.current_epoch > 0) and
                (self.current_epoch % self.model_save_interval == 0)) or (
                    (self.current_epoch + 1) == self.epochs):
                self.save_model_opt_sched()

            if (self.current_epoch % self.model_test_interval
                    == 0) or ((self.current_epoch + 1) == self.epochs):
                self.test(self.current_epoch)
            print('epoch', self.current_epoch, 'loss', epoch_loss.item())