Beispiel #1
0
def nonbonded_v3(
    conf,
    params,
    box,
    lamb,
    charge_rescale_mask,
    lj_rescale_mask,
    beta,
    cutoff,
    lambda_plane_idxs,
    lambda_offset_idxs,
    runtime_validate=True,
):
    """Lennard-Jones + Coulomb, with a few important twists:
    * distances are computed in 4D, controlled by lambda, lambda_plane_idxs, lambda_offset_idxs
    * each pairwise LJ and Coulomb term can be multiplied by an adjustable rescale_mask parameter
    * Coulomb terms are multiplied by erfc(beta * distance)

    Parameters
    ----------
    conf : (N, 3) or (N, 4) np.array
        3D or 4D coordinates
        if 3D, will be converted to 4D using (x,y,z) -> (x,y,z,w)
            where w = cutoff * (lambda_plane_idxs + lambda_offset_idxs * lamb)
    params : (N, 3) np.array
        columns [charges, sigmas, epsilons], one row per particle
    box : Optional 3x3 np.array
    lamb : float
    charge_rescale_mask : (N, N) np.array
        the Coulomb contribution of pair (i,j) will be multiplied by charge_rescale_mask[i,j]
    lj_rescale_mask : (N, N) np.array
        the Lennard-Jones contribution of pair (i,j) will be multiplied by lj_rescale_mask[i,j]
    beta : float
        the charge product q_ij will be multiplied by erfc(beta*d_ij)
    cutoff : Optional float
        a pair of particles (i,j) will be considered non-interacting if the distance d_ij
        between their 4D coordinates exceeds cutoff
    lambda_plane_idxs : Optional (N,) np.array
    lambda_offset_idxs : Optional (N,) np.array
    runtime_validate: bool
        check whether beta is compatible with cutoff
        (if True, this function will currently not play nice with Jax JIT)
        TODO: is there a way to conditionally print a runtime warning inside
            of a Jax JIT-compiled function, without triggering a Jax ConcretizationTypeError?

    Returns
    -------
    energy : float

    References
    ----------
    * Rodinger, Howell, Pomès, 2005, J. Chem. Phys. "Absolute free energy calculations by thermodynamic integration in four spatial
        dimensions" https://aip.scitation.org/doi/abs/10.1063/1.1946750
    * Darden, York, Pedersen, 1993, J. Chem. Phys. "Particle mesh Ewald: An N log(N) method for Ewald sums in large
    systems" https://aip.scitation.org/doi/abs/10.1063/1.470117
        * Coulomb interactions are treated using the direct-space contribution from eq 2
    """
    if runtime_validate:
        assert (charge_rescale_mask == charge_rescale_mask.T).all()
        assert (lj_rescale_mask == lj_rescale_mask.T).all()

    N = conf.shape[0]

    if conf.shape[-1] == 3:
        conf = convert_to_4d(conf, lamb, lambda_plane_idxs, lambda_offset_idxs,
                             cutoff)

    # make 4th dimension of box large enough so its roughly aperiodic
    if box is not None:
        if box.shape[-1] == 3:
            box_4d = np.eye(4) * 1000
            box_4d = index_update(box_4d, index[:3, :3], box)
        else:
            box_4d = box
    else:
        box_4d = None

    box = box_4d

    charges = params[:, 0]
    sig = params[:, 1]
    eps = params[:, 2]

    sig_i = np.expand_dims(sig, 0)
    sig_j = np.expand_dims(sig, 1)
    sig_ij = sig_i + sig_j

    eps_i = np.expand_dims(eps, 0)
    eps_j = np.expand_dims(eps, 1)

    eps_ij = eps_i * eps_j

    dij = distance(conf, box)

    keep_mask = np.ones((N, N)) - np.eye(N)
    keep_mask = np.where(eps_ij != 0, keep_mask, 0)

    if cutoff is not None:
        if runtime_validate:
            validate_coulomb_cutoff(cutoff, beta, threshold=1e-2)
        eps_ij = np.where(dij < cutoff, eps_ij, 0)

    # (ytz): this avoids a nan in the gradient in both jax and tensorflow
    sig_ij = np.where(keep_mask, sig_ij, 0)
    eps_ij = np.where(keep_mask, eps_ij, 0)

    inv_dij = 1 / dij
    inv_dij = np.where(np.eye(N), 0, inv_dij)

    sig2 = sig_ij * inv_dij
    sig2 *= sig2
    sig6 = sig2 * sig2 * sig2

    eij_lj = 4 * eps_ij * (sig6 - 1.0) * sig6
    eij_lj = np.where(keep_mask, eij_lj, 0)

    qi = np.expand_dims(charges, 0)  # (1, N)
    qj = np.expand_dims(charges, 1)  # (N, 1)
    qij = np.multiply(qi, qj)

    # (ytz): trick used to avoid nans in the diagonal due to the 1/dij term.
    keep_mask = 1 - np.eye(N)
    qij = np.where(keep_mask, qij, 0)
    dij = np.where(keep_mask, dij, 0)

    # funny enough lim_{x->0} erfc(x)/x = 0
    eij_charge = np.where(keep_mask,
                          qij * erfc(beta * dij) * inv_dij,
                          0)  # zero out diagonals
    if cutoff is not None:
        eij_charge = np.where(dij > cutoff, 0, eij_charge)

    eij_total = eij_lj * lj_rescale_mask + eij_charge * charge_rescale_mask

    return np.sum(eij_total / 2)
Beispiel #2
0
 def step_fn(i, state_and_energy):
   state, energy = state_and_energy
   state = apply_fn(state)
   energy = ops.index_update(energy, i, invariant(state, kT))
   return state, energy
Beispiel #3
0
    def filter(self, init_state, sample_obs, observations=None, Vinit=None):
        """
        Run the Unscented Kalman Filter algorithm over a set of observed samples.

        Parameters
        ----------
        sample_obs: array(nsamples, obs_size)

        Returns
        -------
        * array(nsamples, state_size)
            History of filtered mean terms
        * array(nsamples, state_size, state_size)
            History of filtered covariance terms
        """
        wm_vec = jnp.array([
            1 / (2 * (self.d + self.lmbda)) if i > 0 else self.lmbda /
            (self.d + self.lmbda) for i in range(2 * self.d + 1)
        ])
        wc_vec = jnp.array([
            1 / (2 * (self.d + self.lmbda)) if i > 0 else self.lmbda /
            (self.d + self.lmbda) + (1 - self.alpha**2 + self.beta)
            for i in range(2 * self.d + 1)
        ])
        nsteps, *_ = sample_obs.shape
        mu_t = init_state
        Sigma_t = self.Q if Vinit is None else Vinit
        if observations is None:
            observations = [()] * nsteps
        else:
            observations = [(obs, ) for obs in observations]

        mu_hist = jnp.zeros((nsteps, self.d))
        Sigma_hist = jnp.zeros((nsteps, self.d, self.d))

        mu_hist = index_update(mu_hist, 0, mu_t)
        Sigma_hist = index_update(Sigma_hist, 0, Sigma_t)

        for t in range(nsteps):
            # TO-DO: use jax.scipy.linalg.sqrtm when it gets added to lib
            comp1 = mu_t[:, None] + self.gamma * self.sqrtm(Sigma_t)
            comp2 = mu_t[:, None] - self.gamma * self.sqrtm(Sigma_t)
            #sigma_points = jnp.c_[mu_t, comp1, comp2]
            sigma_points = jnp.concatenate((mu_t[:, None], comp1, comp2),
                                           axis=1)

            z_bar = self.fz(sigma_points)
            mu_bar = z_bar @ wm_vec
            Sigma_bar = (z_bar - mu_bar[:, None])
            Sigma_bar = jnp.einsum("i,ji,ki->jk", wc_vec, Sigma_bar,
                                   Sigma_bar) + self.Q

            Sigma_bar_half = self.sqrtm(Sigma_bar)
            comp1 = mu_bar[:, None] + self.gamma * Sigma_bar_half
            comp2 = mu_bar[:, None] - self.gamma * Sigma_bar_half
            #sigma_points = jnp.c_[mu_bar, comp1, comp2]
            sigma_points = jnp.concatenate((mu_bar[:, None], comp1, comp2),
                                           axis=1)

            x_bar = self.fx(sigma_points, *observations[t])
            x_hat = x_bar @ wm_vec
            St = x_bar - x_hat[:, None]
            St = jnp.einsum("i,ji,ki->jk", wc_vec, St, St) + self.R

            mu_hat_component = z_bar - mu_bar[:, None]
            x_hat_component = x_bar - x_hat[:, None]
            Sigma_bar_y = jnp.einsum("i,ji,ki->jk", wc_vec, mu_hat_component,
                                     x_hat_component)
            Kt = Sigma_bar_y @ jnp.linalg.inv(St)

            mu_t = mu_bar + Kt @ (sample_obs[t] - x_hat)
            Sigma_t = Sigma_bar - Kt @ St @ Kt.T

            mu_hist = index_update(mu_hist, t, mu_t)
            Sigma_hist = index_update(Sigma_hist, t, Sigma_t)

        return mu_hist, Sigma_hist
Beispiel #4
0
  def build_cells(R: Array, extra_capacity: int=0, **kwargs) -> CellList:
    N = R.shape[0]
    dim = R.shape[1]

    _cell_capacity = cell_capacity + extra_capacity

    if dim != 2 and dim != 3:
      # NOTE(schsam): Do we want to check this in compute_fn as well?
      raise ValueError(
          'Cell list spatial dimension must be 2 or 3. Found {}'.format(dim))

    neighborhood_tile_count = 3 ** dim

    _, cell_size, cells_per_side, cell_count = \
        _cell_dimensions(dim, box_size, minimum_cell_size)

    hash_multipliers = _compute_hash_constants(dim, cells_per_side)

    # Create cell list data.
    particle_id = lax.iota(jnp.int64, N)
    # NOTE(schsam): We use the convention that particles that are successfully,
    # copied have their true id whereas particles empty slots have id = N.
    # Then when we copy data back from the grid, copy it to an array of shape
    # [N + 1, output_dimension] and then truncate it to an array of shape
    # [N, output_dimension] which ignores the empty slots.
    mask_id = jnp.ones((N,), jnp.int64) * N
    cell_R = jnp.zeros((cell_count * _cell_capacity, dim), dtype=R.dtype)
    cell_id = N * jnp.ones((cell_count * _cell_capacity, 1), dtype=i32)

    # It might be worth adding an occupied mask. However, that will involve
    # more compute since often we will do a mask for species that will include
    # an occupancy test. It seems easier to design around this empty_data_value
    # for now and revisit the issue if it comes up later.
    empty_kwarg_value = 10 ** 5
    cell_kwargs = {}
    for k, v in kwargs.items():
      if not util.is_array(v):
        raise ValueError((
          'Data must be specified as an ndarry. Found "{}" with '
          'type {}'.format(k, type(v))))
      if v.shape[0] != R.shape[0]:
        raise ValueError(
          ('Data must be specified per-particle (an ndarray with shape '
           '(R.shape[0], ...)). Found "{}" with shape {}'.format(k, v.shape)))
      kwarg_shape = v.shape[1:] if v.ndim > 1 else (1,)
      cell_kwargs[k] = empty_kwarg_value * jnp.ones(
        (cell_count * _cell_capacity,) + kwarg_shape, v.dtype)

    indices = jnp.array(R / cell_size, dtype=i32)
    hashes = jnp.sum(indices * hash_multipliers, axis=1)

    # Copy the particle data into the grid. Here we use a trick to allow us to
    # copy into all cells simultaneously using a single lax.scatter call. To do
    # this we first sort particles by their cell hash. We then assign each
    # particle to have a cell id = hash * cell_capacity + grid_id where grid_id
    # is a flat list that repeats 0, .., cell_capacity. So long as there are
    # fewer than cell_capacity particles per cell, each particle is guarenteed
    # to get a cell id that is unique.
    sort_map = jnp.argsort(hashes)
    sorted_R = R[sort_map]
    sorted_hash = hashes[sort_map]
    sorted_id = particle_id[sort_map]

    sorted_kwargs = {}
    for k, v in kwargs.items():
      sorted_kwargs[k] = v[sort_map]

    sorted_cell_id = jnp.mod(lax.iota(jnp.int64, N), _cell_capacity)
    sorted_cell_id = sorted_hash * _cell_capacity + sorted_cell_id

    cell_R = ops.index_update(cell_R, sorted_cell_id, sorted_R)
    sorted_id = jnp.reshape(sorted_id, (N, 1))
    cell_id = ops.index_update(
        cell_id, sorted_cell_id, sorted_id)
    cell_R = _unflatten_cell_buffer(cell_R, cells_per_side, dim)
    cell_id = _unflatten_cell_buffer(cell_id, cells_per_side, dim)

    for k, v in sorted_kwargs.items():
      if v.ndim == 1:
        v = jnp.reshape(v, v.shape + (1,))
      cell_kwargs[k] = ops.index_update(cell_kwargs[k], sorted_cell_id, v)
      cell_kwargs[k] = _unflatten_cell_buffer(
        cell_kwargs[k], cells_per_side, dim)

    return CellList(cell_R, cell_id, cell_kwargs)  # pytype: disable=wrong-arg-count
Beispiel #5
0
def _update_history_scalars(history, new):
    # TODO(Jakob-Unfried) use rolling buffer instead? See #6053
    return ops.index_update(jnp.roll(history, -1, axis=0), ops.index[-1], new)
Beispiel #6
0
 def count(cell_hash, filling):
     count = np.sum(particle_hash == cell_hash)
     filling = ops.index_update(filling, ops.index[cell_hash], count)
     return filling
Beispiel #7
0
 def inv(self, y):
     size = self.permutation.size
     permutation_inv = ops.index_update(np.zeros(size, dtype=np.int64),
                                        self.permutation, np.arange(size))
     return y[..., permutation_inv]
Beispiel #8
0
 def body_fun(i, k):
     ti = t0 + dt * alpha[i - 1]
     yi = y0 + dt * np.dot(beta[i - 1, :], k)
     ft = func(yi, ti)
     return ops.index_update(k, jax.ops.index[i, :], ft)
Beispiel #9
0
def do_subject(subject_id):
    fm = agg_fit_metadata.loc[subject_id]
    name = base_models.loc[subject_id]['name']
    agg_res = aggregation_results.loc[subject_id]
    starting_model = agg_res.model

    o = Optimizer(agg_res,
                  *fm[['psf', 'galaxy_data', 'sigma_image']],
                  oversample_n=5)
    # define the parameters controlling only the brightness of components, and
    # fit them first
    L_keys = get_luminosity_keys(o.model)

    # perform the first fit
    with tqdm(desc='Fitting brightness', leave=False) as bar:
        res = minimize(
            __f,
            onp.array([o.model_[k] for k in L_keys]),
            jac=__j,
            args=(o, L_keys),
            callback=__bar_incrementer(bar),
            bounds=onp.array([o.lims_[k] for k in L_keys]),
        )

    # update the optimizer with the new parameters
    for k, v in zip(L_keys, res['x']):
        o[k] = v

    # perform the full fit
    with tqdm(desc='Fitting everything', leave=False) as bar:
        res_full = minimize(__f,
                            onp.array([o.model_[k] for k in o.keys]),
                            jac=__j,
                            args=(o, o.keys),
                            callback=__bar_incrementer(bar),
                            bounds=onp.array(
                                [o.lims_[k0][k1] for k0, k1 in o.keys]),
                            options=dict(maxiter=10000))

    final_model = pd.Series({
        **deepcopy(o.model_),
        **{k: v
           for k, v in zip(o.keys, res_full['x'])}
    })

    # correct the parameters of spirals in this model for the new disk,
    # allowing rendering of the model without needing the rotation of the disk
    # before fitting
    final_model = correct_spirals(final_model, o.base_roll)

    # fix component axis ratios (if > 1, flip major and minor axis)
    final_model = correct_axratio(final_model)

    # remove components with zero brightness
    final_model = remove_zero_brightness_components(final_model)

    # lower the indices of spirals where possible
    final_model = lower_spiral_indices(final_model)

    comps = o.render_comps(final_model.to_dict(), correct_spirals=False)

    d = ops.index_update(
        psf_conv(sum(comps.values()), o.psf) - o.target, o.mask, np.nan)
    chisq = float(np.sum((d[~o.mask] / o.sigma[~o.mask])**2) / (~o.mask).sum())
    disk_spiral_L = (final_model[('disk', 'L')] +
                     (comps['spiral'].sum() if 'spiral' in comps else 0))
    # fractions were originally parametrized vs the disk and spirals (bulge
    # had no knowledge of bar and vice versa)
    bulge_frac = final_model.get(('bulge', 'frac'), 0)
    bar_frac = final_model.get(('bar', 'frac'), 0)

    bulge_L = bulge_frac * disk_spiral_L / (1 - bulge_frac)
    bar_L = bar_frac * disk_spiral_L / (1 - bar_frac)
    gal_L = disk_spiral_L + bulge_L + bar_L

    bulge_frac = bulge_L / (disk_spiral_L + bulge_L + bar_L)
    bar_frac = bar_L / (disk_spiral_L + bulge_L + bar_L)

    deparametrized_model = from_reparametrization(final_model, o)

    ftol = 2.220446049250313e-09

    # Also calculate Hessian-errors
    errs = np.sqrt(
        max(1, abs(res_full.fun)) * ftol *
        np.diag(res_full.hess_inv.todense()))

    os.makedirs('affirmation_subjects_results/tuning_results', exist_ok=True)
    pd.to_pickle(
        dict(
            base_model=starting_model,
            fit_model=final_model,
            deparametrized=deparametrized_model,
            res=res_full,
            chisq=chisq,
            comps=comps,
            r_band_luminosity=float(gal_L),
            bulge_frac=float(bulge_frac),
            bar_frac=float(bar_frac),
            errs=errs,
            keys=o.keys,
        ), 'affirmation_subjects_results/tuning_results/{}.pickle.gz'.format(
            name))
Beispiel #10
0
def test_input_admin(t, y, r, t_test, y_test, r_test):
    """
    TODO: tidy this function up
    Order the inputs, remove duplicates, and index the train and test input locations.
    :param t: training inputs [N, 1]
    :param y: observations at the training inputs [N, 1]
    :param r: training spatial inputs
    :param t_test: testing inputs [N*, 1]
    :param y_test: observations at the test inputs [N*, 1]
    :param r_test: test spatial inputs
    :return:
        t_all: the combined and sorted training and test inputs [N + N*, 1]
        y_all: an array of observations y augmented with nans at test locations [N + N*, R]
        r_all: spatial inputs with nans at test locations [N + N*, R]
        dt_all: combined training and test step sizes, Δtₙ = tₙ - tₙ₋₁ [N + N*, 1]
        dt_train: training step sizes, Δtₙ = tₙ - tₙ₋₁ [N, 1]
        train_id: an array of indices corresponding to the training inputs [N, 1]
        test_id: an array of indices corresponding to the test inputs [N*, 1]
        mask: boolean array to signify training locations [N + N*, 1]
    """
    assert t.shape[0] == y.shape[0]
    if t.ndim < 2:
        t = np.expand_dims(t, 1)  # make 2-D
    if y.ndim < 2:
        y = np.expand_dims(y, 1)  # make 2-D
    if r is None:
        r = np.nan * t  # np.empty((1,) + x.shape[1:]) * np.nan
    if r.ndim < 2:
        r = np.expand_dims(r, 1)  # make 2-D
    ind = np.argsort(t[:, 0], axis=0)
    t_train = t[ind, ...]
    y_train = y[ind, ...]
    r_train = r[ind, ...]
    if t_test is None:
        t_test = np.empty((1, ) + t_train.shape[1:]) * np.nan
        r_test = np.empty((1, ) + t_train.shape[1:]) * np.nan
    else:
        if t_test.ndim < 2:
            t_test = np.expand_dims(t_test, 1)  # make 2-D
        test_sort_ind = np.argsort(t_test[:, 0], axis=0)
        t_test = t_test[test_sort_ind, ...]
        if y_test is not None:
            y_test = y_test[test_sort_ind, ...].reshape((-1, ) + y.shape[1:])
        if r_test is not None:
            r_test = r_test[test_sort_ind, ...]
        else:
            r_test = np.nan * t_test
    if not (t_test.shape[1] == t_train.shape[1]):
        t_test = np.concatenate([
            t_test[:, 0][:, None],
            np.nan * np.empty([t_test.shape[0], t_train.shape[1] - 1])
        ],
                                axis=1)
    # here we use non-JAX numpy to sort out indexing of these static arrays
    t_train_test = np.concatenate([t_train, t_test])
    keep_ind = ~np.isnan(t_train_test[:, 0])
    t_train_test = t_train_test[keep_ind, ...]
    if r_test.shape[1] != r_train.shape[
            1]:  # do spatial test points have different dimensionality to training points?
        r_test_nan = np.nan * np.zeros([r_test.shape[0], r_train.shape[1]])
    else:
        r_test_nan = r_test
    r_train_test = np.concatenate([r_train, r_test_nan])
    r_train_test = r_train_test[keep_ind, ...]
    t_ind = np.argsort(t_train_test[:, 0])
    t_all = t_train_test[t_ind]
    r_all = r_train_test[t_ind]
    reverse_ind = np.argsort(t_ind)
    n_train = t_train.shape[0]
    train_id = reverse_ind[:n_train]  # index the training locations
    test_id = reverse_ind[n_train:]  # index the test locations
    y_all = np.nan * np.zeros([
        t_all.shape[0], y_train.shape[1]
    ])  # observation vector with nans at test locations
    # y_all[reverse_ind[:n_train], ...] = y_train  # and the data at the train locations
    y_all = index_update(y_all, index[reverse_ind[:n_train]],
                         y_train)  # and the data at the train locations
    if y_test is not None:
        # y_all[reverse_ind[n_train:], ...] = y_test  # and the data at the train locations
        y_all = index_update(y_all, index[reverse_ind[n_train:]],
                             y_test)  # and the data at the train locations
    mask = np.ones_like(y_all, dtype=bool)
    # mask[train_id] = False
    mask = index_update(mask, index[train_id], False)
    dt_all = np.concatenate([np.array([0.0]), np.diff(t_all[:, 0])])
    return (np.array(t_all, dtype=np.float64), np.array(y_all,
                                                        dtype=np.float64),
            np.array(r_all,
                     dtype=np.float64), np.array(r_test, dtype=np.float64),
            np.array(dt_all,
                     dtype=np.float64), np.array(train_id, dtype=np.int64),
            np.array(test_id, dtype=np.int64), np.array(mask, dtype=bool))
Beispiel #11
0
def get_g(batch_size, A, B, C, Q, Ru, Rv, K, L, T, baseline=0):
    # mini_batch is a single gradient(log sum derivative of pi), avg of this is ordinary gradient
    # but here it is equivalent to g.
    sigma_K = 5e-1
    sigma_L = 5e-1
    sigma_x = 1e-4
    nx, nu = B.shape
    _, nw = C.shape
    K = K.reshape((nu, nx))
    L = L.reshape((nw, nx))
    Q = np.kron(np.eye(T, dtype=int), Q)
    Rv = np.kron(np.eye(T, dtype=int), Rv)
    Ru = np.kron(np.eye(T, dtype=int), Ru)

    X = np.zeros((nx * (T + 1), batch_size))
    # X[0:nx,:] = 0.2 * random.normal(key, shape=(nx,batch_size))
    X = ops.index_update(X, ops.index[0:nx, :],
                         0.2 * random.normal(key, shape=(nx, batch_size)))

    U = np.zeros((nu * T, batch_size))
    W = np.zeros((nw * T, batch_size))
    Vu = sigma_K * random.normal(key,
                                 shape=(nu * T, batch_size))  # noise for U
    Vw = sigma_L * random.normal(key,
                                 shape=(nw * T, batch_size))  # noise for W

    for t in range(T):
        # U[t*nu:(t+1)*nu,:] = np.matmul(K,X[nx*t:nx*(t+1),:]) + Vu[t*nu:(t+1)*nu,:]
        U = ops.index_update(
            U, ops.index[t * nu:(t + 1) * nu, :],
            np.matmul(K, X[nx * t:nx * (t + 1), :]) +
            Vu[t * nu:(t + 1) * nu, :])
        # W[t*nw:(t + 1) * nw, :] = np.matmul(L, X[nx * t:nx * (t + 1), :]) + Vw[t * nw:(t + 1) * nw, :]
        W = ops.index_update(
            W, ops.index[t * nw:(t + 1) * nw, :],
            np.matmul(L, X[nx * t:nx * (t + 1), :]) +
            Vw[t * nw:(t + 1) * nw, :])
        # X[nx*(t+1):nx*(t+2),:] = np.matmul(A,X[nx*t:nx*(t+1),:]) + np.matmul(B,U[t*nu:(t+1)*nu,:]).reshape((nx,batch_size)) +\
        #                        + np.matmul(C,W[t*nw:(t+1)*nw,:]).reshape((nx,batch_size)) + sigma_x * random.normal(key, shape=(nx, batch_size))
        X = ops.index_update(
            X, ops.index[nx * (t + 1):nx * (t + 2), :],
            np.matmul(A, X[nx * t:nx * (t + 1), :]) +
            np.matmul(B, U[t * nu:(t + 1) * nu, :]).reshape((nx, batch_size)) +
            np.matmul(C, W[t * nw:(t + 1) * nw, :]).reshape((nx, batch_size)) +
            sigma_x * random.normal(key, shape=(nx, batch_size)))

    X_cost = X[nx:, :]
    reward = np.diagonal(np.matmul(X_cost.T, Q.dot(X_cost))) + np.diagonal(
        np.matmul(U.T, Ru.dot(U))) - np.diagonal(np.matmul(W.T, Rv.dot(W)))
    new_baseline = np.mean(reward)
    reward = reward.reshape((len(reward), 1))

    #DK portion
    X_hat = X[:
              -nx, :]  #taking only T = 0:T-1 for X for log gradient computation
    outer_grad_log_K = np.einsum(
        "ik, jk -> ijk", Vu, X_hat
    )  # shape (a,b,c) means there are a of the (b,c) blocks. access (b,c) blocks via C[0,:,:]
    outer_grad_log_L = np.einsum("ik, jk -> ijk", Vw, X_hat)
    sum_grad_log_K = 0
    sum_grad_log_L = 0
    for t in range(T):
        sum_grad_log_K += outer_grad_log_K[
            nu * t:nu * (t + 1), nx * t:nx *
            (t +
             1), :]  # Summing all diagonal blocks. gives p by d by batch_size
        sum_grad_log_L += outer_grad_log_L[nw * t:nw * (t + 1),
                                           nx * t:nx * (t + 1), :]

    mini_batch_K = (1 / sigma_K)**2 * (
        (reward - new_baseline).T * sum_grad_log_K
    )  #mini_batch is p by d, same size as K
    mini_batch_L = (1 / sigma_L)**2 * (
        (reward - new_baseline).T * sum_grad_log_L
    )  # mini_batch is b by a/d, same size as K
    # mini_batch_K = 2 * ((reward-new_baseline).T*sum_grad_log_K) #mini_batch is p by d, same size as K
    # mini_batch_L =  2 * ((reward - new_baseline).T * sum_grad_log_L)  # mini_batch is b by a/d, same size as K
    # print(mini_batch_K[0,0,:])

    temp = np.einsum('mnr,ndr->mdr', sum_grad_log_K.swapaxes(0, 1),
                     sum_grad_log_L)
    batch_mixed_KL = (1 / (sigma_K * sigma_L))**2 * (
        (reward - new_baseline).T * temp)
    # print('---new---',sum_grad_log_K[:,:,10][0,0])

    return np.mean(mini_batch_K,
                   axis=2), np.mean(mini_batch_L,
                                    axis=2), np.mean(batch_mixed_KL,
                                                     axis=2), new_baseline
Beispiel #12
0
 def loop_body(i, acc_arr):
     arr1 = ops.index_update(acc_arr, i, acc_arr[i] + 2.)
     return lax.cond(
         i % 2 == 0, arr1,
         lambda arr1: ops.index_update(arr1, i, arr1[i] + 1.), arr1,
         lambda arr1: arr1)
Beispiel #13
0
 def _body_fn(i, vals):
     val, collection = vals
     val = body_fun(val)
     collection = ops.index_update(collection, i, ravel_fn(val))
     return val, collection
Beispiel #14
0
 def _body_fn(i, vals):
     val, collection = vals
     val = body_fun(val)
     i = np.where(i >= lower, i - lower, 0)
     collection = ops.index_update(collection, i, ravel_fn(val))
     return val, collection
Beispiel #15
0
def lanczos_alg(matrix_vector_product, dim, order, rng_key):
  """Lanczos algorithm for tridiagonalizing a real symmetric matrix.

  This function applies Lanczos algorithm of a given order.  This function
  does full reorthogonalization.

  WARNING: This function may take a long time to jit compile (e.g. ~3min for
  order 90 and dim 1e7).

  Args:
    matrix_vector_product: Maps v -> Hv for a real symmetric matrix H.
      Input/Output must be of shape [dim].
    dim: Matrix H is [dim, dim].
    order: An integer corresponding to the number of Lanczos steps to take.
    rng_key: The jax PRNG key.

  Returns:
    tridiag: A tridiagonal matrix of size (order, order).
    vecs: A numpy array of size (order, dim) corresponding to the Lanczos
      vectors.
  """

  tridiag = np.zeros((order, order))
  vecs = np.zeros((order, dim))

  init_vec = random.normal(rng_key, shape=(dim,))
  init_vec = init_vec / np.linalg.norm(init_vec)
  vecs = ops.index_update(vecs, 0, init_vec)

  beta = 0
  # TODO(gilmer): Better to use lax.fori loop for faster compile?
  for i in range(order):
    v = vecs[i, :].reshape((dim))
    if i == 0:
      v_old = 0
    else:
      v_old = vecs[i - 1, :].reshape((dim))

    w = matrix_vector_product(v)
    assert (w.shape[0] == dim and len(w.shape) == 1), (
        'Output of matrix_vector_product(v) must be of shape [dim].')
    w = w - beta * v_old

    alpha = np.dot(w, v)
    tridiag = ops.index_update(tridiag, (i, i), alpha)
    w = w - alpha * v

    # Full Reorthogonalization
    for j in range(i):
      tau = vecs[j, :].reshape((dim))
      coeff = np.dot(w, tau)
      w += -coeff * tau

    beta = np.linalg.norm(w)

    # TODO(gilmer): The tf implementation raises an exception if beta < 1e-6
    # here. However JAX cannot compile a function that has an if statement
    # that depends on a dynamic variable. Should we still handle this base?
    # beta being small indicates that the lanczos vectors are linearly
    # dependent.

    if i + 1 < order:
      tridiag = ops.index_update(tridiag, (i, i+1), beta)
      tridiag = ops.index_update(tridiag, (i+1, i), beta)
      vecs = ops.index_update(vecs, i+1, w/beta)
  return (tridiag, vecs)
Beispiel #16
0
    ax.set_xlim([0,500])
    plt.show()

    # reward feebdack plots
    ax = plt.subplot(2,1,1)
    ax.plot( r[1] ) # expected reward 
    # maybe expected reward is shifted by one trial compared to plosone.
    # should it be expected reward at the start of the trial?
    # or expected reward for next trial?
    ax.plot( r[0] ) # actual reward
    ax.set_xlim([0,500])
    
    ax = plt.subplot(2,1,2)
    ax.plot( [0,500],[0,0],'--', color = [0,0,0])    
    ax.plot( r[2] ) # reward prediction error
    ax.set_ylim([-1,1])
    ax.set_xlim([0,500])    
    plt.show()


    
    # this simulation produces Healthy BG population activity
    # from the plosone paper as in Fig4 and Fig6
    key = PRNGKey( time.time_ns() )
    w_pfc = 0.01*uniform(key,(2,3,2))
    w_pfc = index_update( w_pfc, index[0,0,0], 0.7 )
    w_pfc = index_update( w_pfc, index[1,1,0], 0.7 )
    uu = jnp.zeros((nn,14))
    uu = do_trial_for_figure( [key, w_pfc] ) 
    plot_all(uu)
Beispiel #17
0
    def compute_surface_fourier_series(self, r_surface):
        """
		Inputs: r_surface is a NZ x NT x 3 array which has x,y,z as a function of zeta and theta.

		Outputs a 3 x 2 x 2 x WSNFZ + 1 x WSNFT + 1 array which contains the Fourier components of the surface.

		"""
        NZ = r_surface.shape[0]
        NT = r_surface.shape[1]
        x_s = r_surface[:, :, 0]
        y_s = r_surface[:, :, 1]
        z_s = r_surface[:, :, 2]

        # xyz x sin/cos(zeta) x sin/cos(theta) x fz x ft
        result = np.zeros((3, 2, 2, self.WSNFZ + 1, self.WSNFT + 1))

        zeta = np.linspace(0, 2 * PI, NZ + 1)[0:NZ]
        theta = np.linspace(0, 2 * PI, NT + 1)[0:NT]

        # X^{cc}_{0,0} terms for x,y,z
        result = index_update(result, index[:, 1, 1, 0, 0],
                              np.mean(r_surface, axis=(0, 1)))

        for m in range(1, self.WSNFZ + 1):
            # X_{cc}_{m,0}
            result = index_update(result, index[:,1,1,m,0], 2.0 * \
             np.mean(r_surface[:,:,:] * \
              np.cos(m * zeta)[:,np.newaxis,np.newaxis], axis=(0,1)))
            # X_{sc}_{m,0}
            result = index_update(result, index[:,0,1,m,0], 2.0 * \
             np.mean(r_surface[:,:,:] * \
              np.sin(m * zeta)[:,np.newaxis,np.newaxis], axis=(0,1)))

        for n in range(1, self.WSNFT + 1):
            # X_{cc}_{0,n}
            result = index_update(result, index[:,1,1,0,n], 2.0 * \
             np.mean(r_surface[:,:,:] * \
              np.cos(n * theta)[np.newaxis,:,np.newaxis], axis=(0,1)))
            # X_{cs}_{0,n}
            result = index_update(result, index[:,1,0,0,n], 2.0 * \
             np.mean(r_surface[:,:,:] * \
              np.sin(n * theta)[np.newaxis,:,np.newaxis], axis=(0,1)))

        for m in range(1, self.WSNFZ + 1):
            for n in range(1, self.WSNFT + 1):
                # X_{ss}_{m,n}
                result = index_update(result, index[:,0,0,m,n], 4.0 * \
                 np.mean(r_surface[:,:,:] * \
                  np.sin(n * theta)[np.newaxis,:,np.newaxis] * \
                  np.sin(m * zeta)[:,np.newaxis,np.newaxis], axis=(0,1)))
                # X_{cs}_{m,n}
                result = index_update(result, index[:,1,0,m,n], 4.0 * \
                 np.mean(r_surface[:,:,:] * \
                  np.sin(n * theta)[np.newaxis,:,np.newaxis] * \
                  np.cos(m * zeta)[:,np.newaxis,np.newaxis], axis=(0,1)))
                # X_{sc}_{m,n}
                result = index_update(result, index[:,0,1,m,n], 4.0 * \
                 np.mean(r_surface[:,:,:] * \
                  np.cos(n * theta)[np.newaxis,:,np.newaxis] * \
                  np.sin(m * zeta)[:,np.newaxis,np.newaxis], axis=(0,1)))
                # X_{cc}_{m,n}
                result = index_update(result, index[:,1,1,m,n], 4.0 * \
                 np.mean(r_surface[:,:,:] * \
                  np.cos(n * theta)[np.newaxis,:,np.newaxis] * \
                  np.cos(m * zeta)[:,np.newaxis,np.newaxis], axis=(0,1)))
        return result
Beispiel #18
0
def do_trial(kwr, reversal_learning):

    key = kwr[0]
    w_pfc = kwr[1]
    Re = kwr[2][0]
    key, subkey = split( key ) # jax makes us handle prng state ourselves    
    vv =  0.1 * uniform(subkey,(7*2,) )
    # set initial conditions before each simulation
    # i've taken these details from plosone modeldb matlab code:
    vv = index_update( vv, index[4] ,   vv[4]+.6   )
    vv = index_update( vv, index[4+6] , vv[4+6]+.6 )
    vv = index_update( vv, index[0] , 0. )
    vv = index_update( vv, index[1] , 0. )
    vk = [vv,key]

    sim_step = partial( simulation_step, w_pfc = w_pfc ) 
    
    # for debugging purposes use this loop:
    # (but don't call this do_trial methods hundreds of times)
    #for i in range(nn-1):
    #    vk = sim_step(i+1,vk)
    
    # for performance use this "loop":
    vv,key = lax.fori_loop( 1, nn, sim_step, vk )

    # rename variables for sanity
    pfc = vv[:2] # note this is of length two
    d1A = vv[2]
    d2A = vv[3]
    pmcA = vv[7]
    d1B = vv[2+6]
    d2B = vv[3+6]    
    pmcB = vv[7+6]
    pmc = [pmcA,pmcB]

    # these jax.lax.cond constructs replace some traditional
    # "if statement" condition logic blocks. this is done for
    # quick and easy jax.jit comaptibility. please check out
    # the jax documentation.

    # this picks the rewarded action
    rewardedAction, otherAction = lax.cond( reversal_learning,
                                            pmc,
                                            lambda x: [x[1],x[0]],
                                            pmc,
                                            lambda x: [x[0],x[1]] )

    # this determines reward for this trial
    R_trial = lax.cond( rewardedAction > otherAction + 0.1,
                        None,
                        lambda x: 1,
                        None,
                        lambda x: 0 )
    # reward prediction error:
    SNc = R_trial - Re

    # expected reward for next trial:
    a = 0.15
    Re_next = a * R_trial + (1 - a)*Re

    # weight updates:
    # the i,j,k notation below refers to a diagram that i drew and tacked
    # to my cork board. it should end up in git repo.
    # hopefully these comments explain the array
    # operations that we use to update all 12 weights with just few commands
    
    # i notation is ellided; that is the dimension of cues {#1,#2}
    # but we are instead just handling the vector pfc (shape = (2,))
    # j notation here indicates dimension of bg loops: {A,B}
    # k notation here indicates dimension of neuronal populations: {d1,d2,pmc}
    # qjk: array with population firing rates in each loop
    # sk: may modify qjk with SNc (which is reward prediction error)
    sk = jnp.array( [SNc,SNc,1] )
    qjk = jnp.array( [[d1A,d2A,pmcA],[d1B,d2B,pmcB]] )
    # sq: product of sk and qjk;
    # this should only modify d1,d2 msns by SNc; pmc is multiplied by 1
    # why? explanation:
    # pfc -> d1,d2 weights are modified by reward prediction error
    # pfc -> pmc weights are updated in a hebbian fashion
    sq = (sk * qjk).reshape((2,3,1))
    # we reshape this product from (2,3) to (2,3,1)
    # in preparation for weight update operations
    
    # weight update rule:    
    # pfc * sq 
    # (2,) * (2,3,1) -> (2,3,2), which is same shape as w_pfc
    # lrk: learning rate for each population (3,1)
    # lrk * ( the product of pfc and sq):
    # (3,1) * (2,3,2) -> (2,3,2)
    # frk: forgetting rate for each population (3,1)
    # frk * w_pfc:
    # (3,1) * (2,3,2) -> (2,3,2)
    dw_pfc = lrk * pfc * sq - frk * w_pfc
    
    # update weights; force new weights to be positive:
    w_pfc = jnp.clip( w_pfc + dw_pfc, a_min = 0)
    # w_pfc should be (2,3,2)

    return key, w_pfc, [Re_next,R_trial,SNc] , pmc
Beispiel #19
0
    def log_marginal_likelihood(self,
                                theta=None,
                                eval_gradient=False,
                                clone_kernel=False):
        """Returns log-marginal likelihood of theta for training data.

        Parameters
        ----------
        theta : array-like of shape (n_kernel_params,) or None
            Kernel hyperparameters for which the log-marginal likelihood is
            evaluated. If None, the precomputed log_marginal_likelihood
            of ``self.kernel_.theta`` is returned.
        eval_gradient : bool, default: False
            If True, the gradient of the log-marginal likelihood with respect
            to the kernel hyperparameters at position theta is returned
            additionally. If True, theta must not be None.
        clone_kernel : bool, default=True
            If True, the kernel attribute is copied. If False, the kernel
            attribute is modified, but may result in a performance improvement.
        Returns
        -------
        log_likelihood : float
            Log-marginal likelihood of theta for training data.
        log_likelihood_gradient : array, shape = (n_kernel_params,), optional
            Gradient of the log-marginal likelihood with respect to the kernel
            hyperparameters at position theta.
            Only returned when eval_gradient is True.
        """

        if theta is None:
            if eval_gradient:
                raise ValueError(
                    "Gradient can only be evaluated for theta!=None")
            return self.log_marginal_likelihood_value_

        kernel_matrix_fn = self.kernel_.get_kernel_matrix_fn(eval_gradient)

        if eval_gradient:
            K, K_gradient = kernel_matrix_fn(theta, self.X_train_, None)
        else:
            K = kernel_matrix_fn(theta, self.X_train_, None)

        # Compute log-marginal-likelihood Z and also store some temporaries
        # which can be reused for computing Z's gradient
        Z, (pi, W_sr, L, b, a) = \
            self._posterior_mode(K, return_temporaries=True)

        if not eval_gradient:
            return Z

        # Compute gradient based on Algorithm 5.1 of GPML

        d_Z = np.empty(theta.shape[0])
        # XXX: Get rid of the np.diag() in the next line
        R = W_sr[:, np.newaxis] * cho_solve((L, True), np.diag(W_sr))  # Line 7
        C = solve(L, W_sr[:, np.newaxis] * K)  # Line 8
        # Line 9: (use einsum to compute np.diag(C.T.dot(C))))
        s_2 = -0.5 * (np.diag(K) - np.einsum('ij, ij -> j', C, C)) \
            * (pi * (1 - pi) * (1 - 2 * pi))  # third derivative

        for j in range(d_Z.shape[0]):
            C = K_gradient[:, :, j]  # Line 11
            # Line 12: (R.T.ravel().dot(C.ravel()) = np.trace(R.dot(C)))
            s_1 = .5 * a.T.dot(C).dot(a) - .5 * R.T.ravel().dot(C.ravel())

            b = C.dot(self.y_train_ - pi)  # Line 13
            s_3 = b - K.dot(R.dot(b))  # Line 14

            d_Z = ops.index_update(d_Z, j, s_1 + s_2.T.dot(s_3))  # Line 15

        return (numpy.asarray(Z, dtype=numpy.float64),
                numpy.asarray(d_Z, dtype=numpy.float64))
Beispiel #20
0
    def filter(self, x_hist, jump_size, dt):
        """
        Compute the online version of the Kalman-Filter, i.e,
        the one-step-ahead prediction for the hidden state or the
        time update step
        
        Parameters
        ----------
        x_hist: array(timesteps, observation_size)
            
        Returns
        -------
        * array(timesteps, state_size):
            Filtered means mut
        * array(timesteps, state_size, state_size)
            Filtered covariances Sigmat
        * array(timesteps, state_size)
            Filtered conditional means mut|t-1
        * array(timesteps, state_size, state_size)
            Filtered conditional covariances Sigmat|t-1
        """
        I = jnp.eye(self.state_size)
        timesteps, *_ = x_hist.shape
        mu_hist = jnp.zeros((timesteps, self.state_size))
        Sigma_hist = jnp.zeros((timesteps, self.state_size, self.state_size))
        Sigma_cond_hist = jnp.zeros((timesteps, self.state_size, self.state_size))
        mu_cond_hist = jnp.zeros((timesteps, self.state_size))
        
        # Initial configuration
        K1 = self.Sigma0 @ self.C.T @ inv(self.C @ self.Sigma0 @ self.C.T + self.R)
        mu1 = self.mu0 + K1 @ (x_hist[0] - self.C @ self.mu0)
        Sigma1 = (I - K1 @ self.C) @ self.Sigma0

        mu_hist = index_update(mu_hist, 0, mu1)
        Sigma_hist = index_update(Sigma_hist, 0, Sigma1)
        mu_cond_hist = index_update(mu_cond_hist, 0, self.mu0)
        Sigma_cond_hist = index_update(Sigma_hist, 0, self.Sigma0)
        
        Sigman = Sigma1.copy()
        mun = mu1.copy()
        for n in range(1, timesteps):
            # Runge-kutta integration step
            for _ in range(jump_size):
                k1 = self.A @ mun
                k2 = self.A @ (mun + dt * k1)
                mun = mun + dt * (k1 + k2) / 2

                k1 = self.A @ Sigman @ self.A.T + self.Q
                k2 = self.A @ (Sigman + dt * k1) @ self.A.T + self.Q
                Sigman = Sigman + dt * (k1 + k2) / 2

            Sigman_cond = Sigman.copy()
            St = self.C @ Sigman_cond @ self.C.T + self.R
            Kn = Sigman_cond @ self.C.T @ inv(St)

            mu_update = mun.copy()
            x_update = self.C @ mun
            mun = mu_update + Kn @ (x_hist[n] - x_update)
            Sigman = (I - Kn @ self.C) @ Sigman_cond

            mu_hist = index_update(mu_hist, n, mun)
            Sigma_hist = index_update(Sigma_hist, n, Sigman)
            mu_cond_hist = index_update(mu_cond_hist, n, mu_update)
            Sigma_cond_hist = index_update(Sigma_cond_hist, n, Sigman_cond)
        
        return mu_hist, Sigma_hist, mu_cond_hist, Sigma_cond_hist
Beispiel #21
0
# print(cov)

rng, key = random.split(rng)
comp = random.choice(key, M, shape=(N, ), p=pi)

samples = jnp.zeros(shape=(N, D), dtype=float)
rng, *key = random.split(rng, M + 1)
for j in range(M):
    idxs = j == comp
    n_j = idxs.sum()
    if n_j > 0:
        x = random.multivariate_normal(key[j],
                                       mean=mu[j],
                                       cov=cov[j],
                                       shape=(n_j, ))
        samples = index_update(samples, index[idxs, :], x)

true_S = jnp.array([
    jnp.append(jnp.append(cov[j] + jnp.outer(mu[j], mu[j]),
                          jnp.array([mu[j]]),
                          axis=0),
               jnp.array([jnp.append(mu[j], 1)]).T,
               axis=1) for j in range(M)
])
true_eta = jnp.array([jnp.log(pi[j] / pi[-1]) for j in range(M - 1)])

piemp = jnp.array([jnp.mean(comp == i) for i in range(M)])
muemp = jnp.array([jnp.mean(samples[comp == i], axis=0) for i in range(M)])
covemp = jnp.array([
    (samples[comp == i].T @ samples[comp == i]) / jnp.sum(comp == i)
    for i in range(M)
Beispiel #22
0
def _cofactor_solve(a, b):
  """Equivalent to det(a)*solve(a, b) for nonsingular mat.

  Intermediate function used for jvp and vjp of det.
  This function borrows heavily from jax.numpy.linalg.solve and
  jax.numpy.linalg.slogdet to compute the gradient of the determinant
  in a way that is well defined even for low rank matrices.

  This function handles two different cases:
  * rank(a) == n or n-1
  * rank(a) < n-1

  For rank n-1 matrices, the gradient of the determinant is a rank 1 matrix.
  Rather than computing det(a)*solve(a, b), which would return NaN, we work
  directly with the LU decomposition. If a = p @ l @ u, then
  det(a)*solve(a, b) =
  prod(diag(u)) * u^-1 @ l^-1 @ p^-1 b =
  prod(diag(u)) * triangular_solve(u, solve(p @ l, b))
  If a is rank n-1, then the lower right corner of u will be zero and the
  triangular_solve will fail.
  Let x = solve(p @ l, b) and y = det(a)*solve(a, b).
  Then y_{n}
  x_{n} / u_{nn} * prod_{i=1...n}(u_{ii}) =
  x_{n} * prod_{i=1...n-1}(u_{ii})
  So by replacing the lower-right corner of u with prod_{i=1...n-1}(u_{ii})^-1
  we can avoid the triangular_solve failing.
  To correctly compute the rest of y_{i} for i != n, we simply multiply
  x_{i} by det(a) for all i != n, which will be zero if rank(a) = n-1.

  For the second case, a check is done on the matrix to see if `solve`
  returns NaN or Inf, and gives a matrix of zeros as a result, as the
  gradient of the determinant of a matrix with rank less than n-1 is 0.
  This will still return the correct value for rank n-1 matrices, as the check
  is applied *after* the lower right corner of u has been updated.

  Args:
    a: A square matrix or batch of matrices, possibly singular.
    b: A matrix, or batch of matrices of the same dimension as a.

  Returns:
    det(a) and cofactor(a)^T*b, aka adjugate(a)*b
  """
  a = _promote_arg_dtypes(jnp.asarray(a))
  b = _promote_arg_dtypes(jnp.asarray(b))
  a_shape = jnp.shape(a)
  b_shape = jnp.shape(b)
  a_ndims = len(a_shape)
  if not (a_ndims >= 2 and a_shape[-1] == a_shape[-2]
    and b_shape[-2:] == a_shape[-2:]):
    msg = ("The arguments to _cofactor_solve must have shapes "
           "a=[..., m, m] and b=[..., m, m]; got a={} and b={}")
    raise ValueError(msg.format(a_shape, b_shape))
  if a_shape[-1] == 1:
    return a[0, 0], b
  # lu contains u in the upper triangular matrix and l in the strict lower
  # triangular matrix.
  # The diagonal of l is set to ones without loss of generality.
  lu, pivots, permutation = lax_linalg.lu(a)
  dtype = lax.dtype(a)
  batch_dims = lax.broadcast_shapes(lu.shape[:-2], b.shape[:-2])
  x = jnp.broadcast_to(b, batch_dims + b.shape[-2:])
  lu = jnp.broadcast_to(lu, batch_dims + lu.shape[-2:])
  # Compute (partial) determinant, ignoring last diagonal of LU
  diag = jnp.diagonal(lu, axis1=-2, axis2=-1)
  parity = jnp.count_nonzero(pivots != jnp.arange(a_shape[-1]), axis=-1)
  sign = jnp.asarray(-2 * (parity % 2) + 1, dtype=dtype)
  # partial_det[:, -1] contains the full determinant and
  # partial_det[:, -2] contains det(u) / u_{nn}.
  partial_det = jnp.cumprod(diag, axis=-1) * sign[..., None]
  lu = ops.index_update(lu, ops.index[..., -1, -1], 1.0 / partial_det[..., -2])
  permutation = jnp.broadcast_to(permutation, batch_dims + (a_shape[-1],))
  iotas = jnp.ix_(*(lax.iota(jnp.int32, b) for b in batch_dims + (1,)))
  # filter out any matrices that are not full rank
  d = jnp.ones(x.shape[:-1], x.dtype)
  d = lax_linalg.triangular_solve(lu, d, left_side=True, lower=False)
  d = jnp.any(jnp.logical_or(jnp.isnan(d), jnp.isinf(d)), axis=-1)
  d = jnp.tile(d[..., None, None], d.ndim*(1,) + x.shape[-2:])
  x = jnp.where(d, jnp.zeros_like(x), x)  # first filter
  x = x[iotas[:-1] + (permutation, slice(None))]
  x = lax_linalg.triangular_solve(lu, x, left_side=True, lower=True,
                                  unit_diagonal=True)
  x = jnp.concatenate((x[..., :-1, :] * partial_det[..., -1, None, None],
                      x[..., -1:, :]), axis=-2)
  x = lax_linalg.triangular_solve(lu, x, left_side=True, lower=False)
  x = jnp.where(d, jnp.zeros_like(x), x)  # second filter

  return partial_det[..., -1], x
Beispiel #23
0
 def copy_values_from_cell(value, cell_value, cell_id):
   scatter_indices = jnp.reshape(cell_id, (-1,))
   cell_value = jnp.reshape(cell_value, (-1,) + cell_value.shape[-2:])
   return ops.index_update(value, scatter_indices, cell_value)
Beispiel #24
0
 def scan_fn(BB, elems):
     o, g = elems
     BB = index_update(BB, index[:, o], BB[:, o] + g)
     return BB, jnp.zeros((0,))
Beispiel #25
0
 def step_fn(i, state_and_energy):
   state, energy = state_and_energy
   state = apply_fn(state)
   energy = ops.index_update(energy, i, E_T(state))
   return state, energy
Beispiel #26
0
  def build_cells(R):
    N = R.shape[0]
    dim = R.shape[1]

    if dim != 2 and dim != 3:
      raise ValueError(
          'Cell list spatial dimension must be 2 or 3. Found {}'.format(dim))

    neighborhood_tile_count = 3 ** dim

    _, cell_size, cells_per_side, cell_count = \
        _cell_dimensions(dim, box_size, minimum_cell_size)

    if species is None:
      _species = np.zeros((N,), dtype=i32)
    else:
      _species = species

    hash_multipliers = _compute_hash_constants(dim, cells_per_side)

    # Create cell list data.
    particle_id = lax.iota(np.int64, N)
    mask_id = np.ones((N,), np.int64) * N
    cell_R = np.zeros((cell_count * cell_capacity, dim), dtype=R.dtype)
    empty_species_index = i32(1000)
    cell_species = empty_species_index * np.ones(
        (cell_count * cell_capacity, 1), dtype=_species.dtype)
    cell_id = N * np.ones((cell_count * cell_capacity, 1), dtype=i32)

    indices = np.array(R / cell_size, dtype=i32)
    hashes = np.sum(indices * hash_multipliers, axis=1)

    # Copy the particle data into the grid. Here we use a trick to allow us to
    # copy into all cells simultaneously using a single lax.scatter call. To do
    # this we first sort particles by their cell hash. We then assign each
    # particle to have a cell id = hash * cell_capacity + grid_id where grid_id
    # is a flat list that repeats 0, .., cell_capacity. So long as there are
    # fewer than cell_capacity particles per cell, each particle is guarenteed
    # to get a cell id that is unique.
    sort_map = np.argsort(hashes)
    sorted_R = R[sort_map]
    sorted_species = _species[sort_map]
    sorted_hash = hashes[sort_map]
    sorted_id = particle_id[sort_map]

    sorted_cell_id = np.mod(lax.iota(np.int64, N), cell_capacity)
    sorted_cell_id = sorted_hash * cell_capacity + sorted_cell_id

    cell_R = ops.index_update(cell_R, sorted_cell_id, sorted_R)
    sorted_species = np.reshape(sorted_species, (N, 1))
    cell_species = ops.index_update(
        cell_species, sorted_cell_id, sorted_species)
    sorted_id = np.reshape(sorted_id, (N, 1))
    cell_id = ops.index_update(
        cell_id, sorted_cell_id, sorted_id)
    
    cell_R = _unflatten_cell_buffer(cell_R, cells_per_side, dim)
    cell_species = _unflatten_cell_buffer(cell_species, cells_per_side, dim)
    cell_id = _unflatten_cell_buffer(cell_id, cells_per_side, dim)

    return CellList(N, dim, cell_count, cell_R, cell_species, cell_id)
Beispiel #27
0
 def inv(self, y):
     size = self.permutation.size
     permutation_inv = ops.index_update(jnp.zeros(size, dtype=canonicalize_dtype(jnp.int64)),
                                        self.permutation,
                                        jnp.arange(size))
     return y[..., permutation_inv]
Beispiel #28
0
 def copy_values_from_cell(value, cell_value, cell_id):
   scatter_indices = np.reshape(cell_id, (-1,))
   cell_value = np.reshape(cell_value, (-1, output_dimension))
   return ops.index_update(value, scatter_indices, cell_value)
Beispiel #29
0
                iteration.append(it)

    graph = np.stack((iteration, training_loss))

    print("total time:", time.time() - tot, "s")

    # np.random.seed(42)
    t_test, W_test = fetch_minibatch(T, M, N, D)
    # X_pred, Y_pred, Y_tilde_pred, Z, DYDT = vXYZpaths(params, t_test, W_test, Xzero)
    # X_pred, Y_pred, Y_tilde_pred, Z = vXYZpaths(params, t_test, W_test, Xzero)
    X_pred, Y_pred, Y_tilde_pred, Z, DY_pred, DY_tilde_pred = vXYZpaths(
        params, t_test, W_test, Xzero)

    Dt = jnp.zeros((M, N + 1, 1))  # M x (N+1) x 1
    dt = T / N
    new_Dt = index_update(Dt, index[:, 1:, :], dt)
    t_plot = jnp.cumsum(new_Dt, axis=1)  # M x (N+1) x 1

    Y_test = jnp.reshape(
        u_exact(np.reshape(t_plot[0:M, :, :], [-1, 1]),
                jnp.reshape(X_pred[0:M, :, :], [-1, D])),
        [M, -1, 1])  # fix all these uneccessary reshapes at some point

    np.save('t_test.npy', t_test)
    np.save('W_test.npy', W_test)
    np.save('t_plot.npy', t_plot)
    np.save('X_pred.npy', X_pred)
    np.save('Y_pred.npy', Y_pred)
    np.save('Y_tilde_pred.npy', Y_tilde_pred)
    np.save('Y_test.npy', Y_test)
    # np.save('DYDT_test.npy', DYDT)
Beispiel #30
0
def get_env(ipeps_tensors, chi_ctm, bvar_threshold, max_iter):
    # TODO should we symmetrise the ipeps tensor?

    a, = ipeps_tensors
    chi_peps = a.shape[1]

    # initialise environment
    # (p*,uldr) & (p,uldr) -> (uldr,uldr) -> (uu,ll,dd,rr)
    flat_tens = np.transpose(np.tensordot(np.conj(a), a, [0, 0]),
                             [0, 4, 1, 5, 2, 6, 3, 7])
    u, _u, l, _l, d, _d, r, _r = flat_tens.shape
    # (uu,ll,dd,rr) -> (U,L,d,d',R)
    flat_tens = np.reshape(flat_tens, [u * _u, l * _l, d, _d, r * _r])
    c_init = np.sum(flat_tens, axis=(2, 3, 4))  # (D,R)
    t_init = np.sum(flat_tens, axis=0)  # (L,d,d',R)

    if c_init.shape[0] > chi_ctm:
        c_init = c_init[:chi_ctm, :chi_ctm]
        t_init = t_init[:chi_ctm, :, :, :chi_ctm]

    # enforce c4v symmetry
    c_init, t_init = _c4v_symmetrise(c_init, t_init, normalise=True)
    # expand to full chi_ctm, for traceability
    _chi = c_init.shape[0]
    if _chi < chi_ctm:
        c_init = index_update(np.zeros([chi_ctm, chi_ctm], dtype=c_init.dtype),
                              index[:_chi, :_chi], c_init)
        t_init = index_update(
            np.zeros([chi_ctm, chi_peps, chi_peps, chi_ctm],
                     dtype=t_init.dtype), index[:_chi, :, :, :_chi], t_init)

    env_init = c_init, t_init

    def update(b, env):
        c, t = env

        # C insertion
        c_tilde = ncon([c, t, t, b, np.conj(b)],
                       [[1, 2], [2, 3, 4, -4], [-1, 5, 6, 1],
                        [7, 3, 5, -2, -5], [7, 4, 6, -3, -6]],
                       [1, 2, 3, 5, 4, 6, 7])
        # (D,d,d',R,r,r') -> (D~,R~)
        _D, _d, _d_, _R, _r, _r_ = c_tilde.shape
        c_tilde = np.reshape(c_tilde, [_D * _d * _d_, _R * _r * _r_])
        # T insertion
        t_tilde = ncon(
            [t, b, np.conj(b)],
            [[-1, 1, 2, -6], [3, 1, -2, -4, -7], [3, 2, -3, -5, -8]])
        # (L,l,l',d,d',R,r,r') -> (L~,d,d',R~)
        _L, _l, _l_, _d, _d_, _R, _r, _r_ = t_tilde.shape
        t_tilde = np.reshape(t_tilde, [_L * _l * _l_, _d, _d_, _R * _r * _r_])

        # enforce symmetry
        c_tilde = _c4v_symmetrise_c(c_tilde)

        # find projector
        P, _, _ = svd_truncated(c_tilde, chi_max=chi_ctm, cutoff=0.)  # (D~,R)

        # renormalise
        c = np.transpose(P) @ c_tilde @ P
        t = ncon([np.conj(P), t_tilde, np.conj(P)],
                 [[1, -1], [1, -2, -3, 2], [2, -4]])

        # enforce symmetry
        env = _c4v_symmetrise(c, t)

        return env

    def convergence_condition(b, env, _):
        c, t = env
        return _variance3(c, t, b) < bvar_threshold

    env_star = fixed_points.fixed_point_novjp(update,
                                              a,
                                              env_init,
                                              convergence_condition,
                                              max_iter=max_iter)
    return env_star