def test_build_runtime_parameter_spec_with_unsupported_type_should_fail(self): pipeline_params = [ dsl.PipelineParam(name='input1', param_type='Dict'), ] with self.assertRaisesRegexp( TypeError, 'Unsupported type "Dict" for argument "input1"'): compiler_utils.build_runtime_parameter_spec(pipeline_params)
def test_build_runtime_parameter_spec(self): pipeline_params = [ dsl.PipelineParam(name='input1', param_type='Integer', value=99), dsl.PipelineParam(name='input2', param_type='String', value='hello'), dsl.PipelineParam(name='input3', param_type='Float', value=3.1415926), dsl.PipelineParam(name='input4', param_type=None, value=None), ] expected_dict = { 'runtimeParameters': { 'input1': { 'type': 'INT', 'defaultValue': { 'intValue': '99' } }, 'input2': { 'type': 'STRING', 'defaultValue': { 'stringValue': 'hello' } }, 'input3': { 'type': 'DOUBLE', 'defaultValue': { 'doubleValue': '3.1415926' } }, 'input4': { 'type': 'STRING' } } } expected_spec = pipeline_spec_pb2.PipelineSpec() json_format.ParseDict(expected_dict, expected_spec) pipeline_spec = pipeline_spec_pb2.PipelineSpec( runtime_parameters=compiler_utils.build_runtime_parameter_spec( pipeline_params)) self.maxDiff = None self.assertEqual(expected_spec, pipeline_spec)
def _create_pipeline_spec( self, args: List[dsl.PipelineParam], pipeline: dsl.Pipeline, ) -> pipeline_spec_pb2.PipelineSpec: """Creates the pipeline spec object. Args: args: The list of pipeline arguments. pipeline: The instantiated pipeline object. Returns: A PipelineSpec proto representing the compiled pipeline. Raises: NotImplementedError if the argument is of unsupported types. """ if not pipeline.name: raise ValueError('Pipeline name is required.') pipeline_spec = pipeline_spec_pb2.PipelineSpec( runtime_parameters=compiler_utils.build_runtime_parameter_spec( args)) pipeline_spec.pipeline_info.name = pipeline.name pipeline_spec.sdk_version = 'kfp-{}'.format(kfp.__version__) pipeline_spec.schema_version = 'v2alpha1' deployment_config = pipeline_spec_pb2.PipelineDeploymentConfig() importer_tasks = [] for op in pipeline.ops.values(): component_spec = op._metadata task = pipeline_spec.tasks.add() task.CopyFrom(op.task_spec) deployment_config.executors[ task.executor_label].container.CopyFrom(op.container_spec) # Check if need to insert importer node for input_name in task.inputs.artifacts: if not task.inputs.artifacts[input_name].producer_task: type_schema = type_utils.get_input_artifact_type_schema( input_name, component_spec.inputs) importer_task = importer_node.build_importer_task_spec( dependent_task=task, input_name=input_name, input_type_schema=type_schema) importer_tasks.append(importer_task) task.inputs.artifacts[ input_name].producer_task = importer_task.task_info.name task.inputs.artifacts[ input_name].output_artifact_key = importer_node.OUTPUT_KEY # Retrieve the pre-built importer spec importer_spec = op.importer_spec[input_name] deployment_config.executors[ importer_task.executor_label].importer.CopyFrom( importer_spec) pipeline_spec.deployment_config.Pack(deployment_config) pipeline_spec.tasks.extend(importer_tasks) return pipeline_spec