Example #1
0
def test_manual_versus_jax_policy_gradient():

    manual_agent_path = tempfile.NamedTemporaryFile(delete=False).name
    run(
        shlex.split(
            f'--random-seed 12345 --agent rlai.agents.mdp.ParameterizedMdpAgent --gamma 1 --environment rlai.environments.gridworld.Gridworld --id example_4_1 --train-function rlai.policy_gradient.monte_carlo.reinforce.improve --num-episodes 10 --policy rlai.policies.parameterized.discrete_action.SoftMaxInActionPreferencesPolicy --policy-feature-extractor rlai.environments.gridworld.GridworldFeatureExtractor --alpha 0.0001 --update-upon-every-visit True --save-agent-path {manual_agent_path} --log DEBUG'
        ))
    with open(manual_agent_path, 'rb') as f:
        manual_agent = pickle.load(f)

    jax_agent_path = tempfile.NamedTemporaryFile(delete=False).name
    run(
        shlex.split(
            f'--random-seed 12345 --agent rlai.agents.mdp.ParameterizedMdpAgent --gamma 1 --environment rlai.environments.gridworld.Gridworld --id example_4_1 --train-function rlai.policy_gradient.monte_carlo.reinforce.improve --num-episodes 10 --policy rlai.policies.parameterized.discrete_action.SoftMaxInActionPreferencesJaxPolicy --policy-feature-extractor rlai.environments.gridworld.GridworldFeatureExtractor --alpha 0.0001 --update-upon-every-visit True --save-agent-path {jax_agent_path} --log DEBUG'
        ))
    with open(jax_agent_path, 'rb') as f:
        jax_agent = pickle.load(f)

    # uncomment the following line and run test to update fixture
    # with open(f'{os.path.dirname(__file__)}/fixtures/test_manual_versus_jax_policy_gradient.pickle', 'wb') as file:
    #     pickle.dump(jax_agent, file)

    with open(
            f'{os.path.dirname(__file__)}/fixtures/test_manual_versus_jax_policy_gradient.pickle',
            'rb') as file:
        fixture_agent = pickle.load(file)

    assert np.allclose(manual_agent.pi.theta, jax_agent.pi.theta)
    assert np.allclose(jax_agent.pi.theta, fixture_agent.pi.theta)
Example #2
0
def test_unparsed_arguments():

    with pytest.raises(ValueError, match='Unparsed arguments'):
        run(
            shlex.split(
                '--agent rlai.agents.mdp.StochasticMdpAgent --gamma 1 --environment rlai.environments.gridworld.Gridworld --id example_4_1 --train-function rlai.gpi.temporal_difference.iteration.iterate_value_q_pi --mode Q_LEARNING --num-improvements 10 --num-episodes-per-improvement 5 --epsilon 0.01 --q-S-A rlai.q_S_A.tabular.TabularStateActionValueEstimator --make-final-policy-greedy True --XXXX'
            ))
Example #3
0
def test_resume():

    checkpoint_path, agent_path = run(
        shlex.split(
            f'--random-seed 12345 --agent rlai.agents.mdp.ParameterizedMdpAgent --gamma 1.0 --environment rlai.environments.openai_gym.Gym --gym-id LunarLanderContinuous-v2 --plot-environment --T 500 --train-function rlai.policy_gradient.monte_carlo.reinforce.improve --num-episodes 2 --plot-state-value True --v-S rlai.v_S.function_approximation.estimators.ApproximateStateValueEstimator --feature-extractor rlai.environments.openai_gym.ContinuousLunarLanderFeatureExtractor --function-approximation-model rlai.models.sklearn.SKLearnSGD --loss squared_error --sgd-alpha 0.0 --learning-rate constant --eta0 0.0001 --policy rlai.policies.parameterized.continuous_action.ContinuousActionBetaDistributionPolicy --policy-feature-extractor rlai.environments.openai_gym.ContinuousLunarLanderFeatureExtractor --plot-policy --alpha 0.0001 --update-upon-every-visit True --checkpoint-path {tempfile.NamedTemporaryFile(delete=False).name} --num-episodes-per-checkpoint 1 --save-agent-path {tempfile.NamedTemporaryFile(delete=False).name} --log DEBUG'
        ))

    _, resumed_agent_path = run(
        shlex.split(
            f'--resume --random-seed 12345 --train-function rlai.policy_gradient.monte_carlo.reinforce.improve --num-episodes 2 --checkpoint-path {checkpoint_path} --save-agent-path {tempfile.NamedTemporaryFile(delete=False).name}'
        ))

    with open(resumed_agent_path, 'rb') as f:
        agent = pickle.load(f)

    # uncomment the following line and run test to update fixture
    # with open(f'{os.path.dirname(__file__)}/fixtures/test_resume.pickle', 'wb') as file:
    #     pickle.dump(agent, file)

    with open(f'{os.path.dirname(__file__)}/fixtures/test_resume.pickle',
              'rb') as file:
        agent_fixture = pickle.load(file)

    # assert that we get the expected result
    assert agent.pi == agent_fixture.pi
Example #4
0
def test_too_many_coefficients_for_plot_model():

    old_vals = (rlai.q_S_A.function_approximation.models.MAX_PLOT_COEFFICIENTS,
                rlai.q_S_A.function_approximation.models.MAX_PLOT_ACTIONS)

    rlai.q_S_A.function_approximation.models.MAX_PLOT_COEFFICIENTS = 2
    rlai.q_S_A.function_approximation.models.MAX_PLOT_ACTIONS = 2

    run(
        shlex.split(
            f'--random-seed 12345 --agent rlai.agents.mdp.StochasticMdpAgent --gamma 1 --environment rlai.environments.gridworld.Gridworld --id example_4_1 --T 25 --train-function rlai.gpi.temporal_difference.iteration.iterate_value_q_pi --mode SARSA --num-improvements 10 --num-episodes-per-improvement 50 --epsilon 0.05 --q-S-A rlai.q_S_A.function_approximation.estimators.ApproximateStateActionValueEstimator --plot-model --plot-model-bins 10 --function-approximation-model rlai.q_S_A.function_approximation.models.sklearn.SKLearnSGD --feature-extractor rlai.environments.gridworld.GridworldFeatureExtractor --make-final-policy-greedy True --num-improvements-per-checkpoint 5 --checkpoint-path {tempfile.NamedTemporaryFile(delete=False).name} --save-agent-path {tempfile.NamedTemporaryFile(delete=False).name}'
        ))

    (rlai.q_S_A.function_approximation.models.MAX_PLOT_COEFFICIENTS,
     rlai.q_S_A.function_approximation.models.MAX_PLOT_ACTIONS) = old_vals
Example #5
0
def test_resume_gym_valid_environment():

    start_virtual_display_if_headless()

    def resume_args_mutator(resume_args: Dict):
        print(f'Called mutator:  {len(resume_args)} resume arguments.')

    def train_function_args_callback(args: Dict):
        print(f'Called callback:  {len(args)} resume arguments.')

    run_args = f'--random-seed 12345 --agent rlai.agents.mdp.StochasticMdpAgent --continuous-state-discretization-resolution 0.005 --gamma 0.95 --environment rlai.environments.openai_gym.Gym --gym-id CartPole-v1 --render-every-nth-episode 2 --train-function rlai.gpi.monte_carlo.iteration.iterate_value_q_pi --num-improvements 2 --num-episodes-per-improvement 2 --update-upon-every-visit True --epsilon 0.2 --q-S-A rlai.q_S_A.tabular.TabularStateActionValueEstimator --make-final-policy-greedy False --num-improvements-per-plot 2 --num-improvements-per-checkpoint 2 --checkpoint-path {tempfile.NamedTemporaryFile(delete=False).name} --save-agent-path {tempfile.NamedTemporaryFile(delete=False).name}'
    checkpoint_path, agent_path = run(
        args=shlex.split(run_args),
        train_function_args_callback=train_function_args_callback)

    random_state = RandomState(12345)
    resume_environment = Gym(random_state, None, 'CartPole-v1', None)
    agent = resume_from_checkpoint(checkpoint_path,
                                   iterate_value_q_pi,
                                   environment=resume_environment,
                                   num_improvements=2,
                                   resume_args_mutator=resume_args_mutator)

    resume_environment.close()

    # uncomment the following line and run test to update fixture
    # with open(f'{os.path.dirname(__file__)}/fixtures/test_resume_gym_valid_environment.pickle', 'wb') as file:
    #     pickle.dump(agent.pi, file)

    with open(
            f'{os.path.dirname(__file__)}/fixtures/test_resume_gym_valid_environment.pickle',
            'rb') as file:
        pi_fixture = pickle.load(file)

    assert agent.pi == pi_fixture
Example #6
0
def test_resume():

    checkpoint_path, agent_path = run(shlex.split(f'--random-seed 12345 --agent rlai.agents.mdp.ActionValueMdpAgent --gamma 1 --environment rlai.environments.gridworld.Gridworld --id example_4_1 --train-function rlai.gpi.temporal_difference.iteration.iterate_value_q_pi --mode Q_LEARNING --num-improvements 10 --num-episodes-per-improvement 5 --epsilon 0.01 --q-S-A rlai.q_S_A.tabular.TabularStateActionValueEstimator --make-final-policy-greedy True --num-improvements-per-checkpoint 10 --checkpoint-path {tempfile.NamedTemporaryFile(delete=False).name} --save-agent-path {tempfile.NamedTemporaryFile(delete=False).name}'))

    _, resumed_agent_path = run(shlex.split(f'--resume --random-seed 12345 --train-function rlai.gpi.temporal_difference.iteration.iterate_value_q_pi --num-improvements 10 --checkpoint-path {checkpoint_path} --save-agent-path {tempfile.NamedTemporaryFile(delete=False).name}'))

    with open(resumed_agent_path, 'rb') as f:
        agent = pickle.load(f)

    # uncomment the following line and run test to update fixture
    # with open(f'{os.path.dirname(__file__)}/fixtures/test_resume.pickle', 'wb') as file:
    #     pickle.dump(agent, file)

    with open(f'{os.path.dirname(__file__)}/fixtures/test_resume.pickle', 'rb') as file:
        agent_fixture = pickle.load(file)

    assert agent.pi == agent_fixture.pi
Example #7
0
def test_train():

    checkpoint_path_top_level, agent_path_top_level = top_level.run(
        shlex.split(
            f'train --random-seed 12345 --agent rlai.agents.mdp.StochasticMdpAgent --gamma 1 --environment rlai.environments.gridworld.Gridworld --id example_4_1 --train-function rlai.gpi.temporal_difference.iteration.iterate_value_q_pi --mode Q_LEARNING --num-improvements 10 --num-episodes-per-improvement 5 --epsilon 0.01 --q-S-A rlai.q_S_A.tabular.TabularStateActionValueEstimator --make-final-policy-greedy True --num-improvements-per-checkpoint 10 --checkpoint-path {tempfile.NamedTemporaryFile(delete=False).name} --save-agent-path {tempfile.NamedTemporaryFile(delete=False).name}'
        ))
    with open(agent_path_top_level, 'rb') as f:
        agent_top_level = pickle.load(f)

    checkpoint_path_train, agent_path_train = trainer.run(
        shlex.split(
            f'--random-seed 12345 --agent rlai.agents.mdp.StochasticMdpAgent --gamma 1 --environment rlai.environments.gridworld.Gridworld --id example_4_1 --train-function rlai.gpi.temporal_difference.iteration.iterate_value_q_pi --mode Q_LEARNING --num-improvements 10 --num-episodes-per-improvement 5 --epsilon 0.01 --q-S-A rlai.q_S_A.tabular.TabularStateActionValueEstimator --make-final-policy-greedy True --num-improvements-per-checkpoint 10 --checkpoint-path {tempfile.NamedTemporaryFile(delete=False).name} --save-agent-path {tempfile.NamedTemporaryFile(delete=False).name}'
        ))
    with open(agent_path_train, 'rb') as f:
        agent_train = pickle.load(f)

    assert agent_top_level.pi == agent_train.pi
Example #8
0
def test_gym_cartpole_function_approximation_plot_model():

    start_virtual_display_if_headless()

    checkpoint_path, agent_path = run(shlex.split(f'--random-seed 12345 --agent rlai.agents.mdp.ActionValueMdpAgent --gamma 0.95 --environment rlai.environments.openai_gym.Gym --gym-id CartPole-v1 --render-every-nth-episode 2 --train-function rlai.gpi.temporal_difference.iteration.iterate_value_q_pi --mode SARSA --num-improvements 2 --num-episodes-per-improvement 2 --num-updates-per-improvement 1 --epsilon 0.2 --q-S-A rlai.q_S_A.function_approximation.estimators.ApproximateStateActionValueEstimator --plot-model --plot-model-bins 10 --function-approximation-model rlai.q_S_A.function_approximation.models.sklearn.SKLearnSGD --loss squared_error --sgd-alpha 0.0 --learning-rate constant --eta0 0.001 --feature-extractor rlai.environments.openai_gym.CartpoleFeatureExtractor --make-final-policy-greedy True --num-improvements-per-plot 2 --num-improvements-per-checkpoint 2 --checkpoint-path {tempfile.NamedTemporaryFile(delete=False).name} --save-agent-path {tempfile.NamedTemporaryFile(delete=False).name}'))

    _, agent = load_checkpoint_and_agent(checkpoint_path, agent_path)

    # uncomment the following line and run test to update fixture
    # with open(f'{os.path.dirname(__file__)}/fixtures/test_gym_cartpole_function_approximation_plot_model.pickle', 'wb') as f:
    #     pickle.dump(agent, f)

    with open(f'{os.path.dirname(__file__)}/fixtures/test_gym_cartpole_function_approximation_plot_model.pickle', 'rb') as f:
        agent_fixture = pickle.load(f)

    assert_run(
        agent,
        agent_fixture
    )
Example #9
0
def test_q_learning_with_patsy_formula():

    start_virtual_display_if_headless()

    checkpoint_path, agent_path = run(shlex.split(f'--random-seed 12345 --agent rlai.agents.mdp.ActionValueMdpAgent --gamma 1 --environment rlai.environments.gridworld.Gridworld --id example_4_1 --T 25 --train-function rlai.gpi.temporal_difference.iteration.iterate_value_q_pi --mode Q_LEARNING --num-improvements 5 --num-episodes-per-improvement 5 --epsilon 0.05 --q-S-A rlai.q_S_A.function_approximation.estimators.ApproximateStateActionValueEstimator --function-approximation-model rlai.q_S_A.function_approximation.models.sklearn.SKLearnSGD --verbose 1 --feature-extractor rlai.q_S_A.function_approximation.models.feature_extraction.StateActionIdentityFeatureExtractor --formula "C(s, levels={list(range(16))}):C(a, levels={list(range(4))})" --make-final-policy-greedy True --num-improvements-per-checkpoint 5 --checkpoint-path {tempfile.NamedTemporaryFile(delete=False).name} --save-agent-path {tempfile.NamedTemporaryFile(delete=False).name}'))

    _, agent = load_checkpoint_and_agent(checkpoint_path, agent_path)

    # uncomment the following line and run test to update fixture
    # with open(f'{os.path.dirname(__file__)}/fixtures/test_q_learning_with_patsy_formula.pickle', 'wb') as f:
    #     pickle.dump(agent, f)

    with open(f'{os.path.dirname(__file__)}/fixtures/test_q_learning_with_patsy_formula.pickle', 'rb') as f:
        agent_fixture = pickle.load(f)

    assert_run(
        agent,
        agent_fixture
    )
Example #10
0
def test_prioritized_sweeping_planning_high_threshold():

    start_virtual_display_if_headless()

    checkpoint_path, agent_path = run(shlex.split(f'--random-seed 12345 --agent rlai.agents.mdp.ActionValueMdpAgent --gamma 1 --environment rlai.environments.gridworld.Gridworld --id example_4_1 --planning-environment rlai.environments.mdp.PrioritizedSweepingMdpPlanningEnvironment --num-planning-improvements-per-direct-improvement 10 --priority-theta -10 --T-planning 50 --train-function rlai.gpi.temporal_difference.iteration.iterate_value_q_pi --mode Q_LEARNING --num-improvements 10 --num-episodes-per-improvement 1 --epsilon 0.01 --q-S-A rlai.q_S_A.tabular.TabularStateActionValueEstimator --make-final-policy-greedy True --num-improvements-per-checkpoint 10 --checkpoint-path {tempfile.NamedTemporaryFile(delete=False).name} --save-agent-path {tempfile.NamedTemporaryFile(delete=False).name}'))

    _, agent = load_checkpoint_and_agent(checkpoint_path, agent_path)

    # uncomment the following line and run test to update fixture
    # with open(f'{os.path.dirname(__file__)}/fixtures/test_prioritized_sweeping_planning_high_threshold.pickle', 'wb') as f:
    #     pickle.dump(agent, f)

    with open(f'{os.path.dirname(__file__)}/fixtures/test_prioritized_sweeping_planning_high_threshold.pickle', 'rb') as f:
        agent_fixture = pickle.load(f)

    assert_run(
        agent,
        agent_fixture
    )
Example #11
0
def test_policy_gradient_reinforce_softmax_action_preferences_with_baseline():

    start_virtual_display_if_headless()

    checkpoint_path, agent_path = run(shlex.split(f'--random-seed 12345 --agent rlai.agents.mdp.ParameterizedMdpAgent --gamma 1 --environment rlai.environments.gridworld.Gridworld --id example_4_1 --T 100 --train-function rlai.policy_gradient.monte_carlo.reinforce.improve --num-episodes 10 --v-S rlai.v_S.function_approximation.estimators.ApproximateStateValueEstimator --feature-extractor rlai.environments.gridworld.GridworldStateFeatureExtractor --function-approximation-model rlai.models.sklearn.SKLearnSGD --loss squared_error --sgd-alpha 0.0 --learning-rate constant --eta0 0.001 --policy rlai.policies.parameterized.discrete_action.SoftMaxInActionPreferencesPolicy --policy-feature-extractor rlai.environments.gridworld.GridworldFeatureExtractor --alpha 0.001 --update-upon-every-visit False --save-agent-path {tempfile.NamedTemporaryFile(delete=False).name}'))

    _, agent = load_checkpoint_and_agent(checkpoint_path, agent_path)

    # uncomment the following line and run test to update fixture
    # with open(f'{os.path.dirname(__file__)}/fixtures/test_policy_gradient_reinforce_softmax_action_preferences_with_baseline.pickle', 'wb') as f:
    #     pickle.dump(agent, f)

    with open(f'{os.path.dirname(__file__)}/fixtures/test_policy_gradient_reinforce_softmax_action_preferences_with_baseline.pickle', 'rb') as f:
        agent_fixture = pickle.load(f)

    assert_run(
        agent,
        agent_fixture
    )
Example #12
0
def test_policy_gradient_reinforce_normal_with_baseline():

    start_virtual_display_if_headless()

    checkpoint_path, agent_path = run(shlex.split(f'--random-seed 12345 --agent rlai.agents.mdp.ParameterizedMdpAgent --gamma 0.99 --environment rlai.environments.openai_gym.Gym --gym-id LunarLanderContinuous-v2 --render-every-nth-episode 2 --steps-per-second 1000 --plot-environment --T 2000 --train-function rlai.policy_gradient.monte_carlo.reinforce.improve --num-episodes 4 --v-S rlai.v_S.function_approximation.estimators.ApproximateStateValueEstimator --feature-extractor rlai.environments.openai_gym.ContinuousFeatureExtractor --function-approximation-model rlai.models.sklearn.SKLearnSGD --loss squared_error --sgd-alpha 0.0 --learning-rate constant --eta0 0.00001 --policy rlai.policies.parameterized.continuous_action.ContinuousActionNormalDistributionPolicy --policy-feature-extractor rlai.environments.openai_gym.ContinuousFeatureExtractor --plot-policy --alpha 0.00001 --update-upon-every-visit True --save-agent-path {tempfile.NamedTemporaryFile(delete=False).name}'))

    _, agent = load_checkpoint_and_agent(checkpoint_path, agent_path)

    # uncomment the following line and run test to update fixture
    # with open(f'{os.path.dirname(__file__)}/fixtures/test_policy_gradient_reinforce_normal_with_baseline.pickle', 'wb') as f:
    #     pickle.dump(agent, f)

    with open(f'{os.path.dirname(__file__)}/fixtures/test_policy_gradient_reinforce_normal_with_baseline.pickle', 'rb') as f:
        agent_fixture = pickle.load(f)

    assert_run(
        agent,
        agent_fixture
    )
Example #13
0
def test_gym_cartpole_tabular():

    start_virtual_display_if_headless()

    checkpoint_path, agent_path = run(shlex.split(f'--random-seed 12345 --agent rlai.agents.mdp.ActionValueMdpAgent --continuous-state-discretization-resolution 0.005 --gamma 0.95 --environment rlai.environments.openai_gym.Gym --gym-id CartPole-v1 --render-every-nth-episode 2 --train-function rlai.gpi.monte_carlo.iteration.iterate_value_q_pi --num-improvements 2 --num-episodes-per-improvement 2 --update-upon-every-visit True --epsilon 0.2 --q-S-A rlai.q_S_A.tabular.TabularStateActionValueEstimator --make-final-policy-greedy True --num-improvements-per-plot 2 --num-improvements-per-checkpoint 2 --checkpoint-path {tempfile.NamedTemporaryFile(delete=False).name} --save-agent-path {tempfile.NamedTemporaryFile(delete=False).name}'))

    _, agent = load_checkpoint_and_agent(checkpoint_path, agent_path)

    # uncomment the following line and run test to update fixture
    # with open(f'{os.path.dirname(__file__)}/fixtures/test_gym_cartpole_tabular.pickle', 'wb') as f:
    #     pickle.dump(agent, f)

    with open(f'{os.path.dirname(__file__)}/fixtures/test_gym_cartpole_tabular.pickle', 'rb') as f:
        agent_fixture = pickle.load(f)

    assert_run(
        agent,
        agent_fixture
    )
Example #14
0
def test_continuous_action_discretization():

    start_virtual_display_if_headless()

    checkpoint_path, agent_path = run(
        shlex.split(
            f'--random-seed 12345 --agent rlai.agents.mdp.StochasticMdpAgent --continuous-state-discretization-resolution 0.005 --gamma 0.95 --environment rlai.environments.openai_gym.Gym --gym-id MountainCarContinuous-v0 --T 20 --continuous-action-discretization-resolution 0.1 --render-every-nth-episode 2 --video-directory {tempfile.TemporaryDirectory().name} --force --train-function rlai.gpi.temporal_difference.iteration.iterate_value_q_pi --mode SARSA --num-improvements 2 --num-episodes-per-improvement 1 --epsilon 0.01 --q-S-A rlai.q_S_A.tabular.TabularStateActionValueEstimator --make-final-policy-greedy True --num-improvements-per-plot 2 --num-improvements-per-checkpoint 2 --checkpoint-path {tempfile.NamedTemporaryFile(delete=False).name} --save-agent-path {tempfile.NamedTemporaryFile(delete=False).name}'
        ))

    checkpoint, agent = load_checkpoint_and_agent(checkpoint_path, agent_path)

    # uncomment the following line and run test to update fixture
    # with open(f'{os.path.dirname(__file__)}/fixtures/test_continuous_action_discretization.pickle', 'wb') as f:
    #     pickle.dump((checkpoint, agent), f)

    with open(
            f'{os.path.dirname(__file__)}/fixtures/test_continuous_action_discretization.pickle',
            'rb') as f:
        checkpoint_fixture, agent_fixture = pickle.load(f)

    assert_run(checkpoint, agent, checkpoint_fixture, agent_fixture)
Example #15
0
def test_scale_learning_rate_with_logging():

    start_virtual_display_if_headless()

    checkpoint_path, agent_path = run(
        shlex.split(
            f'--random-seed 12345 --agent rlai.agents.mdp.StochasticMdpAgent --gamma 1 --environment rlai.environments.gridworld.Gridworld --id example_4_1 --T 25 --train-function rlai.gpi.temporal_difference.iteration.iterate_value_q_pi --mode Q_LEARNING --num-improvements 5 --num-episodes-per-improvement 50 --epsilon 0.05 --q-S-A rlai.q_S_A.function_approximation.estimators.ApproximateStateActionValueEstimator --function-approximation-model rlai.q_S_A.function_approximation.models.sklearn.SKLearnSGD --scale-eta0-for-y --feature-extractor rlai.environments.gridworld.GridworldFeatureExtractor --make-final-policy-greedy True --num-improvements-per-checkpoint 5 --checkpoint-path {tempfile.NamedTemporaryFile(delete=False).name} --save-agent-path {tempfile.NamedTemporaryFile(delete=False).name} --log INFO'
        ))

    checkpoint, agent = load_checkpoint_and_agent(checkpoint_path, agent_path)

    # uncomment the following line and run test to update fixture
    # with open(f'{os.path.dirname(__file__)}/fixtures/test_scale_learning_rate_with_logging.pickle', 'wb') as f:
    #     pickle.dump((checkpoint, agent), f)

    with open(
            f'{os.path.dirname(__file__)}/fixtures/test_scale_learning_rate_with_logging.pickle',
            'rb') as f:
        checkpoint_fixture, agent_fixture = pickle.load(f)

    assert_run(checkpoint, agent, checkpoint_fixture, agent_fixture)
Example #16
0
def test_gym_continuous_mountain_car():

    start_virtual_display_if_headless()

    checkpoint_path, agent_path = run(
        shlex.split(
            f'--random-seed 12345 --agent rlai.agents.mdp.StochasticMdpAgent --gamma 0.99 --environment rlai.environments.openai_gym.Gym --gym-id MountainCarContinuous-v0 --plot-environment --T 1000 --train-function rlai.policy_gradient.monte_carlo.reinforce.improve --num-episodes 2 --plot-state-value True --v-S rlai.v_S.function_approximation.estimators.ApproximateStateValueEstimator --feature-extractor rlai.environments.openai_gym.ContinuousMountainCarFeatureExtractor --function-approximation-model rlai.models.sklearn.SKLearnSGD --loss squared_loss --sgd-alpha 0.0 --learning-rate constant --eta0 0.01 --policy rlai.policies.parameterized.continuous_action.ContinuousActionBetaDistributionPolicy --policy-feature-extractor rlai.environments.openai_gym.ContinuousMountainCarFeatureExtractor --plot-policy --alpha 0.01 --update-upon-every-visit True --checkpoint-path {tempfile.NamedTemporaryFile(delete=False).name} --num-episodes-per-checkpoint 1 --save-agent-path {tempfile.NamedTemporaryFile(delete=False).name} --log DEBUG'
        ))

    checkpoint, agent = load_checkpoint_and_agent(checkpoint_path, agent_path)

    # uncomment the following line and run test to update fixture
    # with open(f'{os.path.dirname(__file__)}/fixtures/test_gym_continuous_mountain_car.pickle', 'wb') as f:
    #     pickle.dump((checkpoint, agent), f)

    with open(
            f'{os.path.dirname(__file__)}/fixtures/test_gym_continuous_mountain_car.pickle',
            'rb') as f:
        checkpoint_fixture, agent_fixture = pickle.load(f)

    assert_run(checkpoint, agent, checkpoint_fixture, agent_fixture)
Example #17
0
def test_learn():

    # set the following to True to update the fixture. if you do this, then you'll also need to start the robocode game
    # and uncomment some stuff in rlai.environments.network.TcpMdpEnvironment.read_from_client in order to update the
    # test fixture. run a battle for 10 rounds to complete the fixture update.
    update_fixture = False

    robocode_port = 54321

    robocode_mock_thread = None

    if not update_fixture:

        with open(f'{os.path.dirname(__file__)}/fixtures/test_robocode.pickle',
                  'rb') as file:
            state_sequence, fixture_pi, fixture_q_S_A = pickle.load(file)

        # set up a mock robocode game that sends state sequence
        def robocode_mock():

            # wait for environment to start up and listen for connections
            time.sleep(5)

            t = 0
            while t < len(state_sequence):

                # start episode by connecting
                with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
                    s.connect(('127.0.0.1', robocode_port))

                    try:

                        while t < len(state_sequence):

                            # send the current game state in the sequence
                            state_dict_json = state_sequence[t]
                            s.sendall(state_dict_json.encode('utf-8'))
                            t += 1

                            # receive next action
                            s.recv(99999999)

                            # if the next state starts a new episode, then break.
                            if t < len(state_sequence):
                                next_state_dict = json.loads(state_sequence[t])
                                if next_state_dict['state']['time'] == 0:
                                    break

                    # if environment closes connection during receive, it ends the episode.
                    except Exception:  # pragma no cover
                        pass

        robocode_mock_thread = Thread(target=robocode_mock)
        robocode_mock_thread.start()

    # run training and load resulting agent
    agent_path = tempfile.NamedTemporaryFile(delete=False).name
    cmd = f'--random-seed 12345 --agent rlai.environments.robocode.RobocodeAgent --gamma 0.95 --environment rlai.environments.robocode.RobocodeEnvironment --port {robocode_port} --bullet-power-decay 0.75 --train-function rlai.gpi.temporal_difference.iteration.iterate_value_q_pi --mode SARSA --n-steps 50 --num-improvements 10 --num-episodes-per-improvement 1 --num-updates-per-improvement 1 --epsilon 0.25 --q-S-A rlai.q_S_A.function_approximation.estimators.ApproximateStateActionValueEstimator --function-approximation-model rlai.q_S_A.function_approximation.models.sklearn.SKLearnSGD --loss squared_loss --sgd-alpha 0.0 --learning-rate constant --eta0 0.0001 --feature-extractor rlai.environments.robocode.RobocodeFeatureExtractor --scanned-robot-decay 0.75 --make-final-policy-greedy True --num-improvements-per-plot 100 --save-agent-path {agent_path} --log DEBUG'
    run(shlex.split(cmd))

    if not update_fixture:
        robocode_mock_thread.join()

    with open(agent_path, 'rb') as f:
        agent = pickle.load(f)

    # if we're updating the test fixture, then save the state sequence and resulting policy to disk.
    if update_fixture:  # pragma no cover
        with open(os.path.expanduser('~/Desktop/state_sequence.txt'),
                  'r') as f:
            state_sequence = f.readlines()
        with open(f'{os.path.dirname(__file__)}/fixtures/test_robocode.pickle',
                  'wb') as file:
            pickle.dump((state_sequence, agent.pi, agent.pi.estimator), file)
    else:
        assert np.allclose(agent.pi.estimator.model.model.coef_,
                           fixture_q_S_A.model.model.coef_)
        assert np.allclose(agent.pi.estimator.model.model.intercept_,
                           fixture_q_S_A.model.model.intercept_)
Example #18
0
def test_missing_arguments():

    run(shlex.split('--agent rlai.agents.mdp.ActionValueMdpAgent --gamma 1 --environment rlai.environments.gridworld.Gridworld --id example_4_1 --train-function rlai.gpi.temporal_difference.iteration.iterate_value_q_pi --mode Q_LEARNING --num-improvements 10 --num-episodes-per-improvement 5 --epsilon 0.01 --q-S-A rlai.q_S_A.tabular.TabularStateActionValueEstimator --make-final-policy-greedy True'))
Example #19
0
 def train_thread_target():
     run(args=args,
         thread_manager=thread_manager,
         train_function_args_callback=train_args_callback)
Example #20
0
def test_continuous_learn():

    # set the following to True to update the fixture. if you do this, then you'll also need to start the robocode game
    # and uncomment some stuff in rlai.environments.network.TcpMdpEnvironment.read_from_client in order to update the
    # test fixture. run a battle for 10 rounds to complete the fixture update.
    update_fixture = False

    # updating the fixture uses the game (per above) on 54321. running test needs to be on a different port to avoid
    # conflict with the other robocode test that runs in parallel.
    robocode_port = 54321 if update_fixture else 54322

    robocode_mock_thread = None

    if not update_fixture:

        with open(
                f'{os.path.dirname(__file__)}/fixtures/test_continuous_learn.pickle',
                'rb') as file:
            state_sequence, fixture_pi = pickle.load(file)

        # set up a mock robocode game that sends state sequence
        def robocode_mock():

            # wait for environment to start up and listen for connections
            time.sleep(5)

            t = 0
            while t < len(state_sequence):

                # start episode by connecting
                with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
                    s.connect(('127.0.0.1', robocode_port))

                    try:

                        while t < len(state_sequence):

                            # send the current game state in the sequence
                            state_dict_json = state_sequence[t]
                            s.sendall(state_dict_json.encode('utf-8'))
                            t += 1

                            # receive next action
                            s.recv(99999999)

                            # if the next state starts a new episode, then break.
                            if t < len(state_sequence):
                                next_state_dict = json.loads(state_sequence[t])
                                if next_state_dict['state']['time'] == 0:
                                    break

                    # if environment closes connection during receive, it ends the episode.
                    except Exception:  # pragma no cover
                        pass

        robocode_mock_thread = Thread(target=robocode_mock)
        robocode_mock_thread.start()

    # run training and load resulting agent
    agent_path = tempfile.NamedTemporaryFile(delete=False).name
    cmd = f'--random-seed 12345 --agent rlai.environments.robocode_continuous_action.RobocodeAgent --gamma 1.0 --environment rlai.environments.robocode_continuous_action.RobocodeEnvironment --port {robocode_port} --bullet-power-decay 0.75 --train-function rlai.policy_gradient.monte_carlo.reinforce.improve --num-episodes 10 --v-S rlai.v_S.function_approximation.estimators.ApproximateStateValueEstimator --feature-extractor rlai.environments.robocode_continuous_action.RobocodeFeatureExtractor --function-approximation-model rlai.models.sklearn.SKLearnSGD --loss squared_loss --sgd-alpha 0.0 --learning-rate constant --eta0 0.00001 --policy rlai.policies.parameterized.continuous_action.ContinuousActionBetaDistributionPolicy --policy-feature-extractor rlai.environments.robocode_continuous_action.RobocodeFeatureExtractor --alpha 0.00001 --update-upon-every-visit True --save-agent-path {agent_path} --log DEBUG'

    run(shlex.split(cmd))

    if not update_fixture:
        robocode_mock_thread.join()

    with open(agent_path, 'rb') as f:
        agent = pickle.load(f)

    # if we're updating the test fixture, then save the state sequence and resulting policy to disk.
    if update_fixture:  # pragma no cover
        with open(os.path.expanduser('~/Desktop/state_sequence.txt'),
                  'r') as f:
            state_sequence = f.readlines()
        with open(
                f'{os.path.dirname(__file__)}/fixtures/test_continuous_learn.pickle',
                'wb') as file:
            pickle.dump((state_sequence, agent.pi), file)
    else:
        assert agent.pi == fixture_pi
Example #21
0
def test_help():

    with pytest.raises(ValueError,
                       match='No training function specified. Cannot train.'):
        run(shlex.split('--agent rlai.agents.mdp.StochasticMdpAgent --help'))