def testImporterDefinitionWithSingleUri(self): impt = importer_node.ImporterNode( instance_name='my_importer', source_uri='m/y/u/r/i', properties={ 'split_names': '["train", "eval"]', }, custom_properties={ 'str_custom_property': 'abc', 'int_custom_property': 123, }, artifact_type=standard_artifacts.Examples) self.assertDictEqual( impt.exec_properties, { importer_node.SOURCE_URI_KEY: 'm/y/u/r/i', importer_node.REIMPORT_OPTION_KEY: False, importer_node.PROPERTIES_KEY: { 'split_names': '["train", "eval"]', }, importer_node.CUSTOM_PROPERTIES_KEY: { 'str_custom_property': 'abc', 'int_custom_property': 123, }, }) self.assertEmpty(impt.inputs.get_all()) self.assertEqual(impt.outputs[importer_node.IMPORT_RESULT_KEY].type, standard_artifacts.Examples)
def testImporterDefinitionWithMultipleUrisBadSplitSpecification(self): with self.assertRaises(ValueError): _ = importer_node.ImporterNode( instance_name='my_importer', source_uri=['m/y/u/r/i/1', 'm/y/u/r/i/2'], artifact_type=standard_artifacts.Examples, )
def testIsImporter(self): impt = importer.Importer(source_uri="uri/to/schema", artifact_type=standard_artifacts.Schema) self.assertTrue(compiler_utils.is_importer(impt)) impt = legacy_importer_node.ImporterNode( source_uri="uri/to/schema", artifact_type=standard_artifacts.Schema) self.assertTrue(compiler_utils.is_importer(impt)) example_gen = CsvExampleGen(input_base="data_path") self.assertFalse(compiler_utils.is_importer(example_gen))
def testImporterNodeDumpsJsonRoundtrip(self): instance_name = 'my_importer' source_uris = ['m/y/u/r/i'] impt = importer_node.ImporterNode( instance_name=instance_name, source_uri=source_uris, artifact_type=standard_artifacts.Examples) # The following line will raise an assertion if object not JSONable. json_text = json_utils.dumps(impt) actual_obj = json_utils.loads(json_text) self.assertEqual(actual_obj._instance_name, instance_name) self.assertEqual(actual_obj._source_uri, source_uris)
def testImporterDefinition(self): impt = importer_node.ImporterNode( instance_name='my_importer', source_uri='m/y/u/r/i', artifact_type=standard_artifacts.Examples) self.assertDictEqual( impt.exec_properties, { importer_node.SOURCE_URI_KEY: 'm/y/u/r/i', importer_node.REIMPORT_OPTION_KEY: False }) self.assertEmpty(impt.inputs.get_all()) self.assertEqual( impt.outputs.get_all()[importer_node.IMPORT_RESULT_KEY].type_name, standard_artifacts.Examples.TYPE_NAME)
def testIsImporter(self): importer = importer_node.ImporterNode( instance_name="import_schema", source_uri="uri/to/schema", artifact_type=standard_artifacts.Schema) self.assertTrue(compiler_utils.is_importer(importer)) importer = legacy_importer_node.ImporterNode( instance_name="import_schema", source_uri="uri/to/schema", artifact_type=standard_artifacts.Schema) self.assertTrue(compiler_utils.is_importer(importer)) example_gen = CsvExampleGen(input=external_input("data_path")) self.assertFalse(compiler_utils.is_importer(example_gen))
def testImporterDefinitionWithMultipleUris(self): impt = importer_node.ImporterNode( instance_name='my_importer', source_uri=['m/y/u/r/i/1', 'm/y/u/r/i/2'], artifact_type=standard_artifacts.Examples, split=['train', 'eval']) self.assertDictEqual( impt.exec_properties, { importer_node.SOURCE_URI_KEY: ['m/y/u/r/i/1', 'm/y/u/r/i/2'], importer_node.REIMPORT_OPTION_KEY: False, importer_node.SPLIT_KEY: ['train', 'eval'], }) self.assertEqual([ artifact_utils.decode_split_names(s.split_names)[0] for s in impt.outputs.get_all()[importer_node.IMPORT_RESULT_KEY].get() ], ['train', 'eval'])
"""Integration tests for AI Platform Training component.""" import tensorflow as tf from tfx.components.common_nodes import importer_node from tfx.dsl.component.experimental import placeholders from tfx.orchestration import pipeline from tfx.orchestration.kubeflow.v2 import test_utils from tfx.orchestration.kubeflow.v2.components.experimental import ai_platform_training_component from tfx.types import standard_artifacts from tfx.types.experimental import simple_artifacts _PIPELINE_NAME = 'aip_training_component_pipeline' _EXAMPLE_IMPORTER = importer_node.ImporterNode( instance_name='examples', artifact_type=simple_artifacts.File, reimport=False, source_uri='gs://tfx-oss-testing-bucket/sample-data/mnist') _TRAIN = ai_platform_training_component.create_ai_platform_training( name='simple_aip_training', project_id='tfx-oss-testing', region='us-central1', image_uri='gcr.io/tfx-oss-testing/caip-training:tfx-test', args=[ '--dataset', placeholders.InputUriPlaceholder('examples'), '--model-dir', placeholders.OutputUriPlaceholder('model'), '--lr', placeholders.InputValuePlaceholder('learning_rate'), ],