def _differentiable_loss_particle(self, model_trace, guide_trace):
        elbo_particle = 0
        surrogate_elbo_particle = 0
        log_r = None

        # compute elbo and surrogate elbo
        for name, site in model_trace.nodes.items():
            if site["type"] == "sample":
                elbo_particle = elbo_particle + torch_item(
                    site["log_prob_sum"])
                surrogate_elbo_particle = surrogate_elbo_particle + site[
                    "log_prob_sum"]

        for name, site in guide_trace.nodes.items():
            if site["type"] == "sample":
                log_prob, score_function_term, entropy_term = site[
                    "score_parts"]

                elbo_particle = elbo_particle - torch_item(
                    site["log_prob_sum"])

                if not is_identically_zero(entropy_term):
                    surrogate_elbo_particle = surrogate_elbo_particle - entropy_term.sum(
                    )

                if not is_identically_zero(score_function_term):
                    if log_r is None:
                        log_r = _compute_log_r(model_trace, guide_trace)
                    site = log_r.sum_to(site["cond_indep_stack"])
                    surrogate_elbo_particle = surrogate_elbo_particle + (
                        site * score_function_term).sum()

        return -elbo_particle, -surrogate_elbo_particle
예제 #2
0
    def do_fit_prior_test(self, reparameterized, n_steps, loss, debug=False):
        pyro.clear_param_store()

        def model():
            with pyro.plate('samples', self.sample_batch_size):
                pyro.sample(
                    "loc_latent",
                    dist.Normal(
                        torch.stack([self.loc0] * self.sample_batch_size,
                                    dim=0),
                        torch.stack([torch.pow(self.lam0, -0.5)] *
                                    self.sample_batch_size,
                                    dim=0)).to_event(1))

        def guide():
            loc_q = pyro.param("loc_q", self.loc0.detach() + 0.134)
            log_sig_q = pyro.param(
                "log_sig_q", -0.5 * torch.log(self.lam0).data.detach() - 0.14)
            sig_q = torch.exp(log_sig_q)
            Normal = dist.Normal if reparameterized else fakes.NonreparameterizedNormal
            with pyro.plate('samples', self.sample_batch_size):
                pyro.sample(
                    "loc_latent",
                    Normal(
                        torch.stack([loc_q] * self.sample_batch_size, dim=0),
                        torch.stack([sig_q] * self.sample_batch_size,
                                    dim=0)).to_event(1))

        adam = optim.Adam({"lr": .001})
        svi = SVI(model, guide, adam, loss=loss)

        alpha = 0.99
        for k in range(n_steps):
            svi.step()
            if debug:
                loc_error = param_mse("loc_q", self.loc0)
                log_sig_error = param_mse("log_sig_q",
                                          -0.5 * torch.log(self.lam0))
                with torch.no_grad():
                    if k == 0:
                        avg_loglikelihood, avg_penalty = loss._differentiable_loss_parts(
                            model, guide)
                        avg_loglikelihood = torch_item(avg_loglikelihood)
                        avg_penalty = torch_item(avg_penalty)
                    loglikelihood, penalty = loss._differentiable_loss_parts(
                        model, guide)
                    avg_loglikelihood = alpha * avg_loglikelihood + (
                        1 - alpha) * torch_item(loglikelihood)
                    avg_penalty = alpha * avg_penalty + (
                        1 - alpha) * torch_item(penalty)
                if k % 100 == 0:
                    print(loc_error, log_sig_error)
                    print(avg_loglikelihood, avg_penalty)
                    print()

        loc_error = param_mse("loc_q", self.loc0)
        log_sig_error = param_mse("log_sig_q", -0.5 * torch.log(self.lam0))
        assert_equal(0.0, loc_error, prec=0.05)
        assert_equal(0.0, log_sig_error, prec=0.05)
    def loss_and_grads(self, model, guide, *args, **kwargs):
        """
        :returns: returns an estimate of the ELBO
        :rtype: float

        Computes the ELBO as well as the surrogate ELBO that is used to form the gradient estimator.
        Performs backward on the latter. Num_particle many samples are used to form the estimators.
        """
        elbo = 0.0
        # grab a trace from the generator
        for model_trace, guide_trace in self._get_traces(
                model, guide, *args, **kwargs):
            elbo_particle = 0
            surrogate_elbo_particle = 0
            log_r = None

            # compute elbo and surrogate elbo
            for name, site in model_trace.nodes.items():
                if site["type"] == "sample":
                    elbo_particle = elbo_particle + torch_item(
                        site["log_prob_sum"])
                    surrogate_elbo_particle = surrogate_elbo_particle + site[
                        "log_prob_sum"]

            for name, site in guide_trace.nodes.items():
                if site["type"] == "sample":
                    log_prob, score_function_term, entropy_term = site[
                        "score_parts"]

                    elbo_particle = elbo_particle - torch_item(
                        site["log_prob_sum"])

                    if not is_identically_zero(entropy_term):
                        surrogate_elbo_particle = surrogate_elbo_particle - entropy_term.sum(
                        )

                    if not is_identically_zero(score_function_term):
                        if log_r is None:
                            log_r = _compute_log_r(model_trace, guide_trace)
                        site = log_r.sum_to(site["cond_indep_stack"])
                        surrogate_elbo_particle = surrogate_elbo_particle + (
                            site * score_function_term).sum()

            elbo += elbo_particle / self.num_particles

            # collect parameters to train from model and guide
            trainable_params = any(site["type"] == "param"
                                   for trace in (model_trace, guide_trace)
                                   for site in trace.nodes.values())

            if trainable_params and getattr(surrogate_elbo_particle,
                                            'requires_grad', False):
                surrogate_loss_particle = -surrogate_elbo_particle / self.num_particles
                surrogate_loss_particle.backward()

        loss = -elbo
        if torch_isnan(loss):
            warnings.warn('Encountered NAN loss')
        return loss
예제 #4
0
    def loss(self, model, guide, *args, **kwargs):
        """
        :returns: returns an estimate of the ELBO
        :rtype: float

        Evaluates the ELBO with an estimator that uses num_particles many samples/particles.
        """
        elbo_particles = []
        is_vectorized = self.vectorize_particles and self.num_particles > 1

        # grab a vectorized trace from the generator
        for model_trace, guide_trace in self._get_traces(
                model, guide, *args, **kwargs):
            elbo_particle = 0.

            # compute elbo
            for name, site in model_trace.nodes.items():
                if site["type"] == "sample":
                    if is_vectorized:
                        log_prob_sum = site["log_prob"].detach().reshape(
                            self.num_particles, -1).sum(-1)
                    else:
                        log_prob_sum = torch_item(site["log_prob_sum"])

                    elbo_particle = elbo_particle + log_prob_sum

            for name, site in guide_trace.nodes.items():
                if site["type"] == "sample":
                    log_prob, score_function_term, entropy_term = site[
                        "score_parts"]
                    if is_vectorized:
                        log_prob_sum = log_prob.detach().reshape(
                            self.num_particles, -1).sum(-1)
                    else:
                        log_prob_sum = torch_item(site["log_prob_sum"])

                    elbo_particle = elbo_particle - log_prob_sum

            elbo_particles.append(elbo_particle)

        if is_vectorized:
            elbo_particles = elbo_particles[0]
        else:
            elbo_particles = torch.tensor(
                elbo_particles)  # no need to use .new*() here

        log_weights = (1. - self.alpha) * elbo_particles
        log_mean_weight = torch.logsumexp(log_weights, dim=0) - math.log(
            self.num_particles)
        elbo = log_mean_weight.sum().item() / (1. - self.alpha)

        loss = -elbo
        warn_if_nan(loss, "loss")
        return loss
예제 #5
0
파일: ess.py 프로젝트: pyro-ppl/pyro-models
    def loss(self, model, guide, *args, **kwargs):
        """
        :returns: returns an estimate of the ELBO
        :rtype: float

        Evaluates the ELBO with an estimator that uses num_particles many samples/particles.
        """
        elbo_particles = []
        is_vectorized = self.vectorize_particles and self.num_particles > 1

        # grab a vectorized trace from the generator
        for model_trace, guide_trace in self._get_traces(model, guide, *args, **kwargs):
            elbo_particle = 0.

            # compute elbo
            for name, site in model_trace.nodes.items():
                if site["type"] == "sample":
                    if is_vectorized:
                        log_prob_sum = site["log_prob"].detach().reshape(self.num_particles, -1).sum(-1)
                    else:
                        log_prob_sum = torch_item(site["log_prob_sum"])

                    elbo_particle = elbo_particle + log_prob_sum

            for name, site in guide_trace.nodes.items():
                if site["type"] == "sample":
                    log_prob, score_function_term, entropy_term = site["score_parts"]
                    if is_vectorized:
                        log_prob_sum = log_prob.detach().reshape(self.num_particles, -1).sum(-1)
                    else:
                        log_prob_sum = torch_item(site["log_prob_sum"])

                    elbo_particle = elbo_particle - log_prob_sum

            elbo_particles.append(elbo_particle)

        if is_vectorized:
            elbo_particles = elbo_particles[0]
        else:
            elbo_particles = torch.tensor(elbo_particles)  # no need to use .new*() here

        #print('elbo_particles', elbo_particles.size())
        elbo_particles = elbo_particles.view(self.num_outer, self.num_inner)
        #sys.exit()

        log_w_norm = elbo_particles - torch.logsumexp(elbo_particles, dim=1, keepdim=True)
        ess_val = torch.exp(-torch.logsumexp(2*log_w_norm, dim=1))

        #print('ess_val', ess_val.size())
        #sys.exit()

        loss = ess_val.mean()
        warn_if_nan(loss, "loss")
        return loss.item()
예제 #6
0
    def do_fit_prior_test(self, reparameterized, n_steps, loss, debug=False):
        pyro.clear_param_store()
        Beta = dist.Beta if reparameterized else fakes.NonreparameterizedBeta

        def model():
            with pyro.plate('samples', self.sample_batch_size):
                pyro.sample(
                    "p_latent", Beta(
                        torch.stack([torch.stack([self.alpha0])]*self.sample_batch_size),
                        torch.stack([torch.stack([self.beta0])]*self.sample_batch_size)
                    ).to_event(1)
                )

        def guide():
            alpha_q_log = pyro.param("alpha_q_log",
                                     torch.log(self.alpha0) + 0.17)
            beta_q_log = pyro.param("beta_q_log",
                                    torch.log(self.beta0) - 0.143)
            alpha_q, beta_q = torch.exp(alpha_q_log), torch.exp(beta_q_log)
            with pyro.plate('samples', self.sample_batch_size):
                pyro.sample(
                    "p_latent", Beta(
                        torch.stack([torch.stack([alpha_q])]*self.sample_batch_size),
                        torch.stack([torch.stack([beta_q])]*self.sample_batch_size)
                    ).to_event(1)
                )

        adam = optim.Adam({"lr": .001, "betas": (0.97, 0.999)})
        svi = SVI(model, guide, adam, loss=loss)

        alpha = 0.99
        for k in range(n_steps):
            svi.step()
            if debug:
                alpha_error = param_abs_error("alpha_q_log", torch.log(self.alpha0))
                beta_error = param_abs_error("beta_q_log", torch.log(self.beta0))
                with torch.no_grad():
                    if k == 0:
                        avg_loglikelihood, avg_penalty = loss._differentiable_loss_parts(model, guide)
                        avg_loglikelihood = torch_item(avg_loglikelihood)
                        avg_penalty = torch_item(avg_penalty)
                    loglikelihood, penalty = loss._differentiable_loss_parts(model, guide)
                    avg_loglikelihood = alpha * avg_loglikelihood + (1-alpha) * torch_item(loglikelihood)
                    avg_penalty = alpha * avg_penalty + (1-alpha) * torch_item(penalty)
                if k % 100 == 0:
                    print(alpha_error, beta_error)
                    print(avg_loglikelihood, avg_penalty)
                    print()

        alpha_error = param_abs_error("alpha_q_log", torch.log(self.alpha0))
        beta_error = param_abs_error("beta_q_log", torch.log(self.beta0))
        assert_equal(0.0, alpha_error, prec=0.08)
        assert_equal(0.0, beta_error, prec=0.08)
예제 #7
0
    def do_fit_prior_test(self, reparameterized, n_steps, loss, debug=False, lr=0.0002):
        pyro.clear_param_store()
        Gamma = dist.Gamma if reparameterized else fakes.NonreparameterizedGamma

        def model():
            with pyro.plate('samples', self.sample_batch_size):
                pyro.sample(
                    "lambda_latent", Gamma(
                        torch.stack([torch.stack([self.alpha0])]*self.sample_batch_size),
                        torch.stack([torch.stack([self.beta0])]*self.sample_batch_size)
                    ).to_event(1)
                )

        def guide():
            alpha_q = pyro.param("alpha_q", self.alpha0.detach() + math.exp(0.17),
                                 constraint=constraints.positive)
            beta_q = pyro.param("beta_q", self.beta0.detach() / math.exp(0.143),
                                constraint=constraints.positive)
            with pyro.plate('samples', self.sample_batch_size):
                pyro.sample(
                    "lambda_latent", Gamma(
                        torch.stack([torch.stack([alpha_q])]*self.sample_batch_size),
                        torch.stack([torch.stack([beta_q])]*self.sample_batch_size)
                    ).to_event(1)
                )

        adam = optim.Adam({"lr": lr, "betas": (0.97, 0.999)})
        svi = SVI(model, guide, adam, loss)

        alpha = 0.99
        for k in range(n_steps):
            svi.step()
            if debug:
                alpha_error = param_mse("alpha_q", self.alpha0)
                beta_error = param_mse("beta_q", self.beta0)
                with torch.no_grad():
                    if k == 0:
                        avg_loglikelihood, avg_penalty = loss._differentiable_loss_parts(model, guide, (), {})
                        avg_loglikelihood = torch_item(avg_loglikelihood)
                        avg_penalty = torch_item(avg_penalty)
                    loglikelihood, penalty = loss._differentiable_loss_parts(model, guide, (), {})
                    avg_loglikelihood = alpha * avg_loglikelihood + (1-alpha) * torch_item(loglikelihood)
                    avg_penalty = alpha * avg_penalty + (1-alpha) * torch_item(penalty)
                if k % 100 == 0:
                    print(alpha_error, beta_error)
                    print(avg_loglikelihood, avg_penalty)
                    print()

        assert_equal(pyro.param("alpha_q"), self.alpha0, prec=0.2, msg='{} vs {}'.format(
            pyro.param("alpha_q").detach().cpu().numpy(), self.alpha0.detach().cpu().numpy()))
        assert_equal(pyro.param("beta_q"), self.beta0, prec=0.15, msg='{} vs {}'.format(
            pyro.param("beta_q").detach().cpu().numpy(), self.beta0.detach().cpu().numpy()))
예제 #8
0
    def loss(self, model, guide, *args, **kwargs):
        """
        :returns: returns an estimate of the ELBO
        :rtype: float

        Evaluates the ELBO with an estimator that uses num_particles many samples/particles.
        """
        elbo = 0.0
        for model_trace, guide_trace in self._get_traces(model, guide, args, kwargs):
            elbo_particle = torch_item(model_trace.log_prob_sum()) - torch_item(guide_trace.log_prob_sum())
            elbo += elbo_particle / float(self.num_particles)

        loss = -elbo
        warn_if_nan(loss, "loss")
        return loss
예제 #9
0
    def _differentiable_loss_particle(self, model_trace, guide_trace):
        elbo_particle = 0

        for name, model_site in model_trace.nodes.items():
            if model_site["type"] == "sample":
                if model_site["is_observed"]:
                    elbo_particle = elbo_particle + model_site["log_prob_sum"]
                else:
                    guide_site = guide_trace.nodes[name]
                    if is_validation_enabled():
                        check_fully_reparametrized(guide_site)

                    # use kl divergence if available, else fall back on sampling
                    try:
                        kl_qp = kl_divergence(guide_site["fn"], model_site["fn"])
                        kl_qp = scale_and_mask(kl_qp, scale=guide_site["scale"], mask=guide_site["mask"])
                        assert kl_qp.shape == guide_site["fn"].batch_shape
                        elbo_particle = elbo_particle - kl_qp.sum()
                    except NotImplementedError:
                        entropy_term = guide_site["score_parts"].entropy_term
                        elbo_particle = elbo_particle + model_site["log_prob_sum"] - entropy_term.sum()

        # handle auxiliary sites in the guide
        for name, guide_site in guide_trace.nodes.items():
            if guide_site["type"] == "sample" and name not in model_trace.nodes:
                assert guide_site["infer"].get("is_auxiliary")
                if is_validation_enabled():
                    check_fully_reparametrized(guide_site)
                entropy_term = guide_site["score_parts"].entropy_term
                elbo_particle = elbo_particle - entropy_term.sum()

        loss = -(elbo_particle.detach() if torch._C._get_tracing_state() else torch_item(elbo_particle))
        surrogate_loss = -elbo_particle
        return loss, surrogate_loss
예제 #10
0
    def loss_and_grads(self, model, guide, *args, **kwargs):
        loss = self.differentiable_loss(model, guide, *args, **kwargs)
        loss.backward()
        loss = torch_item(loss)

        warn_if_nan(loss, "loss")
        return loss
예제 #11
0
    def step(self, *args, **kwargs):
        """
        :returns: estimate of the loss
        :rtype: float

        Take a gradient step on the loss function (and any auxiliary loss functions
        generated under the hood by `loss_and_grads`).
        Any args or kwargs are passed to the model and guide
        """
        # get loss and compute gradients
        with poutine.trace(param_only=True) as param_capture:
            loss = self.loss_and_grads(self.model, self.guide, *args, **kwargs)

        params = set(site["value"].unconstrained()
                     for site in param_capture.trace.nodes.values())

        # actually perform gradient steps
        # torch.optim objects gets instantiated for any params that haven't been seen yet
        self.optim(params)

        # zero gradients
        pyro.infer.util.zero_grads(params)

        if isinstance(loss, tuple):
            # Support losses that return a tuple, e.g. ReweightedWakeSleep.
            return type(loss)(map(torch_item, loss))
        else:
            return torch_item(loss)
예제 #12
0
def train(gpmodule, optimizer=None, loss_fn=None, retain_graph=None, num_steps=1000):
    """
    A helper to optimize parameters for a GP module.

    :param ~pyro.contrib.gp.models.GPModel gpmodule: A GP module.
    :param ~torch.optim.Optimizer optimizer: A PyTorch optimizer instance.
        By default, we use Adam with ``lr=0.01``.
    :param callable loss_fn: A loss function which takes inputs are
        ``gpmodule.model``, ``gpmodule.guide``, and returns ELBO loss.
        By default, ``loss_fn=TraceMeanField_ELBO().differentiable_loss``.
    :param bool retain_graph: An optional flag of ``torch.autograd.backward``.
    :param int num_steps: Number of steps to run SVI.
    :returns: a list of losses during the training procedure
    :rtype: list
    """
    optimizer = (
        torch.optim.Adam(gpmodule.parameters(), lr=0.01)
        if optimizer is None
        else optimizer
    )
    # TODO: add support for JIT loss
    loss_fn = TraceMeanField_ELBO().differentiable_loss if loss_fn is None else loss_fn

    def closure():
        optimizer.zero_grad()
        loss = loss_fn(gpmodule.model, gpmodule.guide)
        torch_backward(loss, retain_graph)
        return loss

    losses = []
    for i in range(num_steps):
        loss = optimizer.step(closure)
        losses.append(torch_item(loss))
    return losses
예제 #13
0
    def loss(self, model, guide, *args, **kwargs):
        """
        :returns: returns an estimate of the ELBO
        :rtype: float

        Evaluates the ELBO with an estimator that uses num_particles many samples/particles.
        """
        elbo = 0.0
        for model_trace, guide_trace in self._get_traces(model, guide, *args, **kwargs):
            elbo_particle = torch_item(model_trace.log_prob_sum()) - torch_item(guide_trace.log_prob_sum())
            elbo += elbo_particle / self.num_particles

        loss = -elbo
        if torch_isnan(loss):
            warnings.warn('Encountered NAN loss')
        return loss
예제 #14
0
    def loss(self, model, guide, *args, **kwargs):
        """
        :returns: returns an estimate of the ELBO
        :rtype: float

        Evaluates the ELBO with an estimator that uses num_particles many samples/particles.
        """
        elbo = 0.0
        for weight, model_trace, guide_trace in self._get_traces(model, guide, *args, **kwargs):
            elbo_particle = torch_item(model_trace.log_prob_sum()) - torch_item(guide_trace.log_prob_sum())
            elbo += weight * elbo_particle

        loss = -elbo
        if torch_isnan(loss):
            warnings.warn('Encountered NAN loss')
        return loss
예제 #15
0
파일: svi.py 프로젝트: lewisKit/pyro
    def evaluate_loss(self, *args, **kwargs):
        """
        :returns: estimate of the loss
        :rtype: float

        Evaluate the loss function. Any args or kwargs are passed to the model and guide.
        """
        with torch.no_grad():
            return torch_item(self.loss(self.model, self.guide, *args, **kwargs))
예제 #16
0
    def evaluate_loss(self, *args, **kwargs):
        """
        :returns: estimate of the loss
        :rtype: float

        Evaluate the loss function. Any args or kwargs are passed to the model and guide.
        """
        with torch.no_grad():
            return torch_item(self.loss(self.model, self.guide, *args, **kwargs))
예제 #17
0
 def step(self, *args, **kwargs):
     with poutine.trace(param_only=True) as param_capture:
         loss = self.loss_and_grads(self.model, self.guide, *args, **kwargs)
     params = []
     for site in param_capture.trace.nodes.values():
         param = site["value"].unconstrained()
         if site.get('free') is not None:
             param.grad = site['free'] * param.grad
         params.append(param)
     self.optim(params)
     pyro.infer.util.zero_grads(params)
     return torch_item(loss)
예제 #18
0
 def differentiable_loss(self, model, guide, *args, **kwargs):
     """
     Computes the surrogate loss that can be differentiated with autograd
     to produce gradient estimates for the model and guide parameters
     """
     loss = 0.
     surrogate_loss = 0.
     for model_trace, guide_trace in self._get_traces(model, guide, args, kwargs):
         loss_particle, surrogate_loss_particle = self._differentiable_loss_particle(model_trace, guide_trace)
         surrogate_loss += surrogate_loss_particle / self.num_particles
         loss += loss_particle / self.num_particles
     warn_if_nan(surrogate_loss, "loss")
     return loss + (surrogate_loss - torch_item(surrogate_loss))
예제 #19
0
    def evaluate_loss(self, *args, **kwargs):
        """
        :returns: estimate of the loss
        :rtype: float

        Evaluate the loss function. Any args or kwargs are passed to the model and guide.
        """
        with torch.no_grad():
            loss = self.loss(self.model, self.guide, *args, **kwargs)
            if isinstance(loss, tuple):
                # Support losses that return a tuple, e.g. ReweightedWakeSleep.
                return type(loss)(map(torch_item, loss))
            else:
                return torch_item(loss)
예제 #20
0
    def custom_step(self, *batch):

        with poutine.trace(param_only=True) as param_capture:
            loss = self.loss.differentiable_loss(self.model, self.guide, *
                                                 batch) + self.beta_l1_loss()

            params = set(site["value"].unconstrained()
                         for site in param_capture.trace.nodes.values())

        loss.backward()
        self.optimizer(params)
        pyro.infer.util.zero_grads(params)

        return torch_item(loss)
예제 #21
0
    def loss_and_grads(self, grads, batch, *args, **kwargs):
        """
        :returns: an estimate of the loss (expectation over p(x, y) of
            -log q(x, y) ) - where p is the model and q is the guide
        :rtype: float

        If a batch is provided, the loss is estimated using these traces
        Otherwise, a fresh batch is generated from the model.

        If grads is True, will also call `backward` on loss.

        `args` and `kwargs` are passed to the model and guide.
        """
        if batch is None:
            batch = (
                self._sample_from_joint(*args, **kwargs)
                for _ in range(self.training_batch_size)
            )
            batch_size = self.training_batch_size
        else:
            batch_size = len(batch)

        loss = 0
        for model_trace in batch:
            with poutine.trace(param_only=True) as particle_param_capture:
                guide_trace = self._get_matched_trace(model_trace, *args, **kwargs)
            particle_loss = self._differentiable_loss_particle(guide_trace)
            particle_loss /= batch_size

            if grads:
                guide_params = set(
                    site["value"].unconstrained()
                    for site in particle_param_capture.trace.nodes.values()
                )
                guide_grads = torch.autograd.grad(
                    particle_loss, guide_params, allow_unused=True
                )
                for guide_grad, guide_param in zip(guide_grads, guide_params):
                    if guide_grad is None:
                        continue
                    guide_param.grad = (
                        guide_grad
                        if guide_param.grad is None
                        else guide_param.grad + guide_grad
                    )

            loss += torch_item(particle_loss)

        warn_if_nan(loss, "loss")
        return loss
예제 #22
0
    def loss(self, model, guide, *args, **kwargs):
        """
        :returns: returns an estimate of the MMD-VAE-type loss [1]
        :rtype: float

        Computes the MMD-VAE-type loss with an estimator that uses num_particles many samples/particles.

        References

        [1] `A Tutorial on Information Maximizing Variational Autoencoders (InfoVAE)`
            Shengjia Zhao
            https://ermongroup.github.io/blog/a-tutorial-on-mmd-variational-autoencoders/
        """
        loss = self.differentiable_loss(model, guide, *args, **kwargs)
        return torch_item(loss)
예제 #23
0
    def evaluate(self, raw_expr, encoded_expr, read_depth):

        batch_logp = []
        for batch in self.epoch_batch(raw_expr,
                                      encoded_expr,
                                      read_depth,
                                      batch_size=512,
                                      bar=False):

            with torch.no_grad():
                log_prob = torch_item(
                    self.loss.loss(self.model, self.guide, *batch))

            batch_logp.append(log_prob)

        return np.array(batch_logp).sum()  #loss is negative log-likelihood
예제 #24
0
    def loss_and_grads(self, model, guide, *args, **kwargs):
        """
        :returns: returns an estimate of the ELBO
        :rtype: float

        Computes the ELBO as well as the surrogate ELBO that is used to form the gradient estimator.
        Performs backward on the latter. Num_particle many samples are used to form the estimators.
        If baselines are present, a baseline loss is also constructed and differentiated.
        """
        elbo, surrogate_loss = self._loss_and_surrogate_loss(model, guide, args, kwargs)

        torch_backward(surrogate_loss, retain_graph=self.retain_graph)

        elbo = torch_item(elbo)
        loss = -elbo
        warn_if_nan(loss, "loss")
        return loss
예제 #25
0
    def _differentiable_loss_particle(self, model_trace, guide_trace):
        # Construct -ELBO part.
        blocked_names = [name for name, site in guide_trace.nodes.items()
                         if site["type"] == "sample" and site["is_observed"]]
        blocked_guide_trace = guide_trace.copy()
        for name in blocked_names:
            del blocked_guide_trace.nodes[name]
        loss, surrogate_loss = super()._differentiable_loss_particle(
            model_trace, blocked_guide_trace)

        # Add log q terms.
        for name in blocked_names:
            log_q = guide_trace.nodes[name]["log_prob_sum"]
            loss = loss - 100* torch_item(log_q)
            surrogate_loss = surrogate_loss - 100* log_q

        return loss, surrogate_loss
예제 #26
0
    def loss_and_grads(self, model, guide, *args, **kwargs):
        """
        :returns: returns an estimate of the MMD-VAE-type loss [1]
        :rtype: float

        Computes the MMD-VAE-type loss and performs backward on it.
        Leads to valid gradient estimates as long as latent variables
        in both the guide and the model are reparameterizable.
        Num_particles many samples are used to form the estimators.

        References

        [1] `A Tutorial on Information Maximizing Variational Autoencoders (InfoVAE)`
            Shengjia Zhao
            https://ermongroup.github.io/blog/a-tutorial-on-mmd-variational-autoencoders/
        """
        loss = self.differentiable_loss(model, guide, *args, **kwargs)
        loss.backward(retain_graph=self.retain_graph)
        return torch_item(loss)
예제 #27
0
파일: csis.py 프로젝트: youngshingjun/pyro
    def step(self, *args, **kwargs):
        """
        :returns: estimate of the loss
        :rtype: float

        Take a gradient step on the loss function. Arguments are passed to the
        model and guide.
        """
        with poutine.trace(param_only=True) as param_capture:
            loss = self.loss_and_grads(True, None, *args, **kwargs)

        params = set(site["value"].unconstrained()
                     for site in param_capture.trace.nodes.values()
                     if site["value"].grad is not None)

        self.optim(params)

        pyro.infer.util.zero_grads(params)

        return torch_item(loss)
예제 #28
0
    def _loss_and_grads_particle(self, weight, model_trace, guide_trace):
        # have the trace compute all the individual (batch) log pdf terms
        # and score function terms (if present) so that they are available below
        model_trace.compute_log_prob()
        guide_trace.compute_score_parts()
        if is_validation_enabled():
            for site in model_trace.nodes.values():
                if site["type"] == "sample":
                    check_site_shape(site, self.max_iarange_nesting)
            for site in guide_trace.nodes.values():
                if site["type"] == "sample":
                    check_site_shape(site, self.max_iarange_nesting)

        # compute elbo for reparameterized nodes
        non_reparam_nodes = set(guide_trace.nonreparam_stochastic_nodes)
        elbo, surrogate_elbo = _compute_elbo_reparam(model_trace, guide_trace,
                                                     non_reparam_nodes)

        # the following computations are only necessary if we have non-reparameterizable nodes
        baseline_loss = 0.0
        if non_reparam_nodes:
            downstream_costs, _ = _compute_downstream_costs(
                model_trace, guide_trace, non_reparam_nodes)
            surrogate_elbo_term, baseline_loss = _compute_elbo_non_reparam(
                guide_trace, non_reparam_nodes, downstream_costs)
            surrogate_elbo += surrogate_elbo_term

        # collect parameters to train from model and guide
        trainable_params = any(site["type"] == "param"
                               for trace in (model_trace, guide_trace)
                               for site in trace.nodes.values())

        if trainable_params:
            surrogate_loss = -surrogate_elbo
            torch_backward(weight * (surrogate_loss + baseline_loss))

        loss = -torch_item(elbo)
        if torch_isnan(loss):
            warnings.warn('Encountered NAN loss')
        return weight * loss
예제 #29
0
파일: holder.py 프로젝트: henrishi/bm_model
    def step(self, *args, **kwargs):

        # compute loss
        loss = self.loss_fn(self.model, self.guide, *args, **
                            kwargs) + self.compute_l2_term(*args, **kwargs)
        loss.backward()

        tr = poutine.trace(self.guide).get_trace(*args, **kwargs)
        params = [
            site['value'].unconstrained() for name, site in tr.nodes.items()
            if site['type'] == 'param'
        ]

        # Copied from Pyro SVI source code
        # actually perform gradient steps
        # torch.optim objects gets instantiated for any params that haven't been seen yet
        self.optimizer(params)

        # zero gradients
        zero_grads(params)

        return torch_item(loss)
예제 #30
0
    def _loss_and_grads_particle(self, weight, model_trace, guide_trace):
        # have the trace compute all the individual (batch) log pdf terms
        # and score function terms (if present) so that they are available below
        model_trace.compute_log_prob()
        guide_trace.compute_score_parts()
        if is_validation_enabled():
            for site in model_trace.nodes.values():
                if site["type"] == "sample":
                    check_site_shape(site, self.max_iarange_nesting)
            for site in guide_trace.nodes.values():
                if site["type"] == "sample":
                    check_site_shape(site, self.max_iarange_nesting)

        # compute elbo for reparameterized nodes
        non_reparam_nodes = set(guide_trace.nonreparam_stochastic_nodes)
        elbo, surrogate_elbo = _compute_elbo_reparam(model_trace, guide_trace, non_reparam_nodes)

        # the following computations are only necessary if we have non-reparameterizable nodes
        baseline_loss = 0.0
        if non_reparam_nodes:
            downstream_costs, _ = _compute_downstream_costs(model_trace, guide_trace, non_reparam_nodes)
            surrogate_elbo_term, baseline_loss = _compute_elbo_non_reparam(guide_trace,
                                                                           non_reparam_nodes, downstream_costs)
            surrogate_elbo += surrogate_elbo_term

        # collect parameters to train from model and guide
        trainable_params = any(site["type"] == "param"
                               for trace in (model_trace, guide_trace)
                               for site in trace.nodes.values())

        if trainable_params:
            surrogate_loss = -surrogate_elbo
            torch_backward(weight * (surrogate_loss + baseline_loss))

        loss = -torch_item(elbo)
        if torch_isnan(loss):
            warnings.warn('Encountered NAN loss')
        return weight * loss
예제 #31
0
파일: svi.py 프로젝트: lewisKit/pyro
    def step(self, *args, **kwargs):
        """
        :returns: estimate of the loss
        :rtype: float

        Take a gradient step on the loss function (and any auxiliary loss functions
        generated under the hood by `loss_and_grads`).
        Any args or kwargs are passed to the model and guide
        """
        # get loss and compute gradients
        with poutine.trace(param_only=True) as param_capture:
            loss = self.loss_and_grads(self.model, self.guide, *args, **kwargs)

        params = set(site["value"].unconstrained()
                     for site in param_capture.trace.nodes.values())

        # actually perform gradient steps
        # torch.optim objects gets instantiated for any params that haven't been seen yet
        self.optim(params)

        # zero gradients
        pyro.infer.util.zero_grads(params)

        return torch_item(loss)
예제 #32
0
파일: __init__.py 프로젝트: pyro-ppl/pyro
 def loss(self, model, guide, *args, **kwargs):
     return torch_item(
         self.differentiable_loss(model, guide, *args, **kwargs))
예제 #33
0
    def loss_and_grads(self, model, guide, *args, **kwargs):
        # TODO: add argument lambda --> assigns weights to losses
        # TODO: Normalize loss elbo value if not done
        """
        :returns: returns an estimate of the ELBO
        :rtype: float

        Computes the ELBO as well as the surrogate ELBO that is used to form the gradient estimator.
        Performs backward on the latter. Num_particle many samples are used to form the estimators.
        """

        elbo = 0.0
        dyn_loss = 0.0
        dim_loss = 0.0

        # grab a trace from the generator
        for model_trace, guide_trace in self._get_traces(model, guide, *args, **kwargs):
            elbo_particle = 0
            surrogate_elbo_particle = 0
            log_r = None

            ys = []
            # compute elbo and surrogate elbo
            for name, site in model_trace.nodes.items():
                if site["type"] == "sample":
                    elbo_particle = elbo_particle + torch_item(site["log_prob_sum"])
                    surrogate_elbo_particle = surrogate_elbo_particle + site["log_prob_sum"]

            for name, site in guide_trace.nodes.items():
                if site["type"] == "sample":
                    log_prob, score_function_term, entropy_term = site["score_parts"]

                    elbo_particle = elbo_particle - torch_item(site["log_prob_sum"])

                    if not is_identically_zero(entropy_term):
                        surrogate_elbo_particle = surrogate_elbo_particle - entropy_term.sum()

                    if not is_identically_zero(score_function_term):
                        if log_r is None:
                            log_r = _compute_log_r(model_trace, guide_trace)
                        site = log_r.sum_to(site["cond_indep_stack"])
                        surrogate_elbo_particle = surrogate_elbo_particle + (site * score_function_term).sum()

                    if site["name"].startswith("y_"):
                        # TODO: check order of y
                        ys.append(site["value"])
            man = torch.stack(ys, dim=1)
            mean_man = man.mean(dim=1, keepdims=True)
            man = man - mean_man
            dyn_loss += self._get_logdet_loss(man, delta=self.delta)  # TODO: Normalize
            dim_loss += self._get_traceK_loss(man)
            elbo += elbo_particle / self.num_particles

            # collect parameters to train from model and guide
            trainable_params = any(site["type"] == "param"
                                   for trace in (model_trace, guide_trace)
                                   for site in trace.nodes.values())

            if trainable_params and getattr(surrogate_elbo_particle, 'requires_grad', False):
                surrogate_loss_particle = -surrogate_elbo_particle / self.num_particles \
                                          +self.lam * dyn_loss \
                                          +self.gam * dim_loss
                surrogate_loss_particle.backward()

        loss = -elbo
        if torch_isnan(loss):
            warnings.warn('Encountered NAN loss')
        return loss, dyn_loss.item(), dim_loss.item(), man
예제 #34
0
 def loss_and_grads(self, model, guide, *args, **kwargs):
     loss = self._loss(model, guide, args, kwargs)
     torch_backward(loss, retain_graph=self.retain_graph)
     loss = torch_item(loss)
     warn_if_nan(loss, "loss")
     return loss