示例#1
0
def job(
    random_seed: int,
    base_dir: Path,
    theta_min: float,
    theta_max: float,
    theta_dot_min: float,
    theta_dot_max: float,
):
    rng = random.PRNGKey(random_seed)

    rng, train_rng = random.split(rng)
    callback_rngs = random.split(rng, num_episodes)

    params = [None]
    tracking_params = [None]

    train_reward_per_episode = []
    policy_value_per_episode = []
    episode_lengths = []
    elapsed_per_episode = []

    def callback(info):
        episode = info['episode']
        params[0] = info["optimizer"].value
        tracking_params[0] = info["tracking_params"]

        policy_value = run_ddpg.eval_policy(callback_rngs[episode],
                                            info["optimizer"].value[0])

        train_reward_per_episode.append(info['reward'])
        policy_value_per_episode.append(policy_value)
        episode_lengths.append(info["episode_length"])
        elapsed_per_episode.append(info["elapsed"])

    run_ddpg.train(
        train_rng,
        num_episodes,
        lambda t, s: lax.bitwise_or(
            lax.ge(t, config.episode_length),
            lax.bitwise_or(
                lax.le(s[0], theta_min),
                lax.bitwise_or(
                    lax.ge(s[0], theta_max),
                    lax.bitwise_or(lax.le(s[1], theta_dot_min),
                                   lax.ge(s[1], theta_dot_max))))),
        callback,
    )
    with (base_dir / f"seed={random_seed}.pkl").open(mode="wb") as f:
        pickle.dump(
            {
                "final_params": params[0],
                "final_tracking_params": tracking_params[0],
                "train_reward_per_episode": train_reward_per_episode,
                "policy_value_per_episode": policy_value_per_episode,
                "episode_lengths": episode_lengths,
                "elapsed_per_episode": elapsed_per_episode,
            }, f)
示例#2
0
def main():
  rng = random.PRNGKey(0)
  num_episodes = 10000

  print(f"Loading best seed from {experiment_folder}... ", end="")
  best_seed_data = load_best_seed()
  print("done")

  print("Building support set... ", end="")
  rng, ss_rng = random.split(rng)
  actor_params, _ = best_seed_data["final_params"]

  support_set = build_support_set(ss_rng, actor_params)
  support_set_flat = jp.reshape(support_set, (-1, support_set.shape[-1]))

  # theta_min = jp.min(support_set_flat[:, 0]) - epsilon
  # theta_max = jp.max(support_set_flat[:, 0]) + epsilon
  # theta_dot_min = jp.min(support_set_flat[:, 1]) - epsilon
  # theta_dot_max = jp.max(support_set_flat[:, 1]) + epsilon
  print("done")

  rng, train_rng = random.split(rng)
  callback_rngs = random.split(rng, num_episodes)

  train_reward_per_episode = []
  policy_value_per_episode = []
  episode_lengths = []

  def callback(info):
    episode = info['episode']
    reward = info['reward']

    current_actor_params = info["optimizer"].value[0]
    policy_value = run_ddpg.eval_policy(callback_rngs[episode],
                                        current_actor_params)

    print(f"Episode {episode}, "
          f"episode_length = {info['episode_length']}, "
          f"reward = {reward}, "
          f"policy_value = {policy_value}, "
          f"elapsed = {info['elapsed']}")

    train_reward_per_episode.append(reward)
    policy_value_per_episode.append(policy_value)
    episode_lengths.append(info["episode_length"])

    # if episode == num_episodes - 1:
    # if episode % 5000 == 0 or episode == num_episodes - 1:
    #   for rollout in range(5):
    #     states, actions, _ = ddpg.rollout(
    #         random.fold_in(callback_rngs[episode], rollout),
    #         config.env,
    #         policy(current_actor_params),
    #         num_timesteps=250,
    #     )
    #     viz_pendulum_rollout(states, 2 * actions / config.max_torque)

  run_ddpg.train(
      train_rng,
      num_episodes,
      # lambda t, s: lax.bitwise_or(
      #     lax.ge(t, config.episode_length),
      #     lax.bitwise_or(
      #         lax.le(s[0], theta_min),
      #         lax.bitwise_or(
      #             lax.ge(s[0], theta_max),
      #             lax.bitwise_or(lax.le(s[1], theta_dot_min),
      #                            lax.ge(s[1], theta_dot_max))))),
      # lambda t, s: lax.bitwise_or(
      #     lax.ge(t, config.episode_length),
      #     lax.bitwise_or(lax.ge(jp.abs(s[1]), 10.0),
      #                    lax.ge(jp.abs(s[0] - jp.pi), 0.5))),
      lambda loop_state: lax.bitwise_or(
          lax.ge(loop_state.episode_length, config.episode_length),
          lax.ge(
              jp.min(
                  jp.sum((support_set_flat[:, :2] - loop_state.state[:2])**2,
                         axis=1)), max_squared_dist)),
      callback,
  )