Beispiel #1
0
def get_qualified_artifacts(
    metadata_handler: metadata.Metadata,
    contexts: Iterable[metadata_store_pb2.Context],
    artifact_type: metadata_store_pb2.ArtifactType,
    output_key: Optional[str] = None,
) -> List[types.Artifact]:
    """Gets qualified artifacts that have the right producer info.

  Args:
    metadata_handler: A metadata handler to access MLMD store.
    contexts: Context constraints to filter artifacts
    artifact_type: Type constraint to filter artifacts
    output_key: Output key constraint to filter artifacts

  Returns:
    A list of qualified TFX Artifacts.
  """
    # We expect to have at least one context for input resolution.
    assert contexts, 'Must have at least one context.'

    try:
        artifact_type_name = artifact_type.name
        artifact_type = metadata_handler.store.get_artifact_type(
            artifact_type_name)
    except mlmd.errors.NotFoundError:
        logging.warning('Artifact type %s is not found in MLMD.',
                        artifact_type.name)
        artifact_type = None

    if not artifact_type:
        return []

    executions_within_context = (
        execution_lib.get_executions_associated_with_all_contexts(
            metadata_handler, contexts))

    # Filters out non-success executions.
    qualified_producer_executions = [
        e.id for e in executions_within_context
        if execution_lib.is_execution_successful(e)
    ]
    # Gets the output events that have the matched output key.
    qualified_output_events = [
        ev for ev in metadata_handler.store.get_events_by_execution_ids(
            qualified_producer_executions)
        if event_lib.validate_output_event(ev, output_key)
    ]

    # Gets the candidate artifacts from output events.
    candidate_artifacts = metadata_handler.store.get_artifacts_by_id(
        list(set(ev.artifact_id for ev in qualified_output_events)))
    # Filters the artifacts that have the right artifact type and state.
    qualified_artifacts = [
        a for a in candidate_artifacts if a.type_id == artifact_type.id
        and a.state == metadata_store_pb2.Artifact.LIVE
    ]
    return [
        artifact_utils.deserialize_artifact(artifact_type, a)
        for a in qualified_artifacts
    ]
Beispiel #2
0
    def _get_successful_executions(
            self, node_id: str,
            run_id: str) -> List[metadata_store_pb2.Execution]:
        """Gets all successful Executions of a given node in a given pipeline run.

    Args:
      node_id: The node whose Executions to query.
      run_id: The pipeline run id to query the Executions from.

    Returns:
      All successful executions for that node at that run_id.

    Raises:
      LookupError: If no successful Execution was found.
    """
        node_context = self._get_node_context(node_id)
        base_run_context = self._get_pipeline_run_context(run_id)
        all_associated_executions = (
            execution_lib.get_executions_associated_with_all_contexts(
                self._mlmd,
                contexts=[
                    node_context, base_run_context, self._pipeline_context
                ]))
        prev_successful_executions = [
            e for e in all_associated_executions
            if execution_lib.is_execution_successful(e)
        ]
        if not prev_successful_executions:
            raise LookupError(
                f'No previous successful executions found for node_id {node_id} in '
                f'pipeline_run {run_id}')

        return execution_lib.sort_executions_newest_to_oldest(
            prev_successful_executions)
Beispiel #3
0
def get_latest_successful_execution(
    executions: Iterable[metadata_store_pb2.Execution]
) -> Optional[metadata_store_pb2.Execution]:
    """Returns the latest successful execution or `None` if no successful executions exist."""
    successful_executions = [
        e for e in executions if execution_lib.is_execution_successful(e)
    ]
    return get_latest_execution(successful_executions)
  def test_cached_execution(self):
    """Tests that cached execution is used if one is available."""

    # Fake ExampleGen run.
    example_gen_exec = otu.fake_example_gen_run(self._mlmd_connection,
                                                self._example_gen, 1, 1)

    # Invoking generator should produce an ExecNodeTask for StatsGen.
    [stats_gen_task] = self._generate_and_test(
        False,
        num_initial_executions=1,
        num_tasks_generated=1,
        num_new_executions=1,
        num_active_executions=1)
    self.assertEqual('my_statistics_gen', stats_gen_task.node_uid.node_id)

    # Finish StatsGen execution.
    otu.fake_execute_node(self._mlmd_connection, stats_gen_task)

    # Prepare another pipeline with a new pipeline_run_id.
    pipeline_run_id = str(uuid.uuid4())
    new_pipeline = self._make_pipeline(self._pipeline_root, pipeline_run_id)

    with self._mlmd_connection as m:
      contexts = m.store.get_contexts_by_execution(example_gen_exec.id)
      # We use node context as cache context for ease of testing.
      cache_context = [c for c in contexts if c.name == 'my_example_gen'][0]
    # Fake example_gen cached execution.
    otu.fake_cached_execution(self._mlmd_connection, cache_context,
                              otu.get_node(new_pipeline, 'my_example_gen'))

    stats_gen = otu.get_node(new_pipeline, 'my_statistics_gen')

    # Invoking generator for the new pipeline should result in:
    # 1. StatsGen execution succeeds with state "CACHED" but no ExecNodeTask
    #    generated.
    # 2. An ExecNodeTask is generated for SchemaGen (component downstream of
    #    StatsGen) with an active execution in MLMD.
    [schema_gen_task] = self._generate_and_test(
        False,
        pipeline=new_pipeline,
        num_initial_executions=3,
        num_tasks_generated=1,
        num_new_executions=2,
        num_active_executions=1)
    self.assertEqual('my_schema_gen', schema_gen_task.node_uid.node_id)

    # Check that StatsGen execution is successful in state "CACHED".
    with self._mlmd_connection as m:
      executions = task_gen_utils.get_executions(m, stats_gen)
      self.assertLen(executions, 1)
      execution = executions[0]
      self.assertTrue(execution_lib.is_execution_successful(execution))
      self.assertEqual(metadata_store_pb2.Execution.CACHED,
                       execution.last_known_state)
Beispiel #5
0
def get_latest_successful_execution(
    executions: Iterable[metadata_store_pb2.Execution]
) -> Optional[metadata_store_pb2.Execution]:
    """Returns the latest successful execution or `None` if no successful executions exist."""
    successful_executions = [
        e for e in executions if execution_lib.is_execution_successful(e)
    ]
    if successful_executions:
        return sorted(successful_executions,
                      key=lambda e: e.create_time_since_epoch,
                      reverse=True)[0]
    return None
Beispiel #6
0
def is_latest_execution_successful(
        executions: Sequence[metadata_store_pb2.Execution]) -> bool:
    """Returns `True` if the latest execution was successful.

  Latest execution will have the most recent `create_time_since_epoch`.

  Args:
    executions: A sequence of executions.

  Returns:
    `True` if latest execution (per `create_time_since_epoch` was successful.
    `False` if `executions` is empty or if latest execution was not successful.
  """
    execution = get_latest_execution(executions)
    return execution_lib.is_execution_successful(
        execution) if execution else False
Beispiel #7
0
def is_latest_execution_successful(
    executions: Sequence[metadata_store_pb2.Execution]) -> bool:
  """Returns `True` if the latest execution was successful.

  Latest execution will have the most recent `create_time_since_epoch`.

  Args:
    executions: A sequence of executions.

  Returns:
    `True` if latest execution (per `create_time_since_epoch` was successful.
    `False` if `executions` is empty or if latest execution was not successful.
  """
  sorted_executions = sorted(
      executions, key=lambda e: e.create_time_since_epoch, reverse=True)
  return (execution_lib.is_execution_successful(sorted_executions[0])
          if sorted_executions else False)
Beispiel #8
0
    def _generate_tasks_for_node(
            self, node: pipeline_pb2.PipelineNode) -> List[task_lib.Task]:
        """Generates list of tasks for the given node."""
        node_uid = task_lib.NodeUid.from_pipeline_node(self._pipeline, node)
        node_id = node.node_info.id
        result = []

        node_state = self._node_states_dict[node_uid]
        if node_state.state in (pstate.NodeState.STOPPING,
                                pstate.NodeState.STOPPED):
            logging.info(
                'Ignoring node in state \'%s\' for task generation: %s',
                node_state.state, node_uid)
            return result

        # If this is a pure service node, there is no ExecNodeTask to generate
        # but we ensure node services and check service status.
        service_status = self._ensure_node_services_if_pure(node_id)
        if service_status is not None:
            if service_status == service_jobs.ServiceStatus.FAILED:
                error_msg = f'service job failed; node uid: {node_uid}'
                result.append(
                    task_lib.UpdateNodeStateTask(
                        node_uid=node_uid,
                        state=pstate.NodeState.FAILED,
                        status=status_lib.Status(code=status_lib.Code.ABORTED,
                                                 message=error_msg)))
            elif service_status == service_jobs.ServiceStatus.SUCCESS:
                logging.info('Service node successful: %s', node_uid)
                result.append(
                    task_lib.UpdateNodeStateTask(
                        node_uid=node_uid, state=pstate.NodeState.COMPLETE))
            elif service_status == service_jobs.ServiceStatus.RUNNING:
                result.append(
                    task_lib.UpdateNodeStateTask(
                        node_uid=node_uid, state=pstate.NodeState.RUNNING))
            return result

        # If a task for the node is already tracked by the task queue, it need
        # not be considered for generation again but we ensure node services
        # in case of a mixed service node.
        if self._is_task_id_tracked_fn(
                task_lib.exec_node_task_id_from_pipeline_node(
                    self._pipeline, node)):
            service_status = self._ensure_node_services_if_mixed(node_id)
            if service_status == service_jobs.ServiceStatus.FAILED:
                error_msg = f'associated service job failed; node uid: {node_uid}'
                result.append(
                    task_lib.UpdateNodeStateTask(
                        node_uid=node_uid,
                        state=pstate.NodeState.FAILED,
                        status=status_lib.Status(code=status_lib.Code.ABORTED,
                                                 message=error_msg)))
            return result

        node_executions = task_gen_utils.get_executions(
            self._mlmd_handle, node)
        latest_execution = task_gen_utils.get_latest_execution(node_executions)

        # If the latest execution is successful, we're done.
        if latest_execution and execution_lib.is_execution_successful(
                latest_execution):
            logging.info('Node successful: %s', node_uid)
            result.append(
                task_lib.UpdateNodeStateTask(node_uid=node_uid,
                                             state=pstate.NodeState.COMPLETE))
            return result

        # If the latest execution failed or cancelled, the pipeline should be
        # aborted if the node is not in state STARTING. For nodes that are
        # in state STARTING, a new execution is created.
        if (latest_execution
                and not execution_lib.is_execution_active(latest_execution)
                and node_state.state != pstate.NodeState.STARTING):
            error_msg_value = latest_execution.custom_properties.get(
                constants.EXECUTION_ERROR_MSG_KEY)
            error_msg = data_types_utils.get_metadata_value(
                error_msg_value) if error_msg_value else ''
            error_msg = f'node failed; node uid: {node_uid}; error: {error_msg}'
            result.append(
                task_lib.UpdateNodeStateTask(node_uid=node_uid,
                                             state=pstate.NodeState.FAILED,
                                             status=status_lib.Status(
                                                 code=status_lib.Code.ABORTED,
                                                 message=error_msg)))
            return result

        exec_node_task = task_gen_utils.generate_task_from_active_execution(
            self._mlmd_handle, self._pipeline, node, node_executions)
        if exec_node_task:
            result.append(
                task_lib.UpdateNodeStateTask(node_uid=node_uid,
                                             state=pstate.NodeState.RUNNING))
            result.append(exec_node_task)
            return result

        # Finally, we are ready to generate tasks for the node by resolving inputs.
        result.extend(self._resolve_inputs_and_generate_tasks_for_node(node))
        return result
Beispiel #9
0
    def generate(self) -> List[task_lib.Task]:
        """Generates tasks for executing the next executable nodes in the pipeline.

    The returned tasks must have `exec_task` populated. List may be empty if
    no nodes are ready for execution.

    Returns:
      A `list` of tasks to execute.
    """
        layers = topsort.topsorted_layers(
            [node.pipeline_node for node in self._pipeline.nodes],
            get_node_id_fn=lambda node: node.node_info.id,
            get_parent_nodes=(
                lambda node: [self._node_map[n] for n in node.upstream_nodes]),
            get_child_nodes=(
                lambda node:
                [self._node_map[n] for n in node.downstream_nodes]))
        result = []
        successful_node_ids = set()
        for layer_num, layer_nodes in enumerate(layers):
            for node in layer_nodes:
                node_uid = task_lib.NodeUid.from_pipeline_node(
                    self._pipeline, node)
                node_id = node.node_info.id

                if self._in_successful_nodes_cache(node_uid):
                    successful_node_ids.add(node_id)
                    continue

                if not self._upstream_nodes_successful(node,
                                                       successful_node_ids):
                    continue

                # If this is a pure service node, there is no ExecNodeTask to generate
                # but we ensure node services and check service status.
                service_status = self._ensure_node_services_if_pure(node_id)
                if service_status is not None:
                    if service_status == service_jobs.ServiceStatus.FAILED:
                        return [
                            self._abort_task(
                                f'service job failed; node uid: {node_uid}')
                        ]
                    if service_status == service_jobs.ServiceStatus.SUCCESS:
                        logging.info('Service node successful: %s', node_uid)
                        successful_node_ids.add(node_id)
                    continue

                # If a task for the node is already tracked by the task queue, it need
                # not be considered for generation again but we ensure node services
                # in case of a mixed service node.
                if self._is_task_id_tracked_fn(
                        task_lib.exec_node_task_id_from_pipeline_node(
                            self._pipeline, node)):
                    service_status = self._ensure_node_services_if_mixed(
                        node_id)
                    if service_status == service_jobs.ServiceStatus.FAILED:
                        return [
                            self._abort_task(
                                f'associated service job failed; node uid: {node_uid}'
                            )
                        ]
                    continue

                node_executions = task_gen_utils.get_executions(
                    self._mlmd_handle, node)
                latest_execution = task_gen_utils.get_latest_execution(
                    node_executions)

                # If the latest execution is successful, we're done.
                if latest_execution and execution_lib.is_execution_successful(
                        latest_execution):
                    logging.info('Node successful: %s', node_uid)
                    successful_node_ids.add(node_id)
                    continue

                # If the latest execution failed, the pipeline should be aborted.
                if latest_execution and not execution_lib.is_execution_active(
                        latest_execution):
                    error_msg_value = latest_execution.custom_properties.get(
                        constants.EXECUTION_ERROR_MSG_KEY)
                    error_msg = data_types_utils.get_metadata_value(
                        error_msg_value) if error_msg_value else ''
                    return [
                        self._abort_task(
                            f'node failed; node uid: {node_uid}; error: {error_msg}'
                        )
                    ]

                # Finally, we are ready to generate an ExecNodeTask for the node.
                task = self._maybe_generate_task(node, node_executions,
                                                 successful_node_ids)
                if task:
                    if task_lib.is_finalize_pipeline_task(task):
                        return [task]
                    else:
                        result.append(task)

            layer_node_ids = set(node.node_info.id for node in layer_nodes)
            successful_layer_node_ids = layer_node_ids & successful_node_ids
            self._update_successful_nodes_cache(successful_layer_node_ids)

            # If all nodes in the final layer are completed successfully , the
            # pipeline can be finalized.
            # TODO(goutham): If there are conditional eval nodes, not all nodes may be
            # executed in the final layer. Handle this case when conditionals are
            # supported.
            if (layer_num == len(layers) - 1
                    and successful_layer_node_ids == layer_node_ids):
                return [
                    task_lib.FinalizePipelineTask(
                        pipeline_uid=self._pipeline_uid,
                        status=status_lib.Status(code=status_lib.Code.OK))
                ]
        return result