def testIsResolver(self):
    resv = resolver.Resolver(
        strategy_class=latest_blessed_model_resolver.LatestBlessedModelResolver)
    self.assertTrue(compiler_utils.is_resolver(resv))

    example_gen = CsvExampleGen(input_base="data_path")
    self.assertFalse(compiler_utils.is_resolver(example_gen))
Exemple #2
0
    def testIsResolver(self):
        resolver = ResolverNode(instance_name="test_resolver_name",
                                resolver_class=latest_blessed_model_resolver.
                                LatestBlessedModelResolver)
        self.assertTrue(compiler_utils.is_resolver(resolver))

        example_gen = CsvExampleGen(input=external_input("data_path"))
        self.assertFalse(compiler_utils.is_resolver(example_gen))
Exemple #3
0
    def testIsResolver(self):
        resv = resolver.Resolver(instance_name="test_resolver_name",
                                 strategy_class=latest_blessed_model_resolver.
                                 LatestBlessedModelResolver)
        self.assertTrue(compiler_utils.is_resolver(resv))
        resv = legacy_resolver_node.ResolverNode(
            instance_name="test_resolver_name",
            resolver_class=latest_blessed_model_resolver.
            LatestBlessedModelResolver)
        self.assertTrue(compiler_utils.is_resolver(resv))

        example_gen = CsvExampleGen(input_base="data_path")
        self.assertFalse(compiler_utils.is_resolver(example_gen))
Exemple #4
0
def _compile_resolver_node(
    resolver_node: base_node.BaseNode,
) -> List[pipeline_pb2.ResolverConfig.ResolverStep]:
  """Converts Resolver node to a corresponding ResolverSteps."""
  assert compiler_utils.is_resolver(resolver_node)
  resolver_node = cast(resolver.Resolver, resolver_node)
  result = _compile_resolver_function(resolver_node.resolver_function)
  for step in result:
    step.input_keys.extend(resolver_node.inputs.keys())
  return result
Exemple #5
0
 def _find_runtime_downstream_node_ids(self, context: _CompilerContext,
                                       here: base_node.BaseNode) -> List[str]:
   """Finds all downstream nodes that depend on the current node."""
   result = set()
   for down in itertools.chain(here.downstream_nodes,
                               context.implicit_downstream_nodes(here)):
     if context.is_async_mode and compiler_utils.is_resolver(down):
       result.update(self._find_runtime_downstream_node_ids(context, down))
     else:
       result.add(down.id)
   # Sort result so that compiler generates consistent results.
   return sorted(result)
Exemple #6
0
def _convert_to_resolver_steps(resolver_node: base_node.BaseNode):
    """Converts ResolverNode to a corresponding ResolverSteps."""
    assert compiler_utils.is_resolver(resolver_node)
    result = []
    for resolver_cls, resolver_config in (
            _iterate_resolver_cls_and_config(resolver_node)):
        resolver_step = pipeline_pb2.ResolverConfig.ResolverStep()
        resolver_step.class_path = (
            f"{resolver_cls.__module__}.{resolver_cls.__name__}")
        resolver_step.config_json = json_utils.dumps(resolver_config)
        resolver_step.input_keys.extend(resolver_node.inputs.keys())
        result.append(resolver_step)
    return result
Exemple #7
0
def _convert_to_resolver_steps(resolver_node: base_node.BaseNode):
    """Converts Resolver node to a corresponding ResolverSteps."""
    assert compiler_utils.is_resolver(resolver_node)
    resolver_node = cast(resolver.Resolver, resolver_node)
    result = []
    for strategy_cls, config in resolver_node.strategy_class_and_configs:
        step = pipeline_pb2.ResolverConfig.ResolverStep()
        step.class_path = (
            f"{strategy_cls.__module__}.{strategy_cls.__name__}")
        step.config_json = json_utils.dumps(config)
        step.input_keys.extend(resolver_node.inputs.keys())
        result.append(step)
    return result
Exemple #8
0
def _iterate_resolver_cls_and_config(resolver_node: base_node.BaseNode):
  """Iterates through resolver class and configs that are bind to the node."""
  assert compiler_utils.is_resolver(resolver_node)
  exec_properties = resolver_node.exec_properties
  if (resolver.RESOLVER_STRATEGY_CLASS in exec_properties and
      resolver.RESOLVER_CONFIG in exec_properties):
    yield (exec_properties[resolver.RESOLVER_STRATEGY_CLASS],
           exec_properties[resolver.RESOLVER_CONFIG])
  elif (resolver.RESOLVER_STRATEGY_CLASS_LIST in exec_properties and
        resolver.RESOLVER_CONFIG_LIST in exec_properties):
    yield from zip(exec_properties[resolver.RESOLVER_STRATEGY_CLASS_LIST],
                   exec_properties[resolver.RESOLVER_CONFIG_LIST])
  else:
    raise ValueError(f"Invalid ResolverNode exec_properties: {exec_properties}")
Exemple #9
0
  def compile(self, tfx_pipeline: pipeline.Pipeline) -> pipeline_pb2.Pipeline:
    """Compiles a tfx pipeline into uDSL proto.

    Args:
      tfx_pipeline: A TFX pipeline.

    Returns:
      A Pipeline proto that encodes all necessary information of the pipeline.
    """
    _validate_pipeline(tfx_pipeline)
    context = _CompilerContext(tfx_pipeline)
    pipeline_pb = pipeline_pb2.Pipeline()
    pipeline_pb.pipeline_info.id = context.pipeline_info.pipeline_name
    pipeline_pb.execution_mode = context.execution_mode
    if isinstance(context.pipeline_info.pipeline_root, placeholder.Placeholder):
      pipeline_pb.runtime_spec.pipeline_root.placeholder.CopyFrom(
          context.pipeline_info.pipeline_root.encode())
    else:
      compiler_utils.set_runtime_parameter_pb(
          pipeline_pb.runtime_spec.pipeline_root.runtime_parameter,
          constants.PIPELINE_ROOT_PARAMETER_NAME, str,
          context.pipeline_info.pipeline_root)
    if pipeline_pb.execution_mode == pipeline_pb2.Pipeline.ExecutionMode.SYNC:
      compiler_utils.set_runtime_parameter_pb(
          pipeline_pb.runtime_spec.pipeline_run_id.runtime_parameter,
          constants.PIPELINE_RUN_ID_PARAMETER_NAME, str)

    deployment_config = pipeline_pb2.IntermediateDeploymentConfig()
    if tfx_pipeline.metadata_connection_config:
      deployment_config.metadata_connection_config.Pack(
          tfx_pipeline.metadata_connection_config)
    for node in tfx_pipeline.components:
      # In ASYNC mode Resolver nodes are merged into the downstream node as a
      # ResolverConfig
      if compiler_utils.is_resolver(node) and context.is_async_mode:
        continue
      node_pb = self._compile_node(node, context, deployment_config,
                                   tfx_pipeline.enable_cache)
      pipeline_or_node = pipeline_pb.PipelineOrNode()
      pipeline_or_node.pipeline_node.CopyFrom(node_pb)
      # TODO(b/158713812): Support sub-pipeline.
      pipeline_pb.nodes.append(pipeline_or_node)
      context.node_pbs[node.id] = node_pb

    if tfx_pipeline.platform_config:
      deployment_config.pipeline_level_platform_config.Pack(
          tfx_pipeline.platform_config)
    pipeline_pb.deployment_config.Pack(deployment_config)
    return pipeline_pb
Exemple #10
0
 def _get_upstream_resolver_nodes(
         self, tfx_node: base_node.BaseNode) -> List[base_node.BaseNode]:
     """Gets all transitive upstream resolver nodes in topological order."""
     result = []
     visit_queue = list(tfx_node.upstream_nodes)
     seen = set(node.id for node in visit_queue)
     while visit_queue:
         node = visit_queue.pop()
         if not compiler_utils.is_resolver(node):
             continue
         result.append(node)
         for upstream_node in node.upstream_nodes:
             if upstream_node.id not in seen:
                 seen.add(node.id)
                 visit_queue.append(upstream_node)
     return result
Exemple #11
0
    def _compile_node(
        self,
        tfx_node: base_node.BaseNode,
        compile_context: _CompilerContext,
        deployment_config: pipeline_pb2.IntermediateDeploymentConfig,
        enable_cache: bool,
    ) -> pipeline_pb2.PipelineNode:
        """Compiles an individual TFX node into a PipelineNode proto.

    Args:
      tfx_node: A TFX node.
      compile_context: Resources needed to compile the node.
      deployment_config: Intermediate deployment config to set. Will include
        related specs for executors, drivers and platform specific configs.
      enable_cache: whether cache is enabled

    Raises:
      TypeError: When supplied tfx_node has values of invalid type.

    Returns:
      A PipelineNode proto that encodes information of the node.
    """
        node = pipeline_pb2.PipelineNode()

        # Step 1: Node info
        node.node_info.type.name = tfx_node.type
        node.node_info.id = tfx_node.id

        # Step 2: Node Context
        # Context for the pipeline, across pipeline runs.
        pipeline_context_pb = node.contexts.contexts.add()
        pipeline_context_pb.type.name = constants.PIPELINE_CONTEXT_TYPE_NAME
        pipeline_context_pb.name.field_value.string_value = compile_context.pipeline_info.pipeline_context_name

        # Context for the current pipeline run.
        if compile_context.is_sync_mode:
            pipeline_run_context_pb = node.contexts.contexts.add()
            pipeline_run_context_pb.type.name = constants.PIPELINE_RUN_CONTEXT_TYPE_NAME
            compiler_utils.set_runtime_parameter_pb(
                pipeline_run_context_pb.name.runtime_parameter,
                constants.PIPELINE_RUN_ID_PARAMETER_NAME, str)

        # Context for the node, across pipeline runs.
        node_context_pb = node.contexts.contexts.add()
        node_context_pb.type.name = constants.NODE_CONTEXT_TYPE_NAME
        node_context_pb.name.field_value.string_value = "{}.{}".format(
            compile_context.pipeline_info.pipeline_context_name,
            node.node_info.id)

        # Pre Step 3: Alter graph topology if needed.
        if compile_context.is_async_mode:
            tfx_node_inputs = self._compile_resolver_config(
                compile_context, tfx_node, node)
        else:
            tfx_node_inputs = tfx_node.inputs

        # Step 3: Node inputs
        for key, value in tfx_node_inputs.items():
            input_spec = node.inputs.inputs[key]
            channel = input_spec.channels.add()
            if value.producer_component_id:
                channel.producer_node_query.id = value.producer_component_id

                # Here we rely on pipeline.components to be topologically sorted.
                assert value.producer_component_id in compile_context.node_pbs, (
                    "producer component should have already been compiled.")
                producer_pb = compile_context.node_pbs[
                    value.producer_component_id]
                for producer_context in producer_pb.contexts.contexts:
                    if (not compiler_utils.is_resolver(tfx_node)
                            or producer_context.name.runtime_parameter.name !=
                            constants.PIPELINE_RUN_CONTEXT_TYPE_NAME):
                        context_query = channel.context_queries.add()
                        context_query.type.CopyFrom(producer_context.type)
                        context_query.name.CopyFrom(producer_context.name)
            else:
                # Caveat: portable core requires every channel to have at least one
                # Contex. But For cases like system nodes and producer-consumer
                # pipelines, a channel may not have contexts at all. In these cases,
                # we want to use the pipeline level context as the input channel
                # context.
                context_query = channel.context_queries.add()
                context_query.type.CopyFrom(pipeline_context_pb.type)
                context_query.name.CopyFrom(pipeline_context_pb.name)

            artifact_type = value.type._get_artifact_type()  # pylint: disable=protected-access
            channel.artifact_query.type.CopyFrom(artifact_type)
            channel.artifact_query.type.ClearField("properties")

            if value.output_key:
                channel.output_key = value.output_key

            # TODO(b/158712886): Calculate min_count based on if inputs are optional.
            # min_count = 0 stands for optional input and 1 stands for required input.

        # Step 3.1: Special treatment for Resolver node.
        if compiler_utils.is_resolver(tfx_node):
            assert compile_context.is_sync_mode
            node.inputs.resolver_config.resolver_steps.extend(
                _convert_to_resolver_steps(tfx_node))

        # Step 4: Node outputs
        if isinstance(tfx_node, base_component.BaseComponent):
            for key, value in tfx_node.outputs.items():
                output_spec = node.outputs.outputs[key]
                artifact_type = value.type._get_artifact_type()  # pylint: disable=protected-access
                output_spec.artifact_spec.type.CopyFrom(artifact_type)
                for prop_key, prop_value in value.additional_properties.items(
                ):
                    _check_property_value_type(prop_key, prop_value,
                                               output_spec.artifact_spec.type)
                    data_types_utils.set_metadata_value(
                        output_spec.artifact_spec.
                        additional_properties[prop_key].field_value,
                        prop_value)
                for prop_key, prop_value in value.additional_custom_properties.items(
                ):
                    data_types_utils.set_metadata_value(
                        output_spec.artifact_spec.
                        additional_custom_properties[prop_key].field_value,
                        prop_value)

        # TODO(b/170694459): Refactor special nodes as plugins.
        # Step 4.1: Special treament for Importer node
        if compiler_utils.is_importer(tfx_node):
            self._compile_importer_node_outputs(tfx_node, node)

        # Step 5: Node parameters
        if not compiler_utils.is_resolver(tfx_node):
            for key, value in tfx_node.exec_properties.items():
                if value is None:
                    continue
                # Ignore following two properties for a importer node, because they are
                # already attached to the artifacts produced by the importer node.
                if compiler_utils.is_importer(tfx_node) and (
                        key == importer.PROPERTIES_KEY
                        or key == importer.CUSTOM_PROPERTIES_KEY):
                    continue
                parameter_value = node.parameters.parameters[key]

                # Order matters, because runtime parameter can be in serialized string.
                if isinstance(value, data_types.RuntimeParameter):
                    compiler_utils.set_runtime_parameter_pb(
                        parameter_value.runtime_parameter, value.name,
                        value.ptype, value.default)
                elif isinstance(value, str) and re.search(
                        data_types.RUNTIME_PARAMETER_PATTERN, value):
                    runtime_param = json.loads(value)
                    compiler_utils.set_runtime_parameter_pb(
                        parameter_value.runtime_parameter, runtime_param.name,
                        runtime_param.ptype, runtime_param.default)
                else:
                    try:
                        data_types_utils.set_metadata_value(
                            parameter_value.field_value, value)
                    except ValueError:
                        raise ValueError(
                            "Component {} got unsupported parameter {} with type {}."
                            .format(tfx_node.id, key, type(value)))

        # Step 6: Executor spec and optional driver spec for components
        if isinstance(tfx_node, base_component.BaseComponent):
            executor_spec = tfx_node.executor_spec.encode(
                component_spec=tfx_node.spec)
            deployment_config.executor_specs[tfx_node.id].Pack(executor_spec)

            # TODO(b/163433174): Remove specialized logic once generalization of
            # driver spec is done.
            if tfx_node.driver_class != base_driver.BaseDriver:
                driver_class_path = "{}.{}".format(
                    tfx_node.driver_class.__module__,
                    tfx_node.driver_class.__name__)
                driver_spec = executable_spec_pb2.PythonClassExecutableSpec()
                driver_spec.class_path = driver_class_path
                deployment_config.custom_driver_specs[tfx_node.id].Pack(
                    driver_spec)

        # Step 7: Upstream/Downstream nodes
        # Note: the order of tfx_node.upstream_nodes is inconsistent from
        # run to run. We sort them so that compiler generates consistent results.
        # For ASYNC mode upstream/downstream node information is not set as
        # compiled IR graph topology can be different from that on pipeline
        # authoring time; for example ResolverNode is removed.
        if compile_context.is_sync_mode:
            node.upstream_nodes.extend(
                sorted(node.id for node in tfx_node.upstream_nodes))
            node.downstream_nodes.extend(
                sorted(node.id for node in tfx_node.downstream_nodes))

        # Step 8: Node execution options
        node.execution_options.caching_options.enable_cache = enable_cache

        # Step 9: Per-node platform config
        if isinstance(tfx_node, base_component.BaseComponent):
            tfx_component = cast(base_component.BaseComponent, tfx_node)
            if tfx_component.platform_config:
                deployment_config.node_level_platform_configs[
                    tfx_node.id].Pack(tfx_component.platform_config)

        return node
Exemple #12
0
    def _compile_node(
        self, tfx_node: base_node.BaseNode, compile_context: _CompilerContext,
        deployment_config: pipeline_pb2.IntermediateDeploymentConfig
    ) -> pipeline_pb2.PipelineNode:
        """Compiles an individual TFX node into a PipelineNode proto.

    Args:
      tfx_node: A TFX node.
      compile_context: Resources needed to compile the node.
      deployment_config: Intermediate deployment config to set. Will include
        related specs for executors, drivers and platform specific configs.

    Returns:
      A PipelineNode proto that encodes information of the node.
    """
        node = pipeline_pb2.PipelineNode()

        # Step 1: Node info
        node.node_info.type.name = tfx_node.type
        node.node_info.id = tfx_node.id

        # Step 2: Node Context
        # Context for the pipeline, across pipeline runs.
        pipeline_context_pb = node.contexts.contexts.add()
        pipeline_context_pb.type.name = constants.PIPELINE_CONTEXT_TYPE_NAME
        pipeline_context_pb.name.field_value.string_value = compile_context.pipeline_info.pipeline_context_name

        # Context for the current pipeline run.
        if (compile_context.execution_mode ==
                pipeline_pb2.Pipeline.ExecutionMode.SYNC):
            pipeline_run_context_pb = node.contexts.contexts.add()
            pipeline_run_context_pb.type.name = constants.PIPELINE_RUN_CONTEXT_TYPE_NAME
            compiler_utils.set_runtime_parameter_pb(
                pipeline_run_context_pb.name.runtime_parameter,
                constants.PIPELINE_RUN_CONTEXT_TYPE_NAME, str)

        # Context for the node, across pipeline runs.
        node_context_pb = node.contexts.contexts.add()
        node_context_pb.type.name = constants.NODE_CONTEXT_TYPE_NAME
        node_context_pb.name.field_value.string_value = "{}.{}".format(
            compile_context.pipeline_info.pipeline_context_name,
            node.node_info.id)

        # Step 3: Node inputs
        for key, value in tfx_node.inputs.items():
            input_spec = node.inputs.inputs[key]
            channel = input_spec.channels.add()
            if value.producer_component_id:
                channel.producer_node_query.id = value.producer_component_id

                # Here we rely on pipeline.components to be topologically sorted.
                assert value.producer_component_id in compile_context.node_pbs, (
                    "producer component should have already been compiled.")
                producer_pb = compile_context.node_pbs[
                    value.producer_component_id]
                for producer_context in producer_pb.contexts.contexts:
                    if (not compiler_utils.is_resolver(tfx_node)
                            or producer_context.name.runtime_parameter.name !=
                            constants.PIPELINE_RUN_CONTEXT_TYPE_NAME):
                        context_query = channel.context_queries.add()
                        context_query.type.CopyFrom(producer_context.type)
                        context_query.name.CopyFrom(producer_context.name)

            artifact_type = value.type._get_artifact_type()  # pylint: disable=protected-access
            channel.artifact_query.type.CopyFrom(artifact_type)
            channel.artifact_query.type.ClearField("properties")

            if value.output_key:
                channel.output_key = value.output_key

            # TODO(b/158712886): Calculate min_count based on if inputs are optional.
            # min_count = 0 stands for optional input and 1 stands for required input.

        # Step 3.1: Special treatment for Resolver node
        if compiler_utils.is_resolver(tfx_node):
            resolver = tfx_node.exec_properties[resolver_node.RESOLVER_CLASS]
            if resolver == latest_artifacts_resolver.LatestArtifactsResolver:
                node.inputs.resolver_config.resolver_policy = (
                    pipeline_pb2.ResolverConfig.ResolverPolicy.LATEST_ARTIFACT)
            elif resolver == latest_blessed_model_resolver.LatestBlessedModelResolver:
                node.inputs.resolver_config.resolver_policy = (
                    pipeline_pb2.ResolverConfig.ResolverPolicy.
                    LATEST_BLESSED_MODEL)
            else:
                raise ValueError("Got unsupported resolver policy: {}".format(
                    resolver.type))

        # Step 4: Node outputs
        if compiler_utils.is_component(tfx_node):
            for key, value in tfx_node.outputs.items():
                output_spec = node.outputs.outputs[key]
                artifact_type = value.type._get_artifact_type()  # pylint: disable=protected-access
                output_spec.artifact_spec.type.CopyFrom(artifact_type)

        # Step 5: Node parameters
        if not compiler_utils.is_resolver(tfx_node):
            for key, value in tfx_node.exec_properties.items():
                if value is None:
                    continue
                parameter_value = node.parameters.parameters[key]

                # Order matters, because runtime parameter can be in serialized string.
                if isinstance(value, data_types.RuntimeParameter):
                    compiler_utils.set_runtime_parameter_pb(
                        parameter_value.runtime_parameter, value.name,
                        value.ptype, value.default)
                elif isinstance(value, str) and re.search(
                        data_types.RUNTIME_PARAMETER_PATTERN, value):
                    runtime_param = json.loads(value)
                    compiler_utils.set_runtime_parameter_pb(
                        parameter_value.runtime_parameter, runtime_param.name,
                        runtime_param.ptype, runtime_param.default)
                elif isinstance(value, str):
                    parameter_value.field_value.string_value = value
                elif isinstance(value, int):
                    parameter_value.field_value.int_value = value
                elif isinstance(value, float):
                    parameter_value.field_value.double_value = value
                else:
                    raise ValueError(
                        "Component {} got unsupported parameter {} with type {}."
                        .format(tfx_node.id, key, type(value)))

        # Step 6: Executor spec and optional driver spec for components
        if compiler_utils.is_component(tfx_node):
            executor_spec = tfx_node.executor_spec.encode()
            deployment_config.executor_specs[tfx_node.id].Pack(executor_spec)

            # TODO(b/163433174): Remove specialized logic once generalization of
            # driver spec is done.
            if tfx_node.driver_class != base_driver.BaseDriver:
                driver_class_path = "{}.{}".format(
                    tfx_node.driver_class.__module__,
                    tfx_node.driver_class.__name__)
                driver_spec = executable_spec_pb2.PythonClassExecutableSpec()
                driver_spec.class_path = driver_class_path
                deployment_config.custom_driver_specs[tfx_node.id].Pack(
                    driver_spec)

        # Step 7: Upstream/Downstream nodes
        # Note: the order of tfx_node.upstream_nodes is inconsistent from
        # run to run. We sort them so that compiler generates consistent results.
        node.upstream_nodes.extend(
            sorted([
                upstream_component.id
                for upstream_component in tfx_node.upstream_nodes
            ]))
        node.downstream_nodes.extend(
            sorted([
                downstream_component.id
                for downstream_component in tfx_node.downstream_nodes
            ]))

        # Step 8: Node execution options
        # TODO(kennethyang): Add support for node execution options.

        return node
Exemple #13
0
  def _compile_node(
      self,
      tfx_node: base_node.BaseNode,
      compile_context: _CompilerContext,
      deployment_config: pipeline_pb2.IntermediateDeploymentConfig,
      enable_cache: bool,
  ) -> pipeline_pb2.PipelineNode:
    """Compiles an individual TFX node into a PipelineNode proto.

    Args:
      tfx_node: A TFX node.
      compile_context: Resources needed to compile the node.
      deployment_config: Intermediate deployment config to set. Will include
        related specs for executors, drivers and platform specific configs.
      enable_cache: whether cache is enabled

    Raises:
      TypeError: When supplied tfx_node has values of invalid type.

    Returns:
      A PipelineNode proto that encodes information of the node.
    """
    node = pipeline_pb2.PipelineNode()

    # Step 1: Node info
    node.node_info.type.name = tfx_node.type
    if isinstance(tfx_node,
                  base_component.BaseComponent) and tfx_node.type_annotation:
      node.node_info.type.base_type = (
          tfx_node.type_annotation.MLMD_SYSTEM_BASE_TYPE)
    node.node_info.id = tfx_node.id

    # Step 2: Node Context
    # Context for the pipeline, across pipeline runs.
    pipeline_context_pb = node.contexts.contexts.add()
    pipeline_context_pb.type.name = constants.PIPELINE_CONTEXT_TYPE_NAME
    pipeline_context_pb.name.field_value.string_value = compile_context.pipeline_info.pipeline_context_name

    # Context for the current pipeline run.
    if compile_context.is_sync_mode:
      pipeline_run_context_pb = node.contexts.contexts.add()
      pipeline_run_context_pb.type.name = constants.PIPELINE_RUN_CONTEXT_TYPE_NAME
      compiler_utils.set_runtime_parameter_pb(
          pipeline_run_context_pb.name.runtime_parameter,
          constants.PIPELINE_RUN_ID_PARAMETER_NAME, str)

    # Context for the node, across pipeline runs.
    node_context_pb = node.contexts.contexts.add()
    node_context_pb.type.name = constants.NODE_CONTEXT_TYPE_NAME
    node_context_pb.name.field_value.string_value = (
        compiler_utils.node_context_name(
            compile_context.pipeline_info.pipeline_context_name,
            node.node_info.id))

    # Pre Step 3: Alter graph topology if needed.
    if compile_context.is_async_mode:
      tfx_node_inputs = self._embed_upstream_resolver_nodes(
          compile_context, tfx_node, node)
    else:
      tfx_node_inputs = tfx_node.inputs

    # Step 3: Node inputs

    # Step 3.1: Generate implicit input channels
    # Step 3.1.1: Conditionals
    implicit_input_channels = {}
    predicates = conditional.get_predicates(tfx_node)
    if predicates:
      implicit_keys_map = {}
      for key, chnl in tfx_node_inputs.items():
        if not isinstance(chnl, types.Channel):
          raise ValueError(
              "Conditional only support using channel as a predicate.")
        implicit_keys_map[compiler_utils.implicit_channel_key(chnl)] = key
      encoded_predicates = []
      for predicate in predicates:
        for chnl in predicate.dependent_channels():
          implicit_key = compiler_utils.implicit_channel_key(chnl)
          if implicit_key not in implicit_keys_map:
            # Store this channel and add it to the node inputs later.
            implicit_input_channels[implicit_key] = chnl
        encoded_predicates.append(
            predicate.encode_with_keys(
                compiler_utils.build_channel_to_key_fn(implicit_keys_map)))
      # In async pipeline, conditional resolver step should be the last step
      # in all resolver steps of a node.
      resolver_step = node.inputs.resolver_config.resolver_steps.add()
      resolver_step.class_path = constants.CONDITIONAL_RESOLVER_CLASS_PATH
      resolver_step.config_json = json_utils.dumps(
          {"predicates": encoded_predicates})

    # Step 3.1.2: Add placeholder exec props to implicit_input_channels
    for key, value in tfx_node.exec_properties.items():
      if isinstance(value, placeholder.ChannelWrappedPlaceholder):
        if not (inspect.isclass(value.channel.type) and
                issubclass(value.channel.type, value_artifact.ValueArtifact)):
          raise ValueError("output channel to dynamic exec properties is not "
                           "ValueArtifact")
        implicit_key = compiler_utils.implicit_channel_key(value.channel)
        implicit_input_channels[implicit_key] = value.channel

    # Step 3.2: Handle ForEach.
    dsl_contexts = context_manager.get_contexts(tfx_node)
    for dsl_context in dsl_contexts:
      if isinstance(dsl_context, for_each.ForEachContext):
        for input_key, channel in tfx_node_inputs.items():
          if (isinstance(channel, types.LoopVarChannel) and
              channel.wrapped is dsl_context.wrapped_channel):
            node.inputs.resolver_config.resolver_steps.extend(
                _compile_for_each_context(input_key))
            break
        else:
          # Ideally should not reach here as the same check is performed at
          # ForEachContext.will_add_node().
          raise ValueError(
              f"Unable to locate ForEach loop variable {dsl_context.channel} "
              f"from inputs of node {tfx_node.id}.")

    # Check loop variable is used outside the ForEach.
    for input_key, channel in tfx_node_inputs.items():
      if isinstance(channel, types.LoopVarChannel):
        dsl_context_ids = {c.id for c in dsl_contexts}
        if channel.context_id not in dsl_context_ids:
          raise ValueError(
              "Loop variable cannot be used outside the ForEach "
              f"(node_id = {tfx_node.id}, input_key = {input_key}).")

    # Step 3.3: Fill node inputs
    for key, value in itertools.chain(tfx_node_inputs.items(),
                                      implicit_input_channels.items()):
      input_spec = node.inputs.inputs[key]
      for input_channel in channel_utils.get_individual_channels(value):
        chnl = input_spec.channels.add()

        # If the node input comes from another node's output, fill the context
        # queries with the producer node's contexts.
        if input_channel in compile_context.node_outputs:
          chnl.producer_node_query.id = input_channel.producer_component_id

          # Here we rely on pipeline.components to be topologically sorted.
          assert input_channel.producer_component_id in compile_context.node_pbs, (
              "producer component should have already been compiled.")
          producer_pb = compile_context.node_pbs[
              input_channel.producer_component_id]
          for producer_context in producer_pb.contexts.contexts:
            context_query = chnl.context_queries.add()
            context_query.type.CopyFrom(producer_context.type)
            context_query.name.CopyFrom(producer_context.name)

        # If the node input does not come from another node's output, fill the
        # context queries based on Channel info. We requires every channel to
        # have pipeline context and will fill it automatically.
        else:
          # Add pipeline context query.
          context_query = chnl.context_queries.add()
          context_query.type.CopyFrom(pipeline_context_pb.type)
          context_query.name.CopyFrom(pipeline_context_pb.name)

          # Optionally add node context query.
          if input_channel.producer_component_id:
            # Add node context query if `producer_component_id` is present.
            chnl.producer_node_query.id = input_channel.producer_component_id
            node_context_query = chnl.context_queries.add()
            node_context_query.type.name = constants.NODE_CONTEXT_TYPE_NAME
            node_context_query.name.field_value.string_value = "{}.{}".format(
                compile_context.pipeline_info.pipeline_context_name,
                input_channel.producer_component_id)

        artifact_type = input_channel.type._get_artifact_type()  # pylint: disable=protected-access
        chnl.artifact_query.type.CopyFrom(artifact_type)
        chnl.artifact_query.type.ClearField("properties")

        if input_channel.output_key:
          chnl.output_key = input_channel.output_key

        # Set NodeInputs.min_count.
        if isinstance(tfx_node, base_component.BaseComponent):
          if key in implicit_input_channels:
            # Mark all input channel as optional for implicit inputs
            # (e.g. conditionals). This is suboptimal, but still a safe guess to
            # avoid breaking the pipeline run.
            input_spec.min_count = 0
          else:
            try:
              # Calculating min_count from ComponentSpec.INPUTS.
              if tfx_node.spec.is_optional_input(key):
                input_spec.min_count = 0
              else:
                input_spec.min_count = 1
            except KeyError:
              # Currently we can fall here if the upstream resolver node inputs
              # are embedded into the current node (in async mode). We always
              # regard resolver's inputs as optional.
              if compile_context.is_async_mode:
                input_spec.min_count = 0
              else:
                raise

    # TODO(b/170694459): Refactor special nodes as plugins.
    # Step 3.4: Special treatment for Resolver node.
    if compiler_utils.is_resolver(tfx_node):
      assert compile_context.is_sync_mode
      node.inputs.resolver_config.resolver_steps.extend(
          _compile_resolver_node(tfx_node))

    # Step 4: Node outputs
    for key, value in tfx_node.outputs.items():
      # Register the output in the context.
      compile_context.node_outputs.add(value)
    if (isinstance(tfx_node, base_component.BaseComponent) or
        compiler_utils.is_importer(tfx_node)):
      self._compile_node_outputs(tfx_node, node)

    # Step 5: Node parameters
    if not compiler_utils.is_resolver(tfx_node):
      for key, value in tfx_node.exec_properties.items():
        if value is None:
          continue
        parameter_value = node.parameters.parameters[key]

        # Order matters, because runtime parameter can be in serialized string.
        if isinstance(value, data_types.RuntimeParameter):
          compiler_utils.set_runtime_parameter_pb(
              parameter_value.runtime_parameter, value.name, value.ptype,
              value.default)
        # RuntimeInfoPlaceholder passes Execution parameters of Facade
        # components.
        elif isinstance(value, placeholder.RuntimeInfoPlaceholder):
          parameter_value.placeholder.CopyFrom(value.encode())
        # ChannelWrappedPlaceholder passes dynamic execution parameter.
        elif isinstance(value, placeholder.ChannelWrappedPlaceholder):
          compiler_utils.validate_dynamic_exec_ph_operator(value)
          parameter_value.placeholder.CopyFrom(
              value.encode_with_keys(compiler_utils.implicit_channel_key))
        else:
          try:
            data_types_utils.set_parameter_value(parameter_value, value)
          except ValueError:
            raise ValueError(
                "Component {} got unsupported parameter {} with type {}."
                .format(tfx_node.id, key, type(value))) from ValueError

    # Step 6: Executor spec and optional driver spec for components
    if isinstance(tfx_node, base_component.BaseComponent):
      executor_spec = tfx_node.executor_spec.encode(
          component_spec=tfx_node.spec)
      deployment_config.executor_specs[tfx_node.id].Pack(executor_spec)

      # TODO(b/163433174): Remove specialized logic once generalization of
      # driver spec is done.
      if tfx_node.driver_class != base_driver.BaseDriver:
        driver_class_path = _fully_qualified_name(tfx_node.driver_class)
        driver_spec = executable_spec_pb2.PythonClassExecutableSpec()
        driver_spec.class_path = driver_class_path
        deployment_config.custom_driver_specs[tfx_node.id].Pack(driver_spec)

    # Step 7: Upstream/Downstream nodes
    node.upstream_nodes.extend(
        self._find_runtime_upstream_node_ids(compile_context, tfx_node))
    node.downstream_nodes.extend(
        self._find_runtime_downstream_node_ids(compile_context, tfx_node))

    # Step 8: Node execution options
    node.execution_options.caching_options.enable_cache = enable_cache

    # Step 9: Per-node platform config
    if isinstance(tfx_node, base_component.BaseComponent):
      tfx_component = cast(base_component.BaseComponent, tfx_node)
      if tfx_component.platform_config:
        deployment_config.node_level_platform_configs[tfx_node.id].Pack(
            tfx_component.platform_config)

    return node