示例#1
0
def cosine_decay_scheduler(step, steps_per_cycle, t_mul=1, m_mul=1., alpha=0.):
  """Gives a scaling factor based on scheduling with a cosine decay.

  Args:
    step: int; Current step.
    steps_per_cycle: int; Number of steps to reset the decay cycle.
    t_mul: int; Used to derive the number of iterations in the i-th period.
    m_mul: float; Used to derive the initial learning rate of the i-th period.
    alpha: float; The minimum value as a fraction of the initial value.

  Returns:
    Scaling factor applied to the learning rate on the given step.
  """
  progress = step / float(steps_per_cycle)
  if t_mul == 1.0:
    i_restart = jnp.floor(progress)
    progress -= i_restart
  else:
    i_restart = jnp.floor(
        jnp.log(1.0 - progress * (1.0 - t_mul)) / jnp.log(t_mul))
    sum_r = (1.0 - t_mul**i_restart) / (1.0 - t_mul)
    progress = (progress - sum_r) / t_mul**i_restart
  m_fac = m_mul**i_restart
  cosine_decay = jnp.maximum(
      0.0, 0.5 * m_fac * (1.0 + jnp.cos(jnp.pi * (progress % 1.0))))
  return (1 - alpha) * cosine_decay + alpha
示例#2
0
def delta_r(ri, rj, box=None):
    diff = ri - rj # this can be either N,N,3 or B,3
    if box is not None:
        diff -= box[2]*np.floor(np.expand_dims(diff[...,2], axis=-1)/box[2][2]+0.5)
        diff -= box[1]*np.floor(np.expand_dims(diff[...,1], axis=-1)/box[1][1]+0.5)
        diff -= box[0]*np.floor(np.expand_dims(diff[...,0], axis=-1)/box[0][0]+0.5)
    return diff
示例#3
0
    def interpolate_bilinear(  # type: ignore
        im: np.ndarray, rows: np.ndarray, cols: np.ndarray
    ) -> np.ndarray:
        # based on http://stackoverflow.com/a/12729229
        col_lo = np.floor(cols).astype(int)
        col_hi = col_lo + 1
        row_lo = np.floor(rows).astype(int)
        row_hi = row_lo + 1

        def cclip(cols: np.ndarray) -> np.ndarray:  # type: ignore
            return np.clip(cols, 0, ncols - 1)

        def rclip(rows: np.ndarray) -> np.ndarray:  # type: ignore
            return np.clip(rows, 0, nrows - 1)

        nrows, ncols = im.shape[-3:-1]

        Ia = im[..., rclip(row_lo), cclip(col_lo), :]
        Ib = im[..., rclip(row_hi), cclip(col_lo), :]
        Ic = im[..., rclip(row_lo), cclip(col_hi), :]
        Id = im[..., rclip(row_hi), cclip(col_hi), :]

        wa = np.expand_dims((col_hi - cols) * (row_hi - rows), -1)
        wb = np.expand_dims((col_hi - cols) * (rows - row_lo), -1)
        wc = np.expand_dims((cols - col_lo) * (row_hi - rows), -1)
        wd = np.expand_dims((cols - col_lo) * (rows - row_lo), -1)

        return wa * Ia + wb * Ib + wc * Ic + wd * Id
示例#4
0
    def __init__(self, lower, upper):
        self.support = constraints.integer_interval(lower, upper)

        self.event_shape = ()
        self.batch_shape = broadcast_batch_shape(np.shape(lower),
                                                 np.shape(upper))
        self.lower = np.floor(lower)
        self.upper = np.floor(upper)
示例#5
0
    def __init__(self, lower, upper):
        self.support = constraints.integer_interval(lower, upper)

        self.event_shape = ()
        lower, upper = promote_shapes(lower, upper)
        batch_shape = lax.broadcast_shapes(jnp.shape(lower), jnp.shape(upper))
        self.batch_shape = batch_shape
        self.lower = jnp.broadcast_to(jnp.floor(lower), batch_shape)
        self.upper = jnp.broadcast_to(jnp.floor(upper), batch_shape)
示例#6
0
def bilinear_resample(x, warp):
    batch_shape = x.shape[:-3]
    input_image_dims = x.shape[-3:-1]
    batch_shape = list(batch_shape)
    input_image_dims = list(input_image_dims)
    num_feats = x.shape[-1]
    # image statistics
    height, width = input_image_dims
    max_x = width - 1
    max_y = height - 1
    idx_size = _reduce(_mul, warp.shape[-3:-1], 1)
    batch_shape_flat = _reduce(_mul, batch_shape, 1)
    # B
    batch_offsets = _jnp.arange(batch_shape_flat) * idx_size
    # B x (HxW)
    base_grid = _jnp.tile(_jnp.expand_dims(batch_offsets, 1), [1, idx_size])
    # (BxHxW)
    base = _jnp.reshape(base_grid, [-1])
    # (BxHxW) x D
    data_flat = _jnp.reshape(x, [batch_shape_flat * height * width, -1])
    # (BxHxW) x 2
    warp_flat = _jnp.reshape(warp, [-1, 2])
    warp_floored = (_jnp.floor(warp_flat)).astype(_jnp.int32)
    bilinear_weights = warp_flat - _jnp.floor(warp_flat)
    # (BxHxW)
    x0 = warp_floored[:, 0]
    x1 = x0 + 1
    y0 = warp_floored[:, 1]
    y1 = y0 + 1
    x0 = _jnp.clip(x0, 0, max_x)
    x1 = _jnp.clip(x1, 0, max_x)
    y0 = _jnp.clip(y0, 0, max_y)
    y1 = _jnp.clip(y1, 0, max_y)
    base_y0 = base + y0 * width
    base_y1 = base + y1 * width
    idx_a = base_y0 + x0
    idx_b = base_y1 + x0
    idx_c = base_y0 + x1
    idx_d = base_y1 + x1
    # (BxHxW) x D
    Ia = _jnp.take(data_flat, idx_a, axis=0)
    Ib = _jnp.take(data_flat, idx_b, axis=0)
    Ic = _jnp.take(data_flat, idx_c, axis=0)
    Id = _jnp.take(data_flat, idx_d, axis=0)
    # (BxHxW)
    xw = bilinear_weights[:, 0]
    yw = bilinear_weights[:, 1]
    # (BxHxW) x 1
    wa = _jnp.expand_dims((1 - xw) * (1 - yw), 1)
    wb = _jnp.expand_dims((1 - xw) * yw, 1)
    wc = _jnp.expand_dims(xw * (1 - yw), 1)
    wd = _jnp.expand_dims(xw * yw, 1)
    # (BxHxW) x D
    resampled_flat = wa * Ia + wb * Ib + wc * Ic + wd * Id
    # B x H x W x D
    return _jnp.reshape(resampled_flat, batch_shape + [-1, num_feats])
示例#7
0
文件: _fasd.py 项目: lschmors/RFEst
def fourierfreq(ncoeff, delta, CONDTHRESH=1e8):
    maxfreq = np.floor(ncoeff / (np.pi * delta) *
                       np.sqrt(0.5 * np.log(CONDTHRESH))).astype(int)
    # wvec = np.hstack([np.arange(maxfreq+1), np.arange(-maxfreq+1, 0)])
    if maxfreq < ncoeff / 2:
        wvec = np.hstack([np.arange(maxfreq + 1), np.arange(-maxfreq + 1, 0)])
    else:
        ncos = np.ceil((ncoeff + 1) / 2)
        nsin = np.floor((ncoeff - 1) / 2)
        wvec = np.hstack([np.arange(ncos), np.arange(-nsin, 0)])

    return wvec
示例#8
0
def recenter(conf, b):

    periodicBoxSize = jnp.array([
        [b, 0.],
        [0., b]
    ])

    diff = jnp.zeros_like(conf)
    diff += jnp.expand_dims(periodicBoxSize[1], axis=0)*jnp.expand_dims(jnp.floor(conf[:, 1]/periodicBoxSize[1][1]), axis=-1)
    diff += jnp.expand_dims(periodicBoxSize[0], axis=0)*jnp.expand_dims(jnp.floor((conf[:, 0]-diff[:, 0])/periodicBoxSize[0][0]), axis=-1)

    return conf - diff
示例#9
0
def recenter(conf, b):

    new_coords = []

    periodicBoxSize = jnp.array([[b, 0.0], [0.0, b]])

    for atom in conf:
        diff = jnp.array([0.0, 0.0])
        diff += periodicBoxSize[1] * jnp.floor(atom[1] / periodicBoxSize[1][1])
        diff += periodicBoxSize[0] * jnp.floor(
            (atom[0] - diff[0]) / periodicBoxSize[0][0])
        new_coords.append(atom - diff)

    return np.array(new_coords)
示例#10
0
def rescale_coordinates(conf, indices, box, scales):

    mol_sizes = np.expand_dims(onp.bincount(indices), axis=1)
    mol_centers = jax.ops.segment_sum(coords, indices) / mol_sizes

    new_centers = mol_centers - box[2] * np.floor(
        np.expand_dims(mol_centers[..., 2], axis=-1) / box[2][2])
    new_centers -= box[1] * np.floor(
        np.expand_dims(new_centers[..., 1], axis=-1) / box[1][1])
    new_centers -= box[0] * np.floor(
        np.expand_dims(new_centers[..., 0], axis=-1) / box[0][0])

    offset = new_centers - mol_centers

    return conf + offset[indices]
示例#11
0
def interp_regular_1d(x: np.ndarray, xmin: float, xmax: float,
                      yp: np.ndarray) -> np.ndarray:
    """One-dimensional linear interpolation.

  Returns the one-dimensional piecewise linear interpolation of the data points
  (xp, yp) evaluated at x. We extrapolate with the constants xmin and xmax
  outside the range [xmin, xmax].

  Args:
    x: The x-coordinates at which to evaluate the interpolated values.
    xmin: The lower bound of the regular input x-coordinate grid.
    xmax: The upper bound of the regular input x-coordinate grid.
    yp: The y coordinates of the data points.

  Returns:
    y: The interpolated values, same shape as x.
  """
    ny = len(yp)
    fractional_idx = (x - xmin) / (xmax - xmin)
    x_idx_unclipped = fractional_idx * (ny - 1)
    x_idx = np.clip(x_idx_unclipped, 0, ny - 1)
    idx_below = np.floor(x_idx)
    idx_above = np.minimum(idx_below + 1, ny - 1)
    idx_below = np.maximum(idx_above - 1, 0)
    y_ref_below = yp[idx_below.astype(np.int32)]
    y_ref_above = yp[idx_above.astype(np.int32)]
    t = x_idx - idx_below
    y = t * y_ref_above + (1 - t) * y_ref_below
    return y
示例#12
0
def hsv_to_rgb(h, s, v):
    """Converts H, S, V values to an R, G, B tuple.

  Reference TF implementation:
  https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/adjust_saturation_op.cc
  Only input values between 0 and 1 are guaranteed to work properly, but this
  function complies with the TF implementation outside of this range.

  Args:
    h: A float tensor of arbitrary shape for the hue (0-1 values).
    s: A float tensor of the same shape for the saturation (0-1 values).
    v: A float tensor of the same shape for the value channel (0-1 values).

  Returns:
    An (r, g, b) tuple, each with the same dimension as the inputs.
  """
    c = s * v
    m = v - c
    dh = (h % 1.) * 6.
    fmodu = dh % 2.
    x = c * (1 - jnp.abs(fmodu - 1))
    hcat = jnp.floor(dh).astype(jnp.int32)
    rr = jnp.where(
        (hcat == 0) | (hcat == 5), c, jnp.where(
            (hcat == 1) | (hcat == 4), x, 0)) + m
    gg = jnp.where(
        (hcat == 1) | (hcat == 2), c, jnp.where(
            (hcat == 0) | (hcat == 3), x, 0)) + m
    bb = jnp.where(
        (hcat == 3) | (hcat == 4), c, jnp.where(
            (hcat == 2) | (hcat == 5), x, 0)) + m
    return rr, gg, bb
示例#13
0
文件: smap.py 项目: VardaHagh/jax-md
def _cell_dimensions(spatial_dimension, box_size, minimum_cell_size):
  """Compute the number of cells-per-side and total number of cells in a box."""
  if isinstance(box_size, int) or isinstance(box_size, float):
    box_size = f32(box_size)

  if (isinstance(box_size, np.ndarray) and
      (box_size.dtype == np.int32 or box_size.dtype == np.int64)):
    box_size = f32(box_size)

  cells_per_side = np.floor(box_size / minimum_cell_size)
  cell_size = box_size / cells_per_side
  cells_per_side = np.array(cells_per_side, dtype=np.int64)

  if isinstance(box_size, np.ndarray):
    flat_cells_per_side = np.reshape(cells_per_side, (-1,))
    for cells in flat_cells_per_side:
      if cells < 3:
        raise ValueError(
            ('Box must be at least 3x the size of the grid spacing in each '
             'dimension.'))

    cell_count = reduce(mul, flat_cells_per_side, 1)
  else:
    cell_count = cells_per_side ** spatial_dimension

  return box_size, cell_size, cells_per_side, int(cell_count)
示例#14
0
def _cell_dimensions(spatial_dimension, box_size, minimum_cell_size):
    """Compute the number of cells-per-side and total number of cells in a box."""
    if isinstance(box_size, int) or isinstance(box_size, float):
        box_size = f32(box_size)

    # NOTE(schsam): Should we auto-cast based on box_size? I can't imagine a case
    # in which the box_size would not be accurately represented by an f32.
    if (isinstance(box_size, np.ndarray)
            and (box_size.dtype == np.int32 or box_size.dtype == np.int64)):
        box_size = f32(box_size)

    cells_per_side = np.floor(box_size / minimum_cell_size)
    cell_size = box_size / cells_per_side
    cells_per_side = np.array(cells_per_side, dtype=np.int64)

    if isinstance(box_size, np.ndarray):
        if box_size.ndim == 1 or box_size.ndim == 2:
            assert box_size.size == spatial_dimension
            flat_cells_per_side = np.reshape(cells_per_side, (-1, ))
            for cells in flat_cells_per_side:
                if cells < 3:
                    raise ValueError((
                        'Box must be at least 3x the size of the grid spacing in each '
                        'dimension.'))
            cell_count = reduce(mul, flat_cells_per_side, 1)
        elif box_size.ndim == 0:
            cell_count = cells_per_side**spatial_dimension
        else:
            raise ValueError('Box must either be a scalar or a vector.')
    else:
        cell_count = cells_per_side**spatial_dimension

    return box_size, cell_size, cells_per_side, int(cell_count)
示例#15
0
    def two_normalize(m):
        # Divide m by a power of 2 to get its norm close to 1
        norm = np.linalg.norm(m, axis=(2, 3), keepdims=True)
        two_pow = np.floor(np.log2(norm))
        stable_m = m / (2**two_pow)

        return stable_m, np.sum(two_pow, axis=0)
示例#16
0
def realfftbasis(nx):
    """
    Basis of sines+cosines for nn-point discrete fourier transform (DFT).
    
    Ported from MatLab code:
    https://github.com/leaduncker/SimpleEvidenceOpt/blob/master/util/realfftbasis.m
    
    """
    import numpy as np
    nn = nx

    ncos = np.ceil((nn + 1) / 2)
    nsin = np.floor((nn - 1) / 2)

    wvec = np.hstack(
        [np.arange(start=0., stop=ncos),
         np.arange(start=-nsin, stop=0.)])

    wcos = wvec[wvec >= 0]
    wsin = wvec[wvec < 0]

    x = np.arange(nx)

    t0 = np.cos(np.outer(wcos * 2 * np.pi / nn, x))
    t1 = np.sin(np.outer(wsin * 2 * np.pi / nn, x))

    B = np.vstack([t0, t1]) / np.sqrt(nn * 0.5)

    return B, wvec
示例#17
0
 def sample(self, key, sample_shape=()):
     assert is_prng_key(key)
     probs = self.probs
     dtype = jnp.result_type(probs)
     shape = sample_shape + self.batch_shape
     u = random.uniform(key, shape, dtype)
     return jnp.floor(jnp.log1p(-u) / jnp.log1p(-probs))
示例#18
0
def test_flip_state_fock_infinite():
    hi = Fock(N=2)
    rng = nk.jax.PRNGSeq(1)
    N_batches = 20

    states = hi.random_state(rng.next(), N_batches, dtype=jnp.int64)

    ids = jnp.asarray(
        jnp.floor(hi.size *
                  jax.random.uniform(rng.next(), shape=(N_batches, ))),
        dtype=int,
    )

    new_states, old_vals = nk.hilbert.random.flip_state(
        hi, rng.next(), states, ids)

    assert new_states.shape == states.shape

    assert np.all(states >= 0)

    states_np = np.asarray(states)
    states_new_np = np.array(new_states)

    for (row, col) in enumerate(ids):
        states_new_np[row, col] = states_np[row, col]

    np.testing.assert_allclose(states_np, states_new_np)
示例#19
0
def uniform_stochastic_quantize(v: jnp.ndarray,
                                num_levels: int,
                                rng: PRNGKey,
                                v_min: Optional[float] = None,
                                v_max: Optional[float] = None) -> jnp.ndarray:
  """Uniform stochastic algorithm in https://arxiv.org/pdf/1611.00429.pdf.

  Args:
    v: vector to be quantized.
    num_levels: Number of levels of quantization.
    rng: jax random key.
    v_min: minimum threshold for quantization. If None, sets it to jnp.amin(v).
    v_max: maximum threshold for quantization. If None, sets it to jnp.amax(v).

  Returns:
    Quantized array.
  """
  # Rescale the vector to be between zero to one.
  if v_min is None:
    v_min = jnp.amin(v)
  if v_max is None:
    v_max = jnp.amax(v)
  v = jnp.nan_to_num((v - v_min) / (v_max - v_min))
  v = jnp.maximum(0., jnp.minimum(v, 1.))
  # Compute the upper and lower boundary of each value.
  v_ceil = jnp.ceil(v * (num_levels - 1)) / (num_levels - 1)
  v_floor = jnp.floor(v * (num_levels - 1)) / (num_levels - 1)
  # uniformly quantize between v_ceil and v_floor.
  rand = jax.random.uniform(key=rng, shape=v.shape)
  threshold = jnp.nan_to_num((v - v_floor) / (v_ceil - v_floor))
  quantized = jnp.where(rand > threshold, v_floor, v_ceil)
  # Rescale the values and return it.
  return v_min + quantized * (v_max - v_min)
示例#20
0
  def bcL(self, rng=None):
    """bcL creates a random boundary condition for a L-shaped domain.

    The boundary is a random 3rd order polynomial of sine functions.
    rng variable allows to reproduce the results. Sine functions are chosen so
    that the boundary is periodic and does not have discontinuities.
    """
    if rng is None:
      rng = random.PRNGKey(1)
    n = self.n
    x = onp.sin(self.bcmesh * np.pi)
    n_y = (np.floor((n + 1) / 2) - 1).astype(int)
    if rng is not None:
      coeffs = random.multivariate_normal(rng, np.zeros(16),
                                          np.diag(np.ones(16)))
    else:
      key = random.randint(random.PRNGKey(1), (1,), 1, 1000)
      coeffs = random.multivariate_normal(
          random.PRNGKey(key[0]), np.zeros(16), np.diag(np.ones(16)))
    left = coeffs[0] * x**3 + coeffs[1] * x**2 + coeffs[2] * x  #+ coeffs[3]
    right = coeffs[4] * x**3 + coeffs[5] * x**2 + coeffs[6] * x  #+ coeffs[7]
    lower = coeffs[8] * x**3 + coeffs[9] * x**2 + coeffs[10] * x  #+ coeffs[11]
    upper = coeffs[12] * x**3 + coeffs[13] * x**2 + coeffs[14] * x  #+ coeffs[15]
    shape = 2 * x.shape
    source = onp.zeros(shape)
    source[0, :] = upper
    source[n_y - 1, n_y - 1:] = lower[:n - n_y + 1]
    source[n_y - 1:, n_y - 1] = right[:n - n_y + 1]
    source[:, 0] = left
    source[-1, :n_y - 1] = right[n:n - n_y:-1]
    source[:n_y - 1, -1] = lower[n:n - n_y:-1]
    # because this makes the correct order of boundary conditions
    return source * (n + 1)**2
示例#21
0
  def call(self,
           inputs: Mapping[str, jnp.ndarray],
           rng: PRNGKey,
           sample: Optional[bool]=False,
           **kwargs
  ) -> Mapping[str, jnp.ndarray]:
    x = inputs["x"]
    x_shape = self.get_unbatched_shapes(sample)["x"]

    log_det = -jnp.zeros(self.batch_shape)
    flow = self.flow if self.flow is not None else self.default_flow()

    if sample == False:
      flow_inputs = {"x": jnp.zeros(x.shape), "condition": x}
      outputs = flow(flow_inputs, rng, sample=True)

      noise = outputs["x"]
      z = x + noise

      log_qugx = outputs["log_det"] + outputs["log_pz"]
      log_det -= log_qugx
    else:
      z_continuous = x
      z = jnp.floor(z_continuous).astype(jnp.int32)
      noise = z_continuous - z
      flow_inputs = {"x": noise, "condition": x}
      outputs = flow(flow_inputs, rng, sample=False)
      log_qugx = outputs["log_det"] + outputs["log_pz"]
      log_det -= log_qugx

    return {"x": z, "log_det": log_det}
示例#22
0
    def __call__(self, x, mask_props=None, is_training=True):
        out = hk.Flatten()(x)

        for l in range(self.nlayers):
            out = hk.Linear(self.nhid, with_bias=self.with_bias)(out)
            if self.batch_norm:
                out = hk.BatchNorm(create_scale=False,
                                   create_offset=False,
                                   decay_rate=0.9)(out, is_training)

            if mask_props is not None:
                num_units = jnp.floor(mask_props[l] * out.shape[1]).astype(
                    jnp.int32)
                mask = jnp.arange(out.shape[1]) < num_units
                out = jnp.where(mask, out, jnp.zeros(out.shape))

            if self.activation == 'relu':
                out = jax.nn.relu(out)
            elif self.activation == 'sigmoid':
                out = jax.nn.sigmoid(out)
            elif self.activation == 'tanh':
                out = jnp.tanh(out)
            elif self.activation == 'linear':
                out = out

        out = hk.Linear(10, with_bias=self.with_bias)(out)
        return out
示例#23
0
 def _binom_inv_body_fn(val):
     i, key, geom_acc = val
     key, key_u = random.split(key)
     u = random.uniform(key_u)
     geom = np.floor(np.log1p(-u) / log1_p) + 1
     geom_acc = geom_acc + geom
     return i + 1, key, geom_acc
示例#24
0
def interpolate1d(x, values, tangents):
  r"""Perform cubic hermite spline interpolation on a 1D spline.

  The x coordinates of the spline knots are at [0 : len(values)-1].
  Queries outside of the range of the spline are computed using linear
  extrapolation. See https://en.wikipedia.org/wiki/Cubic_Hermite_spline
  for details, where "x" corresponds to `x`, "p" corresponds to `values`, and
  "m" corresponds to `tangents`.

  Args:
    x: A tensor containing the set of values to be used for interpolation into
      the spline.
    values: A vector containing the value of each knot of the spline being
      interpolated into. Must be the same length as `tangents`.
    tangents: A vector containing the tangent (derivative) of each knot of the
      spline being interpolated into. Must be the same length as `values` and
      the same type as `x`.

  Returns:
    The result of interpolating along the spline defined by `values`, and
    `tangents`, using `x` as the query values. Will be the same shape as `x`.
  """
  assert len(values.shape) == 1
  assert len(tangents.shape) == 1
  assert values.shape[0] == tangents.shape[0]

  # Find the indices of the knots below and above each x.
  x_lo = jnp.int32(jnp.floor(jnp.clip(x, 0., values.shape[0] - 2)))
  x_hi = x_lo + 1

  # Compute the relative distance between each `x` and the knot below it.
  t = x - x_lo

  # Compute the cubic hermite expansion of `t`.
  t_sq = t**2
  t_cu = t * t_sq
  h01 = -2 * t_cu + 3 * t_sq
  h00 = 1 - h01
  h11 = t_cu - t_sq
  h10 = h11 - t_sq + t

  # Linearly extrapolate above and below the extents of the spline for all
  # values.
  value_before = tangents[0] * t + values[0]
  value_after = tangents[-1] * (t - 1) + values[-1]

  # Cubically interpolate between the knots below and above each query point.
  neighbor_values_lo = jnp.take(values, x_lo)
  neighbor_values_hi = jnp.take(values, x_hi)
  neighbor_tangents_lo = jnp.take(tangents, x_lo)
  neighbor_tangents_hi = jnp.take(tangents, x_hi)

  value_mid = (
      neighbor_values_lo * h00 + neighbor_values_hi * h01 +
      neighbor_tangents_lo * h10 + neighbor_tangents_hi * h11)

  # Return the interpolated or extrapolated values for each query point,
  # depending on whether or not the query lies within the span of the spline.
  return jnp.where(t < 0., value_before,
                   jnp.where(t > 1., value_after, value_mid))
示例#25
0
 def schedule(count):
   count -= transition_begin
   p = count / transition_steps
   if staircase:
     p = jnp.floor(p)
   return jnp.where(
       count <= 0, init_value, init_value * jnp.power(decay_rate, p))
示例#26
0
def getE(M, e):
    """JAX autograd compatible version of the Solver of Kepler's Equation for
    the "eccentric anomaly", E.

    Args:
       M : Mean anomaly
       e : Eccentricity

    Returns:
       Eccentric anomaly
    """
    pi = jnp.pi
    Mt = M - (jnp.floor(M / (2. * pi)) * 2. * pi)
    Mt = jnp.where(M > pi, 2. * pi - Mt, Mt)
    Mt = jnp.where(Mt == 0.0, 0.0, Mt)

    alpha = _alpha(e, Mt)
    d = _d(alpha, e)
    r = _r(alpha, d, Mt, e)
    q = _q(alpha, d, e, Mt)
    w = _w(r, q)
    E1 = _E1(d, r, w, q, Mt)
    f = _f01234(e, E1, Mt)
    d3 = _d3(E1, f)
    d4 = _d4(E1, f, d3)
    d5 = _d5(E1, f, d4)
    # Eq. 29
    E5 = E1 + d5
    # if flip:
    E5 = jnp.where(M > pi, 2. * pi - E5, E5)
    #    E5 = 2. * pi - E5
    E = E5
    return E5
示例#27
0
 def sample(self, key, sample_shape=()):
     assert is_prng_key(key)
     logits = self.logits
     dtype = jnp.result_type(logits)
     shape = sample_shape + self.batch_shape
     u = random.uniform(key, shape, dtype)
     return jnp.floor(jnp.log1p(-u) / -softplus(logits))
示例#28
0
文件: binomial.py 项目: tblazina/mcx
 def logpdf(self, k):
     k = jnp.floor(k)
     unnormalized = xlogy(k, self.p) + xlog1py(self.n - k, -self.p)
     binomialcoeffln = gammaln(self.n + 1) - (
         gammaln(k + 1) + gammaln(self.n - k + 1)
     )
     return unnormalized + binomialcoeffln
示例#29
0
def test_flip_state_discrete(hi: DiscreteHilbert):

    rng = nk.jax.PRNGSeq(1)
    N_batches = 20

    states = hi.random_state(rng.next(), N_batches)

    ids = jnp.asarray(
        jnp.floor(hi.size *
                  jax.random.uniform(rng.next(), shape=(N_batches, ))),
        dtype=int,
    )

    new_states, old_vals = nk.hilbert.random.flip_state(
        hi, rng.next(), states, ids)

    assert new_states.shape == states.shape

    for state in states:
        assert all(val in hi.states_at_index(i) for i, val in enumerate(state))

    states_np = np.asarray(states)
    states_new_np = np.array(new_states)

    for (row, col) in enumerate(ids):
        states_new_np[row, col] = states_np[row, col]

    np.testing.assert_allclose(states_np, states_new_np)
示例#30
0
def test_flip_state(hi):
    rng = nk.jax.PRNGSeq(1)
    N_batches = 20
    if isinstance(hi, DiscreteHilbert):
        local_states = hi.local_states
        states = hi.random_state(rng.next(), N_batches)

        ids = jnp.asarray(
            jnp.floor(hi.size *
                      jax.random.uniform(rng.next(), shape=(N_batches, ))),
            dtype=int,
        )

        new_states, old_vals = nk.hilbert.random.flip_state(
            hi, rng.next(), states, ids)

        assert new_states.shape == states.shape

        assert np.all(np.in1d(new_states.reshape(-1), local_states))

        states_np = np.asarray(states)
        states_new_np = np.array(new_states)

        for (row, col) in enumerate(ids):
            states_new_np[row, col] = states_np[row, col]

        np.testing.assert_allclose(states_np, states_new_np)