Exemplo n.º 1
0
def test_sarsa_iterate_value_q_pi_with_trajectory_planning():

    random_state = RandomState(12345)
    mdp_environment: Gridworld = Gridworld.example_4_1(random_state, None)
    q_S_A = TabularStateActionValueEstimator(mdp_environment, 0.05, None)
    mdp_agent = ActionValueMdpAgent('test', random_state, 1, q_S_A)

    planning_environment = TrajectorySamplingMdpPlanningEnvironment(
        'test planning', random_state, StochasticEnvironmentModel(), 10, None)

    iterate_value_q_pi(agent=mdp_agent,
                       environment=mdp_environment,
                       num_improvements=100,
                       num_episodes_per_improvement=1,
                       num_updates_per_improvement=None,
                       alpha=0.1,
                       mode=Mode.SARSA,
                       n_steps=1,
                       planning_environment=planning_environment,
                       make_final_policy_greedy=True)

    # uncomment the following line and run test to update fixture
    # with open(f'{os.path.dirname(__file__)}/fixtures/test_td_iteration_of_value_q_pi_planning.pickle', 'wb') as file:
    #     pickle.dump((mdp_agent.pi, q_S_A), file)

    with open(
            f'{os.path.dirname(__file__)}/fixtures/test_td_iteration_of_value_q_pi_planning.pickle',
            'rb') as file:
        pi_fixture, q_S_A_fixture = pickle.load(file)

    assert tabular_pi_legacy_eq(mdp_agent.pi,
                                pi_fixture) and tabular_estimator_legacy_eq(
                                    q_S_A, q_S_A_fixture)
Exemplo n.º 2
0
def test_invalid_iterate_value_q_pi():

    random_state = RandomState(12345)

    mdp_environment: Gridworld = Gridworld.example_4_1(random_state, None)

    q_S_A = TabularStateActionValueEstimator(mdp_environment, 0.0, None)

    # target agent
    mdp_agent = StochasticMdpAgent('test', random_state,
                                   q_S_A.get_initial_policy(), 1)

    # episode generation (behavior) policy
    off_policy_agent = StochasticMdpAgent('test', random_state,
                                          q_S_A.get_initial_policy(), 1)

    with pytest.raises(
            ValueError,
            match=
            'Planning environments are not currently supported for Monte Carlo iteration.'
    ):
        iterate_value_q_pi(
            agent=mdp_agent,
            environment=mdp_environment,
            num_improvements=100,
            num_episodes_per_improvement=1,
            update_upon_every_visit=True,
            planning_environment=TrajectorySamplingMdpPlanningEnvironment(
                'foo', random_state, StochasticEnvironmentModel(), 100, None),
            make_final_policy_greedy=False,
            q_S_A=q_S_A,
            off_policy_agent=off_policy_agent)

    # test warning...no off-policy agent with epsilon=0.0
    q_S_A.epsilon = 0.0
    iterate_value_q_pi(agent=mdp_agent,
                       environment=mdp_environment,
                       num_improvements=100,
                       num_episodes_per_improvement=1,
                       update_upon_every_visit=True,
                       planning_environment=None,
                       make_final_policy_greedy=False,
                       q_S_A=q_S_A,
                       off_policy_agent=None)
Exemplo n.º 3
0
    def init_from_arguments(
            cls, args: List[str],
            random_state: RandomState) -> Tuple[Environment, List[str]]:
        """
        Initialize an environment from arguments.

        :param args: Arguments.
        :param random_state: Random state.
        :return: 2-tuple of an environment and a list of unparsed arguments.
        """

        parsed_args, unparsed_args = parse_arguments(cls, args)

        planning_environment = cls(name='trajectory planning',
                                   random_state=random_state,
                                   model=StochasticEnvironmentModel(),
                                   **vars(parsed_args))

        return planning_environment, unparsed_args
Exemplo n.º 4
0
def test_prioritized_planning_environment():

    rng = RandomState(12345)

    planning_environment = PrioritizedSweepingMdpPlanningEnvironment(
        'test', rng, StochasticEnvironmentModel(), 1, 0.3, 10)

    planning_environment.add_state_action_priority(MdpState(1, [], False),
                                                   Action(1), 0.2)
    planning_environment.add_state_action_priority(MdpState(2, [], False),
                                                   Action(2), 0.1)
    planning_environment.add_state_action_priority(MdpState(3, [], False),
                                                   Action(3), 0.3)

    s, a = planning_environment.get_state_action_with_highest_priority()
    assert s.i == 2 and a.i == 2
    s, a = planning_environment.get_state_action_with_highest_priority()
    assert s.i == 1 and a.i == 1
    s, a = planning_environment.get_state_action_with_highest_priority()
    assert s is None and a is None
Exemplo n.º 5
0
def test_stochastic_environment_model():

    random_state = RandomState(12345)

    model = StochasticEnvironmentModel()

    actions = [
        Action(i)
        for i in range(5)
    ]

    states = [
        State(i, actions)
        for i in range(5)
    ]

    for t in range(1000):
        state = sample_list_item(states, None, random_state)
        action = sample_list_item(state.AA, None, random_state)
        next_state = sample_list_item(states, None, random_state)
        reward = Reward(None, random_state.randint(10))
        model.update(state, action, next_state, reward)

    environment_sequence = []
    for i in range(1000):
        state = model.sample_state(random_state)
        action = model.sample_action(state, random_state)
        next_state, reward = model.sample_next_state_and_reward(state, action, random_state)
        environment_sequence.append((next_state, reward))

    # uncomment the following line and run test to update fixture
    # with open(f'{os.path.dirname(__file__)}/fixtures/test_stochastic_environment_model.pickle', 'wb') as file:
    #     pickle.dump(environment_sequence, file)

    with open(f'{os.path.dirname(__file__)}/fixtures/test_stochastic_environment_model.pickle', 'rb') as file:
        environment_sequence_fixture = pickle.load(file)

    assert environment_sequence == environment_sequence_fixture