コード例 #1
0
    def _callImporterDriver(self, reimport: bool):
        with metadata.Metadata(connection_config=self.connection_config) as m:
            m.publish_artifacts(self.existing_artifacts)
            driver = importer_node.ImporterDriver(metadata_handler=m)
            execution_result = driver.pre_execution(
                component_info=self.component_info,
                pipeline_info=self.pipeline_info,
                driver_args=self.driver_args,
                input_dict={},
                output_dict=self.output_dict,
                exec_properties={
                    importer_node.SOURCE_URI_KEY: self.source_uri,
                    importer_node.REIMPORT_OPTION_KEY: reimport,
                    importer_node.SPLIT_KEY: self.split,
                })
            self.assertFalse(execution_result.use_cached_results)
            self.assertEmpty(execution_result.input_dict)
            self.assertEqual(
                execution_result.output_dict[importer_node.IMPORT_RESULT_KEY]
                [0].uri, self.source_uri[0])
            self.assertEqual(
                execution_result.output_dict[importer_node.IMPORT_RESULT_KEY]
                [0].id, 3 if reimport else 1)

            self.assertNotEmpty(
                self.output_dict[importer_node.IMPORT_RESULT_KEY].get())

            results = self.output_dict[importer_node.IMPORT_RESULT_KEY].get()
            for res, uri, split in zip(results, self.source_uri, self.split):
                self.assertEqual(res.uri, uri)
                self.assertEqual(
                    artifact_utils.decode_split_names(res.split_names)[0],
                    split)
コード例 #2
0
ファイル: importer_node_test.py プロジェクト: zvrr/tfx
    def _callImporterDriver(self, reimport: bool):
        with metadata.Metadata(connection_config=self.connection_config) as m:
            m.publish_artifacts(self.existing_artifacts)
            driver = importer_node.ImporterDriver(metadata_handler=m)
            execution_result = driver.pre_execution(
                component_info=self.component_info,
                pipeline_info=self.pipeline_info,
                driver_args=self.driver_args,
                input_dict={},
                output_dict=self.output_dict,
                exec_properties={
                    importer_node.SOURCE_URI_KEY: self.source_uri,
                    importer_node.REIMPORT_OPTION_KEY: reimport,
                    importer_node.PROPERTIES_KEY: self.properties,
                    importer_node.CUSTOM_PROPERTIES_KEY:
                    self.custom_properties,
                })
            self.assertFalse(execution_result.use_cached_results)
            self.assertEmpty(execution_result.input_dict)
            self.assertEqual(
                1,
                len(execution_result.output_dict[
                    importer_node.IMPORT_RESULT_KEY]))
            self.assertEqual(
                execution_result.output_dict[importer_node.IMPORT_RESULT_KEY]
                [0].uri, self.source_uri)

            self.assertNotEmpty(
                self.output_dict[importer_node.IMPORT_RESULT_KEY].get())

            results = self.output_dict[importer_node.IMPORT_RESULT_KEY].get()
            self.assertEqual(1, len(results))
            result = results[0]
            self.assertEqual(result.uri, result.uri)
            for key, value in self.properties.items():
                self.assertEqual(value, getattr(result, key))
            for key, value in self.custom_properties.items():
                if isinstance(value, int):
                    self.assertEqual(value,
                                     result.get_int_custom_property(key))
                elif isinstance(value, (Text, bytes)):
                    self.assertEqual(value,
                                     result.get_string_custom_property(key))
                else:
                    raise ValueError('Invalid custom property value: %r.' %
                                     value)
コード例 #3
0
ファイル: importer_node_test.py プロジェクト: zzhmtxxhh/tfx
 def _callImporterDriver(self, reimport: bool):
   with metadata.Metadata(connection_config=self.connection_config) as m:
     m.publish_artifacts([self.existing_artifact])
     driver = importer_node.ImporterDriver(metadata_handler=m)
     execution_result = driver.pre_execution(
         component_info=self.component_info,
         pipeline_info=self.pipeline_info,
         driver_args=self.driver_args,
         input_dict={},
         output_dict=self.output_dict,
         exec_properties={
             importer_node.SOURCE_URI_KEY: self.source_uri,
             importer_node.REIMPORT_OPTION_KEY: reimport
         })
     self.assertFalse(execution_result.use_cached_results)
     self.assertEmpty(execution_result.input_dict)
     self.assertEqual(
         execution_result.output_dict[importer_node.IMPORT_RESULT_KEY][0].uri,
         self.source_uri)
     self.assertEqual(
         execution_result.output_dict[importer_node.IMPORT_RESULT_KEY][0].id,
         2 if reimport else 1)