예제 #1
0
    def _create_pipeline(
        self,
        pipeline_func: Callable[..., Any],
        pipeline_name: Optional[str] = None,
    ) -> pipeline_spec_pb2.PipelineSpec:
        """Creates a pipeline instance and constructs the pipeline spec from it.

    Args:
      pipeline_func: Pipeline function with @dsl.pipeline decorator.
      pipeline_name: The name of the pipeline. Optional.

    Returns:
      The IR representation (pipeline spec) of the pipeline.
    """

        # Create the arg list with no default values and call pipeline function.
        # Assign type information to the PipelineParam
        pipeline_meta = python_op._extract_component_interface(pipeline_func)
        pipeline_name = pipeline_name or pipeline_meta.name

        args_list = []
        signature = inspect.signature(pipeline_func)
        for arg_name in signature.parameters:
            arg_type = None
            for pipeline_input in pipeline_meta.inputs or []:
                if arg_name == pipeline_input.name:
                    arg_type = pipeline_input.type
                    break
            args_list.append(
                dsl.PipelineParam(sanitize_k8s_name(arg_name, True),
                                  param_type=arg_type))

        with dsl.Pipeline(pipeline_name) as dsl_pipeline:
            pipeline_func(*args_list)

        # Fill in the default values.
        args_list_with_defaults = []
        if pipeline_meta.inputs:
            args_list_with_defaults = [
                dsl.PipelineParam(sanitize_k8s_name(input_spec.name, True),
                                  param_type=input_spec.type,
                                  value=input_spec.default)
                for input_spec in pipeline_meta.inputs
            ]

        pipeline_spec = self._create_pipeline_spec(
            args_list_with_defaults,
            dsl_pipeline,
        )

        return pipeline_spec
예제 #2
0
  def _create_pipeline(
      self,
      pipeline_func: Callable[..., Any],
      pipeline_root: Optional[str] = None,
      pipeline_name: Optional[str] = None,
      pipeline_parameters_override: Optional[Mapping[str, Any]] = None,
  ) -> pipeline_spec_pb2.PipelineJob:
    """Creates a pipeline instance and constructs the pipeline spec from it.

    Args:
      pipeline_func: Pipeline function with @dsl.pipeline decorator.
      pipeline_root: The root of the pipeline outputs. Optional.
      pipeline_name: The name of the pipeline. Optional.
      pipeline_parameters_override: The mapping from parameter names to values.
        Optional.

    Returns:
      A PipelineJob proto representing the compiled pipeline.
    """

    # Create the arg list with no default values and call pipeline function.
    # Assign type information to the PipelineParam
    pipeline_meta = _python_op._extract_component_interface(pipeline_func)
    pipeline_name = pipeline_name or pipeline_meta.name

    pipeline_root = pipeline_root or getattr(pipeline_func, 'output_directory',
                                             None)
    if not pipeline_root:
      warnings.warn('pipeline_root is None or empty. A valid pipeline_root '
                    'must be provided at job submission.')

    args_list = []
    signature = inspect.signature(pipeline_func)
    for arg_name in signature.parameters:
      arg_type = None
      for pipeline_input in pipeline_meta.inputs or []:
        if arg_name == pipeline_input.name:
          arg_type = pipeline_input.type
          break
      args_list.append(
          dsl.PipelineParam(
              sanitize_k8s_name(arg_name, True), param_type=arg_type))

    with dsl.Pipeline(pipeline_name) as dsl_pipeline:
      pipeline_func(*args_list)

    # Fill in the default values.
    args_list_with_defaults = []
    if pipeline_meta.inputs:
      args_list_with_defaults = [
          dsl.PipelineParam(
              sanitize_k8s_name(input_spec.name, True),
              param_type=input_spec.type,
              value=input_spec.default) for input_spec in pipeline_meta.inputs
      ]

    # Making the pipeline group name unique to prevent name clashes with templates
    pipeline_group = dsl_pipeline.groups[0]
    temp_pipeline_group_name = uuid.uuid4().hex
    pipeline_group.name = temp_pipeline_group_name

    pipeline_spec = self._create_pipeline_spec(
        args_list_with_defaults,
        dsl_pipeline,
    )

    pipeline_parameters = {
        arg.name: arg.value for arg in args_list_with_defaults
    }
    # Update pipeline parameters override if there were any.
    pipeline_parameters.update(pipeline_parameters_override or {})
    runtime_config = compiler_utils.build_runtime_config_spec(
        output_directory=pipeline_root, pipeline_parameters=pipeline_parameters)
    pipeline_job = pipeline_spec_pb2.PipelineJob(runtime_config=runtime_config)
    pipeline_job.pipeline_spec.update(json_format.MessageToDict(pipeline_spec))

    return pipeline_job
예제 #3
0
    def _create_pipeline(
        self,
        pipeline_func: Callable[..., Any],
        output_directory: str,
        pipeline_name: Optional[str] = None,
        pipeline_parameters_override: Optional[Mapping[str, Any]] = None,
    ) -> pipeline_spec_pb2.PipelineJob:
        """Creates a pipeline instance and constructs the pipeline spec from it.

    Args:
      pipeline_func: Pipeline function with @dsl.pipeline decorator.
      pipeline_name: The name of the pipeline. Optional.
      output_directory: The root of the pipeline outputs.
      pipeline_parameters_override: The mapping from parameter names to values.
        Optional.

    Returns:
      A PipelineJob proto representing the compiled pipeline.
    """

        # Create the arg list with no default values and call pipeline function.
        # Assign type information to the PipelineParam
        pipeline_meta = _python_op._extract_component_interface(pipeline_func)
        pipeline_name = pipeline_name or pipeline_meta.name

        args_list = []
        signature = inspect.signature(pipeline_func)
        for arg_name in signature.parameters:
            arg_type = None
            for pipeline_input in pipeline_meta.inputs or []:
                if arg_name == pipeline_input.name:
                    arg_type = pipeline_input.type
                    break
            args_list.append(
                dsl.PipelineParam(sanitize_k8s_name(arg_name, True),
                                  param_type=arg_type))

        with dsl.Pipeline(pipeline_name) as dsl_pipeline:
            pipeline_func(*args_list)

        # Fill in the default values.
        args_list_with_defaults = []
        if pipeline_meta.inputs:
            args_list_with_defaults = [
                dsl.PipelineParam(sanitize_k8s_name(input_spec.name, True),
                                  param_type=input_spec.type,
                                  value=input_spec.default)
                for input_spec in pipeline_meta.inputs
            ]

        pipeline_spec = self._create_pipeline_spec(
            args_list_with_defaults,
            dsl_pipeline,
        )

        pipeline_parameters = {
            arg.name: arg.value
            for arg in args_list_with_defaults
        }
        # Update pipeline parameters override if there were any.
        pipeline_parameters.update(pipeline_parameters_override or {})
        runtime_config = compiler_utils.build_runtime_config_spec(
            output_directory=output_directory,
            pipeline_parameters=pipeline_parameters)
        pipeline_job = pipeline_spec_pb2.PipelineJob(
            runtime_config=runtime_config)
        pipeline_job.pipeline_spec.update(
            json_format.MessageToDict(pipeline_spec))

        return pipeline_job