예제 #1
0
파일: compiler.py 프로젝트: kp425/tfx
    def _compile_importer_node_outputs(self, tfx_node: base_node.BaseNode,
                                       node_pb: pipeline_pb2.PipelineNode):
        """Compiles the outputs of an importer node."""
        for key, value in tfx_node.outputs.items():
            output_spec = node_pb.outputs.outputs[key]
            artifact_type = value.type._get_artifact_type()  # pylint: disable=protected-access
            output_spec.artifact_spec.type.CopyFrom(artifact_type)

            # Attach additional properties for artifacts produced by importer nodes.
            for property_name, property_value in tfx_node.exec_properties[
                    importer_node.PROPERTIES_KEY].items():
                value_field = output_spec.artifact_spec.additional_properties[
                    property_name].field_value
                try:
                    compiler_utils.set_field_value_pb(value_field,
                                                      property_value)
                except ValueError:
                    raise ValueError(
                        "Component {} got unsupported parameter {} with type {}."
                        .format(tfx_node.id, property_name,
                                type(property_value)))

            for property_name, property_value in tfx_node.exec_properties[
                    importer_node.CUSTOM_PROPERTIES_KEY].items():
                value_field = output_spec.artifact_spec.additional_custom_properties[
                    property_name].field_value
                try:
                    compiler_utils.set_field_value_pb(value_field,
                                                      property_value)
                except ValueError:
                    raise ValueError(
                        "Component {} got unsupported parameter {} with type {}."
                        .format(tfx_node.id, property_name,
                                type(property_value)))
예제 #2
0
파일: compiler.py 프로젝트: kp425/tfx
    def _compile_node(
        self,
        tfx_node: base_node.BaseNode,
        compile_context: _CompilerContext,
        deployment_config: pipeline_pb2.IntermediateDeploymentConfig,
        enable_cache: bool,
    ) -> pipeline_pb2.PipelineNode:
        """Compiles an individual TFX node into a PipelineNode proto.

    Args:
      tfx_node: A TFX node.
      compile_context: Resources needed to compile the node.
      deployment_config: Intermediate deployment config to set. Will include
        related specs for executors, drivers and platform specific configs.
      enable_cache: whether cache is enabled

    Returns:
      A PipelineNode proto that encodes information of the node.
    """
        node = pipeline_pb2.PipelineNode()

        # Step 1: Node info
        node.node_info.type.name = tfx_node.type
        node.node_info.id = tfx_node.id

        # Step 2: Node Context
        # Context for the pipeline, across pipeline runs.
        pipeline_context_pb = node.contexts.contexts.add()
        pipeline_context_pb.type.name = constants.PIPELINE_CONTEXT_TYPE_NAME
        pipeline_context_pb.name.field_value.string_value = compile_context.pipeline_info.pipeline_context_name

        # Context for the current pipeline run.
        if (compile_context.execution_mode ==
                pipeline_pb2.Pipeline.ExecutionMode.SYNC):
            pipeline_run_context_pb = node.contexts.contexts.add()
            pipeline_run_context_pb.type.name = constants.PIPELINE_RUN_CONTEXT_TYPE_NAME
            compiler_utils.set_runtime_parameter_pb(
                pipeline_run_context_pb.name.runtime_parameter,
                constants.PIPELINE_RUN_ID_PARAMETER_NAME, str)

        # Context for the node, across pipeline runs.
        node_context_pb = node.contexts.contexts.add()
        node_context_pb.type.name = constants.NODE_CONTEXT_TYPE_NAME
        node_context_pb.name.field_value.string_value = "{}.{}".format(
            compile_context.pipeline_info.pipeline_context_name,
            node.node_info.id)

        # Step 3: Node inputs
        for key, value in tfx_node.inputs.items():
            input_spec = node.inputs.inputs[key]
            channel = input_spec.channels.add()
            if value.producer_component_id:
                channel.producer_node_query.id = value.producer_component_id

                # Here we rely on pipeline.components to be topologically sorted.
                assert value.producer_component_id in compile_context.node_pbs, (
                    "producer component should have already been compiled.")
                producer_pb = compile_context.node_pbs[
                    value.producer_component_id]
                for producer_context in producer_pb.contexts.contexts:
                    if (not compiler_utils.is_resolver(tfx_node)
                            or producer_context.name.runtime_parameter.name !=
                            constants.PIPELINE_RUN_CONTEXT_TYPE_NAME):
                        context_query = channel.context_queries.add()
                        context_query.type.CopyFrom(producer_context.type)
                        context_query.name.CopyFrom(producer_context.name)
            else:
                # Caveat: portable core requires every channel to have at least one
                # Contex. But For cases like system nodes and producer-consumer
                # pipelines, a channel may not have contexts at all. In these cases,
                # we want to use the pipeline level context as the input channel
                # context.
                context_query = channel.context_queries.add()
                context_query.type.CopyFrom(pipeline_context_pb.type)
                context_query.name.CopyFrom(pipeline_context_pb.name)

            artifact_type = value.type._get_artifact_type()  # pylint: disable=protected-access
            channel.artifact_query.type.CopyFrom(artifact_type)
            channel.artifact_query.type.ClearField("properties")

            if value.output_key:
                channel.output_key = value.output_key

            # TODO(b/158712886): Calculate min_count based on if inputs are optional.
            # min_count = 0 stands for optional input and 1 stands for required input.

        # Step 3.1: Special treatment for Resolver node
        if compiler_utils.is_resolver(tfx_node):
            resolver = tfx_node.exec_properties[resolver_node.RESOLVER_CLASS]
            if resolver == latest_artifacts_resolver.LatestArtifactsResolver:
                node.inputs.resolver_config.resolver_policy = (
                    pipeline_pb2.ResolverConfig.ResolverPolicy.LATEST_ARTIFACT)
            elif resolver == latest_blessed_model_resolver.LatestBlessedModelResolver:
                node.inputs.resolver_config.resolver_policy = (
                    pipeline_pb2.ResolverConfig.ResolverPolicy.
                    LATEST_BLESSED_MODEL)
            else:
                raise ValueError("Got unsupported resolver policy: {}".format(
                    resolver.type))

        # Step 4: Node outputs
        if isinstance(tfx_node, base_component.BaseComponent):
            for key, value in tfx_node.outputs.items():
                output_spec = node.outputs.outputs[key]
                artifact_type = value.type._get_artifact_type()  # pylint: disable=protected-access
                output_spec.artifact_spec.type.CopyFrom(artifact_type)

        # TODO(b/170694459): Refactor special nodes as plugins.
        # Step 4.1: Special treament for Importer node
        if compiler_utils.is_importer(tfx_node):
            self._compile_importer_node_outputs(tfx_node, node)

        # Step 5: Node parameters
        if not compiler_utils.is_resolver(tfx_node):
            for key, value in tfx_node.exec_properties.items():
                if value is None:
                    continue
                # Ignore following two properties for a importer node, because they are
                # already attached to the artifacts produced by the importer node.
                if compiler_utils.is_importer(tfx_node) and (
                        key == importer_node.PROPERTIES_KEY
                        or key == importer_node.CUSTOM_PROPERTIES_KEY):
                    continue
                parameter_value = node.parameters.parameters[key]

                # Order matters, because runtime parameter can be in serialized string.
                if isinstance(value, data_types.RuntimeParameter):
                    compiler_utils.set_runtime_parameter_pb(
                        parameter_value.runtime_parameter, value.name,
                        value.ptype, value.default)
                elif isinstance(value, str) and re.search(
                        data_types.RUNTIME_PARAMETER_PATTERN, value):
                    runtime_param = json.loads(value)
                    compiler_utils.set_runtime_parameter_pb(
                        parameter_value.runtime_parameter, runtime_param.name,
                        runtime_param.ptype, runtime_param.default)
                else:
                    try:
                        compiler_utils.set_field_value_pb(
                            parameter_value.field_value, value)
                    except ValueError:
                        raise ValueError(
                            "Component {} got unsupported parameter {} with type {}."
                            .format(tfx_node.id, key, type(value)))

        # Step 6: Executor spec and optional driver spec for components
        if isinstance(tfx_node, base_component.BaseComponent):
            executor_spec = tfx_node.executor_spec.encode(
                component_spec=tfx_node.spec)
            deployment_config.executor_specs[tfx_node.id].Pack(executor_spec)

            # TODO(b/163433174): Remove specialized logic once generalization of
            # driver spec is done.
            if tfx_node.driver_class != base_driver.BaseDriver:
                driver_class_path = "{}.{}".format(
                    tfx_node.driver_class.__module__,
                    tfx_node.driver_class.__name__)
                driver_spec = executable_spec_pb2.PythonClassExecutableSpec()
                driver_spec.class_path = driver_class_path
                deployment_config.custom_driver_specs[tfx_node.id].Pack(
                    driver_spec)

        # Step 7: Upstream/Downstream nodes
        # Note: the order of tfx_node.upstream_nodes is inconsistent from
        # run to run. We sort them so that compiler generates consistent results.
        node.upstream_nodes.extend(
            sorted([
                upstream_component.id
                for upstream_component in tfx_node.upstream_nodes
            ]))
        node.downstream_nodes.extend(
            sorted([
                downstream_component.id
                for downstream_component in tfx_node.downstream_nodes
            ]))

        # Step 8: Node execution options
        node.execution_options.caching_options.enable_cache = enable_cache

        # Step 9: Per-node platform config
        if isinstance(tfx_node, base_component.BaseComponent):
            tfx_component = cast(base_component.BaseComponent, tfx_node)
            if tfx_component.platform_config:
                deployment_config.node_level_platform_configs[
                    tfx_node.id].Pack(tfx_component.platform_config)

        return node
예제 #3
0
 def testSetFieldValuePbUnsupportedType(self):
     pb = metadata_store_pb2.Value()
     with self.assertRaises(ValueError):
         compiler_utils.set_field_value_pb(pb, True)
예제 #4
0
 def testSetFieldValuePb(self, value, expected_pb):
     pb = metadata_store_pb2.Value()
     compiler_utils.set_field_value_pb(pb, value)
     self.assertEqual(pb, expected_pb)