Exemple #1
0
def test_iterate_value_q_pi():

    random_state = RandomState(12345)

    mdp_environment: Gridworld = Gridworld.example_4_1(random_state, None)

    q_S_A = TabularStateActionValueEstimator(mdp_environment, 0.1, None)

    mdp_agent = StochasticMdpAgent('test', random_state,
                                   q_S_A.get_initial_policy(), 1)

    iterate_value_q_pi(agent=mdp_agent,
                       environment=mdp_environment,
                       num_improvements=3000,
                       num_episodes_per_improvement=1,
                       update_upon_every_visit=False,
                       planning_environment=None,
                       make_final_policy_greedy=False,
                       q_S_A=q_S_A)

    # uncomment the following line and run test to update fixture
    # with open(f'{os.path.dirname(__file__)}/fixtures/test_monte_carlo_iteration_of_value_q_pi.pickle', 'wb') as file:
    #     pickle.dump((mdp_agent.pi, q_S_A), file)

    with open(
            f'{os.path.dirname(__file__)}/fixtures/test_monte_carlo_iteration_of_value_q_pi.pickle',
            'rb') as file:
        pi_fixture, q_S_A_fixture = pickle.load(file)

    assert tabular_pi_legacy_eq(mdp_agent.pi,
                                pi_fixture) and tabular_estimator_legacy_eq(
                                    q_S_A, q_S_A_fixture)
Exemple #2
0
def test_iterate_value_q_pi_with_pdf():

    random_state = RandomState(12345)

    mdp_environment: Gridworld = Gridworld.example_4_1(random_state, None)

    q_S_A = TabularStateActionValueEstimator(mdp_environment, 0.1, None)

    mdp_agent = StochasticMdpAgent('test', random_state,
                                   q_S_A.get_initial_policy(), 1)

    iterate_value_q_pi(
        agent=mdp_agent,
        environment=mdp_environment,
        num_improvements=3000,
        num_episodes_per_improvement=1,
        update_upon_every_visit=False,
        planning_environment=None,
        make_final_policy_greedy=False,
        q_S_A=q_S_A,
        num_improvements_per_plot=1500,
        pdf_save_path=tempfile.NamedTemporaryFile(delete=False).name)

    with pytest.raises(ValueError, match='Epsilon must be >= 0'):
        q_S_A.epsilon = -1.0
        q_S_A.improve_policy(mdp_agent,
                             states=None,
                             event=PolicyImprovementEvent.MAKING_POLICY_GREEDY)

    q_S_A.epsilon = 0.0
    assert q_S_A.improve_policy(
        mdp_agent, None, PolicyImprovementEvent.MAKING_POLICY_GREEDY) == 14
Exemple #3
0
def test_off_policy_monte_carlo_with_function_approximation():

    random_state = RandomState(12345)

    mdp_environment: Gridworld = Gridworld.example_4_1(random_state, None)

    q_S_A = ApproximateStateActionValueEstimator(
        mdp_environment,
        0.05,
        SKLearnSGD(random_state=random_state, scale_eta0_for_y=False),
        GridworldFeatureExtractor(mdp_environment),
        None,
        False,
        None,
        None
    )

    # target agent
    mdp_agent = ActionValueMdpAgent(
        'test',
        random_state,
        1,
        q_S_A
    )

    # episode generation (behavior) policy
    off_policy_agent = ActionValueMdpAgent(
        'test',
        random_state,
        1,
        TabularStateActionValueEstimator(mdp_environment, None, None)
    )

    iterate_value_q_pi(
        agent=mdp_agent,
        environment=mdp_environment,
        num_improvements=100,
        num_episodes_per_improvement=1,
        update_upon_every_visit=True,
        planning_environment=None,
        make_final_policy_greedy=False,
        off_policy_agent=off_policy_agent
    )

    # uncomment the following line and run test to update fixture
    # with open(f'{os.path.dirname(__file__)}/fixtures/test_off_policy_monte_carlo_with_function_approximationo.pickle', 'wb') as file:
    #     pickle.dump((mdp_agent.pi, q_S_A), file)

    with open(f'{os.path.dirname(__file__)}/fixtures/test_off_policy_monte_carlo_with_function_approximationo.pickle', 'rb') as file:
        pi_fixture, q_S_A_fixture = pickle.load(file)

    assert mdp_agent.pi == pi_fixture and q_S_A == q_S_A_fixture
    assert str(mdp_agent.pi.estimator[mdp_environment.SS[5]][mdp_environment.SS[5].AA[1]]).startswith('-2.4305')

    # make greedy
    q_S_A.epsilon = 0.0
    assert q_S_A.improve_policy(mdp_agent, None, PolicyImprovementEvent.MAKING_POLICY_GREEDY) == -1
    assert mdp_agent.pi.estimator.epsilon == 0.0
Exemple #4
0
def test_learn():

    random_state = RandomState(12345)

    mancala: Mancala = Mancala(random_state=random_state,
                               T=None,
                               initial_count=4,
                               player_2=None)

    p1 = ActionValueMdpAgent(
        'player 1', random_state, 1,
        TabularStateActionValueEstimator(mancala, 0.05, None))

    checkpoint_path = iterate_value_q_pi(
        agent=p1,
        environment=mancala,
        num_improvements=3,
        num_episodes_per_improvement=100,
        update_upon_every_visit=False,
        planning_environment=None,
        make_final_policy_greedy=False,
        num_improvements_per_checkpoint=3,
        checkpoint_path=tempfile.NamedTemporaryFile(delete=False).name)

    # uncomment the following line and run test to update fixture
    # with open(f'{os.path.dirname(__file__)}/fixtures/test_mancala.pickle', 'wb') as file:
    #     pickle.dump(p1.pi, file)

    with open(f'{os.path.dirname(__file__)}/fixtures/test_mancala.pickle',
              'rb') as file:
        fixture = pickle.load(file)

    assert tabular_pi_legacy_eq(p1.pi, fixture)

    resumed_p1 = resume_from_checkpoint(checkpoint_path=checkpoint_path,
                                        resume_function=iterate_value_q_pi,
                                        num_improvements=2)

    # run same number of improvements without checkpoint...result should be the same.
    random_state = RandomState(12345)
    mancala: Mancala = Mancala(random_state=random_state,
                               T=None,
                               initial_count=4,
                               player_2=None)
    no_checkpoint_p1 = ActionValueMdpAgent(
        'player 1', random_state, 1,
        TabularStateActionValueEstimator(mancala, 0.05, None))

    iterate_value_q_pi(agent=no_checkpoint_p1,
                       environment=mancala,
                       num_improvements=5,
                       num_episodes_per_improvement=100,
                       update_upon_every_visit=False,
                       planning_environment=None,
                       make_final_policy_greedy=False)

    assert no_checkpoint_p1.pi == resumed_p1.pi
Exemple #5
0
def test_invalid_iterate_value_q_pi():

    random_state = RandomState(12345)

    mdp_environment: Gridworld = Gridworld.example_4_1(random_state, None)

    q_S_A = TabularStateActionValueEstimator(mdp_environment, 0.0, None)

    # target agent
    mdp_agent = StochasticMdpAgent('test', random_state,
                                   q_S_A.get_initial_policy(), 1)

    # episode generation (behavior) policy
    off_policy_agent = StochasticMdpAgent('test', random_state,
                                          q_S_A.get_initial_policy(), 1)

    with pytest.raises(
            ValueError,
            match=
            'Planning environments are not currently supported for Monte Carlo iteration.'
    ):
        iterate_value_q_pi(
            agent=mdp_agent,
            environment=mdp_environment,
            num_improvements=100,
            num_episodes_per_improvement=1,
            update_upon_every_visit=True,
            planning_environment=TrajectorySamplingMdpPlanningEnvironment(
                'foo', random_state, StochasticEnvironmentModel(), 100, None),
            make_final_policy_greedy=False,
            q_S_A=q_S_A,
            off_policy_agent=off_policy_agent)

    # test warning...no off-policy agent with epsilon=0.0
    q_S_A.epsilon = 0.0
    iterate_value_q_pi(agent=mdp_agent,
                       environment=mdp_environment,
                       num_improvements=100,
                       num_episodes_per_improvement=1,
                       update_upon_every_visit=True,
                       planning_environment=None,
                       make_final_policy_greedy=False,
                       q_S_A=q_S_A,
                       off_policy_agent=None)
Exemple #6
0
    def train_thread_target():
        random_state = RandomState(12345)

        mdp_environment: Gridworld = Gridworld.example_4_1(random_state, None)

        q_S_A = TabularStateActionValueEstimator(mdp_environment, 0.1, None)

        mdp_agent = StochasticMdpAgent('test', random_state,
                                       q_S_A.get_initial_policy(), 1)

        iterate_value_q_pi(agent=mdp_agent,
                           environment=mdp_environment,
                           num_improvements=1000000,
                           num_episodes_per_improvement=10,
                           update_upon_every_visit=False,
                           planning_environment=None,
                           make_final_policy_greedy=False,
                           q_S_A=q_S_A,
                           thread_manager=thread_manager,
                           num_improvements_per_plot=10)