Exemple #1
0
 def __init__(self,
              loc=0.,
              covariance_matrix=None,
              precision_matrix=None,
              scale_tril=None,
              validate_args=None):
     if np.isscalar(loc):
         loc = np.expand_dims(loc, axis=-1)
     # temporary append a new axis to loc
     loc = loc[..., np.newaxis]
     if covariance_matrix is not None:
         loc, self.covariance_matrix = promote_shapes(
             loc, covariance_matrix)
         self.scale_tril = np.linalg.cholesky(self.covariance_matrix)
     elif precision_matrix is not None:
         loc, self.precision_matrix = promote_shapes(loc, precision_matrix)
         self.scale_tril = cholesky_of_inverse(self.precision_matrix)
     elif scale_tril is not None:
         loc, self.scale_tril = promote_shapes(loc, scale_tril)
     else:
         raise ValueError(
             'One of `covariance_matrix`, `precision_matrix`, `scale_tril`'
             ' must be specified.')
     batch_shape = lax.broadcast_shapes(
         np.shape(loc)[:-2],
         np.shape(self.scale_tril)[:-2])
     event_shape = np.shape(self.scale_tril)[-1:]
     self.loc = np.broadcast_to(np.squeeze(loc, axis=-1),
                                batch_shape + event_shape)
     super(MultivariateNormal, self).__init__(batch_shape=batch_shape,
                                              event_shape=event_shape,
                                              validate_args=validate_args)
Exemple #2
0
    def init_fn(z_info, rng_key, step_size=1.0, inverse_mass_matrix=None, mass_matrix_size=None):
        """
        :param IntegratorState z_info: The initial integrator state.
        :param jax.random.PRNGKey rng_key: Random key to be used as the source of randomness.
        :param float step_size: Initial step size.
        :param inverse_mass_matrix: Inverse of the initial mass matrix. If ``None``,
            inverse of mass matrix will be an identity matrix with size is decided
            by the argument `mass_matrix_size`.
        :param int mass_matrix_size: Size of the mass matrix.
        :return: initial state of the adapt scheme.
        """
        rng_key, rng_key_ss = random.split(rng_key)
        if inverse_mass_matrix is None:
            assert mass_matrix_size is not None
            if dense_mass:
                inverse_mass_matrix = jnp.identity(mass_matrix_size)
            else:
                inverse_mass_matrix = jnp.ones(mass_matrix_size)
            mass_matrix_sqrt = inverse_mass_matrix
        else:
            if dense_mass:
                mass_matrix_sqrt = cholesky_of_inverse(inverse_mass_matrix)
            else:
                mass_matrix_sqrt = jnp.sqrt(jnp.reciprocal(inverse_mass_matrix))

        if adapt_step_size:
            step_size = find_reasonable_step_size(step_size, inverse_mass_matrix, z_info, rng_key_ss)
        ss_state = ss_init(jnp.log(10 * step_size))

        mm_state = mm_init(inverse_mass_matrix.shape[-1])

        window_idx = 0
        return HMCAdaptState(step_size, inverse_mass_matrix, mass_matrix_sqrt,
                             ss_state, mm_state, window_idx, rng_key)
Exemple #3
0
    def get_transform(self, params):
        def loss_fn(z):
            params1 = params.copy()
            params1['{}_loc'.format(self.prefix)] = z
            return self._loss_fn(params1)

        loc = params['{}_loc'.format(self.prefix)]
        precision = hessian(loss_fn)(loc)
        scale_tril = cholesky_of_inverse(precision)
        if not_jax_tracer(scale_tril):
            if np.any(np.isnan(scale_tril)):
                warnings.warn("Hessian of log posterior at the MAP point is singular. Posterior"
                              " samples from AutoLaplaceApproxmiation will be constant (equal to"
                              " the MAP point).")
        scale_tril = jnp.where(jnp.isnan(scale_tril), 0., scale_tril)
        return LowerCholeskyAffine(loc, scale_tril)
Exemple #4
0
    def _get_transform(self, params):
        def loss_fn(z):
            params1 = params.copy()
            params1['{}_loc'.format(self.prefix)] = z
            # we are doing maximum likelihood, so only require `num_particles=1` and an arbitrary rng_key.
            return AutoContinuousELBO().loss(random.PRNGKey(0), params1, self.model, self,
                                             *self._args, **self._kwargs)

        loc = params['{}_loc'.format(self.prefix)]
        precision = hessian(loss_fn)(loc)
        scale_tril = cholesky_of_inverse(precision)
        if not_jax_tracer(scale_tril):
            if np.any(np.isnan(scale_tril)):
                warnings.warn("Hessian of log posterior at the MAP point is singular. Posterior"
                              " samples from AutoLaplaceApproxmiation will be constant (equal to"
                              " the MAP point).")
        scale_tril = np.where(np.isnan(scale_tril), 0., scale_tril)
        return MultivariateAffineTransform(loc, scale_tril)
Exemple #5
0
 def final_fn(state, regularize=False):
     """
     :param state: Current state of the scheme.
     :param bool regularize: Whether to adjust diagonal for numerical stability.
     :return: a pair of estimated covariance and the square root of precision.
     """
     mean, m2, n = state
     # XXX it is not necessary to check for the case n=1
     cov = m2 / (n - 1)
     if regularize:
         # Regularization from Stan
         scaled_cov = (n / (n + 5)) * cov
         shrinkage = 1e-3 * (5 / (n + 5))
         if diagonal:
             cov = scaled_cov + shrinkage
         else:
             cov = scaled_cov + shrinkage * np.identity(mean.shape[0])
     if np.ndim(cov) == 2:
         cov_inv_sqrt = cholesky_of_inverse(cov)
     else:
         cov_inv_sqrt = np.sqrt(np.reciprocal(cov))
     return cov, cov_inv_sqrt