def _train_with_condition(
     self,
     num_train_data,
     max_updates,
     max_epochs,
     update_frequency,
     batch_size,
     on_update_end_fn=None,
 ):
     torch.random.manual_seed(2)
     model = SimpleModel({"in_dim": 1})
     model.build()
     opt = torch.optim.SGD(model.parameters(), lr=0.01)
     trainer = TrainerTrainingLoopMock(
         num_train_data,
         max_updates,
         max_epochs,
         optimizer=opt,
         update_frequency=update_frequency,
         batch_size=batch_size,
     )
     trainer.load_datasets()
     if on_update_end_fn:
         trainer.on_update_end = on_update_end_fn
     model.to(trainer.device)
     trainer.model = model
     trainer.training_loop()
     return trainer
    def test_batch_size_per_device(self):
        # Need to patch the mmf.utils.general's world size not mmf.utils.distributed
        # as the first one is what will be used
        with patch("mmf.utils.general.get_world_size", return_value=2):
            trainer = TrainerTrainingLoopMock(100, 2, None, batch_size=4)
            registry.register("config", trainer.config)
            batch_size = get_batch_size()
            trainer.config.training.batch_size = batch_size
            trainer.load_datasets()
            # Train loader has batch size per device, for global batch size 4
            # with world size 2, batch size per device should 4 // 2 = 2
            self.assertEqual(trainer.train_loader.current_loader.batch_size, 2)
            # This is per device, so should stay same
            trainer = TrainerTrainingLoopMock(100, 2, None, batch_size_per_device=4)
            registry.register("config", trainer.config)
            batch_size = get_batch_size()
            trainer.config.training.batch_size = batch_size
            trainer.load_datasets()
            self.assertEqual(trainer.train_loader.current_loader.batch_size, 4)

        max_updates = trainer._calculate_max_updates()
        self.assertEqual(max_updates, 2)

        self.check_values(trainer, 0, 0, 0)
        trainer.training_loop()
        self.check_values(trainer, 2, 1, 2)
    def test_updates(self):
        trainer = TrainerTrainingLoopMock(100, 2, None)
        trainer.load_datasets()
        max_updates = trainer._calculate_max_updates()
        self.assertEqual(max_updates, 2)

        self.check_values(trainer, 0, 0, 0)
        trainer.training_loop()
        self.check_values(trainer, 2, 1, 2)
    def test_fractional_epoch(self):
        trainer = TrainerTrainingLoopMock(100, None, 0.04)
        trainer.load_datasets()
        max_updates = trainer._calculate_max_updates()
        self.assertEqual(max_updates, 4)

        self.check_values(trainer, 0, 0, 0)
        trainer.training_loop()
        self.check_values(trainer, 4, 1, 4)
    def test_exit_on_nan_losses(self, a):
        config = self._get_config(max_updates=2, max_epochs=None, batch_size=4)
        trainer = TrainerTrainingLoopMock(config=config)
        add_model(trainer, SimpleNaNLossModel({"in_dim": 1}))
        add_optimizer(trainer, config)
        registry.register("config", trainer.config)
        batch_size = get_batch_size()
        trainer.config.training.batch_size = batch_size
        trainer.load_datasets()

        exception_raised = False
        try:
            trainer.training_loop()
        except RuntimeError:
            exception_raised = True
        self.assertTrue(exception_raised)
 def test_update_frequency_correct_final_iteration(self):
     trainer = TrainerTrainingLoopMock(100, 2, None, update_frequency=2)
     trainer.load_datasets()
     trainer.training_loop()
     self.assertEqual(trainer.max_updates, 2)
     self.assertEqual(trainer.current_iteration, 4)