def testBuildLatestBlessedModelResolverSucceed(self): latest_blessed_resolver = components.ResolverNode( instance_name='my_resolver2', resolver_class=latest_blessed_model_resolver.LatestBlessedModelResolver, model=channel.Channel(type=standard_artifacts.Model), model_blessing=channel.Channel(type=standard_artifacts.ModelBlessing)) test_pipeline_info = data_types.PipelineInfo( pipeline_name='test-pipeline', pipeline_root='gs://path/to/my/root') deployment_config = pipeline_pb2.PipelineDeploymentConfig() my_builder = step_builder.StepBuilder( node=latest_blessed_resolver, deployment_config=deployment_config, pipeline_info=test_pipeline_info) actual_step_specs = my_builder.build() self.assertProtoEquals( text_format.Parse(_EXPECTED_LATEST_BLESSED_MODEL_RESOLVER_1, pipeline_pb2.PipelineTaskSpec()), actual_step_specs[0]) self.assertProtoEquals( text_format.Parse(_EXPECTED_LATEST_BLESSED_MODEL_RESOLVER_2, pipeline_pb2.PipelineTaskSpec()), actual_step_specs[1]) self.assertProtoEquals( text_format.Parse(_EXPECTED_LATEST_BLESSED_MODEL_EXECUTOR, pipeline_pb2.PipelineDeploymentConfig()), deployment_config)
def testBuildLatestBlessedModelResolverSucceed(self): latest_blessed_resolver = resolver.Resolver( instance_name='my_resolver2', strategy_class=latest_blessed_model_resolver. LatestBlessedModelResolver, model=channel.Channel(type=standard_artifacts.Model), model_blessing=channel.Channel( type=standard_artifacts.ModelBlessing)) test_pipeline_info = data_types.PipelineInfo( pipeline_name='test-pipeline', pipeline_root='gs://path/to/my/root') deployment_config = pipeline_pb2.PipelineDeploymentConfig() component_defs = {} my_builder = step_builder.StepBuilder( node=latest_blessed_resolver, deployment_config=deployment_config, pipeline_info=test_pipeline_info, component_defs=component_defs) actual_step_specs = my_builder.build() model_blessing_resolver_id = 'Resolver.my_resolver2-model-blessing-resolver' model_resolver_id = 'Resolver.my_resolver2-model-resolver' self.assertSameElements( actual_step_specs.keys(), [model_blessing_resolver_id, model_resolver_id]) self.assertProtoEquals( test_utils.get_proto_from_test_data( 'expected_latest_blessed_model_resolver_component_1.pbtxt', pipeline_pb2.ComponentSpec()), component_defs[model_blessing_resolver_id]) self.assertProtoEquals( test_utils.get_proto_from_test_data( 'expected_latest_blessed_model_resolver_task_1.pbtxt', pipeline_pb2.PipelineTaskSpec()), actual_step_specs[model_blessing_resolver_id]) self.assertProtoEquals( test_utils.get_proto_from_test_data( 'expected_latest_blessed_model_resolver_component_2.pbtxt', pipeline_pb2.ComponentSpec()), component_defs[model_resolver_id]) self.assertProtoEquals( test_utils.get_proto_from_test_data( 'expected_latest_blessed_model_resolver_task_2.pbtxt', pipeline_pb2.PipelineTaskSpec()), actual_step_specs[model_resolver_id]) self.assertProtoEquals( test_utils.get_proto_from_test_data( 'expected_latest_blessed_model_resolver_executor.pbtxt', pipeline_pb2.PipelineDeploymentConfig()), deployment_config)
def _build_resolver_for_latest_model_blessing( self, model_blessing_channel_key: str) -> pipeline_pb2.PipelineTaskSpec: """Builds the resolver spec for latest valid ModelBlessing artifact.""" # 1. Build the task info. result = pipeline_pb2.PipelineTaskSpec() name = '{}{}'.format(self._name, _MODEL_BLESSING_RESOLVER_SUFFIX) result.task_info.CopyFrom(pipeline_pb2.PipelineTaskInfo(name=name)) executor_label = _EXECUTOR_LABEL_PATTERN.format(name) result.executor_label = executor_label # 2. Specify the outputs of the task. result.outputs.artifacts[model_blessing_channel_key].CopyFrom( compiler_utils.build_output_artifact_spec( self._outputs[model_blessing_channel_key])) # 3. Build the resolver executor spec for latest valid ModelBlessing. executor = pipeline_pb2.PipelineDeploymentConfig.ExecutorSpec() artifact_queries = {} query_filter = ("artifact_type='{type}' and state={state}" " and custom_properties['{key}']='{value}'").format( type=compiler_utils.get_artifact_title( standard_artifacts.ModelBlessing), state=metadata_store_pb2.Artifact.State.Name( metadata_store_pb2.Artifact.LIVE), key=constants.ARTIFACT_PROPERTY_BLESSED_KEY, value=constants.BLESSED_VALUE) artifact_queries[ model_blessing_channel_key] = ResolverSpec.ArtifactQuerySpec( filter=query_filter) executor.resolver.CopyFrom( ResolverSpec(output_artifact_queries=artifact_queries)) self._deployment_config.executors[executor_label].CopyFrom(executor) return result
def testBuildContainerTask2(self): task = test_utils.dummy_producer_component( output1=channel_utils.as_channel([standard_artifacts.Model()]), param1='value1', ) deployment_config = pipeline_pb2.PipelineDeploymentConfig() component_defs = {} my_builder = step_builder.StepBuilder( node=task, image='gcr.io/tensorflow/tfx:latest', deployment_config=deployment_config, component_defs=component_defs) actual_step_spec = self._sole(my_builder.build()) actual_component_def = self._sole(component_defs) # Same as in testBuildContainerTask self.assertProtoEquals( test_utils.get_proto_from_test_data( 'expected_dummy_container_spec_component.pbtxt', pipeline_pb2.ComponentSpec()), actual_component_def) self.assertProtoEquals( test_utils.get_proto_from_test_data( 'expected_dummy_container_spec_task.pbtxt', pipeline_pb2.PipelineTaskSpec()), actual_step_spec) self.assertProtoEquals( test_utils.get_proto_from_test_data( 'expected_dummy_container_spec_executor.pbtxt', pipeline_pb2.PipelineDeploymentConfig()), deployment_config)
def testBuildTask(self): query = 'SELECT * FROM TABLE' bq_example_gen = big_query_example_gen_component.BigQueryExampleGen( query=query) deployment_config = pipeline_pb2.PipelineDeploymentConfig() component_defs = {} my_builder = step_builder.StepBuilder( node=bq_example_gen, image='gcr.io/tensorflow/tfx:latest', deployment_config=deployment_config, component_defs=component_defs, enable_cache=True) actual_step_spec = self._sole(my_builder.build()) actual_component_def = self._sole(component_defs) self.assertProtoEquals( test_utils.get_proto_from_test_data( 'expected_bq_example_gen_component.pbtxt', pipeline_pb2.ComponentSpec()), actual_component_def) self.assertProtoEquals( test_utils.get_proto_from_test_data( 'expected_bq_example_gen_task.pbtxt', pipeline_pb2.PipelineTaskSpec()), actual_step_spec) self.assertProtoEquals( test_utils.get_proto_from_test_data( 'expected_bq_example_gen_executor.pbtxt', pipeline_pb2.PipelineDeploymentConfig()), deployment_config)
def testBuildFileBasedExampleGenWithInputConfig(self): input_config = example_gen_pb2.Input(splits=[ example_gen_pb2.Input.Split(name='train', pattern='*train.tfr'), example_gen_pb2.Input.Split(name='eval', pattern='*test.tfr') ]) example_gen = components.ImportExampleGen( input_base='path/to/data/root', input_config=input_config) deployment_config = pipeline_pb2.PipelineDeploymentConfig() component_defs = {} my_builder = step_builder.StepBuilder( node=example_gen, image='gcr.io/tensorflow/tfx:latest', deployment_config=deployment_config, component_defs=component_defs) actual_step_spec = self._sole(my_builder.build()) actual_component_def = self._sole(component_defs) self.assertProtoEquals( test_utils.get_proto_from_test_data( 'expected_import_example_gen_component.pbtxt', pipeline_pb2.ComponentSpec()), actual_component_def) self.assertProtoEquals( test_utils.get_proto_from_test_data( 'expected_import_example_gen_task.pbtxt', pipeline_pb2.PipelineTaskSpec()), actual_step_spec) self.assertProtoEquals( test_utils.get_proto_from_test_data( 'expected_import_example_gen_executor.pbtxt', pipeline_pb2.PipelineDeploymentConfig()), deployment_config)
def testBuildLatestArtifactResolverSucceed(self): latest_model_resolver = resolver.Resolver( instance_name='my_resolver', strategy_class=latest_artifacts_resolver.LatestArtifactsResolver, model=channel.Channel(type=standard_artifacts.Model), examples=channel.Channel(type=standard_artifacts.Examples)) deployment_config = pipeline_pb2.PipelineDeploymentConfig() component_defs = {} test_pipeline_info = data_types.PipelineInfo( pipeline_name='test-pipeline', pipeline_root='gs://path/to/my/root') my_builder = step_builder.StepBuilder( node=latest_model_resolver, deployment_config=deployment_config, pipeline_info=test_pipeline_info, component_defs=component_defs) actual_step_spec = self._sole(my_builder.build()) actual_component_def = self._sole(component_defs) self.assertProtoEquals( test_utils.get_proto_from_test_data( 'expected_latest_artifact_resolver_component.pbtxt', pipeline_pb2.ComponentSpec()), actual_component_def) self.assertProtoEquals( test_utils.get_proto_from_test_data( 'expected_latest_artifact_resolver_task.pbtxt', pipeline_pb2.PipelineTaskSpec()), actual_step_spec) self.assertProtoEquals( test_utils.get_proto_from_test_data( 'expected_latest_artifact_resolver_executor.pbtxt', pipeline_pb2.PipelineDeploymentConfig()), deployment_config)
def testBuildImporter(self): impt = importer.Importer(instance_name='my_importer', source_uri='m/y/u/r/i', properties={ 'split_names': '["train", "eval"]', }, custom_properties={ 'str_custom_property': 'abc', 'int_custom_property': 123, }, artifact_type=standard_artifacts.Examples) deployment_config = pipeline_pb2.PipelineDeploymentConfig() component_defs = {} my_builder = step_builder.StepBuilder( node=impt, deployment_config=deployment_config, component_defs=component_defs) actual_step_spec = self._sole(my_builder.build()) actual_component_def = self._sole(component_defs) self.assertProtoEquals( test_utils.get_proto_from_test_data( 'expected_importer_component.pbtxt', pipeline_pb2.ComponentSpec()), actual_component_def) self.assertProtoEquals( test_utils.get_proto_from_test_data( 'expected_importer_task.pbtxt', pipeline_pb2.PipelineTaskSpec()), actual_step_spec) self.assertProtoEquals( test_utils.get_proto_from_test_data( 'expected_importer_executor.pbtxt', pipeline_pb2.PipelineDeploymentConfig()), deployment_config)
def testBuildFileBasedExampleGen(self): beam_pipeline_args = ['runner=DataflowRunner'] example_gen = components.CsvExampleGen(input_base='path/to/data/root') deployment_config = pipeline_pb2.PipelineDeploymentConfig() component_defs = {} my_builder = step_builder.StepBuilder( node=example_gen, image='gcr.io/tensorflow/tfx:latest', image_cmds=_TEST_CMDS, beam_pipeline_args=beam_pipeline_args, deployment_config=deployment_config, component_defs=component_defs) actual_step_spec = self._sole(my_builder.build()) actual_component_def = self._sole(component_defs) self.assertProtoEquals( test_utils.get_proto_from_test_data( 'expected_csv_example_gen_component.pbtxt', pipeline_pb2.ComponentSpec()), actual_component_def) self.assertProtoEquals( test_utils.get_proto_from_test_data( 'expected_csv_example_gen_task.pbtxt', pipeline_pb2.PipelineTaskSpec()), actual_step_spec) self.assertProtoEquals( test_utils.get_proto_from_test_data( 'expected_csv_example_gen_executor.pbtxt', pipeline_pb2.PipelineDeploymentConfig()), deployment_config)
def _build_latest_artifact_resolver( self) -> List[pipeline_pb2.PipelineTaskSpec]: """Builds a resolver spec for a latest artifact resolver. Returns: A list of two PipelineTaskSpecs. One represents the query for latest valid ModelBlessing artifact. Another one represents the query for latest blessed Model artifact. Raises: ValueError: when desired_num_of_artifacts != 1. 1 is the only supported value currently. """ task_spec = pipeline_pb2.PipelineTaskSpec() task_spec.task_info.CopyFrom( pipeline_pb2.PipelineTaskInfo(name=self._name)) executor_label = _EXECUTOR_LABEL_PATTERN.format(self._name) task_spec.executor_label = executor_label # Fetch the init kwargs for the resolver. resolver_config = self._exec_properties[resolver.RESOLVER_CONFIG] if (isinstance(resolver_config, dict) and resolver_config.get('desired_num_of_artifacts', 0) > 1): raise ValueError( 'Only desired_num_of_artifacts=1 is supported currently.' ' Got {}'.format( resolver_config.get('desired_num_of_artifacts'))) # Specify the outputs of the task. for name, output_channel in self._outputs.items(): # Currently, we're working under the assumption that for tasks # (those generated by BaseComponent), each channel contains a single # artifact. output_artifact_spec = compiler_utils.build_output_artifact_spec( output_channel) task_spec.outputs.artifacts[name].CopyFrom(output_artifact_spec) # Specify the input parameters of the task. for k, v in compiler_utils.build_input_parameter_spec( self._exec_properties).items(): task_spec.inputs.parameters[k].CopyFrom(v) artifact_queries = {} # Buid the artifact query for each channel in the input dict. for name, c in self._inputs.items(): query_filter = ("artifact_type='{type}' and state={state}").format( type=compiler_utils.get_artifact_title(c.type), state=metadata_store_pb2.Artifact.State.Name( metadata_store_pb2.Artifact.LIVE)) artifact_queries[name] = ResolverSpec.ArtifactQuerySpec( filter=query_filter) resolver_spec = ResolverSpec(output_artifact_queries=artifact_queries) executor = pipeline_pb2.PipelineDeploymentConfig.ExecutorSpec() executor.resolver.CopyFrom(resolver_spec) self._deployment_config.executors[executor_label].CopyFrom(executor) return [task_spec]
def _build_resolver_for_latest_blessed_model( self, model_channel_key: str, model_blessing_resolver_name: str, model_blessing_channel_key: str) -> pipeline_pb2.PipelineTaskSpec: """Builds the resolver spec for latest blessed Model artifact.""" name = '{}{}'.format(self._name, _MODEL_RESOLVER_SUFFIX) # Component def. component_def = pipeline_pb2.ComponentSpec() executor_label = _EXECUTOR_LABEL_PATTERN.format(name) component_def.executor_label = executor_label input_artifact_spec = compiler_utils.build_input_artifact_spec( self._outputs[model_blessing_channel_key]) component_def.input_definitions.artifacts[ _MODEL_RESOLVER_INPUT_KEY].CopyFrom(input_artifact_spec) output_artifact_spec = compiler_utils.build_output_artifact_spec( self._outputs[model_channel_key]) component_def.output_definitions.artifacts[model_channel_key].CopyFrom( output_artifact_spec) self._component_defs[name] = component_def # Task spec. task_spec = pipeline_pb2.PipelineTaskSpec() task_spec.task_info.name = name task_spec.component_ref.name = name input_artifact_spec = pipeline_pb2.TaskInputsSpec.InputArtifactSpec() input_artifact_spec.task_output_artifact.producer_task = model_blessing_resolver_name input_artifact_spec.task_output_artifact.output_artifact_key = model_blessing_channel_key task_spec.inputs.artifacts[_MODEL_RESOLVER_INPUT_KEY].CopyFrom( input_artifact_spec) # Resolver executor spec. executor = pipeline_pb2.PipelineDeploymentConfig.ExecutorSpec() artifact_queries = {} query_filter = ( 'artifact_type="{type}" and ' 'state={state} and name={{$.inputs.artifacts["{input_key}"]' '.metadata.{property_key}.number_value}}').format( type=compiler_utils.get_artifact_title( standard_artifacts.Model), state=metadata_store_pb2.Artifact.State.Name( metadata_store_pb2.Artifact.LIVE), input_key=_MODEL_RESOLVER_INPUT_KEY, property_key=constants.ARTIFACT_PROPERTY_CURRENT_MODEL_ID_KEY) artifact_queries[model_channel_key] = ResolverSpec.ArtifactQuerySpec( filter=query_filter) executor.resolver.CopyFrom( ResolverSpec(output_artifact_queries=artifact_queries)) self._deployment_config.executors[executor_label].CopyFrom(executor) return task_spec
def _build_resolver_for_latest_blessed_model( self, model_channel_key: str, model_blessing_resolver_name: str, model_blessing_channel_key: str) -> pipeline_pb2.PipelineTaskSpec: """Builds the resolver spec for latest blessed Model artifact.""" # 1. Build the task info. result = pipeline_pb2.PipelineTaskSpec() name = '{}{}'.format(self._name, _MODEL_RESOLVER_SUFFIX) result.task_info.CopyFrom(pipeline_pb2.PipelineTaskInfo(name=name)) executor_label = _EXECUTOR_LABEL_PATTERN.format(name) result.executor_label = executor_label # 2. Specify the input of the task. The output from model_blessing_resolver # will be used as the input. input_artifact_spec = pipeline_pb2.TaskInputsSpec.InputArtifactSpec( producer_task=model_blessing_resolver_name, output_artifact_key=model_blessing_channel_key) result.inputs.artifacts[_MODEL_RESOLVER_INPUT_KEY].CopyFrom( input_artifact_spec) # 3. Specify the outputs of the task. model_resolver has one output for # the latest blessed model. result.outputs.artifacts[model_channel_key].CopyFrom( compiler_utils.build_output_artifact_spec( self._outputs[model_channel_key])) # 4. Build the resolver executor spec for latest blessed Model. executor = pipeline_pb2.PipelineDeploymentConfig.ExecutorSpec() artifact_queries = {} query_filter = ( "artifact_type='{type}' and " "state={state} and name={{$.inputs.artifacts['{input_key}']" ".custom_properties['{property_key}']}}").format( type=compiler_utils.get_artifact_title( standard_artifacts.Model), state=metadata_store_pb2.Artifact.State.Name( metadata_store_pb2.Artifact.LIVE), input_key=_MODEL_RESOLVER_INPUT_KEY, property_key=constants.ARTIFACT_PROPERTY_CURRENT_MODEL_ID_KEY) artifact_queries[model_channel_key] = ResolverSpec.ArtifactQuerySpec( filter=query_filter) executor.resolver.CopyFrom( ResolverSpec(output_artifact_queries=artifact_queries)) self._deployment_config.executors[executor_label].CopyFrom(executor) return result
def testBuildContainerTask(self): task = test_utils.DummyProducerComponent( output1=channel_utils.as_channel([standard_artifacts.Model()]), param1='value1', ) deployment_config = pipeline_pb2.PipelineDeploymentConfig() my_builder = step_builder.StepBuilder( node=task, image='gcr.io/tensorflow/tfx:latest', # Note this has no effect here. deployment_config=deployment_config) actual_step_spec = self._sole(my_builder.build()) self.assertProtoEquals( test_utils.get_proto_from_test_data( 'expected_dummy_container_spec_task.pbtxt', pipeline_pb2.PipelineTaskSpec()), actual_step_spec) self.assertProtoEquals( test_utils.get_proto_from_test_data( 'expected_dummy_container_spec_executor.pbtxt', pipeline_pb2.PipelineDeploymentConfig()), deployment_config)
def _build_resolver_for_latest_model_blessing( self, model_blessing_channel_key: str) -> pipeline_pb2.PipelineTaskSpec: """Builds the resolver spec for latest valid ModelBlessing artifact.""" name = '{}{}'.format(self._name, _MODEL_BLESSING_RESOLVER_SUFFIX) # Component def. component_def = pipeline_pb2.ComponentSpec() executor_label = _EXECUTOR_LABEL_PATTERN.format(name) component_def.executor_label = executor_label output_artifact_spec = compiler_utils.build_output_artifact_spec( self._outputs[model_blessing_channel_key]) component_def.output_definitions.artifacts[ model_blessing_channel_key].CopyFrom(output_artifact_spec) self._component_defs[name] = component_def # Task spec. task_spec = pipeline_pb2.PipelineTaskSpec() task_spec.task_info.name = name task_spec.component_ref.name = name # Builds the resolver executor spec for latest valid ModelBlessing. executor = pipeline_pb2.PipelineDeploymentConfig.ExecutorSpec() artifact_queries = {} query_filter = ('artifact_type="{type}" and state={state}' ' and metadata.{key}.number_value={value}').format( type=compiler_utils.get_artifact_title( standard_artifacts.ModelBlessing), state=metadata_store_pb2.Artifact.State.Name( metadata_store_pb2.Artifact.LIVE), key=constants.ARTIFACT_PROPERTY_BLESSED_KEY, value=constants.BLESSED_VALUE) artifact_queries[ model_blessing_channel_key] = ResolverSpec.ArtifactQuerySpec( filter=query_filter) executor.resolver.CopyFrom( ResolverSpec(output_artifact_queries=artifact_queries)) self._deployment_config.executors[executor_label].CopyFrom(executor) return task_spec
def build(self) -> Dict[str, pipeline_pb2.PipelineTaskSpec]: """Builds a pipeline PipelineTaskSpec given the node information. Each TFX node maps one task spec and usually one component definition and one executor spec. (with resolver node as an exception. See explaination in the Returns section). - Component definition includes interfaces of a node. For example, name and type information of inputs/outputs/execution_properties. - Task spec contains the topologies around the node. For example, the dependency nodes, where to read the inputs and exec_properties (from another task, from parent component or from a constant value). The task spec has the name of the component definition it references. It is possible that a task spec references an existing component definition that's built previously. - Executor spec encodes how the node is actually executed. For example, args to start a container, or query strings for resolvers. All executor spec will be packed into deployment config proto. During the build, all three parts mentioned above will be updated. Returns: A Dict mapping from node id to PipelineTaskSpec messages corresponding to the node. For most of the cases, the dict contains a single element. The only exception is when compiling latest blessed model resolver. One DSL node will be split to two resolver specs to reflect the two-phased query execution. Raises: NotImplementedError: When the node being built is an InfraValidator. """ # 1. Resolver tasks won't have input artifacts in the API proto. First we # specialcase two resolver types we support. if isinstance(self._node, resolver.Resolver): return self._build_resolver_spec() # 2. Build component spec. component_def = pipeline_pb2.ComponentSpec() executor_label = _EXECUTOR_LABEL_PATTERN.format(self._name) component_def.executor_label = executor_label # Inputs for name, input_channel in self._inputs.items(): input_artifact_spec = compiler_utils.build_input_artifact_spec( input_channel) component_def.input_definitions.artifacts[name].CopyFrom( input_artifact_spec) # Outputs for name, output_channel in self._outputs.items(): # Currently, we're working under the assumption that for tasks # (those generated by BaseComponent), each channel contains a single # artifact. output_artifact_spec = compiler_utils.build_output_artifact_spec( output_channel) component_def.output_definitions.artifacts[name].CopyFrom( output_artifact_spec) # Exec properties for name, value in self._exec_properties.items(): # value can be None for unprovided optional exec properties. if value is None: continue parameter_type_spec = compiler_utils.build_parameter_type_spec( value) component_def.input_definitions.parameters[name].CopyFrom( parameter_type_spec) if self._name not in self._component_defs: self._component_defs[self._name] = component_def else: raise ValueError( f'Found duplicate component ids {self._name} while ' 'building component definitions.') # 3. Build task spec. task_spec = pipeline_pb2.PipelineTaskSpec() task_spec.task_info.name = self._name dependency_ids = [node.id for node in self._node.upstream_nodes] for name, input_channel in self._inputs.items(): # If the redirecting map is provided (usually for latest blessed model # resolver, we'll need to redirect accordingly. Also, the upstream node # list will be updated and replaced by the new producer id. producer_id = input_channel.producer_component_id output_key = input_channel.output_key for k, v in self._channel_redirect_map.items(): if k[0] == producer_id and producer_id in dependency_ids: dependency_ids.remove(producer_id) dependency_ids.append(v[0]) producer_id = self._channel_redirect_map.get( (producer_id, output_key), (producer_id, output_key))[0] output_key = self._channel_redirect_map.get( (producer_id, output_key), (producer_id, output_key))[1] input_artifact_spec = pipeline_pb2.TaskInputsSpec.InputArtifactSpec( ) input_artifact_spec.task_output_artifact.producer_task = producer_id input_artifact_spec.task_output_artifact.output_artifact_key = output_key task_spec.inputs.artifacts[name].CopyFrom(input_artifact_spec) for name, value in self._exec_properties.items(): if value is None: continue if isinstance(value, data_types.RuntimeParameter): parameter_utils.attach_parameter(value) task_spec.inputs.parameters[ name].component_input_parameter = value.name else: task_spec.inputs.parameters[name].CopyFrom( pipeline_pb2.TaskInputsSpec.InputParameterSpec( runtime_value=compiler_utils.value_converter(value))) task_spec.component_ref.name = self._name dependency_ids = sorted(dependency_ids) for dependency in dependency_ids: task_spec.dependent_tasks.append(dependency) if self._enable_cache: task_spec.caching_options.CopyFrom( pipeline_pb2.PipelineTaskSpec.CachingOptions( enable_cache=self._enable_cache)) # 4. Build the executor body for other common tasks. executor = pipeline_pb2.PipelineDeploymentConfig.ExecutorSpec() if isinstance(self._node, importer.Importer): executor.importer.CopyFrom(self._build_importer_spec()) elif isinstance(self._node, components.FileBasedExampleGen): executor.container.CopyFrom( self._build_file_based_example_gen_spec()) elif isinstance(self._node, (components.InfraValidator)): raise NotImplementedError( 'The componet type "{}" is not supported'.format( type(self._node))) else: executor.container.CopyFrom(self._build_container_spec()) self._deployment_config.executors[executor_label].CopyFrom(executor) return {self._name: task_spec}
def _build_latest_artifact_resolver( self) -> Dict[str, pipeline_pb2.PipelineTaskSpec]: """Builds a resolver spec for a latest artifact resolver. Returns: A list of two PipelineTaskSpecs. One represents the query for latest valid ModelBlessing artifact. Another one represents the query for latest blessed Model artifact. Raises: ValueError: when desired_num_of_artifacts != 1. 1 is the only supported value currently. """ # Fetch the init kwargs for the resolver. resolver_config = self._exec_properties[resolver.RESOLVER_CONFIG] if (isinstance(resolver_config, dict) and resolver_config.get('desired_num_of_artifacts', 0) > 1): raise ValueError( 'Only desired_num_of_artifacts=1 is supported currently.' ' Got {}'.format( resolver_config.get('desired_num_of_artifacts'))) component_def = pipeline_pb2.ComponentSpec() executor_label = _EXECUTOR_LABEL_PATTERN.format(self._name) component_def.executor_label = executor_label task_spec = pipeline_pb2.PipelineTaskSpec() task_spec.task_info.name = self._name for name, output_channel in self._outputs.items(): output_artifact_spec = compiler_utils.build_output_artifact_spec( output_channel) component_def.output_definitions.artifacts[name].CopyFrom( output_artifact_spec) for name, value in self._exec_properties.items(): if value is None: continue parameter_type_spec = compiler_utils.build_parameter_type_spec( value) component_def.input_definitions.parameters[name].CopyFrom( parameter_type_spec) if isinstance(value, data_types.RuntimeParameter): parameter_utils.attach_parameter(value) task_spec.inputs.parameters[ name].component_input_parameter = value.name else: task_spec.inputs.parameters[name].CopyFrom( pipeline_pb2.TaskInputsSpec.InputParameterSpec( runtime_value=compiler_utils.value_converter(value))) self._component_defs[self._name] = component_def task_spec.component_ref.name = self._name artifact_queries = {} # Buid the artifact query for each channel in the input dict. for name, c in self._inputs.items(): query_filter = ('artifact_type="{type}" and state={state}').format( type=compiler_utils.get_artifact_title(c.type), state=metadata_store_pb2.Artifact.State.Name( metadata_store_pb2.Artifact.LIVE)) # Resolver's output dict has the same set of keys as its input dict. artifact_queries[name] = ResolverSpec.ArtifactQuerySpec( filter=query_filter) resolver_spec = ResolverSpec(output_artifact_queries=artifact_queries) executor = pipeline_pb2.PipelineDeploymentConfig.ExecutorSpec() executor.resolver.CopyFrom(resolver_spec) self._deployment_config.executors[executor_label].CopyFrom(executor) return {self._name: task_spec}
def build(self) -> List[pipeline_pb2.PipelineTaskSpec]: """Builds a pipeline StepSpec given the node information. Returns: A list of PipelineTaskSpec messages corresponding to the node. For most of the cases, the list contains a single element. The only exception is when compiling latest blessed model resolver. One DSL node will be split to two resolver specs to reflect the two-phased query execution. Raises: NotImplementedError: When the node being built is an InfraValidator. """ task_spec = pipeline_pb2.PipelineTaskSpec() task_spec.task_info.CopyFrom( pipeline_pb2.PipelineTaskInfo(name=self._name)) executor_label = _EXECUTOR_LABEL_PATTERN.format(self._name) task_spec.executor_label = executor_label executor = pipeline_pb2.PipelineDeploymentConfig.ExecutorSpec() # 1. Resolver tasks won't have input artifacts in the API proto. First we # specialcase two resolver types we support. if isinstance(self._node, resolver.Resolver): return self._build_resolver_spec() # 2. Build the node spec. # TODO(b/157641727): Tests comparing dictionaries are brittle when comparing # lists as ordering matters. dependency_ids = [node.id for node in self._node.upstream_nodes] # Specify the inputs of the task. for name, input_channel in self._inputs.items(): # If the redirecting map is provided (usually for latest blessed model # resolver, we'll need to redirect accordingly. Also, the upstream node # list will be updated and replaced by the new producer id. producer_id = input_channel.producer_component_id output_key = input_channel.output_key for k, v in self._channel_redirect_map.items(): if k[0] == producer_id and producer_id in dependency_ids: dependency_ids.remove(producer_id) dependency_ids.append(v[0]) producer_id = self._channel_redirect_map.get( (producer_id, output_key), (producer_id, output_key))[0] output_key = self._channel_redirect_map.get( (producer_id, output_key), (producer_id, output_key))[1] input_artifact_spec = pipeline_pb2.TaskInputsSpec.InputArtifactSpec( producer_task=producer_id, output_artifact_key=output_key) task_spec.inputs.artifacts[name].CopyFrom(input_artifact_spec) # Specify the outputs of the task. for name, output_channel in self._outputs.items(): # Currently, we're working under the assumption that for tasks # (those generated by BaseComponent), each channel contains a single # artifact. output_artifact_spec = compiler_utils.build_output_artifact_spec( output_channel) task_spec.outputs.artifacts[name].CopyFrom(output_artifact_spec) # Specify the input parameters of the task. for k, v in compiler_utils.build_input_parameter_spec( self._exec_properties).items(): task_spec.inputs.parameters[k].CopyFrom(v) # 3. Build the executor body for other common tasks. if isinstance(self._node, importer.Importer): executor.importer.CopyFrom(self._build_importer_spec()) elif isinstance(self._node, components.FileBasedExampleGen): executor.container.CopyFrom( self._build_file_based_example_gen_spec()) elif isinstance(self._node, (components.InfraValidator)): raise NotImplementedError( 'The componet type "{}" is not supported'.format( type(self._node))) else: executor.container.CopyFrom(self._build_container_spec()) dependency_ids = sorted(dependency_ids) for dependency in dependency_ids: task_spec.dependent_tasks.append(dependency) task_spec.caching_options.CopyFrom( pipeline_pb2.PipelineTaskSpec.CachingOptions( enable_cache=self._enable_cache)) # 4. Attach the built executor spec to the deployment config. self._deployment_config.executors[executor_label].CopyFrom(executor) return [task_spec]