Exemple #1
0
def test_model_with_transformed_distribution():
    x_prior = dist.HalfNormal(2)
    y_prior = dist.LogNormal(scale=3.)  # transformed distribution

    def model():
        numpyro.sample('x', x_prior)
        numpyro.sample('y', y_prior)

    params = {'x': np.array(-5.), 'y': np.array(7.)}
    model = handlers.seed(model, random.PRNGKey(0))
    inv_transforms = {'x': biject_to(x_prior.support), 'y': biject_to(y_prior.support)}
    expected_samples = partial(transform_fn, inv_transforms)(params)
    expected_potential_energy = (
        - x_prior.log_prob(expected_samples['x']) -
        y_prior.log_prob(expected_samples['y']) -
        inv_transforms['x'].log_abs_det_jacobian(params['x'], expected_samples['x']) -
        inv_transforms['y'].log_abs_det_jacobian(params['y'], expected_samples['y'])
    )

    base_inv_transforms = {'x': biject_to(x_prior.support), 'y': biject_to(y_prior.base_dist.support)}
    actual_samples = constrain_fn(
        handlers.seed(model, random.PRNGKey(0)), (), {}, base_inv_transforms, params)
    actual_potential_energy = potential_energy(model, (), {}, base_inv_transforms, params)

    assert_allclose(expected_samples['x'], actual_samples['x'])
    assert_allclose(expected_samples['y'], actual_samples['y'])
    assert_allclose(actual_potential_energy, expected_potential_energy)
Exemple #2
0
def init_to_uniform(site, radius=2, skip_param=False):
    """
    Initialize to an arbitrary feasible point, ignoring distribution
    parameters.
    """
    if site['type'] == 'sample' and not site['is_observed']:
        if isinstance(site['fn'], dist.TransformedDistribution):
            fn = site['fn'].base_dist
        else:
            fn = site['fn']
        value = numpyro.sample('_init',
                               fn,
                               sample_shape=site['kwargs']['sample_shape'])
        base_transform = biject_to(fn.support)
        unconstrained_value = numpyro.sample('_unconstrained_init',
                                             dist.Uniform(-radius, radius),
                                             sample_shape=np.shape(
                                                 base_transform.inv(value)))
        return base_transform(unconstrained_value)

    if site['type'] == 'param' and not skip_param:
        # return base value of param site
        constraint = site['kwargs'].pop('constraint', real)
        transform = biject_to(constraint)
        value = site['args'][0]
        unconstrained_value = numpyro.sample('_unconstrained_init',
                                             dist.Uniform(-radius, radius),
                                             sample_shape=np.shape(
                                                 transform.inv(value)))
        if isinstance(transform, ComposeTransform):
            base_transform = transform.parts[0]
        else:
            base_transform = transform
        return base_transform(unconstrained_value)
Exemple #3
0
    def _setup_prototype(self, *args, **kwargs):
        super(AutoContinuous, self)._setup_prototype(*args, **kwargs)
        rng = numpyro.sample("_{}_rng_init".format(self.prefix),
                             dist.PRNGIdentity())
        # FIXME: without block statement, get AssertionError: all sites must have unique names
        init_params, is_valid = handlers.block(find_valid_initial_params)(
            rng, self.model, *args, init_strategy=self.init_strategy, **kwargs)
        self._inv_transforms = {}
        self._has_transformed_dist = False
        unconstrained_sites = {}
        for name, site in self.prototype_trace.items():
            if site['type'] == 'sample' and not site['is_observed']:
                if site['intermediates']:
                    transform = biject_to(site['fn'].base_dist.support)
                    self._inv_transforms[name] = transform
                    unconstrained_sites[name] = transform.inv(
                        site['intermediates'][0][0])
                    self._has_transformed_dist = True
                else:
                    transform = biject_to(site['fn'].support)
                    self._inv_transforms[name] = transform
                    unconstrained_sites[name] = transform.inv(site['value'])

        self._init_latent, self.unpack_latent = ravel_pytree(init_params)
        self.latent_size = np.size(self._init_latent)
        if self.base_dist is None:
            self.base_dist = _Normal(np.zeros(self.latent_size), 1.)
        if self.latent_size == 0:
            raise RuntimeError(
                '{} found no latent variables; Use an empty guide instead'.
                format(type(self).__name__))
Exemple #4
0
def initialize_model(rng,
                     model,
                     *model_args,
                     init_strategy='uniform',
                     **model_kwargs):
    """
    Given a model with Pyro primitives, returns a function which, given
    unconstrained parameters, evaluates the potential energy (negative
    joint density). In addition, this also returns initial parameters
    sampled from the prior to initiate MCMC sampling and functions to
    transform unconstrained values at sample sites to constrained values
    within their respective support.

    :param jax.random.PRNGKey rng: random number generator seed to
        sample from the prior.
    :param model: Python callable containing Pyro primitives.
    :param `*model_args`: args provided to the model.
    :param str init_strategy: initialization strategy - `uniform`
        initializes the unconstrained parameters by drawing from
        a `Uniform(-2, 2)` distribution (as used by Stan), whereas
        `prior` initializes the parameters by sampling from the prior
        for each of the sample sites.
    :param `**model_kwargs`: kwargs provided to the model.
    :return: tuple of (`init_params`, `potential_fn`, `constrain_fn`)
        `init_params` are values from the prior used to initiate MCMC.
        `constrain_fn` is a callable that uses inverse transforms
        to convert unconstrained HMC samples to constrained values that
        lie within the site's support.
    """
    model = seed(model, rng)
    model_trace = trace(model).get_trace(*model_args, **model_kwargs)
    sample_sites = {
        k: v
        for k, v in model_trace.items()
        if v['type'] == 'sample' and not v['is_observed']
    }
    inv_transforms = {
        k: biject_to(v['fn'].support)
        for k, v in sample_sites.items()
    }
    prior_params = constrain_fn(
        inv_transforms, {k: v['value']
                         for k, v in sample_sites.items()},
        invert=True)
    if init_strategy == 'uniform':
        init_params = {}
        for k, v in prior_params.items():
            rng, = random.split(rng, 1)
            init_params[k] = random.uniform(rng,
                                            shape=np.shape(v),
                                            minval=-2,
                                            maxval=2)
    elif init_strategy == 'prior':
        init_params = prior_params
    else:
        raise ValueError(
            'initialize={} is not a valid initialization strategy.'.format(
                init_strategy))
    return init_params, potential_energy(model, model_args, model_kwargs, inv_transforms), \
        jax.partial(constrain_fn, inv_transforms)
def test_log_prob_LKJCholesky(dimension, concentration):
    # We will test against the fact that LKJCorrCholesky can be seen as a
    # TransformedDistribution with base distribution is a distribution of partial
    # correlations in C-vine method (modulo an affine transform to change domain from (0, 1)
    # to (1, 0)) and transform is a signed stick-breaking process.
    d = dist.LKJCholesky(dimension, concentration, sample_method="cvine")

    beta_sample = d._beta.sample(random.PRNGKey(0))
    beta_log_prob = np.sum(d._beta.log_prob(beta_sample))
    partial_correlation = 2 * beta_sample - 1
    affine_logdet = beta_sample.shape[-1] * np.log(2)
    sample = signed_stick_breaking_tril(partial_correlation)

    # compute signed stick breaking logdet
    inv_tanh = lambda t: np.log((1 + t) / (1 - t)) / 2  # noqa: E731
    inv_tanh_logdet = np.sum(np.log(vmap(grad(inv_tanh))(partial_correlation)))
    unconstrained = inv_tanh(partial_correlation)
    corr_cholesky_logdet = biject_to(
        constraints.corr_cholesky).log_abs_det_jacobian(
            unconstrained,
            sample,
        )
    signed_stick_breaking_logdet = corr_cholesky_logdet + inv_tanh_logdet

    actual_log_prob = d.log_prob(sample)
    expected_log_prob = beta_log_prob - affine_logdet - signed_stick_breaking_logdet
    assert_allclose(actual_log_prob, expected_log_prob, rtol=1e-5)

    assert_allclose(jax.jit(d.log_prob)(sample), d.log_prob(sample), atol=1e-7)
def test_bijective_transforms(transform, event_shape, batch_shape):
    shape = batch_shape + event_shape
    rng = random.PRNGKey(0)
    x = biject_to(transform.domain)(random.normal(rng, shape))
    y = transform(x)

    # test codomain
    assert_array_equal(transform.codomain(y), np.ones(batch_shape))

    # test inv
    z = transform.inv(y)
    assert_allclose(x, z, atol=1e-6, rtol=1e-6)

    # test domain
    assert_array_equal(transform.domain(z), np.ones(batch_shape))

    # test log_abs_det_jacobian
    actual = transform.log_abs_det_jacobian(x, y)
    assert np.shape(actual) == batch_shape
    if len(shape) == transform.event_dim:
        if isinstance(transform, PermuteTransform):
            expected = onp.linalg.slogdet(jax.jacobian(transform)(x))[1]
            inv_expected = onp.linalg.slogdet(jax.jacobian(
                transform.inv)(y))[1]
        else:
            expected = np.log(np.abs(grad(transform)(x)))
            inv_expected = np.log(np.abs(grad(transform.inv)(y)))

        assert_allclose(actual, expected, atol=1e-6)
        assert_allclose(actual, -inv_expected, atol=1e-6)
Exemple #7
0
def test_elbo_dynamic_support():
    x_prior = dist.Uniform(0, 5)
    x_unconstrained = 2.

    def model():
        numpyro.sample('x', x_prior)

    class _AutoGuide(AutoDiagonalNormal):
        def __call__(self, *args, **kwargs):
            return substitute(
                super(_AutoGuide, self).__call__,
                {'_auto_latent': x_unconstrained})(*args, **kwargs)

    adam = optim.Adam(0.01)
    guide = _AutoGuide(model)
    svi = SVI(model, guide, elbo, adam)
    svi_state = svi.init(random.PRNGKey(0), (), ())
    actual_loss = svi.evaluate(svi_state)
    assert np.isfinite(actual_loss)

    guide_log_prob = dist.Normal(
        guide._init_latent).log_prob(x_unconstrained).sum()
    transfrom = constraints.biject_to(constraints.interval(0, 5))
    x = transfrom(x_unconstrained)
    logdet = transfrom.log_abs_det_jacobian(x_unconstrained, x)
    model_log_prob = x_prior.log_prob(x) + logdet
    expected_loss = guide_log_prob - model_log_prob
    assert_allclose(actual_loss, expected_loss)
def test_biject_to(constraint, shape):
    transform = biject_to(constraint)
    if isinstance(constraint, constraints._Interval):
        assert transform.codomain.upper_bound == constraint.upper_bound
        assert transform.codomain.lower_bound == constraint.lower_bound
    elif isinstance(constraint, constraints._GreaterThan):
        assert transform.codomain.lower_bound == constraint.lower_bound
    if len(shape) < transform.event_dim:
        return
    rng = random.PRNGKey(0)
    x = random.normal(rng, shape)
    y = transform(x)

    # test codomain
    batch_shape = shape if transform.event_dim == 0 else shape[:-1]
    assert_array_equal(transform.codomain(y),
                       np.ones(batch_shape, dtype=np.bool_))

    # test inv
    z = transform.inv(y)
    assert_allclose(x, z, atol=1e-6, rtol=1e-6)

    # test domain, currently all is constraints.real or constraints.real_vector
    assert_array_equal(transform.domain(z), np.ones(batch_shape))

    # test log_abs_det_jacobian
    actual = transform.log_abs_det_jacobian(x, y)
    assert np.shape(actual) == batch_shape
    if len(shape) == transform.event_dim:
        if constraint is constraints.simplex:
            expected = onp.linalg.slogdet(
                jax.jacobian(transform)(x)[:-1, :])[1]
            inv_expected = onp.linalg.slogdet(
                jax.jacobian(transform.inv)(y)[:, :-1])[1]
        elif constraint is constraints.corr_cholesky:
            vec_transform = lambda x: matrix_to_tril_vec(
                transform(x), diagonal=-1)  # noqa: E731
            y_tril = matrix_to_tril_vec(y, diagonal=-1)
            inv_vec_transform = lambda x: transform.inv(
                vec_to_tril_matrix(x, diagonal=-1))  # noqa: E731
            expected = onp.linalg.slogdet(jax.jacobian(vec_transform)(x))[1]
            inv_expected = onp.linalg.slogdet(
                jax.jacobian(inv_vec_transform)(y_tril))[1]
        elif constraint is constraints.lower_cholesky:
            vec_transform = lambda x: matrix_to_tril_vec(transform(x)
                                                         )  # noqa: E731
            y_tril = matrix_to_tril_vec(y)
            inv_vec_transform = lambda x: transform.inv(vec_to_tril_matrix(x)
                                                        )  # noqa: E731
            expected = onp.linalg.slogdet(jax.jacobian(vec_transform)(x))[1]
            inv_expected = onp.linalg.slogdet(
                jax.jacobian(inv_vec_transform)(y_tril))[1]
        else:
            expected = np.log(np.abs(grad(transform)(x)))
            inv_expected = np.log(np.abs(grad(transform.inv)(y)))

        assert_allclose(actual, expected, atol=1e-6)
        assert_allclose(actual, -inv_expected, atol=1e-6)
Exemple #9
0
def init_to_feasible(site):
    """
    Initialize to an arbitrary feasible point, ignoring distribution
    parameters.
    """
    if site['is_observed']:
        return None
    value = sample('_init', site['fn'])
    t = biject_to(site['fn'].support)
    return t(np.zeros(np.shape(t.inv(value))))
Exemple #10
0
    def body_fn(state):
        i, key, _, _ = state
        key, subkey = random.split(key)

        # Wrap model in a `substitute` handler to initialize from `init_loc_fn`.
        # Use `block` to not record sample primitives in `init_loc_fn`.
        seeded_model = substitute(model,
                                  substitute_fn=block(
                                      seed(init_strategy, subkey)))
        model_trace = trace(seeded_model).get_trace(*model_args,
                                                    **model_kwargs)
        constrained_values, inv_transforms = {}, {}
        for k, v in model_trace.items():
            if v['type'] == 'sample' and not v['is_observed']:
                if v['intermediates']:
                    constrained_values[k] = v['intermediates'][0][0]
                    inv_transforms[k] = biject_to(v['fn'].base_dist.support)
                else:
                    constrained_values[k] = v['value']
                    inv_transforms[k] = biject_to(v['fn'].support)
            elif v['type'] == 'param' and param_as_improper:
                constraint = v['kwargs'].pop('constraint', real)
                transform = biject_to(constraint)
                if isinstance(transform, ComposeTransform):
                    base_transform = transform.parts[0]
                    inv_transforms[k] = base_transform
                    constrained_values[k] = base_transform(
                        transform.inv(v['value']))
                else:
                    inv_transforms[k] = transform
                    constrained_values[k] = v['value']
        params = transform_fn(inv_transforms,
                              {k: v
                               for k, v in constrained_values.items()},
                              invert=True)
        potential_fn = jax.partial(potential_energy, model, model_args,
                                   model_kwargs, inv_transforms)
        pe, param_grads = value_and_grad(potential_fn)(params)
        z_grad = ravel_pytree(param_grads)[0]
        is_valid = np.isfinite(pe) & np.all(np.isfinite(z_grad))
        return i + 1, key, params, is_valid
Exemple #11
0
def test_iaf():
    # test for substitute logic for exposed methods `sample_posterior` and `get_transforms`
    N, dim = 3000, 3
    data = random.normal(random.PRNGKey(0), (N, dim))
    true_coefs = np.arange(1., dim + 1.)
    logits = np.sum(true_coefs * data, axis=-1)
    labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(1))

    def model(data, labels):
        coefs = numpyro.sample('coefs', dist.Normal(np.zeros(dim),
                                                    np.ones(dim)))
        offset = numpyro.sample('offset', dist.Uniform(-1, 1))
        logits = offset + np.sum(coefs * data, axis=-1)
        return numpyro.sample('obs', dist.Bernoulli(logits=logits), obs=labels)

    adam = optim.Adam(0.01)
    rng_init = random.PRNGKey(1)
    guide = AutoIAFNormal(model)
    svi = SVI(model, guide, elbo, adam)
    svi_state = svi.init(rng_init,
                         model_args=(data, labels),
                         guide_args=(data, labels))
    params = svi.get_params(svi_state)

    x = random.normal(random.PRNGKey(0), (dim + 1, ))
    rng = random.PRNGKey(1)
    actual_sample = guide.sample_posterior(rng, params)
    actual_output = guide.get_transform(params)(x)

    flows = []
    for i in range(guide.num_flows):
        if i > 0:
            flows.append(constraints.PermuteTransform(
                np.arange(dim + 1)[::-1]))
        arn_init, arn_apply = AutoregressiveNN(
            dim + 1, [dim + 1, dim + 1],
            permutation=np.arange(dim + 1),
            skip_connections=guide._skip_connections,
            nonlinearity=guide._nonlinearity)
        arn = partial(arn_apply, params['auto_arn__{}$params'.format(i)])
        flows.append(InverseAutoregressiveTransform(arn))

    transform = constraints.ComposeTransform(flows)
    rng_seed, rng_sample = random.split(rng)
    expected_sample = guide.unpack_latent(
        transform(dist.Normal(np.zeros(dim + 1), 1).sample(rng_sample)))
    expected_output = transform(x)
    assert_allclose(actual_sample['coefs'], expected_sample['coefs'])
    assert_allclose(
        actual_sample['offset'],
        constraints.biject_to(constraints.interval(-1, 1))(
            expected_sample['offset']))
    assert_allclose(actual_output, expected_output)
Exemple #12
0
    def single_chain_init(key, only_params=False):
        seeded_model = seed(model, key)
        model_trace = trace(seeded_model).get_trace(*model_args,
                                                    **model_kwargs)
        constrained_values, inv_transforms = {}, {}
        for k, v in model_trace.items():
            if v['type'] == 'sample' and not v['is_observed']:
                constrained_values[k] = v['value']
                inv_transforms[k] = biject_to(v['fn'].support)
            elif v['type'] == 'param':
                constrained_values[k] = v['value']
                constraint = v['kwargs'].pop('constraint', real)
                inv_transforms[k] = biject_to(constraint)
        prior_params = transform_fn(
            inv_transforms, {k: v
                             for k, v in constrained_values.items()},
            invert=True)
        if init_strategy == 'uniform':
            init_params = {}
            for k, v in prior_params.items():
                key, = random.split(key, 1)
                init_params[k] = random.uniform(key,
                                                shape=np.shape(v),
                                                minval=-2,
                                                maxval=2)
        elif init_strategy == 'prior':
            init_params = prior_params
        else:
            raise ValueError(
                'initialize={} is not a valid initialization strategy.'.format(
                    init_strategy))

        if only_params:
            return init_params
        else:
            return (init_params,
                    potential_energy(seeded_model, model_args, model_kwargs,
                                     inv_transforms),
                    jax.partial(transform_fn, inv_transforms))
Exemple #13
0
 def process_message(self, msg):
     if self.param_map is not None:
         if msg['name'] in self.param_map:
             msg['value'] = self.param_map[msg['name']]
     elif self.base_param_map is not None:
         if msg['name'] in self.base_param_map:
             if msg['type'] == 'sample':
                 msg['value'], msg['intermediates'] = msg[
                     'fn'].transform_with_intermediates(
                         self.base_param_map[msg['name']])
             else:
                 base_value = self.base_param_map[msg['name']]
                 constraint = msg['kwargs'].pop('constraint', real)
                 transform = biject_to(constraint)
                 if isinstance(transform, ComposeTransform):
                     msg['value'] = ComposeTransform(
                         transform.parts[1:])(base_value)
                 else:
                     msg['value'] = self.base_param_map[msg['name']]
     elif self.substitute_fn is not None:
         base_value = self.substitute_fn(msg)
         if base_value is not None:
             if msg['type'] == 'sample':
                 msg['value'], msg['intermediates'] = msg[
                     'fn'].transform_with_intermediates(base_value)
             else:
                 constraint = msg['kwargs'].pop('constraint', real)
                 transform = biject_to(constraint)
                 if isinstance(transform, ComposeTransform):
                     msg['value'] = ComposeTransform(
                         transform.parts[1:])(base_value)
                 else:
                     msg['value'] = base_value
     else:
         raise ValueError(
             "Neither `param_map`, `base_param_map`, nor `substitute_fn`"
             "provided to substitute handler.")
Exemple #14
0
    def init_fn(rng, model_args=(), guide_args=(), params=None):
        """

        :param jax.random.PRNGKey rng: random number generator seed.
        :param tuple model_args: arguments to the model (these can possibly vary during
            the course of fitting).
        :param tuple guide_args: arguments to the guide (these can possibly vary during
            the course of fitting).
        :param dict params: initial parameter values to condition on. This can be
            useful for initializing neural networks using more specialized methods
            rather than sampling from the prior.
        :return: tuple containing initial optimizer state, and `constrain_fn`, a callable
            that transforms unconstrained parameter values from the optimizer to the
            specified constrained domain
        """
        assert isinstance(model_args, tuple)
        assert isinstance(guide_args, tuple)
        model_init, guide_init = _seed(model, guide, rng)
        if params is None:
            params = {}
        else:
            model_init = substitute(model_init, params)
            guide_init = substitute(guide_init, params)
        guide_trace = trace(guide_init).get_trace(*guide_args, **kwargs)
        model_trace = trace(model_init).get_trace(*model_args, **kwargs)
        inv_transforms = {}
        # NB: params in model_trace will be overwritten by params in guide_trace
        for site in list(model_trace.values()) + list(guide_trace.values()):
            if site['type'] == 'param':
                constraint = site['kwargs'].pop('constraint', constraints.real)
                transform = biject_to(constraint)
                if isinstance(transform, ComposeTransform):
                    base_transform = transform.parts[0]
                    inv_transforms[site['name']] = base_transform
                    params[site['name']] = base_transform(
                        transform.inv(site['value']))
                else:
                    inv_transforms[site['name']] = transform
                    params[site['name']] = site['value']

        nonlocal constrain_fn
        constrain_fn = jax.partial(transform_fn, inv_transforms)
        return optim_init(params), constrain_fn
Exemple #15
0
    def _setup_prototype(self, *args, **kwargs):
        super(AutoContinuous, self)._setup_prototype(*args, **kwargs)
        self._inv_transforms = {}
        unconstrained_sites = {}
        for name, site in self.prototype_trace.items():
            if site['type'] == 'sample' and not site['is_observed']:
                # Collect the shapes of unconstrained values.
                # These may differ from the shapes of constrained values.
                transform = biject_to(site['fn'].support)
                unconstrained_val = transform.inv(site['value'])
                self._inv_transforms[name] = transform
                unconstrained_sites[name] = unconstrained_val

        latent_size = sum(np.size(x) for x in unconstrained_sites.values())
        if latent_size == 0:
            raise RuntimeError(
                '{} found no latent variables; Use an empty guide instead'.
                format(type(self).__name__))
        self._init_latent, self._unravel_fn = ravel_pytree(unconstrained_sites)
Exemple #16
0
 def process_message(self, msg):
     if self.param_map is not None:
         if msg['name'] in self.param_map:
             msg['value'] = self.param_map[msg['name']]
     else:
         base_value = self.substitute_fn(msg) if self.substitute_fn \
             else self.base_param_map.get(msg['name'], None)
         if base_value is not None:
             if msg['type'] == 'sample':
                 msg['value'], msg['intermediates'] = msg[
                     'fn'].transform_with_intermediates(base_value)
             else:
                 constraint = msg['kwargs'].pop('constraint', real)
                 transform = biject_to(constraint)
                 if isinstance(transform, ComposeTransform):
                     # No need to apply the first transform since the base value
                     # should have the same support as the first part's co-domain.
                     msg['value'] = ComposeTransform(
                         transform.parts[1:])(base_value)
                 else:
                     msg['value'] = base_value
Exemple #17
0
def init_to_median(site, num_samples=15, skip_param=False):
    """
    Initialize to the prior median.
    """
    if site['type'] == 'sample' and not site['is_observed']:
        if isinstance(site['fn'], dist.TransformedDistribution):
            fn = site['fn'].base_dist
        else:
            fn = site['fn']
        samples = sample('_init', fn, sample_shape=(num_samples,))
        return np.median(samples, axis=0)

    if site['type'] == 'param' and not skip_param:
        # return base value of param site
        constraint = site['kwargs'].pop('constraint', real)
        transform = biject_to(constraint)
        value = site['args'][0]
        if isinstance(transform, ComposeTransform):
            base_transform = transform.parts[0]
            value = base_transform(transform.inv(value))
        return value
Exemple #18
0
 def _loc_scale(self, opt_state):
     params = self.get_params(opt_state)
     loc = params['{}_loc'.format(self.prefix)]
     scale = biject_to(constraints.positive)(params['{}_scale'.format(
         self.prefix)])
     return loc, scale
Exemple #19
0
def initialize_model(rng, model, *model_args, init_strategy=init_to_uniform, **model_kwargs):
    """
    Given a model with Pyro primitives, returns a function which, given
    unconstrained parameters, evaluates the potential energy (negative
    joint density). In addition, this also returns initial parameters
    sampled from the prior to initiate MCMC sampling and functions to
    transform unconstrained values at sample sites to constrained values
    within their respective support.

    :param jax.random.PRNGKey rng: random number generator seed to
        sample from the prior. The returned `init_params` will have the
        batch shape ``rng.shape[:-1]``.
    :param model: Python callable containing Pyro primitives.
    :param `*model_args`: args provided to the model.
    :param callable init_strategy: a per-site initialization function.
    :param `**model_kwargs`: kwargs provided to the model.
    :return: tuple of (`init_params`, `potential_fn`, `constrain_fn`),
        `init_params` are values from the prior used to initiate MCMC,
        `constrain_fn` is a callable that uses inverse transforms
        to convert unconstrained HMC samples to constrained values that
        lie within the site's support.
    """
    seeded_model = seed(model, rng if rng.ndim == 1 else rng[0])
    model_trace = trace(seeded_model).get_trace(*model_args, **model_kwargs)
    constrained_values, inv_transforms = {}, {}
    has_transformed_dist = False
    for k, v in model_trace.items():
        if v['type'] == 'sample' and not v['is_observed']:
            if v['intermediates']:
                constrained_values[k] = v['intermediates'][0][0]
                inv_transforms[k] = biject_to(v['fn'].base_dist.support)
                has_transformed_dist = True
            else:
                constrained_values[k] = v['value']
                inv_transforms[k] = biject_to(v['fn'].support)
        elif v['type'] == 'param':
            constraint = v['kwargs'].pop('constraint', real)
            transform = biject_to(constraint)
            if isinstance(transform, ComposeTransform):
                base_transform = transform.parts[0]
                constrained_values[k] = base_transform(transform.inv(v['value']))
                inv_transforms[k] = base_transform
                has_transformed_dist = True
            else:
                inv_transforms[k] = transform
                constrained_values[k] = v['value']

    prototype_params = transform_fn(inv_transforms,
                                    {k: v for k, v in constrained_values.items()},
                                    invert=True)

    # NB: we use model instead of seeded_model to prevent unexpected behaviours (if any)
    potential_fn = jax.partial(potential_energy, model, model_args, model_kwargs, inv_transforms)
    if has_transformed_dist:
        # FIXME: why using seeded_model here triggers an error for funnel reparam example
        # if we use MCMC class (mcmc function works fine)
        constrain_fun = jax.partial(constrain_fn, model, model_args, model_kwargs, inv_transforms)
    else:
        constrain_fun = jax.partial(transform_fn, inv_transforms)

    def single_chain_init(key):
        return find_valid_initial_params(key, model, *model_args, init_strategy=init_strategy,
                                         param_as_improper=True, prototype_params=prototype_params,
                                         **model_kwargs)

    if rng.ndim == 1:
        init_params, is_valid = single_chain_init(rng)
    else:
        init_params, is_valid = lax.map(single_chain_init, rng)

    if isinstance(is_valid, jax.interpreters.xla.DeviceArray):
        if device_get(~np.all(is_valid)):
            raise RuntimeError("Cannot find valid initial parameters. Please check your model again.")
    return init_params, potential_fn, constrain_fun