def get_qualified_artifacts( metadata_handler: metadata.Metadata, contexts: Iterable[metadata_store_pb2.Context], artifact_type: metadata_store_pb2.ArtifactType, output_key: Optional[str] = None, ) -> List[types.Artifact]: """Gets qualified artifacts that have the right producer info. Args: metadata_handler: A metadata handler to access MLMD store. contexts: Context constraints to filter artifacts artifact_type: Type constraint to filter artifacts output_key: Output key constraint to filter artifacts Returns: A list of qualified TFX Artifacts. """ # We expect to have at least one context for input resolution. assert contexts, 'Must have at least one context.' try: artifact_type_name = artifact_type.name artifact_type = metadata_handler.store.get_artifact_type( artifact_type_name) except mlmd.errors.NotFoundError: logging.warning('Artifact type %s is not found in MLMD.', artifact_type.name) artifact_type = None if not artifact_type: return [] executions_within_context = ( execution_lib.get_executions_associated_with_all_contexts( metadata_handler, contexts)) # Filters out non-success executions. qualified_producer_executions = [ e.id for e in executions_within_context if execution_lib.is_execution_successful(e) ] # Gets the output events that have the matched output key. qualified_output_events = [ ev for ev in metadata_handler.store.get_events_by_execution_ids( qualified_producer_executions) if event_lib.validate_output_event(ev, output_key) ] # Gets the candidate artifacts from output events. candidate_artifacts = metadata_handler.store.get_artifacts_by_id( list(set(ev.artifact_id for ev in qualified_output_events))) # Filters the artifacts that have the right artifact type and state. qualified_artifacts = [ a for a in candidate_artifacts if a.type_id == artifact_type.id and a.state == metadata_store_pb2.Artifact.LIVE ] return [ artifact_utils.deserialize_artifact(artifact_type, a) for a in qualified_artifacts ]
def _get_successful_executions( self, node_id: str, run_id: str) -> List[metadata_store_pb2.Execution]: """Gets all successful Executions of a given node in a given pipeline run. Args: node_id: The node whose Executions to query. run_id: The pipeline run id to query the Executions from. Returns: All successful executions for that node at that run_id. Raises: LookupError: If no successful Execution was found. """ node_context = self._get_node_context(node_id) base_run_context = self._get_pipeline_run_context(run_id) all_associated_executions = ( execution_lib.get_executions_associated_with_all_contexts( self._mlmd, contexts=[ node_context, base_run_context, self._pipeline_context ])) prev_successful_executions = [ e for e in all_associated_executions if execution_lib.is_execution_successful(e) ] if not prev_successful_executions: raise LookupError( f'No previous successful executions found for node_id {node_id} in ' f'pipeline_run {run_id}') return execution_lib.sort_executions_newest_to_oldest( prev_successful_executions)
def get_latest_successful_execution( executions: Iterable[metadata_store_pb2.Execution] ) -> Optional[metadata_store_pb2.Execution]: """Returns the latest successful execution or `None` if no successful executions exist.""" successful_executions = [ e for e in executions if execution_lib.is_execution_successful(e) ] return get_latest_execution(successful_executions)
def test_cached_execution(self): """Tests that cached execution is used if one is available.""" # Fake ExampleGen run. example_gen_exec = otu.fake_example_gen_run(self._mlmd_connection, self._example_gen, 1, 1) # Invoking generator should produce an ExecNodeTask for StatsGen. [stats_gen_task] = self._generate_and_test( False, num_initial_executions=1, num_tasks_generated=1, num_new_executions=1, num_active_executions=1) self.assertEqual('my_statistics_gen', stats_gen_task.node_uid.node_id) # Finish StatsGen execution. otu.fake_execute_node(self._mlmd_connection, stats_gen_task) # Prepare another pipeline with a new pipeline_run_id. pipeline_run_id = str(uuid.uuid4()) new_pipeline = self._make_pipeline(self._pipeline_root, pipeline_run_id) with self._mlmd_connection as m: contexts = m.store.get_contexts_by_execution(example_gen_exec.id) # We use node context as cache context for ease of testing. cache_context = [c for c in contexts if c.name == 'my_example_gen'][0] # Fake example_gen cached execution. otu.fake_cached_execution(self._mlmd_connection, cache_context, otu.get_node(new_pipeline, 'my_example_gen')) stats_gen = otu.get_node(new_pipeline, 'my_statistics_gen') # Invoking generator for the new pipeline should result in: # 1. StatsGen execution succeeds with state "CACHED" but no ExecNodeTask # generated. # 2. An ExecNodeTask is generated for SchemaGen (component downstream of # StatsGen) with an active execution in MLMD. [schema_gen_task] = self._generate_and_test( False, pipeline=new_pipeline, num_initial_executions=3, num_tasks_generated=1, num_new_executions=2, num_active_executions=1) self.assertEqual('my_schema_gen', schema_gen_task.node_uid.node_id) # Check that StatsGen execution is successful in state "CACHED". with self._mlmd_connection as m: executions = task_gen_utils.get_executions(m, stats_gen) self.assertLen(executions, 1) execution = executions[0] self.assertTrue(execution_lib.is_execution_successful(execution)) self.assertEqual(metadata_store_pb2.Execution.CACHED, execution.last_known_state)
def get_latest_successful_execution( executions: Iterable[metadata_store_pb2.Execution] ) -> Optional[metadata_store_pb2.Execution]: """Returns the latest successful execution or `None` if no successful executions exist.""" successful_executions = [ e for e in executions if execution_lib.is_execution_successful(e) ] if successful_executions: return sorted(successful_executions, key=lambda e: e.create_time_since_epoch, reverse=True)[0] return None
def is_latest_execution_successful( executions: Sequence[metadata_store_pb2.Execution]) -> bool: """Returns `True` if the latest execution was successful. Latest execution will have the most recent `create_time_since_epoch`. Args: executions: A sequence of executions. Returns: `True` if latest execution (per `create_time_since_epoch` was successful. `False` if `executions` is empty or if latest execution was not successful. """ execution = get_latest_execution(executions) return execution_lib.is_execution_successful( execution) if execution else False
def is_latest_execution_successful( executions: Sequence[metadata_store_pb2.Execution]) -> bool: """Returns `True` if the latest execution was successful. Latest execution will have the most recent `create_time_since_epoch`. Args: executions: A sequence of executions. Returns: `True` if latest execution (per `create_time_since_epoch` was successful. `False` if `executions` is empty or if latest execution was not successful. """ sorted_executions = sorted( executions, key=lambda e: e.create_time_since_epoch, reverse=True) return (execution_lib.is_execution_successful(sorted_executions[0]) if sorted_executions else False)
def _generate_tasks_for_node( self, node: pipeline_pb2.PipelineNode) -> List[task_lib.Task]: """Generates list of tasks for the given node.""" node_uid = task_lib.NodeUid.from_pipeline_node(self._pipeline, node) node_id = node.node_info.id result = [] node_state = self._node_states_dict[node_uid] if node_state.state in (pstate.NodeState.STOPPING, pstate.NodeState.STOPPED): logging.info( 'Ignoring node in state \'%s\' for task generation: %s', node_state.state, node_uid) return result # If this is a pure service node, there is no ExecNodeTask to generate # but we ensure node services and check service status. service_status = self._ensure_node_services_if_pure(node_id) if service_status is not None: if service_status == service_jobs.ServiceStatus.FAILED: error_msg = f'service job failed; node uid: {node_uid}' result.append( task_lib.UpdateNodeStateTask( node_uid=node_uid, state=pstate.NodeState.FAILED, status=status_lib.Status(code=status_lib.Code.ABORTED, message=error_msg))) elif service_status == service_jobs.ServiceStatus.SUCCESS: logging.info('Service node successful: %s', node_uid) result.append( task_lib.UpdateNodeStateTask( node_uid=node_uid, state=pstate.NodeState.COMPLETE)) elif service_status == service_jobs.ServiceStatus.RUNNING: result.append( task_lib.UpdateNodeStateTask( node_uid=node_uid, state=pstate.NodeState.RUNNING)) return result # If a task for the node is already tracked by the task queue, it need # not be considered for generation again but we ensure node services # in case of a mixed service node. if self._is_task_id_tracked_fn( task_lib.exec_node_task_id_from_pipeline_node( self._pipeline, node)): service_status = self._ensure_node_services_if_mixed(node_id) if service_status == service_jobs.ServiceStatus.FAILED: error_msg = f'associated service job failed; node uid: {node_uid}' result.append( task_lib.UpdateNodeStateTask( node_uid=node_uid, state=pstate.NodeState.FAILED, status=status_lib.Status(code=status_lib.Code.ABORTED, message=error_msg))) return result node_executions = task_gen_utils.get_executions( self._mlmd_handle, node) latest_execution = task_gen_utils.get_latest_execution(node_executions) # If the latest execution is successful, we're done. if latest_execution and execution_lib.is_execution_successful( latest_execution): logging.info('Node successful: %s', node_uid) result.append( task_lib.UpdateNodeStateTask(node_uid=node_uid, state=pstate.NodeState.COMPLETE)) return result # If the latest execution failed or cancelled, the pipeline should be # aborted if the node is not in state STARTING. For nodes that are # in state STARTING, a new execution is created. if (latest_execution and not execution_lib.is_execution_active(latest_execution) and node_state.state != pstate.NodeState.STARTING): error_msg_value = latest_execution.custom_properties.get( constants.EXECUTION_ERROR_MSG_KEY) error_msg = data_types_utils.get_metadata_value( error_msg_value) if error_msg_value else '' error_msg = f'node failed; node uid: {node_uid}; error: {error_msg}' result.append( task_lib.UpdateNodeStateTask(node_uid=node_uid, state=pstate.NodeState.FAILED, status=status_lib.Status( code=status_lib.Code.ABORTED, message=error_msg))) return result exec_node_task = task_gen_utils.generate_task_from_active_execution( self._mlmd_handle, self._pipeline, node, node_executions) if exec_node_task: result.append( task_lib.UpdateNodeStateTask(node_uid=node_uid, state=pstate.NodeState.RUNNING)) result.append(exec_node_task) return result # Finally, we are ready to generate tasks for the node by resolving inputs. result.extend(self._resolve_inputs_and_generate_tasks_for_node(node)) return result
def generate(self) -> List[task_lib.Task]: """Generates tasks for executing the next executable nodes in the pipeline. The returned tasks must have `exec_task` populated. List may be empty if no nodes are ready for execution. Returns: A `list` of tasks to execute. """ layers = topsort.topsorted_layers( [node.pipeline_node for node in self._pipeline.nodes], get_node_id_fn=lambda node: node.node_info.id, get_parent_nodes=( lambda node: [self._node_map[n] for n in node.upstream_nodes]), get_child_nodes=( lambda node: [self._node_map[n] for n in node.downstream_nodes])) result = [] successful_node_ids = set() for layer_num, layer_nodes in enumerate(layers): for node in layer_nodes: node_uid = task_lib.NodeUid.from_pipeline_node( self._pipeline, node) node_id = node.node_info.id if self._in_successful_nodes_cache(node_uid): successful_node_ids.add(node_id) continue if not self._upstream_nodes_successful(node, successful_node_ids): continue # If this is a pure service node, there is no ExecNodeTask to generate # but we ensure node services and check service status. service_status = self._ensure_node_services_if_pure(node_id) if service_status is not None: if service_status == service_jobs.ServiceStatus.FAILED: return [ self._abort_task( f'service job failed; node uid: {node_uid}') ] if service_status == service_jobs.ServiceStatus.SUCCESS: logging.info('Service node successful: %s', node_uid) successful_node_ids.add(node_id) continue # If a task for the node is already tracked by the task queue, it need # not be considered for generation again but we ensure node services # in case of a mixed service node. if self._is_task_id_tracked_fn( task_lib.exec_node_task_id_from_pipeline_node( self._pipeline, node)): service_status = self._ensure_node_services_if_mixed( node_id) if service_status == service_jobs.ServiceStatus.FAILED: return [ self._abort_task( f'associated service job failed; node uid: {node_uid}' ) ] continue node_executions = task_gen_utils.get_executions( self._mlmd_handle, node) latest_execution = task_gen_utils.get_latest_execution( node_executions) # If the latest execution is successful, we're done. if latest_execution and execution_lib.is_execution_successful( latest_execution): logging.info('Node successful: %s', node_uid) successful_node_ids.add(node_id) continue # If the latest execution failed, the pipeline should be aborted. if latest_execution and not execution_lib.is_execution_active( latest_execution): error_msg_value = latest_execution.custom_properties.get( constants.EXECUTION_ERROR_MSG_KEY) error_msg = data_types_utils.get_metadata_value( error_msg_value) if error_msg_value else '' return [ self._abort_task( f'node failed; node uid: {node_uid}; error: {error_msg}' ) ] # Finally, we are ready to generate an ExecNodeTask for the node. task = self._maybe_generate_task(node, node_executions, successful_node_ids) if task: if task_lib.is_finalize_pipeline_task(task): return [task] else: result.append(task) layer_node_ids = set(node.node_info.id for node in layer_nodes) successful_layer_node_ids = layer_node_ids & successful_node_ids self._update_successful_nodes_cache(successful_layer_node_ids) # If all nodes in the final layer are completed successfully , the # pipeline can be finalized. # TODO(goutham): If there are conditional eval nodes, not all nodes may be # executed in the final layer. Handle this case when conditionals are # supported. if (layer_num == len(layers) - 1 and successful_layer_node_ids == layer_node_ids): return [ task_lib.FinalizePipelineTask( pipeline_uid=self._pipeline_uid, status=status_lib.Status(code=status_lib.Code.OK)) ] return result