Example #1
0
    def _get_mmf_trainer(self,
                         ckpt_config=None,
                         model_config=None,
                         seed=2,
                         max_updates=6):
        config = self._get_ckpt_config(ckpt_config=ckpt_config,
                                       max_steps=max_updates)

        load_model_from_config = False
        if model_config:
            config.model_config = model_config
            config.model = list(model_config.keys())[0]
            load_model_from_config = True

        mmf_trainer = get_mmf_trainer(
            config=config,
            load_model_from_config=load_model_from_config,
            seed=seed)
        mmf_trainer.load_metrics()

        checkpoint_callback = CheckpointCallback(config, mmf_trainer)
        mmf_trainer.on_init_start = checkpoint_callback.on_init_start
        mmf_trainer.on_train_end = checkpoint_callback.on_train_end
        mmf_trainer.callbacks.append(checkpoint_callback)
        mmf_trainer.checkpoint_callback = checkpoint_callback

        mmf_trainer.lr_scheduler_callback = None

        early_stop_callback = EarlyStoppingCallback(config, mmf_trainer)
        mmf_trainer.early_stop_callback = early_stop_callback
        mmf_trainer.callbacks.append(early_stop_callback)

        return mmf_trainer
Example #2
0
    def test_grad_clipping_and_parity_to_mmf(self):
        mmf_trainer = get_mmf_trainer(
            max_updates=5,
            max_epochs=None,
            grad_clipping_config=self.grad_clipping_config,
        )
        mmf_trainer.evaluation_loop = MagicMock(return_value=(None, None))

        def _finish_update():
            clip_gradients(
                mmf_trainer.model,
                mmf_trainer.optimizer,
                mmf_trainer.num_updates,
                None,
                mmf_trainer.config,
            )
            for param in mmf_trainer.model.parameters():
                mmf_grad = torch.clone(param.grad).detach().item()
                self.mmf_grads.append(mmf_grad)

            mmf_trainer.scaler.step(mmf_trainer.optimizer)
            mmf_trainer.scaler.update()
            mmf_trainer.num_updates += 1

        mmf_trainer._finish_update = _finish_update
        mmf_trainer.training_loop()

        trainer = get_lightning_trainer(
            max_steps=5,
            max_epochs=None,
            gradient_clip_val=self.grad_clip_magnitude,
            callback=self,
        )
        trainer.trainer.fit(trainer.model, trainer.data_module.train_loader)
Example #3
0
 def test_eval_loop(self, a, b):
     config = get_config_with_defaults(
         {"training": {"max_updates": 2, "max_epochs": 2}}
     )
     trainer = get_mmf_trainer(config=config)
     combined_report, meter = trainer.evaluation_loop("val")
     self.assertAlmostEqual(combined_report["losses"]["loss"], 493377.5312)
     self.assertAlmostEqual(combined_report["logits"].item(), -0.2379742, 6)
Example #4
0
    def test_tensorboard_logging_parity(
        self,
        summary_writer,
        mmf,
        lightning,
        logistics,
        logistics_logs,
        report_logs,
        trainer_logs,
        mkdirs,
    ):
        # mmf trainer
        mmf_trainer = get_mmf_trainer(
            max_updates=8,
            batch_size=2,
            max_epochs=None,
            log_interval=3,
            tensorboard=True,
        )

        def _add_scalars_mmf(log_dict, iteration):
            self.mmf_tensorboard_logs.append({iteration: log_dict})

        mmf_trainer.load_metrics()
        logistics_callback = LogisticsCallback(mmf_trainer.config, mmf_trainer)
        logistics_callback.snapshot_timer = MagicMock(return_value=None)
        logistics_callback.train_timer = Timer()
        logistics_callback.tb_writer.add_scalars = _add_scalars_mmf
        mmf_trainer.logistics_callback = logistics_callback
        mmf_trainer.callbacks = [logistics_callback]
        mmf_trainer.early_stop_callback = MagicMock(return_value=None)
        mmf_trainer.on_update_end = logistics_callback.on_update_end
        mmf_trainer.training_loop()

        # lightning_trainer
        trainer = get_lightning_trainer(
            max_steps=8,
            batch_size=2,
            prepare_trainer=False,
            log_every_n_steps=3,
            val_check_interval=9,
            tensorboard=True,
        )

        def _add_scalars_lightning(log_dict, iteration):
            self.lightning_tensorboard_logs.append({iteration: log_dict})

        def _on_fit_start_callback():
            trainer.tb_writer.add_scalars = _add_scalars_lightning

        callback = LightningLoopCallback(trainer)
        run_lightning_trainer_with_callback(
            trainer, callback, on_fit_start_callback=_on_fit_start_callback)
        self.assertEqual(len(self.mmf_tensorboard_logs),
                         len(self.lightning_tensorboard_logs))
        for mmf, lightning in zip(self.mmf_tensorboard_logs,
                                  self.lightning_tensorboard_logs):
            self.assertDictEqual(mmf, lightning)
 def test_update_frequency_correct_final_iteration(self, a):
     config = self._get_config(max_updates=2,
                               max_epochs=None,
                               update_frequency=2)
     trainer = get_mmf_trainer(config=config)
     trainer.load_datasets()
     trainer.training_loop()
     self.assertEqual(trainer.max_updates, 2)
     self.assertEqual(trainer.current_iteration, 4)
    def test_updates(self, a):
        config = self._get_config(max_updates=2, max_epochs=None)
        trainer = get_mmf_trainer(config=config)
        max_updates = trainer._calculate_max_updates()
        self.assertEqual(max_updates, 2)

        self.check_values(trainer, 0, 0, 0)
        trainer.training_loop()
        self.check_values(trainer, 2, 1, 2)
    def test_fractional_epoch(self, a):
        config = self._get_config(max_updates=None, max_epochs=0.04)
        trainer = get_mmf_trainer(config=config)
        max_updates = trainer._calculate_max_updates()
        self.assertEqual(max_updates, 4)

        self.check_values(trainer, 0, 0, 0)
        trainer.training_loop()
        self.check_values(trainer, 4, 1, 4)
Example #8
0
    def test_loss_computation_parity_with_mmf_trainer(self):
        # compute mmf_trainer training losses
        def _on_update_end(report, meter, should_log):
            self.mmf_losses.append(report["losses"]["loss"].item())

        mmf_trainer = get_mmf_trainer(max_updates=5,
                                      max_epochs=None,
                                      on_update_end_fn=_on_update_end)
        mmf_trainer.evaluation_loop = MagicMock(return_value=(None, None))
        mmf_trainer.training_loop()

        # compute lightning_trainer training losses
        trainer = get_lightning_trainer(callback=self, max_steps=5)
        trainer.trainer.fit(trainer.model, trainer.data_module.train_loader)
Example #9
0
    def test_lr_schedule_compared_to_mmf_is_same(self):
        trainer_config = get_trainer_config()
        mmf_trainer = get_mmf_trainer(
            max_updates=8,
            max_epochs=None,
            scheduler_config=trainer_config.scheduler)
        mmf_trainer.evaluation_loop = MagicMock(return_value=(None, None))
        mmf_trainer.training_loop()

        trainer = get_lightning_trainer(max_steps=8, lr_scheduler=True)
        trainer.trainer.fit(trainer.model, trainer.data_module.train_loader)

        mmf_trainer.model.to(trainer.model.device)
        last_model_param1 = list(mmf_trainer.model.parameters())[-1]
        last_model_param2 = list(trainer.model.parameters())[-1]
        self.assertTrue(torch.allclose(last_model_param1, last_model_param2))
Example #10
0
    def test_loss_computation_parity_with_mmf_trainer(self):
        # compute mmf_trainer training losses
        def _on_update_end(report, meter, should_log):
            self.mmf_losses.append(report["losses"]["loss"].item())

        config = get_config_with_defaults(
            {"training": {"max_updates": 5, "max_epochs": None}}
        )
        mmf_trainer = get_mmf_trainer(config=config)
        mmf_trainer.on_update_end = _on_update_end
        mmf_trainer.evaluation_loop = MagicMock(return_value=(None, None))
        mmf_trainer.training_loop()

        # compute lightning_trainer training losses
        with patch("mmf.trainers.lightning_trainer.get_mmf_env", return_value=""):
            config = get_config_with_defaults({"trainer": {"params": {"max_steps": 5}}})
            trainer = get_lightning_trainer(config=config)
            trainer.callbacks.append(self)
            trainer.trainer.fit(trainer.model, trainer.data_module.train_loader)
Example #11
0
    def test_validation_parity(self, summarize_report_fn, test_reporter, sw,
                               mkdirs):
        mmf_trainer = get_mmf_trainer(max_updates=8,
                                      batch_size=2,
                                      max_epochs=None,
                                      evaluation_interval=3)
        mmf_trainer.load_metrics()
        logistics_callback = LogisticsCallback(mmf_trainer.config, mmf_trainer)
        logistics_callback.snapshot_timer = Timer()
        logistics_callback.train_timer = Timer()
        mmf_trainer.logistics_callback = logistics_callback
        mmf_trainer.callbacks.append(logistics_callback)
        mmf_trainer.early_stop_callback = MagicMock(return_value=None)
        mmf_trainer.on_validation_end = logistics_callback.on_validation_end
        mmf_trainer.training_loop()

        calls = summarize_report_fn.call_args_list
        self.assertEqual(3, len(calls))
        self.assertEqual(len(self.ground_truths), len(calls))
        self._check_values(calls)
 def _train_with_condition(
     self,
     num_train_data,
     max_updates,
     max_epochs,
     update_frequency,
     batch_size,
     on_update_end_fn=None,
 ):
     torch.random.manual_seed(2)
     config = self._get_config(
         max_updates=max_updates,
         max_epochs=max_epochs,
         update_frequency=update_frequency,
         batch_size=batch_size,
     )
     trainer = get_mmf_trainer(num_data_size=num_train_data, config=config)
     if on_update_end_fn:
         trainer.on_update_end = on_update_end_fn
     trainer.training_loop()
     return trainer
Example #13
0
    def test_lr_schedule_compared_to_mmf_is_same(self):
        config = get_config_with_defaults(
            {"training": {"max_updates": 8, "max_epochs": None, "lr_scheduler": True}}
        )

        mmf_trainer = get_mmf_trainer(config=config)
        mmf_trainer.lr_scheduler_callback = LRSchedulerCallback(config, mmf_trainer)
        mmf_trainer.callbacks.append(mmf_trainer.lr_scheduler_callback)
        mmf_trainer.on_update_end = mmf_trainer.lr_scheduler_callback.on_update_end
        mmf_trainer.evaluation_loop = MagicMock(return_value=(None, None))
        mmf_trainer.training_loop()

        with patch("mmf.trainers.lightning_trainer.get_mmf_env", return_value=""):
            config = self._get_config(max_steps=8, lr_scheduler=True)
            trainer = get_lightning_trainer(config=config)
            trainer.trainer.fit(trainer.model, trainer.data_module.train_loader)

            mmf_trainer.model.to(trainer.model.device)
            last_model_param1 = list(mmf_trainer.model.parameters())[-1]
            last_model_param2 = list(trainer.model.parameters())[-1]
            self.assertTrue(torch.allclose(last_model_param1, last_model_param2))