示例#1
0
    def test_init_success(self, mock_dm):
        mock_rl_app = MockRLApplication()
        mock_engine_config = MockEngineConfig()
        mock_engine_config.start_dt = datetime(2019, 11, 2, 0, 0, 0)
        mock_engine_config.training_interval = timedelta(days=1)
        MockEngineConfig._get_current_datetime = lambda _: datetime(
            2019, 11, 2, 10, 0, 0)

        run_context = {
            TrainAgentOperation.TRAINING_GLOBAL_STEP: 0,
            "available_data": []
        }
        mock_dm.get.return_value = run_context

        engine = BaseEngine(mock_rl_app, mock_engine_config)
        engine._dm = mock_dm
        engine.init()

        expected_run_context = {
            TrainAgentOperation.TRAINING_GLOBAL_STEP: 0,
            "available_data": [(DATANAME.MODEL, 0)]
        }
        mock_dm.get.assert_any_call(DATANAME.RUN_CONTEXT, 0)

        store_calls = [
            call(mock_rl_app.test_agent, DATANAME.MODEL, 0),
            call(expected_run_context, DATANAME.RUN_CONTEXT, 0)
        ]
        mock_dm.store.assert_has_calls(store_calls, any_order=False)
示例#2
0
    def test_get_run_id_1(self):
        mock_engine_config = MockEngineConfig()
        mock_engine_config.training_interval = timedelta(days=1)
        mock_engine_config.start_dt = datetime(2019, 11, 2, 0, 0, 0)
        MockEngineConfig._get_current_datetime = lambda _: datetime(
            2019, 11, 2, 10, 0, 0)

        res = mock_engine_config.get_current_run_id()

        expected_res = 0
        self.assertEquals(res, expected_res)
    def test_timestep_operation(self, mock_abstract_app, mock_data_manager):
        start_dt = datetime.now()
        training_interval = timedelta(days=1)

        mock_engine_config = MockEngineConfig()
        mock_engine_config.start_dt = start_dt
        mock_engine_config.training_interval = training_interval

        mock_timestep = [{
            "env_id_1": 1,
            "env_id_2": 2,
            "ts_1": 1,
            "obs_1": 1,
            "obs_2": 2,
            "action": 1,
            "reward": 0.0,
            "step_type": 0
        }]
        mock_timestep_df = self.spark.createDataFrame(mock_timestep)

        metadata_dict = {"available_data": [("test_data", 0)]}

        mock_abstract_app.build_time_steps.return_value = mock_timestep_df
        mock_data_manager.get_latest.return_value = metadata_dict

        run_id = 1
        operation = BuildTimestepOperation(mock_abstract_app,
                                           mock_engine_config,
                                           mock_data_manager)
        operation.run(run_id)

        mock_data_manager.get_latest.assert_any_call(DATANAME.RUN_CONTEXT,
                                                     run_id)

        expected_start_dt = start_dt
        expected_end_dt = start_dt + training_interval
        mock_abstract_app.build_time_steps.assert_called_with(
            expected_start_dt, expected_end_dt)

        expected_metadata = {
            "available_data": [("test_data", 0), (DATANAME.TIMESTEP, run_id)]
        }

        calls = [
            call(mock_timestep_df, DATANAME.TIMESTEP, run_id),
            call(expected_metadata, DATANAME.RUN_CONTEXT, run_id)
        ]
        mock_data_manager.store.assert_has_calls(calls, any_order=False)
示例#4
0
    def test_train_agent_exist_force(self, mock_dp_plan, mock_dp_init,
                                     mock_factory_build, mock_train_operation,
                                     mock_dm):

        metadata = {
            TrainAgentOperation.TRAINING_GLOBAL_STEP: 0,
            "available_data": [("model", 0)]
        }
        mock_dm.get_latest.return_value = metadata
        mock_dm.get.return_value = True

        mock_train_operation.run.return_value = None
        mock_factory_build.return_value = mock_train_operation

        run_id = 1
        mock_dp_plan.return_value = [
            OperationConfig(TrainAgentOperation.__name__, run_id)
        ]

        mock_rl_app = MockRLApplication()
        mock_engine_config = MockEngineConfig()
        engine = BaseEngine(mock_rl_app, mock_engine_config)
        engine._dm = mock_dm
        engine.train(run_id, True)

        mock_dm.get_latest.assert_any_call(DATANAME.RUN_CONTEXT, run_id)
        mock_dp_init.assert_any_call(engine._get_available_operators(),
                                     mock_engine_config, [("model", 0)])
        mock_dp_plan.assert_any_call(DATANAME.MODEL, run_id)
        mock_train_operation.run.assert_any_call(run_id)
    def test_error(self):
        operation_list = {"op_1": Operation1, "op_2": Operation2}
        mock_rl_engine = MockEngineConfig()
        available_data = []
        planner = DependencyPlanner(operation_list, mock_rl_engine,
                                    available_data)

        with self.assertRaises(Exception):
            planner.plan(Operation2.output_dataname(), 2)
    def test_no_dependencies(self):
        operation_list = {"op_1": Operation1, "op_2": Operation2}
        mock_rl_engine = MockEngineConfig()
        available_data = []
        planner = DependencyPlanner(operation_list, mock_rl_engine,
                                    available_data)

        actual_plan = planner.plan(Operation1.output_dataname(), 2)
        expected_plan = [OperationConfig(Operation1.__name__, 2)]
        self.assertEquals(expected_plan, actual_plan)
示例#7
0
    def test_train_agent_exist(self, mock_dm):

        run_id = 1
        mock_dm.get.return_value = True

        mock_rl_app = MockRLApplication()
        mock_engine_config = MockEngineConfig()
        engine = BaseEngine(mock_rl_app, mock_engine_config)
        engine._dm = mock_dm

        with self.assertRaises(Exception):
            engine.train(run_id)
    def test_back_filling_multiple_runs_and_skip(self):
        operation_list = {"op_1": Operation1, "op_2": Operation2}
        mock_rl_engine = MockEngineConfig()
        available_data = [(Operation2.output_dataname(), 0),
                          (Operation1.output_dataname(), 1)]
        planner = DependencyPlanner(operation_list, mock_rl_engine,
                                    available_data)

        actual_plan = planner.plan(Operation2.output_dataname(), 2)
        expected_plan = [
            OperationConfig(Operation1.__name__, 2),
            OperationConfig(Operation2.__name__, 1),
            OperationConfig(Operation2.__name__, 2),
        ]
        self.assertEquals(expected_plan, actual_plan)
示例#9
0
    def test_train_implicit_run_id(self, mock_dp_plan, mock_dp_init,
                                   mock_factory_build, mock_train_operation,
                                   mock_dm):

        run_id = 1
        mock_train_operation.run.return_value = None
        mock_factory_build.return_value = mock_train_operation
        mock_dp_plan.return_value = [
            OperationConfig(TrainAgentOperation.__name__, run_id)
        ]

        run_context = {
            TrainAgentOperation.TRAINING_GLOBAL_STEP: 0,
            "available_data": [("model", 0)]
        }

        mock_dm.get_latest.return_value = run_context
        mock_dm.get.return_value = False

        mock_engine_config = MockEngineConfig()
        mock_engine_config.start_dt = datetime(2019, 11, 2, 0, 0, 0)
        mock_engine_config.training_interval = timedelta(days=1)
        MockEngineConfig._get_current_datetime = lambda _: datetime(
            2019, 11, 3, 10, 0, 0)

        mock_rl_app = MockRLApplication()
        engine = BaseEngine(mock_rl_app, mock_engine_config)
        engine._dm = mock_dm
        engine.train()

        mock_dm.get.assert_any_call(DATANAME.MODEL, run_id)
        mock_dm.get_latest.assert_any_call(DATANAME.RUN_CONTEXT, run_id)
        mock_dp_init.assert_any_call(engine._get_available_operators(),
                                     mock_engine_config, [("model", 0)])
        mock_dp_plan.assert_any_call(DATANAME.MODEL, run_id)
        mock_train_operation.run.assert_any_call(run_id)
示例#10
0
    def test_init_run_id_not_0(self):
        mock_rl_app = MockRLApplication()
        mock_engine_config = MockEngineConfig()
        mock_engine_config.start_dt = datetime(2019, 11, 2, 0, 0, 0)
        mock_engine_config.training_interval = timedelta(days=1)
        mock_engine_config = MockEngineConfig()

        engine = BaseEngine(mock_rl_app, mock_engine_config)

        with self.assertRaises(Exception):
            engine.init()
示例#11
0
    def test_run(self, mock_rb, mock_tb, mock_data_manager, mock_agent):

        mock_engine_config = MockEngineConfig()
        mock_rl_app = MockRLApplication()
        MockRLApplication._env_id_cols = ["env_id_1"]
        MockRLApplication._ts_id_col = "ts_1"

        mock_rl_app.agent_config = {"num_iterations": 1, "mini_batch_size": 32}

        run_context_dict = {
            "available_data": [("test_data", 0)],
            TrainAgentOperation.TRAINING_GLOBAL_STEP: 0
        }
        mock_data_manager.get_latest.return_value = run_context_dict

        mock_timestep = [{
            "env_id_1": 1,
            "env_id_2": 2,
            "ts_1": 1,
            "obs_1": 1,
            "obs_2": 2,
            "action": 1,
            "reward": 0.0,
            "step_type": 0
        }]
        mock_timestep_df = self.spark.createDataFrame(mock_timestep)

        def get_side_effect(data_name, _):
            if data_name == DATANAME.TIMESTEP:
                return mock_timestep_df
            else:
                return mock_agent

        mock_data_manager.get.side_effect = get_side_effect
        fake_mini_batch = "fake_mini_batch"

        class MockMeta(object):
            def __init__(self, prob):
                self.probabilities = prob

        fake_meta = MockMeta(0.1)
        mock_rb.get_batch.return_value = fake_mini_batch, fake_meta
        mock_traj_dict = {"observations": [1, 2, 3]}
        mock_tb.run.return_value = mock_traj_dict

        class MockLoss(object):
            def __init__(self, loss):
                self.loss = loss

        mock_loss = MockLoss("mock_loss")
        mock_agent.train.return_value = mock_loss

        run_id = 5
        operation = TrainAgentOperation(mock_rl_app, mock_engine_config,
                                        mock_data_manager)
        operation._trajectory_builder = mock_tb
        operation._replay_buffer = mock_rb
        operation.run(run_id)

        get_calls = [
            call(DATANAME.MODEL, run_id - 1),
            call(DATANAME.TIMESTEP, run_id)
        ]
        mock_data_manager.get.assert_has_calls(get_calls, any_order=True)

        mock_tb.run.assert_called_with(mock_timestep_df.collect())
        mock_rb.add_batch.assert_called_with(mock_traj_dict)
        mock_rb.pre_process.assert_called_with(0)
        mock_rb.get_batch.assert_called_with(
            mock_rl_app.training_config["mini_batch_size"])
        mock_agent.train.assert_called_with(fake_mini_batch,
                                            fake_meta.probabilities)
        mock_rb.post_process.assert_called_with(fake_meta, mock_loss, 0)

        expected_metadata = {
            "available_data": [("test_data", 0), (DATANAME.MODEL, run_id)],
            TrainAgentOperation.TRAINING_GLOBAL_STEP:
            0 + mock_rl_app.training_config["num_iterations"]
        }
        store_calls = [
            call(mock_agent, DATANAME.MODEL, run_id),
            call(expected_metadata, DATANAME.RUN_CONTEXT, run_id)
        ]
        mock_data_manager.store.assert_has_calls(store_calls, any_order=True)