Ejemplo n.º 1
0
    def _create_pipeline(
        self,
        pipeline_func: Callable[..., Any],
        pipeline_name: Optional[str] = None,
    ) -> pipeline_spec_pb2.PipelineSpec:
        """Creates a pipeline instance and constructs the pipeline spec from it.

    Args:
      pipeline_func: Pipeline function with @dsl.pipeline decorator.
      pipeline_name: The name of the pipeline. Optional.

    Returns:
      The IR representation (pipeline spec) of the pipeline.
    """

        # Create the arg list with no default values and call pipeline function.
        # Assign type information to the PipelineParam
        pipeline_meta = python_op._extract_component_interface(pipeline_func)
        pipeline_name = pipeline_name or pipeline_meta.name

        args_list = []
        signature = inspect.signature(pipeline_func)
        for arg_name in signature.parameters:
            arg_type = None
            for pipeline_input in pipeline_meta.inputs or []:
                if arg_name == pipeline_input.name:
                    arg_type = pipeline_input.type
                    break
            args_list.append(
                dsl.PipelineParam(sanitize_k8s_name(arg_name, True),
                                  param_type=arg_type))

        with dsl.Pipeline(pipeline_name) as dsl_pipeline:
            pipeline_func(*args_list)

        # Fill in the default values.
        args_list_with_defaults = []
        if pipeline_meta.inputs:
            args_list_with_defaults = [
                dsl.PipelineParam(sanitize_k8s_name(input_spec.name, True),
                                  param_type=input_spec.type,
                                  value=input_spec.default)
                for input_spec in pipeline_meta.inputs
            ]

        pipeline_spec = self._create_pipeline_spec(
            args_list_with_defaults,
            dsl_pipeline,
        )

        return pipeline_spec
    def test_build_task_inputs_spec(self):
        pipeline_params = [
            dsl.PipelineParam(name='output1',
                              param_type='Dataset',
                              op_name='op-1'),
            dsl.PipelineParam(name='output2',
                              param_type='Integer',
                              op_name='op-2'),
            dsl.PipelineParam(name='output3',
                              param_type='Model',
                              op_name='op-3'),
            dsl.PipelineParam(name='output4',
                              param_type='Double',
                              op_name='op-4'),
        ]
        tasks_in_current_dag = ['op-1', 'op-2']
        expected_dict = {
            'inputs': {
                'artifacts': {
                    'op-1-output1': {
                        'taskOutputArtifact': {
                            'producerTask': 'task-op-1',
                            'outputArtifactKey': 'output1'
                        }
                    },
                    'op-3-output3': {
                        'componentInputArtifact': 'op-3-output3'
                    }
                },
                'parameters': {
                    'op-2-output2': {
                        'taskOutputParameter': {
                            'producerTask': 'task-op-2',
                            'outputParameterKey': 'output2'
                        }
                    },
                    'op-4-output4': {
                        'componentInputParameter': 'op-4-output4'
                    }
                }
            }
        }
        expected_spec = pipeline_spec_pb2.PipelineTaskSpec()
        json_format.ParseDict(expected_dict, expected_spec)

        task_spec = pipeline_spec_pb2.PipelineTaskSpec()
        dsl_component_spec.build_task_inputs_spec(task_spec, pipeline_params,
                                                  tasks_in_current_dag)

        self.assertEqual(expected_spec, task_spec)
Ejemplo n.º 3
0
  def test_build_runtime_parameter_spec_with_unsupported_type_should_fail(self):
    pipeline_params = [
        dsl.PipelineParam(name='input1', param_type='Dict'),
    ]

    with self.assertRaisesRegexp(
        TypeError, 'Unsupported type "Dict" for argument "input1"'):
      compiler_utils.build_runtime_parameter_spec(pipeline_params)
    def test_build_runtime_parameter_spec(self):
        pipeline_params = [
            dsl.PipelineParam(name='input1', param_type='Integer', value=99),
            dsl.PipelineParam(name='input2',
                              param_type='String',
                              value='hello'),
            dsl.PipelineParam(name='input3',
                              param_type='Float',
                              value=3.1415926),
            dsl.PipelineParam(name='input4', param_type=None, value=None),
        ]
        expected_dict = {
            'runtimeParameters': {
                'input1': {
                    'type': 'INT',
                    'defaultValue': {
                        'intValue': '99'
                    }
                },
                'input2': {
                    'type': 'STRING',
                    'defaultValue': {
                        'stringValue': 'hello'
                    }
                },
                'input3': {
                    'type': 'DOUBLE',
                    'defaultValue': {
                        'doubleValue': '3.1415926'
                    }
                },
                'input4': {
                    'type': 'STRING'
                }
            }
        }
        expected_spec = pipeline_spec_pb2.PipelineSpec()
        json_format.ParseDict(expected_dict, expected_spec)

        pipeline_spec = pipeline_spec_pb2.PipelineSpec(
            runtime_parameters=compiler_utils.build_runtime_parameter_spec(
                pipeline_params))
        self.maxDiff = None
        self.assertEqual(expected_spec, pipeline_spec)
    def test_build_component_inputs_spec(self):
        pipeline_params = [
            dsl.PipelineParam(name='input1', param_type='Dataset'),
            dsl.PipelineParam(name='input2', param_type='Integer'),
            dsl.PipelineParam(name='input3', param_type='String'),
            dsl.PipelineParam(name='input4', param_type='Float'),
        ]
        expected_dict = {
            'inputDefinitions': {
                'artifacts': {
                    'input1': {
                        'artifactType': {
                            'instanceSchema':
                            'properties:\ntitle: kfp.Dataset\ntype: object\n'
                        }
                    }
                },
                'parameters': {
                    'input2': {
                        'type': 'INT'
                    },
                    'input3': {
                        'type': 'STRING'
                    },
                    'input4': {
                        'type': 'DOUBLE'
                    }
                }
            }
        }
        expected_spec = pipeline_spec_pb2.ComponentSpec()
        json_format.ParseDict(expected_dict, expected_spec)

        component_spec = pipeline_spec_pb2.ComponentSpec()
        dsl_component_spec.build_component_inputs_spec(component_spec,
                                                       pipeline_params)

        self.assertEqual(expected_spec, component_spec)
Ejemplo n.º 6
0
  def _create_pipeline(
      self,
      pipeline_func: Callable[..., Any],
      pipeline_root: Optional[str] = None,
      pipeline_name: Optional[str] = None,
      pipeline_parameters_override: Optional[Mapping[str, Any]] = None,
  ) -> pipeline_spec_pb2.PipelineJob:
    """Creates a pipeline instance and constructs the pipeline spec from it.

    Args:
      pipeline_func: Pipeline function with @dsl.pipeline decorator.
      pipeline_root: The root of the pipeline outputs. Optional.
      pipeline_name: The name of the pipeline. Optional.
      pipeline_parameters_override: The mapping from parameter names to values.
        Optional.

    Returns:
      A PipelineJob proto representing the compiled pipeline.
    """

    # Create the arg list with no default values and call pipeline function.
    # Assign type information to the PipelineParam
    pipeline_meta = _python_op._extract_component_interface(pipeline_func)
    pipeline_name = pipeline_name or pipeline_meta.name

    pipeline_root = pipeline_root or getattr(pipeline_func, 'output_directory',
                                             None)
    if not pipeline_root:
      warnings.warn('pipeline_root is None or empty. A valid pipeline_root '
                    'must be provided at job submission.')

    args_list = []
    signature = inspect.signature(pipeline_func)
    for arg_name in signature.parameters:
      arg_type = None
      for pipeline_input in pipeline_meta.inputs or []:
        if arg_name == pipeline_input.name:
          arg_type = pipeline_input.type
          break
      args_list.append(
          dsl.PipelineParam(
              sanitize_k8s_name(arg_name, True), param_type=arg_type))

    with dsl.Pipeline(pipeline_name) as dsl_pipeline:
      pipeline_func(*args_list)

    # Fill in the default values.
    args_list_with_defaults = []
    if pipeline_meta.inputs:
      args_list_with_defaults = [
          dsl.PipelineParam(
              sanitize_k8s_name(input_spec.name, True),
              param_type=input_spec.type,
              value=input_spec.default) for input_spec in pipeline_meta.inputs
      ]

    # Making the pipeline group name unique to prevent name clashes with templates
    pipeline_group = dsl_pipeline.groups[0]
    temp_pipeline_group_name = uuid.uuid4().hex
    pipeline_group.name = temp_pipeline_group_name

    pipeline_spec = self._create_pipeline_spec(
        args_list_with_defaults,
        dsl_pipeline,
    )

    pipeline_parameters = {
        arg.name: arg.value for arg in args_list_with_defaults
    }
    # Update pipeline parameters override if there were any.
    pipeline_parameters.update(pipeline_parameters_override or {})
    runtime_config = compiler_utils.build_runtime_config_spec(
        output_directory=pipeline_root, pipeline_parameters=pipeline_parameters)
    pipeline_job = pipeline_spec_pb2.PipelineJob(runtime_config=runtime_config)
    pipeline_job.pipeline_spec.update(json_format.MessageToDict(pipeline_spec))

    return pipeline_job
Ejemplo n.º 7
0
    def _create_pipeline(
        self,
        pipeline_func: Callable[..., Any],
        output_directory: str,
        pipeline_name: Optional[str] = None,
        pipeline_parameters_override: Optional[Mapping[str, Any]] = None,
    ) -> pipeline_spec_pb2.PipelineJob:
        """Creates a pipeline instance and constructs the pipeline spec from it.

    Args:
      pipeline_func: Pipeline function with @dsl.pipeline decorator.
      pipeline_name: The name of the pipeline. Optional.
      output_directory: The root of the pipeline outputs.
      pipeline_parameters_override: The mapping from parameter names to values.
        Optional.

    Returns:
      A PipelineJob proto representing the compiled pipeline.
    """

        # Create the arg list with no default values and call pipeline function.
        # Assign type information to the PipelineParam
        pipeline_meta = _python_op._extract_component_interface(pipeline_func)
        pipeline_name = pipeline_name or pipeline_meta.name

        args_list = []
        signature = inspect.signature(pipeline_func)
        for arg_name in signature.parameters:
            arg_type = None
            for pipeline_input in pipeline_meta.inputs or []:
                if arg_name == pipeline_input.name:
                    arg_type = pipeline_input.type
                    break
            args_list.append(
                dsl.PipelineParam(sanitize_k8s_name(arg_name, True),
                                  param_type=arg_type))

        with dsl.Pipeline(pipeline_name) as dsl_pipeline:
            pipeline_func(*args_list)

        # Fill in the default values.
        args_list_with_defaults = []
        if pipeline_meta.inputs:
            args_list_with_defaults = [
                dsl.PipelineParam(sanitize_k8s_name(input_spec.name, True),
                                  param_type=input_spec.type,
                                  value=input_spec.default)
                for input_spec in pipeline_meta.inputs
            ]

        pipeline_spec = self._create_pipeline_spec(
            args_list_with_defaults,
            dsl_pipeline,
        )

        pipeline_parameters = {
            arg.name: arg.value
            for arg in args_list_with_defaults
        }
        # Update pipeline parameters override if there were any.
        pipeline_parameters.update(pipeline_parameters_override or {})
        runtime_config = compiler_utils.build_runtime_config_spec(
            output_directory=output_directory,
            pipeline_parameters=pipeline_parameters)
        pipeline_job = pipeline_spec_pb2.PipelineJob(
            runtime_config=runtime_config)
        pipeline_job.pipeline_spec.update(
            json_format.MessageToDict(pipeline_spec))

        return pipeline_job