Beispiel #1
0
def setup_experiment(
    env_name: str,
    algo_name: str,
    extra_info: str = None,
    base_dir: str = pyrado.TEMP_DIR,
    include_slurm_id: bool = True,
):
    """
    Setup a new experiment for recording.

    :param env_name: environment trained on
    :param algo_name: algorithm trained with, usually also includes the policy type, e.g. 'a2c_fnn'
    :param extra_info: additional information on the experiment (free form)
    :param base_dir: base storage directory
    :param include_slurm_id: if a SLURM ID is present in the environment variables, include them in the experiment ID
    """

    # Create experiment object
    exp = Experiment(env_name,
                     algo_name,
                     extra_info,
                     base_dir=base_dir,
                     include_slurm_id=include_slurm_id)

    # Create the folder
    os.makedirs(exp, exist_ok=True)

    # Set the global logger variable
    set_log_prefix_dir(exp)

    return exp
Beispiel #2
0
def setup_experiment(env_name: str,
                     algo_name: str,
                     extra_info: str = None,
                     base_dir: str = pyrado.TEMP_DIR):
    """ Setup a new experiment for recording. """
    # Create experiment object
    exp = Experiment(env_name, algo_name, extra_info, base_dir=base_dir)

    # Create the folder
    os.makedirs(exp, exist_ok=True)

    # Set the global logger variable
    set_log_prefix_dir(exp)

    return exp
Beispiel #3
0
def test_parallel_sampling_deterministic_smoke_test_w_min_steps(
        tmpdir_factory, env: SimEnv, policy: Policy, algo, min_rollouts: int,
        min_steps: int):
    env.max_steps = 20

    seeds = (0, 1)
    nums_workers = (1, 2, 4)

    logging_results = []
    rollout_results: List[List[List[List[StepSequence]]]] = []
    for seed in seeds:
        logging_results.append((seed, []))
        rollout_results.append([])
        for num_workers in nums_workers:
            pyrado.set_seed(seed)
            policy.init_param(None)
            ex_dir = str(
                tmpdir_factory.mktemp(
                    f"seed={seed}-num_workers={num_workers}"))
            set_log_prefix_dir(ex_dir)
            vfcn = FNN(input_size=env.obs_space.flat_dim,
                       output_size=1,
                       hidden_sizes=[16, 16],
                       hidden_nonlin=to.tanh)
            critic = GAE(vfcn,
                         gamma=0.98,
                         lamda=0.95,
                         batch_size=32,
                         lr=1e-3,
                         standardize_adv=False)
            alg = algo(
                ex_dir,
                env,
                policy,
                critic,
                max_iter=3,
                min_rollouts=min_rollouts,
                min_steps=min_steps * env.max_steps,
                num_workers=num_workers,
            )
            alg.sampler = RolloutSavingWrapper(alg.sampler)
            alg.train()
            with open(f"{ex_dir}/progress.csv") as f:
                logging_results[-1][1].append(str(f.read()))
            rollout_results[-1].append(alg.sampler.rollouts)

    # Test that the observations for all number of workers are equal.
    for rollouts in rollout_results:
        for ros_a, ros_b in [(a, b) for a in rollouts for b in rollouts]:
            assert len(ros_a) == len(ros_b)
            for ro_a, ro_b in zip(ros_a, ros_b):
                assert len(ro_a) == len(ro_b)
                for r_a, r_b in zip(ro_a, ro_b):
                    assert r_a.observations == pytest.approx(r_b.observations)

    # Test that different seeds actually produce different results.
    for results_a, results_b in [(a, b) for seed_a, a in logging_results
                                 for seed_b, b in logging_results
                                 if seed_a != seed_b]:
        for result_a, result_b in [(a, b) for a in results_a for b in results_b
                                   if a is not b]:
            assert result_a != result_b

    # Test that same seeds produce same results.
    for _, results in logging_results:
        for result_a, result_b in [(a, b) for a in results for b in results]:
            assert result_a == result_b
Beispiel #4
0
def ex_dir(tmpdir):
    # Fixture providing an experiment directory
    set_log_prefix_dir(tmpdir)
    return tmpdir
Beispiel #5
0
def ex_dir(tmpdir):
    set_log_prefix_dir(tmpdir)
    return tmpdir