Esempio n. 1
0
def test_parallel(config_updates):
    """Hyperparam tuning smoke test."""
    # No need for TemporaryDirectory because the hyperparameter tuning script
    # itself generates no artifacts, and "debug_log_root" sets inner experiment's
    # log_root="/tmp/parallel_debug/".
    run = parallel_ex.run(named_configs=["debug_log_root"],
                          config_updates=config_updates)
    assert run.status == 'COMPLETED'
Esempio n. 2
0
def test_parallel(config_updates):
    """Hyperparam tuning smoke test."""
    # CI server only has 2 cores
    config_updates = dict(config_updates)
    config_updates.update(PARALLEL_CONFIG_LOW_RESOURCE)
    # No need for TemporaryDirectory because the hyperparameter tuning script
    # itself generates no artifacts, and "debug_log_root" sets inner experiment's
    # log_root="/tmp/parallel_debug/".
    run = parallel_ex.run(named_configs=["debug_log_root"],
                          config_updates=config_updates)
    assert run.status == "COMPLETED"
Esempio n. 3
0
def test_analyze_gather_tb(tmpdir: str):
    config_updates = dict(local_dir=tmpdir, run_name="test")
    config_updates.update(PARALLEL_CONFIG_LOW_RESOURCE)
    parallel_run = parallel_ex.run(named_configs=["generate_test_data"],
                                   config_updates=config_updates)
    assert parallel_run.status == "COMPLETED"

    run = analysis_ex.run(command_name="gather_tb_directories",
                          config_updates=dict(source_dir=tmpdir, ))
    assert run.status == "COMPLETED"
    assert isinstance(run.result, dict)
    assert run.result["n_tb_dirs"] == 4
Esempio n. 4
0
def test_analyze_gather_tb(tmpdir: str):
    parallel_run = parallel_ex.run(named_configs=["generate_test_data"],
                                   config_updates=dict(
                                       local_dir=tmpdir,
                                       run_name="test",
                                   ))
    assert parallel_run.status == 'COMPLETED'

    run = analysis_ex.run(command_name="gather_tb_directories",
                          config_updates=dict(source_dir=tmpdir, ))
    assert run.status == 'COMPLETED'
    assert isinstance(run.result, dict)
    assert run.result["n_tb_dirs"] == 4
Esempio n. 5
0
def test_parallel_train_adversarial_custom_env(tmpdir):
    env_named_config = "custom_ant"
    rollout_path = _generate_test_rollouts(tmpdir, env_named_config)

    config_updates = dict(
        sacred_ex_name="train_adversarial",
        n_seeds=1,
        base_named_configs=[env_named_config, "fast"],
        base_config_updates=dict(
            init_trainer_kwargs=dict(
                parallel=True,
                num_vec=2,
            ),
            rollout_path=rollout_path,
        ),
    )
    config_updates.update(PARALLEL_CONFIG_LOW_RESOURCE)
    run = parallel_ex.run(named_configs=["debug_log_root"],
                          config_updates=config_updates)
    assert run.status == "COMPLETED"