Пример #1
0
  def GetBestResults(
    self,
    run_id: run_id_lib.RunId,
    session: Optional[sqlutil.Database.SessionType] = None,
  ) -> Dict[epoch.Type, epoch.BestResults]:
    """Get the best results for a given run.

    Returns:
      A mapping from <epoch_type, epoch.Results> for the best accuracy on each
      of the epoch types.
    """
    with self.Session(session=session) as session:
      # Check that the run exists:
      if not session.query(RunId).filter(RunId.run_id == str(run_id)).scalar():
        raise ValueError(f"Run not found: {run_id}")

      best_results: Dict[epoch.Type, epoch.BestResults] = {}
      for epoch_type in list(epoch.Type):
        accuracy_to_epoch_num = {
          row.accuracy: row.epoch_num
          for row in session.query(
            Batch.epoch_num, sql.func.avg(Batch.accuracy).label("accuracy")
          )
          .filter(
            Batch.run_id == str(run_id),
            Batch.epoch_type_num == epoch_type.value,
          )
          .group_by(Batch.epoch_num)
        }
        if accuracy_to_epoch_num:
          epoch_num = accuracy_to_epoch_num[max(accuracy_to_epoch_num.keys())]
          epoch_results = self.GetEpochResults(
            run_id=run_id, epoch_num=epoch_num, epoch_type=epoch_type
          )
          best_results_for_type = epoch.BestResults(
            epoch_num=epoch_num, results=epoch_results
          )
        else:
          best_results_for_type = epoch.BestResults()
        best_results[epoch_type] = best_results_for_type
    return best_results
Пример #2
0
    def __init__(
        self,
        logger: logging.Logger,
        graph_db: graph_tuple_database,
        run_id: Optional[run_id_lib.RunId] = None,
    ):
        """Constructor.

    This creates an uninitialized model. Initialize the model before use by
    calling Initialize() or RestoreFrom(checkpoint).

    Args:
      logger: A logger to write {batch, epoch, checkpoint} data to.
      graph_db: The graph database which will be used to feed inputs to the
        model.

    Raises:
      NotImplementedError: If both node and graph labels are set.
      TypeError: If neither graph or node labels are set.
    """
        # Sanity check the dimensionality of input graphs.
        if (not graph_db.node_y_dimensionality
                and not graph_db.graph_y_dimensionality):
            raise NotImplementedError(
                "Neither node or graph labels are set. What am I to do?")
        if graph_db.node_y_dimensionality and graph_db.graph_y_dimensionality:
            raise NotImplementedError(
                "Both node and graph labels are set. This is currently not supported. "
                "See <github.com/ChrisCummins/ProGraML/issues/26>")

        # Model properties.
        self.logger: logging.Logger = logger
        self.graph_db: graph_tuple_database.Database = graph_db
        self.run_id: run_id_lib.RunId = (run_id
                                         or run_id_lib.RunId.GenerateUnique(
                                             type(self).__name__))
        self.y_dimensionality: int = (self.graph_db.node_y_dimensionality
                                      or self.graph_db.graph_y_dimensionality)

        # Set by Initialize() and RestoredFrom()
        self._initialized = False
        self.restored_from: Optional[checkpoints.CheckpointReference] = None

        # Progress counters that are saved and loaded from checkpoints.
        self.epoch_num = 0
        self.best_results: Dict[epoch.Type, epoch.BestResults] = {
            epoch.Type.TRAIN: epoch.BestResults(),
            epoch.Type.VAL: epoch.BestResults(),
            epoch.Type.TEST: epoch.BestResults(),
        }

        # If --strict_graph_segmentation is set, check for graphs that we have
        # already seen before by keep a log of all unique graph IDs of each type.
        self.graph_ids: Dict[epoch.Type, Set[int]] = {
            epoch.Type.TRAIN: set(),
            epoch.Type.VAL: set(),
            epoch.Type.TEST: set(),
        }

        # Register this model with the logger.
        self.logger.OnStartRun(self.run_id, self.graph_db)
Пример #3
0
    def __call__(
        self,
        epoch_type: epoch.Type,
        batch_iterator: batches.BatchIterator,
        logger: logging.Logger,
    ) -> epoch.Results:
        """Run the model for over the input batches.

    This is the heart of the model - where you run an epoch of batches through
    the graph and produce results. The interface for training and inference is
    the same, only the epoch_type value should change.

    Side effects of calling a model are:
      * The model bumps its epoch_num counter if on a training epoch.
      * The model updates its best_results dictionary if the accuracy produced
        by this epoch is greater than the previous best.

    Args:
      epoch_type: The type of epoch to run.
      batch_iterator: The batches to process.
      logger: A logger instance to log results to.

    Returns:
      An epoch results instance.
    """
        if not self._initialized:
            raise TypeError(
                "Model called before Initialize() or FromCheckpoint() invoked")

        # Only training epochs bumps the epoch count.
        if epoch_type == epoch.Type.TRAIN:
            self.epoch_num += 1

        thread = EpochThread(self, epoch_type, batch_iterator, logger)
        progress.Run(thread)

        # Check that there were batches.
        if not thread.batch_count:
            raise ValueError("No batches")

        # If --strict_graph_segmentation is set, check for graphs that we have
        # already seen before.
        if FLAGS.strict_graph_segmentation:
            with logger.ctx.Profile(4, "Checked strict graph segmentation"):
                for other_epoch_type in set(list(epoch.Type)) - {epoch_type}:
                    duplicate_graph_ids = self.graph_ids[
                        other_epoch_type].intersection(thread.graph_ids)
                    if duplicate_graph_ids:
                        raise ValueError(
                            f"{epoch_type} batch contains {len(duplicate_graph_ids)} graphs "
                            f"from {other_epoch_type}: {list(duplicate_graph_ids)[:100]}"
                        )
                self.graph_ids[epoch_type] = self.graph_ids[epoch_type].union(
                    thread.graph_ids)

        # TODO(github.com/ChrisCummins/ProGraML/issues/38): Explicitly free the
        # thread object to see if that is contributing to climbing memory usage.
        results = copy.deepcopy(thread.results)
        if not results:
            raise OSError("Epoch produced no results. Did the model crash?")
        del thread

        # Update the record of best results.
        if results > self.best_results[epoch_type].results:
            new_best = epoch.BestResults(epoch_num=self.epoch_num,
                                         results=results)
            logger.ctx.Log(
                2,
                "%s results improved from %s",
                epoch_type.name.capitalize(),
                self.best_results[epoch_type],
            )
            self.best_results[epoch_type] = new_best

        return results