def test_get_run_id_1(self): td = TimingData(start_dt=datetime(2019, 11, 2, 0, 0, 0), training_interval=timedelta(days=1)) td._get_current_datetime = lambda: datetime(2019, 11, 2, 10, 0, 0) res = td.get_current_run_id() expected_res = 0 self.assertEquals(res, expected_res)
def test_timestep_operation(self, mock_data_manager): start_dt = datetime.now() training_interval = timedelta(days=1) mock_env = mock.MagicMock() conf = OmegaConf.from_dotlist( ["project.tensorboard_path=/tmp/test_tb/"]) mock_app = mock.MagicMock() mock_app.timing_data = TimingData(start_dt=start_dt, training_interval=training_interval) mock_app.env = mock_env mock_app.config = conf mock_env.env_id_cols = ["env_id_1", "env_id_2"] mock_env.ts_id_col = "ts_1" mock_env.obs_cols = ["obs_1", "obs_2"] mock_timestep = [{ "env_id_1": 1, "env_id_2": 2, "ts_1": 1, "discount": 1.0, "obs_1": 1, "obs_2": 2, "action": 1, "reward": 0.0, "step_type": 0 }] mock_timestep_df = self.spark.createDataFrame(mock_timestep) metadata_dict = {"available_data": [("test_data", 0)]} mock_env.build_time_steps = mock.MagicMock( return_value=mock_timestep_df) mock_data_manager.get_latest.return_value = metadata_dict run_id = 1 operation = BuildTimestepOperation(mock_app, mock_data_manager) operation.run(run_id) mock_data_manager.get_latest.assert_any_call(DATANAME.RUN_CONTEXT, run_id) expected_start_dt = start_dt expected_end_dt = start_dt + training_interval mock_env.build_time_steps.assert_called_with(expected_start_dt, expected_end_dt) expected_metadata = { "available_data": [("test_data", 0), (DATANAME.TIMESTEP, run_id)] } calls = [ call(mock_timestep_df, DATANAME.TIMESTEP, run_id), call(expected_metadata, DATANAME.RUN_CONTEXT, run_id) ] mock_data_manager.store.assert_has_calls(calls, any_order=False)
def test_error(self): operation_list = {"op_1": Operation1, "op_2": Operation2} available_data = [] today = datetime(date.today().year, date.today().month, date.today().day) timing_data = TimingData(start_dt=today, training_interval=timedelta(days=1)) planner = DependencyPlanner(operation_list, timing_data, available_data) with self.assertRaises(Exception): planner.plan(Operation2.output_dataname(), 2)
def test_no_dependencies(self): operation_list = {"op_1": Operation1, "op_2": Operation2} today = datetime(date.today().year, date.today().month, date.today().day) timing_data = TimingData(start_dt=today, training_interval=timedelta(days=1)) available_data = [] planner = DependencyPlanner(operation_list, timing_data, available_data) actual_plan = planner.plan(Operation1.output_dataname(), 2) expected_plan = [OperationConfig(Operation1.__name__, 2)] self.assertEquals(expected_plan, actual_plan)
def test_back_filling_multiple_runs_and_skip(self): operation_list = {"op_1": Operation1, "op_2": Operation2} available_data = [(Operation2.output_dataname(), 0), (Operation1.output_dataname(), 1)] today = datetime(date.today().year, date.today().month, date.today().day) timing_data = TimingData(start_dt=today, training_interval=timedelta(days=1)) planner = DependencyPlanner(operation_list, timing_data, available_data) actual_plan = planner.plan(Operation2.output_dataname(), 2) expected_plan = [ OperationConfig(Operation1.__name__, 2), OperationConfig(Operation2.__name__, 1), OperationConfig(Operation2.__name__, 2), ] self.assertEquals(expected_plan, actual_plan)
def timing_data(self): return TimingData(start_dt=self.first_timestep_dt, training_interval=self.training_interval, training_timestep_lag=self.training_timestep_lag, trajectory_training_window=self.config.trajectory. trajectory_training_window)