def _get_mmf_trainer(self, ckpt_config=None, model_config=None, seed=2, max_updates=6): config = self._get_ckpt_config(ckpt_config=ckpt_config, max_steps=max_updates) load_model_from_config = False if model_config: config.model_config = model_config config.model = list(model_config.keys())[0] load_model_from_config = True mmf_trainer = get_mmf_trainer( config=config, load_model_from_config=load_model_from_config, seed=seed) mmf_trainer.load_metrics() checkpoint_callback = CheckpointCallback(config, mmf_trainer) mmf_trainer.on_init_start = checkpoint_callback.on_init_start mmf_trainer.on_train_end = checkpoint_callback.on_train_end mmf_trainer.callbacks.append(checkpoint_callback) mmf_trainer.checkpoint_callback = checkpoint_callback mmf_trainer.lr_scheduler_callback = None early_stop_callback = EarlyStoppingCallback(config, mmf_trainer) mmf_trainer.early_stop_callback = early_stop_callback mmf_trainer.callbacks.append(early_stop_callback) return mmf_trainer
def test_grad_clipping_and_parity_to_mmf(self): mmf_trainer = get_mmf_trainer( max_updates=5, max_epochs=None, grad_clipping_config=self.grad_clipping_config, ) mmf_trainer.evaluation_loop = MagicMock(return_value=(None, None)) def _finish_update(): clip_gradients( mmf_trainer.model, mmf_trainer.optimizer, mmf_trainer.num_updates, None, mmf_trainer.config, ) for param in mmf_trainer.model.parameters(): mmf_grad = torch.clone(param.grad).detach().item() self.mmf_grads.append(mmf_grad) mmf_trainer.scaler.step(mmf_trainer.optimizer) mmf_trainer.scaler.update() mmf_trainer.num_updates += 1 mmf_trainer._finish_update = _finish_update mmf_trainer.training_loop() trainer = get_lightning_trainer( max_steps=5, max_epochs=None, gradient_clip_val=self.grad_clip_magnitude, callback=self, ) trainer.trainer.fit(trainer.model, trainer.data_module.train_loader)
def test_eval_loop(self, a, b): config = get_config_with_defaults( {"training": {"max_updates": 2, "max_epochs": 2}} ) trainer = get_mmf_trainer(config=config) combined_report, meter = trainer.evaluation_loop("val") self.assertAlmostEqual(combined_report["losses"]["loss"], 493377.5312) self.assertAlmostEqual(combined_report["logits"].item(), -0.2379742, 6)
def test_tensorboard_logging_parity( self, summary_writer, mmf, lightning, logistics, logistics_logs, report_logs, trainer_logs, mkdirs, ): # mmf trainer mmf_trainer = get_mmf_trainer( max_updates=8, batch_size=2, max_epochs=None, log_interval=3, tensorboard=True, ) def _add_scalars_mmf(log_dict, iteration): self.mmf_tensorboard_logs.append({iteration: log_dict}) mmf_trainer.load_metrics() logistics_callback = LogisticsCallback(mmf_trainer.config, mmf_trainer) logistics_callback.snapshot_timer = MagicMock(return_value=None) logistics_callback.train_timer = Timer() logistics_callback.tb_writer.add_scalars = _add_scalars_mmf mmf_trainer.logistics_callback = logistics_callback mmf_trainer.callbacks = [logistics_callback] mmf_trainer.early_stop_callback = MagicMock(return_value=None) mmf_trainer.on_update_end = logistics_callback.on_update_end mmf_trainer.training_loop() # lightning_trainer trainer = get_lightning_trainer( max_steps=8, batch_size=2, prepare_trainer=False, log_every_n_steps=3, val_check_interval=9, tensorboard=True, ) def _add_scalars_lightning(log_dict, iteration): self.lightning_tensorboard_logs.append({iteration: log_dict}) def _on_fit_start_callback(): trainer.tb_writer.add_scalars = _add_scalars_lightning callback = LightningLoopCallback(trainer) run_lightning_trainer_with_callback( trainer, callback, on_fit_start_callback=_on_fit_start_callback) self.assertEqual(len(self.mmf_tensorboard_logs), len(self.lightning_tensorboard_logs)) for mmf, lightning in zip(self.mmf_tensorboard_logs, self.lightning_tensorboard_logs): self.assertDictEqual(mmf, lightning)
def test_update_frequency_correct_final_iteration(self, a): config = self._get_config(max_updates=2, max_epochs=None, update_frequency=2) trainer = get_mmf_trainer(config=config) trainer.load_datasets() trainer.training_loop() self.assertEqual(trainer.max_updates, 2) self.assertEqual(trainer.current_iteration, 4)
def test_updates(self, a): config = self._get_config(max_updates=2, max_epochs=None) trainer = get_mmf_trainer(config=config) max_updates = trainer._calculate_max_updates() self.assertEqual(max_updates, 2) self.check_values(trainer, 0, 0, 0) trainer.training_loop() self.check_values(trainer, 2, 1, 2)
def test_fractional_epoch(self, a): config = self._get_config(max_updates=None, max_epochs=0.04) trainer = get_mmf_trainer(config=config) max_updates = trainer._calculate_max_updates() self.assertEqual(max_updates, 4) self.check_values(trainer, 0, 0, 0) trainer.training_loop() self.check_values(trainer, 4, 1, 4)
def test_loss_computation_parity_with_mmf_trainer(self): # compute mmf_trainer training losses def _on_update_end(report, meter, should_log): self.mmf_losses.append(report["losses"]["loss"].item()) mmf_trainer = get_mmf_trainer(max_updates=5, max_epochs=None, on_update_end_fn=_on_update_end) mmf_trainer.evaluation_loop = MagicMock(return_value=(None, None)) mmf_trainer.training_loop() # compute lightning_trainer training losses trainer = get_lightning_trainer(callback=self, max_steps=5) trainer.trainer.fit(trainer.model, trainer.data_module.train_loader)
def test_lr_schedule_compared_to_mmf_is_same(self): trainer_config = get_trainer_config() mmf_trainer = get_mmf_trainer( max_updates=8, max_epochs=None, scheduler_config=trainer_config.scheduler) mmf_trainer.evaluation_loop = MagicMock(return_value=(None, None)) mmf_trainer.training_loop() trainer = get_lightning_trainer(max_steps=8, lr_scheduler=True) trainer.trainer.fit(trainer.model, trainer.data_module.train_loader) mmf_trainer.model.to(trainer.model.device) last_model_param1 = list(mmf_trainer.model.parameters())[-1] last_model_param2 = list(trainer.model.parameters())[-1] self.assertTrue(torch.allclose(last_model_param1, last_model_param2))
def test_loss_computation_parity_with_mmf_trainer(self): # compute mmf_trainer training losses def _on_update_end(report, meter, should_log): self.mmf_losses.append(report["losses"]["loss"].item()) config = get_config_with_defaults( {"training": {"max_updates": 5, "max_epochs": None}} ) mmf_trainer = get_mmf_trainer(config=config) mmf_trainer.on_update_end = _on_update_end mmf_trainer.evaluation_loop = MagicMock(return_value=(None, None)) mmf_trainer.training_loop() # compute lightning_trainer training losses with patch("mmf.trainers.lightning_trainer.get_mmf_env", return_value=""): config = get_config_with_defaults({"trainer": {"params": {"max_steps": 5}}}) trainer = get_lightning_trainer(config=config) trainer.callbacks.append(self) trainer.trainer.fit(trainer.model, trainer.data_module.train_loader)
def test_validation_parity(self, summarize_report_fn, test_reporter, sw, mkdirs): mmf_trainer = get_mmf_trainer(max_updates=8, batch_size=2, max_epochs=None, evaluation_interval=3) mmf_trainer.load_metrics() logistics_callback = LogisticsCallback(mmf_trainer.config, mmf_trainer) logistics_callback.snapshot_timer = Timer() logistics_callback.train_timer = Timer() mmf_trainer.logistics_callback = logistics_callback mmf_trainer.callbacks.append(logistics_callback) mmf_trainer.early_stop_callback = MagicMock(return_value=None) mmf_trainer.on_validation_end = logistics_callback.on_validation_end mmf_trainer.training_loop() calls = summarize_report_fn.call_args_list self.assertEqual(3, len(calls)) self.assertEqual(len(self.ground_truths), len(calls)) self._check_values(calls)
def _train_with_condition( self, num_train_data, max_updates, max_epochs, update_frequency, batch_size, on_update_end_fn=None, ): torch.random.manual_seed(2) config = self._get_config( max_updates=max_updates, max_epochs=max_epochs, update_frequency=update_frequency, batch_size=batch_size, ) trainer = get_mmf_trainer(num_data_size=num_train_data, config=config) if on_update_end_fn: trainer.on_update_end = on_update_end_fn trainer.training_loop() return trainer
def test_lr_schedule_compared_to_mmf_is_same(self): config = get_config_with_defaults( {"training": {"max_updates": 8, "max_epochs": None, "lr_scheduler": True}} ) mmf_trainer = get_mmf_trainer(config=config) mmf_trainer.lr_scheduler_callback = LRSchedulerCallback(config, mmf_trainer) mmf_trainer.callbacks.append(mmf_trainer.lr_scheduler_callback) mmf_trainer.on_update_end = mmf_trainer.lr_scheduler_callback.on_update_end mmf_trainer.evaluation_loop = MagicMock(return_value=(None, None)) mmf_trainer.training_loop() with patch("mmf.trainers.lightning_trainer.get_mmf_env", return_value=""): config = self._get_config(max_steps=8, lr_scheduler=True) trainer = get_lightning_trainer(config=config) trainer.trainer.fit(trainer.model, trainer.data_module.train_loader) mmf_trainer.model.to(trainer.model.device) last_model_param1 = list(mmf_trainer.model.parameters())[-1] last_model_param2 = list(trainer.model.parameters())[-1] self.assertTrue(torch.allclose(last_model_param1, last_model_param2))