예제 #1
0
    def testBuildFileBasedExampleGen(self):
        beam_pipeline_args = ['runner=DataflowRunner']
        example_gen = components.CsvExampleGen(input_base='path/to/data/root')
        deployment_config = pipeline_pb2.PipelineDeploymentConfig()
        component_defs = {}
        my_builder = step_builder.StepBuilder(
            node=example_gen,
            image='gcr.io/tensorflow/tfx:latest',
            image_cmds=_TEST_CMDS,
            beam_pipeline_args=beam_pipeline_args,
            deployment_config=deployment_config,
            component_defs=component_defs)
        actual_step_spec = self._sole(my_builder.build())
        actual_component_def = self._sole(component_defs)

        self.assertProtoEquals(
            test_utils.get_proto_from_test_data(
                'expected_csv_example_gen_component.pbtxt',
                pipeline_pb2.ComponentSpec()), actual_component_def)
        self.assertProtoEquals(
            test_utils.get_proto_from_test_data(
                'expected_csv_example_gen_task.pbtxt',
                pipeline_pb2.PipelineTaskSpec()), actual_step_spec)
        self.assertProtoEquals(
            test_utils.get_proto_from_test_data(
                'expected_csv_example_gen_executor.pbtxt',
                pipeline_pb2.PipelineDeploymentConfig()), deployment_config)
예제 #2
0
    def testBuildImporterWithRuntimeParam(self):
        param = data_types.RuntimeParameter(name='runtime_flag', ptype=str)
        impt = importer.Importer(
            source_uri=param,
            artifact_type=standard_artifacts.Examples).with_id('my_importer')
        deployment_config = pipeline_pb2.PipelineDeploymentConfig()
        component_defs = {}
        with parameter_utils.ParameterContext() as pc:
            my_builder = step_builder.StepBuilder(
                node=impt,
                deployment_config=deployment_config,
                component_defs=component_defs)
            actual_step_spec = self._sole(my_builder.build())
        actual_component_def = self._sole(component_defs)

        self.assertProtoEquals(
            test_utils.get_proto_from_test_data(
                'expected_importer_component_with_runtime_param.pbtxt',
                pipeline_pb2.ComponentSpec()), actual_component_def)
        self.assertProtoEquals(
            test_utils.get_proto_from_test_data(
                'expected_importer_task_with_runtime_param.pbtxt',
                pipeline_pb2.PipelineTaskSpec()), actual_step_spec)
        self.assertProtoEquals(
            test_utils.get_proto_from_test_data(
                'expected_importer_executor_with_runtime_param.pbtxt',
                pipeline_pb2.PipelineDeploymentConfig()), deployment_config)
        self.assertListEqual([param], pc.parameters)
예제 #3
0
    def testBuildImporter(self):
        impt = importer.Importer(instance_name='my_importer',
                                 source_uri='m/y/u/r/i',
                                 properties={
                                     'split_names': '["train", "eval"]',
                                 },
                                 custom_properties={
                                     'str_custom_property': 'abc',
                                     'int_custom_property': 123,
                                 },
                                 artifact_type=standard_artifacts.Examples)
        deployment_config = pipeline_pb2.PipelineDeploymentConfig()
        component_defs = {}
        my_builder = step_builder.StepBuilder(
            node=impt,
            deployment_config=deployment_config,
            component_defs=component_defs)
        actual_step_spec = self._sole(my_builder.build())
        actual_component_def = self._sole(component_defs)

        self.assertProtoEquals(
            test_utils.get_proto_from_test_data(
                'expected_importer_component.pbtxt',
                pipeline_pb2.ComponentSpec()), actual_component_def)
        self.assertProtoEquals(
            test_utils.get_proto_from_test_data(
                'expected_importer_task.pbtxt',
                pipeline_pb2.PipelineTaskSpec()), actual_step_spec)
        self.assertProtoEquals(
            test_utils.get_proto_from_test_data(
                'expected_importer_executor.pbtxt',
                pipeline_pb2.PipelineDeploymentConfig()), deployment_config)
예제 #4
0
    def testBuildFileBasedExampleGenWithInputConfig(self):
        input_config = example_gen_pb2.Input(splits=[
            example_gen_pb2.Input.Split(name='train', pattern='*train.tfr'),
            example_gen_pb2.Input.Split(name='eval', pattern='*test.tfr')
        ])
        example_gen = components.ImportExampleGen(
            input_base='path/to/data/root', input_config=input_config)
        deployment_config = pipeline_pb2.PipelineDeploymentConfig()
        component_defs = {}
        my_builder = step_builder.StepBuilder(
            node=example_gen,
            image='gcr.io/tensorflow/tfx:latest',
            deployment_config=deployment_config,
            component_defs=component_defs)
        actual_step_spec = self._sole(my_builder.build())
        actual_component_def = self._sole(component_defs)

        self.assertProtoEquals(
            test_utils.get_proto_from_test_data(
                'expected_import_example_gen_component.pbtxt',
                pipeline_pb2.ComponentSpec()), actual_component_def)
        self.assertProtoEquals(
            test_utils.get_proto_from_test_data(
                'expected_import_example_gen_task.pbtxt',
                pipeline_pb2.PipelineTaskSpec()), actual_step_spec)
        self.assertProtoEquals(
            test_utils.get_proto_from_test_data(
                'expected_import_example_gen_executor.pbtxt',
                pipeline_pb2.PipelineDeploymentConfig()), deployment_config)
예제 #5
0
    def testBuildLatestArtifactResolverSucceed(self):
        latest_model_resolver = resolver.Resolver(
            instance_name='my_resolver',
            strategy_class=latest_artifacts_resolver.LatestArtifactsResolver,
            model=channel.Channel(type=standard_artifacts.Model),
            examples=channel.Channel(type=standard_artifacts.Examples))
        deployment_config = pipeline_pb2.PipelineDeploymentConfig()
        component_defs = {}
        test_pipeline_info = data_types.PipelineInfo(
            pipeline_name='test-pipeline',
            pipeline_root='gs://path/to/my/root')
        my_builder = step_builder.StepBuilder(
            node=latest_model_resolver,
            deployment_config=deployment_config,
            pipeline_info=test_pipeline_info,
            component_defs=component_defs)
        actual_step_spec = self._sole(my_builder.build())
        actual_component_def = self._sole(component_defs)

        self.assertProtoEquals(
            test_utils.get_proto_from_test_data(
                'expected_latest_artifact_resolver_component.pbtxt',
                pipeline_pb2.ComponentSpec()), actual_component_def)
        self.assertProtoEquals(
            test_utils.get_proto_from_test_data(
                'expected_latest_artifact_resolver_task.pbtxt',
                pipeline_pb2.PipelineTaskSpec()), actual_step_spec)
        self.assertProtoEquals(
            test_utils.get_proto_from_test_data(
                'expected_latest_artifact_resolver_executor.pbtxt',
                pipeline_pb2.PipelineDeploymentConfig()), deployment_config)
예제 #6
0
    def testBuildTask(self):
        query = 'SELECT * FROM TABLE'
        bq_example_gen = big_query_example_gen_component.BigQueryExampleGen(
            query=query)
        deployment_config = pipeline_pb2.PipelineDeploymentConfig()
        component_defs = {}
        my_builder = step_builder.StepBuilder(
            node=bq_example_gen,
            image='gcr.io/tensorflow/tfx:latest',
            deployment_config=deployment_config,
            component_defs=component_defs,
            enable_cache=True)
        actual_step_spec = self._sole(my_builder.build())
        actual_component_def = self._sole(component_defs)

        self.assertProtoEquals(
            test_utils.get_proto_from_test_data(
                'expected_bq_example_gen_component.pbtxt',
                pipeline_pb2.ComponentSpec()), actual_component_def)
        self.assertProtoEquals(
            test_utils.get_proto_from_test_data(
                'expected_bq_example_gen_task.pbtxt',
                pipeline_pb2.PipelineTaskSpec()), actual_step_spec)
        self.assertProtoEquals(
            test_utils.get_proto_from_test_data(
                'expected_bq_example_gen_executor.pbtxt',
                pipeline_pb2.PipelineDeploymentConfig()), deployment_config)
예제 #7
0
    def testBuildExitHandler(self):
        task = test_utils.dummy_producer_component(
            param1=decorators.FinalStatusStr('value1'), )
        deployment_config = pipeline_pb2.PipelineDeploymentConfig()
        component_defs = {}
        my_builder = step_builder.StepBuilder(
            node=task,
            image='gcr.io/tensorflow/tfx:latest',
            deployment_config=deployment_config,
            component_defs=component_defs,
            is_exit_handler=True)
        actual_step_spec = self._sole(my_builder.build())
        actual_component_def = self._sole(component_defs)

        self.assertProtoEquals(
            test_utils.get_proto_from_test_data(
                'expected_dummy_exit_handler_component.pbtxt',
                pipeline_pb2.ComponentSpec()), actual_component_def)
        self.assertProtoEquals(
            test_utils.get_proto_from_test_data(
                'expected_dummy_exit_handler_task.pbtxt',
                pipeline_pb2.PipelineTaskSpec()), actual_step_spec)
        self.assertProtoEquals(
            test_utils.get_proto_from_test_data(
                'expected_dummy_exit_handler_executor.pbtxt',
                pipeline_pb2.PipelineDeploymentConfig()), deployment_config)
예제 #8
0
    def testBuildContainerTask2(self):
        task = test_utils.dummy_producer_component(
            output1=channel_utils.as_channel([standard_artifacts.Model()]),
            param1='value1',
        )
        deployment_config = pipeline_pb2.PipelineDeploymentConfig()
        component_defs = {}
        my_builder = step_builder.StepBuilder(
            node=task,
            image='gcr.io/tensorflow/tfx:latest',
            deployment_config=deployment_config,
            component_defs=component_defs)
        actual_step_spec = self._sole(my_builder.build())
        actual_component_def = self._sole(component_defs)

        # Same as in testBuildContainerTask
        self.assertProtoEquals(
            test_utils.get_proto_from_test_data(
                'expected_dummy_container_spec_component.pbtxt',
                pipeline_pb2.ComponentSpec()), actual_component_def)
        self.assertProtoEquals(
            test_utils.get_proto_from_test_data(
                'expected_dummy_container_spec_task.pbtxt',
                pipeline_pb2.PipelineTaskSpec()), actual_step_spec)
        self.assertProtoEquals(
            test_utils.get_proto_from_test_data(
                'expected_dummy_container_spec_executor.pbtxt',
                pipeline_pb2.PipelineDeploymentConfig()), deployment_config)
예제 #9
0
    def testBuildLatestBlessedModelResolverSucceed(self):
        latest_blessed_resolver = resolver.Resolver(
            instance_name='my_resolver2',
            strategy_class=latest_blessed_model_resolver.
            LatestBlessedModelResolver,
            model=channel.Channel(type=standard_artifacts.Model),
            model_blessing=channel.Channel(
                type=standard_artifacts.ModelBlessing))
        test_pipeline_info = data_types.PipelineInfo(
            pipeline_name='test-pipeline',
            pipeline_root='gs://path/to/my/root')

        deployment_config = pipeline_pb2.PipelineDeploymentConfig()
        component_defs = {}
        my_builder = step_builder.StepBuilder(
            node=latest_blessed_resolver,
            deployment_config=deployment_config,
            pipeline_info=test_pipeline_info,
            component_defs=component_defs)
        actual_step_specs = my_builder.build()

        model_blessing_resolver_id = 'Resolver.my_resolver2-model-blessing-resolver'
        model_resolver_id = 'Resolver.my_resolver2-model-resolver'
        self.assertSameElements(
            actual_step_specs.keys(),
            [model_blessing_resolver_id, model_resolver_id])

        self.assertProtoEquals(
            test_utils.get_proto_from_test_data(
                'expected_latest_blessed_model_resolver_component_1.pbtxt',
                pipeline_pb2.ComponentSpec()),
            component_defs[model_blessing_resolver_id])

        self.assertProtoEquals(
            test_utils.get_proto_from_test_data(
                'expected_latest_blessed_model_resolver_task_1.pbtxt',
                pipeline_pb2.PipelineTaskSpec()),
            actual_step_specs[model_blessing_resolver_id])

        self.assertProtoEquals(
            test_utils.get_proto_from_test_data(
                'expected_latest_blessed_model_resolver_component_2.pbtxt',
                pipeline_pb2.ComponentSpec()),
            component_defs[model_resolver_id])

        self.assertProtoEquals(
            test_utils.get_proto_from_test_data(
                'expected_latest_blessed_model_resolver_task_2.pbtxt',
                pipeline_pb2.PipelineTaskSpec()),
            actual_step_specs[model_resolver_id])

        self.assertProtoEquals(
            test_utils.get_proto_from_test_data(
                'expected_latest_blessed_model_resolver_executor.pbtxt',
                pipeline_pb2.PipelineDeploymentConfig()), deployment_config)
예제 #10
0
    def testBuildDummyConsumerWithCondition(self):
        producer_task_1 = test_utils.dummy_producer_component(
            output1=channel_utils.as_channel([standard_artifacts.Model()]),
            param1='value1',
        ).with_id('producer_task_1')
        producer_task_2 = test_utils.dummy_producer_component_2(
            output1=channel_utils.as_channel([standard_artifacts.Model()]),
            param1='value2',
        ).with_id('producer_task_2')
        # This test tests two things:
        # 1. Nested conditions. The condition string of consumer_task should contain
        #    both predicates.
        # 2. Implicit channels. consumer_task only takes producer_task_1's output.
        #    But producer_task_2 is used in condition, hence producer_task_2 should
        #    be added to the dependency of consumer_task.
        # See testdata for detail.
        with conditional.Cond(
                producer_task_1.outputs['output1'].future()[0].uri != 'uri'):
            with conditional.Cond(producer_task_2.outputs['output1'].future()
                                  [0].property('property') == 'value1'):
                consumer_task = test_utils.dummy_consumer_component(
                    input1=producer_task_1.outputs['output1'],
                    param1=1,
                )
        # Need to construct a pipeline to set producer_component_id.
        unused_pipeline = tfx.dsl.Pipeline(
            pipeline_name='pipeline-with-condition',
            pipeline_root='',
            components=[producer_task_1, producer_task_2, consumer_task],
        )
        deployment_config = pipeline_pb2.PipelineDeploymentConfig()
        component_defs = {}
        my_builder = step_builder.StepBuilder(
            node=consumer_task,
            image='gcr.io/tensorflow/tfx:latest',
            deployment_config=deployment_config,
            component_defs=component_defs)
        actual_step_spec = self._sole(my_builder.build())
        actual_component_def = self._sole(component_defs)

        self.assertProtoEquals(
            test_utils.get_proto_from_test_data(
                'expected_dummy_consumer_with_condition_component.pbtxt',
                pipeline_pb2.ComponentSpec()), actual_component_def)
        self.assertProtoEquals(
            test_utils.get_proto_from_test_data(
                'expected_dummy_consumer_with_condition_task.pbtxt',
                pipeline_pb2.PipelineTaskSpec()), actual_step_spec)
        self.assertProtoEquals(
            test_utils.get_proto_from_test_data(
                'expected_dummy_consumer_with_condition_executor.pbtxt',
                pipeline_pb2.PipelineDeploymentConfig()), deployment_config)
예제 #11
0
 def testBuildTwoStepPipeline(self):
   my_builder = pipeline_builder.PipelineBuilder(
       tfx_pipeline=test_utils.two_step_pipeline(),
       default_image='gcr.io/my-tfx:latest')
   actual_pipeline_spec = my_builder.build()
   self.assertProtoEquals(
       test_utils.get_proto_from_test_data('expected_two_step_pipeline.pbtxt',
                                           pipeline_pb2.PipelineSpec()),
       actual_pipeline_spec)
예제 #12
0
 def testBuildPipelineWithPrimitiveValuePassing(self):
   my_builder = pipeline_builder.PipelineBuilder(
       tfx_pipeline=test_utils.consume_primitive_artifacts_by_value_pipeline(),
       default_image='gcr.io/my-tfx:latest')
   actual_pipeline_spec = my_builder.build()
   self.assertProtoEquals(
       test_utils.get_proto_from_test_data(
           'expected_consume_primitive_artifacts_by_value_pipeline.pbtxt',
           pipeline_pb2.PipelineSpec()), actual_pipeline_spec)
예제 #13
0
 def testBuildPipelineWithRuntimeParameter(self):
   my_builder = pipeline_builder.PipelineBuilder(
       tfx_pipeline=test_utils.pipeline_with_runtime_parameter(),
       default_image='gcr.io/my-tfx:latest')
   actual_pipeline_spec = my_builder.build()
   self.assertProtoEquals(
       test_utils.get_proto_from_test_data(
           'expected_pipeline_with_runtime_parameter.pbtxt',
           pipeline_pb2.PipelineSpec()), actual_pipeline_spec)
예제 #14
0
  def testTwoStepPipelineWithTaskOnlyDependency(self):
    builder = pipeline_builder.PipelineBuilder(
        tfx_pipeline=test_utils.two_step_pipeline_with_task_only_dependency(),
        default_image='unused-image')

    pipeline_spec = builder.build()
    self.assertProtoEquals(
        test_utils.get_proto_from_test_data(
            'expected_two_step_pipeline_with_task_only_dependency.pbtxt',
            pipeline_pb2.PipelineSpec()), pipeline_spec)
예제 #15
0
  def testBuildPipelineWithTwoContainerSpecComponents(self):
    my_builder = pipeline_builder.PipelineBuilder(
        tfx_pipeline=test_utils.pipeline_with_two_container_spec_components(),
        default_image='gcr.io/my-tfx:latest')
    actual_pipeline_spec = my_builder.build()

    self.assertProtoEquals(
        test_utils.get_proto_from_test_data(
            'expected_pipeline_with_two_container_spec_components.pbtxt',
            pipeline_pb2.PipelineSpec()), actual_pipeline_spec)
예제 #16
0
  def testBuildContainerTask(self):
    task = test_utils.DummyProducerComponent(
        output1=channel_utils.as_channel([standard_artifacts.Model()]),
        param1='value1',
    )
    deployment_config = pipeline_pb2.PipelineDeploymentConfig()
    my_builder = step_builder.StepBuilder(
        node=task,
        image='gcr.io/tensorflow/tfx:latest',  # Note this has no effect here.
        deployment_config=deployment_config)
    actual_step_spec = self._sole(my_builder.build())

    self.assertProtoEquals(
        test_utils.get_proto_from_test_data(
            'expected_dummy_container_spec_task.pbtxt',
            pipeline_pb2.PipelineTaskSpec()), actual_step_spec)
    self.assertProtoEquals(
        test_utils.get_proto_from_test_data(
            'expected_dummy_container_spec_executor.pbtxt',
            pipeline_pb2.PipelineDeploymentConfig()), deployment_config)
예제 #17
0
  def testBuildTwoStepPipelineWithCacheEnabled(self):
    pipeline = test_utils.two_step_pipeline()
    pipeline.enable_cache = True

    builder = pipeline_builder.PipelineBuilder(
        tfx_pipeline=pipeline, default_image='gcr.io/my-tfx:latest')
    pipeline_spec = builder.build()
    self.assertProtoEquals(
        test_utils.get_proto_from_test_data(
            'expected_two_step_pipeline_with_cache_enabled.pbtxt',
            pipeline_pb2.PipelineSpec()), pipeline_spec)
예제 #18
0
  def testPipelineWithExitHandler(self):
    pipeline = test_utils.two_step_pipeline()
    # define exit handler
    exit_handler = test_utils.dummy_exit_handler(
        param1=decorators.FinalStatusStr())

    builder = pipeline_builder.PipelineBuilder(
        tfx_pipeline=pipeline,
        default_image='gcr.io/my-tfx:latest',
        exit_handler=exit_handler)
    pipeline_spec = builder.build()
    self.assertProtoEquals(
        test_utils.get_proto_from_test_data(
            'expected_two_step_pipeline_with_exit_handler.pbtxt',
            pipeline_pb2.PipelineSpec()), pipeline_spec)