def Run(self): """Run the train/val/test loop.""" test_on = FLAGS.test_on save_on = FLAGS.save_on() # Run the train/val/test loop. for self.ctx.i in range(self.ctx.i, self.ctx.n): val_results = self.RunOneEpoch(test_on, save_on) if self.ShouldExitEarly(val_results): break # Record the final epoch. self.ctx.i = self.ctx.n if FLAGS.test_on == "best": # If training on the best result, restore the model to the state of the # best epoch. # Flush the logger before using the log database to make sure that all # results have been written. self.logger.Flush() # Restore the model to the state at the best validation accuracy. checkpoint = checkpoints.CheckpointReference( run_id=self.model.run_id, tag=None, epoch_num=None) self.ctx.Log( 1, "Restoring model to best validation results", ) self.model.RestoreFrom(checkpoint) self.RunEpoch(epoch.Type.TEST, self.MakeBatchIterator(epoch.Type.TEST))
def CheckpointReference_without_epoch_num(): """Check construction of a checkpoint reference without epoch number.""" run_id = run_id_lib.RunId.GenerateUnique("reftest") a = checkpoints.CheckpointReference(run_id, epoch_num=None) assert a.run_id == run_id assert a.epoch_num is None b = checkpoints.CheckpointReference.FromString(str(a)) assert b.run_id == run_id assert b.epoch_num is None assert a == b
def SaveCheckpoint(self) -> checkpoints.CheckpointReference: """Construct a checkpoint from the current model state. Returns: A checkpoint reference. """ if not self._initialized: raise TypeError("Cannot save an unitialized model.") self.logger.Save( checkpoints.Checkpoint( run_id=self.run_id, epoch_num=self.epoch_num, best_results=self.best_results, model_data=self.GetModelData(), )) return checkpoints.CheckpointReference(run_id=self.run_id, epoch_num=self.epoch_num)