def testCheckpointExportListenerGC(self): listener = checkpoint_hooks.CheckpointExportListener( self._ExportFn, self._export_dir, num_versions=3) for step in range(5): listener.after_save(None, step) self.assertTrue(tf.gfile.Exists(_CheckpointDir(self._export_dir, 5))) self.assertFalse(tf.gfile.Exists(_CheckpointDir(self._export_dir, 2)))
def testCheckpointExportListenerGCRestore(self): for step in range(6): _MakeSavedModel(self._export_dir, step) # Initializer does GC on old checkpoints. checkpoint_hooks.CheckpointExportListener( self._ExportFn, self._export_dir, num_versions=3) self.assertTrue(tf.gfile.Exists(_CheckpointDir(self._export_dir, 5))) self.assertFalse(tf.gfile.Exists(_CheckpointDir(self._export_dir, 2)))
def create_hooks(self, t2r_model, estimator, export_generator): export_generator.set_specification_from_model(t2r_model) return [ tf.contrib.tpu.AsyncCheckpointSaverHook( save_secs=self._save_secs, checkpoint_dir=estimator.model_dir, listeners=[ checkpoint_hooks.CheckpointExportListener( export_fn=self._create_export_fn( t2r_model, estimator, export_generator), num_versions=self._num_versions, export_dir=self._export_dir) ]) ]
def testCheckpointExportListener(self): listener = checkpoint_hooks.CheckpointExportListener( self._ExportFn, self._export_dir) listener.after_save(None, 10) self.assertTrue(tf.gfile.Exists(_CheckpointDir(self._export_dir, 1)))