Esempio n. 1
0
def _serialize_pipeline(pipeline: tfx_pipeline.Pipeline) -> Text:
    """Serializes a TFX pipeline.

  To be replaced with the the TFX Intermediate Representation:
  tensorflow/community#271. This serialization procedure extracts from
  the pipeline properties necessary for reconstructing the pipeline instance
  from within the cluster. For properties such as components and metadata
  config that can not be directly dumped with JSON, we use NodeWrapper and
  MessageToJson to serialize them beforehand.

  Args:
    pipeline: Logical pipeline containing pipeline args and components.

  Returns:
    Pipeline serialized as JSON string.
  """
    serialized_components = []
    for component in pipeline.components:
        serialized_components.append(
            json_utils.dumps(node_wrapper.NodeWrapper(component)))
    # Extract and pass pipeline graph information which are lost during the
    # serialization process. The orchestrator container uses downstream_ids
    # to reconstruct pipeline graph.
    downstream_ids = _extract_downstream_ids(pipeline.components)
    return json.dumps({
        'pipeline_name':
        pipeline.pipeline_info.pipeline_name,
        'pipeline_root':
        pipeline.pipeline_info.pipeline_root,
        'enable_cache':
        pipeline.enable_cache,
        'components':
        serialized_components,
        'downstream_ids':
        downstream_ids,
        'metadata_connection_config':
        json_format.MessageToJson(
            message=pipeline.metadata_connection_config,
            preserving_proto_field_name=True,
        ),
        'beam_pipeline_args':
        pipeline.beam_pipeline_args,
    })
Esempio n. 2
0
    def __init__(self,
                 component: tfx_base_node.BaseNode,
                 component_launcher_class: Type[
                     base_component_launcher.BaseComponentLauncher],
                 depends_on: Set[dsl.ContainerOp],
                 pipeline: tfx_pipeline.Pipeline,
                 pipeline_name: Text,
                 pipeline_root: dsl.PipelineParam,
                 tfx_image: Text,
                 kubeflow_metadata_config: Optional[
                     kubeflow_pb2.KubeflowMetadataConfig],
                 component_config: base_component_config.BaseComponentConfig,
                 pod_labels_to_attach: Optional[Dict[Text, Text]] = None):
        """Creates a new Kubeflow-based component.

    This class essentially wraps a dsl.ContainerOp construct in Kubeflow
    Pipelines.

    Args:
      component: The logical TFX component to wrap.
      component_launcher_class: the class of the launcher to launch the
        component.
      depends_on: The set of upstream KFP ContainerOp components that this
        component will depend on.
      pipeline: The logical TFX pipeline to which this component belongs.
      pipeline_name: The name of the TFX pipeline.
      pipeline_root: The pipeline root specified, as a dsl.PipelineParam
      tfx_image: The container image to use for this component.
      kubeflow_metadata_config: Configuration settings for connecting to the
        MLMD store in a Kubeflow cluster.
      component_config: Component config to launch the component.
      pod_labels_to_attach: Optional dict of pod labels to attach to the
        GKE pod.
    """
        component_launcher_class_path = '.'.join([
            component_launcher_class.__module__,
            component_launcher_class.__name__
        ])

        serialized_component = utils.replace_placeholder(
            json_utils.dumps(node_wrapper.NodeWrapper(component)))

        arguments = [
            '--pipeline_name',
            pipeline_name,
            '--pipeline_root',
            pipeline_root,
            '--kubeflow_metadata_config',
            json_format.MessageToJson(message=kubeflow_metadata_config,
                                      preserving_proto_field_name=True),
            '--beam_pipeline_args',
            json.dumps(pipeline.beam_pipeline_args),
            '--additional_pipeline_args',
            json.dumps(pipeline.additional_pipeline_args),
            '--component_launcher_class_path',
            component_launcher_class_path,
            '--serialized_component',
            serialized_component,
            '--component_config',
            json_utils.dumps(component_config),
        ]

        if component.enable_cache or (component.enable_cache is None
                                      and pipeline.enable_cache):
            arguments.append('--enable_cache')

        self.container_op = dsl.ContainerOp(
            name=component.id.replace('.', '_'),
            command=_COMMAND,
            image=tfx_image,
            arguments=arguments,
            output_artifact_paths={
                'mlpipeline-ui-metadata': '/mlpipeline-ui-metadata.json',
            },
        )

        absl.logging.info(
            'Adding upstream dependencies for component {}'.format(
                self.container_op.name))
        for op in depends_on:
            absl.logging.info('   ->  Component: {}'.format(op.name))
            self.container_op.after(op)

        # TODO(b/140172100): Document the use of additional_pipeline_args.
        if _WORKFLOW_ID_KEY in pipeline.additional_pipeline_args:
            # Allow overriding pipeline's run_id externally, primarily for testing.
            self.container_op.container.add_env_variable(
                k8s_client.V1EnvVar(
                    name=_WORKFLOW_ID_KEY,
                    value=pipeline.additional_pipeline_args[_WORKFLOW_ID_KEY]))
        else:
            # Add the Argo workflow ID to the container's environment variable so it
            # can be used to uniquely place pipeline outputs under the pipeline_root.
            field_path = "metadata.labels['workflows.argoproj.io/workflow']"
            self.container_op.container.add_env_variable(
                k8s_client.V1EnvVar(
                    name=_WORKFLOW_ID_KEY,
                    value_from=k8s_client.V1EnvVarSource(
                        field_ref=k8s_client.V1ObjectFieldSelector(
                            field_path=field_path))))

        if pod_labels_to_attach:
            for k, v in pod_labels_to_attach.items():
                self.container_op.add_pod_label(k, v)
Esempio n. 3
0
    def _wrap_container_component(
        self,
        component: base_node.BaseNode,
        component_launcher_class: Type[
            base_component_launcher.BaseComponentLauncher],
        component_config: Optional[base_component_config.BaseComponentConfig],
        pipeline: tfx_pipeline.Pipeline,
    ) -> base_node.BaseNode:
        """Wrapper for container component.

    Args:
      component: Component to be executed.
      component_launcher_class: The class of the launcher to launch the
        component.
      component_config: component config to launch the component.
      pipeline: Logical pipeline that contains pipeline related information.

    Returns:
      A container component that runs the wrapped component upon execution.
    """

        component_launcher_class_path = '.'.join([
            component_launcher_class.__module__,
            component_launcher_class.__name__
        ])

        serialized_component = json_utils.dumps(
            node_wrapper.NodeWrapper(component))

        arguments = [
            '--pipeline_name',
            pipeline.pipeline_info.pipeline_name,
            '--pipeline_root',
            pipeline.pipeline_info.pipeline_root,
            '--run_id',
            pipeline.pipeline_info.run_id,
            '--metadata_config',
            json_format.MessageToJson(
                message=get_default_kubernetes_metadata_config(),
                preserving_proto_field_name=True),
            '--beam_pipeline_args',
            json.dumps(pipeline.beam_pipeline_args),
            '--additional_pipeline_args',
            json.dumps(pipeline.additional_pipeline_args),
            '--component_launcher_class_path',
            component_launcher_class_path,
            '--serialized_component',
            serialized_component,
            '--component_config',
            json_utils.dumps(component_config),
        ]

        # Outputs/Parameters fields are not used as they are contained in
        # the serialized component.
        return container_component.create_container_component(
            name=component.__class__.__name__,
            outputs={},
            parameters={},
            image=self._config.tfx_image,
            command=_CONTAINER_COMMAND + arguments)(
                instance_name=component._instance_name + _WRAPPER_SUFFIX)  # pylint: disable=protected-access
Esempio n. 4
0
    def __init__(self,
                 component: tfx_base_node.BaseNode,
                 depends_on: Set[dsl.ContainerOp],
                 pipeline: tfx_pipeline.Pipeline,
                 pipeline_root: dsl.PipelineParam,
                 tfx_image: Text,
                 kubeflow_metadata_config: Optional[
                     kubeflow_pb2.KubeflowMetadataConfig],
                 tfx_ir: Optional[pipeline_pb2.Pipeline] = None,
                 pod_labels_to_attach: Optional[Dict[Text, Text]] = None):
        """Creates a new Kubeflow-based component.

    This class essentially wraps a dsl.ContainerOp construct in Kubeflow
    Pipelines.

    Args:
      component: The logical TFX component to wrap.
      depends_on: The set of upstream KFP ContainerOp components that this
        component will depend on.
      pipeline: The logical TFX pipeline to which this component belongs.
      pipeline_root: The pipeline root specified, as a dsl.PipelineParam
      tfx_image: The container image to use for this component.
      kubeflow_metadata_config: Configuration settings for connecting to the
        MLMD store in a Kubeflow cluster.
      tfx_ir: The TFX intermedia representation of the pipeline.
      pod_labels_to_attach: Optional dict of pod labels to attach to the
        GKE pod.
    """

        serialized_component = utils.replace_placeholder(
            json_utils.dumps(node_wrapper.NodeWrapper(component)))

        arguments = [
            '--pipeline_root',
            pipeline_root,
            '--kubeflow_metadata_config',
            json_format.MessageToJson(message=kubeflow_metadata_config,
                                      preserving_proto_field_name=True),
            '--node_id',
            component.id,
            '--serialized_component',
            serialized_component,
            # TODO(b/182220464): write IR to pipeline_root and let
            # container_entrypoint.py read it back to avoid future issue that IR
            # exeeds the flag size limit.
            '--tfx_ir',
            json_format.MessageToJson(tfx_ir),
        ]

        self.container_op = dsl.ContainerOp(
            name=component.id,
            command=_COMMAND,
            image=tfx_image,
            arguments=arguments,
            output_artifact_paths={
                'mlpipeline-ui-metadata': '/mlpipeline-ui-metadata.json',
            },
        )

        logging.info('Adding upstream dependencies for component %s',
                     self.container_op.name)
        for op in depends_on:
            logging.info('   ->  Component: %s', op.name)
            self.container_op.after(op)

        # TODO(b/140172100): Document the use of additional_pipeline_args.
        if _WORKFLOW_ID_KEY in pipeline.additional_pipeline_args:
            # Allow overriding pipeline's run_id externally, primarily for testing.
            self.container_op.container.add_env_variable(
                k8s_client.V1EnvVar(
                    name=_WORKFLOW_ID_KEY,
                    value=pipeline.additional_pipeline_args[_WORKFLOW_ID_KEY]))
        else:
            # Add the Argo workflow ID to the container's environment variable so it
            # can be used to uniquely place pipeline outputs under the pipeline_root.
            field_path = "metadata.labels['workflows.argoproj.io/workflow']"
            self.container_op.container.add_env_variable(
                k8s_client.V1EnvVar(
                    name=_WORKFLOW_ID_KEY,
                    value_from=k8s_client.V1EnvVarSource(
                        field_ref=k8s_client.V1ObjectFieldSelector(
                            field_path=field_path))))

        if pod_labels_to_attach:
            for k, v in pod_labels_to_attach.items():
                self.container_op.add_pod_label(k, v)