def test_prepare_execution(self): input_dict = copy.deepcopy(self._input_dict) output_dict = copy.deepcopy(self._output_dict) exec_properties = copy.deepcopy(self._exec_properties) self._mock_metadata.previous_run.return_value = None self._mock_metadata.prepare_execution.return_value = self._execution_id driver = base_driver.BaseDriver(logger=self._logger, metadata_handler=self._mock_metadata) execution_decision = driver.prepare_execution(input_dict, output_dict, exec_properties, self._driver_options) self.assertEqual(self._execution_id, execution_decision.execution_id) self._check_output(execution_decision)
def testPreExecutionNewExecution(self, mock_verify_input_artifacts_fn): input_dict = { 'input_a': types.Channel(type_name='input_a', artifacts=[types.Artifact(type_name='input_a')]) } output_dict = { 'output_a': types.Channel(type_name='output_a', artifacts=[ types.Artifact(type_name='output_a', split='split') ]) } execution_id = 1 context_id = 123 exec_properties = copy.deepcopy(self._exec_properties) driver_args = data_types.DriverArgs(enable_cache=True) pipeline_info = data_types.PipelineInfo( pipeline_name='my_pipeline_name', pipeline_root=os.environ.get('TEST_TMP_DIR', self.get_temp_dir()), run_id='my_run_id') component_info = data_types.ComponentInfo( component_type='a.b.c', component_id='my_component_id') self._mock_metadata.get_artifacts_by_info.side_effect = list( input_dict['input_a'].get()) self._mock_metadata.register_execution.side_effect = [execution_id] self._mock_metadata.previous_execution.side_effect = [None] self._mock_metadata.register_run_context_if_not_exists.side_effect = [ context_id ] driver = base_driver.BaseDriver(metadata_handler=self._mock_metadata) execution_decision = driver.pre_execution( input_dict=input_dict, output_dict=output_dict, exec_properties=exec_properties, driver_args=driver_args, pipeline_info=pipeline_info, component_info=component_info) self.assertFalse(execution_decision.use_cached_results) self.assertEqual(execution_decision.execution_id, 1) self.assertItemsEqual(execution_decision.exec_properties, exec_properties) self.assertEqual( execution_decision.output_dict['output_a'][0].uri, os.path.join(pipeline_info.pipeline_root, component_info.component_id, 'output_a', str(execution_id), 'split', ''))
def test_prepare_execution(self): input_dict = copy.deepcopy(self._input_dict) output_dict = copy.deepcopy(self._output_dict) exec_properties = copy.deepcopy(self._exec_properties) log_root = os.path.join( os.environ.get('TEST_TMP_DIR', self.get_temp_dir()), self._testMethodName, 'log_root') self._mock_metadata.previous_run.return_value = None self._mock_metadata.prepare_execution.return_value = self._execution_id driver = base_driver.BaseDriver(log_root, self._mock_metadata) execution_decision = driver.prepare_execution( input_dict, output_dict, exec_properties, self._driver_options) self.assertEqual(self._execution_id, execution_decision.execution_id) self._check_output(execution_decision)
def test_pre_execution_cached(self): input_dict = { 'input_a': channel.Channel(type_name='input_a', artifacts=[types.TfxArtifact(type_name='input_a')]) } output_dict = { 'output_a': channel.Channel(type_name='output_a', artifacts=[ types.TfxArtifact(type_name='output_a', split='split') ]) } execution_id = 1 exec_properties = copy.deepcopy(self._exec_properties) driver_args = data_types.DriverArgs(worker_name='worker_name', base_output_dir='base', enable_cache=True) pipeline_info = data_types.PipelineInfo( pipeline_name='my_pipeline_name', pipeline_root=os.environ.get('TEST_TMP_DIR', self.get_temp_dir()), run_id='my_run_id') component_info = data_types.ComponentInfo( component_type='a.b.c', component_id='my_component_id') self._mock_metadata.get_artifacts_by_info.side_effect = list( input_dict['input_a'].get()) self._mock_metadata.register_execution.side_effect = [execution_id] self._mock_metadata.previous_execution.side_effect = [2] self._mock_metadata.fetch_previous_result_artifacts.side_effect = [ self._output_dict ] driver = base_driver.BaseDriver(metadata_handler=self._mock_metadata) execution_decision = driver.pre_execution( input_dict=input_dict, output_dict=output_dict, exec_properties=exec_properties, driver_args=driver_args, pipeline_info=pipeline_info, component_info=component_info) self.assertTrue(execution_decision.use_cached_results) self.assertEqual(execution_decision.execution_id, 1) self.assertItemsEqual(execution_decision.exec_properties, exec_properties) self.assertItemsEqual(execution_decision.output_dict, self._output_dict)
def test_cached_execution(self): input_dict = copy.deepcopy(self._input_dict) output_dict = copy.deepcopy(self._output_dict) exec_properties = copy.deepcopy(self._exec_properties) cached_output_dict = copy.deepcopy(self._output_dict) for key, artifact_list in cached_output_dict.items(): for artifact in artifact_list: artifact.uri = os.path.join(self._base_output_dir, key, str(self._execution_id), '') # valid cached artifacts must have an existing uri. tf.gfile.MakeDirs(artifact.uri) self._mock_metadata.previous_run.return_value = self._execution_id self._mock_metadata.fetch_previous_result_artifacts.return_value = cached_output_dict driver = base_driver.BaseDriver(metadata_handler=self._mock_metadata) execution_decision = driver.prepare_execution(input_dict, output_dict, exec_properties, self._driver_args) self.assertIsNone(execution_decision.execution_id) self._check_output(execution_decision)
def test_artifact_missing(self): input_dict = copy.deepcopy(self._input_dict) input_dict['input_data'][0].uri = 'should/not/exist' output_dict = copy.deepcopy(self._output_dict) exec_properties = copy.deepcopy(self._exec_properties) driver_options = copy.deepcopy(self._driver_args) driver_options.enable_cache = False cached_output_dict = copy.deepcopy(self._output_dict) for key, artifact_list in cached_output_dict.items(): for artifact in artifact_list: artifact.uri = os.path.join(self._base_output_dir, key, str(self._execution_id), '') # valid cached artifacts must have an existing uri. tf.gfile.MakeDirs(artifact.uri) self._mock_metadata.previous_run.return_value = self._execution_id self._mock_metadata.fetch_previous_result_artifacts.return_value = cached_output_dict driver = base_driver.BaseDriver(self._mock_metadata) with self.assertRaises(RuntimeError): driver.prepare_execution(input_dict, output_dict, exec_properties, driver_options)
def test_cached_execution(self): input_dict = copy.deepcopy(self._input_dict) output_dict = copy.deepcopy(self._output_dict) exec_properties = copy.deepcopy(self._exec_properties) cached_output_dict = copy.deepcopy(self._output_dict) for key, artifact_list in cached_output_dict.items(): for artifact in artifact_list: artifact.uri = os.path.join(self._base_output_dir, key, str(self._execution_id), '') # valid cached artifacts must have an existing uri. tf.gfile.MakeDirs(artifact.uri) log_root = os.path.join( os.environ.get('TEST_TMP_DIR', self.get_temp_dir()), self._testMethodName, 'log_root') self._mock_metadata.previous_run.return_value = self._execution_id self._mock_metadata.fetch_previous_result_artifacts.return_value = cached_output_dict driver = base_driver.BaseDriver(log_root, self._mock_metadata) execution_decision = driver.prepare_execution( input_dict, output_dict, exec_properties, self._driver_options) self.assertIsNone(execution_decision.execution_id) self._check_output(execution_decision)
def test_no_cache_on_missing_uri(self): input_dict = copy.deepcopy(self._input_dict) output_dict = copy.deepcopy(self._output_dict) exec_properties = copy.deepcopy(self._exec_properties) cached_output_dict = copy.deepcopy(self._output_dict) for key, artifact_list in cached_output_dict.items(): for artifact in artifact_list: artifact.uri = os.path.join(self._base_output_dir, key, str(self._execution_id), '') # Non existing output uri will force a cache miss. self.assertFalse(tf.gfile.Exists(artifact.uri)) self._mock_metadata.previous_run.return_value = self._execution_id self._mock_metadata.fetch_previous_result_artifacts.return_value = cached_output_dict actual_execution_id = self._execution_id + 1 self._mock_metadata.prepare_execution.return_value = actual_execution_id driver = base_driver.BaseDriver(metadata_handler=self._mock_metadata) execution_decision = driver.prepare_execution(input_dict, output_dict, exec_properties, self._driver_args) self.assertEqual(actual_execution_id, execution_decision.execution_id) self._check_output(execution_decision)
def testPreExecutionNewExecution(self, mock_verify_input_artifacts_fn): self._mock_metadata.search_artifacts.return_value = list( self._input_dict['input_string'].get()) self._mock_metadata.register_execution.side_effect = [self._execution] self._mock_metadata.get_cached_outputs.side_effect = [None] self._mock_metadata.register_run_context_if_not_exists.side_effect = [ metadata_store_pb2.Context() ] driver = base_driver.BaseDriver(metadata_handler=self._mock_metadata) execution_decision = driver.pre_execution( input_dict=self._input_dict, output_dict=self._output_dict, exec_properties=self._exec_properties, driver_args=self._driver_args, pipeline_info=self._pipeline_info, component_info=self._component_info) self.assertFalse(execution_decision.use_cached_results) self.assertEqual(execution_decision.execution_id, self._execution_id) self.assertCountEqual(execution_decision.exec_properties, self._exec_properties) self.assertEqual( execution_decision.output_dict['output_data'][0].uri, os.path.join(self._pipeline_info.pipeline_root, self._component_info.component_id, 'output_data', str(self._execution_id))) self.assertLen(execution_decision.output_dict['output_multi_data'], 2) for i in range(2): self.assertEqual( execution_decision.output_dict['output_multi_data'][i].uri, os.path.join(self._pipeline_info.pipeline_root, self._component_info.component_id, 'output_multi_data', str(self._execution_id), str(i))) self.assertEqual( execution_decision.input_dict['input_string'][0].value, _STRING_VALUE)
def testVerifyInputArtifactsNotExists(self): driver = base_driver.BaseDriver(metadata_handler=self._mock_metadata) with self.assertRaises(RuntimeError): driver.verify_input_artifacts({'artifact': [_InputArtifact()]})
def testVerifyInputArtifactsOk(self): driver = base_driver.BaseDriver(metadata_handler=self._mock_metadata) driver.verify_input_artifacts(self._input_artifacts)