def get_mmf_trainer( model_size=1, num_data_size=100, max_updates=5, max_epochs=None, on_update_end_fn=None, fp16=False, scheduler_config=None, grad_clipping_config=None, ): torch.random.manual_seed(2) model = SimpleModel({"in_dim": model_size}) model.build() model.train() trainer_config = get_trainer_config() optimizer = build_optimizer(model, trainer_config) trainer = TrainerTrainingLoopMock( num_data_size, max_updates, max_epochs, config=trainer_config, optimizer=optimizer, on_update_end_fn=on_update_end_fn, fp16=fp16, scheduler_config=scheduler_config, grad_clipping_config=grad_clipping_config, ) trainer.load_datasets() model.to(trainer.device) trainer.model = model return trainer
def _train_with_condition( self, num_train_data, max_updates, max_epochs, update_frequency, batch_size, on_update_end_fn=None, ): torch.random.manual_seed(2) model = SimpleModel({"in_dim": 1}) model.build() opt = torch.optim.SGD(model.parameters(), lr=0.01) trainer = TrainerTrainingLoopMock( num_train_data, max_updates, max_epochs, optimizer=opt, update_frequency=update_frequency, batch_size=batch_size, ) trainer.load_datasets() if on_update_end_fn: trainer.on_update_end = on_update_end_fn model.to(trainer.device) trainer.model = model trainer.training_loop() return trainer
def test_updates(self): trainer = TrainerTrainingLoopMock(100, 2, None) trainer.load_datasets() max_updates = trainer._calculate_max_updates() self.assertEqual(max_updates, 2) self.check_values(trainer, 0, 0, 0) trainer.training_loop() self.check_values(trainer, 2, 1, 2)
def test_fractional_epoch(self): trainer = TrainerTrainingLoopMock(100, None, 0.04) trainer.load_datasets() max_updates = trainer._calculate_max_updates() self.assertEqual(max_updates, 4) self.check_values(trainer, 0, 0, 0) trainer.training_loop() self.check_values(trainer, 4, 1, 4)
def get_mmf_trainer( model_size=1, num_data_size=100, max_updates=5, max_epochs=None, on_update_end_fn=None, fp16=False, scheduler_config=None, grad_clipping_config=None, evaluation_interval=4, log_interval=1, batch_size=1, tensorboard=False, ): torch.random.manual_seed(2) model = SimpleModel({"in_dim": model_size}) model.build() model.train() trainer_config = get_trainer_config() trainer_config.training.evaluation_interval = evaluation_interval trainer_config.training.log_interval = log_interval optimizer = build_optimizer(model, trainer_config) trainer = TrainerTrainingLoopMock( num_data_size, max_updates, max_epochs, config=trainer_config, optimizer=optimizer, fp16=fp16, on_update_end_fn=on_update_end_fn, scheduler_config=scheduler_config, grad_clipping_config=grad_clipping_config, batch_size=batch_size, tensorboard=tensorboard, ) trainer.load_datasets() model.to(trainer.device) trainer.model = model return trainer
def test_exit_on_nan_losses(self, a): config = self._get_config(max_updates=2, max_epochs=None, batch_size=4) trainer = TrainerTrainingLoopMock(config=config) add_model(trainer, SimpleNaNLossModel({"in_dim": 1})) add_optimizer(trainer, config) registry.register("config", trainer.config) batch_size = get_batch_size() trainer.config.training.batch_size = batch_size trainer.load_datasets() exception_raised = False try: trainer.training_loop() except RuntimeError: exception_raised = True self.assertTrue(exception_raised)
def get_mmf_trainer(config=None, model_size=1, num_data_size=100, load_model_from_config=False, seed=2): torch.random.manual_seed(seed) trainer = TrainerTrainingLoopMock(num_data_size, config=config) if not load_model_from_config: add_model(trainer, SimpleModel({"in_dim": model_size})) else: trainer.load_model() add_optimizer(trainer, config) trainer.load_datasets() return trainer
def test_batch_size_per_device(self, a): # Need to patch the mmf.utils.general's world size not mmf.utils.distributed # as the first one is what will be used with patch("mmf.utils.general.get_world_size", return_value=2): config = self._get_config(max_updates=2, max_epochs=None, batch_size=4) trainer = TrainerTrainingLoopMock(config=config) add_model(trainer, SimpleModel({"in_dim": 1})) add_optimizer(trainer, config) registry.register("config", trainer.config) batch_size = get_batch_size() trainer.config.training.batch_size = batch_size trainer.load_datasets() # Train loader has batch size per device, for global batch size 4 # with world size 2, batch size per device should 4 // 2 = 2 self.assertEqual(trainer.train_loader.current_loader.batch_size, 2) # This is per device, so should stay same config = self._get_config(max_updates=2, max_epochs=None, batch_size_per_device=4) trainer = TrainerTrainingLoopMock(config=config) add_model(trainer, SimpleModel({"in_dim": 1})) add_optimizer(trainer, config) registry.register("config", trainer.config) batch_size = get_batch_size() trainer.config.training.batch_size = batch_size trainer.load_datasets() self.assertEqual(trainer.train_loader.current_loader.batch_size, 4) max_updates = trainer._calculate_max_updates() self.assertEqual(max_updates, 2) self.check_values(trainer, 0, 0, 0) trainer.training_loop() self.check_values(trainer, 2, 1, 2)
def test_update_frequency_correct_final_iteration(self): trainer = TrainerTrainingLoopMock(100, 2, None, update_frequency=2) trainer.load_datasets() trainer.training_loop() self.assertEqual(trainer.max_updates, 2) self.assertEqual(trainer.current_iteration, 4)