Beispiel #1
0
def stimulate(t, X, stimuli):
    stimulated = np.zeros_like(X)
    for stimulus in stimuli:
        active = np.greater_equal(t, stimulus["start"])
        active &= (np.mod(stimulus["start"] - t + 1, stimulus["period"]) < stimulus["duration"])
        stimulated = np.where(stimulus["field"] * (active), stimulus["field"], stimulated)
    return np.where(stimulated != 0, stimulated, X)
Beispiel #2
0
def update_periodically(steps: jnp.ndarray, target_update_period: int,
                        params: types.NestedArray,
                        target_params: types.NestedArray) -> types.NestedArray:
    """Checks whether to update the params and returns the correct params."""
    return jax.lax.cond(
        jnp.mod(steps, target_update_period) == 0, lambda _: params,
        lambda _: target_params, None)
Beispiel #3
0
    def test_sorted_piecewise_constant_pdf_sparse_delta(self):
        """Test sampling when given a large distribution with a big delta in it."""
        num_samples = 100
        num_bins = 100000
        key = random.PRNGKey(0)
        bins = jnp.arange(num_bins)
        weights = np.ones(len(bins) - 1)
        delta_idx = len(weights) // 2
        weights[delta_idx] = len(weights) - 1
        samples = math.sorted_piecewise_constant_pdf(
            key,
            bins[None],
            weights[None],
            num_samples,
            True,
        )[0]

        # All samples should be within the range of the bins.
        self.assertTrue(jnp.all(samples >= bins[0]))
        self.assertTrue(jnp.all(samples <= bins[-1]))

        # Samples modded by their bin index should resemble a uniform distribution.
        samples_mod = jnp.mod(samples, 1)
        self.assertLessEqual(
            sp.stats.kstest(samples_mod, 'uniform', (0, 1)).statistic, 0.2)

        # The delta function bin should contain ~half of the samples.
        in_delta = (samples >= bins[delta_idx]) & (samples <=
                                                   bins[delta_idx + 1])
        self.assertAllClose(jnp.mean(in_delta), 0.5, atol=0.05)
Beispiel #4
0
    def test_sorted_piecewise_constant_pdf_large_flat(self):
        """Test sampling when given a large flat distribution."""
        num_samples = 100
        num_bins = 100000
        key = random.PRNGKey(0)
        bins = jnp.arange(num_bins)
        weights = np.ones(len(bins) - 1)
        samples = math.sorted_piecewise_constant_pdf(
            key,
            bins[None],
            weights[None],
            num_samples,
            True,
        )[0]
        # All samples should be within the range of the bins.
        self.assertTrue(jnp.all(samples >= bins[0]))
        self.assertTrue(jnp.all(samples <= bins[-1]))

        # Samples modded by their bin index should resemble a uniform distribution.
        samples_mod = jnp.mod(samples, 1)
        self.assertLessEqual(
            sp.stats.kstest(samples_mod, 'uniform', (0, 1)).statistic, 0.2)

        # All samples should collectively resemble a uniform distribution.
        self.assertLessEqual(
            sp.stats.kstest(samples, 'uniform', (bins[0], bins[-1])).statistic,
            0.2)
Beispiel #5
0
def periodic_update(
    new_tensors: base.Params,
    old_tensors: base.Params,
    steps: chex.Array,
    update_period: int
) -> base.Params:
  """Periodically update all parameters with new values.

  A slow copy of a model's parameters, updated every K actual updates, can be
  used to implement forms of self-supervision (in supervised learning), or to
  stabilise temporal difference learning updates (in reinforcement learning).

  References:
    [Grill et al., 2020](https://arxiv.org/abs/2006.07733)
    [Mnih et al., 2015](https://arxiv.org/abs/1312.5602)

  Args:
    new_tensors: the latest value of the tensors.
    old_tensors: a slow copy of the model's parameters.
    steps: number of update steps on the "online" network.
    update_period: every how many steps to update the "target" network.

  Returns:
    a slow copy of the model's parameters, updated every `update_period` steps.
  """
  return jax.lax.cond(
      jnp.mod(steps, update_period) == 0,
      lambda _: new_tensors,
      lambda _: old_tensors,
      None)
Beispiel #6
0
 def value(self, count: JTensor) -> JTensor:
     relative_step = jnp.mod(count, self._period)
     output = self._schedules[0].value(count)
     for boundary, schedule in zip(self._boundaries, self._schedules[1:]):
         output = jnp.where(relative_step < boundary, output,
                            schedule.value(count))
     return output
Beispiel #7
0
def stimulate(t, X, stimuli):
    for stimulus in stimuli:
        active = t > stimulus["start"]
        active &= t < stimulus["start"] + stimulus["duration"]
        # for some weird reason checks for cyclic stimuli does not work
        active = (np.mod(t - stimulus["start"], stimulus["period"]) < stimulus["duration"])  # cyclic
        X = np.where(stimulus["field"] * (active), stimulus["field"], X)
    return X
Beispiel #8
0
 def body_fun(i_acc):
   i, acc = i_acc
   return (i + 1,
           (jnp.cos(acc) +
            lax.cond(jnp.mod(i, 2) == 0,
                     lambda acc: jnp.sin(acc),
                     lambda acc: acc,
                     acc)))
Beispiel #9
0
def kepler(mean_anom, ecc):
    # We're going to apply array broadcasting here since the logic of our op
    # is much simpler if we require the inputs to all have the same shapes
    mean_anom_, ecc_ = jnp.broadcast_arrays(mean_anom, ecc)

    # Then we need to wrap into the range [0, 2*pi)
    M_mod = jnp.mod(mean_anom_, 2 * np.pi)

    return _kepler_prim.bind(M_mod, ecc_)
Beispiel #10
0
def current_stimulate(t, X, stimuli):
    stimulated = np.zeros_like(X)
    for stimulus in stimuli:
        # active = np.greater_equal(t, stimulus["start"])
        # active &= (np.mod(stimulus["start"] - t + 1, stimulus["period"]) < stimulus["duration"])
        active = np.greater_equal(t ,stimulus["start"])
        # active &= np.greater_equal(stimulus["start"] + stimulus["duration"],t)
        active &= (np.mod(t - stimulus["start"], stimulus["period"]) < stimulus["duration"]) # this works for cyclics
        stimulated = np.where(stimulus["field"] * (active), stimulus["field"], stimulated)
    return np.where(stimulated != 0, stimulated, X)
Beispiel #11
0
def visualize_cmap(value,
                   weight,
                   colormap,
                   lo=None,
                   hi=None,
                   percentile=99.,
                   curve_fn=lambda x: x,
                   modulus=None,
                   matte_background=True):
    """Visualize a 1D image and a 1D weighting according to some colormap.

  Args:
    value: A 1D image.
    weight: A weight map, in [0, 1].
    colormap: A colormap function.
    lo: The lower bound to use when rendering, if None then use a percentile.
    hi: The upper bound to use when rendering, if None then use a percentile.
    percentile: What percentile of the value map to crop to when automatically
      generating `lo` and `hi`. Depends on `weight` as well as `value'.
    curve_fn: A curve function that gets applied to `value`, `lo`, and `hi`
      before the rest of visualization. Good choices: x, 1/(x+eps), log(x+eps).
    modulus: If not None, mod the normalized value by `modulus`. Use (0, 1]. If
      `modulus` is not None, `lo`, `hi` and `percentile` will have no effect.
    matte_background: If True, matte the image over a checkerboard.

  Returns:
    A colormap rendering.
  """
    # Identify the values that bound the middle of `value' according to `weight`.
    lo_auto, hi_auto = math.weighted_percentile(
        value, weight, [50 - percentile / 2, 50 + percentile / 2])

    # If `lo` or `hi` are None, use the automatically-computed bounds above.
    eps = jnp.finfo(jnp.float32).eps
    lo = lo or (lo_auto - eps)
    hi = hi or (hi_auto + eps)

    # Curve all values.
    value, lo, hi = [curve_fn(x) for x in [value, lo, hi]]

    # Wrap the values around if requested.
    if modulus:
        value = jnp.mod(value, modulus) / modulus
    else:
        # Otherwise, just scale to [0, 1].
        value = jnp.nan_to_num(
            jnp.clip((value - jnp.minimum(lo, hi)) / jnp.abs(hi - lo), 0, 1))

    if colormap:
        colorized = colormap(value)[:, :, :3]
    else:
        assert len(value.shape) == 3 and value.shape[-1] == 3
        colorized = value

    return matte(colorized, weight) if matte_background else colorized
Beispiel #12
0
def getn(l):
    """
    IN: l = n^2 - n / 2
    OUT: n (positive integer solution)
    """
    n = (1 + np.sqrt(1 + 8 * l)) / 2
    assert np.equal(np.mod(n, 1), 0)  # make sure n is an integer

    n = int(n)
    assert l == n**2 - n / 2
    return n
  def learning_rate_fn(step):
    d_step = step - wait_steps
    pfraction = jnp.mod(d_step, halfwavelength_steps) / halfwavelength_steps
    scale_factor = min_value + (1.0 - min_value) * 0.5 * (
        jnp.cos(jnp.pi * pfraction) + 1.0)

    scale_factor = lax.cond(d_step < 0., d_step, lambda d_step: 1.0, d_step,
                            lambda d_step: scale_factor)
    lr = base_learning_rate * scale_factor
    if warmup_length > 0.0:
      lr = lr * jnp.minimum(1., step / float(warmup_length) / steps_per_epoch)
    return lr
Beispiel #14
0
def periodic_displacement(side, dR):
  """Wraps displacement vectors into a hypercube.

  Args:
    side: Specification of hypercube size. Either,
      (a) float if all sides have equal length.
      (b) ndarray(spatial_dim) if sides have different lengths.
    dR: Matrix of displacements; ndarray(shape=[..., spatial_dim]).
  Returns:
    Matrix of wrapped displacements; ndarray(shape=[..., spatial_dim]).
  """
  return np.mod(dR + side * f32(0.5), side) - f32(0.5) * side
Beispiel #15
0
def periodic_shift(L, x, dx):
    """
    do a periodic shift

    arguments
        L : Array
            length of dimension
        x : Array
            positions
        dx : Array
            position increment
    """
    return jnp.mod(x + dx, L)
Beispiel #16
0
def stimulate(t, X, stimuli):
    stimulated = jnp.zeros_like(X)
    for stimulus in stimuli:
        # check if stimulus is in the past
        active = jnp.greater_equal(t, stimulus.protocol.start)
        # check if stimulus is active at the current time
        active &= (jnp.mod(stimulus.protocol.start - t + 1,
                           stimulus.protocol.period) <
                   stimulus.protocol.duration)
        # build the stimulus field
        stimulated = jnp.where(stimulus.field * (active), stimulus.field,
                               stimulated)
    # set the field to the stimulus
    return jnp.where(stimulated != 0, stimulated, X)
Beispiel #17
0
        def split_top_k(split_queries: Array) -> Tuple[Array, Array, Array]:
            # Find most similar clusters
            prototype_scores = jnp.einsum('qd,pd->qp', split_queries,
                                          prototypes)
            top_indices = jax.lax.top_k(prototype_scores, self.n_search)[1]
            # Perform approximate top-k similarity search over most similar clusters.
            selected_data = table[top_indices]
            split_scores = jnp.einsum('qd,qcrvd->qcrv', split_queries,
                                      selected_data)

            # Find highest scoring vector for each row.
            top_id_by_row = jnp.argmax(split_scores, axis=-1)
            top_score_by_row = jnp.max(split_scores, axis=-1)

            top_id_by_row = top_id_by_row.reshape(
                queries_per_split, self.n_search * rows_per_cluster)
            top_score_by_row = top_score_by_row.reshape(
                queries_per_split, self.n_search * rows_per_cluster)

            # Take k highest scores among all rows.
            top_row_idx = jnp.argsort(top_score_by_row,
                                      axis=-1)[:, :-self.k_top - 1:-1]

            # Sub-select best indices for k best rows.
            ids_by_topk_row = jut.matmul_slice(top_id_by_row, top_row_idx)

            # Gather highest scoring vectors for k best rows.
            query_index = jnp.arange(queries_per_split).reshape(-1, 1).tile(
                [1, self.k_top])
            top_cluster_idx, top_cluster_row_idx = jnp.divmod(
                top_row_idx, rows_per_cluster)
            split_topk_values = selected_data[query_index, top_cluster_idx,
                                              top_cluster_row_idx,
                                              ids_by_topk_row]

            row_offset = jnp.mod(
                jnp.arange(0, self.n_search * values_per_cluster,
                           values_per_row), values_per_cluster)
            cluster_offset = jnp.arange(0, table_size, values_per_cluster)

            # Convert row indices to indices into flattened table.
            top_table_id_by_row = top_id_by_row + row_offset.reshape(
                1, -1) + cluster_offset[top_indices].repeat(rows_per_cluster,
                                                            axis=-1)
            # Get best ids into flattened table.
            split_topk_ids = jut.matmul_slice(top_table_id_by_row, top_row_idx)

            split_topk_scores = jut.matmul_slice(top_score_by_row, top_row_idx)

            return split_topk_values, split_topk_scores, split_topk_ids
Beispiel #18
0
def advect(f, vx, vy):
    """Move field f according to x and y velocities (u and v)
       using an implicit Euler integrator."""
    rows, cols = f.shape
    cell_ys, cell_xs = np.meshgrid(np.arange(rows), np.arange(cols))
    cell_xs = cell_xs.astype(dtype)
    cell_ys = cell_ys.astype(dtype)
    center_xs = (cell_xs - vx).ravel()
    center_ys = (cell_ys - vy).ravel()

    # Compute indices of source cells.
    left_ix = np.floor(center_xs).astype(int)
    top_ix = np.floor(center_ys).astype(int)
    rw = center_xs - left_ix  # Relative weight of right-hand cells.
    bw = center_ys - top_ix  # Relative weight of bottom cells.
    left_ix = np.mod(left_ix, rows)  # Wrap around edges of simulation.
    right_ix = np.mod(left_ix + 1, rows)
    top_ix = np.mod(top_ix, cols)
    bot_ix = np.mod(top_ix + 1, cols)

    # A linearly-weighted sum of the 4 surrounding cells.
    flat_f = (1 - rw) * ((1 - bw)*f[left_ix,  top_ix] + bw*f[left_ix,  bot_ix]) \
                 + rw * ((1 - bw)*f[right_ix, top_ix] + bw*f[right_ix, bot_ix])
    return np.reshape(flat_f, (rows, cols))
Beispiel #19
0
def visualize_depth(depth,
                    acc=None,
                    near=None,
                    far=None,
                    curve_fn=lambda x: jnp.log(x + jnp.finfo(jnp.float32).eps),
                    modulus=0,
                    colormap=None):
    """Visualize a depth map.

  Args:
    depth: A depth map.
    acc: An accumulation map, in [0, 1].
    near: The depth of the near plane, if None then just use the min().
    far: The depth of the far plane, if None then just use the max().
    curve_fn: A curve function that gets applied to `depth`, `near`, and `far`
      before the rest of visualization. Good choices: x, 1/(x+eps), log(x+eps).
    modulus: If > 0, mod the normalized depth by `modulus`. Use (0, 1].
    colormap: A colormap function. If None (default), will be set to
      matplotlib's viridis if modulus==0, sinebow otherwise.

  Returns:
    An RGB visualization of `depth`.
  """
    # If `near` or `far` are None, identify the min/max non-NaN values.
    eps = jnp.finfo(jnp.float32).eps
    near = near or jnp.min(jnp.nan_to_num(depth, jnp.inf)) - eps
    far = far or jnp.max(jnp.nan_to_num(depth, -jnp.inf)) + eps

    # Curve all values.
    depth, near, far = [curve_fn(x) for x in [depth, near, far]]

    # Wrap the values around if requested.
    if modulus > 0:
        value = jnp.mod(depth, modulus) / modulus
        colormap = colormap or sinebow
    else:
        # Scale to [0, 1].
        value = jnp.nan_to_num(jnp.clip((depth - near) / (far - near), 0, 1))
        colormap = colormap or cm.get_cmap('viridis')

    vis = colormap(value)[:, :, :3]

    # Set non-accumulated pixels to white.
    if acc is not None:
        vis = vis * acc[:, :, None] + (1 - acc)[:, :, None]

    return vis
Beispiel #20
0
def circdist(x: jnp.ndarray, y: jnp.ndarray, circumference: float) -> jnp.ndarray:
    """Calculate the signed circular distance between two arrays.

    Returns positive numbers if y is clockwise compared to x, negative if y is counter-
    clockwise compared to x.

    Args:
        x: The first array.
        y: The second array.
        circumference: The circumference of the circle.

    Returns:
        An array of the same shape as x and y, containing the signed circular distances.

    """
    assert y.shape == x.shape
    return -jnp.mod(x - y - circumference / 2, circumference) + circumference / 2
def mobius_flow(theta: jnp.ndarray, w: jnp.ndarray) -> jnp.ndarray:
    """Following [1] the Mobius flow must have the property that the endpoints of
    the interval [0, 2pi] are mapped to themselves. This transformation
    enforces this property by appling a backwards rotation to bring the zero
    angle back to itself.

    Args:
        theta: The angle parameterizing a point on the circle.
        w: Center points of the Mobius transformation.

    Returns:
        out: The Mobius flow as a transformation on angles leaving the endpoints
            of the interval invariant.

    """
    theta = jnp.hstack((0., theta))
    omega = mobius_angle(theta, w)
    omega -= omega[..., [0]]
    return jnp.squeeze(jnp.mod(omega[..., 1:], 2.*jnp.pi))
Beispiel #22
0
def mutate_population(key: jnp.ndarray,
                      population: Population,
                      mutation_rate: float = 0.15,
                      alphabet_size: int = 4) -> Population:
  """Mutate a polymer population.

  Here we convert the total mutation rate into mutation rates for each
  member of an alphabet and use this, combined with the non-mutation
  rate, to sample mutations. These mutations are simply members of the
  alphabet that are added to the existing values followed by a modulo
  operation for the size of the alphabet.
  
  Example:
    alphabet_size = 4, pop = [[0,1,2,3]], mutation = [[0,0,1,2]],
    new_population = [[0,1,3,1]]

  Args:
    key: A random number generator key e.g. produced via
      jax.random.PRNGKey.
    population: A polymer population.
    mutation_rate: The combined rate of mutation over all forms
      of mutation. E.g. 15% that it would be one of 1->2 or 2->3.
    alphabet_size: The size of the alphabet from which polymer
      elements are sampled.

  Returns:
    Population: A population of polymers.

  """
  individual_p_mutation = mutation_rate / alphabet_size
  # Lazily double-counts self-transitions as a type of mutation
  # in the interest of prototyping
  p_no_mutation = (1 - mutation_rate)
  mutation_probs = [individual_p_mutation for _ in range(alphabet_size)]
  mutation_probs = [p_no_mutation] + mutation_probs

  mutation = random.choice(key,
                           a=jnp.array(range(alphabet_size + 1)),
                           shape=population.shape,
                           p=jnp.array(mutation_probs))

  return jnp.mod(population + mutation, alphabet_size - 1)
Beispiel #23
0
def segment_max(data,
                segment_ids,
                num_segments=None,
                indices_are_sorted=False,
                unique_indices=False):
    """Computes the max within segments of an array.

  Similar to TensorFlow's segment_max:
  https://www.tensorflow.org/api_docs/python/tf/math/segment_max

  Args:
    data: an array with the values to be maxed over.
    segment_ids: an array with integer dtype that indicates the segments of
      `data` (along its leading axis) to be maxed over. Values can be repeated
      and need not be sorted. Values outside of the range [0, num_segments) are
      wrapped into that range by applying jnp.mod.
    num_segments: optional, an int with positive value indicating the number of
      segments. The default is ``jnp.maximum(jnp.max(segment_ids) + 1,
      jnp.max(-segment_ids))`` but since `num_segments` determines the size of
      the output, a static value must be provided to use ``segment_max`` in a
      ``jit``-compiled function.
    indices_are_sorted: whether ``segment_ids`` is known to be sorted
    unique_indices: whether ``segment_ids`` is known to be free of duplicates

  Returns:
    An array with shape ``(num_segments,) + data.shape[1:]`` representing
    the segment maxs.
  """
    if num_segments is None:
        num_segments = jnp.maximum(
            jnp.max(segment_ids) + 1, jnp.max(-segment_ids))
    num_segments = int(num_segments)

    min_value = dtype_min_value(data.dtype)
    out = jnp.full((num_segments, ) + data.shape[1:],
                   min_value,
                   dtype=data.dtype)
    segment_ids = jnp.mod(segment_ids, num_segments)
    return jax.ops.index_max(out, segment_ids, data, indices_are_sorted,
                             unique_indices)
Beispiel #24
0
  def __call__(self, inputs, prev_state):
    """Writes a new memory into the episodic memory.

    Args:
      inputs: A Tensor of shape ``[batch_size, memory_size]``.
      prev_state: The previous state of the episodic memory, which is a tuple
         with a (i) counter of shape ``[batch_size, 1]`` indicating how many
         memories have been written so far, and (ii) a tensor of shape
         ``[batch_size, capacity, memory_size]`` with the full content of the
         episodic memory.
    Returns:
      A tuple with (i) a tensor of shape ``[batch_size, capacity, memory_size]``
          with the full content of the episodic memory, including the newly
          written memory, and (ii) the new state of the episodic memory.
    """
    inputs = jax.lax.stop_gradient(inputs)
    counter, memories = prev_state
    counter_mod = jnp.mod(counter, self._capacity)
    slot_selector = jnp.expand_dims(
        jax.nn.one_hot(counter_mod, self._capacity), axis=2)
    memories = memories * (1 - slot_selector) + (
        slot_selector * jnp.expand_dims(inputs, 1))
    counter = counter + 1
    return memories, (counter, memories)
Beispiel #25
0
def onnx_mod(a, b, fmod=0):
    if fmod:
        return jnp.fmod(a, b)
    else:
        return jnp.mod(a, b)
Beispiel #26
0
    def cell_list_fn(position: Array,
                     capacity_overflow_update: Optional[Tuple[
                         int, bool, Callable[..., CellList]]] = None,
                     extra_capacity: int = 0,
                     **kwargs) -> CellList:
        N = position.shape[0]
        dim = position.shape[1]

        if dim != 2 and dim != 3:
            # NOTE(schsam): Do we want to check this in compute_fn as well?
            raise ValueError(
                f'Cell list spatial dimension must be 2 or 3. Found {dim}.')

        _, cell_size, cells_per_side, cell_count = \
            _cell_dimensions(dim, box_size, minimum_cell_size)

        if capacity_overflow_update is None:
            cell_capacity = _estimate_cell_capacity(position, box_size,
                                                    cell_size,
                                                    buffer_size_multiplier)
            cell_capacity += extra_capacity
            overflow = False
            update_fn = cell_list_fn
        else:
            cell_capacity, overflow, update_fn = capacity_overflow_update

        hash_multipliers = _compute_hash_constants(dim, cells_per_side)

        # Create cell list data.
        particle_id = lax.iota(i32, N)
        # NOTE(schsam): We use the convention that particles that are successfully,
        # copied have their true id whereas particles empty slots have id = N.
        # Then when we copy data back from the grid, copy it to an array of shape
        # [N + 1, output_dimension] and then truncate it to an array of shape
        # [N, output_dimension] which ignores the empty slots.
        cell_position = jnp.zeros((cell_count * cell_capacity, dim),
                                  dtype=position.dtype)
        cell_id = N * jnp.ones((cell_count * cell_capacity, 1), dtype=i32)

        # It might be worth adding an occupied mask. However, that will involve
        # more compute since often we will do a mask for species that will include
        # an occupancy test. It seems easier to design around this empty_data_value
        # for now and revisit the issue if it comes up later.
        empty_kwarg_value = 10**5
        cell_kwargs = {}
        #  pytype: disable=attribute-error
        for k, v in kwargs.items():
            if not util.is_array(v):
                raise ValueError(
                    (f'Data must be specified as an ndarray. Found "{k}" '
                     f'with type {type(v)}.'))
            if v.shape[0] != position.shape[0]:
                raise ValueError(
                    ('Data must be specified per-particle (an ndarray '
                     f'with shape ({N}, ...)). Found "{k}" with '
                     f'shape {v.shape}.'))
            kwarg_shape = v.shape[1:] if v.ndim > 1 else (1, )
            cell_kwargs[k] = empty_kwarg_value * jnp.ones(
                (cell_count * cell_capacity, ) + kwarg_shape, v.dtype)
        #  pytype: enable=attribute-error
        indices = jnp.array(position / cell_size, dtype=i32)
        hashes = jnp.sum(indices * hash_multipliers, axis=1)

        # Copy the particle data into the grid. Here we use a trick to allow us to
        # copy into all cells simultaneously using a single lax.scatter call. To do
        # this we first sort particles by their cell hash. We then assign each
        # particle to have a cell id = hash * cell_capacity + grid_id where
        # grid_id is a flat list that repeats 0, .., cell_capacity. So long as
        # there are fewer than cell_capacity particles per cell, each particle is
        # guarenteed to get a cell id that is unique.
        sort_map = jnp.argsort(hashes)
        sorted_position = position[sort_map]
        sorted_hash = hashes[sort_map]
        sorted_id = particle_id[sort_map]

        sorted_kwargs = {}
        for k, v in kwargs.items():
            sorted_kwargs[k] = v[sort_map]

        sorted_cell_id = jnp.mod(lax.iota(i32, N), cell_capacity)
        sorted_cell_id = sorted_hash * cell_capacity + sorted_cell_id

        cell_position = cell_position.at[sorted_cell_id].set(sorted_position)
        sorted_id = jnp.reshape(sorted_id, (N, 1))
        cell_id = cell_id.at[sorted_cell_id].set(sorted_id)
        cell_position = _unflatten_cell_buffer(cell_position, cells_per_side,
                                               dim)
        cell_id = _unflatten_cell_buffer(cell_id, cells_per_side, dim)

        for k, v in sorted_kwargs.items():
            if v.ndim == 1:
                v = jnp.reshape(v, v.shape + (1, ))
            cell_kwargs[k] = cell_kwargs[k].at[sorted_cell_id].set(v)
            cell_kwargs[k] = _unflatten_cell_buffer(cell_kwargs[k],
                                                    cells_per_side, dim)

        occupancy = ops.segment_sum(jnp.ones_like(hashes), hashes, cell_count)
        max_occupancy = jnp.max(occupancy)
        overflow = overflow | (max_occupancy >= cell_capacity)

        return CellList(cell_position, cell_id, cell_kwargs, overflow,
                        cell_capacity, update_fn)  # pytype: disable=wrong-arg-count
Beispiel #27
0
    def _sample_next(sampler, machine, parameters: PyTree,
                     state: MetropolisPtSamplerState):
        new_rng, rng = jax.random.split(state.rng)
        # def cbr(data):
        #    new_rng, rng = data
        #    print("sample_next newrng:\n", new_rng,  "\nand rng:\n", rng)
        #    return new_rng
        # new_rng = hcb.call(
        #   cbr,
        #   (new_rng, rng),
        #   result_shape=jax.ShapeDtypeStruct(new_rng.shape, new_rng.dtype),
        # )

        with loops.Scope() as s:
            s.key = rng
            s.σ = state.σ
            s.log_prob = sampler.machine_pow * machine(parameters,
                                                       state.σ).real
            s.beta = state.beta

            # for logging
            s.beta_0_index = state.beta_0_index
            s.n_accepted_per_beta = state.n_accepted_per_beta
            s.beta_position = state.beta_position
            s.beta_diffusion = state.beta_diffusion

            for i in s.range(sampler.n_sweeps):
                # 1 to propagate for next iteration, 1 for uniform rng and n_chains for transition kernel
                s.key, key1, key2, key3, key4 = jax.random.split(s.key, 5)

                # def cbi(data):
                #    i, beta = data
                #    print("sweep #", i, " for beta=\n", beta)
                #    return beta
                # beta = hcb.call(
                #   cbi,
                #   (i, s.beta),
                #   result_shape=jax.ShapeDtypeStruct(s.beta.shape, s.beta.dtype),
                # )
                beta = s.beta

                σp, log_prob_correction = sampler.rule.transition(
                    sampler, machine, parameters, state, key1, s.σ)
                proposal_log_prob = sampler.machine_pow * machine(
                    parameters, σp).real

                uniform = jax.random.uniform(key2, shape=(sampler.n_batches, ))
                if log_prob_correction is not None:
                    do_accept = uniform < jnp.exp(
                        beta.reshape((-1, )) *
                        (proposal_log_prob - s.log_prob + log_prob_correction))
                else:
                    do_accept = uniform < jnp.exp(
                        beta.reshape(
                            (-1, )) * (proposal_log_prob - s.log_prob))

                # do_accept must match ndim of proposal and state (which is 2)
                s.σ = jnp.where(do_accept.reshape(-1, 1), σp, s.σ)
                n_accepted_per_beta = s.n_accepted_per_beta + do_accept.reshape(
                    (sampler.n_chains, sampler.n_replicas))

                s.log_prob = jax.numpy.where(do_accept.reshape(-1),
                                             proposal_log_prob, s.log_prob)

                # exchange betas

                # randomly decide if every set of replicas should be swapped in even or odd order
                swap_order = jax.random.randint(
                    key3,
                    minval=0,
                    maxval=2,
                    shape=(sampler.n_chains, ),
                )  # 0 or 1
                iswap_order = jnp.mod(swap_order + 1, 2)  #  1 or 0

                # indices of even swapped elements (per-row)
                idxs = jnp.arange(0, sampler.n_replicas, 2).reshape(
                    (1, -1)) + swap_order.reshape((-1, 1))
                # indices off odd swapped elements (per-row)
                inn = (idxs + 1) % sampler.n_replicas

                # for every rows of the input, swap elements at idxs with elements at inn
                @partial(jax.vmap, in_axes=(0, 0, 0), out_axes=0)
                def swap_rows(beta_row, idxs, inn):
                    proposed_beta = jax.ops.index_update(
                        beta_row,
                        idxs,
                        beta_row[inn],
                        unique_indices=True,
                        indices_are_sorted=True,
                    )
                    proposed_beta = jax.ops.index_update(
                        proposed_beta,
                        inn,
                        beta_row[idxs],
                        unique_indices=True,
                        indices_are_sorted=False,
                    )
                    return proposed_beta

                proposed_beta = swap_rows(beta, idxs, inn)

                @partial(jax.vmap, in_axes=(0, 0, 0), out_axes=0)
                def compute_proposed_prob(prob, idxs, inn):
                    prob_rescaled = prob[idxs] + prob[inn]
                    return prob_rescaled

                # compute the probability of the swaps
                log_prob = (proposed_beta - state.beta) * s.log_prob.reshape(
                    (sampler.n_chains, sampler.n_replicas))

                prob_rescaled = jnp.exp(
                    compute_proposed_prob(log_prob, idxs, inn))

                prob_rescaled = jnp.exp(
                    compute_proposed_prob(log_prob, idxs, inn))

                uniform = jax.random.uniform(key4,
                                             shape=(sampler.n_chains,
                                                    sampler.n_replicas // 2))

                do_swap = uniform < prob_rescaled

                do_swap = jnp.dstack((do_swap, do_swap)).reshape(
                    (-1, sampler.n_replicas))  #  concat along last dimension
                # roll if swap_ordeer is odd
                @partial(jax.vmap, in_axes=(0, 0), out_axes=0)
                def fix_swap(do_swap, swap_order):
                    return jax.lax.cond(swap_order == 0, lambda x: x,
                                        lambda x: jnp.roll(x, 1), do_swap)

                do_swap = fix_swap(do_swap, swap_order)
                # jax.experimental.host_callback.id_print(state.beta)
                # jax.experimental.host_callback.id_print(proposed_beta)

                new_beta = jax.numpy.where(do_swap, proposed_beta, beta)

                def cb(data):
                    _bt, _pbt, new_beta, so, do_swap, log_prob, prob = data
                    print("--------.---------.---------.--------")
                    print("     cur beta:\n", _bt)
                    print("proposed beta:\n", _pbt)
                    print("     new beta:\n", new_beta)
                    print("swaporder :", so)
                    print("do_swap :\n", do_swap)
                    print("log_prob;\n", log_prob)
                    print("prob_rescaled;\n", prob)
                    return new_beta

                # new_beta = hcb.call(
                #    cb,
                #    (
                #        beta,
                #        proposed_beta,
                #        new_beta,
                #        swap_order,
                #        do_swap,
                #        log_prob,
                #        prob_rescaled,
                #    ),
                #    result_shape=jax.ShapeDtypeStruct(new_beta.shape, new_beta.dtype),
                # )
                # s.beta = new_beta

                swap_order = swap_order.reshape(-1)

                beta_0_moved = jax.vmap(lambda do_swap, i: do_swap[i],
                                        in_axes=(0, 0),
                                        out_axes=0)(do_swap,
                                                    state.beta_0_index)
                proposed_beta_0_index = jnp.mod(
                    state.beta_0_index + (-jnp.mod(swap_order, 2) * 2 + 1) *
                    (-jnp.mod(state.beta_0_index, 2) * 2 + 1),
                    sampler.n_replicas,
                )

                s.beta_0_index = jnp.where(beta_0_moved, proposed_beta_0_index,
                                           s.beta_0_index)

                # swap acceptances
                swapped_n_accepted_per_beta = swap_rows(
                    n_accepted_per_beta, idxs, inn)
                s.n_accepted_per_beta = jax.numpy.where(
                    do_swap,
                    swapped_n_accepted_per_beta,
                    n_accepted_per_beta,
                )

                # Update statistics to compute diffusion coefficient of replicas
                # Total exchange steps performed
                delta = s.beta_0_index - s.beta_position
                s.beta_position = s.beta_position + delta / (
                    state.exchange_steps + i)
                delta2 = s.beta_0_index - s.beta_position
                s.beta_diffusion = s.beta_diffusion + delta * delta2

            new_state = state.replace(
                rng=new_rng,
                σ=s.σ,
                # n_accepted=s.accepted,
                n_samples=state.n_samples +
                sampler.n_sweeps * sampler.n_chains,
                beta=s.beta,
                beta_0_index=s.beta_0_index,
                beta_position=s.beta_position,
                beta_diffusion=s.beta_diffusion,
                exchange_steps=state.exchange_steps + sampler.n_sweeps,
                n_accepted_per_beta=s.n_accepted_per_beta,
            )

        offsets = jnp.arange(0, sampler.n_chains * sampler.n_replicas,
                             sampler.n_replicas)

        return new_state, new_state.σ[new_state.beta_0_index + offsets, :]
    tree_util.tree_map(count,
                       tree_util.tree_flatten(bij_params)[0])).sum()
print('number of parameters: {}'.format(num_params))

bij_params = tree_util.tree_map(lambda x: x / 2., bij_params)

# Direct estimation on the torus.
bij_params, trace = train(rng_train, bij_params, bij_fns, args.num_steps,
                          args.lr, 100)

# Sample from the learned distribution.
num_samples = 100000
num_dims = 2
xamb = random.normal(rng_xamb, [num_samples, num_dims])
xamb = forward(bij_params, bij_fns, xamb)
xtor = jnp.mod(xamb, 2.0 * jnp.pi)
lp = induced_torus_log_density(bij_params, bij_fns, xtor)
xobs = rejection_sampling(rng_xobs, len(xtor), torus_density, args.beta)

# Compute comparison statistics.
mean_mse = jnp.square(jnp.linalg.norm(xtor.mean(0) - xobs.mean(0)))
cov_mse = jnp.square(jnp.linalg.norm(jnp.cov(xtor.T) - jnp.cov(xobs.T)))
approx = jnp.exp(lp)
target = torus_density(xtor)
w = target / approx
Z = jnp.nanmean(w)
log_approx = jnp.log(approx)
log_target = jnp.log(target)
klqp = jnp.nanmean(log_approx - log_target) + jnp.log(Z)
ess = jnp.square(jnp.nansum(w)) / jnp.nansum(jnp.square(w))
ress = 100 * ess / len(w)
Beispiel #29
0
def periodic_shift(side, R, dR):
  """Shifts positions, wrapping them back within a periodic hypercube."""
  return np.mod(R + dR, side)
Beispiel #30
0
  def build_cells(R, **kwargs):
    N = R.shape[0]
    dim = R.shape[1]

    if dim != 2 and dim != 3:
      # NOTE(schsam): Do we want to check this in compute_fn as well?
      raise ValueError(
          'Cell list spatial dimension must be 2 or 3. Found {}'.format(dim))

    neighborhood_tile_count = 3 ** dim

    _, cell_size, cells_per_side, cell_count = \
        _cell_dimensions(dim, box_size, minimum_cell_size)

    hash_multipliers = _compute_hash_constants(dim, cells_per_side)

    # Create cell list data.
    particle_id = lax.iota(np.int64, N)
    # NOTE(schsam): We use the convention that particles that are successfully,
    # copied have their true id whereas particles empty slots have id = N.
    # Then when we copy data back from the grid, copy it to an array of shape
    # [N + 1, output_dimension] and then truncate it to an array of shape
    # [N, output_dimension] which ignores the empty slots.
    mask_id = np.ones((N,), np.int64) * N
    cell_R = np.zeros((cell_count * cell_capacity, dim), dtype=R.dtype)
    cell_id = N * np.ones((cell_count * cell_capacity, 1), dtype=i32)

    # It might be worth adding an occupied mask. However, that will involve
    # more compute since often we will do a mask for species that will include
    # an occupancy test. It seems easier to design around this empty_data_value
    # for now and revisit the issue if it comes up later.
    empty_kwarg_value = 10 ** 5
    cell_kwargs = {}
    for k, v in kwargs.items():
      if not isinstance(v, np.ndarray):
        raise ValueError((
          'Data must be specified as an ndarry. Found "{}" with '
          'type {}'.format(k, type(v))))
      if v.shape[0] != R.shape[0]:
        raise ValueError(
          ('Data must be specified per-particle (an ndarray with shape '
           '(R.shape[0], ...)). Found "{}" with shape {}'.format(k, v.shape)))
      kwarg_shape = v.shape[1:] if v.ndim > 1 else (1,)
      cell_kwargs[k] = empty_kwarg_value * np.ones(
        (cell_count * cell_capacity,) + kwarg_shape, v.dtype)

    indices = np.array(R / cell_size, dtype=i32)
    hashes = np.sum(indices * hash_multipliers, axis=1)

    # Copy the particle data into the grid. Here we use a trick to allow us to
    # copy into all cells simultaneously using a single lax.scatter call. To do
    # this we first sort particles by their cell hash. We then assign each
    # particle to have a cell id = hash * cell_capacity + grid_id where grid_id
    # is a flat list that repeats 0, .., cell_capacity. So long as there are
    # fewer than cell_capacity particles per cell, each particle is guarenteed
    # to get a cell id that is unique.
    sort_map = np.argsort(hashes)
    sorted_R = R[sort_map]
    sorted_hash = hashes[sort_map]
    sorted_id = particle_id[sort_map]

    sorted_kwargs = {}
    for k, v in kwargs.items():
      sorted_kwargs[k] = v[sort_map]

    sorted_cell_id = np.mod(lax.iota(np.int64, N), cell_capacity)
    sorted_cell_id = sorted_hash * cell_capacity + sorted_cell_id

    cell_R = ops.index_update(cell_R, sorted_cell_id, sorted_R)
    sorted_id = np.reshape(sorted_id, (N, 1))
    cell_id = ops.index_update(
        cell_id, sorted_cell_id, sorted_id)
    cell_R = _unflatten_cell_buffer(cell_R, cells_per_side, dim)
    cell_id = _unflatten_cell_buffer(cell_id, cells_per_side, dim)

    for k, v in sorted_kwargs.items():
      if v.ndim == 1:
        v = np.reshape(v, v.shape + (1,))
      cell_kwargs[k] = ops.index_update(cell_kwargs[k], sorted_cell_id, v)
      cell_kwargs[k] = _unflatten_cell_buffer(
        cell_kwargs[k], cells_per_side, dim)

    return CellList(cell_R, cell_id, cell_kwargs)