示例#1
0
  def test_build_runtime_parameter_spec_with_unsupported_type_should_fail(self):
    pipeline_params = [
        dsl.PipelineParam(name='input1', param_type='Dict'),
    ]

    with self.assertRaisesRegexp(
        TypeError, 'Unsupported type "Dict" for argument "input1"'):
      compiler_utils.build_runtime_parameter_spec(pipeline_params)
    def test_build_runtime_parameter_spec(self):
        pipeline_params = [
            dsl.PipelineParam(name='input1', param_type='Integer', value=99),
            dsl.PipelineParam(name='input2',
                              param_type='String',
                              value='hello'),
            dsl.PipelineParam(name='input3',
                              param_type='Float',
                              value=3.1415926),
            dsl.PipelineParam(name='input4', param_type=None, value=None),
        ]
        expected_dict = {
            'runtimeParameters': {
                'input1': {
                    'type': 'INT',
                    'defaultValue': {
                        'intValue': '99'
                    }
                },
                'input2': {
                    'type': 'STRING',
                    'defaultValue': {
                        'stringValue': 'hello'
                    }
                },
                'input3': {
                    'type': 'DOUBLE',
                    'defaultValue': {
                        'doubleValue': '3.1415926'
                    }
                },
                'input4': {
                    'type': 'STRING'
                }
            }
        }
        expected_spec = pipeline_spec_pb2.PipelineSpec()
        json_format.ParseDict(expected_dict, expected_spec)

        pipeline_spec = pipeline_spec_pb2.PipelineSpec(
            runtime_parameters=compiler_utils.build_runtime_parameter_spec(
                pipeline_params))
        self.maxDiff = None
        self.assertEqual(expected_spec, pipeline_spec)
示例#3
0
    def _create_pipeline_spec(
        self,
        args: List[dsl.PipelineParam],
        pipeline: dsl.Pipeline,
    ) -> pipeline_spec_pb2.PipelineSpec:
        """Creates the pipeline spec object.

    Args:
      args: The list of pipeline arguments.
      pipeline: The instantiated pipeline object.

    Returns:
      A PipelineSpec proto representing the compiled pipeline.

    Raises:
      NotImplementedError if the argument is of unsupported types.
    """
        if not pipeline.name:
            raise ValueError('Pipeline name is required.')

        pipeline_spec = pipeline_spec_pb2.PipelineSpec(
            runtime_parameters=compiler_utils.build_runtime_parameter_spec(
                args))

        pipeline_spec.pipeline_info.name = pipeline.name
        pipeline_spec.sdk_version = 'kfp-{}'.format(kfp.__version__)
        pipeline_spec.schema_version = 'v2alpha1'

        deployment_config = pipeline_spec_pb2.PipelineDeploymentConfig()
        importer_tasks = []

        for op in pipeline.ops.values():
            component_spec = op._metadata
            task = pipeline_spec.tasks.add()
            task.CopyFrom(op.task_spec)
            deployment_config.executors[
                task.executor_label].container.CopyFrom(op.container_spec)

            # Check if need to insert importer node
            for input_name in task.inputs.artifacts:
                if not task.inputs.artifacts[input_name].producer_task:
                    type_schema = type_utils.get_input_artifact_type_schema(
                        input_name, component_spec.inputs)

                    importer_task = importer_node.build_importer_task_spec(
                        dependent_task=task,
                        input_name=input_name,
                        input_type_schema=type_schema)
                    importer_tasks.append(importer_task)

                    task.inputs.artifacts[
                        input_name].producer_task = importer_task.task_info.name
                    task.inputs.artifacts[
                        input_name].output_artifact_key = importer_node.OUTPUT_KEY

                    # Retrieve the pre-built importer spec
                    importer_spec = op.importer_spec[input_name]
                    deployment_config.executors[
                        importer_task.executor_label].importer.CopyFrom(
                            importer_spec)

        pipeline_spec.deployment_config.Pack(deployment_config)
        pipeline_spec.tasks.extend(importer_tasks)

        return pipeline_spec