def testRunExecutor_with_InprocessExecutor(self):
   executor_sepc = text_format.Parse(
       """
     class_path: "tfx.orchestration.portable.python_executor_operator_test.InprocessExecutor"
   """, executable_spec_pb2.PythonClassExecutableSpec())
   operator = python_executor_operator.PythonExecutorOperator(executor_sepc)
   input_dict = {'input_key': [standard_artifacts.Examples()]}
   output_dict = {'output_key': [standard_artifacts.Model()]}
   exec_properties = {'key': 'value'}
   stateful_working_dir = os.path.join(self.tmp_dir, 'stateful_working_dir')
   executor_output_uri = os.path.join(self.tmp_dir, 'executor_output')
   executor_output = operator.run_executor(
       base_executor_operator.ExecutionInfo(
           input_dict=input_dict,
           output_dict=output_dict,
           exec_properties=exec_properties,
           stateful_working_dir=stateful_working_dir,
           executor_output_uri=executor_output_uri))
   self.assertProtoPartiallyEquals("""
         execution_properties {
           key: "key"
           value {
             string_value: "value"
           }
         }
         output_artifacts {
           key: "output_key"
           value {
             artifacts {
             }
           }
         }""", executor_output)
Exemple #2
0
 def testRunExecutor_with_InplaceUpdateExecutor(self):
     executor_sepc = text_format.Parse(
         """
   class_path: "tfx.orchestration.portable.python_executor_operator_test.InplaceUpdateExecutor"
 """, executable_spec_pb2.PythonClassExecutableSpec())
     operator = python_executor_operator.PythonExecutorOperator(
         executor_sepc)
     input_dict = {'input_key': [standard_artifacts.Examples()]}
     output_dict = {'output_key': [standard_artifacts.Model()]}
     exec_properties = {
         'string': 'value',
         'int': 1,
         'float': 0.0,
         # This should not happen on production and will be
         # dropped.
         'proto': execution_result_pb2.ExecutorOutput()
     }
     executor_output = operator.run_executor(
         self._get_execution_info(input_dict, output_dict, exec_properties))
     self.assertProtoPartiallyEquals(
         """
       output_artifacts {
         key: "output_key"
         value {
           artifacts {
             custom_properties {
               key: "name"
               value {
                 string_value: "MyPipeline.MyPythonNode.my_model"
               }
             }
           }
         }
       }""", executor_output)
 def succeed(self):
   custom_driver_spec = (executable_spec_pb2.PythonClassExecutableSpec())
   custom_driver_spec.class_path = 'tfx.orchestration.portable.python_driver_operator._FakeNoopDriver'
   driver_operator = python_driver_operator.PythonDriverOperator(
       custom_driver_spec, None, None, None)
   driver_output = driver_operator.run_driver(None, None, None)
   self.assertEqual(driver_output, _DEFAULT_DRIVER_OUTPUT)
Exemple #4
0
 def testRunExecutor_with_InprocessExecutor(self):
     executor_sepc = text_format.Parse(
         """
   class_path: "tfx.orchestration.portable.python_executor_operator_test.InprocessExecutor"
 """, executable_spec_pb2.PythonClassExecutableSpec())
     operator = python_executor_operator.PythonExecutorOperator(
         executor_sepc)
     input_dict = {'input_key': [standard_artifacts.Examples()]}
     output_dict = {'output_key': [standard_artifacts.Model()]}
     exec_properties = {'key': 'value'}
     executor_output = operator.run_executor(
         self._get_execution_info(input_dict, output_dict, exec_properties))
     self.assertProtoPartiallyEquals(
         """
       execution_properties {
         key: "key"
         value {
           string_value: "value"
         }
       }
       output_artifacts {
         key: "output_key"
         value {
           artifacts {
           }
         }
       }""", executor_output)
Exemple #5
0
 def encode(
     self,
     component_spec: Optional[types.ComponentSpec] = None) -> message.Message:
   result = executable_spec_pb2.PythonClassExecutableSpec()
   result.class_path = self.class_path
   result.extra_flags.extend(self.extra_flags)
   return result
Exemple #6
0
  def testExecutableSpecSerialization(self):
    python_executable_spec = text_format.Parse(
        """
        class_path: 'path_to_my_class'
        extra_flags: '--flag=my_flag'
        """, executable_spec_pb2.PythonClassExecutableSpec())
    python_serialized = python_execution_binary_utils.serialize_executable_spec(
        python_executable_spec)
    python_rehydrated = python_execution_binary_utils.deserialize_executable_spec(
        python_serialized)
    self.assertProtoEquals(python_rehydrated, python_executable_spec)

    beam_executable_spec = text_format.Parse(
        """
        python_executor_spec {
          class_path: 'path_to_my_class'
          extra_flags: '--flag1=1'
        }
        beam_pipeline_args: '--arg=my_beam_pipeline_arg'
        """, executable_spec_pb2.BeamExecutableSpec())
    beam_serialized = python_execution_binary_utils.serialize_executable_spec(
        beam_executable_spec)
    beam_rehydrated = python_execution_binary_utils.deserialize_executable_spec(
        beam_serialized, with_beam=True)
    self.assertProtoEquals(beam_rehydrated, beam_executable_spec)
Exemple #7
0
 def testRunExecutor_with_InplaceUpdateExecutor(self):
     executor_sepc = text_format.Parse(
         """
   class_path: "tfx.orchestration.portable.python_executor_operator_test.InplaceUpdateExecutor"
 """, executable_spec_pb2.PythonClassExecutableSpec())
     operator = python_executor_operator.PythonExecutorOperator(
         executor_sepc)
     input_dict = {'input_key': [standard_artifacts.Examples()]}
     output_dict = {'output_key': [standard_artifacts.Model()]}
     exec_properties = {
         'string': 'value',
         'int': 1,
         'float': 0.0,
         # This should not happen on production and will be
         # dropped.
         'proto': execution_result_pb2.ExecutorOutput()
     }
     stateful_working_dir = os.path.join(self.tmp_dir,
                                         'stateful_working_dir')
     executor_output_uri = os.path.join(self.tmp_dir, 'executor_output')
     executor_output = operator.run_executor(
         base_executor_operator.ExecutionInfo(
             input_dict=input_dict,
             output_dict=output_dict,
             exec_properties=exec_properties,
             stateful_working_dir=stateful_working_dir,
             executor_output_uri=executor_output_uri))
     self.assertProtoPartiallyEquals(
         """
       execution_properties {
         key: "float"
         value {
           double_value: 0.0
         }
       }
       execution_properties {
         key: "int"
         value {
           int_value: 1
         }
       }
       execution_properties {
         key: "string"
         value {
           string_value: "value"
         }
       }
       output_artifacts {
         key: "output_key"
         value {
           artifacts {
             custom_properties {
               key: "name"
               value {
                 string_value: "my_model"
               }
             }
           }
         }
       }""", executor_output)
Exemple #8
0
 def testGetCacheContextTwiceDifferentExecutorSpec(self):
     with metadata.Metadata(connection_config=self._connection_config) as m:
         self._get_cache_context(m)
         self._get_cache_context(m,
                                 executor_spec=text_format.Parse(
                                     """
           class_path: "new.class.path"
           """, executable_spec_pb2.PythonClassExecutableSpec()))
         # Different executor spec will result in new cache context.
         self.assertLen(m.store.get_contexts(), 2)
 def setUp(self):
     super(PlaceholderUtilsTest, self).setUp()
     examples = [standard_artifacts.Examples()]
     examples[0].uri = "/tmp"
     examples[0].split_names = artifact_utils.encode_split_names(
         ["train", "eval"])
     self._serving_spec = infra_validator_pb2.ServingSpec()
     self._serving_spec.tensorflow_serving.tags.extend(
         ["latest", "1.15.0-gpu"])
     self._resolution_context = placeholder_utils.ResolutionContext(
         exec_info=data_types.ExecutionInfo(
             input_dict={
                 "model": [standard_artifacts.Model()],
                 "examples": examples,
             },
             output_dict={"blessing": [standard_artifacts.ModelBlessing()]},
             exec_properties={
                 "proto_property":
                 json_format.MessageToJson(message=self._serving_spec,
                                           sort_keys=True,
                                           preserving_proto_field_name=True,
                                           indent=0)
             },
             execution_output_uri="test_executor_output_uri",
             stateful_working_dir="test_stateful_working_dir",
             pipeline_node=pipeline_pb2.PipelineNode(
                 node_info=pipeline_pb2.NodeInfo(
                     type=metadata_store_pb2.ExecutionType(
                         name="infra_validator"))),
             pipeline_info=pipeline_pb2.PipelineInfo(
                 id="test_pipeline_id")),
         executor_spec=executable_spec_pb2.PythonClassExecutableSpec(
             class_path="test_class_path"),
     )
     # Resolution context to simulate missing optional values.
     self._none_resolution_context = placeholder_utils.ResolutionContext(
         exec_info=data_types.ExecutionInfo(
             input_dict={
                 "model": [],
                 "examples": [],
             },
             output_dict={"blessing": []},
             exec_properties={},
             pipeline_node=pipeline_pb2.PipelineNode(
                 node_info=pipeline_pb2.NodeInfo(
                     type=metadata_store_pb2.ExecutionType(
                         name="infra_validator"))),
             pipeline_info=pipeline_pb2.PipelineInfo(
                 id="test_pipeline_id")),
         executor_spec=None,
         platform_config=None)
 def testRunExecutorWithBeamPipelineArgs(self):
   executor_sepc = text_format.Parse(
       """
     class_path: "tfx.orchestration.portable.python_executor_operator_test.ValidateBeamPipelineArgsExecutor"
     extra_flags: "--runner=DirectRunner"
   """, executable_spec_pb2.PythonClassExecutableSpec())
   operator = python_executor_operator.PythonExecutorOperator(executor_sepc)
   executor_output_uri = os.path.join(self.tmp_dir, 'executor_output')
   operator.run_executor(
       data_types.ExecutionInfo(
           input_dict={},
           output_dict={},
           exec_properties={},
           execution_output_uri=executor_output_uri))
Exemple #11
0
 def setUp(self):
     super().setUp()
     self._connection_config = metadata_store_pb2.ConnectionConfig()
     self._connection_config.sqlite.SetInParent()
     self._module_file_path = os.path.join(self.tmp_dir, 'module_file')
     self._input_artifacts = {
         'input_examples': [standard_artifacts.Examples()]
     }
     self._output_artifacts = {
         'output_models': [standard_artifacts.Model()]
     }
     self._parameters = {'module_file': self._module_file_path}
     self._module_file_content = 'module content'
     self._pipeline_node = text_format.Parse(
         """
     node_info {
       id: "my_id"
     }
     """, pipeline_pb2.PipelineNode())
     self._pipeline_info = pipeline_pb2.PipelineInfo(id='pipeline_id')
     self._executor_spec = text_format.Parse(
         """
     class_path: "my.class.path"
     """, executable_spec_pb2.PythonClassExecutableSpec())
Exemple #12
0
    def _compile_node(
        self,
        tfx_node: base_node.BaseNode,
        compile_context: _CompilerContext,
        deployment_config: pipeline_pb2.IntermediateDeploymentConfig,
        enable_cache: bool,
    ) -> pipeline_pb2.PipelineNode:
        """Compiles an individual TFX node into a PipelineNode proto.

    Args:
      tfx_node: A TFX node.
      compile_context: Resources needed to compile the node.
      deployment_config: Intermediate deployment config to set. Will include
        related specs for executors, drivers and platform specific configs.
      enable_cache: whether cache is enabled

    Raises:
      TypeError: When supplied tfx_node has values of invalid type.

    Returns:
      A PipelineNode proto that encodes information of the node.
    """
        node = pipeline_pb2.PipelineNode()

        # Step 1: Node info
        node.node_info.type.name = tfx_node.type
        node.node_info.id = tfx_node.id

        # Step 2: Node Context
        # Context for the pipeline, across pipeline runs.
        pipeline_context_pb = node.contexts.contexts.add()
        pipeline_context_pb.type.name = constants.PIPELINE_CONTEXT_TYPE_NAME
        pipeline_context_pb.name.field_value.string_value = compile_context.pipeline_info.pipeline_context_name

        # Context for the current pipeline run.
        if compile_context.is_sync_mode:
            pipeline_run_context_pb = node.contexts.contexts.add()
            pipeline_run_context_pb.type.name = constants.PIPELINE_RUN_CONTEXT_TYPE_NAME
            compiler_utils.set_runtime_parameter_pb(
                pipeline_run_context_pb.name.runtime_parameter,
                constants.PIPELINE_RUN_ID_PARAMETER_NAME, str)

        # Context for the node, across pipeline runs.
        node_context_pb = node.contexts.contexts.add()
        node_context_pb.type.name = constants.NODE_CONTEXT_TYPE_NAME
        node_context_pb.name.field_value.string_value = "{}.{}".format(
            compile_context.pipeline_info.pipeline_context_name,
            node.node_info.id)

        # Pre Step 3: Alter graph topology if needed.
        if compile_context.is_async_mode:
            tfx_node_inputs = self._compile_resolver_config(
                compile_context, tfx_node, node)
        else:
            tfx_node_inputs = tfx_node.inputs

        # Step 3: Node inputs
        for key, value in tfx_node_inputs.items():
            input_spec = node.inputs.inputs[key]
            channel = input_spec.channels.add()
            if value.producer_component_id:
                channel.producer_node_query.id = value.producer_component_id

                # Here we rely on pipeline.components to be topologically sorted.
                assert value.producer_component_id in compile_context.node_pbs, (
                    "producer component should have already been compiled.")
                producer_pb = compile_context.node_pbs[
                    value.producer_component_id]
                for producer_context in producer_pb.contexts.contexts:
                    if (not compiler_utils.is_resolver(tfx_node)
                            or producer_context.name.runtime_parameter.name !=
                            constants.PIPELINE_RUN_CONTEXT_TYPE_NAME):
                        context_query = channel.context_queries.add()
                        context_query.type.CopyFrom(producer_context.type)
                        context_query.name.CopyFrom(producer_context.name)
            else:
                # Caveat: portable core requires every channel to have at least one
                # Contex. But For cases like system nodes and producer-consumer
                # pipelines, a channel may not have contexts at all. In these cases,
                # we want to use the pipeline level context as the input channel
                # context.
                context_query = channel.context_queries.add()
                context_query.type.CopyFrom(pipeline_context_pb.type)
                context_query.name.CopyFrom(pipeline_context_pb.name)

            artifact_type = value.type._get_artifact_type()  # pylint: disable=protected-access
            channel.artifact_query.type.CopyFrom(artifact_type)
            channel.artifact_query.type.ClearField("properties")

            if value.output_key:
                channel.output_key = value.output_key

            # TODO(b/158712886): Calculate min_count based on if inputs are optional.
            # min_count = 0 stands for optional input and 1 stands for required input.

        # Step 3.1: Special treatment for Resolver node.
        if compiler_utils.is_resolver(tfx_node):
            assert compile_context.is_sync_mode
            node.inputs.resolver_config.resolver_steps.extend(
                _convert_to_resolver_steps(tfx_node))

        # Step 4: Node outputs
        if isinstance(tfx_node, base_component.BaseComponent):
            for key, value in tfx_node.outputs.items():
                output_spec = node.outputs.outputs[key]
                artifact_type = value.type._get_artifact_type()  # pylint: disable=protected-access
                output_spec.artifact_spec.type.CopyFrom(artifact_type)
                for prop_key, prop_value in value.additional_properties.items(
                ):
                    _check_property_value_type(prop_key, prop_value,
                                               output_spec.artifact_spec.type)
                    data_types_utils.set_metadata_value(
                        output_spec.artifact_spec.
                        additional_properties[prop_key].field_value,
                        prop_value)
                for prop_key, prop_value in value.additional_custom_properties.items(
                ):
                    data_types_utils.set_metadata_value(
                        output_spec.artifact_spec.
                        additional_custom_properties[prop_key].field_value,
                        prop_value)

        # TODO(b/170694459): Refactor special nodes as plugins.
        # Step 4.1: Special treament for Importer node
        if compiler_utils.is_importer(tfx_node):
            self._compile_importer_node_outputs(tfx_node, node)

        # Step 5: Node parameters
        if not compiler_utils.is_resolver(tfx_node):
            for key, value in tfx_node.exec_properties.items():
                if value is None:
                    continue
                # Ignore following two properties for a importer node, because they are
                # already attached to the artifacts produced by the importer node.
                if compiler_utils.is_importer(tfx_node) and (
                        key == importer.PROPERTIES_KEY
                        or key == importer.CUSTOM_PROPERTIES_KEY):
                    continue
                parameter_value = node.parameters.parameters[key]

                # Order matters, because runtime parameter can be in serialized string.
                if isinstance(value, data_types.RuntimeParameter):
                    compiler_utils.set_runtime_parameter_pb(
                        parameter_value.runtime_parameter, value.name,
                        value.ptype, value.default)
                elif isinstance(value, str) and re.search(
                        data_types.RUNTIME_PARAMETER_PATTERN, value):
                    runtime_param = json.loads(value)
                    compiler_utils.set_runtime_parameter_pb(
                        parameter_value.runtime_parameter, runtime_param.name,
                        runtime_param.ptype, runtime_param.default)
                else:
                    try:
                        data_types_utils.set_metadata_value(
                            parameter_value.field_value, value)
                    except ValueError:
                        raise ValueError(
                            "Component {} got unsupported parameter {} with type {}."
                            .format(tfx_node.id, key, type(value)))

        # Step 6: Executor spec and optional driver spec for components
        if isinstance(tfx_node, base_component.BaseComponent):
            executor_spec = tfx_node.executor_spec.encode(
                component_spec=tfx_node.spec)
            deployment_config.executor_specs[tfx_node.id].Pack(executor_spec)

            # TODO(b/163433174): Remove specialized logic once generalization of
            # driver spec is done.
            if tfx_node.driver_class != base_driver.BaseDriver:
                driver_class_path = "{}.{}".format(
                    tfx_node.driver_class.__module__,
                    tfx_node.driver_class.__name__)
                driver_spec = executable_spec_pb2.PythonClassExecutableSpec()
                driver_spec.class_path = driver_class_path
                deployment_config.custom_driver_specs[tfx_node.id].Pack(
                    driver_spec)

        # Step 7: Upstream/Downstream nodes
        # Note: the order of tfx_node.upstream_nodes is inconsistent from
        # run to run. We sort them so that compiler generates consistent results.
        # For ASYNC mode upstream/downstream node information is not set as
        # compiled IR graph topology can be different from that on pipeline
        # authoring time; for example ResolverNode is removed.
        if compile_context.is_sync_mode:
            node.upstream_nodes.extend(
                sorted(node.id for node in tfx_node.upstream_nodes))
            node.downstream_nodes.extend(
                sorted(node.id for node in tfx_node.downstream_nodes))

        # Step 8: Node execution options
        node.execution_options.caching_options.enable_cache = enable_cache

        # Step 9: Per-node platform config
        if isinstance(tfx_node, base_component.BaseComponent):
            tfx_component = cast(base_component.BaseComponent, tfx_node)
            if tfx_component.platform_config:
                deployment_config.node_level_platform_configs[
                    tfx_node.id].Pack(tfx_component.platform_config)

        return node
Exemple #13
0
    def _compile_node(
        self, tfx_node: base_node.BaseNode, compile_context: _CompilerContext,
        deployment_config: pipeline_pb2.IntermediateDeploymentConfig
    ) -> pipeline_pb2.PipelineNode:
        """Compiles an individual TFX node into a PipelineNode proto.

    Args:
      tfx_node: A TFX node.
      compile_context: Resources needed to compile the node.
      deployment_config: Intermediate deployment config to set. Will include
        related specs for executors, drivers and platform specific configs.

    Returns:
      A PipelineNode proto that encodes information of the node.
    """
        node = pipeline_pb2.PipelineNode()

        # Step 1: Node info
        node.node_info.type.name = tfx_node.type
        node.node_info.id = tfx_node.id

        # Step 2: Node Context
        # Context for the pipeline, across pipeline runs.
        pipeline_context_pb = node.contexts.contexts.add()
        pipeline_context_pb.type.name = constants.PIPELINE_CONTEXT_TYPE_NAME
        pipeline_context_pb.name.field_value.string_value = compile_context.pipeline_info.pipeline_context_name

        # Context for the current pipeline run.
        if (compile_context.execution_mode ==
                pipeline_pb2.Pipeline.ExecutionMode.SYNC):
            pipeline_run_context_pb = node.contexts.contexts.add()
            pipeline_run_context_pb.type.name = constants.PIPELINE_RUN_CONTEXT_TYPE_NAME
            compiler_utils.set_runtime_parameter_pb(
                pipeline_run_context_pb.name.runtime_parameter,
                constants.PIPELINE_RUN_CONTEXT_TYPE_NAME, str)

        # Context for the node, across pipeline runs.
        node_context_pb = node.contexts.contexts.add()
        node_context_pb.type.name = constants.NODE_CONTEXT_TYPE_NAME
        node_context_pb.name.field_value.string_value = "{}.{}".format(
            compile_context.pipeline_info.pipeline_context_name,
            node.node_info.id)

        # Step 3: Node inputs
        for key, value in tfx_node.inputs.items():
            input_spec = node.inputs.inputs[key]
            channel = input_spec.channels.add()
            if value.producer_component_id:
                channel.producer_node_query.id = value.producer_component_id

                # Here we rely on pipeline.components to be topologically sorted.
                assert value.producer_component_id in compile_context.node_pbs, (
                    "producer component should have already been compiled.")
                producer_pb = compile_context.node_pbs[
                    value.producer_component_id]
                for producer_context in producer_pb.contexts.contexts:
                    if (not compiler_utils.is_resolver(tfx_node)
                            or producer_context.name.runtime_parameter.name !=
                            constants.PIPELINE_RUN_CONTEXT_TYPE_NAME):
                        context_query = channel.context_queries.add()
                        context_query.type.CopyFrom(producer_context.type)
                        context_query.name.CopyFrom(producer_context.name)

            artifact_type = value.type._get_artifact_type()  # pylint: disable=protected-access
            channel.artifact_query.type.CopyFrom(artifact_type)
            channel.artifact_query.type.ClearField("properties")

            if value.output_key:
                channel.output_key = value.output_key

            # TODO(b/158712886): Calculate min_count based on if inputs are optional.
            # min_count = 0 stands for optional input and 1 stands for required input.

        # Step 3.1: Special treatment for Resolver node
        if compiler_utils.is_resolver(tfx_node):
            resolver = tfx_node.exec_properties[resolver_node.RESOLVER_CLASS]
            if resolver == latest_artifacts_resolver.LatestArtifactsResolver:
                node.inputs.resolver_config.resolver_policy = (
                    pipeline_pb2.ResolverConfig.ResolverPolicy.LATEST_ARTIFACT)
            elif resolver == latest_blessed_model_resolver.LatestBlessedModelResolver:
                node.inputs.resolver_config.resolver_policy = (
                    pipeline_pb2.ResolverConfig.ResolverPolicy.
                    LATEST_BLESSED_MODEL)
            else:
                raise ValueError("Got unsupported resolver policy: {}".format(
                    resolver.type))

        # Step 4: Node outputs
        if compiler_utils.is_component(tfx_node):
            for key, value in tfx_node.outputs.items():
                output_spec = node.outputs.outputs[key]
                artifact_type = value.type._get_artifact_type()  # pylint: disable=protected-access
                output_spec.artifact_spec.type.CopyFrom(artifact_type)

        # Step 5: Node parameters
        if not compiler_utils.is_resolver(tfx_node):
            for key, value in tfx_node.exec_properties.items():
                if value is None:
                    continue
                parameter_value = node.parameters.parameters[key]

                # Order matters, because runtime parameter can be in serialized string.
                if isinstance(value, data_types.RuntimeParameter):
                    compiler_utils.set_runtime_parameter_pb(
                        parameter_value.runtime_parameter, value.name,
                        value.ptype, value.default)
                elif isinstance(value, str) and re.search(
                        data_types.RUNTIME_PARAMETER_PATTERN, value):
                    runtime_param = json.loads(value)
                    compiler_utils.set_runtime_parameter_pb(
                        parameter_value.runtime_parameter, runtime_param.name,
                        runtime_param.ptype, runtime_param.default)
                elif isinstance(value, str):
                    parameter_value.field_value.string_value = value
                elif isinstance(value, int):
                    parameter_value.field_value.int_value = value
                elif isinstance(value, float):
                    parameter_value.field_value.double_value = value
                else:
                    raise ValueError(
                        "Component {} got unsupported parameter {} with type {}."
                        .format(tfx_node.id, key, type(value)))

        # Step 6: Executor spec and optional driver spec for components
        if compiler_utils.is_component(tfx_node):
            executor_spec = tfx_node.executor_spec.encode()
            deployment_config.executor_specs[tfx_node.id].Pack(executor_spec)

            # TODO(b/163433174): Remove specialized logic once generalization of
            # driver spec is done.
            if tfx_node.driver_class != base_driver.BaseDriver:
                driver_class_path = "{}.{}".format(
                    tfx_node.driver_class.__module__,
                    tfx_node.driver_class.__name__)
                driver_spec = executable_spec_pb2.PythonClassExecutableSpec()
                driver_spec.class_path = driver_class_path
                deployment_config.custom_driver_specs[tfx_node.id].Pack(
                    driver_spec)

        # Step 7: Upstream/Downstream nodes
        # Note: the order of tfx_node.upstream_nodes is inconsistent from
        # run to run. We sort them so that compiler generates consistent results.
        node.upstream_nodes.extend(
            sorted([
                upstream_component.id
                for upstream_component in tfx_node.upstream_nodes
            ]))
        node.downstream_nodes.extend(
            sorted([
                downstream_component.id
                for downstream_component in tfx_node.downstream_nodes
            ]))

        # Step 8: Node execution options
        # TODO(kennethyang): Add support for node execution options.

        return node
Exemple #14
0
 def encode(self) -> message.Message:
   result = executable_spec_pb2.PythonClassExecutableSpec()
   result.class_path = self.class_path
   result.extra_flags.extend(self.extra_flags)
   return result
Exemple #15
0
  def _compile_node(
      self,
      tfx_node: base_node.BaseNode,
      compile_context: _CompilerContext,
      deployment_config: pipeline_pb2.IntermediateDeploymentConfig,
      enable_cache: bool,
  ) -> pipeline_pb2.PipelineNode:
    """Compiles an individual TFX node into a PipelineNode proto.

    Args:
      tfx_node: A TFX node.
      compile_context: Resources needed to compile the node.
      deployment_config: Intermediate deployment config to set. Will include
        related specs for executors, drivers and platform specific configs.
      enable_cache: whether cache is enabled

    Raises:
      TypeError: When supplied tfx_node has values of invalid type.

    Returns:
      A PipelineNode proto that encodes information of the node.
    """
    node = pipeline_pb2.PipelineNode()

    # Step 1: Node info
    node.node_info.type.name = tfx_node.type
    if isinstance(tfx_node,
                  base_component.BaseComponent) and tfx_node.type_annotation:
      node.node_info.type.base_type = (
          tfx_node.type_annotation.MLMD_SYSTEM_BASE_TYPE)
    node.node_info.id = tfx_node.id

    # Step 2: Node Context
    # Context for the pipeline, across pipeline runs.
    pipeline_context_pb = node.contexts.contexts.add()
    pipeline_context_pb.type.name = constants.PIPELINE_CONTEXT_TYPE_NAME
    pipeline_context_pb.name.field_value.string_value = compile_context.pipeline_info.pipeline_context_name

    # Context for the current pipeline run.
    if compile_context.is_sync_mode:
      pipeline_run_context_pb = node.contexts.contexts.add()
      pipeline_run_context_pb.type.name = constants.PIPELINE_RUN_CONTEXT_TYPE_NAME
      compiler_utils.set_runtime_parameter_pb(
          pipeline_run_context_pb.name.runtime_parameter,
          constants.PIPELINE_RUN_ID_PARAMETER_NAME, str)

    # Context for the node, across pipeline runs.
    node_context_pb = node.contexts.contexts.add()
    node_context_pb.type.name = constants.NODE_CONTEXT_TYPE_NAME
    node_context_pb.name.field_value.string_value = (
        compiler_utils.node_context_name(
            compile_context.pipeline_info.pipeline_context_name,
            node.node_info.id))

    # Pre Step 3: Alter graph topology if needed.
    if compile_context.is_async_mode:
      tfx_node_inputs = self._embed_upstream_resolver_nodes(
          compile_context, tfx_node, node)
    else:
      tfx_node_inputs = tfx_node.inputs

    # Step 3: Node inputs

    # Step 3.1: Generate implicit input channels
    # Step 3.1.1: Conditionals
    implicit_input_channels = {}
    predicates = conditional.get_predicates(tfx_node)
    if predicates:
      implicit_keys_map = {}
      for key, chnl in tfx_node_inputs.items():
        if not isinstance(chnl, types.Channel):
          raise ValueError(
              "Conditional only support using channel as a predicate.")
        implicit_keys_map[compiler_utils.implicit_channel_key(chnl)] = key
      encoded_predicates = []
      for predicate in predicates:
        for chnl in predicate.dependent_channels():
          implicit_key = compiler_utils.implicit_channel_key(chnl)
          if implicit_key not in implicit_keys_map:
            # Store this channel and add it to the node inputs later.
            implicit_input_channels[implicit_key] = chnl
        encoded_predicates.append(
            predicate.encode_with_keys(
                compiler_utils.build_channel_to_key_fn(implicit_keys_map)))
      # In async pipeline, conditional resolver step should be the last step
      # in all resolver steps of a node.
      resolver_step = node.inputs.resolver_config.resolver_steps.add()
      resolver_step.class_path = constants.CONDITIONAL_RESOLVER_CLASS_PATH
      resolver_step.config_json = json_utils.dumps(
          {"predicates": encoded_predicates})

    # Step 3.1.2: Add placeholder exec props to implicit_input_channels
    for key, value in tfx_node.exec_properties.items():
      if isinstance(value, placeholder.ChannelWrappedPlaceholder):
        if not (inspect.isclass(value.channel.type) and
                issubclass(value.channel.type, value_artifact.ValueArtifact)):
          raise ValueError("output channel to dynamic exec properties is not "
                           "ValueArtifact")
        implicit_key = compiler_utils.implicit_channel_key(value.channel)
        implicit_input_channels[implicit_key] = value.channel

    # Step 3.2: Handle ForEach.
    dsl_contexts = context_manager.get_contexts(tfx_node)
    for dsl_context in dsl_contexts:
      if isinstance(dsl_context, for_each.ForEachContext):
        for input_key, channel in tfx_node_inputs.items():
          if (isinstance(channel, types.LoopVarChannel) and
              channel.wrapped is dsl_context.wrapped_channel):
            node.inputs.resolver_config.resolver_steps.extend(
                _compile_for_each_context(input_key))
            break
        else:
          # Ideally should not reach here as the same check is performed at
          # ForEachContext.will_add_node().
          raise ValueError(
              f"Unable to locate ForEach loop variable {dsl_context.channel} "
              f"from inputs of node {tfx_node.id}.")

    # Check loop variable is used outside the ForEach.
    for input_key, channel in tfx_node_inputs.items():
      if isinstance(channel, types.LoopVarChannel):
        dsl_context_ids = {c.id for c in dsl_contexts}
        if channel.context_id not in dsl_context_ids:
          raise ValueError(
              "Loop variable cannot be used outside the ForEach "
              f"(node_id = {tfx_node.id}, input_key = {input_key}).")

    # Step 3.3: Fill node inputs
    for key, value in itertools.chain(tfx_node_inputs.items(),
                                      implicit_input_channels.items()):
      input_spec = node.inputs.inputs[key]
      for input_channel in channel_utils.get_individual_channels(value):
        chnl = input_spec.channels.add()

        # If the node input comes from another node's output, fill the context
        # queries with the producer node's contexts.
        if input_channel in compile_context.node_outputs:
          chnl.producer_node_query.id = input_channel.producer_component_id

          # Here we rely on pipeline.components to be topologically sorted.
          assert input_channel.producer_component_id in compile_context.node_pbs, (
              "producer component should have already been compiled.")
          producer_pb = compile_context.node_pbs[
              input_channel.producer_component_id]
          for producer_context in producer_pb.contexts.contexts:
            context_query = chnl.context_queries.add()
            context_query.type.CopyFrom(producer_context.type)
            context_query.name.CopyFrom(producer_context.name)

        # If the node input does not come from another node's output, fill the
        # context queries based on Channel info. We requires every channel to
        # have pipeline context and will fill it automatically.
        else:
          # Add pipeline context query.
          context_query = chnl.context_queries.add()
          context_query.type.CopyFrom(pipeline_context_pb.type)
          context_query.name.CopyFrom(pipeline_context_pb.name)

          # Optionally add node context query.
          if input_channel.producer_component_id:
            # Add node context query if `producer_component_id` is present.
            chnl.producer_node_query.id = input_channel.producer_component_id
            node_context_query = chnl.context_queries.add()
            node_context_query.type.name = constants.NODE_CONTEXT_TYPE_NAME
            node_context_query.name.field_value.string_value = "{}.{}".format(
                compile_context.pipeline_info.pipeline_context_name,
                input_channel.producer_component_id)

        artifact_type = input_channel.type._get_artifact_type()  # pylint: disable=protected-access
        chnl.artifact_query.type.CopyFrom(artifact_type)
        chnl.artifact_query.type.ClearField("properties")

        if input_channel.output_key:
          chnl.output_key = input_channel.output_key

        # Set NodeInputs.min_count.
        if isinstance(tfx_node, base_component.BaseComponent):
          if key in implicit_input_channels:
            # Mark all input channel as optional for implicit inputs
            # (e.g. conditionals). This is suboptimal, but still a safe guess to
            # avoid breaking the pipeline run.
            input_spec.min_count = 0
          else:
            try:
              # Calculating min_count from ComponentSpec.INPUTS.
              if tfx_node.spec.is_optional_input(key):
                input_spec.min_count = 0
              else:
                input_spec.min_count = 1
            except KeyError:
              # Currently we can fall here if the upstream resolver node inputs
              # are embedded into the current node (in async mode). We always
              # regard resolver's inputs as optional.
              if compile_context.is_async_mode:
                input_spec.min_count = 0
              else:
                raise

    # TODO(b/170694459): Refactor special nodes as plugins.
    # Step 3.4: Special treatment for Resolver node.
    if compiler_utils.is_resolver(tfx_node):
      assert compile_context.is_sync_mode
      node.inputs.resolver_config.resolver_steps.extend(
          _compile_resolver_node(tfx_node))

    # Step 4: Node outputs
    for key, value in tfx_node.outputs.items():
      # Register the output in the context.
      compile_context.node_outputs.add(value)
    if (isinstance(tfx_node, base_component.BaseComponent) or
        compiler_utils.is_importer(tfx_node)):
      self._compile_node_outputs(tfx_node, node)

    # Step 5: Node parameters
    if not compiler_utils.is_resolver(tfx_node):
      for key, value in tfx_node.exec_properties.items():
        if value is None:
          continue
        parameter_value = node.parameters.parameters[key]

        # Order matters, because runtime parameter can be in serialized string.
        if isinstance(value, data_types.RuntimeParameter):
          compiler_utils.set_runtime_parameter_pb(
              parameter_value.runtime_parameter, value.name, value.ptype,
              value.default)
        # RuntimeInfoPlaceholder passes Execution parameters of Facade
        # components.
        elif isinstance(value, placeholder.RuntimeInfoPlaceholder):
          parameter_value.placeholder.CopyFrom(value.encode())
        # ChannelWrappedPlaceholder passes dynamic execution parameter.
        elif isinstance(value, placeholder.ChannelWrappedPlaceholder):
          compiler_utils.validate_dynamic_exec_ph_operator(value)
          parameter_value.placeholder.CopyFrom(
              value.encode_with_keys(compiler_utils.implicit_channel_key))
        else:
          try:
            data_types_utils.set_parameter_value(parameter_value, value)
          except ValueError:
            raise ValueError(
                "Component {} got unsupported parameter {} with type {}."
                .format(tfx_node.id, key, type(value))) from ValueError

    # Step 6: Executor spec and optional driver spec for components
    if isinstance(tfx_node, base_component.BaseComponent):
      executor_spec = tfx_node.executor_spec.encode(
          component_spec=tfx_node.spec)
      deployment_config.executor_specs[tfx_node.id].Pack(executor_spec)

      # TODO(b/163433174): Remove specialized logic once generalization of
      # driver spec is done.
      if tfx_node.driver_class != base_driver.BaseDriver:
        driver_class_path = _fully_qualified_name(tfx_node.driver_class)
        driver_spec = executable_spec_pb2.PythonClassExecutableSpec()
        driver_spec.class_path = driver_class_path
        deployment_config.custom_driver_specs[tfx_node.id].Pack(driver_spec)

    # Step 7: Upstream/Downstream nodes
    node.upstream_nodes.extend(
        self._find_runtime_upstream_node_ids(compile_context, tfx_node))
    node.downstream_nodes.extend(
        self._find_runtime_downstream_node_ids(compile_context, tfx_node))

    # Step 8: Node execution options
    node.execution_options.caching_options.enable_cache = enable_cache

    # Step 9: Per-node platform config
    if isinstance(tfx_node, base_component.BaseComponent):
      tfx_component = cast(base_component.BaseComponent, tfx_node)
      if tfx_component.platform_config:
        deployment_config.node_level_platform_configs[tfx_node.id].Pack(
            tfx_component.platform_config)

    return node