예제 #1
0
  def resolve(
      self,
      pipeline_info: data_types.PipelineInfo,
      metadata_handler: metadata.Metadata,
      source_channels: Dict[Text, types.Channel],
  ) -> base_resolver.ResolveResult:
    pipeline_context = metadata_handler.get_pipeline_context(pipeline_info)
    if pipeline_context is None:
      raise RuntimeError('Pipeline context absent for %s' % pipeline_context)

    candidate_dict = {}
    for k, c in source_channels.items():
      cancidate_artifacts = metadata_handler.get_qualified_artifacts(
          contexts=[pipeline_context],
          type_name=c.type_name,
          producer_component_id=c.producer_component_id,
          output_key=c.output_key)
      candidate_dict[k] = [
          artifact_utils.deserialize_artifact(a.type, a.artifact)
          for a in cancidate_artifacts
      ]

    resolved_dict = self._resolve(candidate_dict)
    resolve_state_dict = {
        k: len(artifact_list) >= self._desired_num_of_artifact
        for k, artifact_list in resolved_dict.items()
    }

    return base_resolver.ResolveResult(
        per_key_resolve_result=resolved_dict,
        per_key_resolve_state=resolve_state_dict)
예제 #2
0
    def resolve(
        self,
        metadata_handler: metadata.Metadata,
        source_channels: Dict[Text, types.Channel],
    ) -> base_resolver.ResolveResult:
        artifacts_dict = {}
        resolve_state_dict = {}
        for k, c in source_channels.items():
            previous_artifacts = sorted(metadata_handler.get_artifacts_by_type(
                c.type_name),
                                        key=lambda m: m.id,
                                        reverse=True)
            if len(previous_artifacts) >= self._desired_num_of_artifact:
                artifacts_dict[k] = [
                    _generate_tfx_artifact(a, c.type)
                    for a in previous_artifacts[:self._desired_num_of_artifact]
                ]
                resolve_state_dict[k] = True
            else:
                artifacts_dict[k] = [
                    _generate_tfx_artifact(a, c.type)
                    for a in previous_artifacts
                ]
                resolve_state_dict[k] = False

        return base_resolver.ResolveResult(
            per_key_resolve_result=artifacts_dict,
            per_key_resolve_state=resolve_state_dict)
예제 #3
0
    def resolve(
        self,
        pipeline_info: data_types.PipelineInfo,
        metadata_handler: metadata.Metadata,
        source_channels: Dict[Text, types.Channel],
    ) -> base_resolver.ResolveResult:
        artifacts_dict = {}
        resolve_state_dict = {}
        pipeline_context = metadata_handler.get_pipeline_context(pipeline_info)
        if pipeline_context is None:
            raise RuntimeError('Pipeline context absent for %s' %
                               pipeline_context)
        artifacts_in_context = metadata_handler.get_published_artifacts_by_type_within_context(
            [c.type_name for c in source_channels.values()],
            pipeline_context.id)
        for k, c in source_channels.items():
            previous_artifacts = sorted(artifacts_in_context[c.type_name],
                                        key=lambda m: m.id,
                                        reverse=True)
            if len(previous_artifacts) >= self._desired_num_of_artifact:
                artifacts_dict[k] = [
                    _generate_tfx_artifact(a, c.type)
                    for a in previous_artifacts[:self._desired_num_of_artifact]
                ]
                resolve_state_dict[k] = True
            else:
                artifacts_dict[k] = [
                    _generate_tfx_artifact(a, c.type)
                    for a in previous_artifacts
                ]
                resolve_state_dict[k] = False

        return base_resolver.ResolveResult(
            per_key_resolve_result=artifacts_dict,
            per_key_resolve_state=resolve_state_dict)
예제 #4
0
  def resolve(
      self,
      pipeline_info: data_types.PipelineInfo,
      metadata_handler: metadata.Metadata,
      source_channels: Dict[Text, types.Channel],
  ) -> base_resolver.ResolveResult:
    artifacts_dict = {}
    resolve_state_dict = {}
    pipeline_context = metadata_handler.get_pipeline_context(pipeline_info)
    if pipeline_context is None:
      raise RuntimeError('Pipeline context absent for %s' % pipeline_context)
    for k, c in source_channels.items():
      candidate_artifacts = metadata_handler.get_qualified_artifacts(
          context=pipeline_context,
          type_name=c.type_name,
          producer_component_id=c.producer_component_id,
          output_key=c.output_key)
      previous_artifacts = sorted(
          candidate_artifacts, key=lambda a: a.artifact.id, reverse=True)
      if len(previous_artifacts) >= self._desired_num_of_artifact:
        artifacts_dict[k] = [
            artifact_utils.deserialize_artifact(a.type, a.artifact)
            for a in previous_artifacts[:self._desired_num_of_artifact]
        ]
        resolve_state_dict[k] = True
      else:
        artifacts_dict[k] = [
            artifact_utils.deserialize_artifact(a.type, a.artifact)
            for a in previous_artifacts
        ]
        resolve_state_dict[k] = False

    return base_resolver.ResolveResult(
        per_key_resolve_result=artifacts_dict,
        per_key_resolve_state=resolve_state_dict)
예제 #5
0
  def resolve(
      self,
      metadata_handler: metadata.Metadata,
      source_channels: Dict[Text, types.Channel],
  ) -> base_resolver.ResolveResult:
    artifacts_dict = {}
    for k, c in source_channels.items():
      previous_artifacts = metadata_handler.get_artifacts_by_type(c.type_name)
      if previous_artifacts:
        latest_mlmd_artifact = max(previous_artifacts, key=lambda m: m.id)
        result_artifact = types.Artifact(type_name=c.type_name)
        result_artifact.set_artifact(latest_mlmd_artifact)
        artifacts_dict[k] = ([result_artifact], True)
      else:
        artifacts_dict[k] = ([], False)

    return base_resolver.ResolveResult(per_key_resolve_result=artifacts_dict)
예제 #6
0
    def resolve(
        self,
        pipeline_info: data_types.PipelineInfo,
        metadata_handler: metadata.Metadata,
        source_channels: Dict[Text, types.Channel],
    ) -> base_resolver.ResolveResult:
        # First, checks whether we have exactly Model and ModelBlessing Channels.
        model_channel_key = None
        model_blessing_channel_key = None
        assert len(source_channels) == 2, 'Expecting 2 input Channels'
        for k, c in source_channels.items():
            if issubclass(c.type, standard_artifacts.Model):
                model_channel_key = k
            elif issubclass(c.type, standard_artifacts.ModelBlessing):
                model_blessing_channel_key = k
            else:
                raise RuntimeError(
                    'Only expecting Model or ModelBlessing, got %s' % c.type)
        assert model_channel_key is not None, 'Expecting Model as input'
        assert model_blessing_channel_key is not None, (
            'Expecting ModelBlessing as'
            ' input')

        # Gets the pipeline context as the artifact search space.
        pipeline_context = metadata_handler.get_pipeline_context(pipeline_info)
        if pipeline_context is None:
            raise RuntimeError('Pipeline context absent for %s' %
                               pipeline_context)
        # Gets all artifacts of interests within context with one call.
        artifacts_in_context = metadata_handler.get_published_artifacts_by_type_within_context(
            [
                source_channels[model_channel_key].type_name,
                source_channels[model_blessing_channel_key].type_name
            ], pipeline_context.id)
        # Gets all models in the search space and sort in reverse order by id.
        all_models = sorted(
            artifacts_in_context[source_channels[model_channel_key].type_name],
            key=lambda m: m.id,
            reverse=True)
        # Gets all ModelBlessing artifacts in the search space.
        all_model_blessings = artifacts_in_context[
            source_channels[model_blessing_channel_key].type_name]
        # Makes a dict of {model_id : ModelBlessing artifact} for blessed models.
        all_blessed_model_ids = dict((  # pylint: disable=g-complex-comprehension
            a.custom_properties[
                model_validator.ARTIFACT_PROPERTY_CURRENT_MODEL_ID_KEY].
            int_value, a) for a in all_model_blessings if a.custom_properties[
                model_validator.ARTIFACT_PROPERTY_BLESSED_KEY].int_value == 1)

        artifacts_dict = {
            model_channel_key: [],
            model_blessing_channel_key: []
        }
        resolve_state_dict = {
            model_channel_key: False,
            model_blessing_channel_key: False
        }
        # Iterates all models, if blessed, set as result. As the model list was
        # sorted, it is guaranteed to get the latest blessed model.
        for model in all_models:
            if model.id in all_blessed_model_ids:
                artifacts_dict[model_channel_key] = [
                    _generate_tfx_artifact(model, standard_artifacts.Model)
                ]
                artifacts_dict[model_blessing_channel_key] = [
                    _generate_tfx_artifact(all_blessed_model_ids[model.id],
                                           standard_artifacts.ModelBlessing)
                ]
                resolve_state_dict[model_channel_key] = True
                resolve_state_dict[model_blessing_channel_key] = True
                break

        return base_resolver.ResolveResult(
            per_key_resolve_result=artifacts_dict,
            per_key_resolve_state=resolve_state_dict)
예제 #7
0
    def resolve(
        self,
        pipeline_info: data_types.PipelineInfo,
        metadata_handler: metadata.Metadata,
        source_channels: Dict[Text, types.Channel],
    ) -> base_resolver.ResolveResult:
        # First, checks whether we have exactly Model and ModelBlessing Channels.
        model_channel_key = None
        model_blessing_channel_key = None
        assert len(source_channels) == 2, 'Expecting 2 input Channels'
        for k, c in source_channels.items():
            if issubclass(c.type, standard_artifacts.Model):
                model_channel_key = k
            elif issubclass(c.type, standard_artifacts.ModelBlessing):
                model_blessing_channel_key = k
            else:
                raise RuntimeError(
                    'Only expecting Model or ModelBlessing, got %s' % c.type)
        assert model_channel_key is not None, 'Expecting Model as input'
        assert model_blessing_channel_key is not None, (
            'Expecting ModelBlessing as'
            ' input')

        model_channel = source_channels[model_channel_key]
        model_blessing_channel = source_channels[model_blessing_channel_key]
        # Gets the pipeline context as the artifact search space.
        pipeline_context = metadata_handler.get_pipeline_context(pipeline_info)
        if pipeline_context is None:
            raise RuntimeError('Pipeline context absent for %s' %
                               pipeline_context)

        candidate_dict = {}
        # Gets all models in the search space and sort in reverse order by id.
        all_models = metadata_handler.get_qualified_artifacts(
            contexts=[pipeline_context],
            type_name=model_channel.type_name,
            producer_component_id=model_channel.producer_component_id,
            output_key=model_channel.output_key)
        candidate_dict[model_channel_key] = [
            artifact_utils.deserialize_artifact(a.type, a.artifact)
            for a in all_models
        ]
        # Gets all ModelBlessing artifacts in the search space.
        all_model_blessings = metadata_handler.get_qualified_artifacts(
            contexts=[pipeline_context],
            type_name=model_blessing_channel.type_name,
            producer_component_id=model_blessing_channel.producer_component_id,
            output_key=model_blessing_channel.output_key)
        candidate_dict[model_blessing_channel_key] = [
            artifact_utils.deserialize_artifact(a.type, a.artifact)
            for a in all_model_blessings
        ]

        resolved_dict = self._resolve(candidate_dict, model_channel_key,
                                      model_blessing_channel_key)
        resolve_state_dict = {
            k: bool(artifact_list)
            for k, artifact_list in resolved_dict.items()
        }

        return base_resolver.ResolveResult(
            per_key_resolve_result=resolved_dict,
            per_key_resolve_state=resolve_state_dict)