def test_checkpointing(self, tmp_dir):
        """ tests saving and loading from checkpoint. """
        cfg = self._get_cfg(tmp_dir)

        out = main(cfg, accelerator=None)
        ckpts = [
            file for file in os.listdir(tmp_dir) if file.endswith(".ckpt")
        ]
        self.assertCountEqual(
            [
                "last.ckpt",
                FINAL_MODEL_CKPT,
            ],
            ckpts,
        )

        with tempfile.TemporaryDirectory() as tmp_dir2:
            cfg2 = cfg.clone()
            cfg2.defrost()
            cfg2.OUTPUT_DIR = tmp_dir2
            # load the last checkpoint from previous training
            cfg2.MODEL.WEIGHTS = os.path.join(tmp_dir, "last.ckpt")

            out2 = main(cfg2, accelerator=None, eval_only=True)
            accuracy = flatten_config_dict(out.accuracy)
            accuracy2 = flatten_config_dict(out2.accuracy)
            for k in accuracy:
                np.testing.assert_equal(accuracy[k], accuracy2[k])
    def test_checkpointing(self, tmp_dir):
        """ tests saving and loading from checkpoint. """
        cfg = self._get_cfg(tmp_dir)

        out = main(cfg)
        ckpts = [f for f in os.listdir(tmp_dir) if f.endswith(".ckpt")]
        expected_ckpts = ("last.ckpt", FINAL_MODEL_CKPT)
        for ckpt in expected_ckpts:
            self.assertIn(ckpt, ckpts)

        cfg2 = cfg.clone()
        cfg2.defrost()
        cfg2.OUTPUT_DIR = os.path.join(tmp_dir, "output")
        # load the last checkpoint from previous training
        cfg2.MODEL.WEIGHTS = os.path.join(tmp_dir, "last.ckpt")

        out2 = main(cfg2, eval_only=True)
        accuracy = flatten_config_dict(out.accuracy)
        accuracy2 = flatten_config_dict(out2.accuracy)
        for k in accuracy:
            np.testing.assert_equal(accuracy[k], accuracy2[k])
 def test_train_net_main(self, root_dir):
     """ tests the main training entry point. """
     cfg = self._get_cfg(root_dir)
     # set distributed backend to none to avoid spawning child process,
     # which doesn't inherit the temporary dataset
     main(cfg, accelerator=None)