コード例 #1
0
    def autoguide(self, name, dist_constructor):
        """
        Sets an autoguide for an existing parameter with name ``name`` (mimic
        the behavior of module :mod:`pyro.infer.autoguide`).

        .. note:: `dist_constructor` should be one of
            :class:`~pyro.distributions.Delta`,
            :class:`~pyro.distributions.Normal`, and
            :class:`~pyro.distributions.MultivariateNormal`. More distribution
            constructor will be supported in the future if needed.

        :param str name: Name of the parameter.
        :param dist_constructor: A
            :class:`~pyro.distributions.distribution.Distribution` constructor.
        """
        if name not in self._priors:
            raise ValueError(
                "There is no prior for parameter: {}".format(name))

        if dist_constructor not in [
                dist.Delta, dist.Normal, dist.MultivariateNormal
        ]:
            raise NotImplementedError(
                "Unsupported distribution type: {}".format(dist_constructor))

        # delete old guide
        if name in self._guides:
            dist_args = self._guides[name][1]
            for arg in dist_args:
                delattr(self, "{}_{}".format(name, arg))

        p = self._priors[name]()  # init_to_sample strategy
        if dist_constructor is dist.Delta:
            support = self._priors[name].support
            if _is_real_support(support):
                p_map = Parameter(p.detach())
            else:
                p_map = PyroParam(p.detach(), support)
            setattr(self, "{}_map".format(name), p_map)
            dist_args = ("map", )
        elif dist_constructor is dist.Normal:
            loc = Parameter(
                biject_to(self._priors[name].support).inv(p).detach())
            scale = PyroParam(loc.new_ones(loc.shape), constraints.positive)
            setattr(self, "{}_loc".format(name), loc)
            setattr(self, "{}_scale".format(name), scale)
            dist_args = ("loc", "scale")
        elif dist_constructor is dist.MultivariateNormal:
            loc = Parameter(
                biject_to(self._priors[name].support).inv(p).detach())
            identity = eye_like(loc, loc.size(-1))
            scale_tril = PyroParam(identity.repeat(loc.shape[:-1] + (1, 1)),
                                   constraints.lower_cholesky)
            setattr(self, "{}_loc".format(name), loc)
            setattr(self, "{}_scale_tril".format(name), scale_tril)
            dist_args = ("loc", "scale_tril")
        else:
            raise NotImplementedError

        self._guides[name] = (dist_constructor, dist_args)
コード例 #2
0
    def templates_guide_mvn(self):
        """ Multivariate normal guide for template parameters
        """

        loc = _deep_getattr(self, "mvn.loc")
        scale_tril = _deep_getattr(self, "mvn.scale_tril")

        dt = dist.MultivariateNormal(loc, scale_tril=scale_tril)
        states = pyro.sample("states_" + self.name_prefix,
                             dt,
                             infer={"is_auxiliary": True})

        result = {}

        for i_poiss in torch.arange(self.n_poiss):
            transform = biject_to(self.poiss_priors[i_poiss].support)
            value = transform(states[i_poiss])
            log_density = transform.inv.log_abs_det_jacobian(
                value, states[i_poiss])
            log_density = sum_rightmost(
                log_density,
                log_density.dim() - value.dim() +
                self.poiss_priors[i_poiss].event_dim)

            result[self.poiss_labels[i_poiss]] = pyro.sample(
                self.poiss_labels[i_poiss],
                dist.Delta(value,
                           log_density=log_density,
                           event_dim=self.poiss_priors[i_poiss].event_dim))

        i_param = self.n_poiss

        for i_ps in torch.arange(self.n_ps):
            for i_ps_param in torch.arange(self.n_ps_params):

                transform = biject_to(self.ps_priors[i_ps][i_ps_param].support)

                value = transform(states[i_param])

                log_density = transform.inv.log_abs_det_jacobian(
                    value, states[i_param])
                log_density = sum_rightmost(
                    log_density,
                    log_density.dim() - value.dim() +
                    self.ps_priors[i_ps][i_ps_param].event_dim)

                result[self.ps_param_labels[i_ps_param] + "_" +
                       self.ps_labels[i_ps]] = pyro.sample(
                           self.ps_param_labels[i_ps_param] + "_" +
                           self.ps_labels[i_ps],
                           dist.Delta(value,
                                      log_density=log_density,
                                      event_dim=self.ps_priors[i_ps]
                                      [i_ps_param].event_dim))
                i_param += 1

        return result
コード例 #3
0
    def guide(self):

        a_locs = pyro.param("a_locs", torch.full((self.n_params, ), 0.0))
        a_scales_tril = pyro.param(
            "a_scales",
            lambda: 0.1 * eye_like(a_locs, self.n_params),
            constraint=constraints.lower_cholesky)

        dt = dist.MultivariateNormal(a_locs, scale_tril=a_scales_tril)
        states = pyro.sample("states", dt, infer={"is_auxiliary": True})

        result = {}

        for i_poiss in torch.arange(self.n_poiss):
            transform = biject_to(self.poiss_priors[i_poiss].support)
            value = transform(states[i_poiss])
            log_density = transform.inv.log_abs_det_jacobian(
                value, states[i_poiss])
            log_density = sum_rightmost(
                log_density,
                log_density.dim() - value.dim() +
                self.poiss_priors[i_poiss].event_dim)

            result[self.labels_poiss[i_poiss]] = pyro.sample(
                self.labels_poiss[i_poiss],
                dist.Delta(value,
                           log_density=log_density,
                           event_dim=self.poiss_priors[i_poiss].event_dim))

        i_param = self.n_poiss

        for i_ps in torch.arange(self.n_ps):
            for i_ps_param in torch.arange(self.n_ps_params):

                transform = biject_to(self.ps_priors[i_ps][i_ps_param].support)

                value = transform(states[i_param])

                log_density = transform.inv.log_abs_det_jacobian(
                    value, states[i_param])
                log_density = sum_rightmost(
                    log_density,
                    log_density.dim() - value.dim() +
                    self.ps_priors[i_ps][i_ps_param].event_dim)

                result[self.labels_ps_params[i_ps_param] + "_" +
                       self.labels_ps[i_ps]] = pyro.sample(
                           self.labels_ps_params[i_ps_param] + "_" +
                           self.labels_ps[i_ps],
                           dist.Delta(value,
                                      log_density=log_density,
                                      event_dim=self.ps_priors[i_ps]
                                      [i_ps_param].event_dim))
                i_param += 1

        return result
コード例 #4
0
ファイル: hmc.py プロジェクト: lewisKit/pyro
    def setup(self, *args, **kwargs):
        self._args = args
        self._kwargs = kwargs
        # set the trace prototype to inter-convert between trace object
        # and dict object used by the integrator
        trace = poutine.trace(self.model).get_trace(*args, **kwargs)
        self._prototype_trace = trace
        if self._automatic_transform_enabled:
            self.transforms = {}
        for name, node in sorted(trace.iter_stochastic_nodes(), key=lambda x: x[0]):
            site_value = node["value"]
            if node["fn"].support is not constraints.real and self._automatic_transform_enabled:
                self.transforms[name] = biject_to(node["fn"].support).inv
                site_value = self.transforms[name](node["value"])
            r_loc = site_value.new_zeros(site_value.shape)
            r_scale = site_value.new_ones(site_value.shape)
            self._r_dist[name] = dist.Normal(loc=r_loc, scale=r_scale)
        self._validate_trace(trace)

        if self.adapt_step_size:
            self._adapt_phase = True
            z = {name: node["value"] for name, node in trace.iter_stochastic_nodes()}
            for name, transform in self.transforms.items():
                z[name] = transform(z[name])
            self.step_size = self._find_reasonable_step_size(z)
            self.num_steps = max(1, int(self.trajectory_length / self.step_size))
            # make prox-center for Dual Averaging scheme
            loc = math.log(10 * self.step_size)
            self._adapted_scheme = DualAveraging(prox_center=loc)
コード例 #5
0
    def get_posterior(
        self,
        name: str,
        prior: Distribution,
    ) -> Union[Distribution, torch.Tensor]:
        if self._computing_median:
            return self._get_posterior_median(name, prior)
        if self._computing_quantiles:
            return self._get_posterior_quantiles(name, prior)
        if self._computing_mi:
            # the messenger autoguide needs the output to fit certain dimensions
            # this is hack which saves MI to self.mi but returns cheap to compute medians
            self.mi[name] = self._get_mutual_information(name, prior)
            return self._get_posterior_median(name, prior)

        with helpful_support_errors({"name": name, "fn": prior}):
            transform = biject_to(prior.support)
        # If hierarchical_sites not specified all sites are assumed to be hierarchical
        if (self._hierarchical_sites is None) or (name
                                                  in self._hierarchical_sites):
            loc, scale, weight = self._get_params(name, prior)
            loc = loc + transform.inv(prior.mean) * weight
            posterior = dist.TransformedDistribution(
                dist.Normal(loc, scale).to_event(transform.domain.event_dim),
                transform.with_cache(),
            )
            return posterior
        else:
            # Fall back to mean field when hierarchical_sites list is not empty and site not in the list.
            loc, scale = self._get_params(name, prior)
            posterior = dist.TransformedDistribution(
                dist.Normal(loc, scale).to_event(transform.domain.event_dim),
                transform.with_cache(),
            )
            return posterior
コード例 #6
0
 def _initialize_model_properties(self):
     if self.max_plate_nesting is None:
         self._guess_max_plate_nesting()
     # Wrap model in `poutine.enum` to enumerate over discrete latent sites.
     # No-op if model does not have any discrete latents.
     self.model = poutine.enum(config_enumerate(self.model),
                               first_available_dim=-1 -
                               self.max_plate_nesting)
     if self._automatic_transform_enabled:
         self.transforms = {}
     trace = poutine.trace(self.model).get_trace(*self._args,
                                                 **self._kwargs)
     for name, node in trace.iter_stochastic_nodes():
         if isinstance(node["fn"], _Subsample):
             continue
         if node["fn"].has_enumerate_support:
             self._has_enumerable_sites = True
             continue
         site_value = node["value"]
         if node["fn"].support is not constraints.real and self._automatic_transform_enabled:
             self.transforms[name] = biject_to(node["fn"].support).inv
             site_value = self.transforms[name](node["value"])
         self._r_shapes[name] = site_value.shape
         self._r_numels[name] = site_value.numel()
     self._trace_prob_evaluator = TraceEinsumEvaluator(
         trace, self._has_enumerable_sites, self.max_plate_nesting)
     mass_matrix_size = sum(self._r_numels.values())
     if self.full_mass:
         initial_mass_matrix = eye_like(site_value, mass_matrix_size)
     else:
         initial_mass_matrix = site_value.new_ones(mass_matrix_size)
     self._adapter.inverse_mass_matrix = initial_mass_matrix
コード例 #7
0
def test_biject_to(constraint_fn, args, is_cuda):
    constraint = build_constraint(constraint_fn, args, is_cuda=is_cuda)
    try:
        t = biject_to(constraint)
    except NotImplementedError:
        pytest.skip('`biject_to` not implemented.')
    assert t.bijective, "biject_to({}) is not bijective".format(constraint)
    if constraint_fn is constraints.corr_cholesky:
        # (D * (D-1)) / 2 (where D = 4) = 6 (size of last dim)
        x = torch.randn(6, 6, dtype=torch.double)
    else:
        x = torch.randn(5, 5, dtype=torch.double)
    if is_cuda:
        x = x.cuda()
    y = t(x)
    assert constraint.check(y).all(), '\n'.join([
        "Failed to biject_to({})".format(constraint),
        "x = {}".format(x),
        "biject_to(...)(x) = {}".format(y),
    ])
    x2 = t.inv(y)
    assert torch.allclose(
        x, x2), "Error in biject_to({}) inverse".format(constraint)

    j = t.log_abs_det_jacobian(x, y)
    assert j.shape == x.shape[:x.dim() - t.input_event_dim]
コード例 #8
0
ファイル: __init__.py プロジェクト: lewisKit/pyro
    def __call__(self, *args, **kwargs):
        """
        An automatic guide with the same ``*args, **kwargs`` as the base ``model``.

        :return: A dict mapping sample site name to sampled value.
        :rtype: dict
        """
        # if we've never run the model before, do so now so we can inspect the model structure
        if self.prototype_trace is None:
            self._setup_prototype(*args, **kwargs)

        latent = self.sample_latent(*args, **kwargs)
        iaranges = self._create_iaranges()

        # unpack continuous latent samples
        result = {}
        for site, unconstrained_value in self._unpack_latent(latent):
            name = site["name"]
            transform = biject_to(site["fn"].support)
            value = transform(unconstrained_value)
            log_density = transform.inv.log_abs_det_jacobian(value, unconstrained_value)
            log_density = sum_rightmost(log_density, log_density.dim() - value.dim() + site["fn"].event_dim)
            delta_dist = dist.Delta(value, log_density=log_density, event_dim=site["fn"].event_dim)

            with ExitStack() as stack:
                for frame in self._cond_indep_stacks[name]:
                    stack.enter_context(iaranges[frame.name])
                result[name] = pyro.sample(name, delta_dist)

        return result
コード例 #9
0
ファイル: guides.py プロジェクト: jamestwebber/pyro
    def __call__(self, *args, **kwargs):
        """
        An automatic guide with the same ``*args, **kwargs`` as the base ``model``.

        :return: A dict mapping sample site name to sampled value.
        :rtype: dict
        """
        # if we've never run the model before, do so now so we can inspect the model structure
        if self.prototype_trace is None:
            self._setup_prototype(*args, **kwargs)

        latent = self.sample_latent(*args, **kwargs)
        plates = self._create_plates()

        # unpack continuous latent samples
        result = {}
        for site, unconstrained_value in self._unpack_latent(latent):
            name = site["name"]
            transform = biject_to(site["fn"].support)
            value = transform(unconstrained_value)
            log_density = transform.inv.log_abs_det_jacobian(value, unconstrained_value)
            log_density = sum_rightmost(log_density, log_density.dim() - value.dim() + site["fn"].event_dim)
            delta_dist = dist.Delta(value, log_density=log_density, event_dim=site["fn"].event_dim)

            with ExitStack() as stack:
                for frame in self._cond_indep_stacks[name]:
                    stack.enter_context(plates[frame.name])
                result[name] = pyro.sample(name, delta_dist)

        return result
コード例 #10
0
ファイル: pyroutils.py プロジェクト: tomMoral/sbi
def get_transforms(model: Callable, *model_args: Any, **model_kwargs: Any):
    """Get automatic transforms to unbounded space

    Args:
        model: Pyro model
        model_args: Arguments passed to model
        model_args: Keyword arguments passed to model
    
    Example:
        ```python
        def prior():
            return pyro.sample("theta", pyro.distributions.Uniform(0., 1.))
            
        transform_to_unbounded = get_transforms(prior)["theta"]
        ```
    """
    transforms = {}

    model_trace = poutine.trace(model).get_trace(*model_args, **model_kwargs)

    for name, node in model_trace.iter_stochastic_nodes():
        fn = node["fn"]
        transforms[name] = biject_to(fn.support).inv

    return transforms
コード例 #11
0
ファイル: test_inference.py プロジェクト: pyro-ppl/pyro
def test_auto_dirichlet(auto_class, Elbo):
    num_steps = 2000
    prior = torch.tensor([0.5, 1.0, 1.5, 3.0])
    data = torch.tensor([0] * 4 + [1] * 2 + [2] * 5).long()
    posterior = torch.tensor([4.5, 3.0, 6.5, 3.0])

    def model(data):
        p = pyro.sample("p", dist.Dirichlet(prior))
        with pyro.plate("data_plate"):
            pyro.sample("data", dist.Categorical(p).expand_by(data.shape), obs=data)

    guide = auto_class(model)
    svi = SVI(model, guide, optim.Adam({"lr": 0.003}), loss=Elbo())

    for _ in range(num_steps):
        loss = svi.step(data)
        assert np.isfinite(loss), loss

    expected_mean = posterior / posterior.sum()
    if isinstance(guide, (AutoIAFNormal, AutoNormalizingFlow)):
        loc = guide.transform(torch.zeros(guide.latent_dim))
    else:
        loc = guide.loc
    actual_mean = biject_to(constraints.simplex)(loc)
    assert_equal(
        actual_mean,
        expected_mean,
        prec=0.2,
        msg="".join(
            [
                "\nexpected {}".format(expected_mean.detach().cpu().numpy()),
                "\n  actual {}".format(actual_mean.detach().cpu().numpy()),
            ]
        ),
    )
コード例 #12
0
ファイル: guides.py プロジェクト: ucals/pyro
    def quantiles(self, quantiles, *args, **kwargs):
        """
        Returns posterior quantiles each latent variable. Example::

            print(guide.quantiles([0.05, 0.5, 0.95]))

        :param quantiles: A list of requested quantiles between 0 and 1.
        :type quantiles: torch.Tensor or list
        :return: A dict mapping sample site name to a list of quantile values.
        :rtype: dict
        """
        results = {}

        for name, site in self.prototype_trace.iter_stochastic_nodes():
            site_loc, site_scale = self._get_loc_and_scale(name)

            site_quantiles = torch.tensor(quantiles,
                                          dtype=site_loc.dtype,
                                          device=site_loc.device)
            site_quantiles_values = dist.Normal(
                site_loc, site_scale).icdf(site_quantiles)
            constrained_site_quantiles = biject_to(
                site["fn"].support)(site_quantiles_values)
            results[name] = constrained_site_quantiles

        return results
コード例 #13
0
def _get_sample_fn(module, name):
    if module.mode == "model":
        return module._priors[name]

    dist_constructor, dist_args = module._guides[name]

    if dist_constructor is dist.Delta:
        p_map = getattr(module, "{}_map".format(name))
        return dist.Delta(p_map, event_dim=p_map.dim())

    # create guide
    dist_args = {
        arg: getattr(module, "{}_{}".format(name, arg))
        for arg in dist_args
    }
    guide = dist_constructor(**dist_args)

    # no need to do transforms when support is real (for mean field ELBO)
    support = module._priors[name].support
    if _is_real_support(support):
        return guide.to_event()

    # otherwise, we do inference in unconstrained space and transform the value
    # back to original space
    # TODO: move this logic to infer.autoguide or somewhere else
    unconstrained_value = pyro.sample(module._pyro_get_fullname(
        "{}_latent".format(name)),
                                      guide.to_event(),
                                      infer={"is_auxiliary": True})
    transform = biject_to(support)
    value = transform(unconstrained_value)
    log_density = transform.inv.log_abs_det_jacobian(value,
                                                     unconstrained_value)
    return dist.Delta(value, log_density.sum(), event_dim=value.dim())
コード例 #14
0
ファイル: easyguide.py プロジェクト: www3cam/pyro
    def sample(self, guide_name, fn, infer=None):
        """
        Wrapper around ``pyro.sample()`` to create a single auxiliary sample
        site and then unpack to multiple sample sites for model replay.

        :param str guide_name: The name of the auxiliary guide site.
        :param callable fn: A distribution with shape ``self.event_shape``.
        :param dict infer: Optional inference configuration dict.
        :returns: A pair ``(guide_z, model_zs)`` where ``guide_z`` is the
            single concatenated blob and ``model_zs`` is a dict mapping
            site name to constrained model sample.
        :rtype: tuple
        """
        # Sample a packed tensor.
        if fn.event_shape != self.event_shape:
            raise ValueError(
                "Invalid fn.event_shape for group: expected {}, actual {}".
                format(tuple(self.event_shape), tuple(fn.event_shape)))
        if infer is None:
            infer = {}
        infer["is_auxiliary"] = True
        guide_z = pyro.sample(guide_name, fn, infer=infer)
        common_batch_shape = guide_z.shape[:-1]

        model_zs = {}
        pos = 0
        for site in self.prototype_sites:
            name = site["name"]
            fn = site["fn"]

            # Extract slice from packed sample.
            size = self._site_sizes[name]
            batch_shape = broadcast_shape(common_batch_shape,
                                          self._site_batch_shapes[name])
            unconstrained_z = guide_z[..., pos:pos + size]
            unconstrained_z = unconstrained_z.reshape(batch_shape +
                                                      fn.event_shape)
            pos += size

            # Transform to constrained space.
            transform = biject_to(fn.support)
            z = transform(unconstrained_z)
            log_density = transform.inv.log_abs_det_jacobian(
                z, unconstrained_z)
            log_density = sum_rightmost(
                log_density,
                log_density.dim() - z.dim() + fn.event_dim)
            delta_dist = dist.Delta(z,
                                    log_density=log_density,
                                    event_dim=fn.event_dim)

            # Replay model sample statement.
            with ExitStack() as stack:
                for frame in site["cond_indep_stack"]:
                    plate = self.guide.plate(frame.name)
                    if plate not in runtime._PYRO_STACK:
                        stack.enter_context(plate)
                model_zs[name] = pyro.sample(name, delta_dist)

        return guide_z, model_zs
コード例 #15
0
ファイル: guides.py プロジェクト: cweniger/pyrofit-core
    def _get_group(self, match='.*'):
        """Return group and unconstrained initial values."""
        group = self.group(match=match)
        z = []
        for site in group.prototype_sites:
            constrained_z = self.init(site)
            transform = biject_to(site['fn'].support)
            z.append(transform.inv(constrained_z).reshape(-1))
        z_init = torch.cat(z, 0)

        event_mask = []
        for site in group.prototype_sites:
            site_shape = site['fn'].batch_shape + site['fn'].event_shape
            if isinstance(
                    site['fn'],
                    pyro.distributions.torch_distribution.MaskedDistribution):
                mask = site['fn']._mask.expand(site_shape).flatten()
            else:
                mask = torch.full([site_shape.numel()],
                                  True,
                                  dtype=torch.bool,
                                  device=z_init.device)
            event_mask.append(mask)
        group.event_mask = torch.cat(event_mask)

        return group, z_init
コード例 #16
0
ファイル: gaussian.py プロジェクト: pyro-ppl/pyro
    def _transform_values(
        self,
        aux_values: Dict[str, torch.Tensor],
    ) -> Tuple[Dict[str, torch.Tensor], Union[float, torch.Tensor]]:
        # Learnably transform auxiliary values to user-facing values.
        values = {}
        log_densities = defaultdict(float)
        compute_density = am_i_wrapped() and poutine.get_mask() is not False
        for name, site in self._factors.items():
            if site["is_observed"]:
                continue
            loc = deep_getattr(self.locs, name)
            scale = deep_getattr(self.scales, name)
            unconstrained = aux_values[name] * scale + loc

            # Transform to constrained space.
            transform = biject_to(site["fn"].support)
            values[name] = transform(unconstrained)
            if compute_density:
                assert transform.codomain.event_dim == site["fn"].event_dim
                log_densities[name] = transform.inv.log_abs_det_jacobian(
                    values[name], unconstrained
                ) - scale.log().reshape(site["fn"].batch_shape + (-1,)).sum(-1)

        return values, log_densities
コード例 #17
0
ファイル: hmc.py プロジェクト: zippeurfou/pyro
    def setup(self, *args, **kwargs):
        self._args = args
        self._kwargs = kwargs
        # set the trace prototype to inter-convert between trace object
        # and dict object used by the integrator
        trace = poutine.trace(self.model).get_trace(*args, **kwargs)
        self._prototype_trace = trace
        if self._automatic_transform_enabled:
            self.transforms = {}
        for name, node in sorted(trace.iter_stochastic_nodes(),
                                 key=lambda x: x[0]):
            site_value = node["value"]
            if node["fn"].support is not constraints.real and self._automatic_transform_enabled:
                self.transforms[name] = biject_to(node["fn"].support).inv
                site_value = self.transforms[name](node["value"])
            r_loc = site_value.new_zeros(site_value.shape)
            r_scale = site_value.new_ones(site_value.shape)
            self._r_dist[name] = dist.Normal(loc=r_loc, scale=r_scale)
        self._validate_trace(trace)

        if self.adapt_step_size:
            self._adapt_phase = True
            z = {
                name: node["value"]
                for name, node in trace.iter_stochastic_nodes()
            }
            for name, transform in self.transforms.items():
                z[name] = transform(z[name])
            self.step_size = self._find_reasonable_step_size(z)
            self.num_steps = max(1,
                                 int(self.trajectory_length / self.step_size))
            # make prox-center for Dual Averaging scheme
            loc = math.log(10 * self.step_size)
            self._adapted_scheme = DualAveraging(prox_center=loc)
コード例 #18
0
ファイル: effect.py プロジェクト: pyro-ppl/pyro
    def _get_params(self, name: str, prior: Distribution):
        try:
            loc = deep_getattr(self.locs, name)
            scale = deep_getattr(self.scales, name)
            return loc, scale
        except AttributeError:
            pass

        # Initialize.
        with torch.no_grad():
            transform = biject_to(prior.support)
            event_dim = transform.domain.event_dim
            constrained = self.init_loc_fn({
                "name": name,
                "fn": prior
            }).detach()
            unconstrained = transform.inv(constrained)
            # Initialize the distribution to be an affine combination:
            #   init_scale * prior + (1 - init_scale) * init_loc
            init_loc = self._adjust_plates(unconstrained, event_dim)
            init_loc = init_loc * (1 - self._init_scale)
            init_scale = torch.full_like(init_loc, self._init_scale)

        deep_setattr(self, "locs." + name,
                     PyroParam(init_loc, event_dim=event_dim))
        deep_setattr(
            self,
            "scales." + name,
            PyroParam(init_scale,
                      constraint=constraints.positive,
                      event_dim=event_dim),
        )
        return self._get_params(name, prior)
コード例 #19
0
    def guide(self):
        """Approximate posterior for the horseshoe prior. We assume posterior in the form
        of the multivariate normal distriburtion for the global mean and standard deviation
        and multivariate normal distribution for the parameters of each subject independently.
        """
        nsub = self.runs  # number of subjects
        npar = self.npar  # number of parameters
        trns = biject_to(constraints.positive)

        m_hyp = param('m_hyp', zeros(2 * npar))
        st_hyp = param('scale_tril_hyp',
                       torch.eye(2 * npar),
                       constraint=constraints.lower_cholesky)
        hyp = sample('hyp',
                     dist.MultivariateNormal(m_hyp, scale_tril=st_hyp),
                     infer={'is_auxiliary': True})

        unc_mu = hyp[..., :npar]
        unc_tau = hyp[..., npar:]

        c_tau = trns(unc_tau)

        ld_tau = trns.inv.log_abs_det_jacobian(c_tau, unc_tau)
        ld_tau = sum_rightmost(ld_tau, ld_tau.dim() - c_tau.dim() + 1)

        sample("mu", dist.Delta(unc_mu, event_dim=1))
        sample("tau", dist.Delta(c_tau, log_density=ld_tau, event_dim=1))

        m_locs = param('m_locs', zeros(nsub, npar))
        st_locs = param('scale_tril_locs',
                        torch.eye(npar).repeat(nsub, 1, 1),
                        constraint=constraints.lower_cholesky)

        with plate('runs', nsub):
            sample("locs", dist.MultivariateNormal(m_locs, scale_tril=st_locs))
コード例 #20
0
ファイル: guides.py プロジェクト: yufengwa/pyro
    def _setup_prototype(self, *args, **kwargs):
        super()._setup_prototype(*args, **kwargs)

        self._event_dims = {}
        self._cond_indep_stacks = {}
        self.locs = PyroModule()
        self.scales = PyroModule()

        # Initialize guide params
        for name, site in self.prototype_trace.iter_stochastic_nodes():
            # Collect unconstrained event_dims, which may differ from constrained event_dims.
            with helpful_support_errors(site):
                init_loc = biject_to(site["fn"].support).inv(site["value"].detach()).detach()
            event_dim = site["fn"].event_dim + init_loc.dim() - site["value"].dim()
            self._event_dims[name] = event_dim

            # Collect independence contexts.
            self._cond_indep_stacks[name] = site["cond_indep_stack"]

            # If subsampling, repeat init_value to full size.
            for frame in site["cond_indep_stack"]:
                full_size = getattr(frame, "full_size", frame.size)
                if full_size != frame.size:
                    dim = frame.dim - event_dim
                    init_loc = periodic_repeat(init_loc, full_size, dim).contiguous()
            init_scale = torch.full_like(init_loc, self._init_scale)

            _deep_setattr(self.locs, name, PyroParam(init_loc, constraints.real, event_dim))
            _deep_setattr(self.scales, name,
                          PyroParam(init_scale, self.scale_constraint, event_dim))
コード例 #21
0
ファイル: test_inference.py プロジェクト: zyxue/pyro
def test_auto_dirichlet(auto_class, Elbo):
    num_steps = 2000
    prior = torch.tensor([0.5, 1.0, 1.5, 3.0])
    data = torch.tensor([0] * 4 + [1] * 2 + [2] * 5).long()
    posterior = torch.tensor([4.5, 3.0, 6.5, 3.0])

    def model(data):
        p = pyro.sample("p", dist.Dirichlet(prior))
        with pyro.plate("data_plate"):
            pyro.sample("data",
                        dist.Categorical(p).expand_by(data.shape),
                        obs=data)

    guide = auto_class(model)
    svi = SVI(model, guide, optim.Adam({"lr": .003}), loss=Elbo())

    for _ in range(num_steps):
        loss = svi.step(data)
        assert np.isfinite(loss), loss

    expected_mean = posterior / posterior.sum()
    actual_mean = biject_to(constraints.simplex)(pyro.param("auto_loc"))
    assert_equal(actual_mean,
                 expected_mean,
                 prec=0.2,
                 msg=''.join([
                     '\nexpected {}'.format(
                         expected_mean.detach().cpu().numpy()),
                     '\n  actual {}'.format(actual_mean.detach().cpu().numpy())
                 ]))
コード例 #22
0
    def bijection(self) -> Transform:
        """
        Returns a bijected function for transforms from unconstrained to constrained space.
        """
        if not self.trainable:
            raise ValueError('Is not of `Distribution` instance!')

        return biject_to(self._prior.support)
コード例 #23
0
ファイル: structured.py プロジェクト: pyro-ppl/pyro
 def median(self, *args, **kwargs):
     result = {}
     for name, site in self._sorted_sites:
         loc = deep_getattr(self.locs, name).detach()
         shape = self._batch_shapes[
             name] + self._unconstrained_event_shapes[name]
         loc = loc.reshape(shape)
         result[name] = biject_to(site["fn"].support)(loc)
     return result
コード例 #24
0
    def forward(self, *args, **kwargs):
        """
        An automatic guide with the same ``*args, **kwargs`` as the base ``model``.

        .. note:: This method is used internally by :class:`~torch.nn.Module`.
            Users should instead use :meth:`~torch.nn.Module.__call__`.

        :return: A dict mapping sample site name to sampled value.
        :rtype: dict
        """
        # if we've never run the model before, do so now so we can inspect the model structure
        if self.prototype_trace is None:
            self._setup_prototype(*args, **kwargs)

        encoded_hidden = self.encode(*args, **kwargs)

        plates = self._create_plates(*args, **kwargs)
        result = {}
        for name, site in self.prototype_trace.iter_stochastic_nodes():
            transform = biject_to(site["fn"].support)

            with ExitStack() as stack:
                for frame in site["cond_indep_stack"]:
                    if frame.vectorized:
                        stack.enter_context(plates[frame.name])

                site_loc, site_scale = self._get_loc_and_scale(name, encoded_hidden)
                unconstrained_latent = pyro.sample(
                    name + "_unconstrained",
                    dist.Normal(
                        site_loc,
                        site_scale,
                    ).to_event(self._event_dims[name]),
                    infer={"is_auxiliary": True},
                )

                value = transform(unconstrained_latent)
                if pyro.poutine.get_mask() is False:
                    log_density = 0.0
                else:
                    log_density = transform.inv.log_abs_det_jacobian(
                        value,
                        unconstrained_latent,
                    )
                    log_density = sum_rightmost(
                        log_density,
                        log_density.dim() - value.dim() + site["fn"].event_dim,
                    )
                delta_dist = dist.Delta(
                    value,
                    log_density=log_density,
                    event_dim=site["fn"].event_dim,
                )

                result[name] = pyro.sample(name, delta_dist)

        return result
コード例 #25
0
ファイル: effect.py プロジェクト: pyro-ppl/pyro
 def _get_posterior_median(self, name, prior):
     transform = biject_to(prior.support)
     if (self._hierarchical_sites is None) or (name
                                               in self._hierarchical_sites):
         loc, scale, weight = self._get_params(name, prior)
         loc = loc + transform.inv(prior.mean) * weight
     else:
         loc, scale = self._get_params(name, prior)
     return transform(loc)
コード例 #26
0
    def _get_mutual_information(self, name, prior):
        """Approximate the mutual information between data x and latent variable z

            I(x, z) = E_xE_{q(z|x)}log(q(z|x)) - E_xE_{q(z|x)}log(q(z))

        Returns: Float

        """

        #### get posterior mean and variance ####
        transform = biject_to(prior.support)
        if (self._hierarchical_sites is None) or (name
                                                  in self._hierarchical_sites):
            loc, scale, weight = self._get_params(name, prior)
            loc = loc + transform.inv(prior.mean) * weight
        else:
            loc, scale = self._get_params(name, prior)

        if name not in self.amortised_plate_sites["sites"].keys():
            # if amortisation is not used for a particular site return MI=0
            return 0

        #### create tensors with useful numbers ####
        one = torch.ones((), dtype=loc.dtype, device=loc.device)
        two = torch.tensor(2, dtype=loc.dtype, device=loc.device)
        pi = torch.tensor(3.14159265359, dtype=loc.dtype, device=loc.device)
        #### get sample from posterior ####
        z_samples = self.samples_for_mi[name]

        #### compute mi ####
        x_batch, nz = loc.size()
        x_batch = torch.tensor(x_batch, dtype=loc.dtype, device=loc.device)
        nz = torch.tensor(nz, dtype=loc.dtype, device=loc.device)

        # E_{q(z|x)}log(q(z|x)) = -0.5*nz*log(2*\pi) - 0.5*(1+scale.loc()).sum(-1)
        neg_entropy = (-nz * torch.log(pi * two) * (one / two) -
                       ((scale**two).log() + one).sum(-1) *
                       (one / two)).mean()

        # [1, x_batch, nz]
        loc, scale = loc.unsqueeze(0), scale.unsqueeze(0)
        var = scale**two

        # (z_batch, x_batch, nz)
        dev = z_samples - loc

        # (z_batch, x_batch)
        log_density = -((dev**two) / var).sum(dim=-1) * (
            one / two) - (nz * torch.log(pi * two) +
                          (scale**two).log().sum(-1)) * (one / two)

        # log q(z): aggregate posterior
        # [z_batch]
        log_qz = log_sum_exp(log_density, dim=1) - torch.log(x_batch)

        return (neg_entropy - log_qz.mean(-1)).item()
コード例 #27
0
 def get_named_particles(self):
     """
     Create a dictionary mapping name to vectorized value, of the form ``{name: tensor}``.
     The leading dimension of each tensor corresponds to particles, i.e. this creates a struct of arrays.
     """
     return {
         site["name"]: biject_to(site["fn"].support)(unconstrained_value)
         for site, unconstrained_value in self.guide._unpack_latent(
             pyro.param("svgd_particles"))
     }
コード例 #28
0
ファイル: guides.py プロジェクト: jamestwebber/pyro
    def median(self, *args, **kwargs):
        """
        Returns the posterior median value of each latent variable.

        :return: A dict mapping sample site name to median tensor.
        :rtype: dict
        """
        loc, _ = self._loc_scale(*args, **kwargs)
        return {site["name"]: biject_to(site["fn"].support)(unconstrained_value)
                for site, unconstrained_value in self._unpack_latent(loc)}
コード例 #29
0
ファイル: __init__.py プロジェクト: lewisKit/pyro
    def median(self, *args, **kwargs):
        """
        Returns the posterior median value of each latent variable.

        :return: A dict mapping sample site name to median tensor.
        :rtype: dict
        """
        loc, scale = self._loc_scale(*args, **kwargs)
        return {site["name"]: biject_to(site["fn"].support)(unconstrained_value)
                for site, unconstrained_value in self._unpack_latent(loc)}
コード例 #30
0
    def __init__(self, base_dist, **parameters):
        super().__init__()

        self.base_dist = base_dist

        for k, v in parameters.items():
            self.register_buffer(
                k, v if isinstance(v, torch.Tensor) else torch.tensor(v))

        self.bijection = biject_to(self().support)
        self.shape = self().event_shape
コード例 #31
0
ファイル: neutra.py プロジェクト: pyro-ppl/pyro
    def apply(self, msg):
        name = msg["name"]
        fn = msg["fn"]
        value = msg["value"]
        is_observed = msg["is_observed"]
        if name not in self.guide.prototype_trace.nodes:
            return {"fn": fn, "value": value, "is_observed": is_observed}
        if is_observed:
            raise NotImplementedError(
                f"At pyro.sample({repr(name)},...), "
                "NeuTraReparam does not support observe statements.")

        log_density = 0.0
        compute_density = poutine.get_mask() is not False
        if name not in self.x_unconstrained:  # On first sample site.
            # Sample a shared latent.
            try:
                self.transform = self.guide.get_transform()
            except (NotImplementedError, TypeError) as e:
                raise ValueError(
                    "NeuTraReparam only supports guides that implement "
                    "`get_transform` method that does not depend on the "
                    "model's `*args, **kwargs`") from e

            with ExitStack() as stack:
                for plate in self.guide.plates.values():
                    stack.enter_context(
                        block_plate(dim=plate.dim, strict=False))
                z_unconstrained = pyro.sample(
                    f"{name}_shared_latent",
                    self.guide.get_base_dist().mask(False))

            # Differentiably transform.
            x_unconstrained = self.transform(z_unconstrained)
            if compute_density:
                log_density = self.transform.log_abs_det_jacobian(
                    z_unconstrained, x_unconstrained)
            self.x_unconstrained = {
                site["name"]: (site, unconstrained_value)
                for site, unconstrained_value in self.guide._unpack_latent(
                    x_unconstrained)
            }

        # Extract a single site's value from the shared latent.
        site, unconstrained_value = self.x_unconstrained.pop(name)
        transform = biject_to(fn.support)
        value = transform(unconstrained_value)
        if compute_density:
            logdet = transform.log_abs_det_jacobian(unconstrained_value, value)
            logdet = sum_rightmost(logdet,
                                   logdet.dim() - value.dim() + fn.event_dim)
            log_density = log_density + fn.log_prob(value) + logdet
        new_fn = dist.Delta(value, log_density, event_dim=fn.event_dim)
        return {"fn": new_fn, "value": value, "is_observed": True}
コード例 #32
0
ファイル: guides.py プロジェクト: jamestwebber/pyro
 def _init_loc(self):
     """
     Creates an initial latent vector using a per-site init function.
     """
     parts = []
     for name, site in self.prototype_trace.iter_stochastic_nodes():
         constrained_value = site["value"].detach()
         unconstrained_value = biject_to(site["fn"].support).inv(constrained_value)
         parts.append(unconstrained_value.reshape(-1))
     latent = torch.cat(parts)
     assert latent.size() == (self.latent_dim,)
     return latent
コード例 #33
0
ファイル: effect.py プロジェクト: pyro-ppl/pyro
 def get_posterior(
         self, name: str,
         prior: Distribution) -> Union[Distribution, torch.Tensor]:
     with helpful_support_errors({"name": name, "fn": prior}):
         transform = biject_to(prior.support)
     loc, scale = self._get_params(name, prior)
     affine = dist.transforms.AffineTransform(
         loc, scale, event_dim=transform.domain.event_dim, cache_size=1)
     posterior = dist.TransformedDistribution(
         prior,
         [transform.inv.with_cache(), affine,
          transform.with_cache()])
     return posterior
コード例 #34
0
ファイル: __init__.py プロジェクト: lewisKit/pyro
    def _setup_prototype(self, *args, **kwargs):
        super(AutoContinuous, self)._setup_prototype(*args, **kwargs)
        self._unconstrained_shapes = {}
        self._cond_indep_stacks = {}
        for name, site in self.prototype_trace.nodes.items():
            if site["type"] != "sample" or site["is_observed"]:
                continue

            # Collect the shapes of unconstrained values.
            # These may differ from the shapes of constrained values.
            self._unconstrained_shapes[name] = biject_to(site["fn"].support).inv(site["value"]).shape

            # Collect independence contexts.
            self._cond_indep_stacks[name] = site["cond_indep_stack"]

        self.latent_dim = sum(_product(shape) for shape in self._unconstrained_shapes.values())
        if self.latent_dim == 0:
            raise RuntimeError('{} found no latent variables; Use an empty guide instead'.format(type(self).__name__))
コード例 #35
0
ファイル: __init__.py プロジェクト: lewisKit/pyro
    def quantiles(self, quantiles, *args, **kwargs):
        """
        Returns posterior quantiles each latent variable. Example::

            print(guide.quantiles([0.05, 0.5, 0.95]))

        :param quantiles: A list of requested quantiles between 0 and 1.
        :type quantiles: torch.Tensor or list
        :return: A dict mapping sample site name to a list of quantile values.
        :rtype: dict
        """
        loc, scale = self._loc_scale(*args, **kwargs)
        quantiles = loc.new_tensor(quantiles).unsqueeze(-1)
        latents = dist.Normal(loc, scale).icdf(quantiles)
        result = {}
        for latent in latents:
            for site, unconstrained_value in self._unpack_latent(latent):
                result.setdefault(site["name"], []).append(biject_to(site["fn"].support)(unconstrained_value))
        return result