def Overlap(W_i, W_ev, N):
    """
        Computes the overlap between the evolved state and the initial state
        with the Onishi Formula. 
    """
    #N = np.int(np.shape(W_i)[0]/2)

    U_i = W_i[0:N, 0:N]
    V_i = W_i[N:2 * N, 0:N]

    #switch zero modes
    P1 = np.diag(np.concatenate((np.array([0]), np.ones((N - 1)))))
    P = np.concatenate((np.concatenate((P1, (np.identity(N) - P1)), axis=1),
                        np.concatenate(((np.identity(N) - P1), P1), axis=1)))
    w_i = np.dot(W_i, P)
    U_i2 = w_i[0:N, 0:N]
    V_i2 = w_i[N:2 * N, 0:N]

    U_ev = W_ev[0:N, 0:N]
    V_ev = W_ev[N:2 * N, 0:N]

    Overlap1 = np.abs(
        np.linalg.det(
            (np.dot(U_i.conj().T, U_ev) + np.dot(V_i.conj().T, V_ev))))
    Overlap2 = np.abs(
        np.linalg.det(
            (np.dot(U_i2.conj().T, U_ev) + np.dot(V_i2.conj().T, V_ev))))

    return np.max([Overlap1, Overlap2])
Example #2
0
 def final_fn(state, regularize=False):
     """
     :param state: Current state of the scheme.
     :param bool regularize: Whether to adjust diagonal for numerical stability.
     :return: a triple of estimated covariance, the square root of precision, and
         the inverse of that square root.
     """
     mean, m2, n = state
     # XXX it is not necessary to check for the case n=1
     cov = m2 / (n - 1)
     if regularize:
         # Regularization from Stan
         scaled_cov = (n / (n + 5)) * cov
         shrinkage = 1e-3 * (5 / (n + 5))
         if diagonal:
             cov = scaled_cov + shrinkage
         else:
             cov = scaled_cov + shrinkage * jnp.identity(mean.shape[0])
     if jnp.ndim(cov) == 2:
         # copy the implementation of distributions.util.cholesky_of_inverse here
         tril_inv = jnp.swapaxes(
             jnp.linalg.cholesky(cov[..., ::-1, ::-1])[..., ::-1, ::-1], -2,
             -1)
         identity = jnp.identity(cov.shape[-1])
         cov_inv_sqrt = solve_triangular(tril_inv, identity, lower=True)
     else:
         tril_inv = jnp.sqrt(cov)
         cov_inv_sqrt = jnp.reciprocal(tril_inv)
     return cov, cov_inv_sqrt, tril_inv
Example #3
0
    def initialize(self, n, m, x, T_0, K=None, Q=None, R=None):
        """
        Description: Initialize the dynamics of the model
        Args:
            n (float/numpy.ndarray): dimension of the state
            m (float/numpy.ndarray): dimension of the controls
            x (postive int): initial state
            T_0 (int): system identification time
            K  (float/numpy.ndarray): initial controller (optional)
            Q, R (float/numpy.ndarray): cost matrices (c(x, u) = x^T Q x + u^T R u)
        """
        self.initialized = True

        self.n, self.m = n, m
        self.x = x
        self.T_0 = T_0
        self.T = 0

        Q = np.identity(n) if Q is None else Q
        R = np.identity(m) if R is None else R

        if K:
            self.K = K
        else:
            X = scipy.linalg.solve_discrete_are(A, B, Q, R)
            self.K = np.linalg.inv(B.T @ X @ B + R) @ (B.T @ X @ A)
Example #4
0
    def __init__(self,
                 A: jnp.ndarray,
                 B: jnp.ndarray,
                 Q: jnp.ndarray = None,
                 R: jnp.ndarray = None) -> None:
        """
        Description: Initialize the infinite-time horizon LQR.
        Args:
            A (jnp.ndarray): system dynamics
            B (jnp.ndarray): system dynamics
            Q (jnp.ndarray): cost matrices (i.e. cost = x^TQx + u^TRu)
            R (jnp.ndarray): cost matrices (i.e. cost = x^TQx + u^TRu)

        Returns:
            None
        """

        state_size, action_size = B.shape

        if Q is None:
            Q = jnp.identity(state_size, dtype=jnp.float32)

        if R is None:
            R = jnp.identity(action_size, dtype=jnp.float32)

        # solve the ricatti equation
        X = dare(A, B, Q, R)

        # compute LQR gain
        self.K = jnp.linalg.inv(B.T @ X @ B + R) @ (B.T @ X @ A)
Example #5
0
    def __init__(
        self,
        A: jnp.ndarray,
        B: jnp.ndarray,
        T: jnp.ndarray,
        Q: jnp.ndarray = None,
        R: jnp.ndarray = None,
    ) -> None:
        """
        Description: initializes the Hinf agent

        Args:
            A (jnp.ndarray):
            B (jnp.ndarray):
            T (jnp.ndarray):
            Q (jnp.ndarray):
            R (jnp.ndarray):

        Returns:
            None
        """
        d_x, d_u = B.shape

        if Q is None:
            self.Q = jnp.identity(d_x, dtype=jnp.float32)

        if R is None:
            self.R = jnp.identity(d_u, dtype=jnp.float32)

        # self.K, self.W = solve_hinf(A, B, Q, R, T)
        self.t = 0
Example #6
0
    def update_C_prior(self, params):

        sigma = params[0]
        theta = params[1]

        C_prior = np.identity(self.n_features) * 1 / theta
        C_prior_inv = np.identity(self.n_features) * theta
        return C_prior, C_prior_inv
Example #7
0
def _randaugment_inner_for_loop(_, in_args):
    """
    Loop body for for randougment.
    Args:
        i: loop iteration
        in_args: loop body arguments

    Returns:
        updated loop arguments
    """
    (image, geometric_transforms, random_key, available_ops, op_probs,
     magnitude, cutout_const, translate_const, join_transforms,
     default_replace_value) = in_args
    random_keys = random.split(random_key, num=8)
    random_key = random_keys[0]  # keep for next iteration
    op_to_select = random.choice(random_keys[1], available_ops, p=op_probs)
    mask_value = jnp.where(default_replace_value > 0,
                           jnp.ones([image.shape[-1]]) * default_replace_value,
                           random.randint(random_keys[2],
                                          [image.shape[-1]],
                                          minval=-1, maxval=256))
    random_magnitude = random.uniform(random_keys[3], [], minval=0.,
                                      maxval=magnitude)
    cutout_mask = color_transforms.get_random_cutout_mask(
        random_keys[4],
        image.shape,
        cutout_const)

    translate_vals = (random.uniform(random_keys[5], [], minval=0.0,
                                     maxval=1.0) * translate_const,
                      random.uniform(random_keys[6], [], minval=0.0,
                                     maxval=1.0) * translate_const)
    negate = random.randint(random_keys[7], [], minval=0,
                            maxval=2).astype('bool')

    args = level_to_arg(cutout_mask, translate_vals, negate,
                        random_magnitude, mask_value)

    if DEBUG:
        print(op_to_select, args[op_to_select])

    image, geometric_transform = _apply_ops(image, args, op_to_select)

    image, geometric_transform = jax.lax.cond(
        jnp.logical_or(join_transforms, jnp.all(
            jnp.not_equal(geometric_transform, jnp.identity(4)))),
        lambda op: (op[0], op[1]),
        lambda op: (transforms.apply_transform(op[0],
                                               op[1],
                                               mask_value=mask_value),
                    jnp.identity(4)),
        (image, geometric_transform)
    )

    geometric_transforms = jnp.matmul(geometric_transforms, geometric_transform)
    return(image, geometric_transforms, random_key, available_ops, op_probs,
           magnitude, cutout_const, translate_const, join_transforms,
           default_replace_value)
def test_mean_var(jax_dist, sp_dist, params):
    n = 20000 if jax_dist in [dist.LKJ, dist.LKJCholesky] else 200000
    d_jax = jax_dist(*params)
    k = random.PRNGKey(0)
    samples = d_jax.sample(k, sample_shape=(n,))
    # check with suitable scipy implementation if available
    if sp_dist and not _is_batched_multivariate(d_jax):
        d_sp = sp_dist(*params)
        try:
            sp_mean = d_sp.mean()
        except TypeError:  # mvn does not have .mean() method
            sp_mean = d_sp.mean
        # for multivariate distns try .cov first
        if d_jax.event_shape:
            try:
                sp_var = np.diag(d_sp.cov())
            except TypeError:  # mvn does not have .cov() method
                sp_var = np.diag(d_sp.cov)
            except AttributeError:
                sp_var = d_sp.var()
        else:
            sp_var = d_sp.var()
        assert_allclose(d_jax.mean, sp_mean, rtol=0.01, atol=1e-7)
        assert_allclose(d_jax.variance, sp_var, rtol=0.01, atol=1e-7)
        if np.all(np.isfinite(sp_mean)):
            assert_allclose(np.mean(samples, 0), d_jax.mean, rtol=0.05, atol=1e-2)
        if np.all(np.isfinite(sp_var)):
            assert_allclose(np.std(samples, 0), np.sqrt(d_jax.variance), rtol=0.05, atol=1e-2)
    elif jax_dist in [dist.LKJ, dist.LKJCholesky]:
        if jax_dist is dist.LKJCholesky:
            corr_samples = np.matmul(samples, np.swapaxes(samples, -2, -1))
        else:
            corr_samples = samples
        dimension, concentration, _ = params
        # marginal of off-diagonal entries
        marginal = dist.Beta(concentration + 0.5 * (dimension - 2),
                             concentration + 0.5 * (dimension - 2))
        # scale statistics due to linear mapping
        marginal_mean = 2 * marginal.mean - 1
        marginal_std = 2 * np.sqrt(marginal.variance)
        expected_mean = np.broadcast_to(np.reshape(marginal_mean, np.shape(marginal_mean) + (1, 1)),
                                        np.shape(marginal_mean) + d_jax.event_shape)
        expected_std = np.broadcast_to(np.reshape(marginal_std, np.shape(marginal_std) + (1, 1)),
                                       np.shape(marginal_std) + d_jax.event_shape)
        # diagonal elements of correlation matrices are 1
        expected_mean = expected_mean * (1 - np.identity(dimension)) + np.identity(dimension)
        expected_std = expected_std * (1 - np.identity(dimension))

        assert_allclose(np.mean(corr_samples, axis=0), expected_mean, atol=0.01)
        assert_allclose(np.std(corr_samples, axis=0), expected_std, atol=0.01)
    else:
        if np.all(np.isfinite(d_jax.mean)):
            assert_allclose(np.mean(samples, 0), d_jax.mean, rtol=0.05, atol=1e-2)
        if np.all(np.isfinite(d_jax.variance)):
            assert_allclose(np.std(samples, 0), np.sqrt(d_jax.variance), rtol=0.05, atol=1e-2)
Example #9
0
 def norm_project(self, y, A, c):
     """ Project y using norm A on the convex set bounded by c. """
     if np.any(np.isnan(y)) or np.all(np.absolute(y) <= c):
         return y
     y_shape = y.shape
     y_reshaped = np.ravel(y)
     dim_y = y_reshaped.shape[0]
     P = matrix(self.numpyify(A))
     q = matrix(self.numpyify(-np.dot(A, y_reshaped)))
     G = matrix(self.numpyify(np.append(np.identity(dim_y), -np.identity(dim_y), axis=0)), tc='d')
     h = matrix(self.numpyify(np.repeat(c, 2 * dim_y)), tc='d')
     solution = np.array(onp.array(solvers.qp(P, q, G, h)['x'])).squeeze().reshape(y_shape)
     return solution
Example #10
0
    def init_fn(z_info,
                rng_key,
                step_size=1.0,
                inverse_mass_matrix=None,
                mass_matrix_size=None):
        """
        :param IntegratorState z_info: The initial integrator state.
        :param jax.random.PRNGKey rng_key: Random key to be used as the source of randomness.
        :param float step_size: Initial step size.
        :param inverse_mass_matrix: Inverse of the initial mass matrix. If ``None``,
            inverse of mass matrix will be an identity matrix with size is decided
            by the argument `mass_matrix_size`.
        :param int mass_matrix_size: Size of the mass matrix.
        :return: initial state of the adapt scheme.
        """
        rng_key, rng_key_ss = random.split(rng_key)
        if inverse_mass_matrix is None:
            assert mass_matrix_size is not None
            if dense_mass:
                inverse_mass_matrix = jnp.identity(mass_matrix_size)
            else:
                inverse_mass_matrix = jnp.ones(mass_matrix_size)
            mass_matrix_sqrt = mass_matrix_sqrt_inv = inverse_mass_matrix
        else:
            if dense_mass:
                mass_matrix_sqrt_inv = jnp.swapaxes(
                    jnp.linalg.cholesky(
                        inverse_mass_matrix[..., ::-1, ::-1])[..., ::-1, ::-1],
                    -2, -1)
                identity = jnp.identity(inverse_mass_matrix.shape[-1])
                mass_matrix_sqrt = solve_triangular(mass_matrix_sqrt_inv,
                                                    identity,
                                                    lower=True)
            else:
                mass_matrix_sqrt_inv = jnp.sqrt(inverse_mass_matrix)
                mass_matrix_sqrt = jnp.reciprocal(mass_matrix_sqrt_inv)

        if adapt_step_size:
            step_size = find_reasonable_step_size(step_size,
                                                  inverse_mass_matrix, z_info,
                                                  rng_key_ss)
        ss_state = ss_init(jnp.log(10 * step_size))

        mm_state = mm_init(inverse_mass_matrix.shape[-1])

        window_idx = 0
        return HMCAdaptState(step_size, inverse_mass_matrix, mass_matrix_sqrt,
                             mass_matrix_sqrt_inv, ss_state, mm_state,
                             window_idx, rng_key)
Example #11
0
def sample(N,
           X=None,
           prior_var=1.,
           censorship_temp=10,
           distance_threshold=.5,
           distance_var=0.1):
  """Samples from the latent position model.

  Args:
    N: The number of latent positions.
    prior: Either 'uniform' or 'gaussian', the form of the position prior.
    prior_var: The variance for the position prior if it is Gaussian.
    censorship_temp: The temperature of the sigmoid that defines the censorship
      distribution, is multiplied by the covariates before they are exponentiated.
    distance_threshold: The distance where the probability of censorship is 50%.
    distance_var: The variance of the conditional distribution of distance given
      the positions and censoring indicators.

  Returns:
    X: The positions.
    C: The censorship indicators, a 2D numpy array of 0s and 1s
      of length num_pos*(num_pos-1)/2. Is 0 if a distance was observed and 1 otherwise.
    D: The pairwise distance matrix, flattened. Distances that were censored
      are 0.
  """
  if X is None:
    X = onp.random.multivariate_normal(mean=[0.,0.], cov=prior_var*np.identity(2), size=[N])
  distances = oscipy.spatial.distance.pdist(X)
  censorship_probs = oscipy.special.expit(censorship_temp*(distances - distance_threshold))
  C = onp.random.binomial(1, censorship_probs)
  uncensored_D = onp.random.normal(loc=distances, scale=onp.sqrt(distance_var))
  D = uncensored_D*(1-C)
  return X, C, D
Example #12
0
    def system_id(self):
        """ returns current estimate of hidden system dynamics """
        assert self.T > 0
        k = self.k if self.k else int(0.15 * self.T)

        # transform eta and x
        eta_np = np.array(self.eta)
        x_np = np.array(self.x_history)

        # prepare vectors and retrieve B
        scan_len = self.T - k - 1  # need extra -1 because we iterate over j=0,..,k
        N_j = np.array([
            np.dot(x_np[j + 1:j + 1 + scan_len].T, eta_np[:scan_len])
            for j in range(k + 1)
        ]) / scan_len
        B = N_j[0]  # np.dot(x_np[1:].T, eta_np[:-1]) / (self.T-1)
        #B = np.dot(x_np[1:].T, eta_np[:-1]) / (self.T-1)
        # retrieve A
        C_0, C_1 = N_j[:-1], N_j[1:]
        C_inv = np.linalg.inv(
            np.tensordot(C_0, C_0, axes=([0, 2], [0, 2])) +
            self.gamma * np.identity(self.n))
        A = np.tensordot(C_1, C_0, axes=([0, 2], [0, 2])) @ C_inv + B @ self.K

        return (A, B)
Example #13
0
    def kspring(self, kspring):
        """
        Set new spring constant.

        Arguments
        ---------
        kspring:
            A scalar, array length `N` or symmetric `N x N` matrix. Restraining spring constant.
        """
        # Ensure array
        kspring = np.asarray(kspring)
        shape = kspring.shape
        N = self._N

        if len(shape) > 2:
            raise RuntimeError(
                f"Wrong kspring shape {shape} (expected scalar, 1D or 2D)")
        elif len(shape) == 2:
            if shape != (N, N):
                raise RuntimeError(
                    f"2D kspring with wrong shape, expected ({N}, {N}), got {shape}"
                )
            if not np.allclose(kspring, kspring.T):
                raise RuntimeError("Spring matrix is not symmetric")

            self._kspring = kspring
        else:  # len(shape) == 0 or len(shape) == 1
            n = kspring.size
            if n != N and n != 1:
                raise RuntimeError(
                    f"Wrong kspring size, expected 1 or {N}, got {n}.")

            self._kspring = np.identity(N) * kspring
        return self._kspring
Example #14
0
    def __init__(self,
                 loc=jnp.array([0.]),
                 cov=jnp.identity(1),
                 sl=jnp.array([0.])):
        """
        Check dimensions compatibility and initialize the parameters.
        
        Arguments:
            - loc: location parameter
            - cov: covariance matrix
            - sl: slant parameter

        """
        assert loc.shape[-1] == cov.shape[-1]
        assert cov.shape[-1] == cov.shape[-2]
        assert loc.shape[-1] == sl.shape[-1]

        # assert (jnp.linalg.eigvalsh(cor - jnp.outer(skew, skew))).any() > 0

        self.loc = loc
        self.cov = cov
        self.slant = sl

        self.scale = jnp.sqrt(jnp.diag(cov))
        self.cor = jnp.einsum('i,ij,j->ij', 1. / self.scale, self.cov,
                              1. / self.scale)
        self.k = loc.shape[-1]
Example #15
0
def kl_div(mu: np.ndarray,
           A_chol: np.ndarray,
           sigma_prior: float
           ) -> float:
    """
    Computes the KL divergence between
    - the approximated posterior distribution N(mu, Sigma)
    - and the prior distribution on the parameters N(0, (sigma_prior ** 2) I)

    Instead of working directly with the covariance matrix Sigma, we will only deal with its Cholesky matrix A:
    It is the lower triangular matrix such that Sigma = A * A.T

    :param mu: mean of the posterior distribution approximated by variational inference
    :param A: Choleski matrix such that Sigma = A * A.T,
    where Sigma is the coveriance of the posterior distribution approximated by variational inference.
    :param sigma_prior: standard deviation of the prior on the parameters. We put the following prior on the parameters:
    N(mean=0, variance=(sigma_prior**2) I)
    :return: the value of the KL divergence
    """
    # TODO
    covariance_post = np.dot(A_chol,A_chol.T)
    mean_post = mu
    size = len(mu)
    mean_prior = np.zeros(shape=(A_chol.shape[0],1))
    covariance_prior = sigma_prior**2*np.identity(A_chol.shape[0])

    cov_ratio = np.linalg.det(covariance_prior)/np.linalg.det(covariance_post) # get ratio of 2 covariance matrices

    trace_matrices = np.trace(np.dot(np.linalg.inv(covariance_prior), covariance_post))

    last_term = np.dot((mean_prior - mean_post).T, np.dot(np.linalg.inv(covariance_prior), (mean_prior - mean_post)))

    kl = 0.5*(np.log(cov_ratio) - size + trace_matrices + last_term)

    return kl[0][0]
Example #16
0
 def _get_posterior(self):
     loc = numpyro.param('{}_loc'.format(self.prefix), self._init_latent)
     scale_tril = numpyro.param('{}_scale_tril'.format(self.prefix),
                                jnp.identity(self.latent_dim) *
                                self._init_scale,
                                constraint=constraints.lower_cholesky)
     return dist.MultivariateNormal(loc, scale_tril=scale_tril)
Example #17
0
    def _RuLSIF(self, x, y, alpha, s_sigma, s_lambda):
        if len(s_sigma) == 1 and len(s_lambda) == 1:
            sigma = s_sigma[0]
            lambda_ = s_lambda[0]
        else:
            optimized_params = self._optimize_sigma_lambda(
                x, y, alpha, s_sigma, s_lambda)
            sigma = optimized_params['sigma']
            lambda_ = optimized_params['lambda']

        phi_x = self.__kernel(r=x, sigma=sigma)
        phi_y = self.__kernel(r=y, sigma=sigma)
        H = (1. -
             alpha) * (np.dot(phi_y.T, phi_y) / self.__y_num_row) + alpha * (
                 np.dot(phi_x.T, phi_x) / self.__x_num_row)  # Phi* Phi
        h = np.average(phi_x, axis=0).T
        weights = np.linalg.solve(H + lambda_ * np.identity(self.__kernel_num),
                                  h).ravel()
        #  weights[weights < 0] = 0.
        weights = jax.ops.index_update(weights, weights < 0, 0)  # G2[G2<0]=0

        self.__alpha = alpha
        self.__weights = weights
        self.__lambda = lambda_
        self.__sigma = sigma
        self.__phi_x = phi_x
        self.__phi_y = phi_y
Example #18
0
 def covariance_matrix(self):
     # TODO: find a better solution to create a diagonal matrix
     new_diag = self.cov_diag[..., np.newaxis] * np.identity(
         self.loc.shape[-1])
     covariance_matrix = new_diag + np.matmul(
         self.cov_factor, np.swapaxes(self.cov_factor, -1, -2))
     return covariance_matrix
Example #19
0
def test_lqr(steps=10, show_plot=True):

    T = steps

    n = 1  # dimension of  the state x
    m = 1  # control dimension
    noise_magnitude = 0.2
    noise_distribution = 'normal'

    environment_id = "LDS"
    environment_params = {
        'n': n,
        'm': m,
        'noise_magnitude': noise_magnitude,
        'noise_distribution': noise_distribution
    }

    C = np.identity(n + m)  # quadratic cost
    LQR_params = {'C': C, 'T': T}

    LQR_results, LQR_norms, LQR_avg_results = get_trajectory((environment_id, environment_params), \
                                                            ('LQR', LQR_params), T = T)
    if (show_plot):
        plt.plot(LQR_norms, label="LQR")
        plt.title("LQR on LDS")

    print("test_lqr passed")
    return
Example #20
0
 def enumerate_support(self, expand=True):
     n = self.event_shape[-1]
     values = jnp.identity(n, dtype=canonicalize_dtype(self.dtype))
     values = values.reshape((n,) + (1,) * len(self.batch_shape) + (n,))
     if expand:
         values = jnp.broadcast_to(values, (n,) + self.batch_shape + (n,))
     return values
Example #21
0
    def _onion(self, key, size):
        key_beta, key_normal = random.split(key)
        # Now we generate w term in Algorithm 3.2 of [1].
        beta_sample = self._beta.sample(key_beta, size)
        # The following Normal distribution is used to create a uniform distribution on
        # a hypershere (ref: http://mathworld.wolfram.com/HyperspherePointPicking.html)
        normal_sample = random.normal(key_normal,
                                      shape=size + self.batch_shape +
                                      (self.dimension *
                                       (self.dimension - 1) // 2, ))
        normal_sample = vec_to_tril_matrix(normal_sample, diagonal=0)
        u_hypershere = normal_sample / np.linalg.norm(
            normal_sample, axis=-1, keepdims=True)
        w = np.expand_dims(np.sqrt(beta_sample), axis=-1) * u_hypershere

        # put w into the off-diagonal triangular part
        cholesky = ops.index_add(
            np.zeros(size + self.batch_shape + self.event_shape),
            ops.index[..., 1:, :-1], w)
        # correct the diagonal
        # NB: we clip due to numerical precision
        diag = np.sqrt(np.clip(1 - np.sum(cholesky**2, axis=-1), a_min=0.))
        cholesky = cholesky + np.expand_dims(diag, axis=-1) * np.identity(
            self.dimension)
        return cholesky
Example #22
0
 def _get_transform(self):
     loc = numpyro.param('{}_loc'.format(self.prefix), self._init_latent)
     scale_tril = numpyro.param('{}_scale_tril'.format(self.prefix),
                                np.identity(self.latent_size) *
                                self._init_scale,
                                constraint=constraints.lower_cholesky)
     return MultivariateAffineTransform(loc, scale_tril)
Example #23
0
    def init_fn(z, rng, step_size=1.0, inverse_mass_matrix=None, mass_matrix_size=None):
        """
        :param z: Initial position of the integrator.
        :param jax.random.PRNGKey rng: Random key to be used as the source of randomness.
        :param float step_size: Initial step size.
        :param inverse_mass_matrix: Inverse of the initial mass matrix. If ``None``,
            inverse of mass matrix will be an identity matrix with size is decided
            by the argument `mass_matrix_size`.
        :param int mass_matrix_size: Size of the mass matrix.
        :return: initial state of the adapt scheme.
        """
        rng, rng_ss = random.split(rng)
        if inverse_mass_matrix is None:
            assert mass_matrix_size is not None
            if dense_mass:
                inverse_mass_matrix = np.identity(mass_matrix_size)
            else:
                inverse_mass_matrix = np.ones(mass_matrix_size)
            mass_matrix_sqrt = inverse_mass_matrix
        else:
            if dense_mass:
                mass_matrix_sqrt = cholesky_inverse(inverse_mass_matrix)
            else:
                mass_matrix_sqrt = np.sqrt(np.reciprocal(inverse_mass_matrix))

        if adapt_step_size:
            step_size = find_reasonable_step_size(inverse_mass_matrix, z, rng_ss, step_size)
        ss_state = ss_init(np.log(10 * step_size))

        mm_state = mm_init(inverse_mass_matrix.shape[-1])

        window_idx = 0
        return AdaptState(step_size, inverse_mass_matrix, mass_matrix_sqrt,
                          ss_state, mm_state, window_idx, rng)
Example #24
0
 def inv_vec_transform(y):
     matrix = vec_to_tril_matrix(y, diagonal=-1)
     if constraint is constraints.corr_matrix:
         # fill the upper triangular part
         matrix = matrix + np.swapaxes(
             matrix, -2, -1) + np.identity(matrix.shape[-1])
     return transform.inv(matrix)
Example #25
0
def test_rot():
    """Tests the rot function and computation of its gradient"""
    ket0 = jnp.array([1, 0], dtype=jnp.complex64)
    evo = jnp.dot(rot([0.5, 0.7, 0.8]), ket0)
    back_evo = jnp.dot(rot([0.5, 0.7, 0.8]), evo)

    assert jnp.all(rot([0, 0, 0]) == jnp.identity(2, dtype="complex64"))
    assert not jnp.all(jnp.equal(evo, back_evo))
Example #26
0
 def gamma_f(p: jnp.ndarray) -> jnp.ndarray:
     gamma_f_arr = jnp.block([[zeros5],
                              [
                                  jnp.identity(nn),
                                  self.objective_object.final_weight(p),
                                  zeros6
                              ], [zeros51]])
     return gamma_f_arr
Example #27
0
 def _get_posterior(self):
     loc = numpyro.param("{}_loc".format(self.prefix), self._init_latent)
     scale_tril = numpyro.param(
         "{}_scale_tril".format(self.prefix),
         jnp.identity(self.latent_dim) * self._init_scale,
         constraint=self.scale_tril_constraint,
     )
     return dist.MultivariateNormal(loc, scale_tril=scale_tril)
Example #28
0
def step(t, state, dt, space, t_h, parameters):
    hamiltonian = init_hamiltonian(space, t + dt / 2, parameters, t_h)
    I = np.identity(len(state))
    H = 0.5j * hamiltonian * dt
    # get the new state by solving a system of linear equations obtained by crank-nicolson
    state = np.linalg.solve(I + H, (I - H).dot(state))
    t += dt
    return t, state
Example #29
0
def cholesky_of_inverse(matrix):
    # This formulation only takes the inverse of a triangular matrix
    # which is more numerically stable.
    # Refer to:
    # https://nbviewer.jupyter.org/gist/fehiepsi/5ef8e09e61604f10607380467eb82006#Precision-to-scale_tril
    tril_inv = jnp.swapaxes(jnp.linalg.cholesky(matrix[..., ::-1, ::-1])[..., ::-1, ::-1], -2, -1)
    identity = jnp.broadcast_to(jnp.identity(matrix.shape[-1]), tril_inv.shape)
    return solve_triangular(tril_inv, identity, lower=True)
def fill_cov(S, dim):
    m, m = S.shape
    S_eye = jnp.identity(dim - m)
    S_fill = jnp.zeros((m, dim - m))
    S_fill_left = jnp.vstack((S_fill, S_eye))
    S_final = jnp.vstack((S, S_fill.T))
    S_final = jnp.hstack((S_final, S_fill_left))
    return S_final