示例#1
0
    def SaveCheckpoint(self) -> checkpoints.CheckpointReference:
        """Construct a checkpoint from the current model state.

    Returns:
      A checkpoint reference.
    """
        if not self._initialized:
            raise TypeError("Cannot save an unitialized model.")

        self.logger.Save(
            checkpoints.Checkpoint(
                run_id=self.run_id,
                epoch_num=self.epoch_num,
                best_results=self.best_results,
                model_data=self.GetModelData(),
            ))
        return checkpoints.CheckpointReference(run_id=self.run_id,
                                               epoch_num=self.epoch_num)
示例#2
0
    def Load(
        self, checkpoint_ref: checkpoints.CheckpointReference
    ) -> checkpoints.Checkpoint:
        """Load model data.

    Args:
      run_id: The run ID of the model data to load.
      epoch_num: An optional epoch number to restore model data from. If None,
        the most recent epoch is used.

    Returns:
      A checkpoint instance.

    Raises:
      ValueError: If no corresponding entry in the checkpoint table exists.
    """
        # A previous Save() call from this logger might still be buffered. Flush the
        # buffer before loading from the database.
        self._writer.Flush()

        with self.db.Session() as session:
            checkpoint_entry = (session.query(log_database.Checkpoint).filter(
                log_database.Checkpoint.run_id == str(checkpoint_ref.run_id),
                log_database.Checkpoint.epoch_num == checkpoint_ref.epoch_num,
            ).options(sql.orm.joinedload(
                log_database.Checkpoint.data)).first())
            # Check that the requested checkpoint exists.
            if not checkpoint_entry:
                raise ValueError(f"Checkpoint not found: {checkpoint_ref}")

            checkpoint = checkpoints.Checkpoint(
                run_id=run_id_lib.RunId.FromString(checkpoint_entry.run_id),
                epoch_num=checkpoint_entry.epoch_num,
                best_results=self.db.GetBestResults(
                    run_id=checkpoint_ref.run_id, session=session),
                model_data=checkpoint_entry.model_data,
            )

        return checkpoint
示例#3
0
    def Load(
        self, checkpoint_ref: checkpoints.CheckpointReference
    ) -> checkpoints.Checkpoint:
        """Load model data.

    Args:
      checkpoint_ref: A checkpoint to load. If epoch_num is not set, the best
        validation results are selected.

    Returns:
      A checkpoint instance.

    Raises:
      ValueError: If no corresponding entry in the checkpoint table exists.
    """
        # A previous Save() call from this logger might still be buffered. Flush the
        # buffer before loading from the database.
        self.Flush()

        with self.db.Session() as session:
            epoch_num = checkpoint_ref.epoch_num

            # If no epoch number was provided, select the best epoch from the log
            # database.
            if epoch_num is None:
                # Get the per-epoch summary table of model results.
                tables = {
                    name: df
                    for name, df in self.db.GetTables(
                        run_ids=[checkpoint_ref.run_id])
                }
                # Select the epoch with the best validation accuracy.
                epochs = tables["epochs"][tables["epochs"]
                                          ["val_accuracy"].notnull()]
                if not len(epochs):
                    raise ValueError("No epochs found!")
                best_epoch_idx = epochs["val_accuracy"].idxmax()
                best_epoch = epochs.iloc[best_epoch_idx]
                epoch_num = best_epoch["epoch_num"]

            checkpoint_entry = (session.query(log_database.Checkpoint).filter(
                log_database.Checkpoint.run_id == str(checkpoint_ref.run_id),
                log_database.Checkpoint.epoch_num == int(epoch_num),
            ).options(sql.orm.joinedload(
                log_database.Checkpoint.data)).first())
            # Check that the requested checkpoint exists.
            if not checkpoint_entry:
                available_checkpoints = [
                    f"{checkpoint_ref.run_id}@{row.epoch_num}"
                    for row in session.query(log_database.Checkpoint.epoch_num)
                    .join(log_database.CheckpointModelData).filter(
                        log_database.Checkpoint.run_id == str(
                            checkpoint_ref.run_id)).order_by(
                                log_database.Checkpoint.epoch_num)
                ]
                raise ValueError(
                    f"Checkpoint not found: {checkpoint_ref}. "
                    f"Available checkpoints: {available_checkpoints}")

            checkpoint = checkpoints.Checkpoint(
                run_id=run_id_lib.RunId.FromString(checkpoint_entry.run_id),
                epoch_num=checkpoint_entry.epoch_num,
                best_results=self.db.GetBestResults(
                    run_id=checkpoint_ref.run_id, session=session),
                model_data=checkpoint_entry.model_data,
            )

        return checkpoint
示例#4
0
  def Load(
    self, checkpoint_ref: checkpoints.CheckpointReference
  ) -> checkpoints.Checkpoint:
    """Load model data.

    Args:
      checkpoint_ref: A checkpoint to load. If epoch_num is not set, the epoch
        for which there is a checkpoint is selected using the best validation
        results.

    Returns:
      A checkpoint instance.

    Raises:
      ValueError: If no corresponding entry in the checkpoints table is found,
        or if a specific epoch number was requested using a tag which resolves
        to multiple IDs.
    """
    # A previous Save() call from this logger might still be buffered. Flush the
    # buffer before loading from the database.
    self.Flush()

    with self.db.Session() as session:
      epoch_num = checkpoint_ref.epoch_num

      # Resolve the run IDs from the checkpoint reference.
      run_ids = [
        run_id_lib.RunId.FromString(run_id)
        for run_id in self.db.SelectRunIds(
          run_ids=[checkpoint_ref.run_id] if checkpoint_ref.run_id else [],
          tags=[checkpoint_ref.tag] if checkpoint_ref.tag else [],
          session=session,
        )
      ]
      if not run_ids:
        raise ValueError(f"No runs found for checkpoint: {checkpoint_ref}")

      # If no epoch number was provided, select the best epoch from the log
      # database.
      if epoch_num is None:
        # Get the per-epoch summary table of model results.
        tables = {name: df for name, df in self.db.GetTables(run_ids=run_ids)}
        # Select the epoch with the best validation accuracy.
        epochs = tables["epochs"][tables["epochs"]["val_accuracy"].notnull()]
        if not len(epochs):
          raise ValueError(f"No epochs found for checkpoint: {checkpoint_ref}")
        best_epoch_idx = epochs["val_accuracy"].idxmax()
        best_epoch = epochs.iloc[best_epoch_idx]
        epoch_num = best_epoch["epoch_num"]
      elif epoch_num and len(run_ids):
        tables = {name: df for name, df in self.db.GetTables(run_ids=run_ids)}
        epochs = tables["epochs"]
        if len(epochs[epochs["epoch_num"] == epoch_num]) > 1:
          raise ValueError(f"Multiple runs found for tag: {checkpoint_ref}")

      checkpoint_entry = (
        session.query(log_database.Checkpoint)
        .filter(
          log_database.Checkpoint.run_id.in_(str(s) for s in run_ids),
          log_database.Checkpoint.epoch_num == int(epoch_num),
        )
        .options(sql.orm.joinedload(log_database.Checkpoint.data))
        .first()
      )
      # Check that the requested checkpoint exists.
      if not checkpoint_entry:
        available_checkpoints = [
          f"{checkpoint_ref.run_id}@{row.epoch_num}"
          for row in session.query(log_database.Checkpoint.epoch_num)
          .join(log_database.CheckpointModelData)
          .filter(log_database.Checkpoint.run_id == str(checkpoint_ref.run_id))
          .order_by(log_database.Checkpoint.epoch_num)
        ]
        raise ValueError(
          f"Checkpoint not found: {checkpoint_ref} (runs: {run_ids}). "
          f"Available checkpoints: {available_checkpoints}"
        )

      checkpoint = checkpoints.Checkpoint(
        run_id=run_id_lib.RunId.FromString(checkpoint_entry.run_id),
        epoch_num=checkpoint_entry.epoch_num,
        best_results=self.db.GetBestResults(
          run_id=checkpoint_entry.run_id, session=session
        ),
        model_data=checkpoint_entry.model_data,
      )

    return checkpoint