Exemplo n.º 1
0
def test_decay_period(env):
    policy = ConstantPolicy(env.action_space.sample())
    exp_policy = AddGaussianNoise(env,
                                  policy,
                                  total_timesteps=2,
                                  max_sigma=1.,
                                  min_sigma=0.)
    assert (exp_policy.get_action(None)[0] != policy.get_action(None)[0]).all()
    assert (exp_policy.get_action(None)[0] != policy.get_action(None)[0]).all()
    assert (exp_policy.get_action(None)[0] == policy.get_action(None)[0]).all()
Exemplo n.º 2
0
def test_params(env):
    policy1 = ConstantPolicy(env.action_space.sample())
    policy2 = ConstantPolicy(env.action_space.sample())
    assert (policy1.get_action(None)[0] != policy2.get_action(None)[0]).all()

    exp_policy1 = AddGaussianNoise(env, policy1, 1)
    exp_policy2 = AddGaussianNoise(env, policy2, 1)
    exp_policy2.get_action(None)

    assert exp_policy1._sigma() != exp_policy2._sigma()

    exp_policy1.set_param_values(exp_policy2.get_param_values())

    assert (policy1.get_action(None)[0] == policy2.get_action(None)[0]).all()
    assert exp_policy1._sigma() == exp_policy2._sigma()
Exemplo n.º 3
0
def test_update(env):
    policy = ConstantPolicy(env.action_space.sample())
    exp_policy = AddGaussianNoise(env,
                                  policy,
                                  total_timesteps=10,
                                  max_sigma=1.,
                                  min_sigma=0.)
    exp_policy.get_action(None)
    exp_policy.get_action(None)

    DummyBatch = collections.namedtuple('EpisodeBatch', ['lengths'])
    batch = DummyBatch(np.array([1, 2]))

    # new sigma will be 1 - 0.1 * (1 + 2) = 0.7
    exp_policy.update(batch)
    assert np.isclose(exp_policy._sigma(), 0.7)

    exp_policy.get_action(None)

    batch = DummyBatch(np.array([1, 1, 2]))
    # new sigma will be 0.7 - 0.1 * (1 + 1 + 2) = 0.3
    exp_policy.update(batch)
    assert np.isclose(exp_policy._sigma(), 0.3)