def test_remove_task_name_prefix(self): self.assertEqual( 'my-component', dsl_utils.remove_task_name_prefix('task-my-component')) with self.assertRaises(AssertionError): dsl_utils.remove_task_name_prefix('my-component')
def update_task_inputs_spec( task_spec: pipeline_spec_pb2.PipelineTaskSpec, parent_component_inputs: pipeline_spec_pb2.ComponentInputsSpec, pipeline_params: List[_pipeline_param.PipelineParam], tasks_in_current_dag: List[str], input_parameters_in_current_dag: List[str], input_artifacts_in_current_dag: List[str], ) -> None: """Updates task inputs spec. A task input may reference an output outside its immediate DAG. For instance:: random_num = random_num_op(...) with dsl.Condition(random_num.output > 5): print_op('%s > 5' % random_num.output) In this example, `dsl.Condition` forms a sub-DAG with one task from `print_op` inside the sub-DAG. The task of `print_op` references output from `random_num` task, which is outside the sub-DAG. When compiling to IR, such cross DAG reference is disallowed. So we need to "punch a hole" in the sub-DAG to make the input available in the sub-DAG component inputs if it's not already there, Next, we can call this method to fix the tasks inside the sub-DAG to make them reference the component inputs instead of directly referencing the original producer task. Args: task_spec: The task spec to fill in its inputs spec. parent_component_inputs: The input spec of the task's parent component. pipeline_params: The list of pipeline params. tasks_in_current_dag: The list of tasks names for tasks in the same dag. input_parameters_in_current_dag: The list of input parameters in the DAG component. input_artifacts_in_current_dag: The list of input artifacts in the DAG component. """ if not hasattr(task_spec, 'inputs'): return for input_name in getattr(task_spec.inputs, 'parameters', []): if task_spec.inputs.parameters[input_name].WhichOneof( 'kind') == 'task_output_parameter' and ( task_spec.inputs.parameters[input_name]. task_output_parameter.producer_task not in tasks_in_current_dag): param = _pipeline_param.PipelineParam( name=task_spec.inputs.parameters[input_name]. task_output_parameter.output_parameter_key, op_name=dsl_utils.remove_task_name_prefix( task_spec.inputs.parameters[input_name]. task_output_parameter.producer_task)) component_input_parameter = ( additional_input_name_for_pipelineparam(param)) assert component_input_parameter in parent_component_inputs.parameters task_spec.inputs.parameters[ input_name].component_input_parameter = component_input_parameter elif task_spec.inputs.parameters[input_name].WhichOneof( 'kind') == 'component_input_parameter': component_input_parameter = ( task_spec.inputs.parameters[input_name]. component_input_parameter) if component_input_parameter not in input_parameters_in_current_dag: component_input_parameter = ( additional_input_name_for_pipelineparam( task_spec.inputs.parameters[input_name]. component_input_parameter)) assert component_input_parameter in parent_component_inputs.parameters task_spec.inputs.parameters[ input_name].component_input_parameter = component_input_parameter for input_name in getattr(task_spec.inputs, 'artifacts', []): if task_spec.inputs.artifacts[input_name].WhichOneof( 'kind') == 'task_output_artifact' and ( task_spec.inputs.artifacts[input_name].task_output_artifact .producer_task not in tasks_in_current_dag): param = _pipeline_param.PipelineParam( name=task_spec.inputs.artifacts[input_name]. task_output_artifact.output_artifact_key, op_name=dsl_utils.remove_task_name_prefix( task_spec.inputs.artifacts[input_name]. task_output_artifact.producer_task)) component_input_artifact = ( additional_input_name_for_pipelineparam(param)) assert component_input_artifact in parent_component_inputs.artifacts task_spec.inputs.artifacts[ input_name].component_input_artifact = component_input_artifact elif task_spec.inputs.artifacts[input_name].WhichOneof( 'kind') == 'component_input_artifact': component_input_artifact = (task_spec.inputs.artifacts[input_name]. component_input_artifact) if component_input_artifact not in input_artifacts_in_current_dag: component_input_artifact = ( additional_input_name_for_pipelineparam( task_spec.inputs.artifacts[input_name]. component_input_artifact)) assert component_input_artifact in parent_component_inputs.artifacts task_spec.inputs.artifacts[ input_name].component_input_artifact = component_input_artifact
def _update_loop_specs( self, group: dsl.OpsGroup, subgroup: _GroupOrOp, group_component_spec: pipeline_spec_pb2.ComponentSpec, subgroup_component_spec: pipeline_spec_pb2.ComponentSpec, subgroup_task_spec: pipeline_spec_pb2.PipelineTaskSpec, ) -> None: """Update IR specs for loop. Args: group: The dsl.ParallelFor OpsGroup. subgroup: One of the subgroups of dsl.ParallelFor. group_component_spec: The component spec of the group to update in place. subgroup_component_spec: The component spec of the subgroup to update. subgroup_task_spec: The task spec of the subgroup to update. """ input_names = [ input_name for input_name in subgroup_task_spec.inputs.parameters ] for input_name in input_names: if subgroup_task_spec.inputs.parameters[input_name].HasField( 'component_input_parameter'): loop_argument_name = subgroup_task_spec.inputs.parameters[ input_name].component_input_parameter else: producer_task_name = dsl_utils.remove_task_name_prefix( subgroup_task_spec.inputs.parameters[input_name] .task_output_parameter.producer_task) producer_task_output_key = subgroup_task_spec.inputs.parameters[ input_name].task_output_parameter.output_parameter_key loop_argument_name = '{}-{}'.format(producer_task_name, producer_task_output_key) # Loop arguments are from dynamic input: pipeline param or task output if _for_loop.LoopArguments.name_is_withparams_loop_argument( loop_argument_name): arg_and_var_name = ( _for_loop.LoopArgumentVariable .parse_loop_args_name_and_this_var_name(loop_argument_name)) # The current IR representation is insufficient for referencing a subvar # which is a key in a list of dictionaries. if arg_and_var_name: raise NotImplementedError( 'Use subvar in dsl.ParallelFor with dynamic loop arguments is not ' 'supported. Got subvar: {}'.format(arg_and_var_name[1])) assert group.items_is_pipeline_param pipeline_param = group.loop_args.items_or_pipeline_param input_parameter_name = pipeline_param.full_name # Correct loop argument input type in the parent component spec. # The loop argument was categorized as an artifact due to its missing # or non-primitive type annotation. But it should always be String # typed, as its value is a serialized JSON string. dsl_component_spec.pop_input_from_component_spec( group_component_spec, input_parameter_name) group_component_spec.input_definitions.parameters[ input_parameter_name].type = pipeline_spec_pb2.PrimitiveType.STRING subgroup_task_spec.inputs.parameters[ input_parameter_name].component_input_parameter = ( input_parameter_name) subgroup_task_spec.parameter_iterator.item_input = input_name subgroup_task_spec.parameter_iterator.items.input_parameter = ( input_parameter_name) # Loop arguments comme from static raw values known at compile time. elif _for_loop.LoopArguments.name_is_withitems_loop_argument( loop_argument_name): # Prepare the raw values, either the whole list or the sliced list based # on subvar_name. subvar_name = None if _for_loop.LoopArgumentVariable.name_is_loop_arguments_variable( loop_argument_name): subvar_name = _for_loop.LoopArgumentVariable.get_subvar_name( loop_argument_name) loop_args = group.loop_args.to_list_for_task_yaml() if subvar_name: raw_values = [loop_arg.get(subvar_name) for loop_arg in loop_args] else: raw_values = loop_args # If the loop iterator component expects `str` or `int` typed items from # the loop argument, make sure the item values are string values. # This is because both integers and strings are assigned to protobuf # [Value.string_value](https://github.com/protocolbuffers/protobuf/blob/133e5e75263be696c06599ab97614a1e1e6d9c66/src/google/protobuf/struct.proto#L70) # Such a conversion is not needed for `float` type. which uses protobuf # [Value.number_value](https://github.com/protocolbuffers/protobuf/blob/133e5e75263be696c06599ab97614a1e1e6d9c66/src/google/protobuf/struct.proto#L68) if subgroup_component_spec.input_definitions.parameters[ input_name].type in [ pipeline_spec_pb2.PrimitiveType.STRING, pipeline_spec_pb2.PrimitiveType.INT ]: raw_values = [str(v) for v in raw_values] if subgroup_component_spec.input_definitions.parameters[ input_name].type == pipeline_spec_pb2.PrimitiveType.INT: warnings.warn( 'The loop iterator component is expecting an `int` value.' 'Consider changing the input type to either `str` or `float`.') subgroup_task_spec.parameter_iterator.item_input = input_name subgroup_task_spec.parameter_iterator.items.raw = json.dumps(raw_values) else: raise AssertionError( 'Unexpected loop argument: {}'.format(loop_argument_name)) # Clean up unused inputs from task spec and parent component spec. dsl_component_spec.pop_input_from_task_spec(subgroup_task_spec, input_name) dsl_component_spec.pop_input_from_component_spec(group_component_spec, loop_argument_name)