Exemplo n.º 1
0
def get_model_transforms(rng_key, model, model_args=(), model_kwargs=None):
    model_kwargs = {} if model_kwargs is None else model_kwargs
    seeded_model = seed(model, rng_key if rng_key.ndim == 1 else rng_key[0])
    model_trace = trace(seeded_model).get_trace(*model_args, **model_kwargs)
    inv_transforms = {}
    # model code may need to be replayed in the presence of dynamic constraints
    # or deterministic sites
    replay_model = False
    for k, v in model_trace.items():
        if v['type'] == 'sample' and not v['is_observed']:
            if v['intermediates']:
                inv_transforms[k] = biject_to(v['fn'].base_dist.support)
                replay_model = True
            else:
                inv_transforms[k] = biject_to(v['fn'].support)
        elif v['type'] == 'param':
            constraint = v['kwargs'].pop('constraint', real)
            transform = biject_to(constraint)
            if isinstance(transform, ComposeTransform):
                inv_transforms[k] = transform.parts[0]
                replay_model = True
            else:
                inv_transforms[k] = transform
        elif v['type'] == 'deterministic':
            replay_model = True
    return inv_transforms, replay_model
Exemplo n.º 2
0
    def _setup_prototype(self, *args, **kwargs):
        super(AutoContinuous, self)._setup_prototype(*args, **kwargs)
        rng_key = numpyro.sample("_{}_rng_key_init".format(self.prefix),
                                 dist.PRNGIdentity())
        init_params, _ = handlers.block(find_valid_initial_params)(
            rng_key,
            self.model,
            init_strategy=self.init_strategy,
            model_args=args,
            model_kwargs=kwargs)
        self._inv_transforms = {}
        self._has_transformed_dist = False
        unconstrained_sites = {}
        for name, site in self.prototype_trace.items():
            if site['type'] == 'sample' and not site['is_observed']:
                if site['intermediates']:
                    transform = biject_to(site['fn'].base_dist.support)
                    self._inv_transforms[name] = transform
                    unconstrained_sites[name] = transform.inv(
                        site['intermediates'][0][0])
                    self._has_transformed_dist = True
                else:
                    transform = biject_to(site['fn'].support)
                    self._inv_transforms[name] = transform
                    unconstrained_sites[name] = transform.inv(site['value'])

        self._init_latent, self._unpack_latent = ravel_pytree(init_params)
        self.latent_size = np.size(self._init_latent)
        if self.base_dist is None:
            self.base_dist = dist.Independent(
                dist.Normal(np.zeros(self.latent_size), 1.), 1)
        if self.latent_size == 0:
            raise RuntimeError(
                '{} found no latent variables; Use an empty guide instead'.
                format(type(self).__name__))
Exemplo n.º 3
0
    def body_fn(state):
        i, key, _, _ = state
        key, subkey = random.split(key)

        # Wrap model in a `substitute` handler to initialize from `init_loc_fn`.
        # Use `block` to not record sample primitives in `init_loc_fn`.
        seeded_model = substitute(model, substitute_fn=block(seed(init_strategy, subkey)))
        model_trace = trace(seeded_model).get_trace(*model_args, **model_kwargs)
        constrained_values, inv_transforms = {}, {}
        for k, v in model_trace.items():
            if v['type'] == 'sample' and not v['is_observed']:
                if v['intermediates']:
                    constrained_values[k] = v['intermediates'][0][0]
                    inv_transforms[k] = biject_to(v['fn'].base_dist.support)
                else:
                    constrained_values[k] = v['value']
                    inv_transforms[k] = biject_to(v['fn'].support)
            elif v['type'] == 'param' and param_as_improper:
                constraint = v['kwargs'].pop('constraint', real)
                transform = biject_to(constraint)
                if isinstance(transform, ComposeTransform):
                    base_transform = transform.parts[0]
                    inv_transforms[k] = base_transform
                    constrained_values[k] = base_transform(transform.inv(v['value']))
                else:
                    inv_transforms[k] = transform
                    constrained_values[k] = v['value']
        params = transform_fn(inv_transforms,
                              {k: v for k, v in constrained_values.items()},
                              invert=True)
        potential_fn = jax.partial(potential_energy, model, inv_transforms, model_args, model_kwargs)
        pe, param_grads = value_and_grad(potential_fn)(params)
        z_grad = ravel_pytree(param_grads)[0]
        is_valid = np.isfinite(pe) & np.all(np.isfinite(z_grad))
        return i + 1, key, params, is_valid
Exemplo n.º 4
0
def test_model_with_transformed_distribution():
    x_prior = dist.HalfNormal(2)
    y_prior = dist.LogNormal(scale=3.)  # transformed distribution

    def model():
        numpyro.sample('x', x_prior)
        numpyro.sample('y', y_prior)

    params = {'x': jnp.array(-5.), 'y': jnp.array(7.)}
    model = handlers.seed(model, random.PRNGKey(0))
    inv_transforms = {
        'x': biject_to(x_prior.support),
        'y': biject_to(y_prior.support)
    }
    expected_samples = partial(transform_fn, inv_transforms)(params)
    expected_potential_energy = (-x_prior.log_prob(expected_samples['x']) -
                                 y_prior.log_prob(expected_samples['y']) -
                                 inv_transforms['x'].log_abs_det_jacobian(
                                     params['x'], expected_samples['x']) -
                                 inv_transforms['y'].log_abs_det_jacobian(
                                     params['y'], expected_samples['y']))

    reparam_model = handlers.reparam(model, {'y': TransformReparam()})
    base_params = {'x': params['x'], 'y_base': params['y']}
    actual_samples = constrain_fn(handlers.seed(reparam_model,
                                                random.PRNGKey(0)), (), {},
                                  base_params,
                                  return_deterministic=True)
    actual_potential_energy = potential_energy(reparam_model, (), {},
                                               base_params)

    assert_allclose(expected_samples['x'], actual_samples['x'])
    assert_allclose(expected_samples['y'], actual_samples['y'])
    assert_allclose(actual_potential_energy, expected_potential_energy)
Exemplo n.º 5
0
def get_potential_fn(rng_key, model, dynamic_args=False, model_args=(), model_kwargs=None):
    """
    (EXPERIMENTAL INTERFACE) Given a model with Pyro primitives, returns a
    function which, given unconstrained parameters, evaluates the potential
    energy (negative log joint density). In addition, this returns a
    function to transform unconstrained values at sample sites to constrained
    values within their respective support.

    :param jax.random.PRNGKey rng_key: random number generator seed to
        sample from the prior. The returned `init_params` will have the
        batch shape ``rng_key.shape[:-1]``.
    :param model: Python callable containing Pyro primitives.
    :param bool dynamic_args: if `True`, the `potential_fn` and
        `constraints_fn` are themselves dependent on model arguments.
        When provided a `*model_args, **model_kwargs`, they return
        `potential_fn` and `constraints_fn` callables, respectively.
    :param tuple model_args: args provided to the model.
    :param dict model_kwargs: kwargs provided to the model.
    :return: tuple of (`potential_fn`, `constrain_fn`). The latter is used
        to constrain unconstrained samples (e.g. those returned by HMC)
        to values that lie within the site's support.
    """
    model_kwargs = {} if model_kwargs is None else model_kwargs
    seeded_model = seed(model, rng_key if rng_key.ndim == 1 else rng_key[0])
    model_trace = trace(seeded_model).get_trace(*model_args, **model_kwargs)
    inv_transforms = {}
    has_transformed_dist = False
    for k, v in model_trace.items():
        if v['type'] == 'sample' and not v['is_observed']:
            if v['intermediates']:
                inv_transforms[k] = biject_to(v['fn'].base_dist.support)
                has_transformed_dist = True
            else:
                inv_transforms[k] = biject_to(v['fn'].support)
        elif v['type'] == 'param':
            constraint = v['kwargs'].pop('constraint', real)
            transform = biject_to(constraint)
            if isinstance(transform, ComposeTransform):
                inv_transforms[k] = transform.parts[0]
                has_transformed_dist = True
            else:
                inv_transforms[k] = transform

    if dynamic_args:
        def potential_fn(*args, **kwargs):
            return jax.partial(potential_energy, model, inv_transforms, args, kwargs)
        if has_transformed_dist:
            def constrain_fun(*args, **kwargs):
                return jax.partial(constrain_fn, model, inv_transforms, args, kwargs)
        else:
            def constrain_fun(*args, **kwargs):
                return jax.partial(transform_fn, inv_transforms)
    else:
        potential_fn = jax.partial(potential_energy, model, inv_transforms, model_args, model_kwargs)
        if has_transformed_dist:
            constrain_fun = jax.partial(constrain_fn, model, inv_transforms, model_args, model_kwargs)
        else:
            constrain_fun = jax.partial(transform_fn, inv_transforms)

    return potential_fn, constrain_fun
Exemplo n.º 6
0
def _init_to_uniform(site, radius=2, skip_param=False):
    if site['type'] == 'sample' and not site['is_observed']:
        if isinstance(site['fn'], dist.TransformedDistribution):
            fn = site['fn'].base_dist
        else:
            fn = site['fn']
        value = numpyro.sample('_init',
                               fn,
                               sample_shape=site['kwargs']['sample_shape'])
        base_transform = biject_to(fn.support)
        unconstrained_value = numpyro.sample('_unconstrained_init',
                                             dist.Uniform(-radius, radius),
                                             sample_shape=np.shape(
                                                 base_transform.inv(value)))
        return base_transform(unconstrained_value)

    if site['type'] == 'param' and not skip_param:
        # return base value of param site
        constraint = site['kwargs'].pop('constraint', real)
        transform = biject_to(constraint)
        value = site['args'][0]
        unconstrained_value = numpyro.sample('_unconstrained_init',
                                             dist.Uniform(-radius, radius),
                                             sample_shape=np.shape(
                                                 transform.inv(value)))
        if isinstance(transform, ComposeTransform):
            base_transform = transform.parts[0]
        else:
            base_transform = transform
        return base_transform(unconstrained_value)
Exemplo n.º 7
0
 def substitute_fn(site):
     if site["name"] in params:
         if site["type"] == "sample":
             with helpful_support_errors(site):
                 return biject_to(site["fn"].support)(params[site["name"]])
         else:
             return params[site["name"]]
Exemplo n.º 8
0
def _unconstrain_reparam(params, site):
    name = site["name"]
    if name in params:
        p = params[name]
        support = site["fn"].support
        t = biject_to(support)
        # in scan, we might only want to substitute an item at index i, rather than the whole sequence
        i = site["infer"].get("_scan_current_index", None)
        if i is not None:
            event_dim_shift = t.codomain.event_dim - t.domain.event_dim
            expected_unconstrained_dim = len(
                site["fn"].shape()) - event_dim_shift
            # check if p has additional time dimension
            if jnp.ndim(p) > expected_unconstrained_dim:
                p = p[i]

        if support in [constraints.real, constraints.real_vector]:
            return p
        value = t(p)

        log_det = t.log_abs_det_jacobian(p, value)
        log_det = sum_rightmost(
            log_det,
            jnp.ndim(log_det) - jnp.ndim(value) + len(site["fn"].event_shape))
        if site["scale"] is not None:
            log_det = site["scale"] * log_det
        numpyro.factor("_{}_log_det".format(name), log_det)
        return value
Exemplo n.º 9
0
def _get_model_transforms(model, model_args=(), model_kwargs=None):
    model_kwargs = {} if model_kwargs is None else model_kwargs
    model_trace = trace(model).get_trace(*model_args, **model_kwargs)
    inv_transforms = {}
    # model code may need to be replayed in the presence of deterministic sites
    replay_model = False
    has_enumerate_support = False
    for k, v in model_trace.items():
        if v["type"] == "sample" and not v["is_observed"]:
            if v["fn"].is_discrete:
                has_enumerate_support = True
                if not v["fn"].has_enumerate_support:
                    raise RuntimeError(
                        "MCMC only supports continuous sites or discrete sites "
                        f"with enumerate support, but got {type(v['fn']).__name__}."
                    )
            else:
                support = v["fn"].support
                inv_transforms[k] = biject_to(support)
                # XXX: the following code filters out most situations with dynamic supports
                args = ()
                if isinstance(support, constraints._GreaterThan):
                    args = ("lower_bound", )
                elif isinstance(support, constraints._Interval):
                    args = ("lower_bound", "upper_bound")
                for arg in args:
                    if not isinstance(getattr(support, arg), (int, float)):
                        replay_model = True
        elif v["type"] == "deterministic":
            replay_model = True
    return inv_transforms, replay_model, has_enumerate_support, model_trace
Exemplo n.º 10
0
        def log_likelihood(params_flat, subsample_indices=None):
            if subsample_indices is None:
                subsample_indices = {
                    k: jnp.arange(v[0])
                    for k, v in subsample_plate_sizes.items()
                }
            params = unravel_fn(params_flat)
            with warnings.catch_warnings():
                warnings.simplefilter("ignore")
                params = {
                    name: biject_to(prototype_trace[name]["fn"].support)(value)
                    for name, value in params.items()
                }
                with block(), trace() as tr, substitute(
                        data=subsample_indices), substitute(data=params):
                    model(*model_args, **model_kwargs)

            log_lik = {}
            for site in tr.values():
                if site["type"] == "sample" and site["is_observed"]:
                    for frame in site["cond_indep_stack"]:
                        if frame.name in log_lik:
                            log_lik[frame.name] += _sum_all_except_at_dim(
                                site["fn"].log_prob(site["value"]), frame.dim)
                        else:
                            log_lik[frame.name] = _sum_all_except_at_dim(
                                site["fn"].log_prob(site["value"]), frame.dim)
            return log_lik
def test_elbo_dynamic_support():
    x_prior = dist.Uniform(0, 5)
    x_unconstrained = 2.

    def model():
        numpyro.sample('x', x_prior)

    class _AutoGuide(AutoDiagonalNormal):
        def __call__(self, *args, **kwargs):
            return substitute(
                super(_AutoGuide, self).__call__,
                {'_auto_latent': x_unconstrained})(*args, **kwargs)

    adam = optim.Adam(0.01)
    guide = _AutoGuide(model)
    svi = SVI(model, guide, adam, AutoContinuousELBO())
    svi_state = svi.init(random.PRNGKey(0))
    actual_loss = svi.evaluate(svi_state)
    assert np.isfinite(actual_loss)

    guide_log_prob = dist.Normal(
        guide._init_latent, guide._init_scale).log_prob(x_unconstrained).sum()
    transfrom = transforms.biject_to(constraints.interval(0, 5))
    x = transfrom(x_unconstrained)
    logdet = transfrom.log_abs_det_jacobian(x_unconstrained, x)
    model_log_prob = x_prior.log_prob(x) + logdet
    expected_loss = guide_log_prob - model_log_prob
    assert_allclose(actual_loss, expected_loss, rtol=1e-6)
Exemplo n.º 12
0
    def body_fn(state):
        i, key, _, _ = state
        key, subkey = random.split(key)

        if radius is None or prototype_params is None:
            # Wrap model in a `substitute` handler to initialize from `init_loc_fn`.
            seeded_model = substitute(seed(model, subkey), substitute_fn=init_strategy)
            model_trace = trace(seeded_model).get_trace(*model_args, **model_kwargs)
            constrained_values, inv_transforms = {}, {}
            for k, v in model_trace.items():
                if v['type'] == 'sample' and not v['is_observed'] and not v['fn'].is_discrete:
                    constrained_values[k] = v['value']
                    inv_transforms[k] = biject_to(v['fn'].support)
            params = transform_fn(inv_transforms,
                                  {k: v for k, v in constrained_values.items()},
                                  invert=True)
        else:  # this branch doesn't require tracing the model
            params = {}
            for k, v in prototype_params.items():
                if k in init_values:
                    params[k] = init_values[k]
                else:
                    params[k] = random.uniform(subkey, jnp.shape(v), minval=-radius, maxval=radius)
                    key, subkey = random.split(key)

        potential_fn = partial(potential_energy, model, model_args, model_kwargs, enum=enum)
        pe, z_grad = value_and_grad(potential_fn)(params)
        z_grad_flat = ravel_pytree(z_grad)[0]
        is_valid = jnp.isfinite(pe) & jnp.all(jnp.isfinite(z_grad_flat))
        return i + 1, key, (params, pe, z_grad), is_valid
Exemplo n.º 13
0
    def __call__(self, *args, **kwargs):
        """
        An automatic guide with the same ``*args, **kwargs`` as the base ``model``.

        :return: A dict mapping sample site name to sampled value.
        :rtype: dict
        """
        if self.prototype_trace is None:
            # run model to inspect the model structure
            self._setup_prototype(*args, **kwargs)

        latent = self._sample_latent(*args, **kwargs)

        # unpack continuous latent samples
        result = {}

        for name, unconstrained_value in self._unpack_latent(latent).items():
            site = self.prototype_trace[name]
            transform = biject_to(site["fn"].support)
            value = transform(unconstrained_value)
            event_ndim = site["fn"].event_dim
            if numpyro.get_mask() is False:
                log_density = 0.0
            else:
                log_density = -transform.log_abs_det_jacobian(
                    unconstrained_value, value)
                log_density = sum_rightmost(
                    log_density,
                    jnp.ndim(log_density) - jnp.ndim(value) + event_ndim)
            delta_dist = dist.Delta(value,
                                    log_density=log_density,
                                    event_dim=event_ndim)
            result[name] = numpyro.sample(name, delta_dist)

        return result
Exemplo n.º 14
0
def test_log_prob_LKJCholesky(dimension, concentration):
    # We will test against the fact that LKJCorrCholesky can be seen as a
    # TransformedDistribution with base distribution is a distribution of partial
    # correlations in C-vine method (modulo an affine transform to change domain from (0, 1)
    # to (1, 0)) and transform is a signed stick-breaking process.
    d = dist.LKJCholesky(dimension, concentration, sample_method="cvine")

    beta_sample = d._beta.sample(random.PRNGKey(0))
    beta_log_prob = np.sum(d._beta.log_prob(beta_sample))
    partial_correlation = 2 * beta_sample - 1
    affine_logdet = beta_sample.shape[-1] * np.log(2)
    sample = signed_stick_breaking_tril(partial_correlation)

    # compute signed stick breaking logdet
    inv_tanh = lambda t: np.log((1 + t) / (1 - t)) / 2  # noqa: E731
    inv_tanh_logdet = np.sum(np.log(vmap(grad(inv_tanh))(partial_correlation)))
    unconstrained = inv_tanh(partial_correlation)
    corr_cholesky_logdet = biject_to(constraints.corr_cholesky).log_abs_det_jacobian(
        unconstrained,
        sample,
    )
    signed_stick_breaking_logdet = corr_cholesky_logdet + inv_tanh_logdet

    actual_log_prob = d.log_prob(sample)
    expected_log_prob = beta_log_prob - affine_logdet - signed_stick_breaking_logdet
    assert_allclose(actual_log_prob, expected_log_prob, rtol=1e-5)

    assert_allclose(jax.jit(d.log_prob)(sample), d.log_prob(sample), atol=1e-7)
Exemplo n.º 15
0
def test_bijective_transforms(transform, event_shape, batch_shape):
    shape = batch_shape + event_shape
    rng_key = random.PRNGKey(0)
    x = biject_to(transform.domain)(random.normal(rng_key, shape))
    y = transform(x)

    # test codomain
    assert_array_equal(transform.codomain(y), np.ones(batch_shape))

    # test inv
    z = transform.inv(y)
    assert_allclose(x, z, atol=1e-6, rtol=1e-6)

    # test domain
    assert_array_equal(transform.domain(z), np.ones(batch_shape))

    # test log_abs_det_jacobian
    actual = transform.log_abs_det_jacobian(x, y)
    assert np.shape(actual) == batch_shape
    if len(shape) == transform.event_dim:
        if len(event_shape) == 1:
            expected = onp.linalg.slogdet(jax.jacobian(transform)(x))[1]
            inv_expected = onp.linalg.slogdet(jax.jacobian(transform.inv)(y))[1]
        else:
            expected = np.log(np.abs(grad(transform)(x)))
            inv_expected = np.log(np.abs(grad(transform.inv)(y)))

        assert_allclose(actual, expected, atol=1e-6)
        assert_allclose(actual, -inv_expected, atol=1e-6)
Exemplo n.º 16
0
    def init(self, rng_key, *args, **kwargs):
        """
        Gets the initial SVI state.

        :param jax.random.PRNGKey rng_key: random number generator seed.
        :param args: arguments to the model / guide (these can possibly vary during
            the course of fitting).
        :param kwargs: keyword arguments to the model / guide (these can possibly vary
            during the course of fitting).
        :return: the initial :data:`SVIState`
        """
        rng_key, model_seed, guide_seed = random.split(rng_key, 3)
        model_init = seed(self.model, model_seed)
        guide_init = seed(self.guide, guide_seed)
        guide_trace = trace(guide_init).get_trace(*args, **kwargs,
                                                  **self.static_kwargs)
        model_trace = trace(replay(model_init, guide_trace)).get_trace(
            *args, **kwargs, **self.static_kwargs)
        params = {}
        inv_transforms = {}
        # NB: params in model_trace will be overwritten by params in guide_trace
        for site in list(model_trace.values()) + list(guide_trace.values()):
            if site['type'] == 'param':
                constraint = site['kwargs'].pop('constraint', constraints.real)
                transform = biject_to(constraint)
                inv_transforms[site['name']] = transform
                params[site['name']] = transform.inv(site['value'])

        self.constrain_fn = partial(transform_fn, inv_transforms)
        # we convert weak types like float to float32/float64
        # to avoid recompiling body_fn in svi.run
        params = tree_map(
            lambda x: lax.convert_element_type(x, jnp.result_type(x)), params)
        return SVIState(self.optim.init(params), rng_key)
Exemplo n.º 17
0
    def body_fn(state):
        i, key, _, _ = state
        key, subkey = random.split(key)

        if radius is None or prototype_params is None:
            # XXX: we don't want to apply enum to draw latent samples
            model_ = model
            if enum:
                from numpyro.contrib.funsor import enum as enum_handler

                if isinstance(model, substitute) and isinstance(model.fn, enum_handler):
                    model_ = substitute(model.fn.fn, data=model.data)
                elif isinstance(model, enum_handler):
                    model_ = model.fn

            # Wrap model in a `substitute` handler to initialize from `init_loc_fn`.
            seeded_model = substitute(seed(model_, subkey), substitute_fn=init_strategy)
            model_trace = trace(seeded_model).get_trace(*model_args, **model_kwargs)
            constrained_values, inv_transforms = {}, {}
            for k, v in model_trace.items():
                if (
                    v["type"] == "sample"
                    and not v["is_observed"]
                    and not v["fn"].is_discrete
                ):
                    constrained_values[k] = v["value"]
                    inv_transforms[k] = biject_to(v["fn"].support)
            params = transform_fn(
                inv_transforms,
                {k: v for k, v in constrained_values.items()},
                invert=True,
            )
        else:  # this branch doesn't require tracing the model
            params = {}
            for k, v in prototype_params.items():
                if k in init_values:
                    params[k] = init_values[k]
                else:
                    params[k] = random.uniform(
                        subkey, jnp.shape(v), minval=-radius, maxval=radius
                    )
                    key, subkey = random.split(key)

        potential_fn = partial(
            potential_energy, model, model_args, model_kwargs, enum=enum
        )
        if validate_grad:
            if forward_mode_differentiation:
                pe = potential_fn(params)
                z_grad = jacfwd(potential_fn)(params)
            else:
                pe, z_grad = value_and_grad(potential_fn)(params)
            z_grad_flat = ravel_pytree(z_grad)[0]
            is_valid = jnp.isfinite(pe) & jnp.all(jnp.isfinite(z_grad_flat))
        else:
            pe = potential_fn(params)
            is_valid = jnp.isfinite(pe)
            z_grad = None

        return i + 1, key, (params, pe, z_grad), is_valid
Exemplo n.º 18
0
def test_initialize_model_change_point(init_strategy):
    def model(data):
        alpha = 1 / jnp.mean(data)
        lambda1 = numpyro.sample('lambda1', dist.Exponential(alpha))
        lambda2 = numpyro.sample('lambda2', dist.Exponential(alpha))
        tau = numpyro.sample('tau', dist.Uniform(0, 1))
        lambda12 = jnp.where(jnp.arange(len(data)) < tau * len(data), lambda1, lambda2)
        numpyro.sample('obs', dist.Poisson(lambda12), obs=data)

    count_data = jnp.array([
        13,  24,   8,  24,   7,  35,  14,  11,  15,  11,  22,  22,  11,  57,
        11,  19,  29,   6,  19,  12,  22,  12,  18,  72,  32,   9,   7,  13,
        19,  23,  27,  20,   6,  17,  13,  10,  14,   6,  16,  15,   7,   2,
        15,  15,  19,  70,  49,   7,  53,  22,  21,  31,  19,  11,  18,  20,
        12,  35,  17,  23,  17,   4,   2,  31,  30,  13,  27,   0,  39,  37,
        5,  14,  13,  22,
    ])

    rng_keys = random.split(random.PRNGKey(1), 2)
    init_params, _, _, _ = initialize_model(rng_keys, model,
                                            init_strategy=init_strategy,
                                            model_args=(count_data,))
    if isinstance(init_strategy, partial) and init_strategy.func is init_to_value:
        expected = biject_to(constraints.unit_interval).inv(init_strategy.keywords.get('values')['tau'])
        assert_allclose(init_params[0]['tau'], jnp.repeat(expected, 2))
    for i in range(2):
        init_params_i, _, _, _ = initialize_model(rng_keys[i], model,
                                                  init_strategy=init_strategy,
                                                  model_args=(count_data,))
        for name, p in init_params[0].items():
            # XXX: the result is equal if we disable fast-math-mode
            assert_allclose(p[i], init_params_i[0][name], atol=1e-6)
Exemplo n.º 19
0
    def init(self, rng_key, *args, **kwargs):
        """

        :param jax.random.PRNGKey rng_key: random number generator seed.
        :param args: arguments to the model / guide (these can possibly vary during
            the course of fitting).
        :param kwargs: keyword arguments to the model / guide (these can possibly vary
            during the course of fitting).
        :return: tuple containing initial :data:`SVIState`, and `get_params`, a callable
            that transforms unconstrained parameter values from the optimizer to the
            specified constrained domain
        """
        rng_key, model_seed, guide_seed = random.split(rng_key, 3)
        model_init = seed(self.model, model_seed)
        guide_init = seed(self.guide, guide_seed)
        guide_trace = trace(guide_init).get_trace(*args, **kwargs, **self.static_kwargs)
        model_trace = trace(replay(model_init, guide_trace)).get_trace(*args, **kwargs, **self.static_kwargs)
        params = {}
        inv_transforms = {}
        # NB: params in model_trace will be overwritten by params in guide_trace
        for site in list(model_trace.values()) + list(guide_trace.values()):
            if site['type'] == 'param':
                constraint = site['kwargs'].pop('constraint', constraints.real)
                transform = biject_to(constraint)
                inv_transforms[site['name']] = transform
                params[site['name']] = transform.inv(site['value'])

        self.constrain_fn = partial(transform_fn, inv_transforms)
        return SVIState(self.optim.init(params), rng_key)
Exemplo n.º 20
0
    def __call__(self, *args, **kwargs):
        """
        An automatic guide with the same ``*args, **kwargs`` as the base ``model``.

        :return: A dict mapping sample site name to sampled value.
        :rtype: dict
        """
        if self.prototype_trace is None:
            # run model to inspect the model structure
            self._setup_prototype(*args, **kwargs)

        plates = self._create_plates(*args, **kwargs)
        result = {}
        for name, site in self.prototype_trace.items():
            if site["type"] != "sample" or isinstance(
                    site["fn"], dist.PRNGIdentity) or site["is_observed"]:
                continue

            event_dim = self._event_dims[name]
            init_loc = self._init_locs[name]
            with ExitStack() as stack:
                for frame in site["cond_indep_stack"]:
                    stack.enter_context(plates[frame.name])

                site_loc = numpyro.param("{}_{}_loc".format(name, self.prefix),
                                         init_loc,
                                         event_dim=event_dim)
                site_scale = numpyro.param("{}_{}_scale".format(
                    name, self.prefix),
                                           jnp.full(jnp.shape(init_loc),
                                                    self._init_scale),
                                           constraint=constraints.positive,
                                           event_dim=event_dim)

                site_fn = dist.Normal(site_loc, site_scale).to_event(event_dim)
                if site["fn"].support in [
                        constraints.real, constraints.real_vector
                ]:
                    result[name] = numpyro.sample(name, site_fn)
                else:
                    unconstrained_value = numpyro.sample(
                        "{}_unconstrained".format(name),
                        site_fn,
                        infer={"is_auxiliary": True})

                    transform = biject_to(site['fn'].support)
                    value = transform(unconstrained_value)
                    log_density = -transform.log_abs_det_jacobian(
                        unconstrained_value, value)
                    log_density = sum_rightmost(
                        log_density,
                        jnp.ndim(log_density) - jnp.ndim(value) +
                        site["fn"].event_dim)
                    delta_dist = dist.Delta(value,
                                            log_density=log_density,
                                            event_dim=site["fn"].event_dim)
                    result[name] = numpyro.sample(name, delta_dist)

        return result
Exemplo n.º 21
0
 def init(site, skip_param=False):
     if isinstance(site['fn'], dist.TransformedDistribution):
         fn = site['fn'].base_dist
     else:
         fn = site['fn']
     vals = init_strategy(site, skip_param=skip_param)
     if vals is not None:
         base_transform = biject_to(fn.support)
         unconstrained_init = numpyro.sample('_noisy_init', dist.Normal(loc=base_transform.inv(vals), scale=noise_scale))
         return base_transform(unconstrained_init)
Exemplo n.º 22
0
def test_iaf():
    # test for substitute logic for exposed methods `sample_posterior` and `get_transforms`
    N, dim = 3000, 3
    data = random.normal(random.PRNGKey(0), (N, dim))
    true_coefs = jnp.arange(1.0, dim + 1.0)
    logits = jnp.sum(true_coefs * data, axis=-1)
    labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(1))

    def model(data, labels):
        coefs = numpyro.sample("coefs",
                               dist.Normal(jnp.zeros(dim), jnp.ones(dim)))
        offset = numpyro.sample("offset", dist.Uniform(-1, 1))
        logits = offset + jnp.sum(coefs * data, axis=-1)
        return numpyro.sample("obs", dist.Bernoulli(logits=logits), obs=labels)

    adam = optim.Adam(0.01)
    rng_key_init = random.PRNGKey(1)
    guide = AutoIAFNormal(model)
    svi = SVI(model, guide, adam, Trace_ELBO())
    svi_state = svi.init(rng_key_init, data, labels)
    params = svi.get_params(svi_state)

    x = random.normal(random.PRNGKey(0), (dim + 1, ))
    rng_key = random.PRNGKey(1)
    actual_sample = guide.sample_posterior(rng_key, params)
    actual_output = guide._unpack_latent(guide.get_transform(params)(x))

    flows = []
    for i in range(guide.num_flows):
        if i > 0:
            flows.append(transforms.PermuteTransform(
                jnp.arange(dim + 1)[::-1]))
        arn_init, arn_apply = AutoregressiveNN(
            dim + 1,
            [dim + 1, dim + 1],
            permutation=jnp.arange(dim + 1),
            skip_connections=guide._skip_connections,
            nonlinearity=guide._nonlinearity,
        )
        arn = partial(arn_apply, params["auto_arn__{}$params".format(i)])
        flows.append(InverseAutoregressiveTransform(arn))
    flows.append(guide._unpack_latent)

    transform = transforms.ComposeTransform(flows)
    _, rng_key_sample = random.split(rng_key)
    expected_sample = transform(
        dist.Normal(jnp.zeros(dim + 1), 1).sample(rng_key_sample))
    expected_output = transform(x)
    assert_allclose(actual_sample["coefs"], expected_sample["coefs"])
    assert_allclose(
        actual_sample["offset"],
        transforms.biject_to(constraints.interval(-1, 1))(
            expected_sample["offset"]),
    )
    check_eq(actual_output, expected_output)
Exemplo n.º 23
0
    def init(self, rng_key, *args, **kwargs):
        """
        Gets the initial SVI state.

        :param jax.random.PRNGKey rng_key: random number generator seed.
        :param args: arguments to the model / guide (these can possibly vary during
            the course of fitting).
        :param kwargs: keyword arguments to the model / guide (these can possibly vary
            during the course of fitting).
        :return: the initial :data:`SVIState`
        """
        rng_key, model_seed, guide_seed = random.split(rng_key, 3)
        model_init = seed(self.model, model_seed)
        guide_init = seed(self.guide, guide_seed)
        guide_trace = trace(guide_init).get_trace(*args, **kwargs, **self.static_kwargs)
        model_trace = trace(replay(model_init, guide_trace)).get_trace(
            *args, **kwargs, **self.static_kwargs
        )
        params = {}
        inv_transforms = {}
        mutable_state = {}
        # NB: params in model_trace will be overwritten by params in guide_trace
        for site in list(model_trace.values()) + list(guide_trace.values()):
            if site["type"] == "param":
                constraint = site["kwargs"].pop("constraint", constraints.real)
                with helpful_support_errors(site):
                    transform = biject_to(constraint)
                inv_transforms[site["name"]] = transform
                params[site["name"]] = transform.inv(site["value"])
            elif site["type"] == "mutable":
                mutable_state[site["name"]] = site["value"]
            elif (
                site["type"] == "sample"
                and (not site["is_observed"])
                and site["fn"].support.is_discrete
                and not self.loss.can_infer_discrete
            ):
                s_name = type(self.loss).__name__
                warnings.warn(
                    f"Currently, SVI with {s_name} loss does not support models with discrete latent variables"
                )

        if not mutable_state:
            mutable_state = None
        self.constrain_fn = partial(transform_fn, inv_transforms)
        # we convert weak types like float to float32/float64
        # to avoid recompiling body_fn in svi.run
        params, mutable_state = tree_map(
            lambda x: lax.convert_element_type(x, jnp.result_type(x)),
            (params, mutable_state),
        )
        return SVIState(self.optim.init(params), mutable_state, rng_key)
Exemplo n.º 24
0
 def find_params(self, rng_keys, *args, **kwargs):
     params = {}
     init_params, _ = handlers.block(find_valid_initial_params)(rng_keys, self.model,
                                                                init_strategy=self.init_strategy,
                                                                model_args=args,
                                                                model_kwargs=kwargs)
     for name, site in self.prototype_trace.items():
         if site['type'] == 'sample' and not site['is_observed']:
             param_name = "{}_{}".format(self.prefix, name)
             param_val = biject_to(site['fn'].support)(init_params[name])
             params[name] = (param_name, param_val, site['fn'].support)
     self._param_map = params
     self._init_params = {param: (val, constr) for param, val, constr in self._param_map.values()}
Exemplo n.º 25
0
def get_model_transforms(rng_key, model, model_args=(), model_kwargs=None):
    model_kwargs = {} if model_kwargs is None else model_kwargs
    seeded_model = seed(model, rng_key if rng_key.ndim == 1 else rng_key[0])
    model_trace = trace(seeded_model).get_trace(*model_args, **model_kwargs)
    inv_transforms = {}
    has_transformed_dist = False
    for k, v in model_trace.items():
        if v['type'] == 'sample' and not v['is_observed']:
            if v['intermediates']:
                inv_transforms[k] = biject_to(v['fn'].base_dist.support)
                has_transformed_dist = True
            else:
                inv_transforms[k] = biject_to(v['fn'].support)
        elif v['type'] == 'param':
            constraint = v['kwargs'].pop('constraint', real)
            transform = biject_to(constraint)
            if isinstance(transform, ComposeTransform):
                inv_transforms[k] = transform.parts[0]
                has_transformed_dist = True
            else:
                inv_transforms[k] = transform
    return inv_transforms, has_transformed_dist
Exemplo n.º 26
0
def _unconstrain_reparam(params, site):
    name = site['name']
    if name in params:
        p = params[name]
        t = biject_to(site['fn'].support)
        value = t(p)

        log_det = t.log_abs_det_jacobian(p, value)
        log_det = sum_rightmost(
            log_det,
            jnp.ndim(log_det) - jnp.ndim(value) + len(site['fn'].event_shape))
        if site['scale'] is not None:
            log_det = site['scale'] * log_det
        numpyro.factor('_{}_log_det'.format(name), log_det)
        return value
Exemplo n.º 27
0
def _get_model_transforms(model, model_args=(), model_kwargs=None):
    model_kwargs = {} if model_kwargs is None else model_kwargs
    model_trace = trace(model).get_trace(*model_args, **model_kwargs)
    inv_transforms = {}
    # model code may need to be replayed in the presence of deterministic sites
    replay_model = False
    has_enumerate_support = False
    for k, v in model_trace.items():
        if v["type"] == "sample" and not v["is_observed"]:
            if v["fn"].support.is_discrete:
                enum_type = v["infer"].get("enumerate")
                if enum_type is not None and (enum_type != "parallel"):
                    raise RuntimeError(
                        "This algorithm might only work for discrete sites with"
                        f" enumerate marked 'parallel'. But the site {k} is marked"
                        f" as '{enum_type}'.")
                has_enumerate_support = True
                if not v["fn"].has_enumerate_support:
                    dist_name = type(v["fn"]).__name__
                    raise RuntimeError(
                        "This algorithm might only work for discrete sites with"
                        f" enumerate support. But the {dist_name} distribution at"
                        f" site {k} does not have enumerate support.")
                if enum_type is None:
                    warnings.warn(
                        "Some algorithms will automatically enumerate the discrete"
                        f" latent site {k} of your model. In the future,"
                        " enumerated sites need to be marked with"
                        " `infer={'enumerate': 'parallel'}`.",
                        FutureWarning,
                        stacklevel=find_stack_level(),
                    )
            else:
                support = v["fn"].support
                with helpful_support_errors(v, raise_warnings=True):
                    inv_transforms[k] = biject_to(support)
                # XXX: the following code filters out most situations with dynamic supports
                args = ()
                if isinstance(support, constraints._GreaterThan):
                    args = ("lower_bound", )
                elif isinstance(support, constraints._Interval):
                    args = ("lower_bound", "upper_bound")
                for arg in args:
                    if not isinstance(getattr(support, arg), (int, float)):
                        replay_model = True
        elif v["type"] == "deterministic":
            replay_model = True
    return inv_transforms, replay_model, has_enumerate_support, model_trace
Exemplo n.º 28
0
    def __call__(self, *args, **kwargs):
        """
        An automatic guide with the same ``*args, **kwargs`` as the base ``model``.

        :return: A dict mapping sample site name to sampled value.
        :rtype: dict
        """
        if self.prototype_trace is None:
            # run model to inspect the model structure
            self._setup_prototype(*args, **kwargs)

        plates = self._create_plates(*args, **kwargs)
        result = {}
        for name, site in self.prototype_trace.items():
            if site["type"] != "sample" or site["is_observed"]:
                continue

            event_dim = self._event_dims[name]
            init_loc = self._init_locs[name]
            with ExitStack() as stack:
                for frame in site["cond_indep_stack"]:
                    stack.enter_context(plates[frame.name])

                site_loc = numpyro.param("{}_{}_loc".format(name, self.prefix),
                                         init_loc,
                                         event_dim=event_dim)
                site_scale = numpyro.param(
                    "{}_{}_scale".format(name, self.prefix),
                    jnp.full(jnp.shape(init_loc), self._init_scale),
                    constraint=self.scale_constraint,
                    event_dim=event_dim,
                )

                site_fn = dist.Normal(site_loc, site_scale).to_event(event_dim)
                if site["fn"].support is constraints.real or (
                        isinstance(site["fn"].support, constraints.independent)
                        and site["fn"].support is constraints.real):
                    result[name] = numpyro.sample(name, site_fn)
                else:
                    transform = biject_to(site["fn"].support)
                    guide_dist = dist.TransformedDistribution(
                        site_fn, transform)
                    result[name] = numpyro.sample(name, guide_dist)

        return result
Exemplo n.º 29
0
def _init_to_median(site, num_samples=15, skip_param=False):
    if site['type'] == 'sample' and not site['is_observed']:
        if isinstance(site['fn'], dist.TransformedDistribution):
            fn = site['fn'].base_dist
        else:
            fn = site['fn']
        samples = numpyro.sample('_init', fn,
                                 sample_shape=(num_samples,) + site['kwargs']['sample_shape'])
        return np.median(samples, axis=0)

    if site['type'] == 'param' and not skip_param:
        # return base value of param site
        constraint = site['kwargs'].pop('constraint', real)
        transform = biject_to(constraint)
        value = site['args'][0]
        if isinstance(transform, ComposeTransform):
            base_transform = transform.parts[0]
            value = base_transform(transform.inv(value))
        return value
Exemplo n.º 30
0
def _init_to_value(site, values={}, skip_param=False):
    if site['type'] == 'sample' and not site['is_observed']:
        if site['name'] not in values:
            return _init_to_uniform(site, skip_param=skip_param)

        value = values[site['name']]
        if isinstance(site['fn'], dist.TransformedDistribution):
            value = ComposeTransform(site['fn'].transforms).inv(value)
        return value

    if site['type'] == 'param' and not skip_param:
        # return base value of param site
        constraint = site['kwargs'].pop('constraint', real)
        transform = biject_to(constraint)
        value = site['args'][0]
        if isinstance(transform, ComposeTransform):
            base_transform = transform.parts[0]
            value = base_transform(transform.inv(value))
        return value