def __init__(self, loc=0., covariance_matrix=None, precision_matrix=None, scale_tril=None, validate_args=None): if np.isscalar(loc): loc = np.expand_dims(loc, axis=-1) # temporary append a new axis to loc loc = loc[..., np.newaxis] if covariance_matrix is not None: loc, self.covariance_matrix = promote_shapes( loc, covariance_matrix) self.scale_tril = np.linalg.cholesky(self.covariance_matrix) elif precision_matrix is not None: loc, self.precision_matrix = promote_shapes(loc, precision_matrix) self.scale_tril = cholesky_of_inverse(self.precision_matrix) elif scale_tril is not None: loc, self.scale_tril = promote_shapes(loc, scale_tril) else: raise ValueError( 'One of `covariance_matrix`, `precision_matrix`, `scale_tril`' ' must be specified.') batch_shape = lax.broadcast_shapes( np.shape(loc)[:-2], np.shape(self.scale_tril)[:-2]) event_shape = np.shape(self.scale_tril)[-1:] self.loc = np.broadcast_to(np.squeeze(loc, axis=-1), batch_shape + event_shape) super(MultivariateNormal, self).__init__(batch_shape=batch_shape, event_shape=event_shape, validate_args=validate_args)
def init_fn(z_info, rng_key, step_size=1.0, inverse_mass_matrix=None, mass_matrix_size=None): """ :param IntegratorState z_info: The initial integrator state. :param jax.random.PRNGKey rng_key: Random key to be used as the source of randomness. :param float step_size: Initial step size. :param inverse_mass_matrix: Inverse of the initial mass matrix. If ``None``, inverse of mass matrix will be an identity matrix with size is decided by the argument `mass_matrix_size`. :param int mass_matrix_size: Size of the mass matrix. :return: initial state of the adapt scheme. """ rng_key, rng_key_ss = random.split(rng_key) if inverse_mass_matrix is None: assert mass_matrix_size is not None if dense_mass: inverse_mass_matrix = jnp.identity(mass_matrix_size) else: inverse_mass_matrix = jnp.ones(mass_matrix_size) mass_matrix_sqrt = inverse_mass_matrix else: if dense_mass: mass_matrix_sqrt = cholesky_of_inverse(inverse_mass_matrix) else: mass_matrix_sqrt = jnp.sqrt(jnp.reciprocal(inverse_mass_matrix)) if adapt_step_size: step_size = find_reasonable_step_size(step_size, inverse_mass_matrix, z_info, rng_key_ss) ss_state = ss_init(jnp.log(10 * step_size)) mm_state = mm_init(inverse_mass_matrix.shape[-1]) window_idx = 0 return HMCAdaptState(step_size, inverse_mass_matrix, mass_matrix_sqrt, ss_state, mm_state, window_idx, rng_key)
def get_transform(self, params): def loss_fn(z): params1 = params.copy() params1['{}_loc'.format(self.prefix)] = z return self._loss_fn(params1) loc = params['{}_loc'.format(self.prefix)] precision = hessian(loss_fn)(loc) scale_tril = cholesky_of_inverse(precision) if not_jax_tracer(scale_tril): if np.any(np.isnan(scale_tril)): warnings.warn("Hessian of log posterior at the MAP point is singular. Posterior" " samples from AutoLaplaceApproxmiation will be constant (equal to" " the MAP point).") scale_tril = jnp.where(jnp.isnan(scale_tril), 0., scale_tril) return LowerCholeskyAffine(loc, scale_tril)
def _get_transform(self, params): def loss_fn(z): params1 = params.copy() params1['{}_loc'.format(self.prefix)] = z # we are doing maximum likelihood, so only require `num_particles=1` and an arbitrary rng_key. return AutoContinuousELBO().loss(random.PRNGKey(0), params1, self.model, self, *self._args, **self._kwargs) loc = params['{}_loc'.format(self.prefix)] precision = hessian(loss_fn)(loc) scale_tril = cholesky_of_inverse(precision) if not_jax_tracer(scale_tril): if np.any(np.isnan(scale_tril)): warnings.warn("Hessian of log posterior at the MAP point is singular. Posterior" " samples from AutoLaplaceApproxmiation will be constant (equal to" " the MAP point).") scale_tril = np.where(np.isnan(scale_tril), 0., scale_tril) return MultivariateAffineTransform(loc, scale_tril)
def final_fn(state, regularize=False): """ :param state: Current state of the scheme. :param bool regularize: Whether to adjust diagonal for numerical stability. :return: a pair of estimated covariance and the square root of precision. """ mean, m2, n = state # XXX it is not necessary to check for the case n=1 cov = m2 / (n - 1) if regularize: # Regularization from Stan scaled_cov = (n / (n + 5)) * cov shrinkage = 1e-3 * (5 / (n + 5)) if diagonal: cov = scaled_cov + shrinkage else: cov = scaled_cov + shrinkage * np.identity(mean.shape[0]) if np.ndim(cov) == 2: cov_inv_sqrt = cholesky_of_inverse(cov) else: cov_inv_sqrt = np.sqrt(np.reciprocal(cov)) return cov, cov_inv_sqrt