Ejemplo n.º 1
0
 def _generate_all_brain_info(self) -> AllBrainInfo:
     all_brain_info = {}
     for brain_name in self.env.get_agent_groups():
         all_brain_info[brain_name] = step_result_to_brain_info(
             self.env.get_step_result(brain_name),
             self.env.get_agent_group_spec(brain_name),
         )
     return all_brain_info
Ejemplo n.º 2
0
 def _generate_all_brain_info() -> AllBrainInfo:
     all_brain_info = {}
     for brain_name in env.get_agent_groups():
         all_brain_info[brain_name] = step_result_to_brain_info(
             env.get_step_result(brain_name),
             env.get_agent_group_spec(brain_name),
             worker_id,
         )
     return all_brain_info
Ejemplo n.º 3
0
def test_ppo_policy_evaluate(mock_communicator, mock_launcher, dummy_config):
    tf.reset_default_graph()
    mock_communicator.return_value = MockCommunicator(discrete_action=False,
                                                      visual_inputs=0)
    env = UnityEnvironment(" ")
    env.reset()
    brain_name = env.get_agent_groups()[0]
    brain_info = step_result_to_brain_info(
        env.get_step_result(brain_name), env.get_agent_group_spec(brain_name))
    brain_params = group_spec_to_brain_parameters(
        brain_name, env.get_agent_group_spec(brain_name))

    trainer_parameters = dummy_config
    model_path = brain_name
    trainer_parameters["model_path"] = model_path
    trainer_parameters["keep_checkpoints"] = 3
    policy = PPOPolicy(0, brain_params, trainer_parameters, False, False)
    run_out = policy.evaluate(brain_info)
    assert run_out["action"].shape == (3, 2)
    env.close()
Ejemplo n.º 4
0
def test_ppo_get_value_estimates(mock_communicator, mock_launcher,
                                 dummy_config):
    tf.reset_default_graph()
    mock_communicator.return_value = MockCommunicator(discrete_action=False,
                                                      visual_inputs=0)
    env = UnityEnvironment(" ")
    env.reset()
    brain_name = env.get_agent_groups()[0]
    brain_info = step_result_to_brain_info(
        env.get_step_result(brain_name), env.get_agent_group_spec(brain_name))
    brain_params = group_spec_to_brain_parameters(
        brain_name, env.get_agent_group_spec(brain_name))

    trainer_parameters = dummy_config
    model_path = brain_name
    trainer_parameters["model_path"] = model_path
    trainer_parameters["keep_checkpoints"] = 3
    policy = PPOPolicy(0, brain_params, trainer_parameters, False, False)
    run_out = policy.get_value_estimates(brain_info, 0, done=False)
    for key, val in run_out.items():
        assert type(key) is str
        assert type(val) is float

    run_out = policy.get_value_estimates(brain_info, 0, done=True)
    for key, val in run_out.items():
        assert type(key) is str
        assert val == 0.0

    # Check if we ignore terminal states properly
    policy.reward_signals["extrinsic"].use_terminal_states = False
    run_out = policy.get_value_estimates(brain_info, 0, done=True)
    for key, val in run_out.items():
        assert type(key) is str
        assert val != 0.0

    env.close()