def get_executions( metadata_handler: metadata.Metadata, node: pipeline_pb2.PipelineNode) -> List[metadata_store_pb2.Execution]: """Returns all executions for the given pipeline node. This finds all executions having the same set of contexts as the pipeline node. Args: metadata_handler: A handler to access MLMD db. node: The pipeline node for which to obtain executions. Returns: List of executions for the given node in MLMD db. """ # Get all the contexts associated with the node. contexts = [] for context_spec in node.contexts.contexts: context = metadata_handler.store.get_context_by_type_and_name( context_spec.type.name, common_utils.get_value(context_spec.name)) if context is None: # If no context is registered, it's certain that there is no # associated execution for the node. return [] contexts.append(context) return execution_lib.get_executions_associated_with_all_contexts( metadata_handler, contexts)
def get_qualified_artifacts( metadata_handler: metadata.Metadata, contexts: Iterable[metadata_store_pb2.Context], artifact_type: metadata_store_pb2.ArtifactType, output_key: Optional[str] = None, ) -> List[types.Artifact]: """Gets qualified artifacts that have the right producer info. Args: metadata_handler: A metadata handler to access MLMD store. contexts: Context constraints to filter artifacts artifact_type: Type constraint to filter artifacts output_key: Output key constraint to filter artifacts Returns: A list of qualified TFX Artifacts. """ # We expect to have at least one context for input resolution. assert contexts, 'Must have at least one context.' try: artifact_type_name = artifact_type.name artifact_type = metadata_handler.store.get_artifact_type( artifact_type_name) except mlmd.errors.NotFoundError: logging.warning('Artifact type %s is not found in MLMD.', artifact_type.name) artifact_type = None if not artifact_type: return [] executions_within_context = ( execution_lib.get_executions_associated_with_all_contexts( metadata_handler, contexts)) # Filters out non-success executions. qualified_producer_executions = [ e.id for e in executions_within_context if execution_lib.is_execution_successful(e) ] # Gets the output events that have the matched output key. qualified_output_events = [ ev for ev in metadata_handler.store.get_events_by_execution_ids( qualified_producer_executions) if event_lib.validate_output_event(ev, output_key) ] # Gets the candidate artifacts from output events. candidate_artifacts = metadata_handler.store.get_artifacts_by_id( list(set(ev.artifact_id for ev in qualified_output_events))) # Filters the artifacts that have the right artifact type and state. qualified_artifacts = [ a for a in candidate_artifacts if a.type_id == artifact_type.id and a.state == metadata_store_pb2.Artifact.LIVE ] return [ artifact_utils.deserialize_artifact(artifact_type, a) for a in qualified_artifacts ]
def _get_successful_executions( self, node_id: str, run_id: str) -> List[metadata_store_pb2.Execution]: """Gets all successful Executions of a given node in a given pipeline run. Args: node_id: The node whose Executions to query. run_id: The pipeline run id to query the Executions from. Returns: All successful executions for that node at that run_id. Raises: LookupError: If no successful Execution was found. """ node_context = self._get_node_context(node_id) base_run_context = self._get_pipeline_run_context(run_id) all_associated_executions = ( execution_lib.get_executions_associated_with_all_contexts( self._mlmd, contexts=[ node_context, base_run_context, self._pipeline_context ])) prev_successful_executions = [ e for e in all_associated_executions if execution_lib.is_execution_successful(e) ] if not prev_successful_executions: raise LookupError( f'No previous successful executions found for node_id {node_id} in ' f'pipeline_run {run_id}') return execution_lib.sort_executions_newest_to_oldest( prev_successful_executions)
def testGetExecutionsAssociatedWithAllContexts(self): with metadata.Metadata(connection_config=self._connection_config) as m: contexts = self._generate_contexts(m) self.assertLen(contexts, 2) # Create 2 executions and associate with one context each. execution1 = execution_lib.prepare_execution( m, metadata_store_pb2.ExecutionType(name='my_execution_type'), metadata_store_pb2.Execution.RUNNING) execution1 = execution_lib.put_execution(m, execution1, [contexts[0]]) execution2 = execution_lib.prepare_execution( m, metadata_store_pb2.ExecutionType(name='my_execution_type'), metadata_store_pb2.Execution.COMPLETE) execution2 = execution_lib.put_execution(m, execution2, [contexts[1]]) # Create another execution and associate with both contexts. execution3 = execution_lib.prepare_execution( m, metadata_store_pb2.ExecutionType(name='my_execution_type'), metadata_store_pb2.Execution.NEW) execution3 = execution_lib.put_execution(m, execution3, contexts) # Verify that the right executions are returned. with self.subTest(for_contexts=(0, )): executions = execution_lib.get_executions_associated_with_all_contexts( m, [contexts[0]]) self.assertCountEqual([execution1.id, execution3.id], [e.id for e in executions]) with self.subTest(for_contexts=(1, )): executions = execution_lib.get_executions_associated_with_all_contexts( m, [contexts[1]]) self.assertCountEqual([execution2.id, execution3.id], [e.id for e in executions]) with self.subTest(for_contexts=(0, 1)): executions = execution_lib.get_executions_associated_with_all_contexts( m, contexts) self.assertCountEqual([execution3.id], [e.id for e in executions])
def _cache_and_publish(self, existing_execution: metadata_store_pb2.Execution): """Updates MLMD.""" cached_execution_contexts = self._get_cached_execution_contexts( existing_execution) # Check if there are any previous attempts to cache and publish. prev_cache_executions = ( execution_lib.get_executions_associated_with_all_contexts( self._mlmd, contexts=cached_execution_contexts)) if not prev_cache_executions: new_execution = execution_publish_utils.register_execution( self._mlmd, execution_type=metadata_store_pb2.ExecutionType( id=existing_execution.type_id), contexts=cached_execution_contexts) else: if len(prev_cache_executions) > 1: logging.warning( 'More than one previous cache executions seen when attempting ' 'reuse_node_outputs: %s', prev_cache_executions) if (prev_cache_executions[-1].last_known_state == metadata_store_pb2.Execution.CACHED): return else: new_execution = prev_cache_executions[-1] output_artifacts = execution_lib.get_artifacts_dict( self._mlmd, existing_execution.id, event_types=list(event_lib.VALID_OUTPUT_EVENT_TYPES)) execution_publish_utils.publish_cached_execution( self._mlmd, contexts=cached_execution_contexts, execution_id=new_execution.id, output_artifacts=output_artifacts)