Exemple #1
0
def test_value_iteration():

    random_state = RandomState(12345)

    mdp_environment: Gridworld = Gridworld.example_4_1(random_state, None)

    # run policy iteration on v_pi
    mdp_agent_v_pi_policy_iteration = StochasticMdpAgent(
        'test',
        random_state,
        TabularPolicy(None, mdp_environment.SS),
        1
    )

    iterate_policy_v_pi(
        mdp_agent_v_pi_policy_iteration,
        mdp_environment,
        0.001,
        True
    )

    # run value iteration on v_pi
    mdp_agent_v_pi_value_iteration = StochasticMdpAgent(
        'test',
        random_state,
        TabularPolicy(None, mdp_environment.SS),
        1
    )

    iterate_value_v_pi(
        mdp_agent_v_pi_value_iteration,
        mdp_environment,
        0.001,
        1,
        True
    )

    assert mdp_agent_v_pi_policy_iteration.pi == mdp_agent_v_pi_value_iteration.pi

    # run value iteration on q_pi
    mdp_agent_q_pi_value_iteration = StochasticMdpAgent(
        'test',
        random_state,
        TabularPolicy(None, mdp_environment.SS),
        1
    )

    iterate_value_q_pi(
        mdp_agent_q_pi_value_iteration,
        mdp_environment,
        0.001,
        1,
        True
    )

    assert mdp_agent_q_pi_value_iteration.pi == mdp_agent_v_pi_policy_iteration.pi
Exemple #2
0
def test_run():

    random_state = RandomState(12345)

    mdp_environment: GamblersProblem = GamblersProblem(
        'gamblers problem',
        random_state=random_state,
        T=None,
        p_h=0.4
    )

    agent = StochasticMdpAgent(
        'test',
        random_state,
        TabularPolicy(None, mdp_environment.SS),
        1
    )

    monitor = Monitor()
    state = mdp_environment.reset_for_new_run(agent)
    agent.reset_for_new_run(state)
    mdp_environment.run(agent, monitor)

    # uncomment the following line and run test to update fixture
    # with open(f'{os.path.dirname(__file__)}/fixtures/test_run.pickle', 'wb') as file:
    #     pickle.dump(monitor, file)

    with open(f'{os.path.dirname(__file__)}/fixtures/test_run.pickle', 'rb') as file:
        fixture = pickle.load(file)

    assert monitor.t_average_reward == fixture.t_average_reward
Exemple #3
0
def test_gamblers_problem():

    random_state = RandomState(12345)

    mdp_environment: GamblersProblem = GamblersProblem(
        'gamblers problem',
        random_state=random_state,
        T=None,
        p_h=0.4
    )

    mdp_agent_v_pi_value_iteration = StochasticMdpAgent(
        'test',
        random_state,
        TabularPolicy(None, mdp_environment.SS),
        1
    )

    v_pi = iterate_value_v_pi(
        mdp_agent_v_pi_value_iteration,
        mdp_environment,
        0.001,
        1,
        True
    )

    # uncomment the following line and run test to update fixture
    # with open(f'{os.path.dirname(__file__)}/fixtures/test_gamblers_problem.pickle', 'wb') as file:
    #     pickle.dump(v_pi, file)

    with open(f'{os.path.dirname(__file__)}/fixtures/test_gamblers_problem.pickle', 'rb') as file:
        fixture = pickle.load(file)

    assert v_pi == fixture
Exemple #4
0
def test_human_player_mutator():

    random = RandomState()
    mancala = Mancala(
        random, None, 5,
        StochasticMdpAgent('foo', random, TabularPolicy(None, []), 1.0))
    Mancala.human_player_mutator(mancala)

    assert isinstance(mancala.player_2, Human)
Exemple #5
0
    def __init__(self):
        """
        Initialize the agent.
        """

        super().__init__(name='human', random_state=None)

        # TODO:  This is a hack to make the human agent compatible with tabular methods, which request state
        # identifiers during operation.
        self.pi = TabularPolicy(None, None)
Exemple #6
0
def dump_agent() -> str:

    # create dummy mdp agent for runner
    stochastic_mdp_agent = StochasticMdpAgent('foo', RandomState(12345),
                                              TabularPolicy(None, None), 1.0)
    agent_path = tempfile.NamedTemporaryFile(delete=False).name
    with open(agent_path, 'wb') as f:
        pickle.dump(stochastic_mdp_agent, f)

    return agent_path
Exemple #7
0
    def get_initial_policy(self) -> TabularPolicy:
        """
        Get the initial policy defined by the estimator.

        :return: Policy.
        """

        return TabularPolicy(continuous_state_discretization_resolution=self.
                             continuous_state_discretization_resolution,
                             SS=self.SS)
Exemple #8
0
def test_policy_iteration():

    random_state = RandomState(12345)

    mdp_environment: Gridworld = Gridworld.example_4_1(random_state, None)

    # state-value policy iteration
    mdp_agent_v_pi = StochasticMdpAgent(
        'test',
        random_state,
        TabularPolicy(None, mdp_environment.SS),
        1
    )

    iterate_policy_v_pi(
        mdp_agent_v_pi,
        mdp_environment,
        0.001,
        True
    )

    # action-value policy iteration
    mdp_agent_q_pi = StochasticMdpAgent(
        'test',
        random_state,
        TabularPolicy(None, mdp_environment.SS),
        1
    )

    iterate_policy_q_pi(
        mdp_agent_q_pi,
        mdp_environment,
        0.001,
        True
    )

    # should get the same policy
    assert mdp_agent_v_pi.pi == mdp_agent_q_pi.pi
Exemple #9
0
def test_off_policy_monte_carlo_with_function_approximation():

    random_state = RandomState(12345)

    mdp_environment: Gridworld = Gridworld.example_4_1(random_state, None)

    q_S_A = ApproximateStateActionValueEstimator(
        mdp_environment, 0.05,
        SKLearnSGD(random_state=random_state, scale_eta0_for_y=False),
        GridworldFeatureExtractor(mdp_environment), None, False, None, None)

    # target agent
    mdp_agent = StochasticMdpAgent('test', random_state,
                                   q_S_A.get_initial_policy(), 1)

    # episode generation (behavior) policy
    off_policy_agent = StochasticMdpAgent('test', random_state,
                                          TabularPolicy(None, None), 1)

    iterate_value_q_pi(agent=mdp_agent,
                       environment=mdp_environment,
                       num_improvements=100,
                       num_episodes_per_improvement=1,
                       update_upon_every_visit=True,
                       planning_environment=None,
                       make_final_policy_greedy=False,
                       q_S_A=q_S_A,
                       off_policy_agent=off_policy_agent)

    # uncomment the following line and run test to update fixture
    # with open(f'{os.path.dirname(__file__)}/fixtures/test_off_policy_monte_carlo_with_function_approximationo.pickle', 'wb') as file:
    #     pickle.dump((mdp_agent.pi, q_S_A), file)

    with open(
            f'{os.path.dirname(__file__)}/fixtures/test_off_policy_monte_carlo_with_function_approximationo.pickle',
            'rb') as file:
        pi_fixture, q_S_A_fixture = pickle.load(file)

    assert mdp_agent.pi == pi_fixture and q_S_A == q_S_A_fixture
    assert str(mdp_agent.pi.estimator[mdp_environment.SS[5]][
        mdp_environment.SS[5].AA[1]]).startswith('-1.4524')

    # make greedy
    q_S_A.epsilon = 0.0
    assert q_S_A.improve_policy(
        mdp_agent, None, PolicyImprovementEvent.MAKING_POLICY_GREEDY) == -1
    assert mdp_agent.pi.estimator.epsilon == 0.0
Exemple #10
0
def test_agent_invalid_action():

    random = RandomState()
    agent = StochasticMdpAgent('foo', random, TabularPolicy(None, None), 1.0)

    # test None action
    agent.__act__ = lambda t: None

    with pytest.raises(ValueError, match='Agent returned action of None'):
        agent.act(0)

    # test infeasible action
    action = Action(1, 'foo')
    agent.__act__ = lambda t: action
    state = MdpState(1, [], False)
    agent.sense(state, 0)
    with pytest.raises(ValueError, match=f'Action {action} is not feasible in state {state}'):
        agent.act(0)
Exemple #11
0
def test_invalid_get_state_i():

    policy = TabularPolicy(None, None)

    with pytest.raises(
            ValueError,
            match=
            'Attempted to discretize a continuous state without a resolution.'
    ):
        policy.get_state_i(np.array([[1, 2, 3]]))

    with pytest.raises(ValueError,
                       match=f'Unknown state space type:  {type(3)}'):
        policy.get_state_i(3)
Exemple #12
0
    def init_from_arguments(
            cls, args: List[str],
            random_state: RandomState) -> Tuple[Environment, List[str]]:
        """
        Initialize an environment from arguments.

        :param args: Arguments.
        :param random_state: Random state.
        :return: 2-tuple of an environment and a list of unparsed arguments.
        """

        parsed_args, unparsed_args = parse_arguments(cls, args)

        mancala = cls(random_state=random_state,
                      player_2=StochasticMdpAgent('environmental agent',
                                                  random_state,
                                                  TabularPolicy(None,
                                                                None), 1),
                      **vars(parsed_args))

        return mancala, unparsed_args
Exemple #13
0
def test_evaluate_v_pi():

    random_state = RandomState(12345)

    mdp_environment: Gridworld = Gridworld.example_4_1(random_state, None)

    mdp_agent = StochasticMdpAgent('test', random_state,
                                   TabularPolicy(None, mdp_environment.SS), 1)

    v_pi = evaluate_v_pi(agent=mdp_agent,
                         environment=mdp_environment,
                         num_episodes=1000)

    # uncomment the following line and run test to update fixture
    # with open(f'{os.path.dirname(__file__)}/fixtures/test_monte_carlo_evaluation_of_state_value.pickle', 'wb') as file:
    #     pickle.dump(v_pi, file)

    with open(
            f'{os.path.dirname(__file__)}/fixtures/test_monte_carlo_evaluation_of_state_value.pickle',
            'rb') as file:
        fixture = pickle.load(file)

    assert v_pi == fixture
Exemple #14
0
def test_evaluate_q_pi():

    random_state = RandomState(12345)

    mdp_environment: Gridworld = Gridworld.example_4_1(random_state, None)

    mdp_agent = StochasticMdpAgent('test', random_state,
                                   TabularPolicy(None, mdp_environment.SS), 1)

    q_pi, _ = evaluate_q_pi(agent=mdp_agent,
                            environment=mdp_environment,
                            theta=0.001,
                            num_iterations=100,
                            update_in_place=True)

    q_pi_not_in_place, _ = evaluate_q_pi(agent=mdp_agent,
                                         environment=mdp_environment,
                                         theta=0.001,
                                         num_iterations=200,
                                         update_in_place=False)

    assert list(q_pi.keys()) == list(q_pi_not_in_place.keys())

    for s in q_pi:
        for a in q_pi[s]:
            assert np.allclose(q_pi[s][a], q_pi_not_in_place[s][a], atol=0.01)

    # uncomment the following line and run test to update fixture
    # with open(f'{os.path.dirname(__file__)}/fixtures/test_iterative_policy_evaluation_of_action_value.pickle', 'wb') as file:
    #     pickle.dump(q_pi, file)

    with open(
            f'{os.path.dirname(__file__)}/fixtures/test_iterative_policy_evaluation_of_action_value.pickle',
            'rb') as file:
        fixture = pickle.load(file)

    assert q_pi == fixture
Exemple #15
0
def test_learn():

    random_state = RandomState(12345)

    mancala: Mancala = Mancala(random_state=random_state,
                               T=None,
                               initial_count=4,
                               player_2=StochasticMdpAgent(
                                   'player 2', random_state,
                                   TabularPolicy(None, None), 1))

    q_S_A = TabularStateActionValueEstimator(mancala, 0.05, None)

    p1 = StochasticMdpAgent('player 1', random_state,
                            q_S_A.get_initial_policy(), 1)

    checkpoint_path = iterate_value_q_pi(
        agent=p1,
        environment=mancala,
        num_improvements=3,
        num_episodes_per_improvement=100,
        update_upon_every_visit=False,
        planning_environment=None,
        make_final_policy_greedy=False,
        q_S_A=q_S_A,
        num_improvements_per_checkpoint=3,
        checkpoint_path=tempfile.NamedTemporaryFile(delete=False).name)

    # uncomment the following line and run test to update fixture
    # with open(f'{os.path.dirname(__file__)}/fixtures/test_mancala.pickle', 'wb') as file:
    #     pickle.dump(p1.pi, file)

    with open(f'{os.path.dirname(__file__)}/fixtures/test_mancala.pickle',
              'rb') as file:
        fixture = pickle.load(file)

    assert tabular_pi_legacy_eq(p1.pi, fixture)

    resumed_p1 = resume_from_checkpoint(checkpoint_path=checkpoint_path,
                                        resume_function=iterate_value_q_pi,
                                        num_improvements=2)

    # run same number of improvements without checkpoint...result should be the same.
    random_state = RandomState(12345)

    mancala: Mancala = Mancala(random_state=random_state,
                               T=None,
                               initial_count=4,
                               player_2=StochasticMdpAgent(
                                   'player 2', random_state,
                                   TabularPolicy(None, None), 1))

    q_S_A = TabularStateActionValueEstimator(mancala, 0.05, None)

    no_checkpoint_p1 = StochasticMdpAgent('player 1', random_state,
                                          q_S_A.get_initial_policy(), 1)

    iterate_value_q_pi(agent=no_checkpoint_p1,
                       environment=mancala,
                       num_improvements=5,
                       num_episodes_per_improvement=100,
                       update_upon_every_visit=False,
                       planning_environment=None,
                       make_final_policy_greedy=False,
                       q_S_A=q_S_A)

    assert no_checkpoint_p1.pi == resumed_p1.pi
Exemple #16
0
def test_policy_not_equal():

    policy_1 = TabularPolicy(None, None)
    policy_2 = TabularPolicy(None, None)

    assert not (policy_1 != policy_2)
Exemple #17
0
 def get_initial_policy() -> TabularPolicy:
     return TabularPolicy(None, None)