Example #1
0
def _generate_test_rollouts(tmpdir: str, env_named_config: str) -> str:
    expert_demos_ex.run(named_configs=[env_named_config, "fast"],
                        config_updates=dict(
                            rollout_save_interval=0,
                            log_dir=tmpdir,
                        ))
    rollout_path = osp.abspath(f"{tmpdir}/rollouts/final.pkl")
    return rollout_path
Example #2
0
def test_transfer_learning():
  """Transfer learning smoke test.

  Save a dummy AIRL test reward, then load it for transfer learning."""

  with tempfile.TemporaryDirectory(prefix='imitation-transfer',
                                   ) as tmpdir:
    log_dir_train = osp.join(tmpdir, "train")
    run = train_ex.run(
        named_configs=['cartpole', 'airl', 'fast'],
        config_updates=dict(
          rollout_glob="tests/data/rollouts/CartPole*.pkl",
          log_dir=log_dir_train,
        ),
    )
    assert run.status == 'COMPLETED'

    log_dir_data = osp.join(tmpdir, "expert_demos")
    discrim_path = osp.join(log_dir_train, "checkpoints", "final", "discrim")
    run = expert_demos_ex.run(
        named_configs=['cartpole', 'fast'],
        config_updates=dict(
          log_dir=log_dir_data,
          reward_type='DiscrimNetAIRL',
          reward_path=discrim_path,
        ),
    )
    assert run.status == 'COMPLETED'
Example #3
0
def test_transfer_learning(tmpdir):
    """Transfer learning smoke test.

    Saves a dummy AIRL test reward, then loads it for transfer learning.
    """
    log_dir_train = osp.join(tmpdir, "train")
    run = train_ex.run(
        named_configs=["cartpole", "airl", "fast"],
        config_updates=dict(
            rollout_path=
            "tests/data/expert_models/cartpole_0/rollouts/final.pkl",
            log_dir=log_dir_train,
        ),
    )
    assert run.status == "COMPLETED"
    _check_train_ex_result(run.result)

    _check_rollout_stats(run.result["imit_stats"])

    log_dir_data = osp.join(tmpdir, "expert_demos")
    discrim_path = osp.join(log_dir_train, "checkpoints", "final", "discrim")
    run = expert_demos_ex.run(
        named_configs=["cartpole", "fast"],
        config_updates=dict(
            log_dir=log_dir_data,
            reward_type="DiscrimNet",
            reward_path=discrim_path,
        ),
    )
    assert run.status == "COMPLETED"
    _check_rollout_stats(run.result)
Example #4
0
def test_expert_demos_main(tmpdir):
    """Smoke test for imitation.scripts.expert_demos.rollouts_and_policy."""
    run = expert_demos_ex.run(
        named_configs=["cartpole", "fast"],
        config_updates=dict(log_root=tmpdir, ),
    )
    assert run.status == "COMPLETED"
    assert isinstance(run.result, dict)
Example #5
0
def test_expert_demos_main():
  """Smoke test for imitation.scripts.expert_demos.rollouts_and_policy"""
  with tempfile.TemporaryDirectory(prefix='imitation-data_collect-main',
                                   ) as tmpdir:
      run = expert_demos_ex.run(
          named_configs=['cartpole', 'fast'],
          config_updates=dict(
            log_root=tmpdir,
          ),
      )
      assert run.status == 'COMPLETED'
Example #6
0
def test_expert_demos_rollouts_from_policy(tmpdir):
    """Smoke test for imitation.scripts.expert_demos.rollouts_from_policy."""
    run = expert_demos_ex.run(
        command_name="rollouts_from_policy",
        named_configs=["cartpole", "fast"],
        config_updates=dict(
            log_root=tmpdir,
            rollout_save_path=osp.join(tmpdir, "rollouts", "test.pkl"),
            policy_path="tests/data/expert_models/cartpole_0/policies/final/",
        ),
    )
    assert run.status == "COMPLETED"
Example #7
0
def test_expert_demos_rollouts_from_policy():
  """Smoke test for imitation.scripts.expert_demos.rollouts_from_policy"""
  with tempfile.TemporaryDirectory(prefix='imitation-data_collect-policy',
                                   ) as tmpdir:
    run = expert_demos_ex.run(
        command_name="rollouts_from_policy",
        named_configs=['cartpole', 'fast'],
        config_updates=dict(
          log_root=tmpdir,
          rollout_save_dir=osp.join(tmpdir, "rollouts"),
          policy_path="expert_models/PPO2_CartPole-v1_0",
        ))
  assert run.status == 'COMPLETED'