def fori_loop(lower, upper, body_fun, init_val): if _DISABLE_CONTROL_FLOW_PRIM: val = init_val for i in range(int(lower), int(upper)): val = body_fun(i, val) return val else: return lax.fori_loop(lower, upper, body_fun, init_val)
def epoch_train(opt_state, rng): def body_fn(i, val): loss_sum, opt_state, rng = val rng, batch = binarize(rng, train_fetch(i, train_idx)[0]) loss, opt_state, rng = svi_update(i, rng, opt_state, (batch,), (batch,),) loss_sum += loss return loss_sum, opt_state, rng return lax.fori_loop(0, num_train, body_fn, (0., opt_state, rng))
def sum_first_n(arr, num): def body_fun(i, state): arr, total = state['arr'], state['total'] arr_i = lax.dynamic_index_in_dim(arr, i, 0, False) return {'arr': arr, 'total': lax.add(total, arr_i)} init_val = {'arr': arr, 'total': 0.} out_val = lax.fori_loop(0, lax.min(arr.shape[0], num), body_fun, init_val) return out_val['total']
def sum_first_n(arr, num): def body_fun(i, state): arr, total, _ = state arr_i = lax.dynamic_index_in_dim(arr, i, 0, False) return (arr, lax.add(total, arr_i), ()) init_val = (arr, 0., ()) _, tot, _ = lax.fori_loop(0, lax.min(arr.shape[0], num), body_fun, init_val) return tot
def simulate(t, state, dt, space, t_h, parameters, goal, N_step, dx): # for i in range(N_step): # t, state = step(t, state, dt, space, t_h, parameters) body_fun = lambda i, val: step(*val, dt, space, t_h, parameters) t, state = fori_loop(0, N_step, body_fun, (t, state)) occ = occupation(state, goal, dx) occ = (occ * np.conj(occ)).real print(occ) return occ
def iterate_leapfrogs(theta, phi, eps, M, L, grad_fun): init_val = {"theta": theta, "phi": phi, "eps": eps, "M": M} to_iterate = lambda i, val: body_fun_leapfrog(i, val, grad_fun) final_res = fori_loop(0, L, to_iterate, init_val) return final_res["theta"], final_res["phi"]
def nloglik(X): y = jnp.concatenate([data.T, jnp.ones(shape=(1, N))], axis=0).T def body(i, ll): Si = jnp.outer(y[i], y[i]) return ll + jnp.log(1 + jnp.trace(jnp.linalg.solve(X, Si))) llik = - (df + p) * 0.5 * fori_loop(0, N, body, 0.) return llik - 0.5 * N * jnp.linalg.slogdet(X)[1]
def eval_test(svi_state, batchifier_state, num_batch): def body_fn(i, loss_sum): batch = test_fetch(i, batchifier_state) loss = svi.evaluate(svi_state, *batch) loss_sum += loss / (args.num_samples * num_batch) return loss_sum return lax.fori_loop(0, num_batch, body_fn, 0.)
def epoch_train(svi_state, batchifier_state, num_batch): def body_fn(i, val): svi_state, loss = val batch = train_fetch(i, batchifier_state) svi_state, batch_loss = svi.update(svi_state, *batch) loss += batch_loss / (args.num_samples * num_batch) return svi_state, loss return lax.fori_loop(0, num_batch, body_fn, (svi_state, 0.))
def make_distanceMatrix(points, idx, distance, n): distmatrix = jnp.zeros(shape=(n, n)) def bodyfun(i, dists): j, k = idx[i] return dists.at[j, k].set(distance(points[j], points[k])) distmatrix = fori_loop(0, len(idx), bodyfun, distmatrix) return distmatrix
def test_mnist_data_load(): def mean_pixels(i, mean_pix): batch, _ = fetch(i, idx) return mean_pix + np.sum(batch) / batch.size init, fetch = load_dataset(MNIST, batch_size=128, split='train') num_batches, idx = init() assert lax.fori_loop(0, num_batches, mean_pixels, np.float32(0.)) / num_batches < 0.15
def run_epoch(rng, opt_state): def body_fun(i, opt_state): elbo_rng, data_rng = random.split(random.fold_in(rng, i)) batch = binarize_batch(data_rng, i, train_images) loss = lambda params: -elbo(elbo_rng, params, batch) / batch_size g = grad(loss)(get_params(opt_state)) return opt_update(i, g, opt_state) return lax.fori_loop(0, num_batches, body_fun, opt_state)
def p_sample_loop(self, model_fn, *, shape, rng, num_timesteps=None, return_x_init=False): """Ancestral sampling.""" init_rng, body_rng = jax.random.split(rng) del rng noise_shape = shape + (self.num_pixel_vals, ) def body_fun(i, x): t = jnp.full([shape[0]], self.num_timesteps - 1 - i) x, _ = self.p_sample(model_fn=model_fn, x=x, t=t, noise=jax.random.uniform(jax.random.fold_in( body_rng, i), shape=noise_shape)) return x if self.transition_mat_type in ['gaussian', 'uniform']: # Stationary distribution is a uniform distribution over all pixel values. x_init = jax.random.randint(init_rng, shape=shape, minval=0, maxval=self.num_pixel_vals) elif self.transition_mat_type == 'absorbing': # Stationary distribution is a kronecker delta distribution # with all its mass on the absorbing state. # Absorbing state is located at rgb values (128, 128, 128) x_init = jnp.full(shape=shape, fill_value=self.num_pixel_vals // 2, dtype=jnp.int32) else: raise ValueError( f"transition_mat_type must be 'gaussian', 'uniform', 'absorbing' " f", but is {self.transition_mat_type}") del init_rng if num_timesteps is None: num_timesteps = self.num_timesteps final_x = lax.fori_loop(lower=0, upper=num_timesteps, body_fun=body_fun, init_val=x_init) assert final_x.shape == shape if return_x_init: return x_init, final_x else: return final_x
def _outer_fun(n, vals): vals["cur_total"] = jnp.zeros_like(vals["es"][0]) vals["cur_n"] = n result = fori_loop(1, n + 1, _inner_fun, vals) vals["es"] = index_update(vals["es"], n, result["cur_total"] / n) return vals
def epoch_train(svi_state, rng_key): def body_fn(i, val): loss_sum, svi_state = val rng_key_binarize = random.fold_in(rng_key, i) batch = binarize(rng_key_binarize, train_fetch(i, train_idx)[0]) svi_state, loss = svi.update(svi_state, batch) loss_sum += loss return loss_sum, svi_state return lax.fori_loop(0, num_train, body_fn, (0., svi_state))
def run_epoch(rng, opt_state): def body_fun(i, rng__opt_state__images): (rng, opt_state, images) = rng__opt_state__images rng, elbo_rng, data_rng = random.split(rng, 3) batch = binarize_batch(data_rng, i, images) loss = lambda params: -elbo(elbo_rng, params, batch) / batch_size g = grad(loss)(minmax.get_params(opt_state)) return rng, opt_update(i, g, opt_state), images init_val = rng, opt_state, train_images _, opt_state, _ = lax.fori_loop(0, num_batches, body_fun, init_val) return opt_state
def backward_pass(x_trj, u_trj, regu, target): k_trj = np.empty_like(u_trj) K_trj = np.empty((TIME_STEPS-1, N_U, N_X)) expected_cost_redu = 0. V_x, V_xx = derivative_final(x_trj[-1], target) V_x, V_xx, k_trj, K_trj, x_trj, u_trj, expected_cost_redu, regu, target = lax.fori_loop( 0, TIME_STEPS-1, backward_pass_looper, [V_x, V_xx, k_trj, K_trj, x_trj, u_trj, expected_cost_redu, regu, target] ) return k_trj, K_trj, expected_cost_redu
def _sample(self,key,n_samps,factors, bits_to_fix = -1, values_to_fix = -1): """generate samples from a distributions with a given set of factors Parameters ---------- key : jax.random.PRNGKey jax random number generator n_samps : int number of samples to generate factors : array_like factors of the distribution Returns ------- array_like samples from the model """ state = random.randint(key,minval=0,maxval=2, shape=(self.N,)) unifs = random.uniform(key, shape=(n_samps*self.N,)) all_states = np.zeros((n_samps,self.N)) if bits_to_fix != -1: condition = True bits_to_keep = np.array([x for x in range(self.N) if x not in bits_to_fix]) N = bits_to_keep.size values_to_fix = np.array(values_to_fix) else: condition = False bits_to_keep = np.arange(self.N) N = self.N # @jit # def run_mh(j, loop_carry): # state, all_states = loop_carry # all_states = index_update(all_states,j//self.N,state) # a bit wasteful # state_flipped = index_update(state,j%self.N,1-state[j%self.N]) # dE = self.calc_e(factors,state_flipped)-self.calc_e(factors,state) # accept = ((dE < 0) | (unifs[j] < np.exp(-dE))) # state = np.where(accept, state_flipped, state) # return state, all_states @jit def run_mh(j, loop_carry): state, all_states = loop_carry if condition: state = index_update(state, bits_to_fix, values_to_fix) all_states = index_update(all_states,j//N,state) # a bit wasteful state_flipped = index_update(state,bits_to_keep[j%N],1-state[bits_to_keep[j%N]]) dE = self.calc_e(factors,state_flipped)-self.calc_e(factors,state) accept = ((dE < 0) | (unifs[j] < np.exp(-dE))) state = np.where(accept, state_flipped, state) return state, all_states all_states = fori_loop(0, n_samps * N, run_mh, (state, all_states)) return all_states[1]
def f_op_jax(): arr = jnp.zeros(5) def loop_body(i, acc_arr): arr1 = acc_arr.at[i].set(acc_arr[i] + 2.) return lax.cond(i % 2 == 0, arr1, lambda arr1: arr1.at[i].set(arr1[i] + 1.), arr1, lambda arr1: arr1) arr = lax.fori_loop(0, arr.shape[0], loop_body, arr) return arr
def f_op_jax(): arr = jnp.zeros(5) def loop_body(i, acc_arr): arr1 = ops.index_update(acc_arr, i, acc_arr[i] + 2.) return lax.cond(i % 2 == 0, arr1, lambda arr1: ops.index_update(arr1, i, arr1[i] + 1.), arr1, lambda arr1: arr1) arr = lax.fori_loop(0, arr.shape[0], loop_body, arr) return arr
def eval_test(svi_state: SVIState, rng_key: np.ndarray) -> jnp.ndarray: def body_fun(i: jnp.ndarray, loss_sum: jnp.ndarray) -> jnp.ndarray: rng_key_binarize = random.fold_in(rng_key, i) batch = binarize(rng_key_binarize, test_fetch(i, test_idx)[0]) loss = svi.evaluate(svi_state, batch) / len(batch) loss_sum += loss return loss_sum loss = lax.fori_loop(0, num_test, body_fun, 0.0) loss = loss / num_test return loss
def eval_test(opt_state, rng): def body_fun(i, val): loss_sum, rng = val rng, = random.split(rng, 1) rng, batch = binarize(rng, test_fetch(i, test_idx)[0]) loss = svi_eval(rng, opt_state, (batch,), (batch,)) / len(batch) loss_sum += loss return loss_sum, rng loss, _ = lax.fori_loop(0, num_test, body_fun, (0., rng)) loss = loss / num_test return loss
def forward_pass(x_trj, u_trj, k_trj, K_trj): u_trj = np.arcsin(np.sin(u_trj)) x_trj_new = np.empty_like(x_trj) x_trj_new = jax.ops.index_update(x_trj_new, jax.ops.index[0], x_trj[0]) u_trj_new = np.empty_like(u_trj) x_trj, u_trj, k_trj, K_trj, x_trj_new, u_trj_new = lax.fori_loop( 0, TIME_STEPS-1, forward_pass_looper, [x_trj, u_trj, k_trj, K_trj, x_trj_new, u_trj_new] ) return x_trj_new, u_trj_new
def eval_test(svi_state, rng_key): def body_fun(i, loss_sum): rng_key_binarize = random.fold_in(rng_key, i) batch = binarize(rng_key_binarize, test_fetch(i, test_idx)[0]) # FIXME: does this lead to a requirement for an rng_key arg in svi_eval? loss = svi.evaluate(svi_state, batch) / len(batch) loss_sum += loss return loss_sum loss = lax.fori_loop(0, num_test, body_fun, 0.) loss = loss / num_test return loss
def epoch_train(svi_state: SVIState, rng_key: np.ndarray) -> Tuple[jnp.ndarray, SVIState]: def body_fun( i: jnp.ndarray, val: Tuple[jnp.ndarray, SVIState] ) -> Tuple[jnp.ndarray, SVIState]: loss_sum, svi_state = val rng_key_binarize = random.fold_in(rng_key, i) batch = binarize(rng_key_binarize, train_fetch(i, train_idx)[0]) svi_state, loss = svi.update(svi_state, batch) loss_sum += loss return loss_sum, svi_state return lax.fori_loop(0, num_train, body_fun, (0.0, svi_state))
def _threefry2x32_lowering(key1, key2, x1, x2, use_rolled_loops=True): """Apply the Threefry 2x32 hash. Args: keypair: a pair of 32bit unsigned integers used for the key. count: an array of dtype uint32 used for the counts. Returns: An array of dtype uint32 with the same shape as `count`. """ x = [x1, x2] rotations = [ np.array([13, 15, 26, 6], dtype=np.uint32), np.array([17, 29, 16, 24], dtype=np.uint32) ] ks = [key1, key2, key1 ^ key2 ^ np.uint32(0x1BD11BDA)] x[0] = x[0] + ks[0] x[1] = x[1] + ks[1] if use_rolled_loops: x, _, _ = lax.fori_loop(0, 5, rolled_loop_step, (x, rotate_list(ks), rotations)) else: for r in rotations[0]: x = apply_round(x, r) x[0] = x[0] + ks[1] x[1] = x[1] + ks[2] + np.uint32(1) for r in rotations[1]: x = apply_round(x, r) x[0] = x[0] + ks[2] x[1] = x[1] + ks[0] + np.uint32(2) for r in rotations[0]: x = apply_round(x, r) x[0] = x[0] + ks[0] x[1] = x[1] + ks[1] + np.uint32(3) for r in rotations[1]: x = apply_round(x, r) x[0] = x[0] + ks[1] x[1] = x[1] + ks[2] + np.uint32(4) for r in rotations[0]: x = apply_round(x, r) x[0] = x[0] + ks[2] x[1] = x[1] + ks[0] + np.uint32(5) return tuple(x)
def test_rans_lax_fori_loop(): size = 3 tail_capacity = 100 precision = 3 n_data = 100 data = jnp.array(rng.integers(0, 4, size=(n_data, size))) # x ~ Categorical(1 / 8, 2 / 8, 3 / 8, 2 / 8) m = m_init = rans.base_message(size, tail_capacity) choose = partial(jnp.choose, mode='clip') def enc_fun(x): return (choose(x, jnp.array([0, 1, 3, 6])), choose(x, jnp.array([1, 2, 3, 2]))) def dec_fun(cf): return choose(cf, jnp.array([0, 1, 1, 2, 2, 2, 3, 3])) codec_push, codec_pop = rans.NonUniform(enc_fun, dec_fun, precision) _, freqs = enc_fun(data) # Encode def push_body(i, carry): m = carry m = codec_push(m, data[n_data - i - 1]) return m m = lax.fori_loop(0, n_data, push_body, m) coded_arr = rans.flatten(m) assert coded_arr.dtype == np.uint8 # Decode def pop_body(i, carry): m, xs = carry m, x = codec_pop(m) return m, lax.dynamic_update_index_in_dim(xs, x, i, 0) m = rans.unflatten(coded_arr, size, tail_capacity) m, data_decoded = lax.fori_loop(0, n_data, pop_body, (m, jnp.zeros((n_data, size), 'int32'))) assert rans.message_equal(m, m_init)
def trace(state, fn, num_steps, **_): """Implementation of `trace` operator, without the calling convention.""" # We need the shapes and dtypes of the outputs of `fn`. _, untraced_spec, traced_spec = jax.eval_shape( fn, map_tree(lambda s: jax.ShapeDtypeStruct(s.shape, s.dtype), state)) untraced_init = map_tree(lambda spec: np.zeros(spec.shape, spec.dtype), untraced_spec) try: num_steps = int(num_steps) use_scan = True except TypeError: use_scan = False if flatten_tree(traced_spec): raise ValueError( 'Cannot trace values when `num_steps` is not statically known. Pass ' 'False to `trace_mask` or return an empty structure (e.g. `()`) as ' 'the extra output.') if use_scan: def wrapper(state_untraced, _): state, _ = state_untraced state, untraced, traced = fn(state) return (state, untraced), traced (state, untraced), traced = lax.scan( wrapper, (state, untraced_init), xs=None, length=num_steps, ) else: trace_arrays = map_tree( lambda spec: np.zeros((num_steps, ) + spec.shape, spec.dtype), traced_spec) def wrapper(i, state_untraced_traced): state, _, trace_arrays = state_untraced_traced state, untraced, traced = fn(state) trace_arrays = map_tree(lambda a, e: jax.ops.index_update(a, i, e), trace_arrays, traced) return (state, untraced, trace_arrays) state, untraced, traced = lax.fori_loop( np.asarray(0, num_steps.dtype), num_steps, wrapper, (state, untraced_init, trace_arrays), ) return state, untraced, traced
def sample(model, key): def body(i, loop_carry): key, config = loop_carry out = model(config) probs = prob(out) key, subkey = random.split(key) sample = random.bernoulli(subkey, probs[:, i, 1]) * 2 - 1.0 sample = sample[..., jnp.newaxis] config = jax.ops.index_update(config, jax.ops.index[:, i], sample) return key, config key, config = fori_loop(0, init_config.shape[1], body, (key, init_config)) return key, config
def test_nve_neighbor_list(self, spatial_dimension, dtype): Nx = particles_per_side = 8 spacing = f32(1.25) tol = 5e-12 if dtype == np.float64 else 5e-3 L = Nx * spacing if spatial_dimension == 2: R = np.stack([np.array(r) for r in onp.ndindex(Nx, Nx)]) * spacing elif spatial_dimension == 3: R = np.stack([np.array(r) for r in onp.ndindex(Nx, Nx, Nx)]) * spacing R = np.array(R, dtype) displacement, shift = space.periodic(L) neighbor_fn, energy_fn = energy.lennard_jones_neighbor_list( displacement, L) exact_energy_fn = energy.lennard_jones_pair(displacement) init_fn, apply_fn = simulate.nve(energy_fn, shift, 1e-3) exact_init_fn, exact_apply_fn = simulate.nve(exact_energy_fn, shift, 1e-3) nbrs = neighbor_fn(R) state = init_fn(random.PRNGKey(0), R, neighbor=nbrs) exact_state = exact_init_fn(random.PRNGKey(0), R) def body_fn(i, state): state, nbrs, exact_state = state nbrs = neighbor_fn(state.position, nbrs) state = apply_fn(state, neighbor=nbrs) return state, nbrs, exact_apply_fn(exact_state) step = 0 for i in range(20): new_state, nbrs, new_exact_state = lax.fori_loop( 0, 100, body_fn, (state, nbrs, exact_state)) if nbrs.did_buffer_overflow: nbrs = neighbor_fn(state.position) else: state = new_state exact_state = new_exact_state step += 1 assert state.position.dtype == dtype self.assertAllClose(state.position, exact_state.position, atol=tol, rtol=tol)