Ejemplo n.º 1
0
def test_transfer_learning(tmpdir):
    """Transfer learning smoke test.

    Saves a dummy AIRL test reward, then loads it for transfer learning.
    """
    log_dir_train = osp.join(tmpdir, "train")
    run = train_ex.run(
        named_configs=["cartpole", "airl", "fast"],
        config_updates=dict(
            rollout_path=
            "tests/data/expert_models/cartpole_0/rollouts/final.pkl",
            log_dir=log_dir_train,
        ),
    )
    assert run.status == "COMPLETED"
    _check_train_ex_result(run.result)

    _check_rollout_stats(run.result["imit_stats"])

    log_dir_data = osp.join(tmpdir, "expert_demos")
    discrim_path = osp.join(log_dir_train, "checkpoints", "final", "discrim")
    run = expert_demos_ex.run(
        named_configs=["cartpole", "fast"],
        config_updates=dict(
            log_dir=log_dir_data,
            reward_type="DiscrimNet",
            reward_path=discrim_path,
        ),
    )
    assert run.status == "COMPLETED"
    _check_rollout_stats(run.result)
Ejemplo n.º 2
0
def test_transfer_learning():
  """Transfer learning smoke test.

  Save a dummy AIRL test reward, then load it for transfer learning."""

  with tempfile.TemporaryDirectory(prefix='imitation-transfer',
                                   ) as tmpdir:
    log_dir_train = osp.join(tmpdir, "train")
    run = train_ex.run(
        named_configs=['cartpole', 'airl', 'fast'],
        config_updates=dict(
          rollout_glob="tests/data/rollouts/CartPole*.pkl",
          log_dir=log_dir_train,
        ),
    )
    assert run.status == 'COMPLETED'

    log_dir_data = osp.join(tmpdir, "expert_demos")
    discrim_path = osp.join(log_dir_train, "checkpoints", "final", "discrim")
    run = expert_demos_ex.run(
        named_configs=['cartpole', 'fast'],
        config_updates=dict(
          log_dir=log_dir_data,
          reward_type='DiscrimNetAIRL',
          reward_path=discrim_path,
        ),
    )
    assert run.status == 'COMPLETED'
Ejemplo n.º 3
0
def test_train_adversarial(tmpdir):
    """Smoke test for imitation.scripts.train_adversarial."""
    named_configs = ["cartpole", "gail", "fast"]
    config_updates = {
        "log_root": tmpdir,
        "rollout_path":
        "tests/data/expert_models/cartpole_0/rollouts/final.pkl",
        "init_tensorboard": True,
        "plot_interval": 1,
        "extra_episode_data_interval": 1,
    }
    run = train_ex.run(
        named_configs=named_configs,
        config_updates=config_updates,
    )
    assert run.status == "COMPLETED"
    _check_train_ex_result(run.result)
Ejemplo n.º 4
0
def test_train_adversarial(tmpdir):
    """Smoke test for imitation.scripts.train_adversarial"""
    named_configs = ['cartpole', 'gail', 'fast']
    config_updates = {
        'log_root': tmpdir,
        'rollout_path':
        "tests/data/expert_models/cartpole_0/rollouts/final.pkl",
        'init_tensorboard': True,
        'plot_interval': 1,
        'extra_episode_data_interval': 1,
    }
    run = train_ex.run(
        named_configs=named_configs,
        config_updates=config_updates,
    )
    assert run.status == 'COMPLETED'
    _check_train_ex_result(run.result)
Ejemplo n.º 5
0
def test_train_adversarial():
  """Smoke test for imitation.scripts.train_adversarial"""
  with tempfile.TemporaryDirectory(prefix='imitation-train',
                                   ) as tmpdir:
      config_updates = {
          'init_trainer_kwargs': {
              # Rollouts are small, decrease size of buffer to avoid warning
              'trainer_kwargs': {
                  'n_disc_samples_per_buffer': 50,
              },
          },
          'log_root': tmpdir,
          'rollout_glob': "tests/data/rollouts/CartPole*.pkl",
      }
      run = train_ex.run(
          named_configs=['cartpole', 'gail', 'fast'],
          config_updates=config_updates,
      )
      assert run.status == 'COMPLETED'
Ejemplo n.º 6
0
def test_train_adversarial(tmpdir):
    """Smoke test for imitation.scripts.train_adversarial"""
    named_configs = ['cartpole', 'gail', 'fast', 'plots']
    config_updates = {
        'init_trainer_kwargs': {
            # Rollouts are small, decrease size of buffer to avoid warning
            'trainer_kwargs': {
                'n_disc_samples_per_buffer': 50,
            },
        },
        'log_root': tmpdir,
        'rollout_path':
        "tests/data/expert_models/cartpole_0/rollouts/final.pkl",
        'init_tensorboard': True,
    }
    run = train_ex.run(
        named_configs=named_configs,
        config_updates=config_updates,
    )
    assert run.status == 'COMPLETED'
    _check_train_ex_result(run.result)
Ejemplo n.º 7
0
def test_analyze_imitation(tmpdir: str, run_names: List[str]):
    sacred_logs_dir = tmpdir

    # Generate sacred logs (other logs are put in separate tmpdir for deletion).
    for i, run_name in enumerate(run_names):
        with tempfile.TemporaryDirectory(prefix="junk") as junkdir:
            rollout_path = "tests/data/expert_models/cartpole_0/rollouts/final.pkl"
            run = train_ex.run(
                named_configs=["fast", "cartpole"],
                config_updates=dict(
                    rollout_path=rollout_path,
                    log_dir=junkdir,
                    checkpoint_interval=-1,
                ),
                options={
                    "--name": run_name,
                    "--file_storage": sacred_logs_dir
                },
            )
            assert run.status == "COMPLETED"

    # Check that analyze script finds the correct number of logs.
    def check(run_name: Optional[str], count: int) -> None:
        run = analysis_ex.run(
            command_name="analyze_imitation",
            config_updates=dict(
                source_dir=sacred_logs_dir,
                run_name=run_name,
                csv_output_path=osp.join(tmpdir, "analysis.csv"),
                verbose=True,
            ),
        )
        assert run.status == "COMPLETED"
        df = pd.DataFrame(run.result)
        assert df.shape[0] == count

    for run_name, count in Counter(run_names).items():
        check(run_name, count)

    check(None, len(run_names))  # Check total number of logs.