Ejemplo n.º 1
0
    def test_prepare_execution(self):
        input_dict = copy.deepcopy(self._input_dict)
        output_dict = copy.deepcopy(self._output_dict)
        exec_properties = copy.deepcopy(self._exec_properties)

        self._mock_metadata.previous_run.return_value = None
        self._mock_metadata.prepare_execution.return_value = self._execution_id
        driver = base_driver.BaseDriver(logger=self._logger,
                                        metadata_handler=self._mock_metadata)
        execution_decision = driver.prepare_execution(input_dict, output_dict,
                                                      exec_properties,
                                                      self._driver_options)
        self.assertEqual(self._execution_id, execution_decision.execution_id)
        self._check_output(execution_decision)
Ejemplo n.º 2
0
    def testPreExecutionNewExecution(self, mock_verify_input_artifacts_fn):
        input_dict = {
            'input_a':
            types.Channel(type_name='input_a',
                          artifacts=[types.Artifact(type_name='input_a')])
        }
        output_dict = {
            'output_a':
            types.Channel(type_name='output_a',
                          artifacts=[
                              types.Artifact(type_name='output_a',
                                             split='split')
                          ])
        }
        execution_id = 1
        context_id = 123
        exec_properties = copy.deepcopy(self._exec_properties)
        driver_args = data_types.DriverArgs(enable_cache=True)
        pipeline_info = data_types.PipelineInfo(
            pipeline_name='my_pipeline_name',
            pipeline_root=os.environ.get('TEST_TMP_DIR', self.get_temp_dir()),
            run_id='my_run_id')
        component_info = data_types.ComponentInfo(
            component_type='a.b.c', component_id='my_component_id')
        self._mock_metadata.get_artifacts_by_info.side_effect = list(
            input_dict['input_a'].get())
        self._mock_metadata.register_execution.side_effect = [execution_id]
        self._mock_metadata.previous_execution.side_effect = [None]
        self._mock_metadata.register_run_context_if_not_exists.side_effect = [
            context_id
        ]

        driver = base_driver.BaseDriver(metadata_handler=self._mock_metadata)
        execution_decision = driver.pre_execution(
            input_dict=input_dict,
            output_dict=output_dict,
            exec_properties=exec_properties,
            driver_args=driver_args,
            pipeline_info=pipeline_info,
            component_info=component_info)
        self.assertFalse(execution_decision.use_cached_results)
        self.assertEqual(execution_decision.execution_id, 1)
        self.assertItemsEqual(execution_decision.exec_properties,
                              exec_properties)
        self.assertEqual(
            execution_decision.output_dict['output_a'][0].uri,
            os.path.join(pipeline_info.pipeline_root,
                         component_info.component_id, 'output_a',
                         str(execution_id), 'split', ''))
Ejemplo n.º 3
0
  def test_prepare_execution(self):
    input_dict = copy.deepcopy(self._input_dict)
    output_dict = copy.deepcopy(self._output_dict)
    exec_properties = copy.deepcopy(self._exec_properties)

    log_root = os.path.join(
        os.environ.get('TEST_TMP_DIR', self.get_temp_dir()),
        self._testMethodName, 'log_root')
    self._mock_metadata.previous_run.return_value = None
    self._mock_metadata.prepare_execution.return_value = self._execution_id
    driver = base_driver.BaseDriver(log_root, self._mock_metadata)
    execution_decision = driver.prepare_execution(
        input_dict, output_dict, exec_properties, self._driver_options)
    self.assertEqual(self._execution_id, execution_decision.execution_id)
    self._check_output(execution_decision)
Ejemplo n.º 4
0
    def test_pre_execution_cached(self):
        input_dict = {
            'input_a':
            channel.Channel(type_name='input_a',
                            artifacts=[types.TfxArtifact(type_name='input_a')])
        }
        output_dict = {
            'output_a':
            channel.Channel(type_name='output_a',
                            artifacts=[
                                types.TfxArtifact(type_name='output_a',
                                                  split='split')
                            ])
        }
        execution_id = 1
        exec_properties = copy.deepcopy(self._exec_properties)
        driver_args = data_types.DriverArgs(worker_name='worker_name',
                                            base_output_dir='base',
                                            enable_cache=True)
        pipeline_info = data_types.PipelineInfo(
            pipeline_name='my_pipeline_name',
            pipeline_root=os.environ.get('TEST_TMP_DIR', self.get_temp_dir()),
            run_id='my_run_id')
        component_info = data_types.ComponentInfo(
            component_type='a.b.c', component_id='my_component_id')
        self._mock_metadata.get_artifacts_by_info.side_effect = list(
            input_dict['input_a'].get())
        self._mock_metadata.register_execution.side_effect = [execution_id]
        self._mock_metadata.previous_execution.side_effect = [2]
        self._mock_metadata.fetch_previous_result_artifacts.side_effect = [
            self._output_dict
        ]

        driver = base_driver.BaseDriver(metadata_handler=self._mock_metadata)
        execution_decision = driver.pre_execution(
            input_dict=input_dict,
            output_dict=output_dict,
            exec_properties=exec_properties,
            driver_args=driver_args,
            pipeline_info=pipeline_info,
            component_info=component_info)
        self.assertTrue(execution_decision.use_cached_results)
        self.assertEqual(execution_decision.execution_id, 1)
        self.assertItemsEqual(execution_decision.exec_properties,
                              exec_properties)
        self.assertItemsEqual(execution_decision.output_dict,
                              self._output_dict)
Ejemplo n.º 5
0
    def test_cached_execution(self):
        input_dict = copy.deepcopy(self._input_dict)
        output_dict = copy.deepcopy(self._output_dict)
        exec_properties = copy.deepcopy(self._exec_properties)

        cached_output_dict = copy.deepcopy(self._output_dict)
        for key, artifact_list in cached_output_dict.items():
            for artifact in artifact_list:
                artifact.uri = os.path.join(self._base_output_dir, key,
                                            str(self._execution_id), '')
                # valid cached artifacts must have an existing uri.
                tf.gfile.MakeDirs(artifact.uri)
        self._mock_metadata.previous_run.return_value = self._execution_id
        self._mock_metadata.fetch_previous_result_artifacts.return_value = cached_output_dict
        driver = base_driver.BaseDriver(metadata_handler=self._mock_metadata)
        execution_decision = driver.prepare_execution(input_dict, output_dict,
                                                      exec_properties,
                                                      self._driver_args)
        self.assertIsNone(execution_decision.execution_id)
        self._check_output(execution_decision)
Ejemplo n.º 6
0
    def test_artifact_missing(self):
        input_dict = copy.deepcopy(self._input_dict)
        input_dict['input_data'][0].uri = 'should/not/exist'
        output_dict = copy.deepcopy(self._output_dict)
        exec_properties = copy.deepcopy(self._exec_properties)
        driver_options = copy.deepcopy(self._driver_args)
        driver_options.enable_cache = False

        cached_output_dict = copy.deepcopy(self._output_dict)
        for key, artifact_list in cached_output_dict.items():
            for artifact in artifact_list:
                artifact.uri = os.path.join(self._base_output_dir, key,
                                            str(self._execution_id), '')
                # valid cached artifacts must have an existing uri.
                tf.gfile.MakeDirs(artifact.uri)

        self._mock_metadata.previous_run.return_value = self._execution_id
        self._mock_metadata.fetch_previous_result_artifacts.return_value = cached_output_dict
        driver = base_driver.BaseDriver(self._mock_metadata)
        with self.assertRaises(RuntimeError):
            driver.prepare_execution(input_dict, output_dict, exec_properties,
                                     driver_options)
Ejemplo n.º 7
0
  def test_cached_execution(self):
    input_dict = copy.deepcopy(self._input_dict)
    output_dict = copy.deepcopy(self._output_dict)
    exec_properties = copy.deepcopy(self._exec_properties)

    cached_output_dict = copy.deepcopy(self._output_dict)
    for key, artifact_list in cached_output_dict.items():
      for artifact in artifact_list:
        artifact.uri = os.path.join(self._base_output_dir, key,
                                    str(self._execution_id), '')
        # valid cached artifacts must have an existing uri.
        tf.gfile.MakeDirs(artifact.uri)
    log_root = os.path.join(
        os.environ.get('TEST_TMP_DIR', self.get_temp_dir()),
        self._testMethodName, 'log_root')
    self._mock_metadata.previous_run.return_value = self._execution_id
    self._mock_metadata.fetch_previous_result_artifacts.return_value = cached_output_dict
    driver = base_driver.BaseDriver(log_root, self._mock_metadata)
    execution_decision = driver.prepare_execution(
        input_dict, output_dict, exec_properties, self._driver_options)
    self.assertIsNone(execution_decision.execution_id)
    self._check_output(execution_decision)
Ejemplo n.º 8
0
    def test_no_cache_on_missing_uri(self):
        input_dict = copy.deepcopy(self._input_dict)
        output_dict = copy.deepcopy(self._output_dict)
        exec_properties = copy.deepcopy(self._exec_properties)

        cached_output_dict = copy.deepcopy(self._output_dict)
        for key, artifact_list in cached_output_dict.items():
            for artifact in artifact_list:
                artifact.uri = os.path.join(self._base_output_dir, key,
                                            str(self._execution_id), '')
                # Non existing output uri will force a cache miss.
                self.assertFalse(tf.gfile.Exists(artifact.uri))
        self._mock_metadata.previous_run.return_value = self._execution_id
        self._mock_metadata.fetch_previous_result_artifacts.return_value = cached_output_dict
        actual_execution_id = self._execution_id + 1
        self._mock_metadata.prepare_execution.return_value = actual_execution_id

        driver = base_driver.BaseDriver(metadata_handler=self._mock_metadata)
        execution_decision = driver.prepare_execution(input_dict, output_dict,
                                                      exec_properties,
                                                      self._driver_args)
        self.assertEqual(actual_execution_id, execution_decision.execution_id)
        self._check_output(execution_decision)
Ejemplo n.º 9
0
    def testPreExecutionNewExecution(self, mock_verify_input_artifacts_fn):
        self._mock_metadata.search_artifacts.return_value = list(
            self._input_dict['input_string'].get())
        self._mock_metadata.register_execution.side_effect = [self._execution]
        self._mock_metadata.get_cached_outputs.side_effect = [None]
        self._mock_metadata.register_run_context_if_not_exists.side_effect = [
            metadata_store_pb2.Context()
        ]

        driver = base_driver.BaseDriver(metadata_handler=self._mock_metadata)
        execution_decision = driver.pre_execution(
            input_dict=self._input_dict,
            output_dict=self._output_dict,
            exec_properties=self._exec_properties,
            driver_args=self._driver_args,
            pipeline_info=self._pipeline_info,
            component_info=self._component_info)
        self.assertFalse(execution_decision.use_cached_results)
        self.assertEqual(execution_decision.execution_id, self._execution_id)
        self.assertCountEqual(execution_decision.exec_properties,
                              self._exec_properties)
        self.assertEqual(
            execution_decision.output_dict['output_data'][0].uri,
            os.path.join(self._pipeline_info.pipeline_root,
                         self._component_info.component_id, 'output_data',
                         str(self._execution_id)))
        self.assertLen(execution_decision.output_dict['output_multi_data'], 2)
        for i in range(2):
            self.assertEqual(
                execution_decision.output_dict['output_multi_data'][i].uri,
                os.path.join(self._pipeline_info.pipeline_root,
                             self._component_info.component_id,
                             'output_multi_data', str(self._execution_id),
                             str(i)))
        self.assertEqual(
            execution_decision.input_dict['input_string'][0].value,
            _STRING_VALUE)
Ejemplo n.º 10
0
 def testVerifyInputArtifactsNotExists(self):
     driver = base_driver.BaseDriver(metadata_handler=self._mock_metadata)
     with self.assertRaises(RuntimeError):
         driver.verify_input_artifacts({'artifact': [_InputArtifact()]})
Ejemplo n.º 11
0
 def testVerifyInputArtifactsOk(self):
     driver = base_driver.BaseDriver(metadata_handler=self._mock_metadata)
     driver.verify_input_artifacts(self._input_artifacts)