Ejemplo n.º 1
0
    def test_build_runtime_config_spec(self):
        expected_dict = {
            'gcsOutputDirectory': 'gs://path',
            'parameterValues': {
                'input1': 'test',
                'input2': 2,
                'input3': [1, 2, 3]
            }
        }
        expected_spec = pipeline_spec_pb2.PipelineJob.RuntimeConfig()
        json_format.ParseDict(expected_dict, expected_spec)

        runtime_config = compiler_utils.build_runtime_config_spec(
            'gs://path', {
                'input1':
                _pipeline_param.PipelineParam(
                    name='input1', param_type='String', value='test'),
                'input2':
                _pipeline_param.PipelineParam(
                    name='input2', param_type='Integer', value=2),
                'input3':
                _pipeline_param.PipelineParam(
                    name='input3', param_type='List', value=[1, 2, 3]),
                'input4':
                _pipeline_param.PipelineParam(
                    name='input4', param_type='Double', value=None)
            })
        self.assertEqual(expected_spec, runtime_config)
Ejemplo n.º 2
0
    def _create_pipeline_job(
        self,
        pipeline_spec: pipeline_spec_pb2.PipelineSpec,
        pipeline_root: str,
        pipeline_parameters: Optional[Mapping[str, Any]] = None,
    ) -> pipeline_spec_pb2.PipelineJob:
        """Creates the pipeline job spec object.

    Args:
      pipeline_spec: The pipeline spec object.
      pipeline_root: The root of the pipeline outputs.
      pipeline_parameters: The mapping from parameter names to values. Optional.

    Returns:
      A PipelineJob proto representing the compiled pipeline.
    """
        runtime_config = compiler_utils.build_runtime_config_spec(
            pipeline_root=pipeline_root,
            pipeline_parameters=pipeline_parameters)
        pipeline_job = pipeline_spec_pb2.PipelineJob(
            runtime_config=runtime_config)
        pipeline_job.pipeline_spec.update(
            json_format.MessageToDict(pipeline_spec))

        return pipeline_job
    def test_build_runtime_config_spec(self):
        expected_dict = {
            'gcsOutputDirectory': 'gs://path',
            'parameters': {
                'input1': {
                    'stringValue': 'test'
                }
            }
        }
        expected_spec = pipeline_spec_pb2.PipelineJob.RuntimeConfig()
        json_format.ParseDict(expected_dict, expected_spec)

        runtime_config = compiler_utils.build_runtime_config_spec(
            'gs://path', {'input1': 'test'})
        self.assertEqual(expected_spec, runtime_config)
Ejemplo n.º 4
0
  def _create_pipeline_v2(
      self,
      pipeline_func: Callable[..., Any],
      pipeline_root: Optional[str] = None,
      pipeline_name: Optional[str] = None,
      pipeline_parameters_override: Optional[Mapping[str, Any]] = None,
  ) -> pipeline_spec_pb2.PipelineJob:
    """Creates a pipeline instance and constructs the pipeline spec from it.

    Args:
      pipeline_func: Pipeline function with @dsl.pipeline decorator.
      pipeline_root: The root of the pipeline outputs. Optional.
      pipeline_name: The name of the pipeline. Optional.
      pipeline_parameters_override: The mapping from parameter names to values.
        Optional.

    Returns:
      A PipelineJob proto representing the compiled pipeline.
    """

    # Create the arg list with no default values and call pipeline function.
    # Assign type information to the PipelineParam
    pipeline_meta = _python_op._extract_component_interface(pipeline_func)
    pipeline_name = pipeline_name or pipeline_meta.name

    pipeline_root = pipeline_root or getattr(pipeline_func, 'output_directory',
                                             None)
    if not pipeline_root:
      warnings.warn('pipeline_root is None or empty. A valid pipeline_root '
                    'must be provided at job submission.')

    args_list = []
    signature = inspect.signature(pipeline_func)
    for arg_name in signature.parameters:
      arg_type = None
      for pipeline_input in pipeline_meta.inputs or []:
        if arg_name == pipeline_input.name:
          arg_type = pipeline_input.type
          break
      args_list.append(
          dsl.PipelineParam(
              sanitize_k8s_name(arg_name, True), param_type=arg_type))

    with dsl.Pipeline(pipeline_name) as dsl_pipeline:
      pipeline_func(*args_list)

    self._sanitize_and_inject_artifact(dsl_pipeline)

    # Fill in the default values.
    args_list_with_defaults = []
    if pipeline_meta.inputs:
      args_list_with_defaults = [
          dsl.PipelineParam(
              sanitize_k8s_name(input_spec.name, True),
              param_type=input_spec.type,
              value=input_spec.default) for input_spec in pipeline_meta.inputs
      ]

    # Making the pipeline group name unique to prevent name clashes with templates
    pipeline_group = dsl_pipeline.groups[0]
    temp_pipeline_group_name = uuid.uuid4().hex
    pipeline_group.name = temp_pipeline_group_name

    pipeline_spec = self._create_pipeline_spec(
        args_list_with_defaults,
        dsl_pipeline,
    )

    pipeline_parameters = {
        param.name: param for param in args_list_with_defaults
    }
    # Update pipeline parameters override if there were any.
    pipeline_parameters_override = pipeline_parameters_override or {}
    for k, v in pipeline_parameters_override.items():
      if k not in pipeline_parameters:
        raise ValueError('Pipeline parameter {} does not match any known '
                         'pipeline argument.'.format(k))
      pipeline_parameters[k].value = v

    runtime_config = compiler_utils.build_runtime_config_spec(
        output_directory=pipeline_root, pipeline_parameters=pipeline_parameters)
    pipeline_job = pipeline_spec_pb2.PipelineJob(runtime_config=runtime_config)
    pipeline_job.pipeline_spec.update(json_format.MessageToDict(pipeline_spec))

    return pipeline_job
Ejemplo n.º 5
0
    def _create_pipeline(
        self,
        pipeline_func: Callable[..., Any],
        output_directory: str,
        pipeline_name: Optional[str] = None,
        pipeline_parameters_override: Optional[Mapping[str, Any]] = None,
    ) -> pipeline_spec_pb2.PipelineJob:
        """Creates a pipeline instance and constructs the pipeline spec from it.

    Args:
      pipeline_func: Pipeline function with @dsl.pipeline decorator.
      pipeline_name: The name of the pipeline. Optional.
      output_directory: The root of the pipeline outputs.
      pipeline_parameters_override: The mapping from parameter names to values.
        Optional.

    Returns:
      A PipelineJob proto representing the compiled pipeline.
    """

        # Create the arg list with no default values and call pipeline function.
        # Assign type information to the PipelineParam
        pipeline_meta = _python_op._extract_component_interface(pipeline_func)
        pipeline_name = pipeline_name or pipeline_meta.name

        args_list = []
        signature = inspect.signature(pipeline_func)
        for arg_name in signature.parameters:
            arg_type = None
            for pipeline_input in pipeline_meta.inputs or []:
                if arg_name == pipeline_input.name:
                    arg_type = pipeline_input.type
                    break
            args_list.append(
                dsl.PipelineParam(sanitize_k8s_name(arg_name, True),
                                  param_type=arg_type))

        with dsl.Pipeline(pipeline_name) as dsl_pipeline:
            pipeline_func(*args_list)

        # Fill in the default values.
        args_list_with_defaults = []
        if pipeline_meta.inputs:
            args_list_with_defaults = [
                dsl.PipelineParam(sanitize_k8s_name(input_spec.name, True),
                                  param_type=input_spec.type,
                                  value=input_spec.default)
                for input_spec in pipeline_meta.inputs
            ]

        pipeline_spec = self._create_pipeline_spec(
            args_list_with_defaults,
            dsl_pipeline,
        )

        pipeline_parameters = {
            arg.name: arg.value
            for arg in args_list_with_defaults
        }
        # Update pipeline parameters override if there were any.
        pipeline_parameters.update(pipeline_parameters_override or {})
        runtime_config = compiler_utils.build_runtime_config_spec(
            output_directory=output_directory,
            pipeline_parameters=pipeline_parameters)
        pipeline_job = pipeline_spec_pb2.PipelineJob(
            runtime_config=runtime_config)
        pipeline_job.pipeline_spec.update(
            json_format.MessageToDict(pipeline_spec))

        return pipeline_job