Exemple #1
0
    def __init__(self,
                 component: tfx_base_node.BaseNode,
                 depends_on: Set[dsl.ContainerOp],
                 pipeline: tfx_pipeline.Pipeline,
                 pipeline_root: dsl.PipelineParam,
                 tfx_image: str,
                 kubeflow_metadata_config: kubeflow_pb2.KubeflowMetadataConfig,
                 tfx_ir: pipeline_pb2.Pipeline,
                 pod_labels_to_attach: Dict[str, str],
                 runtime_parameters: List[data_types.RuntimeParameter],
                 metadata_ui_path: str = '/mlpipeline-ui-metadata.json'):
        """Creates a new Kubeflow-based component.

    This class essentially wraps a dsl.ContainerOp construct in Kubeflow
    Pipelines.

    Args:
      component: The logical TFX component to wrap.
      depends_on: The set of upstream KFP ContainerOp components that this
        component will depend on.
      pipeline: The logical TFX pipeline to which this component belongs.
      pipeline_root: The pipeline root specified, as a dsl.PipelineParam
      tfx_image: The container image to use for this component.
      kubeflow_metadata_config: Configuration settings for connecting to the
        MLMD store in a Kubeflow cluster.
      tfx_ir: The TFX intermedia representation of the pipeline.
      pod_labels_to_attach: Dict of pod labels to attach to the GKE pod.
      runtime_parameters: Runtime parameters of the pipeline.
      metadata_ui_path: File location for metadata-ui-metadata.json file.
    """

        utils.replace_placeholder(component)

        arguments = [
            '--pipeline_root',
            pipeline_root,
            '--kubeflow_metadata_config',
            json_format.MessageToJson(message=kubeflow_metadata_config,
                                      preserving_proto_field_name=True),
            '--node_id',
            component.id,
            # TODO(b/182220464): write IR to pipeline_root and let
            # container_entrypoint.py read it back to avoid future issue that IR
            # exeeds the flag size limit.
            '--tfx_ir',
            json_format.MessageToJson(tfx_ir),
            '--metadata_ui_path',
            metadata_ui_path,
        ]

        for param in runtime_parameters:
            arguments.append('--runtime_parameter')
            arguments.append(_encode_runtime_parameter(param))

        self.container_op = dsl.ContainerOp(
            name=component.id,
            command=_COMMAND,
            image=tfx_image,
            arguments=arguments,
            output_artifact_paths={
                'mlpipeline-ui-metadata': metadata_ui_path,
            },
        )

        logging.info('Adding upstream dependencies for component %s',
                     self.container_op.name)
        for op in depends_on:
            logging.info('   ->  Component: %s', op.name)
            self.container_op.after(op)

        # TODO(b/140172100): Document the use of additional_pipeline_args.
        if _WORKFLOW_ID_KEY in pipeline.additional_pipeline_args:
            # Allow overriding pipeline's run_id externally, primarily for testing.
            self.container_op.container.add_env_variable(
                k8s_client.V1EnvVar(
                    name=_WORKFLOW_ID_KEY,
                    value=pipeline.additional_pipeline_args[_WORKFLOW_ID_KEY]))
        else:
            # Add the Argo workflow ID to the container's environment variable so it
            # can be used to uniquely place pipeline outputs under the pipeline_root.
            field_path = "metadata.labels['workflows.argoproj.io/workflow']"
            self.container_op.container.add_env_variable(
                k8s_client.V1EnvVar(
                    name=_WORKFLOW_ID_KEY,
                    value_from=k8s_client.V1EnvVarSource(
                        field_ref=k8s_client.V1ObjectFieldSelector(
                            field_path=field_path))))

        if pod_labels_to_attach:
            for k, v in pod_labels_to_attach.items():
                self.container_op.add_pod_label(k, v)
Exemple #2
0
    def __init__(self,
                 component: tfx_base_node.BaseNode,
                 component_launcher_class: Type[
                     base_component_launcher.BaseComponentLauncher],
                 depends_on: Set[dsl.ContainerOp],
                 pipeline: tfx_pipeline.Pipeline,
                 pipeline_name: Text,
                 pipeline_root: dsl.PipelineParam,
                 tfx_image: Text,
                 kubeflow_metadata_config: Optional[
                     kubeflow_pb2.KubeflowMetadataConfig],
                 component_config: base_component_config.BaseComponentConfig,
                 pod_labels_to_attach: Optional[Dict[Text, Text]] = None):
        """Creates a new Kubeflow-based component.

    This class essentially wraps a dsl.ContainerOp construct in Kubeflow
    Pipelines.

    Args:
      component: The logical TFX component to wrap.
      component_launcher_class: the class of the launcher to launch the
        component.
      depends_on: The set of upstream KFP ContainerOp components that this
        component will depend on.
      pipeline: The logical TFX pipeline to which this component belongs.
      pipeline_name: The name of the TFX pipeline.
      pipeline_root: The pipeline root specified, as a dsl.PipelineParam
      tfx_image: The container image to use for this component.
      kubeflow_metadata_config: Configuration settings for connecting to the
        MLMD store in a Kubeflow cluster.
      component_config: Component config to launch the component.
      pod_labels_to_attach: Optional dict of pod labels to attach to the
        GKE pod.
    """
        component_launcher_class_path = '.'.join([
            component_launcher_class.__module__,
            component_launcher_class.__name__
        ])

        serialized_component = utils.replace_placeholder(
            json_utils.dumps(node_wrapper.NodeWrapper(component)))

        arguments = [
            '--pipeline_name',
            pipeline_name,
            '--pipeline_root',
            pipeline_root,
            '--kubeflow_metadata_config',
            json_format.MessageToJson(message=kubeflow_metadata_config,
                                      preserving_proto_field_name=True),
            '--beam_pipeline_args',
            json.dumps(pipeline.beam_pipeline_args),
            '--additional_pipeline_args',
            json.dumps(pipeline.additional_pipeline_args),
            '--component_launcher_class_path',
            component_launcher_class_path,
            '--serialized_component',
            serialized_component,
            '--component_config',
            json_utils.dumps(component_config),
        ]

        if component.enable_cache or (component.enable_cache is None
                                      and pipeline.enable_cache):
            arguments.append('--enable_cache')

        self.container_op = dsl.ContainerOp(
            name=component.id.replace('.', '_'),
            command=_COMMAND,
            image=tfx_image,
            arguments=arguments,
            output_artifact_paths={
                'mlpipeline-ui-metadata': '/mlpipeline-ui-metadata.json',
            },
        )

        absl.logging.info(
            'Adding upstream dependencies for component {}'.format(
                self.container_op.name))
        for op in depends_on:
            absl.logging.info('   ->  Component: {}'.format(op.name))
            self.container_op.after(op)

        # TODO(b/140172100): Document the use of additional_pipeline_args.
        if _WORKFLOW_ID_KEY in pipeline.additional_pipeline_args:
            # Allow overriding pipeline's run_id externally, primarily for testing.
            self.container_op.container.add_env_variable(
                k8s_client.V1EnvVar(
                    name=_WORKFLOW_ID_KEY,
                    value=pipeline.additional_pipeline_args[_WORKFLOW_ID_KEY]))
        else:
            # Add the Argo workflow ID to the container's environment variable so it
            # can be used to uniquely place pipeline outputs under the pipeline_root.
            field_path = "metadata.labels['workflows.argoproj.io/workflow']"
            self.container_op.container.add_env_variable(
                k8s_client.V1EnvVar(
                    name=_WORKFLOW_ID_KEY,
                    value_from=k8s_client.V1EnvVarSource(
                        field_ref=k8s_client.V1ObjectFieldSelector(
                            field_path=field_path))))

        if pod_labels_to_attach:
            for k, v in pod_labels_to_attach.items():
                self.container_op.add_pod_label(k, v)