def test_build_runtime_config_spec(self): expected_dict = { 'gcsOutputDirectory': 'gs://path', 'parameterValues': { 'input1': 'test', 'input2': 2, 'input3': [1, 2, 3] } } expected_spec = pipeline_spec_pb2.PipelineJob.RuntimeConfig() json_format.ParseDict(expected_dict, expected_spec) runtime_config = compiler_utils.build_runtime_config_spec( 'gs://path', { 'input1': _pipeline_param.PipelineParam( name='input1', param_type='String', value='test'), 'input2': _pipeline_param.PipelineParam( name='input2', param_type='Integer', value=2), 'input3': _pipeline_param.PipelineParam( name='input3', param_type='List', value=[1, 2, 3]), 'input4': _pipeline_param.PipelineParam( name='input4', param_type='Double', value=None) }) self.assertEqual(expected_spec, runtime_config)
def _create_pipeline_job( self, pipeline_spec: pipeline_spec_pb2.PipelineSpec, pipeline_root: str, pipeline_parameters: Optional[Mapping[str, Any]] = None, ) -> pipeline_spec_pb2.PipelineJob: """Creates the pipeline job spec object. Args: pipeline_spec: The pipeline spec object. pipeline_root: The root of the pipeline outputs. pipeline_parameters: The mapping from parameter names to values. Optional. Returns: A PipelineJob proto representing the compiled pipeline. """ runtime_config = compiler_utils.build_runtime_config_spec( pipeline_root=pipeline_root, pipeline_parameters=pipeline_parameters) pipeline_job = pipeline_spec_pb2.PipelineJob( runtime_config=runtime_config) pipeline_job.pipeline_spec.update( json_format.MessageToDict(pipeline_spec)) return pipeline_job
def test_build_runtime_config_spec(self): expected_dict = { 'gcsOutputDirectory': 'gs://path', 'parameters': { 'input1': { 'stringValue': 'test' } } } expected_spec = pipeline_spec_pb2.PipelineJob.RuntimeConfig() json_format.ParseDict(expected_dict, expected_spec) runtime_config = compiler_utils.build_runtime_config_spec( 'gs://path', {'input1': 'test'}) self.assertEqual(expected_spec, runtime_config)
def _create_pipeline_v2( self, pipeline_func: Callable[..., Any], pipeline_root: Optional[str] = None, pipeline_name: Optional[str] = None, pipeline_parameters_override: Optional[Mapping[str, Any]] = None, ) -> pipeline_spec_pb2.PipelineJob: """Creates a pipeline instance and constructs the pipeline spec from it. Args: pipeline_func: Pipeline function with @dsl.pipeline decorator. pipeline_root: The root of the pipeline outputs. Optional. pipeline_name: The name of the pipeline. Optional. pipeline_parameters_override: The mapping from parameter names to values. Optional. Returns: A PipelineJob proto representing the compiled pipeline. """ # Create the arg list with no default values and call pipeline function. # Assign type information to the PipelineParam pipeline_meta = _python_op._extract_component_interface(pipeline_func) pipeline_name = pipeline_name or pipeline_meta.name pipeline_root = pipeline_root or getattr(pipeline_func, 'output_directory', None) if not pipeline_root: warnings.warn('pipeline_root is None or empty. A valid pipeline_root ' 'must be provided at job submission.') args_list = [] signature = inspect.signature(pipeline_func) for arg_name in signature.parameters: arg_type = None for pipeline_input in pipeline_meta.inputs or []: if arg_name == pipeline_input.name: arg_type = pipeline_input.type break args_list.append( dsl.PipelineParam( sanitize_k8s_name(arg_name, True), param_type=arg_type)) with dsl.Pipeline(pipeline_name) as dsl_pipeline: pipeline_func(*args_list) self._sanitize_and_inject_artifact(dsl_pipeline) # Fill in the default values. args_list_with_defaults = [] if pipeline_meta.inputs: args_list_with_defaults = [ dsl.PipelineParam( sanitize_k8s_name(input_spec.name, True), param_type=input_spec.type, value=input_spec.default) for input_spec in pipeline_meta.inputs ] # Making the pipeline group name unique to prevent name clashes with templates pipeline_group = dsl_pipeline.groups[0] temp_pipeline_group_name = uuid.uuid4().hex pipeline_group.name = temp_pipeline_group_name pipeline_spec = self._create_pipeline_spec( args_list_with_defaults, dsl_pipeline, ) pipeline_parameters = { param.name: param for param in args_list_with_defaults } # Update pipeline parameters override if there were any. pipeline_parameters_override = pipeline_parameters_override or {} for k, v in pipeline_parameters_override.items(): if k not in pipeline_parameters: raise ValueError('Pipeline parameter {} does not match any known ' 'pipeline argument.'.format(k)) pipeline_parameters[k].value = v runtime_config = compiler_utils.build_runtime_config_spec( output_directory=pipeline_root, pipeline_parameters=pipeline_parameters) pipeline_job = pipeline_spec_pb2.PipelineJob(runtime_config=runtime_config) pipeline_job.pipeline_spec.update(json_format.MessageToDict(pipeline_spec)) return pipeline_job
def _create_pipeline( self, pipeline_func: Callable[..., Any], output_directory: str, pipeline_name: Optional[str] = None, pipeline_parameters_override: Optional[Mapping[str, Any]] = None, ) -> pipeline_spec_pb2.PipelineJob: """Creates a pipeline instance and constructs the pipeline spec from it. Args: pipeline_func: Pipeline function with @dsl.pipeline decorator. pipeline_name: The name of the pipeline. Optional. output_directory: The root of the pipeline outputs. pipeline_parameters_override: The mapping from parameter names to values. Optional. Returns: A PipelineJob proto representing the compiled pipeline. """ # Create the arg list with no default values and call pipeline function. # Assign type information to the PipelineParam pipeline_meta = _python_op._extract_component_interface(pipeline_func) pipeline_name = pipeline_name or pipeline_meta.name args_list = [] signature = inspect.signature(pipeline_func) for arg_name in signature.parameters: arg_type = None for pipeline_input in pipeline_meta.inputs or []: if arg_name == pipeline_input.name: arg_type = pipeline_input.type break args_list.append( dsl.PipelineParam(sanitize_k8s_name(arg_name, True), param_type=arg_type)) with dsl.Pipeline(pipeline_name) as dsl_pipeline: pipeline_func(*args_list) # Fill in the default values. args_list_with_defaults = [] if pipeline_meta.inputs: args_list_with_defaults = [ dsl.PipelineParam(sanitize_k8s_name(input_spec.name, True), param_type=input_spec.type, value=input_spec.default) for input_spec in pipeline_meta.inputs ] pipeline_spec = self._create_pipeline_spec( args_list_with_defaults, dsl_pipeline, ) pipeline_parameters = { arg.name: arg.value for arg in args_list_with_defaults } # Update pipeline parameters override if there were any. pipeline_parameters.update(pipeline_parameters_override or {}) runtime_config = compiler_utils.build_runtime_config_spec( output_directory=output_directory, pipeline_parameters=pipeline_parameters) pipeline_job = pipeline_spec_pb2.PipelineJob( runtime_config=runtime_config) pipeline_job.pipeline_spec.update( json_format.MessageToDict(pipeline_spec)) return pipeline_job