Exemple #1
0
    def test_remove_task_name_prefix(self):
        self.assertEqual(
            'my-component',
            dsl_utils.remove_task_name_prefix('task-my-component'))

        with self.assertRaises(AssertionError):
            dsl_utils.remove_task_name_prefix('my-component')
Exemple #2
0
def update_task_inputs_spec(
    task_spec: pipeline_spec_pb2.PipelineTaskSpec,
    parent_component_inputs: pipeline_spec_pb2.ComponentInputsSpec,
    pipeline_params: List[_pipeline_param.PipelineParam],
    tasks_in_current_dag: List[str],
    input_parameters_in_current_dag: List[str],
    input_artifacts_in_current_dag: List[str],
) -> None:
    """Updates task inputs spec.

  A task input may reference an output outside its immediate DAG.
  For instance::

    random_num = random_num_op(...)
    with dsl.Condition(random_num.output > 5):
      print_op('%s > 5' % random_num.output)

  In this example, `dsl.Condition` forms a sub-DAG with one task from `print_op`
  inside the sub-DAG. The task of `print_op` references output from `random_num`
  task, which is outside the sub-DAG. When compiling to IR, such cross DAG
  reference is disallowed. So we need to "punch a hole" in the sub-DAG to make
  the input available in the sub-DAG component inputs if it's not already there,
  Next, we can call this method to fix the tasks inside the sub-DAG to make them
  reference the component inputs instead of directly referencing the original
  producer task.

  Args:
    task_spec: The task spec to fill in its inputs spec.
    parent_component_inputs: The input spec of the task's parent component.
    pipeline_params: The list of pipeline params.
    tasks_in_current_dag: The list of tasks names for tasks in the same dag.
    input_parameters_in_current_dag: The list of input parameters in the DAG
      component.
    input_artifacts_in_current_dag: The list of input artifacts in the DAG
      component.
  """
    if not hasattr(task_spec, 'inputs'):
        return

    for input_name in getattr(task_spec.inputs, 'parameters', []):

        if task_spec.inputs.parameters[input_name].WhichOneof(
                'kind') == 'task_output_parameter' and (
                    task_spec.inputs.parameters[input_name].
                    task_output_parameter.producer_task
                    not in tasks_in_current_dag):

            param = _pipeline_param.PipelineParam(
                name=task_spec.inputs.parameters[input_name].
                task_output_parameter.output_parameter_key,
                op_name=dsl_utils.remove_task_name_prefix(
                    task_spec.inputs.parameters[input_name].
                    task_output_parameter.producer_task))
            component_input_parameter = (
                additional_input_name_for_pipelineparam(param))
            assert component_input_parameter in parent_component_inputs.parameters

            task_spec.inputs.parameters[
                input_name].component_input_parameter = component_input_parameter

        elif task_spec.inputs.parameters[input_name].WhichOneof(
                'kind') == 'component_input_parameter':

            component_input_parameter = (
                task_spec.inputs.parameters[input_name].
                component_input_parameter)

            if component_input_parameter not in input_parameters_in_current_dag:
                component_input_parameter = (
                    additional_input_name_for_pipelineparam(
                        task_spec.inputs.parameters[input_name].
                        component_input_parameter))
                assert component_input_parameter in parent_component_inputs.parameters

                task_spec.inputs.parameters[
                    input_name].component_input_parameter = component_input_parameter

    for input_name in getattr(task_spec.inputs, 'artifacts', []):

        if task_spec.inputs.artifacts[input_name].WhichOneof(
                'kind') == 'task_output_artifact' and (
                    task_spec.inputs.artifacts[input_name].task_output_artifact
                    .producer_task not in tasks_in_current_dag):

            param = _pipeline_param.PipelineParam(
                name=task_spec.inputs.artifacts[input_name].
                task_output_artifact.output_artifact_key,
                op_name=dsl_utils.remove_task_name_prefix(
                    task_spec.inputs.artifacts[input_name].
                    task_output_artifact.producer_task))
            component_input_artifact = (
                additional_input_name_for_pipelineparam(param))
            assert component_input_artifact in parent_component_inputs.artifacts

            task_spec.inputs.artifacts[
                input_name].component_input_artifact = component_input_artifact

        elif task_spec.inputs.artifacts[input_name].WhichOneof(
                'kind') == 'component_input_artifact':

            component_input_artifact = (task_spec.inputs.artifacts[input_name].
                                        component_input_artifact)

            if component_input_artifact not in input_artifacts_in_current_dag:
                component_input_artifact = (
                    additional_input_name_for_pipelineparam(
                        task_spec.inputs.artifacts[input_name].
                        component_input_artifact))
                assert component_input_artifact in parent_component_inputs.artifacts

                task_spec.inputs.artifacts[
                    input_name].component_input_artifact = component_input_artifact
Exemple #3
0
  def _update_loop_specs(
      self,
      group: dsl.OpsGroup,
      subgroup: _GroupOrOp,
      group_component_spec: pipeline_spec_pb2.ComponentSpec,
      subgroup_component_spec: pipeline_spec_pb2.ComponentSpec,
      subgroup_task_spec: pipeline_spec_pb2.PipelineTaskSpec,
  ) -> None:
    """Update IR specs for loop.

    Args:
      group: The dsl.ParallelFor OpsGroup.
      subgroup: One of the subgroups of dsl.ParallelFor.
      group_component_spec: The component spec of the group to update in place.
      subgroup_component_spec: The component spec of the subgroup to update.
      subgroup_task_spec: The task spec of the subgroup to update.
    """
    input_names = [
        input_name for input_name in subgroup_task_spec.inputs.parameters
    ]
    for input_name in input_names:

      if subgroup_task_spec.inputs.parameters[input_name].HasField(
          'component_input_parameter'):
        loop_argument_name = subgroup_task_spec.inputs.parameters[
            input_name].component_input_parameter
      else:
        producer_task_name = dsl_utils.remove_task_name_prefix(
            subgroup_task_spec.inputs.parameters[input_name]
            .task_output_parameter.producer_task)
        producer_task_output_key = subgroup_task_spec.inputs.parameters[
            input_name].task_output_parameter.output_parameter_key
        loop_argument_name = '{}-{}'.format(producer_task_name,
                                            producer_task_output_key)

      # Loop arguments are from dynamic input: pipeline param or task output
      if _for_loop.LoopArguments.name_is_withparams_loop_argument(
          loop_argument_name):

        arg_and_var_name = (
            _for_loop.LoopArgumentVariable
            .parse_loop_args_name_and_this_var_name(loop_argument_name))
        # The current IR representation is insufficient for referencing a subvar
        # which is a key in a list of dictionaries.
        if arg_and_var_name:
          raise NotImplementedError(
              'Use subvar in dsl.ParallelFor with dynamic loop arguments is not '
              'supported. Got subvar: {}'.format(arg_and_var_name[1]))

        assert group.items_is_pipeline_param
        pipeline_param = group.loop_args.items_or_pipeline_param
        input_parameter_name = pipeline_param.full_name

        # Correct loop argument input type in the parent component spec.
        # The loop argument was categorized as an artifact due to its missing
        # or non-primitive type annotation. But it should always be String
        # typed, as its value is a serialized JSON string.
        dsl_component_spec.pop_input_from_component_spec(
            group_component_spec, input_parameter_name)
        group_component_spec.input_definitions.parameters[
            input_parameter_name].type = pipeline_spec_pb2.PrimitiveType.STRING

        subgroup_task_spec.inputs.parameters[
            input_parameter_name].component_input_parameter = (
                input_parameter_name)
        subgroup_task_spec.parameter_iterator.item_input = input_name
        subgroup_task_spec.parameter_iterator.items.input_parameter = (
            input_parameter_name)

      # Loop arguments comme from static raw values known at compile time.
      elif _for_loop.LoopArguments.name_is_withitems_loop_argument(
          loop_argument_name):

        # Prepare the raw values, either the whole list or the sliced list based
        # on subvar_name.
        subvar_name = None
        if _for_loop.LoopArgumentVariable.name_is_loop_arguments_variable(
            loop_argument_name):
          subvar_name = _for_loop.LoopArgumentVariable.get_subvar_name(
              loop_argument_name)

        loop_args = group.loop_args.to_list_for_task_yaml()
        if subvar_name:
          raw_values = [loop_arg.get(subvar_name) for loop_arg in loop_args]
        else:
          raw_values = loop_args

        # If the loop iterator component expects `str` or `int` typed items from
        # the loop argument, make sure the item values are string values.
        # This is because both integers and strings are assigned to protobuf
        # [Value.string_value](https://github.com/protocolbuffers/protobuf/blob/133e5e75263be696c06599ab97614a1e1e6d9c66/src/google/protobuf/struct.proto#L70)
        # Such a  conversion is not needed for `float` type. which uses protobuf
        # [Value.number_value](https://github.com/protocolbuffers/protobuf/blob/133e5e75263be696c06599ab97614a1e1e6d9c66/src/google/protobuf/struct.proto#L68)
        if subgroup_component_spec.input_definitions.parameters[
            input_name].type in [
                pipeline_spec_pb2.PrimitiveType.STRING,
                pipeline_spec_pb2.PrimitiveType.INT
            ]:
          raw_values = [str(v) for v in raw_values]
          if subgroup_component_spec.input_definitions.parameters[
              input_name].type == pipeline_spec_pb2.PrimitiveType.INT:
            warnings.warn(
                'The loop iterator component is expecting an `int` value.'
                'Consider changing the input type to either `str` or `float`.')

        subgroup_task_spec.parameter_iterator.item_input = input_name
        subgroup_task_spec.parameter_iterator.items.raw = json.dumps(raw_values)

      else:
        raise AssertionError(
            'Unexpected loop argument: {}'.format(loop_argument_name))

      # Clean up unused inputs from task spec and parent component spec.
      dsl_component_spec.pop_input_from_task_spec(subgroup_task_spec,
                                                  input_name)
      dsl_component_spec.pop_input_from_component_spec(group_component_spec,
                                                       loop_argument_name)