def testBuildTask(self): query = 'SELECT * FROM TABLE' bq_example_gen = big_query_example_gen_component.BigQueryExampleGen( query=query) deployment_config = pipeline_pb2.PipelineDeploymentConfig() component_defs = {} my_builder = step_builder.StepBuilder( node=bq_example_gen, image='gcr.io/tensorflow/tfx:latest', deployment_config=deployment_config, component_defs=component_defs, enable_cache=True) actual_step_spec = self._sole(my_builder.build()) actual_component_def = self._sole(component_defs) self.assertProtoEquals( test_utils.get_proto_from_test_data( 'expected_bq_example_gen_component.pbtxt', pipeline_pb2.ComponentSpec()), actual_component_def) self.assertProtoEquals( test_utils.get_proto_from_test_data( 'expected_bq_example_gen_task.pbtxt', pipeline_pb2.PipelineTaskSpec()), actual_step_spec) self.assertProtoEquals( test_utils.get_proto_from_test_data( 'expected_bq_example_gen_executor.pbtxt', pipeline_pb2.PipelineDeploymentConfig()), deployment_config)
def testBuildContainerTask2(self): task = test_utils.dummy_producer_component( output1=channel_utils.as_channel([standard_artifacts.Model()]), param1='value1', ) deployment_config = pipeline_pb2.PipelineDeploymentConfig() component_defs = {} my_builder = step_builder.StepBuilder( node=task, image='gcr.io/tensorflow/tfx:latest', deployment_config=deployment_config, component_defs=component_defs) actual_step_spec = self._sole(my_builder.build()) actual_component_def = self._sole(component_defs) # Same as in testBuildContainerTask self.assertProtoEquals( test_utils.get_proto_from_test_data( 'expected_dummy_container_spec_component.pbtxt', pipeline_pb2.ComponentSpec()), actual_component_def) self.assertProtoEquals( test_utils.get_proto_from_test_data( 'expected_dummy_container_spec_task.pbtxt', pipeline_pb2.PipelineTaskSpec()), actual_step_spec) self.assertProtoEquals( test_utils.get_proto_from_test_data( 'expected_dummy_container_spec_executor.pbtxt', pipeline_pb2.PipelineDeploymentConfig()), deployment_config)
def testBuildFileBasedExampleGenWithInputConfig(self): input_config = example_gen_pb2.Input(splits=[ example_gen_pb2.Input.Split(name='train', pattern='*train.tfr'), example_gen_pb2.Input.Split(name='eval', pattern='*test.tfr') ]) example_gen = components.ImportExampleGen( input_base='path/to/data/root', input_config=input_config) deployment_config = pipeline_pb2.PipelineDeploymentConfig() component_defs = {} my_builder = step_builder.StepBuilder( node=example_gen, image='gcr.io/tensorflow/tfx:latest', deployment_config=deployment_config, component_defs=component_defs) actual_step_spec = self._sole(my_builder.build()) actual_component_def = self._sole(component_defs) self.assertProtoEquals( test_utils.get_proto_from_test_data( 'expected_import_example_gen_component.pbtxt', pipeline_pb2.ComponentSpec()), actual_component_def) self.assertProtoEquals( test_utils.get_proto_from_test_data( 'expected_import_example_gen_task.pbtxt', pipeline_pb2.PipelineTaskSpec()), actual_step_spec) self.assertProtoEquals( test_utils.get_proto_from_test_data( 'expected_import_example_gen_executor.pbtxt', pipeline_pb2.PipelineDeploymentConfig()), deployment_config)
def testBuildLatestArtifactResolverSucceed(self): latest_model_resolver = resolver.Resolver( instance_name='my_resolver', strategy_class=latest_artifacts_resolver.LatestArtifactsResolver, model=channel.Channel(type=standard_artifacts.Model), examples=channel.Channel(type=standard_artifacts.Examples)) deployment_config = pipeline_pb2.PipelineDeploymentConfig() component_defs = {} test_pipeline_info = data_types.PipelineInfo( pipeline_name='test-pipeline', pipeline_root='gs://path/to/my/root') my_builder = step_builder.StepBuilder( node=latest_model_resolver, deployment_config=deployment_config, pipeline_info=test_pipeline_info, component_defs=component_defs) actual_step_spec = self._sole(my_builder.build()) actual_component_def = self._sole(component_defs) self.assertProtoEquals( test_utils.get_proto_from_test_data( 'expected_latest_artifact_resolver_component.pbtxt', pipeline_pb2.ComponentSpec()), actual_component_def) self.assertProtoEquals( test_utils.get_proto_from_test_data( 'expected_latest_artifact_resolver_task.pbtxt', pipeline_pb2.PipelineTaskSpec()), actual_step_spec) self.assertProtoEquals( test_utils.get_proto_from_test_data( 'expected_latest_artifact_resolver_executor.pbtxt', pipeline_pb2.PipelineDeploymentConfig()), deployment_config)
def testBuildImporter(self): impt = importer.Importer(instance_name='my_importer', source_uri='m/y/u/r/i', properties={ 'split_names': '["train", "eval"]', }, custom_properties={ 'str_custom_property': 'abc', 'int_custom_property': 123, }, artifact_type=standard_artifacts.Examples) deployment_config = pipeline_pb2.PipelineDeploymentConfig() component_defs = {} my_builder = step_builder.StepBuilder( node=impt, deployment_config=deployment_config, component_defs=component_defs) actual_step_spec = self._sole(my_builder.build()) actual_component_def = self._sole(component_defs) self.assertProtoEquals( test_utils.get_proto_from_test_data( 'expected_importer_component.pbtxt', pipeline_pb2.ComponentSpec()), actual_component_def) self.assertProtoEquals( test_utils.get_proto_from_test_data( 'expected_importer_task.pbtxt', pipeline_pb2.PipelineTaskSpec()), actual_step_spec) self.assertProtoEquals( test_utils.get_proto_from_test_data( 'expected_importer_executor.pbtxt', pipeline_pb2.PipelineDeploymentConfig()), deployment_config)
def testBuildExitHandler(self): task = test_utils.dummy_producer_component( param1=decorators.FinalStatusStr('value1'), ) deployment_config = pipeline_pb2.PipelineDeploymentConfig() component_defs = {} my_builder = step_builder.StepBuilder( node=task, image='gcr.io/tensorflow/tfx:latest', deployment_config=deployment_config, component_defs=component_defs, is_exit_handler=True) actual_step_spec = self._sole(my_builder.build()) actual_component_def = self._sole(component_defs) self.assertProtoEquals( test_utils.get_proto_from_test_data( 'expected_dummy_exit_handler_component.pbtxt', pipeline_pb2.ComponentSpec()), actual_component_def) self.assertProtoEquals( test_utils.get_proto_from_test_data( 'expected_dummy_exit_handler_task.pbtxt', pipeline_pb2.PipelineTaskSpec()), actual_step_spec) self.assertProtoEquals( test_utils.get_proto_from_test_data( 'expected_dummy_exit_handler_executor.pbtxt', pipeline_pb2.PipelineDeploymentConfig()), deployment_config)
def testBuildFileBasedExampleGen(self): beam_pipeline_args = ['runner=DataflowRunner'] example_gen = components.CsvExampleGen(input_base='path/to/data/root') deployment_config = pipeline_pb2.PipelineDeploymentConfig() component_defs = {} my_builder = step_builder.StepBuilder( node=example_gen, image='gcr.io/tensorflow/tfx:latest', image_cmds=_TEST_CMDS, beam_pipeline_args=beam_pipeline_args, deployment_config=deployment_config, component_defs=component_defs) actual_step_spec = self._sole(my_builder.build()) actual_component_def = self._sole(component_defs) self.assertProtoEquals( test_utils.get_proto_from_test_data( 'expected_csv_example_gen_component.pbtxt', pipeline_pb2.ComponentSpec()), actual_component_def) self.assertProtoEquals( test_utils.get_proto_from_test_data( 'expected_csv_example_gen_task.pbtxt', pipeline_pb2.PipelineTaskSpec()), actual_step_spec) self.assertProtoEquals( test_utils.get_proto_from_test_data( 'expected_csv_example_gen_executor.pbtxt', pipeline_pb2.PipelineDeploymentConfig()), deployment_config)
def testBuildImporterWithRuntimeParam(self): param = data_types.RuntimeParameter(name='runtime_flag', ptype=str) impt = importer.Importer( source_uri=param, artifact_type=standard_artifacts.Examples).with_id('my_importer') deployment_config = pipeline_pb2.PipelineDeploymentConfig() component_defs = {} with parameter_utils.ParameterContext() as pc: my_builder = step_builder.StepBuilder( node=impt, deployment_config=deployment_config, component_defs=component_defs) actual_step_spec = self._sole(my_builder.build()) actual_component_def = self._sole(component_defs) self.assertProtoEquals( test_utils.get_proto_from_test_data( 'expected_importer_component_with_runtime_param.pbtxt', pipeline_pb2.ComponentSpec()), actual_component_def) self.assertProtoEquals( test_utils.get_proto_from_test_data( 'expected_importer_task_with_runtime_param.pbtxt', pipeline_pb2.PipelineTaskSpec()), actual_step_spec) self.assertProtoEquals( test_utils.get_proto_from_test_data( 'expected_importer_executor_with_runtime_param.pbtxt', pipeline_pb2.PipelineDeploymentConfig()), deployment_config) self.assertListEqual([param], pc.parameters)
def testBuildLatestBlessedModelResolverSucceed(self): latest_blessed_resolver = components.ResolverNode( instance_name='my_resolver2', resolver_class=latest_blessed_model_resolver.LatestBlessedModelResolver, model=channel.Channel(type=standard_artifacts.Model), model_blessing=channel.Channel(type=standard_artifacts.ModelBlessing)) test_pipeline_info = data_types.PipelineInfo( pipeline_name='test-pipeline', pipeline_root='gs://path/to/my/root') deployment_config = pipeline_pb2.PipelineDeploymentConfig() my_builder = step_builder.StepBuilder( node=latest_blessed_resolver, deployment_config=deployment_config, pipeline_info=test_pipeline_info) actual_step_specs = my_builder.build() self.assertProtoEquals( text_format.Parse(_EXPECTED_LATEST_BLESSED_MODEL_RESOLVER_1, pipeline_pb2.PipelineTaskSpec()), actual_step_specs[0]) self.assertProtoEquals( text_format.Parse(_EXPECTED_LATEST_BLESSED_MODEL_RESOLVER_2, pipeline_pb2.PipelineTaskSpec()), actual_step_specs[1]) self.assertProtoEquals( text_format.Parse(_EXPECTED_LATEST_BLESSED_MODEL_EXECUTOR, pipeline_pb2.PipelineDeploymentConfig()), deployment_config)
def testBuildLatestBlessedModelResolverSucceed(self): latest_blessed_resolver = resolver.Resolver( instance_name='my_resolver2', strategy_class=latest_blessed_model_resolver. LatestBlessedModelResolver, model=channel.Channel(type=standard_artifacts.Model), model_blessing=channel.Channel( type=standard_artifacts.ModelBlessing)) test_pipeline_info = data_types.PipelineInfo( pipeline_name='test-pipeline', pipeline_root='gs://path/to/my/root') deployment_config = pipeline_pb2.PipelineDeploymentConfig() component_defs = {} my_builder = step_builder.StepBuilder( node=latest_blessed_resolver, deployment_config=deployment_config, pipeline_info=test_pipeline_info, component_defs=component_defs) actual_step_specs = my_builder.build() model_blessing_resolver_id = 'Resolver.my_resolver2-model-blessing-resolver' model_resolver_id = 'Resolver.my_resolver2-model-resolver' self.assertSameElements( actual_step_specs.keys(), [model_blessing_resolver_id, model_resolver_id]) self.assertProtoEquals( test_utils.get_proto_from_test_data( 'expected_latest_blessed_model_resolver_component_1.pbtxt', pipeline_pb2.ComponentSpec()), component_defs[model_blessing_resolver_id]) self.assertProtoEquals( test_utils.get_proto_from_test_data( 'expected_latest_blessed_model_resolver_task_1.pbtxt', pipeline_pb2.PipelineTaskSpec()), actual_step_specs[model_blessing_resolver_id]) self.assertProtoEquals( test_utils.get_proto_from_test_data( 'expected_latest_blessed_model_resolver_component_2.pbtxt', pipeline_pb2.ComponentSpec()), component_defs[model_resolver_id]) self.assertProtoEquals( test_utils.get_proto_from_test_data( 'expected_latest_blessed_model_resolver_task_2.pbtxt', pipeline_pb2.PipelineTaskSpec()), actual_step_specs[model_resolver_id]) self.assertProtoEquals( test_utils.get_proto_from_test_data( 'expected_latest_blessed_model_resolver_executor.pbtxt', pipeline_pb2.PipelineDeploymentConfig()), deployment_config)
def testBuildDummyConsumerWithCondition(self): producer_task_1 = test_utils.dummy_producer_component( output1=channel_utils.as_channel([standard_artifacts.Model()]), param1='value1', ).with_id('producer_task_1') producer_task_2 = test_utils.dummy_producer_component_2( output1=channel_utils.as_channel([standard_artifacts.Model()]), param1='value2', ).with_id('producer_task_2') # This test tests two things: # 1. Nested conditions. The condition string of consumer_task should contain # both predicates. # 2. Implicit channels. consumer_task only takes producer_task_1's output. # But producer_task_2 is used in condition, hence producer_task_2 should # be added to the dependency of consumer_task. # See testdata for detail. with conditional.Cond( producer_task_1.outputs['output1'].future()[0].uri != 'uri'): with conditional.Cond(producer_task_2.outputs['output1'].future() [0].property('property') == 'value1'): consumer_task = test_utils.dummy_consumer_component( input1=producer_task_1.outputs['output1'], param1=1, ) # Need to construct a pipeline to set producer_component_id. unused_pipeline = tfx.dsl.Pipeline( pipeline_name='pipeline-with-condition', pipeline_root='', components=[producer_task_1, producer_task_2, consumer_task], ) deployment_config = pipeline_pb2.PipelineDeploymentConfig() component_defs = {} my_builder = step_builder.StepBuilder( node=consumer_task, image='gcr.io/tensorflow/tfx:latest', deployment_config=deployment_config, component_defs=component_defs) actual_step_spec = self._sole(my_builder.build()) actual_component_def = self._sole(component_defs) self.assertProtoEquals( test_utils.get_proto_from_test_data( 'expected_dummy_consumer_with_condition_component.pbtxt', pipeline_pb2.ComponentSpec()), actual_component_def) self.assertProtoEquals( test_utils.get_proto_from_test_data( 'expected_dummy_consumer_with_condition_task.pbtxt', pipeline_pb2.PipelineTaskSpec()), actual_step_spec) self.assertProtoEquals( test_utils.get_proto_from_test_data( 'expected_dummy_consumer_with_condition_executor.pbtxt', pipeline_pb2.PipelineDeploymentConfig()), deployment_config)
def build(self) -> pipeline_pb2.PipelineSpec: """Build a pipeline PipelineSpec.""" _check_name(self._pipeline_info.pipeline_name) deployment_config = pipeline_pb2.PipelineDeploymentConfig() pipeline_info = pipeline_pb2.PipelineInfo( name=self._pipeline_info.pipeline_name) tasks = {} component_defs = {} # Map from (producer component id, output key) to (new producer component # id, output key) channel_redirect_map = {} with parameter_utils.ParameterContext() as pc: for component in self._pipeline.components: # Here the topological order of components is required. # If a channel redirection is needed, redirect mapping is expected to be # available because the upstream node (which is the cause for # redirecting) is processed before the downstream consumer nodes. built_tasks = step_builder.StepBuilder( node=component, deployment_config=deployment_config, component_defs=component_defs, image=self._default_image, image_cmds=self._default_commands, beam_pipeline_args=self._pipeline.beam_pipeline_args, enable_cache=self._pipeline.enable_cache, pipeline_info=self._pipeline_info, channel_redirect_map=channel_redirect_map).build() tasks.update(built_tasks) result = pipeline_pb2.PipelineSpec(pipeline_info=pipeline_info) result.deployment_spec.update( json_format.MessageToDict(deployment_config)) for name, component_def in component_defs.items(): result.components[name].CopyFrom(component_def) for name, task_spec in tasks.items(): result.root.dag.tasks[name].CopyFrom(task_spec) # Attach runtime parameter to root's input parameter for param in pc.parameters: result.root.input_definitions.parameters[param.name].CopyFrom( compiler_utils.build_parameter_type_spec(param)) return result
def testBuildContainerTask(self): task = test_utils.DummyProducerComponent( output1=channel_utils.as_channel([standard_artifacts.Model()]), param1='value1', ) deployment_config = pipeline_pb2.PipelineDeploymentConfig() my_builder = step_builder.StepBuilder( node=task, image='gcr.io/tensorflow/tfx:latest', # Note this has no effect here. deployment_config=deployment_config) actual_step_spec = self._sole(my_builder.build()) self.assertProtoEquals( test_utils.get_proto_from_test_data( 'expected_dummy_container_spec_task.pbtxt', pipeline_pb2.PipelineTaskSpec()), actual_step_spec) self.assertProtoEquals( test_utils.get_proto_from_test_data( 'expected_dummy_container_spec_executor.pbtxt', pipeline_pb2.PipelineDeploymentConfig()), deployment_config)
def build(self) -> pipeline_pb2.PipelineSpec: """Build a pipeline PipelineSpec.""" _check_name(self._pipeline_info.pipeline_name) deployment_config = pipeline_pb2.PipelineDeploymentConfig() pipeline_info = pipeline_pb2.PipelineInfo( name=self._pipeline_info.pipeline_name) tasks = [] # Map from (producer component id, output key) to (new producer component # id, output key) channel_redirect_map = {} with parameter_utils.ParameterContext() as pc: for component in self._pipeline.components: # Here the topological order of components is required. # If a channel redirection is needed, redirect mapping is expected to be # available because the upstream node (which is the cause for # redirecting) is processed before the downstream consumer nodes. built_tasks = step_builder.StepBuilder( node=component, deployment_config=deployment_config, image=self._default_image, image_cmds=self._default_commands, beam_pipeline_args=self._pipeline.beam_pipeline_args, enable_cache=self._pipeline.enable_cache, pipeline_info=self._pipeline_info, channel_redirect_map=channel_redirect_map).build() tasks.extend(built_tasks) result = pipeline_pb2.PipelineSpec( pipeline_info=pipeline_info, tasks=tasks, runtime_parameters=compiler_utils.build_runtime_parameter_spec( pc.parameters)) result.deployment_spec.update( json_format.MessageToDict(deployment_config)) return result
def build(self) -> pipeline_pb2.PipelineSpec: """Build a pipeline PipelineSpec.""" _check_name(self._pipeline_info.pipeline_name) deployment_config = pipeline_pb2.PipelineDeploymentConfig() pipeline_info = pipeline_pb2.PipelineInfo( name=self._pipeline_info.pipeline_name) tfx_tasks = {} component_defs = {} # Map from (producer component id, output key) to (new producer component # id, output key) channel_redirect_map = {} with parameter_utils.ParameterContext() as pc: for component in self._pipeline.components: if self._exit_handler and component.id == compiler_utils.TFX_DAG_NAME: component.with_id(component.id + _generate_component_name_suffix()) logging.warning( '_tfx_dag is system reserved name for pipeline with' 'exit handler, added suffix to your component name: %s', component.id) # Here the topological order of components is required. # If a channel redirection is needed, redirect mapping is expected to be # available because the upstream node (which is the cause for # redirecting) is processed before the downstream consumer nodes. built_tasks = step_builder.StepBuilder( node=component, deployment_config=deployment_config, component_defs=component_defs, image=self._default_image, image_cmds=self._default_commands, beam_pipeline_args=self._pipeline.beam_pipeline_args, enable_cache=self._pipeline.enable_cache, pipeline_info=self._pipeline_info, channel_redirect_map=channel_redirect_map).build() tfx_tasks.update(built_tasks) result = pipeline_pb2.PipelineSpec(pipeline_info=pipeline_info) # if exit handler is defined, put all the TFX tasks under tfx_dag, # exit handler is a separate component triggered by tfx_dag. if self._exit_handler: for name, task_spec in tfx_tasks.items(): result.components[compiler_utils.TFX_DAG_NAME].dag.tasks[ name].CopyFrom(task_spec) # construct root with exit handler exit_handler_task = step_builder.StepBuilder( node=self._exit_handler, deployment_config=deployment_config, component_defs=component_defs, image=self._default_image, image_cmds=self._default_commands, beam_pipeline_args=self._pipeline.beam_pipeline_args, enable_cache=False, pipeline_info=self._pipeline_info, channel_redirect_map=channel_redirect_map, is_exit_handler=True).build() result.root.dag.tasks[ compiler_utils. TFX_DAG_NAME].component_ref.name = compiler_utils.TFX_DAG_NAME result.root.dag.tasks[ compiler_utils. TFX_DAG_NAME].task_info.name = compiler_utils.TFX_DAG_NAME result.root.dag.tasks[self._exit_handler.id].CopyFrom( exit_handler_task[self._exit_handler.id]) else: for name, task_spec in tfx_tasks.items(): result.root.dag.tasks[name].CopyFrom(task_spec) result.deployment_spec.update( json_format.MessageToDict(deployment_config)) for name, component_def in component_defs.items(): result.components[name].CopyFrom(component_def) # Attach runtime parameter to root's input parameter for param in pc.parameters: result.root.input_definitions.parameters[param.name].CopyFrom( compiler_utils.build_parameter_type_spec(param)) return result