Ejemplo n.º 1
0
    def test_tensorboard_logging_parity(
        self,
        summary_writer,
        mmf,
        lightning,
        logistics,
        logistics_logs,
        report_logs,
        trainer_logs,
        mkdirs,
    ):
        # mmf trainer
        mmf_trainer = get_mmf_trainer(
            max_updates=8,
            batch_size=2,
            max_epochs=None,
            log_interval=3,
            tensorboard=True,
        )

        def _add_scalars_mmf(log_dict, iteration):
            self.mmf_tensorboard_logs.append({iteration: log_dict})

        mmf_trainer.load_metrics()
        logistics_callback = LogisticsCallback(mmf_trainer.config, mmf_trainer)
        logistics_callback.snapshot_timer = MagicMock(return_value=None)
        logistics_callback.train_timer = Timer()
        logistics_callback.tb_writer.add_scalars = _add_scalars_mmf
        mmf_trainer.logistics_callback = logistics_callback
        mmf_trainer.callbacks = [logistics_callback]
        mmf_trainer.early_stop_callback = MagicMock(return_value=None)
        mmf_trainer.on_update_end = logistics_callback.on_update_end
        mmf_trainer.training_loop()

        # lightning_trainer
        trainer = get_lightning_trainer(
            max_steps=8,
            batch_size=2,
            prepare_trainer=False,
            log_every_n_steps=3,
            val_check_interval=9,
            tensorboard=True,
        )

        def _add_scalars_lightning(log_dict, iteration):
            self.lightning_tensorboard_logs.append({iteration: log_dict})

        def _on_fit_start_callback():
            trainer.tb_writer.add_scalars = _add_scalars_lightning

        callback = LightningLoopCallback(trainer)
        run_lightning_trainer_with_callback(
            trainer, callback, on_fit_start_callback=_on_fit_start_callback)
        self.assertEqual(len(self.mmf_tensorboard_logs),
                         len(self.lightning_tensorboard_logs))
        for mmf, lightning in zip(self.mmf_tensorboard_logs,
                                  self.lightning_tensorboard_logs):
            self.assertDictEqual(mmf, lightning)
Ejemplo n.º 2
0
    def setUp(self):
        self.tmpdir = tempfile.mkdtemp()
        self.trainer = argparse.Namespace()
        self.config = load_yaml(os.path.join("configs", "defaults.yaml"))
        self.config = OmegaConf.merge(
            self.config,
            {
                "model": "simple",
                "model_config": {},
                "training": {
                    "checkpoint_interval": 1,
                    "evaluation_interval": 10,
                    "early_stop": {
                        "criteria": "val/total_loss"
                    },
                    "batch_size": 16,
                    "log_interval": 10,
                    "logger_level": "info",
                },
                "env": {
                    "save_dir": self.tmpdir
                },
            },
        )
        # Keep original copy for testing purposes
        self.trainer.config = deepcopy(self.config)
        registry.register("config", self.trainer.config)
        setup_logger()
        self.report = Mock(spec=Report)
        self.report.dataset_name = "abcd"
        self.report.dataset_type = "test"

        self.trainer.model = SimpleModule()
        self.trainer.val_loader = torch.utils.data.DataLoader(
            NumbersDataset(), batch_size=self.config.training.batch_size)

        self.trainer.optimizer = torch.optim.Adam(
            self.trainer.model.parameters(), lr=1e-01)
        self.trainer.device = "cpu"
        self.trainer.num_updates = 0
        self.trainer.current_iteration = 0
        self.trainer.current_epoch = 0
        self.trainer.max_updates = 0
        self.trainer.meter = Meter()
        self.cb = LogisticsCallback(self.config, self.trainer)
Ejemplo n.º 3
0
    def configure_callbacks(self):
        self.checkpoint_callback = CheckpointCallback(self.config, self)
        self.early_stop_callback = EarlyStoppingCallback(self.config, self)
        self.logistics_callback = LogisticsCallback(self.config, self)
        self.lr_scheduler_callback = LRSchedulerCallback(self.config, self)

        # Add callbacks for execution during events
        self.callbacks.append(self.checkpoint_callback)
        self.callbacks.append(self.logistics_callback)
        self.callbacks.append(self.lr_scheduler_callback)
Ejemplo n.º 4
0
    def test_validation_parity(self, summarize_report_fn, test_reporter, sw,
                               mkdirs):
        mmf_trainer = get_mmf_trainer(max_updates=8,
                                      batch_size=2,
                                      max_epochs=None,
                                      evaluation_interval=3)
        mmf_trainer.load_metrics()
        logistics_callback = LogisticsCallback(mmf_trainer.config, mmf_trainer)
        logistics_callback.snapshot_timer = Timer()
        logistics_callback.train_timer = Timer()
        mmf_trainer.logistics_callback = logistics_callback
        mmf_trainer.callbacks.append(logistics_callback)
        mmf_trainer.early_stop_callback = MagicMock(return_value=None)
        mmf_trainer.on_validation_end = logistics_callback.on_validation_end
        mmf_trainer.training_loop()

        calls = summarize_report_fn.call_args_list
        self.assertEqual(3, len(calls))
        self.assertEqual(len(self.ground_truths), len(calls))
        self._check_values(calls)
    def configure_callbacks(self):
        self.checkpoint_callback = CheckpointCallback(self.config, self)
        self.early_stop_callback = EarlyStoppingCallback(self.config, self)
        self.logistics_callback = LogisticsCallback(self.config, self)
        self.lr_scheduler_callback = LRSchedulerCallback(self.config, self)

        # Add callbacks for execution during events
        self.callbacks.append(self.lr_scheduler_callback)
        # checkpoint_callback needs to be called after lr_scheduler_callback so that
        # lr_scheduler_callback._scheduler.step() happens before saving checkpoints
        # (otherwise the saved last_epoch in scheduler would be wrong)
        self.callbacks.append(self.checkpoint_callback)
        self.callbacks.append(self.logistics_callback)
Ejemplo n.º 6
0
    def configure_callbacks(self):
        self.checkpoint_callback = CheckpointCallback(self.config, self)
        self.early_stop_callback = EarlyStoppingCallback(self.config, self)
        self.logistics_callback = LogisticsCallback(self.config, self)
        self.lr_scheduler_callback = LRSchedulerCallback(self.config, self)

        # Reset callbacks as they are class variables and would be shared between
        # multiple interactive shell calls to `run`
        self.callbacks = []
        # Add callbacks for execution during events
        self.callbacks.append(self.lr_scheduler_callback)
        # checkpoint_callback needs to be called after lr_scheduler_callback so that
        # lr_scheduler_callback._scheduler.step() happens before saving checkpoints
        # (otherwise the saved last_epoch in scheduler would be wrong)
        self.callbacks.append(self.checkpoint_callback)
        self.callbacks.append(self.logistics_callback)
Ejemplo n.º 7
0
    def configure_callbacks(self):
        self.checkpoint_callback = CheckpointCallback(self.config, self)
        self.early_stop_callback = EarlyStoppingCallback(self.config, self)
        self.logistics_callback = LogisticsCallback(self.config, self)
        self.lr_scheduler_callback = LRSchedulerCallback(self.config, self)

        # Reset callbacks as they are class variables and would be shared between
        # multiple interactive shell calls to `run`
        self.callbacks = []
        # Add callbacks for execution during events
        self.callbacks.append(self.lr_scheduler_callback)
        # checkpoint_callback needs to be called after lr_scheduler_callback so that
        # lr_scheduler_callback._scheduler.step() happens before saving checkpoints
        # (otherwise the saved last_epoch in scheduler would be wrong)
        self.callbacks.append(self.checkpoint_callback)
        self.callbacks.append(self.logistics_callback)
        # Add all customized callbacks defined by users
        for callback in self.config.training.get("callbacks", []):
            callback_type = callback.type
            callback_param = callback.params
            callback_cls = registry.get_callback_class(callback_type)
            self.callbacks.append(
                callback_cls(self.config, self, **callback_param))
class TestLogisticsCallback(unittest.TestCase):
    def setUp(self):
        self.tmpdir = tempfile.mkdtemp()
        self.trainer = argparse.Namespace()
        self.config = OmegaConf.create(
            {
                "model": "simple",
                "model_config": {},
                "training": {
                    "checkpoint_interval": 1,
                    "evaluation_interval": 10,
                    "early_stop": {"criteria": "val/total_loss"},
                    "batch_size": 16,
                    "log_interval": 10,
                    "logger_level": "info",
                },
                "env": {"save_dir": self.tmpdir},
            }
        )
        # Keep original copy for testing purposes
        self.trainer.config = deepcopy(self.config)
        registry.register("config", self.trainer.config)
        setup_logger.cache_clear()
        setup_logger()
        self.report = Mock(spec=Report)
        self.report.dataset_name = "abcd"
        self.report.dataset_type = "test"

        self.trainer.model = SimpleModule()
        self.trainer.val_dataset = NumbersDataset()

        self.trainer.optimizer = torch.optim.Adam(
            self.trainer.model.parameters(), lr=1e-01
        )
        self.trainer.device = "cpu"
        self.trainer.num_updates = 0
        self.trainer.current_iteration = 0
        self.trainer.current_epoch = 0
        self.trainer.max_updates = 0
        self.trainer.meter = Meter()
        self.cb = LogisticsCallback(self.config, self.trainer)

    def tearDown(self):
        registry.unregister("config")

    def test_on_train_start(self):
        self.cb.on_train_start()
        expected = 0
        self.assertEqual(
            int(self.cb.train_timer.get_time_since_start().split("ms")[0]), expected
        )

    def test_on_update_end(self):
        self.cb.on_train_start()
        self.cb.on_update_end(meter=self.trainer.meter, should_log=False)
        f = PathManager.open(os.path.join(self.tmpdir, "train.log"))
        self.assertFalse(any("time_since_start" in line for line in f.readlines()))
        self.cb.on_update_end(meter=self.trainer.meter, should_log=True)
        f = PathManager.open(os.path.join(self.tmpdir, "train.log"))
        self.assertTrue(any("time_since_start" in line for line in f.readlines()))

    def test_on_validation_start(self):
        self.cb.on_train_start()
        self.cb.on_validation_start()
        expected = 0
        self.assertEqual(
            int(self.cb.snapshot_timer.get_time_since_start().split("ms")[0]), expected
        )

    def test_on_test_end(self):
        self.cb.on_test_end(report=self.report, meter=self.trainer.meter)
        f = PathManager.open(os.path.join(self.tmpdir, "train.log"))
        self.assertTrue(any("Finished run in" in line for line in f.readlines()))