示例#1
0
def interp2d(
    x: jnp.ndarray,
    y: jnp.ndarray,
    xp: jnp.ndarray,
    yp: jnp.ndarray,
    zp: jnp.ndarray,
    fill_value: jnp.ndarray = None,
) -> jnp.ndarray:
    """
    Bilinear interpolation on a grid.

    Args:
        x, y: 1D arrays of point at which to interpolate. Any out-of-bounds
            coordinates will be clamped to lie in-bounds.
        xp, yp: 1D arrays of points specifying grid points where function values
            are provided.
        zp: 2D array of function values. For a function `f(x, y)` this must
            satisfy `zp[i, j] = f(xp[i], yp[j])`

    Returns:
        1D array `z` satisfying `z[i] = f(x[i], y[i])`.
    """
    if xp.ndim != 1 or yp.ndim != 1:
        raise ValueError("xp and yp must be 1D arrays")
    if zp.shape != (xp.shape + yp.shape):
        raise ValueError("zp must be a 2D array with shape xp.shape + yp.shape")

    ix = jnp.clip(jnp.searchsorted(xp, x, side="right"), 1, len(xp) - 1)
    iy = jnp.clip(jnp.searchsorted(yp, y, side="right"), 1, len(yp) - 1)

    # Using Wikipedia's notation (https://en.wikipedia.org/wiki/Bilinear_interpolation)
    z_11 = zp[ix - 1, iy - 1]
    z_21 = zp[ix, iy - 1]
    z_12 = zp[ix - 1, iy]
    z_22 = zp[ix, iy]

    z_xy1 = (xp[ix] - x) / (xp[ix] - xp[ix - 1]) * z_11 + (x - xp[ix - 1]) / (
        xp[ix] - xp[ix - 1]
    ) * z_21
    z_xy2 = (xp[ix] - x) / (xp[ix] - xp[ix - 1]) * z_12 + (x - xp[ix - 1]) / (
        xp[ix] - xp[ix - 1]
    ) * z_22

    z = (yp[iy] - y) / (yp[iy] - yp[iy - 1]) * z_xy1 + (y - yp[iy - 1]) / (
        yp[iy] - yp[iy - 1]
    ) * z_xy2

    if fill_value is not None:
        oob = (x < xp[0]) | (x > xp[-1]) | (y < yp[0]) | (y > yp[-1])
        z = jnp.where(oob, fill_value, z)

    return z
示例#2
0
def sample_pdf(bins, weights, num_samples, rng, det):
    weights = weights + 1e-5
    pdf = weights / jnp.sum(weights, axis=-1, keepdims=True)
    cdf = jnp.cumsum(pdf, -1)
    cdf = jnp.concatenate((jnp.zeros_like(cdf[..., :1]), cdf), -1)

    if det:
        u = jnp.linspace(0.0, 1.0, num_samples)
        u = jnp.repeat(jnp.expand_dims(u, 0), cdf.shape[:-1], axis=0)
    else:
        u = jax.random.uniform(rng, list(cdf.shape[:-1]) + [num_samples])

    inds = vmap(lambda cdf_i, u_i: jnp.searchsorted(cdf_i, u_i, side="right").
                astype(np.int32))(cdf, u)

    below = jnp.maximum(0, inds - 1)
    above = jnp.minimum(cdf.shape[-1] - 1, inds)
    inds_g = jnp.stack((below, above), axis=-1)

    cdf_g = vmap(lambda cdf_i, inds_gi: cdf_i[inds_gi])(cdf, inds_g)
    bins_g = vmap(lambda bins_i, inds_gi: bins_i[inds_gi])(bins, inds_g)

    # don't know why we have to zero out the outliers?
    clean_inds = lambda arr, cutoff: jnp.where(inds_g < cutoff, arr, 0)
    cdf_g = clean_inds(cdf_g, cdf.shape[-1])
    bins_g = clean_inds(bins_g, bins.shape[-1])

    denom = cdf_g[..., 1] - cdf_g[..., 0]
    denom = jnp.where(denom < 1e-5, 1.0, denom)

    t = (u - cdf_g[..., 0]) / denom
    samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0])

    return samples
示例#3
0
def predict(x, xp, coef):

    a, b, c = coef
    idx = jnp.clip(jnp.searchsorted(xp, x) - 1, 0)
    y = a[idx] * x**2 + b[idx] * x + c[idx]

    return y
示例#4
0
    def __call__(self, x_new):

        i = np.searchsorted(self.x, x_new) - 1
        i = np.where(i == -1, 0, i)
        i = np.where(i == len(self.x) - 1, -1, i)

        return self.y[i] + self.slopes[i] * (x_new - self.x[i])
示例#5
0
 def eval(self, t):
     # The first interval is 1, and so on...
     ix = jnp.searchsorted(self.ts, t) - 1
     # In case t is before the beginning or after the end.
     ix = jnp.clip(ix, 0, self.Q.shape[0] - 1)
     return eval_spline(self.ts[ix], self.ts[ix + 1], self.Q[ix],
                        self.y_old[ix], t)
示例#6
0
def branch(key, configs, weights):
    """
    Perform branching on a set of walkers  by stochastic reconfiguration

    Walkers are resampled with probability proportional to the weights, and the new weights are all set to be equal to the average weight.
    
    Args:
      configs: (nconfig,nelec,3) walker coordinates

      weights: (nconfig,) walker weights

    Returns:
      configs: resampled walker configurations

      weights: (nconfig,) all weights are equal to average weight
    """
    nconfig = configs.shape[0]
    wtot = jnp.sum(weights)
    probability = jnp.cumsum(weights / wtot)
    key, subkey = jax.random.split(key)
    base = jax.random.uniform(subkey)
    newinds = jnp.searchsorted(probability,
                               (base + jnp.arange(nconfig) / nconfig) % 1.0)
    configs = configs[newinds]
    weights = jnp.ones((nconfig, )) * wtot / nconfig
    return configs, weights
示例#7
0
def _searchsorted(  # pylint: disable=unused-argument
        sorted_sequence,
        values,
        side='left',
        out_type=tf.int32,
        name=None):
    return np.searchsorted(sorted_sequence, values, side=side,
                           sorter=None).astype(out_type)
示例#8
0
def interp_np(grid, xnew, return_wnext=True, trim=False):
    # this finds grid positions and weights for performing linear interpolation
    # this implementation uses numpy

    if trim: xnew = np.minimum(grid[-1], np.maximum(grid[0], xnew))

    j = np.minimum(np.searchsorted(grid, xnew, side='left') - 1, grid.size - 2)
    wnext = (xnew - grid[j]) / (grid[j + 1] - grid[j])

    return j, (wnext if return_wnext else 1 - wnext)
示例#9
0
def _find_indices(xi, grid):
    # find relevant edges between which xi are situated
    indices = []
    # compute distance to lower edge in unity units
    norm_distances = []
    # iterate through dimensions
    for x, grid in zip(xi, grid):
        i = jnp.searchsorted(grid, x) - 1
        indices.append(i)
        norm_distances.append((x - grid[i]) / (grid[i + 1] - grid[i]))
    return indices, norm_distances
示例#10
0
    def body_fun(state):
        stepper, t_eval, i, y_out = state
        stepper = _bdf_step(stepper, fun_bind_inputs, jac_bind_inputs)
        index = jnp.searchsorted(t_eval, stepper.t)

        def for_body(j, y_out):
            t = t_eval[j]
            y_out = jax.ops.index_update(y_out, jax.ops.index[j, :],
                                         _bdf_interpolate(stepper, t))
            return y_out

        y_out = jax.lax.fori_loop(i, index, for_body, y_out)
        return [stepper, t_eval, index, y_out]
示例#11
0
def find_rw(rarr, vf, KzzpL):
    """finding rw from rarr and terminal velocity array.

    Args:
        rarr: particle radius array (cm)
        vf: terminal velocity (cm/s)
        KzzpL: Kzz/L in Ackerman and Marley 2001

    Returns:
        rw in Ackerman and Marley 2001
    """
    iscale = jnp.searchsorted(vf, KzzpL)
    rw = rarr[iscale]
    return rw
示例#12
0
def lotka_volterra_simulate(initial_prey_pred: jnp.ndarray,
                            times: jnp.ndarray,
                            params: jnp.ndarray,
                            random_key: jnp.ndarray,
                            max_iter: int = 10000) -> jnp.ndarray:
    max_time = jnp.max(times)

    (simulated_pred_prey, simulated_times), _, _ = _while_loop_stacked(
        lambda carry, extra: carry[1] < max_time, lotka_volterra_single_step,
        ((initial_prey_pred, times[0]), (params, random_key)), max_iter)

    simulated_times = jnp.where(simulated_times == 0, jnp.inf, simulated_times)
    return simulated_pred_prey[jnp.searchsorted(simulated_times, times[1:]) -
                               1]
示例#13
0
def get_rw(vfs, Kzz, L, rarr):
    """compute rw in AM01 implicitly defined by (11)

    Args:
       vfs: terminal velocity (cm/s)
       Kzz: diffusion coefficient (cm2/s)
       L: typical convection scale (cm)
       rarr: condensate scale array

    Returns:
       rw: rw (cm) in AM01. i.e. condensate size that balances an upward transport and sedimentation
    """
    iscale = jnp.searchsorted(vfs, Kzz / L)
    rw = rarr[iscale]
    return rw
示例#14
0
    def get_conditional_mean_matrices(self, x, t):
        ar, cr, ac, bc, cc, dc = self.get_coefficients()

        inds = np.searchsorted(x, t)
        _, U_star, V_star, _ = self.get_celerite_matrices(t, t)

        c = np.concatenate((cr, cc, cc))

        dx = t - x[np.minimum(inds, x.size - 1)]
        U_star *= np.exp(-c[None, :] * dx[:, None])

        dx = x[np.maximum(inds - 1, 0)] - t
        V_star *= np.exp(-c[None, :] * dx[:, None])

        return U_star, V_star, inds
示例#15
0
def temporal_conditional_infinite_horizon(X, X_test, mean, cov, gain, kernel):
    """
    predict from time X to time X_test give state mean and covariance at X
    """
    Pinf = kernel.stationary_covariance()[None, ...]
    minf = np.zeros([1, Pinf.shape[1], 1])
    mean_aug = np.concatenate([minf, mean, minf])

    # figure out which two training states each test point is located between
    ind_test = np.searchsorted(X.reshape(-1, ), X_test.reshape(-1, )) - 1

    # project from training states to test locations
    test_mean = predict_from_state_infinite_horizon(X_test, ind_test, X,
                                                    mean_aug, kernel)

    return test_mean, np.tile(cov[0], [test_mean.shape[0], 1, 1])
示例#16
0
def temporal_conditional(X, X_test, mean, cov, gain, kernel):
    """
    predict from time X to time X_test give state mean and covariance at X
    """
    Pinf = kernel.stationary_covariance()[None, ...]
    minf = np.zeros([1, Pinf.shape[1], 1])
    mean_aug = np.concatenate([minf, mean, minf])
    cov_aug = np.concatenate([Pinf, cov, Pinf])
    gain = np.concatenate([np.zeros_like(gain[:1]), gain])

    # figure out which two training states each test point is located between
    ind_test = np.searchsorted(X.reshape(-1, ), X_test.reshape(-1, )) - 1

    # project from training states to test locations
    test_mean, test_cov = predict_from_state(X_test, ind_test, X, mean_aug,
                                             cov_aug, gain, kernel)

    return test_mean, test_cov
示例#17
0
def sample_pdf(bins, weights, num_importance, perturbation, rng):
    """Hierarchical sampler.
    Sample `num_importance` rays from `bins` with distribution defined by `weights`.
    Args:
        bins: (num_rays, num_samples - 1) bins to sample from
        weights: (num_rays, num_samples - 2) weights assigned to each sampled color for the coarse model
        num_importance: the number of samples to draw from the distribution
        perturbation: whether to apply jitter on each ray or not
        rng: random key
    Returns:
        samples: (num_rays, num_importance) the sampled rays
    """
    # get pdf
    weights = jnp.clip(weights, 1e-5)  # prevent NaNs
    pdf = weights / jnp.sum(weights, axis=-1, keepdims=True)
    cdf = jnp.cumsum(pdf, axis=-1)
    cdf = jnp.concatenate([jnp.zeros_like(cdf[..., :1]), cdf], axis=-1)

    # take uniform samples
    samples_shape = [*cdf.shape[:-1], num_importance]
    if perturbation:
        uni_samples = random.uniform(rng, shape=samples_shape)
    else:
        uni_samples = jnp.linspace(0.0, 1.0, num_importance)
        uni_samples = jnp.broadcast_to(uni_samples, samples_shape)

    # invert CDF
    idx = jax.vmap(lambda x, y: jnp.searchsorted(x, y, side="right"))(
        cdf, uni_samples)

    below = jnp.maximum(0, idx - 1)
    above = jnp.minimum(cdf.shape[-1] - 1, idx)
    inds_g = jnp.stack([below, above], axis=-1)

    cdf_g = jnp.take_along_axis(cdf[..., None], inds_g, axis=1)
    bins_g = jnp.take_along_axis(bins[..., None], inds_g, axis=1)

    denom = cdf_g[..., 1] - cdf_g[..., 0]
    # denom = jnp.where(denom < 1e-5, jnp.ones_like(denom), denom)
    denom = lax.select(denom < 1e-5, jnp.ones_like(denom), denom)
    t = (uni_samples - cdf_g[..., 0]) / denom
    samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0])
    return samples
示例#18
0
def rational_quadratic_parameters(x: jnp.ndarray, xk: jnp.ndarray,
                                  yk: jnp.ndarray,
                                  delta: jnp.ndarray) -> jnp.ndarray:
    """Compute necessary intermediate parameters used for computing the spline
    flow. For details on the rational quadratic transformation, consult [1].

    [1] https://arxiv.org/pdf/1906.04032.pdf

    """
    idxn = jnp.searchsorted(xk, x)
    idx = idxn - 1
    xi = (x - xk[idx]) / (xk[idxn] - xk[idx])
    ym = yk[idxn] - yk[idx]
    sk = ym / (xk[idxn] - xk[idx])
    dk = delta[idx]
    dkp = delta[idxn]
    dp = dkp + dk
    xib = xi * (1.0 - xi)
    xisq = jnp.square(xi)
    return (idx, xi, xib, xisq, dk, dkp, dp, sk, ym)
def stratified(weights, key):
    """
    Stratified resampling method

    Parameters
    ----------
    weights: array_like
        Weights to resample from
    key: PRNGKey
        The random key used

    Returns
    -------
    idx: array_like
        The indices of the resampled particles
    """
    n_samples = weights.shape[0]
    cumsum = jnp.cumsum(weights)
    u = uniform(key, (n_samples,))
    aux = (u + jnp.arange(n_samples)) / n_samples
    return jnp.searchsorted(cumsum, aux)
示例#20
0
文件: utils.py 项目: fehiepsi/jaxns
def resample(key, samples, log_weights, S=None):
    """
    resample the samples with weights which are interpreted as log_probabilities.
    Args:
        samples:
        weights:

    Returns: S samples of equal weight

    """
    if S is None:

        #ESS = (sum w)^2 / sum w^2

        S = int(jnp.exp(2.* logsumexp(log_weights) - logsumexp(2.*log_weights)))

    # use cumulative_logsumexp because some log_weights could be really small
    log_p_cuml = cumulative_logsumexp(log_weights)
    p_cuml = jnp.exp(log_p_cuml)
    r = p_cuml[-1] * (1 - random.uniform(key, (S,)))
    idx = jnp.searchsorted(p_cuml, r)
    return dict_multimap(lambda s:s[idx,...], samples)
示例#21
0
def set_z_stats(t, z):
    ind = (np.searchsorted(z.reshape(-1, ), t[:, :1].reshape(-1, )) - 1)
    num_neighbours = np.array(
        [np.sum(ind == m) for m in range(z.shape[0] - 1)])
    return ind, num_neighbours
示例#22
0
 def get(self, step):
     idx = jnp.searchsorted(self.milestones, step, side='right')
     schedule = self.schedules[idx]
     base_idx = self.milestones[idx - 1] if idx >= 1 else 0
     return schedule.get(step - base_idx)
示例#23
0
 def phase(self, t):
     return jnp.searchsorted(self.xp, t % self.period, side="right")
示例#24
0
def searchsorted(a, v, side='left', sorter=None):
  if isinstance(a, JaxArray): a = a.value
  if isinstance(v, JaxArray): v = v.value
  return JaxArray(jnp.searchsorted(a, v, side=side, sorter=sorter))
示例#25
0
def visualize_rays(t_vals,
                   weights,
                   rgbs,
                   t_range,
                   accumulate=False,
                   renormalize=False,
                   resolution=512,
                   oversample=1024,
                   bg_color=0.8):
    """Visualize a bundle of rays."""
    t_vis = jnp.linspace(*t_range, oversample * resolution)
    vis_rgb, vis_alpha = [], []
    for ts, ws, rs in zip(t_vals, weights, rgbs):
        vis_rs, vis_ws = [], []
        for t, w, r in zip(ts, ws, rs):
            if accumulate:
                # Produce the accumulated color and weight at each point along the ray.
                w_csum = jnp.cumsum(w, axis=0)
                rw_csum = jnp.cumsum((r * w[:, None]), axis=0)
                eps = jnp.finfo(jnp.float32).eps
                r, w = (rw_csum + eps) / (w_csum[:, None] + 2 * eps), w_csum
            idx = jnp.searchsorted(t, t_vis) - 1
            bounds = 0, len(t) - 2
            mask = (idx >= bounds[0]) & (idx <= bounds[1])
            r_mat = jnp.where(mask[:, None], r[jnp.clip(idx, *bounds), :], 0)
            w_mat = jnp.where(mask, w[jnp.clip(idx, *bounds)], 0)
            # Grab the highest-weighted value in each oversampled span.
            r_mat = r_mat.reshape(resolution, oversample, -1)
            w_mat = w_mat.reshape(resolution, oversample)
            mask = w_mat == w_mat.max(axis=1, keepdims=True)
            r_ray = (mask[Ellipsis, None] * r_mat).sum(axis=1) / jnp.maximum(
                1, mask.sum(axis=1))[:, None]
            w_ray = (mask * w_mat).sum(axis=1) / jnp.maximum(
                1, mask.sum(axis=1))
            vis_rs.append(r_ray)
            vis_ws.append(w_ray)
        vis_rgb.append(jnp.stack(vis_rs))
        vis_alpha.append(jnp.stack(vis_ws))
    vis_rgb = jnp.stack(vis_rgb, axis=1)
    vis_alpha = jnp.stack(vis_alpha, axis=1)

    if renormalize:
        # Scale the alphas so that the largest value is 1, for visualization.
        vis_alpha /= jnp.maximum(
            jnp.finfo(jnp.float32).eps, jnp.max(vis_alpha))

    if resolution > vis_rgb.shape[0]:
        rep = resolution // (vis_rgb.shape[0] * vis_rgb.shape[1] + 1)
        stride = rep * vis_rgb.shape[1]

        vis_rgb = vis_rgb.tile(
            (1, 1, rep, 1)).reshape((-1, ) + vis_rgb.shape[2:])
        vis_alpha = vis_alpha.tile(
            (1, 1, rep)).reshape((-1, ) + vis_alpha.shape[2:])

        # Add a strip of background pixels after each set of levels of rays.
        vis_rgb = vis_rgb.reshape((-1, stride) + vis_rgb.shape[1:])
        vis_alpha = vis_alpha.reshape((-1, stride) + vis_alpha.shape[1:])
        vis_rgb = jnp.concatenate(
            [vis_rgb, jnp.zeros_like(vis_rgb[:, :1])],
            axis=1).reshape((-1, ) + vis_rgb.shape[2:])
        vis_alpha = jnp.concatenate(
            [vis_alpha, jnp.zeros_like(vis_alpha[:, :1])],
            axis=1).reshape((-1, ) + vis_alpha.shape[2:])

    # Matte the RGB image over the background.
    vis = vis_rgb * vis_alpha[Ellipsis, None] + (bg_color *
                                                 (1 - vis_alpha))[Ellipsis,
                                                                  None]

    # Remove the final row of background pixels.
    vis = vis[:-1]
    vis_alpha = vis_alpha[:-1]
    return vis, vis_alpha
示例#26
0
文件: utils.py 项目: fehiepsi/jaxns
def safe_gaussian_kde(samples, weights):
    try:
        return gaussian_kde(samples, weights=weights, bw_method='silverman')
    except:
        hist, bin_edges = jnp.histogram(samples,weights=weights, bins='auto')
        return lambda x: hist[jnp.searchsorted(bin_edges, x)]