def testIsImporter(self): impt = importer.Importer(source_uri="uri/to/schema", artifact_type=standard_artifacts.Schema) self.assertTrue(compiler_utils.is_importer(impt)) example_gen = CsvExampleGen(input_base="data_path") self.assertFalse(compiler_utils.is_importer(example_gen))
def testIsImporter(self): importer = ImporterNode(instance_name="import_schema", source_uri="uri/to/schema", artifact_type=standard_artifacts.Schema) self.assertTrue(compiler_utils.is_importer(importer)) example_gen = CsvExampleGen(input=external_input("data_path")) self.assertFalse(compiler_utils.is_importer(example_gen))
def testCompileImporterAdditionalPropertyTypeError(self): dsl_compiler = compiler.Compiler() test_pipeline = self._get_test_pipeline_definition(iris_pipeline_async) impt = next(c for c in test_pipeline.components if compiler_utils.is_importer(c)) impt.exec_properties[importer.PROPERTIES_KEY]["split_names"] = 2.1 with self.assertRaisesRegex(TypeError, "Expected STRING but given DOUBLE"): dsl_compiler.compile(test_pipeline)
def _compile_node( self, tfx_node: base_node.BaseNode, compile_context: _CompilerContext, deployment_config: pipeline_pb2.IntermediateDeploymentConfig, enable_cache: bool, ) -> pipeline_pb2.PipelineNode: """Compiles an individual TFX node into a PipelineNode proto. Args: tfx_node: A TFX node. compile_context: Resources needed to compile the node. deployment_config: Intermediate deployment config to set. Will include related specs for executors, drivers and platform specific configs. enable_cache: whether cache is enabled Raises: TypeError: When supplied tfx_node has values of invalid type. Returns: A PipelineNode proto that encodes information of the node. """ node = pipeline_pb2.PipelineNode() # Step 1: Node info node.node_info.type.name = tfx_node.type node.node_info.id = tfx_node.id # Step 2: Node Context # Context for the pipeline, across pipeline runs. pipeline_context_pb = node.contexts.contexts.add() pipeline_context_pb.type.name = constants.PIPELINE_CONTEXT_TYPE_NAME pipeline_context_pb.name.field_value.string_value = compile_context.pipeline_info.pipeline_context_name # Context for the current pipeline run. if compile_context.is_sync_mode: pipeline_run_context_pb = node.contexts.contexts.add() pipeline_run_context_pb.type.name = constants.PIPELINE_RUN_CONTEXT_TYPE_NAME compiler_utils.set_runtime_parameter_pb( pipeline_run_context_pb.name.runtime_parameter, constants.PIPELINE_RUN_ID_PARAMETER_NAME, str) # Context for the node, across pipeline runs. node_context_pb = node.contexts.contexts.add() node_context_pb.type.name = constants.NODE_CONTEXT_TYPE_NAME node_context_pb.name.field_value.string_value = "{}.{}".format( compile_context.pipeline_info.pipeline_context_name, node.node_info.id) # Pre Step 3: Alter graph topology if needed. if compile_context.is_async_mode: tfx_node_inputs = self._compile_resolver_config( compile_context, tfx_node, node) else: tfx_node_inputs = tfx_node.inputs # Step 3: Node inputs for key, value in tfx_node_inputs.items(): input_spec = node.inputs.inputs[key] channel = input_spec.channels.add() if value.producer_component_id: channel.producer_node_query.id = value.producer_component_id # Here we rely on pipeline.components to be topologically sorted. assert value.producer_component_id in compile_context.node_pbs, ( "producer component should have already been compiled.") producer_pb = compile_context.node_pbs[ value.producer_component_id] for producer_context in producer_pb.contexts.contexts: if (not compiler_utils.is_resolver(tfx_node) or producer_context.name.runtime_parameter.name != constants.PIPELINE_RUN_CONTEXT_TYPE_NAME): context_query = channel.context_queries.add() context_query.type.CopyFrom(producer_context.type) context_query.name.CopyFrom(producer_context.name) else: # Caveat: portable core requires every channel to have at least one # Contex. But For cases like system nodes and producer-consumer # pipelines, a channel may not have contexts at all. In these cases, # we want to use the pipeline level context as the input channel # context. context_query = channel.context_queries.add() context_query.type.CopyFrom(pipeline_context_pb.type) context_query.name.CopyFrom(pipeline_context_pb.name) artifact_type = value.type._get_artifact_type() # pylint: disable=protected-access channel.artifact_query.type.CopyFrom(artifact_type) channel.artifact_query.type.ClearField("properties") if value.output_key: channel.output_key = value.output_key # TODO(b/158712886): Calculate min_count based on if inputs are optional. # min_count = 0 stands for optional input and 1 stands for required input. # Step 3.1: Special treatment for Resolver node. if compiler_utils.is_resolver(tfx_node): assert compile_context.is_sync_mode node.inputs.resolver_config.resolver_steps.extend( _convert_to_resolver_steps(tfx_node)) # Step 4: Node outputs if isinstance(tfx_node, base_component.BaseComponent): for key, value in tfx_node.outputs.items(): output_spec = node.outputs.outputs[key] artifact_type = value.type._get_artifact_type() # pylint: disable=protected-access output_spec.artifact_spec.type.CopyFrom(artifact_type) for prop_key, prop_value in value.additional_properties.items( ): _check_property_value_type(prop_key, prop_value, output_spec.artifact_spec.type) data_types_utils.set_metadata_value( output_spec.artifact_spec. additional_properties[prop_key].field_value, prop_value) for prop_key, prop_value in value.additional_custom_properties.items( ): data_types_utils.set_metadata_value( output_spec.artifact_spec. additional_custom_properties[prop_key].field_value, prop_value) # TODO(b/170694459): Refactor special nodes as plugins. # Step 4.1: Special treament for Importer node if compiler_utils.is_importer(tfx_node): self._compile_importer_node_outputs(tfx_node, node) # Step 5: Node parameters if not compiler_utils.is_resolver(tfx_node): for key, value in tfx_node.exec_properties.items(): if value is None: continue # Ignore following two properties for a importer node, because they are # already attached to the artifacts produced by the importer node. if compiler_utils.is_importer(tfx_node) and ( key == importer.PROPERTIES_KEY or key == importer.CUSTOM_PROPERTIES_KEY): continue parameter_value = node.parameters.parameters[key] # Order matters, because runtime parameter can be in serialized string. if isinstance(value, data_types.RuntimeParameter): compiler_utils.set_runtime_parameter_pb( parameter_value.runtime_parameter, value.name, value.ptype, value.default) elif isinstance(value, str) and re.search( data_types.RUNTIME_PARAMETER_PATTERN, value): runtime_param = json.loads(value) compiler_utils.set_runtime_parameter_pb( parameter_value.runtime_parameter, runtime_param.name, runtime_param.ptype, runtime_param.default) else: try: data_types_utils.set_metadata_value( parameter_value.field_value, value) except ValueError: raise ValueError( "Component {} got unsupported parameter {} with type {}." .format(tfx_node.id, key, type(value))) # Step 6: Executor spec and optional driver spec for components if isinstance(tfx_node, base_component.BaseComponent): executor_spec = tfx_node.executor_spec.encode( component_spec=tfx_node.spec) deployment_config.executor_specs[tfx_node.id].Pack(executor_spec) # TODO(b/163433174): Remove specialized logic once generalization of # driver spec is done. if tfx_node.driver_class != base_driver.BaseDriver: driver_class_path = "{}.{}".format( tfx_node.driver_class.__module__, tfx_node.driver_class.__name__) driver_spec = executable_spec_pb2.PythonClassExecutableSpec() driver_spec.class_path = driver_class_path deployment_config.custom_driver_specs[tfx_node.id].Pack( driver_spec) # Step 7: Upstream/Downstream nodes # Note: the order of tfx_node.upstream_nodes is inconsistent from # run to run. We sort them so that compiler generates consistent results. # For ASYNC mode upstream/downstream node information is not set as # compiled IR graph topology can be different from that on pipeline # authoring time; for example ResolverNode is removed. if compile_context.is_sync_mode: node.upstream_nodes.extend( sorted(node.id for node in tfx_node.upstream_nodes)) node.downstream_nodes.extend( sorted(node.id for node in tfx_node.downstream_nodes)) # Step 8: Node execution options node.execution_options.caching_options.enable_cache = enable_cache # Step 9: Per-node platform config if isinstance(tfx_node, base_component.BaseComponent): tfx_component = cast(base_component.BaseComponent, tfx_node) if tfx_component.platform_config: deployment_config.node_level_platform_configs[ tfx_node.id].Pack(tfx_component.platform_config) return node
def _compile_node( self, tfx_node: base_node.BaseNode, compile_context: _CompilerContext, deployment_config: pipeline_pb2.IntermediateDeploymentConfig ) -> pipeline_pb2.PipelineNode: """Compiles an individual TFX node into a PipelineNode proto. Args: tfx_node: A TFX node. compile_context: Resources needed to compile the node. deployment_config: Intermediate deployment config to set. Will include related specs for executors, drivers and platform specific configs. Returns: A PipelineNode proto that encodes information of the node. """ node = pipeline_pb2.PipelineNode() # Step 1: Node info node.node_info.type.name = tfx_node.type node.node_info.id = tfx_node.id # Step 2: Node Context # Context for the pipeline, across pipeline runs. pipeline_context_pb = node.contexts.contexts.add() pipeline_context_pb.type.name = constants.PIPELINE_CONTEXT_TYPE_NAME pipeline_context_pb.name.field_value.string_value = compile_context.pipeline_info.pipeline_context_name # Context for the current pipeline run. if (compile_context.execution_mode == pipeline_pb2.Pipeline.ExecutionMode.SYNC): pipeline_run_context_pb = node.contexts.contexts.add() pipeline_run_context_pb.type.name = constants.PIPELINE_RUN_CONTEXT_TYPE_NAME compiler_utils.set_runtime_parameter_pb( pipeline_run_context_pb.name.runtime_parameter, constants.PIPELINE_RUN_ID_PARAMETER_NAME, str) # Context for the node, across pipeline runs. node_context_pb = node.contexts.contexts.add() node_context_pb.type.name = constants.NODE_CONTEXT_TYPE_NAME node_context_pb.name.field_value.string_value = "{}.{}".format( compile_context.pipeline_info.pipeline_context_name, node.node_info.id) # Step 3: Node inputs for key, value in tfx_node.inputs.items(): input_spec = node.inputs.inputs[key] channel = input_spec.channels.add() if value.producer_component_id: channel.producer_node_query.id = value.producer_component_id # Here we rely on pipeline.components to be topologically sorted. assert value.producer_component_id in compile_context.node_pbs, ( "producer component should have already been compiled.") producer_pb = compile_context.node_pbs[ value.producer_component_id] for producer_context in producer_pb.contexts.contexts: if (not compiler_utils.is_resolver(tfx_node) or producer_context.name.runtime_parameter.name != constants.PIPELINE_RUN_CONTEXT_TYPE_NAME): context_query = channel.context_queries.add() context_query.type.CopyFrom(producer_context.type) context_query.name.CopyFrom(producer_context.name) artifact_type = value.type._get_artifact_type() # pylint: disable=protected-access channel.artifact_query.type.CopyFrom(artifact_type) channel.artifact_query.type.ClearField("properties") if value.output_key: channel.output_key = value.output_key # TODO(b/158712886): Calculate min_count based on if inputs are optional. # min_count = 0 stands for optional input and 1 stands for required input. # Step 3.1: Special treatment for Resolver node if compiler_utils.is_resolver(tfx_node): resolver = tfx_node.exec_properties[resolver_node.RESOLVER_CLASS] if resolver == latest_artifacts_resolver.LatestArtifactsResolver: node.inputs.resolver_config.resolver_policy = ( pipeline_pb2.ResolverConfig.ResolverPolicy.LATEST_ARTIFACT) elif resolver == latest_blessed_model_resolver.LatestBlessedModelResolver: node.inputs.resolver_config.resolver_policy = ( pipeline_pb2.ResolverConfig.ResolverPolicy. LATEST_BLESSED_MODEL) else: raise ValueError("Got unsupported resolver policy: {}".format( resolver.type)) # Step 4: Node outputs if isinstance(tfx_node, base_component.BaseComponent): for key, value in tfx_node.outputs.items(): output_spec = node.outputs.outputs[key] artifact_type = value.type._get_artifact_type() # pylint: disable=protected-access output_spec.artifact_spec.type.CopyFrom(artifact_type) # TODO(b/170694459): Refactor special nodes as plugins. # Step 4.1: Special treament for Importer node if compiler_utils.is_importer(tfx_node): self._compile_importer_node_outputs(tfx_node, node) # Step 5: Node parameters if not compiler_utils.is_resolver(tfx_node): for key, value in tfx_node.exec_properties.items(): if value is None: continue # Ignore following two properties for a importer node, because they are # already attached to the artifacts produced by the importer node. if compiler_utils.is_importer(tfx_node) and ( key == importer_node.PROPERTIES_KEY or key == importer_node.CUSTOM_PROPERTIES_KEY): continue parameter_value = node.parameters.parameters[key] # Order matters, because runtime parameter can be in serialized string. if isinstance(value, data_types.RuntimeParameter): compiler_utils.set_runtime_parameter_pb( parameter_value.runtime_parameter, value.name, value.ptype, value.default) elif isinstance(value, str) and re.search( data_types.RUNTIME_PARAMETER_PATTERN, value): runtime_param = json.loads(value) compiler_utils.set_runtime_parameter_pb( parameter_value.runtime_parameter, runtime_param.name, runtime_param.ptype, runtime_param.default) else: try: compiler_utils.set_field_value_pb( parameter_value.field_value, value) except ValueError: raise ValueError( "Component {} got unsupported parameter {} with type {}." .format(tfx_node.id, key, type(value))) # Step 6: Executor spec and optional driver spec for components if isinstance(tfx_node, base_component.BaseComponent): executor_spec = tfx_node.executor_spec.encode() deployment_config.executor_specs[tfx_node.id].Pack(executor_spec) # TODO(b/163433174): Remove specialized logic once generalization of # driver spec is done. if tfx_node.driver_class != base_driver.BaseDriver: driver_class_path = "{}.{}".format( tfx_node.driver_class.__module__, tfx_node.driver_class.__name__) driver_spec = executable_spec_pb2.PythonClassExecutableSpec() driver_spec.class_path = driver_class_path deployment_config.custom_driver_specs[tfx_node.id].Pack( driver_spec) # Step 7: Upstream/Downstream nodes # Note: the order of tfx_node.upstream_nodes is inconsistent from # run to run. We sort them so that compiler generates consistent results. node.upstream_nodes.extend( sorted([ upstream_component.id for upstream_component in tfx_node.upstream_nodes ])) node.downstream_nodes.extend( sorted([ downstream_component.id for downstream_component in tfx_node.downstream_nodes ])) # Step 8: Node execution options # TODO(kennethyang): Add support for node execution options. # Step 9: Per-node platform config if isinstance(tfx_node, base_component.BaseComponent): tfx_component = cast(base_component.BaseComponent, tfx_node) if tfx_component.platform_config: deployment_config.node_level_platform_configs[ tfx_node.id].Pack(tfx_component.platform_config) return node
def _compile_node( self, tfx_node: base_node.BaseNode, compile_context: _CompilerContext, deployment_config: pipeline_pb2.IntermediateDeploymentConfig, enable_cache: bool, ) -> pipeline_pb2.PipelineNode: """Compiles an individual TFX node into a PipelineNode proto. Args: tfx_node: A TFX node. compile_context: Resources needed to compile the node. deployment_config: Intermediate deployment config to set. Will include related specs for executors, drivers and platform specific configs. enable_cache: whether cache is enabled Raises: TypeError: When supplied tfx_node has values of invalid type. Returns: A PipelineNode proto that encodes information of the node. """ node = pipeline_pb2.PipelineNode() # Step 1: Node info node.node_info.type.name = tfx_node.type if isinstance(tfx_node, base_component.BaseComponent) and tfx_node.type_annotation: node.node_info.type.base_type = ( tfx_node.type_annotation.MLMD_SYSTEM_BASE_TYPE) node.node_info.id = tfx_node.id # Step 2: Node Context # Context for the pipeline, across pipeline runs. pipeline_context_pb = node.contexts.contexts.add() pipeline_context_pb.type.name = constants.PIPELINE_CONTEXT_TYPE_NAME pipeline_context_pb.name.field_value.string_value = compile_context.pipeline_info.pipeline_context_name # Context for the current pipeline run. if compile_context.is_sync_mode: pipeline_run_context_pb = node.contexts.contexts.add() pipeline_run_context_pb.type.name = constants.PIPELINE_RUN_CONTEXT_TYPE_NAME compiler_utils.set_runtime_parameter_pb( pipeline_run_context_pb.name.runtime_parameter, constants.PIPELINE_RUN_ID_PARAMETER_NAME, str) # Context for the node, across pipeline runs. node_context_pb = node.contexts.contexts.add() node_context_pb.type.name = constants.NODE_CONTEXT_TYPE_NAME node_context_pb.name.field_value.string_value = ( compiler_utils.node_context_name( compile_context.pipeline_info.pipeline_context_name, node.node_info.id)) # Pre Step 3: Alter graph topology if needed. if compile_context.is_async_mode: tfx_node_inputs = self._embed_upstream_resolver_nodes( compile_context, tfx_node, node) else: tfx_node_inputs = tfx_node.inputs # Step 3: Node inputs # Step 3.1: Generate implicit input channels # Step 3.1.1: Conditionals implicit_input_channels = {} predicates = conditional.get_predicates(tfx_node) if predicates: implicit_keys_map = {} for key, chnl in tfx_node_inputs.items(): if not isinstance(chnl, types.Channel): raise ValueError( "Conditional only support using channel as a predicate.") implicit_keys_map[compiler_utils.implicit_channel_key(chnl)] = key encoded_predicates = [] for predicate in predicates: for chnl in predicate.dependent_channels(): implicit_key = compiler_utils.implicit_channel_key(chnl) if implicit_key not in implicit_keys_map: # Store this channel and add it to the node inputs later. implicit_input_channels[implicit_key] = chnl encoded_predicates.append( predicate.encode_with_keys( compiler_utils.build_channel_to_key_fn(implicit_keys_map))) # In async pipeline, conditional resolver step should be the last step # in all resolver steps of a node. resolver_step = node.inputs.resolver_config.resolver_steps.add() resolver_step.class_path = constants.CONDITIONAL_RESOLVER_CLASS_PATH resolver_step.config_json = json_utils.dumps( {"predicates": encoded_predicates}) # Step 3.1.2: Add placeholder exec props to implicit_input_channels for key, value in tfx_node.exec_properties.items(): if isinstance(value, placeholder.ChannelWrappedPlaceholder): if not (inspect.isclass(value.channel.type) and issubclass(value.channel.type, value_artifact.ValueArtifact)): raise ValueError("output channel to dynamic exec properties is not " "ValueArtifact") implicit_key = compiler_utils.implicit_channel_key(value.channel) implicit_input_channels[implicit_key] = value.channel # Step 3.2: Handle ForEach. dsl_contexts = context_manager.get_contexts(tfx_node) for dsl_context in dsl_contexts: if isinstance(dsl_context, for_each.ForEachContext): for input_key, channel in tfx_node_inputs.items(): if (isinstance(channel, types.LoopVarChannel) and channel.wrapped is dsl_context.wrapped_channel): node.inputs.resolver_config.resolver_steps.extend( _compile_for_each_context(input_key)) break else: # Ideally should not reach here as the same check is performed at # ForEachContext.will_add_node(). raise ValueError( f"Unable to locate ForEach loop variable {dsl_context.channel} " f"from inputs of node {tfx_node.id}.") # Check loop variable is used outside the ForEach. for input_key, channel in tfx_node_inputs.items(): if isinstance(channel, types.LoopVarChannel): dsl_context_ids = {c.id for c in dsl_contexts} if channel.context_id not in dsl_context_ids: raise ValueError( "Loop variable cannot be used outside the ForEach " f"(node_id = {tfx_node.id}, input_key = {input_key}).") # Step 3.3: Fill node inputs for key, value in itertools.chain(tfx_node_inputs.items(), implicit_input_channels.items()): input_spec = node.inputs.inputs[key] for input_channel in channel_utils.get_individual_channels(value): chnl = input_spec.channels.add() # If the node input comes from another node's output, fill the context # queries with the producer node's contexts. if input_channel in compile_context.node_outputs: chnl.producer_node_query.id = input_channel.producer_component_id # Here we rely on pipeline.components to be topologically sorted. assert input_channel.producer_component_id in compile_context.node_pbs, ( "producer component should have already been compiled.") producer_pb = compile_context.node_pbs[ input_channel.producer_component_id] for producer_context in producer_pb.contexts.contexts: context_query = chnl.context_queries.add() context_query.type.CopyFrom(producer_context.type) context_query.name.CopyFrom(producer_context.name) # If the node input does not come from another node's output, fill the # context queries based on Channel info. We requires every channel to # have pipeline context and will fill it automatically. else: # Add pipeline context query. context_query = chnl.context_queries.add() context_query.type.CopyFrom(pipeline_context_pb.type) context_query.name.CopyFrom(pipeline_context_pb.name) # Optionally add node context query. if input_channel.producer_component_id: # Add node context query if `producer_component_id` is present. chnl.producer_node_query.id = input_channel.producer_component_id node_context_query = chnl.context_queries.add() node_context_query.type.name = constants.NODE_CONTEXT_TYPE_NAME node_context_query.name.field_value.string_value = "{}.{}".format( compile_context.pipeline_info.pipeline_context_name, input_channel.producer_component_id) artifact_type = input_channel.type._get_artifact_type() # pylint: disable=protected-access chnl.artifact_query.type.CopyFrom(artifact_type) chnl.artifact_query.type.ClearField("properties") if input_channel.output_key: chnl.output_key = input_channel.output_key # Set NodeInputs.min_count. if isinstance(tfx_node, base_component.BaseComponent): if key in implicit_input_channels: # Mark all input channel as optional for implicit inputs # (e.g. conditionals). This is suboptimal, but still a safe guess to # avoid breaking the pipeline run. input_spec.min_count = 0 else: try: # Calculating min_count from ComponentSpec.INPUTS. if tfx_node.spec.is_optional_input(key): input_spec.min_count = 0 else: input_spec.min_count = 1 except KeyError: # Currently we can fall here if the upstream resolver node inputs # are embedded into the current node (in async mode). We always # regard resolver's inputs as optional. if compile_context.is_async_mode: input_spec.min_count = 0 else: raise # TODO(b/170694459): Refactor special nodes as plugins. # Step 3.4: Special treatment for Resolver node. if compiler_utils.is_resolver(tfx_node): assert compile_context.is_sync_mode node.inputs.resolver_config.resolver_steps.extend( _compile_resolver_node(tfx_node)) # Step 4: Node outputs for key, value in tfx_node.outputs.items(): # Register the output in the context. compile_context.node_outputs.add(value) if (isinstance(tfx_node, base_component.BaseComponent) or compiler_utils.is_importer(tfx_node)): self._compile_node_outputs(tfx_node, node) # Step 5: Node parameters if not compiler_utils.is_resolver(tfx_node): for key, value in tfx_node.exec_properties.items(): if value is None: continue parameter_value = node.parameters.parameters[key] # Order matters, because runtime parameter can be in serialized string. if isinstance(value, data_types.RuntimeParameter): compiler_utils.set_runtime_parameter_pb( parameter_value.runtime_parameter, value.name, value.ptype, value.default) # RuntimeInfoPlaceholder passes Execution parameters of Facade # components. elif isinstance(value, placeholder.RuntimeInfoPlaceholder): parameter_value.placeholder.CopyFrom(value.encode()) # ChannelWrappedPlaceholder passes dynamic execution parameter. elif isinstance(value, placeholder.ChannelWrappedPlaceholder): compiler_utils.validate_dynamic_exec_ph_operator(value) parameter_value.placeholder.CopyFrom( value.encode_with_keys(compiler_utils.implicit_channel_key)) else: try: data_types_utils.set_parameter_value(parameter_value, value) except ValueError: raise ValueError( "Component {} got unsupported parameter {} with type {}." .format(tfx_node.id, key, type(value))) from ValueError # Step 6: Executor spec and optional driver spec for components if isinstance(tfx_node, base_component.BaseComponent): executor_spec = tfx_node.executor_spec.encode( component_spec=tfx_node.spec) deployment_config.executor_specs[tfx_node.id].Pack(executor_spec) # TODO(b/163433174): Remove specialized logic once generalization of # driver spec is done. if tfx_node.driver_class != base_driver.BaseDriver: driver_class_path = _fully_qualified_name(tfx_node.driver_class) driver_spec = executable_spec_pb2.PythonClassExecutableSpec() driver_spec.class_path = driver_class_path deployment_config.custom_driver_specs[tfx_node.id].Pack(driver_spec) # Step 7: Upstream/Downstream nodes node.upstream_nodes.extend( self._find_runtime_upstream_node_ids(compile_context, tfx_node)) node.downstream_nodes.extend( self._find_runtime_downstream_node_ids(compile_context, tfx_node)) # Step 8: Node execution options node.execution_options.caching_options.enable_cache = enable_cache # Step 9: Per-node platform config if isinstance(tfx_node, base_component.BaseComponent): tfx_component = cast(base_component.BaseComponent, tfx_node) if tfx_component.platform_config: deployment_config.node_level_platform_configs[tfx_node.id].Pack( tfx_component.platform_config) return node