Exemplo n.º 1
0
 def unpack_single_latent(latent):
     unpacked_samples = self.unpack_latent(latent)
     if self._has_transformed_dist:
         base_param_map = {**params, **unpacked_samples}
         return constrain_fn(self.model, model_args, model_kwargs,
                             self._inv_transforms, base_param_map)
     else:
         return transform_fn(self._inv_transforms, unpacked_samples)
Exemplo n.º 2
0
 def unpack_single_latent(latent):
     unpacked_samples = self.unpack_latent(latent)
     if self._has_transformed_dist:
         # first, substitute to `param` statements in model
         model = handlers.substitute(self.model, params)
         return constrain_fn(model, model_args, model_kwargs,
                             self._inv_transforms, unpacked_samples)
     else:
         return transform_fn(self._inv_transforms, unpacked_samples)
Exemplo n.º 3
0
    def median(self, opt_state):
        """
        Returns the posterior median value of each latent variable.

        :param opt_state: Current state of the optimizer.
        :return: A dict mapping sample site name to median tensor.
        :rtype: dict
        """
        loc, _ = self._loc_scale(opt_state)
        return transform_fn(self._inv_transforms, self._unravel_fn(loc))
Exemplo n.º 4
0
 def _potential_energy(params):
     params_constrained = transform_fn(inv_transforms, params)
     log_joint, model_trace = log_density(model, model_args, model_kwargs,
                                          params_constrained)
     for name, t in inv_transforms.items():
         t_log_det = np.sum(
             t.log_abs_det_jacobian(params[name], params_constrained[name]))
         if 'scale' in model_trace[name]:
             t_log_det = model_trace[name]['scale'] * t_log_det
         log_joint = log_joint + t_log_det
     return -log_joint
Exemplo n.º 5
0
 def sample_posterior(self, rng, opt_state, *args, **kwargs):
     sample_shape = kwargs.pop('sample_shape', ())
     loc, scale = self._loc_scale(opt_state)
     num_samples = int(np.prod(sample_shape))
     latent_sample = dist.Normal(loc, scale).sample(rng, sample_shape)
     if not sample_shape:
         unpacked_samples = self._unravel_fn(latent_sample)
     else:
         latent_sample = np.reshape(
             latent_sample,
             (num_samples, ) + np.shape(latent_sample)[len(sample_shape):])
         unpacked_samples = vmap(self._unravel_fn)(latent_sample)
         unpacked_samples = {
             name: np.reshape(val, sample_shape + np.shape(val)[1:])
             for name, val in unpacked_samples.items()
         }
     return transform_fn(self._inv_transforms, unpacked_samples)
Exemplo n.º 6
0
    def single_chain_init(key, only_params=False):
        seeded_model = seed(model, key)
        model_trace = trace(seeded_model).get_trace(*model_args,
                                                    **model_kwargs)
        constrained_values, inv_transforms = {}, {}
        for k, v in model_trace.items():
            if v['type'] == 'sample' and not v['is_observed']:
                constrained_values[k] = v['value']
                inv_transforms[k] = biject_to(v['fn'].support)
            elif v['type'] == 'param':
                constrained_values[k] = v['value']
                constraint = v['kwargs'].pop('constraint', real)
                inv_transforms[k] = biject_to(constraint)
        prior_params = transform_fn(
            inv_transforms, {k: v
                             for k, v in constrained_values.items()},
            invert=True)
        if init_strategy == 'uniform':
            init_params = {}
            for k, v in prior_params.items():
                key, = random.split(key, 1)
                init_params[k] = random.uniform(key,
                                                shape=np.shape(v),
                                                minval=-2,
                                                maxval=2)
        elif init_strategy == 'prior':
            init_params = prior_params
        else:
            raise ValueError(
                'initialize={} is not a valid initialization strategy.'.format(
                    init_strategy))

        if only_params:
            return init_params
        else:
            return (init_params,
                    potential_energy(seeded_model, model_args, model_kwargs,
                                     inv_transforms),
                    jax.partial(transform_fn, inv_transforms))
Exemplo n.º 7
0
def initialize_model(rng, model, *model_args, init_strategy=init_to_uniform, **model_kwargs):
    """
    Given a model with Pyro primitives, returns a function which, given
    unconstrained parameters, evaluates the potential energy (negative
    joint density). In addition, this also returns initial parameters
    sampled from the prior to initiate MCMC sampling and functions to
    transform unconstrained values at sample sites to constrained values
    within their respective support.

    :param jax.random.PRNGKey rng: random number generator seed to
        sample from the prior. The returned `init_params` will have the
        batch shape ``rng.shape[:-1]``.
    :param model: Python callable containing Pyro primitives.
    :param `*model_args`: args provided to the model.
    :param callable init_strategy: a per-site initialization function.
    :param `**model_kwargs`: kwargs provided to the model.
    :return: tuple of (`init_params`, `potential_fn`, `constrain_fn`),
        `init_params` are values from the prior used to initiate MCMC,
        `constrain_fn` is a callable that uses inverse transforms
        to convert unconstrained HMC samples to constrained values that
        lie within the site's support.
    """
    seeded_model = seed(model, rng if rng.ndim == 1 else rng[0])
    model_trace = trace(seeded_model).get_trace(*model_args, **model_kwargs)
    constrained_values, inv_transforms = {}, {}
    has_transformed_dist = False
    for k, v in model_trace.items():
        if v['type'] == 'sample' and not v['is_observed']:
            if v['intermediates']:
                constrained_values[k] = v['intermediates'][0][0]
                inv_transforms[k] = biject_to(v['fn'].base_dist.support)
                has_transformed_dist = True
            else:
                constrained_values[k] = v['value']
                inv_transforms[k] = biject_to(v['fn'].support)
        elif v['type'] == 'param':
            constraint = v['kwargs'].pop('constraint', real)
            transform = biject_to(constraint)
            if isinstance(transform, ComposeTransform):
                base_transform = transform.parts[0]
                constrained_values[k] = base_transform(transform.inv(v['value']))
                inv_transforms[k] = base_transform
                has_transformed_dist = True
            else:
                inv_transforms[k] = transform
                constrained_values[k] = v['value']

    prototype_params = transform_fn(inv_transforms,
                                    {k: v for k, v in constrained_values.items()},
                                    invert=True)

    # NB: we use model instead of seeded_model to prevent unexpected behaviours (if any)
    potential_fn = jax.partial(potential_energy, model, model_args, model_kwargs, inv_transforms)
    if has_transformed_dist:
        # FIXME: why using seeded_model here triggers an error for funnel reparam example
        # if we use MCMC class (mcmc function works fine)
        constrain_fun = jax.partial(constrain_fn, model, model_args, model_kwargs, inv_transforms)
    else:
        constrain_fun = jax.partial(transform_fn, inv_transforms)

    def single_chain_init(key):
        return find_valid_initial_params(key, model, *model_args, init_strategy=init_strategy,
                                         param_as_improper=True, prototype_params=prototype_params,
                                         **model_kwargs)

    if rng.ndim == 1:
        init_params, is_valid = single_chain_init(rng)
    else:
        init_params, is_valid = lax.map(single_chain_init, rng)

    if isinstance(is_valid, jax.interpreters.xla.DeviceArray):
        if device_get(~np.all(is_valid)):
            raise RuntimeError("Cannot find valid initial parameters. Please check your model again.")
    return init_params, potential_fn, constrain_fun