def setUp(self): super(PlaceholderUtilsTest, self).setUp() examples = [standard_artifacts.Examples()] examples[0].uri = "/tmp" examples[0].split_names = artifact_utils.encode_split_names( ["train", "eval"]) self._serving_spec = infra_validator_pb2.ServingSpec() self._serving_spec.tensorflow_serving.tags.extend( ["latest", "1.15.0-gpu"]) self._resolution_context = placeholder_utils.ResolutionContext( exec_info=data_types.ExecutionInfo( input_dict={ "model": [standard_artifacts.Model()], "examples": examples, }, output_dict={"blessing": [standard_artifacts.ModelBlessing()]}, exec_properties={ "proto_property": json_format.MessageToJson(message=self._serving_spec, sort_keys=True, preserving_proto_field_name=True, indent=0) }, execution_output_uri="test_executor_output_uri", stateful_working_dir="test_stateful_working_dir", pipeline_node=pipeline_pb2.PipelineNode( node_info=pipeline_pb2.NodeInfo( type=metadata_store_pb2.ExecutionType( name="infra_validator"))), pipeline_info=pipeline_pb2.PipelineInfo( id="test_pipeline_id")), executor_spec=executable_spec_pb2.PythonClassExecutableSpec( class_path="test_class_path"), ) # Resolution context to simulate missing optional values. self._none_resolution_context = placeholder_utils.ResolutionContext( exec_info=data_types.ExecutionInfo( input_dict={ "model": [], "examples": [], }, output_dict={"blessing": []}, exec_properties={}, pipeline_node=pipeline_pb2.PipelineNode( node_info=pipeline_pb2.NodeInfo( type=metadata_store_pb2.ExecutionType( name="infra_validator"))), pipeline_info=pipeline_pb2.PipelineInfo( id="test_pipeline_id")), executor_spec=None, platform_config=None)
def testDumpUiMetadata(self): trainer = pipeline_pb2.PipelineNode() trainer.node_info.type.name = 'tfx.components.trainer.component.Trainer' model_run_out_spec = pipeline_pb2.OutputSpec( artifact_spec=pipeline_pb2.OutputSpec.ArtifactSpec( type=metadata_store_pb2.ArtifactType( name=standard_artifacts.ModelRun.TYPE_NAME))) trainer.outputs.outputs['model_run'].CopyFrom(model_run_out_spec) model_run = standard_artifacts.ModelRun() model_run.uri = 'model_run_uri' exec_info = data_types.ExecutionInfo( input_dict={}, output_dict={'model_run': [model_run]}, exec_properties={}, execution_id='id') ui_metadata_path = os.path.join( os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), self._testMethodName, 'json') fileio.makedirs(os.path.dirname(ui_metadata_path)) container_entrypoint._dump_ui_metadata(trainer, exec_info, ui_metadata_path) with open(ui_metadata_path) as f: ui_metadata = json.load(f) self.assertEqual('tensorboard', ui_metadata['outputs'][-1]['type']) self.assertEqual('model_run_uri', ui_metadata['outputs'][-1]['source'])
def test_node_uid_from_pipeline_node(self): pipeline = pipeline_pb2.Pipeline() pipeline.pipeline_info.id = 'pipeline' node = pipeline_pb2.PipelineNode() node.node_info.id = 'Trainer' self.assertEqual( task_lib.NodeUid( pipeline_uid=task_lib.PipelineUid(pipeline_id='pipeline'), node_id='Trainer'), task_lib.NodeUid.from_pipeline_node(pipeline, node))
def test_node_uid_from_pipeline_node(self): pipeline = pipeline_pb2.Pipeline() pipeline.pipeline_info.id = 'pipeline' pipeline.runtime_spec.pipeline_run_id.field_value.string_value = 'run0' node = pipeline_pb2.PipelineNode() node.node_info.id = 'Trainer' self.assertEqual( task_lib.NodeUid(pipeline_uid=task_lib.PipelineUid( pipeline_id='pipeline', pipeline_run_id='run0'), node_id='Trainer'), task_lib.NodeUid.from_pipeline_node(pipeline, node))
def testGetCacheContextTwiceDifferentNodeInfo(self): with metadata.Metadata(connection_config=self._connection_config) as m: self._get_cache_context(m) self._get_cache_context(m, custom_pipeline_node=text_format.Parse( """ node_info { id: "new_node_id" } """, pipeline_pb2.PipelineNode())) # Different executor spec will result in new cache context. self.assertLen(m.store.get_contexts(), 2)
def testGetCacheContextTwiceDifferentExecutorSpec(self): with metadata.Metadata(connection_config=self._connection_config) as m: self._get_cache_context(m) self._get_cache_context(m, custom_pipeline_node=text_format.Parse( """ executor { python_class_executor_spec {class_path: 'n.e.w'} } """, pipeline_pb2.PipelineNode())) # Different executor spec will result in new cache context. self.assertLen(m.store.get_contexts(), 2)
def test_exec_node_task_create(self): pipeline = pipeline_pb2.Pipeline() pipeline.pipeline_info.id = 'pipeline' pipeline.runtime_spec.pipeline_run_id.field_value.string_value = 'run0' node = pipeline_pb2.PipelineNode() node.node_info.id = 'Trainer' self.assertEqual( task_lib.ExecNodeTask(node_uid=task_lib.NodeUid( pipeline_id='pipeline', pipeline_run_id='run0', node_id='Trainer'), execution_id=123), task_lib.ExecNodeTask.create(pipeline, node, 123))
def _set_up_test_execution_info(self, input_dict=None, output_dict=None, exec_properties=None): return data_types.ExecutionInfo( input_dict=input_dict or {}, output_dict=output_dict or {}, exec_properties=exec_properties or {}, execution_output_uri='/testing/executor/output/', stateful_working_dir='/testing/stateful/dir', pipeline_node=pipeline_pb2.PipelineNode( node_info=pipeline_pb2.NodeInfo( type=metadata_store_pb2.ExecutionType(name='Docker_executor'))), pipeline_info=pipeline_pb2.PipelineInfo(id='test_pipeline_id'))
def _remove_dangling_downstream_nodes( node: p_pb2.PipelineNode, node_ids_to_keep: Collection[str]) -> p_pb2.PipelineNode: """Remove node.downstream_nodes that have been filtered out.""" # Using a loop instead of set intersection to ensure the same order. downstream_nodes_to_keep = [ downstream_node for downstream_node in node.downstream_nodes if downstream_node in node_ids_to_keep ] if len(downstream_nodes_to_keep) == len(node.downstream_nodes): return node result = p_pb2.PipelineNode() result.CopyFrom(node) result.downstream_nodes[:] = downstream_nodes_to_keep return result
def _set_up_test_execution_info(self, input_dict=None, output_dict=None, exec_properties=None): return data_types.ExecutionInfo( execution_id=123, input_dict=input_dict or {}, output_dict=output_dict or {}, exec_properties=exec_properties or {}, execution_output_uri='/testing/executor/output/', stateful_working_dir='/testing/stateful/dir', pipeline_node=pipeline_pb2.PipelineNode( node_info=pipeline_pb2.NodeInfo( id='fakecomponent-fakecomponent')), pipeline_info=pipeline_pb2.PipelineInfo(id='Test'), pipeline_run_id='123')
def _get_execution_info(self, input_dict, output_dict, exec_properties): pipeline_node = pipeline_pb2.PipelineNode( node_info={'id': 'MyPythonNode'}) pipeline_info = pipeline_pb2.PipelineInfo(id='MyPipeline') stateful_working_dir = os.path.join(self.tmp_dir, 'stateful_working_dir') executor_output_uri = os.path.join(self.tmp_dir, 'executor_output') return data_types.ExecutionInfo( execution_id=1, input_dict=input_dict, output_dict=output_dict, exec_properties=exec_properties, stateful_working_dir=stateful_working_dir, execution_output_uri=executor_output_uri, pipeline_node=pipeline_node, pipeline_info=pipeline_info, pipeline_run_id=99)
def setUp(self): super().setUp() self._connection_config = metadata_store_pb2.ConnectionConfig() self._connection_config.sqlite.SetInParent() self._module_file_path = os.path.join(self.tmp_dir, 'module_file') self._input_artifacts = {'input_examples': [standard_artifacts.Examples()]} self._output_artifacts = {'output_models': [standard_artifacts.Model()]} self._parameters = {'module_file': self._module_file_path} self._module_file_content = 'module content' self._pipeline_node = text_format.Parse( """ executor { python_class_executor_spec {class_path: 'a.b.c'} } """, pipeline_pb2.PipelineNode()) self._executor_class_path = 'a.b.c' self._pipeline_info = pipeline_pb2.PipelineInfo(id='pipeline_id')
def testExecutionInfoSerialization(self): my_artifact = _MyArtifact() my_artifact.int1 = 111 execution_output_uri = 'output/uri' stateful_working_dir = 'workding/dir' exec_properties = { 'property1': 'value1', 'property2': 'value2', } pipeline_info = pipeline_pb2.PipelineInfo(id='my_pipeline') pipeline_node = text_format.Parse( """ node_info { id: 'my_node' } """, pipeline_pb2.PipelineNode()) original = data_types.ExecutionInfo( input_dict={'input': [my_artifact]}, output_dict={'output': [my_artifact]}, exec_properties=exec_properties, execution_output_uri=execution_output_uri, stateful_working_dir=stateful_working_dir, pipeline_info=pipeline_info, pipeline_node=pipeline_node) serialized = python_execution_binary_utils.serialize_execution_info( original) rehydrated = python_execution_binary_utils.deserialize_execution_info( serialized) self.CheckArtifactDict(rehydrated.input_dict, {'input': [my_artifact]}) self.CheckArtifactDict(rehydrated.output_dict, {'output': [my_artifact]}) self.assertEqual(rehydrated.exec_properties, exec_properties) self.assertEqual(rehydrated.execution_output_uri, execution_output_uri) self.assertEqual(rehydrated.stateful_working_dir, stateful_working_dir) self.assertProtoEquals(rehydrated.pipeline_info, original.pipeline_info) self.assertProtoEquals(rehydrated.pipeline_node, original.pipeline_node)
def _handle_missing_inputs( node: p_pb2.PipelineNode, node_ids_to_keep: Collection[str], pipeline_run_id_fn: Callable[[p_pb2.InputSpec.Channel], str], ) -> p_pb2.PipelineNode: """Private helper function to handle missing inputs. Args: node: The Pipeline node to check for missing inputs. node_ids_to_keep: The node_ids that are not filtered out. pipeline_run_id_fn: If this node has upstream nodes that are filtered out, this function would be used to obtain the pipeline_run_id for that input channel, which would then be provided as the 'pipeline_run_id' in the 'pipeline_run' ContextQuery. Returns: A copy of the Pipeline node where all inputs that reference filtered-out nodes would have their 'pipeline_run' ContextQuery updated. """ upstream_nodes_to_replace = set() upstream_nodes_to_keep = [] for upstream_node in node.upstream_nodes: if upstream_node in node_ids_to_keep: upstream_nodes_to_keep.append(upstream_node) else: upstream_nodes_to_replace.add(upstream_node) if not upstream_nodes_to_replace: return node # No parent missing, no need to change anything. result = p_pb2.PipelineNode() result.CopyFrom(node) for input_spec in result.inputs.inputs.values(): for channel in input_spec.channels: if channel.producer_node_query.id in upstream_nodes_to_replace: pipeline_run_id = pipeline_run_id_fn(channel) _replace_pipeline_run_id_in_channel(channel, pipeline_run_id) result.upstream_nodes[:] = upstream_nodes_to_keep return result
def testRunExecutorWithBeamPipelineArgs(self): executor_spec = text_format.Parse( """ python_executor_spec: { class_path: "tfx.orchestration.portable.beam_executor_operator_test.ValidateBeamPipelineArgsExecutor" } beam_pipeline_args: "--runner=DirectRunner" """, executable_spec_pb2.BeamExecutableSpec()) operator = beam_executor_operator.BeamExecutorOperator(executor_spec) pipeline_node = pipeline_pb2.PipelineNode( node_info={'id': 'MyBeamNode'}) pipeline_info = pipeline_pb2.PipelineInfo(id='MyPipeline') executor_output_uri = os.path.join(self.tmp_dir, 'executor_output') executor_output = operator.run_executor( data_types.ExecutionInfo( execution_id=1, input_dict={'input_key': [standard_artifacts.Examples()]}, output_dict={'output_key': [standard_artifacts.Model()]}, exec_properties={}, execution_output_uri=executor_output_uri, pipeline_node=pipeline_node, pipeline_info=pipeline_info, pipeline_run_id=99)) self.assertProtoPartiallyEquals( """ output_artifacts { key: "output_key" value { artifacts { custom_properties { key: "name" value { string_value: "MyPipeline.MyBeamNode.my_model" } } } } }""", executor_output)
def testRun(self): # Create input dir. self._input_base_path = os.path.join(self._test_dir, 'input_base') tf.io.gfile.makedirs(self._input_base_path) # Create PipelineInfo and PipelineNode pipeline_info = pipeline_pb2.PipelineInfo() pipeline_node = pipeline_pb2.PipelineNode() # Fake previous outputs span1_v1_split1 = os.path.join(self._input_base_path, 'span01', 'version01', 'split1', 'data') io_utils.write_string_file(span1_v1_split1, 'testing11') span1_v1_split2 = os.path.join(self._input_base_path, 'span01', 'version01', 'split2', 'data') io_utils.write_string_file(span1_v1_split2, 'testing12') ir_driver = driver.Driver(self._mock_metadata, pipeline_info, pipeline_node) example = standard_artifacts.Examples() # Prepare output_dic example.uri = 'my_uri' # Will verify that this uri is not changed. output_dic = {utils.EXAMPLES_KEY: [example]} # Prepare output_dic exec_proterties. exec_properties = { utils.INPUT_BASE_KEY: self._input_base_path, utils.INPUT_CONFIG_KEY: json_format.MessageToJson( example_gen_pb2.Input(splits=[ example_gen_pb2.Input.Split( name='s1', pattern='span{SPAN}/version{VERSION}/split1/*'), example_gen_pb2.Input.Split( name='s2', pattern='span{SPAN}/version{VERSION}/split2/*') ]), preserving_proto_field_name=True), } result = ir_driver.run(None, output_dic, exec_properties) print(result) # Assert exec_properties' values exec_properties = result.exec_properties self.assertEqual(exec_properties[utils.SPAN_PROPERTY_NAME].int_value, 1) self.assertEqual(exec_properties[utils.VERSION_PROPERTY_NAME].int_value, 1) updated_input_config = example_gen_pb2.Input() json_format.Parse(exec_properties[utils.INPUT_CONFIG_KEY].string_value, updated_input_config) self.assertProtoEquals( """ splits { name: "s1" pattern: "span01/version01/split1/*" } splits { name: "s2" pattern: "span01/version01/split2/*" }""", updated_input_config) self.assertRegex( exec_properties[utils.FINGERPRINT_PROPERTY_NAME].string_value, r'split:s1,num_files:1,total_bytes:9,xor_checksum:.*,sum_checksum:.*\nsplit:s2,num_files:1,total_bytes:9,xor_checksum:.*,sum_checksum:.*' ) # Assert output_artifacts' values self.assertLen(result.output_artifacts[utils.EXAMPLES_KEY].artifacts, 1) output_example = result.output_artifacts[utils.EXAMPLES_KEY].artifacts[0] self.assertEqual(output_example.uri, example.uri) self.assertEqual( output_example.custom_properties[utils.SPAN_PROPERTY_NAME].string_value, '1') self.assertEqual( output_example.custom_properties[ utils.VERSION_PROPERTY_NAME].string_value, '1') self.assertRegex( output_example.custom_properties[ utils.FINGERPRINT_PROPERTY_NAME].string_value, r'split:s1,num_files:1,total_bytes:9,xor_checksum:.*,sum_checksum:.*\nsplit:s2,num_files:1,total_bytes:9,xor_checksum:.*,sum_checksum:.*' )
def testSuccess(self): with self._mlmd_connection as m: # Publishes two models which will be consumed by downstream resolver. output_model_1 = types.Artifact( self._my_trainer.outputs.outputs['model'].artifact_spec.type) output_model_1.uri = 'my_model_uri_1' output_model_2 = types.Artifact( self._my_trainer.outputs.outputs['model'].artifact_spec.type) output_model_2.uri = 'my_model_uri_2' contexts = context_lib.prepare_contexts(m, self._my_trainer.contexts) execution = execution_publish_utils.register_execution( m, self._my_trainer.node_info.type, contexts) execution_publish_utils.publish_succeeded_execution( m, execution.id, contexts, { 'model': [output_model_1, output_model_2], }) handler = resolver_node_handler.ResolverNodeHandler() execution_metadata = handler.run( mlmd_connection=self._mlmd_connection, pipeline_node=self._resolver_node, pipeline_info=self._pipeline_info, pipeline_runtime_spec=self._pipeline_runtime_spec) with self._mlmd_connection as m: # There is no way to directly verify the output artifact of the resolver # So here a fake downstream component is created which listens to the # resolver's output and we verify its input. down_stream_node = text_format.Parse( """ inputs { inputs { key: "input_models" value { channels { producer_node_query { id: "my_resolver" } context_queries { type { name: "pipeline" } name { field_value { string_value: "my_pipeline" } } } context_queries { type { name: "component" } name { field_value { string_value: "my_resolver" } } } artifact_query { type { name: "Model" } } output_key: "models" } min_count: 1 } } } upstream_nodes: "my_resolver" """, pipeline_pb2.PipelineNode()) downstream_input_artifacts = inputs_utils.resolve_input_artifacts( metadata_handler=m, node_inputs=down_stream_node.inputs) downstream_input_model = downstream_input_artifacts['input_models'] self.assertLen(downstream_input_model, 1) self.assertProtoPartiallyEquals( """ id: 2 type_id: 5 uri: "my_model_uri_2" state: LIVE""", downstream_input_model[0].mlmd_artifact, ignored_fields=[ 'create_time_since_epoch', 'last_update_time_since_epoch' ]) [execution] = m.store.get_executions_by_id([execution_metadata.id]) self.assertProtoPartiallyEquals(""" id: 2 type_id: 6 last_known_state: COMPLETE """, execution, ignored_fields=[ 'create_time_since_epoch', 'last_update_time_since_epoch' ])
def _compile_node( self, tfx_node: base_node.BaseNode, compile_context: _CompilerContext, deployment_config: pipeline_pb2.IntermediateDeploymentConfig, enable_cache: bool, ) -> pipeline_pb2.PipelineNode: """Compiles an individual TFX node into a PipelineNode proto. Args: tfx_node: A TFX node. compile_context: Resources needed to compile the node. deployment_config: Intermediate deployment config to set. Will include related specs for executors, drivers and platform specific configs. enable_cache: whether cache is enabled Raises: TypeError: When supplied tfx_node has values of invalid type. Returns: A PipelineNode proto that encodes information of the node. """ node = pipeline_pb2.PipelineNode() # Step 1: Node info node.node_info.type.name = tfx_node.type if isinstance(tfx_node, base_component.BaseComponent) and tfx_node.type_annotation: node.node_info.type.base_type = ( tfx_node.type_annotation.MLMD_SYSTEM_BASE_TYPE) node.node_info.id = tfx_node.id # Step 2: Node Context # Context for the pipeline, across pipeline runs. pipeline_context_pb = node.contexts.contexts.add() pipeline_context_pb.type.name = constants.PIPELINE_CONTEXT_TYPE_NAME pipeline_context_pb.name.field_value.string_value = compile_context.pipeline_info.pipeline_context_name # Context for the current pipeline run. if compile_context.is_sync_mode: pipeline_run_context_pb = node.contexts.contexts.add() pipeline_run_context_pb.type.name = constants.PIPELINE_RUN_CONTEXT_TYPE_NAME compiler_utils.set_runtime_parameter_pb( pipeline_run_context_pb.name.runtime_parameter, constants.PIPELINE_RUN_ID_PARAMETER_NAME, str) # Context for the node, across pipeline runs. node_context_pb = node.contexts.contexts.add() node_context_pb.type.name = constants.NODE_CONTEXT_TYPE_NAME node_context_pb.name.field_value.string_value = ( compiler_utils.node_context_name( compile_context.pipeline_info.pipeline_context_name, node.node_info.id)) # Pre Step 3: Alter graph topology if needed. if compile_context.is_async_mode: tfx_node_inputs = self._embed_upstream_resolver_nodes( compile_context, tfx_node, node) else: tfx_node_inputs = tfx_node.inputs # Step 3: Node inputs # Step 3.1: Generate implicit input channels # Step 3.1.1: Conditionals implicit_input_channels = {} predicates = conditional.get_predicates(tfx_node) if predicates: implicit_keys_map = {} for key, chnl in tfx_node_inputs.items(): if not isinstance(chnl, types.Channel): raise ValueError( "Conditional only support using channel as a predicate.") implicit_keys_map[compiler_utils.implicit_channel_key(chnl)] = key encoded_predicates = [] for predicate in predicates: for chnl in predicate.dependent_channels(): implicit_key = compiler_utils.implicit_channel_key(chnl) if implicit_key not in implicit_keys_map: # Store this channel and add it to the node inputs later. implicit_input_channels[implicit_key] = chnl encoded_predicates.append( predicate.encode_with_keys( compiler_utils.build_channel_to_key_fn(implicit_keys_map))) # In async pipeline, conditional resolver step should be the last step # in all resolver steps of a node. resolver_step = node.inputs.resolver_config.resolver_steps.add() resolver_step.class_path = constants.CONDITIONAL_RESOLVER_CLASS_PATH resolver_step.config_json = json_utils.dumps( {"predicates": encoded_predicates}) # Step 3.1.2: Add placeholder exec props to implicit_input_channels for key, value in tfx_node.exec_properties.items(): if isinstance(value, placeholder.ChannelWrappedPlaceholder): if not (inspect.isclass(value.channel.type) and issubclass(value.channel.type, value_artifact.ValueArtifact)): raise ValueError("output channel to dynamic exec properties is not " "ValueArtifact") implicit_key = compiler_utils.implicit_channel_key(value.channel) implicit_input_channels[implicit_key] = value.channel # Step 3.2: Handle ForEach. dsl_contexts = context_manager.get_contexts(tfx_node) for dsl_context in dsl_contexts: if isinstance(dsl_context, for_each.ForEachContext): for input_key, channel in tfx_node_inputs.items(): if (isinstance(channel, types.LoopVarChannel) and channel.wrapped is dsl_context.wrapped_channel): node.inputs.resolver_config.resolver_steps.extend( _compile_for_each_context(input_key)) break else: # Ideally should not reach here as the same check is performed at # ForEachContext.will_add_node(). raise ValueError( f"Unable to locate ForEach loop variable {dsl_context.channel} " f"from inputs of node {tfx_node.id}.") # Check loop variable is used outside the ForEach. for input_key, channel in tfx_node_inputs.items(): if isinstance(channel, types.LoopVarChannel): dsl_context_ids = {c.id for c in dsl_contexts} if channel.context_id not in dsl_context_ids: raise ValueError( "Loop variable cannot be used outside the ForEach " f"(node_id = {tfx_node.id}, input_key = {input_key}).") # Step 3.3: Fill node inputs for key, value in itertools.chain(tfx_node_inputs.items(), implicit_input_channels.items()): input_spec = node.inputs.inputs[key] for input_channel in channel_utils.get_individual_channels(value): chnl = input_spec.channels.add() # If the node input comes from another node's output, fill the context # queries with the producer node's contexts. if input_channel in compile_context.node_outputs: chnl.producer_node_query.id = input_channel.producer_component_id # Here we rely on pipeline.components to be topologically sorted. assert input_channel.producer_component_id in compile_context.node_pbs, ( "producer component should have already been compiled.") producer_pb = compile_context.node_pbs[ input_channel.producer_component_id] for producer_context in producer_pb.contexts.contexts: context_query = chnl.context_queries.add() context_query.type.CopyFrom(producer_context.type) context_query.name.CopyFrom(producer_context.name) # If the node input does not come from another node's output, fill the # context queries based on Channel info. We requires every channel to # have pipeline context and will fill it automatically. else: # Add pipeline context query. context_query = chnl.context_queries.add() context_query.type.CopyFrom(pipeline_context_pb.type) context_query.name.CopyFrom(pipeline_context_pb.name) # Optionally add node context query. if input_channel.producer_component_id: # Add node context query if `producer_component_id` is present. chnl.producer_node_query.id = input_channel.producer_component_id node_context_query = chnl.context_queries.add() node_context_query.type.name = constants.NODE_CONTEXT_TYPE_NAME node_context_query.name.field_value.string_value = "{}.{}".format( compile_context.pipeline_info.pipeline_context_name, input_channel.producer_component_id) artifact_type = input_channel.type._get_artifact_type() # pylint: disable=protected-access chnl.artifact_query.type.CopyFrom(artifact_type) chnl.artifact_query.type.ClearField("properties") if input_channel.output_key: chnl.output_key = input_channel.output_key # Set NodeInputs.min_count. if isinstance(tfx_node, base_component.BaseComponent): if key in implicit_input_channels: # Mark all input channel as optional for implicit inputs # (e.g. conditionals). This is suboptimal, but still a safe guess to # avoid breaking the pipeline run. input_spec.min_count = 0 else: try: # Calculating min_count from ComponentSpec.INPUTS. if tfx_node.spec.is_optional_input(key): input_spec.min_count = 0 else: input_spec.min_count = 1 except KeyError: # Currently we can fall here if the upstream resolver node inputs # are embedded into the current node (in async mode). We always # regard resolver's inputs as optional. if compile_context.is_async_mode: input_spec.min_count = 0 else: raise # TODO(b/170694459): Refactor special nodes as plugins. # Step 3.4: Special treatment for Resolver node. if compiler_utils.is_resolver(tfx_node): assert compile_context.is_sync_mode node.inputs.resolver_config.resolver_steps.extend( _compile_resolver_node(tfx_node)) # Step 4: Node outputs for key, value in tfx_node.outputs.items(): # Register the output in the context. compile_context.node_outputs.add(value) if (isinstance(tfx_node, base_component.BaseComponent) or compiler_utils.is_importer(tfx_node)): self._compile_node_outputs(tfx_node, node) # Step 5: Node parameters if not compiler_utils.is_resolver(tfx_node): for key, value in tfx_node.exec_properties.items(): if value is None: continue parameter_value = node.parameters.parameters[key] # Order matters, because runtime parameter can be in serialized string. if isinstance(value, data_types.RuntimeParameter): compiler_utils.set_runtime_parameter_pb( parameter_value.runtime_parameter, value.name, value.ptype, value.default) # RuntimeInfoPlaceholder passes Execution parameters of Facade # components. elif isinstance(value, placeholder.RuntimeInfoPlaceholder): parameter_value.placeholder.CopyFrom(value.encode()) # ChannelWrappedPlaceholder passes dynamic execution parameter. elif isinstance(value, placeholder.ChannelWrappedPlaceholder): compiler_utils.validate_dynamic_exec_ph_operator(value) parameter_value.placeholder.CopyFrom( value.encode_with_keys(compiler_utils.implicit_channel_key)) else: try: data_types_utils.set_parameter_value(parameter_value, value) except ValueError: raise ValueError( "Component {} got unsupported parameter {} with type {}." .format(tfx_node.id, key, type(value))) from ValueError # Step 6: Executor spec and optional driver spec for components if isinstance(tfx_node, base_component.BaseComponent): executor_spec = tfx_node.executor_spec.encode( component_spec=tfx_node.spec) deployment_config.executor_specs[tfx_node.id].Pack(executor_spec) # TODO(b/163433174): Remove specialized logic once generalization of # driver spec is done. if tfx_node.driver_class != base_driver.BaseDriver: driver_class_path = _fully_qualified_name(tfx_node.driver_class) driver_spec = executable_spec_pb2.PythonClassExecutableSpec() driver_spec.class_path = driver_class_path deployment_config.custom_driver_specs[tfx_node.id].Pack(driver_spec) # Step 7: Upstream/Downstream nodes node.upstream_nodes.extend( self._find_runtime_upstream_node_ids(compile_context, tfx_node)) node.downstream_nodes.extend( self._find_runtime_downstream_node_ids(compile_context, tfx_node)) # Step 8: Node execution options node.execution_options.caching_options.enable_cache = enable_cache # Step 9: Per-node platform config if isinstance(tfx_node, base_component.BaseComponent): tfx_component = cast(base_component.BaseComponent, tfx_node) if tfx_component.platform_config: deployment_config.node_level_platform_configs[tfx_node.id].Pack( tfx_component.platform_config) return node
def _compile_node( self, tfx_node: base_node.BaseNode, compile_context: _CompilerContext, deployment_config: pipeline_pb2.IntermediateDeploymentConfig ) -> pipeline_pb2.PipelineNode: """Compiles an individual TFX node into a PipelineNode proto. Args: tfx_node: A TFX node. compile_context: Resources needed to compile the node. deployment_config: Intermediate deployment config to set. Will include related specs for executors, drivers and platform specific configs. Returns: A PipelineNode proto that encodes information of the node. """ node = pipeline_pb2.PipelineNode() # Step 1: Node info node.node_info.type.name = tfx_node.type node.node_info.id = tfx_node.id # Step 2: Node Context # Context for the pipeline, across pipeline runs. pipeline_context_pb = node.contexts.contexts.add() pipeline_context_pb.type.name = constants.PIPELINE_CONTEXT_TYPE_NAME pipeline_context_pb.name.field_value.string_value = compile_context.pipeline_info.pipeline_context_name # Context for the current pipeline run. if (compile_context.execution_mode == pipeline_pb2.Pipeline.ExecutionMode.SYNC): pipeline_run_context_pb = node.contexts.contexts.add() pipeline_run_context_pb.type.name = constants.PIPELINE_RUN_CONTEXT_TYPE_NAME compiler_utils.set_runtime_parameter_pb( pipeline_run_context_pb.name.runtime_parameter, constants.PIPELINE_RUN_CONTEXT_TYPE_NAME, str) # Context for the node, across pipeline runs. node_context_pb = node.contexts.contexts.add() node_context_pb.type.name = constants.NODE_CONTEXT_TYPE_NAME node_context_pb.name.field_value.string_value = "{}.{}".format( compile_context.pipeline_info.pipeline_context_name, node.node_info.id) # Step 3: Node inputs for key, value in tfx_node.inputs.items(): input_spec = node.inputs.inputs[key] channel = input_spec.channels.add() if value.producer_component_id: channel.producer_node_query.id = value.producer_component_id # Here we rely on pipeline.components to be topologically sorted. assert value.producer_component_id in compile_context.node_pbs, ( "producer component should have already been compiled.") producer_pb = compile_context.node_pbs[ value.producer_component_id] for producer_context in producer_pb.contexts.contexts: if (not compiler_utils.is_resolver(tfx_node) or producer_context.name.runtime_parameter.name != constants.PIPELINE_RUN_CONTEXT_TYPE_NAME): context_query = channel.context_queries.add() context_query.type.CopyFrom(producer_context.type) context_query.name.CopyFrom(producer_context.name) artifact_type = value.type._get_artifact_type() # pylint: disable=protected-access channel.artifact_query.type.CopyFrom(artifact_type) channel.artifact_query.type.ClearField("properties") if value.output_key: channel.output_key = value.output_key # TODO(b/158712886): Calculate min_count based on if inputs are optional. # min_count = 0 stands for optional input and 1 stands for required input. # Step 3.1: Special treatment for Resolver node if compiler_utils.is_resolver(tfx_node): resolver = tfx_node.exec_properties[resolver_node.RESOLVER_CLASS] if resolver == latest_artifacts_resolver.LatestArtifactsResolver: node.inputs.resolver_config.resolver_policy = ( pipeline_pb2.ResolverConfig.ResolverPolicy.LATEST_ARTIFACT) elif resolver == latest_blessed_model_resolver.LatestBlessedModelResolver: node.inputs.resolver_config.resolver_policy = ( pipeline_pb2.ResolverConfig.ResolverPolicy. LATEST_BLESSED_MODEL) else: raise ValueError("Got unsupported resolver policy: {}".format( resolver.type)) # Step 4: Node outputs if compiler_utils.is_component(tfx_node): for key, value in tfx_node.outputs.items(): output_spec = node.outputs.outputs[key] artifact_type = value.type._get_artifact_type() # pylint: disable=protected-access output_spec.artifact_spec.type.CopyFrom(artifact_type) # Step 5: Node parameters if not compiler_utils.is_resolver(tfx_node): for key, value in tfx_node.exec_properties.items(): if value is None: continue parameter_value = node.parameters.parameters[key] # Order matters, because runtime parameter can be in serialized string. if isinstance(value, data_types.RuntimeParameter): compiler_utils.set_runtime_parameter_pb( parameter_value.runtime_parameter, value.name, value.ptype, value.default) elif isinstance(value, str) and re.search( data_types.RUNTIME_PARAMETER_PATTERN, value): runtime_param = json.loads(value) compiler_utils.set_runtime_parameter_pb( parameter_value.runtime_parameter, runtime_param.name, runtime_param.ptype, runtime_param.default) elif isinstance(value, str): parameter_value.field_value.string_value = value elif isinstance(value, int): parameter_value.field_value.int_value = value elif isinstance(value, float): parameter_value.field_value.double_value = value else: raise ValueError( "Component {} got unsupported parameter {} with type {}." .format(tfx_node.id, key, type(value))) # Step 6: Executor spec and optional driver spec for components if compiler_utils.is_component(tfx_node): executor_spec = tfx_node.executor_spec.encode() deployment_config.executor_specs[tfx_node.id].Pack(executor_spec) # TODO(b/163433174): Remove specialized logic once generalization of # driver spec is done. if tfx_node.driver_class != base_driver.BaseDriver: driver_class_path = "{}.{}".format( tfx_node.driver_class.__module__, tfx_node.driver_class.__name__) driver_spec = executable_spec_pb2.PythonClassExecutableSpec() driver_spec.class_path = driver_class_path deployment_config.custom_driver_specs[tfx_node.id].Pack( driver_spec) # Step 7: Upstream/Downstream nodes # Note: the order of tfx_node.upstream_nodes is inconsistent from # run to run. We sort them so that compiler generates consistent results. node.upstream_nodes.extend( sorted([ upstream_component.id for upstream_component in tfx_node.upstream_nodes ])) node.downstream_nodes.extend( sorted([ downstream_component.id for downstream_component in tfx_node.downstream_nodes ])) # Step 8: Node execution options # TODO(kennethyang): Add support for node execution options. return node
def _compile_node( self, tfx_node: base_node.BaseNode, compile_context: _CompilerContext, deployment_config: pipeline_pb2.IntermediateDeploymentConfig, enable_cache: bool, ) -> pipeline_pb2.PipelineNode: """Compiles an individual TFX node into a PipelineNode proto. Args: tfx_node: A TFX node. compile_context: Resources needed to compile the node. deployment_config: Intermediate deployment config to set. Will include related specs for executors, drivers and platform specific configs. enable_cache: whether cache is enabled Raises: TypeError: When supplied tfx_node has values of invalid type. Returns: A PipelineNode proto that encodes information of the node. """ node = pipeline_pb2.PipelineNode() # Step 1: Node info node.node_info.type.name = tfx_node.type node.node_info.id = tfx_node.id # Step 2: Node Context # Context for the pipeline, across pipeline runs. pipeline_context_pb = node.contexts.contexts.add() pipeline_context_pb.type.name = constants.PIPELINE_CONTEXT_TYPE_NAME pipeline_context_pb.name.field_value.string_value = compile_context.pipeline_info.pipeline_context_name # Context for the current pipeline run. if compile_context.is_sync_mode: pipeline_run_context_pb = node.contexts.contexts.add() pipeline_run_context_pb.type.name = constants.PIPELINE_RUN_CONTEXT_TYPE_NAME compiler_utils.set_runtime_parameter_pb( pipeline_run_context_pb.name.runtime_parameter, constants.PIPELINE_RUN_ID_PARAMETER_NAME, str) # Context for the node, across pipeline runs. node_context_pb = node.contexts.contexts.add() node_context_pb.type.name = constants.NODE_CONTEXT_TYPE_NAME node_context_pb.name.field_value.string_value = "{}.{}".format( compile_context.pipeline_info.pipeline_context_name, node.node_info.id) # Pre Step 3: Alter graph topology if needed. if compile_context.is_async_mode: tfx_node_inputs = self._compile_resolver_config( compile_context, tfx_node, node) else: tfx_node_inputs = tfx_node.inputs # Step 3: Node inputs for key, value in tfx_node_inputs.items(): input_spec = node.inputs.inputs[key] channel = input_spec.channels.add() if value.producer_component_id: channel.producer_node_query.id = value.producer_component_id # Here we rely on pipeline.components to be topologically sorted. assert value.producer_component_id in compile_context.node_pbs, ( "producer component should have already been compiled.") producer_pb = compile_context.node_pbs[ value.producer_component_id] for producer_context in producer_pb.contexts.contexts: if (not compiler_utils.is_resolver(tfx_node) or producer_context.name.runtime_parameter.name != constants.PIPELINE_RUN_CONTEXT_TYPE_NAME): context_query = channel.context_queries.add() context_query.type.CopyFrom(producer_context.type) context_query.name.CopyFrom(producer_context.name) else: # Caveat: portable core requires every channel to have at least one # Contex. But For cases like system nodes and producer-consumer # pipelines, a channel may not have contexts at all. In these cases, # we want to use the pipeline level context as the input channel # context. context_query = channel.context_queries.add() context_query.type.CopyFrom(pipeline_context_pb.type) context_query.name.CopyFrom(pipeline_context_pb.name) artifact_type = value.type._get_artifact_type() # pylint: disable=protected-access channel.artifact_query.type.CopyFrom(artifact_type) channel.artifact_query.type.ClearField("properties") if value.output_key: channel.output_key = value.output_key # TODO(b/158712886): Calculate min_count based on if inputs are optional. # min_count = 0 stands for optional input and 1 stands for required input. # Step 3.1: Special treatment for Resolver node. if compiler_utils.is_resolver(tfx_node): assert compile_context.is_sync_mode node.inputs.resolver_config.resolver_steps.extend( _convert_to_resolver_steps(tfx_node)) # Step 4: Node outputs if isinstance(tfx_node, base_component.BaseComponent): for key, value in tfx_node.outputs.items(): output_spec = node.outputs.outputs[key] artifact_type = value.type._get_artifact_type() # pylint: disable=protected-access output_spec.artifact_spec.type.CopyFrom(artifact_type) for prop_key, prop_value in value.additional_properties.items( ): _check_property_value_type(prop_key, prop_value, output_spec.artifact_spec.type) data_types_utils.set_metadata_value( output_spec.artifact_spec. additional_properties[prop_key].field_value, prop_value) for prop_key, prop_value in value.additional_custom_properties.items( ): data_types_utils.set_metadata_value( output_spec.artifact_spec. additional_custom_properties[prop_key].field_value, prop_value) # TODO(b/170694459): Refactor special nodes as plugins. # Step 4.1: Special treament for Importer node if compiler_utils.is_importer(tfx_node): self._compile_importer_node_outputs(tfx_node, node) # Step 5: Node parameters if not compiler_utils.is_resolver(tfx_node): for key, value in tfx_node.exec_properties.items(): if value is None: continue # Ignore following two properties for a importer node, because they are # already attached to the artifacts produced by the importer node. if compiler_utils.is_importer(tfx_node) and ( key == importer.PROPERTIES_KEY or key == importer.CUSTOM_PROPERTIES_KEY): continue parameter_value = node.parameters.parameters[key] # Order matters, because runtime parameter can be in serialized string. if isinstance(value, data_types.RuntimeParameter): compiler_utils.set_runtime_parameter_pb( parameter_value.runtime_parameter, value.name, value.ptype, value.default) elif isinstance(value, str) and re.search( data_types.RUNTIME_PARAMETER_PATTERN, value): runtime_param = json.loads(value) compiler_utils.set_runtime_parameter_pb( parameter_value.runtime_parameter, runtime_param.name, runtime_param.ptype, runtime_param.default) else: try: data_types_utils.set_metadata_value( parameter_value.field_value, value) except ValueError: raise ValueError( "Component {} got unsupported parameter {} with type {}." .format(tfx_node.id, key, type(value))) # Step 6: Executor spec and optional driver spec for components if isinstance(tfx_node, base_component.BaseComponent): executor_spec = tfx_node.executor_spec.encode( component_spec=tfx_node.spec) deployment_config.executor_specs[tfx_node.id].Pack(executor_spec) # TODO(b/163433174): Remove specialized logic once generalization of # driver spec is done. if tfx_node.driver_class != base_driver.BaseDriver: driver_class_path = "{}.{}".format( tfx_node.driver_class.__module__, tfx_node.driver_class.__name__) driver_spec = executable_spec_pb2.PythonClassExecutableSpec() driver_spec.class_path = driver_class_path deployment_config.custom_driver_specs[tfx_node.id].Pack( driver_spec) # Step 7: Upstream/Downstream nodes # Note: the order of tfx_node.upstream_nodes is inconsistent from # run to run. We sort them so that compiler generates consistent results. # For ASYNC mode upstream/downstream node information is not set as # compiled IR graph topology can be different from that on pipeline # authoring time; for example ResolverNode is removed. if compile_context.is_sync_mode: node.upstream_nodes.extend( sorted(node.id for node in tfx_node.upstream_nodes)) node.downstream_nodes.extend( sorted(node.id for node in tfx_node.downstream_nodes)) # Step 8: Node execution options node.execution_options.caching_options.enable_cache = enable_cache # Step 9: Per-node platform config if isinstance(tfx_node, base_component.BaseComponent): tfx_component = cast(base_component.BaseComponent, tfx_node) if tfx_component.platform_config: deployment_config.node_level_platform_configs[ tfx_node.id].Pack(tfx_component.platform_config) return node
} } } outputs { key: "output_3" value { artifact_spec { type { id: 3 name: "String" } } } } } """, pipeline_pb2.PipelineNode()) class OutputUtilsTest(test_case_utils.TfxTest, parameterized.TestCase): def setUp(self): super().setUp() pipeline_runtime_spec = pipeline_pb2.PipelineRuntimeSpec() pipeline_runtime_spec.pipeline_root.field_value.string_value = self.tmp_dir pipeline_runtime_spec.pipeline_run_id.field_value.string_value = ( 'test_run_0') self._pipeline_runtime_spec = pipeline_runtime_spec def _output_resolver(self, execution_mode=pipeline_pb2.Pipeline.SYNC): return outputs_utils.OutputsResolver( pipeline_node=_PIPELINE_NODE, pipeline_info=_PIPELINE_INFO,