示例#1
0
    def testBuildTask(self):
        query = 'SELECT * FROM TABLE'
        bq_example_gen = big_query_example_gen_component.BigQueryExampleGen(
            query=query)
        deployment_config = pipeline_pb2.PipelineDeploymentConfig()
        component_defs = {}
        my_builder = step_builder.StepBuilder(
            node=bq_example_gen,
            image='gcr.io/tensorflow/tfx:latest',
            deployment_config=deployment_config,
            component_defs=component_defs,
            enable_cache=True)
        actual_step_spec = self._sole(my_builder.build())
        actual_component_def = self._sole(component_defs)

        self.assertProtoEquals(
            test_utils.get_proto_from_test_data(
                'expected_bq_example_gen_component.pbtxt',
                pipeline_pb2.ComponentSpec()), actual_component_def)
        self.assertProtoEquals(
            test_utils.get_proto_from_test_data(
                'expected_bq_example_gen_task.pbtxt',
                pipeline_pb2.PipelineTaskSpec()), actual_step_spec)
        self.assertProtoEquals(
            test_utils.get_proto_from_test_data(
                'expected_bq_example_gen_executor.pbtxt',
                pipeline_pb2.PipelineDeploymentConfig()), deployment_config)
示例#2
0
    def testBuildFileBasedExampleGen(self):
        example_gen = components.CsvExampleGen(
            input_base='path/to/data/root').with_beam_pipeline_args(
                ['--runner=DataflowRunner'])
        deployment_config = pipeline_pb2.PipelineDeploymentConfig()
        component_defs = {}
        my_builder = step_builder.StepBuilder(
            node=example_gen,
            image='gcr.io/tensorflow/tfx:latest',
            image_cmds=_TEST_CMDS,
            deployment_config=deployment_config,
            component_defs=component_defs)
        actual_step_spec = self._sole(my_builder.build())
        actual_component_def = self._sole(component_defs)

        self.assertProtoEquals(
            test_utils.get_proto_from_test_data(
                'expected_csv_example_gen_component.pbtxt',
                pipeline_pb2.ComponentSpec()), actual_component_def)
        self.assertProtoEquals(
            test_utils.get_proto_from_test_data(
                'expected_csv_example_gen_task.pbtxt',
                pipeline_pb2.PipelineTaskSpec()), actual_step_spec)
        self.assertProtoEquals(
            test_utils.get_proto_from_test_data(
                'expected_csv_example_gen_executor.pbtxt',
                pipeline_pb2.PipelineDeploymentConfig()), deployment_config)
示例#3
0
    def testBuildExitHandler(self):
        task = test_utils.dummy_producer_component(
            param1=decorators.FinalStatusStr('value1'), )
        deployment_config = pipeline_pb2.PipelineDeploymentConfig()
        component_defs = {}
        my_builder = step_builder.StepBuilder(
            node=task,
            image='gcr.io/tensorflow/tfx:latest',
            deployment_config=deployment_config,
            component_defs=component_defs,
            is_exit_handler=True)
        actual_step_spec = self._sole(my_builder.build())
        actual_component_def = self._sole(component_defs)

        self.assertProtoEquals(
            test_utils.get_proto_from_test_data(
                'expected_dummy_exit_handler_component.pbtxt',
                pipeline_pb2.ComponentSpec()), actual_component_def)
        self.assertProtoEquals(
            test_utils.get_proto_from_test_data(
                'expected_dummy_exit_handler_task.pbtxt',
                pipeline_pb2.PipelineTaskSpec()), actual_step_spec)
        self.assertProtoEquals(
            test_utils.get_proto_from_test_data(
                'expected_dummy_exit_handler_executor.pbtxt',
                pipeline_pb2.PipelineDeploymentConfig()), deployment_config)
示例#4
0
    def testBuildLatestArtifactResolverSucceed(self):
        latest_model_resolver = resolver.Resolver(
            strategy_class=latest_artifact_strategy.LatestArtifactStrategy,
            model=channel.Channel(type=standard_artifacts.Model),
            examples=channel.Channel(
                type=standard_artifacts.Examples)).with_id('my_resolver')
        deployment_config = pipeline_pb2.PipelineDeploymentConfig()
        component_defs = {}
        test_pipeline_info = data_types.PipelineInfo(
            pipeline_name='test-pipeline',
            pipeline_root='gs://path/to/my/root')
        my_builder = step_builder.StepBuilder(
            node=latest_model_resolver,
            deployment_config=deployment_config,
            pipeline_info=test_pipeline_info,
            component_defs=component_defs)
        actual_step_spec = self._sole(my_builder.build())
        actual_component_def = self._sole(component_defs)

        self.assertProtoEquals(
            test_utils.get_proto_from_test_data(
                'expected_latest_artifact_resolver_component.pbtxt',
                pipeline_pb2.ComponentSpec()), actual_component_def)
        self.assertProtoEquals(
            test_utils.get_proto_from_test_data(
                'expected_latest_artifact_resolver_task.pbtxt',
                pipeline_pb2.PipelineTaskSpec()), actual_step_spec)
        self.assertProtoEquals(
            test_utils.get_proto_from_test_data(
                'expected_latest_artifact_resolver_executor.pbtxt',
                pipeline_pb2.PipelineDeploymentConfig()), deployment_config)
示例#5
0
    def testBuildImporterWithRuntimeParam(self):
        param = data_types.RuntimeParameter(name='runtime_flag', ptype=str)
        impt = importer.Importer(
            source_uri=param,
            artifact_type=standard_artifacts.Examples).with_id('my_importer')
        deployment_config = pipeline_pb2.PipelineDeploymentConfig()
        component_defs = {}
        with parameter_utils.ParameterContext() as pc:
            my_builder = step_builder.StepBuilder(
                node=impt,
                deployment_config=deployment_config,
                component_defs=component_defs)
            actual_step_spec = self._sole(my_builder.build())
        actual_component_def = self._sole(component_defs)

        self.assertProtoEquals(
            test_utils.get_proto_from_test_data(
                'expected_importer_component_with_runtime_param.pbtxt',
                pipeline_pb2.ComponentSpec()), actual_component_def)
        self.assertProtoEquals(
            test_utils.get_proto_from_test_data(
                'expected_importer_task_with_runtime_param.pbtxt',
                pipeline_pb2.PipelineTaskSpec()), actual_step_spec)
        self.assertProtoEquals(
            test_utils.get_proto_from_test_data(
                'expected_importer_executor_with_runtime_param.pbtxt',
                pipeline_pb2.PipelineDeploymentConfig()), deployment_config)
        self.assertListEqual([param], pc.parameters)
示例#6
0
    def testBuildFileBasedExampleGenWithInputConfig(self):
        input_config = example_gen_pb2.Input(splits=[
            example_gen_pb2.Input.Split(name='train', pattern='*train.tfr'),
            example_gen_pb2.Input.Split(name='eval', pattern='*test.tfr')
        ])
        example_gen = components.ImportExampleGen(
            input_base='path/to/data/root', input_config=input_config)
        deployment_config = pipeline_pb2.PipelineDeploymentConfig()
        component_defs = {}
        my_builder = step_builder.StepBuilder(
            node=example_gen,
            image='gcr.io/tensorflow/tfx:latest',
            deployment_config=deployment_config,
            component_defs=component_defs)
        actual_step_spec = self._sole(my_builder.build())
        actual_component_def = self._sole(component_defs)

        self.assertProtoEquals(
            test_utils.get_proto_from_test_data(
                'expected_import_example_gen_component.pbtxt',
                pipeline_pb2.ComponentSpec()), actual_component_def)
        self.assertProtoEquals(
            test_utils.get_proto_from_test_data(
                'expected_import_example_gen_task.pbtxt',
                pipeline_pb2.PipelineTaskSpec()), actual_step_spec)
        self.assertProtoEquals(
            test_utils.get_proto_from_test_data(
                'expected_import_example_gen_executor.pbtxt',
                pipeline_pb2.PipelineDeploymentConfig()), deployment_config)
示例#7
0
    def testBuildContainerTask(self):
        task = test_utils.DummyProducerComponent(
            output1=channel_utils.as_channel([standard_artifacts.Model()]),
            param1='value1',
        )
        deployment_config = pipeline_pb2.PipelineDeploymentConfig()
        component_defs = {}
        my_builder = step_builder.StepBuilder(
            node=task,
            image=
            'gcr.io/tensorflow/tfx:latest',  # Note this has no effect here.
            deployment_config=deployment_config,
            component_defs=component_defs)
        actual_step_spec = self._sole(my_builder.build())
        actual_component_def = self._sole(component_defs)

        self.assertProtoEquals(
            test_utils.get_proto_from_test_data(
                'expected_dummy_container_spec_component.pbtxt',
                pipeline_pb2.ComponentSpec()), actual_component_def)
        self.assertProtoEquals(
            test_utils.get_proto_from_test_data(
                'expected_dummy_container_spec_task.pbtxt',
                pipeline_pb2.PipelineTaskSpec()), actual_step_spec)
        self.assertProtoEquals(
            test_utils.get_proto_from_test_data(
                'expected_dummy_container_spec_executor.pbtxt',
                pipeline_pb2.PipelineDeploymentConfig()), deployment_config)
示例#8
0
    def testBuildImporter(self):
        impt = importer.Importer(
            source_uri='m/y/u/r/i',
            properties={
                'split_names': '["train", "eval"]',
            },
            custom_properties={
                'str_custom_property': 'abc',
                'int_custom_property': 123,
            },
            artifact_type=standard_artifacts.Examples).with_id('my_importer')
        deployment_config = pipeline_pb2.PipelineDeploymentConfig()
        component_defs = {}
        my_builder = step_builder.StepBuilder(
            node=impt,
            deployment_config=deployment_config,
            component_defs=component_defs)
        actual_step_spec = self._sole(my_builder.build())
        actual_component_def = self._sole(component_defs)

        self.assertProtoEquals(
            test_utils.get_proto_from_test_data(
                'expected_importer_component.pbtxt',
                pipeline_pb2.ComponentSpec()), actual_component_def)
        self.assertProtoEquals(
            test_utils.get_proto_from_test_data(
                'expected_importer_task.pbtxt',
                pipeline_pb2.PipelineTaskSpec()), actual_step_spec)
        self.assertProtoEquals(
            test_utils.get_proto_from_test_data(
                'expected_importer_executor.pbtxt',
                pipeline_pb2.PipelineDeploymentConfig()), deployment_config)
示例#9
0
    def testBuildLatestBlessedModelStrategySucceed(self):
        latest_blessed_resolver = resolver.Resolver(
            strategy_class=latest_blessed_model_strategy.
            LatestBlessedModelStrategy,
            model=channel.Channel(type=standard_artifacts.Model),
            model_blessing=channel.Channel(
                type=standard_artifacts.ModelBlessing)).with_id('my_resolver2')
        test_pipeline_info = data_types.PipelineInfo(
            pipeline_name='test-pipeline',
            pipeline_root='gs://path/to/my/root')

        deployment_config = pipeline_pb2.PipelineDeploymentConfig()
        component_defs = {}
        my_builder = step_builder.StepBuilder(
            node=latest_blessed_resolver,
            deployment_config=deployment_config,
            pipeline_info=test_pipeline_info,
            component_defs=component_defs)
        actual_step_specs = my_builder.build()

        model_blessing_resolver_id = 'my_resolver2-model-blessing-resolver'
        model_resolver_id = 'my_resolver2-model-resolver'
        self.assertSameElements(
            actual_step_specs.keys(),
            [model_blessing_resolver_id, model_resolver_id])

        self.assertProtoEquals(
            test_utils.get_proto_from_test_data(
                'expected_latest_blessed_model_resolver_component_1.pbtxt',
                pipeline_pb2.ComponentSpec()),
            component_defs[model_blessing_resolver_id])

        self.assertProtoEquals(
            test_utils.get_proto_from_test_data(
                'expected_latest_blessed_model_resolver_task_1.pbtxt',
                pipeline_pb2.PipelineTaskSpec()),
            actual_step_specs[model_blessing_resolver_id])

        self.assertProtoEquals(
            test_utils.get_proto_from_test_data(
                'expected_latest_blessed_model_resolver_component_2.pbtxt',
                pipeline_pb2.ComponentSpec()),
            component_defs[model_resolver_id])

        self.assertProtoEquals(
            test_utils.get_proto_from_test_data(
                'expected_latest_blessed_model_resolver_task_2.pbtxt',
                pipeline_pb2.PipelineTaskSpec()),
            actual_step_specs[model_resolver_id])

        self.assertProtoEquals(
            test_utils.get_proto_from_test_data(
                'expected_latest_blessed_model_resolver_executor.pbtxt',
                pipeline_pb2.PipelineDeploymentConfig()), deployment_config)
示例#10
0
    def testBuildDummyConsumerWithCondition(self):
        producer_task_1 = test_utils.dummy_producer_component(
            output1=channel_utils.as_channel([standard_artifacts.Model()]),
            param1='value1',
        ).with_id('producer_task_1')
        producer_task_2 = test_utils.dummy_producer_component_2(
            output1=channel_utils.as_channel([standard_artifacts.Model()]),
            param1='value2',
        ).with_id('producer_task_2')
        # This test tests two things:
        # 1. Nested conditions. The condition string of consumer_task should contain
        #    both predicates.
        # 2. Implicit channels. consumer_task only takes producer_task_1's output.
        #    But producer_task_2 is used in condition, hence producer_task_2 should
        #    be added to the dependency of consumer_task.
        # See testdata for detail.
        with conditional.Cond(
                producer_task_1.outputs['output1'].future()[0].uri != 'uri'):
            with conditional.Cond(producer_task_2.outputs['output1'].future()
                                  [0].property('property') == 'value1'):
                consumer_task = test_utils.dummy_consumer_component(
                    input1=producer_task_1.outputs['output1'],
                    param1=1,
                )
        # Need to construct a pipeline to set producer_component_id.
        unused_pipeline = tfx.dsl.Pipeline(
            pipeline_name='pipeline-with-condition',
            pipeline_root='',
            components=[producer_task_1, producer_task_2, consumer_task],
        )
        deployment_config = pipeline_pb2.PipelineDeploymentConfig()
        component_defs = {}
        my_builder = step_builder.StepBuilder(
            node=consumer_task,
            image='gcr.io/tensorflow/tfx:latest',
            deployment_config=deployment_config,
            component_defs=component_defs)
        actual_step_spec = self._sole(my_builder.build())
        actual_component_def = self._sole(component_defs)

        self.assertProtoEquals(
            test_utils.get_proto_from_test_data(
                'expected_dummy_consumer_with_condition_component.pbtxt',
                pipeline_pb2.ComponentSpec()), actual_component_def)
        self.assertProtoEquals(
            test_utils.get_proto_from_test_data(
                'expected_dummy_consumer_with_condition_task.pbtxt',
                pipeline_pb2.PipelineTaskSpec()), actual_step_spec)
        self.assertProtoEquals(
            test_utils.get_proto_from_test_data(
                'expected_dummy_consumer_with_condition_executor.pbtxt',
                pipeline_pb2.PipelineDeploymentConfig()), deployment_config)
示例#11
0
  def _create_pipeline_spec(
      self,
      args: List[dsl.PipelineParam],
      pipeline: dsl.Pipeline,
  ) -> pipeline_spec_pb2.PipelineSpec:
    """Creates the pipeline spec object.

    Args:
      args: The list of pipeline arguments.
      pipeline: The instantiated pipeline object.

    Returns:
      A PipelineSpec proto representing the compiled pipeline.

    Raises:
      NotImplementedError if the argument is of unsupported types.
    """
    compiler_utils.validate_pipeline_name(pipeline.name)

    deployment_config = pipeline_spec_pb2.PipelineDeploymentConfig()
    pipeline_spec = pipeline_spec_pb2.PipelineSpec()

    pipeline_spec.pipeline_info.name = pipeline.name
    pipeline_spec.sdk_version = 'kfp-{}'.format(kfp.__version__)
    # Schema version 2.0.0 is required for kfp-pipeline-spec>0.1.3.1
    pipeline_spec.schema_version = '2.0.0'

    dsl_component_spec.build_component_inputs_spec(
        component_spec=pipeline_spec.root,
        pipeline_params=args,
        is_root_component=True)

    root_group = pipeline.groups[0]
    opsgroups = self._get_groups(root_group)
    op_name_to_parent_groups = self._get_groups_for_ops(root_group)
    opgroup_name_to_parent_groups = self._get_groups_for_opsgroups(root_group)

    condition_params = self._get_condition_params_for_ops(root_group)
    op_name_to_for_loop_op = self._get_for_loop_ops(root_group)
    inputs, outputs = self._get_inputs_outputs(
        pipeline,
        args,
        root_group,
        op_name_to_parent_groups,
        opgroup_name_to_parent_groups,
        condition_params,
        op_name_to_for_loop_op,
    )
    dependencies = self._get_dependencies(
        pipeline,
        root_group,
        op_name_to_parent_groups,
        opgroup_name_to_parent_groups,
        opsgroups,
        condition_params,
    )

    for opsgroup_name in opsgroups.keys():
      self._group_to_dag_spec(
          opsgroups[opsgroup_name],
          inputs,
          outputs,
          dependencies,
          pipeline_spec,
          deployment_config,
          root_group.name,
          op_name_to_parent_groups,
      )

    # Exit Handler
    if pipeline.groups[0].groups:
      first_group = pipeline.groups[0].groups[0]
      if first_group.type == 'exit_handler':
        exit_handler_op = first_group.exit_op

        # Add exit op task spec
        task_name = exit_handler_op.task_spec.task_info.name
        exit_handler_op.task_spec.dependent_tasks.extend(
            pipeline_spec.root.dag.tasks.keys())
        exit_handler_op.task_spec.trigger_policy.strategy = (
            pipeline_spec_pb2.PipelineTaskSpec.TriggerPolicy.TriggerStrategy
            .ALL_UPSTREAM_TASKS_COMPLETED)
        pipeline_spec.root.dag.tasks[task_name].CopyFrom(
            exit_handler_op.task_spec)

        # Add exit op component spec if it does not exist.
        component_name = exit_handler_op.task_spec.component_ref.name
        if component_name not in pipeline_spec.components:
          pipeline_spec.components[component_name].CopyFrom(
              exit_handler_op.component_spec)

        # Add exit op executor spec if it does not exist.
        executor_label = exit_handler_op.component_spec.executor_label
        if executor_label not in deployment_config.executors:
          deployment_config.executors[executor_label].container.CopyFrom(
              exit_handler_op.container_spec)
          pipeline_spec.deployment_spec.update(
              json_format.MessageToDict(deployment_config))

    return pipeline_spec
示例#12
0
  def _create_pipeline_spec(
      self,
      args: List[dsl.PipelineParam],
      pipeline: dsl.Pipeline,
  ) -> pipeline_spec_pb2.PipelineSpec:
    """Creates the pipeline spec object.

    Args:
      args: The list of pipeline arguments.
      pipeline: The instantiated pipeline object.

    Returns:
      A PipelineSpec proto representing the compiled pipeline.

    Raises:
      NotImplementedError if the argument is of unsupported types.
    """
    compiler_utils.validate_pipeline_name(pipeline.name)

    deployment_config = pipeline_spec_pb2.PipelineDeploymentConfig()
    pipeline_spec = pipeline_spec_pb2.PipelineSpec()

    pipeline_spec.pipeline_info.name = pipeline.name
    pipeline_spec.sdk_version = 'kfp-{}'.format(kfp.__version__)
    # Schema version 2.0.0 is required for kfp-pipeline-spec>0.1.3.1
    pipeline_spec.schema_version = '2.0.0'

    dsl_component_spec.build_component_inputs_spec(
        component_spec=pipeline_spec.root,
        pipeline_params=args,
        is_root_component=True)

    root_group = pipeline.groups[0]
    opsgroups = self._get_groups(root_group)
    op_name_to_parent_groups = self._get_groups_for_ops(root_group)
    opgroup_name_to_parent_groups = self._get_groups_for_opsgroups(root_group)

    condition_params = self._get_condition_params_for_ops(root_group)
    op_name_to_for_loop_op = self._get_for_loop_ops(root_group)
    inputs, outputs = self._get_inputs_outputs(
        pipeline,
        args,
        root_group,
        op_name_to_parent_groups,
        opgroup_name_to_parent_groups,
        condition_params,
        op_name_to_for_loop_op,
    )
    dependencies = self._get_dependencies(
        pipeline,
        root_group,
        op_name_to_parent_groups,
        opgroup_name_to_parent_groups,
        opsgroups,
        condition_params,
    )

    for opsgroup_name in opsgroups.keys():
      self._group_to_dag_spec(
          opsgroups[opsgroup_name],
          inputs,
          outputs,
          dependencies,
          pipeline_spec,
          deployment_config,
          root_group.name,
      )

    return pipeline_spec
示例#13
0
    def _create_pipeline_spec(
        self,
        args: List[dsl.PipelineParam],
        pipeline: dsl.Pipeline,
    ) -> pipeline_spec_pb2.PipelineSpec:
        """Creates the pipeline spec object.

    Args:
      args: The list of pipeline arguments.
      pipeline: The instantiated pipeline object.

    Returns:
      A PipelineSpec proto representing the compiled pipeline.

    Raises:
      NotImplementedError if the argument is of unsupported types.
    """
        compiler_utils.validate_pipeline_name(pipeline.name)

        pipeline_spec = pipeline_spec_pb2.PipelineSpec(
            runtime_parameters=compiler_utils.build_runtime_parameter_spec(
                args))

        pipeline_spec.pipeline_info.name = pipeline.name
        pipeline_spec.sdk_version = 'kfp-{}'.format(kfp.__version__)
        pipeline_spec.schema_version = 'v2alpha1'

        deployment_config = pipeline_spec_pb2.PipelineDeploymentConfig()
        importer_tasks = []

        for op in pipeline.ops.values():
            component_spec = op._metadata
            task = pipeline_spec.tasks.add()
            task.CopyFrom(op.task_spec)
            deployment_config.executors[
                task.executor_label].container.CopyFrom(op.container_spec)

            # A task may have explicit depdency on other tasks even though they may
            # not have inputs/outputs dependency. e.g.: op2.after(op1)
            if op.dependent_names:
                task.dependent_tasks.extend(op.dependent_names)

            # Check if need to insert importer node
            for input_name in task.inputs.artifacts:
                if not task.inputs.artifacts[input_name].producer_task:
                    type_schema = type_utils.get_input_artifact_type_schema(
                        input_name, component_spec.inputs)

                    importer_task = importer_node.build_importer_task_spec(
                        dependent_task=task,
                        input_name=input_name,
                        input_type_schema=type_schema)
                    importer_tasks.append(importer_task)

                    task.inputs.artifacts[
                        input_name].producer_task = importer_task.task_info.name
                    task.inputs.artifacts[
                        input_name].output_artifact_key = importer_node.OUTPUT_KEY

                    # Retrieve the pre-built importer spec
                    importer_spec = op.importer_spec[input_name]
                    deployment_config.executors[
                        importer_task.executor_label].importer.CopyFrom(
                            importer_spec)

        pipeline_spec.deployment_config.Pack(deployment_config)
        pipeline_spec.tasks.extend(importer_tasks)

        return pipeline_spec
示例#14
0
    def _create_pipeline_spec(
        self,
        args: List[dsl.PipelineParam],
        pipeline: dsl.Pipeline,
    ) -> pipeline_spec_pb2.PipelineSpec:
        """Creates the pipeline spec object.

    Args:
      args: The list of pipeline arguments.
      pipeline: The instantiated pipeline object.

    Returns:
      A PipelineSpec proto representing the compiled pipeline.

    Raises:
      NotImplementedError if the argument is of unsupported types.
    """
        compiler_utils.validate_pipeline_name(pipeline.name)

        pipeline_spec = pipeline_spec_pb2.PipelineSpec()

        pipeline_spec.pipeline_info.name = pipeline.name
        pipeline_spec.sdk_version = 'kfp-{}'.format(kfp.__version__)
        # Schema version 2.0.0 is required for kfp-pipeline-spec>0.1.3.1
        pipeline_spec.schema_version = '2.0.0'

        pipeline_spec.root.CopyFrom(
            dsl_component_spec.build_root_spec_from_pipeline_params(args))

        deployment_config = pipeline_spec_pb2.PipelineDeploymentConfig()

        for op in pipeline.ops.values():
            task_name = op.task_spec.task_info.name
            component_name = op.task_spec.component_ref.name
            executor_label = op.component_spec.executor_label

            pipeline_spec.root.dag.tasks[task_name].CopyFrom(op.task_spec)
            pipeline_spec.components[component_name].CopyFrom(
                op.component_spec)
            deployment_config.executors[executor_label].container.CopyFrom(
                op.container_spec)

            task = pipeline_spec.root.dag.tasks[task_name]
            # A task may have explicit depdency on other tasks even though they may
            # not have inputs/outputs dependency. e.g.: op2.after(op1)
            if op.dependent_names:
                op.dependent_names = [
                    dsl_utils.sanitize_task_name(name)
                    for name in op.dependent_names
                ]
                task.dependent_tasks.extend(op.dependent_names)

            # Check if need to insert importer node
            for input_name in task.inputs.artifacts:
                if not task.inputs.artifacts[
                        input_name].task_output_artifact.producer_task:
                    type_schema = type_utils.get_input_artifact_type_schema(
                        input_name, op._metadata.inputs)

                    importer_name = importer_node.generate_importer_base_name(
                        dependent_task_name=task_name, input_name=input_name)
                    importer_task_spec = importer_node.build_importer_task_spec(
                        importer_name)
                    importer_comp_spec = importer_node.build_importer_component_spec(
                        importer_base_name=importer_name,
                        input_name=input_name,
                        input_type_schema=type_schema)
                    importer_task_name = importer_task_spec.task_info.name
                    importer_comp_name = importer_task_spec.component_ref.name
                    importer_exec_label = importer_comp_spec.executor_label
                    pipeline_spec.root.dag.tasks[importer_task_name].CopyFrom(
                        importer_task_spec)
                    pipeline_spec.components[importer_comp_name].CopyFrom(
                        importer_comp_spec)

                    task.inputs.artifacts[
                        input_name].task_output_artifact.producer_task = (
                            importer_task_name)
                    task.inputs.artifacts[
                        input_name].task_output_artifact.output_artifact_key = (
                            importer_node.OUTPUT_KEY)

                    # Retrieve the pre-built importer spec
                    importer_spec = op.importer_specs[input_name]
                    deployment_config.executors[
                        importer_exec_label].importer.CopyFrom(importer_spec)

        pipeline_spec.deployment_spec.update(
            json_format.MessageToDict(deployment_config))

        return pipeline_spec
示例#15
0
    def build(self) -> pipeline_pb2.PipelineSpec:
        """Build a pipeline PipelineSpec."""

        _check_name(self._pipeline_info.pipeline_name)

        deployment_config = pipeline_pb2.PipelineDeploymentConfig()
        pipeline_info = pipeline_pb2.PipelineInfo(
            name=self._pipeline_info.pipeline_name)

        tfx_tasks = {}
        component_defs = {}
        # Map from (producer component id, output key) to (new producer component
        # id, output key)
        channel_redirect_map = {}
        with parameter_utils.ParameterContext() as pc:
            for component in self._pipeline.components:
                if self._exit_handler and component.id == compiler_utils.TFX_DAG_NAME:
                    component.with_id(component.id +
                                      _generate_component_name_suffix())
                    logging.warning(
                        '_tfx_dag is system reserved name for pipeline with'
                        'exit handler, added suffix to your component name: %s',
                        component.id)
                # Here the topological order of components is required.
                # If a channel redirection is needed, redirect mapping is expected to be
                # available because the upstream node (which is the cause for
                # redirecting) is processed before the downstream consumer nodes.
                built_tasks = step_builder.StepBuilder(
                    node=component,
                    deployment_config=deployment_config,
                    component_defs=component_defs,
                    image=self._default_image,
                    image_cmds=self._default_commands,
                    beam_pipeline_args=self._pipeline.beam_pipeline_args,
                    enable_cache=self._pipeline.enable_cache,
                    pipeline_info=self._pipeline_info,
                    channel_redirect_map=channel_redirect_map).build()
                tfx_tasks.update(built_tasks)

        result = pipeline_pb2.PipelineSpec(pipeline_info=pipeline_info)

        # if exit handler is defined, put all the TFX tasks under tfx_dag,
        # exit handler is a separate component triggered by tfx_dag.
        if self._exit_handler:
            for name, task_spec in tfx_tasks.items():
                result.components[compiler_utils.TFX_DAG_NAME].dag.tasks[
                    name].CopyFrom(task_spec)
            # construct root with exit handler
            exit_handler_task = step_builder.StepBuilder(
                node=self._exit_handler,
                deployment_config=deployment_config,
                component_defs=component_defs,
                image=self._default_image,
                image_cmds=self._default_commands,
                beam_pipeline_args=self._pipeline.beam_pipeline_args,
                enable_cache=False,
                pipeline_info=self._pipeline_info,
                channel_redirect_map=channel_redirect_map,
                is_exit_handler=True).build()
            result.root.dag.tasks[
                compiler_utils.
                TFX_DAG_NAME].component_ref.name = compiler_utils.TFX_DAG_NAME
            result.root.dag.tasks[
                compiler_utils.
                TFX_DAG_NAME].task_info.name = compiler_utils.TFX_DAG_NAME
            result.root.dag.tasks[self._exit_handler.id].CopyFrom(
                exit_handler_task[self._exit_handler.id])
        else:
            for name, task_spec in tfx_tasks.items():
                result.root.dag.tasks[name].CopyFrom(task_spec)

        result.deployment_spec.update(
            json_format.MessageToDict(deployment_config))
        for name, component_def in component_defs.items():
            result.components[name].CopyFrom(component_def)

        # Attach runtime parameter to root's input parameter
        for param in pc.parameters:
            result.root.input_definitions.parameters[param.name].CopyFrom(
                compiler_utils.build_parameter_type_spec(param))

        return result
示例#16
0
    def to_pipeline_spec(self) -> pipeline_spec_pb2.PipelineSpec:
        """Creates a pipeline instance and constructs the pipeline spec for a
        single component.

        Args:
            component_spec: The ComponentSpec to convert to PipelineSpec.

        Returns:
            A PipelineSpec proto representing the compiled component.
        """
        # import here to aviod circular module dependency
        from kfp.compiler import pipeline_spec_builder as builder
        from kfp.components import pipeline_task
        from kfp.components import tasks_group
        from kfp.components.types import type_utils

        args_dict = {}
        pipeline_inputs = self.inputs or {}

        for arg_name, input_spec in pipeline_inputs.items():
            arg_type = input_spec.type
            if not type_utils.is_parameter_type(
                    arg_type) or type_utils.is_task_final_status_type(
                        arg_type):
                raise TypeError(
                    builder.make_invalid_input_type_error_msg(
                        arg_name, arg_type))
            args_dict[arg_name] = dsl.PipelineParameterChannel(
                name=arg_name, channel_type=arg_type)

        task = pipeline_task.PipelineTask(self, args_dict)

        # instead of constructing a pipeline with pipeline_context.Pipeline,
        # just build the single task group
        group = tasks_group.TasksGroup(
            group_type=tasks_group.TasksGroupType.PIPELINE)
        group.tasks.append(task)

        # Fill in the default values.
        args_list_with_defaults = [
            dsl.PipelineParameterChannel(
                name=input_name,
                channel_type=input_spec.type,
                value=input_spec.default,
            ) for input_name, input_spec in pipeline_inputs.items()
        ]
        group.name = uuid.uuid4().hex

        pipeline_name = self.name
        pipeline_args = args_list_with_defaults
        task_group = group

        builder.validate_pipeline_name(pipeline_name)

        pipeline_spec = pipeline_spec_pb2.PipelineSpec()
        pipeline_spec.pipeline_info.name = pipeline_name
        pipeline_spec.sdk_version = f'kfp-{kfp.__version__}'
        # Schema version 2.1.0 is required for kfp-pipeline-spec>0.1.13
        pipeline_spec.schema_version = '2.1.0'
        pipeline_spec.root.CopyFrom(
            builder.build_component_spec_for_group(
                pipeline_channels=pipeline_args,
                is_root_group=True,
            ))

        deployment_config = pipeline_spec_pb2.PipelineDeploymentConfig()
        root_group = task_group

        task_name_to_parent_groups, group_name_to_parent_groups = builder.get_parent_groups(
            root_group)

        def get_inputs(task_group: tasks_group.TasksGroup,
                       task_name_to_parent_groups):
            inputs = collections.defaultdict(set)
            if len(task_group.tasks) != 1:
                raise ValueError(
                    f'Error compiling component. Expected one task in task group, got {len(task_group.tasks)}.'
                )
            only_task = task_group.tasks[0]
            if only_task.channel_inputs:
                for group_name in task_name_to_parent_groups[only_task.name]:
                    inputs[group_name].add(
                        (only_task.channel_inputs[-1], None))
            return inputs

        inputs = get_inputs(task_group, task_name_to_parent_groups)

        builder.build_spec_by_group(
            pipeline_spec=pipeline_spec,
            deployment_config=deployment_config,
            group=root_group,
            inputs=inputs,
            dependencies={},  # no dependencies for single-component pipeline
            rootgroup_name=root_group.name,
            task_name_to_parent_groups=task_name_to_parent_groups,
            group_name_to_parent_groups=group_name_to_parent_groups,
            name_to_for_loop_group=
            {},  # no for loop for single-component pipeline
        )

        return pipeline_spec
示例#17
0
    def _create_pipeline_spec(
        self,
        pipeline_args: List[dsl.PipelineChannel],
        pipeline: pipeline_context.Pipeline,
    ) -> pipeline_spec_pb2.PipelineSpec:
        """Creates a pipeline spec object.

        Args:
            pipeline_args: The list of pipeline input parameters.
            pipeline: The instantiated pipeline object.

        Returns:
            A PipelineSpec proto representing the compiled pipeline.

        Raises:
            ValueError if the argument is of unsupported types.
        """
        builder.validate_pipeline_name(pipeline.name)

        deployment_config = pipeline_spec_pb2.PipelineDeploymentConfig()
        pipeline_spec = pipeline_spec_pb2.PipelineSpec()

        pipeline_spec.pipeline_info.name = pipeline.name
        pipeline_spec.sdk_version = f'kfp-{kfp.__version__}'
        # Schema version 2.1.0 is required for kfp-pipeline-spec>0.1.13
        pipeline_spec.schema_version = '2.1.0'

        pipeline_spec.root.CopyFrom(
            builder.build_component_spec_for_group(
                pipeline_channels=pipeline_args,
                is_root_group=True,
            ))

        root_group = pipeline.groups[0]

        all_groups = self._get_all_groups(root_group)
        group_name_to_group = {group.name: group for group in all_groups}
        task_name_to_parent_groups, group_name_to_parent_groups = (
            builder.get_parent_groups(root_group))
        condition_channels = self._get_condition_channels_for_tasks(root_group)
        name_to_for_loop_group = {
            group_name: group
            for group_name, group in group_name_to_group.items()
            if isinstance(group, dsl.ParallelFor)
        }
        inputs = self._get_inputs_for_all_groups(
            pipeline=pipeline,
            pipeline_args=pipeline_args,
            root_group=root_group,
            task_name_to_parent_groups=task_name_to_parent_groups,
            group_name_to_parent_groups=group_name_to_parent_groups,
            condition_channels=condition_channels,
            name_to_for_loop_group=name_to_for_loop_group,
        )
        dependencies = self._get_dependencies(
            pipeline=pipeline,
            root_group=root_group,
            task_name_to_parent_groups=task_name_to_parent_groups,
            group_name_to_parent_groups=group_name_to_parent_groups,
            group_name_to_group=group_name_to_group,
            condition_channels=condition_channels,
        )

        for group in all_groups:
            builder.build_spec_by_group(
                pipeline_spec=pipeline_spec,
                deployment_config=deployment_config,
                group=group,
                inputs=inputs,
                dependencies=dependencies,
                rootgroup_name=root_group.name,
                task_name_to_parent_groups=task_name_to_parent_groups,
                group_name_to_parent_groups=group_name_to_parent_groups,
                name_to_for_loop_group=name_to_for_loop_group,
            )

        # TODO: refactor to support multiple exit handler per pipeline.
        if pipeline.groups[0].groups:
            first_group = pipeline.groups[0].groups[0]
            if isinstance(first_group, dsl.ExitHandler):
                exit_task = first_group.exit_task
                exit_task_name = component_utils.sanitize_task_name(
                    exit_task.name)
                exit_handler_group_task_name = component_utils.sanitize_task_name(
                    first_group.name)
                input_parameters_in_current_dag = [
                    input_name for input_name in
                    pipeline_spec.root.input_definitions.parameters
                ]
                exit_task_task_spec = builder.build_task_spec_for_exit_task(
                    task=exit_task,
                    dependent_task=exit_handler_group_task_name,
                    pipeline_inputs=pipeline_spec.root.input_definitions,
                )

                exit_task_component_spec = builder.build_component_spec_for_exit_task(
                    task=exit_task)

                exit_task_container_spec = builder.build_container_spec_for_task(
                    task=exit_task)

                # Add exit task task spec
                pipeline_spec.root.dag.tasks[exit_task_name].CopyFrom(
                    exit_task_task_spec)

                # Add exit task component spec if it does not exist.
                component_name = exit_task_task_spec.component_ref.name
                if component_name not in pipeline_spec.components:
                    pipeline_spec.components[component_name].CopyFrom(
                        exit_task_component_spec)

                # Add exit task container spec if it does not exist.
                executor_label = exit_task_component_spec.executor_label
                if executor_label not in deployment_config.executors:
                    deployment_config.executors[
                        executor_label].container.CopyFrom(
                            exit_task_container_spec)
                    pipeline_spec.deployment_spec.update(
                        json_format.MessageToDict(deployment_config))

        return pipeline_spec
    def _group_to_dag_spec(
        self,
        group: dsl.OpsGroup,
        inputs: Dict[str, List[Tuple[dsl.PipelineParam, str]]],
        outputs: Dict[str, List[Tuple[dsl.PipelineParam, str]]],
        dependencies: Dict[str, List[_GroupOrOp]],
        pipeline_spec: pipeline_spec_pb2.PipelineSpec,
        rootgroup_name: str,
    ) -> None:
        """Generate IR spec given an OpsGroup.

    Args:
      group: The OpsGroup to generate spec for.
      inputs: The inputs dictionary. The keys are group/op names and values are
        lists of tuples (param, producing_op_name).
      outputs: The outputs dictionary. The keys are group/op names and values
        are lists of tuples (param, producing_op_name).
      dependencies: The group dependencies dictionary. The keys are group/op
        names, and the values are lists of dependent groups/ops.
      pipeline_spec: The pipeline_spec to update in-place.
      rootgroup_name: The name of the group root. Used to determine whether the
        component spec for the current group should be the root dag.
    """
        group_component_name = dsl_utils.sanitize_component_name(group.name)

        if group.name == rootgroup_name:
            group_component_spec = pipeline_spec.root
        else:
            group_component_spec = pipeline_spec.components[
                group_component_name]

        deployment_config = pipeline_spec_pb2.PipelineDeploymentConfig()

        # Generate component inputs spec.
        if inputs.get(group.name, None):
            dsl_component_spec.build_component_inputs_spec(
                group_component_spec,
                [param for param, _ in inputs[group.name]])

        # Generate component outputs spec.
        if outputs.get(group.name, None):
            group_component_spec.output_definitions.CopyFrom(
                dsl_component_spec.build_component_outputs_spec(
                    [param for param, _ in outputs[group.name]]))

        # Generate task specs and component specs for the dag.
        subgroups = group.groups + group.ops
        for subgroup in subgroups:
            subgroup_task_spec = getattr(subgroup, 'task_spec',
                                         pipeline_spec_pb2.PipelineTaskSpec())
            subgroup_component_spec = getattr(
                subgroup, 'component_spec', pipeline_spec_pb2.ComponentSpec())
            is_recursive_subgroup = (isinstance(subgroup, dsl.OpsGroup)
                                     and subgroup.recursive_ref)
            # Special handling for recursive subgroup: use the existing opsgroup name
            if is_recursive_subgroup:
                subgroup_key = subgroup.recursive_ref.name
            else:
                subgroup_key = subgroup.name

            subgroup_task_spec.task_info.name = dsl_utils.sanitize_task_name(
                subgroup_key)
            # human_name exists for ops only, and is used to de-dupe component spec.
            subgroup_component_name = dsl_utils.sanitize_component_name(
                getattr(subgroup, 'human_name', subgroup_key))
            subgroup_task_spec.component_ref.name = subgroup_component_name

            if isinstance(subgroup,
                          dsl.OpsGroup) and subgroup.type == 'condition':
                condition = subgroup.condition
                operand_values = []
                subgroup_inputs = inputs.get(subgroup.name, [])
                subgroup_params = [param for param, _ in subgroup_inputs]
                tasks_in_current_dag = [
                    subgroup.name for subgroup in subgroups
                ]

                dsl_component_spec.build_component_inputs_spec(
                    subgroup_component_spec,
                    subgroup_params,
                )
                dsl_component_spec.build_task_inputs_spec(
                    subgroup_task_spec,
                    subgroup_params,
                    tasks_in_current_dag,
                )

                for operand in [condition.operand1, condition.operand2]:
                    operand_values.append(
                        self._resolve_value_or_reference(operand))

                condition_string = '{} {} {}'.format(operand_values[0],
                                                     condition.operator,
                                                     operand_values[1])

                subgroup_task_spec.trigger_policy.CopyFrom(
                    pipeline_spec_pb2.PipelineTaskSpec.TriggerPolicy(
                        condition=condition_string))

            # Generate dependencies section for this task.
            if dependencies.get(subgroup.name, None):
                group_dependencies = list(dependencies[subgroup.name])
                group_dependencies.sort()
                subgroup_task_spec.dependent_tasks.extend([
                    dsl_utils.sanitize_task_name(dep)
                    for dep in group_dependencies
                ])

            # Add importer node when applicable
            for input_name in subgroup_task_spec.inputs.artifacts:
                if not subgroup_task_spec.inputs.artifacts[
                        input_name].task_output_artifact.producer_task:
                    type_schema = type_utils.get_input_artifact_type_schema(
                        input_name, subgroup._metadata.inputs)

                    importer_name = importer_node.generate_importer_base_name(
                        dependent_task_name=subgroup_task_spec.task_info.name,
                        input_name=input_name)
                    importer_task_spec = importer_node.build_importer_task_spec(
                        importer_name)
                    importer_comp_spec = importer_node.build_importer_component_spec(
                        importer_base_name=importer_name,
                        input_name=input_name,
                        input_type_schema=type_schema)
                    importer_task_name = importer_task_spec.task_info.name
                    importer_comp_name = importer_task_spec.component_ref.name
                    importer_exec_label = importer_comp_spec.executor_label
                    group_component_spec.dag.tasks[
                        importer_task_name].CopyFrom(importer_task_spec)
                    pipeline_spec.components[importer_comp_name].CopyFrom(
                        importer_comp_spec)

                    subgroup_task_spec.inputs.artifacts[
                        input_name].task_output_artifact.producer_task = (
                            importer_task_name)
                    subgroup_task_spec.inputs.artifacts[
                        input_name].task_output_artifact.output_artifact_key = (
                            importer_node.OUTPUT_KEY)

                    # Retrieve the pre-built importer spec
                    importer_spec = subgroup.importer_specs[input_name]
                    deployment_config.executors[
                        importer_exec_label].importer.CopyFrom(importer_spec)

            # Add component spec if not exists
            if subgroup_component_name not in pipeline_spec.components:
                pipeline_spec.components[subgroup_component_name].CopyFrom(
                    subgroup_component_spec)

            # Add task spec
            group_component_spec.dag.tasks[
                subgroup_task_spec.task_info.name].CopyFrom(subgroup_task_spec)

            # Add executor spec
            container_spec = getattr(subgroup, 'container_spec', None)
            if container_spec:
                if compiler_utils.is_v2_component(subgroup):
                    compiler_utils.refactor_v2_container_spec(container_spec)
                executor_label = subgroup_component_spec.executor_label

                if executor_label not in deployment_config.executors:
                    deployment_config.executors[
                        executor_label].container.CopyFrom(container_spec)

        pipeline_spec.deployment_spec.update(
            json_format.MessageToDict(deployment_config))