Exemple #1
0
 def test_step_takes_steps_for_all_non_waiting_envs(self,
                                                    mock_create_worker):
     mock_create_worker.side_effect = create_worker_mock
     manager = SubprocessEnvManager(mock_env_factory,
                                    EngineConfig.default_config(), 3)
     manager.step_queue = Mock()
     manager.step_queue.get_nowait.side_effect = [
         EnvironmentResponse(EnvironmentCommand.STEP, 0,
                             StepResponse(0, None, {})),
         EnvironmentResponse(EnvironmentCommand.STEP, 1,
                             StepResponse(1, None, {})),
         EmptyQueue(),
     ]
     step_mock = Mock()
     last_steps = [Mock(), Mock(), Mock()]
     manager.env_workers[0].previous_step = last_steps[0]
     manager.env_workers[1].previous_step = last_steps[1]
     manager.env_workers[2].previous_step = last_steps[2]
     manager.env_workers[2].waiting = True
     manager._take_step = Mock(return_value=step_mock)
     res = manager._step()
     for i, env in enumerate(manager.env_workers):
         if i < 2:
             env.send.assert_called_with(EnvironmentCommand.STEP, step_mock)
             manager.step_queue.get_nowait.assert_called()
             # Check that the "last steps" are set to the value returned for each step
             self.assertEqual(
                 manager.env_workers[i].previous_step.
                 current_all_step_result, i)
     assert res == [
         manager.env_workers[0].previous_step,
         manager.env_workers[1].previous_step,
     ]
 def test_step_takes_steps_for_all_non_waiting_envs(self):
     SubprocessEnvManager.create_worker = lambda em, worker_id, step_queue, env_factory, engine_c: MockEnvWorker(
         worker_id, EnvironmentResponse("step", worker_id, worker_id))
     manager = SubprocessEnvManager(mock_env_factory,
                                    EngineConfig.default_config(), 3)
     manager.step_queue = Mock()
     manager.step_queue.get_nowait.side_effect = [
         EnvironmentResponse("step", 0, StepResponse(0, None)),
         EnvironmentResponse("step", 1, StepResponse(1, None)),
         EmptyQueue(),
     ]
     step_mock = Mock()
     last_steps = [Mock(), Mock(), Mock()]
     manager.env_workers[0].previous_step = last_steps[0]
     manager.env_workers[1].previous_step = last_steps[1]
     manager.env_workers[2].previous_step = last_steps[2]
     manager.env_workers[2].waiting = True
     manager._take_step = Mock(return_value=step_mock)
     res = manager.step()
     for i, env in enumerate(manager.env_workers):
         if i < 2:
             env.send.assert_called_with("step", step_mock)
             manager.step_queue.get_nowait.assert_called()
             # Check that the "last steps" are set to the value returned for each step
             self.assertEqual(
                 manager.env_workers[i].previous_step.
                 current_all_brain_info, i)
     assert res == [
         manager.env_workers[0].previous_step,
         manager.env_workers[1].previous_step,
     ]
    def test_advance(self, external_brains_mock, step_mock):
        brain_name = "testbrain"
        action_info_dict = {brain_name: MagicMock()}
        SubprocessEnvManager.create_worker = lambda em, worker_id, step_queue, env_factory, engine_c: MockEnvWorker(
            worker_id, EnvironmentResponse("step", worker_id, worker_id)
        )
        env_manager = SubprocessEnvManager(
            mock_env_factory, EngineConfig.default_config(), 3
        )
        external_brains_mock.return_value = [brain_name]
        agent_manager_mock = mock.Mock()
        env_manager.set_agent_manager(brain_name, agent_manager_mock)

        step_info_dict = {brain_name: Mock()}
        step_info = EnvironmentStep(step_info_dict, 0, action_info_dict)
        step_mock.return_value = [step_info]
        env_manager.advance()

        # Test add_experiences
        env_manager._step.assert_called_once()

        agent_manager_mock.add_experiences.assert_called_once_with(
            step_info.current_all_step_result[brain_name],
            0,
            step_info.brain_name_to_action_info[brain_name],
        )

        # Test policy queue
        mock_policy = mock.Mock()
        agent_manager_mock.policy_queue.get_nowait.return_value = mock_policy
        env_manager.advance()
        assert env_manager.policies[brain_name] == mock_policy
        assert agent_manager_mock.policy == mock_policy
Exemple #4
0
 def create_worker_mock(worker_id, step_queue, env_factor, engine_c):
     return MockEnvWorker(
         worker_id,
         EnvironmentResponse(
             EnvironmentCommand.RESET, worker_id, {f"key{worker_id}": worker_id}
         ),
     )
Exemple #5
0
 def test_reset_passes_reset_params(self):
     SubprocessEnvManager.create_worker = lambda em, worker_id, step_queue, env_factory, engine_c: MockEnvWorker(
         worker_id, EnvironmentResponse("reset", worker_id, worker_id))
     manager = SubprocessEnvManager(mock_env_factory,
                                    EngineConfig.default_config(), 1)
     params = {"test": "params"}
     manager.reset(params)
     manager.env_workers[0].send.assert_called_with("reset", (params))
Exemple #6
0
 def test_crashed_env_restarts(self, mock_create_worker):
     crashing_worker = MockEnvWorker(
         0, EnvironmentResponse(EnvironmentCommand.RESET, 0, 0)
     )
     restarting_worker = MockEnvWorker(
         0, EnvironmentResponse(EnvironmentCommand.RESET, 0, 0)
     )
     healthy_worker = MockEnvWorker(
         1, EnvironmentResponse(EnvironmentCommand.RESET, 1, 1)
     )
     mock_create_worker.side_effect = [
         crashing_worker,
         healthy_worker,
         restarting_worker,
     ]
     manager = SubprocessEnvManager(mock_env_factory, RunOptions(), 2)
     manager.step_queue = Mock()
     manager.step_queue.get_nowait.side_effect = [
         EnvironmentResponse(
             EnvironmentCommand.ENV_EXITED,
             0,
             UnityCommunicationException("Test msg"),
         ),
         EnvironmentResponse(EnvironmentCommand.CLOSED, 0, None),
         EnvironmentResponse(EnvironmentCommand.STEP, 1, StepResponse(0, None, {})),
         EmptyQueue(),
         EnvironmentResponse(EnvironmentCommand.STEP, 0, StepResponse(1, None, {})),
         EnvironmentResponse(EnvironmentCommand.STEP, 1, StepResponse(2, None, {})),
         EmptyQueue(),
     ]
     step_mock = Mock()
     last_steps = [Mock(), Mock(), Mock()]
     assert crashing_worker is manager.env_workers[0]
     assert healthy_worker is manager.env_workers[1]
     crashing_worker.previous_step = last_steps[0]
     crashing_worker.waiting = True
     healthy_worker.previous_step = last_steps[1]
     healthy_worker.waiting = True
     manager._take_step = Mock(return_value=step_mock)
     manager._step()
     healthy_worker.send.assert_has_calls(
         [
             call(EnvironmentCommand.ENVIRONMENT_PARAMETERS, ANY),
             call(EnvironmentCommand.RESET, ANY),
             call(EnvironmentCommand.STEP, ANY),
         ]
     )
     restarting_worker.send.assert_has_calls(
         [
             call(EnvironmentCommand.ENVIRONMENT_PARAMETERS, ANY),
             call(EnvironmentCommand.RESET, ANY),
             call(EnvironmentCommand.STEP, ANY),
         ]
     )
Exemple #7
0
    def test_reset_collects_results_from_all_envs(self):
        SubprocessEnvManager.create_worker = lambda em, worker_id, step_queue, env_factory, engine_c: MockEnvWorker(
            worker_id, EnvironmentResponse("reset", worker_id, worker_id))
        manager = SubprocessEnvManager(mock_env_factory,
                                       EngineConfig.default_config(), 4)

        params = {"test": "params"}
        res = manager.reset(params)
        for i, env in enumerate(manager.env_workers):
            env.send.assert_called_with("reset", (params))
            env.recv.assert_called()
            # Check that the "last steps" are set to the value returned for each step
            self.assertEqual(
                manager.env_workers[i].previous_step.current_all_brain_info, i)
        assert res == list(
            map(lambda ew: ew.previous_step, manager.env_workers))
 def _send_response(cmd_name, payload):
     parent_conn.send(EnvironmentResponse(cmd_name, worker_id, payload))
def worker_aai(
    parent_conn: Connection,
    step_queue: Queue,
    pickled_env_factory: str,
    worker_id: int,
    engine_configuration: EngineConfig,
) -> None:
    env_factory: Callable[
        [int, List[SideChannel]], AnimalAIEnvironment
    ] = cloudpickle.loads(pickled_env_factory)
    shared_float_properties = FloatPropertiesChannel()
    engine_configuration_channel = EngineConfigurationChannel()
    engine_configuration_channel.set_configuration(engine_configuration)
    env: AnimalAIEnvironment = env_factory(
        worker_id, [shared_float_properties, engine_configuration_channel]
    )

    def _send_response(cmd_name, payload):
        parent_conn.send(EnvironmentResponse(cmd_name, worker_id, payload))

    def _generate_all_results() -> AllStepResult:
        all_step_result: AllStepResult = {}
        for brain_name in env.get_agent_groups():
            all_step_result[brain_name] = env.get_step_result(brain_name)
        return all_step_result

    def external_brains():
        result = {}
        for brain_name in env.get_agent_groups():
            result[brain_name] = group_spec_to_brain_parameters(
                brain_name, env.get_agent_group_spec(brain_name)
            )
        return result

    try:
        while True:
            cmd: EnvironmentCommand = parent_conn.recv()
            if cmd.name == "step":
                all_action_info = cmd.payload
                for brain_name, action_info in all_action_info.items():
                    if len(action_info.action) != 0:
                        env.set_actions(brain_name, action_info.action)
                env.step()
                all_step_result = _generate_all_results()
                # The timers in this process are independent from all the processes and the "main" process
                # So after we send back the root timer, we can safely clear them.
                # Note that we could randomly return timers a fraction of the time if we wanted to reduce
                # the data transferred.
                # TODO get gauges from the workers and merge them in the main process too.
                step_response = StepResponse(all_step_result, get_timer_root())
                step_queue.put(EnvironmentResponse("step", worker_id, step_response))
                reset_timers()
            elif cmd.name == "external_brains":
                _send_response("external_brains", external_brains())
            elif cmd.name == "get_properties":
                reset_params = shared_float_properties.get_property_dict_copy()
                _send_response("get_properties", reset_params)
            elif cmd.name == "reset":
                env.reset(arenas_configurations=cmd.payload)
                all_step_result = _generate_all_results()
                _send_response("reset", all_step_result)
            elif cmd.name == "close":
                break
    except (KeyboardInterrupt, UnityCommunicationException, UnityTimeOutException):
        logger.info(f"UnityEnvironment worker {worker_id}: environment stopping.")
        step_queue.put(EnvironmentResponse("env_close", worker_id, None))
    finally:
        # If this worker has put an item in the step queue that hasn't been processed by the EnvManager, the process
        # will hang until the item is processed. We avoid this behavior by using Queue.cancel_join_thread()
        # See https://docs.python.org/3/library/multiprocessing.html#multiprocessing.Queue.cancel_join_thread for
        # more info.
        logger.debug(f"UnityEnvironment worker {worker_id} closing.")
        step_queue.cancel_join_thread()
        step_queue.close()
        env.close()
        logger.debug(f"UnityEnvironment worker {worker_id} done.")
def create_worker_mock(worker_id, step_queue, env_factor, engine_c):
    return MockEnvWorker(worker_id,
                         EnvironmentResponse("reset", worker_id, worker_id))