Exemplo n.º 1
0
def init_test_trainer(use_gail: bool, parallel: bool = False):
    with open("tests/data/expert_models/cartpole_0/rollouts/final.pkl",
              "rb") as f:
        trajs = pickle.load(f)
    return init_trainer("CartPole-v1",
                        trajs,
                        use_gail=use_gail,
                        parallel=parallel)
Exemplo n.º 2
0
def train(
    _run,
    _seed: int,
    env_name: str,
    rollout_path: str,
    n_expert_demos: Optional[int],
    log_dir: str,
    init_trainer_kwargs: dict,
    total_timesteps: int,
    n_episodes_eval: int,
    init_tensorboard: bool,
    checkpoint_interval: int,
) -> dict:
    """Train an adversarial-network-based imitation learning algorithm.

  Plots (turn on using `plot_interval > 0`):
    - Plot discriminator loss during discriminator training steps in blue and
      discriminator loss during generator training steps in red.
    - Plot the performance of the generator policy versus the performance of
      a random policy. Also plot the performance of an expert policy if that is
      provided in the arguments.

  Checkpoints:
    - DiscrimNets are saved to f"{log_dir}/checkpoints/{step}/discrim/",
      where step is either the training epoch or "final".
    - Generator policies are saved to
      f"{log_dir}/checkpoints/{step}/gen_policy/".

  Args:
    _seed: Random seed.
    env_name: The environment to train in.
    rollout_path: Path to pickle containing list of Trajectories. Used as
      expert demonstrations.
    n_expert_demos: The number of expert trajectories to actually use
      after loading them from `rollout_path`.
      If None, then use all available trajectories.
      If `n_expert_demos` is an `int`, then use exactly `n_expert_demos`
      trajectories, erroring if there aren't enough trajectories. If there are
      surplus trajectories, then use the
      first `n_expert_demos` trajectories and drop the rest.
    log_dir: Directory to save models and other logging to.

    init_trainer_kwargs: Keyword arguments passed to `init_trainer`,
      used to initialize the trainer.
    total_timesteps: The number of transitions to sample from the environment
      during training.
    n_episodes_eval: The number of episodes to average over when calculating
      the average episode reward of the imitation policy for return.

    plot_interval: The number of epochs between each plot. If negative,
      then plots are disabled. If zero, then only plot at the end of training.
    n_plot_episodes: The number of episodes averaged over when
      calculating the average episode reward of a policy for the performance
      plots.
    extra_episode_data_interval: Usually mean episode rewards are calculated
      immediately before every plot. Set this parameter to a nonnegative number
      to also add episode reward data points every
      `extra_episodes_data_interval` epochs.
    show_plots: Figures are always saved to `f"{log_dir}/plots/*.png"`. If
      `show_plots` is True, then also show plots as they are created.
    init_tensorboard: If True, then write tensorboard logs to `{log_dir}/sb_tb`.

    checkpoint_interval: Save the discriminator and generator models every
      `checkpoint_interval` epochs and after training is complete. If 0,
      then only save weights after training is complete. If <0, then don't
      save weights at all.

  Returns:
    A dictionary with two keys. "imit_stats" gives the return value of
      `rollout_stats()` on rollouts test-reward-wrapped
      environment, using the final policy (remember that the ground-truth reward
      can be recovered from the "monitor_return" key). "expert_stats" gives the
      return value of `rollout_stats()` on the expert demonstrations loaded from
      `rollout_path`.
  """
    total_timesteps = int(total_timesteps)

    tf.logging.info("Logging to %s", log_dir)
    os.makedirs(log_dir, exist_ok=True)
    sacred_util.build_sacred_symlink(log_dir, _run)

    # Calculate stats for expert rollouts. Used for plot and return value.
    with open(rollout_path, "rb") as f:
        expert_trajs = pickle.load(f)

    if n_expert_demos is not None:
        assert len(expert_trajs) >= n_expert_demos
        expert_trajs = expert_trajs[:n_expert_demos]

    expert_stats = util.rollout.rollout_stats(expert_trajs)

    with util.make_session():
        if init_tensorboard:
            sb_tensorboard_dir = osp.join(log_dir, "sb_tb")
            kwargs = init_trainer_kwargs
            kwargs["init_rl_kwargs"] = kwargs.get("init_rl_kwargs", {})
            kwargs["init_rl_kwargs"]["tensorboard_log"] = sb_tensorboard_dir

        trainer = init_trainer(env_name,
                               expert_trajs,
                               seed=_seed,
                               log_dir=log_dir,
                               **init_trainer_kwargs)

        def callback(epoch):
            if checkpoint_interval > 0 and epoch % checkpoint_interval == 0:
                save(trainer,
                     os.path.join(log_dir, "checkpoints", f"{epoch:05d}"))

        trainer.train(total_timesteps, callback)

        # Save final artifacts.
        if checkpoint_interval >= 0:
            save(trainer, os.path.join(log_dir, "checkpoints", "final"))

        # Final evaluation of imitation policy.
        results = {}
        sample_until_eval = util.rollout.min_episodes(n_episodes_eval)
        trajs = util.rollout.generate_trajectories(
            trainer.gen_policy,
            trainer.venv_test,
            sample_until=sample_until_eval)
        results["imit_stats"] = util.rollout.rollout_stats(trajs)
        results["expert_stats"] = expert_stats
        return results
Exemplo n.º 3
0
def train_and_plot(
    _seed: int,
    env_name: str,
    rollout_glob: str,
    log_dir: str,
    *,
    n_epochs: int = 100,
    n_epochs_per_plot: Optional[float] = None,
    n_disc_steps_per_epoch: int = 10,
    n_gen_steps_per_epoch: int = 10000,
    n_episodes_per_reward_data: int = 5,
    n_episodes_eval: int = 50,
    checkpoint_interval: int = 5,
    interactive: bool = True,
    expert_policy=None,
    init_trainer_kwargs: dict = {},
) -> Dict[str, float]:
    """Alternate between training the generator and discriminator.

  Every epoch:
    - Plot discriminator loss during discriminator training steps in blue and
      discriminator loss during generator training steps in red.
    - Plot the performance of the generator policy versus the performance of
      a random policy. Also plot the performance of an expert policy if that is
      provided in the arguments.

  Checkpoints:
    - DiscrimNets are saved to f"{log_dir}/checkpoints/{step}/discrim/",
      where step is either the training epoch or "final".
    - Generator policies are saved to
      f"{log_dir}/checkpoints/{step}/gen_policy/".

  Args:
      _seed: Random seed.
      env_name: The environment to train in.
      log_dir: Directory to save models and other logging to.
      n_epochs: The number of epochs to train. Each epoch consists of
          `n_disc_steps_per_epoch` discriminator steps followed by
          `n_gen_steps_per_epoch` generator steps.
      n_epochs_per_plot: An optional number, greater than or equal to 1. The
          (possibly fractional) number of epochs between each plot. The first
          plot is at epoch 0, after the first discrim and generator steps.
          If `n_epochs_per_plot is None`, then don't make any plots.
      n_disc_steps_per_epoch: The number of discriminator update steps during
          every training epoch.
      n_gen_plot_episodes: The number of generator update steps during every
          generator epoch.
      n_episodes_per_reward_data: The number of episodes averaged over when
          calculating the average episode reward of a policy for the performance
          plots.
      n_episodes_eval: The number of episodes to average over when calculating
          the average ground truth reward return of the final policy.
      checkpoint_interval: Save the discriminator and generator models every
          `checkpoint_interval` epochs and after training is complete. If <=0,
          then only save weights after training is complete.
      interactive: Figures are always saved to `output/*.png`. If `interactive`
        is True, then also show plots as they are created.
      expert_policy (BasePolicy or BaseRLModel, optional): If provided, then
          also plot the performance of this expert policy.
      init_trainer_kwargs: Keyword arguments passed to `init_trainer`,
        used to initialize the trainer.

  Returns:
      results: A dictionary with two keys, "mean" and "std_err". The
          corresponding values are the mean and standard error of
          ground truth episode return for the imitation learning algorithm.
  """
    assert n_epochs_per_plot is None or n_epochs_per_plot >= 1

    with util.make_session():
        trainer = init_trainer(env_name,
                               rollout_glob=rollout_glob,
                               seed=_seed,
                               log_dir=log_dir,
                               **init_trainer_kwargs)

        tf.logging.info("Logging to %s", log_dir)
        os.makedirs(log_dir, exist_ok=True)
        sb_logger.configure(folder=osp.join(log_dir, 'generator'),
                            format_strs=['tensorboard', 'stdout'])

        plot_idx = 0
        gen_data = ([], [])
        disc_data = ([], [])

        def disc_plot_add_data(gen_mode: bool = False):
            """Evaluates and records the discriminator loss for plotting later.

      Args:
          gen_mode: Whether the generator or the discriminator is active.
              We use this to color the data points.
      """
            nonlocal plot_idx
            mode = "gen" if gen_mode else "dis"
            X, Y = gen_data if gen_mode else disc_data
            # Divide by two since we get two data points (gen and disc) per epoch.
            X.append(plot_idx / 2)
            Y.append(trainer.eval_disc_loss())
            tf.logging.info("plot idx ({}): {} disc loss: {}".format(
                mode, plot_idx, Y[-1]))
            plot_idx += 1

        def disc_plot_show():
            """Render a plot of discriminator loss vs. training epoch number."""
            plt.scatter(disc_data[0],
                        disc_data[1],
                        c='g',
                        alpha=0.7,
                        s=4,
                        label="discriminator loss (dis step)")
            plt.scatter(gen_data[0],
                        gen_data[1],
                        c='r',
                        alpha=0.7,
                        s=4,
                        label="discriminator loss (gen step)")
            plt.title("Discriminator loss")
            plt.legend()
            _savefig_timestamp("plot_fight_loss_disc", interactive)

        gen_ep_reward = defaultdict(list)
        rand_ep_reward = defaultdict(list)
        exp_ep_reward = defaultdict(list)

        def ep_reward_plot_add_data(env, name):
            """Calculate and record average episode returns."""
            gen_policy = trainer.gen_policy
            gen_ret = util.rollout.mean_return(
                gen_policy, env, n_episodes=n_episodes_per_reward_data)
            gen_ep_reward[name].append(gen_ret)
            tf.logging.info("generator return: {}".format(gen_ret))

            rand_policy = util.init_rl(trainer.env)
            rand_ret = util.rollout.mean_return(
                rand_policy, env, n_episodes=n_episodes_per_reward_data)
            rand_ep_reward[name].append(rand_ret)
            tf.logging.info("random return: {}".format(rand_ret))

            if expert_policy is not None:
                exp_ret = util.rollout.mean_return(
                    expert_policy, env, n_episodes=n_episodes_per_reward_data)
                exp_ep_reward[name].append(exp_ret)
                tf.logging.info("exp return: {}".format(exp_ret))

        def ep_reward_plot_show():
            """Render and show average episode reward plots."""
            for name in gen_ep_reward:
                plt.title(name + " Performance")
                plt.xlabel("epochs")
                plt.ylabel("Average reward per episode (n={})".format(
                    n_episodes_per_reward_data))
                plt.plot(gen_ep_reward[name],
                         label="avg gen ep reward",
                         c="red")
                plt.plot(rand_ep_reward[name],
                         label="avg random ep reward",
                         c="black")
                plt.plot(exp_ep_reward[name],
                         label="avg exp ep reward",
                         c="blue")
                plt.legend()
                _savefig_timestamp("plot_fight_epreward_gen", interactive)

        if n_epochs_per_plot is not None:
            n_plots_per_epoch = 1 / n_epochs_per_plot
        else:
            n_plots_per_epoch = None

        def should_plot_now(epoch) -> bool:
            """For positive epochs, returns True if a plot should be rendered now.

      This also controls the frequency at which `ep_reward_plot_add_data` is
      called, because generating those rollouts is too expensive to perform
      every timestep.
      """
            assert epoch >= 1
            if n_plots_per_epoch is None:
                return False
            plot_num = math.floor(n_plots_per_epoch * epoch)
            prev_plot_num = math.floor(n_plots_per_epoch * (epoch - 1))
            assert abs(plot_num - prev_plot_num) <= 1
            return plot_num != prev_plot_num

        # Collect data for epoch 0.
        if n_epochs_per_plot is not None:
            disc_plot_add_data(False)
            ep_reward_plot_add_data(trainer.env, "Ground Truth Reward")
            ep_reward_plot_add_data(trainer.env_train, "Train Reward")
            ep_reward_plot_add_data(trainer.env_test, "Test Reward")

        # Main training loop.
        for epoch in tqdm.tqdm(range(1, n_epochs + 1), desc="epoch"):
            trainer.train_disc(n_disc_steps_per_epoch)
            disc_plot_add_data(False)
            trainer.train_gen(n_gen_steps_per_epoch)
            disc_plot_add_data(True)

            if should_plot_now(epoch):
                disc_plot_show()
                ep_reward_plot_add_data(trainer.env, "Ground Truth Reward")
                ep_reward_plot_add_data(trainer.env_train, "Train Reward")
                ep_reward_plot_add_data(trainer.env_test, "Test Reward")
                ep_reward_plot_show()

            if checkpoint_interval > 0 and epoch % checkpoint_interval == 0:
                save(trainer,
                     os.path.join(log_dir, "checkpoints", f"{epoch:05d}"))

        # Save final artifacts.
        save(trainer, os.path.join(log_dir, "checkpoints", "final"))

        # Final evaluation of imitation policy.
        stats = util.rollout.rollout_stats(trainer.gen_policy,
                                           trainer.env,
                                           n_episodes=n_episodes_eval)
        assert stats["n_traj"] >= n_episodes_eval
        mean = stats["return_mean"]
        std_err = stats["return_std"] / math.sqrt(n_episodes_eval)
        print(f"[result] Mean Episode Return: {mean:.4g} ± {std_err:.3g} "
              f"(n={stats['n_traj']})")

        return dict(mean=mean, std_err=std_err)
Exemplo n.º 4
0
def init_test_trainer(env_id: str, use_gail: bool, parallel: bool = False):
  return init_trainer(env_id=env_id,
                      rollout_glob=f"tests/data/rollouts/{env_id}*.pkl",
                      use_gail=use_gail,
                      parallel=parallel)
Exemplo n.º 5
0
def train(
    _run,
    _seed: int,
    env_name: str,
    rollout_path: str,
    n_expert_demos: Optional[int],
    log_dir: str,
    *,
    n_epochs: int,
    n_gen_steps_per_epoch: int,
    n_disc_steps_per_epoch: int,
    init_trainer_kwargs: dict,
    n_episodes_eval: int,
    plot_interval: int,
    n_plot_episodes: int,
    show_plots: bool,
    init_tensorboard: bool,
    checkpoint_interval: int = 5,
) -> dict:
    """Train an adversarial-network-based imitation learning algorithm.

  Plots (turn on using `plot_interval > 0`):
    - Plot discriminator loss during discriminator training steps in blue and
      discriminator loss during generator training steps in red.
    - Plot the performance of the generator policy versus the performance of
      a random policy. Also plot the performance of an expert policy if that is
      provided in the arguments.

  Checkpoints:
    - DiscrimNets are saved to f"{log_dir}/checkpoints/{step}/discrim/",
      where step is either the training epoch or "final".
    - Generator policies are saved to
      f"{log_dir}/checkpoints/{step}/gen_policy/".

  Args:
    _seed: Random seed.
    env_name: The environment to train in.
    rollout_path: Path to pickle containing list of Trajectories. Used as
      expert demonstrations.
    n_expert_demos: The number of expert trajectories to actually use
      after loading them from `rollout_path`.
      If None, then use all available trajectories.
      If `n_expert_demos` is an `int`, then use exactly `n_expert_demos`
      trajectories, erroring if there aren't enough trajectories. If there are
      surplus trajectories, then use the
      first `n_expert_demos` trajectories and drop the rest.
    log_dir: Directory to save models and other logging to.

    n_epochs: The number of epochs to train. Each epoch consists of
      `n_disc_steps_per_epoch` discriminator steps followed by
      `n_gen_steps_per_epoch` generator steps.
    n_gen_steps_per_epoch: The number of generator update steps during every
      training epoch.
    n_disc_steps_per_epoch: The number of discriminator update steps during
      every training epoch.
    init_trainer_kwargs: Keyword arguments passed to `init_trainer`,
      used to initialize the trainer.
    n_episodes_eval: The number of episodes to average over when calculating
      the average episode reward of the imitation policy for return.

    plot_interval: The number of epochs between each plot. (If nonpositive,
      then plots are disabled).
    n_plot_episodes: The number of episodes averaged over when
      calculating the average episode reward of a policy for the performance
      plots.
    show_plots: Figures are always saved to `f"{log_dir}/plots/*.png"`. If
      `show_plots` is True, then also show plots as they are created.
    init_tensorboard: If True, then write tensorboard logs to `{log_dir}/sb_tb`.

    checkpoint_interval: Save the discriminator and generator models every
      `checkpoint_interval` epochs and after training is complete. If <=0,
      then only save weights after training is complete.

  Returns:
    A dictionary with two keys. "imit_stats" gives the return value of
      `rollout_stats()` on rollouts test-reward-wrapped
      environment, using the final policy (remember that the ground-truth reward
      can be recovered from the "monitor_return" key). "expert_stats" gives the
      return value of `rollout_stats()` on the expert demonstrations loaded from
      `rollout_path`.
  """
    tf.logging.info("Logging to %s", log_dir)
    os.makedirs(log_dir, exist_ok=True)
    sacred_util.build_sacred_symlink(log_dir, _run)

    # Calculate stats for expert rollouts. Used for plot and return value.
    with open(rollout_path, "rb") as f:
        expert_trajs = pickle.load(f)

    if n_expert_demos is not None:
        assert len(expert_trajs) >= n_expert_demos
        expert_trajs = expert_trajs[:n_expert_demos]

    expert_stats = util.rollout.rollout_stats(expert_trajs)

    with util.make_session():
        sb_logger.configure(folder=osp.join(log_dir, 'generator'),
                            format_strs=['tensorboard', 'stdout'])

        if init_tensorboard:
            sb_tensorboard_dir = osp.join(log_dir, "sb_tb")
            kwargs = init_trainer_kwargs
            kwargs["init_rl_kwargs"] = kwargs.get("init_rl_kwargs", {})
            kwargs["init_rl_kwargs"]["tensorboard_log"] = sb_tensorboard_dir

        trainer = init_trainer(env_name,
                               expert_trajs,
                               seed=_seed,
                               log_dir=log_dir,
                               **init_trainer_kwargs)

        if plot_interval > 0:
            visualizer = _TrainVisualizer(
                trainer=trainer,
                show_plots=show_plots,
                n_episodes_per_reward_data=n_plot_episodes,
                log_dir=log_dir,
                expert_mean_ep_reward=expert_stats["return_mean"])
        else:
            visualizer = None

        # Main training loop.
        for epoch in tqdm.tqdm(range(1, n_epochs + 1), desc="epoch"):
            trainer.train_disc(n_disc_steps_per_epoch)
            if visualizer:
                visualizer.add_data_disc_loss(False)

            trainer.train_gen(n_gen_steps_per_epoch)
            if visualizer:
                visualizer.add_data_disc_loss(True)

            if visualizer and epoch % plot_interval == 0:
                visualizer.plot_disc_loss()
                visualizer.add_data_ep_reward(trainer.venv,
                                              "Ground Truth Reward")
                visualizer.add_data_ep_reward(trainer.venv_train,
                                              "Train Reward")
                visualizer.add_data_ep_reward(trainer.venv_test, "Test Reward")
                visualizer.plot_ep_reward()

            if checkpoint_interval > 0 and epoch % checkpoint_interval == 0:
                save(trainer,
                     os.path.join(log_dir, "checkpoints", f"{epoch:05d}"))

        # Save final artifacts.
        save(trainer, os.path.join(log_dir, "checkpoints", "final"))

        # Final evaluation of imitation policy.
        results = {}
        sample_until_eval = util.rollout.min_episodes(n_episodes_eval)
        trajs = util.rollout.generate_trajectories(
            trainer.gen_policy,
            trainer.venv_test,
            sample_until=sample_until_eval)
        results["imit_stats"] = util.rollout.rollout_stats(trajs)
        results["expert_stats"] = expert_stats

        return results
Exemplo n.º 6
0
def init_test_trainer(tmpdir: str, use_gail: bool, parallel: bool = False):
    trajs = types.load("tests/data/expert_models/cartpole_0/rollouts/final.pkl")
    return init_trainer(
        "CartPole-v1", trajs, log_dir=tmpdir, use_gail=use_gail, parallel=parallel
    )