def codomain(self): if self.domain is constraints.real: return constraints.real elif isinstance(self.domain, constraints.greater_than): if not_jax_tracer(self.scale) and np.all(np.less(self.scale, 0)): return constraints.less_than(self(self.domain.lower_bound)) # we suppose scale > 0 for any tracer else: return constraints.greater_than(self(self.domain.lower_bound)) elif isinstance(self.domain, constraints.less_than): if not_jax_tracer(self.scale) and np.all(np.less(self.scale, 0)): return constraints.greater_than(self(self.domain.upper_bound)) # we suppose scale > 0 for any tracer else: return constraints.less_than(self(self.domain.upper_bound)) elif isinstance(self.domain, constraints.interval): if not_jax_tracer(self.scale) and np.all(np.less(self.scale, 0)): return constraints.interval( self(self.domain.upper_bound), self(self.domain.lower_bound) ) else: return constraints.interval( self(self.domain.lower_bound), self(self.domain.upper_bound) ) else: raise NotImplementedError
def _find_valid_params(rng_key, exit_early=False): init_state = (0, rng_key, (prototype_params, 0., prototype_params), False) if exit_early and not_jax_tracer(rng_key): # Early return if valid params found. This is only helpful for single chain, # where we can avoid compiling body_fn in while_loop. _, _, (init_params, pe, z_grad), is_valid = init_state = body_fn(init_state) if not_jax_tracer(is_valid): if device_get(is_valid): return (init_params, pe, z_grad), is_valid # XXX: this requires compiling the model, so for multi-chain, we trace the model 2-times # even if the init_state is a valid result _, _, (init_params, pe, z_grad), is_valid = while_loop(cond_fn, body_fn, init_state) return (init_params, pe, z_grad), is_valid
def __init__(self, fn=None, scale=1.): if not_jax_tracer(scale): if scale <= 0: raise ValueError( "'scale' argument should be a positive number.") self.scale = scale super().__init__(fn)
def _validate_sample(self, value): mask = self.support(value) if not_jax_tracer(mask): if not np.all(mask): warnings.warn('Out-of-support values provided to log prob method. ' 'The value argument should be within the support.') return mask
def body_fn(wrapped_carry, x, prefix=None): i, rng_key, carry = wrapped_carry init = True if (not_jax_tracer(i) and i == 0) else False rng_key, subkey = random.split(rng_key) if rng_key is not None else (None, None) seeded_fn = handlers.seed(f, subkey) if subkey is not None else f for subs_type, subs_map in substitute_stack: subs_fn = partial(_subs_wrapper, subs_map, i, length) if subs_type == 'condition': seeded_fn = handlers.condition(seeded_fn, condition_fn=subs_fn) elif subs_type == 'substitute': seeded_fn = handlers.substitute(seeded_fn, substitute_fn=subs_fn) if init: with handlers.scope(prefix="_init"): new_carry, y = seeded_fn(carry, x) trace = {} else: with handlers.block(), packed_trace() as trace, promote_shapes(), enum(), markov(): # Like scan_wrapper, we collect the trace of scan's transition function # `seeded_fn` here. To put time dimension to the correct position, we need to # promote shapes to make `fn` and `value` # at each site have the same batch dims (e.g. if `fn.batch_shape = (2, 3)`, # and value's batch_shape is (3,), then we promote shape of # value so that its batch shape is (1, 3)). new_carry, y = config_enumerate(seeded_fn)(carry, x) # store shape of new_carry at a global variable nonlocal carry_shape_at_t1 carry_shape_at_t1 = [jnp.shape(x) for x in tree_flatten(new_carry)[0]] # make new_carry have the same shape as carry # FIXME: is this rigorous? new_carry = tree_multimap(lambda a, b: jnp.reshape(a, jnp.shape(b)), new_carry, carry) return (i + jnp.array(1), rng_key, new_carry), (PytreeTrace(trace), y)
def init(self, rng_key, num_warmup, init_params=None, model_args=(), model_kwargs={}): # non-vectorized if rng_key.ndim == 1: rng_key, rng_key_init_model = random.split(rng_key) # vectorized else: rng_key, rng_key_init_model = np.swapaxes( vmap(random.split)(rng_key), 0, 1) # we need only a single key for initializing PE / constraints fn rng_key_init_model = rng_key_init_model[0] if not self._init_fn: self._init_state(rng_key_init_model, model_args, model_kwargs) if self._potential_fn and init_params is None: raise ValueError( 'Valid value of `init_params` must be provided with' ' `potential_fn`.') # Find valid initial params if self._model and not init_params: init_params, is_valid = find_valid_initial_params( rng_key, self._model, init_strategy=self._init_strategy, param_as_improper=True, model_args=model_args, model_kwargs=model_kwargs) if not_jax_tracer(is_valid): if device_get(~np.all(is_valid)): raise RuntimeError("Cannot find valid initial parameters. " "Please check your model again.") hmc_init_fn = lambda init_params, rng_key: self._init_fn( # noqa: E731 init_params, num_warmup=num_warmup, step_size=self._step_size, adapt_step_size=self._adapt_step_size, adapt_mass_matrix=self._adapt_mass_matrix, dense_mass=self._dense_mass, target_accept_prob=self._target_accept_prob, trajectory_length=self._trajectory_length, max_tree_depth=self._max_tree_depth, rng_key=rng_key, model_args=model_args, model_kwargs=model_kwargs, ) if rng_key.ndim == 1: init_state = hmc_init_fn(init_params, rng_key) else: # XXX it is safe to run hmc_init_fn under vmap despite that hmc_init_fn changes some # nonlocal variables: momentum_generator, wa_update, trajectory_len, max_treedepth, # wa_steps because those variables do not depend on traced args: init_params, rng_key. init_state = vmap(hmc_init_fn)(init_params, rng_key) sample_fn = vmap(self._sample_fn, in_axes=(0, None, None)) self._sample_fn = sample_fn return init_state
def model(): x = numpyro.sample("x", dist.MultivariateNormal(np.zeros(3), np.eye(3))) with numpyro.plate("plate", len(data)): y = numpyro.sample( "y", dist.MultivariateNormal(x, np.eye(3)), obs=data, obs_mask=mask ) if not_jax_tracer(y): assert ((y == data).all(-1) == mask).all()
def _find_valid_params(rng_key_): _, _, prototype_params, is_valid = init_state = body_fn((0, rng_key_, None, None)) # Early return if valid params found. if not_jax_tracer(is_valid): if device_get(is_valid): return prototype_params, is_valid _, _, init_params, is_valid = while_loop(cond_fn, body_fn, init_state) return init_params, is_valid
def enumerate_support(self, expand=True): if not not_jax_tracer(self.high) or not not_jax_tracer(self.low): raise NotImplementedError( "Both `low` and `high` must not be a JAX Tracer.") if np.any(np.amax(self.low) != self.low): # NB: the error can't be raised if inhomogeneous issue happens when tracing raise NotImplementedError( "Inhomogeneous `low` not supported by `enumerate_support`.") if np.any(np.amax(self.high) != self.high): # NB: the error can't be raised if inhomogeneous issue happens when tracing raise NotImplementedError( "Inhomogeneous `high` not supported by `enumerate_support`.") values = (self.low + jnp.arange(np.amax(self.high - self.low) + 1) ).reshape((-1, ) + (1, ) * len(self.batch_shape)) if expand: values = jnp.broadcast_to(values, values.shape[:1] + self.batch_shape) return values
def model(): x = numpyro.sample("x", dist.Normal(0., 1.)) with numpyro.plate("plate", len(data)): y = numpyro.sample("y", dist.Normal(x, 1.), obs=data, obs_mask=mask) if not_jax_tracer(y): assert ((y == data) == mask).all()
def _validate_sample(self, value): mask = self.support(value) if not_jax_tracer(mask): if not np.all(mask): warnings.warn( "Out-of-support values provided to log prob method. " "The value argument should be within the support.", stacklevel=find_stack_level(), ) return mask
def sample(self, key, sample_shape=()): key_dirichlet, key_multinom = random.split(key) probs = self._dirichlet.sample(key_dirichlet, sample_shape) total_count = jnp.amax(self.total_count) if not_jax_tracer(total_count): # NB: the error can't be raised if inhomogeneous issue happens when tracing if jnp.amin(self.total_count) != total_count: raise NotImplementedError( "Inhomogeneous total count not supported" " by `sample`.") return Multinomial(total_count, probs).sample(key_multinom)
def enumerate_support(self, expand=True): total_count = jnp.amax(self.total_count) if not_jax_tracer(total_count): # NB: the error can't be raised if inhomogeneous issue happens when tracing if jnp.amin(self.total_count) != total_count: raise NotImplementedError("Inhomogeneous total count not supported" " by `enumerate_support`.") values = jnp.arange(total_count + 1).reshape((-1,) + (1,) * len(self.batch_shape)) if expand: values = jnp.broadcast_to(values, values.shape[:1] + self.batch_shape) return values
def initialize_model(rng_key, model, init_strategy=init_to_uniform(), dynamic_args=False, model_args=(), model_kwargs=None): """ (EXPERIMENTAL INTERFACE) Helper function that calls :func:`~numpyro.infer.util.get_potential_fn` and :func:`~numpyro.infer.util.find_valid_initial_params` under the hood to return a tuple of (`init_params`, `potential_fn`, `constrain_fn`). :param jax.random.PRNGKey rng_key: random number generator seed to sample from the prior. The returned `init_params` will have the batch shape ``rng_key.shape[:-1]``. :param model: Python callable containing Pyro primitives. :param callable init_strategy: a per-site initialization function. See :ref:`init_strategy` section for available functions. :param bool dynamic_args: if `True`, the `potential_fn` and `constraints_fn` are themselves dependent on model arguments. When provided a `*model_args, **model_kwargs`, they return `potential_fn` and `constraints_fn` callables, respectively. :param tuple model_args: args provided to the model. :param dict model_kwargs: kwargs provided to the model. :return: tuple of (`init_params`, `potential_fn`, `postprocess_fn`), `init_params` are values from the prior used to initiate MCMC, `postprocess_fn` is a callable that uses inverse transforms to convert unconstrained HMC samples to constrained values that lie within the site's support, in addition to returning values at `deterministic` sites in the model. """ if model_kwargs is None: model_kwargs = {} potential_fn, postprocess_fn = get_potential_fn( rng_key if rng_key.ndim == 1 else rng_key[0], model, dynamic_args=dynamic_args, model_args=model_args, model_kwargs=model_kwargs) init_params, is_valid = find_valid_initial_params( rng_key, model, init_strategy=init_strategy, param_as_improper=True, model_args=model_args, model_kwargs=model_kwargs) if not_jax_tracer(is_valid): if device_get(~np.all(is_valid)): raise RuntimeError( "Cannot find valid initial parameters. Please check your model again." ) return init_params, potential_fn, postprocess_fn
def body_fn(wrapped_carry, x, prefix=None): i, rng_key, carry = wrapped_carry init = True if (not_jax_tracer(i) and i in range(unroll_steps)) else False rng_key, subkey = random.split(rng_key) if rng_key is not None else ( None, None) # we need to tell unconstrained messenger in potential energy computation # that only the item at time `i` is needed when transforming fn = handlers.infer_config( f, config_fn=lambda msg: {"_scan_current_index": i}) seeded_fn = handlers.seed(fn, subkey) if subkey is not None else fn for subs_type, subs_map in substitute_stack: subs_fn = partial(_subs_wrapper, subs_map, i, length) if subs_type == "condition": seeded_fn = handlers.condition(seeded_fn, condition_fn=subs_fn) elif subs_type == "substitute": seeded_fn = handlers.substitute(seeded_fn, substitute_fn=subs_fn) if init: # handler the name to match the pattern of sakkar_bilmes product with handlers.scope(prefix="_PREV_" * (unroll_steps - i), divider=""): new_carry, y = config_enumerate(seeded_fn)(carry, x) trace = {} else: # Like scan_wrapper, we collect the trace of scan's transition function # `seeded_fn` here. To put time dimension to the correct position, we need to # promote shapes to make `fn` and `value` # at each site have the same batch dims (e.g. if `fn.batch_shape = (2, 3)`, # and value's batch_shape is (3,), then we promote shape of # value so that its batch shape is (1, 3)). # Here we will promote `fn` shape first. `value` shape will be promoted after scanned. # We don't promote `value` shape here because we need to store carry shape # at this step. If we reshape the `value` here, output carry might get wrong shape. with _promote_fn_shapes(), packed_trace() as trace: new_carry, y = config_enumerate(seeded_fn)(carry, x) # store shape of new_carry at a global variable if len(carry_shapes) < (history + 1): carry_shapes.append( [jnp.shape(x) for x in tree_flatten(new_carry)[0]]) # make new_carry have the same shape as carry # FIXME: is this rigorous? new_carry = tree_multimap( lambda a, b: jnp.reshape(a, jnp.shape(b)), new_carry, carry) return (i + 1, rng_key, new_carry), (PytreeTrace(trace), y)
def get_transform(self, params): def loss_fn(z): params1 = params.copy() params1['{}_loc'.format(self.prefix)] = z return self._loss_fn(params1) loc = params['{}_loc'.format(self.prefix)] precision = hessian(loss_fn)(loc) scale_tril = cholesky_of_inverse(precision) if not_jax_tracer(scale_tril): if np.any(np.isnan(scale_tril)): warnings.warn("Hessian of log posterior at the MAP point is singular. Posterior" " samples from AutoLaplaceApproxmiation will be constant (equal to" " the MAP point).") scale_tril = jnp.where(jnp.isnan(scale_tril), 0., scale_tril) return LowerCholeskyAffine(loc, scale_tril)
def __init__(self, batch_shape=(), event_shape=(), validate_args=None): self._batch_shape = batch_shape self._event_shape = event_shape if validate_args is not None: self._validate_args = validate_args if self._validate_args: for param, constraint in self.arg_constraints.items(): if param not in self.__dict__ and isinstance(getattr(type(self), param), lazy_property): continue if is_dependent(constraint): continue # skip constraints that cannot be checked is_valid = np.all(constraint(getattr(self, param))) if not_jax_tracer(is_valid): if not is_valid: raise ValueError("The parameter {} has invalid values".format(param)) super(Distribution, self).__init__()
def _get_transform(self, params): def loss_fn(z): params1 = params.copy() params1['{}_loc'.format(self.prefix)] = z # we are doing maximum likelihood, so only require `num_particles=1` and an arbitrary rng_key. return AutoContinuousELBO().loss(random.PRNGKey(0), params1, self.model, self, *self._args, **self._kwargs) loc = params['{}_loc'.format(self.prefix)] precision = hessian(loss_fn)(loc) scale_tril = cholesky_of_inverse(precision) if not_jax_tracer(scale_tril): if np.any(np.isnan(scale_tril)): warnings.warn("Hessian of log posterior at the MAP point is singular. Posterior" " samples from AutoLaplaceApproxmiation will be constant (equal to" " the MAP point).") scale_tril = np.where(np.isnan(scale_tril), 0., scale_tril) return MultivariateAffineTransform(loc, scale_tril)
def log_density(model, model_args, model_kwargs, params): """ (EXPERIMENTAL INTERFACE) Computes log of joint density for the model given latent values ``params``. :param model: Python callable containing NumPyro primitives. :param tuple model_args: args provided to the model. :param dict model_kwargs: kwargs provided to the model. :param dict params: dictionary of current parameter values keyed by site name. :return: log of joint density and a corresponding model trace """ model = substitute(model, param_map=params) model_trace = trace(model).get_trace(*model_args, **model_kwargs) log_joint = jnp.array(0.) for site in model_trace.values(): if site['type'] == 'sample': value = site['value'] intermediates = site['intermediates'] mask = site['mask'] scale = site['scale'] # Early exit when all elements are masked if not_jax_tracer(mask) and mask is not None and not jnp.any(mask): continue if intermediates: log_prob = site['fn'].log_prob(value, intermediates) else: log_prob = site['fn'].log_prob(value) # Minor optimizations # XXX: note that this may not work correctly for dynamic masks, provide # explicit jax.DeviceArray for masking. if mask is not None: if scale is not None: log_prob = jnp.where(mask, scale * log_prob, 0.) else: log_prob = jnp.where(mask, log_prob, 0.) else: if scale is not None: log_prob = scale * log_prob log_prob = jnp.sum(log_prob) log_joint = log_joint + log_prob return log_joint, model_trace
def _get_codomain(bijector): if bijector.__class__.__name__ == "Sigmoid": return constraints.interval(bijector.low, bijector.high) elif bijector.__class__.__name__ == "Identity": return constraints.real elif bijector.__class__.__name__ in ["Exp", "SoftPlus"]: return constraints.positive elif bijector.__class__.__name__ == "GeneralizedPareto": loc, scale, concentration = bijector.loc, bijector.scale, bijector.concentration if not_jax_tracer(concentration) and jnp.all(concentration < 0): return constraints.interval(loc, loc + scale / jnp.abs(concentration)) # XXX: here we suppose concentration > 0 # which is not true in general, but should cover enough usage cases else: return constraints.greater_than(loc) elif bijector.__class__.__name__ == "SoftmaxCentered": return constraints.simplex elif bijector.__class__.__name__ == "Chain": return _get_codomain(bijector.bijectors[-1]) else: return constraints.real
def initialize_model( rng_key, model, *, init_strategy=init_to_uniform, dynamic_args=False, model_args=(), model_kwargs=None, forward_mode_differentiation=False, validate_grad=True, ): """ (EXPERIMENTAL INTERFACE) Helper function that calls :func:`~numpyro.infer.util.get_potential_fn` and :func:`~numpyro.infer.util.find_valid_initial_params` under the hood to return a tuple of (`init_params_info`, `potential_fn`, `postprocess_fn`, `model_trace`). :param jax.random.PRNGKey rng_key: random number generator seed to sample from the prior. The returned `init_params` will have the batch shape ``rng_key.shape[:-1]``. :param model: Python callable containing Pyro primitives. :param callable init_strategy: a per-site initialization function. See :ref:`init_strategy` section for available functions. :param bool dynamic_args: if `True`, the `potential_fn` and `constraints_fn` are themselves dependent on model arguments. When provided a `*model_args, **model_kwargs`, they return `potential_fn` and `constraints_fn` callables, respectively. :param tuple model_args: args provided to the model. :param dict model_kwargs: kwargs provided to the model. :param bool forward_mode_differentiation: whether to use forward-mode differentiation or reverse-mode differentiation. By default, we use reverse mode but the forward mode can be useful in some cases to improve the performance. In addition, some control flow utility on JAX such as `jax.lax.while_loop` or `jax.lax.fori_loop` only supports forward-mode differentiation. See `JAX's The Autodiff Cookbook <https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html>`_ for more information. :param bool validate_grad: whether to validate gradient of the initial params. Defaults to True. :return: a namedtupe `ModelInfo` which contains the fields (`param_info`, `potential_fn`, `postprocess_fn`, `model_trace`), where `param_info` is a namedtuple `ParamInfo` containing values from the prior used to initiate MCMC, their corresponding potential energy, and their gradients; `postprocess_fn` is a callable that uses inverse transforms to convert unconstrained HMC samples to constrained values that lie within the site's support, in addition to returning values at `deterministic` sites in the model. """ model_kwargs = {} if model_kwargs is None else model_kwargs substituted_model = substitute( seed(model, rng_key if jnp.ndim(rng_key) == 1 else rng_key[0]), substitute_fn=init_strategy, ) ( inv_transforms, replay_model, has_enumerate_support, model_trace, ) = _get_model_transforms(substituted_model, model_args, model_kwargs) # substitute param sites from model_trace to model so # we don't need to generate again parameters of `numpyro.module` model = substitute( model, data={ k: site["value"] for k, site in model_trace.items() if site["type"] in ["param"] }, ) constrained_values = { k: v["value"] for k, v in model_trace.items() if v["type"] == "sample" and not v["is_observed"] and not v["fn"].is_discrete } if has_enumerate_support: from numpyro.contrib.funsor import config_enumerate, enum if not isinstance(model, enum): max_plate_nesting = _guess_max_plate_nesting(model_trace) _validate_model(model_trace) model = enum(config_enumerate(model), -max_plate_nesting - 1) potential_fn, postprocess_fn = get_potential_fn( model, inv_transforms, replay_model=replay_model, enum=has_enumerate_support, dynamic_args=dynamic_args, model_args=model_args, model_kwargs=model_kwargs, ) init_strategy = (init_strategy if isinstance(init_strategy, partial) else init_strategy()) if (init_strategy.func is init_to_value) and not replay_model: init_values = init_strategy.keywords.get("values") unconstrained_values = transform_fn(inv_transforms, init_values, invert=True) init_strategy = _init_to_unconstrained_value( values=unconstrained_values) prototype_params = transform_fn(inv_transforms, constrained_values, invert=True) (init_params, pe, grad), is_valid = find_valid_initial_params( rng_key, substitute( model, data={ k: site["value"] for k, site in model_trace.items() if site["type"] in ["plate"] }, ), init_strategy=init_strategy, enum=has_enumerate_support, model_args=model_args, model_kwargs=model_kwargs, prototype_params=prototype_params, forward_mode_differentiation=forward_mode_differentiation, validate_grad=validate_grad, ) if not_jax_tracer(is_valid): if device_get(~jnp.all(is_valid)): with numpyro.validation_enabled(), trace() as tr: # validate parameters substituted_model(*model_args, **model_kwargs) # validate values for site in tr.values(): if site["type"] == "sample": with warnings.catch_warnings(record=True) as ws: site["fn"]._validate_sample(site["value"]) if len(ws) > 0: for w in ws: # at site information to the warning message w.message.args = ("Site {}: {}".format( site["name"], w.message.args[0]), ) + w.message.args[1:] warnings.showwarning( w.message, w.category, w.filename, w.lineno, file=w.file, line=w.line, ) raise RuntimeError( "Cannot find valid initial parameters. Please check your model again." ) return ModelInfo(ParamInfo(init_params, pe, grad), potential_fn, postprocess_fn, model_trace)
def __init__(self, fn=None, scale=1.0): if not_jax_tracer(scale): if np.any(np.less_equal(scale, 0)): raise ValueError("'scale' argument should be positive.") self.scale = scale super().__init__(fn)
def initialize_model(rng_key, model, init_strategy=init_to_uniform, dynamic_args=False, model_args=(), model_kwargs=None): """ (EXPERIMENTAL INTERFACE) Helper function that calls :func:`~numpyro.infer.util.get_potential_fn` and :func:`~numpyro.infer.util.find_valid_initial_params` under the hood to return a tuple of (`init_params_info`, `potential_fn`, `postprocess_fn`, `model_trace`). :param jax.random.PRNGKey rng_key: random number generator seed to sample from the prior. The returned `init_params` will have the batch shape ``rng_key.shape[:-1]``. :param model: Python callable containing Pyro primitives. :param callable init_strategy: a per-site initialization function. See :ref:`init_strategy` section for available functions. :param bool dynamic_args: if `True`, the `potential_fn` and `constraints_fn` are themselves dependent on model arguments. When provided a `*model_args, **model_kwargs`, they return `potential_fn` and `constraints_fn` callables, respectively. :param tuple model_args: args provided to the model. :param dict model_kwargs: kwargs provided to the model. :return: a namedtupe `ModelInfo` which contains the fields (`param_info`, `potential_fn`, `postprocess_fn`, `model_trace`), where `param_info` is a namedtuple `ParamInfo` containing values from the prior used to initiate MCMC, their corresponding potential energy, and their gradients; `postprocess_fn` is a callable that uses inverse transforms to convert unconstrained HMC samples to constrained values that lie within the site's support, in addition to returning values at `deterministic` sites in the model. """ model_kwargs = {} if model_kwargs is None else model_kwargs substituted_model = substitute(seed( model, rng_key if jnp.ndim(rng_key) == 1 else rng_key[0]), substitute_fn=init_strategy) inv_transforms, replay_model, model_trace = _get_model_transforms( substituted_model, model_args, model_kwargs) constrained_values = { k: v['value'] for k, v in model_trace.items() if v['type'] == 'sample' and not v['is_observed'] and not v['fn'].is_discrete } potential_fn, postprocess_fn = get_potential_fn(model, inv_transforms, replay_model=replay_model, dynamic_args=dynamic_args, model_args=model_args, model_kwargs=model_kwargs) init_strategy = init_strategy if isinstance(init_strategy, partial) else init_strategy() if init_strategy.func is init_to_value: init_values = init_strategy.keywords.get("values") unconstrained_values = transform_fn(inv_transforms, init_values, invert=True) init_strategy = _init_to_unconstrained_value( values=unconstrained_values) prototype_params = transform_fn(inv_transforms, constrained_values, invert=True) (init_params, pe, grad), is_valid = find_valid_initial_params( rng_key, model, init_strategy=init_strategy, model_args=model_args, model_kwargs=model_kwargs, prototype_params=prototype_params) if not_jax_tracer(is_valid): if device_get(~jnp.all(is_valid)): raise RuntimeError( "Cannot find valid initial parameters. Please check your model again." ) return ModelInfo(ParamInfo(init_params, pe, grad), potential_fn, postprocess_fn, model_trace)
def __init__(self, fn=None, scale_factor=1.): if not_jax_tracer(scale_factor): if scale_factor <= 0: raise ValueError("scale factor should be a positive number.") self.scale = scale_factor super(scale, self).__init__(fn)
def initialize_model(rng_key, model, *model_args, init_strategy=init_to_uniform(), **model_kwargs): """ Given a model with Pyro primitives, returns a function which, given unconstrained parameters, evaluates the potential energy (negative joint density). In addition, this also returns initial parameters sampled from the prior to initiate MCMC sampling and functions to transform unconstrained values at sample sites to constrained values within their respective support. :param jax.random.PRNGKey rng_key: random number generator seed to sample from the prior. The returned `init_params` will have the batch shape ``rng_key.shape[:-1]``. :param model: Python callable containing Pyro primitives. :param `*model_args`: args provided to the model. :param callable init_strategy: a per-site initialization function. :param `**model_kwargs`: kwargs provided to the model. :return: tuple of (`init_params`, `potential_fn`, `constrain_fn`), `init_params` are values from the prior used to initiate MCMC, `constrain_fn` is a callable that uses inverse transforms to convert unconstrained HMC samples to constrained values that lie within the site's support. """ seeded_model = seed(model, rng_key if rng_key.ndim == 1 else rng_key[0]) model_trace = trace(seeded_model).get_trace(*model_args, **model_kwargs) constrained_values, inv_transforms = {}, {} has_transformed_dist = False for k, v in model_trace.items(): if v['type'] == 'sample' and not v['is_observed']: if v['intermediates']: constrained_values[k] = v['intermediates'][0][0] inv_transforms[k] = biject_to(v['fn'].base_dist.support) has_transformed_dist = True else: constrained_values[k] = v['value'] inv_transforms[k] = biject_to(v['fn'].support) elif v['type'] == 'param': constraint = v['kwargs'].pop('constraint', real) transform = biject_to(constraint) if isinstance(transform, ComposeTransform): base_transform = transform.parts[0] constrained_values[k] = base_transform( transform.inv(v['value'])) inv_transforms[k] = base_transform has_transformed_dist = True else: inv_transforms[k] = transform constrained_values[k] = v['value'] prototype_params = transform_fn( inv_transforms, {k: v for k, v in constrained_values.items()}, invert=True) # NB: we use model instead of seeded_model to prevent unexpected behaviours (if any) potential_fn = jax.partial(potential_energy, model, model_args, model_kwargs, inv_transforms) if has_transformed_dist: # FIXME: why using seeded_model here triggers an error for funnel reparam example # if we use MCMC class (mcmc function works fine) constrain_fun = jax.partial(constrain_fn, model, model_args, model_kwargs, inv_transforms) else: constrain_fun = jax.partial(transform_fn, inv_transforms) def single_chain_init(key): return find_valid_initial_params(key, model, *model_args, init_strategy=init_strategy, param_as_improper=True, prototype_params=prototype_params, **model_kwargs) if rng_key.ndim == 1: init_params, is_valid = single_chain_init(rng_key) else: init_params, is_valid = lax.map(single_chain_init, rng_key) if not_jax_tracer(is_valid): if device_get(~np.all(is_valid)): raise RuntimeError( "Cannot find valid initial parameters. Please check your model again." ) return init_params, potential_fn, constrain_fun
def find_valid_initial_params(rng_key, model, *model_args, init_strategy=init_to_uniform(), param_as_improper=False, prototype_params=None, **model_kwargs): """ Given a model with Pyro primitives, returns an initial valid unconstrained parameters. This function also returns an `is_valid` flag to say whether the initial parameters are valid. :param jax.random.PRNGKey rng_key: random number generator seed to sample from the prior. The returned `init_params` will have the batch shape ``rng_key.shape[:-1]``. :param model: Python callable containing Pyro primitives. :param `*model_args`: args provided to the model. :param callable init_strategy: a per-site initialization function. :param bool param_as_improper: a flag to decide whether to consider sites with `param` statement as sites with improper priors. :param `**model_kwargs`: kwargs provided to the model. :return: tuple of (`init_params`, `is_valid`). """ init_strategy = jax.partial(init_strategy, skip_param=not param_as_improper) def cond_fn(state): i, _, _, is_valid = state return (i < 100) & (~is_valid) def body_fn(state): i, key, _, _ = state key, subkey = random.split(key) # Wrap model in a `substitute` handler to initialize from `init_loc_fn`. # Use `block` to not record sample primitives in `init_loc_fn`. seeded_model = substitute(model, substitute_fn=block( seed(init_strategy, subkey))) model_trace = trace(seeded_model).get_trace(*model_args, **model_kwargs) constrained_values, inv_transforms = {}, {} for k, v in model_trace.items(): if v['type'] == 'sample' and not v['is_observed']: if v['intermediates']: constrained_values[k] = v['intermediates'][0][0] inv_transforms[k] = biject_to(v['fn'].base_dist.support) else: constrained_values[k] = v['value'] inv_transforms[k] = biject_to(v['fn'].support) elif v['type'] == 'param' and param_as_improper: constraint = v['kwargs'].pop('constraint', real) transform = biject_to(constraint) if isinstance(transform, ComposeTransform): base_transform = transform.parts[0] inv_transforms[k] = base_transform constrained_values[k] = base_transform( transform.inv(v['value'])) else: inv_transforms[k] = transform constrained_values[k] = v['value'] params = transform_fn(inv_transforms, {k: v for k, v in constrained_values.items()}, invert=True) potential_fn = jax.partial(potential_energy, model, model_args, model_kwargs, inv_transforms) pe, param_grads = value_and_grad(potential_fn)(params) z_grad = ravel_pytree(param_grads)[0] is_valid = np.isfinite(pe) & np.all(np.isfinite(z_grad)) return i + 1, key, params, is_valid if prototype_params is not None: init_state = (0, rng_key, prototype_params, False) else: _, _, prototype_params, is_valid = init_state = body_fn( (0, rng_key, None, None)) if not_jax_tracer(is_valid): if device_get(is_valid): return prototype_params, is_valid _, _, init_params, is_valid = while_loop(cond_fn, body_fn, init_state) return init_params, is_valid
def log_density(model, model_args, model_kwargs, params, skip_dist_transforms=False): """ (EXPERIMENTAL INTERFACE) Computes log of joint density for the model given latent values ``params``. :param model: Python callable containing NumPyro primitives. :param tuple model_args: args provided to the model. :param dict model_kwargs: kwargs provided to the model. :param dict params: dictionary of current parameter values keyed by site name. :param bool skip_dist_transforms: whether to compute log probability of a site (if its prior is a transformed distribution) in its base distribution domain. :return: log of joint density and a corresponding model trace """ # We skip transforms in # + autoguide's model # + hmc's model # We apply transforms in # + autoguide's guide # + svi's model + guide if skip_dist_transforms: model = substitute(model, base_param_map=params) else: model = substitute(model, param_map=params) model_trace = trace(model).get_trace(*model_args, **model_kwargs) log_joint = 0. for site in model_trace.values(): if site['type'] == 'sample': value = site['value'] intermediates = site['intermediates'] mask = site['mask'] scale = site['scale'] # Early exit when all elements are masked if not_jax_tracer(mask) and mask is not None and not np.any(mask): return jax.device_put(0.), model_trace if intermediates: if skip_dist_transforms: log_prob = site['fn'].base_dist.log_prob( intermediates[0][0]) else: log_prob = site['fn'].log_prob(value, intermediates) else: log_prob = site['fn'].log_prob(value) # Minor optimizations # XXX: note that this may not work correctly for dynamic masks, provide # explicit jax.DeviceArray for masking. if mask is not None: if scale is not None: log_prob = np.where(mask, scale * log_prob, 0.) else: log_prob = np.where(mask, log_prob, 0.) else: if scale is not None: log_prob = scale * log_prob log_prob = np.sum(log_prob) log_joint = log_joint + log_prob return log_joint, model_trace