def setUp(self): super().setUp() self.artifact_struct_dict = { 'a1': text_format.Parse( """ elements { artifact { artifact { id: 123 } type { name: 't1' } } } """, metadata_store_service_pb2.ArtifactStructList()), 'a2': text_format.Parse( """ elements { artifact { artifact { id: 456 } type { name: 't2' } } } """, metadata_store_service_pb2.ArtifactStructList()) } self.artifact_dict = { 'a1': [ artifact_utils.deserialize_artifact( metadata_store_pb2.ArtifactType(name='t1'), metadata_store_pb2.Artifact(id=123)) ], 'a2': [ artifact_utils.deserialize_artifact( metadata_store_pb2.ArtifactType(name='t2'), metadata_store_pb2.Artifact(id=456)) ] } self.metadata_value_dict = { 'p0': metadata_store_pb2.Value(int_value=0), 'p1': metadata_store_pb2.Value(int_value=1), 'p2': metadata_store_pb2.Value(string_value='hello'), 'p3': metadata_store_pb2.Value(string_value='') } self.value_dict = {'p0': 0, 'p1': 1, 'p2': 'hello', 'p3': ''}
def generate_output_artifacts( self, execution_id: int) -> Dict[Text, List[types.Artifact]]: """Generates output artifacts given execution_id.""" output_artifacts = collections.defaultdict(list) for key, output_spec in self._pipeline_node.outputs.outputs.items(): artifact = artifact_utils.deserialize_artifact( output_spec.artifact_spec.type) artifact.uri = os.path.join(self._node_dir, _EXECUTION_PREFIX + str(execution_id), key) if isinstance(artifact, ValueArtifact): artifact.uri = os.path.join(artifact.uri, _VALUE_ARTIFACT_FILE_NAME) # artifact.name will contain the set of information to track its creation # and is guaranteed to be idempotent across retires of a node. artifact.name = '{0}:{1}:{2}:{3}:{4}'.format( self._pipeline_info.id, self._pipeline_run_id, self._pipeline_node.node_info.id, key, # The index of this artifact, since we only has one artifact per # output for now, it is always 0. # TODO(b/162331170): Update the "0" to the actual index. 0) logging.debug('Creating output artifact uri %s as directory', artifact.uri) output_artifacts[key].append(artifact) return output_artifacts
def _get_outputs_of_execution( self, execution_id: int, events: List[metadata_store_pb2.Event] ) -> Optional[Dict[Text, List[Artifact]]]: """Fetches outputs produced by a historical execution. Args: execution_id: the id of the execution that produced the outputs. events: events related to the execution id. Returns: A dict of key -> List[Artifact] as the result """ absl.logging.debug('Execution %s matches all inputs' % execution_id) result = collections.defaultdict(list) output_events = [ event for event in events if event.type in [metadata_store_pb2.Event.OUTPUT] ] output_events.sort(key=lambda e: e.path.steps[1].index) cached_output_artifacts = self.store.get_artifacts_by_id( [e.artifact_id for e in output_events]) artifact_types = self.store.get_artifact_types_by_id( [a.type_id for a in cached_output_artifacts]) for event, mlmd_artifact, artifact_type in zip( output_events, cached_output_artifacts, artifact_types): key = event.path.steps[0].key tfx_artifact = artifact_utils.deserialize_artifact( artifact_type, mlmd_artifact) result[key].append(tfx_artifact) return result
def get_qualified_artifacts( metadata_handler: metadata.Metadata, contexts: Iterable[metadata_store_pb2.Context], artifact_type: metadata_store_pb2.ArtifactType, output_key: Optional[str] = None, ) -> List[types.Artifact]: """Gets qualified artifacts that have the right producer info. Args: metadata_handler: A metadata handler to access MLMD store. contexts: Context constraints to filter artifacts artifact_type: Type constraint to filter artifacts output_key: Output key constraint to filter artifacts Returns: A list of qualified TFX Artifacts. """ # We expect to have at least one context for input resolution. assert contexts, 'Must have at least one context.' try: artifact_type_name = artifact_type.name artifact_type = metadata_handler.store.get_artifact_type( artifact_type_name) except mlmd.errors.NotFoundError: logging.warning('Artifact type %s is not found in MLMD.', artifact_type.name) artifact_type = None if not artifact_type: return [] executions_within_context = ( execution_lib.get_executions_associated_with_all_contexts( metadata_handler, contexts)) # Filters out non-success executions. qualified_producer_executions = [ e.id for e in executions_within_context if execution_lib.is_execution_successful(e) ] # Gets the output events that have the matched output key. qualified_output_events = [ ev for ev in metadata_handler.store.get_events_by_execution_ids( qualified_producer_executions) if event_lib.validate_output_event(ev, output_key) ] # Gets the candidate artifacts from output events. candidate_artifacts = metadata_handler.store.get_artifacts_by_id( list(set(ev.artifact_id for ev in qualified_output_events))) # Filters the artifacts that have the right artifact type and state. qualified_artifacts = [ a for a in candidate_artifacts if a.type_id == artifact_type.id and a.state == metadata_store_pb2.Artifact.LIVE ] return [ artifact_utils.deserialize_artifact(artifact_type, a) for a in qualified_artifacts ]
def generate_output_artifacts( self, execution_id: int) -> Dict[Text, List[types.Artifact]]: """Generates output artifacts given execution_id.""" output_artifacts = collections.defaultdict(list) for key, output_spec in self._pipeline_node.outputs.outputs.items(): artifact = artifact_utils.deserialize_artifact( output_spec.artifact_spec.type) artifact.uri = os.path.join(self._node_dir, key, str(execution_id)) if isinstance(artifact, ValueArtifact): artifact.uri = os.path.join(artifact.uri, _VALUE_ARTIFACT_FILE_NAME) # artifact.name will contain the set of information to track its creation # and is guaranteed to be idempotent across retires of a node. artifact_name = f'{self._pipeline_info.id}' if self._execution_mode == pipeline_pb2.Pipeline.SYNC: artifact_name = f'{artifact_name}:{self._pipeline_run_id}' # The index of this artifact, since we only has one artifact per output # for now, it is always 0. # TODO(b/162331170): Update the "0" to the actual index. artifact_name = ( f'{artifact_name}:{self._pipeline_node.node_info.id}:{key}:0') artifact.name = artifact_name _attach_artifact_properties(output_spec.artifact_spec, artifact) logging.debug('Creating output artifact uri %s', artifact.uri) output_artifacts[key].append(artifact) return output_artifacts
def _build_input_dict( self, pipeline_info: data_types.PipelineInfo, input_channels: Mapping[str, types.BaseChannel], ) -> Dict[str, List[types.Artifact]]: pipeline_context = self._metadata_handler.get_pipeline_context( pipeline_info) if pipeline_context is None: raise RuntimeError(f'Pipeline context absent for {pipeline_info}.') result = {} for key, c in input_channels.items(): artifacts_by_id = {} # Deduplicate by ID. for channel in channel_utils.get_individual_channels(c): artifact_and_types = self._metadata_handler.get_qualified_artifacts( contexts=[pipeline_context], type_name=channel.type_name, producer_component_id=channel.producer_component_id, output_key=channel.output_key) artifacts = [ artifact_utils.deserialize_artifact(a.type, a.artifact) for a in artifact_and_types ] artifacts_by_id.update({a.id: a for a in artifacts}) result[key] = list(artifacts_by_id.values()) return result
def resolve( self, pipeline_info: data_types.PipelineInfo, metadata_handler: metadata.Metadata, source_channels: Dict[Text, types.Channel], ) -> base_resolver.ResolveResult: pipeline_context = metadata_handler.get_pipeline_context(pipeline_info) if pipeline_context is None: raise RuntimeError('Pipeline context absent for %s' % pipeline_context) candidate_dict = {} for k, c in source_channels.items(): cancidate_artifacts = metadata_handler.get_qualified_artifacts( contexts=[pipeline_context], type_name=c.type_name, producer_component_id=c.producer_component_id, output_key=c.output_key) candidate_dict[k] = [ artifact_utils.deserialize_artifact(a.type, a.artifact) for a in cancidate_artifacts ] resolved_dict = self._resolve(candidate_dict) resolve_state_dict = { k: len(artifact_list) >= self._desired_num_of_artifact for k, artifact_list in resolved_dict.items() } return base_resolver.ResolveResult( per_key_resolve_result=resolved_dict, per_key_resolve_state=resolve_state_dict)
def search_artifacts(self, artifact_name: Text, pipeline_info: data_types.PipelineInfo, producer_component_id: Text) -> List[Artifact]: """Search artifacts that matches given info. Args: artifact_name: the name of the artifact that set by producer component. The name is logged both in artifacts and the events when the execution being published. pipeline_info: the information of the current pipeline producer_component_id: the id of the component that produces the artifact Returns: A list of Artifacts that matches the given info Raises: RuntimeError: when no matching execution is found given producer info. """ producer_execution = None matching_artifact_ids = set() # TODO(ruoyu): We need to revisit this when adding support for async # execution. context = self.get_pipeline_run_context(pipeline_info) if context is None: raise RuntimeError('Pipeline run context for %s does not exist' % pipeline_info) for execution in self.store.get_executions_by_context(context.id): if execution.properties[ 'component_id'].string_value == producer_component_id: producer_execution = execution break if not producer_execution: raise RuntimeError( 'Cannot find matching execution with pipeline name %s,' 'run id %s and component id %s' % (pipeline_info.pipeline_name, pipeline_info.run_id, producer_component_id)) for event in self.store.get_events_by_execution_ids( [producer_execution.id]): if (event.type == metadata_store_pb2.Event.OUTPUT and event.path.steps[0].key == artifact_name): matching_artifact_ids.add(event.artifact_id) # Get relevant artifacts along with their types. artifacts_by_id = self.store.get_artifacts_by_id( list(matching_artifact_ids)) matching_artifact_type_ids = list( set(a.type_id for a in artifacts_by_id)) matching_artifact_types = self.store.get_artifact_types_by_id( matching_artifact_type_ids) artifact_types = dict( zip(matching_artifact_type_ids, matching_artifact_types)) result_artifacts = [] for a in artifacts_by_id: tfx_artifact = artifact_utils.deserialize_artifact( artifact_types[a.type_id], a) result_artifacts.append(tfx_artifact) return result_artifacts
def _build_artifact_dict(proto_dict): """Build ExecutionInfo input/output artifact dicts.""" artifact_dict = {} for k, v in proto_dict.items(): artifact_dict[k] = [] for artifact_struct in v.elements: if not artifact_struct.HasField('artifact'): raise RuntimeError('Only support artifact oneof field') artifact_and_type = artifact_struct.artifact artifact_dict[k].append( artifact_utils.deserialize_artifact( artifact_and_type.type, artifact_and_type.artifact)) return artifact_dict
def get_pipeline_outputs( metadata_connection_config: Optional[metadata_store_pb2.ConnectionConfig], pipeline_info: data_types.PipelineInfo ) -> Dict[Text, Dict[Text, Dict[int, types.Artifact]]]: """Returns a dictionary of pipeline output artifacts for every component. Args: metadata_connection_config: connection configuration to MLMD. pipeline_info: pipeline info from orchestration. Returns: a dictionary of holding list of artifacts for a component id. """ output_map = {} with metadata.Metadata(metadata_connection_config) as m: context = m.get_pipeline_run_context(pipeline_info) if context is None: raise ValueError( 'No context found with pipeline_info:{}'.format(pipeline_info)) executions = m.store.get_executions_by_context(context.id) for execution in executions: component_id = execution.properties['component_id'].string_value output_dict = {} for event in m.store.get_events_by_execution_ids([execution.id]): if event.type == metadata_store_pb2.Event.OUTPUT: artifacts = m.store.get_artifacts_by_id( [event.artifact_id]) steps = event.path.steps if not steps or not steps[0].HasField('key'): raise ValueError( 'Artifact key is not recorded in the MLMD.') key = steps[0].key artifacts = m.store.get_artifacts_by_id( [event.artifact_id]) if key not in output_dict: output_dict[key] = {} for pb_artifact in artifacts: if len(steps) < 2 or not steps[1].HasField('index'): raise ValueError( 'Artifact index is not recorded in the MLMD.') artifact_index = steps[1].index if artifact_index in output_dict[key]: raise ValueError('Artifact already in output_dict') [artifact_type] = m.store.get_artifact_types_by_id( [pb_artifact.type_id]) artifact = artifact_utils.deserialize_artifact( artifact_type, pb_artifact) output_dict[key][artifact_index] = artifact output_map[component_id] = output_dict return output_map
def resolve( self, pipeline_info: data_types.PipelineInfo, metadata_handler: metadata.Metadata, source_channels: Dict[Text, types.Channel], ) -> base_resolver.ResolveResult: artifacts_dict = {} resolve_state_dict = {} pipeline_context = metadata_handler.get_pipeline_context(pipeline_info) if pipeline_context is None: raise RuntimeError('Pipeline context absent for %s' % pipeline_context) for k, c in source_channels.items(): candidate_artifacts = metadata_handler.get_qualified_artifacts( contexts=[pipeline_context], type_name=c.type_name, producer_component_id=c.producer_component_id, output_key=c.output_key) previous_artifacts = sorted(candidate_artifacts, key=lambda a: a.artifact.id, reverse=True) if len(previous_artifacts) >= self._desired_num_of_artifact: artifacts_dict[k] = [ artifact_utils.deserialize_artifact(a.type, a.artifact) for a in previous_artifacts[:self._desired_num_of_artifact] ] resolve_state_dict[k] = True else: artifacts_dict[k] = [ artifact_utils.deserialize_artifact(a.type, a.artifact) for a in previous_artifacts ] resolve_state_dict[k] = False return base_resolver.ResolveResult( per_key_resolve_result=artifacts_dict, per_key_resolve_state=resolve_state_dict)
def _build_artifact_dict( proto_dict: Mapping[str, metadata_store_service_pb2.ArtifactStructList] ) -> Dict[str, List[types.Artifact]]: """Builds ExecutionInfo input/output artifact dicts.""" result = {} for k, v in proto_dict.items(): result[k] = [] for artifact_struct in v.elements: if not artifact_struct.HasField('artifact'): raise RuntimeError('Only support artifact oneof field') artifact_and_type = artifact_struct.artifact result[k].append( artifact_utils.deserialize_artifact(artifact_and_type.type, artifact_and_type.artifact)) return result
def get_pipeline_outputs( metadata_connection_config: Optional[ metadata_store_pb2.ConnectionConfig], pipeline_name: str ) -> Dict[Text, Dict[Text, Dict[int, types.Artifact]]]: """Returns a dictionary of pipeline output artifacts for every component. Args: metadata_connection_config: connection configuration to MLMD. pipeline_name: Name of the pipeline. Returns: a dictionary of holding list of artifacts for a component id. """ output_map = {} with metadata.Metadata(metadata_connection_config) as m: executions = pipeline_recorder_utils.get_latest_executions( m, pipeline_name) for execution in executions: component_id = pipeline_recorder_utils.get_component_id_from_execution( m, execution) output_dict = {} for event in m.store.get_events_by_execution_ids([execution.id]): if event.type == metadata_store_pb2.Event.OUTPUT: artifacts = m.store.get_artifacts_by_id( [event.artifact_id]) steps = event.path.steps if not steps or not steps[0].HasField('key'): raise ValueError( 'Artifact key is not recorded in the MLMD.') key = steps[0].key artifacts = m.store.get_artifacts_by_id( [event.artifact_id]) if key not in output_dict: output_dict[key] = {} for pb_artifact in artifacts: if len(steps) < 2 or not steps[1].HasField('index'): raise ValueError( 'Artifact index is not recorded in the MLMD.') artifact_index = steps[1].index if artifact_index in output_dict[key]: raise ValueError('Artifact already in output_dict') [artifact_type] = m.store.get_artifact_types_by_id( [pb_artifact.type_id]) artifact = artifact_utils.deserialize_artifact( artifact_type, pb_artifact) output_dict[key][artifact_index] = artifact output_map[component_id] = output_dict return output_map
def _get_outputs_of_execution( metadata_handler: metadata.Metadata, execution_id: int) -> Optional[Dict[Text, List[types.Artifact]]]: """Fetches outputs produced by a historical execution. Args: metadata_handler: A handler to access MLMD store. execution_id: The id of the execution that produced the outputs. Returns: A dict of key -> List[Artifact] as the result if qualified outputs found. Otherwise returns None. """ result = collections.defaultdict(list) output_events = [ event for event in metadata_handler.store.get_events_by_execution_ids( [execution_id]) if event.type == metadata_store_pb2.Event.OUTPUT ] cached_output_artifacts = metadata_handler.store.get_artifacts_by_id( [e.artifact_id for e in output_events]) for artifact in cached_output_artifacts: # Non-live artifact means partial result, will not be used. if artifact.state != metadata_store_pb2.Artifact.LIVE: return None artifact_types = metadata_handler.store.get_artifact_types_by_id( [a.type_id for a in cached_output_artifacts]) for event, mlmd_artifact, artifact_type in zip(output_events, cached_output_artifacts, artifact_types): key = event.path.steps[0].key tfx_artifact = artifact_utils.deserialize_artifact( artifact_type, mlmd_artifact) result[key].append(tfx_artifact) return result
def _parse_raw_artifact(dict_data: Dict[Text, Any]) -> artifact.Artifact: """Parses json serialized version of artifact without artifact_type.""" # This parser can only reserve what's inside artifact pb message. artifact_pb = metadata_store_pb2.Artifact() # TODO(b/152444458): For compatibility, current TFX serialization assumes # there is no type field in Artifact pb message. type_name = dict_data.pop('type') json_format.Parse(json.dumps(dict_data), artifact_pb) # Make an ArtifactType pb according to artifact_pb type_pb = metadata_store_pb2.ArtifactType() type_pb.name = type_name for k, v in artifact_pb.properties.items(): if v.HasField('int_value'): type_pb.properties[k] = metadata_store_pb2.PropertyType.INT elif v.HasField('string_value'): type_pb.properties[k] = metadata_store_pb2.PropertyType.STRING elif v.HasField('double_value'): type_pb.properties[k] = metadata_store_pb2.PropertyType.DOUBLE else: raise ValueError('Unrecognized type encountered at field %s' % (k)) result = artifact_utils.deserialize_artifact(type_pb, artifact_pb) return result
def resolve( self, pipeline_info: data_types.PipelineInfo, metadata_handler: metadata.Metadata, source_channels: Dict[Text, types.Channel], ) -> base_resolver.ResolveResult: # First, checks whether we have exactly Model and ModelBlessing Channels. model_channel_key = None model_blessing_channel_key = None assert len(source_channels) == 2, 'Expecting 2 input Channels' for k, c in source_channels.items(): if issubclass(c.type, standard_artifacts.Model): model_channel_key = k elif issubclass(c.type, standard_artifacts.ModelBlessing): model_blessing_channel_key = k else: raise RuntimeError( 'Only expecting Model or ModelBlessing, got %s' % c.type) assert model_channel_key is not None, 'Expecting Model as input' assert model_blessing_channel_key is not None, ( 'Expecting ModelBlessing as' ' input') model_channel = source_channels[model_channel_key] model_blessing_channel = source_channels[model_blessing_channel_key] # Gets the pipeline context as the artifact search space. pipeline_context = metadata_handler.get_pipeline_context(pipeline_info) if pipeline_context is None: raise RuntimeError('Pipeline context absent for %s' % pipeline_context) # Gets all models in the search space and sort in reverse order by id. all_models = metadata_handler.get_qualified_artifacts( contexts=[pipeline_context], type_name=model_channel.type_name, producer_component_id=model_channel.producer_component_id, output_key=model_channel.output_key) all_models.sort(key=lambda a: a.artifact.id, reverse=True) # Gets all ModelBlessing artifacts in the search space. all_model_blessings = metadata_handler.get_qualified_artifacts( contexts=[pipeline_context], type_name=model_blessing_channel.type_name, producer_component_id=model_blessing_channel.producer_component_id, output_key=model_blessing_channel.output_key) # Makes a dict of {model_id : ModelBlessing artifact} for blessed models. all_blessed_model_ids = dict(( # pylint: disable=g-complex-comprehension a.artifact.custom_properties[ evaluator.ARTIFACT_PROPERTY_CURRENT_MODEL_ID_KEY].int_value, a) for a in all_model_blessings if a.artifact.custom_properties[ evaluator.ARTIFACT_PROPERTY_BLESSED_KEY].int_value == 1) artifacts_dict = { model_channel_key: [], model_blessing_channel_key: [] } resolve_state_dict = { model_channel_key: False, model_blessing_channel_key: False } # Iterates all models, if blessed, set as result. As the model list was # sorted, it is guaranteed to get the latest blessed model. for model in all_models: if model.artifact.id in all_blessed_model_ids: artifacts_dict[model_channel_key] = [ artifact_utils.deserialize_artifact( model.type, model.artifact) ] model_blessing = all_blessed_model_ids[model.artifact.id] artifacts_dict[model_blessing_channel_key] = [ artifact_utils.deserialize_artifact( model_blessing.type, model_blessing.artifact) ] resolve_state_dict[model_channel_key] = True resolve_state_dict[model_blessing_channel_key] = True break return base_resolver.ResolveResult( per_key_resolve_result=artifacts_dict, per_key_resolve_state=resolve_state_dict)
def get_artifacts_dict( metadata_handler: metadata.Metadata, execution_id: int, event_type: 'metadata_store_pb2.Event.Type' ) -> Dict[Text, List[types.Artifact]]: """Returns a map from key to an ordered list of artifacts for the given execution id. The dict is constructed purely from information stored in MLMD for the execution given by `execution_id`. The "key" is the tag associated with the `InputSpec` or `OutputSpec` in the pipeline IR. Args: metadata_handler: A handler to access MLMD. execution_id: Id of the execution for which to get artifacts. event_type: Event type to filter by. Returns: A dict mapping key to an ordered list of artifacts. Raises: ValueError: If the events are badly formed and correct ordering of artifacts cannot be determined or if all the artifacts could not be fetched from MLMD. """ events = metadata_handler.store.get_events_by_execution_ids([execution_id]) # Create a map from "key" to list of (index, artifact_id)s. indexed_artifact_ids_dict = collections.defaultdict(list) for event in events: if event.type != event_type: continue key, index = event_lib.get_artifact_path(event) artifact_id = event.artifact_id indexed_artifact_ids_dict[key].append((index, artifact_id)) # Create a map from "key" to ordered list of artifact ids. artifact_ids_dict = {} for key, indexed_artifact_ids in indexed_artifact_ids_dict.items(): ordered_artifact_ids = sorted(indexed_artifact_ids, key=lambda x: x[0]) # There shouldn't be any missing or duplicate indices. indices = [idx for idx, _ in ordered_artifact_ids] if indices != list(range(0, len(indices))): raise ValueError( f'Cannot construct artifact ids dict due to missing or duplicate ' f'indices: {indexed_artifact_ids_dict}') artifact_ids_dict[key] = [aid for _, aid in ordered_artifact_ids] # Fetch all the relevant artifacts. all_artifact_ids = list(itertools.chain(*artifact_ids_dict.values())) mlmd_artifacts = metadata_handler.store.get_artifacts_by_id( all_artifact_ids) if len(all_artifact_ids) != len(mlmd_artifacts): raise ValueError( 'Could not find all mlmd artifacts for ids: {}'.format( ', '.join(all_artifact_ids))) # Fetch artifact types and create a map keyed by artifact type id. artifact_type_ids = set(a.type_id for a in mlmd_artifacts) artifact_types = metadata_handler.store.get_artifact_types_by_id( artifact_type_ids) artifact_types_by_id = {a.id: a for a in artifact_types} # Create a map from artifact id to `types.Artifact` instances. artifacts_by_id = { aid: artifact_utils.deserialize_artifact(artifact_types_by_id[a.type_id], a) for aid, a in zip(all_artifact_ids, mlmd_artifacts) } # Create a map from "key" to ordered list of `types.Artifact` to be returned. # The ordering of artifacts is in accordance with their "index" derived from # the events above. result = collections.defaultdict(list) for key, artifact_ids in artifact_ids_dict.items(): for artifact_id in artifact_ids: result[key].append(artifacts_by_id[artifact_id]) return result
def resolve( self, pipeline_info: data_types.PipelineInfo, metadata_handler: metadata.Metadata, source_channels: Dict[str, types.Channel], ) -> resolver.ResolveResult: # First, checks whether we have exactly Model and ModelBlessing Channels. model_channel_key = None model_blessing_channel_key = None assert len(source_channels) == 2, 'Expecting 2 input Channels' for k, c in source_channels.items(): if issubclass(c.type, standard_artifacts.Model): model_channel_key = k elif issubclass(c.type, standard_artifacts.ModelBlessing): model_blessing_channel_key = k else: raise RuntimeError( 'Only expecting Model or ModelBlessing, got %s' % c.type) assert model_channel_key is not None, 'Expecting Model as input' assert model_blessing_channel_key is not None, ( 'Expecting ModelBlessing as' ' input') model_channel = source_channels[model_channel_key] model_blessing_channel = source_channels[model_blessing_channel_key] # Gets the pipeline context as the artifact search space. pipeline_context = metadata_handler.get_pipeline_context(pipeline_info) if pipeline_context is None: raise RuntimeError('Pipeline context absent for %s' % pipeline_context) candidate_dict = {} # Gets all models in the search space and sort in reverse order by id. all_models = metadata_handler.get_qualified_artifacts( contexts=[pipeline_context], type_name=model_channel.type_name, producer_component_id=model_channel.producer_component_id, output_key=model_channel.output_key) candidate_dict[model_channel_key] = [ artifact_utils.deserialize_artifact(a.type, a.artifact) for a in all_models ] # Gets all ModelBlessing artifacts in the search space. all_model_blessings = metadata_handler.get_qualified_artifacts( contexts=[pipeline_context], type_name=model_blessing_channel.type_name, producer_component_id=model_blessing_channel.producer_component_id, output_key=model_blessing_channel.output_key) candidate_dict[model_blessing_channel_key] = [ artifact_utils.deserialize_artifact(a.type, a.artifact) for a in all_model_blessings ] resolved_dict = self._resolve(candidate_dict, model_channel_key, model_blessing_channel_key) resolve_state_dict = { k: bool(artifact_list) for k, artifact_list in resolved_dict.items() } return resolver.ResolveResult(per_key_resolve_result=resolved_dict, per_key_resolve_state=resolve_state_dict)