def test_reward_overwrite(): """Test that reward wrapper actually overwrites base rewards.""" env_name = "Pendulum-v0" num_envs = 3 env = util.make_vec_env(env_name, num_envs) reward_fn = FunkyReward() wrapped_env = reward_wrapper.RewardVecEnvWrapper(env, reward_fn) policy = RandomPolicy(env.observation_space, env.action_space) sample_until = rollout.min_episodes(10) default_stats = rollout.rollout_stats( rollout.generate_trajectories(policy, env, sample_until)) wrapped_stats = rollout.rollout_stats( rollout.generate_trajectories(policy, wrapped_env, sample_until)) # Pendulum-v0 always has negative rewards assert default_stats["return_max"] < 0 # ours gives between 1 * traj_len and num_envs * traj_len reward # (trajectories are all constant length of 200 in Pendulum) steps = wrapped_stats["len_mean"] assert wrapped_stats["return_min"] == 1 * steps assert wrapped_stats["return_max"] == num_envs * steps # check that wrapped reward is negative (all pendulum rewards is negative) # and other rewards are non-negative rand_act, _, _, _ = policy.step(wrapped_env.reset()) _, rew, _, infos = wrapped_env.step(rand_act) assert np.all(rew >= 0) assert np.all([info_dict["wrapped_env_rew"] < 0 for info_dict in infos])
def eval_policy( rl_algo: Union[base_class.BaseAlgorithm, policies.BasePolicy], venv: vec_env.VecEnv, n_episodes_eval: int, ) -> Mapping[str, float]: """Evaluation of imitation learned policy. Args: rl_algo: Algorithm to evaluate. venv: Environment to evaluate on. n_episodes_eval: The number of episodes to average over when calculating the average episode reward of the imitation policy for return. Returns: A dictionary with two keys. "imit_stats" gives the return value of `rollout_stats()` on rollouts test-reward-wrapped environment, using the final policy (remember that the ground-truth reward can be recovered from the "monitor_return" key). "expert_stats" gives the return value of `rollout_stats()` on the expert demonstrations loaded from `rollout_path`. """ sample_until_eval = rollout.make_min_episodes(n_episodes_eval) trajs = rollout.generate_trajectories( rl_algo, venv, sample_until=sample_until_eval, ) return rollout.rollout_stats(trajs)
def test_policy(self, *, min_episodes: int = 10) -> dict: """Test current imitation policy on environment & give some rollout stats. Args: min_episodes: Minimum number of rolled-out episodes. Returns: rollout statistics collected by `imitation.utils.rollout.rollout_stats()`. """ trajs = rollout.generate_trajectories( self.policy, self.env, sample_until=rollout.min_episodes(min_episodes)) reward_stats = rollout.rollout_stats(trajs) return reward_stats
def test_rollout_stats(): """Applying `ObsRewIncrementWrapper` halves the reward mean. `rollout_stats` should reflect this. """ env = gym.make("CartPole-v1") env = bench.Monitor(env, None) env = ObsRewHalveWrapper(env) venv = vec_env.DummyVecEnv([lambda: env]) with serialize.load_policy("zero", "UNUSED", venv) as policy: trajs = rollout.generate_trajectories(policy, venv, rollout.min_episodes(10)) s = rollout.rollout_stats(trajs) np.testing.assert_allclose(s["return_mean"], s["monitor_return_mean"] / 2) np.testing.assert_allclose(s["return_std"], s["monitor_return_std"] / 2) np.testing.assert_allclose(s["return_min"], s["monitor_return_min"] / 2) np.testing.assert_allclose(s["return_max"], s["monitor_return_max"] / 2)
def test_policy(self, *, n_trajectories=10, true_reward=True): """Test current imitation policy on environment & give some rollout stats. Args: n_trajectories (int): number of rolled-out trajectories. true_reward (bool): should this use ground truth reward from underlying environment (True), or imitation reward (False)? Returns: dict: rollout statistics collected by `imitation.utils.rollout.rollout_stats()`. """ self.imitation_trainer.set_env(self.venv) trajs = rollout.generate_trajectories( self.imitation_trainer, self.venv if true_reward else self.wrapped_env, sample_until=rollout.min_episodes(n_trajectories), ) reward_stats = rollout.rollout_stats(trajs) return reward_stats
def rollouts_and_policy( _run, _seed: int, env_name: str, total_timesteps: int, *, log_dir: str, num_vec: int, parallel: bool, max_episode_steps: Optional[int], normalize: bool, normalize_kwargs: dict, init_rl_kwargs: dict, n_episodes_eval: int, reward_type: Optional[str], reward_path: Optional[str], rollout_save_interval: int, rollout_save_final: bool, rollout_save_n_timesteps: Optional[int], rollout_save_n_episodes: Optional[int], policy_save_interval: int, policy_save_final: bool, init_tensorboard: bool, ) -> dict: """Trains an expert policy from scratch and saves the rollouts and policy. Checkpoints: At applicable training steps `step` (where step is either an integer or "final"): - Policies are saved to `{log_dir}/policies/{step}/`. - Rollouts are saved to `{log_dir}/rollouts/{step}.pkl`. Args: env_name: The gym.Env name. Loaded as VecEnv. total_timesteps: Number of training timesteps in `model.learn()`. log_dir: The root directory to save metrics and checkpoints to. num_vec: Number of environments in VecEnv. parallel: If True, then use DummyVecEnv. Otherwise use SubprocVecEnv. max_episode_steps: If not None, then environments are wrapped by TimeLimit so that they have at most `max_episode_steps` steps per episode. normalize: If True, then rescale observations and reward. normalize_kwargs: kwargs for `VecNormalize`. init_rl_kwargs: kwargs for `init_rl`. n_episodes_eval: The number of episodes to average over when calculating the average ground truth reward return of the final policy. reward_type: If provided, then load the serialized reward of this type, wrapping the environment in this reward. This is useful to test whether a reward model transfers. For more information, see `imitation.rewards.serialize.load_reward`. reward_path: A specifier, such as a path to a file on disk, used by reward_type to load the reward model. For more information, see `imitation.rewards.serialize.load_reward`. rollout_save_interval: The number of training updates in between intermediate rollout saves. If the argument is nonpositive, then don't save intermediate updates. rollout_save_final: If True, then save rollouts right after training is finished. rollout_save_n_timesteps: The minimum number of timesteps saved in every file. Could be more than `rollout_save_n_timesteps` because trajectories are saved by episode rather than by transition. Must set exactly one of `rollout_save_n_timesteps` and `rollout_save_n_episodes`. rollout_save_n_episodes: The number of episodes saved in every file. Must set exactly one of `rollout_save_n_timesteps` and `rollout_save_n_episodes`. policy_save_interval: The number of training updates between saves. Has the same semantics are `rollout_save_interval`. policy_save_final: If True, then save the policy right after training is finished. init_tensorboard: If True, then write tensorboard logs to {log_dir}/sb_tb and "output/summary/...". Returns: The return value of `rollout_stats()` using the final policy. """ os.makedirs(log_dir, exist_ok=True) sacred_util.build_sacred_symlink(log_dir, _run) sample_until = rollout.make_sample_until(rollout_save_n_timesteps, rollout_save_n_episodes) eval_sample_until = rollout.min_episodes(n_episodes_eval) with networks.make_session(): tf.logging.set_verbosity(tf.logging.INFO) logger.configure(folder=osp.join(log_dir, "rl"), format_strs=["tensorboard", "stdout"]) rollout_dir = osp.join(log_dir, "rollouts") policy_dir = osp.join(log_dir, "policies") os.makedirs(rollout_dir, exist_ok=True) os.makedirs(policy_dir, exist_ok=True) if init_tensorboard: sb_tensorboard_dir = osp.join(log_dir, "sb_tb") # Convert sacred's ReadOnlyDict to dict so we can modify on next line. init_rl_kwargs = dict(init_rl_kwargs) init_rl_kwargs["tensorboard_log"] = sb_tensorboard_dir venv = util.make_vec_env( env_name, num_vec, seed=_seed, parallel=parallel, log_dir=log_dir, max_episode_steps=max_episode_steps, ) log_callbacks = [] with contextlib.ExitStack() as stack: if reward_type is not None: reward_fn_ctx = load_reward(reward_type, reward_path, venv) reward_fn = stack.enter_context(reward_fn_ctx) venv = RewardVecEnvWrapper(venv, reward_fn) log_callbacks.append(venv.log_callback) tf.logging.info( f"Wrapped env in reward {reward_type} from {reward_path}.") vec_normalize = None if normalize: venv = vec_normalize = VecNormalize(venv, **normalize_kwargs) policy = util.init_rl(venv, verbose=1, **init_rl_kwargs) # Make callback to save intermediate artifacts during training. step = 0 def callback(locals_: dict, _) -> bool: nonlocal step step += 1 policy = locals_["self"] # TODO(adam): make logging frequency configurable for callback in log_callbacks: callback(sb_logger) if rollout_save_interval > 0 and step % rollout_save_interval == 0: save_path = osp.join(rollout_dir, f"{step}.pkl") rollout.rollout_and_save(save_path, policy, venv, sample_until) if policy_save_interval > 0 and step % policy_save_interval == 0: output_dir = os.path.join(policy_dir, f"{step:05d}") serialize.save_stable_model(output_dir, policy, vec_normalize) policy.learn(total_timesteps, callback=callback) # Save final artifacts after training is complete. if rollout_save_final: save_path = osp.join(rollout_dir, "final.pkl") rollout.rollout_and_save(save_path, policy, venv, sample_until) if policy_save_final: output_dir = os.path.join(policy_dir, "final") serialize.save_stable_model(output_dir, policy, vec_normalize) # Final evaluation of expert policy. trajs = rollout.generate_trajectories(policy, venv, eval_sample_until) stats = rollout.rollout_stats(trajs) return stats
def eval_policy( _run, _seed: int, env_name: str, eval_n_timesteps: Optional[int], eval_n_episodes: Optional[int], num_vec: int, parallel: bool, render: bool, render_fps: int, log_dir: str, policy_type: str, policy_path: str, reward_type: Optional[str] = None, reward_path: Optional[str] = None, max_episode_steps: Optional[int] = None, ): """Rolls a policy out in an environment, collecting statistics. Args: _seed: generated by Sacred. env_name: Gym environment identifier. eval_n_timesteps: Minimum number of timesteps to evaluate for. Set exactly one of `eval_n_episodes` and `eval_n_timesteps`. eval_n_episodes: Minimum number of episodes to evaluate for. Set exactly one of `eval_n_episodes` and `eval_n_timesteps`. num_vec: Number of environments to run simultaneously. parallel: If True, use `SubprocVecEnv` for true parallelism; otherwise, uses `DummyVecEnv`. max_episode_steps: If not None, then environments are wrapped by TimeLimit so that they have at most `max_episode_steps` steps per episode. render: If True, renders interactively to the screen. log_dir: The directory to log intermediate output to. (As of 2019-07-19 this is just episode-by-episode reward from bench.Monitor.) policy_type: A unique identifier for the saved policy, defined in POLICY_CLASSES. policy_path: A path to the serialized policy. reward_type: If specified, overrides the environment reward with a reward of this. reward_path: If reward_type is specified, the path to a serialized reward of `reward_type` to override the environment reward with. Returns: Return value of `imitation.util.rollout.rollout_stats()`. """ os.makedirs(log_dir, exist_ok=True) sacred_util.build_sacred_symlink(log_dir, _run) tf.logging.set_verbosity(tf.logging.INFO) tf.logging.info("Logging to %s", log_dir) sample_until = rollout.make_sample_until(eval_n_timesteps, eval_n_episodes) venv = util.make_vec_env( env_name, num_vec, seed=_seed, parallel=parallel, log_dir=log_dir, max_episode_steps=max_episode_steps, ) if render: venv = InteractiveRender(venv, render_fps) # TODO(adam): add support for videos using VideoRecorder? with contextlib.ExitStack() as stack: if reward_type is not None: reward_fn_ctx = load_reward(reward_type, reward_path, venv) reward_fn = stack.enter_context(reward_fn_ctx) venv = reward_wrapper.RewardVecEnvWrapper(venv, reward_fn) tf.logging.info( f"Wrapped env in reward {reward_type} from {reward_path}.") with serialize.load_policy(policy_type, policy_path, venv) as policy: trajs = rollout.generate_trajectories(policy, venv, sample_until) return rollout.rollout_stats(trajs)
def train( _run, _seed: int, algorithm: str, env_name: str, num_vec: int, parallel: bool, max_episode_steps: Optional[int], rollout_path: str, n_expert_demos: Optional[int], log_dir: str, total_timesteps: int, n_episodes_eval: int, init_tensorboard: bool, checkpoint_interval: int, gen_batch_size: int, init_rl_kwargs: Mapping, algorithm_kwargs: Mapping[str, Mapping], discrim_net_kwargs: Mapping[str, Mapping], ) -> dict: """Train an adversarial-network-based imitation learning algorithm. Checkpoints: - DiscrimNets are saved to `f"{log_dir}/checkpoints/{step}/discrim/"`, where step is either the training round or "final". - Generator policies are saved to `f"{log_dir}/checkpoints/{step}/gen_policy/"`. Args: _seed: Random seed. algorithm: A case-insensitive string determining which adversarial imitation learning algorithm is executed. Either "airl" or "gail". env_name: The environment to train in. num_vec: Number of `gym.Env` to vectorize. parallel: Whether to use "true" parallelism. If True, then use `SubProcVecEnv`. Otherwise, use `DummyVecEnv` which steps through environments serially. max_episode_steps: If not None, then a TimeLimit wrapper is applied to each environment to artificially limit the maximum number of timesteps in an episode. rollout_path: Path to pickle containing list of Trajectories. Used as expert demonstrations. n_expert_demos: The number of expert trajectories to actually use after loading them from `rollout_path`. If None, then use all available trajectories. If `n_expert_demos` is an `int`, then use exactly `n_expert_demos` trajectories, erroring if there aren't enough trajectories. If there are surplus trajectories, then use the first `n_expert_demos` trajectories and drop the rest. log_dir: Directory to save models and other logging to. total_timesteps: The number of transitions to sample from the environment during training. n_episodes_eval: The number of episodes to average over when calculating the average episode reward of the imitation policy for return. init_tensorboard: If True, then write tensorboard logs to `{log_dir}/sb_tb`. checkpoint_interval: Save the discriminator and generator models every `checkpoint_interval` rounds and after training is complete. If 0, then only save weights after training is complete. If <0, then don't save weights at all. gen_batch_size: Batch size for generator updates. Sacred automatically uses this to calculate `n_steps` in `init_rl_kwargs`. In the script body, this is only used in sanity checks. init_rl_kwargs: Keyword arguments for `init_rl`, the RL algorithm initialization utility function. algorithm_kwargs: Keyword arguments for the `GAIL` or `AIRL` constructor that can apply to either constructor. Unlike a regular kwargs argument, this argument can only have the following keys: "shared", "airl", and "gail". `algorithm_kwargs["airl"]`, if it is provided, is a kwargs `Mapping` passed to the `AIRL` constructor when `algorithm == "airl"`. Likewise `algorithm_kwargs["gail"]` is passed to the `GAIL` constructor when `algorithm == "gail"`. `algorithm_kwargs["shared"]`, if provided, is passed to both the `AIRL` and `GAIL` constructors. Duplicate keyword argument keys between `algorithm_kwargs["shared"]` and `algorithm_kwargs["airl"]` (or "gail") leads to an error. discrim_net_kwargs: Keyword arguments for the `DiscrimNet` constructor. Unlike a regular kwargs argument, this argument can only have the following keys: "shared", "airl", "gail". These keys have the same meaning as they do in `algorithm_kwargs`. Returns: A dictionary with two keys. "imit_stats" gives the return value of `rollout_stats()` on rollouts test-reward-wrapped environment, using the final policy (remember that the ground-truth reward can be recovered from the "monitor_return" key). "expert_stats" gives the return value of `rollout_stats()` on the expert demonstrations loaded from `rollout_path`. """ if gen_batch_size % num_vec != 0: raise ValueError( f"num_vec={num_vec} must evenly divide gen_batch_size={gen_batch_size}." ) allowed_keys = {"shared", "gail", "airl"} if not discrim_net_kwargs.keys() <= allowed_keys: raise ValueError( f"Invalid discrim_net_kwargs.keys()={discrim_net_kwargs.keys()}. " f"Allowed keys: {allowed_keys}" ) if not algorithm_kwargs.keys() <= allowed_keys: raise ValueError( f"Invalid discrim_net_kwargs.keys()={algorithm_kwargs.keys()}. " f"Allowed keys: {allowed_keys}" ) if not os.path.exists(rollout_path): raise ValueError(f"File at rollout_path={rollout_path} does not exist.") expert_trajs = types.load(rollout_path) if n_expert_demos is not None: if not len(expert_trajs) >= n_expert_demos: raise ValueError( f"Want to use n_expert_demos={n_expert_demos} trajectories, but only " f"{len(expert_trajs)} are available via {rollout_path}." ) expert_trajs = expert_trajs[:n_expert_demos] expert_transitions = rollout.flatten_trajectories(expert_trajs) total_timesteps = int(total_timesteps) logging.info("Logging to %s", log_dir) logger.configure(log_dir, ["tensorboard", "stdout"]) os.makedirs(log_dir, exist_ok=True) sacred_util.build_sacred_symlink(log_dir, _run) venv = util.make_vec_env( env_name, num_vec, seed=_seed, parallel=parallel, log_dir=log_dir, max_episode_steps=max_episode_steps, ) # if init_tensorboard: # tensorboard_log = osp.join(log_dir, "sb_tb") # else: # tensorboard_log = None gen_algo = util.init_rl( # FIXME(sam): ignoring tensorboard_log is a hack to prevent SB3 from # re-configuring the logger (SB3 issue #109). See init_rl() for details. # TODO(shwang): Let's get rid of init_rl after SB3 issue #109 is fixed? # Besides sidestepping #109, init_rl is just a stub function. venv, **init_rl_kwargs, ) discrim_kwargs_shared = discrim_net_kwargs.get("shared", {}) discrim_kwargs_algo = discrim_net_kwargs.get(algorithm, {}) final_discrim_kwargs = dict(**discrim_kwargs_shared, **discrim_kwargs_algo) algorithm_kwargs_shared = algorithm_kwargs.get("shared", {}) algorithm_kwargs_algo = algorithm_kwargs.get(algorithm, {}) final_algorithm_kwargs = dict( **algorithm_kwargs_shared, **algorithm_kwargs_algo, ) if algorithm.lower() == "gail": algo_cls = adversarial.GAIL elif algorithm.lower() == "airl": algo_cls = adversarial.AIRL else: raise ValueError(f"Invalid value algorithm={algorithm}.") trainer = algo_cls( venv=venv, expert_data=expert_transitions, gen_algo=gen_algo, log_dir=log_dir, discrim_kwargs=final_discrim_kwargs, **final_algorithm_kwargs, ) def callback(round_num): if checkpoint_interval > 0 and round_num % checkpoint_interval == 0: save(trainer, os.path.join(log_dir, "checkpoints", f"{round_num:05d}")) trainer.train(total_timesteps, callback) # Save final artifacts. if checkpoint_interval >= 0: save(trainer, os.path.join(log_dir, "checkpoints", "final")) # Final evaluation of imitation policy. results = {} sample_until_eval = rollout.min_episodes(n_episodes_eval) trajs = rollout.generate_trajectories( trainer.gen_algo, trainer.venv_train_norm, sample_until=sample_until_eval ) results["expert_stats"] = rollout.rollout_stats(expert_trajs) results["imit_stats"] = rollout.rollout_stats(trajs) return results
def eval_policy( _run, _seed: int, env_name: str, eval_n_timesteps: Optional[int], eval_n_episodes: Optional[int], num_vec: int, parallel: bool, render: bool, render_fps: int, videos: bool, video_kwargs: Mapping[str, Any], log_dir: str, policy_type: str, policy_path: str, reward_type: Optional[str] = None, reward_path: Optional[str] = None, max_episode_steps: Optional[int] = None, ): """Rolls a policy out in an environment, collecting statistics. Args: _seed: generated by Sacred. env_name: Gym environment identifier. eval_n_timesteps: Minimum number of timesteps to evaluate for. Set exactly one of `eval_n_episodes` and `eval_n_timesteps`. eval_n_episodes: Minimum number of episodes to evaluate for. Set exactly one of `eval_n_episodes` and `eval_n_timesteps`. num_vec: Number of environments to run simultaneously. parallel: If True, use `SubprocVecEnv` for true parallelism; otherwise, uses `DummyVecEnv`. max_episode_steps: If not None, then environments are wrapped by TimeLimit so that they have at most `max_episode_steps` steps per episode. render: If True, renders interactively to the screen. render_fps: The target number of frames per second to render on screen. videos: If True, saves videos to `log_dir`. video_kwargs: Keyword arguments passed through to `video_wrapper.VideoWrapper`. log_dir: The directory to log intermediate output to, such as episode reward. policy_type: A unique identifier for the saved policy, defined in POLICY_CLASSES. policy_path: A path to the serialized policy. reward_type: If specified, overrides the environment reward with a reward of this. reward_path: If reward_type is specified, the path to a serialized reward of `reward_type` to override the environment reward with. Returns: Return value of `imitation.util.rollout.rollout_stats()`. """ os.makedirs(log_dir, exist_ok=True) sacred_util.build_sacred_symlink(log_dir, _run) logging.basicConfig(level=logging.INFO) logging.info("Logging to %s", log_dir) sample_until = rollout.make_sample_until(eval_n_timesteps, eval_n_episodes) post_wrappers = [video_wrapper_factory(log_dir, **video_kwargs) ] if videos else None venv = util.make_vec_env( env_name, num_vec, seed=_seed, parallel=parallel, log_dir=log_dir, max_episode_steps=max_episode_steps, post_wrappers=post_wrappers, ) try: if render: # As of July 31, 2020, DummyVecEnv rendering only works with num_vec=1 # due to a bug on Stable Baselines 3. venv = InteractiveRender(venv, render_fps) if reward_type is not None: reward_fn = load_reward(reward_type, reward_path, venv) venv = reward_wrapper.RewardVecEnvWrapper(venv, reward_fn) logging.info( f"Wrapped env in reward {reward_type} from {reward_path}.") policy = serialize.load_policy(policy_type, policy_path, venv) trajs = rollout.generate_trajectories(policy, venv, sample_until) return rollout.rollout_stats(trajs) finally: venv.close()
def train( _run, _seed: int, env_name: str, rollout_path: str, n_expert_demos: Optional[int], log_dir: str, init_trainer_kwargs: dict, total_timesteps: int, n_episodes_eval: int, init_tensorboard: bool, checkpoint_interval: int, ) -> dict: """Train an adversarial-network-based imitation learning algorithm. Plots (turn on using `plot_interval > 0`): - Plot discriminator loss during discriminator training steps in blue and discriminator loss during generator training steps in red. - Plot the performance of the generator policy versus the performance of a random policy. Also plot the performance of an expert policy if that is provided in the arguments. Checkpoints: - DiscrimNets are saved to f"{log_dir}/checkpoints/{step}/discrim/", where step is either the training epoch or "final". - Generator policies are saved to f"{log_dir}/checkpoints/{step}/gen_policy/". Args: _seed: Random seed. env_name: The environment to train in. rollout_path: Path to pickle containing list of Trajectories. Used as expert demonstrations. n_expert_demos: The number of expert trajectories to actually use after loading them from `rollout_path`. If None, then use all available trajectories. If `n_expert_demos` is an `int`, then use exactly `n_expert_demos` trajectories, erroring if there aren't enough trajectories. If there are surplus trajectories, then use the first `n_expert_demos` trajectories and drop the rest. log_dir: Directory to save models and other logging to. init_trainer_kwargs: Keyword arguments passed to `init_trainer`, used to initialize the trainer. total_timesteps: The number of transitions to sample from the environment during training. n_episodes_eval: The number of episodes to average over when calculating the average episode reward of the imitation policy for return. plot_interval: The number of epochs between each plot. If negative, then plots are disabled. If zero, then only plot at the end of training. n_plot_episodes: The number of episodes averaged over when calculating the average episode reward of a policy for the performance plots. extra_episode_data_interval: Usually mean episode rewards are calculated immediately before every plot. Set this parameter to a nonnegative number to also add episode reward data points every `extra_episodes_data_interval` epochs. show_plots: Figures are always saved to `f"{log_dir}/plots/*.png"`. If `show_plots` is True, then also show plots as they are created. init_tensorboard: If True, then write tensorboard logs to `{log_dir}/sb_tb`. checkpoint_interval: Save the discriminator and generator models every `checkpoint_interval` epochs and after training is complete. If 0, then only save weights after training is complete. If <0, then don't save weights at all. Returns: A dictionary with two keys. "imit_stats" gives the return value of `rollout_stats()` on rollouts test-reward-wrapped environment, using the final policy (remember that the ground-truth reward can be recovered from the "monitor_return" key). "expert_stats" gives the return value of `rollout_stats()` on the expert demonstrations loaded from `rollout_path`. """ total_timesteps = int(total_timesteps) tf.logging.info("Logging to %s", log_dir) os.makedirs(log_dir, exist_ok=True) sacred_util.build_sacred_symlink(log_dir, _run) # Calculate stats for expert rollouts. Used for plot and return value. expert_trajs = types.load(rollout_path) if n_expert_demos is not None: assert len(expert_trajs) >= n_expert_demos expert_trajs = expert_trajs[:n_expert_demos] expert_stats = rollout.rollout_stats(expert_trajs) with networks.make_session(): if init_tensorboard: sb_tensorboard_dir = osp.join(log_dir, "sb_tb") kwargs = init_trainer_kwargs kwargs["init_rl_kwargs"] = kwargs.get("init_rl_kwargs", {}) kwargs["init_rl_kwargs"]["tensorboard_log"] = sb_tensorboard_dir trainer = init_trainer(env_name, expert_trajs, seed=_seed, log_dir=log_dir, **init_trainer_kwargs) def callback(epoch): if checkpoint_interval > 0 and epoch % checkpoint_interval == 0: save(trainer, os.path.join(log_dir, "checkpoints", f"{epoch:05d}")) trainer.train(total_timesteps, callback) # Save final artifacts. if checkpoint_interval >= 0: save(trainer, os.path.join(log_dir, "checkpoints", "final")) # Final evaluation of imitation policy. results = {} sample_until_eval = rollout.min_episodes(n_episodes_eval) trajs = rollout.generate_trajectories(trainer.gen_policy, trainer.venv_test, sample_until=sample_until_eval) results["imit_stats"] = rollout.rollout_stats(trajs) results["expert_stats"] = expert_stats return results