Example #1
0
def job(
    random_seed: int,
    base_dir: Path,
    theta_min: float,
    theta_max: float,
    theta_dot_min: float,
    theta_dot_max: float,
):
    rng = random.PRNGKey(random_seed)

    rng, train_rng = random.split(rng)
    callback_rngs = random.split(rng, num_episodes)

    params = [None]
    tracking_params = [None]

    train_reward_per_episode = []
    policy_value_per_episode = []
    episode_lengths = []
    elapsed_per_episode = []

    def callback(info):
        episode = info['episode']
        params[0] = info["optimizer"].value
        tracking_params[0] = info["tracking_params"]

        policy_value = run_ddpg.eval_policy(callback_rngs[episode],
                                            info["optimizer"].value[0])

        train_reward_per_episode.append(info['reward'])
        policy_value_per_episode.append(policy_value)
        episode_lengths.append(info["episode_length"])
        elapsed_per_episode.append(info["elapsed"])

    run_ddpg.train(
        train_rng,
        num_episodes,
        lambda t, s: lax.bitwise_or(
            lax.ge(t, config.episode_length),
            lax.bitwise_or(
                lax.le(s[0], theta_min),
                lax.bitwise_or(
                    lax.ge(s[0], theta_max),
                    lax.bitwise_or(lax.le(s[1], theta_dot_min),
                                   lax.ge(s[1], theta_dot_max))))),
        callback,
    )
    with (base_dir / f"seed={random_seed}.pkl").open(mode="wb") as f:
        pickle.dump(
            {
                "final_params": params[0],
                "final_tracking_params": tracking_params[0],
                "train_reward_per_episode": train_reward_per_episode,
                "policy_value_per_episode": policy_value_per_episode,
                "episode_lengths": episode_lengths,
                "elapsed_per_episode": elapsed_per_episode,
            }, f)
Example #2
0
def modf(x, out=None):
    _check_arraylike("modf", x)
    if out is not None:
        raise NotImplementedError(
            "The 'out' argument to jnp.modf is not supported.")
    whole = _where(lax.ge(x, lax_internal._zero(x)), floor(x), ceil(x))
    return x - whole, whole
Example #3
0
 def compare(patch_y, patch_x):
     patch_center_y = patch_y + filter_radius
     patch_center_x = patch_x + filter_radius
     # Skip if patch is out of image boundaries or this is the center patch
     skip = lax.lt(patch_center_y, pad) | lax.ge(patch_center_y, _h +
                                                 pad) | lax.lt(patch_center_x, pad) | lax.ge(patch_center_x, _w+pad) | (lax.eq(patch_center_y, win_center_y) & lax.eq(patch_center_x, win_center_x))
     return lax.cond(skip, lambda _: (0., 0.), _compare, (patch_center_y, patch_center_x))
Example #4
0
 def tri(n, m, k=0):
   # Tie in the key to avoid the mask becoming a constant.
   # This way XLA can construct the mask during computation and fuse it
   # with the attention ops.
   x = jnp.arange(n, dtype=jnp.int32)
   y = jnp.arange(m, dtype=jnp.int32)
   mask = lax.ge(
       (lax.broadcast_in_dim(x, shape=(n, m), broadcast_dimensions=(0,))) + k,
       lax.broadcast(y, [n]))
   return mask
Example #5
0
def main():
    num_episodes = 1000

    rng = random.PRNGKey(0)
    train_rng, rng = random.split(rng)
    callback_rngs = random.split(rng, num_episodes)

    train_reward_per_episode = []
    policy_value_per_episode = []

    def callback(info):
        episode = info['episode']
        reward = info['reward']
        current_actor_params, _ = info["optimizer"].value

        policy_value = eval_policy(callback_rngs[episode],
                                   current_actor_params)

        print(f"Episode {episode}, "
              f"train reward = {reward}, "
              f"policy value = {policy_value}, "
              f"elapsed = {info['elapsed']}")

        train_reward_per_episode.append(reward)
        policy_value_per_episode.append(policy_value)

        if episode == num_episodes - 1:
            # if episode % 500 == 0 or episode == num_episodes - 1:
            for rollout in range(5):
                states, actions, _ = rollout(
                    random.fold_in(callback_rngs[episode], rollout),
                    config.env,
                    policy(current_actor_params),
                    num_timesteps=250,
                )
                viz_pendulum_rollout(states, 2 * actions / config.max_torque)

    train(
        train_rng,
        num_episodes,
        lambda t, _: lax.ge(t, config.episode_length),
        callback,
    )
Example #6
0
def floor_divide(x1, x2):
  x1, x2 = _promote_args("floor_divide", x1, x2)
  dtype = dtypes.dtype(x1)
  if dtypes.issubdtype(dtype, np.integer):
    quotient = lax.div(x1, x2)
    select = logical_and(lax.sign(x1) != lax.sign(x2), lax.rem(x1, x2) != 0)
    # TODO(mattjj): investigate why subtracting a scalar was causing promotion
    return _where(select, quotient - 1, quotient)
  elif dtypes.issubdtype(dtype, np.complexfloating):
    x1r = lax.real(x1)
    x1i = lax.imag(x1)
    x2r = lax.real(x2)
    x2i = lax.imag(x2)
    which = lax.ge(lax.abs(x2r), lax.abs(x2i))
    rat1 = _where(which, lax.full_like(x2i, 1), lax.div(x2r, x2i))
    rat2 = _where(which, lax.div(x2i, x2r), _lax_const(x2i, 1))
    out = lax.floor(lax.div(lax.add(lax.mul(x1r, rat1), lax.mul(x1i, rat2)),
                            lax.add(lax.mul(x2r, rat1), lax.mul(x2i, rat2))))
    return lax.convert_element_type(out, dtype)
  else:
    return _float_divmod(x1, x2)[0]
Example #7
0
def main():
  rng = random.PRNGKey(0)
  num_episodes = 10000

  print(f"Loading best seed from {experiment_folder}... ", end="")
  best_seed_data = load_best_seed()
  print("done")

  print("Building support set... ", end="")
  rng, ss_rng = random.split(rng)
  actor_params, _ = best_seed_data["final_params"]

  support_set = build_support_set(ss_rng, actor_params)
  support_set_flat = jp.reshape(support_set, (-1, support_set.shape[-1]))

  # theta_min = jp.min(support_set_flat[:, 0]) - epsilon
  # theta_max = jp.max(support_set_flat[:, 0]) + epsilon
  # theta_dot_min = jp.min(support_set_flat[:, 1]) - epsilon
  # theta_dot_max = jp.max(support_set_flat[:, 1]) + epsilon
  print("done")

  rng, train_rng = random.split(rng)
  callback_rngs = random.split(rng, num_episodes)

  train_reward_per_episode = []
  policy_value_per_episode = []
  episode_lengths = []

  def callback(info):
    episode = info['episode']
    reward = info['reward']

    current_actor_params = info["optimizer"].value[0]
    policy_value = run_ddpg.eval_policy(callback_rngs[episode],
                                        current_actor_params)

    print(f"Episode {episode}, "
          f"episode_length = {info['episode_length']}, "
          f"reward = {reward}, "
          f"policy_value = {policy_value}, "
          f"elapsed = {info['elapsed']}")

    train_reward_per_episode.append(reward)
    policy_value_per_episode.append(policy_value)
    episode_lengths.append(info["episode_length"])

    # if episode == num_episodes - 1:
    # if episode % 5000 == 0 or episode == num_episodes - 1:
    #   for rollout in range(5):
    #     states, actions, _ = ddpg.rollout(
    #         random.fold_in(callback_rngs[episode], rollout),
    #         config.env,
    #         policy(current_actor_params),
    #         num_timesteps=250,
    #     )
    #     viz_pendulum_rollout(states, 2 * actions / config.max_torque)

  run_ddpg.train(
      train_rng,
      num_episodes,
      # lambda t, s: lax.bitwise_or(
      #     lax.ge(t, config.episode_length),
      #     lax.bitwise_or(
      #         lax.le(s[0], theta_min),
      #         lax.bitwise_or(
      #             lax.ge(s[0], theta_max),
      #             lax.bitwise_or(lax.le(s[1], theta_dot_min),
      #                            lax.ge(s[1], theta_dot_max))))),
      # lambda t, s: lax.bitwise_or(
      #     lax.ge(t, config.episode_length),
      #     lax.bitwise_or(lax.ge(jp.abs(s[1]), 10.0),
      #                    lax.ge(jp.abs(s[0] - jp.pi), 0.5))),
      lambda loop_state: lax.bitwise_or(
          lax.ge(loop_state.episode_length, config.episode_length),
          lax.ge(
              jp.min(
                  jp.sum((support_set_flat[:, :2] - loop_state.state[:2])**2,
                         axis=1)), max_squared_dist)),
      callback,
  )