예제 #1
0
 def test_crashed_env_restarts(self, mock_create_worker):
     crashing_worker = MockEnvWorker(
         0, EnvironmentResponse(EnvironmentCommand.RESET, 0, 0)
     )
     restarting_worker = MockEnvWorker(
         0, EnvironmentResponse(EnvironmentCommand.RESET, 0, 0)
     )
     healthy_worker = MockEnvWorker(
         1, EnvironmentResponse(EnvironmentCommand.RESET, 1, 1)
     )
     mock_create_worker.side_effect = [
         crashing_worker,
         healthy_worker,
         restarting_worker,
     ]
     manager = SubprocessEnvManager(mock_env_factory, RunOptions(), 2)
     manager.step_queue = Mock()
     manager.step_queue.get_nowait.side_effect = [
         EnvironmentResponse(
             EnvironmentCommand.ENV_EXITED,
             0,
             UnityCommunicationException("Test msg"),
         ),
         EnvironmentResponse(EnvironmentCommand.CLOSED, 0, None),
         EnvironmentResponse(EnvironmentCommand.STEP, 1, StepResponse(0, None, {})),
         EmptyQueue(),
         EnvironmentResponse(EnvironmentCommand.STEP, 0, StepResponse(1, None, {})),
         EnvironmentResponse(EnvironmentCommand.STEP, 1, StepResponse(2, None, {})),
         EmptyQueue(),
     ]
     step_mock = Mock()
     last_steps = [Mock(), Mock(), Mock()]
     assert crashing_worker is manager.env_workers[0]
     assert healthy_worker is manager.env_workers[1]
     crashing_worker.previous_step = last_steps[0]
     crashing_worker.waiting = True
     healthy_worker.previous_step = last_steps[1]
     healthy_worker.waiting = True
     manager._take_step = Mock(return_value=step_mock)
     manager._step()
     healthy_worker.send.assert_has_calls(
         [
             call(EnvironmentCommand.ENVIRONMENT_PARAMETERS, ANY),
             call(EnvironmentCommand.RESET, ANY),
             call(EnvironmentCommand.STEP, ANY),
         ]
     )
     restarting_worker.send.assert_has_calls(
         [
             call(EnvironmentCommand.ENVIRONMENT_PARAMETERS, ANY),
             call(EnvironmentCommand.RESET, ANY),
             call(EnvironmentCommand.STEP, ANY),
         ]
     )
예제 #2
0
 def test_step_takes_steps_for_all_non_waiting_envs(self,
                                                    mock_create_worker):
     mock_create_worker.side_effect = create_worker_mock
     manager = SubprocessEnvManager(mock_env_factory,
                                    EngineConfig.default_config(), 3)
     manager.step_queue = Mock()
     manager.step_queue.get_nowait.side_effect = [
         EnvironmentResponse(EnvironmentCommand.STEP, 0,
                             StepResponse(0, None, {})),
         EnvironmentResponse(EnvironmentCommand.STEP, 1,
                             StepResponse(1, None, {})),
         EmptyQueue(),
     ]
     step_mock = Mock()
     last_steps = [Mock(), Mock(), Mock()]
     manager.env_workers[0].previous_step = last_steps[0]
     manager.env_workers[1].previous_step = last_steps[1]
     manager.env_workers[2].previous_step = last_steps[2]
     manager.env_workers[2].waiting = True
     manager._take_step = Mock(return_value=step_mock)
     res = manager._step()
     for i, env in enumerate(manager.env_workers):
         if i < 2:
             env.send.assert_called_with(EnvironmentCommand.STEP, step_mock)
             manager.step_queue.get_nowait.assert_called()
             # Check that the "last steps" are set to the value returned for each step
             self.assertEqual(
                 manager.env_workers[i].previous_step.
                 current_all_step_result, i)
     assert res == [
         manager.env_workers[0].previous_step,
         manager.env_workers[1].previous_step,
     ]