Exemplo n.º 1
0
    def test_build_component_outputs_spec(self):
        pipeline_params = [
            _pipeline_param.PipelineParam(name='output1',
                                          param_type='Dataset'),
            _pipeline_param.PipelineParam(name='output2',
                                          param_type='Integer'),
            _pipeline_param.PipelineParam(name='output3', param_type='String'),
            _pipeline_param.PipelineParam(name='output4', param_type='Float'),
        ]
        expected_dict = {
            'outputDefinitions': {
                'artifacts': {
                    'output1': {
                        'artifactType': {
                            'instanceSchema':
                            'title: kfp.Dataset\ntype: object\nproperties:\n  '
                            'payload_format:\n    type: string\n  '
                            'container_format:\n    type: string\n'
                        }
                    }
                },
                'parameters': {
                    'output2': {
                        'type': 'INT'
                    },
                    'output3': {
                        'type': 'STRING'
                    },
                    'output4': {
                        'type': 'DOUBLE'
                    }
                }
            }
        }
        expected_spec = pipeline_spec_pb2.ComponentSpec()
        json_format.ParseDict(expected_dict, expected_spec)

        component_spec = pipeline_spec_pb2.ComponentSpec()
        dsl_component_spec.build_component_outputs_spec(
            component_spec, pipeline_params)

        self.assertEqual(expected_spec, component_spec)
Exemplo n.º 2
0
    def test_build_component_outputs_spec(self):
        pipeline_params = [
            _pipeline_param.PipelineParam(name='output1',
                                          param_type='Dataset'),
            _pipeline_param.PipelineParam(name='output2',
                                          param_type='Integer'),
            _pipeline_param.PipelineParam(name='output3', param_type='String'),
            _pipeline_param.PipelineParam(name='output4', param_type='Float'),
        ]
        expected_dict = {
            'outputDefinitions': {
                'artifacts': {
                    'output1': {
                        'artifactType': {
                            'schemaTitle': 'system.Dataset',
                            'schemaVersion': '0.0.1'
                        }
                    }
                },
                'parameters': {
                    'output2': {
                        'parameterType': 'NUMBER_INTEGER'
                    },
                    'output3': {
                        'parameterType': 'STRING'
                    },
                    'output4': {
                        'parameterType': 'NUMBER_DOUBLE'
                    }
                }
            }
        }
        expected_spec = pipeline_spec_pb2.ComponentSpec()
        json_format.ParseDict(expected_dict, expected_spec)

        component_spec = pipeline_spec_pb2.ComponentSpec()
        dsl_component_spec.build_component_outputs_spec(
            component_spec, pipeline_params)

        self.assertEqual(expected_spec, component_spec)
Exemplo n.º 3
0
    def _group_to_dag_spec(
        self,
        group: dsl.OpsGroup,
        inputs: Dict[str, List[Tuple[dsl.PipelineParam, str]]],
        outputs: Dict[str, List[Tuple[dsl.PipelineParam, str]]],
        dependencies: Dict[str, List[_GroupOrOp]],
        pipeline_spec: pipeline_spec_pb2.PipelineSpec,
        deployment_config: pipeline_spec_pb2.PipelineDeploymentConfig,
        rootgroup_name: str,
    ) -> None:
        """Generate IR spec given an OpsGroup.

    Args:
      group: The OpsGroup to generate spec for.
      inputs: The inputs dictionary. The keys are group/op names and values are
        lists of tuples (param, producing_op_name).
      outputs: The outputs dictionary. The keys are group/op names and values
        are lists of tuples (param, producing_op_name).
      dependencies: The group dependencies dictionary. The keys are group/op
        names, and the values are lists of dependent groups/ops.
      pipeline_spec: The pipeline_spec to update in-place.
      deployment_config: The deployment_config to hold all executors.
      rootgroup_name: The name of the group root. Used to determine whether the
        component spec for the current group should be the root dag.
    """
        group_component_name = dsl_utils.sanitize_component_name(group.name)

        if group.name == rootgroup_name:
            group_component_spec = pipeline_spec.root
        else:
            group_component_spec = pipeline_spec.components[
                group_component_name]

        # Generate component inputs spec.
        if inputs.get(group.name, None):
            dsl_component_spec.build_component_inputs_spec(
                group_component_spec,
                [param for param, _ in inputs[group.name]])

        # Generate component outputs spec.
        if outputs.get(group.name, None):
            group_component_spec.output_definitions.CopyFrom(
                dsl_component_spec.build_component_outputs_spec(
                    [param for param, _ in outputs[group.name]]))

        # Generate task specs and component specs for the dag.
        subgroups = group.groups + group.ops
        for subgroup in subgroups:
            subgroup_task_spec = getattr(subgroup, 'task_spec',
                                         pipeline_spec_pb2.PipelineTaskSpec())
            subgroup_component_spec = getattr(
                subgroup, 'component_spec', pipeline_spec_pb2.ComponentSpec())
            is_loop_subgroup = (isinstance(group, dsl.ParallelFor))
            is_recursive_subgroup = (isinstance(subgroup, dsl.OpsGroup)
                                     and subgroup.recursive_ref)

            # Special handling for recursive subgroup: use the existing opsgroup name
            if is_recursive_subgroup:
                subgroup_key = subgroup.recursive_ref.name
            else:
                subgroup_key = subgroup.name

            subgroup_task_spec.task_info.name = dsl_utils.sanitize_task_name(
                subgroup_key)
            # human_name exists for ops only, and is used to de-dupe component spec.
            subgroup_component_name = dsl_utils.sanitize_component_name(
                getattr(subgroup, 'human_name', subgroup_key))
            subgroup_task_spec.component_ref.name = subgroup_component_name

            if isinstance(subgroup, dsl.OpsGroup) and subgroup.type == 'graph':
                raise NotImplementedError(
                    'dsl.graph_component is not yet supported in KFP v2 compiler.'
                )

            if isinstance(subgroup,
                          dsl.OpsGroup) and subgroup.type == 'exit_handler':
                raise NotImplementedError(
                    'dsl.ExitHandler is not yet supported in KFP v2 compiler.')

            if isinstance(subgroup,
                          dsl.OpsGroup) and subgroup.type == 'condition':
                condition = subgroup.condition
                operand_values = []
                subgroup_inputs = inputs.get(subgroup.name, [])
                subgroup_params = [param for param, _ in subgroup_inputs]
                tasks_in_current_dag = [
                    subgroup.name for subgroup in subgroups
                ]

                dsl_component_spec.build_component_inputs_spec(
                    subgroup_component_spec,
                    subgroup_params,
                )
                dsl_component_spec.build_task_inputs_spec(
                    subgroup_task_spec,
                    subgroup_params,
                    tasks_in_current_dag,
                )

                for operand in [condition.operand1, condition.operand2]:
                    operand_values.append(
                        self._resolve_value_or_reference(operand))

                condition_string = '{} {} {}'.format(operand_values[0],
                                                     condition.operator,
                                                     operand_values[1])

                subgroup_task_spec.trigger_policy.CopyFrom(
                    pipeline_spec_pb2.PipelineTaskSpec.TriggerPolicy(
                        condition=condition_string))

            # Generate dependencies section for this task.
            if dependencies.get(subgroup.name, None):
                group_dependencies = list(dependencies[subgroup.name])
                group_dependencies.sort()
                subgroup_task_spec.dependent_tasks.extend([
                    dsl_utils.sanitize_task_name(dep)
                    for dep in group_dependencies
                ])

            if isinstance(subgroup, dsl.ParallelFor):
                if subgroup.parallelism is not None:
                    warnings.warn(
                        'Setting parallelism in ParallelFor is not supported yet.'
                        'The setting is ignored.')

                # Remove loop arguments related inputs from parent group component spec.
                input_names = [
                    param.full_name for param, _ in inputs[subgroup.name]
                ]
                for input_name in input_names:
                    if _for_loop.LoopArguments.name_is_loop_argument(
                            input_name):
                        dsl_component_spec.pop_input_from_component_spec(
                            group_component_spec, input_name)

                if subgroup.items_is_pipeline_param:
                    # These loop args are a 'withParam' rather than 'withItems'.
                    # i.e., rather than a static list, they are either the output of
                    # another task or were input as global pipeline parameters.

                    pipeline_param = subgroup.loop_args.items_or_pipeline_param
                    input_parameter_name = pipeline_param.full_name

                    if pipeline_param.op_name:
                        subgroup_task_spec.inputs.parameters[
                            input_parameter_name].task_output_parameter.producer_task = (
                                dsl_utils.sanitize_task_name(
                                    pipeline_param.op_name))
                        subgroup_task_spec.inputs.parameters[
                            input_parameter_name].task_output_parameter.output_parameter_key = (
                                pipeline_param.name)
                    else:
                        subgroup_task_spec.inputs.parameters[
                            input_parameter_name].component_input_parameter = (
                                input_parameter_name)

                    if pipeline_param.op_name is None:
                        # Input parameter is from pipeline func rather than component output.
                        # Correct loop argument input type in the parent component spec.
                        # The loop argument was categorized as an artifact due to its missing
                        # or non-primitive type annotation. But it should always be String
                        # typed, as its value is a serialized JSON string.
                        dsl_component_spec.pop_input_from_component_spec(
                            group_component_spec, input_parameter_name)
                        group_component_spec.input_definitions.parameters[
                            input_parameter_name].type = pipeline_spec_pb2.PrimitiveType.STRING

            # Additional spec modifications for dsl.ParallelFor's subgroups.
            if is_loop_subgroup:
                self._update_loop_specs(group, subgroup, group_component_spec,
                                        subgroup_component_spec,
                                        subgroup_task_spec)

            # Add importer node when applicable
            for input_name in subgroup_task_spec.inputs.artifacts:
                if not subgroup_task_spec.inputs.artifacts[
                        input_name].task_output_artifact.producer_task:
                    type_schema = type_utils.get_input_artifact_type_schema(
                        input_name, subgroup._metadata.inputs)

                    importer_name = importer_node.generate_importer_base_name(
                        dependent_task_name=subgroup_task_spec.task_info.name,
                        input_name=input_name)
                    importer_task_spec = importer_node.build_importer_task_spec(
                        importer_name)
                    importer_comp_spec = importer_node.build_importer_component_spec(
                        importer_base_name=importer_name,
                        input_name=input_name,
                        input_type_schema=type_schema)
                    importer_task_name = importer_task_spec.task_info.name
                    importer_comp_name = importer_task_spec.component_ref.name
                    importer_exec_label = importer_comp_spec.executor_label
                    group_component_spec.dag.tasks[
                        importer_task_name].CopyFrom(importer_task_spec)
                    pipeline_spec.components[importer_comp_name].CopyFrom(
                        importer_comp_spec)

                    subgroup_task_spec.inputs.artifacts[
                        input_name].task_output_artifact.producer_task = (
                            importer_task_name)
                    subgroup_task_spec.inputs.artifacts[
                        input_name].task_output_artifact.output_artifact_key = (
                            importer_node.OUTPUT_KEY)

                    # Retrieve the pre-built importer spec
                    importer_spec = subgroup.importer_specs[input_name]
                    deployment_config.executors[
                        importer_exec_label].importer.CopyFrom(importer_spec)

            # Add component spec if not exists
            if subgroup_component_name not in pipeline_spec.components:
                pipeline_spec.components[subgroup_component_name].CopyFrom(
                    subgroup_component_spec)

            # Add task spec
            group_component_spec.dag.tasks[
                subgroup_task_spec.task_info.name].CopyFrom(subgroup_task_spec)

            # Add executor spec, if applicable.
            container_spec = getattr(subgroup, 'container_spec', None)
            if container_spec:
                if compiler_utils.is_v2_component(subgroup):
                    compiler_utils.refactor_v2_container_spec(container_spec)
                executor_label = subgroup_component_spec.executor_label

                if executor_label not in deployment_config.executors:
                    deployment_config.executors[
                        executor_label].container.CopyFrom(container_spec)

            # Add AIPlatformCustomJobSpec, if applicable.
            custom_job_spec = getattr(subgroup, 'custom_job_spec', None)
            if custom_job_spec:
                executor_label = subgroup_component_spec.executor_label
                if executor_label not in deployment_config.executors:
                    deployment_config.executors[
                        executor_label].custom_job.custom_job.update(
                            custom_job_spec)

        pipeline_spec.deployment_spec.update(
            json_format.MessageToDict(deployment_config))