def test_build_runtime_parameter_spec(self): pipeline_params = [ dsl.PipelineParam(name='input1', param_type='Integer', value=99), dsl.PipelineParam(name='input2', param_type='String', value='hello'), dsl.PipelineParam(name='input3', param_type='Float', value=3.1415926), dsl.PipelineParam(name='input4', param_type=None, value=None), ] expected_dict = { 'runtimeParameters': { 'input1': { 'type': 'INT', 'defaultValue': { 'intValue': '99' } }, 'input2': { 'type': 'STRING', 'defaultValue': { 'stringValue': 'hello' } }, 'input3': { 'type': 'DOUBLE', 'defaultValue': { 'doubleValue': '3.1415926' } }, 'input4': { 'type': 'STRING' } } } expected_spec = pipeline_spec_pb2.PipelineSpec() json_format.ParseDict(expected_dict, expected_spec) pipeline_spec = pipeline_spec_pb2.PipelineSpec( runtime_parameters=compiler_utils.build_runtime_parameter_spec( pipeline_params)) self.maxDiff = None self.assertEqual(expected_spec, pipeline_spec)
def _create_pipeline_spec( self, args: List[dsl.PipelineParam], pipeline: dsl.Pipeline, ) -> pipeline_spec_pb2.PipelineSpec: """Creates the pipeline spec object. Args: args: The list of pipeline arguments. pipeline: The instantiated pipeline object. Returns: A PipelineSpec proto representing the compiled pipeline. Raises: NotImplementedError if the argument is of unsupported types. """ if not pipeline.name: raise ValueError('Pipeline name is required.') pipeline_spec = pipeline_spec_pb2.PipelineSpec( runtime_parameters=compiler_utils.build_runtime_parameter_spec( args)) pipeline_spec.pipeline_info.name = pipeline.name pipeline_spec.sdk_version = 'kfp-{}'.format(kfp.__version__) pipeline_spec.schema_version = 'v2alpha1' deployment_config = pipeline_spec_pb2.PipelineDeploymentConfig() importer_tasks = [] for op in pipeline.ops.values(): component_spec = op._metadata task = pipeline_spec.tasks.add() task.CopyFrom(op.task_spec) deployment_config.executors[ task.executor_label].container.CopyFrom(op.container_spec) # Check if need to insert importer node for input_name in task.inputs.artifacts: if not task.inputs.artifacts[input_name].producer_task: type_schema = type_utils.get_input_artifact_type_schema( input_name, component_spec.inputs) importer_task = importer_node.build_importer_task_spec( dependent_task=task, input_name=input_name, input_type_schema=type_schema) importer_tasks.append(importer_task) task.inputs.artifacts[ input_name].producer_task = importer_task.task_info.name task.inputs.artifacts[ input_name].output_artifact_key = importer_node.OUTPUT_KEY # Retrieve the pre-built importer spec importer_spec = op.importer_spec[input_name] deployment_config.executors[ importer_task.executor_label].importer.CopyFrom( importer_spec) pipeline_spec.deployment_config.Pack(deployment_config) pipeline_spec.tasks.extend(importer_tasks) return pipeline_spec
def _create_pipeline_spec( self, args: List[dsl.PipelineParam], pipeline: dsl.Pipeline, ) -> pipeline_spec_pb2.PipelineSpec: """Creates the pipeline spec object. Args: args: The list of pipeline arguments. pipeline: The instantiated pipeline object. Returns: A PipelineSpec proto representing the compiled pipeline. Raises: NotImplementedError if the argument is of unsupported types. """ pipeline_spec = pipeline_spec_pb2.PipelineSpec() if not pipeline.name: raise ValueError('Pipeline name is required.') pipeline_spec.pipeline_info.name = pipeline.name pipeline_spec.sdk_version = 'kfp-{}'.format(kfp.__version__) pipeline_spec.schema_version = 'v2alpha1' # Pipeline Parameters for arg in args: if arg.value is not None: if isinstance(arg.value, int): pipeline_spec.runtime_parameters[ arg.name].type = pipeline_spec_pb2.PrimitiveType.INT pipeline_spec.runtime_parameters[ arg.name].default_value.int_value = arg.value elif isinstance(arg.value, float): pipeline_spec.runtime_parameters[ arg.name].type = pipeline_spec_pb2.PrimitiveType.DOUBLE pipeline_spec.runtime_parameters[ arg.name].default_value.double_value = arg.value elif isinstance(arg.value, str): pipeline_spec.runtime_parameters[ arg.name].type = pipeline_spec_pb2.PrimitiveType.STRING pipeline_spec.runtime_parameters[ arg.name].default_value.string_value = arg.value else: raise NotImplementedError( 'Unexpected parameter type with: "{}".'.format( str(arg.value))) deployment_config = pipeline_spec_pb2.PipelineDeploymentConfig() importer_tasks = [] for op in pipeline.ops.values(): component_spec = op._metadata task = pipeline_spec.tasks.add() task.CopyFrom(op.task_spec) deployment_config.executors[ task.executor_label].container.CopyFrom(op.container_spec) # Check if need to insert importer node for input_name in task.inputs.artifacts: if not task.inputs.artifacts[input_name].producer_task: artifact_type = type_utils.get_input_artifact_type_schema( input_name, component_spec.inputs) importer_task, importer_spec = importer_node.build_importer_spec( task, input_name, artifact_type) importer_tasks.append(importer_task) task.inputs.artifacts[ input_name].producer_task = importer_task.task_info.name task.inputs.artifacts[ input_name].output_artifact_key = importer_node.OUTPUT_KEY deployment_config.executors[ importer_task.executor_label].importer.CopyFrom( importer_spec) pipeline_spec.deployment_config.Pack(deployment_config) pipeline_spec.tasks.extend(importer_tasks) return pipeline_spec