コード例 #1
0
def test_parallel_sampling_deterministic_smoke_test_w_min_steps(
        tmpdir_factory, env: SimEnv, policy: Policy, algo, min_rollouts: int,
        min_steps: int):
    env.max_steps = 20

    seeds = (0, 1)
    nums_workers = (1, 2, 4)

    logging_results = []
    rollout_results: List[List[List[List[StepSequence]]]] = []
    for seed in seeds:
        logging_results.append((seed, []))
        rollout_results.append([])
        for num_workers in nums_workers:
            pyrado.set_seed(seed)
            policy.init_param(None)
            ex_dir = str(
                tmpdir_factory.mktemp(
                    f"seed={seed}-num_workers={num_workers}"))
            set_log_prefix_dir(ex_dir)
            vfcn = FNN(input_size=env.obs_space.flat_dim,
                       output_size=1,
                       hidden_sizes=[16, 16],
                       hidden_nonlin=to.tanh)
            critic = GAE(vfcn,
                         gamma=0.98,
                         lamda=0.95,
                         batch_size=32,
                         lr=1e-3,
                         standardize_adv=False)
            alg = algo(
                ex_dir,
                env,
                policy,
                critic,
                max_iter=3,
                min_rollouts=min_rollouts,
                min_steps=min_steps * env.max_steps,
                num_workers=num_workers,
            )
            alg.sampler = RolloutSavingWrapper(alg.sampler)
            alg.train()
            with open(f"{ex_dir}/progress.csv") as f:
                logging_results[-1][1].append(str(f.read()))
            rollout_results[-1].append(alg.sampler.rollouts)

    # Test that the observations for all number of workers are equal.
    for rollouts in rollout_results:
        for ros_a, ros_b in [(a, b) for a in rollouts for b in rollouts]:
            assert len(ros_a) == len(ros_b)
            for ro_a, ro_b in zip(ros_a, ros_b):
                assert len(ro_a) == len(ro_b)
                for r_a, r_b in zip(ro_a, ro_b):
                    assert r_a.observations == pytest.approx(r_b.observations)

    # Test that different seeds actually produce different results.
    for results_a, results_b in [(a, b) for seed_a, a in logging_results
                                 for seed_b, b in logging_results
                                 if seed_a != seed_b]:
        for result_a, result_b in [(a, b) for a in results_a for b in results_b
                                   if a is not b]:
            assert result_a != result_b

    # Test that same seeds produce same results.
    for _, results in logging_results:
        for result_a, result_b in [(a, b) for a in results for b in results]:
            assert result_a == result_b
コード例 #2
0
ファイル: test_policies.py プロジェクト: fdamken/SimuRLacra
def test_parameterized_policies_init_param(env: Env, policy: Policy):
    some_values = to.ones_like(policy.param_values)
    policy.init_param(some_values)
    to.testing.assert_allclose(policy.param_values, some_values)