def test_file( self, save_final, final_filename, save_key_metric, key_metric_name, key_metric_n_saved, key_metric_filename, key_metric_save_state, key_metric_greater_or_equal, epoch_level, save_interval, n_saved, filenames, multi_devices=False, ): logging.basicConfig(stream=sys.stdout, level=logging.INFO) data = [0] * 8 # set up engine def _train_func(engine, batch): engine.state.metrics["val_loss"] = engine.state.iteration engine = Engine(_train_func) # set up testing handler net = torch.nn.PReLU() if multi_devices: net = torch.nn.DataParallel(net) optimizer = optim.SGD(net.parameters(), lr=0.02) with tempfile.TemporaryDirectory() as tempdir: handler = CheckpointSaver( tempdir, { "net": net, "opt": optimizer }, "CheckpointSaver", "test", save_final, final_filename, save_key_metric, key_metric_name, key_metric_n_saved, key_metric_filename, key_metric_save_state, key_metric_greater_or_equal, epoch_level, save_interval, n_saved, ) handler.attach(engine) engine.run(data, max_epochs=2) engine.run(data, max_epochs=5) for filename in filenames: self.assertTrue(os.path.exists(os.path.join(tempdir, filename)))
def test_exception(self): net = torch.nn.PReLU() # set up engine def _train_func(engine, batch): raise RuntimeError("test exception.") engine = Engine(_train_func) # set up testing handler with tempfile.TemporaryDirectory() as tempdir: stats_handler = CheckpointSaver(tempdir, {"net": net}, save_final=True) stats_handler.attach(engine) with self.assertRaises(RuntimeError): engine.run(range(3), max_epochs=2) self.assertTrue(os.path.exists(os.path.join(tempdir, "net_final_iteration=1.pt")))