def SaveCheckpoint(self) -> checkpoints.CheckpointReference: """Construct a checkpoint from the current model state. Returns: A checkpoint reference. """ if not self._initialized: raise TypeError("Cannot save an unitialized model.") self.logger.Save( checkpoints.Checkpoint( run_id=self.run_id, epoch_num=self.epoch_num, best_results=self.best_results, model_data=self.GetModelData(), )) return checkpoints.CheckpointReference(run_id=self.run_id, epoch_num=self.epoch_num)
def Load( self, checkpoint_ref: checkpoints.CheckpointReference ) -> checkpoints.Checkpoint: """Load model data. Args: run_id: The run ID of the model data to load. epoch_num: An optional epoch number to restore model data from. If None, the most recent epoch is used. Returns: A checkpoint instance. Raises: ValueError: If no corresponding entry in the checkpoint table exists. """ # A previous Save() call from this logger might still be buffered. Flush the # buffer before loading from the database. self._writer.Flush() with self.db.Session() as session: checkpoint_entry = (session.query(log_database.Checkpoint).filter( log_database.Checkpoint.run_id == str(checkpoint_ref.run_id), log_database.Checkpoint.epoch_num == checkpoint_ref.epoch_num, ).options(sql.orm.joinedload( log_database.Checkpoint.data)).first()) # Check that the requested checkpoint exists. if not checkpoint_entry: raise ValueError(f"Checkpoint not found: {checkpoint_ref}") checkpoint = checkpoints.Checkpoint( run_id=run_id_lib.RunId.FromString(checkpoint_entry.run_id), epoch_num=checkpoint_entry.epoch_num, best_results=self.db.GetBestResults( run_id=checkpoint_ref.run_id, session=session), model_data=checkpoint_entry.model_data, ) return checkpoint
def Load( self, checkpoint_ref: checkpoints.CheckpointReference ) -> checkpoints.Checkpoint: """Load model data. Args: checkpoint_ref: A checkpoint to load. If epoch_num is not set, the best validation results are selected. Returns: A checkpoint instance. Raises: ValueError: If no corresponding entry in the checkpoint table exists. """ # A previous Save() call from this logger might still be buffered. Flush the # buffer before loading from the database. self.Flush() with self.db.Session() as session: epoch_num = checkpoint_ref.epoch_num # If no epoch number was provided, select the best epoch from the log # database. if epoch_num is None: # Get the per-epoch summary table of model results. tables = { name: df for name, df in self.db.GetTables( run_ids=[checkpoint_ref.run_id]) } # Select the epoch with the best validation accuracy. epochs = tables["epochs"][tables["epochs"] ["val_accuracy"].notnull()] if not len(epochs): raise ValueError("No epochs found!") best_epoch_idx = epochs["val_accuracy"].idxmax() best_epoch = epochs.iloc[best_epoch_idx] epoch_num = best_epoch["epoch_num"] checkpoint_entry = (session.query(log_database.Checkpoint).filter( log_database.Checkpoint.run_id == str(checkpoint_ref.run_id), log_database.Checkpoint.epoch_num == int(epoch_num), ).options(sql.orm.joinedload( log_database.Checkpoint.data)).first()) # Check that the requested checkpoint exists. if not checkpoint_entry: available_checkpoints = [ f"{checkpoint_ref.run_id}@{row.epoch_num}" for row in session.query(log_database.Checkpoint.epoch_num) .join(log_database.CheckpointModelData).filter( log_database.Checkpoint.run_id == str( checkpoint_ref.run_id)).order_by( log_database.Checkpoint.epoch_num) ] raise ValueError( f"Checkpoint not found: {checkpoint_ref}. " f"Available checkpoints: {available_checkpoints}") checkpoint = checkpoints.Checkpoint( run_id=run_id_lib.RunId.FromString(checkpoint_entry.run_id), epoch_num=checkpoint_entry.epoch_num, best_results=self.db.GetBestResults( run_id=checkpoint_ref.run_id, session=session), model_data=checkpoint_entry.model_data, ) return checkpoint
def Load( self, checkpoint_ref: checkpoints.CheckpointReference ) -> checkpoints.Checkpoint: """Load model data. Args: checkpoint_ref: A checkpoint to load. If epoch_num is not set, the epoch for which there is a checkpoint is selected using the best validation results. Returns: A checkpoint instance. Raises: ValueError: If no corresponding entry in the checkpoints table is found, or if a specific epoch number was requested using a tag which resolves to multiple IDs. """ # A previous Save() call from this logger might still be buffered. Flush the # buffer before loading from the database. self.Flush() with self.db.Session() as session: epoch_num = checkpoint_ref.epoch_num # Resolve the run IDs from the checkpoint reference. run_ids = [ run_id_lib.RunId.FromString(run_id) for run_id in self.db.SelectRunIds( run_ids=[checkpoint_ref.run_id] if checkpoint_ref.run_id else [], tags=[checkpoint_ref.tag] if checkpoint_ref.tag else [], session=session, ) ] if not run_ids: raise ValueError(f"No runs found for checkpoint: {checkpoint_ref}") # If no epoch number was provided, select the best epoch from the log # database. if epoch_num is None: # Get the per-epoch summary table of model results. tables = {name: df for name, df in self.db.GetTables(run_ids=run_ids)} # Select the epoch with the best validation accuracy. epochs = tables["epochs"][tables["epochs"]["val_accuracy"].notnull()] if not len(epochs): raise ValueError(f"No epochs found for checkpoint: {checkpoint_ref}") best_epoch_idx = epochs["val_accuracy"].idxmax() best_epoch = epochs.iloc[best_epoch_idx] epoch_num = best_epoch["epoch_num"] elif epoch_num and len(run_ids): tables = {name: df for name, df in self.db.GetTables(run_ids=run_ids)} epochs = tables["epochs"] if len(epochs[epochs["epoch_num"] == epoch_num]) > 1: raise ValueError(f"Multiple runs found for tag: {checkpoint_ref}") checkpoint_entry = ( session.query(log_database.Checkpoint) .filter( log_database.Checkpoint.run_id.in_(str(s) for s in run_ids), log_database.Checkpoint.epoch_num == int(epoch_num), ) .options(sql.orm.joinedload(log_database.Checkpoint.data)) .first() ) # Check that the requested checkpoint exists. if not checkpoint_entry: available_checkpoints = [ f"{checkpoint_ref.run_id}@{row.epoch_num}" for row in session.query(log_database.Checkpoint.epoch_num) .join(log_database.CheckpointModelData) .filter(log_database.Checkpoint.run_id == str(checkpoint_ref.run_id)) .order_by(log_database.Checkpoint.epoch_num) ] raise ValueError( f"Checkpoint not found: {checkpoint_ref} (runs: {run_ids}). " f"Available checkpoints: {available_checkpoints}" ) checkpoint = checkpoints.Checkpoint( run_id=run_id_lib.RunId.FromString(checkpoint_entry.run_id), epoch_num=checkpoint_entry.epoch_num, best_results=self.db.GetBestResults( run_id=checkpoint_entry.run_id, session=session ), model_data=checkpoint_entry.model_data, ) return checkpoint