def test_epsilon_greedy_policy_gives_deterministic_results_when_an_rng_with_a_fixed_seed_is_supplied(
):
    state = 5
    q_function = {
        state: {
            0: 17.1,
            1: -5.5,
            2: 20.2,
            3: -20.9,
            4: 5.1,
            5: 1.0,
            6: -1.0,
            7: 1.5,
        }
    }
    policy, rng = create_rngs_with_fixed_seed(8, 0, 0)
    greedy_policy = GreedyPolicy(
        epsilon=0.5,
        random_policy=policy,
        random_generator=rng,
    )
    actions_expected = [2, 5, 4, 2, 2, 2]
    for expected in actions_expected:
        actual = greedy_policy.next_action(q_function, state)
        assert expected == actual
Ejemplo n.º 2
0
    def __init__(self, epsilon, random_policy, rng, q):
        """Initializes the core control functionality.

        Args:
            q: An initial q-function to use, represented as a
                two-level nested Iterable, i.e. of form q[state][action].  If
                this argument is not supplied, an empty q-function will be
                initialized.
            epsilon: The probablity of going off-policy and selecting
                a random action from a random policy.  Behavior ranges from
                no deviation from the policy (epsilon=0.0) and complete
                randomness (epsilon=1.0).
            random_policy: The policy used to generate the random action when
                deviating from the policy.  Required when epsilon > 0.
            rng: (optional) Pseudo-random number generator (either from Python
                standard library or NumPy) used for all random number generation
                within the policy.  When None, Python's built-in PRNG will be
                used.

        Raises:
            None
        """
        if q is None:
            q = {}
        self._q = copy.deepcopy(q)
        if epsilon > 0.0:
            self._policy = GreedyPolicy(
                epsilon=epsilon,
                random_policy=random_policy,
                random_generator=rng,
            )
        else:
            self._policy = GreedyPolicy(epsilon=0.0)
def test_epsilon_greedy_policy_picks_unique_greedy_answer_when_there_is_a_unique_greedy_maximum_and_epsilon_is_zero(
):
    state = 5
    expected = 2
    q_function = {
        state: {
            0: 17.1,
            1: -5.5,
            expected: 20.2,
            3: -20.9,
            4: 5.1,
            5: 1.0,
            6: -1.0,
            7: 1.5,
        }
    }
    greedy_policy = GreedyPolicy(epsilon=0.0)
    for _ in range(100):
        actual = greedy_policy.next_action(q_function, state)
        assert expected == actual
def test_greedy_policy_breaks_ties_in_a_pseudorandom_fashion_when_there_is_more_than_one_greedy_maximum(
):
    state = 5
    q_function = {
        state: {
            0: 20,
            1: 20,
            2: 20,
            3: 20,
            4: 20,
            5: 20,
            6: 20,
            7: 20,
        }
    }
    random_generator = default_rng(seed=0)
    greedy_policy = GreedyPolicy(random_generator=random_generator)
    actions_expected = [6, 5, 4, 2, 2]
    for expected in actions_expected:
        actual = greedy_policy.next_action(q_function, state)
        assert expected == actual
def test_epsilon_greedy_policy_gives_identical_results_to_greedy_policy_when_there_is_more_than_one_greedy_maximum_and_epsilon_equals_zero(
):
    state = 5
    q_function = {
        state: {
            0: 20,
            1: 20,
            2: 20,
            3: 20,
            4: 20,
            5: 20,
            6: 20,
            7: 20,
        }
    }
    random_generator_1 = default_rng(seed=0)
    random_generator_2 = copy.deepcopy(random_generator_1)
    greedy_policy = GreedyPolicy(random_generator=random_generator_1)
    epsilon_greedy_policy = GreedyPolicy(
        epsilon=0.0,
        random_generator=random_generator_2,
    )
    for _ in range(100):
        expected = greedy_policy.next_action(q_function, state)
        actual = epsilon_greedy_policy.next_action(q_function, state)
        assert expected == actual
def test_epsilon_greedy_policy_gives_identical_results_to_random_policy_when_epsilon_equals_one_and_same_rng_used(
):
    state = 5
    dummy_q_function = {
        state: {
            0: 17.1,
            1: -5.5,
            2: 20.2,
            3: -20.9,
            4: 5.1,
            5: 1.0,
            6: -1.0,
            7: 1.5,
        }
    }
    random_policy, _ = create_rngs_with_fixed_seed(8, 0, 0)
    random_policy_copy = copy.deepcopy(random_policy)
    greedy_policy = GreedyPolicy(epsilon=1.0, random_policy=random_policy_copy)
    for _ in range(100):
        expected = random_policy.next_action()
        actual = greedy_policy.next_action(dummy_q_function, state)
        assert expected == actual
def test_epsilon_greedy_policy_gives_identical_results_to_greedy_policy_when_there_is_a_unique_greedy_maximum_and_epsilon_equals_zero(
):
    state = 5
    q_function = {
        state: {
            0: 17.1,
            1: -5.5,
            2: 20.2,
            3: -20.9,
            4: 5.1,
            5: 1.0,
            6: -1.0,
            7: 1.5,
        }
    }
    greedy_policy = GreedyPolicy()
    epsilon_greedy_policy = GreedyPolicy(epsilon=0.0)
    for _ in range(100):
        expected = greedy_policy.next_action(q_function, state)
        actual = epsilon_greedy_policy.next_action(q_function, state)
        assert expected == actual
Ejemplo n.º 8
0
class Control(object):
    """Core functionality for control classes.
    """
    def __init__(self, epsilon, random_policy, rng, q):
        """Initializes the core control functionality.

        Args:
            q: An initial q-function to use, represented as a
                two-level nested Iterable, i.e. of form q[state][action].  If
                this argument is not supplied, an empty q-function will be
                initialized.
            epsilon: The probablity of going off-policy and selecting
                a random action from a random policy.  Behavior ranges from
                no deviation from the policy (epsilon=0.0) and complete
                randomness (epsilon=1.0).
            random_policy: The policy used to generate the random action when
                deviating from the policy.  Required when epsilon > 0.
            rng: (optional) Pseudo-random number generator (either from Python
                standard library or NumPy) used for all random number generation
                within the policy.  When None, Python's built-in PRNG will be
                used.

        Raises:
            None
        """
        if q is None:
            q = {}
        self._q = copy.deepcopy(q)
        if epsilon > 0.0:
            self._policy = GreedyPolicy(
                epsilon=epsilon,
                random_policy=random_policy,
                random_generator=rng,
            )
        else:
            self._policy = GreedyPolicy(epsilon=0.0)

    def next_action(self, state):
        """Selects an action based on the control's current policy.

        For off-policy learning, the action corresponds to the on-policy choice,
        i.e. the action the control will take (modulo epsilon-randomness)

        Args:
            state: The state for which an action should be chosen

        Returns:
            Action to take for the given state based on the current policy

        Raises:
            None
        """
        return self._policy.next_action(self._q, state)

    def get_q(self):
        """Returns the current q-function.

        Args:
            None

        Returns:
            The q function represented as a two-level nested Iterable, i.e. of
            form q[state][action].

        Raises:
            None
        """
        return copy.deepcopy(self._q)
def test_creating_greedy_policy_throws_exception_when_epsilon_is_greater_than_zero_and_random_policy_not_provided(
):
    with pytest.raises(Exception) as excinfo:
        _ = GreedyPolicy(epsilon=1.0)
    assert "when specifying an epsilon value greater than 0, you must also provide a random policy!" in str(
        excinfo.value)