示例#1
0
def get_mmf_trainer(
    model_size=1,
    num_data_size=100,
    max_updates=5,
    max_epochs=None,
    on_update_end_fn=None,
    fp16=False,
    scheduler_config=None,
    grad_clipping_config=None,
):
    torch.random.manual_seed(2)
    model = SimpleModel({"in_dim": model_size})
    model.build()
    model.train()
    trainer_config = get_trainer_config()
    optimizer = build_optimizer(model, trainer_config)
    trainer = TrainerTrainingLoopMock(
        num_data_size,
        max_updates,
        max_epochs,
        config=trainer_config,
        optimizer=optimizer,
        on_update_end_fn=on_update_end_fn,
        fp16=fp16,
        scheduler_config=scheduler_config,
        grad_clipping_config=grad_clipping_config,
    )
    trainer.load_datasets()
    model.to(trainer.device)
    trainer.model = model
    return trainer
 def _train_with_condition(
     self,
     num_train_data,
     max_updates,
     max_epochs,
     update_frequency,
     batch_size,
     on_update_end_fn=None,
 ):
     torch.random.manual_seed(2)
     model = SimpleModel({"in_dim": 1})
     model.build()
     opt = torch.optim.SGD(model.parameters(), lr=0.01)
     trainer = TrainerTrainingLoopMock(
         num_train_data,
         max_updates,
         max_epochs,
         optimizer=opt,
         update_frequency=update_frequency,
         batch_size=batch_size,
     )
     trainer.load_datasets()
     if on_update_end_fn:
         trainer.on_update_end = on_update_end_fn
     model.to(trainer.device)
     trainer.model = model
     trainer.training_loop()
     return trainer
    def test_updates(self):
        trainer = TrainerTrainingLoopMock(100, 2, None)
        trainer.load_datasets()
        max_updates = trainer._calculate_max_updates()
        self.assertEqual(max_updates, 2)

        self.check_values(trainer, 0, 0, 0)
        trainer.training_loop()
        self.check_values(trainer, 2, 1, 2)
    def test_fractional_epoch(self):
        trainer = TrainerTrainingLoopMock(100, None, 0.04)
        trainer.load_datasets()
        max_updates = trainer._calculate_max_updates()
        self.assertEqual(max_updates, 4)

        self.check_values(trainer, 0, 0, 0)
        trainer.training_loop()
        self.check_values(trainer, 4, 1, 4)
示例#5
0
def get_mmf_trainer(
    model_size=1,
    num_data_size=100,
    max_updates=5,
    max_epochs=None,
    on_update_end_fn=None,
    fp16=False,
    scheduler_config=None,
    grad_clipping_config=None,
    evaluation_interval=4,
    log_interval=1,
    batch_size=1,
    tensorboard=False,
):
    torch.random.manual_seed(2)
    model = SimpleModel({"in_dim": model_size})
    model.build()
    model.train()
    trainer_config = get_trainer_config()
    trainer_config.training.evaluation_interval = evaluation_interval
    trainer_config.training.log_interval = log_interval
    optimizer = build_optimizer(model, trainer_config)
    trainer = TrainerTrainingLoopMock(
        num_data_size,
        max_updates,
        max_epochs,
        config=trainer_config,
        optimizer=optimizer,
        fp16=fp16,
        on_update_end_fn=on_update_end_fn,
        scheduler_config=scheduler_config,
        grad_clipping_config=grad_clipping_config,
        batch_size=batch_size,
        tensorboard=tensorboard,
    )
    trainer.load_datasets()
    model.to(trainer.device)
    trainer.model = model
    return trainer
    def test_exit_on_nan_losses(self, a):
        config = self._get_config(max_updates=2, max_epochs=None, batch_size=4)
        trainer = TrainerTrainingLoopMock(config=config)
        add_model(trainer, SimpleNaNLossModel({"in_dim": 1}))
        add_optimizer(trainer, config)
        registry.register("config", trainer.config)
        batch_size = get_batch_size()
        trainer.config.training.batch_size = batch_size
        trainer.load_datasets()

        exception_raised = False
        try:
            trainer.training_loop()
        except RuntimeError:
            exception_raised = True
        self.assertTrue(exception_raised)
示例#7
0
def get_mmf_trainer(config=None,
                    model_size=1,
                    num_data_size=100,
                    load_model_from_config=False,
                    seed=2):
    torch.random.manual_seed(seed)
    trainer = TrainerTrainingLoopMock(num_data_size, config=config)

    if not load_model_from_config:
        add_model(trainer, SimpleModel({"in_dim": model_size}))
    else:
        trainer.load_model()

    add_optimizer(trainer, config)

    trainer.load_datasets()
    return trainer
    def test_batch_size_per_device(self, a):
        # Need to patch the mmf.utils.general's world size not mmf.utils.distributed
        # as the first one is what will be used
        with patch("mmf.utils.general.get_world_size", return_value=2):
            config = self._get_config(max_updates=2,
                                      max_epochs=None,
                                      batch_size=4)
            trainer = TrainerTrainingLoopMock(config=config)
            add_model(trainer, SimpleModel({"in_dim": 1}))
            add_optimizer(trainer, config)
            registry.register("config", trainer.config)
            batch_size = get_batch_size()
            trainer.config.training.batch_size = batch_size
            trainer.load_datasets()
            # Train loader has batch size per device, for global batch size 4
            # with world size 2, batch size per device should 4 // 2 = 2
            self.assertEqual(trainer.train_loader.current_loader.batch_size, 2)
            # This is per device, so should stay same
            config = self._get_config(max_updates=2,
                                      max_epochs=None,
                                      batch_size_per_device=4)
            trainer = TrainerTrainingLoopMock(config=config)
            add_model(trainer, SimpleModel({"in_dim": 1}))
            add_optimizer(trainer, config)
            registry.register("config", trainer.config)
            batch_size = get_batch_size()
            trainer.config.training.batch_size = batch_size
            trainer.load_datasets()
            self.assertEqual(trainer.train_loader.current_loader.batch_size, 4)

        max_updates = trainer._calculate_max_updates()
        self.assertEqual(max_updates, 2)

        self.check_values(trainer, 0, 0, 0)
        trainer.training_loop()
        self.check_values(trainer, 2, 1, 2)
 def test_update_frequency_correct_final_iteration(self):
     trainer = TrainerTrainingLoopMock(100, 2, None, update_frequency=2)
     trainer.load_datasets()
     trainer.training_loop()
     self.assertEqual(trainer.max_updates, 2)
     self.assertEqual(trainer.current_iteration, 4)