Beispiel #1
0
 def __call__(self, shape: Shape, dtype: DType) -> np.ndarray:
     if len(shape) < 2:
         raise ValueError(
             "Orthogonal initializer requires at least a 2D shape.")
     n_rows = shape[self.axis]
     n_cols = np.prod(shape) // n_rows
     matrix_shape = (n_rows, n_cols) if n_rows > n_cols else (n_cols,
                                                              n_rows)
     norm_dst = jax.random.normal(hooks.next_rng_key(), matrix_shape, dtype)
     q_mat, r_mat = jnp.linalg.qr(norm_dst)
     # Enforce Q is uniformly distributed
     q_mat *= jnp.sign(jnp.diag(r_mat))
     if n_rows < n_cols:
         q_mat = q_mat.T
     q_mat = jnp.reshape(q_mat,
                         (n_rows, ) + tuple(np.delete(shape, self.axis)))
     q_mat = jnp.moveaxis(q_mat, 0, self.axis)
     return jax.lax.convert_element_type(self.scale, dtype) * q_mat
Beispiel #2
0
    def expected_value_delta(params: transform.Params,
                             state: CvState) -> float:
        """"Expected value of second order expansion of `function` at dist mean."""
        del state
        mean_dist = params[0]
        var_dist = jnp.square(jnp.exp(params[1]))
        hessians = jax.hessian(function)(mean_dist)

        assert hessians.ndim == 2
        hess_diags = jnp.diag(hessians)
        assert hess_diags.ndim == 1

        # Trace (Hessian * Sigma) and we use that Sigma is diagonal.
        expected_second_order_term = jnp.sum(var_dist * hess_diags) / 2.

        expected_control_variate = function(mean_dist)
        expected_control_variate += expected_second_order_term
        return expected_control_variate
def compute_normal_modes(simulation_parameters):
    """Returns the angular frequencies and eigenvectors for the normal modes."""
    m, k_wall, k_pair = (simulation_parameters["m"],
                         simulation_parameters["k_wall"],
                         simulation_parameters["k_pair"])
    num_trajectories = m.shape[0]

    # Construct coupling matrix.
    coupling_matrix = (-(k_wall + 2 * k_pair) * jnp.eye(num_trajectories) +
                       k_pair * jnp.ones((num_trajectories, num_trajectories)))
    coupling_matrix = jnp.diag(1 / m) @ coupling_matrix

    # Compute eigenvalues and eigenvectors.
    eigvals, eigvecs = jnp.linalg.eig(coupling_matrix)
    w = jnp.sqrt(-eigvals)
    w = jnp.real(w)
    eigvecs = jnp.real(eigvecs)
    return w, eigvecs
Beispiel #4
0
def get_default_potential_initializer(dimension):
    """
    helper function to return a tuple of defaults

    arguments
        dimension : int
            dimension of the latent variable

    returns
        default_potential
            mean (at lambda=0)
        cov (at lambda=1)
            free_energy (-logZ_T)
    """
    potential = default_potential
    mu, cov = jnp.zeros(dimension), jnp.diag(jnp.ones(dimension)) * 0.5
    dG = 0.
    return potential, (mu, cov), dG
Beispiel #5
0
 def DCjacobian(self, r):
     """
     DC Jacobian (i.e. zero-frequency linear response) for
     linearization around state-vector v, leading to rate-vector r
     """
     if len(r.shape) < 2:
         Phi = self.gains_from_r(r)
         return (
             -np.eye(self.num_rcpt * self.N) +
             np.tile(self.Wrcpt * Phi[None, :], (1, self.num_rcpt))
         )  # broadcasting so that gain (Phi) varies by 2nd (presynaptic) neural index, and does not depend on receptor type or post-synaptic (1st) neural index
     else:
         Phi = lambda rr: np.diag(self.gains_from_r(rr))
         return (np.array([
             np.kron(np.ones(
                 (1, self.num_rcpt)), np.dot(self.Wrcpt, Phi(r[:, cc]))) -
             np.eye(self.N * self.num_rcpt) for cc in range(r.shape[1])
         ]))
Beispiel #6
0
def get_nondefault_potential_initializer(dimension):
    """
    helper function to return a tuple of nondefaults

    arguments
        dimension : int
            dimension of the latent variable

    returns
        default_potential
        mean (at lambda=0)
        cov (at lambda=1)
        free_energy (-logZ_T)
    """
    potential = nondefault_gaussian_trap_potential
    mu, cov = jnp.zeros(dimension), jnp.diag(jnp.ones(dimension)) * 2.
    dG = -0.3465
    return potential, (mu, cov), dG
 def plot_gp_test(params,
                  predictors,
                  target,
                  test_predictors,
                  conf_stds=2.0):
     mu, cov = predict(params, predictors, target, test_predictors)
     plt.matshow(cov)
     plt.title('covariance matrix')
     plt.legend()
     plt.show()
     std = np.sqrt(np.diag(cov))
     plt.plot(test_predictors, mu, color='green')
     plt.plot(predictors, target, 'k.')
     #plt.label()
     plt.fill_between(test_predictors.flatten(),
                      mu.flatten() - 2 * std,
                      mu.flatten() + 2 * std)
     plt.show()
Beispiel #8
0
def test_kmeans():
    points = jnp.concatenate([random.normal(random.PRNGKey(0), shape=(30, 2)),
                              3. + random.normal(random.PRNGKey(0), shape=(10, 2))],
                             axis=0)

    cluster_id, centers = kmeans(random.PRNGKey(0), points, jnp.ones(points.shape[0], dtype=jnp.bool_), K=2)

    mu, C = bounding_ellipsoid(points, jnp.ones(points.shape[0]))
    radii, rotation = ellipsoid_params(C)
    theta = jnp.linspace(0., jnp.pi * 2, 100)
    x = mu[:, None] + rotation @ jnp.diag(radii) @ jnp.stack([jnp.cos(theta), jnp.sin(theta)], axis=0)
    import pylab as plt
    mask = cluster_id == 0
    plt.scatter(points[mask, 0], points[mask, 1])
    mask = cluster_id == 1
    plt.scatter(points[mask, 0], points[mask, 1])
    plt.plot(x[0, :], x[1, :])
    plt.show()
Beispiel #9
0
def test_initial_inverse_mass_matrix_ndarray(dense_mass):
    def model():
        numpyro.sample("z", dist.Normal(0, 1).expand([2]))
        numpyro.sample("x", dist.Normal(0, 1).expand([3]))

    expected_mm = jnp.arange(1, 6.0)
    kernel = NUTS(
        model,
        dense_mass=dense_mass,
        inverse_mass_matrix=expected_mm,
        adapt_mass_matrix=False,
    )
    mcmc = MCMC(kernel, num_warmup=1, num_samples=1)
    mcmc.run(random.PRNGKey(0))
    inverse_mass_matrix = mcmc.last_state.adapt_state.inverse_mass_matrix
    assert set(inverse_mass_matrix.keys()) == {("x", "z")}
    expected_mm = jnp.diag(expected_mm) if dense_mass else expected_mm
    assert_allclose(inverse_mass_matrix[("x", "z")], expected_mm)
Beispiel #10
0
    def body(state):
        (count, err, u) = state
        V = Q.T @ jnp.diag(u) @ Q  # D+1, D+1
        # g[i] = Q[i,j].V^-1_jk.Q[i,k]
        g = vmap(lambda q: q @ jnp.linalg.solve(V, q))(Q)  # difference
        # jnp.diag(Q @ jnp.linalg.solve(V, Q.T))
        j = jnp.argmax(g)
        g_max = g[j]

        step_size = \
            (g_max - D - 1) / ((D + 1) * (g_max - 1))
        search_direction = jnp.where(jnp.arange(N) == j, 1. - u, -u)
        new_u = u + step_size * search_direction
        # new_u = (1. - step_size)*u
        new_u = jnp.where(
            jnp.arange(N) == j, u + step_size * (1. - u), u * (1. - step_size))
        new_err = jnp.linalg.norm(u - new_u)
        return (count + 1, new_err, new_u)
Beispiel #11
0
 def init(key, shape, dtype=dtype):
     dtype = dtypes.canonicalize_dtype(dtype)
     if len(shape) < 2:
         raise ValueError(
             "orthogonal initializer requires at least a 2D shape")
     n_rows, n_cols = prod(shape) // shape[column_axis], shape[column_axis]
     matrix_shape = (n_cols, n_rows) if n_rows < n_cols else (n_rows,
                                                              n_cols)
     A = random.normal(key, matrix_shape, dtype)
     Q, R = jnp.linalg.qr(A)
     diag_sign = lax.broadcast_to_rank(jnp.sign(jnp.diag(R)), rank=Q.ndim)
     Q *= diag_sign  # needed for a uniform distribution
     if n_rows < n_cols: Q = Q.T
     Q = jnp.reshape(
         Q,
         tuple(np.delete(shape, column_axis)) + (shape[column_axis], ))
     Q = jnp.moveaxis(Q, -1, column_axis)
     return scale * Q
Beispiel #12
0
def reciprocal_energy(conf, box, charges, alpha, kmax):

    assert kmax > 0
    assert box is not None
    assert alpha > 0

    recipBoxSize = (2 * np.pi) / np.diag(box)

    mg = []
    lowry = 0
    lowrz = 1

    numRx, numRy, numRz = kmax, kmax, kmax

    for rx in range(numRx):
        for ry in range(lowry, numRy):
            for rz in range(lowrz, numRz):
                mg.append([rx, ry, rz])
                lowrz = 1 - numRz
            lowry = 1 - numRy

    mg = np.array(onp.array(mg))

    # lattice vectors
    ki = np.expand_dims(recipBoxSize, axis=0) * mg  # [nk, 3]
    ri = np.expand_dims(conf, axis=0)  # [1, N, 3]
    rik = np.sum(np.multiply(ri, np.expand_dims(ki, axis=1)),
                 axis=-1)  # [nk, N]
    real = np.cos(rik)
    imag = np.sin(rik)
    eikr = real + 1j * imag  # [nk, N]
    qi = charges + 0j
    Sk = np.sum(qi * eikr, axis=-1)  # [nk]
    n2Sk = np.power(np.abs(Sk), 2)
    k2 = np.sum(np.multiply(ki, ki), axis=-1)  # [nk]
    factorEwald = -1 / (4 * alpha * alpha)
    ak = np.exp(k2 * factorEwald) / k2  # [nk]
    nrg = np.sum(ak * n2Sk)
    # the following volume calculation assumes the reduced PBC convention consistent
    # with that of OpenMM
    recipCoeff = (ONE_4PI_EPS0 * 4 * np.pi) / (box[0][0] * box[1][1] *
                                               box[2][2])

    return recipCoeff * nrg
Beispiel #13
0
  def test_swap_mc_jammed(self, dtype):
    key = random.PRNGKey(0)

    state = test_util.load_jammed_state('simulation_test_state.npy', dtype)
    space_fn = space.periodic(state.box[0, 0])
    displacement_fn, shift_fn = space_fn

    sigma = np.diag(state.sigma)[state.species]

    energy_fn = lambda dr, sigma: energy.soft_sphere(dr, sigma=sigma)
    neighbor_fn = partition.neighbor_list(displacement_fn,
                                          state.box[0, 0],
                                          np.max(sigma) + 0.1,
                                          dr_threshold=0.5)

    kT = 1e-2
    t_md = 0.1
    N_swap = 10
    init_fn, apply_fn = simulate.hybrid_swap_mc(space_fn,
                                                energy_fn,
                                                neighbor_fn,
                                                1e-3,
                                                kT,
                                                t_md,
                                                N_swap)
    state = init_fn(key, state.real_position, sigma)

    Ts = np.zeros((DYNAMICS_STEPS,))

    def step_fn(i, state_and_temp):
      state, temp = state_and_temp
      state = apply_fn(state)
      temp = temp.at[i].set(quantity.temperature(state.md.velocity))
      return state, temp

    state, Ts = lax.fori_loop(0, DYNAMICS_STEPS, step_fn, (state, Ts))

    tol = 5e-4
    self.assertAllClose(Ts[10:],
                        kT * np.ones((DYNAMICS_STEPS - 10)),
                        rtol=5e-1,
                        atol=5e-3)
    self.assertAllClose(np.mean(Ts[10:]), kT, rtol=tol, atol=tol)
    self.assertTrue(not np.all(state.sigma == sigma))
 def posterior_sample(self, key, sample, X_star, **kwargs):
     # Fetch training data
     batch = kwargs['batch']
     XL, XH = batch['XL'], batch['XH']
     NL, NH = XL.shape[0], XH.shape[0]
     # Fetch params
     var_L = sample['kernel_var_L']
     var_H = sample['kernel_var_H']
     length_L = sample['kernel_length_L']
     length_H = sample['kernel_length_H']
     beta_L = sample['beta_L']
     beta_H = sample['beta_H']
     eta_L = sample['eta_L']
     eta_H = sample['eta_H']
     rho = sample['rho']
     theta_L = np.concatenate([var_L, length_L])
     theta_H = np.concatenate([var_H, length_H])
     beta = np.concatenate([beta_L * np.ones(NL), beta_H * np.ones(NH)])
     eta = np.concatenate([eta_L, eta_H])
     # Compute kernels
     k_pp = rho**2 * self.kernel(X_star, X_star, theta_L) + \
                     self.kernel(X_star, X_star, theta_H) + \
                     np.eye(X_star.shape[0])*1e-8
     psi1 = rho * self.kernel(X_star, XL, theta_L)
     psi2 = rho**2 * self.kernel(X_star, XH, theta_L) + \
                     self.kernel(X_star, XH, theta_H)
     k_pX = np.hstack((psi1, psi2))
     # Compute K_xx
     K_LL = self.kernel(XL, XL, theta_L) + np.eye(NL) * 1e-8
     K_LH = rho * self.kernel(XL, XH, theta_L)
     K_HH = rho**2 * self.kernel(XH, XH, theta_L) + \
                     self.kernel(XH, XH, theta_H) + np.eye(NH)*1e-8
     K_xx = np.vstack((np.hstack((K_LL, K_LH)), np.hstack((K_LH.T, K_HH))))
     L = cholesky(K_xx, lower=True)
     # Sample latent function
     f = np.matmul(L, eta) + beta
     tmp_1 = solve_triangular(L.T, solve_triangular(L, f, lower=True))
     tmp_2 = solve_triangular(L.T, solve_triangular(L, k_pX.T, lower=True))
     # Compute predictive mean
     mu = np.matmul(k_pX, tmp_1)
     cov = k_pp - np.matmul(k_pX, tmp_2)
     std = np.sqrt(np.clip(np.diag(cov), a_min=0.))
     sample = mu + std * random.normal(key, mu.shape)
     return mu, sample
Beispiel #15
0
    def test_not_pd(self):
        z = jnp.array([4.2, -3.7])
        init_hessian_sqrt_diag = jnp.array([0.3, 1.5])
        init_hessian_sqrt_diag = jnp.diag(self.hess_sqrt)
        # init_hessian_sqrt_diag = jnp.ones(2) * 0.1

        ps, qs, us, ts = utils.bfgs_sqrt_pqut(self.vals, self.grads,
                                              init_hessian_sqrt_diag)

        # Test transpose prod
        hess_inv_sqrt_t_z = utils.bfgs_sqrt_transpose_prod(
            ps, qs, z, 1 / init_hessian_sqrt_diag)
        z2 = z.copy()
        for i in range(len(ps) - 1, -1, -1):
            z2 = (jnp.eye(2) - jnp.outer(qs[i], ps[i])) @ z2
        z2 = 1 / init_hessian_sqrt_diag * z2
        npt.assert_array_almost_equal(hess_inv_sqrt_t_z, z2, decimal=3)

        # Test prod
        hess_inv_z = utils.bfgs_sqrt_prod(ps, qs, hess_inv_sqrt_t_z,
                                          1 / init_hessian_sqrt_diag)
        z3 = 1 / init_hessian_sqrt_diag * z2.copy()
        for i in range(len(ps)):
            z3 = (jnp.eye(2) - jnp.outer(ps[i], qs[i])) @ z3
        npt.assert_array_almost_equal(hess_inv_z, z3, decimal=3)

        # Test accurate hessian inverse mvp
        npt.assert_array_almost_equal(self.hess_inv @ z, hess_inv_z, decimal=3)

        # Test accurate hessian mvp
        hess_sqrt_t_z = utils.bfgs_sqrt_transpose_prod(us, ts, z,
                                                       init_hessian_sqrt_diag)
        hess_z = utils.bfgs_sqrt_prod(us, ts, hess_sqrt_t_z,
                                      init_hessian_sqrt_diag)
        npt.assert_array_almost_equal(self.hess @ z, hess_z, decimal=3)

        # Test determinant
        hess_inv_det = utils.bfgs_sqrt_det(ps, qs,
                                           1 / init_hessian_sqrt_diag)**2
        npt.assert_almost_equal(hess_inv_det,
                                jnp.linalg.det(self.hess_inv),
                                decimal=3)
        hess_det = utils.bfgs_sqrt_det(us, ts, init_hessian_sqrt_diag)**2
        npt.assert_almost_equal(hess_det, jnp.linalg.det(self.hess), decimal=3)
Beispiel #16
0
    def bbox_differentiable(self, outputs, target_boxes, box_loss_mask,
                            num_boxes):
        box_loss_mask = box_loss_mask[:, :, jnp.newaxis]
        src_boxes, target_boxes = outputs["pred_boxes"], target_boxes
        loss_bbox = jnp.abs(src_boxes - target_boxes) * box_loss_mask
        losses = {}
        losses["loss_bbox"] = loss_bbox.sum() / num_boxes
        B, N, D = src_boxes.shape
        src_boxes, target_boxes = jnp.reshape(src_boxes,
                                              (B * N, D)), jnp.reshape(
                                                  target_boxes, (B * N, D))

        loss_giou = 1 - jnp.diag(
            generalized_box_iou(center_to_corners_format(src_boxes),
                                center_to_corners_format(target_boxes)))
        loss_giou = jnp.reshape(loss_giou, (B, N, 1)) * box_loss_mask

        losses["loss_giou"] = loss_giou.sum() / num_boxes
        return losses
Beispiel #17
0
def main(_):
    train_ds, test_ds = get_datasets(random.PRNGKey(123))
    trained_model = train(train_ds)

    if FLAGS.plot:
        import matplotlib.pyplot as plt

        obs_noise_scale = jax.nn.softplus(
            trained_model.params['observation_model']
            ['observation_noise_scale'])

        def learned_kernel_fn(x1, x2):
            return kernels.RBFKernelProvider.call(
                trained_model.params['kernel_fn'], x1)(x1, x2)

        def learned_mean_fn(x):
            return nn.Dense.call(trained_model.params['linear_mean_fn'],
                                 x,
                                 features=1)[:, 0]

        # prior GP model at learned model parameters
        fitted_gp = gaussian_processes.GaussianProcess(
            train_ds['index_points'], learned_mean_fn, learned_kernel_fn, 1e-4)
        posterior_gp = fitted_gp.posterior_gp(train_ds['y'],
                                              test_ds['index_points'],
                                              obs_noise_scale**2)

        pred_f_mean = posterior_gp.mean_function(test_ds['index_points'])
        pred_f_var = jnp.diag(
            posterior_gp.kernel_function(test_ds['index_points'],
                                         test_ds['index_points']))

        fig, ax = plt.subplots()
        ax.fill_between(test_ds['index_points'][:, 0],
                        pred_f_mean - 2 * jnp.sqrt(pred_f_var),
                        pred_f_mean + 2 * jnp.sqrt(pred_f_var),
                        alpha=0.5)

        ax.plot(test_ds['index_points'][:, 0],
                posterior_gp.mean_function(test_ds['index_points']), '-')
        ax.plot(train_ds['index_points'], train_ds['y'], 'ks')

        plt.show()
Beispiel #18
0
    def predict(self, X, return_std=False):
        # compute kernels between train and test data, etc.
        k_pp = self.kernel(X, X, **self.kernel_params)
        k_pX = self.kernel(X, self.X_train, **self.kernel_params, jitter=0.0)

        # compute posterior covariance
        K = k_pp - k_pX @ linalg.cho_solve(self.L, k_pX.T)

        # compute posterior mean
        mean = k_pX @ self.alpha

        # we return both the mean function and the standard deviation
        if return_std:
            return (
                (mean * self.y_std) + self.y_mean,
                jnp.sqrt(jnp.diag(K * self.y_std**2)),
            )
        else:
            return (mean * self.y_std) + self.y_mean, K * self.y_std**2
 def _jitchol2_3d(x, jitter):
     """
     :param x: (B,N,N)
     """
     # Scale jitter to the matrices at hand
     # (B,1,1)
     jitter = (
         1.0e-12
         + jitter * np.array([np.mean(np.diag(xi)) for xi in x])[:, None, None]
     )
     # return _jc((x, jitter))
     lx = cholesky(x)
     return lax.cond(
         np.any(np.isnan(lx)),
         x + jitter * np.eye(x.shape[1])[None],
         cholesky,
         x,
         lambda _x: _x,
     )
Beispiel #20
0
 def merge_factors(self, X, Xs=None, diag=False):
     factor_list = []
     for factor in self.factor_list:
         # make sure diag=True is handled properly
         if isinstance(factor, Covariance):
             factor_list.append(factor(X, Xs, diag))
         elif isinstance(factor, np.ndarray):
             if np.ndim(factor) == 2 and diag:
                 factor_list.append(np.diag(factor))
             else:
                 factor_list.append(factor)
         elif isinstance(factor, jnp.DeviceArray):
             if factor.ndim == 2 and diag:
                 factor_list.append(jnp.diag(factor))
             else:
                 factor_list.append(factor)
         else:
             factor_list.append(factor)
     return factor_list
def predict_f(Y_obs, K, uncert):
    """
    Predictive mu and sigma with outliers removed.

    Args:
        Y_obs: [N]
        K: [N,N]
        uncert: [N] outliers encoded with inf

    Returns:
        mu [N]
        sigma [N]
    """
    # (K + sigma.sigma)^-1 = sigma^-1.(sigma^-1.K.sigma^-1 + I)^-1.sigma^-1
    C = K / (uncert[:, None] * uncert[None, :]) + jnp.eye(K.shape[0])
    JT = jnp.linalg.solve(C, K / uncert[:, None])
    mu_star = JT.T @ (Y_obs / uncert)
    sigma2_star = jnp.diag(K - JT.T @ (K / uncert[:, None]))
    return mu_star, sigma2_star
Beispiel #22
0
def check_termination(values, params, memory):
    """ Check whether to terminate CMA-ES loop. """
    dC = jnp.diag(memory["C"])
    C, B, D = eigen_decomposition(memory["C"], memory["B"], memory["D"])

    # Stop if generation fct values of recent generation is below thresh.
    if (memory["generation"] > params["min_generations"]
            and jnp.max(values) - jnp.min(values) < params["tol_fun"]):
        print("TERMINATE ----> Convergence/No progress in objective")
        return True

    # Stop if std of normal distrib is smaller than tolx in all coordinates
    # and pc is smaller than tolx in all components.
    if jnp.all(memory["sigma"] * dC < params["tol_x"]) and np.all(
            memory["sigma"] * memory["p_c"] < params["tol_x"]):
        print("TERMINATE ----> Convergence/Search variance too small")
        return True

    # Stop if detecting divergent behavior.
    if memory["sigma"] * jnp.max(D) > params["tol_x_up"]:
        print("TERMINATE ----> Stepsize sigma exploded")
        return True

    # No effect coordinates: stop if adding 0.2-standard deviations
    # in any single coordinate does not change m.
    if jnp.any(memory["mean"] == memory["mean"] +
               (0.2 * memory["sigma"] * jnp.sqrt(dC))):
        print("TERMINATE ----> No effect when adding std to mean")
        return True

    # No effect axis: stop if adding 0.1-standard deviation vector in
    # any principal axis direction of C does not change m.
    if jnp.all(memory["mean"] == memory["mean"] +
               (0.1 * memory["sigma"] * D[0] * B[:, 0])):
        print("TERMINATE ----> No effect when adding std to mean")
        return True

    # Stop if the condition number of the covariance matrix exceeds 1e14.
    condition_cov = jnp.max(D) / jnp.min(D)
    if condition_cov > params["tol_condition_C"]:
        print("TERMINATE ----> C condition number exploded")
        return True
    return False
Beispiel #23
0
def test_gaussian_subposterior(method, diagonal):
    D = 10
    n_samples = 10000
    n_draws = 9000
    n_subs = 8

    mean = np.arange(D)
    cov = np.ones((D, D)) * 0.9 + np.identity(D) * 0.1
    subcov = n_subs * cov  # subposterior's covariance
    subposteriors = list(dist.MultivariateNormal(mean, subcov).sample(
        random.PRNGKey(1), (n_subs, n_samples)))

    draws = method(subposteriors, n_draws, diagonal=diagonal)
    assert draws.shape == (n_draws, D)
    assert_allclose(np.mean(draws, axis=0), mean, atol=0.03)
    if diagonal:
        assert_allclose(np.var(draws, axis=0), np.diag(cov), atol=0.05)
    else:
        assert_allclose(np.cov(draws.T), cov, atol=0.05)
Beispiel #24
0
def uqr(A: jnp.ndarray) -> Tuple[jnp.ndarray]:
    """This is the implementation of the unique QR decomposition as proposed in
    [1], modified for JAX compatibility.

    [1] https://github.com/numpy/numpy/issues/15628

    Args:
        A: Matrix for which to compute the unique QR decomposition.

    Returns:
        Q: Orthogonal matrix factor.
        R: Upper triangular matrix with positive elements on the diagonal.

    """
    Q, R = jnp.linalg.qr(A)
    signs = 2 * (jnp.diag(R) >= 0) - 1
    Q = Q * signs[:, jnp.newaxis]
    R = R * signs[..., jnp.newaxis]
    return Q, R
Beispiel #25
0
def test_initial_inverse_mass_matrix(dense_mass):
    def model():
        numpyro.sample("x", dist.Normal(0, 1).expand([3]))
        numpyro.sample("z", dist.Normal(0, 1).expand([2]))

    expected_mm = jnp.arange(1, 4.0)
    kernel = NUTS(
        model,
        dense_mass=dense_mass,
        inverse_mass_matrix={("x",): expected_mm},
        adapt_mass_matrix=False,
    )
    mcmc = MCMC(kernel, 1, 1)
    mcmc.run(random.PRNGKey(0))
    inverse_mass_matrix = mcmc.last_state.adapt_state.inverse_mass_matrix
    assert set(inverse_mass_matrix.keys()) == {("x",), ("z",)}
    expected_mm = jnp.diag(expected_mm) if dense_mass else expected_mm
    assert_allclose(inverse_mass_matrix[("x",)], expected_mm)
    assert_allclose(inverse_mass_matrix[("z",)], jnp.ones(2))
Beispiel #26
0
    def test_periodic_against_periodic_general(self, spatial_dimension, dtype):
        key = random.PRNGKey(0)

        tol = 1e-13
        if dtype is f32:
            tol = 1e-5

        for _ in range(STOCHASTIC_SAMPLES):
            key, split1, split2, split3 = random.split(key, 4)

            max_box_size = f32(10.0)
            box_size = max_box_size * random.uniform(split1,
                                                     (spatial_dimension, ),
                                                     dtype=dtype)
            transform = np.diag(box_size)

            R = random.uniform(split2, (PARTICLE_COUNT, spatial_dimension),
                               dtype=dtype)
            R_scaled = R * box_size

            dR = random.normal(split3, (PARTICLE_COUNT, spatial_dimension),
                               dtype=dtype)

            disp_fn, shift_fn = space.periodic(box_size)
            general_disp_fn, general_shift_fn = space.periodic_general(
                transform)

            disp_fn = space.map_product(disp_fn)
            general_disp_fn = space.map_product(general_disp_fn)

            self.assertAllClose(disp_fn(R_scaled, R_scaled),
                                general_disp_fn(R, R),
                                True,
                                atol=tol,
                                rtol=tol)
            assert disp_fn(R_scaled, R_scaled).dtype == dtype
            self.assertAllClose(shift_fn(R_scaled, dR),
                                general_shift_fn(R, dR) * box_size,
                                True,
                                atol=tol,
                                rtol=tol)
            assert shift_fn(R_scaled, dR).dtype == dtype
Beispiel #27
0
def model_w_c(T, T_forecast, x, obs=None):
    # Define priors over beta, tau, sigma, z_1 (keep the shapes in mind)
    W = numpyro.sample(name="W",
                       fn=dist.Normal(loc=jnp.zeros((2, 4)),
                                      scale=jnp.ones((2, 4))))
    beta = numpyro.sample(name="beta",
                          fn=dist.Normal(loc=jnp.array([0.0, 0.0]),
                                         scale=jnp.ones(2)))
    tau = numpyro.sample(name="tau", fn=dist.HalfCauchy(scale=jnp.ones(2)))
    sigma = numpyro.sample(name="sigma", fn=dist.HalfCauchy(scale=0.1))
    z_prev = numpyro.sample(name="z_1",
                            fn=dist.Normal(loc=jnp.zeros(2),
                                           scale=jnp.ones(2)))
    # Define LKJ prior
    L_Omega = numpyro.sample("L_Omega", dist.LKJCholesky(2, 10.0))
    Sigma_lower = jnp.matmul(
        jnp.diag(jnp.sqrt(tau)),
        L_Omega)  # lower cholesky factor of the covariance matrix
    noises = numpyro.sample(
        "noises",
        fn=dist.MultivariateNormal(loc=jnp.zeros(2), scale_tril=Sigma_lower),
        sample_shape=(T + T_forecast, ),
    )
    # Propagate the dynamics forward using jax.lax.scan
    carry = (W, beta, z_prev, tau)
    z_collection = [z_prev]
    carry, zs_exp = lax.scan(f, carry, (x, noises), T + T_forecast)
    z_collection = jnp.concatenate((jnp.array(z_collection), zs_exp), axis=0)

    c = numpyro.sample(name="c",
                       fn=dist.Normal(loc=jnp.array([[0.0], [0.0]]),
                                      scale=jnp.ones((2, 1))))
    obs_mean = jnp.dot(z_collection[:T, :], c).squeeze()
    pred_mean = jnp.dot(z_collection[T:, :], c).squeeze()

    # Sample the observed y (y_obs)
    numpyro.sample(name="y_obs",
                   fn=dist.Normal(loc=obs_mean, scale=sigma),
                   obs=obs)
    numpyro.sample(name="y_pred",
                   fn=dist.Normal(loc=pred_mean, scale=sigma),
                   obs=None)
Beispiel #28
0
    def calculation(lattice, positions, species, shifts, kpts, kwargs, kwargs_diag, kwargs_overlap):
        """ creates function to return parameters and hamiltonian_wo_k
        Args:
            lattice: takes lattice matrix as 2D-Array , e.g.: jnp.diag(jnp.ones(3))
            positions: Array of position vectors of atoms
            species: Array of species, e.g. jnp.array([0, 0]) or jnp.array([0, 1])
            shifts: uses 2D shifts matrix of comupte_shifts function
            kpts: Array of coordinates of the k-points, e.g. for gamma: jnp.array([[0, 0, 0]])
        Returns: 2D- vector of eigenvalues from hamiltonian for each k point
        """
        ham_wo_k, overlap_wo_k = create_hamiltonian_wo_k(positions, species, shifts, kwargs, kwargs_diag, kwargs_overlap)
        hamiltonian = get_ham(ham_wo_k, kpts, shifts, lattice)
        hamiltonian += vmap(set_diagonal_to_inf, 0)(hamiltonian)
        # print("hamiltonian", hamiltonian.shape, jnp.round(jnp.abs(hamiltonian[0, :, :]), decimals=2))  # , jnp.iscomplex(hamiltonian))  # , "\n", hamiltonian[0, :, :])
        overlap_matrix = get_ham(overlap_wo_k, kpts, shifts, lattice)
        # print("overlap", jnp.expand_dims(jnp.diag(jnp.ones(overlap_matrix.shape[1])), 0).shape)
        overlap_matrix += jnp.expand_dims(jnp.diag(jnp.ones(overlap_matrix.shape[1])), 0)

        solution_jaxscipy = scipy.linalg.eigh(hamiltonian, eigvals_only=True)
        # print("Solutions jax scipy", solution_jaxscipy.shape, solution_jaxscipy)

        # to calculate generalized eigenvalue problem
        hamiltonian = jnp.where(jnp.abs(hamiltonian) < 10e-10, 0, hamiltonian)
        overlap_matrix = jnp.where(jnp.abs(overlap_matrix) < 10e-10, 0, overlap_matrix)
        overlap_inverse = vmap(jnp.linalg.inv, 0)(overlap_matrix)
        new_ham = vmap(jnp.dot, 0, 0)(overlap_inverse, hamiltonian)

        # solution_generalized, vectors = eigh_generalized(hamiltonian, overlap_matrix)
        # print("Solutions generalized", solution_generalized.shape, solution_generalized[0, :])

        # solution = sol_ham(new_ham[1, :, :], eig_vectors=False, generalized=False)
        solution_jaxscipy_gen = scipy.linalg.eigh(new_ham, eigvals_only=True)
        # print("Solutions gen jax scipy", solution_jaxscipy_gen.shape, solution_jaxscipy_gen[0, :])

        # solution_jaxscipy_gen_np = scipy_nonjax.linalg.eigh(new_ham[0, :, :], eigvals_only=True)
        # print("Solutions gen numpy scipy", solution_jaxscipy_gen_np.shape, solution_jaxscipy_gen_np)

        # solution_jaxnumpy_gen, _ = jnp.linalg.eigh(new_ham)
        # print("Solutions gen jax numpy", solution_jaxnumpy_gen.shape, solution_jaxnumpy_gen[0, :])
        solution_jaxscipy_gen -= find_fermi(solution_jaxscipy_gen, highest_occupied, plot=False)
        solution_jaxscipy_gen = solution_jaxscipy_gen * 27.211396  # conversion au (atomic unit) to eV
        return solution_jaxscipy_gen  # -[:, -1::-1]  # [:9, :8]
Beispiel #29
0
    def _log_probs(self, X: np.ndarray, alpha: np.ndarray) -> np.ndarray:
        """Compute class log probabilities for X.

        Args:
            X: An array of shape ``(n_samples, n_features)`` containing the training
                examples.
            alpha: The SVM normal vector scales. Normally this should be
                ``self.alpha_``, but we leave it as an argument so we can differentiate
                through this function when fitting.

        Returns:
            An array of shape ``(n_samples, n_classes)`` containing the predicted log
            probabilities for each class.

        """
        n = alpha.shape[0]
        L = jnp.tril(np.ones((n, n)))
        A = jnp.diag(alpha)
        likelihoods = (L @ A @ (self.coefs_ @ X.T + self.b_[:, None])).T
        return logsoftmax(likelihoods)
Beispiel #30
0
def variance(
    gp: NonConjugatePosterior,
    param: dict,
    test_inputs: Array,
    train_inputs: Array,
    train_outputs: Array,
):
    ell, alpha, nu = param["lengthscale"], param["variance"], param["latent"]
    Kff = gram(gp.prior.kernel, train_inputs, param)
    Kfx = cross_covariance(gp.prior.kernel, train_inputs, test_inputs, param)
    Kxx = gram(gp.prior.kernel, test_inputs, param)
    L = jnp.linalg.cholesky(Kff + jnp.eye(train_inputs.shape[0]) * 1e-6)

    A = solve_triangular(L, Kfx.T, lower=True)
    latent_var = Kxx - jnp.sum(jnp.square(A), -2)
    latent_mean = jnp.matmul(A.T, nu)
    lvar = jnp.diag(latent_var)
    moment_fn = predictive_moments(gp.likelihood)
    pred_rv = moment_fn(latent_mean.ravel(), lvar)
    return pred_rv.variance()