def interp2d( x: jnp.ndarray, y: jnp.ndarray, xp: jnp.ndarray, yp: jnp.ndarray, zp: jnp.ndarray, fill_value: jnp.ndarray = None, ) -> jnp.ndarray: """ Bilinear interpolation on a grid. Args: x, y: 1D arrays of point at which to interpolate. Any out-of-bounds coordinates will be clamped to lie in-bounds. xp, yp: 1D arrays of points specifying grid points where function values are provided. zp: 2D array of function values. For a function `f(x, y)` this must satisfy `zp[i, j] = f(xp[i], yp[j])` Returns: 1D array `z` satisfying `z[i] = f(x[i], y[i])`. """ if xp.ndim != 1 or yp.ndim != 1: raise ValueError("xp and yp must be 1D arrays") if zp.shape != (xp.shape + yp.shape): raise ValueError("zp must be a 2D array with shape xp.shape + yp.shape") ix = jnp.clip(jnp.searchsorted(xp, x, side="right"), 1, len(xp) - 1) iy = jnp.clip(jnp.searchsorted(yp, y, side="right"), 1, len(yp) - 1) # Using Wikipedia's notation (https://en.wikipedia.org/wiki/Bilinear_interpolation) z_11 = zp[ix - 1, iy - 1] z_21 = zp[ix, iy - 1] z_12 = zp[ix - 1, iy] z_22 = zp[ix, iy] z_xy1 = (xp[ix] - x) / (xp[ix] - xp[ix - 1]) * z_11 + (x - xp[ix - 1]) / ( xp[ix] - xp[ix - 1] ) * z_21 z_xy2 = (xp[ix] - x) / (xp[ix] - xp[ix - 1]) * z_12 + (x - xp[ix - 1]) / ( xp[ix] - xp[ix - 1] ) * z_22 z = (yp[iy] - y) / (yp[iy] - yp[iy - 1]) * z_xy1 + (y - yp[iy - 1]) / ( yp[iy] - yp[iy - 1] ) * z_xy2 if fill_value is not None: oob = (x < xp[0]) | (x > xp[-1]) | (y < yp[0]) | (y > yp[-1]) z = jnp.where(oob, fill_value, z) return z
def sample_pdf(bins, weights, num_samples, rng, det): weights = weights + 1e-5 pdf = weights / jnp.sum(weights, axis=-1, keepdims=True) cdf = jnp.cumsum(pdf, -1) cdf = jnp.concatenate((jnp.zeros_like(cdf[..., :1]), cdf), -1) if det: u = jnp.linspace(0.0, 1.0, num_samples) u = jnp.repeat(jnp.expand_dims(u, 0), cdf.shape[:-1], axis=0) else: u = jax.random.uniform(rng, list(cdf.shape[:-1]) + [num_samples]) inds = vmap(lambda cdf_i, u_i: jnp.searchsorted(cdf_i, u_i, side="right"). astype(np.int32))(cdf, u) below = jnp.maximum(0, inds - 1) above = jnp.minimum(cdf.shape[-1] - 1, inds) inds_g = jnp.stack((below, above), axis=-1) cdf_g = vmap(lambda cdf_i, inds_gi: cdf_i[inds_gi])(cdf, inds_g) bins_g = vmap(lambda bins_i, inds_gi: bins_i[inds_gi])(bins, inds_g) # don't know why we have to zero out the outliers? clean_inds = lambda arr, cutoff: jnp.where(inds_g < cutoff, arr, 0) cdf_g = clean_inds(cdf_g, cdf.shape[-1]) bins_g = clean_inds(bins_g, bins.shape[-1]) denom = cdf_g[..., 1] - cdf_g[..., 0] denom = jnp.where(denom < 1e-5, 1.0, denom) t = (u - cdf_g[..., 0]) / denom samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0]) return samples
def predict(x, xp, coef): a, b, c = coef idx = jnp.clip(jnp.searchsorted(xp, x) - 1, 0) y = a[idx] * x**2 + b[idx] * x + c[idx] return y
def __call__(self, x_new): i = np.searchsorted(self.x, x_new) - 1 i = np.where(i == -1, 0, i) i = np.where(i == len(self.x) - 1, -1, i) return self.y[i] + self.slopes[i] * (x_new - self.x[i])
def eval(self, t): # The first interval is 1, and so on... ix = jnp.searchsorted(self.ts, t) - 1 # In case t is before the beginning or after the end. ix = jnp.clip(ix, 0, self.Q.shape[0] - 1) return eval_spline(self.ts[ix], self.ts[ix + 1], self.Q[ix], self.y_old[ix], t)
def branch(key, configs, weights): """ Perform branching on a set of walkers by stochastic reconfiguration Walkers are resampled with probability proportional to the weights, and the new weights are all set to be equal to the average weight. Args: configs: (nconfig,nelec,3) walker coordinates weights: (nconfig,) walker weights Returns: configs: resampled walker configurations weights: (nconfig,) all weights are equal to average weight """ nconfig = configs.shape[0] wtot = jnp.sum(weights) probability = jnp.cumsum(weights / wtot) key, subkey = jax.random.split(key) base = jax.random.uniform(subkey) newinds = jnp.searchsorted(probability, (base + jnp.arange(nconfig) / nconfig) % 1.0) configs = configs[newinds] weights = jnp.ones((nconfig, )) * wtot / nconfig return configs, weights
def _searchsorted( # pylint: disable=unused-argument sorted_sequence, values, side='left', out_type=tf.int32, name=None): return np.searchsorted(sorted_sequence, values, side=side, sorter=None).astype(out_type)
def interp_np(grid, xnew, return_wnext=True, trim=False): # this finds grid positions and weights for performing linear interpolation # this implementation uses numpy if trim: xnew = np.minimum(grid[-1], np.maximum(grid[0], xnew)) j = np.minimum(np.searchsorted(grid, xnew, side='left') - 1, grid.size - 2) wnext = (xnew - grid[j]) / (grid[j + 1] - grid[j]) return j, (wnext if return_wnext else 1 - wnext)
def _find_indices(xi, grid): # find relevant edges between which xi are situated indices = [] # compute distance to lower edge in unity units norm_distances = [] # iterate through dimensions for x, grid in zip(xi, grid): i = jnp.searchsorted(grid, x) - 1 indices.append(i) norm_distances.append((x - grid[i]) / (grid[i + 1] - grid[i])) return indices, norm_distances
def body_fun(state): stepper, t_eval, i, y_out = state stepper = _bdf_step(stepper, fun_bind_inputs, jac_bind_inputs) index = jnp.searchsorted(t_eval, stepper.t) def for_body(j, y_out): t = t_eval[j] y_out = jax.ops.index_update(y_out, jax.ops.index[j, :], _bdf_interpolate(stepper, t)) return y_out y_out = jax.lax.fori_loop(i, index, for_body, y_out) return [stepper, t_eval, index, y_out]
def find_rw(rarr, vf, KzzpL): """finding rw from rarr and terminal velocity array. Args: rarr: particle radius array (cm) vf: terminal velocity (cm/s) KzzpL: Kzz/L in Ackerman and Marley 2001 Returns: rw in Ackerman and Marley 2001 """ iscale = jnp.searchsorted(vf, KzzpL) rw = rarr[iscale] return rw
def lotka_volterra_simulate(initial_prey_pred: jnp.ndarray, times: jnp.ndarray, params: jnp.ndarray, random_key: jnp.ndarray, max_iter: int = 10000) -> jnp.ndarray: max_time = jnp.max(times) (simulated_pred_prey, simulated_times), _, _ = _while_loop_stacked( lambda carry, extra: carry[1] < max_time, lotka_volterra_single_step, ((initial_prey_pred, times[0]), (params, random_key)), max_iter) simulated_times = jnp.where(simulated_times == 0, jnp.inf, simulated_times) return simulated_pred_prey[jnp.searchsorted(simulated_times, times[1:]) - 1]
def get_rw(vfs, Kzz, L, rarr): """compute rw in AM01 implicitly defined by (11) Args: vfs: terminal velocity (cm/s) Kzz: diffusion coefficient (cm2/s) L: typical convection scale (cm) rarr: condensate scale array Returns: rw: rw (cm) in AM01. i.e. condensate size that balances an upward transport and sedimentation """ iscale = jnp.searchsorted(vfs, Kzz / L) rw = rarr[iscale] return rw
def get_conditional_mean_matrices(self, x, t): ar, cr, ac, bc, cc, dc = self.get_coefficients() inds = np.searchsorted(x, t) _, U_star, V_star, _ = self.get_celerite_matrices(t, t) c = np.concatenate((cr, cc, cc)) dx = t - x[np.minimum(inds, x.size - 1)] U_star *= np.exp(-c[None, :] * dx[:, None]) dx = x[np.maximum(inds - 1, 0)] - t V_star *= np.exp(-c[None, :] * dx[:, None]) return U_star, V_star, inds
def temporal_conditional_infinite_horizon(X, X_test, mean, cov, gain, kernel): """ predict from time X to time X_test give state mean and covariance at X """ Pinf = kernel.stationary_covariance()[None, ...] minf = np.zeros([1, Pinf.shape[1], 1]) mean_aug = np.concatenate([minf, mean, minf]) # figure out which two training states each test point is located between ind_test = np.searchsorted(X.reshape(-1, ), X_test.reshape(-1, )) - 1 # project from training states to test locations test_mean = predict_from_state_infinite_horizon(X_test, ind_test, X, mean_aug, kernel) return test_mean, np.tile(cov[0], [test_mean.shape[0], 1, 1])
def temporal_conditional(X, X_test, mean, cov, gain, kernel): """ predict from time X to time X_test give state mean and covariance at X """ Pinf = kernel.stationary_covariance()[None, ...] minf = np.zeros([1, Pinf.shape[1], 1]) mean_aug = np.concatenate([minf, mean, minf]) cov_aug = np.concatenate([Pinf, cov, Pinf]) gain = np.concatenate([np.zeros_like(gain[:1]), gain]) # figure out which two training states each test point is located between ind_test = np.searchsorted(X.reshape(-1, ), X_test.reshape(-1, )) - 1 # project from training states to test locations test_mean, test_cov = predict_from_state(X_test, ind_test, X, mean_aug, cov_aug, gain, kernel) return test_mean, test_cov
def sample_pdf(bins, weights, num_importance, perturbation, rng): """Hierarchical sampler. Sample `num_importance` rays from `bins` with distribution defined by `weights`. Args: bins: (num_rays, num_samples - 1) bins to sample from weights: (num_rays, num_samples - 2) weights assigned to each sampled color for the coarse model num_importance: the number of samples to draw from the distribution perturbation: whether to apply jitter on each ray or not rng: random key Returns: samples: (num_rays, num_importance) the sampled rays """ # get pdf weights = jnp.clip(weights, 1e-5) # prevent NaNs pdf = weights / jnp.sum(weights, axis=-1, keepdims=True) cdf = jnp.cumsum(pdf, axis=-1) cdf = jnp.concatenate([jnp.zeros_like(cdf[..., :1]), cdf], axis=-1) # take uniform samples samples_shape = [*cdf.shape[:-1], num_importance] if perturbation: uni_samples = random.uniform(rng, shape=samples_shape) else: uni_samples = jnp.linspace(0.0, 1.0, num_importance) uni_samples = jnp.broadcast_to(uni_samples, samples_shape) # invert CDF idx = jax.vmap(lambda x, y: jnp.searchsorted(x, y, side="right"))( cdf, uni_samples) below = jnp.maximum(0, idx - 1) above = jnp.minimum(cdf.shape[-1] - 1, idx) inds_g = jnp.stack([below, above], axis=-1) cdf_g = jnp.take_along_axis(cdf[..., None], inds_g, axis=1) bins_g = jnp.take_along_axis(bins[..., None], inds_g, axis=1) denom = cdf_g[..., 1] - cdf_g[..., 0] # denom = jnp.where(denom < 1e-5, jnp.ones_like(denom), denom) denom = lax.select(denom < 1e-5, jnp.ones_like(denom), denom) t = (uni_samples - cdf_g[..., 0]) / denom samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0]) return samples
def rational_quadratic_parameters(x: jnp.ndarray, xk: jnp.ndarray, yk: jnp.ndarray, delta: jnp.ndarray) -> jnp.ndarray: """Compute necessary intermediate parameters used for computing the spline flow. For details on the rational quadratic transformation, consult [1]. [1] https://arxiv.org/pdf/1906.04032.pdf """ idxn = jnp.searchsorted(xk, x) idx = idxn - 1 xi = (x - xk[idx]) / (xk[idxn] - xk[idx]) ym = yk[idxn] - yk[idx] sk = ym / (xk[idxn] - xk[idx]) dk = delta[idx] dkp = delta[idxn] dp = dkp + dk xib = xi * (1.0 - xi) xisq = jnp.square(xi) return (idx, xi, xib, xisq, dk, dkp, dp, sk, ym)
def stratified(weights, key): """ Stratified resampling method Parameters ---------- weights: array_like Weights to resample from key: PRNGKey The random key used Returns ------- idx: array_like The indices of the resampled particles """ n_samples = weights.shape[0] cumsum = jnp.cumsum(weights) u = uniform(key, (n_samples,)) aux = (u + jnp.arange(n_samples)) / n_samples return jnp.searchsorted(cumsum, aux)
def resample(key, samples, log_weights, S=None): """ resample the samples with weights which are interpreted as log_probabilities. Args: samples: weights: Returns: S samples of equal weight """ if S is None: #ESS = (sum w)^2 / sum w^2 S = int(jnp.exp(2.* logsumexp(log_weights) - logsumexp(2.*log_weights))) # use cumulative_logsumexp because some log_weights could be really small log_p_cuml = cumulative_logsumexp(log_weights) p_cuml = jnp.exp(log_p_cuml) r = p_cuml[-1] * (1 - random.uniform(key, (S,))) idx = jnp.searchsorted(p_cuml, r) return dict_multimap(lambda s:s[idx,...], samples)
def set_z_stats(t, z): ind = (np.searchsorted(z.reshape(-1, ), t[:, :1].reshape(-1, )) - 1) num_neighbours = np.array( [np.sum(ind == m) for m in range(z.shape[0] - 1)]) return ind, num_neighbours
def get(self, step): idx = jnp.searchsorted(self.milestones, step, side='right') schedule = self.schedules[idx] base_idx = self.milestones[idx - 1] if idx >= 1 else 0 return schedule.get(step - base_idx)
def phase(self, t): return jnp.searchsorted(self.xp, t % self.period, side="right")
def searchsorted(a, v, side='left', sorter=None): if isinstance(a, JaxArray): a = a.value if isinstance(v, JaxArray): v = v.value return JaxArray(jnp.searchsorted(a, v, side=side, sorter=sorter))
def visualize_rays(t_vals, weights, rgbs, t_range, accumulate=False, renormalize=False, resolution=512, oversample=1024, bg_color=0.8): """Visualize a bundle of rays.""" t_vis = jnp.linspace(*t_range, oversample * resolution) vis_rgb, vis_alpha = [], [] for ts, ws, rs in zip(t_vals, weights, rgbs): vis_rs, vis_ws = [], [] for t, w, r in zip(ts, ws, rs): if accumulate: # Produce the accumulated color and weight at each point along the ray. w_csum = jnp.cumsum(w, axis=0) rw_csum = jnp.cumsum((r * w[:, None]), axis=0) eps = jnp.finfo(jnp.float32).eps r, w = (rw_csum + eps) / (w_csum[:, None] + 2 * eps), w_csum idx = jnp.searchsorted(t, t_vis) - 1 bounds = 0, len(t) - 2 mask = (idx >= bounds[0]) & (idx <= bounds[1]) r_mat = jnp.where(mask[:, None], r[jnp.clip(idx, *bounds), :], 0) w_mat = jnp.where(mask, w[jnp.clip(idx, *bounds)], 0) # Grab the highest-weighted value in each oversampled span. r_mat = r_mat.reshape(resolution, oversample, -1) w_mat = w_mat.reshape(resolution, oversample) mask = w_mat == w_mat.max(axis=1, keepdims=True) r_ray = (mask[Ellipsis, None] * r_mat).sum(axis=1) / jnp.maximum( 1, mask.sum(axis=1))[:, None] w_ray = (mask * w_mat).sum(axis=1) / jnp.maximum( 1, mask.sum(axis=1)) vis_rs.append(r_ray) vis_ws.append(w_ray) vis_rgb.append(jnp.stack(vis_rs)) vis_alpha.append(jnp.stack(vis_ws)) vis_rgb = jnp.stack(vis_rgb, axis=1) vis_alpha = jnp.stack(vis_alpha, axis=1) if renormalize: # Scale the alphas so that the largest value is 1, for visualization. vis_alpha /= jnp.maximum( jnp.finfo(jnp.float32).eps, jnp.max(vis_alpha)) if resolution > vis_rgb.shape[0]: rep = resolution // (vis_rgb.shape[0] * vis_rgb.shape[1] + 1) stride = rep * vis_rgb.shape[1] vis_rgb = vis_rgb.tile( (1, 1, rep, 1)).reshape((-1, ) + vis_rgb.shape[2:]) vis_alpha = vis_alpha.tile( (1, 1, rep)).reshape((-1, ) + vis_alpha.shape[2:]) # Add a strip of background pixels after each set of levels of rays. vis_rgb = vis_rgb.reshape((-1, stride) + vis_rgb.shape[1:]) vis_alpha = vis_alpha.reshape((-1, stride) + vis_alpha.shape[1:]) vis_rgb = jnp.concatenate( [vis_rgb, jnp.zeros_like(vis_rgb[:, :1])], axis=1).reshape((-1, ) + vis_rgb.shape[2:]) vis_alpha = jnp.concatenate( [vis_alpha, jnp.zeros_like(vis_alpha[:, :1])], axis=1).reshape((-1, ) + vis_alpha.shape[2:]) # Matte the RGB image over the background. vis = vis_rgb * vis_alpha[Ellipsis, None] + (bg_color * (1 - vis_alpha))[Ellipsis, None] # Remove the final row of background pixels. vis = vis[:-1] vis_alpha = vis_alpha[:-1] return vis, vis_alpha
def safe_gaussian_kde(samples, weights): try: return gaussian_kde(samples, weights=weights, bw_method='silverman') except: hist, bin_edges = jnp.histogram(samples,weights=weights, bins='auto') return lambda x: hist[jnp.searchsorted(bin_edges, x)]