def testRunExecutor_with_InprocessExecutor(self): executor_sepc = text_format.Parse( """ class_path: "tfx.orchestration.portable.python_executor_operator_test.InprocessExecutor" """, executable_spec_pb2.PythonClassExecutableSpec()) operator = python_executor_operator.PythonExecutorOperator(executor_sepc) input_dict = {'input_key': [standard_artifacts.Examples()]} output_dict = {'output_key': [standard_artifacts.Model()]} exec_properties = {'key': 'value'} stateful_working_dir = os.path.join(self.tmp_dir, 'stateful_working_dir') executor_output_uri = os.path.join(self.tmp_dir, 'executor_output') executor_output = operator.run_executor( base_executor_operator.ExecutionInfo( input_dict=input_dict, output_dict=output_dict, exec_properties=exec_properties, stateful_working_dir=stateful_working_dir, executor_output_uri=executor_output_uri)) self.assertProtoPartiallyEquals(""" execution_properties { key: "key" value { string_value: "value" } } output_artifacts { key: "output_key" value { artifacts { } } }""", executor_output)
def testRunExecutor_with_InplaceUpdateExecutor(self): executor_sepc = text_format.Parse( """ class_path: "tfx.orchestration.portable.python_executor_operator_test.InplaceUpdateExecutor" """, executable_spec_pb2.PythonClassExecutableSpec()) operator = python_executor_operator.PythonExecutorOperator( executor_sepc) input_dict = {'input_key': [standard_artifacts.Examples()]} output_dict = {'output_key': [standard_artifacts.Model()]} exec_properties = { 'string': 'value', 'int': 1, 'float': 0.0, # This should not happen on production and will be # dropped. 'proto': execution_result_pb2.ExecutorOutput() } executor_output = operator.run_executor( self._get_execution_info(input_dict, output_dict, exec_properties)) self.assertProtoPartiallyEquals( """ output_artifacts { key: "output_key" value { artifacts { custom_properties { key: "name" value { string_value: "MyPipeline.MyPythonNode.my_model" } } } } }""", executor_output)
def succeed(self): custom_driver_spec = (executable_spec_pb2.PythonClassExecutableSpec()) custom_driver_spec.class_path = 'tfx.orchestration.portable.python_driver_operator._FakeNoopDriver' driver_operator = python_driver_operator.PythonDriverOperator( custom_driver_spec, None, None, None) driver_output = driver_operator.run_driver(None, None, None) self.assertEqual(driver_output, _DEFAULT_DRIVER_OUTPUT)
def testRunExecutor_with_InprocessExecutor(self): executor_sepc = text_format.Parse( """ class_path: "tfx.orchestration.portable.python_executor_operator_test.InprocessExecutor" """, executable_spec_pb2.PythonClassExecutableSpec()) operator = python_executor_operator.PythonExecutorOperator( executor_sepc) input_dict = {'input_key': [standard_artifacts.Examples()]} output_dict = {'output_key': [standard_artifacts.Model()]} exec_properties = {'key': 'value'} executor_output = operator.run_executor( self._get_execution_info(input_dict, output_dict, exec_properties)) self.assertProtoPartiallyEquals( """ execution_properties { key: "key" value { string_value: "value" } } output_artifacts { key: "output_key" value { artifacts { } } }""", executor_output)
def encode( self, component_spec: Optional[types.ComponentSpec] = None) -> message.Message: result = executable_spec_pb2.PythonClassExecutableSpec() result.class_path = self.class_path result.extra_flags.extend(self.extra_flags) return result
def testExecutableSpecSerialization(self): python_executable_spec = text_format.Parse( """ class_path: 'path_to_my_class' extra_flags: '--flag=my_flag' """, executable_spec_pb2.PythonClassExecutableSpec()) python_serialized = python_execution_binary_utils.serialize_executable_spec( python_executable_spec) python_rehydrated = python_execution_binary_utils.deserialize_executable_spec( python_serialized) self.assertProtoEquals(python_rehydrated, python_executable_spec) beam_executable_spec = text_format.Parse( """ python_executor_spec { class_path: 'path_to_my_class' extra_flags: '--flag1=1' } beam_pipeline_args: '--arg=my_beam_pipeline_arg' """, executable_spec_pb2.BeamExecutableSpec()) beam_serialized = python_execution_binary_utils.serialize_executable_spec( beam_executable_spec) beam_rehydrated = python_execution_binary_utils.deserialize_executable_spec( beam_serialized, with_beam=True) self.assertProtoEquals(beam_rehydrated, beam_executable_spec)
def testRunExecutor_with_InplaceUpdateExecutor(self): executor_sepc = text_format.Parse( """ class_path: "tfx.orchestration.portable.python_executor_operator_test.InplaceUpdateExecutor" """, executable_spec_pb2.PythonClassExecutableSpec()) operator = python_executor_operator.PythonExecutorOperator( executor_sepc) input_dict = {'input_key': [standard_artifacts.Examples()]} output_dict = {'output_key': [standard_artifacts.Model()]} exec_properties = { 'string': 'value', 'int': 1, 'float': 0.0, # This should not happen on production and will be # dropped. 'proto': execution_result_pb2.ExecutorOutput() } stateful_working_dir = os.path.join(self.tmp_dir, 'stateful_working_dir') executor_output_uri = os.path.join(self.tmp_dir, 'executor_output') executor_output = operator.run_executor( base_executor_operator.ExecutionInfo( input_dict=input_dict, output_dict=output_dict, exec_properties=exec_properties, stateful_working_dir=stateful_working_dir, executor_output_uri=executor_output_uri)) self.assertProtoPartiallyEquals( """ execution_properties { key: "float" value { double_value: 0.0 } } execution_properties { key: "int" value { int_value: 1 } } execution_properties { key: "string" value { string_value: "value" } } output_artifacts { key: "output_key" value { artifacts { custom_properties { key: "name" value { string_value: "my_model" } } } } }""", executor_output)
def testGetCacheContextTwiceDifferentExecutorSpec(self): with metadata.Metadata(connection_config=self._connection_config) as m: self._get_cache_context(m) self._get_cache_context(m, executor_spec=text_format.Parse( """ class_path: "new.class.path" """, executable_spec_pb2.PythonClassExecutableSpec())) # Different executor spec will result in new cache context. self.assertLen(m.store.get_contexts(), 2)
def setUp(self): super(PlaceholderUtilsTest, self).setUp() examples = [standard_artifacts.Examples()] examples[0].uri = "/tmp" examples[0].split_names = artifact_utils.encode_split_names( ["train", "eval"]) self._serving_spec = infra_validator_pb2.ServingSpec() self._serving_spec.tensorflow_serving.tags.extend( ["latest", "1.15.0-gpu"]) self._resolution_context = placeholder_utils.ResolutionContext( exec_info=data_types.ExecutionInfo( input_dict={ "model": [standard_artifacts.Model()], "examples": examples, }, output_dict={"blessing": [standard_artifacts.ModelBlessing()]}, exec_properties={ "proto_property": json_format.MessageToJson(message=self._serving_spec, sort_keys=True, preserving_proto_field_name=True, indent=0) }, execution_output_uri="test_executor_output_uri", stateful_working_dir="test_stateful_working_dir", pipeline_node=pipeline_pb2.PipelineNode( node_info=pipeline_pb2.NodeInfo( type=metadata_store_pb2.ExecutionType( name="infra_validator"))), pipeline_info=pipeline_pb2.PipelineInfo( id="test_pipeline_id")), executor_spec=executable_spec_pb2.PythonClassExecutableSpec( class_path="test_class_path"), ) # Resolution context to simulate missing optional values. self._none_resolution_context = placeholder_utils.ResolutionContext( exec_info=data_types.ExecutionInfo( input_dict={ "model": [], "examples": [], }, output_dict={"blessing": []}, exec_properties={}, pipeline_node=pipeline_pb2.PipelineNode( node_info=pipeline_pb2.NodeInfo( type=metadata_store_pb2.ExecutionType( name="infra_validator"))), pipeline_info=pipeline_pb2.PipelineInfo( id="test_pipeline_id")), executor_spec=None, platform_config=None)
def testRunExecutorWithBeamPipelineArgs(self): executor_sepc = text_format.Parse( """ class_path: "tfx.orchestration.portable.python_executor_operator_test.ValidateBeamPipelineArgsExecutor" extra_flags: "--runner=DirectRunner" """, executable_spec_pb2.PythonClassExecutableSpec()) operator = python_executor_operator.PythonExecutorOperator(executor_sepc) executor_output_uri = os.path.join(self.tmp_dir, 'executor_output') operator.run_executor( data_types.ExecutionInfo( input_dict={}, output_dict={}, exec_properties={}, execution_output_uri=executor_output_uri))
def setUp(self): super().setUp() self._connection_config = metadata_store_pb2.ConnectionConfig() self._connection_config.sqlite.SetInParent() self._module_file_path = os.path.join(self.tmp_dir, 'module_file') self._input_artifacts = { 'input_examples': [standard_artifacts.Examples()] } self._output_artifacts = { 'output_models': [standard_artifacts.Model()] } self._parameters = {'module_file': self._module_file_path} self._module_file_content = 'module content' self._pipeline_node = text_format.Parse( """ node_info { id: "my_id" } """, pipeline_pb2.PipelineNode()) self._pipeline_info = pipeline_pb2.PipelineInfo(id='pipeline_id') self._executor_spec = text_format.Parse( """ class_path: "my.class.path" """, executable_spec_pb2.PythonClassExecutableSpec())
def _compile_node( self, tfx_node: base_node.BaseNode, compile_context: _CompilerContext, deployment_config: pipeline_pb2.IntermediateDeploymentConfig, enable_cache: bool, ) -> pipeline_pb2.PipelineNode: """Compiles an individual TFX node into a PipelineNode proto. Args: tfx_node: A TFX node. compile_context: Resources needed to compile the node. deployment_config: Intermediate deployment config to set. Will include related specs for executors, drivers and platform specific configs. enable_cache: whether cache is enabled Raises: TypeError: When supplied tfx_node has values of invalid type. Returns: A PipelineNode proto that encodes information of the node. """ node = pipeline_pb2.PipelineNode() # Step 1: Node info node.node_info.type.name = tfx_node.type node.node_info.id = tfx_node.id # Step 2: Node Context # Context for the pipeline, across pipeline runs. pipeline_context_pb = node.contexts.contexts.add() pipeline_context_pb.type.name = constants.PIPELINE_CONTEXT_TYPE_NAME pipeline_context_pb.name.field_value.string_value = compile_context.pipeline_info.pipeline_context_name # Context for the current pipeline run. if compile_context.is_sync_mode: pipeline_run_context_pb = node.contexts.contexts.add() pipeline_run_context_pb.type.name = constants.PIPELINE_RUN_CONTEXT_TYPE_NAME compiler_utils.set_runtime_parameter_pb( pipeline_run_context_pb.name.runtime_parameter, constants.PIPELINE_RUN_ID_PARAMETER_NAME, str) # Context for the node, across pipeline runs. node_context_pb = node.contexts.contexts.add() node_context_pb.type.name = constants.NODE_CONTEXT_TYPE_NAME node_context_pb.name.field_value.string_value = "{}.{}".format( compile_context.pipeline_info.pipeline_context_name, node.node_info.id) # Pre Step 3: Alter graph topology if needed. if compile_context.is_async_mode: tfx_node_inputs = self._compile_resolver_config( compile_context, tfx_node, node) else: tfx_node_inputs = tfx_node.inputs # Step 3: Node inputs for key, value in tfx_node_inputs.items(): input_spec = node.inputs.inputs[key] channel = input_spec.channels.add() if value.producer_component_id: channel.producer_node_query.id = value.producer_component_id # Here we rely on pipeline.components to be topologically sorted. assert value.producer_component_id in compile_context.node_pbs, ( "producer component should have already been compiled.") producer_pb = compile_context.node_pbs[ value.producer_component_id] for producer_context in producer_pb.contexts.contexts: if (not compiler_utils.is_resolver(tfx_node) or producer_context.name.runtime_parameter.name != constants.PIPELINE_RUN_CONTEXT_TYPE_NAME): context_query = channel.context_queries.add() context_query.type.CopyFrom(producer_context.type) context_query.name.CopyFrom(producer_context.name) else: # Caveat: portable core requires every channel to have at least one # Contex. But For cases like system nodes and producer-consumer # pipelines, a channel may not have contexts at all. In these cases, # we want to use the pipeline level context as the input channel # context. context_query = channel.context_queries.add() context_query.type.CopyFrom(pipeline_context_pb.type) context_query.name.CopyFrom(pipeline_context_pb.name) artifact_type = value.type._get_artifact_type() # pylint: disable=protected-access channel.artifact_query.type.CopyFrom(artifact_type) channel.artifact_query.type.ClearField("properties") if value.output_key: channel.output_key = value.output_key # TODO(b/158712886): Calculate min_count based on if inputs are optional. # min_count = 0 stands for optional input and 1 stands for required input. # Step 3.1: Special treatment for Resolver node. if compiler_utils.is_resolver(tfx_node): assert compile_context.is_sync_mode node.inputs.resolver_config.resolver_steps.extend( _convert_to_resolver_steps(tfx_node)) # Step 4: Node outputs if isinstance(tfx_node, base_component.BaseComponent): for key, value in tfx_node.outputs.items(): output_spec = node.outputs.outputs[key] artifact_type = value.type._get_artifact_type() # pylint: disable=protected-access output_spec.artifact_spec.type.CopyFrom(artifact_type) for prop_key, prop_value in value.additional_properties.items( ): _check_property_value_type(prop_key, prop_value, output_spec.artifact_spec.type) data_types_utils.set_metadata_value( output_spec.artifact_spec. additional_properties[prop_key].field_value, prop_value) for prop_key, prop_value in value.additional_custom_properties.items( ): data_types_utils.set_metadata_value( output_spec.artifact_spec. additional_custom_properties[prop_key].field_value, prop_value) # TODO(b/170694459): Refactor special nodes as plugins. # Step 4.1: Special treament for Importer node if compiler_utils.is_importer(tfx_node): self._compile_importer_node_outputs(tfx_node, node) # Step 5: Node parameters if not compiler_utils.is_resolver(tfx_node): for key, value in tfx_node.exec_properties.items(): if value is None: continue # Ignore following two properties for a importer node, because they are # already attached to the artifacts produced by the importer node. if compiler_utils.is_importer(tfx_node) and ( key == importer.PROPERTIES_KEY or key == importer.CUSTOM_PROPERTIES_KEY): continue parameter_value = node.parameters.parameters[key] # Order matters, because runtime parameter can be in serialized string. if isinstance(value, data_types.RuntimeParameter): compiler_utils.set_runtime_parameter_pb( parameter_value.runtime_parameter, value.name, value.ptype, value.default) elif isinstance(value, str) and re.search( data_types.RUNTIME_PARAMETER_PATTERN, value): runtime_param = json.loads(value) compiler_utils.set_runtime_parameter_pb( parameter_value.runtime_parameter, runtime_param.name, runtime_param.ptype, runtime_param.default) else: try: data_types_utils.set_metadata_value( parameter_value.field_value, value) except ValueError: raise ValueError( "Component {} got unsupported parameter {} with type {}." .format(tfx_node.id, key, type(value))) # Step 6: Executor spec and optional driver spec for components if isinstance(tfx_node, base_component.BaseComponent): executor_spec = tfx_node.executor_spec.encode( component_spec=tfx_node.spec) deployment_config.executor_specs[tfx_node.id].Pack(executor_spec) # TODO(b/163433174): Remove specialized logic once generalization of # driver spec is done. if tfx_node.driver_class != base_driver.BaseDriver: driver_class_path = "{}.{}".format( tfx_node.driver_class.__module__, tfx_node.driver_class.__name__) driver_spec = executable_spec_pb2.PythonClassExecutableSpec() driver_spec.class_path = driver_class_path deployment_config.custom_driver_specs[tfx_node.id].Pack( driver_spec) # Step 7: Upstream/Downstream nodes # Note: the order of tfx_node.upstream_nodes is inconsistent from # run to run. We sort them so that compiler generates consistent results. # For ASYNC mode upstream/downstream node information is not set as # compiled IR graph topology can be different from that on pipeline # authoring time; for example ResolverNode is removed. if compile_context.is_sync_mode: node.upstream_nodes.extend( sorted(node.id for node in tfx_node.upstream_nodes)) node.downstream_nodes.extend( sorted(node.id for node in tfx_node.downstream_nodes)) # Step 8: Node execution options node.execution_options.caching_options.enable_cache = enable_cache # Step 9: Per-node platform config if isinstance(tfx_node, base_component.BaseComponent): tfx_component = cast(base_component.BaseComponent, tfx_node) if tfx_component.platform_config: deployment_config.node_level_platform_configs[ tfx_node.id].Pack(tfx_component.platform_config) return node
def _compile_node( self, tfx_node: base_node.BaseNode, compile_context: _CompilerContext, deployment_config: pipeline_pb2.IntermediateDeploymentConfig ) -> pipeline_pb2.PipelineNode: """Compiles an individual TFX node into a PipelineNode proto. Args: tfx_node: A TFX node. compile_context: Resources needed to compile the node. deployment_config: Intermediate deployment config to set. Will include related specs for executors, drivers and platform specific configs. Returns: A PipelineNode proto that encodes information of the node. """ node = pipeline_pb2.PipelineNode() # Step 1: Node info node.node_info.type.name = tfx_node.type node.node_info.id = tfx_node.id # Step 2: Node Context # Context for the pipeline, across pipeline runs. pipeline_context_pb = node.contexts.contexts.add() pipeline_context_pb.type.name = constants.PIPELINE_CONTEXT_TYPE_NAME pipeline_context_pb.name.field_value.string_value = compile_context.pipeline_info.pipeline_context_name # Context for the current pipeline run. if (compile_context.execution_mode == pipeline_pb2.Pipeline.ExecutionMode.SYNC): pipeline_run_context_pb = node.contexts.contexts.add() pipeline_run_context_pb.type.name = constants.PIPELINE_RUN_CONTEXT_TYPE_NAME compiler_utils.set_runtime_parameter_pb( pipeline_run_context_pb.name.runtime_parameter, constants.PIPELINE_RUN_CONTEXT_TYPE_NAME, str) # Context for the node, across pipeline runs. node_context_pb = node.contexts.contexts.add() node_context_pb.type.name = constants.NODE_CONTEXT_TYPE_NAME node_context_pb.name.field_value.string_value = "{}.{}".format( compile_context.pipeline_info.pipeline_context_name, node.node_info.id) # Step 3: Node inputs for key, value in tfx_node.inputs.items(): input_spec = node.inputs.inputs[key] channel = input_spec.channels.add() if value.producer_component_id: channel.producer_node_query.id = value.producer_component_id # Here we rely on pipeline.components to be topologically sorted. assert value.producer_component_id in compile_context.node_pbs, ( "producer component should have already been compiled.") producer_pb = compile_context.node_pbs[ value.producer_component_id] for producer_context in producer_pb.contexts.contexts: if (not compiler_utils.is_resolver(tfx_node) or producer_context.name.runtime_parameter.name != constants.PIPELINE_RUN_CONTEXT_TYPE_NAME): context_query = channel.context_queries.add() context_query.type.CopyFrom(producer_context.type) context_query.name.CopyFrom(producer_context.name) artifact_type = value.type._get_artifact_type() # pylint: disable=protected-access channel.artifact_query.type.CopyFrom(artifact_type) channel.artifact_query.type.ClearField("properties") if value.output_key: channel.output_key = value.output_key # TODO(b/158712886): Calculate min_count based on if inputs are optional. # min_count = 0 stands for optional input and 1 stands for required input. # Step 3.1: Special treatment for Resolver node if compiler_utils.is_resolver(tfx_node): resolver = tfx_node.exec_properties[resolver_node.RESOLVER_CLASS] if resolver == latest_artifacts_resolver.LatestArtifactsResolver: node.inputs.resolver_config.resolver_policy = ( pipeline_pb2.ResolverConfig.ResolverPolicy.LATEST_ARTIFACT) elif resolver == latest_blessed_model_resolver.LatestBlessedModelResolver: node.inputs.resolver_config.resolver_policy = ( pipeline_pb2.ResolverConfig.ResolverPolicy. LATEST_BLESSED_MODEL) else: raise ValueError("Got unsupported resolver policy: {}".format( resolver.type)) # Step 4: Node outputs if compiler_utils.is_component(tfx_node): for key, value in tfx_node.outputs.items(): output_spec = node.outputs.outputs[key] artifact_type = value.type._get_artifact_type() # pylint: disable=protected-access output_spec.artifact_spec.type.CopyFrom(artifact_type) # Step 5: Node parameters if not compiler_utils.is_resolver(tfx_node): for key, value in tfx_node.exec_properties.items(): if value is None: continue parameter_value = node.parameters.parameters[key] # Order matters, because runtime parameter can be in serialized string. if isinstance(value, data_types.RuntimeParameter): compiler_utils.set_runtime_parameter_pb( parameter_value.runtime_parameter, value.name, value.ptype, value.default) elif isinstance(value, str) and re.search( data_types.RUNTIME_PARAMETER_PATTERN, value): runtime_param = json.loads(value) compiler_utils.set_runtime_parameter_pb( parameter_value.runtime_parameter, runtime_param.name, runtime_param.ptype, runtime_param.default) elif isinstance(value, str): parameter_value.field_value.string_value = value elif isinstance(value, int): parameter_value.field_value.int_value = value elif isinstance(value, float): parameter_value.field_value.double_value = value else: raise ValueError( "Component {} got unsupported parameter {} with type {}." .format(tfx_node.id, key, type(value))) # Step 6: Executor spec and optional driver spec for components if compiler_utils.is_component(tfx_node): executor_spec = tfx_node.executor_spec.encode() deployment_config.executor_specs[tfx_node.id].Pack(executor_spec) # TODO(b/163433174): Remove specialized logic once generalization of # driver spec is done. if tfx_node.driver_class != base_driver.BaseDriver: driver_class_path = "{}.{}".format( tfx_node.driver_class.__module__, tfx_node.driver_class.__name__) driver_spec = executable_spec_pb2.PythonClassExecutableSpec() driver_spec.class_path = driver_class_path deployment_config.custom_driver_specs[tfx_node.id].Pack( driver_spec) # Step 7: Upstream/Downstream nodes # Note: the order of tfx_node.upstream_nodes is inconsistent from # run to run. We sort them so that compiler generates consistent results. node.upstream_nodes.extend( sorted([ upstream_component.id for upstream_component in tfx_node.upstream_nodes ])) node.downstream_nodes.extend( sorted([ downstream_component.id for downstream_component in tfx_node.downstream_nodes ])) # Step 8: Node execution options # TODO(kennethyang): Add support for node execution options. return node
def encode(self) -> message.Message: result = executable_spec_pb2.PythonClassExecutableSpec() result.class_path = self.class_path result.extra_flags.extend(self.extra_flags) return result
def _compile_node( self, tfx_node: base_node.BaseNode, compile_context: _CompilerContext, deployment_config: pipeline_pb2.IntermediateDeploymentConfig, enable_cache: bool, ) -> pipeline_pb2.PipelineNode: """Compiles an individual TFX node into a PipelineNode proto. Args: tfx_node: A TFX node. compile_context: Resources needed to compile the node. deployment_config: Intermediate deployment config to set. Will include related specs for executors, drivers and platform specific configs. enable_cache: whether cache is enabled Raises: TypeError: When supplied tfx_node has values of invalid type. Returns: A PipelineNode proto that encodes information of the node. """ node = pipeline_pb2.PipelineNode() # Step 1: Node info node.node_info.type.name = tfx_node.type if isinstance(tfx_node, base_component.BaseComponent) and tfx_node.type_annotation: node.node_info.type.base_type = ( tfx_node.type_annotation.MLMD_SYSTEM_BASE_TYPE) node.node_info.id = tfx_node.id # Step 2: Node Context # Context for the pipeline, across pipeline runs. pipeline_context_pb = node.contexts.contexts.add() pipeline_context_pb.type.name = constants.PIPELINE_CONTEXT_TYPE_NAME pipeline_context_pb.name.field_value.string_value = compile_context.pipeline_info.pipeline_context_name # Context for the current pipeline run. if compile_context.is_sync_mode: pipeline_run_context_pb = node.contexts.contexts.add() pipeline_run_context_pb.type.name = constants.PIPELINE_RUN_CONTEXT_TYPE_NAME compiler_utils.set_runtime_parameter_pb( pipeline_run_context_pb.name.runtime_parameter, constants.PIPELINE_RUN_ID_PARAMETER_NAME, str) # Context for the node, across pipeline runs. node_context_pb = node.contexts.contexts.add() node_context_pb.type.name = constants.NODE_CONTEXT_TYPE_NAME node_context_pb.name.field_value.string_value = ( compiler_utils.node_context_name( compile_context.pipeline_info.pipeline_context_name, node.node_info.id)) # Pre Step 3: Alter graph topology if needed. if compile_context.is_async_mode: tfx_node_inputs = self._embed_upstream_resolver_nodes( compile_context, tfx_node, node) else: tfx_node_inputs = tfx_node.inputs # Step 3: Node inputs # Step 3.1: Generate implicit input channels # Step 3.1.1: Conditionals implicit_input_channels = {} predicates = conditional.get_predicates(tfx_node) if predicates: implicit_keys_map = {} for key, chnl in tfx_node_inputs.items(): if not isinstance(chnl, types.Channel): raise ValueError( "Conditional only support using channel as a predicate.") implicit_keys_map[compiler_utils.implicit_channel_key(chnl)] = key encoded_predicates = [] for predicate in predicates: for chnl in predicate.dependent_channels(): implicit_key = compiler_utils.implicit_channel_key(chnl) if implicit_key not in implicit_keys_map: # Store this channel and add it to the node inputs later. implicit_input_channels[implicit_key] = chnl encoded_predicates.append( predicate.encode_with_keys( compiler_utils.build_channel_to_key_fn(implicit_keys_map))) # In async pipeline, conditional resolver step should be the last step # in all resolver steps of a node. resolver_step = node.inputs.resolver_config.resolver_steps.add() resolver_step.class_path = constants.CONDITIONAL_RESOLVER_CLASS_PATH resolver_step.config_json = json_utils.dumps( {"predicates": encoded_predicates}) # Step 3.1.2: Add placeholder exec props to implicit_input_channels for key, value in tfx_node.exec_properties.items(): if isinstance(value, placeholder.ChannelWrappedPlaceholder): if not (inspect.isclass(value.channel.type) and issubclass(value.channel.type, value_artifact.ValueArtifact)): raise ValueError("output channel to dynamic exec properties is not " "ValueArtifact") implicit_key = compiler_utils.implicit_channel_key(value.channel) implicit_input_channels[implicit_key] = value.channel # Step 3.2: Handle ForEach. dsl_contexts = context_manager.get_contexts(tfx_node) for dsl_context in dsl_contexts: if isinstance(dsl_context, for_each.ForEachContext): for input_key, channel in tfx_node_inputs.items(): if (isinstance(channel, types.LoopVarChannel) and channel.wrapped is dsl_context.wrapped_channel): node.inputs.resolver_config.resolver_steps.extend( _compile_for_each_context(input_key)) break else: # Ideally should not reach here as the same check is performed at # ForEachContext.will_add_node(). raise ValueError( f"Unable to locate ForEach loop variable {dsl_context.channel} " f"from inputs of node {tfx_node.id}.") # Check loop variable is used outside the ForEach. for input_key, channel in tfx_node_inputs.items(): if isinstance(channel, types.LoopVarChannel): dsl_context_ids = {c.id for c in dsl_contexts} if channel.context_id not in dsl_context_ids: raise ValueError( "Loop variable cannot be used outside the ForEach " f"(node_id = {tfx_node.id}, input_key = {input_key}).") # Step 3.3: Fill node inputs for key, value in itertools.chain(tfx_node_inputs.items(), implicit_input_channels.items()): input_spec = node.inputs.inputs[key] for input_channel in channel_utils.get_individual_channels(value): chnl = input_spec.channels.add() # If the node input comes from another node's output, fill the context # queries with the producer node's contexts. if input_channel in compile_context.node_outputs: chnl.producer_node_query.id = input_channel.producer_component_id # Here we rely on pipeline.components to be topologically sorted. assert input_channel.producer_component_id in compile_context.node_pbs, ( "producer component should have already been compiled.") producer_pb = compile_context.node_pbs[ input_channel.producer_component_id] for producer_context in producer_pb.contexts.contexts: context_query = chnl.context_queries.add() context_query.type.CopyFrom(producer_context.type) context_query.name.CopyFrom(producer_context.name) # If the node input does not come from another node's output, fill the # context queries based on Channel info. We requires every channel to # have pipeline context and will fill it automatically. else: # Add pipeline context query. context_query = chnl.context_queries.add() context_query.type.CopyFrom(pipeline_context_pb.type) context_query.name.CopyFrom(pipeline_context_pb.name) # Optionally add node context query. if input_channel.producer_component_id: # Add node context query if `producer_component_id` is present. chnl.producer_node_query.id = input_channel.producer_component_id node_context_query = chnl.context_queries.add() node_context_query.type.name = constants.NODE_CONTEXT_TYPE_NAME node_context_query.name.field_value.string_value = "{}.{}".format( compile_context.pipeline_info.pipeline_context_name, input_channel.producer_component_id) artifact_type = input_channel.type._get_artifact_type() # pylint: disable=protected-access chnl.artifact_query.type.CopyFrom(artifact_type) chnl.artifact_query.type.ClearField("properties") if input_channel.output_key: chnl.output_key = input_channel.output_key # Set NodeInputs.min_count. if isinstance(tfx_node, base_component.BaseComponent): if key in implicit_input_channels: # Mark all input channel as optional for implicit inputs # (e.g. conditionals). This is suboptimal, but still a safe guess to # avoid breaking the pipeline run. input_spec.min_count = 0 else: try: # Calculating min_count from ComponentSpec.INPUTS. if tfx_node.spec.is_optional_input(key): input_spec.min_count = 0 else: input_spec.min_count = 1 except KeyError: # Currently we can fall here if the upstream resolver node inputs # are embedded into the current node (in async mode). We always # regard resolver's inputs as optional. if compile_context.is_async_mode: input_spec.min_count = 0 else: raise # TODO(b/170694459): Refactor special nodes as plugins. # Step 3.4: Special treatment for Resolver node. if compiler_utils.is_resolver(tfx_node): assert compile_context.is_sync_mode node.inputs.resolver_config.resolver_steps.extend( _compile_resolver_node(tfx_node)) # Step 4: Node outputs for key, value in tfx_node.outputs.items(): # Register the output in the context. compile_context.node_outputs.add(value) if (isinstance(tfx_node, base_component.BaseComponent) or compiler_utils.is_importer(tfx_node)): self._compile_node_outputs(tfx_node, node) # Step 5: Node parameters if not compiler_utils.is_resolver(tfx_node): for key, value in tfx_node.exec_properties.items(): if value is None: continue parameter_value = node.parameters.parameters[key] # Order matters, because runtime parameter can be in serialized string. if isinstance(value, data_types.RuntimeParameter): compiler_utils.set_runtime_parameter_pb( parameter_value.runtime_parameter, value.name, value.ptype, value.default) # RuntimeInfoPlaceholder passes Execution parameters of Facade # components. elif isinstance(value, placeholder.RuntimeInfoPlaceholder): parameter_value.placeholder.CopyFrom(value.encode()) # ChannelWrappedPlaceholder passes dynamic execution parameter. elif isinstance(value, placeholder.ChannelWrappedPlaceholder): compiler_utils.validate_dynamic_exec_ph_operator(value) parameter_value.placeholder.CopyFrom( value.encode_with_keys(compiler_utils.implicit_channel_key)) else: try: data_types_utils.set_parameter_value(parameter_value, value) except ValueError: raise ValueError( "Component {} got unsupported parameter {} with type {}." .format(tfx_node.id, key, type(value))) from ValueError # Step 6: Executor spec and optional driver spec for components if isinstance(tfx_node, base_component.BaseComponent): executor_spec = tfx_node.executor_spec.encode( component_spec=tfx_node.spec) deployment_config.executor_specs[tfx_node.id].Pack(executor_spec) # TODO(b/163433174): Remove specialized logic once generalization of # driver spec is done. if tfx_node.driver_class != base_driver.BaseDriver: driver_class_path = _fully_qualified_name(tfx_node.driver_class) driver_spec = executable_spec_pb2.PythonClassExecutableSpec() driver_spec.class_path = driver_class_path deployment_config.custom_driver_specs[tfx_node.id].Pack(driver_spec) # Step 7: Upstream/Downstream nodes node.upstream_nodes.extend( self._find_runtime_upstream_node_ids(compile_context, tfx_node)) node.downstream_nodes.extend( self._find_runtime_downstream_node_ids(compile_context, tfx_node)) # Step 8: Node execution options node.execution_options.caching_options.enable_cache = enable_cache # Step 9: Per-node platform config if isinstance(tfx_node, base_component.BaseComponent): tfx_component = cast(base_component.BaseComponent, tfx_node) if tfx_component.platform_config: deployment_config.node_level_platform_configs[tfx_node.id].Pack( tfx_component.platform_config) return node