def test_file(
        self,
        save_final,
        final_filename,
        save_key_metric,
        key_metric_name,
        key_metric_n_saved,
        key_metric_filename,
        key_metric_save_state,
        key_metric_greater_or_equal,
        epoch_level,
        save_interval,
        n_saved,
        filenames,
        multi_devices=False,
    ):
        logging.basicConfig(stream=sys.stdout, level=logging.INFO)
        data = [0] * 8

        # set up engine
        def _train_func(engine, batch):
            engine.state.metrics["val_loss"] = engine.state.iteration

        engine = Engine(_train_func)

        # set up testing handler
        net = torch.nn.PReLU()
        if multi_devices:
            net = torch.nn.DataParallel(net)
        optimizer = optim.SGD(net.parameters(), lr=0.02)
        with tempfile.TemporaryDirectory() as tempdir:
            handler = CheckpointSaver(
                tempdir,
                {
                    "net": net,
                    "opt": optimizer
                },
                "CheckpointSaver",
                "test",
                save_final,
                final_filename,
                save_key_metric,
                key_metric_name,
                key_metric_n_saved,
                key_metric_filename,
                key_metric_save_state,
                key_metric_greater_or_equal,
                epoch_level,
                save_interval,
                n_saved,
            )
            handler.attach(engine)
            engine.run(data, max_epochs=2)
            engine.run(data, max_epochs=5)
            for filename in filenames:
                self.assertTrue(os.path.exists(os.path.join(tempdir,
                                                            filename)))
Beispiel #2
0
    def test_exception(self):
        net = torch.nn.PReLU()

        # set up engine
        def _train_func(engine, batch):
            raise RuntimeError("test exception.")

        engine = Engine(_train_func)

        # set up testing handler
        with tempfile.TemporaryDirectory() as tempdir:
            stats_handler = CheckpointSaver(tempdir, {"net": net}, save_final=True)
            stats_handler.attach(engine)

            with self.assertRaises(RuntimeError):
                engine.run(range(3), max_epochs=2)
            self.assertTrue(os.path.exists(os.path.join(tempdir, "net_final_iteration=1.pt")))