Ejemplo n.º 1
0
 def testImporterDefinitionWithSingleUri(self):
     impt = importer_node.ImporterNode(
         instance_name='my_importer',
         source_uri='m/y/u/r/i',
         properties={
             'split_names': '["train", "eval"]',
         },
         custom_properties={
             'str_custom_property': 'abc',
             'int_custom_property': 123,
         },
         artifact_type=standard_artifacts.Examples)
     self.assertDictEqual(
         impt.exec_properties, {
             importer_node.SOURCE_URI_KEY: 'm/y/u/r/i',
             importer_node.REIMPORT_OPTION_KEY: False,
             importer_node.PROPERTIES_KEY: {
                 'split_names': '["train", "eval"]',
             },
             importer_node.CUSTOM_PROPERTIES_KEY: {
                 'str_custom_property': 'abc',
                 'int_custom_property': 123,
             },
         })
     self.assertEmpty(impt.inputs.get_all())
     self.assertEqual(impt.outputs[importer_node.IMPORT_RESULT_KEY].type,
                      standard_artifacts.Examples)
Ejemplo n.º 2
0
 def testImporterDefinitionWithMultipleUrisBadSplitSpecification(self):
     with self.assertRaises(ValueError):
         _ = importer_node.ImporterNode(
             instance_name='my_importer',
             source_uri=['m/y/u/r/i/1', 'm/y/u/r/i/2'],
             artifact_type=standard_artifacts.Examples,
         )
Ejemplo n.º 3
0
    def testIsImporter(self):
        impt = importer.Importer(source_uri="uri/to/schema",
                                 artifact_type=standard_artifacts.Schema)
        self.assertTrue(compiler_utils.is_importer(impt))
        impt = legacy_importer_node.ImporterNode(
            source_uri="uri/to/schema",
            artifact_type=standard_artifacts.Schema)
        self.assertTrue(compiler_utils.is_importer(impt))

        example_gen = CsvExampleGen(input_base="data_path")
        self.assertFalse(compiler_utils.is_importer(example_gen))
Ejemplo n.º 4
0
    def testImporterNodeDumpsJsonRoundtrip(self):
        instance_name = 'my_importer'
        source_uris = ['m/y/u/r/i']
        impt = importer_node.ImporterNode(
            instance_name=instance_name,
            source_uri=source_uris,
            artifact_type=standard_artifacts.Examples)

        # The following line will raise an assertion if object not JSONable.
        json_text = json_utils.dumps(impt)

        actual_obj = json_utils.loads(json_text)
        self.assertEqual(actual_obj._instance_name, instance_name)
        self.assertEqual(actual_obj._source_uri, source_uris)
Ejemplo n.º 5
0
 def testImporterDefinition(self):
     impt = importer_node.ImporterNode(
         instance_name='my_importer',
         source_uri='m/y/u/r/i',
         artifact_type=standard_artifacts.Examples)
     self.assertDictEqual(
         impt.exec_properties, {
             importer_node.SOURCE_URI_KEY: 'm/y/u/r/i',
             importer_node.REIMPORT_OPTION_KEY: False
         })
     self.assertEmpty(impt.inputs.get_all())
     self.assertEqual(
         impt.outputs.get_all()[importer_node.IMPORT_RESULT_KEY].type_name,
         standard_artifacts.Examples.TYPE_NAME)
Ejemplo n.º 6
0
    def testIsImporter(self):
        importer = importer_node.ImporterNode(
            instance_name="import_schema",
            source_uri="uri/to/schema",
            artifact_type=standard_artifacts.Schema)
        self.assertTrue(compiler_utils.is_importer(importer))
        importer = legacy_importer_node.ImporterNode(
            instance_name="import_schema",
            source_uri="uri/to/schema",
            artifact_type=standard_artifacts.Schema)
        self.assertTrue(compiler_utils.is_importer(importer))

        example_gen = CsvExampleGen(input=external_input("data_path"))
        self.assertFalse(compiler_utils.is_importer(example_gen))
Ejemplo n.º 7
0
 def testImporterDefinitionWithMultipleUris(self):
     impt = importer_node.ImporterNode(
         instance_name='my_importer',
         source_uri=['m/y/u/r/i/1', 'm/y/u/r/i/2'],
         artifact_type=standard_artifacts.Examples,
         split=['train', 'eval'])
     self.assertDictEqual(
         impt.exec_properties, {
             importer_node.SOURCE_URI_KEY: ['m/y/u/r/i/1', 'm/y/u/r/i/2'],
             importer_node.REIMPORT_OPTION_KEY: False,
             importer_node.SPLIT_KEY: ['train', 'eval'],
         })
     self.assertEqual([
         artifact_utils.decode_split_names(s.split_names)[0] for s in
         impt.outputs.get_all()[importer_node.IMPORT_RESULT_KEY].get()
     ], ['train', 'eval'])
"""Integration tests for AI Platform Training component."""

import tensorflow as tf
from tfx.components.common_nodes import importer_node
from tfx.dsl.component.experimental import placeholders
from tfx.orchestration import pipeline
from tfx.orchestration.kubeflow.v2 import test_utils
from tfx.orchestration.kubeflow.v2.components.experimental import ai_platform_training_component
from tfx.types import standard_artifacts
from tfx.types.experimental import simple_artifacts

_PIPELINE_NAME = 'aip_training_component_pipeline'

_EXAMPLE_IMPORTER = importer_node.ImporterNode(
    instance_name='examples',
    artifact_type=simple_artifacts.File,
    reimport=False,
    source_uri='gs://tfx-oss-testing-bucket/sample-data/mnist')
_TRAIN = ai_platform_training_component.create_ai_platform_training(
    name='simple_aip_training',
    project_id='tfx-oss-testing',
    region='us-central1',
    image_uri='gcr.io/tfx-oss-testing/caip-training:tfx-test',
    args=[
        '--dataset',
        placeholders.InputUriPlaceholder('examples'),
        '--model-dir',
        placeholders.OutputUriPlaceholder('model'),
        '--lr',
        placeholders.InputValuePlaceholder('learning_rate'),
    ],