def test_transfer_learning(tmpdir): """Transfer learning smoke test. Saves a dummy AIRL test reward, then loads it for transfer learning. """ log_dir_train = osp.join(tmpdir, "train") run = train_ex.run( named_configs=["cartpole", "airl", "fast"], config_updates=dict( rollout_path= "tests/data/expert_models/cartpole_0/rollouts/final.pkl", log_dir=log_dir_train, ), ) assert run.status == "COMPLETED" _check_train_ex_result(run.result) _check_rollout_stats(run.result["imit_stats"]) log_dir_data = osp.join(tmpdir, "expert_demos") discrim_path = osp.join(log_dir_train, "checkpoints", "final", "discrim") run = expert_demos_ex.run( named_configs=["cartpole", "fast"], config_updates=dict( log_dir=log_dir_data, reward_type="DiscrimNet", reward_path=discrim_path, ), ) assert run.status == "COMPLETED" _check_rollout_stats(run.result)
def test_transfer_learning(): """Transfer learning smoke test. Save a dummy AIRL test reward, then load it for transfer learning.""" with tempfile.TemporaryDirectory(prefix='imitation-transfer', ) as tmpdir: log_dir_train = osp.join(tmpdir, "train") run = train_ex.run( named_configs=['cartpole', 'airl', 'fast'], config_updates=dict( rollout_glob="tests/data/rollouts/CartPole*.pkl", log_dir=log_dir_train, ), ) assert run.status == 'COMPLETED' log_dir_data = osp.join(tmpdir, "expert_demos") discrim_path = osp.join(log_dir_train, "checkpoints", "final", "discrim") run = expert_demos_ex.run( named_configs=['cartpole', 'fast'], config_updates=dict( log_dir=log_dir_data, reward_type='DiscrimNetAIRL', reward_path=discrim_path, ), ) assert run.status == 'COMPLETED'
def test_train_adversarial(tmpdir): """Smoke test for imitation.scripts.train_adversarial.""" named_configs = ["cartpole", "gail", "fast"] config_updates = { "log_root": tmpdir, "rollout_path": "tests/data/expert_models/cartpole_0/rollouts/final.pkl", "init_tensorboard": True, "plot_interval": 1, "extra_episode_data_interval": 1, } run = train_ex.run( named_configs=named_configs, config_updates=config_updates, ) assert run.status == "COMPLETED" _check_train_ex_result(run.result)
def test_train_adversarial(tmpdir): """Smoke test for imitation.scripts.train_adversarial""" named_configs = ['cartpole', 'gail', 'fast'] config_updates = { 'log_root': tmpdir, 'rollout_path': "tests/data/expert_models/cartpole_0/rollouts/final.pkl", 'init_tensorboard': True, 'plot_interval': 1, 'extra_episode_data_interval': 1, } run = train_ex.run( named_configs=named_configs, config_updates=config_updates, ) assert run.status == 'COMPLETED' _check_train_ex_result(run.result)
def test_train_adversarial(): """Smoke test for imitation.scripts.train_adversarial""" with tempfile.TemporaryDirectory(prefix='imitation-train', ) as tmpdir: config_updates = { 'init_trainer_kwargs': { # Rollouts are small, decrease size of buffer to avoid warning 'trainer_kwargs': { 'n_disc_samples_per_buffer': 50, }, }, 'log_root': tmpdir, 'rollout_glob': "tests/data/rollouts/CartPole*.pkl", } run = train_ex.run( named_configs=['cartpole', 'gail', 'fast'], config_updates=config_updates, ) assert run.status == 'COMPLETED'
def test_train_adversarial(tmpdir): """Smoke test for imitation.scripts.train_adversarial""" named_configs = ['cartpole', 'gail', 'fast', 'plots'] config_updates = { 'init_trainer_kwargs': { # Rollouts are small, decrease size of buffer to avoid warning 'trainer_kwargs': { 'n_disc_samples_per_buffer': 50, }, }, 'log_root': tmpdir, 'rollout_path': "tests/data/expert_models/cartpole_0/rollouts/final.pkl", 'init_tensorboard': True, } run = train_ex.run( named_configs=named_configs, config_updates=config_updates, ) assert run.status == 'COMPLETED' _check_train_ex_result(run.result)
def test_analyze_imitation(tmpdir: str, run_names: List[str]): sacred_logs_dir = tmpdir # Generate sacred logs (other logs are put in separate tmpdir for deletion). for i, run_name in enumerate(run_names): with tempfile.TemporaryDirectory(prefix="junk") as junkdir: rollout_path = "tests/data/expert_models/cartpole_0/rollouts/final.pkl" run = train_ex.run( named_configs=["fast", "cartpole"], config_updates=dict( rollout_path=rollout_path, log_dir=junkdir, checkpoint_interval=-1, ), options={ "--name": run_name, "--file_storage": sacred_logs_dir }, ) assert run.status == "COMPLETED" # Check that analyze script finds the correct number of logs. def check(run_name: Optional[str], count: int) -> None: run = analysis_ex.run( command_name="analyze_imitation", config_updates=dict( source_dir=sacred_logs_dir, run_name=run_name, csv_output_path=osp.join(tmpdir, "analysis.csv"), verbose=True, ), ) assert run.status == "COMPLETED" df = pd.DataFrame(run.result) assert df.shape[0] == count for run_name, count in Counter(run_names).items(): check(run_name, count) check(None, len(run_names)) # Check total number of logs.