def test_train_agent_exist_force(self, mock_dp_plan, mock_dp_init, mock_factory_build, mock_train_operation, mock_dm): metadata = { TrainAgentOperation.TRAINING_GLOBAL_STEP: 0, "available_data": [("model", 0)] } mock_dm.get_latest.return_value = metadata mock_dm.get.return_value = True mock_train_operation.run.return_value = None mock_factory_build.return_value = mock_train_operation run_id = 1 mock_dp_plan.return_value = [ OperationConfig(TrainAgentOperation.__name__, run_id) ] mock_rl_app = MockRLApplication() mock_engine_config = MockEngineConfig() engine = BaseEngine(mock_rl_app, mock_engine_config) engine._dm = mock_dm engine.train(run_id, True) mock_dm.get_latest.assert_any_call(DATANAME.RUN_CONTEXT, run_id) mock_dp_init.assert_any_call(engine._get_available_operators(), mock_engine_config, [("model", 0)]) mock_dp_plan.assert_any_call(DATANAME.MODEL, run_id) mock_train_operation.run.assert_any_call(run_id)
def test_train_implicit_run_id(self, mock_dp_plan, mock_dp_init, mock_factory_build, mock_train_operation, mock_dm): run_id = 1 mock_train_operation.run.return_value = None mock_factory_build.return_value = mock_train_operation mock_dp_plan.return_value = [ OperationConfig(TrainAgentOperation.__name__, run_id) ] run_context = { TrainAgentOperation.TRAINING_GLOBAL_STEP: 0, "available_data": [("model", 0)] } mock_dm.get_latest.return_value = run_context mock_dm.get.return_value = False self._application.first_timestep_dt -= timedelta(days=1) engine = BaseEngine(self._application, mock_dm) engine.train() mock_dm.get.assert_any_call(DATANAME.MODEL, run_id) mock_dm.get_latest.assert_any_call(DATANAME.RUN_CONTEXT, run_id) mock_dp_plan.assert_any_call(DATANAME.MODEL, run_id) mock_train_operation.run.assert_any_call(run_id)
def test_train_explicit_run_id(self, mock_dp_plan, mock_dp_init, mock_factory_build, mock_train_operation, mock_dm): metadata = { TrainAgentOperation.TRAINING_GLOBAL_STEP: 0, "available_data": [("model", 0)] } mock_dm.get_latest.return_value = metadata mock_dm.get.return_value = False mock_train_operation.run.return_value = None mock_factory_build.return_value = mock_train_operation run_id = 1 mock_dp_plan.return_value = [ OperationConfig(TrainAgentOperation.__name__, run_id) ] engine = BaseEngine(self._application, mock_dm) engine.train(run_id) mock_dm.get.assert_any_call(DATANAME.MODEL, run_id) mock_dm.get_latest.assert_any_call(DATANAME.RUN_CONTEXT, run_id) mock_dp_plan.assert_any_call(DATANAME.MODEL, run_id) mock_train_operation.run.assert_any_call(run_id)
def test_train_agent_exist(self, mock_dm): run_id = 1 mock_dm.get.return_value = True engine = BaseEngine(self._application, mock_dm) engine._dm = mock_dm with self.assertRaises(Exception): engine.train(run_id)
def test_train_agent_exist(self, mock_dm): run_id = 1 mock_dm.get.return_value = True mock_rl_app = MockRLApplication() mock_engine_config = MockEngineConfig() engine = BaseEngine(mock_rl_app, mock_engine_config) engine._dm = mock_dm with self.assertRaises(Exception): engine.train(run_id)
def run(self): engine = BaseEngine(self._application, self._dm) engine.init(force_run=True) logger.info("Training started") eval_avg_rwd = [] for run_id in range(1, self._num_runs): engine.train(run_id) if run_id % self._eval_interval == 0: avg_rwd = self._evaluate_agent(run_id, self._num_eval_episodes) eval_avg_rwd.append(avg_rwd) logger.info("Training is done") logger.info("Eval result: %s" % str(eval_avg_rwd)) return eval_avg_rwd
def test_train_implicit_run_id(self, mock_dp_plan, mock_dp_init, mock_factory_build, mock_train_operation, mock_dm): run_id = 1 mock_train_operation.run.return_value = None mock_factory_build.return_value = mock_train_operation mock_dp_plan.return_value = [ OperationConfig(TrainAgentOperation.__name__, run_id) ] run_context = { TrainAgentOperation.TRAINING_GLOBAL_STEP: 0, "available_data": [("model", 0)] } mock_dm.get_latest.return_value = run_context mock_dm.get.return_value = False mock_engine_config = MockEngineConfig() mock_engine_config.start_dt = datetime(2019, 11, 2, 0, 0, 0) mock_engine_config.training_interval = timedelta(days=1) MockEngineConfig._get_current_datetime = lambda _: datetime( 2019, 11, 3, 10, 0, 0) mock_rl_app = MockRLApplication() engine = BaseEngine(mock_rl_app, mock_engine_config) engine._dm = mock_dm engine.train() mock_dm.get.assert_any_call(DATANAME.MODEL, run_id) mock_dm.get_latest.assert_any_call(DATANAME.RUN_CONTEXT, run_id) mock_dp_init.assert_any_call(engine._get_available_operators(), mock_engine_config, [("model", 0)]) mock_dp_plan.assert_any_call(DATANAME.MODEL, run_id) mock_train_operation.run.assert_any_call(run_id)