예제 #1
0
    def test_train_agent_exist_force(self, mock_dp_plan, mock_dp_init,
                                     mock_factory_build, mock_train_operation,
                                     mock_dm):

        metadata = {
            TrainAgentOperation.TRAINING_GLOBAL_STEP: 0,
            "available_data": [("model", 0)]
        }
        mock_dm.get_latest.return_value = metadata
        mock_dm.get.return_value = True

        mock_train_operation.run.return_value = None
        mock_factory_build.return_value = mock_train_operation

        run_id = 1
        mock_dp_plan.return_value = [
            OperationConfig(TrainAgentOperation.__name__, run_id)
        ]

        mock_rl_app = MockRLApplication()
        mock_engine_config = MockEngineConfig()
        engine = BaseEngine(mock_rl_app, mock_engine_config)
        engine._dm = mock_dm
        engine.train(run_id, True)

        mock_dm.get_latest.assert_any_call(DATANAME.RUN_CONTEXT, run_id)
        mock_dp_init.assert_any_call(engine._get_available_operators(),
                                     mock_engine_config, [("model", 0)])
        mock_dp_plan.assert_any_call(DATANAME.MODEL, run_id)
        mock_train_operation.run.assert_any_call(run_id)
예제 #2
0
    def test_train_implicit_run_id(self, mock_dp_plan, mock_dp_init,
                                   mock_factory_build, mock_train_operation,
                                   mock_dm):

        run_id = 1
        mock_train_operation.run.return_value = None
        mock_factory_build.return_value = mock_train_operation
        mock_dp_plan.return_value = [
            OperationConfig(TrainAgentOperation.__name__, run_id)
        ]

        run_context = {
            TrainAgentOperation.TRAINING_GLOBAL_STEP: 0,
            "available_data": [("model", 0)]
        }

        mock_dm.get_latest.return_value = run_context
        mock_dm.get.return_value = False

        self._application.first_timestep_dt -= timedelta(days=1)
        engine = BaseEngine(self._application, mock_dm)
        engine.train()

        mock_dm.get.assert_any_call(DATANAME.MODEL, run_id)
        mock_dm.get_latest.assert_any_call(DATANAME.RUN_CONTEXT, run_id)
        mock_dp_plan.assert_any_call(DATANAME.MODEL, run_id)
        mock_train_operation.run.assert_any_call(run_id)
예제 #3
0
    def test_train_explicit_run_id(self, mock_dp_plan, mock_dp_init,
                                   mock_factory_build, mock_train_operation,
                                   mock_dm):

        metadata = {
            TrainAgentOperation.TRAINING_GLOBAL_STEP: 0,
            "available_data": [("model", 0)]
        }
        mock_dm.get_latest.return_value = metadata
        mock_dm.get.return_value = False

        mock_train_operation.run.return_value = None
        mock_factory_build.return_value = mock_train_operation

        run_id = 1
        mock_dp_plan.return_value = [
            OperationConfig(TrainAgentOperation.__name__, run_id)
        ]

        engine = BaseEngine(self._application, mock_dm)
        engine.train(run_id)

        mock_dm.get.assert_any_call(DATANAME.MODEL, run_id)
        mock_dm.get_latest.assert_any_call(DATANAME.RUN_CONTEXT, run_id)
        mock_dp_plan.assert_any_call(DATANAME.MODEL, run_id)
        mock_train_operation.run.assert_any_call(run_id)
예제 #4
0
    def test_train_agent_exist(self, mock_dm):

        run_id = 1
        mock_dm.get.return_value = True
        engine = BaseEngine(self._application, mock_dm)
        engine._dm = mock_dm

        with self.assertRaises(Exception):
            engine.train(run_id)
예제 #5
0
    def test_train_agent_exist(self, mock_dm):

        run_id = 1
        mock_dm.get.return_value = True

        mock_rl_app = MockRLApplication()
        mock_engine_config = MockEngineConfig()
        engine = BaseEngine(mock_rl_app, mock_engine_config)
        engine._dm = mock_dm

        with self.assertRaises(Exception):
            engine.train(run_id)
예제 #6
0
    def run(self):
        engine = BaseEngine(self._application, self._dm)

        engine.init(force_run=True)

        logger.info("Training started")
        eval_avg_rwd = []
        for run_id in range(1, self._num_runs):
            engine.train(run_id)

            if run_id % self._eval_interval == 0:
                avg_rwd = self._evaluate_agent(run_id, self._num_eval_episodes)
                eval_avg_rwd.append(avg_rwd)

        logger.info("Training is done")
        logger.info("Eval result: %s" % str(eval_avg_rwd))
        return eval_avg_rwd
예제 #7
0
    def test_train_implicit_run_id(self, mock_dp_plan, mock_dp_init,
                                   mock_factory_build, mock_train_operation,
                                   mock_dm):

        run_id = 1
        mock_train_operation.run.return_value = None
        mock_factory_build.return_value = mock_train_operation
        mock_dp_plan.return_value = [
            OperationConfig(TrainAgentOperation.__name__, run_id)
        ]

        run_context = {
            TrainAgentOperation.TRAINING_GLOBAL_STEP: 0,
            "available_data": [("model", 0)]
        }

        mock_dm.get_latest.return_value = run_context
        mock_dm.get.return_value = False

        mock_engine_config = MockEngineConfig()
        mock_engine_config.start_dt = datetime(2019, 11, 2, 0, 0, 0)
        mock_engine_config.training_interval = timedelta(days=1)
        MockEngineConfig._get_current_datetime = lambda _: datetime(
            2019, 11, 3, 10, 0, 0)

        mock_rl_app = MockRLApplication()
        engine = BaseEngine(mock_rl_app, mock_engine_config)
        engine._dm = mock_dm
        engine.train()

        mock_dm.get.assert_any_call(DATANAME.MODEL, run_id)
        mock_dm.get_latest.assert_any_call(DATANAME.RUN_CONTEXT, run_id)
        mock_dp_init.assert_any_call(engine._get_available_operators(),
                                     mock_engine_config, [("model", 0)])
        mock_dp_plan.assert_any_call(DATANAME.MODEL, run_id)
        mock_train_operation.run.assert_any_call(run_id)