Exemplo n.º 1
0
    def _build_latest_artifact_resolver(
            self) -> List[pipeline_pb2.PipelineTaskSpec]:
        """Builds a resolver spec for a latest artifact resolver.

    Returns:
      A list of two PipelineTaskSpecs. One represents the query for latest valid
      ModelBlessing artifact. Another one represents the query for latest
      blessed Model artifact.
    Raises:
      ValueError: when desired_num_of_artifacts != 1. 1 is the only supported
        value currently.
    """

        task_spec = pipeline_pb2.PipelineTaskSpec()
        task_spec.task_info.CopyFrom(
            pipeline_pb2.PipelineTaskInfo(name=self._name))
        executor_label = _EXECUTOR_LABEL_PATTERN.format(self._name)
        task_spec.executor_label = executor_label

        # Fetch the init kwargs for the resolver.
        resolver_config = self._exec_properties[resolver.RESOLVER_CONFIG]
        if (isinstance(resolver_config, dict)
                and resolver_config.get('desired_num_of_artifacts', 0) > 1):
            raise ValueError(
                'Only desired_num_of_artifacts=1 is supported currently.'
                ' Got {}'.format(
                    resolver_config.get('desired_num_of_artifacts')))

        # Specify the outputs of the task.
        for name, output_channel in self._outputs.items():
            # Currently, we're working under the assumption that for tasks
            # (those generated by BaseComponent), each channel contains a single
            # artifact.
            output_artifact_spec = compiler_utils.build_output_artifact_spec(
                output_channel)
            task_spec.outputs.artifacts[name].CopyFrom(output_artifact_spec)

        # Specify the input parameters of the task.
        for k, v in compiler_utils.build_input_parameter_spec(
                self._exec_properties).items():
            task_spec.inputs.parameters[k].CopyFrom(v)

        artifact_queries = {}
        # Buid the artifact query for each channel in the input dict.
        for name, c in self._inputs.items():
            query_filter = ("artifact_type='{type}' and state={state}").format(
                type=compiler_utils.get_artifact_title(c.type),
                state=metadata_store_pb2.Artifact.State.Name(
                    metadata_store_pb2.Artifact.LIVE))
            artifact_queries[name] = ResolverSpec.ArtifactQuerySpec(
                filter=query_filter)

        resolver_spec = ResolverSpec(output_artifact_queries=artifact_queries)
        executor = pipeline_pb2.PipelineDeploymentConfig.ExecutorSpec()
        executor.resolver.CopyFrom(resolver_spec)
        self._deployment_config.executors[executor_label].CopyFrom(executor)
        return [task_spec]
Exemplo n.º 2
0
    def build(self) -> List[pipeline_pb2.PipelineTaskSpec]:
        """Builds a pipeline StepSpec given the node information.

    Returns:
      A list of PipelineTaskSpec messages corresponding to the node. For most of
      the cases, the list contains a single element. The only exception is when
      compiling latest blessed model resolver. One DSL node will be
      split to two resolver specs to reflect the two-phased query execution.
    Raises:
      NotImplementedError: When the node being built is an InfraValidator.
    """
        task_spec = pipeline_pb2.PipelineTaskSpec()
        task_spec.task_info.CopyFrom(
            pipeline_pb2.PipelineTaskInfo(name=self._name))
        executor_label = _EXECUTOR_LABEL_PATTERN.format(self._name)
        task_spec.executor_label = executor_label
        executor = pipeline_pb2.PipelineDeploymentConfig.ExecutorSpec()

        # 1. Resolver tasks won't have input artifacts in the API proto. First we
        #    specialcase two resolver types we support.
        if isinstance(self._node, resolver.Resolver):
            return self._build_resolver_spec()

        # 2. Build the node spec.
        # TODO(b/157641727): Tests comparing dictionaries are brittle when comparing
        # lists as ordering matters.
        dependency_ids = [node.id for node in self._node.upstream_nodes]
        # Specify the inputs of the task.
        for name, input_channel in self._inputs.items():
            # If the redirecting map is provided (usually for latest blessed model
            # resolver, we'll need to redirect accordingly. Also, the upstream node
            # list will be updated and replaced by the new producer id.
            producer_id = input_channel.producer_component_id
            output_key = input_channel.output_key
            for k, v in self._channel_redirect_map.items():
                if k[0] == producer_id and producer_id in dependency_ids:
                    dependency_ids.remove(producer_id)
                    dependency_ids.append(v[0])
            producer_id = self._channel_redirect_map.get(
                (producer_id, output_key), (producer_id, output_key))[0]
            output_key = self._channel_redirect_map.get(
                (producer_id, output_key), (producer_id, output_key))[1]

            input_artifact_spec = pipeline_pb2.TaskInputsSpec.InputArtifactSpec(
                producer_task=producer_id, output_artifact_key=output_key)
            task_spec.inputs.artifacts[name].CopyFrom(input_artifact_spec)

        # Specify the outputs of the task.
        for name, output_channel in self._outputs.items():
            # Currently, we're working under the assumption that for tasks
            # (those generated by BaseComponent), each channel contains a single
            # artifact.
            output_artifact_spec = compiler_utils.build_output_artifact_spec(
                output_channel)
            task_spec.outputs.artifacts[name].CopyFrom(output_artifact_spec)

        # Specify the input parameters of the task.
        for k, v in compiler_utils.build_input_parameter_spec(
                self._exec_properties).items():
            task_spec.inputs.parameters[k].CopyFrom(v)

        # 3. Build the executor body for other common tasks.
        if isinstance(self._node, importer.Importer):
            executor.importer.CopyFrom(self._build_importer_spec())
        elif isinstance(self._node, components.FileBasedExampleGen):
            executor.container.CopyFrom(
                self._build_file_based_example_gen_spec())
        elif isinstance(self._node, (components.InfraValidator)):
            raise NotImplementedError(
                'The componet type "{}" is not supported'.format(
                    type(self._node)))
        else:
            executor.container.CopyFrom(self._build_container_spec())

        dependency_ids = sorted(dependency_ids)
        for dependency in dependency_ids:
            task_spec.dependent_tasks.append(dependency)

        task_spec.caching_options.CopyFrom(
            pipeline_pb2.PipelineTaskSpec.CachingOptions(
                enable_cache=self._enable_cache))

        # 4. Attach the built executor spec to the deployment config.
        self._deployment_config.executors[executor_label].CopyFrom(executor)

        return [task_spec]