def nonbonded_v3( conf, params, box, lamb, charge_rescale_mask, lj_rescale_mask, beta, cutoff, lambda_plane_idxs, lambda_offset_idxs, runtime_validate=True, ): """Lennard-Jones + Coulomb, with a few important twists: * distances are computed in 4D, controlled by lambda, lambda_plane_idxs, lambda_offset_idxs * each pairwise LJ and Coulomb term can be multiplied by an adjustable rescale_mask parameter * Coulomb terms are multiplied by erfc(beta * distance) Parameters ---------- conf : (N, 3) or (N, 4) np.array 3D or 4D coordinates if 3D, will be converted to 4D using (x,y,z) -> (x,y,z,w) where w = cutoff * (lambda_plane_idxs + lambda_offset_idxs * lamb) params : (N, 3) np.array columns [charges, sigmas, epsilons], one row per particle box : Optional 3x3 np.array lamb : float charge_rescale_mask : (N, N) np.array the Coulomb contribution of pair (i,j) will be multiplied by charge_rescale_mask[i,j] lj_rescale_mask : (N, N) np.array the Lennard-Jones contribution of pair (i,j) will be multiplied by lj_rescale_mask[i,j] beta : float the charge product q_ij will be multiplied by erfc(beta*d_ij) cutoff : Optional float a pair of particles (i,j) will be considered non-interacting if the distance d_ij between their 4D coordinates exceeds cutoff lambda_plane_idxs : Optional (N,) np.array lambda_offset_idxs : Optional (N,) np.array runtime_validate: bool check whether beta is compatible with cutoff (if True, this function will currently not play nice with Jax JIT) TODO: is there a way to conditionally print a runtime warning inside of a Jax JIT-compiled function, without triggering a Jax ConcretizationTypeError? Returns ------- energy : float References ---------- * Rodinger, Howell, Pomès, 2005, J. Chem. Phys. "Absolute free energy calculations by thermodynamic integration in four spatial dimensions" https://aip.scitation.org/doi/abs/10.1063/1.1946750 * Darden, York, Pedersen, 1993, J. Chem. Phys. "Particle mesh Ewald: An N log(N) method for Ewald sums in large systems" https://aip.scitation.org/doi/abs/10.1063/1.470117 * Coulomb interactions are treated using the direct-space contribution from eq 2 """ if runtime_validate: assert (charge_rescale_mask == charge_rescale_mask.T).all() assert (lj_rescale_mask == lj_rescale_mask.T).all() N = conf.shape[0] if conf.shape[-1] == 3: conf = convert_to_4d(conf, lamb, lambda_plane_idxs, lambda_offset_idxs, cutoff) # make 4th dimension of box large enough so its roughly aperiodic if box is not None: if box.shape[-1] == 3: box_4d = np.eye(4) * 1000 box_4d = index_update(box_4d, index[:3, :3], box) else: box_4d = box else: box_4d = None box = box_4d charges = params[:, 0] sig = params[:, 1] eps = params[:, 2] sig_i = np.expand_dims(sig, 0) sig_j = np.expand_dims(sig, 1) sig_ij = sig_i + sig_j eps_i = np.expand_dims(eps, 0) eps_j = np.expand_dims(eps, 1) eps_ij = eps_i * eps_j dij = distance(conf, box) keep_mask = np.ones((N, N)) - np.eye(N) keep_mask = np.where(eps_ij != 0, keep_mask, 0) if cutoff is not None: if runtime_validate: validate_coulomb_cutoff(cutoff, beta, threshold=1e-2) eps_ij = np.where(dij < cutoff, eps_ij, 0) # (ytz): this avoids a nan in the gradient in both jax and tensorflow sig_ij = np.where(keep_mask, sig_ij, 0) eps_ij = np.where(keep_mask, eps_ij, 0) inv_dij = 1 / dij inv_dij = np.where(np.eye(N), 0, inv_dij) sig2 = sig_ij * inv_dij sig2 *= sig2 sig6 = sig2 * sig2 * sig2 eij_lj = 4 * eps_ij * (sig6 - 1.0) * sig6 eij_lj = np.where(keep_mask, eij_lj, 0) qi = np.expand_dims(charges, 0) # (1, N) qj = np.expand_dims(charges, 1) # (N, 1) qij = np.multiply(qi, qj) # (ytz): trick used to avoid nans in the diagonal due to the 1/dij term. keep_mask = 1 - np.eye(N) qij = np.where(keep_mask, qij, 0) dij = np.where(keep_mask, dij, 0) # funny enough lim_{x->0} erfc(x)/x = 0 eij_charge = np.where(keep_mask, qij * erfc(beta * dij) * inv_dij, 0) # zero out diagonals if cutoff is not None: eij_charge = np.where(dij > cutoff, 0, eij_charge) eij_total = eij_lj * lj_rescale_mask + eij_charge * charge_rescale_mask return np.sum(eij_total / 2)
def step_fn(i, state_and_energy): state, energy = state_and_energy state = apply_fn(state) energy = ops.index_update(energy, i, invariant(state, kT)) return state, energy
def filter(self, init_state, sample_obs, observations=None, Vinit=None): """ Run the Unscented Kalman Filter algorithm over a set of observed samples. Parameters ---------- sample_obs: array(nsamples, obs_size) Returns ------- * array(nsamples, state_size) History of filtered mean terms * array(nsamples, state_size, state_size) History of filtered covariance terms """ wm_vec = jnp.array([ 1 / (2 * (self.d + self.lmbda)) if i > 0 else self.lmbda / (self.d + self.lmbda) for i in range(2 * self.d + 1) ]) wc_vec = jnp.array([ 1 / (2 * (self.d + self.lmbda)) if i > 0 else self.lmbda / (self.d + self.lmbda) + (1 - self.alpha**2 + self.beta) for i in range(2 * self.d + 1) ]) nsteps, *_ = sample_obs.shape mu_t = init_state Sigma_t = self.Q if Vinit is None else Vinit if observations is None: observations = [()] * nsteps else: observations = [(obs, ) for obs in observations] mu_hist = jnp.zeros((nsteps, self.d)) Sigma_hist = jnp.zeros((nsteps, self.d, self.d)) mu_hist = index_update(mu_hist, 0, mu_t) Sigma_hist = index_update(Sigma_hist, 0, Sigma_t) for t in range(nsteps): # TO-DO: use jax.scipy.linalg.sqrtm when it gets added to lib comp1 = mu_t[:, None] + self.gamma * self.sqrtm(Sigma_t) comp2 = mu_t[:, None] - self.gamma * self.sqrtm(Sigma_t) #sigma_points = jnp.c_[mu_t, comp1, comp2] sigma_points = jnp.concatenate((mu_t[:, None], comp1, comp2), axis=1) z_bar = self.fz(sigma_points) mu_bar = z_bar @ wm_vec Sigma_bar = (z_bar - mu_bar[:, None]) Sigma_bar = jnp.einsum("i,ji,ki->jk", wc_vec, Sigma_bar, Sigma_bar) + self.Q Sigma_bar_half = self.sqrtm(Sigma_bar) comp1 = mu_bar[:, None] + self.gamma * Sigma_bar_half comp2 = mu_bar[:, None] - self.gamma * Sigma_bar_half #sigma_points = jnp.c_[mu_bar, comp1, comp2] sigma_points = jnp.concatenate((mu_bar[:, None], comp1, comp2), axis=1) x_bar = self.fx(sigma_points, *observations[t]) x_hat = x_bar @ wm_vec St = x_bar - x_hat[:, None] St = jnp.einsum("i,ji,ki->jk", wc_vec, St, St) + self.R mu_hat_component = z_bar - mu_bar[:, None] x_hat_component = x_bar - x_hat[:, None] Sigma_bar_y = jnp.einsum("i,ji,ki->jk", wc_vec, mu_hat_component, x_hat_component) Kt = Sigma_bar_y @ jnp.linalg.inv(St) mu_t = mu_bar + Kt @ (sample_obs[t] - x_hat) Sigma_t = Sigma_bar - Kt @ St @ Kt.T mu_hist = index_update(mu_hist, t, mu_t) Sigma_hist = index_update(Sigma_hist, t, Sigma_t) return mu_hist, Sigma_hist
def build_cells(R: Array, extra_capacity: int=0, **kwargs) -> CellList: N = R.shape[0] dim = R.shape[1] _cell_capacity = cell_capacity + extra_capacity if dim != 2 and dim != 3: # NOTE(schsam): Do we want to check this in compute_fn as well? raise ValueError( 'Cell list spatial dimension must be 2 or 3. Found {}'.format(dim)) neighborhood_tile_count = 3 ** dim _, cell_size, cells_per_side, cell_count = \ _cell_dimensions(dim, box_size, minimum_cell_size) hash_multipliers = _compute_hash_constants(dim, cells_per_side) # Create cell list data. particle_id = lax.iota(jnp.int64, N) # NOTE(schsam): We use the convention that particles that are successfully, # copied have their true id whereas particles empty slots have id = N. # Then when we copy data back from the grid, copy it to an array of shape # [N + 1, output_dimension] and then truncate it to an array of shape # [N, output_dimension] which ignores the empty slots. mask_id = jnp.ones((N,), jnp.int64) * N cell_R = jnp.zeros((cell_count * _cell_capacity, dim), dtype=R.dtype) cell_id = N * jnp.ones((cell_count * _cell_capacity, 1), dtype=i32) # It might be worth adding an occupied mask. However, that will involve # more compute since often we will do a mask for species that will include # an occupancy test. It seems easier to design around this empty_data_value # for now and revisit the issue if it comes up later. empty_kwarg_value = 10 ** 5 cell_kwargs = {} for k, v in kwargs.items(): if not util.is_array(v): raise ValueError(( 'Data must be specified as an ndarry. Found "{}" with ' 'type {}'.format(k, type(v)))) if v.shape[0] != R.shape[0]: raise ValueError( ('Data must be specified per-particle (an ndarray with shape ' '(R.shape[0], ...)). Found "{}" with shape {}'.format(k, v.shape))) kwarg_shape = v.shape[1:] if v.ndim > 1 else (1,) cell_kwargs[k] = empty_kwarg_value * jnp.ones( (cell_count * _cell_capacity,) + kwarg_shape, v.dtype) indices = jnp.array(R / cell_size, dtype=i32) hashes = jnp.sum(indices * hash_multipliers, axis=1) # Copy the particle data into the grid. Here we use a trick to allow us to # copy into all cells simultaneously using a single lax.scatter call. To do # this we first sort particles by their cell hash. We then assign each # particle to have a cell id = hash * cell_capacity + grid_id where grid_id # is a flat list that repeats 0, .., cell_capacity. So long as there are # fewer than cell_capacity particles per cell, each particle is guarenteed # to get a cell id that is unique. sort_map = jnp.argsort(hashes) sorted_R = R[sort_map] sorted_hash = hashes[sort_map] sorted_id = particle_id[sort_map] sorted_kwargs = {} for k, v in kwargs.items(): sorted_kwargs[k] = v[sort_map] sorted_cell_id = jnp.mod(lax.iota(jnp.int64, N), _cell_capacity) sorted_cell_id = sorted_hash * _cell_capacity + sorted_cell_id cell_R = ops.index_update(cell_R, sorted_cell_id, sorted_R) sorted_id = jnp.reshape(sorted_id, (N, 1)) cell_id = ops.index_update( cell_id, sorted_cell_id, sorted_id) cell_R = _unflatten_cell_buffer(cell_R, cells_per_side, dim) cell_id = _unflatten_cell_buffer(cell_id, cells_per_side, dim) for k, v in sorted_kwargs.items(): if v.ndim == 1: v = jnp.reshape(v, v.shape + (1,)) cell_kwargs[k] = ops.index_update(cell_kwargs[k], sorted_cell_id, v) cell_kwargs[k] = _unflatten_cell_buffer( cell_kwargs[k], cells_per_side, dim) return CellList(cell_R, cell_id, cell_kwargs) # pytype: disable=wrong-arg-count
def _update_history_scalars(history, new): # TODO(Jakob-Unfried) use rolling buffer instead? See #6053 return ops.index_update(jnp.roll(history, -1, axis=0), ops.index[-1], new)
def count(cell_hash, filling): count = np.sum(particle_hash == cell_hash) filling = ops.index_update(filling, ops.index[cell_hash], count) return filling
def inv(self, y): size = self.permutation.size permutation_inv = ops.index_update(np.zeros(size, dtype=np.int64), self.permutation, np.arange(size)) return y[..., permutation_inv]
def body_fun(i, k): ti = t0 + dt * alpha[i - 1] yi = y0 + dt * np.dot(beta[i - 1, :], k) ft = func(yi, ti) return ops.index_update(k, jax.ops.index[i, :], ft)
def do_subject(subject_id): fm = agg_fit_metadata.loc[subject_id] name = base_models.loc[subject_id]['name'] agg_res = aggregation_results.loc[subject_id] starting_model = agg_res.model o = Optimizer(agg_res, *fm[['psf', 'galaxy_data', 'sigma_image']], oversample_n=5) # define the parameters controlling only the brightness of components, and # fit them first L_keys = get_luminosity_keys(o.model) # perform the first fit with tqdm(desc='Fitting brightness', leave=False) as bar: res = minimize( __f, onp.array([o.model_[k] for k in L_keys]), jac=__j, args=(o, L_keys), callback=__bar_incrementer(bar), bounds=onp.array([o.lims_[k] for k in L_keys]), ) # update the optimizer with the new parameters for k, v in zip(L_keys, res['x']): o[k] = v # perform the full fit with tqdm(desc='Fitting everything', leave=False) as bar: res_full = minimize(__f, onp.array([o.model_[k] for k in o.keys]), jac=__j, args=(o, o.keys), callback=__bar_incrementer(bar), bounds=onp.array( [o.lims_[k0][k1] for k0, k1 in o.keys]), options=dict(maxiter=10000)) final_model = pd.Series({ **deepcopy(o.model_), **{k: v for k, v in zip(o.keys, res_full['x'])} }) # correct the parameters of spirals in this model for the new disk, # allowing rendering of the model without needing the rotation of the disk # before fitting final_model = correct_spirals(final_model, o.base_roll) # fix component axis ratios (if > 1, flip major and minor axis) final_model = correct_axratio(final_model) # remove components with zero brightness final_model = remove_zero_brightness_components(final_model) # lower the indices of spirals where possible final_model = lower_spiral_indices(final_model) comps = o.render_comps(final_model.to_dict(), correct_spirals=False) d = ops.index_update( psf_conv(sum(comps.values()), o.psf) - o.target, o.mask, np.nan) chisq = float(np.sum((d[~o.mask] / o.sigma[~o.mask])**2) / (~o.mask).sum()) disk_spiral_L = (final_model[('disk', 'L')] + (comps['spiral'].sum() if 'spiral' in comps else 0)) # fractions were originally parametrized vs the disk and spirals (bulge # had no knowledge of bar and vice versa) bulge_frac = final_model.get(('bulge', 'frac'), 0) bar_frac = final_model.get(('bar', 'frac'), 0) bulge_L = bulge_frac * disk_spiral_L / (1 - bulge_frac) bar_L = bar_frac * disk_spiral_L / (1 - bar_frac) gal_L = disk_spiral_L + bulge_L + bar_L bulge_frac = bulge_L / (disk_spiral_L + bulge_L + bar_L) bar_frac = bar_L / (disk_spiral_L + bulge_L + bar_L) deparametrized_model = from_reparametrization(final_model, o) ftol = 2.220446049250313e-09 # Also calculate Hessian-errors errs = np.sqrt( max(1, abs(res_full.fun)) * ftol * np.diag(res_full.hess_inv.todense())) os.makedirs('affirmation_subjects_results/tuning_results', exist_ok=True) pd.to_pickle( dict( base_model=starting_model, fit_model=final_model, deparametrized=deparametrized_model, res=res_full, chisq=chisq, comps=comps, r_band_luminosity=float(gal_L), bulge_frac=float(bulge_frac), bar_frac=float(bar_frac), errs=errs, keys=o.keys, ), 'affirmation_subjects_results/tuning_results/{}.pickle.gz'.format( name))
def test_input_admin(t, y, r, t_test, y_test, r_test): """ TODO: tidy this function up Order the inputs, remove duplicates, and index the train and test input locations. :param t: training inputs [N, 1] :param y: observations at the training inputs [N, 1] :param r: training spatial inputs :param t_test: testing inputs [N*, 1] :param y_test: observations at the test inputs [N*, 1] :param r_test: test spatial inputs :return: t_all: the combined and sorted training and test inputs [N + N*, 1] y_all: an array of observations y augmented with nans at test locations [N + N*, R] r_all: spatial inputs with nans at test locations [N + N*, R] dt_all: combined training and test step sizes, Δtₙ = tₙ - tₙ₋₁ [N + N*, 1] dt_train: training step sizes, Δtₙ = tₙ - tₙ₋₁ [N, 1] train_id: an array of indices corresponding to the training inputs [N, 1] test_id: an array of indices corresponding to the test inputs [N*, 1] mask: boolean array to signify training locations [N + N*, 1] """ assert t.shape[0] == y.shape[0] if t.ndim < 2: t = np.expand_dims(t, 1) # make 2-D if y.ndim < 2: y = np.expand_dims(y, 1) # make 2-D if r is None: r = np.nan * t # np.empty((1,) + x.shape[1:]) * np.nan if r.ndim < 2: r = np.expand_dims(r, 1) # make 2-D ind = np.argsort(t[:, 0], axis=0) t_train = t[ind, ...] y_train = y[ind, ...] r_train = r[ind, ...] if t_test is None: t_test = np.empty((1, ) + t_train.shape[1:]) * np.nan r_test = np.empty((1, ) + t_train.shape[1:]) * np.nan else: if t_test.ndim < 2: t_test = np.expand_dims(t_test, 1) # make 2-D test_sort_ind = np.argsort(t_test[:, 0], axis=0) t_test = t_test[test_sort_ind, ...] if y_test is not None: y_test = y_test[test_sort_ind, ...].reshape((-1, ) + y.shape[1:]) if r_test is not None: r_test = r_test[test_sort_ind, ...] else: r_test = np.nan * t_test if not (t_test.shape[1] == t_train.shape[1]): t_test = np.concatenate([ t_test[:, 0][:, None], np.nan * np.empty([t_test.shape[0], t_train.shape[1] - 1]) ], axis=1) # here we use non-JAX numpy to sort out indexing of these static arrays t_train_test = np.concatenate([t_train, t_test]) keep_ind = ~np.isnan(t_train_test[:, 0]) t_train_test = t_train_test[keep_ind, ...] if r_test.shape[1] != r_train.shape[ 1]: # do spatial test points have different dimensionality to training points? r_test_nan = np.nan * np.zeros([r_test.shape[0], r_train.shape[1]]) else: r_test_nan = r_test r_train_test = np.concatenate([r_train, r_test_nan]) r_train_test = r_train_test[keep_ind, ...] t_ind = np.argsort(t_train_test[:, 0]) t_all = t_train_test[t_ind] r_all = r_train_test[t_ind] reverse_ind = np.argsort(t_ind) n_train = t_train.shape[0] train_id = reverse_ind[:n_train] # index the training locations test_id = reverse_ind[n_train:] # index the test locations y_all = np.nan * np.zeros([ t_all.shape[0], y_train.shape[1] ]) # observation vector with nans at test locations # y_all[reverse_ind[:n_train], ...] = y_train # and the data at the train locations y_all = index_update(y_all, index[reverse_ind[:n_train]], y_train) # and the data at the train locations if y_test is not None: # y_all[reverse_ind[n_train:], ...] = y_test # and the data at the train locations y_all = index_update(y_all, index[reverse_ind[n_train:]], y_test) # and the data at the train locations mask = np.ones_like(y_all, dtype=bool) # mask[train_id] = False mask = index_update(mask, index[train_id], False) dt_all = np.concatenate([np.array([0.0]), np.diff(t_all[:, 0])]) return (np.array(t_all, dtype=np.float64), np.array(y_all, dtype=np.float64), np.array(r_all, dtype=np.float64), np.array(r_test, dtype=np.float64), np.array(dt_all, dtype=np.float64), np.array(train_id, dtype=np.int64), np.array(test_id, dtype=np.int64), np.array(mask, dtype=bool))
def get_g(batch_size, A, B, C, Q, Ru, Rv, K, L, T, baseline=0): # mini_batch is a single gradient(log sum derivative of pi), avg of this is ordinary gradient # but here it is equivalent to g. sigma_K = 5e-1 sigma_L = 5e-1 sigma_x = 1e-4 nx, nu = B.shape _, nw = C.shape K = K.reshape((nu, nx)) L = L.reshape((nw, nx)) Q = np.kron(np.eye(T, dtype=int), Q) Rv = np.kron(np.eye(T, dtype=int), Rv) Ru = np.kron(np.eye(T, dtype=int), Ru) X = np.zeros((nx * (T + 1), batch_size)) # X[0:nx,:] = 0.2 * random.normal(key, shape=(nx,batch_size)) X = ops.index_update(X, ops.index[0:nx, :], 0.2 * random.normal(key, shape=(nx, batch_size))) U = np.zeros((nu * T, batch_size)) W = np.zeros((nw * T, batch_size)) Vu = sigma_K * random.normal(key, shape=(nu * T, batch_size)) # noise for U Vw = sigma_L * random.normal(key, shape=(nw * T, batch_size)) # noise for W for t in range(T): # U[t*nu:(t+1)*nu,:] = np.matmul(K,X[nx*t:nx*(t+1),:]) + Vu[t*nu:(t+1)*nu,:] U = ops.index_update( U, ops.index[t * nu:(t + 1) * nu, :], np.matmul(K, X[nx * t:nx * (t + 1), :]) + Vu[t * nu:(t + 1) * nu, :]) # W[t*nw:(t + 1) * nw, :] = np.matmul(L, X[nx * t:nx * (t + 1), :]) + Vw[t * nw:(t + 1) * nw, :] W = ops.index_update( W, ops.index[t * nw:(t + 1) * nw, :], np.matmul(L, X[nx * t:nx * (t + 1), :]) + Vw[t * nw:(t + 1) * nw, :]) # X[nx*(t+1):nx*(t+2),:] = np.matmul(A,X[nx*t:nx*(t+1),:]) + np.matmul(B,U[t*nu:(t+1)*nu,:]).reshape((nx,batch_size)) +\ # + np.matmul(C,W[t*nw:(t+1)*nw,:]).reshape((nx,batch_size)) + sigma_x * random.normal(key, shape=(nx, batch_size)) X = ops.index_update( X, ops.index[nx * (t + 1):nx * (t + 2), :], np.matmul(A, X[nx * t:nx * (t + 1), :]) + np.matmul(B, U[t * nu:(t + 1) * nu, :]).reshape((nx, batch_size)) + np.matmul(C, W[t * nw:(t + 1) * nw, :]).reshape((nx, batch_size)) + sigma_x * random.normal(key, shape=(nx, batch_size))) X_cost = X[nx:, :] reward = np.diagonal(np.matmul(X_cost.T, Q.dot(X_cost))) + np.diagonal( np.matmul(U.T, Ru.dot(U))) - np.diagonal(np.matmul(W.T, Rv.dot(W))) new_baseline = np.mean(reward) reward = reward.reshape((len(reward), 1)) #DK portion X_hat = X[: -nx, :] #taking only T = 0:T-1 for X for log gradient computation outer_grad_log_K = np.einsum( "ik, jk -> ijk", Vu, X_hat ) # shape (a,b,c) means there are a of the (b,c) blocks. access (b,c) blocks via C[0,:,:] outer_grad_log_L = np.einsum("ik, jk -> ijk", Vw, X_hat) sum_grad_log_K = 0 sum_grad_log_L = 0 for t in range(T): sum_grad_log_K += outer_grad_log_K[ nu * t:nu * (t + 1), nx * t:nx * (t + 1), :] # Summing all diagonal blocks. gives p by d by batch_size sum_grad_log_L += outer_grad_log_L[nw * t:nw * (t + 1), nx * t:nx * (t + 1), :] mini_batch_K = (1 / sigma_K)**2 * ( (reward - new_baseline).T * sum_grad_log_K ) #mini_batch is p by d, same size as K mini_batch_L = (1 / sigma_L)**2 * ( (reward - new_baseline).T * sum_grad_log_L ) # mini_batch is b by a/d, same size as K # mini_batch_K = 2 * ((reward-new_baseline).T*sum_grad_log_K) #mini_batch is p by d, same size as K # mini_batch_L = 2 * ((reward - new_baseline).T * sum_grad_log_L) # mini_batch is b by a/d, same size as K # print(mini_batch_K[0,0,:]) temp = np.einsum('mnr,ndr->mdr', sum_grad_log_K.swapaxes(0, 1), sum_grad_log_L) batch_mixed_KL = (1 / (sigma_K * sigma_L))**2 * ( (reward - new_baseline).T * temp) # print('---new---',sum_grad_log_K[:,:,10][0,0]) return np.mean(mini_batch_K, axis=2), np.mean(mini_batch_L, axis=2), np.mean(batch_mixed_KL, axis=2), new_baseline
def loop_body(i, acc_arr): arr1 = ops.index_update(acc_arr, i, acc_arr[i] + 2.) return lax.cond( i % 2 == 0, arr1, lambda arr1: ops.index_update(arr1, i, arr1[i] + 1.), arr1, lambda arr1: arr1)
def _body_fn(i, vals): val, collection = vals val = body_fun(val) collection = ops.index_update(collection, i, ravel_fn(val)) return val, collection
def _body_fn(i, vals): val, collection = vals val = body_fun(val) i = np.where(i >= lower, i - lower, 0) collection = ops.index_update(collection, i, ravel_fn(val)) return val, collection
def lanczos_alg(matrix_vector_product, dim, order, rng_key): """Lanczos algorithm for tridiagonalizing a real symmetric matrix. This function applies Lanczos algorithm of a given order. This function does full reorthogonalization. WARNING: This function may take a long time to jit compile (e.g. ~3min for order 90 and dim 1e7). Args: matrix_vector_product: Maps v -> Hv for a real symmetric matrix H. Input/Output must be of shape [dim]. dim: Matrix H is [dim, dim]. order: An integer corresponding to the number of Lanczos steps to take. rng_key: The jax PRNG key. Returns: tridiag: A tridiagonal matrix of size (order, order). vecs: A numpy array of size (order, dim) corresponding to the Lanczos vectors. """ tridiag = np.zeros((order, order)) vecs = np.zeros((order, dim)) init_vec = random.normal(rng_key, shape=(dim,)) init_vec = init_vec / np.linalg.norm(init_vec) vecs = ops.index_update(vecs, 0, init_vec) beta = 0 # TODO(gilmer): Better to use lax.fori loop for faster compile? for i in range(order): v = vecs[i, :].reshape((dim)) if i == 0: v_old = 0 else: v_old = vecs[i - 1, :].reshape((dim)) w = matrix_vector_product(v) assert (w.shape[0] == dim and len(w.shape) == 1), ( 'Output of matrix_vector_product(v) must be of shape [dim].') w = w - beta * v_old alpha = np.dot(w, v) tridiag = ops.index_update(tridiag, (i, i), alpha) w = w - alpha * v # Full Reorthogonalization for j in range(i): tau = vecs[j, :].reshape((dim)) coeff = np.dot(w, tau) w += -coeff * tau beta = np.linalg.norm(w) # TODO(gilmer): The tf implementation raises an exception if beta < 1e-6 # here. However JAX cannot compile a function that has an if statement # that depends on a dynamic variable. Should we still handle this base? # beta being small indicates that the lanczos vectors are linearly # dependent. if i + 1 < order: tridiag = ops.index_update(tridiag, (i, i+1), beta) tridiag = ops.index_update(tridiag, (i+1, i), beta) vecs = ops.index_update(vecs, i+1, w/beta) return (tridiag, vecs)
ax.set_xlim([0,500]) plt.show() # reward feebdack plots ax = plt.subplot(2,1,1) ax.plot( r[1] ) # expected reward # maybe expected reward is shifted by one trial compared to plosone. # should it be expected reward at the start of the trial? # or expected reward for next trial? ax.plot( r[0] ) # actual reward ax.set_xlim([0,500]) ax = plt.subplot(2,1,2) ax.plot( [0,500],[0,0],'--', color = [0,0,0]) ax.plot( r[2] ) # reward prediction error ax.set_ylim([-1,1]) ax.set_xlim([0,500]) plt.show() # this simulation produces Healthy BG population activity # from the plosone paper as in Fig4 and Fig6 key = PRNGKey( time.time_ns() ) w_pfc = 0.01*uniform(key,(2,3,2)) w_pfc = index_update( w_pfc, index[0,0,0], 0.7 ) w_pfc = index_update( w_pfc, index[1,1,0], 0.7 ) uu = jnp.zeros((nn,14)) uu = do_trial_for_figure( [key, w_pfc] ) plot_all(uu)
def compute_surface_fourier_series(self, r_surface): """ Inputs: r_surface is a NZ x NT x 3 array which has x,y,z as a function of zeta and theta. Outputs a 3 x 2 x 2 x WSNFZ + 1 x WSNFT + 1 array which contains the Fourier components of the surface. """ NZ = r_surface.shape[0] NT = r_surface.shape[1] x_s = r_surface[:, :, 0] y_s = r_surface[:, :, 1] z_s = r_surface[:, :, 2] # xyz x sin/cos(zeta) x sin/cos(theta) x fz x ft result = np.zeros((3, 2, 2, self.WSNFZ + 1, self.WSNFT + 1)) zeta = np.linspace(0, 2 * PI, NZ + 1)[0:NZ] theta = np.linspace(0, 2 * PI, NT + 1)[0:NT] # X^{cc}_{0,0} terms for x,y,z result = index_update(result, index[:, 1, 1, 0, 0], np.mean(r_surface, axis=(0, 1))) for m in range(1, self.WSNFZ + 1): # X_{cc}_{m,0} result = index_update(result, index[:,1,1,m,0], 2.0 * \ np.mean(r_surface[:,:,:] * \ np.cos(m * zeta)[:,np.newaxis,np.newaxis], axis=(0,1))) # X_{sc}_{m,0} result = index_update(result, index[:,0,1,m,0], 2.0 * \ np.mean(r_surface[:,:,:] * \ np.sin(m * zeta)[:,np.newaxis,np.newaxis], axis=(0,1))) for n in range(1, self.WSNFT + 1): # X_{cc}_{0,n} result = index_update(result, index[:,1,1,0,n], 2.0 * \ np.mean(r_surface[:,:,:] * \ np.cos(n * theta)[np.newaxis,:,np.newaxis], axis=(0,1))) # X_{cs}_{0,n} result = index_update(result, index[:,1,0,0,n], 2.0 * \ np.mean(r_surface[:,:,:] * \ np.sin(n * theta)[np.newaxis,:,np.newaxis], axis=(0,1))) for m in range(1, self.WSNFZ + 1): for n in range(1, self.WSNFT + 1): # X_{ss}_{m,n} result = index_update(result, index[:,0,0,m,n], 4.0 * \ np.mean(r_surface[:,:,:] * \ np.sin(n * theta)[np.newaxis,:,np.newaxis] * \ np.sin(m * zeta)[:,np.newaxis,np.newaxis], axis=(0,1))) # X_{cs}_{m,n} result = index_update(result, index[:,1,0,m,n], 4.0 * \ np.mean(r_surface[:,:,:] * \ np.sin(n * theta)[np.newaxis,:,np.newaxis] * \ np.cos(m * zeta)[:,np.newaxis,np.newaxis], axis=(0,1))) # X_{sc}_{m,n} result = index_update(result, index[:,0,1,m,n], 4.0 * \ np.mean(r_surface[:,:,:] * \ np.cos(n * theta)[np.newaxis,:,np.newaxis] * \ np.sin(m * zeta)[:,np.newaxis,np.newaxis], axis=(0,1))) # X_{cc}_{m,n} result = index_update(result, index[:,1,1,m,n], 4.0 * \ np.mean(r_surface[:,:,:] * \ np.cos(n * theta)[np.newaxis,:,np.newaxis] * \ np.cos(m * zeta)[:,np.newaxis,np.newaxis], axis=(0,1))) return result
def do_trial(kwr, reversal_learning): key = kwr[0] w_pfc = kwr[1] Re = kwr[2][0] key, subkey = split( key ) # jax makes us handle prng state ourselves vv = 0.1 * uniform(subkey,(7*2,) ) # set initial conditions before each simulation # i've taken these details from plosone modeldb matlab code: vv = index_update( vv, index[4] , vv[4]+.6 ) vv = index_update( vv, index[4+6] , vv[4+6]+.6 ) vv = index_update( vv, index[0] , 0. ) vv = index_update( vv, index[1] , 0. ) vk = [vv,key] sim_step = partial( simulation_step, w_pfc = w_pfc ) # for debugging purposes use this loop: # (but don't call this do_trial methods hundreds of times) #for i in range(nn-1): # vk = sim_step(i+1,vk) # for performance use this "loop": vv,key = lax.fori_loop( 1, nn, sim_step, vk ) # rename variables for sanity pfc = vv[:2] # note this is of length two d1A = vv[2] d2A = vv[3] pmcA = vv[7] d1B = vv[2+6] d2B = vv[3+6] pmcB = vv[7+6] pmc = [pmcA,pmcB] # these jax.lax.cond constructs replace some traditional # "if statement" condition logic blocks. this is done for # quick and easy jax.jit comaptibility. please check out # the jax documentation. # this picks the rewarded action rewardedAction, otherAction = lax.cond( reversal_learning, pmc, lambda x: [x[1],x[0]], pmc, lambda x: [x[0],x[1]] ) # this determines reward for this trial R_trial = lax.cond( rewardedAction > otherAction + 0.1, None, lambda x: 1, None, lambda x: 0 ) # reward prediction error: SNc = R_trial - Re # expected reward for next trial: a = 0.15 Re_next = a * R_trial + (1 - a)*Re # weight updates: # the i,j,k notation below refers to a diagram that i drew and tacked # to my cork board. it should end up in git repo. # hopefully these comments explain the array # operations that we use to update all 12 weights with just few commands # i notation is ellided; that is the dimension of cues {#1,#2} # but we are instead just handling the vector pfc (shape = (2,)) # j notation here indicates dimension of bg loops: {A,B} # k notation here indicates dimension of neuronal populations: {d1,d2,pmc} # qjk: array with population firing rates in each loop # sk: may modify qjk with SNc (which is reward prediction error) sk = jnp.array( [SNc,SNc,1] ) qjk = jnp.array( [[d1A,d2A,pmcA],[d1B,d2B,pmcB]] ) # sq: product of sk and qjk; # this should only modify d1,d2 msns by SNc; pmc is multiplied by 1 # why? explanation: # pfc -> d1,d2 weights are modified by reward prediction error # pfc -> pmc weights are updated in a hebbian fashion sq = (sk * qjk).reshape((2,3,1)) # we reshape this product from (2,3) to (2,3,1) # in preparation for weight update operations # weight update rule: # pfc * sq # (2,) * (2,3,1) -> (2,3,2), which is same shape as w_pfc # lrk: learning rate for each population (3,1) # lrk * ( the product of pfc and sq): # (3,1) * (2,3,2) -> (2,3,2) # frk: forgetting rate for each population (3,1) # frk * w_pfc: # (3,1) * (2,3,2) -> (2,3,2) dw_pfc = lrk * pfc * sq - frk * w_pfc # update weights; force new weights to be positive: w_pfc = jnp.clip( w_pfc + dw_pfc, a_min = 0) # w_pfc should be (2,3,2) return key, w_pfc, [Re_next,R_trial,SNc] , pmc
def log_marginal_likelihood(self, theta=None, eval_gradient=False, clone_kernel=False): """Returns log-marginal likelihood of theta for training data. Parameters ---------- theta : array-like of shape (n_kernel_params,) or None Kernel hyperparameters for which the log-marginal likelihood is evaluated. If None, the precomputed log_marginal_likelihood of ``self.kernel_.theta`` is returned. eval_gradient : bool, default: False If True, the gradient of the log-marginal likelihood with respect to the kernel hyperparameters at position theta is returned additionally. If True, theta must not be None. clone_kernel : bool, default=True If True, the kernel attribute is copied. If False, the kernel attribute is modified, but may result in a performance improvement. Returns ------- log_likelihood : float Log-marginal likelihood of theta for training data. log_likelihood_gradient : array, shape = (n_kernel_params,), optional Gradient of the log-marginal likelihood with respect to the kernel hyperparameters at position theta. Only returned when eval_gradient is True. """ if theta is None: if eval_gradient: raise ValueError( "Gradient can only be evaluated for theta!=None") return self.log_marginal_likelihood_value_ kernel_matrix_fn = self.kernel_.get_kernel_matrix_fn(eval_gradient) if eval_gradient: K, K_gradient = kernel_matrix_fn(theta, self.X_train_, None) else: K = kernel_matrix_fn(theta, self.X_train_, None) # Compute log-marginal-likelihood Z and also store some temporaries # which can be reused for computing Z's gradient Z, (pi, W_sr, L, b, a) = \ self._posterior_mode(K, return_temporaries=True) if not eval_gradient: return Z # Compute gradient based on Algorithm 5.1 of GPML d_Z = np.empty(theta.shape[0]) # XXX: Get rid of the np.diag() in the next line R = W_sr[:, np.newaxis] * cho_solve((L, True), np.diag(W_sr)) # Line 7 C = solve(L, W_sr[:, np.newaxis] * K) # Line 8 # Line 9: (use einsum to compute np.diag(C.T.dot(C)))) s_2 = -0.5 * (np.diag(K) - np.einsum('ij, ij -> j', C, C)) \ * (pi * (1 - pi) * (1 - 2 * pi)) # third derivative for j in range(d_Z.shape[0]): C = K_gradient[:, :, j] # Line 11 # Line 12: (R.T.ravel().dot(C.ravel()) = np.trace(R.dot(C))) s_1 = .5 * a.T.dot(C).dot(a) - .5 * R.T.ravel().dot(C.ravel()) b = C.dot(self.y_train_ - pi) # Line 13 s_3 = b - K.dot(R.dot(b)) # Line 14 d_Z = ops.index_update(d_Z, j, s_1 + s_2.T.dot(s_3)) # Line 15 return (numpy.asarray(Z, dtype=numpy.float64), numpy.asarray(d_Z, dtype=numpy.float64))
def filter(self, x_hist, jump_size, dt): """ Compute the online version of the Kalman-Filter, i.e, the one-step-ahead prediction for the hidden state or the time update step Parameters ---------- x_hist: array(timesteps, observation_size) Returns ------- * array(timesteps, state_size): Filtered means mut * array(timesteps, state_size, state_size) Filtered covariances Sigmat * array(timesteps, state_size) Filtered conditional means mut|t-1 * array(timesteps, state_size, state_size) Filtered conditional covariances Sigmat|t-1 """ I = jnp.eye(self.state_size) timesteps, *_ = x_hist.shape mu_hist = jnp.zeros((timesteps, self.state_size)) Sigma_hist = jnp.zeros((timesteps, self.state_size, self.state_size)) Sigma_cond_hist = jnp.zeros((timesteps, self.state_size, self.state_size)) mu_cond_hist = jnp.zeros((timesteps, self.state_size)) # Initial configuration K1 = self.Sigma0 @ self.C.T @ inv(self.C @ self.Sigma0 @ self.C.T + self.R) mu1 = self.mu0 + K1 @ (x_hist[0] - self.C @ self.mu0) Sigma1 = (I - K1 @ self.C) @ self.Sigma0 mu_hist = index_update(mu_hist, 0, mu1) Sigma_hist = index_update(Sigma_hist, 0, Sigma1) mu_cond_hist = index_update(mu_cond_hist, 0, self.mu0) Sigma_cond_hist = index_update(Sigma_hist, 0, self.Sigma0) Sigman = Sigma1.copy() mun = mu1.copy() for n in range(1, timesteps): # Runge-kutta integration step for _ in range(jump_size): k1 = self.A @ mun k2 = self.A @ (mun + dt * k1) mun = mun + dt * (k1 + k2) / 2 k1 = self.A @ Sigman @ self.A.T + self.Q k2 = self.A @ (Sigman + dt * k1) @ self.A.T + self.Q Sigman = Sigman + dt * (k1 + k2) / 2 Sigman_cond = Sigman.copy() St = self.C @ Sigman_cond @ self.C.T + self.R Kn = Sigman_cond @ self.C.T @ inv(St) mu_update = mun.copy() x_update = self.C @ mun mun = mu_update + Kn @ (x_hist[n] - x_update) Sigman = (I - Kn @ self.C) @ Sigman_cond mu_hist = index_update(mu_hist, n, mun) Sigma_hist = index_update(Sigma_hist, n, Sigman) mu_cond_hist = index_update(mu_cond_hist, n, mu_update) Sigma_cond_hist = index_update(Sigma_cond_hist, n, Sigman_cond) return mu_hist, Sigma_hist, mu_cond_hist, Sigma_cond_hist
# print(cov) rng, key = random.split(rng) comp = random.choice(key, M, shape=(N, ), p=pi) samples = jnp.zeros(shape=(N, D), dtype=float) rng, *key = random.split(rng, M + 1) for j in range(M): idxs = j == comp n_j = idxs.sum() if n_j > 0: x = random.multivariate_normal(key[j], mean=mu[j], cov=cov[j], shape=(n_j, )) samples = index_update(samples, index[idxs, :], x) true_S = jnp.array([ jnp.append(jnp.append(cov[j] + jnp.outer(mu[j], mu[j]), jnp.array([mu[j]]), axis=0), jnp.array([jnp.append(mu[j], 1)]).T, axis=1) for j in range(M) ]) true_eta = jnp.array([jnp.log(pi[j] / pi[-1]) for j in range(M - 1)]) piemp = jnp.array([jnp.mean(comp == i) for i in range(M)]) muemp = jnp.array([jnp.mean(samples[comp == i], axis=0) for i in range(M)]) covemp = jnp.array([ (samples[comp == i].T @ samples[comp == i]) / jnp.sum(comp == i) for i in range(M)
def _cofactor_solve(a, b): """Equivalent to det(a)*solve(a, b) for nonsingular mat. Intermediate function used for jvp and vjp of det. This function borrows heavily from jax.numpy.linalg.solve and jax.numpy.linalg.slogdet to compute the gradient of the determinant in a way that is well defined even for low rank matrices. This function handles two different cases: * rank(a) == n or n-1 * rank(a) < n-1 For rank n-1 matrices, the gradient of the determinant is a rank 1 matrix. Rather than computing det(a)*solve(a, b), which would return NaN, we work directly with the LU decomposition. If a = p @ l @ u, then det(a)*solve(a, b) = prod(diag(u)) * u^-1 @ l^-1 @ p^-1 b = prod(diag(u)) * triangular_solve(u, solve(p @ l, b)) If a is rank n-1, then the lower right corner of u will be zero and the triangular_solve will fail. Let x = solve(p @ l, b) and y = det(a)*solve(a, b). Then y_{n} x_{n} / u_{nn} * prod_{i=1...n}(u_{ii}) = x_{n} * prod_{i=1...n-1}(u_{ii}) So by replacing the lower-right corner of u with prod_{i=1...n-1}(u_{ii})^-1 we can avoid the triangular_solve failing. To correctly compute the rest of y_{i} for i != n, we simply multiply x_{i} by det(a) for all i != n, which will be zero if rank(a) = n-1. For the second case, a check is done on the matrix to see if `solve` returns NaN or Inf, and gives a matrix of zeros as a result, as the gradient of the determinant of a matrix with rank less than n-1 is 0. This will still return the correct value for rank n-1 matrices, as the check is applied *after* the lower right corner of u has been updated. Args: a: A square matrix or batch of matrices, possibly singular. b: A matrix, or batch of matrices of the same dimension as a. Returns: det(a) and cofactor(a)^T*b, aka adjugate(a)*b """ a = _promote_arg_dtypes(jnp.asarray(a)) b = _promote_arg_dtypes(jnp.asarray(b)) a_shape = jnp.shape(a) b_shape = jnp.shape(b) a_ndims = len(a_shape) if not (a_ndims >= 2 and a_shape[-1] == a_shape[-2] and b_shape[-2:] == a_shape[-2:]): msg = ("The arguments to _cofactor_solve must have shapes " "a=[..., m, m] and b=[..., m, m]; got a={} and b={}") raise ValueError(msg.format(a_shape, b_shape)) if a_shape[-1] == 1: return a[0, 0], b # lu contains u in the upper triangular matrix and l in the strict lower # triangular matrix. # The diagonal of l is set to ones without loss of generality. lu, pivots, permutation = lax_linalg.lu(a) dtype = lax.dtype(a) batch_dims = lax.broadcast_shapes(lu.shape[:-2], b.shape[:-2]) x = jnp.broadcast_to(b, batch_dims + b.shape[-2:]) lu = jnp.broadcast_to(lu, batch_dims + lu.shape[-2:]) # Compute (partial) determinant, ignoring last diagonal of LU diag = jnp.diagonal(lu, axis1=-2, axis2=-1) parity = jnp.count_nonzero(pivots != jnp.arange(a_shape[-1]), axis=-1) sign = jnp.asarray(-2 * (parity % 2) + 1, dtype=dtype) # partial_det[:, -1] contains the full determinant and # partial_det[:, -2] contains det(u) / u_{nn}. partial_det = jnp.cumprod(diag, axis=-1) * sign[..., None] lu = ops.index_update(lu, ops.index[..., -1, -1], 1.0 / partial_det[..., -2]) permutation = jnp.broadcast_to(permutation, batch_dims + (a_shape[-1],)) iotas = jnp.ix_(*(lax.iota(jnp.int32, b) for b in batch_dims + (1,))) # filter out any matrices that are not full rank d = jnp.ones(x.shape[:-1], x.dtype) d = lax_linalg.triangular_solve(lu, d, left_side=True, lower=False) d = jnp.any(jnp.logical_or(jnp.isnan(d), jnp.isinf(d)), axis=-1) d = jnp.tile(d[..., None, None], d.ndim*(1,) + x.shape[-2:]) x = jnp.where(d, jnp.zeros_like(x), x) # first filter x = x[iotas[:-1] + (permutation, slice(None))] x = lax_linalg.triangular_solve(lu, x, left_side=True, lower=True, unit_diagonal=True) x = jnp.concatenate((x[..., :-1, :] * partial_det[..., -1, None, None], x[..., -1:, :]), axis=-2) x = lax_linalg.triangular_solve(lu, x, left_side=True, lower=False) x = jnp.where(d, jnp.zeros_like(x), x) # second filter return partial_det[..., -1], x
def copy_values_from_cell(value, cell_value, cell_id): scatter_indices = jnp.reshape(cell_id, (-1,)) cell_value = jnp.reshape(cell_value, (-1,) + cell_value.shape[-2:]) return ops.index_update(value, scatter_indices, cell_value)
def scan_fn(BB, elems): o, g = elems BB = index_update(BB, index[:, o], BB[:, o] + g) return BB, jnp.zeros((0,))
def step_fn(i, state_and_energy): state, energy = state_and_energy state = apply_fn(state) energy = ops.index_update(energy, i, E_T(state)) return state, energy
def build_cells(R): N = R.shape[0] dim = R.shape[1] if dim != 2 and dim != 3: raise ValueError( 'Cell list spatial dimension must be 2 or 3. Found {}'.format(dim)) neighborhood_tile_count = 3 ** dim _, cell_size, cells_per_side, cell_count = \ _cell_dimensions(dim, box_size, minimum_cell_size) if species is None: _species = np.zeros((N,), dtype=i32) else: _species = species hash_multipliers = _compute_hash_constants(dim, cells_per_side) # Create cell list data. particle_id = lax.iota(np.int64, N) mask_id = np.ones((N,), np.int64) * N cell_R = np.zeros((cell_count * cell_capacity, dim), dtype=R.dtype) empty_species_index = i32(1000) cell_species = empty_species_index * np.ones( (cell_count * cell_capacity, 1), dtype=_species.dtype) cell_id = N * np.ones((cell_count * cell_capacity, 1), dtype=i32) indices = np.array(R / cell_size, dtype=i32) hashes = np.sum(indices * hash_multipliers, axis=1) # Copy the particle data into the grid. Here we use a trick to allow us to # copy into all cells simultaneously using a single lax.scatter call. To do # this we first sort particles by their cell hash. We then assign each # particle to have a cell id = hash * cell_capacity + grid_id where grid_id # is a flat list that repeats 0, .., cell_capacity. So long as there are # fewer than cell_capacity particles per cell, each particle is guarenteed # to get a cell id that is unique. sort_map = np.argsort(hashes) sorted_R = R[sort_map] sorted_species = _species[sort_map] sorted_hash = hashes[sort_map] sorted_id = particle_id[sort_map] sorted_cell_id = np.mod(lax.iota(np.int64, N), cell_capacity) sorted_cell_id = sorted_hash * cell_capacity + sorted_cell_id cell_R = ops.index_update(cell_R, sorted_cell_id, sorted_R) sorted_species = np.reshape(sorted_species, (N, 1)) cell_species = ops.index_update( cell_species, sorted_cell_id, sorted_species) sorted_id = np.reshape(sorted_id, (N, 1)) cell_id = ops.index_update( cell_id, sorted_cell_id, sorted_id) cell_R = _unflatten_cell_buffer(cell_R, cells_per_side, dim) cell_species = _unflatten_cell_buffer(cell_species, cells_per_side, dim) cell_id = _unflatten_cell_buffer(cell_id, cells_per_side, dim) return CellList(N, dim, cell_count, cell_R, cell_species, cell_id)
def inv(self, y): size = self.permutation.size permutation_inv = ops.index_update(jnp.zeros(size, dtype=canonicalize_dtype(jnp.int64)), self.permutation, jnp.arange(size)) return y[..., permutation_inv]
def copy_values_from_cell(value, cell_value, cell_id): scatter_indices = np.reshape(cell_id, (-1,)) cell_value = np.reshape(cell_value, (-1, output_dimension)) return ops.index_update(value, scatter_indices, cell_value)
iteration.append(it) graph = np.stack((iteration, training_loss)) print("total time:", time.time() - tot, "s") # np.random.seed(42) t_test, W_test = fetch_minibatch(T, M, N, D) # X_pred, Y_pred, Y_tilde_pred, Z, DYDT = vXYZpaths(params, t_test, W_test, Xzero) # X_pred, Y_pred, Y_tilde_pred, Z = vXYZpaths(params, t_test, W_test, Xzero) X_pred, Y_pred, Y_tilde_pred, Z, DY_pred, DY_tilde_pred = vXYZpaths( params, t_test, W_test, Xzero) Dt = jnp.zeros((M, N + 1, 1)) # M x (N+1) x 1 dt = T / N new_Dt = index_update(Dt, index[:, 1:, :], dt) t_plot = jnp.cumsum(new_Dt, axis=1) # M x (N+1) x 1 Y_test = jnp.reshape( u_exact(np.reshape(t_plot[0:M, :, :], [-1, 1]), jnp.reshape(X_pred[0:M, :, :], [-1, D])), [M, -1, 1]) # fix all these uneccessary reshapes at some point np.save('t_test.npy', t_test) np.save('W_test.npy', W_test) np.save('t_plot.npy', t_plot) np.save('X_pred.npy', X_pred) np.save('Y_pred.npy', Y_pred) np.save('Y_tilde_pred.npy', Y_tilde_pred) np.save('Y_test.npy', Y_test) # np.save('DYDT_test.npy', DYDT)
def get_env(ipeps_tensors, chi_ctm, bvar_threshold, max_iter): # TODO should we symmetrise the ipeps tensor? a, = ipeps_tensors chi_peps = a.shape[1] # initialise environment # (p*,uldr) & (p,uldr) -> (uldr,uldr) -> (uu,ll,dd,rr) flat_tens = np.transpose(np.tensordot(np.conj(a), a, [0, 0]), [0, 4, 1, 5, 2, 6, 3, 7]) u, _u, l, _l, d, _d, r, _r = flat_tens.shape # (uu,ll,dd,rr) -> (U,L,d,d',R) flat_tens = np.reshape(flat_tens, [u * _u, l * _l, d, _d, r * _r]) c_init = np.sum(flat_tens, axis=(2, 3, 4)) # (D,R) t_init = np.sum(flat_tens, axis=0) # (L,d,d',R) if c_init.shape[0] > chi_ctm: c_init = c_init[:chi_ctm, :chi_ctm] t_init = t_init[:chi_ctm, :, :, :chi_ctm] # enforce c4v symmetry c_init, t_init = _c4v_symmetrise(c_init, t_init, normalise=True) # expand to full chi_ctm, for traceability _chi = c_init.shape[0] if _chi < chi_ctm: c_init = index_update(np.zeros([chi_ctm, chi_ctm], dtype=c_init.dtype), index[:_chi, :_chi], c_init) t_init = index_update( np.zeros([chi_ctm, chi_peps, chi_peps, chi_ctm], dtype=t_init.dtype), index[:_chi, :, :, :_chi], t_init) env_init = c_init, t_init def update(b, env): c, t = env # C insertion c_tilde = ncon([c, t, t, b, np.conj(b)], [[1, 2], [2, 3, 4, -4], [-1, 5, 6, 1], [7, 3, 5, -2, -5], [7, 4, 6, -3, -6]], [1, 2, 3, 5, 4, 6, 7]) # (D,d,d',R,r,r') -> (D~,R~) _D, _d, _d_, _R, _r, _r_ = c_tilde.shape c_tilde = np.reshape(c_tilde, [_D * _d * _d_, _R * _r * _r_]) # T insertion t_tilde = ncon( [t, b, np.conj(b)], [[-1, 1, 2, -6], [3, 1, -2, -4, -7], [3, 2, -3, -5, -8]]) # (L,l,l',d,d',R,r,r') -> (L~,d,d',R~) _L, _l, _l_, _d, _d_, _R, _r, _r_ = t_tilde.shape t_tilde = np.reshape(t_tilde, [_L * _l * _l_, _d, _d_, _R * _r * _r_]) # enforce symmetry c_tilde = _c4v_symmetrise_c(c_tilde) # find projector P, _, _ = svd_truncated(c_tilde, chi_max=chi_ctm, cutoff=0.) # (D~,R) # renormalise c = np.transpose(P) @ c_tilde @ P t = ncon([np.conj(P), t_tilde, np.conj(P)], [[1, -1], [1, -2, -3, 2], [2, -4]]) # enforce symmetry env = _c4v_symmetrise(c, t) return env def convergence_condition(b, env, _): c, t = env return _variance3(c, t, b) < bvar_threshold env_star = fixed_points.fixed_point_novjp(update, a, env_init, convergence_condition, max_iter=max_iter) return env_star