def stimulate(t, X, stimuli): stimulated = np.zeros_like(X) for stimulus in stimuli: active = np.greater_equal(t, stimulus["start"]) active &= (np.mod(stimulus["start"] - t + 1, stimulus["period"]) < stimulus["duration"]) stimulated = np.where(stimulus["field"] * (active), stimulus["field"], stimulated) return np.where(stimulated != 0, stimulated, X)
def update_periodically(steps: jnp.ndarray, target_update_period: int, params: types.NestedArray, target_params: types.NestedArray) -> types.NestedArray: """Checks whether to update the params and returns the correct params.""" return jax.lax.cond( jnp.mod(steps, target_update_period) == 0, lambda _: params, lambda _: target_params, None)
def test_sorted_piecewise_constant_pdf_sparse_delta(self): """Test sampling when given a large distribution with a big delta in it.""" num_samples = 100 num_bins = 100000 key = random.PRNGKey(0) bins = jnp.arange(num_bins) weights = np.ones(len(bins) - 1) delta_idx = len(weights) // 2 weights[delta_idx] = len(weights) - 1 samples = math.sorted_piecewise_constant_pdf( key, bins[None], weights[None], num_samples, True, )[0] # All samples should be within the range of the bins. self.assertTrue(jnp.all(samples >= bins[0])) self.assertTrue(jnp.all(samples <= bins[-1])) # Samples modded by their bin index should resemble a uniform distribution. samples_mod = jnp.mod(samples, 1) self.assertLessEqual( sp.stats.kstest(samples_mod, 'uniform', (0, 1)).statistic, 0.2) # The delta function bin should contain ~half of the samples. in_delta = (samples >= bins[delta_idx]) & (samples <= bins[delta_idx + 1]) self.assertAllClose(jnp.mean(in_delta), 0.5, atol=0.05)
def test_sorted_piecewise_constant_pdf_large_flat(self): """Test sampling when given a large flat distribution.""" num_samples = 100 num_bins = 100000 key = random.PRNGKey(0) bins = jnp.arange(num_bins) weights = np.ones(len(bins) - 1) samples = math.sorted_piecewise_constant_pdf( key, bins[None], weights[None], num_samples, True, )[0] # All samples should be within the range of the bins. self.assertTrue(jnp.all(samples >= bins[0])) self.assertTrue(jnp.all(samples <= bins[-1])) # Samples modded by their bin index should resemble a uniform distribution. samples_mod = jnp.mod(samples, 1) self.assertLessEqual( sp.stats.kstest(samples_mod, 'uniform', (0, 1)).statistic, 0.2) # All samples should collectively resemble a uniform distribution. self.assertLessEqual( sp.stats.kstest(samples, 'uniform', (bins[0], bins[-1])).statistic, 0.2)
def periodic_update( new_tensors: base.Params, old_tensors: base.Params, steps: chex.Array, update_period: int ) -> base.Params: """Periodically update all parameters with new values. A slow copy of a model's parameters, updated every K actual updates, can be used to implement forms of self-supervision (in supervised learning), or to stabilise temporal difference learning updates (in reinforcement learning). References: [Grill et al., 2020](https://arxiv.org/abs/2006.07733) [Mnih et al., 2015](https://arxiv.org/abs/1312.5602) Args: new_tensors: the latest value of the tensors. old_tensors: a slow copy of the model's parameters. steps: number of update steps on the "online" network. update_period: every how many steps to update the "target" network. Returns: a slow copy of the model's parameters, updated every `update_period` steps. """ return jax.lax.cond( jnp.mod(steps, update_period) == 0, lambda _: new_tensors, lambda _: old_tensors, None)
def value(self, count: JTensor) -> JTensor: relative_step = jnp.mod(count, self._period) output = self._schedules[0].value(count) for boundary, schedule in zip(self._boundaries, self._schedules[1:]): output = jnp.where(relative_step < boundary, output, schedule.value(count)) return output
def stimulate(t, X, stimuli): for stimulus in stimuli: active = t > stimulus["start"] active &= t < stimulus["start"] + stimulus["duration"] # for some weird reason checks for cyclic stimuli does not work active = (np.mod(t - stimulus["start"], stimulus["period"]) < stimulus["duration"]) # cyclic X = np.where(stimulus["field"] * (active), stimulus["field"], X) return X
def body_fun(i_acc): i, acc = i_acc return (i + 1, (jnp.cos(acc) + lax.cond(jnp.mod(i, 2) == 0, lambda acc: jnp.sin(acc), lambda acc: acc, acc)))
def kepler(mean_anom, ecc): # We're going to apply array broadcasting here since the logic of our op # is much simpler if we require the inputs to all have the same shapes mean_anom_, ecc_ = jnp.broadcast_arrays(mean_anom, ecc) # Then we need to wrap into the range [0, 2*pi) M_mod = jnp.mod(mean_anom_, 2 * np.pi) return _kepler_prim.bind(M_mod, ecc_)
def current_stimulate(t, X, stimuli): stimulated = np.zeros_like(X) for stimulus in stimuli: # active = np.greater_equal(t, stimulus["start"]) # active &= (np.mod(stimulus["start"] - t + 1, stimulus["period"]) < stimulus["duration"]) active = np.greater_equal(t ,stimulus["start"]) # active &= np.greater_equal(stimulus["start"] + stimulus["duration"],t) active &= (np.mod(t - stimulus["start"], stimulus["period"]) < stimulus["duration"]) # this works for cyclics stimulated = np.where(stimulus["field"] * (active), stimulus["field"], stimulated) return np.where(stimulated != 0, stimulated, X)
def visualize_cmap(value, weight, colormap, lo=None, hi=None, percentile=99., curve_fn=lambda x: x, modulus=None, matte_background=True): """Visualize a 1D image and a 1D weighting according to some colormap. Args: value: A 1D image. weight: A weight map, in [0, 1]. colormap: A colormap function. lo: The lower bound to use when rendering, if None then use a percentile. hi: The upper bound to use when rendering, if None then use a percentile. percentile: What percentile of the value map to crop to when automatically generating `lo` and `hi`. Depends on `weight` as well as `value'. curve_fn: A curve function that gets applied to `value`, `lo`, and `hi` before the rest of visualization. Good choices: x, 1/(x+eps), log(x+eps). modulus: If not None, mod the normalized value by `modulus`. Use (0, 1]. If `modulus` is not None, `lo`, `hi` and `percentile` will have no effect. matte_background: If True, matte the image over a checkerboard. Returns: A colormap rendering. """ # Identify the values that bound the middle of `value' according to `weight`. lo_auto, hi_auto = math.weighted_percentile( value, weight, [50 - percentile / 2, 50 + percentile / 2]) # If `lo` or `hi` are None, use the automatically-computed bounds above. eps = jnp.finfo(jnp.float32).eps lo = lo or (lo_auto - eps) hi = hi or (hi_auto + eps) # Curve all values. value, lo, hi = [curve_fn(x) for x in [value, lo, hi]] # Wrap the values around if requested. if modulus: value = jnp.mod(value, modulus) / modulus else: # Otherwise, just scale to [0, 1]. value = jnp.nan_to_num( jnp.clip((value - jnp.minimum(lo, hi)) / jnp.abs(hi - lo), 0, 1)) if colormap: colorized = colormap(value)[:, :, :3] else: assert len(value.shape) == 3 and value.shape[-1] == 3 colorized = value return matte(colorized, weight) if matte_background else colorized
def getn(l): """ IN: l = n^2 - n / 2 OUT: n (positive integer solution) """ n = (1 + np.sqrt(1 + 8 * l)) / 2 assert np.equal(np.mod(n, 1), 0) # make sure n is an integer n = int(n) assert l == n**2 - n / 2 return n
def learning_rate_fn(step): d_step = step - wait_steps pfraction = jnp.mod(d_step, halfwavelength_steps) / halfwavelength_steps scale_factor = min_value + (1.0 - min_value) * 0.5 * ( jnp.cos(jnp.pi * pfraction) + 1.0) scale_factor = lax.cond(d_step < 0., d_step, lambda d_step: 1.0, d_step, lambda d_step: scale_factor) lr = base_learning_rate * scale_factor if warmup_length > 0.0: lr = lr * jnp.minimum(1., step / float(warmup_length) / steps_per_epoch) return lr
def periodic_displacement(side, dR): """Wraps displacement vectors into a hypercube. Args: side: Specification of hypercube size. Either, (a) float if all sides have equal length. (b) ndarray(spatial_dim) if sides have different lengths. dR: Matrix of displacements; ndarray(shape=[..., spatial_dim]). Returns: Matrix of wrapped displacements; ndarray(shape=[..., spatial_dim]). """ return np.mod(dR + side * f32(0.5), side) - f32(0.5) * side
def periodic_shift(L, x, dx): """ do a periodic shift arguments L : Array length of dimension x : Array positions dx : Array position increment """ return jnp.mod(x + dx, L)
def stimulate(t, X, stimuli): stimulated = jnp.zeros_like(X) for stimulus in stimuli: # check if stimulus is in the past active = jnp.greater_equal(t, stimulus.protocol.start) # check if stimulus is active at the current time active &= (jnp.mod(stimulus.protocol.start - t + 1, stimulus.protocol.period) < stimulus.protocol.duration) # build the stimulus field stimulated = jnp.where(stimulus.field * (active), stimulus.field, stimulated) # set the field to the stimulus return jnp.where(stimulated != 0, stimulated, X)
def split_top_k(split_queries: Array) -> Tuple[Array, Array, Array]: # Find most similar clusters prototype_scores = jnp.einsum('qd,pd->qp', split_queries, prototypes) top_indices = jax.lax.top_k(prototype_scores, self.n_search)[1] # Perform approximate top-k similarity search over most similar clusters. selected_data = table[top_indices] split_scores = jnp.einsum('qd,qcrvd->qcrv', split_queries, selected_data) # Find highest scoring vector for each row. top_id_by_row = jnp.argmax(split_scores, axis=-1) top_score_by_row = jnp.max(split_scores, axis=-1) top_id_by_row = top_id_by_row.reshape( queries_per_split, self.n_search * rows_per_cluster) top_score_by_row = top_score_by_row.reshape( queries_per_split, self.n_search * rows_per_cluster) # Take k highest scores among all rows. top_row_idx = jnp.argsort(top_score_by_row, axis=-1)[:, :-self.k_top - 1:-1] # Sub-select best indices for k best rows. ids_by_topk_row = jut.matmul_slice(top_id_by_row, top_row_idx) # Gather highest scoring vectors for k best rows. query_index = jnp.arange(queries_per_split).reshape(-1, 1).tile( [1, self.k_top]) top_cluster_idx, top_cluster_row_idx = jnp.divmod( top_row_idx, rows_per_cluster) split_topk_values = selected_data[query_index, top_cluster_idx, top_cluster_row_idx, ids_by_topk_row] row_offset = jnp.mod( jnp.arange(0, self.n_search * values_per_cluster, values_per_row), values_per_cluster) cluster_offset = jnp.arange(0, table_size, values_per_cluster) # Convert row indices to indices into flattened table. top_table_id_by_row = top_id_by_row + row_offset.reshape( 1, -1) + cluster_offset[top_indices].repeat(rows_per_cluster, axis=-1) # Get best ids into flattened table. split_topk_ids = jut.matmul_slice(top_table_id_by_row, top_row_idx) split_topk_scores = jut.matmul_slice(top_score_by_row, top_row_idx) return split_topk_values, split_topk_scores, split_topk_ids
def advect(f, vx, vy): """Move field f according to x and y velocities (u and v) using an implicit Euler integrator.""" rows, cols = f.shape cell_ys, cell_xs = np.meshgrid(np.arange(rows), np.arange(cols)) cell_xs = cell_xs.astype(dtype) cell_ys = cell_ys.astype(dtype) center_xs = (cell_xs - vx).ravel() center_ys = (cell_ys - vy).ravel() # Compute indices of source cells. left_ix = np.floor(center_xs).astype(int) top_ix = np.floor(center_ys).astype(int) rw = center_xs - left_ix # Relative weight of right-hand cells. bw = center_ys - top_ix # Relative weight of bottom cells. left_ix = np.mod(left_ix, rows) # Wrap around edges of simulation. right_ix = np.mod(left_ix + 1, rows) top_ix = np.mod(top_ix, cols) bot_ix = np.mod(top_ix + 1, cols) # A linearly-weighted sum of the 4 surrounding cells. flat_f = (1 - rw) * ((1 - bw)*f[left_ix, top_ix] + bw*f[left_ix, bot_ix]) \ + rw * ((1 - bw)*f[right_ix, top_ix] + bw*f[right_ix, bot_ix]) return np.reshape(flat_f, (rows, cols))
def visualize_depth(depth, acc=None, near=None, far=None, curve_fn=lambda x: jnp.log(x + jnp.finfo(jnp.float32).eps), modulus=0, colormap=None): """Visualize a depth map. Args: depth: A depth map. acc: An accumulation map, in [0, 1]. near: The depth of the near plane, if None then just use the min(). far: The depth of the far plane, if None then just use the max(). curve_fn: A curve function that gets applied to `depth`, `near`, and `far` before the rest of visualization. Good choices: x, 1/(x+eps), log(x+eps). modulus: If > 0, mod the normalized depth by `modulus`. Use (0, 1]. colormap: A colormap function. If None (default), will be set to matplotlib's viridis if modulus==0, sinebow otherwise. Returns: An RGB visualization of `depth`. """ # If `near` or `far` are None, identify the min/max non-NaN values. eps = jnp.finfo(jnp.float32).eps near = near or jnp.min(jnp.nan_to_num(depth, jnp.inf)) - eps far = far or jnp.max(jnp.nan_to_num(depth, -jnp.inf)) + eps # Curve all values. depth, near, far = [curve_fn(x) for x in [depth, near, far]] # Wrap the values around if requested. if modulus > 0: value = jnp.mod(depth, modulus) / modulus colormap = colormap or sinebow else: # Scale to [0, 1]. value = jnp.nan_to_num(jnp.clip((depth - near) / (far - near), 0, 1)) colormap = colormap or cm.get_cmap('viridis') vis = colormap(value)[:, :, :3] # Set non-accumulated pixels to white. if acc is not None: vis = vis * acc[:, :, None] + (1 - acc)[:, :, None] return vis
def circdist(x: jnp.ndarray, y: jnp.ndarray, circumference: float) -> jnp.ndarray: """Calculate the signed circular distance between two arrays. Returns positive numbers if y is clockwise compared to x, negative if y is counter- clockwise compared to x. Args: x: The first array. y: The second array. circumference: The circumference of the circle. Returns: An array of the same shape as x and y, containing the signed circular distances. """ assert y.shape == x.shape return -jnp.mod(x - y - circumference / 2, circumference) + circumference / 2
def mobius_flow(theta: jnp.ndarray, w: jnp.ndarray) -> jnp.ndarray: """Following [1] the Mobius flow must have the property that the endpoints of the interval [0, 2pi] are mapped to themselves. This transformation enforces this property by appling a backwards rotation to bring the zero angle back to itself. Args: theta: The angle parameterizing a point on the circle. w: Center points of the Mobius transformation. Returns: out: The Mobius flow as a transformation on angles leaving the endpoints of the interval invariant. """ theta = jnp.hstack((0., theta)) omega = mobius_angle(theta, w) omega -= omega[..., [0]] return jnp.squeeze(jnp.mod(omega[..., 1:], 2.*jnp.pi))
def mutate_population(key: jnp.ndarray, population: Population, mutation_rate: float = 0.15, alphabet_size: int = 4) -> Population: """Mutate a polymer population. Here we convert the total mutation rate into mutation rates for each member of an alphabet and use this, combined with the non-mutation rate, to sample mutations. These mutations are simply members of the alphabet that are added to the existing values followed by a modulo operation for the size of the alphabet. Example: alphabet_size = 4, pop = [[0,1,2,3]], mutation = [[0,0,1,2]], new_population = [[0,1,3,1]] Args: key: A random number generator key e.g. produced via jax.random.PRNGKey. population: A polymer population. mutation_rate: The combined rate of mutation over all forms of mutation. E.g. 15% that it would be one of 1->2 or 2->3. alphabet_size: The size of the alphabet from which polymer elements are sampled. Returns: Population: A population of polymers. """ individual_p_mutation = mutation_rate / alphabet_size # Lazily double-counts self-transitions as a type of mutation # in the interest of prototyping p_no_mutation = (1 - mutation_rate) mutation_probs = [individual_p_mutation for _ in range(alphabet_size)] mutation_probs = [p_no_mutation] + mutation_probs mutation = random.choice(key, a=jnp.array(range(alphabet_size + 1)), shape=population.shape, p=jnp.array(mutation_probs)) return jnp.mod(population + mutation, alphabet_size - 1)
def segment_max(data, segment_ids, num_segments=None, indices_are_sorted=False, unique_indices=False): """Computes the max within segments of an array. Similar to TensorFlow's segment_max: https://www.tensorflow.org/api_docs/python/tf/math/segment_max Args: data: an array with the values to be maxed over. segment_ids: an array with integer dtype that indicates the segments of `data` (along its leading axis) to be maxed over. Values can be repeated and need not be sorted. Values outside of the range [0, num_segments) are wrapped into that range by applying jnp.mod. num_segments: optional, an int with positive value indicating the number of segments. The default is ``jnp.maximum(jnp.max(segment_ids) + 1, jnp.max(-segment_ids))`` but since `num_segments` determines the size of the output, a static value must be provided to use ``segment_max`` in a ``jit``-compiled function. indices_are_sorted: whether ``segment_ids`` is known to be sorted unique_indices: whether ``segment_ids`` is known to be free of duplicates Returns: An array with shape ``(num_segments,) + data.shape[1:]`` representing the segment maxs. """ if num_segments is None: num_segments = jnp.maximum( jnp.max(segment_ids) + 1, jnp.max(-segment_ids)) num_segments = int(num_segments) min_value = dtype_min_value(data.dtype) out = jnp.full((num_segments, ) + data.shape[1:], min_value, dtype=data.dtype) segment_ids = jnp.mod(segment_ids, num_segments) return jax.ops.index_max(out, segment_ids, data, indices_are_sorted, unique_indices)
def __call__(self, inputs, prev_state): """Writes a new memory into the episodic memory. Args: inputs: A Tensor of shape ``[batch_size, memory_size]``. prev_state: The previous state of the episodic memory, which is a tuple with a (i) counter of shape ``[batch_size, 1]`` indicating how many memories have been written so far, and (ii) a tensor of shape ``[batch_size, capacity, memory_size]`` with the full content of the episodic memory. Returns: A tuple with (i) a tensor of shape ``[batch_size, capacity, memory_size]`` with the full content of the episodic memory, including the newly written memory, and (ii) the new state of the episodic memory. """ inputs = jax.lax.stop_gradient(inputs) counter, memories = prev_state counter_mod = jnp.mod(counter, self._capacity) slot_selector = jnp.expand_dims( jax.nn.one_hot(counter_mod, self._capacity), axis=2) memories = memories * (1 - slot_selector) + ( slot_selector * jnp.expand_dims(inputs, 1)) counter = counter + 1 return memories, (counter, memories)
def onnx_mod(a, b, fmod=0): if fmod: return jnp.fmod(a, b) else: return jnp.mod(a, b)
def cell_list_fn(position: Array, capacity_overflow_update: Optional[Tuple[ int, bool, Callable[..., CellList]]] = None, extra_capacity: int = 0, **kwargs) -> CellList: N = position.shape[0] dim = position.shape[1] if dim != 2 and dim != 3: # NOTE(schsam): Do we want to check this in compute_fn as well? raise ValueError( f'Cell list spatial dimension must be 2 or 3. Found {dim}.') _, cell_size, cells_per_side, cell_count = \ _cell_dimensions(dim, box_size, minimum_cell_size) if capacity_overflow_update is None: cell_capacity = _estimate_cell_capacity(position, box_size, cell_size, buffer_size_multiplier) cell_capacity += extra_capacity overflow = False update_fn = cell_list_fn else: cell_capacity, overflow, update_fn = capacity_overflow_update hash_multipliers = _compute_hash_constants(dim, cells_per_side) # Create cell list data. particle_id = lax.iota(i32, N) # NOTE(schsam): We use the convention that particles that are successfully, # copied have their true id whereas particles empty slots have id = N. # Then when we copy data back from the grid, copy it to an array of shape # [N + 1, output_dimension] and then truncate it to an array of shape # [N, output_dimension] which ignores the empty slots. cell_position = jnp.zeros((cell_count * cell_capacity, dim), dtype=position.dtype) cell_id = N * jnp.ones((cell_count * cell_capacity, 1), dtype=i32) # It might be worth adding an occupied mask. However, that will involve # more compute since often we will do a mask for species that will include # an occupancy test. It seems easier to design around this empty_data_value # for now and revisit the issue if it comes up later. empty_kwarg_value = 10**5 cell_kwargs = {} # pytype: disable=attribute-error for k, v in kwargs.items(): if not util.is_array(v): raise ValueError( (f'Data must be specified as an ndarray. Found "{k}" ' f'with type {type(v)}.')) if v.shape[0] != position.shape[0]: raise ValueError( ('Data must be specified per-particle (an ndarray ' f'with shape ({N}, ...)). Found "{k}" with ' f'shape {v.shape}.')) kwarg_shape = v.shape[1:] if v.ndim > 1 else (1, ) cell_kwargs[k] = empty_kwarg_value * jnp.ones( (cell_count * cell_capacity, ) + kwarg_shape, v.dtype) # pytype: enable=attribute-error indices = jnp.array(position / cell_size, dtype=i32) hashes = jnp.sum(indices * hash_multipliers, axis=1) # Copy the particle data into the grid. Here we use a trick to allow us to # copy into all cells simultaneously using a single lax.scatter call. To do # this we first sort particles by their cell hash. We then assign each # particle to have a cell id = hash * cell_capacity + grid_id where # grid_id is a flat list that repeats 0, .., cell_capacity. So long as # there are fewer than cell_capacity particles per cell, each particle is # guarenteed to get a cell id that is unique. sort_map = jnp.argsort(hashes) sorted_position = position[sort_map] sorted_hash = hashes[sort_map] sorted_id = particle_id[sort_map] sorted_kwargs = {} for k, v in kwargs.items(): sorted_kwargs[k] = v[sort_map] sorted_cell_id = jnp.mod(lax.iota(i32, N), cell_capacity) sorted_cell_id = sorted_hash * cell_capacity + sorted_cell_id cell_position = cell_position.at[sorted_cell_id].set(sorted_position) sorted_id = jnp.reshape(sorted_id, (N, 1)) cell_id = cell_id.at[sorted_cell_id].set(sorted_id) cell_position = _unflatten_cell_buffer(cell_position, cells_per_side, dim) cell_id = _unflatten_cell_buffer(cell_id, cells_per_side, dim) for k, v in sorted_kwargs.items(): if v.ndim == 1: v = jnp.reshape(v, v.shape + (1, )) cell_kwargs[k] = cell_kwargs[k].at[sorted_cell_id].set(v) cell_kwargs[k] = _unflatten_cell_buffer(cell_kwargs[k], cells_per_side, dim) occupancy = ops.segment_sum(jnp.ones_like(hashes), hashes, cell_count) max_occupancy = jnp.max(occupancy) overflow = overflow | (max_occupancy >= cell_capacity) return CellList(cell_position, cell_id, cell_kwargs, overflow, cell_capacity, update_fn) # pytype: disable=wrong-arg-count
def _sample_next(sampler, machine, parameters: PyTree, state: MetropolisPtSamplerState): new_rng, rng = jax.random.split(state.rng) # def cbr(data): # new_rng, rng = data # print("sample_next newrng:\n", new_rng, "\nand rng:\n", rng) # return new_rng # new_rng = hcb.call( # cbr, # (new_rng, rng), # result_shape=jax.ShapeDtypeStruct(new_rng.shape, new_rng.dtype), # ) with loops.Scope() as s: s.key = rng s.σ = state.σ s.log_prob = sampler.machine_pow * machine(parameters, state.σ).real s.beta = state.beta # for logging s.beta_0_index = state.beta_0_index s.n_accepted_per_beta = state.n_accepted_per_beta s.beta_position = state.beta_position s.beta_diffusion = state.beta_diffusion for i in s.range(sampler.n_sweeps): # 1 to propagate for next iteration, 1 for uniform rng and n_chains for transition kernel s.key, key1, key2, key3, key4 = jax.random.split(s.key, 5) # def cbi(data): # i, beta = data # print("sweep #", i, " for beta=\n", beta) # return beta # beta = hcb.call( # cbi, # (i, s.beta), # result_shape=jax.ShapeDtypeStruct(s.beta.shape, s.beta.dtype), # ) beta = s.beta σp, log_prob_correction = sampler.rule.transition( sampler, machine, parameters, state, key1, s.σ) proposal_log_prob = sampler.machine_pow * machine( parameters, σp).real uniform = jax.random.uniform(key2, shape=(sampler.n_batches, )) if log_prob_correction is not None: do_accept = uniform < jnp.exp( beta.reshape((-1, )) * (proposal_log_prob - s.log_prob + log_prob_correction)) else: do_accept = uniform < jnp.exp( beta.reshape( (-1, )) * (proposal_log_prob - s.log_prob)) # do_accept must match ndim of proposal and state (which is 2) s.σ = jnp.where(do_accept.reshape(-1, 1), σp, s.σ) n_accepted_per_beta = s.n_accepted_per_beta + do_accept.reshape( (sampler.n_chains, sampler.n_replicas)) s.log_prob = jax.numpy.where(do_accept.reshape(-1), proposal_log_prob, s.log_prob) # exchange betas # randomly decide if every set of replicas should be swapped in even or odd order swap_order = jax.random.randint( key3, minval=0, maxval=2, shape=(sampler.n_chains, ), ) # 0 or 1 iswap_order = jnp.mod(swap_order + 1, 2) # 1 or 0 # indices of even swapped elements (per-row) idxs = jnp.arange(0, sampler.n_replicas, 2).reshape( (1, -1)) + swap_order.reshape((-1, 1)) # indices off odd swapped elements (per-row) inn = (idxs + 1) % sampler.n_replicas # for every rows of the input, swap elements at idxs with elements at inn @partial(jax.vmap, in_axes=(0, 0, 0), out_axes=0) def swap_rows(beta_row, idxs, inn): proposed_beta = jax.ops.index_update( beta_row, idxs, beta_row[inn], unique_indices=True, indices_are_sorted=True, ) proposed_beta = jax.ops.index_update( proposed_beta, inn, beta_row[idxs], unique_indices=True, indices_are_sorted=False, ) return proposed_beta proposed_beta = swap_rows(beta, idxs, inn) @partial(jax.vmap, in_axes=(0, 0, 0), out_axes=0) def compute_proposed_prob(prob, idxs, inn): prob_rescaled = prob[idxs] + prob[inn] return prob_rescaled # compute the probability of the swaps log_prob = (proposed_beta - state.beta) * s.log_prob.reshape( (sampler.n_chains, sampler.n_replicas)) prob_rescaled = jnp.exp( compute_proposed_prob(log_prob, idxs, inn)) prob_rescaled = jnp.exp( compute_proposed_prob(log_prob, idxs, inn)) uniform = jax.random.uniform(key4, shape=(sampler.n_chains, sampler.n_replicas // 2)) do_swap = uniform < prob_rescaled do_swap = jnp.dstack((do_swap, do_swap)).reshape( (-1, sampler.n_replicas)) # concat along last dimension # roll if swap_ordeer is odd @partial(jax.vmap, in_axes=(0, 0), out_axes=0) def fix_swap(do_swap, swap_order): return jax.lax.cond(swap_order == 0, lambda x: x, lambda x: jnp.roll(x, 1), do_swap) do_swap = fix_swap(do_swap, swap_order) # jax.experimental.host_callback.id_print(state.beta) # jax.experimental.host_callback.id_print(proposed_beta) new_beta = jax.numpy.where(do_swap, proposed_beta, beta) def cb(data): _bt, _pbt, new_beta, so, do_swap, log_prob, prob = data print("--------.---------.---------.--------") print(" cur beta:\n", _bt) print("proposed beta:\n", _pbt) print(" new beta:\n", new_beta) print("swaporder :", so) print("do_swap :\n", do_swap) print("log_prob;\n", log_prob) print("prob_rescaled;\n", prob) return new_beta # new_beta = hcb.call( # cb, # ( # beta, # proposed_beta, # new_beta, # swap_order, # do_swap, # log_prob, # prob_rescaled, # ), # result_shape=jax.ShapeDtypeStruct(new_beta.shape, new_beta.dtype), # ) # s.beta = new_beta swap_order = swap_order.reshape(-1) beta_0_moved = jax.vmap(lambda do_swap, i: do_swap[i], in_axes=(0, 0), out_axes=0)(do_swap, state.beta_0_index) proposed_beta_0_index = jnp.mod( state.beta_0_index + (-jnp.mod(swap_order, 2) * 2 + 1) * (-jnp.mod(state.beta_0_index, 2) * 2 + 1), sampler.n_replicas, ) s.beta_0_index = jnp.where(beta_0_moved, proposed_beta_0_index, s.beta_0_index) # swap acceptances swapped_n_accepted_per_beta = swap_rows( n_accepted_per_beta, idxs, inn) s.n_accepted_per_beta = jax.numpy.where( do_swap, swapped_n_accepted_per_beta, n_accepted_per_beta, ) # Update statistics to compute diffusion coefficient of replicas # Total exchange steps performed delta = s.beta_0_index - s.beta_position s.beta_position = s.beta_position + delta / ( state.exchange_steps + i) delta2 = s.beta_0_index - s.beta_position s.beta_diffusion = s.beta_diffusion + delta * delta2 new_state = state.replace( rng=new_rng, σ=s.σ, # n_accepted=s.accepted, n_samples=state.n_samples + sampler.n_sweeps * sampler.n_chains, beta=s.beta, beta_0_index=s.beta_0_index, beta_position=s.beta_position, beta_diffusion=s.beta_diffusion, exchange_steps=state.exchange_steps + sampler.n_sweeps, n_accepted_per_beta=s.n_accepted_per_beta, ) offsets = jnp.arange(0, sampler.n_chains * sampler.n_replicas, sampler.n_replicas) return new_state, new_state.σ[new_state.beta_0_index + offsets, :]
tree_util.tree_map(count, tree_util.tree_flatten(bij_params)[0])).sum() print('number of parameters: {}'.format(num_params)) bij_params = tree_util.tree_map(lambda x: x / 2., bij_params) # Direct estimation on the torus. bij_params, trace = train(rng_train, bij_params, bij_fns, args.num_steps, args.lr, 100) # Sample from the learned distribution. num_samples = 100000 num_dims = 2 xamb = random.normal(rng_xamb, [num_samples, num_dims]) xamb = forward(bij_params, bij_fns, xamb) xtor = jnp.mod(xamb, 2.0 * jnp.pi) lp = induced_torus_log_density(bij_params, bij_fns, xtor) xobs = rejection_sampling(rng_xobs, len(xtor), torus_density, args.beta) # Compute comparison statistics. mean_mse = jnp.square(jnp.linalg.norm(xtor.mean(0) - xobs.mean(0))) cov_mse = jnp.square(jnp.linalg.norm(jnp.cov(xtor.T) - jnp.cov(xobs.T))) approx = jnp.exp(lp) target = torus_density(xtor) w = target / approx Z = jnp.nanmean(w) log_approx = jnp.log(approx) log_target = jnp.log(target) klqp = jnp.nanmean(log_approx - log_target) + jnp.log(Z) ess = jnp.square(jnp.nansum(w)) / jnp.nansum(jnp.square(w)) ress = 100 * ess / len(w)
def periodic_shift(side, R, dR): """Shifts positions, wrapping them back within a periodic hypercube.""" return np.mod(R + dR, side)
def build_cells(R, **kwargs): N = R.shape[0] dim = R.shape[1] if dim != 2 and dim != 3: # NOTE(schsam): Do we want to check this in compute_fn as well? raise ValueError( 'Cell list spatial dimension must be 2 or 3. Found {}'.format(dim)) neighborhood_tile_count = 3 ** dim _, cell_size, cells_per_side, cell_count = \ _cell_dimensions(dim, box_size, minimum_cell_size) hash_multipliers = _compute_hash_constants(dim, cells_per_side) # Create cell list data. particle_id = lax.iota(np.int64, N) # NOTE(schsam): We use the convention that particles that are successfully, # copied have their true id whereas particles empty slots have id = N. # Then when we copy data back from the grid, copy it to an array of shape # [N + 1, output_dimension] and then truncate it to an array of shape # [N, output_dimension] which ignores the empty slots. mask_id = np.ones((N,), np.int64) * N cell_R = np.zeros((cell_count * cell_capacity, dim), dtype=R.dtype) cell_id = N * np.ones((cell_count * cell_capacity, 1), dtype=i32) # It might be worth adding an occupied mask. However, that will involve # more compute since often we will do a mask for species that will include # an occupancy test. It seems easier to design around this empty_data_value # for now and revisit the issue if it comes up later. empty_kwarg_value = 10 ** 5 cell_kwargs = {} for k, v in kwargs.items(): if not isinstance(v, np.ndarray): raise ValueError(( 'Data must be specified as an ndarry. Found "{}" with ' 'type {}'.format(k, type(v)))) if v.shape[0] != R.shape[0]: raise ValueError( ('Data must be specified per-particle (an ndarray with shape ' '(R.shape[0], ...)). Found "{}" with shape {}'.format(k, v.shape))) kwarg_shape = v.shape[1:] if v.ndim > 1 else (1,) cell_kwargs[k] = empty_kwarg_value * np.ones( (cell_count * cell_capacity,) + kwarg_shape, v.dtype) indices = np.array(R / cell_size, dtype=i32) hashes = np.sum(indices * hash_multipliers, axis=1) # Copy the particle data into the grid. Here we use a trick to allow us to # copy into all cells simultaneously using a single lax.scatter call. To do # this we first sort particles by their cell hash. We then assign each # particle to have a cell id = hash * cell_capacity + grid_id where grid_id # is a flat list that repeats 0, .., cell_capacity. So long as there are # fewer than cell_capacity particles per cell, each particle is guarenteed # to get a cell id that is unique. sort_map = np.argsort(hashes) sorted_R = R[sort_map] sorted_hash = hashes[sort_map] sorted_id = particle_id[sort_map] sorted_kwargs = {} for k, v in kwargs.items(): sorted_kwargs[k] = v[sort_map] sorted_cell_id = np.mod(lax.iota(np.int64, N), cell_capacity) sorted_cell_id = sorted_hash * cell_capacity + sorted_cell_id cell_R = ops.index_update(cell_R, sorted_cell_id, sorted_R) sorted_id = np.reshape(sorted_id, (N, 1)) cell_id = ops.index_update( cell_id, sorted_cell_id, sorted_id) cell_R = _unflatten_cell_buffer(cell_R, cells_per_side, dim) cell_id = _unflatten_cell_buffer(cell_id, cells_per_side, dim) for k, v in sorted_kwargs.items(): if v.ndim == 1: v = np.reshape(v, v.shape + (1,)) cell_kwargs[k] = ops.index_update(cell_kwargs[k], sorted_cell_id, v) cell_kwargs[k] = _unflatten_cell_buffer( cell_kwargs[k], cells_per_side, dim) return CellList(cell_R, cell_id, cell_kwargs)