示例#1
0
    def Run(self):
        """Run the train/val/test loop."""
        test_on = FLAGS.test_on
        save_on = FLAGS.save_on()

        # Run the train/val/test loop.
        for self.ctx.i in range(self.ctx.i, self.ctx.n):
            val_results = self.RunOneEpoch(test_on, save_on)
            if self.ShouldExitEarly(val_results):
                break

        # Record the final epoch.
        self.ctx.i = self.ctx.n

        if FLAGS.test_on == "best":
            # If training on the best result, restore the model to the state of the
            # best epoch.

            # Flush the logger before using the log database to make sure that all
            # results have been written.
            self.logger.Flush()

            # Restore the model to the state at the best validation accuracy.
            checkpoint = checkpoints.CheckpointReference(
                run_id=self.model.run_id, tag=None, epoch_num=None)
            self.ctx.Log(
                1,
                "Restoring model to best validation results",
            )
            self.model.RestoreFrom(checkpoint)
            self.RunEpoch(epoch.Type.TEST,
                          self.MakeBatchIterator(epoch.Type.TEST))
示例#2
0
def CheckpointReference_without_epoch_num():
    """Check construction of a checkpoint reference without epoch number."""
    run_id = run_id_lib.RunId.GenerateUnique("reftest")

    a = checkpoints.CheckpointReference(run_id, epoch_num=None)
    assert a.run_id == run_id
    assert a.epoch_num is None

    b = checkpoints.CheckpointReference.FromString(str(a))
    assert b.run_id == run_id
    assert b.epoch_num is None
    assert a == b
示例#3
0
    def SaveCheckpoint(self) -> checkpoints.CheckpointReference:
        """Construct a checkpoint from the current model state.

    Returns:
      A checkpoint reference.
    """
        if not self._initialized:
            raise TypeError("Cannot save an unitialized model.")

        self.logger.Save(
            checkpoints.Checkpoint(
                run_id=self.run_id,
                epoch_num=self.epoch_num,
                best_results=self.best_results,
                model_data=self.GetModelData(),
            ))
        return checkpoints.CheckpointReference(run_id=self.run_id,
                                               epoch_num=self.epoch_num)