Exemplo n.º 1
0
 def codomain(self):
     if self.domain is constraints.real:
         return constraints.real
     elif isinstance(self.domain, constraints.greater_than):
         if not_jax_tracer(self.scale) and np.all(np.less(self.scale, 0)):
             return constraints.less_than(self(self.domain.lower_bound))
         # we suppose scale > 0 for any tracer
         else:
             return constraints.greater_than(self(self.domain.lower_bound))
     elif isinstance(self.domain, constraints.less_than):
         if not_jax_tracer(self.scale) and np.all(np.less(self.scale, 0)):
             return constraints.greater_than(self(self.domain.upper_bound))
         # we suppose scale > 0 for any tracer
         else:
             return constraints.less_than(self(self.domain.upper_bound))
     elif isinstance(self.domain, constraints.interval):
         if not_jax_tracer(self.scale) and np.all(np.less(self.scale, 0)):
             return constraints.interval(
                 self(self.domain.upper_bound), self(self.domain.lower_bound)
             )
         else:
             return constraints.interval(
                 self(self.domain.lower_bound), self(self.domain.upper_bound)
             )
     else:
         raise NotImplementedError
Exemplo n.º 2
0
    def _find_valid_params(rng_key, exit_early=False):
        init_state = (0, rng_key, (prototype_params, 0., prototype_params), False)
        if exit_early and not_jax_tracer(rng_key):
            # Early return if valid params found. This is only helpful for single chain,
            # where we can avoid compiling body_fn in while_loop.
            _, _, (init_params, pe, z_grad), is_valid = init_state = body_fn(init_state)
            if not_jax_tracer(is_valid):
                if device_get(is_valid):
                    return (init_params, pe, z_grad), is_valid

        # XXX: this requires compiling the model, so for multi-chain, we trace the model 2-times
        # even if the init_state is a valid result
        _, _, (init_params, pe, z_grad), is_valid = while_loop(cond_fn, body_fn, init_state)
        return (init_params, pe, z_grad), is_valid
Exemplo n.º 3
0
 def __init__(self, fn=None, scale=1.):
     if not_jax_tracer(scale):
         if scale <= 0:
             raise ValueError(
                 "'scale' argument should be a positive number.")
     self.scale = scale
     super().__init__(fn)
Exemplo n.º 4
0
 def _validate_sample(self, value):
     mask = self.support(value)
     if not_jax_tracer(mask):
         if not np.all(mask):
             warnings.warn('Out-of-support values provided to log prob method. '
                           'The value argument should be within the support.')
     return mask
Exemplo n.º 5
0
    def body_fn(wrapped_carry, x, prefix=None):
        i, rng_key, carry = wrapped_carry
        init = True if (not_jax_tracer(i) and i == 0) else False
        rng_key, subkey = random.split(rng_key) if rng_key is not None else (None, None)

        seeded_fn = handlers.seed(f, subkey) if subkey is not None else f
        for subs_type, subs_map in substitute_stack:
            subs_fn = partial(_subs_wrapper, subs_map, i, length)
            if subs_type == 'condition':
                seeded_fn = handlers.condition(seeded_fn, condition_fn=subs_fn)
            elif subs_type == 'substitute':
                seeded_fn = handlers.substitute(seeded_fn, substitute_fn=subs_fn)

        if init:
            with handlers.scope(prefix="_init"):
                new_carry, y = seeded_fn(carry, x)
                trace = {}
        else:
            with handlers.block(), packed_trace() as trace, promote_shapes(), enum(), markov():
                # Like scan_wrapper, we collect the trace of scan's transition function
                # `seeded_fn` here. To put time dimension to the correct position, we need to
                # promote shapes to make `fn` and `value`
                # at each site have the same batch dims (e.g. if `fn.batch_shape = (2, 3)`,
                # and value's batch_shape is (3,), then we promote shape of
                # value so that its batch shape is (1, 3)).
                new_carry, y = config_enumerate(seeded_fn)(carry, x)

            # store shape of new_carry at a global variable
            nonlocal carry_shape_at_t1
            carry_shape_at_t1 = [jnp.shape(x) for x in tree_flatten(new_carry)[0]]
            # make new_carry have the same shape as carry
            # FIXME: is this rigorous?
            new_carry = tree_multimap(lambda a, b: jnp.reshape(a, jnp.shape(b)),
                                      new_carry, carry)
        return (i + jnp.array(1), rng_key, new_carry), (PytreeTrace(trace), y)
Exemplo n.º 6
0
    def init(self,
             rng_key,
             num_warmup,
             init_params=None,
             model_args=(),
             model_kwargs={}):
        # non-vectorized
        if rng_key.ndim == 1:
            rng_key, rng_key_init_model = random.split(rng_key)
        # vectorized
        else:
            rng_key, rng_key_init_model = np.swapaxes(
                vmap(random.split)(rng_key), 0, 1)
            # we need only a single key for initializing PE / constraints fn
            rng_key_init_model = rng_key_init_model[0]
        if not self._init_fn:
            self._init_state(rng_key_init_model, model_args, model_kwargs)
        if self._potential_fn and init_params is None:
            raise ValueError(
                'Valid value of `init_params` must be provided with'
                ' `potential_fn`.')
        # Find valid initial params
        if self._model and not init_params:
            init_params, is_valid = find_valid_initial_params(
                rng_key,
                self._model,
                init_strategy=self._init_strategy,
                param_as_improper=True,
                model_args=model_args,
                model_kwargs=model_kwargs)
            if not_jax_tracer(is_valid):
                if device_get(~np.all(is_valid)):
                    raise RuntimeError("Cannot find valid initial parameters. "
                                       "Please check your model again.")

        hmc_init_fn = lambda init_params, rng_key: self._init_fn(  # noqa: E731
            init_params,
            num_warmup=num_warmup,
            step_size=self._step_size,
            adapt_step_size=self._adapt_step_size,
            adapt_mass_matrix=self._adapt_mass_matrix,
            dense_mass=self._dense_mass,
            target_accept_prob=self._target_accept_prob,
            trajectory_length=self._trajectory_length,
            max_tree_depth=self._max_tree_depth,
            rng_key=rng_key,
            model_args=model_args,
            model_kwargs=model_kwargs,
        )
        if rng_key.ndim == 1:
            init_state = hmc_init_fn(init_params, rng_key)
        else:
            # XXX it is safe to run hmc_init_fn under vmap despite that hmc_init_fn changes some
            # nonlocal variables: momentum_generator, wa_update, trajectory_len, max_treedepth,
            # wa_steps because those variables do not depend on traced args: init_params, rng_key.
            init_state = vmap(hmc_init_fn)(init_params, rng_key)
            sample_fn = vmap(self._sample_fn, in_axes=(0, None, None))
            self._sample_fn = sample_fn
        return init_state
Exemplo n.º 7
0
 def model():
     x = numpyro.sample("x", dist.MultivariateNormal(np.zeros(3), np.eye(3)))
     with numpyro.plate("plate", len(data)):
         y = numpyro.sample(
             "y", dist.MultivariateNormal(x, np.eye(3)), obs=data, obs_mask=mask
         )
         if not_jax_tracer(y):
             assert ((y == data).all(-1) == mask).all()
Exemplo n.º 8
0
    def _find_valid_params(rng_key_):
        _, _, prototype_params, is_valid = init_state = body_fn((0, rng_key_, None, None))
        # Early return if valid params found.
        if not_jax_tracer(is_valid):
            if device_get(is_valid):
                return prototype_params, is_valid

        _, _, init_params, is_valid = while_loop(cond_fn, body_fn, init_state)
        return init_params, is_valid
Exemplo n.º 9
0
 def enumerate_support(self, expand=True):
     if not not_jax_tracer(self.high) or not not_jax_tracer(self.low):
         raise NotImplementedError(
             "Both `low` and `high` must not be a JAX Tracer.")
     if np.any(np.amax(self.low) != self.low):
         # NB: the error can't be raised if inhomogeneous issue happens when tracing
         raise NotImplementedError(
             "Inhomogeneous `low` not supported by `enumerate_support`.")
     if np.any(np.amax(self.high) != self.high):
         # NB: the error can't be raised if inhomogeneous issue happens when tracing
         raise NotImplementedError(
             "Inhomogeneous `high` not supported by `enumerate_support`.")
     values = (self.low + jnp.arange(np.amax(self.high - self.low) + 1)
               ).reshape((-1, ) + (1, ) * len(self.batch_shape))
     if expand:
         values = jnp.broadcast_to(values,
                                   values.shape[:1] + self.batch_shape)
     return values
Exemplo n.º 10
0
 def model():
     x = numpyro.sample("x", dist.Normal(0., 1.))
     with numpyro.plate("plate", len(data)):
         y = numpyro.sample("y",
                            dist.Normal(x, 1.),
                            obs=data,
                            obs_mask=mask)
         if not_jax_tracer(y):
             assert ((y == data) == mask).all()
Exemplo n.º 11
0
 def _validate_sample(self, value):
     mask = self.support(value)
     if not_jax_tracer(mask):
         if not np.all(mask):
             warnings.warn(
                 "Out-of-support values provided to log prob method. "
                 "The value argument should be within the support.",
                 stacklevel=find_stack_level(),
             )
     return mask
Exemplo n.º 12
0
 def sample(self, key, sample_shape=()):
     key_dirichlet, key_multinom = random.split(key)
     probs = self._dirichlet.sample(key_dirichlet, sample_shape)
     total_count = jnp.amax(self.total_count)
     if not_jax_tracer(total_count):
         # NB: the error can't be raised if inhomogeneous issue happens when tracing
         if jnp.amin(self.total_count) != total_count:
             raise NotImplementedError(
                 "Inhomogeneous total count not supported"
                 " by `sample`.")
     return Multinomial(total_count, probs).sample(key_multinom)
Exemplo n.º 13
0
 def enumerate_support(self, expand=True):
     total_count = jnp.amax(self.total_count)
     if not_jax_tracer(total_count):
         # NB: the error can't be raised if inhomogeneous issue happens when tracing
         if jnp.amin(self.total_count) != total_count:
             raise NotImplementedError("Inhomogeneous total count not supported"
                                       " by `enumerate_support`.")
     values = jnp.arange(total_count + 1).reshape((-1,) + (1,) * len(self.batch_shape))
     if expand:
         values = jnp.broadcast_to(values, values.shape[:1] + self.batch_shape)
     return values
Exemplo n.º 14
0
def initialize_model(rng_key,
                     model,
                     init_strategy=init_to_uniform(),
                     dynamic_args=False,
                     model_args=(),
                     model_kwargs=None):
    """
    (EXPERIMENTAL INTERFACE) Helper function that calls :func:`~numpyro.infer.util.get_potential_fn`
    and :func:`~numpyro.infer.util.find_valid_initial_params` under the hood
    to return a tuple of (`init_params`, `potential_fn`, `constrain_fn`).

    :param jax.random.PRNGKey rng_key: random number generator seed to
        sample from the prior. The returned `init_params` will have the
        batch shape ``rng_key.shape[:-1]``.
    :param model: Python callable containing Pyro primitives.
    :param callable init_strategy: a per-site initialization function.
        See :ref:`init_strategy` section for available functions.
    :param bool dynamic_args: if `True`, the `potential_fn` and
        `constraints_fn` are themselves dependent on model arguments.
        When provided a `*model_args, **model_kwargs`, they return
        `potential_fn` and `constraints_fn` callables, respectively.
    :param tuple model_args: args provided to the model.
    :param dict model_kwargs: kwargs provided to the model.
    :return: tuple of (`init_params`, `potential_fn`, `postprocess_fn`),
        `init_params` are values from the prior used to initiate MCMC,
        `postprocess_fn` is a callable that uses inverse transforms
        to convert unconstrained HMC samples to constrained values that
        lie within the site's support, in addition to returning values
        at `deterministic` sites in the model.
    """
    if model_kwargs is None:
        model_kwargs = {}
    potential_fn, postprocess_fn = get_potential_fn(
        rng_key if rng_key.ndim == 1 else rng_key[0],
        model,
        dynamic_args=dynamic_args,
        model_args=model_args,
        model_kwargs=model_kwargs)

    init_params, is_valid = find_valid_initial_params(
        rng_key,
        model,
        init_strategy=init_strategy,
        param_as_improper=True,
        model_args=model_args,
        model_kwargs=model_kwargs)

    if not_jax_tracer(is_valid):
        if device_get(~np.all(is_valid)):
            raise RuntimeError(
                "Cannot find valid initial parameters. Please check your model again."
            )
    return init_params, potential_fn, postprocess_fn
Exemplo n.º 15
0
    def body_fn(wrapped_carry, x, prefix=None):
        i, rng_key, carry = wrapped_carry
        init = True if (not_jax_tracer(i)
                        and i in range(unroll_steps)) else False
        rng_key, subkey = random.split(rng_key) if rng_key is not None else (
            None, None)

        # we need to tell unconstrained messenger in potential energy computation
        # that only the item at time `i` is needed when transforming
        fn = handlers.infer_config(
            f, config_fn=lambda msg: {"_scan_current_index": i})

        seeded_fn = handlers.seed(fn, subkey) if subkey is not None else fn
        for subs_type, subs_map in substitute_stack:
            subs_fn = partial(_subs_wrapper, subs_map, i, length)
            if subs_type == "condition":
                seeded_fn = handlers.condition(seeded_fn, condition_fn=subs_fn)
            elif subs_type == "substitute":
                seeded_fn = handlers.substitute(seeded_fn,
                                                substitute_fn=subs_fn)

        if init:
            # handler the name to match the pattern of sakkar_bilmes product
            with handlers.scope(prefix="_PREV_" * (unroll_steps - i),
                                divider=""):
                new_carry, y = config_enumerate(seeded_fn)(carry, x)
                trace = {}
        else:
            # Like scan_wrapper, we collect the trace of scan's transition function
            # `seeded_fn` here. To put time dimension to the correct position, we need to
            # promote shapes to make `fn` and `value`
            # at each site have the same batch dims (e.g. if `fn.batch_shape = (2, 3)`,
            # and value's batch_shape is (3,), then we promote shape of
            # value so that its batch shape is (1, 3)).
            # Here we will promote `fn` shape first. `value` shape will be promoted after scanned.
            # We don't promote `value` shape here because we need to store carry shape
            # at this step. If we reshape the `value` here, output carry might get wrong shape.
            with _promote_fn_shapes(), packed_trace() as trace:
                new_carry, y = config_enumerate(seeded_fn)(carry, x)

            # store shape of new_carry at a global variable
            if len(carry_shapes) < (history + 1):
                carry_shapes.append(
                    [jnp.shape(x) for x in tree_flatten(new_carry)[0]])
            # make new_carry have the same shape as carry
            # FIXME: is this rigorous?
            new_carry = tree_multimap(
                lambda a, b: jnp.reshape(a, jnp.shape(b)), new_carry, carry)
        return (i + 1, rng_key, new_carry), (PytreeTrace(trace), y)
Exemplo n.º 16
0
    def get_transform(self, params):
        def loss_fn(z):
            params1 = params.copy()
            params1['{}_loc'.format(self.prefix)] = z
            return self._loss_fn(params1)

        loc = params['{}_loc'.format(self.prefix)]
        precision = hessian(loss_fn)(loc)
        scale_tril = cholesky_of_inverse(precision)
        if not_jax_tracer(scale_tril):
            if np.any(np.isnan(scale_tril)):
                warnings.warn("Hessian of log posterior at the MAP point is singular. Posterior"
                              " samples from AutoLaplaceApproxmiation will be constant (equal to"
                              " the MAP point).")
        scale_tril = jnp.where(jnp.isnan(scale_tril), 0., scale_tril)
        return LowerCholeskyAffine(loc, scale_tril)
Exemplo n.º 17
0
 def __init__(self, batch_shape=(), event_shape=(), validate_args=None):
     self._batch_shape = batch_shape
     self._event_shape = event_shape
     if validate_args is not None:
         self._validate_args = validate_args
     if self._validate_args:
         for param, constraint in self.arg_constraints.items():
             if param not in self.__dict__ and isinstance(getattr(type(self), param), lazy_property):
                 continue
             if is_dependent(constraint):
                 continue  # skip constraints that cannot be checked
             is_valid = np.all(constraint(getattr(self, param)))
             if not_jax_tracer(is_valid):
                 if not is_valid:
                     raise ValueError("The parameter {} has invalid values".format(param))
     super(Distribution, self).__init__()
Exemplo n.º 18
0
    def _get_transform(self, params):
        def loss_fn(z):
            params1 = params.copy()
            params1['{}_loc'.format(self.prefix)] = z
            # we are doing maximum likelihood, so only require `num_particles=1` and an arbitrary rng_key.
            return AutoContinuousELBO().loss(random.PRNGKey(0), params1, self.model, self,
                                             *self._args, **self._kwargs)

        loc = params['{}_loc'.format(self.prefix)]
        precision = hessian(loss_fn)(loc)
        scale_tril = cholesky_of_inverse(precision)
        if not_jax_tracer(scale_tril):
            if np.any(np.isnan(scale_tril)):
                warnings.warn("Hessian of log posterior at the MAP point is singular. Posterior"
                              " samples from AutoLaplaceApproxmiation will be constant (equal to"
                              " the MAP point).")
        scale_tril = np.where(np.isnan(scale_tril), 0., scale_tril)
        return MultivariateAffineTransform(loc, scale_tril)
Exemplo n.º 19
0
def log_density(model, model_args, model_kwargs, params):
    """
    (EXPERIMENTAL INTERFACE) Computes log of joint density for the model given
    latent values ``params``.

    :param model: Python callable containing NumPyro primitives.
    :param tuple model_args: args provided to the model.
    :param dict model_kwargs: kwargs provided to the model.
    :param dict params: dictionary of current parameter values keyed by site
        name.
    :return: log of joint density and a corresponding model trace
    """
    model = substitute(model, param_map=params)
    model_trace = trace(model).get_trace(*model_args, **model_kwargs)
    log_joint = jnp.array(0.)
    for site in model_trace.values():
        if site['type'] == 'sample':
            value = site['value']
            intermediates = site['intermediates']
            mask = site['mask']
            scale = site['scale']
            # Early exit when all elements are masked
            if not_jax_tracer(mask) and mask is not None and not jnp.any(mask):
                continue
            if intermediates:
                log_prob = site['fn'].log_prob(value, intermediates)
            else:
                log_prob = site['fn'].log_prob(value)

            # Minor optimizations
            # XXX: note that this may not work correctly for dynamic masks, provide
            # explicit jax.DeviceArray for masking.
            if mask is not None:
                if scale is not None:
                    log_prob = jnp.where(mask, scale * log_prob, 0.)
                else:
                    log_prob = jnp.where(mask, log_prob, 0.)
            else:
                if scale is not None:
                    log_prob = scale * log_prob
            log_prob = jnp.sum(log_prob)
            log_joint = log_joint + log_prob
    return log_joint, model_trace
Exemplo n.º 20
0
def _get_codomain(bijector):
    if bijector.__class__.__name__ == "Sigmoid":
        return constraints.interval(bijector.low, bijector.high)
    elif bijector.__class__.__name__ == "Identity":
        return constraints.real
    elif bijector.__class__.__name__ in ["Exp", "SoftPlus"]:
        return constraints.positive
    elif bijector.__class__.__name__ == "GeneralizedPareto":
        loc, scale, concentration = bijector.loc, bijector.scale, bijector.concentration
        if not_jax_tracer(concentration) and jnp.all(concentration < 0):
            return constraints.interval(loc, loc + scale / jnp.abs(concentration))
        # XXX: here we suppose concentration > 0
        # which is not true in general, but should cover enough usage cases
        else:
            return constraints.greater_than(loc)
    elif bijector.__class__.__name__ == "SoftmaxCentered":
        return constraints.simplex
    elif bijector.__class__.__name__ == "Chain":
        return _get_codomain(bijector.bijectors[-1])
    else:
        return constraints.real
Exemplo n.º 21
0
def initialize_model(
    rng_key,
    model,
    *,
    init_strategy=init_to_uniform,
    dynamic_args=False,
    model_args=(),
    model_kwargs=None,
    forward_mode_differentiation=False,
    validate_grad=True,
):
    """
    (EXPERIMENTAL INTERFACE) Helper function that calls :func:`~numpyro.infer.util.get_potential_fn`
    and :func:`~numpyro.infer.util.find_valid_initial_params` under the hood
    to return a tuple of (`init_params_info`, `potential_fn`, `postprocess_fn`, `model_trace`).

    :param jax.random.PRNGKey rng_key: random number generator seed to
        sample from the prior. The returned `init_params` will have the
        batch shape ``rng_key.shape[:-1]``.
    :param model: Python callable containing Pyro primitives.
    :param callable init_strategy: a per-site initialization function.
        See :ref:`init_strategy` section for available functions.
    :param bool dynamic_args: if `True`, the `potential_fn` and
        `constraints_fn` are themselves dependent on model arguments.
        When provided a `*model_args, **model_kwargs`, they return
        `potential_fn` and `constraints_fn` callables, respectively.
    :param tuple model_args: args provided to the model.
    :param dict model_kwargs: kwargs provided to the model.
    :param bool forward_mode_differentiation: whether to use forward-mode differentiation
        or reverse-mode differentiation. By default, we use reverse mode but the forward
        mode can be useful in some cases to improve the performance. In addition, some
        control flow utility on JAX such as `jax.lax.while_loop` or `jax.lax.fori_loop`
        only supports forward-mode differentiation. See
        `JAX's The Autodiff Cookbook <https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html>`_
        for more information.
    :param bool validate_grad: whether to validate gradient of the initial params.
        Defaults to True.
    :return: a namedtupe `ModelInfo` which contains the fields
        (`param_info`, `potential_fn`, `postprocess_fn`, `model_trace`), where
        `param_info` is a namedtuple `ParamInfo` containing values from the prior
        used to initiate MCMC, their corresponding potential energy, and their gradients;
        `postprocess_fn` is a callable that uses inverse transforms
        to convert unconstrained HMC samples to constrained values that
        lie within the site's support, in addition to returning values
        at `deterministic` sites in the model.
    """
    model_kwargs = {} if model_kwargs is None else model_kwargs
    substituted_model = substitute(
        seed(model, rng_key if jnp.ndim(rng_key) == 1 else rng_key[0]),
        substitute_fn=init_strategy,
    )
    (
        inv_transforms,
        replay_model,
        has_enumerate_support,
        model_trace,
    ) = _get_model_transforms(substituted_model, model_args, model_kwargs)
    # substitute param sites from model_trace to model so
    # we don't need to generate again parameters of `numpyro.module`
    model = substitute(
        model,
        data={
            k: site["value"]
            for k, site in model_trace.items() if site["type"] in ["param"]
        },
    )
    constrained_values = {
        k: v["value"]
        for k, v in model_trace.items() if v["type"] == "sample"
        and not v["is_observed"] and not v["fn"].is_discrete
    }

    if has_enumerate_support:
        from numpyro.contrib.funsor import config_enumerate, enum

        if not isinstance(model, enum):
            max_plate_nesting = _guess_max_plate_nesting(model_trace)
            _validate_model(model_trace)
            model = enum(config_enumerate(model), -max_plate_nesting - 1)

    potential_fn, postprocess_fn = get_potential_fn(
        model,
        inv_transforms,
        replay_model=replay_model,
        enum=has_enumerate_support,
        dynamic_args=dynamic_args,
        model_args=model_args,
        model_kwargs=model_kwargs,
    )

    init_strategy = (init_strategy if isinstance(init_strategy, partial) else
                     init_strategy())
    if (init_strategy.func is init_to_value) and not replay_model:
        init_values = init_strategy.keywords.get("values")
        unconstrained_values = transform_fn(inv_transforms,
                                            init_values,
                                            invert=True)
        init_strategy = _init_to_unconstrained_value(
            values=unconstrained_values)
    prototype_params = transform_fn(inv_transforms,
                                    constrained_values,
                                    invert=True)
    (init_params, pe, grad), is_valid = find_valid_initial_params(
        rng_key,
        substitute(
            model,
            data={
                k: site["value"]
                for k, site in model_trace.items()
                if site["type"] in ["plate"]
            },
        ),
        init_strategy=init_strategy,
        enum=has_enumerate_support,
        model_args=model_args,
        model_kwargs=model_kwargs,
        prototype_params=prototype_params,
        forward_mode_differentiation=forward_mode_differentiation,
        validate_grad=validate_grad,
    )

    if not_jax_tracer(is_valid):
        if device_get(~jnp.all(is_valid)):
            with numpyro.validation_enabled(), trace() as tr:
                # validate parameters
                substituted_model(*model_args, **model_kwargs)
                # validate values
                for site in tr.values():
                    if site["type"] == "sample":
                        with warnings.catch_warnings(record=True) as ws:
                            site["fn"]._validate_sample(site["value"])
                        if len(ws) > 0:
                            for w in ws:
                                # at site information to the warning message
                                w.message.args = ("Site {}: {}".format(
                                    site["name"],
                                    w.message.args[0]), ) + w.message.args[1:]
                                warnings.showwarning(
                                    w.message,
                                    w.category,
                                    w.filename,
                                    w.lineno,
                                    file=w.file,
                                    line=w.line,
                                )
            raise RuntimeError(
                "Cannot find valid initial parameters. Please check your model again."
            )
    return ModelInfo(ParamInfo(init_params, pe, grad), potential_fn,
                     postprocess_fn, model_trace)
Exemplo n.º 22
0
 def __init__(self, fn=None, scale=1.0):
     if not_jax_tracer(scale):
         if np.any(np.less_equal(scale, 0)):
             raise ValueError("'scale' argument should be positive.")
     self.scale = scale
     super().__init__(fn)
Exemplo n.º 23
0
def initialize_model(rng_key,
                     model,
                     init_strategy=init_to_uniform,
                     dynamic_args=False,
                     model_args=(),
                     model_kwargs=None):
    """
    (EXPERIMENTAL INTERFACE) Helper function that calls :func:`~numpyro.infer.util.get_potential_fn`
    and :func:`~numpyro.infer.util.find_valid_initial_params` under the hood
    to return a tuple of (`init_params_info`, `potential_fn`, `postprocess_fn`, `model_trace`).

    :param jax.random.PRNGKey rng_key: random number generator seed to
        sample from the prior. The returned `init_params` will have the
        batch shape ``rng_key.shape[:-1]``.
    :param model: Python callable containing Pyro primitives.
    :param callable init_strategy: a per-site initialization function.
        See :ref:`init_strategy` section for available functions.
    :param bool dynamic_args: if `True`, the `potential_fn` and
        `constraints_fn` are themselves dependent on model arguments.
        When provided a `*model_args, **model_kwargs`, they return
        `potential_fn` and `constraints_fn` callables, respectively.
    :param tuple model_args: args provided to the model.
    :param dict model_kwargs: kwargs provided to the model.
    :return: a namedtupe `ModelInfo` which contains the fields
        (`param_info`, `potential_fn`, `postprocess_fn`, `model_trace`), where
        `param_info` is a namedtuple `ParamInfo` containing values from the prior
        used to initiate MCMC, their corresponding potential energy, and their gradients;
        `postprocess_fn` is a callable that uses inverse transforms
        to convert unconstrained HMC samples to constrained values that
        lie within the site's support, in addition to returning values
        at `deterministic` sites in the model.
    """
    model_kwargs = {} if model_kwargs is None else model_kwargs
    substituted_model = substitute(seed(
        model, rng_key if jnp.ndim(rng_key) == 1 else rng_key[0]),
                                   substitute_fn=init_strategy)
    inv_transforms, replay_model, model_trace = _get_model_transforms(
        substituted_model, model_args, model_kwargs)
    constrained_values = {
        k: v['value']
        for k, v in model_trace.items() if v['type'] == 'sample'
        and not v['is_observed'] and not v['fn'].is_discrete
    }

    potential_fn, postprocess_fn = get_potential_fn(model,
                                                    inv_transforms,
                                                    replay_model=replay_model,
                                                    dynamic_args=dynamic_args,
                                                    model_args=model_args,
                                                    model_kwargs=model_kwargs)

    init_strategy = init_strategy if isinstance(init_strategy,
                                                partial) else init_strategy()
    if init_strategy.func is init_to_value:
        init_values = init_strategy.keywords.get("values")
        unconstrained_values = transform_fn(inv_transforms,
                                            init_values,
                                            invert=True)
        init_strategy = _init_to_unconstrained_value(
            values=unconstrained_values)
    prototype_params = transform_fn(inv_transforms,
                                    constrained_values,
                                    invert=True)
    (init_params, pe, grad), is_valid = find_valid_initial_params(
        rng_key,
        model,
        init_strategy=init_strategy,
        model_args=model_args,
        model_kwargs=model_kwargs,
        prototype_params=prototype_params)

    if not_jax_tracer(is_valid):
        if device_get(~jnp.all(is_valid)):
            raise RuntimeError(
                "Cannot find valid initial parameters. Please check your model again."
            )
    return ModelInfo(ParamInfo(init_params, pe, grad), potential_fn,
                     postprocess_fn, model_trace)
Exemplo n.º 24
0
 def __init__(self, fn=None, scale_factor=1.):
     if not_jax_tracer(scale_factor):
         if scale_factor <= 0:
             raise ValueError("scale factor should be a positive number.")
     self.scale = scale_factor
     super(scale, self).__init__(fn)
Exemplo n.º 25
0
def initialize_model(rng_key,
                     model,
                     *model_args,
                     init_strategy=init_to_uniform(),
                     **model_kwargs):
    """
    Given a model with Pyro primitives, returns a function which, given
    unconstrained parameters, evaluates the potential energy (negative
    joint density). In addition, this also returns initial parameters
    sampled from the prior to initiate MCMC sampling and functions to
    transform unconstrained values at sample sites to constrained values
    within their respective support.

    :param jax.random.PRNGKey rng_key: random number generator seed to
        sample from the prior. The returned `init_params` will have the
        batch shape ``rng_key.shape[:-1]``.
    :param model: Python callable containing Pyro primitives.
    :param `*model_args`: args provided to the model.
    :param callable init_strategy: a per-site initialization function.
    :param `**model_kwargs`: kwargs provided to the model.
    :return: tuple of (`init_params`, `potential_fn`, `constrain_fn`),
        `init_params` are values from the prior used to initiate MCMC,
        `constrain_fn` is a callable that uses inverse transforms
        to convert unconstrained HMC samples to constrained values that
        lie within the site's support.
    """
    seeded_model = seed(model, rng_key if rng_key.ndim == 1 else rng_key[0])
    model_trace = trace(seeded_model).get_trace(*model_args, **model_kwargs)
    constrained_values, inv_transforms = {}, {}
    has_transformed_dist = False
    for k, v in model_trace.items():
        if v['type'] == 'sample' and not v['is_observed']:
            if v['intermediates']:
                constrained_values[k] = v['intermediates'][0][0]
                inv_transforms[k] = biject_to(v['fn'].base_dist.support)
                has_transformed_dist = True
            else:
                constrained_values[k] = v['value']
                inv_transforms[k] = biject_to(v['fn'].support)
        elif v['type'] == 'param':
            constraint = v['kwargs'].pop('constraint', real)
            transform = biject_to(constraint)
            if isinstance(transform, ComposeTransform):
                base_transform = transform.parts[0]
                constrained_values[k] = base_transform(
                    transform.inv(v['value']))
                inv_transforms[k] = base_transform
                has_transformed_dist = True
            else:
                inv_transforms[k] = transform
                constrained_values[k] = v['value']

    prototype_params = transform_fn(
        inv_transforms, {k: v
                         for k, v in constrained_values.items()},
        invert=True)

    # NB: we use model instead of seeded_model to prevent unexpected behaviours (if any)
    potential_fn = jax.partial(potential_energy, model, model_args,
                               model_kwargs, inv_transforms)
    if has_transformed_dist:
        # FIXME: why using seeded_model here triggers an error for funnel reparam example
        # if we use MCMC class (mcmc function works fine)
        constrain_fun = jax.partial(constrain_fn, model, model_args,
                                    model_kwargs, inv_transforms)
    else:
        constrain_fun = jax.partial(transform_fn, inv_transforms)

    def single_chain_init(key):
        return find_valid_initial_params(key,
                                         model,
                                         *model_args,
                                         init_strategy=init_strategy,
                                         param_as_improper=True,
                                         prototype_params=prototype_params,
                                         **model_kwargs)

    if rng_key.ndim == 1:
        init_params, is_valid = single_chain_init(rng_key)
    else:
        init_params, is_valid = lax.map(single_chain_init, rng_key)

    if not_jax_tracer(is_valid):
        if device_get(~np.all(is_valid)):
            raise RuntimeError(
                "Cannot find valid initial parameters. Please check your model again."
            )
    return init_params, potential_fn, constrain_fun
Exemplo n.º 26
0
def find_valid_initial_params(rng_key,
                              model,
                              *model_args,
                              init_strategy=init_to_uniform(),
                              param_as_improper=False,
                              prototype_params=None,
                              **model_kwargs):
    """
    Given a model with Pyro primitives, returns an initial valid unconstrained
    parameters. This function also returns an `is_valid` flag to say whether the
    initial parameters are valid.

    :param jax.random.PRNGKey rng_key: random number generator seed to
        sample from the prior. The returned `init_params` will have the
        batch shape ``rng_key.shape[:-1]``.
    :param model: Python callable containing Pyro primitives.
    :param `*model_args`: args provided to the model.
    :param callable init_strategy: a per-site initialization function.
    :param bool param_as_improper: a flag to decide whether to consider sites with
        `param` statement as sites with improper priors.
    :param `**model_kwargs`: kwargs provided to the model.
    :return: tuple of (`init_params`, `is_valid`).
    """
    init_strategy = jax.partial(init_strategy,
                                skip_param=not param_as_improper)

    def cond_fn(state):
        i, _, _, is_valid = state
        return (i < 100) & (~is_valid)

    def body_fn(state):
        i, key, _, _ = state
        key, subkey = random.split(key)

        # Wrap model in a `substitute` handler to initialize from `init_loc_fn`.
        # Use `block` to not record sample primitives in `init_loc_fn`.
        seeded_model = substitute(model,
                                  substitute_fn=block(
                                      seed(init_strategy, subkey)))
        model_trace = trace(seeded_model).get_trace(*model_args,
                                                    **model_kwargs)
        constrained_values, inv_transforms = {}, {}
        for k, v in model_trace.items():
            if v['type'] == 'sample' and not v['is_observed']:
                if v['intermediates']:
                    constrained_values[k] = v['intermediates'][0][0]
                    inv_transforms[k] = biject_to(v['fn'].base_dist.support)
                else:
                    constrained_values[k] = v['value']
                    inv_transforms[k] = biject_to(v['fn'].support)
            elif v['type'] == 'param' and param_as_improper:
                constraint = v['kwargs'].pop('constraint', real)
                transform = biject_to(constraint)
                if isinstance(transform, ComposeTransform):
                    base_transform = transform.parts[0]
                    inv_transforms[k] = base_transform
                    constrained_values[k] = base_transform(
                        transform.inv(v['value']))
                else:
                    inv_transforms[k] = transform
                    constrained_values[k] = v['value']
        params = transform_fn(inv_transforms,
                              {k: v
                               for k, v in constrained_values.items()},
                              invert=True)
        potential_fn = jax.partial(potential_energy, model, model_args,
                                   model_kwargs, inv_transforms)
        pe, param_grads = value_and_grad(potential_fn)(params)
        z_grad = ravel_pytree(param_grads)[0]
        is_valid = np.isfinite(pe) & np.all(np.isfinite(z_grad))
        return i + 1, key, params, is_valid

    if prototype_params is not None:
        init_state = (0, rng_key, prototype_params, False)
    else:
        _, _, prototype_params, is_valid = init_state = body_fn(
            (0, rng_key, None, None))
        if not_jax_tracer(is_valid):
            if device_get(is_valid):
                return prototype_params, is_valid

    _, _, init_params, is_valid = while_loop(cond_fn, body_fn, init_state)
    return init_params, is_valid
Exemplo n.º 27
0
def log_density(model,
                model_args,
                model_kwargs,
                params,
                skip_dist_transforms=False):
    """
    (EXPERIMENTAL INTERFACE) Computes log of joint density for the model given
    latent values ``params``.

    :param model: Python callable containing NumPyro primitives.
    :param tuple model_args: args provided to the model.
    :param dict model_kwargs: kwargs provided to the model.
    :param dict params: dictionary of current parameter values keyed by site
        name.
    :param bool skip_dist_transforms: whether to compute log probability of a site
        (if its prior is a transformed distribution) in its base distribution
        domain.
    :return: log of joint density and a corresponding model trace
    """
    # We skip transforms in
    #   + autoguide's model
    #   + hmc's model
    # We apply transforms in
    #   + autoguide's guide
    #   + svi's model + guide
    if skip_dist_transforms:
        model = substitute(model, base_param_map=params)
    else:
        model = substitute(model, param_map=params)
    model_trace = trace(model).get_trace(*model_args, **model_kwargs)
    log_joint = 0.
    for site in model_trace.values():
        if site['type'] == 'sample':
            value = site['value']
            intermediates = site['intermediates']
            mask = site['mask']
            scale = site['scale']
            # Early exit when all elements are masked
            if not_jax_tracer(mask) and mask is not None and not np.any(mask):
                return jax.device_put(0.), model_trace
            if intermediates:
                if skip_dist_transforms:
                    log_prob = site['fn'].base_dist.log_prob(
                        intermediates[0][0])
                else:
                    log_prob = site['fn'].log_prob(value, intermediates)
            else:
                log_prob = site['fn'].log_prob(value)

            # Minor optimizations
            # XXX: note that this may not work correctly for dynamic masks, provide
            # explicit jax.DeviceArray for masking.
            if mask is not None:
                if scale is not None:
                    log_prob = np.where(mask, scale * log_prob, 0.)
                else:
                    log_prob = np.where(mask, log_prob, 0.)
            else:
                if scale is not None:
                    log_prob = scale * log_prob
            log_prob = np.sum(log_prob)
            log_joint = log_joint + log_prob
    return log_joint, model_trace