Exemplo n.º 1
0
def test_criterion_rollout_based_convergence_none():
    rollout = SimpleNamespace(undiscounted_return=lambda: 0)
    sampler = RolloutSavingWrapper(MockSampler([rollout]))
    sampler.sample()
    algo = SimpleNamespace(sampler=sampler)
    criterion = ConvergenceStoppingCriterion(
        convergence_probability_threshold=0.5)
    criterion._compute_convergence_probability = lambda: None
    assert not criterion.is_met(algo)
Exemplo n.º 2
0
def test_criterion_rollout_based_convergence_history_filling():
    rollouts = [
        SimpleNamespace(undiscounted_return=(lambda k: lambda: k)(n))
        for n in range(10)
    ]
    mock_sampler = MockSampler()
    sampler = RolloutSavingWrapper(mock_sampler)
    algo = SimpleNamespace(sampler=sampler)
    criterion = ConvergenceStoppingCriterion()
    for i, rollout in enumerate(rollouts):
        mock_sampler.step_sequences = [rollout]
        sampler.sample()
        criterion.is_met(algo)
        assert criterion._return_statistic_history == np.arange(i + 1).tolist()
Exemplo n.º 3
0
def test_criterion_rollout_based_convergence_regress_random():
    criterion = ConvergenceStoppingCriterion()
    criterion._return_statistic_history = np.random.default_rng(seed=5).normal(
        loc=0.0, scale=0.001, size=10000)
    assert criterion._compute_convergence_probability() > 0.9
Exemplo n.º 4
0
def test_criterion_rollout_based_convergence_regress_not_constant():
    criterion = ConvergenceStoppingCriterion()
    criterion._return_statistic_history = np.arange(10).tolist()
    assert np.isclose(criterion._compute_convergence_probability(), 0.0)
Exemplo n.º 5
0
def test_criterion_rollout_based_convergence_subset(num_iter, expected):
    criterion = ConvergenceStoppingCriterion(num_iter=num_iter)
    criterion._return_statistic_history = [1, 2, 3]
    assert criterion._get_relevant_return_statistic_subset() == expected