コード例 #1
0
    def test_build_runtime_parameter_spec(self):
        pipeline_params = [
            dsl.PipelineParam(name='input1', param_type='Integer', value=99),
            dsl.PipelineParam(name='input2',
                              param_type='String',
                              value='hello'),
            dsl.PipelineParam(name='input3',
                              param_type='Float',
                              value=3.1415926),
            dsl.PipelineParam(name='input4', param_type=None, value=None),
        ]
        expected_dict = {
            'runtimeParameters': {
                'input1': {
                    'type': 'INT',
                    'defaultValue': {
                        'intValue': '99'
                    }
                },
                'input2': {
                    'type': 'STRING',
                    'defaultValue': {
                        'stringValue': 'hello'
                    }
                },
                'input3': {
                    'type': 'DOUBLE',
                    'defaultValue': {
                        'doubleValue': '3.1415926'
                    }
                },
                'input4': {
                    'type': 'STRING'
                }
            }
        }
        expected_spec = pipeline_spec_pb2.PipelineSpec()
        json_format.ParseDict(expected_dict, expected_spec)

        pipeline_spec = pipeline_spec_pb2.PipelineSpec(
            runtime_parameters=compiler_utils.build_runtime_parameter_spec(
                pipeline_params))
        self.maxDiff = None
        self.assertEqual(expected_spec, pipeline_spec)
コード例 #2
0
ファイル: compiler.py プロジェクト: tanguycdls/pipelines
    def _create_pipeline_spec(
        self,
        args: List[dsl.PipelineParam],
        pipeline: dsl.Pipeline,
    ) -> pipeline_spec_pb2.PipelineSpec:
        """Creates the pipeline spec object.

    Args:
      args: The list of pipeline arguments.
      pipeline: The instantiated pipeline object.

    Returns:
      A PipelineSpec proto representing the compiled pipeline.

    Raises:
      NotImplementedError if the argument is of unsupported types.
    """
        if not pipeline.name:
            raise ValueError('Pipeline name is required.')

        pipeline_spec = pipeline_spec_pb2.PipelineSpec(
            runtime_parameters=compiler_utils.build_runtime_parameter_spec(
                args))

        pipeline_spec.pipeline_info.name = pipeline.name
        pipeline_spec.sdk_version = 'kfp-{}'.format(kfp.__version__)
        pipeline_spec.schema_version = 'v2alpha1'

        deployment_config = pipeline_spec_pb2.PipelineDeploymentConfig()
        importer_tasks = []

        for op in pipeline.ops.values():
            component_spec = op._metadata
            task = pipeline_spec.tasks.add()
            task.CopyFrom(op.task_spec)
            deployment_config.executors[
                task.executor_label].container.CopyFrom(op.container_spec)

            # Check if need to insert importer node
            for input_name in task.inputs.artifacts:
                if not task.inputs.artifacts[input_name].producer_task:
                    type_schema = type_utils.get_input_artifact_type_schema(
                        input_name, component_spec.inputs)

                    importer_task = importer_node.build_importer_task_spec(
                        dependent_task=task,
                        input_name=input_name,
                        input_type_schema=type_schema)
                    importer_tasks.append(importer_task)

                    task.inputs.artifacts[
                        input_name].producer_task = importer_task.task_info.name
                    task.inputs.artifacts[
                        input_name].output_artifact_key = importer_node.OUTPUT_KEY

                    # Retrieve the pre-built importer spec
                    importer_spec = op.importer_spec[input_name]
                    deployment_config.executors[
                        importer_task.executor_label].importer.CopyFrom(
                            importer_spec)

        pipeline_spec.deployment_config.Pack(deployment_config)
        pipeline_spec.tasks.extend(importer_tasks)

        return pipeline_spec
コード例 #3
0
    def _create_pipeline_spec(
        self,
        args: List[dsl.PipelineParam],
        pipeline: dsl.Pipeline,
    ) -> pipeline_spec_pb2.PipelineSpec:
        """Creates the pipeline spec object.

    Args:
      args: The list of pipeline arguments.
      pipeline: The instantiated pipeline object.

    Returns:
      A PipelineSpec proto representing the compiled pipeline.

    Raises:
      NotImplementedError if the argument is of unsupported types.
    """
        pipeline_spec = pipeline_spec_pb2.PipelineSpec()
        if not pipeline.name:
            raise ValueError('Pipeline name is required.')
        pipeline_spec.pipeline_info.name = pipeline.name
        pipeline_spec.sdk_version = 'kfp-{}'.format(kfp.__version__)
        pipeline_spec.schema_version = 'v2alpha1'

        # Pipeline Parameters
        for arg in args:
            if arg.value is not None:
                if isinstance(arg.value, int):
                    pipeline_spec.runtime_parameters[
                        arg.name].type = pipeline_spec_pb2.PrimitiveType.INT
                    pipeline_spec.runtime_parameters[
                        arg.name].default_value.int_value = arg.value
                elif isinstance(arg.value, float):
                    pipeline_spec.runtime_parameters[
                        arg.name].type = pipeline_spec_pb2.PrimitiveType.DOUBLE
                    pipeline_spec.runtime_parameters[
                        arg.name].default_value.double_value = arg.value
                elif isinstance(arg.value, str):
                    pipeline_spec.runtime_parameters[
                        arg.name].type = pipeline_spec_pb2.PrimitiveType.STRING
                    pipeline_spec.runtime_parameters[
                        arg.name].default_value.string_value = arg.value
                else:
                    raise NotImplementedError(
                        'Unexpected parameter type with: "{}".'.format(
                            str(arg.value)))

        deployment_config = pipeline_spec_pb2.PipelineDeploymentConfig()
        importer_tasks = []

        for op in pipeline.ops.values():
            component_spec = op._metadata
            task = pipeline_spec.tasks.add()
            task.CopyFrom(op.task_spec)
            deployment_config.executors[
                task.executor_label].container.CopyFrom(op.container_spec)

            # Check if need to insert importer node
            for input_name in task.inputs.artifacts:
                if not task.inputs.artifacts[input_name].producer_task:
                    artifact_type = type_utils.get_input_artifact_type_schema(
                        input_name, component_spec.inputs)

                    importer_task, importer_spec = importer_node.build_importer_spec(
                        task, input_name, artifact_type)
                    importer_tasks.append(importer_task)

                    task.inputs.artifacts[
                        input_name].producer_task = importer_task.task_info.name
                    task.inputs.artifacts[
                        input_name].output_artifact_key = importer_node.OUTPUT_KEY

                    deployment_config.executors[
                        importer_task.executor_label].importer.CopyFrom(
                            importer_spec)

        pipeline_spec.deployment_config.Pack(deployment_config)
        pipeline_spec.tasks.extend(importer_tasks)

        return pipeline_spec