Beispiel #1
0
    def testBuildLatestBlessedModelStrategySucceed(self):
        latest_blessed_resolver = resolver.Resolver(
            strategy_class=latest_blessed_model_strategy.
            LatestBlessedModelStrategy,
            model=channel.Channel(type=standard_artifacts.Model),
            model_blessing=channel.Channel(
                type=standard_artifacts.ModelBlessing)).with_id('my_resolver2')
        test_pipeline_info = data_types.PipelineInfo(
            pipeline_name='test-pipeline',
            pipeline_root='gs://path/to/my/root')

        deployment_config = pipeline_pb2.PipelineDeploymentConfig()
        component_defs = {}
        my_builder = step_builder.StepBuilder(
            node=latest_blessed_resolver,
            deployment_config=deployment_config,
            pipeline_info=test_pipeline_info,
            component_defs=component_defs)
        actual_step_specs = my_builder.build()

        model_blessing_resolver_id = 'my_resolver2-model-blessing-resolver'
        model_resolver_id = 'my_resolver2-model-resolver'
        self.assertSameElements(
            actual_step_specs.keys(),
            [model_blessing_resolver_id, model_resolver_id])

        self.assertProtoEquals(
            test_utils.get_proto_from_test_data(
                'expected_latest_blessed_model_resolver_component_1.pbtxt',
                pipeline_pb2.ComponentSpec()),
            component_defs[model_blessing_resolver_id])

        self.assertProtoEquals(
            test_utils.get_proto_from_test_data(
                'expected_latest_blessed_model_resolver_task_1.pbtxt',
                pipeline_pb2.PipelineTaskSpec()),
            actual_step_specs[model_blessing_resolver_id])

        self.assertProtoEquals(
            test_utils.get_proto_from_test_data(
                'expected_latest_blessed_model_resolver_component_2.pbtxt',
                pipeline_pb2.ComponentSpec()),
            component_defs[model_resolver_id])

        self.assertProtoEquals(
            test_utils.get_proto_from_test_data(
                'expected_latest_blessed_model_resolver_task_2.pbtxt',
                pipeline_pb2.PipelineTaskSpec()),
            actual_step_specs[model_resolver_id])

        self.assertProtoEquals(
            test_utils.get_proto_from_test_data(
                'expected_latest_blessed_model_resolver_executor.pbtxt',
                pipeline_pb2.PipelineDeploymentConfig()), deployment_config)
Beispiel #2
0
    def test_build_task_inputs_spec(self, is_parent_component_root,
                                    expected_result):
        pipeline_params = self.TEST_PIPELINE_PARAMS
        tasks_in_current_dag = ['op-1', 'op-2']
        expected_spec = pipeline_spec_pb2.PipelineTaskSpec()
        json_format.ParseDict(expected_result, expected_spec)

        task_spec = pipeline_spec_pb2.PipelineTaskSpec()
        dsl_component_spec.build_task_inputs_spec(task_spec, pipeline_params,
                                                  tasks_in_current_dag,
                                                  is_parent_component_root)

        self.assertEqual(expected_spec, task_spec)
    def test_build_task_inputs_spec(self):
        pipeline_params = [
            dsl.PipelineParam(name='output1',
                              param_type='Dataset',
                              op_name='op-1'),
            dsl.PipelineParam(name='output2',
                              param_type='Integer',
                              op_name='op-2'),
            dsl.PipelineParam(name='output3',
                              param_type='Model',
                              op_name='op-3'),
            dsl.PipelineParam(name='output4',
                              param_type='Double',
                              op_name='op-4'),
        ]
        tasks_in_current_dag = ['op-1', 'op-2']
        expected_dict = {
            'inputs': {
                'artifacts': {
                    'op-1-output1': {
                        'taskOutputArtifact': {
                            'producerTask': 'task-op-1',
                            'outputArtifactKey': 'output1'
                        }
                    },
                    'op-3-output3': {
                        'componentInputArtifact': 'op-3-output3'
                    }
                },
                'parameters': {
                    'op-2-output2': {
                        'taskOutputParameter': {
                            'producerTask': 'task-op-2',
                            'outputParameterKey': 'output2'
                        }
                    },
                    'op-4-output4': {
                        'componentInputParameter': 'op-4-output4'
                    }
                }
            }
        }
        expected_spec = pipeline_spec_pb2.PipelineTaskSpec()
        json_format.ParseDict(expected_dict, expected_spec)

        task_spec = pipeline_spec_pb2.PipelineTaskSpec()
        dsl_component_spec.build_task_inputs_spec(task_spec, pipeline_params,
                                                  tasks_in_current_dag)

        self.assertEqual(expected_spec, task_spec)
Beispiel #4
0
  def _populate_metrics_in_dag_outputs(
      self,
      ops: List[dsl.ContainerOp],
      op_to_parent_groups: Dict[str, List[str]],
      pipeline_spec: pipeline_spec_pb2.PipelineSpec,
  ) -> None:
    """Populates metrics artifacts in dag outputs.

    Args:
      ops: The list of ops that may produce metrics outputs.
      op_to_parent_groups: The dict of op name to parent groups. Key is the op's
        name. Value is a list of ancestor groups including the op itself. The
        list of a given op is sorted in a way that the farthest group is the
        first and the op itself is the last.
      pipeline_spec: The pipeline_spec to update in-place.
    """
    for op in ops:
      op_task_spec = getattr(op, 'task_spec',
                             pipeline_spec_pb2.PipelineTaskSpec())
      op_component_spec = getattr(op, 'component_spec',
                                  pipeline_spec_pb2.ComponentSpec())

      # Get the tuple of (component_name, task_name) of all its parent groups.
      parent_components_and_tasks = [('_root', '')]
      # skip the op itself and the root group which cannot be retrived via name.
      for group_name in op_to_parent_groups[op.name][1:-1]:
        parent_components_and_tasks.append(
            (dsl_utils.sanitize_component_name(group_name),
             dsl_utils.sanitize_task_name(group_name)))
      # Reverse the order to make the farthest group in the end.
      parent_components_and_tasks.reverse()

      for output_name, artifact_spec in \
          op_component_spec.output_definitions.artifacts.items():

        if artifact_spec.artifact_type.WhichOneof(
            'kind'
        ) == 'schema_title' and artifact_spec.artifact_type.schema_title in [
            io_types.Metrics.TYPE_NAME,
            io_types.ClassificationMetrics.TYPE_NAME,
        ]:
          unique_output_name = '{}-{}'.format(op_task_spec.task_info.name,
                                              output_name)

          sub_task_name = op_task_spec.task_info.name
          sub_task_output = output_name
          for component_name, task_name in parent_components_and_tasks:
            group_component_spec = (
                pipeline_spec.root if component_name == '_root' else
                pipeline_spec.components[component_name])
            group_component_spec.output_definitions.artifacts[
                unique_output_name].CopyFrom(artifact_spec)
            group_component_spec.dag.outputs.artifacts[
                unique_output_name].artifact_selectors.append(
                    pipeline_spec_pb2.DagOutputsSpec.ArtifactSelectorSpec(
                        producer_subtask=sub_task_name,
                        output_artifact_key=sub_task_output,
                    ))
            sub_task_name = task_name
            sub_task_output = unique_output_name
Beispiel #5
0
    def testBuildExitHandler(self):
        task = test_utils.dummy_producer_component(
            param1=decorators.FinalStatusStr('value1'), )
        deployment_config = pipeline_pb2.PipelineDeploymentConfig()
        component_defs = {}
        my_builder = step_builder.StepBuilder(
            node=task,
            image='gcr.io/tensorflow/tfx:latest',
            deployment_config=deployment_config,
            component_defs=component_defs,
            is_exit_handler=True)
        actual_step_spec = self._sole(my_builder.build())
        actual_component_def = self._sole(component_defs)

        self.assertProtoEquals(
            test_utils.get_proto_from_test_data(
                'expected_dummy_exit_handler_component.pbtxt',
                pipeline_pb2.ComponentSpec()), actual_component_def)
        self.assertProtoEquals(
            test_utils.get_proto_from_test_data(
                'expected_dummy_exit_handler_task.pbtxt',
                pipeline_pb2.PipelineTaskSpec()), actual_step_spec)
        self.assertProtoEquals(
            test_utils.get_proto_from_test_data(
                'expected_dummy_exit_handler_executor.pbtxt',
                pipeline_pb2.PipelineDeploymentConfig()), deployment_config)
Beispiel #6
0
    def testBuildImporter(self):
        impt = importer.Importer(
            source_uri='m/y/u/r/i',
            properties={
                'split_names': '["train", "eval"]',
            },
            custom_properties={
                'str_custom_property': 'abc',
                'int_custom_property': 123,
            },
            artifact_type=standard_artifacts.Examples).with_id('my_importer')
        deployment_config = pipeline_pb2.PipelineDeploymentConfig()
        component_defs = {}
        my_builder = step_builder.StepBuilder(
            node=impt,
            deployment_config=deployment_config,
            component_defs=component_defs)
        actual_step_spec = self._sole(my_builder.build())
        actual_component_def = self._sole(component_defs)

        self.assertProtoEquals(
            test_utils.get_proto_from_test_data(
                'expected_importer_component.pbtxt',
                pipeline_pb2.ComponentSpec()), actual_component_def)
        self.assertProtoEquals(
            test_utils.get_proto_from_test_data(
                'expected_importer_task.pbtxt',
                pipeline_pb2.PipelineTaskSpec()), actual_step_spec)
        self.assertProtoEquals(
            test_utils.get_proto_from_test_data(
                'expected_importer_executor.pbtxt',
                pipeline_pb2.PipelineDeploymentConfig()), deployment_config)
Beispiel #7
0
    def testBuildImporterWithRuntimeParam(self):
        param = data_types.RuntimeParameter(name='runtime_flag', ptype=str)
        impt = importer.Importer(
            source_uri=param,
            artifact_type=standard_artifacts.Examples).with_id('my_importer')
        deployment_config = pipeline_pb2.PipelineDeploymentConfig()
        component_defs = {}
        with parameter_utils.ParameterContext() as pc:
            my_builder = step_builder.StepBuilder(
                node=impt,
                deployment_config=deployment_config,
                component_defs=component_defs)
            actual_step_spec = self._sole(my_builder.build())
        actual_component_def = self._sole(component_defs)

        self.assertProtoEquals(
            test_utils.get_proto_from_test_data(
                'expected_importer_component_with_runtime_param.pbtxt',
                pipeline_pb2.ComponentSpec()), actual_component_def)
        self.assertProtoEquals(
            test_utils.get_proto_from_test_data(
                'expected_importer_task_with_runtime_param.pbtxt',
                pipeline_pb2.PipelineTaskSpec()), actual_step_spec)
        self.assertProtoEquals(
            test_utils.get_proto_from_test_data(
                'expected_importer_executor_with_runtime_param.pbtxt',
                pipeline_pb2.PipelineDeploymentConfig()), deployment_config)
        self.assertListEqual([param], pc.parameters)
Beispiel #8
0
    def testBuildLatestArtifactResolverSucceed(self):
        latest_model_resolver = resolver.Resolver(
            strategy_class=latest_artifact_strategy.LatestArtifactStrategy,
            model=channel.Channel(type=standard_artifacts.Model),
            examples=channel.Channel(
                type=standard_artifacts.Examples)).with_id('my_resolver')
        deployment_config = pipeline_pb2.PipelineDeploymentConfig()
        component_defs = {}
        test_pipeline_info = data_types.PipelineInfo(
            pipeline_name='test-pipeline',
            pipeline_root='gs://path/to/my/root')
        my_builder = step_builder.StepBuilder(
            node=latest_model_resolver,
            deployment_config=deployment_config,
            pipeline_info=test_pipeline_info,
            component_defs=component_defs)
        actual_step_spec = self._sole(my_builder.build())
        actual_component_def = self._sole(component_defs)

        self.assertProtoEquals(
            test_utils.get_proto_from_test_data(
                'expected_latest_artifact_resolver_component.pbtxt',
                pipeline_pb2.ComponentSpec()), actual_component_def)
        self.assertProtoEquals(
            test_utils.get_proto_from_test_data(
                'expected_latest_artifact_resolver_task.pbtxt',
                pipeline_pb2.PipelineTaskSpec()), actual_step_spec)
        self.assertProtoEquals(
            test_utils.get_proto_from_test_data(
                'expected_latest_artifact_resolver_executor.pbtxt',
                pipeline_pb2.PipelineDeploymentConfig()), deployment_config)
Beispiel #9
0
def build_importer_task_spec(
    dependent_task: pipeline_spec_pb2.PipelineTaskSpec,
    input_name: str,
    input_type_schema: str,
) -> pipeline_spec_pb2.PipelineTaskSpec:
    """Builds an importer task spec.

  Args:
    dependent_task: The task requires importer node.
    input_name: The name of the input artifact needs to be imported.
    input_type_schema: The type of the input artifact.

  Returns:
    An importer node task spec.
  """
    dependent_task_name = dependent_task.task_info.name

    task_spec = pipeline_spec_pb2.PipelineTaskSpec()
    task_spec.task_info.name = '{}_{}_importer'.format(dependent_task_name,
                                                       input_name)
    task_spec.outputs.artifacts[OUTPUT_KEY].artifact_type.instance_schema = (
        input_type_schema)
    task_spec.executor_label = task_spec.task_info.name

    return task_spec
Beispiel #10
0
    def testBuildFileBasedExampleGen(self):
        example_gen = components.CsvExampleGen(
            input_base='path/to/data/root').with_beam_pipeline_args(
                ['--runner=DataflowRunner'])
        deployment_config = pipeline_pb2.PipelineDeploymentConfig()
        component_defs = {}
        my_builder = step_builder.StepBuilder(
            node=example_gen,
            image='gcr.io/tensorflow/tfx:latest',
            image_cmds=_TEST_CMDS,
            deployment_config=deployment_config,
            component_defs=component_defs)
        actual_step_spec = self._sole(my_builder.build())
        actual_component_def = self._sole(component_defs)

        self.assertProtoEquals(
            test_utils.get_proto_from_test_data(
                'expected_csv_example_gen_component.pbtxt',
                pipeline_pb2.ComponentSpec()), actual_component_def)
        self.assertProtoEquals(
            test_utils.get_proto_from_test_data(
                'expected_csv_example_gen_task.pbtxt',
                pipeline_pb2.PipelineTaskSpec()), actual_step_spec)
        self.assertProtoEquals(
            test_utils.get_proto_from_test_data(
                'expected_csv_example_gen_executor.pbtxt',
                pipeline_pb2.PipelineDeploymentConfig()), deployment_config)
Beispiel #11
0
    def testBuildFileBasedExampleGenWithInputConfig(self):
        input_config = example_gen_pb2.Input(splits=[
            example_gen_pb2.Input.Split(name='train', pattern='*train.tfr'),
            example_gen_pb2.Input.Split(name='eval', pattern='*test.tfr')
        ])
        example_gen = components.ImportExampleGen(
            input_base='path/to/data/root', input_config=input_config)
        deployment_config = pipeline_pb2.PipelineDeploymentConfig()
        component_defs = {}
        my_builder = step_builder.StepBuilder(
            node=example_gen,
            image='gcr.io/tensorflow/tfx:latest',
            deployment_config=deployment_config,
            component_defs=component_defs)
        actual_step_spec = self._sole(my_builder.build())
        actual_component_def = self._sole(component_defs)

        self.assertProtoEquals(
            test_utils.get_proto_from_test_data(
                'expected_import_example_gen_component.pbtxt',
                pipeline_pb2.ComponentSpec()), actual_component_def)
        self.assertProtoEquals(
            test_utils.get_proto_from_test_data(
                'expected_import_example_gen_task.pbtxt',
                pipeline_pb2.PipelineTaskSpec()), actual_step_spec)
        self.assertProtoEquals(
            test_utils.get_proto_from_test_data(
                'expected_import_example_gen_executor.pbtxt',
                pipeline_pb2.PipelineDeploymentConfig()), deployment_config)
Beispiel #12
0
    def testBuildTask(self):
        query = 'SELECT * FROM TABLE'
        bq_example_gen = big_query_example_gen_component.BigQueryExampleGen(
            query=query)
        deployment_config = pipeline_pb2.PipelineDeploymentConfig()
        component_defs = {}
        my_builder = step_builder.StepBuilder(
            node=bq_example_gen,
            image='gcr.io/tensorflow/tfx:latest',
            deployment_config=deployment_config,
            component_defs=component_defs,
            enable_cache=True)
        actual_step_spec = self._sole(my_builder.build())
        actual_component_def = self._sole(component_defs)

        self.assertProtoEquals(
            test_utils.get_proto_from_test_data(
                'expected_bq_example_gen_component.pbtxt',
                pipeline_pb2.ComponentSpec()), actual_component_def)
        self.assertProtoEquals(
            test_utils.get_proto_from_test_data(
                'expected_bq_example_gen_task.pbtxt',
                pipeline_pb2.PipelineTaskSpec()), actual_step_spec)
        self.assertProtoEquals(
            test_utils.get_proto_from_test_data(
                'expected_bq_example_gen_executor.pbtxt',
                pipeline_pb2.PipelineDeploymentConfig()), deployment_config)
Beispiel #13
0
    def testBuildContainerTask(self):
        task = test_utils.DummyProducerComponent(
            output1=channel_utils.as_channel([standard_artifacts.Model()]),
            param1='value1',
        )
        deployment_config = pipeline_pb2.PipelineDeploymentConfig()
        component_defs = {}
        my_builder = step_builder.StepBuilder(
            node=task,
            image=
            'gcr.io/tensorflow/tfx:latest',  # Note this has no effect here.
            deployment_config=deployment_config,
            component_defs=component_defs)
        actual_step_spec = self._sole(my_builder.build())
        actual_component_def = self._sole(component_defs)

        self.assertProtoEquals(
            test_utils.get_proto_from_test_data(
                'expected_dummy_container_spec_component.pbtxt',
                pipeline_pb2.ComponentSpec()), actual_component_def)
        self.assertProtoEquals(
            test_utils.get_proto_from_test_data(
                'expected_dummy_container_spec_task.pbtxt',
                pipeline_pb2.PipelineTaskSpec()), actual_step_spec)
        self.assertProtoEquals(
            test_utils.get_proto_from_test_data(
                'expected_dummy_container_spec_executor.pbtxt',
                pipeline_pb2.PipelineDeploymentConfig()), deployment_config)
Beispiel #14
0
def _build_importer_task_spec(
    importer_base_name: str,
    artifact_uri: Union[_pipeline_param.PipelineParam, str],
) -> pipeline_spec_pb2.PipelineTaskSpec:
    """Builds an importer task spec.

  Args:
    importer_base_name: The base name of the importer node.
    artifact_uri: The artifact uri to import from.

  Returns:
    An importer node task spec.
  """
    result = pipeline_spec_pb2.PipelineTaskSpec()
    result.task_info.name = dsl_utils.sanitize_task_name(importer_base_name)
    result.component_ref.name = dsl_utils.sanitize_component_name(
        importer_base_name)

    if isinstance(artifact_uri, _pipeline_param.PipelineParam):
        result.inputs.parameters[
            INPUT_KEY].component_input_parameter = artifact_uri.full_name
    elif isinstance(artifact_uri, str):
        result.inputs.parameters[
            INPUT_KEY].runtime_value.constant_value.string_value = artifact_uri

    return result
    def test_build_importer_task_spec(self, importer_name, input_uri,
                                      expected_result):
        expected_task_spec = pb.PipelineTaskSpec()
        json_format.ParseDict(expected_result, expected_task_spec)

        task_spec = importer_node._build_importer_task_spec(
            importer_base_name=importer_name, artifact_uri=input_uri)

        self.assertEqual(expected_task_spec, task_spec)
    def test_update_task_inputs_spec(self, original_task_spec,
                                     parent_component_inputs,
                                     tasks_in_current_dag, expected_result):
        pipeline_params = self.TEST_PIPELINE_PARAMS

        expected_spec = pipeline_spec_pb2.PipelineTaskSpec()
        json_format.ParseDict(expected_result, expected_spec)

        task_spec = pipeline_spec_pb2.PipelineTaskSpec()
        json_format.ParseDict(original_task_spec, task_spec)
        parent_component_inputs_spec = pipeline_spec_pb2.ComponentInputsSpec()
        json_format.ParseDict(parent_component_inputs,
                              parent_component_inputs_spec)
        dsl_component_spec.update_task_inputs_spec(
            task_spec, parent_component_inputs_spec, pipeline_params,
            tasks_in_current_dag)

        self.assertEqual(expected_spec, task_spec)
  def test_build_importer_task(self):
    dependent_task = {
        'taskInfo': {
            'name': 'task1'
        },
        'inputs': {
            'artifacts': {
                'input1': {
                    'producerTask': '',
                }
            }
        },
        'executorLabel': 'task1_input1_importer'
    }
    dependent_task_spec = pb.PipelineTaskSpec()
    json_format.ParseDict(dependent_task, dependent_task_spec)

    expected_task = {
        'taskInfo': {
            'name': 'task1_input1_importer'
        },
        'outputs': {
            'artifacts': {
                'result': {
                    'artifactType': {
                        'instanceSchema': 'title: kfp.Artifact'
                    }
                }
            }
        },
        'executorLabel': 'task1_input1_importer'
    }
    expected_task_spec = pb.PipelineTaskSpec()
    json_format.ParseDict(expected_task, expected_task_spec)

    task_spec = importer_node.build_importer_task_spec(
        dependent_task=dependent_task_spec,
        input_name='input1',
        input_type_schema='title: kfp.Artifact')

    self.maxDiff = None
    self.assertEqual(expected_task_spec, task_spec)
Beispiel #18
0
    def testBuildDummyConsumerWithCondition(self):
        producer_task_1 = test_utils.dummy_producer_component(
            output1=channel_utils.as_channel([standard_artifacts.Model()]),
            param1='value1',
        ).with_id('producer_task_1')
        producer_task_2 = test_utils.dummy_producer_component_2(
            output1=channel_utils.as_channel([standard_artifacts.Model()]),
            param1='value2',
        ).with_id('producer_task_2')
        # This test tests two things:
        # 1. Nested conditions. The condition string of consumer_task should contain
        #    both predicates.
        # 2. Implicit channels. consumer_task only takes producer_task_1's output.
        #    But producer_task_2 is used in condition, hence producer_task_2 should
        #    be added to the dependency of consumer_task.
        # See testdata for detail.
        with conditional.Cond(
                producer_task_1.outputs['output1'].future()[0].uri != 'uri'):
            with conditional.Cond(producer_task_2.outputs['output1'].future()
                                  [0].property('property') == 'value1'):
                consumer_task = test_utils.dummy_consumer_component(
                    input1=producer_task_1.outputs['output1'],
                    param1=1,
                )
        # Need to construct a pipeline to set producer_component_id.
        unused_pipeline = tfx.dsl.Pipeline(
            pipeline_name='pipeline-with-condition',
            pipeline_root='',
            components=[producer_task_1, producer_task_2, consumer_task],
        )
        deployment_config = pipeline_pb2.PipelineDeploymentConfig()
        component_defs = {}
        my_builder = step_builder.StepBuilder(
            node=consumer_task,
            image='gcr.io/tensorflow/tfx:latest',
            deployment_config=deployment_config,
            component_defs=component_defs)
        actual_step_spec = self._sole(my_builder.build())
        actual_component_def = self._sole(component_defs)

        self.assertProtoEquals(
            test_utils.get_proto_from_test_data(
                'expected_dummy_consumer_with_condition_component.pbtxt',
                pipeline_pb2.ComponentSpec()), actual_component_def)
        self.assertProtoEquals(
            test_utils.get_proto_from_test_data(
                'expected_dummy_consumer_with_condition_task.pbtxt',
                pipeline_pb2.PipelineTaskSpec()), actual_step_spec)
        self.assertProtoEquals(
            test_utils.get_proto_from_test_data(
                'expected_dummy_consumer_with_condition_executor.pbtxt',
                pipeline_pb2.PipelineDeploymentConfig()), deployment_config)
Beispiel #19
0
  def _build_resolver_for_latest_blessed_model(
      self, model_channel_key: str, model_blessing_resolver_name: str,
      model_blessing_channel_key: str) -> pipeline_pb2.PipelineTaskSpec:
    """Builds the resolver spec for latest blessed Model artifact."""
    name = '{}{}'.format(self._name, _MODEL_RESOLVER_SUFFIX)

    # Component def.
    component_def = pipeline_pb2.ComponentSpec()
    executor_label = _EXECUTOR_LABEL_PATTERN.format(name)
    component_def.executor_label = executor_label
    input_artifact_spec = compiler_utils.build_input_artifact_spec(
        self._outputs[model_blessing_channel_key])
    component_def.input_definitions.artifacts[
        _MODEL_RESOLVER_INPUT_KEY].CopyFrom(input_artifact_spec)
    output_artifact_spec = compiler_utils.build_output_artifact_spec(
        self._outputs[model_channel_key])
    component_def.output_definitions.artifacts[model_channel_key].CopyFrom(
        output_artifact_spec)
    self._component_defs[name] = component_def

    # Task spec.
    task_spec = pipeline_pb2.PipelineTaskSpec()
    task_spec.task_info.name = name
    task_spec.component_ref.name = name
    input_artifact_spec = pipeline_pb2.TaskInputsSpec.InputArtifactSpec()
    input_artifact_spec.task_output_artifact.producer_task = model_blessing_resolver_name
    input_artifact_spec.task_output_artifact.output_artifact_key = model_blessing_channel_key
    task_spec.inputs.artifacts[_MODEL_RESOLVER_INPUT_KEY].CopyFrom(
        input_artifact_spec)

    # Resolver executor spec.
    executor = pipeline_pb2.PipelineDeploymentConfig.ExecutorSpec()
    artifact_queries = {}
    query_filter = (
        'schema_title="{type}" AND '
        'state={state} AND '
        'name="{{{{$.inputs.artifacts[\'{input_key}\']'
        '.metadata[\'{property_key}\']}}}}"').format(
            type=compiler_utils.get_artifact_title(standard_artifacts.Model),
            state=metadata_store_pb2.Artifact.State.Name(
                metadata_store_pb2.Artifact.LIVE),
            input_key=_MODEL_RESOLVER_INPUT_KEY,
            property_key=constants.ARTIFACT_PROPERTY_CURRENT_MODEL_ID_KEY)
    artifact_queries[model_channel_key] = ResolverSpec.ArtifactQuerySpec(
        filter=query_filter)
    executor.resolver.CopyFrom(
        ResolverSpec(output_artifact_queries=artifact_queries))
    self._deployment_config.executors[executor_label].CopyFrom(executor)

    return task_spec
Beispiel #20
0
def build_importer_task_spec(
    importer_base_name: str, ) -> pipeline_spec_pb2.PipelineTaskSpec:
    """Builds an importer task spec.

  Args:
    importer_base_name: The base name of the importer node.

  Returns:
    An importer node task spec.
  """
    result = pipeline_spec_pb2.PipelineTaskSpec()
    result.task_info.name = dsl_utils.sanitize_task_name(importer_base_name)
    result.component_ref.name = dsl_utils.sanitize_component_name(
        importer_base_name)

    return result
    def test_build_importer_task_spec(self):
        expected_task = {
            'taskInfo': {
                'name': 'task-importer-task0-input1'
            },
            'componentRef': {
                'name': 'comp-importer-task0-input1'
            },
        }
        expected_task_spec = pb.PipelineTaskSpec()
        json_format.ParseDict(expected_task, expected_task_spec)

        task_spec = importer_node.build_importer_task_spec(
            importer_base_name='importer-task0-input1')

        self.maxDiff = None
        self.assertEqual(expected_task_spec, task_spec)
Beispiel #22
0
  def _build_resolver_for_latest_model_blessing(
      self, model_blessing_channel_key: str) -> pipeline_pb2.PipelineTaskSpec:
    """Builds the resolver spec for latest valid ModelBlessing artifact."""
    name = '{}{}'.format(self._name, _MODEL_BLESSING_RESOLVER_SUFFIX)

    # Component def.
    component_def = pipeline_pb2.ComponentSpec()
    executor_label = _EXECUTOR_LABEL_PATTERN.format(name)
    component_def.executor_label = executor_label
    output_artifact_spec = compiler_utils.build_output_artifact_spec(
        self._outputs[model_blessing_channel_key])
    component_def.output_definitions.artifacts[
        model_blessing_channel_key].CopyFrom(output_artifact_spec)
    self._component_defs[name] = component_def

    # Task spec.
    task_spec = pipeline_pb2.PipelineTaskSpec()
    task_spec.task_info.name = name
    task_spec.component_ref.name = name

    # Builds the resolver executor spec for latest valid ModelBlessing.
    executor = pipeline_pb2.PipelineDeploymentConfig.ExecutorSpec()
    artifact_queries = {}
    query_filter = ('artifact_type="{type}" and state={state}'
                    ' and metadata.{key}.number_value={value}').format(
                        type=compiler_utils.get_artifact_title(
                            standard_artifacts.ModelBlessing),
                        state=metadata_store_pb2.Artifact.State.Name(
                            metadata_store_pb2.Artifact.LIVE),
                        key=constants.ARTIFACT_PROPERTY_BLESSED_KEY,
                        value=constants.BLESSED_VALUE)
    artifact_queries[
        model_blessing_channel_key] = ResolverSpec.ArtifactQuerySpec(
            filter=query_filter)
    executor.resolver.CopyFrom(
        ResolverSpec(output_artifact_queries=artifact_queries))
    self._deployment_config.executors[executor_label].CopyFrom(executor)

    return task_spec
Beispiel #23
0
def build_task_spec_for_group(
    group: tasks_group.TasksGroup,
    pipeline_channels: List[pipeline_channel.PipelineChannel],
    tasks_in_current_dag: List[str],
    is_parent_component_root: bool,
) -> pipeline_spec_pb2.PipelineTaskSpec:
    """Builds PipelineTaskSpec for a group.

    Args:
        group: The group to build PipelineTaskSpec for.
        pipeline_channels: The list of pipeline channels referenced by the group.
        tasks_in_current_dag: The list of tasks names for tasks in the same dag.
        is_parent_component_root: Whether the parent component is the pipeline's
            root dag.

    Returns:
        A PipelineTaskSpec object representing the group.
    """
    pipeline_task_spec = pipeline_spec_pb2.PipelineTaskSpec()
    pipeline_task_spec.task_info.name = group.display_name or group.name
    pipeline_task_spec.component_ref.name = (
        component_utils.sanitize_component_name(group.name))

    for channel in pipeline_channels:

        channel_full_name = channel.full_name
        subvar_name = None
        if isinstance(channel, for_loop.LoopArgumentVariable):
            channel_full_name = channel.loop_argument.full_name
            subvar_name = channel.subvar_name

        input_name = _additional_input_name_for_pipeline_channel(channel)

        channel_name = channel.name
        if subvar_name:
            pipeline_task_spec.inputs.parameters[
                input_name].parameter_expression_selector = (
                    'parseJson(string_value)["{}"]'.format(subvar_name))
            if not channel.is_with_items_loop_argument:
                channel_name = channel.items_or_pipeline_channel.name

        if isinstance(channel, pipeline_channel.PipelineArtifactChannel):
            if channel.task_name and channel.task_name in tasks_in_current_dag:
                pipeline_task_spec.inputs.artifacts[
                    input_name].task_output_artifact.producer_task = (
                        component_utils.sanitize_task_name(channel.task_name))
                pipeline_task_spec.inputs.artifacts[
                    input_name].task_output_artifact.output_artifact_key = (
                        channel_name)
            else:
                pipeline_task_spec.inputs.artifacts[
                    input_name].component_input_artifact = (
                        channel_full_name
                        if is_parent_component_root else input_name)
        else:
            # channel is one of PipelineParameterChannel, LoopArgument, or
            # LoopArgumentVariable
            if channel.task_name and channel.task_name in tasks_in_current_dag:
                pipeline_task_spec.inputs.parameters[
                    input_name].task_output_parameter.producer_task = (
                        component_utils.sanitize_task_name(channel.task_name))
                pipeline_task_spec.inputs.parameters[
                    input_name].task_output_parameter.output_parameter_key = (
                        channel_name)
            else:
                pipeline_task_spec.inputs.parameters[
                    input_name].component_input_parameter = (
                        channel_full_name if is_parent_component_root else
                        _additional_input_name_for_pipeline_channel(
                            channel_full_name))

    if isinstance(group, tasks_group.ParallelFor):
        _update_task_spec_for_loop_group(
            group=group,
            pipeline_task_spec=pipeline_task_spec,
        )
    elif isinstance(group, tasks_group.Condition):
        _update_task_spec_for_condition_group(
            group=group,
            pipeline_task_spec=pipeline_task_spec,
        )

    return pipeline_task_spec
Beispiel #24
0
def build_task_spec_for_task(
    task: pipeline_task.PipelineTask,
    parent_component_inputs: pipeline_spec_pb2.ComponentInputsSpec,
    tasks_in_current_dag: List[str],
    input_parameters_in_current_dag: List[str],
    input_artifacts_in_current_dag: List[str],
) -> pipeline_spec_pb2.PipelineTaskSpec:
    """Builds PipelineTaskSpec for a pipeline task.

    A task input may reference an output outside its immediate DAG.
    For instance::

        random_num = random_num_op(...)
        with dsl.Condition(random_num.output > 5):
            print_op('%s > 5' % random_num.output)

    In this example, `dsl.Condition` forms a subDAG with one task from `print_op`
    inside the subDAG. The task of `print_op` references output from `random_num`
    task, which is outside the sub-DAG. When compiling to IR, such cross DAG
    reference is disallowed. So we need to "punch a hole" in the sub-DAG to make
    the input available in the subDAG component inputs if it's not already there,
    Next, we can call this method to fix the tasks inside the subDAG to make them
    reference the component inputs instead of directly referencing the original
    producer task.

    Args:
        task: The task to build a PipelineTaskSpec for.
        parent_component_inputs: The task's parent component's input specs.
        tasks_in_current_dag: The list of tasks names for tasks in the same dag.
        input_parameters_in_current_dag: The list of input parameters in the DAG
            component.
        input_artifacts_in_current_dag: The list of input artifacts in the DAG
            component.

    Returns:
        A PipelineTaskSpec object representing the task.
    """
    pipeline_task_spec = pipeline_spec_pb2.PipelineTaskSpec()
    pipeline_task_spec.task_info.name = (task.task_spec.display_name
                                         or task.name)
    # Use task.name for component_ref.name because we may customize component
    # spec for individual tasks to work around the lack of optional inputs
    # support in IR.
    pipeline_task_spec.component_ref.name = (
        component_utils.sanitize_component_name(task.name))
    pipeline_task_spec.caching_options.enable_cache = (
        task.task_spec.enable_caching)

    for input_name, input_value in task.inputs.items():
        if isinstance(input_value, pipeline_channel.PipelineArtifactChannel):

            if input_value.task_name:
                # Value is produced by an upstream task.
                if input_value.task_name in tasks_in_current_dag:
                    # Dependent task within the same DAG.
                    pipeline_task_spec.inputs.artifacts[
                        input_name].task_output_artifact.producer_task = (
                            component_utils.sanitize_task_name(
                                input_value.task_name))
                    pipeline_task_spec.inputs.artifacts[
                        input_name].task_output_artifact.output_artifact_key = (
                            input_value.name)
                else:
                    # Dependent task not from the same DAG.
                    component_input_artifact = (
                        _additional_input_name_for_pipeline_channel(
                            input_value))
                    assert component_input_artifact in parent_component_inputs.artifacts, \
                        'component_input_artifact: {} not found. All inputs: {}'.format(
                            component_input_artifact, parent_component_inputs)
                    pipeline_task_spec.inputs.artifacts[
                        input_name].component_input_artifact = (
                            component_input_artifact)
            else:
                raise RuntimeError(
                    f'Artifacts must be produced by a task. Got {input_value}.'
                )

        elif isinstance(input_value,
                        pipeline_channel.PipelineParameterChannel):

            if input_value.task_name:
                # Value is produced by an upstream task.
                if input_value.task_name in tasks_in_current_dag:
                    # Dependent task within the same DAG.
                    pipeline_task_spec.inputs.parameters[
                        input_name].task_output_parameter.producer_task = (
                            component_utils.sanitize_task_name(
                                input_value.task_name))
                    pipeline_task_spec.inputs.parameters[
                        input_name].task_output_parameter.output_parameter_key = (
                            input_value.name)
                else:
                    # Dependent task not from the same DAG.
                    component_input_parameter = (
                        _additional_input_name_for_pipeline_channel(
                            input_value))
                    assert component_input_parameter in parent_component_inputs.parameters, \
                        'component_input_parameter: {} not found. All inputs: {}'.format(
                            component_input_parameter, parent_component_inputs)
                    pipeline_task_spec.inputs.parameters[
                        input_name].component_input_parameter = (
                            component_input_parameter)
            else:
                # Value is from pipeline input.
                component_input_parameter = input_value.full_name
                if component_input_parameter not in parent_component_inputs.parameters:
                    component_input_parameter = (
                        _additional_input_name_for_pipeline_channel(
                            input_value))
                pipeline_task_spec.inputs.parameters[
                    input_name].component_input_parameter = (
                        component_input_parameter)

        elif isinstance(input_value, for_loop.LoopArgument):

            component_input_parameter = (
                _additional_input_name_for_pipeline_channel(input_value))
            assert component_input_parameter in parent_component_inputs.parameters, \
                'component_input_parameter: {} not found. All inputs: {}'.format(
                    component_input_parameter, parent_component_inputs)
            pipeline_task_spec.inputs.parameters[
                input_name].component_input_parameter = (
                    component_input_parameter)

        elif isinstance(input_value, for_loop.LoopArgumentVariable):

            component_input_parameter = (
                _additional_input_name_for_pipeline_channel(
                    input_value.loop_argument))
            assert component_input_parameter in parent_component_inputs.parameters, \
                'component_input_parameter: {} not found. All inputs: {}'.format(
                    component_input_parameter, parent_component_inputs)
            pipeline_task_spec.inputs.parameters[
                input_name].component_input_parameter = (
                    component_input_parameter)
            pipeline_task_spec.inputs.parameters[
                input_name].parameter_expression_selector = (
                    'parseJson(string_value)["{}"]'.format(
                        input_value.subvar_name))

        elif isinstance(input_value, str):

            # Handle extra input due to string concat
            pipeline_channels = (
                pipeline_channel.extract_pipeline_channels_from_any(
                    input_value))
            for channel in pipeline_channels:
                # value contains PipelineChannel placeholders which needs to be
                # replaced. And the input needs to be added to the task spec.

                # Form the name for the compiler injected input, and make sure it
                # doesn't collide with any existing input names.
                additional_input_name = (
                    _additional_input_name_for_pipeline_channel(channel))

                # We don't expect collision to happen because we prefix the name
                # of additional input with 'pipelinechannel--'. But just in case
                # collision did happend, throw a RuntimeError so that we don't
                # get surprise at runtime.
                for existing_input_name, _ in task.inputs.items():
                    if existing_input_name == additional_input_name:
                        raise RuntimeError(
                            'Name collision between existing input name '
                            '{} and compiler injected input name {}'.format(
                                existing_input_name, additional_input_name))

                additional_input_placeholder = (
                    placeholders.input_parameter_placeholder(
                        additional_input_name))
                input_value = input_value.replace(
                    channel.pattern, additional_input_placeholder)

                if channel.task_name:
                    # Value is produced by an upstream task.
                    if channel.task_name in tasks_in_current_dag:
                        # Dependent task within the same DAG.
                        pipeline_task_spec.inputs.parameters[
                            additional_input_name].task_output_parameter.producer_task = (
                                component_utils.sanitize_task_name(
                                    channel.task_name))
                        pipeline_task_spec.inputs.parameters[
                            input_name].task_output_parameter.output_parameter_key = (
                                channel.name)
                    else:
                        # Dependent task not from the same DAG.
                        component_input_parameter = (
                            _additional_input_name_for_pipeline_channel(
                                channel))
                        assert component_input_parameter in parent_component_inputs.parameters, \
                            'component_input_parameter: {} not found. All inputs: {}'.format(
                                component_input_parameter, parent_component_inputs)
                        pipeline_task_spec.inputs.parameters[
                            additional_input_name].component_input_parameter = (
                                component_input_parameter)
                else:
                    # Value is from pipeline input. (or loop?)
                    component_input_parameter = channel.full_name
                    if component_input_parameter not in parent_component_inputs.parameters:
                        component_input_parameter = (
                            _additional_input_name_for_pipeline_channel(
                                channel))
                    pipeline_task_spec.inputs.parameters[
                        additional_input_name].component_input_parameter = (
                            component_input_parameter)

            pipeline_task_spec.inputs.parameters[
                input_name].runtime_value.constant.string_value = input_value

        elif isinstance(input_value, (str, int, float, bool, dict, list)):

            pipeline_task_spec.inputs.parameters[
                input_name].runtime_value.constant.CopyFrom(
                    _to_protobuf_value(input_value))

        else:
            raise ValueError(
                'Input argument supports only the following types: '
                'str, int, float, bool, dict, and list.'
                f'Got {input_value} of type {type(input_value)}.')

    return pipeline_task_spec
Beispiel #25
0
def _get_custom_job_op(
    task_name: str,
    job_spec: Dict[str, Any],
    input_artifacts: Optional[Dict[str, dsl.PipelineParam]] = None,
    input_parameters: Optional[Dict[str, _ValueOrPipelineParam]] = None,
    output_artifacts: Optional[Dict[str, Type[artifact.Artifact]]] = None,
    output_parameters: Optional[Dict[str, Any]] = None,
) -> AiPlatformCustomJobOp:
    """Gets an AiPlatformCustomJobOp from job spec and I/O definition."""
    pipeline_task_spec = pipeline_spec_pb2.PipelineTaskSpec()
    pipeline_component_spec = pipeline_spec_pb2.ComponentSpec()

    pipeline_task_spec.task_info.CopyFrom(
        pipeline_spec_pb2.PipelineTaskInfo(name=task_name))

    # Iterate through the inputs/outputs declaration to get pipeline component
    # spec.
    for input_name, param in input_parameters.items():
        if isinstance(param, dsl.PipelineParam):
            pipeline_component_spec.input_definitions.parameters[
                input_name].type = type_utils.get_parameter_type(
                    param.param_type)
        else:
            pipeline_component_spec.input_definitions.parameters[
                input_name].type = type_utils.get_parameter_type(type(param))

    for input_name, art in input_artifacts.items():
        if not isinstance(art, dsl.PipelineParam):
            raise RuntimeError(
                'Get unresolved input artifact for input %s. Input '
                'artifacts must be connected to a producer task.' % input_name)
        pipeline_component_spec.input_definitions.artifacts[
            input_name].artifact_type.CopyFrom(
                type_utils.get_artifact_type_schema_message(art.param_type))

    for output_name, param_type in output_parameters.items():
        pipeline_component_spec.output_definitions.parameters[
            output_name].type = type_utils.get_parameter_type(param_type)

    for output_name, artifact_type in output_artifacts.items():
        pipeline_component_spec.output_definitions.artifacts[
            output_name].artifact_type.CopyFrom(artifact_type.get_ir_type())

    pipeline_component_spec.executor_label = dsl_utils.sanitize_executor_label(
        task_name)

    # Iterate through the inputs/outputs specs to get pipeline task spec.
    for input_name, param in input_parameters.items():
        if isinstance(param, dsl.PipelineParam) and param.op_name:
            # If the param has a valid op_name, this should be a pipeline parameter
            # produced by an upstream task.
            pipeline_task_spec.inputs.parameters[input_name].CopyFrom(
                pipeline_spec_pb2.TaskInputsSpec.InputParameterSpec(
                    task_output_parameter=pipeline_spec_pb2.TaskInputsSpec.
                    InputParameterSpec.TaskOutputParameterSpec(
                        producer_task='task-{}'.format(param.op_name),
                        output_parameter_key=param.name)))
        elif isinstance(param, dsl.PipelineParam) and not param.op_name:
            # If a valid op_name is missing, this should be a pipeline parameter.
            pipeline_task_spec.inputs.parameters[input_name].CopyFrom(
                pipeline_spec_pb2.TaskInputsSpec.InputParameterSpec(
                    component_input_parameter=param.name))
        else:
            # If this is not a pipeline param, then it should be a value.
            pipeline_task_spec.inputs.parameters[input_name].CopyFrom(
                pipeline_spec_pb2.TaskInputsSpec.InputParameterSpec(
                    runtime_value=pipeline_spec_pb2.ValueOrRuntimeParameter(
                        constant_value=dsl_utils.get_value(param))))

    for input_name, art in input_artifacts.items():
        if art.op_name:
            # If the param has a valid op_name, this should be an artifact produced
            # by an upstream task.
            pipeline_task_spec.inputs.artifacts[input_name].CopyFrom(
                pipeline_spec_pb2.TaskInputsSpec.InputArtifactSpec(
                    task_output_artifact=pipeline_spec_pb2.TaskInputsSpec.
                    InputArtifactSpec.TaskOutputArtifactSpec(
                        producer_task='task-{}'.format(art.op_name),
                        output_artifact_key=art.name)))
        else:
            # Otherwise, this should be from the input of the subdag.
            pipeline_task_spec.inputs.artifacts[input_name].CopyFrom(
                pipeline_spec_pb2.TaskInputsSpec.InputArtifactSpec(
                    component_input_artifact=art.name))

    # TODO: Add task dependencies/trigger policies/caching/iterator
    pipeline_task_spec.component_ref.name = dsl_utils.sanitize_component_name(
        task_name)

    # Construct dummy I/O declaration for the op.
    # TODO: resolve name conflict instead of raising errors.
    dummy_outputs = collections.OrderedDict()
    for output_name, _ in output_artifacts.items():
        dummy_outputs[output_name] = _DUMMY_PATH

    for output_name, _ in output_parameters.items():
        if output_name in dummy_outputs:
            raise KeyError(
                'Got name collision for output key %s. Consider renaming '
                'either output parameters or output '
                'artifacts.' % output_name)
        dummy_outputs[output_name] = _DUMMY_PATH

    dummy_inputs = collections.OrderedDict()
    for input_name, art in input_artifacts.items():
        dummy_inputs[input_name] = _DUMMY_PATH
    for input_name, param in input_parameters.items():
        if input_name in dummy_inputs:
            raise KeyError(
                'Got name collision for input key %s. Consider renaming '
                'either input parameters or input '
                'artifacts.' % input_name)
        dummy_inputs[input_name] = _DUMMY_PATH

    # Construct the AIP (Unified) custom job op.
    return AiPlatformCustomJobOp(
        name=task_name,
        custom_job_spec=job_spec,
        component_spec=pipeline_component_spec,
        task_spec=pipeline_task_spec,
        task_inputs=[
            dsl.InputArgumentPath(
                argument=dummy_inputs[input_name],
                input=input_name,
                path=path,
            ) for input_name, path in dummy_inputs.items()
        ],
        task_outputs=dummy_outputs)
Beispiel #26
0
  def _build_latest_artifact_resolver(
      self) -> Dict[str, pipeline_pb2.PipelineTaskSpec]:
    """Builds a resolver spec for a latest artifact resolver.

    Returns:
      A list of two PipelineTaskSpecs. One represents the query for latest valid
      ModelBlessing artifact. Another one represents the query for latest
      blessed Model artifact.
    Raises:
      ValueError: when desired_num_of_artifacts != 1. 1 is the only supported
        value currently.
    """
    # Fetch the init kwargs for the resolver.
    resolver_config = self._exec_properties[resolver.RESOLVER_CONFIG]
    if (isinstance(resolver_config, dict) and
        resolver_config.get('desired_num_of_artifacts', 0) > 1):
      raise ValueError('Only desired_num_of_artifacts=1 is supported currently.'
                       ' Got {}'.format(
                           resolver_config.get('desired_num_of_artifacts')))

    component_def = pipeline_pb2.ComponentSpec()
    executor_label = _EXECUTOR_LABEL_PATTERN.format(self._name)
    component_def.executor_label = executor_label
    task_spec = pipeline_pb2.PipelineTaskSpec()
    task_spec.task_info.name = self._name

    for name, output_channel in self._outputs.items():
      output_artifact_spec = compiler_utils.build_output_artifact_spec(
          output_channel)
      component_def.output_definitions.artifacts[name].CopyFrom(
          output_artifact_spec)
    for name, value in self._exec_properties.items():
      if value is None:
        continue
      parameter_type_spec = compiler_utils.build_parameter_type_spec(value)
      component_def.input_definitions.parameters[name].CopyFrom(
          parameter_type_spec)
      if isinstance(value, data_types.RuntimeParameter):
        parameter_utils.attach_parameter(value)
        task_spec.inputs.parameters[name].component_input_parameter = value.name
      else:
        task_spec.inputs.parameters[name].CopyFrom(
            pipeline_pb2.TaskInputsSpec.InputParameterSpec(
                runtime_value=compiler_utils.value_converter(value)))
    self._component_defs[self._name] = component_def
    task_spec.component_ref.name = self._name

    artifact_queries = {}
    # Buid the artifact query for each channel in the input dict.
    for name, c in self._inputs.items():
      query_filter = ('artifact_type="{type}" and state={state}').format(
          type=compiler_utils.get_artifact_title(c.type),
          state=metadata_store_pb2.Artifact.State.Name(
              metadata_store_pb2.Artifact.LIVE))
      # Resolver's output dict has the same set of keys as its input dict.
      artifact_queries[name] = ResolverSpec.ArtifactQuerySpec(
          filter=query_filter)

    resolver_spec = ResolverSpec(output_artifact_queries=artifact_queries)
    executor = pipeline_pb2.PipelineDeploymentConfig.ExecutorSpec()
    executor.resolver.CopyFrom(resolver_spec)
    self._deployment_config.executors[executor_label].CopyFrom(executor)
    return {self._name: task_spec}
Beispiel #27
0
  def build(self) -> Dict[str, pipeline_pb2.PipelineTaskSpec]:
    """Builds a pipeline PipelineTaskSpec given the node information.

    Each TFX node maps one task spec and usually one component definition and
    one executor spec. (with resolver node as an exception. See explaination
    in the Returns section).

     - Component definition includes interfaces of a node. For example, name
    and type information of inputs/outputs/execution_properties.
     - Task spec contains the topologies around the node. For example, the
    dependency nodes, where to read the inputs and exec_properties (from another
    task, from parent component or from a constant value). The task spec has the
    name of the component definition it references. It is possible that a task
    spec references an existing component definition that's built previously.
     - Executor spec encodes how the node is actually executed. For example,
    args to start a container, or query strings for resolvers. All executor spec
    will be packed into deployment config proto.

    During the build, all three parts mentioned above will be updated.

    Returns:
      A Dict mapping from node id to PipelineTaskSpec messages corresponding to
      the node. For most of the cases, the dict contains a single element.
      The only exception is when compiling latest blessed model resolver.
      One DSL node will be split to two resolver specs to reflect the
      two-phased query execution.

    Raises:
      NotImplementedError: When the node being built is an InfraValidator.
    """
    # 1. Resolver tasks won't have input artifacts in the API proto. First we
    #    specialcase two resolver types we support.
    if isinstance(self._node, resolver.Resolver):
      return self._build_resolver_spec()

    # 2. Build component spec.
    component_def = pipeline_pb2.ComponentSpec()
    executor_label = _EXECUTOR_LABEL_PATTERN.format(self._name)
    component_def.executor_label = executor_label
    # Inputs
    for name, input_channel in self._inputs.items():
      input_artifact_spec = compiler_utils.build_input_artifact_spec(
          input_channel)
      component_def.input_definitions.artifacts[name].CopyFrom(
          input_artifact_spec)
    # Outputs
    for name, output_channel in self._outputs.items():
      # Currently, we're working under the assumption that for tasks
      # (those generated by BaseComponent), each channel contains a single
      # artifact.
      output_artifact_spec = compiler_utils.build_output_artifact_spec(
          output_channel)
      component_def.output_definitions.artifacts[name].CopyFrom(
          output_artifact_spec)
    # Exec properties
    for name, value in self._exec_properties.items():
      # value can be None for unprovided optional exec properties.
      if value is None:
        continue
      parameter_type_spec = compiler_utils.build_parameter_type_spec(value)
      component_def.input_definitions.parameters[name].CopyFrom(
          parameter_type_spec)
    if self._name not in self._component_defs:
      self._component_defs[self._name] = component_def
    else:
      raise ValueError(f'Found duplicate component ids {self._name} while '
                       'building component definitions.')

    # 3. Build task spec.
    task_spec = pipeline_pb2.PipelineTaskSpec()
    task_spec.task_info.name = self._name
    dependency_ids = [node.id for node in self._node.upstream_nodes]
    for name, input_channel in self._inputs.items():
      # If the redirecting map is provided (usually for latest blessed model
      # resolver, we'll need to redirect accordingly. Also, the upstream node
      # list will be updated and replaced by the new producer id.
      producer_id = input_channel.producer_component_id
      output_key = input_channel.output_key
      for k, v in self._channel_redirect_map.items():
        if k[0] == producer_id and producer_id in dependency_ids:
          dependency_ids.remove(producer_id)
          dependency_ids.append(v[0])
      producer_id = self._channel_redirect_map.get((producer_id, output_key),
                                                   (producer_id, output_key))[0]
      output_key = self._channel_redirect_map.get((producer_id, output_key),
                                                  (producer_id, output_key))[1]
      input_artifact_spec = pipeline_pb2.TaskInputsSpec.InputArtifactSpec()
      input_artifact_spec.task_output_artifact.producer_task = producer_id
      input_artifact_spec.task_output_artifact.output_artifact_key = output_key
      task_spec.inputs.artifacts[name].CopyFrom(input_artifact_spec)
    for name, value in self._exec_properties.items():
      if value is None:
        continue
      if isinstance(value, data_types.RuntimeParameter):
        parameter_utils.attach_parameter(value)
        task_spec.inputs.parameters[name].component_input_parameter = value.name
      else:
        task_spec.inputs.parameters[name].CopyFrom(
            pipeline_pb2.TaskInputsSpec.InputParameterSpec(
                runtime_value=compiler_utils.value_converter(value)))

    task_spec.component_ref.name = self._name

    dependency_ids = sorted(dependency_ids)
    for dependency in dependency_ids:
      task_spec.dependent_tasks.append(dependency)

    if self._enable_cache:
      task_spec.caching_options.CopyFrom(
          pipeline_pb2.PipelineTaskSpec.CachingOptions(
              enable_cache=self._enable_cache))

    # 4. Build the executor body for other common tasks.
    executor = pipeline_pb2.PipelineDeploymentConfig.ExecutorSpec()
    if isinstance(self._node, importer.Importer):
      executor.importer.CopyFrom(self._build_importer_spec())
    elif isinstance(self._node, components.FileBasedExampleGen):
      executor.container.CopyFrom(self._build_file_based_example_gen_spec())
    elif isinstance(self._node, (components.InfraValidator)):
      raise NotImplementedError(
          'The componet type "{}" is not supported'.format(type(self._node)))
    else:
      executor.container.CopyFrom(self._build_container_spec())
    self._deployment_config.executors[executor_label].CopyFrom(executor)

    return {self._name: task_spec}
Beispiel #28
0
def _attach_v2_specs(
    task: _container_op.ContainerOp,
    component_spec: _structures.ComponentSpec,
    arguments: Mapping[str, Any],
) -> None:
  """Attaches v2 specs to a ContainerOp object.

  Attach v2_specs to the ContainerOp object regardless whether the pipeline is
  being compiled to v1 (Argo yaml) or v2 (IR json).
  However, there're different behaviors for the two cases. Namely, resolved
  commands and arguments, error handling, etc.
  Regarding the difference in error handling, v2 has a stricter requirement on
  input type annotation. For instance, an input without any type annotation is
  viewed as an artifact, and if it's paired with InputValuePlaceholder, an
  error will be thrown at compile time. However, we cannot raise such an error
  in v1, as it wouldn't break existing pipelines.

  Args:
    task: The ContainerOp object to attach IR specs.
    component_spec: The component spec object.
    arguments: The dictionary of component arguments.
  """

  def _resolve_commands_and_args_v2(
      component_spec: _structures.ComponentSpec,
      arguments: Mapping[str, Any],
  ) -> _components._ResolvedCommandLineAndPaths:
    """Resolves the command line argument placeholders for v2 (IR).

    Args:
      component_spec: The component spec object.
      arguments: The dictionary of component arguments.

    Returns:
      A named tuple: _components._ResolvedCommandLineAndPaths.
    """
    inputs_dict = {
        input_spec.name: input_spec
        for input_spec in component_spec.inputs or []
    }
    outputs_dict = {
        output_spec.name: output_spec
        for output_spec in component_spec.outputs or []
    }

    def _input_artifact_uri_placeholder(input_key: str) -> str:
      if kfp.COMPILING_FOR_V2 and type_utils.is_parameter_type(
          inputs_dict[input_key].type):
        raise TypeError('Input "{}" with type "{}" cannot be paired with '
                        'InputUriPlaceholder.'.format(
                            input_key, inputs_dict[input_key].type))
      else:
        return _generate_input_uri_placeholder(input_key)

    def _input_artifact_path_placeholder(input_key: str) -> str:
      if kfp.COMPILING_FOR_V2 and type_utils.is_parameter_type(
          inputs_dict[input_key].type):
        raise TypeError('Input "{}" with type "{}" cannot be paired with '
                        'InputPathPlaceholder.'.format(
                            input_key, inputs_dict[input_key].type))
      else:
        return "{{{{$.inputs.artifacts['{}'].path}}}}".format(input_key)

    def _input_parameter_placeholder(input_key: str) -> str:
      if kfp.COMPILING_FOR_V2 and not type_utils.is_parameter_type(
          inputs_dict[input_key].type):
        raise TypeError('Input "{}" with type "{}" cannot be paired with '
                        'InputValuePlaceholder.'.format(
                            input_key, inputs_dict[input_key].type))
      else:
        return "{{{{$.inputs.parameters['{}']}}}}".format(input_key)

    def _output_artifact_uri_placeholder(output_key: str) -> str:
      if kfp.COMPILING_FOR_V2 and type_utils.is_parameter_type(
          outputs_dict[output_key].type):
        raise TypeError('Output "{}" with type "{}" cannot be paired with '
                        'OutputUriPlaceholder.'.format(
                            output_key, outputs_dict[output_key].type))
      else:
        return _generate_output_uri_placeholder(output_key)

    def _output_artifact_path_placeholder(output_key: str) -> str:
      return "{{{{$.outputs.artifacts['{}'].path}}}}".format(output_key)

    def _output_parameter_path_placeholder(output_key: str) -> str:
      return "{{{{$.outputs.parameters['{}'].output_file}}}}".format(output_key)

    def _resolve_output_path_placeholder(output_key: str) -> str:
      if type_utils.is_parameter_type(outputs_dict[output_key].type):
        return _output_parameter_path_placeholder(output_key)
      else:
        return _output_artifact_path_placeholder(output_key)

    placeholder_resolver = ExtraPlaceholderResolver()
    def _resolve_ir_placeholders_v2(
        arg,
        component_spec: _structures.ComponentSpec,
        arguments: dict,
    ) -> str:
      inputs_dict = {input_spec.name: input_spec for input_spec in component_spec.inputs or []}
      if isinstance(arg, _structures.InputValuePlaceholder):
        input_name = arg.input_name
        input_value = arguments.get(input_name, None)
        if input_value is not None:
          return _input_parameter_placeholder(input_name)
        else:
          input_spec = inputs_dict[input_name]
          if input_spec.optional:
            return None
          else:
            raise ValueError('No value provided for input {}'.format(input_name))

      elif isinstance(arg, _structures.InputUriPlaceholder):
        input_name = arg.input_name
        if input_name in arguments:
          input_uri = _input_artifact_uri_placeholder(input_name)
          return input_uri
        else:
          input_spec = inputs_dict[input_name]
          if input_spec.optional:
            return None
          else:
            raise ValueError('No value provided for input {}'.format(input_name))

      elif isinstance(arg, _structures.OutputUriPlaceholder):
        output_name = arg.output_name
        output_uri = _output_artifact_uri_placeholder(output_name)
        return output_uri

      return placeholder_resolver.resolve_placeholder(
        arg=arg,
        component_spec=component_spec,
        arguments=arguments,
      )

    resolved_cmd = _components._resolve_command_line_and_paths(
        component_spec=component_spec,
        arguments=arguments,
        input_path_generator=_input_artifact_path_placeholder,
        output_path_generator=_resolve_output_path_placeholder,
        placeholder_resolver=_resolve_ir_placeholders_v2,
    )
    return resolved_cmd

  pipeline_task_spec = pipeline_spec_pb2.PipelineTaskSpec()

  # Check types of the reference arguments and serialize PipelineParams
  arguments = arguments.copy()

  # Preserve input params for ContainerOp.inputs
  input_params_set = set([
      param for param in arguments.values()
      if isinstance(param, _pipeline_param.PipelineParam)
  ])

  for input_name, argument_value in arguments.items():
    input_type = component_spec._inputs_dict[input_name].type
    argument_type = None

    if isinstance(argument_value, _pipeline_param.PipelineParam):
      argument_type = argument_value.param_type

      types.verify_type_compatibility(
          argument_type, input_type,
          'Incompatible argument passed to the input "{}" of component "{}": '
          .format(input_name, component_spec.name))

      # Loop arguments defaults to 'String' type if type is unknown.
      # This has to be done after the type compatiblity check.
      if argument_type is None and isinstance(
          argument_value, (_for_loop.LoopArguments,
                           _for_loop.LoopArgumentVariable)):
        argument_type = 'String'

      arguments[input_name] = str(argument_value)

      if type_utils.is_parameter_type(input_type):
        if argument_value.op_name:
          pipeline_task_spec.inputs.parameters[
              input_name].task_output_parameter.producer_task = (
                  dsl_utils.sanitize_task_name(argument_value.op_name))
          pipeline_task_spec.inputs.parameters[
              input_name].task_output_parameter.output_parameter_key = (
                  argument_value.name)
        else:
          pipeline_task_spec.inputs.parameters[
              input_name].component_input_parameter = argument_value.name
      else:
        if argument_value.op_name:
          pipeline_task_spec.inputs.artifacts[
              input_name].task_output_artifact.producer_task = (
                  dsl_utils.sanitize_task_name(argument_value.op_name))
          pipeline_task_spec.inputs.artifacts[
              input_name].task_output_artifact.output_artifact_key = (
                  argument_value.name)
    elif isinstance(argument_value, str):
      argument_type = 'String'
      pipeline_params = _pipeline_param.extract_pipelineparams_from_any(
          argument_value)
      if pipeline_params and kfp.COMPILING_FOR_V2:
        # argument_value contains PipelineParam placeholders which needs to be
        # replaced. And the input needs to be added to the task spec.
        for param in pipeline_params:
          # Form the name for the compiler injected input, and make sure it
          # doesn't collide with any existing input names.
          additional_input_name = (
              dsl_component_spec.additional_input_name_for_pipelineparam(param))
          for existing_input_name, _ in arguments.items():
            if existing_input_name == additional_input_name:
              raise ValueError('Name collision between existing input name '
                               '{} and compiler injected input name {}'.format(
                                   existing_input_name, additional_input_name))

          # Add the additional param to the input params set. Otherwise, it will
          # not be included when the params set is not empty.
          input_params_set.add(param)
          additional_input_placeholder = (
              "{{{{$.inputs.parameters['{}']}}}}".format(additional_input_name))
          argument_value = argument_value.replace(param.pattern,
                                                  additional_input_placeholder)

          # The output references are subject to change -- the producer task may
          # not be whitin the same DAG.
          if param.op_name:
            pipeline_task_spec.inputs.parameters[
                additional_input_name].task_output_parameter.producer_task = (
                    dsl_utils.sanitize_task_name(param.op_name))
            pipeline_task_spec.inputs.parameters[
                additional_input_name].task_output_parameter.output_parameter_key = param.name
          else:
            pipeline_task_spec.inputs.parameters[
                additional_input_name].component_input_parameter = param.full_name

      input_type = component_spec._inputs_dict[input_name].type
      if type_utils.is_parameter_type(input_type):
        pipeline_task_spec.inputs.parameters[
            input_name].runtime_value.constant_value.string_value = (
                argument_value)
    elif isinstance(argument_value, int):
      argument_type = 'Integer'
      pipeline_task_spec.inputs.parameters[
          input_name].runtime_value.constant_value.int_value = argument_value
    elif isinstance(argument_value, float):
      argument_type = 'Float'
      pipeline_task_spec.inputs.parameters[
          input_name].runtime_value.constant_value.double_value = argument_value
    elif isinstance(argument_value, _container_op.ContainerOp):
      raise TypeError(
          'ContainerOp object {} was passed to component as an input argument. '
          'Pass a single output instead.'.format(input_name))
    else:
      if kfp.COMPILING_FOR_V2:
        raise NotImplementedError(
            'Input argument supports only the following types: PipelineParam'
            ', str, int, float. Got: "{}".'.format(argument_value))

    argument_is_parameter_type = type_utils.is_parameter_type(argument_type)
    input_is_parameter_type = type_utils.is_parameter_type(input_type)
    if kfp.COMPILING_FOR_V2 and (argument_is_parameter_type !=
                                input_is_parameter_type):
      if isinstance(argument_value, dsl.PipelineParam):
        param_or_value_msg = 'PipelineParam "{}"'.format(
            argument_value.full_name)
      else:
        param_or_value_msg = 'value "{}"'.format(argument_value)

      raise TypeError(
          'Passing '
          '{param_or_value} with type "{arg_type}" (as "{arg_category}") to '
          'component input '
          '"{input_name}" with type "{input_type}" (as "{input_category}") is '
          'incompatible. Please fix the type of the component input.'.format(
              param_or_value=param_or_value_msg,
              arg_type=argument_type,
              arg_category='Parameter'
              if argument_is_parameter_type else 'Artifact',
              input_name=input_name,
              input_type=input_type,
              input_category='Paramter'
              if input_is_parameter_type else 'Artifact',
          ))

  if not component_spec.name:
    component_spec.name = _components._default_component_name

  # task.name is unique at this point.
  pipeline_task_spec.task_info.name = (dsl_utils.sanitize_task_name(task.name))

  resolved_cmd = _resolve_commands_and_args_v2(
      component_spec=component_spec, arguments=arguments)

  task.container_spec = (
      pipeline_spec_pb2.PipelineDeploymentConfig.PipelineContainerSpec(
          image=component_spec.implementation.container.image,
          command=resolved_cmd.command,
          args=resolved_cmd.args))

  # TODO(chensun): dedupe IR component_spec and contaienr_spec
  pipeline_task_spec.component_ref.name = (
      dsl_utils.sanitize_component_name(task.name))
  executor_label = dsl_utils.sanitize_executor_label(task.name)

  task.component_spec = dsl_component_spec.build_component_spec_from_structure(
      component_spec, executor_label, arguments.keys())

  task.task_spec = pipeline_task_spec

  # Override command and arguments if compiling to v2.
  if kfp.COMPILING_FOR_V2:
    task.command = resolved_cmd.command
    task.arguments = resolved_cmd.args

    # limit this to v2 compiling only to avoid possible behavior change in v1.
    task.inputs = list(input_params_set)
Beispiel #29
0
def create_container_op_from_component_and_arguments(
    component_spec: structures.ComponentSpec,
    arguments: Mapping[str, Any],
    component_ref: structures.ComponentReference = None,
) -> container_op.ContainerOp:
    """Instantiates ContainerOp object.

  Args:
    component_spec: The component spec object.
    arguments: The dictionary of component arguments.
    component_ref: The component reference. Optional.

  Returns:
    A ContainerOp instance.
  """

    pipeline_task_spec = pipeline_spec_pb2.PipelineTaskSpec()
    pipeline_task_spec.task_info.name = component_spec.name
    # might need to append suffix to exuector_label to ensure its uniqueness?
    pipeline_task_spec.executor_label = component_spec.name

    # Keep track of auto-injected importer spec.
    importer_spec = {}

    # Check types of the reference arguments and serialize PipelineParams
    arguments = arguments.copy()
    for input_name, argument_value in arguments.items():
        if isinstance(argument_value, dsl.PipelineParam):
            input_type = component_spec._inputs_dict[input_name].type
            reference_type = argument_value.param_type
            types.verify_type_compatibility(
                reference_type, input_type,
                'Incompatible argument passed to the input "{}" of component "{}": '
                .format(input_name, component_spec.name))

            arguments[input_name] = str(argument_value)

            if type_utils.is_parameter_type(input_type):
                if argument_value.op_name:
                    pipeline_task_spec.inputs.parameters[
                        input_name].task_output_parameter.producer_task = (
                            argument_value.op_name)
                    pipeline_task_spec.inputs.parameters[
                        input_name].task_output_parameter.output_parameter_key = (
                            argument_value.name)
                else:
                    pipeline_task_spec.inputs.parameters[
                        input_name].runtime_value.runtime_parameter = argument_value.name
            else:
                if argument_value.op_name:
                    pipeline_task_spec.inputs.artifacts[
                        input_name].producer_task = (argument_value.op_name)
                    pipeline_task_spec.inputs.artifacts[
                        input_name].output_artifact_key = (argument_value.name)
                else:
                    # argument_value.op_name could be none, in which case an importer node
                    # will be inserted later.
                    pipeline_task_spec.inputs.artifacts[
                        input_name].producer_task = ''
                    type_schema = type_utils.get_input_artifact_type_schema(
                        input_name, component_spec.inputs)
                    importer_spec[
                        input_name] = importer_node.build_importer_spec(
                            input_type_schema=type_schema,
                            pipeline_param_name=argument_value.name)
        elif isinstance(argument_value, str):
            input_type = component_spec._inputs_dict[input_name].type
            if type_utils.is_parameter_type(input_type):
                pipeline_task_spec.inputs.parameters[
                    input_name].runtime_value.constant_value.string_value = (
                        argument_value)
            else:
                # An importer node with constant value artifact_uri will be inserted.
                pipeline_task_spec.inputs.artifacts[
                    input_name].producer_task = ''
                type_schema = type_utils.get_input_artifact_type_schema(
                    input_name, component_spec.inputs)
                importer_spec[input_name] = importer_node.build_importer_spec(
                    input_type_schema=type_schema,
                    constant_value=argument_value)
        elif isinstance(argument_value, int):
            pipeline_task_spec.inputs.parameters[
                input_name].runtime_value.constant_value.int_value = argument_value
        elif isinstance(argument_value, float):
            pipeline_task_spec.inputs.parameters[
                input_name].runtime_value.constant_value.double_value = argument_value
        elif isinstance(argument_value, dsl.ContainerOp):
            raise TypeError(
                'ContainerOp object {} was passed to component as an input argument. '
                'Pass a single output instead.'.format(input_name))
        else:
            raise NotImplementedError(
                'Input argument supports only the following types: PipelineParam'
                ', str, int, float. Got: "{}".'.format(argument_value))

    for output in component_spec.outputs or []:
        if type_utils.is_parameter_type(output.type):
            pipeline_task_spec.outputs.parameters[
                output.name].type = type_utils.get_parameter_type(output.type)
        else:
            pipeline_task_spec.outputs.artifacts[
                output.name].artifact_type.instance_schema = (
                    type_utils.get_artifact_type_schema(output.type))

    inputs_dict = {
        input_spec.name: input_spec
        for input_spec in component_spec.inputs or []
    }
    outputs_dict = {
        output_spec.name: output_spec
        for output_spec in component_spec.outputs or []
    }

    def _input_artifact_uri_placeholder(input_key: str) -> str:
        if type_utils.is_parameter_type(inputs_dict[input_key].type):
            raise TypeError(
                'Input "{}" with type "{}" cannot be paired with InputUriPlaceholder.'
                .format(input_key, inputs_dict[input_key].type))
        else:
            return "{{{{$.inputs.artifacts['{}'].uri}}}}".format(input_key)

    def _input_artifact_path_placeholder(input_key: str) -> str:
        if type_utils.is_parameter_type(inputs_dict[input_key].type):
            raise TypeError(
                'Input "{}" with type "{}" cannot be paired with InputPathPlaceholder.'
                .format(input_key, inputs_dict[input_key].type))
        else:
            return "{{{{$.inputs.artifacts['{}'].path}}}}".format(input_key)

    def _input_parameter_placeholder(input_key: str) -> str:
        if type_utils.is_parameter_type(inputs_dict[input_key].type):
            return "{{{{$.inputs.parameters['{}']}}}}".format(input_key)
        else:
            raise TypeError(
                'Input "{}" with type "{}" cannot be paired with InputValuePlaceholder.'
                .format(input_key, inputs_dict[input_key].type))

    def _output_artifact_uri_placeholder(output_key: str) -> str:
        if type_utils.is_parameter_type(outputs_dict[output_key].type):
            raise TypeError(
                'Output "{}" with type "{}" cannot be paired with OutputUriPlaceholder.'
                .format(output_key, outputs_dict[output_key].type))
        else:
            return "{{{{$.outputs.artifacts['{}'].uri}}}}".format(output_key)

    def _output_artifact_path_placeholder(output_key: str) -> str:
        return "{{{{$.outputs.artifacts['{}'].path}}}}".format(output_key)

    def _output_parameter_path_placeholder(output_key: str) -> str:
        return "{{{{$.outputs.parameters['{}'].output_file}}}}".format(
            output_key)

    def _resolve_output_path_placeholder(output_key: str) -> str:
        if type_utils.is_parameter_type(outputs_dict[output_key].type):
            return _output_parameter_path_placeholder(output_key)
        else:
            return _output_artifact_path_placeholder(output_key)

    resolved_cmd = _resolve_command_line_and_paths(
        component_spec=component_spec,
        arguments=arguments,
        input_value_generator=_input_parameter_placeholder,
        input_uri_generator=_input_artifact_uri_placeholder,
        output_uri_generator=_output_artifact_uri_placeholder,
        input_path_generator=_input_artifact_path_placeholder,
        output_path_generator=_resolve_output_path_placeholder,
    )

    container_spec = component_spec.implementation.container

    pipeline_container_spec = (
        pipeline_spec_pb2.PipelineDeploymentConfig.PipelineContainerSpec())
    pipeline_container_spec.image = container_spec.image
    pipeline_container_spec.command.extend(resolved_cmd.command)
    pipeline_container_spec.args.extend(resolved_cmd.args)

    output_uris_and_paths = resolved_cmd.output_uris.copy()
    output_uris_and_paths.update(resolved_cmd.output_paths)
    input_uris_and_paths = resolved_cmd.input_uris.copy()
    input_uris_and_paths.update(resolved_cmd.input_paths)

    old_warn_value = dsl.ContainerOp._DISABLE_REUSABLE_COMPONENT_WARNING
    dsl.ContainerOp._DISABLE_REUSABLE_COMPONENT_WARNING = True
    task = container_op.ContainerOp(
        name=component_spec.name or _default_component_name,
        image=container_spec.image,
        command=resolved_cmd.command,
        arguments=resolved_cmd.args,
        file_outputs=output_uris_and_paths,
        artifact_argument_paths=[
            dsl.InputArgumentPath(
                argument=arguments[input_name],
                input=input_name,
                path=path,
            ) for input_name, path in input_uris_and_paths.items()
        ],
    )

    task.task_spec = pipeline_task_spec
    task.importer_spec = importer_spec
    task.container_spec = pipeline_container_spec
    dsl.ContainerOp._DISABLE_REUSABLE_COMPONENT_WARNING = old_warn_value

    component_meta = copy.copy(component_spec)
    task._set_metadata(component_meta)
    component_ref_without_spec = copy.copy(component_ref)
    component_ref_without_spec.spec = None
    task._component_ref = component_ref_without_spec

    # Previously, ContainerOp had strict requirements for the output names, so we
    # had to convert all the names before passing them to the ContainerOp
    # constructor. Outputs with non-pythonic names could not be accessed using
    # their original names. Now ContainerOp supports any output names, so we're
    # now using the original output names. However to support legacy pipelines,
    # we're also adding output references with pythonic names.
    # TODO: Add warning when people use the legacy output names.
    output_names = [
        output_spec.name for output_spec in component_spec.outputs or []
    ]  # Stabilizing the ordering
    output_name_to_python = generate_unique_name_conversion_table(
        output_names, _sanitize_python_function_name)
    for output_name in output_names:
        pythonic_output_name = output_name_to_python[output_name]
        # Note: Some component outputs are currently missing from task.outputs
        # (e.g. MLPipeline UI Metadata)
        if pythonic_output_name not in task.outputs and output_name in task.outputs:
            task.outputs[pythonic_output_name] = task.outputs[output_name]

    if component_spec.metadata:
        annotations = component_spec.metadata.annotations or {}
        for key, value in annotations.items():
            task.add_pod_annotation(key, value)
        for key, value in (component_spec.metadata.labels or {}).items():
            task.add_pod_label(key, value)
            # Disabling the caching for the volatile components by default
        if annotations.get('volatile_component', 'false') == 'true':
            task.execution_options.caching_strategy.max_cache_staleness = 'P0D'

    return task
Beispiel #30
0
    def test_pop_input_from_task_spec(self):
        task_spec = pipeline_spec_pb2.PipelineTaskSpec()
        task_spec.component_ref.name = 'comp-component1'
        task_spec.inputs.artifacts[
            'input1'].task_output_artifact.producer_task = 'op-1'
        task_spec.inputs.artifacts[
            'input1'].task_output_artifact.output_artifact_key = 'output1'
        task_spec.inputs.parameters[
            'input2'].task_output_parameter.producer_task = 'op-2'
        task_spec.inputs.parameters[
            'input2'].task_output_parameter.output_parameter_key = 'output2'
        task_spec.inputs.parameters[
            'input3'].component_input_parameter = 'op3-output3'

        # pop an parameter, and there're other inputs left
        dsl_component_spec.pop_input_from_task_spec(task_spec, 'input3')
        expected_dict = {
            'inputs': {
                'artifacts': {
                    'input1': {
                        'taskOutputArtifact': {
                            'producerTask': 'op-1',
                            'outputArtifactKey': 'output1'
                        }
                    }
                },
                'parameters': {
                    'input2': {
                        'taskOutputParameter': {
                            'producerTask': 'op-2',
                            'outputParameterKey': 'output2'
                        }
                    }
                }
            },
            'component_ref': {
                'name': 'comp-component1'
            }
        }
        expected_spec = pipeline_spec_pb2.PipelineTaskSpec()
        json_format.ParseDict(expected_dict, expected_spec)
        self.assertEqual(expected_spec, task_spec)

        # pop an artifact, and there're other inputs left
        dsl_component_spec.pop_input_from_task_spec(task_spec, 'input1')
        expected_dict = {
            'inputs': {
                'parameters': {
                    'input2': {
                        'taskOutputParameter': {
                            'producerTask': 'op-2',
                            'outputParameterKey': 'output2'
                        }
                    }
                }
            },
            'component_ref': {
                'name': 'comp-component1'
            }
        }
        expected_spec = pipeline_spec_pb2.PipelineTaskSpec()
        json_format.ParseDict(expected_dict, expected_spec)
        self.assertEqual(expected_spec, task_spec)

        # pop the last input, expect no inputDefinitions
        dsl_component_spec.pop_input_from_task_spec(task_spec, 'input2')
        expected_dict = {'component_ref': {'name': 'comp-component1'}}
        expected_spec = pipeline_spec_pb2.PipelineTaskSpec()
        json_format.ParseDict(expected_dict, expected_spec)
        self.assertEqual(expected_spec, task_spec)

        # pop an input that doesn't exist, expect no-op.
        dsl_component_spec.pop_input_from_task_spec(task_spec, 'input4')
        self.assertEqual(expected_spec, task_spec)