def test_prioritized_dqn_save(tmpdir): agent_path = tmpdir / 'agent_{}'.format( datetime.now().strftime("%H%M%S%f")) replay_memory = PrioritizedReplayMemory(50, 500, alpha=.6, beta=LinearParameter( .4, threshold_value=1, n=500 // 5)) params = dict(batch_size=50, initial_replay_size=50, max_replay_size=500, target_update_frequency=50, replay_memory=replay_memory) agent_save = learn(DQN, params) agent_save.save(agent_path, full_save=True) agent_load = Agent.load(agent_path) for att, method in vars(agent_save).items(): save_attr = getattr(agent_save, att) load_attr = getattr(agent_load, att) tu.assert_eq(save_attr, load_attr)
def test_copdac_q_save(tmpdir): agent_path = tmpdir / 'agent_{}'.format( datetime.now().strftime("%H%M%S%f")) agent_save = learn_copdac_q() agent_save.save(agent_path) agent_load = Agent.load(agent_path) for att, method in vars(agent_save).items(): save_attr = getattr(agent_save, att) load_attr = getattr(agent_load, att) tu.assert_eq(save_attr, load_attr)
def test_PGPE_save(tmpdir): agent_path = tmpdir / 'agent_{}'.format( datetime.now().strftime("%H%M%S%f")) agent_save = learn(PGPE, optimizer=AdaptiveOptimizer(1.5)) agent_save.save(agent_path) agent_load = Agent.load(agent_path) for att, method in vars(agent_save).items(): save_attr = getattr(agent_save, att) load_attr = getattr(agent_load, att) tu.assert_eq(save_attr, load_attr)
def test_a2c_save(tmpdir): agent_path = tmpdir / 'agent_{}'.format(datetime.now().strftime("%H%M%S%f")) agent_save = learn_a2c() agent_save.save(agent_path, full_save=True) agent_load = Agent.load(agent_path) for att, method in vars(agent_save).items(): save_attr = getattr(agent_save, att) load_attr = getattr(agent_load, att) print('checking ', att) print(save_attr, load_attr) tu.assert_eq(save_attr, load_attr)
def test_fqi_boosted_save(tmpdir): agent_path = tmpdir / 'agent_{}'.format( datetime.now().strftime("%H%M%S%f")) params = dict(n_iterations=10) agent_save, _ = learn(BoostedFQI, params) agent_save.save(agent_path) agent_load = Agent.load(agent_path) for att, method in vars(agent_save).items(): save_attr = getattr(agent_save, att) load_attr = getattr(agent_load, att) tu.assert_eq(save_attr, load_attr)
def test_GPOMDP_save(tmpdir): agent_path = tmpdir / 'agent_{}'.format( datetime.now().strftime("%H%M%S%f")) params = dict(optimizer=AdaptiveOptimizer(eps=.01)) agent_save = learn(GPOMDP, params) agent_save.save(agent_path) agent_load = Agent.load(agent_path) for att, method in vars(agent_save).items(): save_attr = getattr(agent_save, att) load_attr = getattr(agent_load, att) tu.assert_eq(save_attr, load_attr)
def test_noisy_dqn_save(tmpdir): agent_path = tmpdir / 'agent_{}'.format( datetime.now().strftime("%H%M%S%f")) params = dict(batch_size=50, initial_replay_size=50, max_replay_size=5000, target_update_frequency=50) agent_save = learn(NoisyDQN, params) agent_save.save(agent_path, full_save=True) agent_load = Agent.load(agent_path) for att, method in vars(agent_save).items(): save_attr = getattr(agent_save, att) load_attr = getattr(agent_load, att) tu.assert_eq(save_attr, load_attr)
def test_q_learning_save(tmpdir): agent_path = tmpdir / 'agent_{}'.format( datetime.now().strftime("%H%M%S%f")) pi, mdp, _ = initialize() agent_save = QLearning(mdp.info, pi, Parameter(.5)) core = Core(agent_save, mdp) # Train core.learn(n_steps=100, n_steps_per_fit=1, quiet=True) agent_save.save(agent_path) agent_load = Agent.load(agent_path) for att, method in vars(agent_save).items(): save_attr = getattr(agent_save, att) load_attr = getattr(agent_load, att) tu.assert_eq(save_attr, load_attr)
def test_rainbow_save(tmpdir): agent_path = tmpdir / 'agent_{}'.format( datetime.now().strftime("%H%M%S%f")) params = dict(batch_size=50, initial_replay_size=50, max_replay_size=500, target_update_frequency=50, n_steps_return=1, alpha_coeff=.6, beta=LinearParameter(.4, threshold_value=1, n=500 // 5)) agent_save = learn(Rainbow, params) agent_save.save(agent_path, full_save=True) agent_load = Agent.load(agent_path) for att, method in vars(agent_save).items(): save_attr = getattr(agent_save, att) load_attr = getattr(agent_load, att) tu.assert_eq(save_attr, load_attr)
def test_TRPO_save(tmpdir): agent_path = tmpdir / 'agent_{}'.format( datetime.now().strftime("%H%M%S%f")) params = dict(ent_coeff=0.0, max_kl=.001, lam=.98, n_epochs_line_search=10, n_epochs_cg=10, cg_damping=1e-2, cg_residual_tol=1e-10) agent_save = learn(TRPO, params) agent_save.save(agent_path) agent_load = Agent.load(agent_path) for att, method in vars(agent_save).items(): save_attr = getattr(agent_save, att) load_attr = getattr(agent_load, att) tu.assert_eq(save_attr, load_attr)
def test_sarsa_lambda_continuous_linear_save(tmpdir): agent_path = tmpdir / 'agent_{}'.format( datetime.now().strftime("%H%M%S%f")) pi, _, mdp_continuous = initialize() mdp_continuous.seed(1) n_tilings = 1 tilings = Tiles.generate(n_tilings, [2, 2], mdp_continuous.info.observation_space.low, mdp_continuous.info.observation_space.high) features = Features(tilings=tilings) approximator_params = dict( input_shape=(features.size, ), output_shape=(mdp_continuous.info.action_space.n, ), n_actions=mdp_continuous.info.action_space.n) agent_save = SARSALambdaContinuous(mdp_continuous.info, pi, LinearApproximator, Parameter(.1), .9, features=features, approximator_params=approximator_params) core = Core(agent_save, mdp_continuous) # Train core.learn(n_steps=100, n_steps_per_fit=1, quiet=True) agent_save.save(agent_path) agent_load = Agent.load(agent_path) for att, method in vars(agent_save).items(): save_attr = getattr(agent_save, att) load_attr = getattr(agent_load, att) tu.assert_eq(save_attr, load_attr)
def test_PPO_save(tmpdir): agent_path = tmpdir / 'agent_{}'.format( datetime.now().strftime("%H%M%S%f")) params = dict(actor_optimizer={ 'class': optim.Adam, 'params': { 'lr': 3e-4 } }, n_epochs_policy=4, batch_size=64, eps_ppo=.2, lam=.95) agent_save = learn(PPO, params) agent_save.save(agent_path) agent_load = Agent.load(agent_path) for att, method in vars(agent_save).items(): save_attr = getattr(agent_save, att) load_attr = getattr(agent_load, att) tu.assert_eq(save_attr, load_attr)