Exemple #1
0
 def test_stochastic_rngs(self):
   rng = random.PRNGKey(0)
   with nn.stochastic(rng):
     r1 = nn.make_rng()
     r2 = nn.make_rng()
   self.assertTrue(onp.all(r1 == random.fold_in(rng, 1)))
   self.assertTrue(onp.all(r2 == random.fold_in(rng, 2)))
Exemple #2
0
def _read_keys(key, x1, x2):
    """Read dropout key.

     `key` might be a tuple of two rng keys or a single rng key or None. In
     either case, `key` will be mapped into two rng keys `key1` and `key2` to
     make sure `(x1==x2) == (key1==key2)`.
  """

    if key is None or x2 is None:
        key1 = key2 = key
    elif isinstance(key, tuple) and len(key) == 2:
        key1, key2 = key
        new_key = np.where(utils.x1_is_x2(key1, key2), random.fold_in(key2, 1),
                           key2)
        key2 = np.where(utils.x1_is_x2(x1, x2), key1, new_key)
        warnings.warn(
            'The value of `key[1]` might be replaced by a new value if '
            'key[0] == key[1] and x1 != x2 or key[0] != key[1] and '
            'x1 == x2.')
    elif isinstance(key, np.ndarray):
        key1 = key
        key2 = np.where(utils.x1_is_x2(x1, x2), key1, random.fold_in(key, 1))
    else:
        raise TypeError(type(key))
    return key1, key2
Exemple #3
0
  def test_init_by_shape_lifts_stochastic(self):
    class StochasticModule(nn.Module):
      def apply(self):
        return nn.make_rng()

    with nn.stochastic(random.PRNGKey(0)):
      rng, _ = StochasticModule.init_by_shape(random.PRNGKey(1), [])
    expected_rng = random.fold_in(random.PRNGKey(0), 1)
    expected_rng = random.fold_in(expected_rng, 1)
    self.assertTrue(onp.all(rng == expected_rng))
 def run_update(batch_idx, opt_state_n_keys):
     opt_state, keys = opt_state_n_keys
     dkey, fkey = keys
     dkey = random.fold_in(dkey, batch_idx)
     fkey = random.fold_in(fkey, batch_idx)
     kl_warmup = kl_warmup_fun(batch_idx)
     x_bxt = train_data_fun(dkey).astype(np.float32)
     opt_state = update_fun(batch_idx, opt_state, hps, opt_hps, fkey, x_bxt,
                            kl_warmup)
     opt_state_n_keys = (opt_state, (dkey, fkey))
     return opt_state_n_keys
 def run_update(batch_idx, opt_state_n_keys):
     """Run the optimization one time."""
     opt_state, keys = opt_state_n_keys
     dkey, fkey = keys
     dkey = random.fold_in(dkey, batch_idx)
     fkey = random.fold_in(fkey, batch_idx)
     kl_warmup = kl_warmup_fun(batch_idx)
     x_bxt, class_id_b = train_data_fun(dkey)
     opt_state = update_fun(batch_idx, opt_state, hps, opt_hps, fkey, x_bxt,
                            class_id_b, kl_warmup)
     opt_state_n_keys = (opt_state, (dkey, fkey))
     return opt_state_n_keys
def HMC2(U, grad_U, epsilon, L, current_q, rng):
    q = current_q
    # random flick - p is momentum
    p = dist.Normal(0, 1).sample(random.fold_in(rng, 0), (q.shape[0], ))
    current_p = p
    # Make a half step for momentum at the beginning
    p = p - epsilon * grad_U(q) / 2
    # initialize bookkeeping - saves trajectory
    qtraj = jnp.full((L + 1, q.shape[0]), jnp.nan)
    ptraj = qtraj
    qtraj = ops.index_update(qtraj, 0, current_q)
    ptraj = ops.index_update(ptraj, 0, p)

    # Alternate full steps for position and momentum
    for i in range(L):
        q = q + epsilon * p  # Full step for the position
        # Make a full step for the momentum, except at end of trajectory
        if i != (L - 1):
            p = p - epsilon * grad_U(q)
            ptraj = ops.index_update(ptraj, i + 1, p)
        qtraj = ops.index_update(qtraj, i + 1, q)

    # Make a half step for momentum at the end
    p = p - epsilon * grad_U(q) / 2
    ptraj = ops.index_update(ptraj, L, p)
    # Negate momentum at end of trajectory to make the proposal symmetric
    p = -p
    # Evaluate potential and kinetic energies at start and end of trajectory
    current_U = U(current_q)
    current_K = jnp.sum(current_p**2) / 2
    proposed_U = U(q)
    proposed_K = jnp.sum(p**2) / 2
    # Accept or reject the state at end of trajectory, returning either
    # the position at the end of the trajectory or the initial position
    accept = 0
    runif = dist.Uniform().sample(random.fold_in(rng, 1))
    if runif < jnp.exp(current_U - proposed_U + current_K - proposed_K):
        new_q = q  # accept
        accept = 1
    else:
        new_q = current_q  # reject
    return {
        "q": new_q,
        "traj": qtraj,
        "ptraj": ptraj,
        "accept": accept,
        "dH": proposed_U + proposed_K - (current_U + current_K),
    }
Exemple #7
0
def sanitize_seed(seed, salt=None):
  """Map various types to a seed `Tensor`."""
  if callable(seed):  # e.g. SeedStream.
    seed = seed()
  if salt is not None and not isinstance(salt, str):
    raise TypeError('`salt` must be a python `str`, got {}'.format(repr(salt)))
  if seed is None or isinstance(seed, six.integer_types):
    if JAX_MODE:
      raise ValueError('TFP-on-JAX requires a `jax.random.PRNGKey` `seed` arg.')
    # TODO(b/147874898): Do we deprecate `int` seeds, migrate ints to stateless?
    if salt is not None:
      # Prefer to incorporate salt as a constant.
      if seed is not None:
        seed = int(hashlib.sha512(
            str((seed, salt)).encode('utf-8')).hexdigest(), 16) % (2**31 - 1)
      salt = None
    # Convert "stateful-indicating" `int`/`None` seed to stateless Tensor seed,
    # by way of a stateful sampler.
    seed = tf.random.uniform([2], seed=seed, minval=np.iinfo(SEED_DTYPE).min,
                             maxval=np.iinfo(SEED_DTYPE).max, dtype=SEED_DTYPE,
                             name='seed')
  if salt is not None:
    salt = int(hashlib.sha512(str(salt).encode('utf-8')).hexdigest(), 16)
    if JAX_MODE:
      from jax import random as jaxrand  # pylint: disable=g-import-not-at-top
      seed = jaxrand.fold_in(seed, salt & (2**32 - 1))
    else:
      seed = tf.bitwise.bitwise_xor(
          seed, np.uint64([salt & (2**64 - 1)]).view(np.int32))
  return tf.convert_to_tensor(seed, dtype=SEED_DTYPE, name='seed')
 def private_update(rng, i, opt_state, batch):
     params = get_params(opt_state)
     rng = random.fold_in(rng, i)  # get new key for new random numbers
     return opt_update(
         i,
         grad_fn(model, loss, params, batch, rng, args.l2_norm_clip,
                 args.noise_multiplier, args.batch_size), opt_state)
Exemple #9
0
 def body_fn(i, val):
     svi_state, loss = val
     binarize_rng = random.fold_in(rng, i)
     batch = train_fetch(i, batchifier_state, binarize_rng)[0]
     svi_state, batch_loss = svi.update(svi_state, batch)
     loss += batch_loss / num_batches
     return svi_state, loss
Exemple #10
0
  def test_graph_network_neighbor_list_moving(self,
                                              spatial_dimension,
                                              dtype,
                                              format):
    if format is partition.OrderedSparse:
      self.skipTest('OrderedSparse format incompatible with GNN '
                    'force field.')

    key = random.PRNGKey(0)

    R = random.uniform(key, (32, spatial_dimension), dtype=dtype)

    d, _ = space.free()

    cutoff = 0.3
    dr_threshold = 0.1

    init_fn, energy_fn = energy.graph_network(d, cutoff)
    params = init_fn(key, R)

    neighbor_fn, _, nl_energy_fn = \
      energy.graph_network_neighbor_list(d, 1.0, cutoff,
                                         dr_threshold, format=format)

    nbrs = neighbor_fn.allocate(R)
    key = random.fold_in(key, 1)
    R = R + random.uniform(key, (32, spatial_dimension),
                           minval=-0.05, maxval=0.05, dtype=dtype)
    if format is partition.Dense:
      self.assertAllClose(energy_fn(params, R), nl_energy_fn(params, R, nbrs))
    else:
      self.assertAllClose(energy_fn(params, R), nl_energy_fn(params, R, nbrs),
                          rtol=2e-4, atol=2e-4)
Exemple #11
0
  def testPRNGValues(self):
    # Test to ensure consistent random values between JAX versions
    k = random.PRNGKey(0)

    if config.x64_enabled:
        self.assertAllClose(
            random.randint(k, (3, 3), 0, 8),
            np.array([[7, 2, 6],
                       [2, 1, 0],
                       [6, 7, 7]], dtype='int64'))
    else:
        self.assertAllClose(
            random.randint(k, (3, 3), 0, 8),
            np.array([[2, 1, 3],
                       [6, 1, 5],
                       [6, 3, 4]], dtype='int32'))

    self.assertAllClose(
        random.split(k, 4),
        np.array([[2285895361, 1501764800],
                   [1518642379, 4090693311],
                   [ 433833334, 4221794875],
                   [ 839183663, 3740430601]], dtype='uint32'))

    self.assertAllClose(
        random.fold_in(k, 4),
        np.array([2285895361,  433833334], dtype='uint32'))
Exemple #12
0
    def step(loop_state: LoopState):
      t = loop_state.episode_length
      step_out = ddpg_step(
          random.fold_in(rng_steps, t),
          loop_state.optimizer,
          loop_state.tracking_params,
          env,
          gamma,
          tau,
          loop_state.replay_buffer,
          batch_size,
          actor,
          critic,
          loop_state.state,
          noise(t, loop_state.prev_noise),
          lambda s: terminal_criterion(t, s),
      )

      new_discounted_cumulative_reward = loop_state.discounted_cumulative_reward + (
          gamma**t) * step_out.reward
      new_undiscounted_cumulative_reward = (loop_state.undiscounted_cumulative_reward +
                                            step_out.reward)
      return LoopState(
          episode_length=loop_state.episode_length + 1,
          optimizer=step_out.optimizer,
          tracking_params=step_out.tracking_params,
          discounted_cumulative_reward=new_discounted_cumulative_reward,
          undiscounted_cumulative_reward=new_undiscounted_cumulative_reward,
          state=step_out.next_state,
          replay_buffer=step_out.replay_buffer,
          prev_noise=step_out.action_noise,
          done=step_out.done,
      )
Exemple #13
0
 def make_rng(self, name: str) -> PRNGKey:
   """Generate A PRNGKey from a PRNGSequence."""
   assert self.has_rng(name), f'Need PRNG for "{name}"'
   self._check_valid()
   self._validate_trace_level()
   self.rng_counters[name] += 1
   return random.fold_in(self.rngs[name], self.rng_counters[name])
Exemple #14
0
 def body_fn(i, val):
     loss_sum, svi_state = val
     rng_key_binarize = random.fold_in(rng_key, i)
     batch = binarize(rng_key_binarize, train_fetch(i, train_idx)[0])
     svi_state, loss = svi.update(svi_state, batch)
     loss_sum += loss
     return loss_sum, svi_state
Exemple #15
0
 def private_update(rng, i, opt_state, batch):
     params = get_params(opt_state)
     rng = random.fold_in(rng, i)  # get new key for new random numbers
     return opt_update(
         i,
         private_grad(params, batch, rng, FLAGS.l2_norm_clip,
                      FLAGS.noise_multiplier, FLAGS.batch_size), opt_state)
Exemple #16
0
 def body_fun(i, loss_sum):
     rng_key_binarize = random.fold_in(rng_key, i)
     batch = binarize(rng_key_binarize, test_fetch(i, test_idx)[0])
     # FIXME: does this lead to a requirement for an rng_key arg in svi_eval?
     loss = svi.evaluate(svi_state, batch) / len(batch)
     loss_sum += loss
     return loss_sum
Exemple #17
0
            def body_fun(i, args):
                opt_state, cv_history = args
                params = get_params(opt_state)
                elbo_rng, data_rng, cv_rng = random.split(random.fold_in(
                    rng, i),
                                                          num=3)
                batch = fetch_batch(i, train_images)
                encoder_grad = grad(objective_fn)(*params, batch, elbo_rng,
                                                  num_samples)
                decoder_grad = grad(elbo_fn, argnums=1)(*params, batch,
                                                        elbo_rng, num_samples)
                control_variate = kwargs['support_fn'](params[0], batch,
                                                       elbo_rng, num_samples)

                # Old computation of CV coeff as exponentially weighted moving average
                cv_coeff = compute_cv_coeff(encoder_grad, control_variate)
                cv_history = update_cv_coeff(cv_history, cv_coeff)
                encoder_grad = jax.tree_multimap(lambda x, y, z: x - y * z,
                                                 encoder_grad, cv_history,
                                                 control_variate)

                # coeff_encoder_grad = grad(objective_fn)(*params, batch, cv_rng, num_samples)
                # coeff_control_variate = kwargs['support_fn'](params[0], batch, cv_rng, num_samples)

                # cv_coeff = compute_cv_coeff(coeff_encoder_grad, coeff_control_variate)

                encoder_grad = jax.tree_multimap(lambda x, y, z: x - y * z,
                                                 encoder_grad, cv_coeff,
                                                 control_variate)

                g = (encoder_grad, decoder_grad)
                return opt_update(i, g, opt_state), cv_history
Exemple #18
0
 def generate_sample_images(opt_state):
     params = get_params(opt_state)
     image_rng = random.fold_in(test_rng, 2)
     sampled_images = two_layer_image_sample(image_rng,
                                             (decoder, params[1]), nrow,
                                             ncol)
     return sampled_images
 def step(opt_state, it):
     step_rng = random.fold_in(rng, it)
     bij_params, deq_params = get_params(opt_state)
     loss_val, loss_grad = value_and_grad(loss, (1, 3))(step_rng, bij_params, bij_fns, deq_params, deq_fn, num_samples)
     loss_grad = tree_util.tree_map(partial(put.clip_and_zero_nans, clip_value=1.), loss_grad)
     opt_state = opt_update(it, loss_grad, opt_state)
     return opt_state, loss_val
def run_trials(batched_run_fun, inputs_targets_h0s_fun, nbatches, batch_size):
    """Run a bunch of trials and save everything in a dictionary."""
    inputs = []
    hiddens = []
    outputs = []
    targets = []
    h0s = []
    key = random.PRNGKey(onp.random.randint(0, MAX_SEED_INT))
    for n in range(nbatches):
        key = random.fold_in(key, n)
        skeys = random.split(key, batch_size)
        input_b, target_b, h0s_b = inputs_targets_h0s_fun(skeys)
        if h0s_b is None:
            h_b, o_b = batched_run_fun(input_b)
        else:
            h_b, o_b = batched_run_fun(input_b, h0s_b)
            h0s.append(h0s_b)

        inputs.append(input_b)
        hiddens.append(h_b)
        outputs.append(o_b)
        targets.append(target_b)

    trial_dict = {
        'inputs': onp.vstack(inputs),
        'hiddens': onp.vstack(hiddens),
        'outputs': onp.vstack(outputs),
        'targets': onp.vstack(targets)
    }
    if h0s_b is not None:
        trial_dict['h0s'] = onp.vstack(h0s)
    else:
        trial_dict['h0s'] = None
    return trial_dict
Exemple #21
0
    def callback(info):
        episode = info['episode']
        reward = info['reward']
        current_actor_params, _ = info["optimizer"].value

        policy_value = eval_policy(callback_rngs[episode],
                                   current_actor_params)

        print(f"Episode {episode}, "
              f"train reward = {reward}, "
              f"policy value = {policy_value}, "
              f"elapsed = {info['elapsed']}")

        train_reward_per_episode.append(reward)
        policy_value_per_episode.append(policy_value)

        if episode == num_episodes - 1:
            # if episode % 500 == 0 or episode == num_episodes - 1:
            for rollout in range(5):
                states, actions, _ = rollout(
                    random.fold_in(callback_rngs[episode], rollout),
                    config.env,
                    policy(current_actor_params),
                    num_timesteps=250,
                )
                viz_pendulum_rollout(states, 2 * actions / config.max_torque)
Exemple #22
0
def _fold_in_str(rng, data):
    """Folds a string into a jax.random.PRNGKey using its SHA-1 hash."""
    m = hashlib.sha1()
    m.update(data.encode('utf-8'))
    d = m.digest()
    hash_int = int.from_bytes(d[:4], byteorder='big', signed=True)
    return random.fold_in(rng, hash_int)
Exemple #23
0
 def make_rng(self, name: str) -> PRNGKey:
   """Generates A PRNGKey from a PRNGSequence with name `name`."""
   if not self.has_rng(name):
     raise errors.InvalidRngError(f'{self.name} needs PRNG for "{name}"')
   self._check_valid()
   self._validate_trace_level()
   self.rng_counters[name] += 1
   return random.fold_in(self.rngs[name], self.rng_counters[name])
Exemple #24
0
 def step(opt_state, it):
     step_rng = random.fold_in(rng, it)
     omega, params = get_params(opt_state)
     loss_val, loss_grad = value_and_grad(loss,
                                          (1, 2))(step_rng, omega, params,
                                                  fn, num_samples)
     opt_state = opt_update(it, loss_grad, opt_state)
     return opt_state, loss_val
Exemple #25
0
 def body_fun(
     i: jnp.ndarray, val: Tuple[jnp.ndarray, SVIState]
 ) -> Tuple[jnp.ndarray, SVIState]:
     loss_sum, svi_state = val
     rng_key_binarize = random.fold_in(rng_key, i)
     batch = binarize(rng_key_binarize, train_fetch(i, train_idx)[0])
     svi_state, loss = svi.update(svi_state, batch)
     loss_sum += loss
     return loss_sum, svi_state
Exemple #26
0
 def hitObjCase(_):
     (incidentRay, osurf) = (raymarchResult.ray, raymarchResult.surf)
     (k1, k2) = random.split(random.fold_in(k, i))
     lightRadiance = sampleLightRadiance(scene, osurf, incidentRay, k1)
     outRayHemisphere = sampleReflection(osurf, incidentRay, k2)
     newFilter = surfaceFilter(color_filter, osurf[1])
     newRadiance = total_radiance + applyFilter(newFilter,
                                                lightRadiance)
     return (False, newFilter, newRadiance, outRayHemisphere)
def make_plot(step=0.03, L=11):
    Q = {}
    Q["q"] = jnp.array([-0.1, 0.2])
    pr = 0.4  #0.31
    plt.figure()
    plt.subplot(ylabel=r"$\mu_y$",
                xlabel=r"$\mu_x$",
                xlim=(-pr, pr),
                ylim=(-pr, pr))
    n_samples = 4
    path_col = (0, 0, 0, 0.5)
    for r in 0.075 * jnp.arange(2, 6):
        plt.gca().add_artist(plt.Circle((0, 0), r, alpha=0.2, fill=False))
    plt.scatter(Q["q"][0], Q["q"][1], c="k", marker="x", zorder=4)
    for i in range(n_samples):
        Q = HMC2(U, U_gradient, step, L, Q["q"],
                 random.fold_in(random.PRNGKey(0), i))
        if n_samples < 10:
            for j in range(L):
                K0 = jnp.sum(Q["ptraj"][j]**2) / 2
                plt.plot(
                    Q["traj"][j:j + 2, 0],
                    Q["traj"][j:j + 2, 1],
                    c=path_col,
                    lw=1 + 2 * K0,
                )
            plt.scatter(Q["traj"][:, 0],
                        Q["traj"][:, 1],
                        c="white",
                        s=5,
                        zorder=3)
            # for fancy arrows
            dx = Q["traj"][L, 0] - Q["traj"][L - 1, 0]
            dy = Q["traj"][L, 1] - Q["traj"][L - 1, 1]
            d = jnp.sqrt(dx**2 + dy**2)
            plt.annotate(
                "",
                (Q["traj"][L - 1, 0], Q["traj"][L - 1, 1]),
                (Q["traj"][L, 0], Q["traj"][L, 1]),
                arrowprops={"arrowstyle": "<-"},
            )
            plt.annotate(
                str(i + 1),
                (Q["traj"][L, 0], Q["traj"][L, 1]),
                xytext=(3, 3),
                textcoords="offset points",
            )
        plt.scatter(
            Q["traj"][L + 1, 0],
            Q["traj"][L + 1, 1],
            c=("red" if jnp.abs(Q["dH"]) > 0.1 else "black"),
            zorder=4,
        )
        #plt.axis('square')
        plt.title(f'L={L}')
        plt.savefig(f'../figures/hmc2d_L{L}.pdf', dpi=300)
Exemple #28
0
 def step(opt_state, it):
     step_rng = random.fold_in(rng, it)
     thetax, thetay, thetad, paramsm = get_params(opt_state)
     loss_val, loss_grad = value_and_grad(loss,
                                          (1, 2, 3, 4))(step_rng, thetax,
                                                        thetay, thetad,
                                                        paramsm, netm,
                                                        num_samples)
     opt_state = opt_update(it, loss_grad, opt_state)
     return opt_state, loss_val
Exemple #29
0
 def make_rng(self):
     # when calling make_rng within a jax transformations
     # the rng could be implicitly reused (eg. in jit, vmap, scan, ...).
     # We raise an error to avoid silent errors.
     level = utils._trace_level(utils._current_trace())
     if level > self.level:
         raise ValueError('stochastic operations are not allowed when the'
                          ' stochastic context is created outside of the'
                          ' current Jax transformation')
     self.counter += 1
     return random.fold_in(self.base_rng, self.counter)
def fold_in(seed, salt):
  """Folds salt into seed to form a new seed."""
  if JAX_MODE:
    from jax import random as jaxrand  # pylint: disable=g-import-not-at-top
    return jaxrand.fold_in(seed, salt & (2**32 - 1))
  if isinstance(salt, (six.integer_types)):
    seed = tf.bitwise.bitwise_xor(
        seed, np.uint64([salt & (2**64 - 1)]).view(np.int32))
  else:
    seed = tf.random.experimental.stateless_fold_in(seed, salt)
  return seed