def test_add_get_policy(ppo_optimizer, dummy_config): brain_params = make_brain_parameters( discrete_action=False, visual_inputs=0, vec_obs_size=6 ) mock_optimizer = mock.Mock() mock_optimizer.reward_signals = {} ppo_optimizer.return_value = mock_optimizer dummy_config["summary_path"] = "./summaries/test_trainer_summary" dummy_config["model_path"] = "./models/test_trainer_models/TestModel" trainer = PPOTrainer(brain_params, 0, dummy_config, True, False, 0, "0") policy = mock.Mock(spec=NNPolicy) policy.get_current_step.return_value = 2000 trainer.add_policy(brain_params.brain_name, policy) assert trainer.get_policy(brain_params.brain_name) == policy # Make sure the summary steps were loaded properly assert trainer.get_step == 2000 assert trainer.next_summary_step > 2000 # Test incorrect class of policy policy = mock.Mock() with pytest.raises(RuntimeError): trainer.add_policy(brain_params, policy)
def test_add_get_policy(ppo_optimizer, mock_create_model_saver, dummy_config): mock_optimizer = mock.Mock() mock_optimizer.reward_signals = {} ppo_optimizer.return_value = mock_optimizer trainer = PPOTrainer("test_policy", 0, dummy_config, True, False, 0, "0") policy = mock.Mock(spec=TFPolicy) policy.get_current_step.return_value = 2000 behavior_id = BehaviorIdentifiers.from_name_behavior_id(trainer.brain_name) trainer.add_policy(behavior_id, policy) assert trainer.get_policy("test_policy") == policy # Make sure the summary steps were loaded properly assert trainer.get_step == 2000
def test_add_get_policy(ppo_optimizer, dummy_config): mock_optimizer = mock.Mock() mock_optimizer.reward_signals = {} ppo_optimizer.return_value = mock_optimizer trainer = PPOTrainer("test_policy", 0, dummy_config, True, False, 0, "0") policy = mock.Mock(spec=NNPolicy) policy.get_current_step.return_value = 2000 trainer.add_policy("test_policy", policy) assert trainer.get_policy("test_policy") == policy # Make sure the summary steps were loaded properly assert trainer.get_step == 2000 # Test incorrect class of policy policy = mock.Mock() with pytest.raises(RuntimeError): trainer.add_policy("test_policy", policy)