コード例 #1
0
    def test_lr_schedule(self):
        # note, be aware some of the logic also is in the SimpleLightningModel
        trainer1 = get_lightning_trainer(max_steps=8, lr_scheduler=True)
        trainer1.trainer.fit(trainer1.model, trainer1.data_module.train_loader)

        trainer2 = get_lightning_trainer(max_steps=8)
        trainer2.trainer.fit(trainer2.model, trainer2.data_module.train_loader)

        last_model_param1 = list(trainer1.model.parameters())[-1]
        last_model_param2 = list(trainer2.model.parameters())[-1]
        self.assertFalse(torch.allclose(last_model_param1, last_model_param2))
コード例 #2
0
    def test_lr_schedule(self):
        with patch("mmf.trainers.lightning_trainer.get_mmf_env", return_value=""):
            # note, be aware some of the logic also is in the SimpleLightningModel
            config = self._get_config(max_steps=8, lr_scheduler=True)
            trainer1 = get_lightning_trainer(config=config)
            trainer1.trainer.fit(trainer1.model, trainer1.data_module.train_loader)

            config = self._get_config(max_steps=8)
            trainer2 = get_lightning_trainer(config=config)
            trainer2.trainer.fit(trainer2.model, trainer2.data_module.train_loader)

            last_model_param1 = list(trainer1.model.parameters())[-1]
            last_model_param2 = list(trainer2.model.parameters())[-1]
            self.assertFalse(torch.allclose(last_model_param1, last_model_param2))
コード例 #3
0
    def test_grad_accumulate(self):
        trainer1 = get_lightning_trainer(accumulate_grad_batches=2,
                                         max_steps=2,
                                         batch_size=3)
        trainer1.trainer.fit(trainer1.model, trainer1.data_module.train_loader)

        trainer2 = get_lightning_trainer(accumulate_grad_batches=1,
                                         max_steps=2,
                                         batch_size=6)
        trainer2.trainer.fit(trainer2.model, trainer2.data_module.train_loader)

        for param1, param2 in zip(trainer1.model.parameters(),
                                  trainer2.model.parameters()):
            self.assertTrue(torch.allclose(param1, param2))
コード例 #4
0
    def test_grad_clipping_and_parity_to_mmf(self):
        mmf_trainer = get_mmf_trainer(
            max_updates=5,
            max_epochs=None,
            grad_clipping_config=self.grad_clipping_config,
        )
        mmf_trainer.evaluation_loop = MagicMock(return_value=(None, None))

        def _finish_update():
            clip_gradients(
                mmf_trainer.model,
                mmf_trainer.optimizer,
                mmf_trainer.num_updates,
                None,
                mmf_trainer.config,
            )
            for param in mmf_trainer.model.parameters():
                mmf_grad = torch.clone(param.grad).detach().item()
                self.mmf_grads.append(mmf_grad)

            mmf_trainer.scaler.step(mmf_trainer.optimizer)
            mmf_trainer.scaler.update()
            mmf_trainer.num_updates += 1

        mmf_trainer._finish_update = _finish_update
        mmf_trainer.training_loop()

        trainer = get_lightning_trainer(
            max_steps=5,
            max_epochs=None,
            gradient_clip_val=self.grad_clip_magnitude,
            callback=self,
        )
        trainer.trainer.fit(trainer.model, trainer.data_module.train_loader)
コード例 #5
0
    def _get_lightning_trainer(
        self,
        ckpt_config=None,
        model_config=None,
        seed=2,
        max_steps=6,
        resume_from_checkpoint=None,
    ):
        config = self._get_ckpt_config(
            ckpt_config=ckpt_config,
            max_steps=max_steps,
            is_pl=True,
            resume_from_checkpoint=resume_from_checkpoint,
        )

        load_model_from_config = False
        if model_config:
            config.model_config = model_config
            config.model = list(model_config.keys())[0]
            load_model_from_config = True

        lightning = get_lightning_trainer(
            config=config,
            prepare_trainer=False,
            load_model_from_config=load_model_from_config,
            seed=seed,
        )
        callback = LightningLoopCallback(lightning)
        lightning.callbacks.append(callback)
        lightning.callbacks += lightning.configure_checkpoint_callbacks()
        lightning.callbacks += lightning.configure_monitor_callbacks()
        prepare_lightning_trainer(lightning)
        return lightning
コード例 #6
0
    def test_epoch_over_updates(self):
        trainer = get_lightning_trainer(max_steps=2, max_epochs=0.04)
        self.assertEqual(trainer._max_updates, 4)

        self._check_values(trainer, 0, 0)
        trainer.trainer.fit(trainer.model, trainer.data_module.train_loader)
        self._check_values(trainer, 4, 0)
コード例 #7
0
ファイル: test_logging.py プロジェクト: tranvinhhung99/mmf
    def test_tensorboard_logging_parity(
        self,
        summary_writer,
        mmf,
        lightning,
        logistics,
        logistics_logs,
        report_logs,
        trainer_logs,
        mkdirs,
    ):
        # mmf trainer
        mmf_trainer = get_mmf_trainer(
            max_updates=8,
            batch_size=2,
            max_epochs=None,
            log_interval=3,
            tensorboard=True,
        )

        def _add_scalars_mmf(log_dict, iteration):
            self.mmf_tensorboard_logs.append({iteration: log_dict})

        mmf_trainer.load_metrics()
        logistics_callback = LogisticsCallback(mmf_trainer.config, mmf_trainer)
        logistics_callback.snapshot_timer = MagicMock(return_value=None)
        logistics_callback.train_timer = Timer()
        logistics_callback.tb_writer.add_scalars = _add_scalars_mmf
        mmf_trainer.logistics_callback = logistics_callback
        mmf_trainer.callbacks = [logistics_callback]
        mmf_trainer.early_stop_callback = MagicMock(return_value=None)
        mmf_trainer.on_update_end = logistics_callback.on_update_end
        mmf_trainer.training_loop()

        # lightning_trainer
        trainer = get_lightning_trainer(
            max_steps=8,
            batch_size=2,
            prepare_trainer=False,
            log_every_n_steps=3,
            val_check_interval=9,
            tensorboard=True,
        )

        def _add_scalars_lightning(log_dict, iteration):
            self.lightning_tensorboard_logs.append({iteration: log_dict})

        def _on_fit_start_callback():
            trainer.tb_writer.add_scalars = _add_scalars_lightning

        callback = LightningLoopCallback(trainer)
        run_lightning_trainer_with_callback(
            trainer, callback, on_fit_start_callback=_on_fit_start_callback)
        self.assertEqual(len(self.mmf_tensorboard_logs),
                         len(self.lightning_tensorboard_logs))
        for mmf, lightning in zip(self.mmf_tensorboard_logs,
                                  self.lightning_tensorboard_logs):
            self.assertDictEqual(mmf, lightning)
コード例 #8
0
    def test_epoch_over_updates(self):
        with patch("mmf.trainers.lightning_trainer.get_mmf_env",
                   return_value=None):
            trainer = get_lightning_trainer(max_steps=2, max_epochs=0.04)
            self.assertEqual(trainer._max_updates, 4)

            self._check_values(trainer, 0, 0)
            trainer.trainer.fit(trainer.model,
                                trainer.data_module.train_loader)
            self._check_values(trainer, 4, 0)
コード例 #9
0
    def test_validation(self, log_dir, mkdirs):
        config = self._get_config(
            max_steps=8,
            batch_size=2,
            val_check_interval=3,
            log_every_n_steps=9,  # turn it off
            limit_val_batches=1.0,
        )
        trainer = get_lightning_trainer(config=config, prepare_trainer=False)
        callback = LightningLoopCallback(trainer)
        trainer.callbacks.append(callback)
        lightning_values = []

        def log_values(
            current_iteration: int,
            num_updates: int,
            max_updates: int,
            meter: Meter,
            extra: Dict[str, Any],
            tb_writer: TensorboardLogger,
        ):
            lightning_values.append({
                "current_iteration": current_iteration,
                "num_updates": num_updates,
                "max_updates": max_updates,
                "avg_loss": meter.loss.avg,
            })

        with patch(
                "mmf.trainers.lightning_core.loop_callback.summarize_report",
                side_effect=log_values,
        ):
            run_lightning_trainer(trainer)

        self.assertEqual(len(self.ground_truths), len(lightning_values))
        for gt, lv in zip(self.ground_truths, lightning_values):
            keys = list(gt.keys())
            self.assertListEqual(keys, list(lv.keys()))
            for key in keys:
                if key == "num_updates" and gt[key] == self.ground_truths[-1][
                        key]:
                    # After training, in the last evaluation run, mmf's num updates is 8
                    # while lightning's num updates is 9, this is due to a hack to
                    # assign the lightning num_updates to be the trainer.global_step+1.
                    #
                    # This is necessary because of a lightning bug: trainer.global_step
                    # is 1 off less than the actual step count. When on_train_batch_end
                    # is called for the first time, the trainer.global_step should be 1,
                    # rather than 0, since 1 update/step has already been done.
                    #
                    # When lightning fixes its bug, we will update this test to remove
                    # the hack. # issue: 6997 in pytorch lightning
                    self.assertAlmostEqual(gt[key], lv[key] - 1, 1)
                else:
                    self.assertAlmostEqual(gt[key], lv[key], 1)
コード例 #10
0
    def test_updates(self):
        with patch("mmf.trainers.lightning_trainer.get_mmf_env",
                   return_value=""):
            config = self._get_config(max_steps=2, max_epochs=None)
            trainer = get_lightning_trainer(config=config)
            self.assertEqual(trainer._max_updates, 2)

            self._check_values(trainer, 0, 0)
            trainer.trainer.fit(trainer.model,
                                trainer.data_module.train_loader)
            self._check_values(trainer, 2, 1)
コード例 #11
0
    def test_grad_accumulate(self):
        with patch("mmf.trainers.lightning_trainer.get_mmf_env",
                   return_value=""):
            config = self._get_config(accumulate_grad_batches=2,
                                      max_steps=2,
                                      batch_size=3)
            trainer1 = get_lightning_trainer(config=config)
            trainer1.trainer.fit(trainer1.model,
                                 trainer1.data_module.train_loader)

            config = self._get_config(accumulate_grad_batches=1,
                                      max_steps=2,
                                      batch_size=6)
            trainer2 = get_lightning_trainer(config=config)
            trainer2.trainer.fit(trainer2.model,
                                 trainer2.data_module.train_loader)

            for param1, param2 in zip(trainer1.model.parameters(),
                                      trainer2.model.parameters()):
                self.assertTrue(torch.allclose(param1, param2))
コード例 #12
0
ファイル: test_loss.py プロジェクト: locoalien/mmf
    def test_loss_computation_parity_with_mmf_trainer(self):
        # compute mmf_trainer training losses
        def _on_update_end(report, meter, should_log):
            self.mmf_losses.append(report["losses"]["loss"].item())

        mmf_trainer = get_mmf_trainer(max_updates=5,
                                      max_epochs=None,
                                      on_update_end_fn=_on_update_end)
        mmf_trainer.evaluation_loop = MagicMock(return_value=(None, None))
        mmf_trainer.training_loop()

        # compute lightning_trainer training losses
        trainer = get_lightning_trainer(callback=self, max_steps=5)
        trainer.trainer.fit(trainer.model, trainer.data_module.train_loader)
コード例 #13
0
    def test_lr_schedule_compared_to_mmf_is_same(self):
        trainer_config = get_trainer_config()
        mmf_trainer = get_mmf_trainer(
            max_updates=8,
            max_epochs=None,
            scheduler_config=trainer_config.scheduler)
        mmf_trainer.evaluation_loop = MagicMock(return_value=(None, None))
        mmf_trainer.training_loop()

        trainer = get_lightning_trainer(max_steps=8, lr_scheduler=True)
        trainer.trainer.fit(trainer.model, trainer.data_module.train_loader)

        mmf_trainer.model.to(trainer.model.device)
        last_model_param1 = list(mmf_trainer.model.parameters())[-1]
        last_model_param2 = list(trainer.model.parameters())[-1]
        self.assertTrue(torch.allclose(last_model_param1, last_model_param2))
コード例 #14
0
ファイル: test_loss.py プロジェクト: facebookresearch/mmf
    def test_loss_computation_parity_with_mmf_trainer(self):
        # compute mmf_trainer training losses
        def _on_update_end(report, meter, should_log):
            self.mmf_losses.append(report["losses"]["loss"].item())

        config = get_config_with_defaults(
            {"training": {"max_updates": 5, "max_epochs": None}}
        )
        mmf_trainer = get_mmf_trainer(config=config)
        mmf_trainer.on_update_end = _on_update_end
        mmf_trainer.evaluation_loop = MagicMock(return_value=(None, None))
        mmf_trainer.training_loop()

        # compute lightning_trainer training losses
        with patch("mmf.trainers.lightning_trainer.get_mmf_env", return_value=""):
            config = get_config_with_defaults({"trainer": {"params": {"max_steps": 5}}})
            trainer = get_lightning_trainer(config=config)
            trainer.callbacks.append(self)
            trainer.trainer.fit(trainer.model, trainer.data_module.train_loader)
コード例 #15
0
    def test_lr_schedule_compared_to_mmf_is_same(self):
        config = get_config_with_defaults(
            {"training": {"max_updates": 8, "max_epochs": None, "lr_scheduler": True}}
        )

        mmf_trainer = get_mmf_trainer(config=config)
        mmf_trainer.lr_scheduler_callback = LRSchedulerCallback(config, mmf_trainer)
        mmf_trainer.callbacks.append(mmf_trainer.lr_scheduler_callback)
        mmf_trainer.on_update_end = mmf_trainer.lr_scheduler_callback.on_update_end
        mmf_trainer.evaluation_loop = MagicMock(return_value=(None, None))
        mmf_trainer.training_loop()

        with patch("mmf.trainers.lightning_trainer.get_mmf_env", return_value=""):
            config = self._get_config(max_steps=8, lr_scheduler=True)
            trainer = get_lightning_trainer(config=config)
            trainer.trainer.fit(trainer.model, trainer.data_module.train_loader)

            mmf_trainer.model.to(trainer.model.device)
            last_model_param1 = list(mmf_trainer.model.parameters())[-1]
            last_model_param2 = list(trainer.model.parameters())[-1]
            self.assertTrue(torch.allclose(last_model_param1, last_model_param2))
コード例 #16
0
    def test_validation(self, log_dir, mkdirs):
        trainer = get_lightning_trainer(
            max_steps=8,
            batch_size=2,
            prepare_trainer=False,
            val_check_interval=3,
            log_every_n_steps=9,  # turn it off
            limit_val_batches=1.0,
        )
        callback = LightningLoopCallback(trainer)
        lightning_values = []

        def log_values(
            current_iteration: int,
            num_updates: int,
            max_updates: int,
            meter: Meter,
            extra: Dict[str, Any],
            tb_writer: TensorboardLogger,
        ):
            lightning_values.append({
                "current_iteration": current_iteration,
                "num_updates": num_updates,
                "max_updates": max_updates,
                "avg_loss": meter.loss.avg,
            })

        with patch(
                "mmf.trainers.lightning_core.loop_callback.summarize_report",
                side_effect=log_values,
        ):
            run_lightning_trainer_with_callback(trainer, callback)

        self.assertEqual(len(self.ground_truths), len(lightning_values))
        for gt, lv in zip(self.ground_truths, lightning_values):
            keys = list(gt.keys())
            self.assertListEqual(keys, list(lv.keys()))
            for key in keys:
                self.assertAlmostEqual(gt[key], lv[key], 1)