예제 #1
0
    def templates_guide_mvn(self):
        """ Multivariate normal guide for template parameters
        """

        loc = _deep_getattr(self, "mvn.loc")
        scale_tril = _deep_getattr(self, "mvn.scale_tril")

        dt = dist.MultivariateNormal(loc, scale_tril=scale_tril)
        states = pyro.sample("states_" + self.name_prefix,
                             dt,
                             infer={"is_auxiliary": True})

        result = {}

        for i_poiss in torch.arange(self.n_poiss):
            transform = biject_to(self.poiss_priors[i_poiss].support)
            value = transform(states[i_poiss])
            log_density = transform.inv.log_abs_det_jacobian(
                value, states[i_poiss])
            log_density = sum_rightmost(
                log_density,
                log_density.dim() - value.dim() +
                self.poiss_priors[i_poiss].event_dim)

            result[self.poiss_labels[i_poiss]] = pyro.sample(
                self.poiss_labels[i_poiss],
                dist.Delta(value,
                           log_density=log_density,
                           event_dim=self.poiss_priors[i_poiss].event_dim))

        i_param = self.n_poiss

        for i_ps in torch.arange(self.n_ps):
            for i_ps_param in torch.arange(self.n_ps_params):

                transform = biject_to(self.ps_priors[i_ps][i_ps_param].support)

                value = transform(states[i_param])

                log_density = transform.inv.log_abs_det_jacobian(
                    value, states[i_param])
                log_density = sum_rightmost(
                    log_density,
                    log_density.dim() - value.dim() +
                    self.ps_priors[i_ps][i_ps_param].event_dim)

                result[self.ps_param_labels[i_ps_param] + "_" +
                       self.ps_labels[i_ps]] = pyro.sample(
                           self.ps_param_labels[i_ps_param] + "_" +
                           self.ps_labels[i_ps],
                           dist.Delta(value,
                                      log_density=log_density,
                                      event_dim=self.ps_priors[i_ps]
                                      [i_ps_param].event_dim))
                i_param += 1

        return result
예제 #2
0
 def score_parts(self, value):
     shape = broadcast_shape(self.batch_shape, value.shape[:value.dim() - self.event_dim])
     log_prob, score_function, entropy_term = self.base_dist.score_parts(value)
     log_prob = sum_rightmost(log_prob, self.reinterpreted_batch_ndims).expand(shape)
     if not isinstance(score_function, numbers.Number):
         score_function = sum_rightmost(score_function, self.reinterpreted_batch_ndims).expand(shape)
     if not isinstance(entropy_term, numbers.Number):
         entropy_term = sum_rightmost(entropy_term, self.reinterpreted_batch_ndims).expand(shape)
     return ScoreParts(log_prob, score_function, entropy_term)
예제 #3
0
    def guide(self):

        a_locs = pyro.param("a_locs", torch.full((self.n_params, ), 0.0))
        a_scales_tril = pyro.param(
            "a_scales",
            lambda: 0.1 * eye_like(a_locs, self.n_params),
            constraint=constraints.lower_cholesky)

        dt = dist.MultivariateNormal(a_locs, scale_tril=a_scales_tril)
        states = pyro.sample("states", dt, infer={"is_auxiliary": True})

        result = {}

        for i_poiss in torch.arange(self.n_poiss):
            transform = biject_to(self.poiss_priors[i_poiss].support)
            value = transform(states[i_poiss])
            log_density = transform.inv.log_abs_det_jacobian(
                value, states[i_poiss])
            log_density = sum_rightmost(
                log_density,
                log_density.dim() - value.dim() +
                self.poiss_priors[i_poiss].event_dim)

            result[self.labels_poiss[i_poiss]] = pyro.sample(
                self.labels_poiss[i_poiss],
                dist.Delta(value,
                           log_density=log_density,
                           event_dim=self.poiss_priors[i_poiss].event_dim))

        i_param = self.n_poiss

        for i_ps in torch.arange(self.n_ps):
            for i_ps_param in torch.arange(self.n_ps_params):

                transform = biject_to(self.ps_priors[i_ps][i_ps_param].support)

                value = transform(states[i_param])

                log_density = transform.inv.log_abs_det_jacobian(
                    value, states[i_param])
                log_density = sum_rightmost(
                    log_density,
                    log_density.dim() - value.dim() +
                    self.ps_priors[i_ps][i_ps_param].event_dim)

                result[self.labels_ps_params[i_ps_param] + "_" +
                       self.labels_ps[i_ps]] = pyro.sample(
                           self.labels_ps_params[i_ps_param] + "_" +
                           self.labels_ps[i_ps],
                           dist.Delta(value,
                                      log_density=log_density,
                                      event_dim=self.ps_priors[i_ps]
                                      [i_ps_param].event_dim))
                i_param += 1

        return result
예제 #4
0
    def guide_horseshoe_plus(self):
        
        npar = self.npars  # number of parameters
        nsub = self.runs  # number of subjects
        trns = biject_to(constraints.positive)

        
        m_hyp = param('m_hyp', zeros(2*npar))
        st_hyp = param('scale_tril_hyp', 
                              torch.eye(2*npar), 
                              constraint=constraints.lower_cholesky)
        hyp = sample('hyp', dist.MultivariateNormal(m_hyp, 
                                                  scale_tril=st_hyp), 
                            infer={'is_auxiliary': True})
        
        unc_mu = hyp[:npar]
        unc_sigma = hyp[npar:]
    
    
        c_sigma = trns(unc_sigma)
    
        ld_sigma = trns.inv.log_abs_det_jacobian(c_sigma, unc_sigma)
        ld_sigma = sum_rightmost(ld_sigma, ld_sigma.dim() - c_sigma.dim() + 1)
    
        mu_g = sample("mu_g", dist.Delta(unc_mu, event_dim=1))
        sigma_g = sample("sigma_g", dist.Delta(c_sigma, log_density=ld_sigma, event_dim=1))
        
        m_tmp = param('m_tmp', zeros(nsub, 2*npar))
        st_tmp = param('s_tmp', torch.eye(2*npar).repeat(nsub, 1, 1), 
                   constraint=constraints.lower_cholesky)

        with plate('subjects', nsub):
            tmp = sample('tmp', dist.MultivariateNormal(m_tmp, 
                                                  scale_tril=st_tmp), 
                            infer={'is_auxiliary': True})
            
            unc_locs = tmp[..., :npar]
            unc_scale = tmp[..., npar:]
            
            c_scale = trns(unc_scale)
            
            ld_scale = trns.inv.log_abs_det_jacobian(c_scale, unc_scale)
            ld_scale = sum_rightmost(ld_scale, ld_scale.dim() - c_scale.dim() + 1)
            
            x = sample("x", dist.Delta(unc_locs, event_dim=1))
            sigma_x = sample("sigma_x", dist.Delta(c_scale, log_density=ld_scale, event_dim=1))
        
        return {'mu_g': mu_g, 'sigma_g': sigma_g, 'sigma_x': sigma_x, 'x': x}
예제 #5
0
    def guide(self):
        """Approximate posterior for the horseshoe prior. We assume posterior in the form
        of the multivariate normal distriburtion for the global mean and standard deviation
        and multivariate normal distribution for the parameters of each subject independently.
        """
        nsub = self.runs  # number of subjects
        npar = self.npar  # number of parameters
        trns = biject_to(constraints.positive)

        m_hyp = param('m_hyp', zeros(2 * npar))
        st_hyp = param('scale_tril_hyp',
                       torch.eye(2 * npar),
                       constraint=constraints.lower_cholesky)
        hyp = sample('hyp',
                     dist.MultivariateNormal(m_hyp, scale_tril=st_hyp),
                     infer={'is_auxiliary': True})

        unc_mu = hyp[..., :npar]
        unc_tau = hyp[..., npar:]

        c_tau = trns(unc_tau)

        ld_tau = trns.inv.log_abs_det_jacobian(c_tau, unc_tau)
        ld_tau = sum_rightmost(ld_tau, ld_tau.dim() - c_tau.dim() + 1)

        sample("mu", dist.Delta(unc_mu, event_dim=1))
        sample("tau", dist.Delta(c_tau, log_density=ld_tau, event_dim=1))

        m_locs = param('m_locs', zeros(nsub, npar))
        st_locs = param('scale_tril_locs',
                        torch.eye(npar).repeat(nsub, 1, 1),
                        constraint=constraints.lower_cholesky)

        with plate('runs', nsub):
            sample("locs", dist.MultivariateNormal(m_locs, scale_tril=st_locs))
예제 #6
0
파일: torch.py 프로젝트: zyxue/pyro
def _kl_independent_independent(p, q):
    if p.reinterpreted_batch_ndims != q.reinterpreted_batch_ndims:
        raise NotImplementedError
    kl = kl_divergence(p.base_dist, q.base_dist)
    if p.reinterpreted_batch_ndims:
        kl = sum_rightmost(kl, p.reinterpreted_batch_ndims)
    return kl
예제 #7
0
파일: guides.py 프로젝트: jamestwebber/pyro
    def __call__(self, *args, **kwargs):
        """
        An automatic guide with the same ``*args, **kwargs`` as the base ``model``.

        :return: A dict mapping sample site name to sampled value.
        :rtype: dict
        """
        # if we've never run the model before, do so now so we can inspect the model structure
        if self.prototype_trace is None:
            self._setup_prototype(*args, **kwargs)

        latent = self.sample_latent(*args, **kwargs)
        plates = self._create_plates()

        # unpack continuous latent samples
        result = {}
        for site, unconstrained_value in self._unpack_latent(latent):
            name = site["name"]
            transform = biject_to(site["fn"].support)
            value = transform(unconstrained_value)
            log_density = transform.inv.log_abs_det_jacobian(value, unconstrained_value)
            log_density = sum_rightmost(log_density, log_density.dim() - value.dim() + site["fn"].event_dim)
            delta_dist = dist.Delta(value, log_density=log_density, event_dim=site["fn"].event_dim)

            with ExitStack() as stack:
                for frame in self._cond_indep_stacks[name]:
                    stack.enter_context(plates[frame.name])
                result[name] = pyro.sample(name, delta_dist)

        return result
예제 #8
0
파일: easyguide.py 프로젝트: www3cam/pyro
    def sample(self, guide_name, fn, infer=None):
        """
        Wrapper around ``pyro.sample()`` to create a single auxiliary sample
        site and then unpack to multiple sample sites for model replay.

        :param str guide_name: The name of the auxiliary guide site.
        :param callable fn: A distribution with shape ``self.event_shape``.
        :param dict infer: Optional inference configuration dict.
        :returns: A pair ``(guide_z, model_zs)`` where ``guide_z`` is the
            single concatenated blob and ``model_zs`` is a dict mapping
            site name to constrained model sample.
        :rtype: tuple
        """
        # Sample a packed tensor.
        if fn.event_shape != self.event_shape:
            raise ValueError(
                "Invalid fn.event_shape for group: expected {}, actual {}".
                format(tuple(self.event_shape), tuple(fn.event_shape)))
        if infer is None:
            infer = {}
        infer["is_auxiliary"] = True
        guide_z = pyro.sample(guide_name, fn, infer=infer)
        common_batch_shape = guide_z.shape[:-1]

        model_zs = {}
        pos = 0
        for site in self.prototype_sites:
            name = site["name"]
            fn = site["fn"]

            # Extract slice from packed sample.
            size = self._site_sizes[name]
            batch_shape = broadcast_shape(common_batch_shape,
                                          self._site_batch_shapes[name])
            unconstrained_z = guide_z[..., pos:pos + size]
            unconstrained_z = unconstrained_z.reshape(batch_shape +
                                                      fn.event_shape)
            pos += size

            # Transform to constrained space.
            transform = biject_to(fn.support)
            z = transform(unconstrained_z)
            log_density = transform.inv.log_abs_det_jacobian(
                z, unconstrained_z)
            log_density = sum_rightmost(
                log_density,
                log_density.dim() - z.dim() + fn.event_dim)
            delta_dist = dist.Delta(z,
                                    log_density=log_density,
                                    event_dim=fn.event_dim)

            # Replay model sample statement.
            with ExitStack() as stack:
                for frame in site["cond_indep_stack"]:
                    plate = self.guide.plate(frame.name)
                    if plate not in runtime._PYRO_STACK:
                        stack.enter_context(plate)
                model_zs[name] = pyro.sample(name, delta_dist)

        return guide_z, model_zs
예제 #9
0
파일: __init__.py 프로젝트: lewisKit/pyro
    def __call__(self, *args, **kwargs):
        """
        An automatic guide with the same ``*args, **kwargs`` as the base ``model``.

        :return: A dict mapping sample site name to sampled value.
        :rtype: dict
        """
        # if we've never run the model before, do so now so we can inspect the model structure
        if self.prototype_trace is None:
            self._setup_prototype(*args, **kwargs)

        latent = self.sample_latent(*args, **kwargs)
        iaranges = self._create_iaranges()

        # unpack continuous latent samples
        result = {}
        for site, unconstrained_value in self._unpack_latent(latent):
            name = site["name"]
            transform = biject_to(site["fn"].support)
            value = transform(unconstrained_value)
            log_density = transform.inv.log_abs_det_jacobian(value, unconstrained_value)
            log_density = sum_rightmost(log_density, log_density.dim() - value.dim() + site["fn"].event_dim)
            delta_dist = dist.Delta(value, log_density=log_density, event_dim=site["fn"].event_dim)

            with ExitStack() as stack:
                for frame in self._cond_indep_stacks[name]:
                    stack.enter_context(iaranges[frame.name])
                result[name] = pyro.sample(name, delta_dist)

        return result
예제 #10
0
def test_kl_independent_normal(batch_shape, event_shape):
    shape = batch_shape + event_shape
    p = dist.Normal(torch.randn(shape), torch.randn(shape).exp())
    q = dist.Normal(torch.randn(shape), torch.randn(shape).exp())
    actual = kl_divergence(dist.Independent(p, len(event_shape)),
                           dist.Independent(q, len(event_shape)))
    expected = sum_rightmost(kl_divergence(p, q), len(event_shape))
    assert_close(actual, expected)
예제 #11
0
def _kl_transformed_transformed(p, q):
    if p.transforms != q.transforms:
        raise NotImplementedError
    if p.event_shape != q.event_shape:
        raise NotImplementedError
    extra_event_dim = len(p.base_dist.batch_shape) - len(p.batch_shape)
    base_kl_divergence = kl_divergence(p.base_dist, q.base_dist)
    return sum_rightmost(base_kl_divergence, extra_event_dim)
예제 #12
0
    def forward(self, *args, **kwargs):
        """
        An automatic guide with the same ``*args, **kwargs`` as the base ``model``.

        .. note:: This method is used internally by :class:`~torch.nn.Module`.
            Users should instead use :meth:`~torch.nn.Module.__call__`.

        :return: A dict mapping sample site name to sampled value.
        :rtype: dict
        """
        # if we've never run the model before, do so now so we can inspect the model structure
        if self.prototype_trace is None:
            self._setup_prototype(*args, **kwargs)

        encoded_hidden = self.encode(*args, **kwargs)

        plates = self._create_plates(*args, **kwargs)
        result = {}
        for name, site in self.prototype_trace.iter_stochastic_nodes():
            transform = biject_to(site["fn"].support)

            with ExitStack() as stack:
                for frame in site["cond_indep_stack"]:
                    if frame.vectorized:
                        stack.enter_context(plates[frame.name])

                site_loc, site_scale = self._get_loc_and_scale(name, encoded_hidden)
                unconstrained_latent = pyro.sample(
                    name + "_unconstrained",
                    dist.Normal(
                        site_loc,
                        site_scale,
                    ).to_event(self._event_dims[name]),
                    infer={"is_auxiliary": True},
                )

                value = transform(unconstrained_latent)
                if pyro.poutine.get_mask() is False:
                    log_density = 0.0
                else:
                    log_density = transform.inv.log_abs_det_jacobian(
                        value,
                        unconstrained_latent,
                    )
                    log_density = sum_rightmost(
                        log_density,
                        log_density.dim() - value.dim() + site["fn"].event_dim,
                    )
                delta_dist = dist.Delta(
                    value,
                    log_density=log_density,
                    event_dim=site["fn"].event_dim,
                )

                result[name] = pyro.sample(name, delta_dist)

        return result
예제 #13
0
 def conjugate_update(self, other):
     """
     EXPERIMENTAL.
     """
     n = self.reintepreted_batch_ndims
     updated, log_normalizer = self.base_dist.conjugate_update(other.to_event(-n))
     updated = updated.to_event(n)
     log_normalizer = sum_rightmost(log_normalizer, n)
     return updated, log_normalizer
예제 #14
0
파일: kl.py 프로젝트: jamestwebber/pyro
def _kl_independent_independent(p, q):
    shared_ndims = min(p.reinterpreted_batch_ndims, q.reinterpreted_batch_ndims)
    p_ndims = p.reinterpreted_batch_ndims - shared_ndims
    q_ndims = q.reinterpreted_batch_ndims - shared_ndims
    p = Independent(p.base_dist, p_ndims) if p_ndims else p.base_dist
    q = Independent(q.base_dist, q_ndims) if q_ndims else q.base_dist
    kl = kl_divergence(p, q)
    if shared_ndims:
        kl = sum_rightmost(kl, shared_ndims)
    return kl
예제 #15
0
파일: neutra.py 프로젝트: pyro-ppl/pyro
    def apply(self, msg):
        name = msg["name"]
        fn = msg["fn"]
        value = msg["value"]
        is_observed = msg["is_observed"]
        if name not in self.guide.prototype_trace.nodes:
            return {"fn": fn, "value": value, "is_observed": is_observed}
        if is_observed:
            raise NotImplementedError(
                f"At pyro.sample({repr(name)},...), "
                "NeuTraReparam does not support observe statements.")

        log_density = 0.0
        compute_density = poutine.get_mask() is not False
        if name not in self.x_unconstrained:  # On first sample site.
            # Sample a shared latent.
            try:
                self.transform = self.guide.get_transform()
            except (NotImplementedError, TypeError) as e:
                raise ValueError(
                    "NeuTraReparam only supports guides that implement "
                    "`get_transform` method that does not depend on the "
                    "model's `*args, **kwargs`") from e

            with ExitStack() as stack:
                for plate in self.guide.plates.values():
                    stack.enter_context(
                        block_plate(dim=plate.dim, strict=False))
                z_unconstrained = pyro.sample(
                    f"{name}_shared_latent",
                    self.guide.get_base_dist().mask(False))

            # Differentiably transform.
            x_unconstrained = self.transform(z_unconstrained)
            if compute_density:
                log_density = self.transform.log_abs_det_jacobian(
                    z_unconstrained, x_unconstrained)
            self.x_unconstrained = {
                site["name"]: (site, unconstrained_value)
                for site, unconstrained_value in self.guide._unpack_latent(
                    x_unconstrained)
            }

        # Extract a single site's value from the shared latent.
        site, unconstrained_value = self.x_unconstrained.pop(name)
        transform = biject_to(fn.support)
        value = transform(unconstrained_value)
        if compute_density:
            logdet = transform.log_abs_det_jacobian(unconstrained_value, value)
            logdet = sum_rightmost(logdet,
                                   logdet.dim() - value.dim() + fn.event_dim)
            log_density = log_density + fn.log_prob(value) + logdet
        new_fn = dist.Delta(value, log_density, event_dim=fn.event_dim)
        return {"fn": new_fn, "value": value, "is_observed": True}
예제 #16
0
파일: test_util.py 프로젝트: zeta1999/pyro
def test_sum_rightmost():
    x = torch.ones(2, 3, 4)
    assert sum_rightmost(x, 0).shape == (2, 3, 4)
    assert sum_rightmost(x, 1).shape == (2, 3)
    assert sum_rightmost(x, 2).shape == (2, )
    assert sum_rightmost(x, -1).shape == (2, )
    assert sum_rightmost(x, -2).shape == (2, 3)
    assert sum_rightmost(x, INF).shape == ()
예제 #17
0
파일: test_util.py 프로젝트: lewisKit/pyro
def test_sum_rightmost():
    x = torch.ones(2, 3, 4)
    assert sum_rightmost(x, 0).shape == (2, 3, 4)
    assert sum_rightmost(x, 1).shape == (2, 3)
    assert sum_rightmost(x, 2).shape == (2,)
    assert sum_rightmost(x, -1).shape == (2,)
    assert sum_rightmost(x, -2).shape == (2, 3)
    assert sum_rightmost(x, float('inf')).shape == ()
예제 #18
0
파일: guides.py 프로젝트: yufengwa/pyro
    def forward(self, *args, **kwargs):
        """
        An automatic guide with the same ``*args, **kwargs`` as the base ``model``.

        .. note:: This method is used internally by :class:`~torch.nn.Module`.
            Users should instead use :meth:`~torch.nn.Module.__call__`.

        :return: A dict mapping sample site name to sampled value.
        :rtype: dict
        """
        # if we've never run the model before, do so now so we can inspect the model structure
        if self.prototype_trace is None:
            self._setup_prototype(*args, **kwargs)

        latent = self.sample_latent(*args, **kwargs)
        plates = self._create_plates(*args, **kwargs)

        # unpack continuous latent samples
        result = {}
        for site, unconstrained_value in self._unpack_latent(latent):
            name = site["name"]
            transform = biject_to(site["fn"].support)
            value = transform(unconstrained_value)
            if poutine.get_mask() is False:
                log_density = 0.0
            else:
                log_density = transform.inv.log_abs_det_jacobian(
                    value,
                    unconstrained_value,
                )
                log_density = sum_rightmost(
                    log_density,
                    log_density.dim() - value.dim() + site["fn"].event_dim,
                )
            delta_dist = dist.Delta(
                value,
                log_density=log_density,
                event_dim=site["fn"].event_dim,
            )

            with ExitStack() as stack:
                for frame in self._cond_indep_stacks[name]:
                    stack.enter_context(plates[frame.name])
                result[name] = pyro.sample(name, delta_dist)

        return result
예제 #19
0
파일: neutra.py 프로젝트: yufengwa/pyro
    def __call__(self, name, fn, obs):
        if name not in self.guide.prototype_trace.nodes:
            return fn, obs
        assert obs is None, "NeuTraReparam does not support observe statements"
        log_density = 0.0
        compute_density = (poutine.get_mask() is not False)
        if not self.x_unconstrained:  # On first sample site.
            # Sample a shared latent.
            try:
                self.transform = self.guide.get_transform()
            except (NotImplementedError, TypeError) as e:
                raise ValueError(
                    "NeuTraReparam only supports guides that implement "
                    "`get_transform` method that does not depend on the "
                    "model's `*args, **kwargs`") from e

            z_unconstrained = pyro.sample(
                "{}_shared_latent".format(name),
                self.guide.get_base_dist().mask(False))

            # Differentiably transform.
            x_unconstrained = self.transform(z_unconstrained)
            if compute_density:
                log_density = self.transform.log_abs_det_jacobian(
                    z_unconstrained, x_unconstrained)
            self.x_unconstrained = list(
                reversed(list(self.guide._unpack_latent(x_unconstrained))))

        # Extract a single site's value from the shared latent.
        site, unconstrained_value = self.x_unconstrained.pop()
        assert name == site["name"], "model structure changed"
        transform = biject_to(fn.support)
        value = transform(unconstrained_value)
        if compute_density:
            logdet = transform.log_abs_det_jacobian(unconstrained_value, value)
            logdet = sum_rightmost(logdet,
                                   logdet.dim() - value.dim() + fn.event_dim)
            log_density = log_density + fn.log_prob(value) + logdet
        new_fn = dist.Delta(value, log_density, event_dim=fn.event_dim)
        return new_fn, value
예제 #20
0
파일: neutra.py 프로젝트: www3cam/pyro
    def __call__(self, name, fn, obs):
        if name not in self.guide.prototype_trace.nodes:
            return fn, obs
        assert obs is None, "NeuTraReparam does not support observe statements"
        log_density = 0.
        if not self.x_unconstrained:  # On first sample site.
            # Sample a shared latent.
            # TODO(fehiepsi) Consider adding a method to extract transform from an Auto*Normal(posterior).
            posterior = self.guide.get_posterior()
            if not isinstance(posterior, dist.TransformedDistribution):
                raise ValueError(
                    "NeuTraReparam only supports guides whose posteriors are "
                    "TransformedDistributions but got a posterior of type {}".
                    format(type(posterior)))
            self.transform = dist.transforms.ComposeTransform(
                posterior.transforms)
            z_unconstrained = pyro.sample("{}_shared_latent".format(name),
                                          posterior.base_dist.mask(False))

            # Differentiably transform.
            x_unconstrained = self.transform(z_unconstrained)
            log_density = self.transform.log_abs_det_jacobian(
                z_unconstrained, x_unconstrained)
            self.x_unconstrained = list(
                reversed(list(self.guide._unpack_latent(x_unconstrained))))

        # Extract a single site's value from the shared latent.
        site, unconstrained_value = self.x_unconstrained.pop()
        assert name == site["name"], "model structure changed"
        transform = biject_to(fn.support)
        value = transform(unconstrained_value)
        logdet = transform.log_abs_det_jacobian(unconstrained_value, value)
        logdet = sum_rightmost(logdet,
                               logdet.dim() - value.dim() + fn.event_dim)
        log_density = log_density + fn.log_prob(value) + logdet
        new_fn = dist.Delta(value, log_density, event_dim=fn.event_dim)
        return new_fn, value
예제 #21
0
def evaluate_log_posterior_density(model, posterior_samples, baseball_dataset):
    """
    Evaluate the log probability density of observing the unseen data (season hits)
    given a model and posterior distribution over the parameters.
    """
    _, test, player_names = train_test_split(baseball_dataset)
    at_bats_season, hits_season = test[:, 0], test[:, 1]
    with ignore_experimental_warning():
        trace = predictive(model, posterior_samples, at_bats_season, hits_season,
                           parallel=True, return_trace=True)
    # Use LogSumExp trick to evaluate $log(1/num_samples \sum_i p(new_data | \theta^{i})) $,
    # where $\theta^{i}$ are parameter samples from the model's posterior.
    trace.compute_log_prob()
    log_joint = 0.
    for name, site in trace.nodes.items():
        if site["type"] == "sample" and not site_is_subsample(site):
            # We use `sum_rightmost(x, -1)` to take the sum of all rightmost dimensions of `x`
            # except the first dimension (which corresponding to the number of posterior samples)
            site_log_prob_sum = sum_rightmost(site['log_prob'], -1)
            log_joint += site_log_prob_sum
    posterior_pred_density = torch.logsumexp(log_joint, dim=0) - math.log(log_joint.shape[0])
    logging.info("\nLog posterior predictive density")
    logging.info("--------------------------------")
    logging.info("{:.4f}\n".format(posterior_pred_density))
예제 #22
0
파일: delta.py 프로젝트: zippeurfou/pyro
 def log_prob(self, x):
     v = self.v.expand(self.shape())
     log_prob = x.new_tensor(x == v).log()
     log_prob = sum_rightmost(log_prob, self.event_dim)
     return log_prob + self.log_density
예제 #23
0
파일: delta.py 프로젝트: lewisKit/pyro
 def log_prob(self, x):
     v = self.v.expand(self.shape())
     log_prob = x.new_tensor(x == v).log()
     log_prob = sum_rightmost(log_prob, self.event_dim)
     return log_prob + self.log_density
예제 #24
0
 def log_prob(self, x):
     v = self.v.expand(self.shape())
     log_prob = (x == v).type(x.dtype).log()
     log_prob = sum_rightmost(log_prob, self.event_dim)
     return log_prob + self.log_density
예제 #25
0
    def templates_guide_iaf(self, indices, gp_sample=None):
        """ IAF guide for template parameters
        """

        # Number of context variables (GP summary statistics) to condition IAF
        context_vars = torch.zeros(self.n_poiss + self.n_ps + 2)
        # context_vars = torch.zeros(2)

        td = None

        # IAF transformation either with or without conditioning on GP draw

        if self.guide_name == "IAF":

            td = dist.TransformedDistribution(self.base_dist, self.transform)

        elif self.guide_name == "ConditionalIAF":

            # Summary stats of GP draw---dor products of Poiss/non-Poiss templates with GP, as well as GP mean and variance
            context_vars[:self.n_poiss] = (
                self.poiss_temps[:, indices]
                @ gp_sample.exp().double()) / self.n_pix
            context_vars[self.n_poiss:self.n_poiss +
                         self.n_ps] = (self.ps_temps[:, indices]
                                       @ gp_sample.exp().double()) / self.n_pix
            context_vars[-2] = torch.mean(gp_sample.exp())
            context_vars[-1] = torch.var(gp_sample.exp()).sqrt()

            td = dist.ConditionalTransformedDistribution(
                self.base_dist, self.transform).condition(context=context_vars)

        states = pyro.sample("states_" + self.name_prefix,
                             td,
                             infer={"is_auxiliary": True})

        result = {}

        for i_poiss in torch.arange(self.n_poiss):
            transform = biject_to(self.poiss_priors[i_poiss].support)
            value = transform(states[i_poiss])
            log_density = transform.inv.log_abs_det_jacobian(
                value, states[i_poiss])
            log_density = sum_rightmost(
                log_density,
                log_density.dim() - value.dim() +
                self.poiss_priors[i_poiss].event_dim)

            result[self.poiss_labels[i_poiss]] = pyro.sample(
                self.poiss_labels[i_poiss],
                dist.Delta(value,
                           log_density=log_density,
                           event_dim=self.poiss_priors[i_poiss].event_dim))

        i_param = self.n_poiss

        for i_ps in torch.arange(self.n_ps):
            for i_ps_param in torch.arange(self.n_ps_params):

                transform = biject_to(self.ps_priors[i_ps][i_ps_param].support)

                value = transform(states[i_param])

                log_density = transform.inv.log_abs_det_jacobian(
                    value, states[i_param])
                log_density = sum_rightmost(
                    log_density,
                    log_density.dim() - value.dim() +
                    self.ps_priors[i_ps][i_ps_param].event_dim)

                result[self.ps_param_labels[i_ps_param] + "_" +
                       self.ps_labels[i_ps]] = pyro.sample(
                           self.ps_param_labels[i_ps_param] + "_" +
                           self.ps_labels[i_ps],
                           dist.Delta(value,
                                      log_density=log_density,
                                      event_dim=self.ps_priors[i_ps]
                                      [i_ps_param].event_dim))
                i_param += 1

        return result
예제 #26
0
 def log_prob(self, value):
     shape = broadcast_shape(self.batch_shape, value.shape[:value.dim() - self.event_dim])
     return sum_rightmost(self.base_dist.log_prob(value), self.reinterpreted_batch_ndims).expand(shape)