Пример #1
0
    def templates_guide_mvn(self):
        """ Multivariate normal guide for template parameters
        """

        loc = _deep_getattr(self, "mvn.loc")
        scale_tril = _deep_getattr(self, "mvn.scale_tril")

        dt = dist.MultivariateNormal(loc, scale_tril=scale_tril)
        states = pyro.sample("states_" + self.name_prefix,
                             dt,
                             infer={"is_auxiliary": True})

        result = {}

        for i_poiss in torch.arange(self.n_poiss):
            transform = biject_to(self.poiss_priors[i_poiss].support)
            value = transform(states[i_poiss])
            log_density = transform.inv.log_abs_det_jacobian(
                value, states[i_poiss])
            log_density = sum_rightmost(
                log_density,
                log_density.dim() - value.dim() +
                self.poiss_priors[i_poiss].event_dim)

            result[self.poiss_labels[i_poiss]] = pyro.sample(
                self.poiss_labels[i_poiss],
                dist.Delta(value,
                           log_density=log_density,
                           event_dim=self.poiss_priors[i_poiss].event_dim))

        i_param = self.n_poiss

        for i_ps in torch.arange(self.n_ps):
            for i_ps_param in torch.arange(self.n_ps_params):

                transform = biject_to(self.ps_priors[i_ps][i_ps_param].support)

                value = transform(states[i_param])

                log_density = transform.inv.log_abs_det_jacobian(
                    value, states[i_param])
                log_density = sum_rightmost(
                    log_density,
                    log_density.dim() - value.dim() +
                    self.ps_priors[i_ps][i_ps_param].event_dim)

                result[self.ps_param_labels[i_ps_param] + "_" +
                       self.ps_labels[i_ps]] = pyro.sample(
                           self.ps_param_labels[i_ps_param] + "_" +
                           self.ps_labels[i_ps],
                           dist.Delta(value,
                                      log_density=log_density,
                                      event_dim=self.ps_priors[i_ps]
                                      [i_ps_param].event_dim))
                i_param += 1

        return result
Пример #2
0
 def score_parts(self, value):
     shape = broadcast_shape(self.batch_shape, value.shape[:value.dim() - self.event_dim])
     log_prob, score_function, entropy_term = self.base_dist.score_parts(value)
     log_prob = sum_rightmost(log_prob, self.reinterpreted_batch_ndims).expand(shape)
     if not isinstance(score_function, numbers.Number):
         score_function = sum_rightmost(score_function, self.reinterpreted_batch_ndims).expand(shape)
     if not isinstance(entropy_term, numbers.Number):
         entropy_term = sum_rightmost(entropy_term, self.reinterpreted_batch_ndims).expand(shape)
     return ScoreParts(log_prob, score_function, entropy_term)
Пример #3
0
    def guide(self):

        a_locs = pyro.param("a_locs", torch.full((self.n_params, ), 0.0))
        a_scales_tril = pyro.param(
            "a_scales",
            lambda: 0.1 * eye_like(a_locs, self.n_params),
            constraint=constraints.lower_cholesky)

        dt = dist.MultivariateNormal(a_locs, scale_tril=a_scales_tril)
        states = pyro.sample("states", dt, infer={"is_auxiliary": True})

        result = {}

        for i_poiss in torch.arange(self.n_poiss):
            transform = biject_to(self.poiss_priors[i_poiss].support)
            value = transform(states[i_poiss])
            log_density = transform.inv.log_abs_det_jacobian(
                value, states[i_poiss])
            log_density = sum_rightmost(
                log_density,
                log_density.dim() - value.dim() +
                self.poiss_priors[i_poiss].event_dim)

            result[self.labels_poiss[i_poiss]] = pyro.sample(
                self.labels_poiss[i_poiss],
                dist.Delta(value,
                           log_density=log_density,
                           event_dim=self.poiss_priors[i_poiss].event_dim))

        i_param = self.n_poiss

        for i_ps in torch.arange(self.n_ps):
            for i_ps_param in torch.arange(self.n_ps_params):

                transform = biject_to(self.ps_priors[i_ps][i_ps_param].support)

                value = transform(states[i_param])

                log_density = transform.inv.log_abs_det_jacobian(
                    value, states[i_param])
                log_density = sum_rightmost(
                    log_density,
                    log_density.dim() - value.dim() +
                    self.ps_priors[i_ps][i_ps_param].event_dim)

                result[self.labels_ps_params[i_ps_param] + "_" +
                       self.labels_ps[i_ps]] = pyro.sample(
                           self.labels_ps_params[i_ps_param] + "_" +
                           self.labels_ps[i_ps],
                           dist.Delta(value,
                                      log_density=log_density,
                                      event_dim=self.ps_priors[i_ps]
                                      [i_ps_param].event_dim))
                i_param += 1

        return result
Пример #4
0
    def guide_horseshoe_plus(self):
        
        npar = self.npars  # number of parameters
        nsub = self.runs  # number of subjects
        trns = biject_to(constraints.positive)

        
        m_hyp = param('m_hyp', zeros(2*npar))
        st_hyp = param('scale_tril_hyp', 
                              torch.eye(2*npar), 
                              constraint=constraints.lower_cholesky)
        hyp = sample('hyp', dist.MultivariateNormal(m_hyp, 
                                                  scale_tril=st_hyp), 
                            infer={'is_auxiliary': True})
        
        unc_mu = hyp[:npar]
        unc_sigma = hyp[npar:]
    
    
        c_sigma = trns(unc_sigma)
    
        ld_sigma = trns.inv.log_abs_det_jacobian(c_sigma, unc_sigma)
        ld_sigma = sum_rightmost(ld_sigma, ld_sigma.dim() - c_sigma.dim() + 1)
    
        mu_g = sample("mu_g", dist.Delta(unc_mu, event_dim=1))
        sigma_g = sample("sigma_g", dist.Delta(c_sigma, log_density=ld_sigma, event_dim=1))
        
        m_tmp = param('m_tmp', zeros(nsub, 2*npar))
        st_tmp = param('s_tmp', torch.eye(2*npar).repeat(nsub, 1, 1), 
                   constraint=constraints.lower_cholesky)

        with plate('subjects', nsub):
            tmp = sample('tmp', dist.MultivariateNormal(m_tmp, 
                                                  scale_tril=st_tmp), 
                            infer={'is_auxiliary': True})
            
            unc_locs = tmp[..., :npar]
            unc_scale = tmp[..., npar:]
            
            c_scale = trns(unc_scale)
            
            ld_scale = trns.inv.log_abs_det_jacobian(c_scale, unc_scale)
            ld_scale = sum_rightmost(ld_scale, ld_scale.dim() - c_scale.dim() + 1)
            
            x = sample("x", dist.Delta(unc_locs, event_dim=1))
            sigma_x = sample("sigma_x", dist.Delta(c_scale, log_density=ld_scale, event_dim=1))
        
        return {'mu_g': mu_g, 'sigma_g': sigma_g, 'sigma_x': sigma_x, 'x': x}
Пример #5
0
    def guide(self):
        """Approximate posterior for the horseshoe prior. We assume posterior in the form
        of the multivariate normal distriburtion for the global mean and standard deviation
        and multivariate normal distribution for the parameters of each subject independently.
        """
        nsub = self.runs  # number of subjects
        npar = self.npar  # number of parameters
        trns = biject_to(constraints.positive)

        m_hyp = param('m_hyp', zeros(2 * npar))
        st_hyp = param('scale_tril_hyp',
                       torch.eye(2 * npar),
                       constraint=constraints.lower_cholesky)
        hyp = sample('hyp',
                     dist.MultivariateNormal(m_hyp, scale_tril=st_hyp),
                     infer={'is_auxiliary': True})

        unc_mu = hyp[..., :npar]
        unc_tau = hyp[..., npar:]

        c_tau = trns(unc_tau)

        ld_tau = trns.inv.log_abs_det_jacobian(c_tau, unc_tau)
        ld_tau = sum_rightmost(ld_tau, ld_tau.dim() - c_tau.dim() + 1)

        sample("mu", dist.Delta(unc_mu, event_dim=1))
        sample("tau", dist.Delta(c_tau, log_density=ld_tau, event_dim=1))

        m_locs = param('m_locs', zeros(nsub, npar))
        st_locs = param('scale_tril_locs',
                        torch.eye(npar).repeat(nsub, 1, 1),
                        constraint=constraints.lower_cholesky)

        with plate('runs', nsub):
            sample("locs", dist.MultivariateNormal(m_locs, scale_tril=st_locs))
Пример #6
0
def _kl_independent_independent(p, q):
    if p.reinterpreted_batch_ndims != q.reinterpreted_batch_ndims:
        raise NotImplementedError
    kl = kl_divergence(p.base_dist, q.base_dist)
    if p.reinterpreted_batch_ndims:
        kl = sum_rightmost(kl, p.reinterpreted_batch_ndims)
    return kl
Пример #7
0
    def __call__(self, *args, **kwargs):
        """
        An automatic guide with the same ``*args, **kwargs`` as the base ``model``.

        :return: A dict mapping sample site name to sampled value.
        :rtype: dict
        """
        # if we've never run the model before, do so now so we can inspect the model structure
        if self.prototype_trace is None:
            self._setup_prototype(*args, **kwargs)

        latent = self.sample_latent(*args, **kwargs)
        plates = self._create_plates()

        # unpack continuous latent samples
        result = {}
        for site, unconstrained_value in self._unpack_latent(latent):
            name = site["name"]
            transform = biject_to(site["fn"].support)
            value = transform(unconstrained_value)
            log_density = transform.inv.log_abs_det_jacobian(value, unconstrained_value)
            log_density = sum_rightmost(log_density, log_density.dim() - value.dim() + site["fn"].event_dim)
            delta_dist = dist.Delta(value, log_density=log_density, event_dim=site["fn"].event_dim)

            with ExitStack() as stack:
                for frame in self._cond_indep_stacks[name]:
                    stack.enter_context(plates[frame.name])
                result[name] = pyro.sample(name, delta_dist)

        return result
Пример #8
0
    def sample(self, guide_name, fn, infer=None):
        """
        Wrapper around ``pyro.sample()`` to create a single auxiliary sample
        site and then unpack to multiple sample sites for model replay.

        :param str guide_name: The name of the auxiliary guide site.
        :param callable fn: A distribution with shape ``self.event_shape``.
        :param dict infer: Optional inference configuration dict.
        :returns: A pair ``(guide_z, model_zs)`` where ``guide_z`` is the
            single concatenated blob and ``model_zs`` is a dict mapping
            site name to constrained model sample.
        :rtype: tuple
        """
        # Sample a packed tensor.
        if fn.event_shape != self.event_shape:
            raise ValueError(
                "Invalid fn.event_shape for group: expected {}, actual {}".
                format(tuple(self.event_shape), tuple(fn.event_shape)))
        if infer is None:
            infer = {}
        infer["is_auxiliary"] = True
        guide_z = pyro.sample(guide_name, fn, infer=infer)
        common_batch_shape = guide_z.shape[:-1]

        model_zs = {}
        pos = 0
        for site in self.prototype_sites:
            name = site["name"]
            fn = site["fn"]

            # Extract slice from packed sample.
            size = self._site_sizes[name]
            batch_shape = broadcast_shape(common_batch_shape,
                                          self._site_batch_shapes[name])
            unconstrained_z = guide_z[..., pos:pos + size]
            unconstrained_z = unconstrained_z.reshape(batch_shape +
                                                      fn.event_shape)
            pos += size

            # Transform to constrained space.
            transform = biject_to(fn.support)
            z = transform(unconstrained_z)
            log_density = transform.inv.log_abs_det_jacobian(
                z, unconstrained_z)
            log_density = sum_rightmost(
                log_density,
                log_density.dim() - z.dim() + fn.event_dim)
            delta_dist = dist.Delta(z,
                                    log_density=log_density,
                                    event_dim=fn.event_dim)

            # Replay model sample statement.
            with ExitStack() as stack:
                for frame in site["cond_indep_stack"]:
                    plate = self.guide.plate(frame.name)
                    if plate not in runtime._PYRO_STACK:
                        stack.enter_context(plate)
                model_zs[name] = pyro.sample(name, delta_dist)

        return guide_z, model_zs
Пример #9
0
    def __call__(self, *args, **kwargs):
        """
        An automatic guide with the same ``*args, **kwargs`` as the base ``model``.

        :return: A dict mapping sample site name to sampled value.
        :rtype: dict
        """
        # if we've never run the model before, do so now so we can inspect the model structure
        if self.prototype_trace is None:
            self._setup_prototype(*args, **kwargs)

        latent = self.sample_latent(*args, **kwargs)
        iaranges = self._create_iaranges()

        # unpack continuous latent samples
        result = {}
        for site, unconstrained_value in self._unpack_latent(latent):
            name = site["name"]
            transform = biject_to(site["fn"].support)
            value = transform(unconstrained_value)
            log_density = transform.inv.log_abs_det_jacobian(value, unconstrained_value)
            log_density = sum_rightmost(log_density, log_density.dim() - value.dim() + site["fn"].event_dim)
            delta_dist = dist.Delta(value, log_density=log_density, event_dim=site["fn"].event_dim)

            with ExitStack() as stack:
                for frame in self._cond_indep_stacks[name]:
                    stack.enter_context(iaranges[frame.name])
                result[name] = pyro.sample(name, delta_dist)

        return result
Пример #10
0
def test_kl_independent_normal(batch_shape, event_shape):
    shape = batch_shape + event_shape
    p = dist.Normal(torch.randn(shape), torch.randn(shape).exp())
    q = dist.Normal(torch.randn(shape), torch.randn(shape).exp())
    actual = kl_divergence(dist.Independent(p, len(event_shape)),
                           dist.Independent(q, len(event_shape)))
    expected = sum_rightmost(kl_divergence(p, q), len(event_shape))
    assert_close(actual, expected)
Пример #11
0
def _kl_transformed_transformed(p, q):
    if p.transforms != q.transforms:
        raise NotImplementedError
    if p.event_shape != q.event_shape:
        raise NotImplementedError
    extra_event_dim = len(p.base_dist.batch_shape) - len(p.batch_shape)
    base_kl_divergence = kl_divergence(p.base_dist, q.base_dist)
    return sum_rightmost(base_kl_divergence, extra_event_dim)
Пример #12
0
    def forward(self, *args, **kwargs):
        """
        An automatic guide with the same ``*args, **kwargs`` as the base ``model``.

        .. note:: This method is used internally by :class:`~torch.nn.Module`.
            Users should instead use :meth:`~torch.nn.Module.__call__`.

        :return: A dict mapping sample site name to sampled value.
        :rtype: dict
        """
        # if we've never run the model before, do so now so we can inspect the model structure
        if self.prototype_trace is None:
            self._setup_prototype(*args, **kwargs)

        encoded_hidden = self.encode(*args, **kwargs)

        plates = self._create_plates(*args, **kwargs)
        result = {}
        for name, site in self.prototype_trace.iter_stochastic_nodes():
            transform = biject_to(site["fn"].support)

            with ExitStack() as stack:
                for frame in site["cond_indep_stack"]:
                    if frame.vectorized:
                        stack.enter_context(plates[frame.name])

                site_loc, site_scale = self._get_loc_and_scale(name, encoded_hidden)
                unconstrained_latent = pyro.sample(
                    name + "_unconstrained",
                    dist.Normal(
                        site_loc,
                        site_scale,
                    ).to_event(self._event_dims[name]),
                    infer={"is_auxiliary": True},
                )

                value = transform(unconstrained_latent)
                if pyro.poutine.get_mask() is False:
                    log_density = 0.0
                else:
                    log_density = transform.inv.log_abs_det_jacobian(
                        value,
                        unconstrained_latent,
                    )
                    log_density = sum_rightmost(
                        log_density,
                        log_density.dim() - value.dim() + site["fn"].event_dim,
                    )
                delta_dist = dist.Delta(
                    value,
                    log_density=log_density,
                    event_dim=site["fn"].event_dim,
                )

                result[name] = pyro.sample(name, delta_dist)

        return result
Пример #13
0
 def conjugate_update(self, other):
     """
     EXPERIMENTAL.
     """
     n = self.reintepreted_batch_ndims
     updated, log_normalizer = self.base_dist.conjugate_update(other.to_event(-n))
     updated = updated.to_event(n)
     log_normalizer = sum_rightmost(log_normalizer, n)
     return updated, log_normalizer
Пример #14
0
def _kl_independent_independent(p, q):
    shared_ndims = min(p.reinterpreted_batch_ndims, q.reinterpreted_batch_ndims)
    p_ndims = p.reinterpreted_batch_ndims - shared_ndims
    q_ndims = q.reinterpreted_batch_ndims - shared_ndims
    p = Independent(p.base_dist, p_ndims) if p_ndims else p.base_dist
    q = Independent(q.base_dist, q_ndims) if q_ndims else q.base_dist
    kl = kl_divergence(p, q)
    if shared_ndims:
        kl = sum_rightmost(kl, shared_ndims)
    return kl
Пример #15
0
    def apply(self, msg):
        name = msg["name"]
        fn = msg["fn"]
        value = msg["value"]
        is_observed = msg["is_observed"]
        if name not in self.guide.prototype_trace.nodes:
            return {"fn": fn, "value": value, "is_observed": is_observed}
        if is_observed:
            raise NotImplementedError(
                f"At pyro.sample({repr(name)},...), "
                "NeuTraReparam does not support observe statements.")

        log_density = 0.0
        compute_density = poutine.get_mask() is not False
        if name not in self.x_unconstrained:  # On first sample site.
            # Sample a shared latent.
            try:
                self.transform = self.guide.get_transform()
            except (NotImplementedError, TypeError) as e:
                raise ValueError(
                    "NeuTraReparam only supports guides that implement "
                    "`get_transform` method that does not depend on the "
                    "model's `*args, **kwargs`") from e

            with ExitStack() as stack:
                for plate in self.guide.plates.values():
                    stack.enter_context(
                        block_plate(dim=plate.dim, strict=False))
                z_unconstrained = pyro.sample(
                    f"{name}_shared_latent",
                    self.guide.get_base_dist().mask(False))

            # Differentiably transform.
            x_unconstrained = self.transform(z_unconstrained)
            if compute_density:
                log_density = self.transform.log_abs_det_jacobian(
                    z_unconstrained, x_unconstrained)
            self.x_unconstrained = {
                site["name"]: (site, unconstrained_value)
                for site, unconstrained_value in self.guide._unpack_latent(
                    x_unconstrained)
            }

        # Extract a single site's value from the shared latent.
        site, unconstrained_value = self.x_unconstrained.pop(name)
        transform = biject_to(fn.support)
        value = transform(unconstrained_value)
        if compute_density:
            logdet = transform.log_abs_det_jacobian(unconstrained_value, value)
            logdet = sum_rightmost(logdet,
                                   logdet.dim() - value.dim() + fn.event_dim)
            log_density = log_density + fn.log_prob(value) + logdet
        new_fn = dist.Delta(value, log_density, event_dim=fn.event_dim)
        return {"fn": new_fn, "value": value, "is_observed": True}
Пример #16
0
def test_sum_rightmost():
    x = torch.ones(2, 3, 4)
    assert sum_rightmost(x, 0).shape == (2, 3, 4)
    assert sum_rightmost(x, 1).shape == (2, 3)
    assert sum_rightmost(x, 2).shape == (2, )
    assert sum_rightmost(x, -1).shape == (2, )
    assert sum_rightmost(x, -2).shape == (2, 3)
    assert sum_rightmost(x, INF).shape == ()
Пример #17
0
def test_sum_rightmost():
    x = torch.ones(2, 3, 4)
    assert sum_rightmost(x, 0).shape == (2, 3, 4)
    assert sum_rightmost(x, 1).shape == (2, 3)
    assert sum_rightmost(x, 2).shape == (2,)
    assert sum_rightmost(x, -1).shape == (2,)
    assert sum_rightmost(x, -2).shape == (2, 3)
    assert sum_rightmost(x, float('inf')).shape == ()
Пример #18
0
    def forward(self, *args, **kwargs):
        """
        An automatic guide with the same ``*args, **kwargs`` as the base ``model``.

        .. note:: This method is used internally by :class:`~torch.nn.Module`.
            Users should instead use :meth:`~torch.nn.Module.__call__`.

        :return: A dict mapping sample site name to sampled value.
        :rtype: dict
        """
        # if we've never run the model before, do so now so we can inspect the model structure
        if self.prototype_trace is None:
            self._setup_prototype(*args, **kwargs)

        latent = self.sample_latent(*args, **kwargs)
        plates = self._create_plates(*args, **kwargs)

        # unpack continuous latent samples
        result = {}
        for site, unconstrained_value in self._unpack_latent(latent):
            name = site["name"]
            transform = biject_to(site["fn"].support)
            value = transform(unconstrained_value)
            if poutine.get_mask() is False:
                log_density = 0.0
            else:
                log_density = transform.inv.log_abs_det_jacobian(
                    value,
                    unconstrained_value,
                )
                log_density = sum_rightmost(
                    log_density,
                    log_density.dim() - value.dim() + site["fn"].event_dim,
                )
            delta_dist = dist.Delta(
                value,
                log_density=log_density,
                event_dim=site["fn"].event_dim,
            )

            with ExitStack() as stack:
                for frame in self._cond_indep_stacks[name]:
                    stack.enter_context(plates[frame.name])
                result[name] = pyro.sample(name, delta_dist)

        return result
Пример #19
0
    def __call__(self, name, fn, obs):
        if name not in self.guide.prototype_trace.nodes:
            return fn, obs
        assert obs is None, "NeuTraReparam does not support observe statements"
        log_density = 0.0
        compute_density = (poutine.get_mask() is not False)
        if not self.x_unconstrained:  # On first sample site.
            # Sample a shared latent.
            try:
                self.transform = self.guide.get_transform()
            except (NotImplementedError, TypeError) as e:
                raise ValueError(
                    "NeuTraReparam only supports guides that implement "
                    "`get_transform` method that does not depend on the "
                    "model's `*args, **kwargs`") from e

            z_unconstrained = pyro.sample(
                "{}_shared_latent".format(name),
                self.guide.get_base_dist().mask(False))

            # Differentiably transform.
            x_unconstrained = self.transform(z_unconstrained)
            if compute_density:
                log_density = self.transform.log_abs_det_jacobian(
                    z_unconstrained, x_unconstrained)
            self.x_unconstrained = list(
                reversed(list(self.guide._unpack_latent(x_unconstrained))))

        # Extract a single site's value from the shared latent.
        site, unconstrained_value = self.x_unconstrained.pop()
        assert name == site["name"], "model structure changed"
        transform = biject_to(fn.support)
        value = transform(unconstrained_value)
        if compute_density:
            logdet = transform.log_abs_det_jacobian(unconstrained_value, value)
            logdet = sum_rightmost(logdet,
                                   logdet.dim() - value.dim() + fn.event_dim)
            log_density = log_density + fn.log_prob(value) + logdet
        new_fn = dist.Delta(value, log_density, event_dim=fn.event_dim)
        return new_fn, value
Пример #20
0
    def __call__(self, name, fn, obs):
        if name not in self.guide.prototype_trace.nodes:
            return fn, obs
        assert obs is None, "NeuTraReparam does not support observe statements"
        log_density = 0.
        if not self.x_unconstrained:  # On first sample site.
            # Sample a shared latent.
            # TODO(fehiepsi) Consider adding a method to extract transform from an Auto*Normal(posterior).
            posterior = self.guide.get_posterior()
            if not isinstance(posterior, dist.TransformedDistribution):
                raise ValueError(
                    "NeuTraReparam only supports guides whose posteriors are "
                    "TransformedDistributions but got a posterior of type {}".
                    format(type(posterior)))
            self.transform = dist.transforms.ComposeTransform(
                posterior.transforms)
            z_unconstrained = pyro.sample("{}_shared_latent".format(name),
                                          posterior.base_dist.mask(False))

            # Differentiably transform.
            x_unconstrained = self.transform(z_unconstrained)
            log_density = self.transform.log_abs_det_jacobian(
                z_unconstrained, x_unconstrained)
            self.x_unconstrained = list(
                reversed(list(self.guide._unpack_latent(x_unconstrained))))

        # Extract a single site's value from the shared latent.
        site, unconstrained_value = self.x_unconstrained.pop()
        assert name == site["name"], "model structure changed"
        transform = biject_to(fn.support)
        value = transform(unconstrained_value)
        logdet = transform.log_abs_det_jacobian(unconstrained_value, value)
        logdet = sum_rightmost(logdet,
                               logdet.dim() - value.dim() + fn.event_dim)
        log_density = log_density + fn.log_prob(value) + logdet
        new_fn = dist.Delta(value, log_density, event_dim=fn.event_dim)
        return new_fn, value
Пример #21
0
def evaluate_log_posterior_density(model, posterior_samples, baseball_dataset):
    """
    Evaluate the log probability density of observing the unseen data (season hits)
    given a model and posterior distribution over the parameters.
    """
    _, test, player_names = train_test_split(baseball_dataset)
    at_bats_season, hits_season = test[:, 0], test[:, 1]
    with ignore_experimental_warning():
        trace = predictive(model, posterior_samples, at_bats_season, hits_season,
                           parallel=True, return_trace=True)
    # Use LogSumExp trick to evaluate $log(1/num_samples \sum_i p(new_data | \theta^{i})) $,
    # where $\theta^{i}$ are parameter samples from the model's posterior.
    trace.compute_log_prob()
    log_joint = 0.
    for name, site in trace.nodes.items():
        if site["type"] == "sample" and not site_is_subsample(site):
            # We use `sum_rightmost(x, -1)` to take the sum of all rightmost dimensions of `x`
            # except the first dimension (which corresponding to the number of posterior samples)
            site_log_prob_sum = sum_rightmost(site['log_prob'], -1)
            log_joint += site_log_prob_sum
    posterior_pred_density = torch.logsumexp(log_joint, dim=0) - math.log(log_joint.shape[0])
    logging.info("\nLog posterior predictive density")
    logging.info("--------------------------------")
    logging.info("{:.4f}\n".format(posterior_pred_density))
Пример #22
0
 def log_prob(self, x):
     v = self.v.expand(self.shape())
     log_prob = x.new_tensor(x == v).log()
     log_prob = sum_rightmost(log_prob, self.event_dim)
     return log_prob + self.log_density
Пример #23
0
 def log_prob(self, x):
     v = self.v.expand(self.shape())
     log_prob = x.new_tensor(x == v).log()
     log_prob = sum_rightmost(log_prob, self.event_dim)
     return log_prob + self.log_density
Пример #24
0
 def log_prob(self, x):
     v = self.v.expand(self.shape())
     log_prob = (x == v).type(x.dtype).log()
     log_prob = sum_rightmost(log_prob, self.event_dim)
     return log_prob + self.log_density
Пример #25
0
    def templates_guide_iaf(self, indices, gp_sample=None):
        """ IAF guide for template parameters
        """

        # Number of context variables (GP summary statistics) to condition IAF
        context_vars = torch.zeros(self.n_poiss + self.n_ps + 2)
        # context_vars = torch.zeros(2)

        td = None

        # IAF transformation either with or without conditioning on GP draw

        if self.guide_name == "IAF":

            td = dist.TransformedDistribution(self.base_dist, self.transform)

        elif self.guide_name == "ConditionalIAF":

            # Summary stats of GP draw---dor products of Poiss/non-Poiss templates with GP, as well as GP mean and variance
            context_vars[:self.n_poiss] = (
                self.poiss_temps[:, indices]
                @ gp_sample.exp().double()) / self.n_pix
            context_vars[self.n_poiss:self.n_poiss +
                         self.n_ps] = (self.ps_temps[:, indices]
                                       @ gp_sample.exp().double()) / self.n_pix
            context_vars[-2] = torch.mean(gp_sample.exp())
            context_vars[-1] = torch.var(gp_sample.exp()).sqrt()

            td = dist.ConditionalTransformedDistribution(
                self.base_dist, self.transform).condition(context=context_vars)

        states = pyro.sample("states_" + self.name_prefix,
                             td,
                             infer={"is_auxiliary": True})

        result = {}

        for i_poiss in torch.arange(self.n_poiss):
            transform = biject_to(self.poiss_priors[i_poiss].support)
            value = transform(states[i_poiss])
            log_density = transform.inv.log_abs_det_jacobian(
                value, states[i_poiss])
            log_density = sum_rightmost(
                log_density,
                log_density.dim() - value.dim() +
                self.poiss_priors[i_poiss].event_dim)

            result[self.poiss_labels[i_poiss]] = pyro.sample(
                self.poiss_labels[i_poiss],
                dist.Delta(value,
                           log_density=log_density,
                           event_dim=self.poiss_priors[i_poiss].event_dim))

        i_param = self.n_poiss

        for i_ps in torch.arange(self.n_ps):
            for i_ps_param in torch.arange(self.n_ps_params):

                transform = biject_to(self.ps_priors[i_ps][i_ps_param].support)

                value = transform(states[i_param])

                log_density = transform.inv.log_abs_det_jacobian(
                    value, states[i_param])
                log_density = sum_rightmost(
                    log_density,
                    log_density.dim() - value.dim() +
                    self.ps_priors[i_ps][i_ps_param].event_dim)

                result[self.ps_param_labels[i_ps_param] + "_" +
                       self.ps_labels[i_ps]] = pyro.sample(
                           self.ps_param_labels[i_ps_param] + "_" +
                           self.ps_labels[i_ps],
                           dist.Delta(value,
                                      log_density=log_density,
                                      event_dim=self.ps_priors[i_ps]
                                      [i_ps_param].event_dim))
                i_param += 1

        return result
Пример #26
0
 def log_prob(self, value):
     shape = broadcast_shape(self.batch_shape, value.shape[:value.dim() - self.event_dim])
     return sum_rightmost(self.base_dist.log_prob(value), self.reinterpreted_batch_ndims).expand(shape)