def build_importer_task_spec(
    importer_base_name: str, ) -> pipeline_spec_pb2.PipelineTaskSpec:
    """Builds an importer task spec.

  Args:
    importer_base_name: The base name of the importer node.

  Returns:
    An importer node task spec.
  """
    result = pipeline_spec_pb2.PipelineTaskSpec()
    result.task_info.name = dsl_utils.sanitize_task_name(importer_base_name)
    result.component_ref.name = dsl_utils.sanitize_component_name(
        importer_base_name)

    return result
Example #2
0
def _get_custom_job_op(
    task_name: str,
    job_spec: Dict[str, Any],
    input_artifacts: Optional[Dict[str, dsl.PipelineParam]] = None,
    input_parameters: Optional[Dict[str, _ValueOrPipelineParam]] = None,
    output_artifacts: Optional[Dict[str, Type[artifact.Artifact]]] = None,
    output_parameters: Optional[Dict[str, Any]] = None,
) -> AiPlatformCustomJobOp:
    """Gets an AiPlatformCustomJobOp from job spec and I/O definition."""
    pipeline_task_spec = pipeline_spec_pb2.PipelineTaskSpec()
    pipeline_component_spec = pipeline_spec_pb2.ComponentSpec()

    pipeline_task_spec.task_info.CopyFrom(
        pipeline_spec_pb2.PipelineTaskInfo(name=task_name))

    # Iterate through the inputs/outputs declaration to get pipeline component
    # spec.
    for input_name, param in input_parameters.items():
        if isinstance(param, dsl.PipelineParam):
            pipeline_component_spec.input_definitions.parameters[
                input_name].type = type_utils.get_parameter_type(
                    param.param_type)
        else:
            pipeline_component_spec.input_definitions.parameters[
                input_name].type = type_utils.get_parameter_type(type(param))

    for input_name, art in input_artifacts.items():
        if not isinstance(art, dsl.PipelineParam):
            raise RuntimeError(
                'Get unresolved input artifact for input %s. Input '
                'artifacts must be connected to a producer task.' % input_name)
        pipeline_component_spec.input_definitions.artifacts[
            input_name].artifact_type.CopyFrom(
                type_utils.get_artifact_type_schema_message(art.param_type))

    for output_name, param_type in output_parameters.items():
        pipeline_component_spec.output_definitions.parameters[
            output_name].type = type_utils.get_parameter_type(param_type)

    for output_name, artifact_type in output_artifacts.items():
        pipeline_component_spec.output_definitions.artifacts[
            output_name].artifact_type.CopyFrom(artifact_type.get_ir_type())

    pipeline_component_spec.executor_label = dsl_utils.sanitize_executor_label(
        task_name)

    # Iterate through the inputs/outputs specs to get pipeline task spec.
    for input_name, param in input_parameters.items():
        if isinstance(param, dsl.PipelineParam) and param.op_name:
            # If the param has a valid op_name, this should be a pipeline parameter
            # produced by an upstream task.
            pipeline_task_spec.inputs.parameters[input_name].CopyFrom(
                pipeline_spec_pb2.TaskInputsSpec.InputParameterSpec(
                    task_output_parameter=pipeline_spec_pb2.TaskInputsSpec.
                    InputParameterSpec.TaskOutputParameterSpec(
                        producer_task='task-{}'.format(param.op_name),
                        output_parameter_key=param.name)))
        elif isinstance(param, dsl.PipelineParam) and not param.op_name:
            # If a valid op_name is missing, this should be a pipeline parameter.
            pipeline_task_spec.inputs.parameters[input_name].CopyFrom(
                pipeline_spec_pb2.TaskInputsSpec.InputParameterSpec(
                    component_input_parameter=param.name))
        else:
            # If this is not a pipeline param, then it should be a value.
            pipeline_task_spec.inputs.parameters[input_name].CopyFrom(
                pipeline_spec_pb2.TaskInputsSpec.InputParameterSpec(
                    runtime_value=pipeline_spec_pb2.ValueOrRuntimeParameter(
                        constant_value=dsl_utils.get_value(param))))

    for input_name, art in input_artifacts.items():
        if art.op_name:
            # If the param has a valid op_name, this should be an artifact produced
            # by an upstream task.
            pipeline_task_spec.inputs.artifacts[input_name].CopyFrom(
                pipeline_spec_pb2.TaskInputsSpec.InputArtifactSpec(
                    task_output_artifact=pipeline_spec_pb2.TaskInputsSpec.
                    InputArtifactSpec.TaskOutputArtifactSpec(
                        producer_task='task-{}'.format(art.op_name),
                        output_artifact_key=art.name)))
        else:
            # Otherwise, this should be from the input of the subdag.
            pipeline_task_spec.inputs.artifacts[input_name].CopyFrom(
                pipeline_spec_pb2.TaskInputsSpec.InputArtifactSpec(
                    component_input_artifact=art.name))

    # TODO: Add task dependencies/trigger policies/caching/iterator
    pipeline_task_spec.component_ref.name = dsl_utils.sanitize_component_name(
        task_name)

    # Construct dummy I/O declaration for the op.
    # TODO: resolve name conflict instead of raising errors.
    dummy_outputs = collections.OrderedDict()
    for output_name, _ in output_artifacts.items():
        dummy_outputs[output_name] = _DUMMY_PATH

    for output_name, _ in output_parameters.items():
        if output_name in dummy_outputs:
            raise KeyError(
                'Got name collision for output key %s. Consider renaming '
                'either output parameters or output '
                'artifacts.' % output_name)
        dummy_outputs[output_name] = _DUMMY_PATH

    dummy_inputs = collections.OrderedDict()
    for input_name, art in input_artifacts.items():
        dummy_inputs[input_name] = _DUMMY_PATH
    for input_name, param in input_parameters.items():
        if input_name in dummy_inputs:
            raise KeyError(
                'Got name collision for input key %s. Consider renaming '
                'either input parameters or input '
                'artifacts.' % input_name)
        dummy_inputs[input_name] = _DUMMY_PATH

    # Construct the AIP (Unified) custom job op.
    return AiPlatformCustomJobOp(
        name=task_name,
        custom_job_spec=job_spec,
        component_spec=pipeline_component_spec,
        task_spec=pipeline_task_spec,
        task_inputs=[
            dsl.InputArgumentPath(
                argument=dummy_inputs[input_name],
                input=input_name,
                path=path,
            ) for input_name, path in dummy_inputs.items()
        ],
        task_outputs=dummy_outputs)
Example #3
0
  def _group_to_dag_spec(
      self,
      group: dsl.OpsGroup,
      inputs: Dict[str, List[Tuple[dsl.PipelineParam, str]]],
      outputs: Dict[str, List[Tuple[dsl.PipelineParam, str]]],
      dependencies: Dict[str, List[_GroupOrOp]],
      pipeline_spec: pipeline_spec_pb2.PipelineSpec,
      deployment_config: pipeline_spec_pb2.PipelineDeploymentConfig,
      rootgroup_name: str,
  ) -> None:
    """Generate IR spec given an OpsGroup.

    Args:
      group: The OpsGroup to generate spec for.
      inputs: The inputs dictionary. The keys are group/op names and values are
        lists of tuples (param, producing_op_name).
      outputs: The outputs dictionary. The keys are group/op names and values
        are lists of tuples (param, producing_op_name).
      dependencies: The group dependencies dictionary. The keys are group/op
        names, and the values are lists of dependent groups/ops.
      pipeline_spec: The pipeline_spec to update in-place.
      deployment_config: The deployment_config to hold all executors.
      rootgroup_name: The name of the group root. Used to determine whether the
        component spec for the current group should be the root dag.
    """
    group_component_name = dsl_utils.sanitize_component_name(group.name)

    if group.name == rootgroup_name:
      group_component_spec = pipeline_spec.root
    else:
      group_component_spec = pipeline_spec.components[group_component_name]

    # Generate component inputs spec.
    if inputs.get(group.name, None):
      dsl_component_spec.build_component_inputs_spec(
          group_component_spec, [param for param, _ in inputs[group.name]])

    # Generate component outputs spec.
    if outputs.get(group.name, None):
      group_component_spec.output_definitions.CopyFrom(
          dsl_component_spec.build_component_outputs_spec(
              [param for param, _ in outputs[group.name]]))

    # Generate task specs and component specs for the dag.
    subgroups = group.groups + group.ops
    for subgroup in subgroups:
      subgroup_task_spec = getattr(subgroup, 'task_spec',
                                   pipeline_spec_pb2.PipelineTaskSpec())
      subgroup_component_spec = getattr(subgroup, 'component_spec',
                                        pipeline_spec_pb2.ComponentSpec())
      is_loop_subgroup = (isinstance(group, dsl.ParallelFor))
      is_recursive_subgroup = (
          isinstance(subgroup, dsl.OpsGroup) and subgroup.recursive_ref)

      # Special handling for recursive subgroup: use the existing opsgroup name
      if is_recursive_subgroup:
        subgroup_key = subgroup.recursive_ref.name
      else:
        subgroup_key = subgroup.name

      subgroup_task_spec.task_info.name = dsl_utils.sanitize_task_name(
          subgroup_key)
      # human_name exists for ops only, and is used to de-dupe component spec.
      subgroup_component_name = dsl_utils.sanitize_component_name(
          getattr(subgroup, 'human_name', subgroup_key))
      subgroup_task_spec.component_ref.name = subgroup_component_name

      if isinstance(subgroup, dsl.OpsGroup) and subgroup.type == 'condition':
        condition = subgroup.condition
        operand_values = []
        subgroup_inputs = inputs.get(subgroup.name, [])
        subgroup_params = [param for param, _ in subgroup_inputs]
        tasks_in_current_dag = [subgroup.name for subgroup in subgroups]

        dsl_component_spec.build_component_inputs_spec(
            subgroup_component_spec,
            subgroup_params,
        )
        dsl_component_spec.build_task_inputs_spec(
            subgroup_task_spec,
            subgroup_params,
            tasks_in_current_dag,
        )

        for operand in [condition.operand1, condition.operand2]:
          operand_values.append(self._resolve_value_or_reference(operand))

        condition_string = '{} {} {}'.format(operand_values[0],
                                             condition.operator,
                                             operand_values[1])

        subgroup_task_spec.trigger_policy.CopyFrom(
            pipeline_spec_pb2.PipelineTaskSpec.TriggerPolicy(
                condition=condition_string))

      # Generate dependencies section for this task.
      if dependencies.get(subgroup.name, None):
        group_dependencies = list(dependencies[subgroup.name])
        group_dependencies.sort()
        subgroup_task_spec.dependent_tasks.extend(
            [dsl_utils.sanitize_task_name(dep) for dep in group_dependencies])

      if isinstance(subgroup, dsl.ParallelFor):
        # Remove loop arguments related inputs from parent group component spec.
        input_names = [param.full_name for param, _ in inputs[subgroup.name]]
        for input_name in input_names:
          if _for_loop.LoopArguments.name_is_loop_argument(input_name):
            dsl_component_spec.pop_input_from_component_spec(
                group_component_spec, input_name)

        if subgroup.items_is_pipeline_param:
          # These loop args are a 'withParam' rather than 'withItems'.
          # i.e., rather than a static list, they are either the output of
          # another task or were input as global pipeline parameters.

          pipeline_param = subgroup.loop_args.items_or_pipeline_param
          input_parameter_name = pipeline_param.full_name

          if pipeline_param.op_name:
            subgroup_task_spec.inputs.parameters[
                input_parameter_name].task_output_parameter.producer_task = (
                    dsl_utils.sanitize_task_name(pipeline_param.op_name))
            subgroup_task_spec.inputs.parameters[
                input_parameter_name].task_output_parameter.output_parameter_key = (
                    pipeline_param.name)
          else:
            subgroup_task_spec.inputs.parameters[
                input_parameter_name].component_input_parameter = (
                    input_parameter_name)

          # Correct loop argument input type in the parent component spec.
          # The loop argument was categorized as an artifact due to its missing
          # or non-primitive type annotation. But it should always be String
          # typed, as its value is a serialized JSON string.
          dsl_component_spec.pop_input_from_component_spec(
              group_component_spec, input_parameter_name)
          group_component_spec.input_definitions.parameters[
              input_parameter_name].type = pipeline_spec_pb2.PrimitiveType.STRING

      # Additional spec modifications for dsl.ParallelFor's subgroups.
      if is_loop_subgroup:
        self._update_loop_specs(group, subgroup, group_component_spec,
                                subgroup_component_spec, subgroup_task_spec)

      # Add importer node when applicable
      for input_name in subgroup_task_spec.inputs.artifacts:
        if not subgroup_task_spec.inputs.artifacts[
            input_name].task_output_artifact.producer_task:
          type_schema = type_utils.get_input_artifact_type_schema(
              input_name, subgroup._metadata.inputs)

          importer_name = importer_node.generate_importer_base_name(
              dependent_task_name=subgroup_task_spec.task_info.name,
              input_name=input_name)
          importer_task_spec = importer_node.build_importer_task_spec(
              importer_name)
          importer_comp_spec = importer_node.build_importer_component_spec(
              importer_base_name=importer_name,
              input_name=input_name,
              input_type_schema=type_schema)
          importer_task_name = importer_task_spec.task_info.name
          importer_comp_name = importer_task_spec.component_ref.name
          importer_exec_label = importer_comp_spec.executor_label
          group_component_spec.dag.tasks[importer_task_name].CopyFrom(
              importer_task_spec)
          pipeline_spec.components[importer_comp_name].CopyFrom(
              importer_comp_spec)

          subgroup_task_spec.inputs.artifacts[
              input_name].task_output_artifact.producer_task = (
                  importer_task_name)
          subgroup_task_spec.inputs.artifacts[
              input_name].task_output_artifact.output_artifact_key = (
                  importer_node.OUTPUT_KEY)

          # Retrieve the pre-built importer spec
          importer_spec = subgroup.importer_specs[input_name]
          deployment_config.executors[importer_exec_label].importer.CopyFrom(
              importer_spec)

      # Add component spec if not exists
      if subgroup_component_name not in pipeline_spec.components:
        pipeline_spec.components[subgroup_component_name].CopyFrom(
            subgroup_component_spec)

      # Add task spec
      group_component_spec.dag.tasks[
          subgroup_task_spec.task_info.name].CopyFrom(subgroup_task_spec)

      # Add executor spec, if applicable.
      container_spec = getattr(subgroup, 'container_spec', None)
      if container_spec:
        if compiler_utils.is_v2_component(subgroup):
          compiler_utils.refactor_v2_container_spec(container_spec)
        executor_label = subgroup_component_spec.executor_label

        if executor_label not in deployment_config.executors:
          deployment_config.executors[executor_label].container.CopyFrom(
              container_spec)

      # Add AIPlatformCustomJobSpec, if applicable.
      custom_job_spec = getattr(subgroup, 'custom_job_spec', None)
      if custom_job_spec:
        executor_label = subgroup_component_spec.executor_label
        if executor_label not in deployment_config.executors:
          deployment_config.executors[
            executor_label].custom_job.custom_job.update(custom_job_spec)

    pipeline_spec.deployment_spec.update(
        json_format.MessageToDict(deployment_config))
Example #4
0
def create_container_op_from_component_and_arguments(
    component_spec: structures.ComponentSpec,
    arguments: Mapping[str, Any],
    component_ref: Optional[structures.ComponentReference] = None,
) -> container_op.ContainerOp:
    """Instantiates ContainerOp object.

  Args:
    component_spec: The component spec object.
    arguments: The dictionary of component arguments.
    component_ref: (not used in v2)

  Returns:
    A ContainerOp instance.
  """

    pipeline_task_spec = pipeline_spec_pb2.PipelineTaskSpec()

    # Keep track of auto-injected importer spec.
    importer_specs = {}

    # Check types of the reference arguments and serialize PipelineParams
    arguments = arguments.copy()
    for input_name, argument_value in arguments.items():
        if isinstance(argument_value, dsl.PipelineParam):
            input_type = component_spec._inputs_dict[input_name].type
            reference_type = argument_value.param_type
            types.verify_type_compatibility(
                reference_type, input_type,
                'Incompatible argument passed to the input "{}" of component "{}": '
                .format(input_name, component_spec.name))

            arguments[input_name] = str(argument_value)

            if type_utils.is_parameter_type(input_type):
                if argument_value.op_name:
                    pipeline_task_spec.inputs.parameters[
                        input_name].task_output_parameter.producer_task = (
                            dsl_utils.sanitize_task_name(
                                argument_value.op_name))
                    pipeline_task_spec.inputs.parameters[
                        input_name].task_output_parameter.output_parameter_key = (
                            argument_value.name)
                else:
                    pipeline_task_spec.inputs.parameters[
                        input_name].component_input_parameter = argument_value.name
            else:
                if argument_value.op_name:
                    pipeline_task_spec.inputs.artifacts[
                        input_name].task_output_artifact.producer_task = (
                            dsl_utils.sanitize_task_name(
                                argument_value.op_name))
                    pipeline_task_spec.inputs.artifacts[
                        input_name].task_output_artifact.output_artifact_key = (
                            argument_value.name)
                else:
                    # argument_value.op_name could be none, in which case an importer node
                    # will be inserted later.
                    pipeline_task_spec.inputs.artifacts[
                        input_name].task_output_artifact.producer_task = ''
                    type_schema = type_utils.get_input_artifact_type_schema(
                        input_name, component_spec.inputs)
                    importer_specs[
                        input_name] = importer_node.build_importer_spec(
                            input_type_schema=type_schema,
                            pipeline_param_name=argument_value.name)
        elif isinstance(argument_value, str):
            input_type = component_spec._inputs_dict[input_name].type
            if type_utils.is_parameter_type(input_type):
                pipeline_task_spec.inputs.parameters[
                    input_name].runtime_value.constant_value.string_value = (
                        argument_value)
            else:
                # An importer node with constant value artifact_uri will be inserted.
                pipeline_task_spec.inputs.artifacts[
                    input_name].task_output_artifact.producer_task = ''
                type_schema = type_utils.get_input_artifact_type_schema(
                    input_name, component_spec.inputs)
                importer_specs[input_name] = importer_node.build_importer_spec(
                    input_type_schema=type_schema,
                    constant_value=argument_value)
        elif isinstance(argument_value, int):
            pipeline_task_spec.inputs.parameters[
                input_name].runtime_value.constant_value.int_value = argument_value
        elif isinstance(argument_value, float):
            pipeline_task_spec.inputs.parameters[
                input_name].runtime_value.constant_value.double_value = argument_value
        elif isinstance(argument_value, dsl.ContainerOp):
            raise TypeError(
                'ContainerOp object {} was passed to component as an input argument. '
                'Pass a single output instead.'.format(input_name))
        else:
            raise NotImplementedError(
                'Input argument supports only the following types: PipelineParam'
                ', str, int, float. Got: "{}".'.format(argument_value))

    inputs_dict = {
        input_spec.name: input_spec
        for input_spec in component_spec.inputs or []
    }
    outputs_dict = {
        output_spec.name: output_spec
        for output_spec in component_spec.outputs or []
    }

    def _input_artifact_uri_placeholder(input_key: str) -> str:
        if type_utils.is_parameter_type(inputs_dict[input_key].type):
            raise TypeError(
                'Input "{}" with type "{}" cannot be paired with InputUriPlaceholder.'
                .format(input_key, inputs_dict[input_key].type))
        else:
            return "{{{{$.inputs.artifacts['{}'].uri}}}}".format(input_key)

    def _input_artifact_path_placeholder(input_key: str) -> str:
        if type_utils.is_parameter_type(inputs_dict[input_key].type):
            raise TypeError(
                'Input "{}" with type "{}" cannot be paired with InputPathPlaceholder.'
                .format(input_key, inputs_dict[input_key].type))
        elif input_key in importer_specs:
            raise TypeError(
                'Input "{}" with type "{}" is not connected to any upstream output. '
                'However it is used with InputPathPlaceholder. '
                'If you want to import an existing artifact using a system-connected '
                'importer node, use InputUriPlaceholder instead. '
                'Or if you just want to pass a string parameter, use string type and '
                'InputValuePlaceholder instead.'.format(
                    input_key, inputs_dict[input_key].type))
        else:
            return "{{{{$.inputs.artifacts['{}'].path}}}}".format(input_key)

    def _input_parameter_placeholder(input_key: str) -> str:
        if type_utils.is_parameter_type(inputs_dict[input_key].type):
            return "{{{{$.inputs.parameters['{}']}}}}".format(input_key)
        else:
            raise TypeError(
                'Input "{}" with type "{}" cannot be paired with InputValuePlaceholder.'
                .format(input_key, inputs_dict[input_key].type))

    def _output_artifact_uri_placeholder(output_key: str) -> str:
        if type_utils.is_parameter_type(outputs_dict[output_key].type):
            raise TypeError(
                'Output "{}" with type "{}" cannot be paired with OutputUriPlaceholder.'
                .format(output_key, outputs_dict[output_key].type))
        else:
            return "{{{{$.outputs.artifacts['{}'].uri}}}}".format(output_key)

    def _output_artifact_path_placeholder(output_key: str) -> str:
        return "{{{{$.outputs.artifacts['{}'].path}}}}".format(output_key)

    def _output_parameter_path_placeholder(output_key: str) -> str:
        return "{{{{$.outputs.parameters['{}'].output_file}}}}".format(
            output_key)

    def _resolve_output_path_placeholder(output_key: str) -> str:
        if type_utils.is_parameter_type(outputs_dict[output_key].type):
            return _output_parameter_path_placeholder(output_key)
        else:
            return _output_artifact_path_placeholder(output_key)

    resolved_cmd = _resolve_command_line_and_paths(
        component_spec=component_spec,
        arguments=arguments,
        input_value_generator=_input_parameter_placeholder,
        input_uri_generator=_input_artifact_uri_placeholder,
        output_uri_generator=_output_artifact_uri_placeholder,
        input_path_generator=_input_artifact_path_placeholder,
        output_path_generator=_resolve_output_path_placeholder,
    )

    container_spec = component_spec.implementation.container

    output_uris_and_paths = resolved_cmd.output_uris.copy()
    output_uris_and_paths.update(resolved_cmd.output_paths)
    input_uris_and_paths = resolved_cmd.input_uris.copy()
    input_uris_and_paths.update(resolved_cmd.input_paths)

    old_warn_value = dsl.ContainerOp._DISABLE_REUSABLE_COMPONENT_WARNING
    dsl.ContainerOp._DISABLE_REUSABLE_COMPONENT_WARNING = True
    task = container_op.ContainerOp(
        name=component_spec.name or _default_component_name,
        image=container_spec.image,
        command=resolved_cmd.command,
        arguments=resolved_cmd.args,
        file_outputs=output_uris_and_paths,
        artifact_argument_paths=[
            dsl.InputArgumentPath(
                argument=arguments[input_name],
                input=input_name,
                path=path,
            ) for input_name, path in input_uris_and_paths.items()
        ],
    )

    # task.name is unique at this point.
    pipeline_task_spec.task_info.name = (dsl_utils.sanitize_task_name(
        task.name))
    pipeline_task_spec.component_ref.name = (dsl_utils.sanitize_component_name(
        component_spec.name))

    task.task_spec = pipeline_task_spec
    task.importer_specs = importer_specs
    task.component_spec = dsl_component_spec.build_component_spec_from_structure(
        component_spec)
    task.container_spec = (
        pipeline_spec_pb2.PipelineDeploymentConfig.PipelineContainerSpec(
            image=container_spec.image,
            command=resolved_cmd.command,
            args=resolved_cmd.args))

    dsl.ContainerOp._DISABLE_REUSABLE_COMPONENT_WARNING = old_warn_value

    component_meta = copy.copy(component_spec)
    task._set_metadata(component_meta)

    # Previously, ContainerOp had strict requirements for the output names, so we
    # had to convert all the names before passing them to the ContainerOp
    # constructor. Outputs with non-pythonic names could not be accessed using
    # their original names. Now ContainerOp supports any output names, so we're
    # now using the original output names. However to support legacy pipelines,
    # we're also adding output references with pythonic names.
    # TODO: Add warning when people use the legacy output names.
    output_names = [
        output_spec.name for output_spec in component_spec.outputs or []
    ]  # Stabilizing the ordering
    output_name_to_python = generate_unique_name_conversion_table(
        output_names, _sanitize_python_function_name)
    for output_name in output_names:
        pythonic_output_name = output_name_to_python[output_name]
        # Note: Some component outputs are currently missing from task.outputs
        # (e.g. MLPipeline UI Metadata)
        if pythonic_output_name not in task.outputs and output_name in task.outputs:
            task.outputs[pythonic_output_name] = task.outputs[output_name]

    if component_spec.metadata:
        annotations = component_spec.metadata.annotations or {}
        for key, value in annotations.items():
            task.add_pod_annotation(key, value)
        for key, value in (component_spec.metadata.labels or {}).items():
            task.add_pod_label(key, value)
            # Disabling the caching for the volatile components by default
        if annotations.get('volatile_component', 'false') == 'true':
            task.execution_options.caching_strategy.max_cache_staleness = 'P0D'

    return task
 def test_sanitize_component_name(self):
     self.assertEqual('comp-my-component',
                      dsl_utils.sanitize_component_name('My component'))
    def _group_to_dag_spec(
        self,
        group: dsl.OpsGroup,
        inputs: Dict[str, List[Tuple[dsl.PipelineParam, str]]],
        outputs: Dict[str, List[Tuple[dsl.PipelineParam, str]]],
        dependencies: Dict[str, List[_GroupOrOp]],
        pipeline_spec: pipeline_spec_pb2.PipelineSpec,
        rootgroup_name: str,
    ) -> None:
        """Generate IR spec given an OpsGroup.

    Args:
      group: The OpsGroup to generate spec for.
      inputs: The inputs dictionary. The keys are group/op names and values are
        lists of tuples (param, producing_op_name).
      outputs: The outputs dictionary. The keys are group/op names and values
        are lists of tuples (param, producing_op_name).
      dependencies: The group dependencies dictionary. The keys are group/op
        names, and the values are lists of dependent groups/ops.
      pipeline_spec: The pipeline_spec to update in-place.
      rootgroup_name: The name of the group root. Used to determine whether the
        component spec for the current group should be the root dag.
    """
        group_component_name = dsl_utils.sanitize_component_name(group.name)

        if group.name == rootgroup_name:
            group_component_spec = pipeline_spec.root
        else:
            group_component_spec = pipeline_spec.components[
                group_component_name]

        deployment_config = pipeline_spec_pb2.PipelineDeploymentConfig()

        # Generate component inputs spec.
        if inputs.get(group.name, None):
            dsl_component_spec.build_component_inputs_spec(
                group_component_spec,
                [param for param, _ in inputs[group.name]])

        # Generate component outputs spec.
        if outputs.get(group.name, None):
            group_component_spec.output_definitions.CopyFrom(
                dsl_component_spec.build_component_outputs_spec(
                    [param for param, _ in outputs[group.name]]))

        # Generate task specs and component specs for the dag.
        subgroups = group.groups + group.ops
        for subgroup in subgroups:
            subgroup_task_spec = getattr(subgroup, 'task_spec',
                                         pipeline_spec_pb2.PipelineTaskSpec())
            subgroup_component_spec = getattr(
                subgroup, 'component_spec', pipeline_spec_pb2.ComponentSpec())
            is_recursive_subgroup = (isinstance(subgroup, dsl.OpsGroup)
                                     and subgroup.recursive_ref)
            # Special handling for recursive subgroup: use the existing opsgroup name
            if is_recursive_subgroup:
                subgroup_key = subgroup.recursive_ref.name
            else:
                subgroup_key = subgroup.name

            subgroup_task_spec.task_info.name = dsl_utils.sanitize_task_name(
                subgroup_key)
            # human_name exists for ops only, and is used to de-dupe component spec.
            subgroup_component_name = dsl_utils.sanitize_component_name(
                getattr(subgroup, 'human_name', subgroup_key))
            subgroup_task_spec.component_ref.name = subgroup_component_name

            if isinstance(subgroup,
                          dsl.OpsGroup) and subgroup.type == 'condition':
                condition = subgroup.condition
                operand_values = []
                subgroup_inputs = inputs.get(subgroup.name, [])
                subgroup_params = [param for param, _ in subgroup_inputs]
                tasks_in_current_dag = [
                    subgroup.name for subgroup in subgroups
                ]

                dsl_component_spec.build_component_inputs_spec(
                    subgroup_component_spec,
                    subgroup_params,
                )
                dsl_component_spec.build_task_inputs_spec(
                    subgroup_task_spec,
                    subgroup_params,
                    tasks_in_current_dag,
                )

                for operand in [condition.operand1, condition.operand2]:
                    operand_values.append(
                        self._resolve_value_or_reference(operand))

                condition_string = '{} {} {}'.format(operand_values[0],
                                                     condition.operator,
                                                     operand_values[1])

                subgroup_task_spec.trigger_policy.CopyFrom(
                    pipeline_spec_pb2.PipelineTaskSpec.TriggerPolicy(
                        condition=condition_string))

            # Generate dependencies section for this task.
            if dependencies.get(subgroup.name, None):
                group_dependencies = list(dependencies[subgroup.name])
                group_dependencies.sort()
                subgroup_task_spec.dependent_tasks.extend([
                    dsl_utils.sanitize_task_name(dep)
                    for dep in group_dependencies
                ])

            # Add importer node when applicable
            for input_name in subgroup_task_spec.inputs.artifacts:
                if not subgroup_task_spec.inputs.artifacts[
                        input_name].task_output_artifact.producer_task:
                    type_schema = type_utils.get_input_artifact_type_schema(
                        input_name, subgroup._metadata.inputs)

                    importer_name = importer_node.generate_importer_base_name(
                        dependent_task_name=subgroup_task_spec.task_info.name,
                        input_name=input_name)
                    importer_task_spec = importer_node.build_importer_task_spec(
                        importer_name)
                    importer_comp_spec = importer_node.build_importer_component_spec(
                        importer_base_name=importer_name,
                        input_name=input_name,
                        input_type_schema=type_schema)
                    importer_task_name = importer_task_spec.task_info.name
                    importer_comp_name = importer_task_spec.component_ref.name
                    importer_exec_label = importer_comp_spec.executor_label
                    group_component_spec.dag.tasks[
                        importer_task_name].CopyFrom(importer_task_spec)
                    pipeline_spec.components[importer_comp_name].CopyFrom(
                        importer_comp_spec)

                    subgroup_task_spec.inputs.artifacts[
                        input_name].task_output_artifact.producer_task = (
                            importer_task_name)
                    subgroup_task_spec.inputs.artifacts[
                        input_name].task_output_artifact.output_artifact_key = (
                            importer_node.OUTPUT_KEY)

                    # Retrieve the pre-built importer spec
                    importer_spec = subgroup.importer_specs[input_name]
                    deployment_config.executors[
                        importer_exec_label].importer.CopyFrom(importer_spec)

            # Add component spec if not exists
            if subgroup_component_name not in pipeline_spec.components:
                pipeline_spec.components[subgroup_component_name].CopyFrom(
                    subgroup_component_spec)

            # Add task spec
            group_component_spec.dag.tasks[
                subgroup_task_spec.task_info.name].CopyFrom(subgroup_task_spec)

            # Add executor spec
            container_spec = getattr(subgroup, 'container_spec', None)
            if container_spec:
                if compiler_utils.is_v2_component(subgroup):
                    compiler_utils.refactor_v2_container_spec(container_spec)
                executor_label = subgroup_component_spec.executor_label

                if executor_label not in deployment_config.executors:
                    deployment_config.executors[
                        executor_label].container.CopyFrom(container_spec)

        pipeline_spec.deployment_spec.update(
            json_format.MessageToDict(deployment_config))