def test_stochastic_rngs(self): rng = random.PRNGKey(0) with nn.stochastic(rng): r1 = nn.make_rng() r2 = nn.make_rng() self.assertTrue(onp.all(r1 == random.fold_in(rng, 1))) self.assertTrue(onp.all(r2 == random.fold_in(rng, 2)))
def _read_keys(key, x1, x2): """Read dropout key. `key` might be a tuple of two rng keys or a single rng key or None. In either case, `key` will be mapped into two rng keys `key1` and `key2` to make sure `(x1==x2) == (key1==key2)`. """ if key is None or x2 is None: key1 = key2 = key elif isinstance(key, tuple) and len(key) == 2: key1, key2 = key new_key = np.where(utils.x1_is_x2(key1, key2), random.fold_in(key2, 1), key2) key2 = np.where(utils.x1_is_x2(x1, x2), key1, new_key) warnings.warn( 'The value of `key[1]` might be replaced by a new value if ' 'key[0] == key[1] and x1 != x2 or key[0] != key[1] and ' 'x1 == x2.') elif isinstance(key, np.ndarray): key1 = key key2 = np.where(utils.x1_is_x2(x1, x2), key1, random.fold_in(key, 1)) else: raise TypeError(type(key)) return key1, key2
def test_init_by_shape_lifts_stochastic(self): class StochasticModule(nn.Module): def apply(self): return nn.make_rng() with nn.stochastic(random.PRNGKey(0)): rng, _ = StochasticModule.init_by_shape(random.PRNGKey(1), []) expected_rng = random.fold_in(random.PRNGKey(0), 1) expected_rng = random.fold_in(expected_rng, 1) self.assertTrue(onp.all(rng == expected_rng))
def run_update(batch_idx, opt_state_n_keys): opt_state, keys = opt_state_n_keys dkey, fkey = keys dkey = random.fold_in(dkey, batch_idx) fkey = random.fold_in(fkey, batch_idx) kl_warmup = kl_warmup_fun(batch_idx) x_bxt = train_data_fun(dkey).astype(np.float32) opt_state = update_fun(batch_idx, opt_state, hps, opt_hps, fkey, x_bxt, kl_warmup) opt_state_n_keys = (opt_state, (dkey, fkey)) return opt_state_n_keys
def run_update(batch_idx, opt_state_n_keys): """Run the optimization one time.""" opt_state, keys = opt_state_n_keys dkey, fkey = keys dkey = random.fold_in(dkey, batch_idx) fkey = random.fold_in(fkey, batch_idx) kl_warmup = kl_warmup_fun(batch_idx) x_bxt, class_id_b = train_data_fun(dkey) opt_state = update_fun(batch_idx, opt_state, hps, opt_hps, fkey, x_bxt, class_id_b, kl_warmup) opt_state_n_keys = (opt_state, (dkey, fkey)) return opt_state_n_keys
def HMC2(U, grad_U, epsilon, L, current_q, rng): q = current_q # random flick - p is momentum p = dist.Normal(0, 1).sample(random.fold_in(rng, 0), (q.shape[0], )) current_p = p # Make a half step for momentum at the beginning p = p - epsilon * grad_U(q) / 2 # initialize bookkeeping - saves trajectory qtraj = jnp.full((L + 1, q.shape[0]), jnp.nan) ptraj = qtraj qtraj = ops.index_update(qtraj, 0, current_q) ptraj = ops.index_update(ptraj, 0, p) # Alternate full steps for position and momentum for i in range(L): q = q + epsilon * p # Full step for the position # Make a full step for the momentum, except at end of trajectory if i != (L - 1): p = p - epsilon * grad_U(q) ptraj = ops.index_update(ptraj, i + 1, p) qtraj = ops.index_update(qtraj, i + 1, q) # Make a half step for momentum at the end p = p - epsilon * grad_U(q) / 2 ptraj = ops.index_update(ptraj, L, p) # Negate momentum at end of trajectory to make the proposal symmetric p = -p # Evaluate potential and kinetic energies at start and end of trajectory current_U = U(current_q) current_K = jnp.sum(current_p**2) / 2 proposed_U = U(q) proposed_K = jnp.sum(p**2) / 2 # Accept or reject the state at end of trajectory, returning either # the position at the end of the trajectory or the initial position accept = 0 runif = dist.Uniform().sample(random.fold_in(rng, 1)) if runif < jnp.exp(current_U - proposed_U + current_K - proposed_K): new_q = q # accept accept = 1 else: new_q = current_q # reject return { "q": new_q, "traj": qtraj, "ptraj": ptraj, "accept": accept, "dH": proposed_U + proposed_K - (current_U + current_K), }
def sanitize_seed(seed, salt=None): """Map various types to a seed `Tensor`.""" if callable(seed): # e.g. SeedStream. seed = seed() if salt is not None and not isinstance(salt, str): raise TypeError('`salt` must be a python `str`, got {}'.format(repr(salt))) if seed is None or isinstance(seed, six.integer_types): if JAX_MODE: raise ValueError('TFP-on-JAX requires a `jax.random.PRNGKey` `seed` arg.') # TODO(b/147874898): Do we deprecate `int` seeds, migrate ints to stateless? if salt is not None: # Prefer to incorporate salt as a constant. if seed is not None: seed = int(hashlib.sha512( str((seed, salt)).encode('utf-8')).hexdigest(), 16) % (2**31 - 1) salt = None # Convert "stateful-indicating" `int`/`None` seed to stateless Tensor seed, # by way of a stateful sampler. seed = tf.random.uniform([2], seed=seed, minval=np.iinfo(SEED_DTYPE).min, maxval=np.iinfo(SEED_DTYPE).max, dtype=SEED_DTYPE, name='seed') if salt is not None: salt = int(hashlib.sha512(str(salt).encode('utf-8')).hexdigest(), 16) if JAX_MODE: from jax import random as jaxrand # pylint: disable=g-import-not-at-top seed = jaxrand.fold_in(seed, salt & (2**32 - 1)) else: seed = tf.bitwise.bitwise_xor( seed, np.uint64([salt & (2**64 - 1)]).view(np.int32)) return tf.convert_to_tensor(seed, dtype=SEED_DTYPE, name='seed')
def private_update(rng, i, opt_state, batch): params = get_params(opt_state) rng = random.fold_in(rng, i) # get new key for new random numbers return opt_update( i, grad_fn(model, loss, params, batch, rng, args.l2_norm_clip, args.noise_multiplier, args.batch_size), opt_state)
def body_fn(i, val): svi_state, loss = val binarize_rng = random.fold_in(rng, i) batch = train_fetch(i, batchifier_state, binarize_rng)[0] svi_state, batch_loss = svi.update(svi_state, batch) loss += batch_loss / num_batches return svi_state, loss
def test_graph_network_neighbor_list_moving(self, spatial_dimension, dtype, format): if format is partition.OrderedSparse: self.skipTest('OrderedSparse format incompatible with GNN ' 'force field.') key = random.PRNGKey(0) R = random.uniform(key, (32, spatial_dimension), dtype=dtype) d, _ = space.free() cutoff = 0.3 dr_threshold = 0.1 init_fn, energy_fn = energy.graph_network(d, cutoff) params = init_fn(key, R) neighbor_fn, _, nl_energy_fn = \ energy.graph_network_neighbor_list(d, 1.0, cutoff, dr_threshold, format=format) nbrs = neighbor_fn.allocate(R) key = random.fold_in(key, 1) R = R + random.uniform(key, (32, spatial_dimension), minval=-0.05, maxval=0.05, dtype=dtype) if format is partition.Dense: self.assertAllClose(energy_fn(params, R), nl_energy_fn(params, R, nbrs)) else: self.assertAllClose(energy_fn(params, R), nl_energy_fn(params, R, nbrs), rtol=2e-4, atol=2e-4)
def testPRNGValues(self): # Test to ensure consistent random values between JAX versions k = random.PRNGKey(0) if config.x64_enabled: self.assertAllClose( random.randint(k, (3, 3), 0, 8), np.array([[7, 2, 6], [2, 1, 0], [6, 7, 7]], dtype='int64')) else: self.assertAllClose( random.randint(k, (3, 3), 0, 8), np.array([[2, 1, 3], [6, 1, 5], [6, 3, 4]], dtype='int32')) self.assertAllClose( random.split(k, 4), np.array([[2285895361, 1501764800], [1518642379, 4090693311], [ 433833334, 4221794875], [ 839183663, 3740430601]], dtype='uint32')) self.assertAllClose( random.fold_in(k, 4), np.array([2285895361, 433833334], dtype='uint32'))
def step(loop_state: LoopState): t = loop_state.episode_length step_out = ddpg_step( random.fold_in(rng_steps, t), loop_state.optimizer, loop_state.tracking_params, env, gamma, tau, loop_state.replay_buffer, batch_size, actor, critic, loop_state.state, noise(t, loop_state.prev_noise), lambda s: terminal_criterion(t, s), ) new_discounted_cumulative_reward = loop_state.discounted_cumulative_reward + ( gamma**t) * step_out.reward new_undiscounted_cumulative_reward = (loop_state.undiscounted_cumulative_reward + step_out.reward) return LoopState( episode_length=loop_state.episode_length + 1, optimizer=step_out.optimizer, tracking_params=step_out.tracking_params, discounted_cumulative_reward=new_discounted_cumulative_reward, undiscounted_cumulative_reward=new_undiscounted_cumulative_reward, state=step_out.next_state, replay_buffer=step_out.replay_buffer, prev_noise=step_out.action_noise, done=step_out.done, )
def make_rng(self, name: str) -> PRNGKey: """Generate A PRNGKey from a PRNGSequence.""" assert self.has_rng(name), f'Need PRNG for "{name}"' self._check_valid() self._validate_trace_level() self.rng_counters[name] += 1 return random.fold_in(self.rngs[name], self.rng_counters[name])
def body_fn(i, val): loss_sum, svi_state = val rng_key_binarize = random.fold_in(rng_key, i) batch = binarize(rng_key_binarize, train_fetch(i, train_idx)[0]) svi_state, loss = svi.update(svi_state, batch) loss_sum += loss return loss_sum, svi_state
def private_update(rng, i, opt_state, batch): params = get_params(opt_state) rng = random.fold_in(rng, i) # get new key for new random numbers return opt_update( i, private_grad(params, batch, rng, FLAGS.l2_norm_clip, FLAGS.noise_multiplier, FLAGS.batch_size), opt_state)
def body_fun(i, loss_sum): rng_key_binarize = random.fold_in(rng_key, i) batch = binarize(rng_key_binarize, test_fetch(i, test_idx)[0]) # FIXME: does this lead to a requirement for an rng_key arg in svi_eval? loss = svi.evaluate(svi_state, batch) / len(batch) loss_sum += loss return loss_sum
def body_fun(i, args): opt_state, cv_history = args params = get_params(opt_state) elbo_rng, data_rng, cv_rng = random.split(random.fold_in( rng, i), num=3) batch = fetch_batch(i, train_images) encoder_grad = grad(objective_fn)(*params, batch, elbo_rng, num_samples) decoder_grad = grad(elbo_fn, argnums=1)(*params, batch, elbo_rng, num_samples) control_variate = kwargs['support_fn'](params[0], batch, elbo_rng, num_samples) # Old computation of CV coeff as exponentially weighted moving average cv_coeff = compute_cv_coeff(encoder_grad, control_variate) cv_history = update_cv_coeff(cv_history, cv_coeff) encoder_grad = jax.tree_multimap(lambda x, y, z: x - y * z, encoder_grad, cv_history, control_variate) # coeff_encoder_grad = grad(objective_fn)(*params, batch, cv_rng, num_samples) # coeff_control_variate = kwargs['support_fn'](params[0], batch, cv_rng, num_samples) # cv_coeff = compute_cv_coeff(coeff_encoder_grad, coeff_control_variate) encoder_grad = jax.tree_multimap(lambda x, y, z: x - y * z, encoder_grad, cv_coeff, control_variate) g = (encoder_grad, decoder_grad) return opt_update(i, g, opt_state), cv_history
def generate_sample_images(opt_state): params = get_params(opt_state) image_rng = random.fold_in(test_rng, 2) sampled_images = two_layer_image_sample(image_rng, (decoder, params[1]), nrow, ncol) return sampled_images
def step(opt_state, it): step_rng = random.fold_in(rng, it) bij_params, deq_params = get_params(opt_state) loss_val, loss_grad = value_and_grad(loss, (1, 3))(step_rng, bij_params, bij_fns, deq_params, deq_fn, num_samples) loss_grad = tree_util.tree_map(partial(put.clip_and_zero_nans, clip_value=1.), loss_grad) opt_state = opt_update(it, loss_grad, opt_state) return opt_state, loss_val
def run_trials(batched_run_fun, inputs_targets_h0s_fun, nbatches, batch_size): """Run a bunch of trials and save everything in a dictionary.""" inputs = [] hiddens = [] outputs = [] targets = [] h0s = [] key = random.PRNGKey(onp.random.randint(0, MAX_SEED_INT)) for n in range(nbatches): key = random.fold_in(key, n) skeys = random.split(key, batch_size) input_b, target_b, h0s_b = inputs_targets_h0s_fun(skeys) if h0s_b is None: h_b, o_b = batched_run_fun(input_b) else: h_b, o_b = batched_run_fun(input_b, h0s_b) h0s.append(h0s_b) inputs.append(input_b) hiddens.append(h_b) outputs.append(o_b) targets.append(target_b) trial_dict = { 'inputs': onp.vstack(inputs), 'hiddens': onp.vstack(hiddens), 'outputs': onp.vstack(outputs), 'targets': onp.vstack(targets) } if h0s_b is not None: trial_dict['h0s'] = onp.vstack(h0s) else: trial_dict['h0s'] = None return trial_dict
def callback(info): episode = info['episode'] reward = info['reward'] current_actor_params, _ = info["optimizer"].value policy_value = eval_policy(callback_rngs[episode], current_actor_params) print(f"Episode {episode}, " f"train reward = {reward}, " f"policy value = {policy_value}, " f"elapsed = {info['elapsed']}") train_reward_per_episode.append(reward) policy_value_per_episode.append(policy_value) if episode == num_episodes - 1: # if episode % 500 == 0 or episode == num_episodes - 1: for rollout in range(5): states, actions, _ = rollout( random.fold_in(callback_rngs[episode], rollout), config.env, policy(current_actor_params), num_timesteps=250, ) viz_pendulum_rollout(states, 2 * actions / config.max_torque)
def _fold_in_str(rng, data): """Folds a string into a jax.random.PRNGKey using its SHA-1 hash.""" m = hashlib.sha1() m.update(data.encode('utf-8')) d = m.digest() hash_int = int.from_bytes(d[:4], byteorder='big', signed=True) return random.fold_in(rng, hash_int)
def make_rng(self, name: str) -> PRNGKey: """Generates A PRNGKey from a PRNGSequence with name `name`.""" if not self.has_rng(name): raise errors.InvalidRngError(f'{self.name} needs PRNG for "{name}"') self._check_valid() self._validate_trace_level() self.rng_counters[name] += 1 return random.fold_in(self.rngs[name], self.rng_counters[name])
def step(opt_state, it): step_rng = random.fold_in(rng, it) omega, params = get_params(opt_state) loss_val, loss_grad = value_and_grad(loss, (1, 2))(step_rng, omega, params, fn, num_samples) opt_state = opt_update(it, loss_grad, opt_state) return opt_state, loss_val
def body_fun( i: jnp.ndarray, val: Tuple[jnp.ndarray, SVIState] ) -> Tuple[jnp.ndarray, SVIState]: loss_sum, svi_state = val rng_key_binarize = random.fold_in(rng_key, i) batch = binarize(rng_key_binarize, train_fetch(i, train_idx)[0]) svi_state, loss = svi.update(svi_state, batch) loss_sum += loss return loss_sum, svi_state
def hitObjCase(_): (incidentRay, osurf) = (raymarchResult.ray, raymarchResult.surf) (k1, k2) = random.split(random.fold_in(k, i)) lightRadiance = sampleLightRadiance(scene, osurf, incidentRay, k1) outRayHemisphere = sampleReflection(osurf, incidentRay, k2) newFilter = surfaceFilter(color_filter, osurf[1]) newRadiance = total_radiance + applyFilter(newFilter, lightRadiance) return (False, newFilter, newRadiance, outRayHemisphere)
def make_plot(step=0.03, L=11): Q = {} Q["q"] = jnp.array([-0.1, 0.2]) pr = 0.4 #0.31 plt.figure() plt.subplot(ylabel=r"$\mu_y$", xlabel=r"$\mu_x$", xlim=(-pr, pr), ylim=(-pr, pr)) n_samples = 4 path_col = (0, 0, 0, 0.5) for r in 0.075 * jnp.arange(2, 6): plt.gca().add_artist(plt.Circle((0, 0), r, alpha=0.2, fill=False)) plt.scatter(Q["q"][0], Q["q"][1], c="k", marker="x", zorder=4) for i in range(n_samples): Q = HMC2(U, U_gradient, step, L, Q["q"], random.fold_in(random.PRNGKey(0), i)) if n_samples < 10: for j in range(L): K0 = jnp.sum(Q["ptraj"][j]**2) / 2 plt.plot( Q["traj"][j:j + 2, 0], Q["traj"][j:j + 2, 1], c=path_col, lw=1 + 2 * K0, ) plt.scatter(Q["traj"][:, 0], Q["traj"][:, 1], c="white", s=5, zorder=3) # for fancy arrows dx = Q["traj"][L, 0] - Q["traj"][L - 1, 0] dy = Q["traj"][L, 1] - Q["traj"][L - 1, 1] d = jnp.sqrt(dx**2 + dy**2) plt.annotate( "", (Q["traj"][L - 1, 0], Q["traj"][L - 1, 1]), (Q["traj"][L, 0], Q["traj"][L, 1]), arrowprops={"arrowstyle": "<-"}, ) plt.annotate( str(i + 1), (Q["traj"][L, 0], Q["traj"][L, 1]), xytext=(3, 3), textcoords="offset points", ) plt.scatter( Q["traj"][L + 1, 0], Q["traj"][L + 1, 1], c=("red" if jnp.abs(Q["dH"]) > 0.1 else "black"), zorder=4, ) #plt.axis('square') plt.title(f'L={L}') plt.savefig(f'../figures/hmc2d_L{L}.pdf', dpi=300)
def step(opt_state, it): step_rng = random.fold_in(rng, it) thetax, thetay, thetad, paramsm = get_params(opt_state) loss_val, loss_grad = value_and_grad(loss, (1, 2, 3, 4))(step_rng, thetax, thetay, thetad, paramsm, netm, num_samples) opt_state = opt_update(it, loss_grad, opt_state) return opt_state, loss_val
def make_rng(self): # when calling make_rng within a jax transformations # the rng could be implicitly reused (eg. in jit, vmap, scan, ...). # We raise an error to avoid silent errors. level = utils._trace_level(utils._current_trace()) if level > self.level: raise ValueError('stochastic operations are not allowed when the' ' stochastic context is created outside of the' ' current Jax transformation') self.counter += 1 return random.fold_in(self.base_rng, self.counter)
def fold_in(seed, salt): """Folds salt into seed to form a new seed.""" if JAX_MODE: from jax import random as jaxrand # pylint: disable=g-import-not-at-top return jaxrand.fold_in(seed, salt & (2**32 - 1)) if isinstance(salt, (six.integer_types)): seed = tf.bitwise.bitwise_xor( seed, np.uint64([salt & (2**64 - 1)]).view(np.int32)) else: seed = tf.random.experimental.stateless_fold_in(seed, salt) return seed