def test_step_takes_steps_for_all_non_waiting_envs(self):
     SubprocessEnvManager.create_worker = lambda em, worker_id, step_queue, env_factory, engine_c: MockEnvWorker(
         worker_id, EnvironmentResponse("step", worker_id, worker_id))
     manager = SubprocessEnvManager(mock_env_factory,
                                    EngineConfig.default_config(), 3)
     manager.step_queue = Mock()
     manager.step_queue.get_nowait.side_effect = [
         EnvironmentResponse("step", 0, StepResponse(0, None)),
         EnvironmentResponse("step", 1, StepResponse(1, None)),
         EmptyQueue(),
     ]
     step_mock = Mock()
     last_steps = [Mock(), Mock(), Mock()]
     manager.env_workers[0].previous_step = last_steps[0]
     manager.env_workers[1].previous_step = last_steps[1]
     manager.env_workers[2].previous_step = last_steps[2]
     manager.env_workers[2].waiting = True
     manager._take_step = Mock(return_value=step_mock)
     res = manager.step()
     for i, env in enumerate(manager.env_workers):
         if i < 2:
             env.send.assert_called_with("step", step_mock)
             manager.step_queue.get_nowait.assert_called()
             # Check that the "last steps" are set to the value returned for each step
             self.assertEqual(
                 manager.env_workers[i].previous_step.
                 current_all_brain_info, i)
     assert res == [
         manager.env_workers[0].previous_step,
         manager.env_workers[1].previous_step,
     ]
Ejemplo n.º 2
0
 def test_step_takes_steps_for_all_non_waiting_envs(self,
                                                    mock_create_worker):
     mock_create_worker.side_effect = create_worker_mock
     manager = SubprocessEnvManager(mock_env_factory,
                                    EngineConfig.default_config(), 3)
     manager.step_queue = Mock()
     manager.step_queue.get_nowait.side_effect = [
         EnvironmentResponse(EnvironmentCommand.STEP, 0,
                             StepResponse(0, None, {})),
         EnvironmentResponse(EnvironmentCommand.STEP, 1,
                             StepResponse(1, None, {})),
         EmptyQueue(),
     ]
     step_mock = Mock()
     last_steps = [Mock(), Mock(), Mock()]
     manager.env_workers[0].previous_step = last_steps[0]
     manager.env_workers[1].previous_step = last_steps[1]
     manager.env_workers[2].previous_step = last_steps[2]
     manager.env_workers[2].waiting = True
     manager._take_step = Mock(return_value=step_mock)
     res = manager._step()
     for i, env in enumerate(manager.env_workers):
         if i < 2:
             env.send.assert_called_with(EnvironmentCommand.STEP, step_mock)
             manager.step_queue.get_nowait.assert_called()
             # Check that the "last steps" are set to the value returned for each step
             self.assertEqual(
                 manager.env_workers[i].previous_step.
                 current_all_step_result, i)
     assert res == [
         manager.env_workers[0].previous_step,
         manager.env_workers[1].previous_step,
     ]
Ejemplo n.º 3
0
 def test_crashed_env_restarts(self, mock_create_worker):
     crashing_worker = MockEnvWorker(
         0, EnvironmentResponse(EnvironmentCommand.RESET, 0, 0)
     )
     restarting_worker = MockEnvWorker(
         0, EnvironmentResponse(EnvironmentCommand.RESET, 0, 0)
     )
     healthy_worker = MockEnvWorker(
         1, EnvironmentResponse(EnvironmentCommand.RESET, 1, 1)
     )
     mock_create_worker.side_effect = [
         crashing_worker,
         healthy_worker,
         restarting_worker,
     ]
     manager = SubprocessEnvManager(mock_env_factory, RunOptions(), 2)
     manager.step_queue = Mock()
     manager.step_queue.get_nowait.side_effect = [
         EnvironmentResponse(
             EnvironmentCommand.ENV_EXITED,
             0,
             UnityCommunicationException("Test msg"),
         ),
         EnvironmentResponse(EnvironmentCommand.CLOSED, 0, None),
         EnvironmentResponse(EnvironmentCommand.STEP, 1, StepResponse(0, None, {})),
         EmptyQueue(),
         EnvironmentResponse(EnvironmentCommand.STEP, 0, StepResponse(1, None, {})),
         EnvironmentResponse(EnvironmentCommand.STEP, 1, StepResponse(2, None, {})),
         EmptyQueue(),
     ]
     step_mock = Mock()
     last_steps = [Mock(), Mock(), Mock()]
     assert crashing_worker is manager.env_workers[0]
     assert healthy_worker is manager.env_workers[1]
     crashing_worker.previous_step = last_steps[0]
     crashing_worker.waiting = True
     healthy_worker.previous_step = last_steps[1]
     healthy_worker.waiting = True
     manager._take_step = Mock(return_value=step_mock)
     manager._step()
     healthy_worker.send.assert_has_calls(
         [
             call(EnvironmentCommand.ENVIRONMENT_PARAMETERS, ANY),
             call(EnvironmentCommand.RESET, ANY),
             call(EnvironmentCommand.STEP, ANY),
         ]
     )
     restarting_worker.send.assert_has_calls(
         [
             call(EnvironmentCommand.ENVIRONMENT_PARAMETERS, ANY),
             call(EnvironmentCommand.RESET, ANY),
             call(EnvironmentCommand.STEP, ANY),
         ]
     )
def worker_aai(
    parent_conn: Connection,
    step_queue: Queue,
    pickled_env_factory: str,
    worker_id: int,
    engine_configuration: EngineConfig,
) -> None:
    env_factory: Callable[
        [int, List[SideChannel]], AnimalAIEnvironment
    ] = cloudpickle.loads(pickled_env_factory)
    shared_float_properties = FloatPropertiesChannel()
    engine_configuration_channel = EngineConfigurationChannel()
    engine_configuration_channel.set_configuration(engine_configuration)
    env: AnimalAIEnvironment = env_factory(
        worker_id, [shared_float_properties, engine_configuration_channel]
    )

    def _send_response(cmd_name, payload):
        parent_conn.send(EnvironmentResponse(cmd_name, worker_id, payload))

    def _generate_all_results() -> AllStepResult:
        all_step_result: AllStepResult = {}
        for brain_name in env.get_agent_groups():
            all_step_result[brain_name] = env.get_step_result(brain_name)
        return all_step_result

    def external_brains():
        result = {}
        for brain_name in env.get_agent_groups():
            result[brain_name] = group_spec_to_brain_parameters(
                brain_name, env.get_agent_group_spec(brain_name)
            )
        return result

    try:
        while True:
            cmd: EnvironmentCommand = parent_conn.recv()
            if cmd.name == "step":
                all_action_info = cmd.payload
                for brain_name, action_info in all_action_info.items():
                    if len(action_info.action) != 0:
                        env.set_actions(brain_name, action_info.action)
                env.step()
                all_step_result = _generate_all_results()
                # The timers in this process are independent from all the processes and the "main" process
                # So after we send back the root timer, we can safely clear them.
                # Note that we could randomly return timers a fraction of the time if we wanted to reduce
                # the data transferred.
                # TODO get gauges from the workers and merge them in the main process too.
                step_response = StepResponse(all_step_result, get_timer_root())
                step_queue.put(EnvironmentResponse("step", worker_id, step_response))
                reset_timers()
            elif cmd.name == "external_brains":
                _send_response("external_brains", external_brains())
            elif cmd.name == "get_properties":
                reset_params = shared_float_properties.get_property_dict_copy()
                _send_response("get_properties", reset_params)
            elif cmd.name == "reset":
                env.reset(arenas_configurations=cmd.payload)
                all_step_result = _generate_all_results()
                _send_response("reset", all_step_result)
            elif cmd.name == "close":
                break
    except (KeyboardInterrupt, UnityCommunicationException, UnityTimeOutException):
        logger.info(f"UnityEnvironment worker {worker_id}: environment stopping.")
        step_queue.put(EnvironmentResponse("env_close", worker_id, None))
    finally:
        # If this worker has put an item in the step queue that hasn't been processed by the EnvManager, the process
        # will hang until the item is processed. We avoid this behavior by using Queue.cancel_join_thread()
        # See https://docs.python.org/3/library/multiprocessing.html#multiprocessing.Queue.cancel_join_thread for
        # more info.
        logger.debug(f"UnityEnvironment worker {worker_id} closing.")
        step_queue.cancel_join_thread()
        step_queue.close()
        env.close()
        logger.debug(f"UnityEnvironment worker {worker_id} done.")