예제 #1
0
    def resolve(
        self,
        pipeline_info: data_types.PipelineInfo,
        metadata_handler: metadata.Metadata,
        source_channels: Dict[Text, types.Channel],
    ) -> base_resolver.ResolveResult:
        artifacts_dict = {}
        resolve_state_dict = {}
        pipeline_context = metadata_handler.get_pipeline_context(pipeline_info)
        if pipeline_context is None:
            raise RuntimeError('Pipeline context absent for %s' %
                               pipeline_context)
        artifacts_in_context = metadata_handler.get_published_artifacts_by_type_within_context(
            [c.type_name for c in source_channels.values()],
            pipeline_context.id)
        for k, c in source_channels.items():
            previous_artifacts = sorted(artifacts_in_context[c.type_name],
                                        key=lambda m: m.id,
                                        reverse=True)
            if len(previous_artifacts) >= self._desired_num_of_artifact:
                artifacts_dict[k] = [
                    _generate_tfx_artifact(a, c.type)
                    for a in previous_artifacts[:self._desired_num_of_artifact]
                ]
                resolve_state_dict[k] = True
            else:
                artifacts_dict[k] = [
                    _generate_tfx_artifact(a, c.type)
                    for a in previous_artifacts
                ]
                resolve_state_dict[k] = False

        return base_resolver.ResolveResult(
            per_key_resolve_result=artifacts_dict,
            per_key_resolve_state=resolve_state_dict)
예제 #2
0
    def resolve(
        self,
        pipeline_info: data_types.PipelineInfo,
        metadata_handler: metadata.Metadata,
        source_channels: Dict[Text, types.Channel],
    ) -> base_resolver.ResolveResult:
        # First, checks whether we have exactly Model and ModelBlessing Channels.
        model_channel_key = None
        model_blessing_channel_key = None
        assert len(source_channels) == 2, 'Expecting 2 input Channels'
        for k, c in source_channels.items():
            if issubclass(c.type, standard_artifacts.Model):
                model_channel_key = k
            elif issubclass(c.type, standard_artifacts.ModelBlessing):
                model_blessing_channel_key = k
            else:
                raise RuntimeError(
                    'Only expecting Model or ModelBlessing, got %s' % c.type)
        assert model_channel_key is not None, 'Expecting Model as input'
        assert model_blessing_channel_key is not None, (
            'Expecting ModelBlessing as'
            ' input')

        # Gets the pipeline context as the artifact search space.
        pipeline_context = metadata_handler.get_pipeline_context(pipeline_info)
        if pipeline_context is None:
            raise RuntimeError('Pipeline context absent for %s' %
                               pipeline_context)
        # Gets all artifacts of interests within context with one call.
        artifacts_in_context = metadata_handler.get_published_artifacts_by_type_within_context(
            [
                source_channels[model_channel_key].type_name,
                source_channels[model_blessing_channel_key].type_name
            ], pipeline_context.id)
        # Gets all models in the search space and sort in reverse order by id.
        all_models = sorted(
            artifacts_in_context[source_channels[model_channel_key].type_name],
            key=lambda m: m.id,
            reverse=True)
        # Gets all ModelBlessing artifacts in the search space.
        all_model_blessings = artifacts_in_context[
            source_channels[model_blessing_channel_key].type_name]
        # Makes a dict of {model_id : ModelBlessing artifact} for blessed models.
        all_blessed_model_ids = dict((  # pylint: disable=g-complex-comprehension
            a.custom_properties[
                model_validator.ARTIFACT_PROPERTY_CURRENT_MODEL_ID_KEY].
            int_value, a) for a in all_model_blessings if a.custom_properties[
                model_validator.ARTIFACT_PROPERTY_BLESSED_KEY].int_value == 1)

        artifacts_dict = {
            model_channel_key: [],
            model_blessing_channel_key: []
        }
        resolve_state_dict = {
            model_channel_key: False,
            model_blessing_channel_key: False
        }
        # Iterates all models, if blessed, set as result. As the model list was
        # sorted, it is guaranteed to get the latest blessed model.
        for model in all_models:
            if model.id in all_blessed_model_ids:
                artifacts_dict[model_channel_key] = [
                    _generate_tfx_artifact(model, standard_artifacts.Model)
                ]
                artifacts_dict[model_blessing_channel_key] = [
                    _generate_tfx_artifact(all_blessed_model_ids[model.id],
                                           standard_artifacts.ModelBlessing)
                ]
                resolve_state_dict[model_channel_key] = True
                resolve_state_dict[model_blessing_channel_key] = True
                break

        return base_resolver.ResolveResult(
            per_key_resolve_result=artifacts_dict,
            per_key_resolve_state=resolve_state_dict)