def test_lr_schedule(self): # note, be aware some of the logic also is in the SimpleLightningModel trainer1 = get_lightning_trainer(max_steps=8, lr_scheduler=True) trainer1.trainer.fit(trainer1.model, trainer1.data_module.train_loader) trainer2 = get_lightning_trainer(max_steps=8) trainer2.trainer.fit(trainer2.model, trainer2.data_module.train_loader) last_model_param1 = list(trainer1.model.parameters())[-1] last_model_param2 = list(trainer2.model.parameters())[-1] self.assertFalse(torch.allclose(last_model_param1, last_model_param2))
def test_lr_schedule(self): with patch("mmf.trainers.lightning_trainer.get_mmf_env", return_value=""): # note, be aware some of the logic also is in the SimpleLightningModel config = self._get_config(max_steps=8, lr_scheduler=True) trainer1 = get_lightning_trainer(config=config) trainer1.trainer.fit(trainer1.model, trainer1.data_module.train_loader) config = self._get_config(max_steps=8) trainer2 = get_lightning_trainer(config=config) trainer2.trainer.fit(trainer2.model, trainer2.data_module.train_loader) last_model_param1 = list(trainer1.model.parameters())[-1] last_model_param2 = list(trainer2.model.parameters())[-1] self.assertFalse(torch.allclose(last_model_param1, last_model_param2))
def test_grad_accumulate(self): trainer1 = get_lightning_trainer(accumulate_grad_batches=2, max_steps=2, batch_size=3) trainer1.trainer.fit(trainer1.model, trainer1.data_module.train_loader) trainer2 = get_lightning_trainer(accumulate_grad_batches=1, max_steps=2, batch_size=6) trainer2.trainer.fit(trainer2.model, trainer2.data_module.train_loader) for param1, param2 in zip(trainer1.model.parameters(), trainer2.model.parameters()): self.assertTrue(torch.allclose(param1, param2))
def test_grad_clipping_and_parity_to_mmf(self): mmf_trainer = get_mmf_trainer( max_updates=5, max_epochs=None, grad_clipping_config=self.grad_clipping_config, ) mmf_trainer.evaluation_loop = MagicMock(return_value=(None, None)) def _finish_update(): clip_gradients( mmf_trainer.model, mmf_trainer.optimizer, mmf_trainer.num_updates, None, mmf_trainer.config, ) for param in mmf_trainer.model.parameters(): mmf_grad = torch.clone(param.grad).detach().item() self.mmf_grads.append(mmf_grad) mmf_trainer.scaler.step(mmf_trainer.optimizer) mmf_trainer.scaler.update() mmf_trainer.num_updates += 1 mmf_trainer._finish_update = _finish_update mmf_trainer.training_loop() trainer = get_lightning_trainer( max_steps=5, max_epochs=None, gradient_clip_val=self.grad_clip_magnitude, callback=self, ) trainer.trainer.fit(trainer.model, trainer.data_module.train_loader)
def _get_lightning_trainer( self, ckpt_config=None, model_config=None, seed=2, max_steps=6, resume_from_checkpoint=None, ): config = self._get_ckpt_config( ckpt_config=ckpt_config, max_steps=max_steps, is_pl=True, resume_from_checkpoint=resume_from_checkpoint, ) load_model_from_config = False if model_config: config.model_config = model_config config.model = list(model_config.keys())[0] load_model_from_config = True lightning = get_lightning_trainer( config=config, prepare_trainer=False, load_model_from_config=load_model_from_config, seed=seed, ) callback = LightningLoopCallback(lightning) lightning.callbacks.append(callback) lightning.callbacks += lightning.configure_checkpoint_callbacks() lightning.callbacks += lightning.configure_monitor_callbacks() prepare_lightning_trainer(lightning) return lightning
def test_epoch_over_updates(self): trainer = get_lightning_trainer(max_steps=2, max_epochs=0.04) self.assertEqual(trainer._max_updates, 4) self._check_values(trainer, 0, 0) trainer.trainer.fit(trainer.model, trainer.data_module.train_loader) self._check_values(trainer, 4, 0)
def test_tensorboard_logging_parity( self, summary_writer, mmf, lightning, logistics, logistics_logs, report_logs, trainer_logs, mkdirs, ): # mmf trainer mmf_trainer = get_mmf_trainer( max_updates=8, batch_size=2, max_epochs=None, log_interval=3, tensorboard=True, ) def _add_scalars_mmf(log_dict, iteration): self.mmf_tensorboard_logs.append({iteration: log_dict}) mmf_trainer.load_metrics() logistics_callback = LogisticsCallback(mmf_trainer.config, mmf_trainer) logistics_callback.snapshot_timer = MagicMock(return_value=None) logistics_callback.train_timer = Timer() logistics_callback.tb_writer.add_scalars = _add_scalars_mmf mmf_trainer.logistics_callback = logistics_callback mmf_trainer.callbacks = [logistics_callback] mmf_trainer.early_stop_callback = MagicMock(return_value=None) mmf_trainer.on_update_end = logistics_callback.on_update_end mmf_trainer.training_loop() # lightning_trainer trainer = get_lightning_trainer( max_steps=8, batch_size=2, prepare_trainer=False, log_every_n_steps=3, val_check_interval=9, tensorboard=True, ) def _add_scalars_lightning(log_dict, iteration): self.lightning_tensorboard_logs.append({iteration: log_dict}) def _on_fit_start_callback(): trainer.tb_writer.add_scalars = _add_scalars_lightning callback = LightningLoopCallback(trainer) run_lightning_trainer_with_callback( trainer, callback, on_fit_start_callback=_on_fit_start_callback) self.assertEqual(len(self.mmf_tensorboard_logs), len(self.lightning_tensorboard_logs)) for mmf, lightning in zip(self.mmf_tensorboard_logs, self.lightning_tensorboard_logs): self.assertDictEqual(mmf, lightning)
def test_epoch_over_updates(self): with patch("mmf.trainers.lightning_trainer.get_mmf_env", return_value=None): trainer = get_lightning_trainer(max_steps=2, max_epochs=0.04) self.assertEqual(trainer._max_updates, 4) self._check_values(trainer, 0, 0) trainer.trainer.fit(trainer.model, trainer.data_module.train_loader) self._check_values(trainer, 4, 0)
def test_validation(self, log_dir, mkdirs): config = self._get_config( max_steps=8, batch_size=2, val_check_interval=3, log_every_n_steps=9, # turn it off limit_val_batches=1.0, ) trainer = get_lightning_trainer(config=config, prepare_trainer=False) callback = LightningLoopCallback(trainer) trainer.callbacks.append(callback) lightning_values = [] def log_values( current_iteration: int, num_updates: int, max_updates: int, meter: Meter, extra: Dict[str, Any], tb_writer: TensorboardLogger, ): lightning_values.append({ "current_iteration": current_iteration, "num_updates": num_updates, "max_updates": max_updates, "avg_loss": meter.loss.avg, }) with patch( "mmf.trainers.lightning_core.loop_callback.summarize_report", side_effect=log_values, ): run_lightning_trainer(trainer) self.assertEqual(len(self.ground_truths), len(lightning_values)) for gt, lv in zip(self.ground_truths, lightning_values): keys = list(gt.keys()) self.assertListEqual(keys, list(lv.keys())) for key in keys: if key == "num_updates" and gt[key] == self.ground_truths[-1][ key]: # After training, in the last evaluation run, mmf's num updates is 8 # while lightning's num updates is 9, this is due to a hack to # assign the lightning num_updates to be the trainer.global_step+1. # # This is necessary because of a lightning bug: trainer.global_step # is 1 off less than the actual step count. When on_train_batch_end # is called for the first time, the trainer.global_step should be 1, # rather than 0, since 1 update/step has already been done. # # When lightning fixes its bug, we will update this test to remove # the hack. # issue: 6997 in pytorch lightning self.assertAlmostEqual(gt[key], lv[key] - 1, 1) else: self.assertAlmostEqual(gt[key], lv[key], 1)
def test_updates(self): with patch("mmf.trainers.lightning_trainer.get_mmf_env", return_value=""): config = self._get_config(max_steps=2, max_epochs=None) trainer = get_lightning_trainer(config=config) self.assertEqual(trainer._max_updates, 2) self._check_values(trainer, 0, 0) trainer.trainer.fit(trainer.model, trainer.data_module.train_loader) self._check_values(trainer, 2, 1)
def test_grad_accumulate(self): with patch("mmf.trainers.lightning_trainer.get_mmf_env", return_value=""): config = self._get_config(accumulate_grad_batches=2, max_steps=2, batch_size=3) trainer1 = get_lightning_trainer(config=config) trainer1.trainer.fit(trainer1.model, trainer1.data_module.train_loader) config = self._get_config(accumulate_grad_batches=1, max_steps=2, batch_size=6) trainer2 = get_lightning_trainer(config=config) trainer2.trainer.fit(trainer2.model, trainer2.data_module.train_loader) for param1, param2 in zip(trainer1.model.parameters(), trainer2.model.parameters()): self.assertTrue(torch.allclose(param1, param2))
def test_loss_computation_parity_with_mmf_trainer(self): # compute mmf_trainer training losses def _on_update_end(report, meter, should_log): self.mmf_losses.append(report["losses"]["loss"].item()) mmf_trainer = get_mmf_trainer(max_updates=5, max_epochs=None, on_update_end_fn=_on_update_end) mmf_trainer.evaluation_loop = MagicMock(return_value=(None, None)) mmf_trainer.training_loop() # compute lightning_trainer training losses trainer = get_lightning_trainer(callback=self, max_steps=5) trainer.trainer.fit(trainer.model, trainer.data_module.train_loader)
def test_lr_schedule_compared_to_mmf_is_same(self): trainer_config = get_trainer_config() mmf_trainer = get_mmf_trainer( max_updates=8, max_epochs=None, scheduler_config=trainer_config.scheduler) mmf_trainer.evaluation_loop = MagicMock(return_value=(None, None)) mmf_trainer.training_loop() trainer = get_lightning_trainer(max_steps=8, lr_scheduler=True) trainer.trainer.fit(trainer.model, trainer.data_module.train_loader) mmf_trainer.model.to(trainer.model.device) last_model_param1 = list(mmf_trainer.model.parameters())[-1] last_model_param2 = list(trainer.model.parameters())[-1] self.assertTrue(torch.allclose(last_model_param1, last_model_param2))
def test_loss_computation_parity_with_mmf_trainer(self): # compute mmf_trainer training losses def _on_update_end(report, meter, should_log): self.mmf_losses.append(report["losses"]["loss"].item()) config = get_config_with_defaults( {"training": {"max_updates": 5, "max_epochs": None}} ) mmf_trainer = get_mmf_trainer(config=config) mmf_trainer.on_update_end = _on_update_end mmf_trainer.evaluation_loop = MagicMock(return_value=(None, None)) mmf_trainer.training_loop() # compute lightning_trainer training losses with patch("mmf.trainers.lightning_trainer.get_mmf_env", return_value=""): config = get_config_with_defaults({"trainer": {"params": {"max_steps": 5}}}) trainer = get_lightning_trainer(config=config) trainer.callbacks.append(self) trainer.trainer.fit(trainer.model, trainer.data_module.train_loader)
def test_lr_schedule_compared_to_mmf_is_same(self): config = get_config_with_defaults( {"training": {"max_updates": 8, "max_epochs": None, "lr_scheduler": True}} ) mmf_trainer = get_mmf_trainer(config=config) mmf_trainer.lr_scheduler_callback = LRSchedulerCallback(config, mmf_trainer) mmf_trainer.callbacks.append(mmf_trainer.lr_scheduler_callback) mmf_trainer.on_update_end = mmf_trainer.lr_scheduler_callback.on_update_end mmf_trainer.evaluation_loop = MagicMock(return_value=(None, None)) mmf_trainer.training_loop() with patch("mmf.trainers.lightning_trainer.get_mmf_env", return_value=""): config = self._get_config(max_steps=8, lr_scheduler=True) trainer = get_lightning_trainer(config=config) trainer.trainer.fit(trainer.model, trainer.data_module.train_loader) mmf_trainer.model.to(trainer.model.device) last_model_param1 = list(mmf_trainer.model.parameters())[-1] last_model_param2 = list(trainer.model.parameters())[-1] self.assertTrue(torch.allclose(last_model_param1, last_model_param2))
def test_validation(self, log_dir, mkdirs): trainer = get_lightning_trainer( max_steps=8, batch_size=2, prepare_trainer=False, val_check_interval=3, log_every_n_steps=9, # turn it off limit_val_batches=1.0, ) callback = LightningLoopCallback(trainer) lightning_values = [] def log_values( current_iteration: int, num_updates: int, max_updates: int, meter: Meter, extra: Dict[str, Any], tb_writer: TensorboardLogger, ): lightning_values.append({ "current_iteration": current_iteration, "num_updates": num_updates, "max_updates": max_updates, "avg_loss": meter.loss.avg, }) with patch( "mmf.trainers.lightning_core.loop_callback.summarize_report", side_effect=log_values, ): run_lightning_trainer_with_callback(trainer, callback) self.assertEqual(len(self.ground_truths), len(lightning_values)) for gt, lv in zip(self.ground_truths, lightning_values): keys = list(gt.keys()) self.assertListEqual(keys, list(lv.keys())) for key in keys: self.assertAlmostEqual(gt[key], lv[key], 1)