def __init__(self, component: tfx_base_node.BaseNode, depends_on: Set[dsl.ContainerOp], pipeline: tfx_pipeline.Pipeline, pipeline_root: dsl.PipelineParam, tfx_image: str, kubeflow_metadata_config: kubeflow_pb2.KubeflowMetadataConfig, tfx_ir: pipeline_pb2.Pipeline, pod_labels_to_attach: Dict[str, str], runtime_parameters: List[data_types.RuntimeParameter], metadata_ui_path: str = '/mlpipeline-ui-metadata.json'): """Creates a new Kubeflow-based component. This class essentially wraps a dsl.ContainerOp construct in Kubeflow Pipelines. Args: component: The logical TFX component to wrap. depends_on: The set of upstream KFP ContainerOp components that this component will depend on. pipeline: The logical TFX pipeline to which this component belongs. pipeline_root: The pipeline root specified, as a dsl.PipelineParam tfx_image: The container image to use for this component. kubeflow_metadata_config: Configuration settings for connecting to the MLMD store in a Kubeflow cluster. tfx_ir: The TFX intermedia representation of the pipeline. pod_labels_to_attach: Dict of pod labels to attach to the GKE pod. runtime_parameters: Runtime parameters of the pipeline. metadata_ui_path: File location for metadata-ui-metadata.json file. """ utils.replace_placeholder(component) arguments = [ '--pipeline_root', pipeline_root, '--kubeflow_metadata_config', json_format.MessageToJson(message=kubeflow_metadata_config, preserving_proto_field_name=True), '--node_id', component.id, # TODO(b/182220464): write IR to pipeline_root and let # container_entrypoint.py read it back to avoid future issue that IR # exeeds the flag size limit. '--tfx_ir', json_format.MessageToJson(tfx_ir), '--metadata_ui_path', metadata_ui_path, ] for param in runtime_parameters: arguments.append('--runtime_parameter') arguments.append(_encode_runtime_parameter(param)) self.container_op = dsl.ContainerOp( name=component.id, command=_COMMAND, image=tfx_image, arguments=arguments, output_artifact_paths={ 'mlpipeline-ui-metadata': metadata_ui_path, }, ) logging.info('Adding upstream dependencies for component %s', self.container_op.name) for op in depends_on: logging.info(' -> Component: %s', op.name) self.container_op.after(op) # TODO(b/140172100): Document the use of additional_pipeline_args. if _WORKFLOW_ID_KEY in pipeline.additional_pipeline_args: # Allow overriding pipeline's run_id externally, primarily for testing. self.container_op.container.add_env_variable( k8s_client.V1EnvVar( name=_WORKFLOW_ID_KEY, value=pipeline.additional_pipeline_args[_WORKFLOW_ID_KEY])) else: # Add the Argo workflow ID to the container's environment variable so it # can be used to uniquely place pipeline outputs under the pipeline_root. field_path = "metadata.labels['workflows.argoproj.io/workflow']" self.container_op.container.add_env_variable( k8s_client.V1EnvVar( name=_WORKFLOW_ID_KEY, value_from=k8s_client.V1EnvVarSource( field_ref=k8s_client.V1ObjectFieldSelector( field_path=field_path)))) if pod_labels_to_attach: for k, v in pod_labels_to_attach.items(): self.container_op.add_pod_label(k, v)
def __init__(self, component: tfx_base_node.BaseNode, component_launcher_class: Type[ base_component_launcher.BaseComponentLauncher], depends_on: Set[dsl.ContainerOp], pipeline: tfx_pipeline.Pipeline, pipeline_name: Text, pipeline_root: dsl.PipelineParam, tfx_image: Text, kubeflow_metadata_config: Optional[ kubeflow_pb2.KubeflowMetadataConfig], component_config: base_component_config.BaseComponentConfig, pod_labels_to_attach: Optional[Dict[Text, Text]] = None): """Creates a new Kubeflow-based component. This class essentially wraps a dsl.ContainerOp construct in Kubeflow Pipelines. Args: component: The logical TFX component to wrap. component_launcher_class: the class of the launcher to launch the component. depends_on: The set of upstream KFP ContainerOp components that this component will depend on. pipeline: The logical TFX pipeline to which this component belongs. pipeline_name: The name of the TFX pipeline. pipeline_root: The pipeline root specified, as a dsl.PipelineParam tfx_image: The container image to use for this component. kubeflow_metadata_config: Configuration settings for connecting to the MLMD store in a Kubeflow cluster. component_config: Component config to launch the component. pod_labels_to_attach: Optional dict of pod labels to attach to the GKE pod. """ component_launcher_class_path = '.'.join([ component_launcher_class.__module__, component_launcher_class.__name__ ]) serialized_component = utils.replace_placeholder( json_utils.dumps(node_wrapper.NodeWrapper(component))) arguments = [ '--pipeline_name', pipeline_name, '--pipeline_root', pipeline_root, '--kubeflow_metadata_config', json_format.MessageToJson(message=kubeflow_metadata_config, preserving_proto_field_name=True), '--beam_pipeline_args', json.dumps(pipeline.beam_pipeline_args), '--additional_pipeline_args', json.dumps(pipeline.additional_pipeline_args), '--component_launcher_class_path', component_launcher_class_path, '--serialized_component', serialized_component, '--component_config', json_utils.dumps(component_config), ] if component.enable_cache or (component.enable_cache is None and pipeline.enable_cache): arguments.append('--enable_cache') self.container_op = dsl.ContainerOp( name=component.id.replace('.', '_'), command=_COMMAND, image=tfx_image, arguments=arguments, output_artifact_paths={ 'mlpipeline-ui-metadata': '/mlpipeline-ui-metadata.json', }, ) absl.logging.info( 'Adding upstream dependencies for component {}'.format( self.container_op.name)) for op in depends_on: absl.logging.info(' -> Component: {}'.format(op.name)) self.container_op.after(op) # TODO(b/140172100): Document the use of additional_pipeline_args. if _WORKFLOW_ID_KEY in pipeline.additional_pipeline_args: # Allow overriding pipeline's run_id externally, primarily for testing. self.container_op.container.add_env_variable( k8s_client.V1EnvVar( name=_WORKFLOW_ID_KEY, value=pipeline.additional_pipeline_args[_WORKFLOW_ID_KEY])) else: # Add the Argo workflow ID to the container's environment variable so it # can be used to uniquely place pipeline outputs under the pipeline_root. field_path = "metadata.labels['workflows.argoproj.io/workflow']" self.container_op.container.add_env_variable( k8s_client.V1EnvVar( name=_WORKFLOW_ID_KEY, value_from=k8s_client.V1EnvVarSource( field_ref=k8s_client.V1ObjectFieldSelector( field_path=field_path)))) if pod_labels_to_attach: for k, v in pod_labels_to_attach.items(): self.container_op.add_pod_label(k, v)