def test_manual_versus_jax_policy_gradient(): manual_agent_path = tempfile.NamedTemporaryFile(delete=False).name run( shlex.split( f'--random-seed 12345 --agent rlai.agents.mdp.ParameterizedMdpAgent --gamma 1 --environment rlai.environments.gridworld.Gridworld --id example_4_1 --train-function rlai.policy_gradient.monte_carlo.reinforce.improve --num-episodes 10 --policy rlai.policies.parameterized.discrete_action.SoftMaxInActionPreferencesPolicy --policy-feature-extractor rlai.environments.gridworld.GridworldFeatureExtractor --alpha 0.0001 --update-upon-every-visit True --save-agent-path {manual_agent_path} --log DEBUG' )) with open(manual_agent_path, 'rb') as f: manual_agent = pickle.load(f) jax_agent_path = tempfile.NamedTemporaryFile(delete=False).name run( shlex.split( f'--random-seed 12345 --agent rlai.agents.mdp.ParameterizedMdpAgent --gamma 1 --environment rlai.environments.gridworld.Gridworld --id example_4_1 --train-function rlai.policy_gradient.monte_carlo.reinforce.improve --num-episodes 10 --policy rlai.policies.parameterized.discrete_action.SoftMaxInActionPreferencesJaxPolicy --policy-feature-extractor rlai.environments.gridworld.GridworldFeatureExtractor --alpha 0.0001 --update-upon-every-visit True --save-agent-path {jax_agent_path} --log DEBUG' )) with open(jax_agent_path, 'rb') as f: jax_agent = pickle.load(f) # uncomment the following line and run test to update fixture # with open(f'{os.path.dirname(__file__)}/fixtures/test_manual_versus_jax_policy_gradient.pickle', 'wb') as file: # pickle.dump(jax_agent, file) with open( f'{os.path.dirname(__file__)}/fixtures/test_manual_versus_jax_policy_gradient.pickle', 'rb') as file: fixture_agent = pickle.load(file) assert np.allclose(manual_agent.pi.theta, jax_agent.pi.theta) assert np.allclose(jax_agent.pi.theta, fixture_agent.pi.theta)
def test_unparsed_arguments(): with pytest.raises(ValueError, match='Unparsed arguments'): run( shlex.split( '--agent rlai.agents.mdp.StochasticMdpAgent --gamma 1 --environment rlai.environments.gridworld.Gridworld --id example_4_1 --train-function rlai.gpi.temporal_difference.iteration.iterate_value_q_pi --mode Q_LEARNING --num-improvements 10 --num-episodes-per-improvement 5 --epsilon 0.01 --q-S-A rlai.q_S_A.tabular.TabularStateActionValueEstimator --make-final-policy-greedy True --XXXX' ))
def test_resume(): checkpoint_path, agent_path = run( shlex.split( f'--random-seed 12345 --agent rlai.agents.mdp.ParameterizedMdpAgent --gamma 1.0 --environment rlai.environments.openai_gym.Gym --gym-id LunarLanderContinuous-v2 --plot-environment --T 500 --train-function rlai.policy_gradient.monte_carlo.reinforce.improve --num-episodes 2 --plot-state-value True --v-S rlai.v_S.function_approximation.estimators.ApproximateStateValueEstimator --feature-extractor rlai.environments.openai_gym.ContinuousLunarLanderFeatureExtractor --function-approximation-model rlai.models.sklearn.SKLearnSGD --loss squared_error --sgd-alpha 0.0 --learning-rate constant --eta0 0.0001 --policy rlai.policies.parameterized.continuous_action.ContinuousActionBetaDistributionPolicy --policy-feature-extractor rlai.environments.openai_gym.ContinuousLunarLanderFeatureExtractor --plot-policy --alpha 0.0001 --update-upon-every-visit True --checkpoint-path {tempfile.NamedTemporaryFile(delete=False).name} --num-episodes-per-checkpoint 1 --save-agent-path {tempfile.NamedTemporaryFile(delete=False).name} --log DEBUG' )) _, resumed_agent_path = run( shlex.split( f'--resume --random-seed 12345 --train-function rlai.policy_gradient.monte_carlo.reinforce.improve --num-episodes 2 --checkpoint-path {checkpoint_path} --save-agent-path {tempfile.NamedTemporaryFile(delete=False).name}' )) with open(resumed_agent_path, 'rb') as f: agent = pickle.load(f) # uncomment the following line and run test to update fixture # with open(f'{os.path.dirname(__file__)}/fixtures/test_resume.pickle', 'wb') as file: # pickle.dump(agent, file) with open(f'{os.path.dirname(__file__)}/fixtures/test_resume.pickle', 'rb') as file: agent_fixture = pickle.load(file) # assert that we get the expected result assert agent.pi == agent_fixture.pi
def test_too_many_coefficients_for_plot_model(): old_vals = (rlai.q_S_A.function_approximation.models.MAX_PLOT_COEFFICIENTS, rlai.q_S_A.function_approximation.models.MAX_PLOT_ACTIONS) rlai.q_S_A.function_approximation.models.MAX_PLOT_COEFFICIENTS = 2 rlai.q_S_A.function_approximation.models.MAX_PLOT_ACTIONS = 2 run( shlex.split( f'--random-seed 12345 --agent rlai.agents.mdp.StochasticMdpAgent --gamma 1 --environment rlai.environments.gridworld.Gridworld --id example_4_1 --T 25 --train-function rlai.gpi.temporal_difference.iteration.iterate_value_q_pi --mode SARSA --num-improvements 10 --num-episodes-per-improvement 50 --epsilon 0.05 --q-S-A rlai.q_S_A.function_approximation.estimators.ApproximateStateActionValueEstimator --plot-model --plot-model-bins 10 --function-approximation-model rlai.q_S_A.function_approximation.models.sklearn.SKLearnSGD --feature-extractor rlai.environments.gridworld.GridworldFeatureExtractor --make-final-policy-greedy True --num-improvements-per-checkpoint 5 --checkpoint-path {tempfile.NamedTemporaryFile(delete=False).name} --save-agent-path {tempfile.NamedTemporaryFile(delete=False).name}' )) (rlai.q_S_A.function_approximation.models.MAX_PLOT_COEFFICIENTS, rlai.q_S_A.function_approximation.models.MAX_PLOT_ACTIONS) = old_vals
def test_resume_gym_valid_environment(): start_virtual_display_if_headless() def resume_args_mutator(resume_args: Dict): print(f'Called mutator: {len(resume_args)} resume arguments.') def train_function_args_callback(args: Dict): print(f'Called callback: {len(args)} resume arguments.') run_args = f'--random-seed 12345 --agent rlai.agents.mdp.StochasticMdpAgent --continuous-state-discretization-resolution 0.005 --gamma 0.95 --environment rlai.environments.openai_gym.Gym --gym-id CartPole-v1 --render-every-nth-episode 2 --train-function rlai.gpi.monte_carlo.iteration.iterate_value_q_pi --num-improvements 2 --num-episodes-per-improvement 2 --update-upon-every-visit True --epsilon 0.2 --q-S-A rlai.q_S_A.tabular.TabularStateActionValueEstimator --make-final-policy-greedy False --num-improvements-per-plot 2 --num-improvements-per-checkpoint 2 --checkpoint-path {tempfile.NamedTemporaryFile(delete=False).name} --save-agent-path {tempfile.NamedTemporaryFile(delete=False).name}' checkpoint_path, agent_path = run( args=shlex.split(run_args), train_function_args_callback=train_function_args_callback) random_state = RandomState(12345) resume_environment = Gym(random_state, None, 'CartPole-v1', None) agent = resume_from_checkpoint(checkpoint_path, iterate_value_q_pi, environment=resume_environment, num_improvements=2, resume_args_mutator=resume_args_mutator) resume_environment.close() # uncomment the following line and run test to update fixture # with open(f'{os.path.dirname(__file__)}/fixtures/test_resume_gym_valid_environment.pickle', 'wb') as file: # pickle.dump(agent.pi, file) with open( f'{os.path.dirname(__file__)}/fixtures/test_resume_gym_valid_environment.pickle', 'rb') as file: pi_fixture = pickle.load(file) assert agent.pi == pi_fixture
def test_resume(): checkpoint_path, agent_path = run(shlex.split(f'--random-seed 12345 --agent rlai.agents.mdp.ActionValueMdpAgent --gamma 1 --environment rlai.environments.gridworld.Gridworld --id example_4_1 --train-function rlai.gpi.temporal_difference.iteration.iterate_value_q_pi --mode Q_LEARNING --num-improvements 10 --num-episodes-per-improvement 5 --epsilon 0.01 --q-S-A rlai.q_S_A.tabular.TabularStateActionValueEstimator --make-final-policy-greedy True --num-improvements-per-checkpoint 10 --checkpoint-path {tempfile.NamedTemporaryFile(delete=False).name} --save-agent-path {tempfile.NamedTemporaryFile(delete=False).name}')) _, resumed_agent_path = run(shlex.split(f'--resume --random-seed 12345 --train-function rlai.gpi.temporal_difference.iteration.iterate_value_q_pi --num-improvements 10 --checkpoint-path {checkpoint_path} --save-agent-path {tempfile.NamedTemporaryFile(delete=False).name}')) with open(resumed_agent_path, 'rb') as f: agent = pickle.load(f) # uncomment the following line and run test to update fixture # with open(f'{os.path.dirname(__file__)}/fixtures/test_resume.pickle', 'wb') as file: # pickle.dump(agent, file) with open(f'{os.path.dirname(__file__)}/fixtures/test_resume.pickle', 'rb') as file: agent_fixture = pickle.load(file) assert agent.pi == agent_fixture.pi
def test_train(): checkpoint_path_top_level, agent_path_top_level = top_level.run( shlex.split( f'train --random-seed 12345 --agent rlai.agents.mdp.StochasticMdpAgent --gamma 1 --environment rlai.environments.gridworld.Gridworld --id example_4_1 --train-function rlai.gpi.temporal_difference.iteration.iterate_value_q_pi --mode Q_LEARNING --num-improvements 10 --num-episodes-per-improvement 5 --epsilon 0.01 --q-S-A rlai.q_S_A.tabular.TabularStateActionValueEstimator --make-final-policy-greedy True --num-improvements-per-checkpoint 10 --checkpoint-path {tempfile.NamedTemporaryFile(delete=False).name} --save-agent-path {tempfile.NamedTemporaryFile(delete=False).name}' )) with open(agent_path_top_level, 'rb') as f: agent_top_level = pickle.load(f) checkpoint_path_train, agent_path_train = trainer.run( shlex.split( f'--random-seed 12345 --agent rlai.agents.mdp.StochasticMdpAgent --gamma 1 --environment rlai.environments.gridworld.Gridworld --id example_4_1 --train-function rlai.gpi.temporal_difference.iteration.iterate_value_q_pi --mode Q_LEARNING --num-improvements 10 --num-episodes-per-improvement 5 --epsilon 0.01 --q-S-A rlai.q_S_A.tabular.TabularStateActionValueEstimator --make-final-policy-greedy True --num-improvements-per-checkpoint 10 --checkpoint-path {tempfile.NamedTemporaryFile(delete=False).name} --save-agent-path {tempfile.NamedTemporaryFile(delete=False).name}' )) with open(agent_path_train, 'rb') as f: agent_train = pickle.load(f) assert agent_top_level.pi == agent_train.pi
def test_gym_cartpole_function_approximation_plot_model(): start_virtual_display_if_headless() checkpoint_path, agent_path = run(shlex.split(f'--random-seed 12345 --agent rlai.agents.mdp.ActionValueMdpAgent --gamma 0.95 --environment rlai.environments.openai_gym.Gym --gym-id CartPole-v1 --render-every-nth-episode 2 --train-function rlai.gpi.temporal_difference.iteration.iterate_value_q_pi --mode SARSA --num-improvements 2 --num-episodes-per-improvement 2 --num-updates-per-improvement 1 --epsilon 0.2 --q-S-A rlai.q_S_A.function_approximation.estimators.ApproximateStateActionValueEstimator --plot-model --plot-model-bins 10 --function-approximation-model rlai.q_S_A.function_approximation.models.sklearn.SKLearnSGD --loss squared_error --sgd-alpha 0.0 --learning-rate constant --eta0 0.001 --feature-extractor rlai.environments.openai_gym.CartpoleFeatureExtractor --make-final-policy-greedy True --num-improvements-per-plot 2 --num-improvements-per-checkpoint 2 --checkpoint-path {tempfile.NamedTemporaryFile(delete=False).name} --save-agent-path {tempfile.NamedTemporaryFile(delete=False).name}')) _, agent = load_checkpoint_and_agent(checkpoint_path, agent_path) # uncomment the following line and run test to update fixture # with open(f'{os.path.dirname(__file__)}/fixtures/test_gym_cartpole_function_approximation_plot_model.pickle', 'wb') as f: # pickle.dump(agent, f) with open(f'{os.path.dirname(__file__)}/fixtures/test_gym_cartpole_function_approximation_plot_model.pickle', 'rb') as f: agent_fixture = pickle.load(f) assert_run( agent, agent_fixture )
def test_q_learning_with_patsy_formula(): start_virtual_display_if_headless() checkpoint_path, agent_path = run(shlex.split(f'--random-seed 12345 --agent rlai.agents.mdp.ActionValueMdpAgent --gamma 1 --environment rlai.environments.gridworld.Gridworld --id example_4_1 --T 25 --train-function rlai.gpi.temporal_difference.iteration.iterate_value_q_pi --mode Q_LEARNING --num-improvements 5 --num-episodes-per-improvement 5 --epsilon 0.05 --q-S-A rlai.q_S_A.function_approximation.estimators.ApproximateStateActionValueEstimator --function-approximation-model rlai.q_S_A.function_approximation.models.sklearn.SKLearnSGD --verbose 1 --feature-extractor rlai.q_S_A.function_approximation.models.feature_extraction.StateActionIdentityFeatureExtractor --formula "C(s, levels={list(range(16))}):C(a, levels={list(range(4))})" --make-final-policy-greedy True --num-improvements-per-checkpoint 5 --checkpoint-path {tempfile.NamedTemporaryFile(delete=False).name} --save-agent-path {tempfile.NamedTemporaryFile(delete=False).name}')) _, agent = load_checkpoint_and_agent(checkpoint_path, agent_path) # uncomment the following line and run test to update fixture # with open(f'{os.path.dirname(__file__)}/fixtures/test_q_learning_with_patsy_formula.pickle', 'wb') as f: # pickle.dump(agent, f) with open(f'{os.path.dirname(__file__)}/fixtures/test_q_learning_with_patsy_formula.pickle', 'rb') as f: agent_fixture = pickle.load(f) assert_run( agent, agent_fixture )
def test_prioritized_sweeping_planning_high_threshold(): start_virtual_display_if_headless() checkpoint_path, agent_path = run(shlex.split(f'--random-seed 12345 --agent rlai.agents.mdp.ActionValueMdpAgent --gamma 1 --environment rlai.environments.gridworld.Gridworld --id example_4_1 --planning-environment rlai.environments.mdp.PrioritizedSweepingMdpPlanningEnvironment --num-planning-improvements-per-direct-improvement 10 --priority-theta -10 --T-planning 50 --train-function rlai.gpi.temporal_difference.iteration.iterate_value_q_pi --mode Q_LEARNING --num-improvements 10 --num-episodes-per-improvement 1 --epsilon 0.01 --q-S-A rlai.q_S_A.tabular.TabularStateActionValueEstimator --make-final-policy-greedy True --num-improvements-per-checkpoint 10 --checkpoint-path {tempfile.NamedTemporaryFile(delete=False).name} --save-agent-path {tempfile.NamedTemporaryFile(delete=False).name}')) _, agent = load_checkpoint_and_agent(checkpoint_path, agent_path) # uncomment the following line and run test to update fixture # with open(f'{os.path.dirname(__file__)}/fixtures/test_prioritized_sweeping_planning_high_threshold.pickle', 'wb') as f: # pickle.dump(agent, f) with open(f'{os.path.dirname(__file__)}/fixtures/test_prioritized_sweeping_planning_high_threshold.pickle', 'rb') as f: agent_fixture = pickle.load(f) assert_run( agent, agent_fixture )
def test_policy_gradient_reinforce_softmax_action_preferences_with_baseline(): start_virtual_display_if_headless() checkpoint_path, agent_path = run(shlex.split(f'--random-seed 12345 --agent rlai.agents.mdp.ParameterizedMdpAgent --gamma 1 --environment rlai.environments.gridworld.Gridworld --id example_4_1 --T 100 --train-function rlai.policy_gradient.monte_carlo.reinforce.improve --num-episodes 10 --v-S rlai.v_S.function_approximation.estimators.ApproximateStateValueEstimator --feature-extractor rlai.environments.gridworld.GridworldStateFeatureExtractor --function-approximation-model rlai.models.sklearn.SKLearnSGD --loss squared_error --sgd-alpha 0.0 --learning-rate constant --eta0 0.001 --policy rlai.policies.parameterized.discrete_action.SoftMaxInActionPreferencesPolicy --policy-feature-extractor rlai.environments.gridworld.GridworldFeatureExtractor --alpha 0.001 --update-upon-every-visit False --save-agent-path {tempfile.NamedTemporaryFile(delete=False).name}')) _, agent = load_checkpoint_and_agent(checkpoint_path, agent_path) # uncomment the following line and run test to update fixture # with open(f'{os.path.dirname(__file__)}/fixtures/test_policy_gradient_reinforce_softmax_action_preferences_with_baseline.pickle', 'wb') as f: # pickle.dump(agent, f) with open(f'{os.path.dirname(__file__)}/fixtures/test_policy_gradient_reinforce_softmax_action_preferences_with_baseline.pickle', 'rb') as f: agent_fixture = pickle.load(f) assert_run( agent, agent_fixture )
def test_policy_gradient_reinforce_normal_with_baseline(): start_virtual_display_if_headless() checkpoint_path, agent_path = run(shlex.split(f'--random-seed 12345 --agent rlai.agents.mdp.ParameterizedMdpAgent --gamma 0.99 --environment rlai.environments.openai_gym.Gym --gym-id LunarLanderContinuous-v2 --render-every-nth-episode 2 --steps-per-second 1000 --plot-environment --T 2000 --train-function rlai.policy_gradient.monte_carlo.reinforce.improve --num-episodes 4 --v-S rlai.v_S.function_approximation.estimators.ApproximateStateValueEstimator --feature-extractor rlai.environments.openai_gym.ContinuousFeatureExtractor --function-approximation-model rlai.models.sklearn.SKLearnSGD --loss squared_error --sgd-alpha 0.0 --learning-rate constant --eta0 0.00001 --policy rlai.policies.parameterized.continuous_action.ContinuousActionNormalDistributionPolicy --policy-feature-extractor rlai.environments.openai_gym.ContinuousFeatureExtractor --plot-policy --alpha 0.00001 --update-upon-every-visit True --save-agent-path {tempfile.NamedTemporaryFile(delete=False).name}')) _, agent = load_checkpoint_and_agent(checkpoint_path, agent_path) # uncomment the following line and run test to update fixture # with open(f'{os.path.dirname(__file__)}/fixtures/test_policy_gradient_reinforce_normal_with_baseline.pickle', 'wb') as f: # pickle.dump(agent, f) with open(f'{os.path.dirname(__file__)}/fixtures/test_policy_gradient_reinforce_normal_with_baseline.pickle', 'rb') as f: agent_fixture = pickle.load(f) assert_run( agent, agent_fixture )
def test_gym_cartpole_tabular(): start_virtual_display_if_headless() checkpoint_path, agent_path = run(shlex.split(f'--random-seed 12345 --agent rlai.agents.mdp.ActionValueMdpAgent --continuous-state-discretization-resolution 0.005 --gamma 0.95 --environment rlai.environments.openai_gym.Gym --gym-id CartPole-v1 --render-every-nth-episode 2 --train-function rlai.gpi.monte_carlo.iteration.iterate_value_q_pi --num-improvements 2 --num-episodes-per-improvement 2 --update-upon-every-visit True --epsilon 0.2 --q-S-A rlai.q_S_A.tabular.TabularStateActionValueEstimator --make-final-policy-greedy True --num-improvements-per-plot 2 --num-improvements-per-checkpoint 2 --checkpoint-path {tempfile.NamedTemporaryFile(delete=False).name} --save-agent-path {tempfile.NamedTemporaryFile(delete=False).name}')) _, agent = load_checkpoint_and_agent(checkpoint_path, agent_path) # uncomment the following line and run test to update fixture # with open(f'{os.path.dirname(__file__)}/fixtures/test_gym_cartpole_tabular.pickle', 'wb') as f: # pickle.dump(agent, f) with open(f'{os.path.dirname(__file__)}/fixtures/test_gym_cartpole_tabular.pickle', 'rb') as f: agent_fixture = pickle.load(f) assert_run( agent, agent_fixture )
def test_continuous_action_discretization(): start_virtual_display_if_headless() checkpoint_path, agent_path = run( shlex.split( f'--random-seed 12345 --agent rlai.agents.mdp.StochasticMdpAgent --continuous-state-discretization-resolution 0.005 --gamma 0.95 --environment rlai.environments.openai_gym.Gym --gym-id MountainCarContinuous-v0 --T 20 --continuous-action-discretization-resolution 0.1 --render-every-nth-episode 2 --video-directory {tempfile.TemporaryDirectory().name} --force --train-function rlai.gpi.temporal_difference.iteration.iterate_value_q_pi --mode SARSA --num-improvements 2 --num-episodes-per-improvement 1 --epsilon 0.01 --q-S-A rlai.q_S_A.tabular.TabularStateActionValueEstimator --make-final-policy-greedy True --num-improvements-per-plot 2 --num-improvements-per-checkpoint 2 --checkpoint-path {tempfile.NamedTemporaryFile(delete=False).name} --save-agent-path {tempfile.NamedTemporaryFile(delete=False).name}' )) checkpoint, agent = load_checkpoint_and_agent(checkpoint_path, agent_path) # uncomment the following line and run test to update fixture # with open(f'{os.path.dirname(__file__)}/fixtures/test_continuous_action_discretization.pickle', 'wb') as f: # pickle.dump((checkpoint, agent), f) with open( f'{os.path.dirname(__file__)}/fixtures/test_continuous_action_discretization.pickle', 'rb') as f: checkpoint_fixture, agent_fixture = pickle.load(f) assert_run(checkpoint, agent, checkpoint_fixture, agent_fixture)
def test_scale_learning_rate_with_logging(): start_virtual_display_if_headless() checkpoint_path, agent_path = run( shlex.split( f'--random-seed 12345 --agent rlai.agents.mdp.StochasticMdpAgent --gamma 1 --environment rlai.environments.gridworld.Gridworld --id example_4_1 --T 25 --train-function rlai.gpi.temporal_difference.iteration.iterate_value_q_pi --mode Q_LEARNING --num-improvements 5 --num-episodes-per-improvement 50 --epsilon 0.05 --q-S-A rlai.q_S_A.function_approximation.estimators.ApproximateStateActionValueEstimator --function-approximation-model rlai.q_S_A.function_approximation.models.sklearn.SKLearnSGD --scale-eta0-for-y --feature-extractor rlai.environments.gridworld.GridworldFeatureExtractor --make-final-policy-greedy True --num-improvements-per-checkpoint 5 --checkpoint-path {tempfile.NamedTemporaryFile(delete=False).name} --save-agent-path {tempfile.NamedTemporaryFile(delete=False).name} --log INFO' )) checkpoint, agent = load_checkpoint_and_agent(checkpoint_path, agent_path) # uncomment the following line and run test to update fixture # with open(f'{os.path.dirname(__file__)}/fixtures/test_scale_learning_rate_with_logging.pickle', 'wb') as f: # pickle.dump((checkpoint, agent), f) with open( f'{os.path.dirname(__file__)}/fixtures/test_scale_learning_rate_with_logging.pickle', 'rb') as f: checkpoint_fixture, agent_fixture = pickle.load(f) assert_run(checkpoint, agent, checkpoint_fixture, agent_fixture)
def test_gym_continuous_mountain_car(): start_virtual_display_if_headless() checkpoint_path, agent_path = run( shlex.split( f'--random-seed 12345 --agent rlai.agents.mdp.StochasticMdpAgent --gamma 0.99 --environment rlai.environments.openai_gym.Gym --gym-id MountainCarContinuous-v0 --plot-environment --T 1000 --train-function rlai.policy_gradient.monte_carlo.reinforce.improve --num-episodes 2 --plot-state-value True --v-S rlai.v_S.function_approximation.estimators.ApproximateStateValueEstimator --feature-extractor rlai.environments.openai_gym.ContinuousMountainCarFeatureExtractor --function-approximation-model rlai.models.sklearn.SKLearnSGD --loss squared_loss --sgd-alpha 0.0 --learning-rate constant --eta0 0.01 --policy rlai.policies.parameterized.continuous_action.ContinuousActionBetaDistributionPolicy --policy-feature-extractor rlai.environments.openai_gym.ContinuousMountainCarFeatureExtractor --plot-policy --alpha 0.01 --update-upon-every-visit True --checkpoint-path {tempfile.NamedTemporaryFile(delete=False).name} --num-episodes-per-checkpoint 1 --save-agent-path {tempfile.NamedTemporaryFile(delete=False).name} --log DEBUG' )) checkpoint, agent = load_checkpoint_and_agent(checkpoint_path, agent_path) # uncomment the following line and run test to update fixture # with open(f'{os.path.dirname(__file__)}/fixtures/test_gym_continuous_mountain_car.pickle', 'wb') as f: # pickle.dump((checkpoint, agent), f) with open( f'{os.path.dirname(__file__)}/fixtures/test_gym_continuous_mountain_car.pickle', 'rb') as f: checkpoint_fixture, agent_fixture = pickle.load(f) assert_run(checkpoint, agent, checkpoint_fixture, agent_fixture)
def test_learn(): # set the following to True to update the fixture. if you do this, then you'll also need to start the robocode game # and uncomment some stuff in rlai.environments.network.TcpMdpEnvironment.read_from_client in order to update the # test fixture. run a battle for 10 rounds to complete the fixture update. update_fixture = False robocode_port = 54321 robocode_mock_thread = None if not update_fixture: with open(f'{os.path.dirname(__file__)}/fixtures/test_robocode.pickle', 'rb') as file: state_sequence, fixture_pi, fixture_q_S_A = pickle.load(file) # set up a mock robocode game that sends state sequence def robocode_mock(): # wait for environment to start up and listen for connections time.sleep(5) t = 0 while t < len(state_sequence): # start episode by connecting with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: s.connect(('127.0.0.1', robocode_port)) try: while t < len(state_sequence): # send the current game state in the sequence state_dict_json = state_sequence[t] s.sendall(state_dict_json.encode('utf-8')) t += 1 # receive next action s.recv(99999999) # if the next state starts a new episode, then break. if t < len(state_sequence): next_state_dict = json.loads(state_sequence[t]) if next_state_dict['state']['time'] == 0: break # if environment closes connection during receive, it ends the episode. except Exception: # pragma no cover pass robocode_mock_thread = Thread(target=robocode_mock) robocode_mock_thread.start() # run training and load resulting agent agent_path = tempfile.NamedTemporaryFile(delete=False).name cmd = f'--random-seed 12345 --agent rlai.environments.robocode.RobocodeAgent --gamma 0.95 --environment rlai.environments.robocode.RobocodeEnvironment --port {robocode_port} --bullet-power-decay 0.75 --train-function rlai.gpi.temporal_difference.iteration.iterate_value_q_pi --mode SARSA --n-steps 50 --num-improvements 10 --num-episodes-per-improvement 1 --num-updates-per-improvement 1 --epsilon 0.25 --q-S-A rlai.q_S_A.function_approximation.estimators.ApproximateStateActionValueEstimator --function-approximation-model rlai.q_S_A.function_approximation.models.sklearn.SKLearnSGD --loss squared_loss --sgd-alpha 0.0 --learning-rate constant --eta0 0.0001 --feature-extractor rlai.environments.robocode.RobocodeFeatureExtractor --scanned-robot-decay 0.75 --make-final-policy-greedy True --num-improvements-per-plot 100 --save-agent-path {agent_path} --log DEBUG' run(shlex.split(cmd)) if not update_fixture: robocode_mock_thread.join() with open(agent_path, 'rb') as f: agent = pickle.load(f) # if we're updating the test fixture, then save the state sequence and resulting policy to disk. if update_fixture: # pragma no cover with open(os.path.expanduser('~/Desktop/state_sequence.txt'), 'r') as f: state_sequence = f.readlines() with open(f'{os.path.dirname(__file__)}/fixtures/test_robocode.pickle', 'wb') as file: pickle.dump((state_sequence, agent.pi, agent.pi.estimator), file) else: assert np.allclose(agent.pi.estimator.model.model.coef_, fixture_q_S_A.model.model.coef_) assert np.allclose(agent.pi.estimator.model.model.intercept_, fixture_q_S_A.model.model.intercept_)
def test_missing_arguments(): run(shlex.split('--agent rlai.agents.mdp.ActionValueMdpAgent --gamma 1 --environment rlai.environments.gridworld.Gridworld --id example_4_1 --train-function rlai.gpi.temporal_difference.iteration.iterate_value_q_pi --mode Q_LEARNING --num-improvements 10 --num-episodes-per-improvement 5 --epsilon 0.01 --q-S-A rlai.q_S_A.tabular.TabularStateActionValueEstimator --make-final-policy-greedy True'))
def train_thread_target(): run(args=args, thread_manager=thread_manager, train_function_args_callback=train_args_callback)
def test_continuous_learn(): # set the following to True to update the fixture. if you do this, then you'll also need to start the robocode game # and uncomment some stuff in rlai.environments.network.TcpMdpEnvironment.read_from_client in order to update the # test fixture. run a battle for 10 rounds to complete the fixture update. update_fixture = False # updating the fixture uses the game (per above) on 54321. running test needs to be on a different port to avoid # conflict with the other robocode test that runs in parallel. robocode_port = 54321 if update_fixture else 54322 robocode_mock_thread = None if not update_fixture: with open( f'{os.path.dirname(__file__)}/fixtures/test_continuous_learn.pickle', 'rb') as file: state_sequence, fixture_pi = pickle.load(file) # set up a mock robocode game that sends state sequence def robocode_mock(): # wait for environment to start up and listen for connections time.sleep(5) t = 0 while t < len(state_sequence): # start episode by connecting with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: s.connect(('127.0.0.1', robocode_port)) try: while t < len(state_sequence): # send the current game state in the sequence state_dict_json = state_sequence[t] s.sendall(state_dict_json.encode('utf-8')) t += 1 # receive next action s.recv(99999999) # if the next state starts a new episode, then break. if t < len(state_sequence): next_state_dict = json.loads(state_sequence[t]) if next_state_dict['state']['time'] == 0: break # if environment closes connection during receive, it ends the episode. except Exception: # pragma no cover pass robocode_mock_thread = Thread(target=robocode_mock) robocode_mock_thread.start() # run training and load resulting agent agent_path = tempfile.NamedTemporaryFile(delete=False).name cmd = f'--random-seed 12345 --agent rlai.environments.robocode_continuous_action.RobocodeAgent --gamma 1.0 --environment rlai.environments.robocode_continuous_action.RobocodeEnvironment --port {robocode_port} --bullet-power-decay 0.75 --train-function rlai.policy_gradient.monte_carlo.reinforce.improve --num-episodes 10 --v-S rlai.v_S.function_approximation.estimators.ApproximateStateValueEstimator --feature-extractor rlai.environments.robocode_continuous_action.RobocodeFeatureExtractor --function-approximation-model rlai.models.sklearn.SKLearnSGD --loss squared_loss --sgd-alpha 0.0 --learning-rate constant --eta0 0.00001 --policy rlai.policies.parameterized.continuous_action.ContinuousActionBetaDistributionPolicy --policy-feature-extractor rlai.environments.robocode_continuous_action.RobocodeFeatureExtractor --alpha 0.00001 --update-upon-every-visit True --save-agent-path {agent_path} --log DEBUG' run(shlex.split(cmd)) if not update_fixture: robocode_mock_thread.join() with open(agent_path, 'rb') as f: agent = pickle.load(f) # if we're updating the test fixture, then save the state sequence and resulting policy to disk. if update_fixture: # pragma no cover with open(os.path.expanduser('~/Desktop/state_sequence.txt'), 'r') as f: state_sequence = f.readlines() with open( f'{os.path.dirname(__file__)}/fixtures/test_continuous_learn.pickle', 'wb') as file: pickle.dump((state_sequence, agent.pi), file) else: assert agent.pi == fixture_pi
def test_help(): with pytest.raises(ValueError, match='No training function specified. Cannot train.'): run(shlex.split('--agent rlai.agents.mdp.StochasticMdpAgent --help'))