def _generate_test_rollouts(tmpdir: str, env_named_config: str) -> str: expert_demos_ex.run(named_configs=[env_named_config, "fast"], config_updates=dict( rollout_save_interval=0, log_dir=tmpdir, )) rollout_path = osp.abspath(f"{tmpdir}/rollouts/final.pkl") return rollout_path
def test_transfer_learning(): """Transfer learning smoke test. Save a dummy AIRL test reward, then load it for transfer learning.""" with tempfile.TemporaryDirectory(prefix='imitation-transfer', ) as tmpdir: log_dir_train = osp.join(tmpdir, "train") run = train_ex.run( named_configs=['cartpole', 'airl', 'fast'], config_updates=dict( rollout_glob="tests/data/rollouts/CartPole*.pkl", log_dir=log_dir_train, ), ) assert run.status == 'COMPLETED' log_dir_data = osp.join(tmpdir, "expert_demos") discrim_path = osp.join(log_dir_train, "checkpoints", "final", "discrim") run = expert_demos_ex.run( named_configs=['cartpole', 'fast'], config_updates=dict( log_dir=log_dir_data, reward_type='DiscrimNetAIRL', reward_path=discrim_path, ), ) assert run.status == 'COMPLETED'
def test_transfer_learning(tmpdir): """Transfer learning smoke test. Saves a dummy AIRL test reward, then loads it for transfer learning. """ log_dir_train = osp.join(tmpdir, "train") run = train_ex.run( named_configs=["cartpole", "airl", "fast"], config_updates=dict( rollout_path= "tests/data/expert_models/cartpole_0/rollouts/final.pkl", log_dir=log_dir_train, ), ) assert run.status == "COMPLETED" _check_train_ex_result(run.result) _check_rollout_stats(run.result["imit_stats"]) log_dir_data = osp.join(tmpdir, "expert_demos") discrim_path = osp.join(log_dir_train, "checkpoints", "final", "discrim") run = expert_demos_ex.run( named_configs=["cartpole", "fast"], config_updates=dict( log_dir=log_dir_data, reward_type="DiscrimNet", reward_path=discrim_path, ), ) assert run.status == "COMPLETED" _check_rollout_stats(run.result)
def test_expert_demos_main(tmpdir): """Smoke test for imitation.scripts.expert_demos.rollouts_and_policy.""" run = expert_demos_ex.run( named_configs=["cartpole", "fast"], config_updates=dict(log_root=tmpdir, ), ) assert run.status == "COMPLETED" assert isinstance(run.result, dict)
def test_expert_demos_main(): """Smoke test for imitation.scripts.expert_demos.rollouts_and_policy""" with tempfile.TemporaryDirectory(prefix='imitation-data_collect-main', ) as tmpdir: run = expert_demos_ex.run( named_configs=['cartpole', 'fast'], config_updates=dict( log_root=tmpdir, ), ) assert run.status == 'COMPLETED'
def test_expert_demos_rollouts_from_policy(tmpdir): """Smoke test for imitation.scripts.expert_demos.rollouts_from_policy.""" run = expert_demos_ex.run( command_name="rollouts_from_policy", named_configs=["cartpole", "fast"], config_updates=dict( log_root=tmpdir, rollout_save_path=osp.join(tmpdir, "rollouts", "test.pkl"), policy_path="tests/data/expert_models/cartpole_0/policies/final/", ), ) assert run.status == "COMPLETED"
def test_expert_demos_rollouts_from_policy(): """Smoke test for imitation.scripts.expert_demos.rollouts_from_policy""" with tempfile.TemporaryDirectory(prefix='imitation-data_collect-policy', ) as tmpdir: run = expert_demos_ex.run( command_name="rollouts_from_policy", named_configs=['cartpole', 'fast'], config_updates=dict( log_root=tmpdir, rollout_save_dir=osp.join(tmpdir, "rollouts"), policy_path="expert_models/PPO2_CartPole-v1_0", )) assert run.status == 'COMPLETED'