def __init__(self, num_train_data, max_updates, max_epochs, device="cuda", fp16_model=False): config = get_config_with_defaults({ "training": { "max_updates": max_updates, "max_epochs": max_epochs, "evaluation_interval": 10000, "fp16": True, }, "run_type": "train", }) super().__init__(num_train_data, config=config) if fp16_model: assert (torch.cuda.is_available() ), "MMFTrainerMock fp16 requires cuda enabled" model = SimpleModelWithFp16Assert({"in_dim": 1}) model.build() model = model.cuda() else: model = SimpleModel({"in_dim": 1}) model.build() model.train() model.to(self.device) self.model = model self.optimizer = torch.optim.SGD(self.model.parameters(), lr=1e-3)
def get_mmf_trainer( model_size=1, num_data_size=100, max_updates=5, max_epochs=None, on_update_end_fn=None, fp16=False, scheduler_config=None, grad_clipping_config=None, ): torch.random.manual_seed(2) model = SimpleModel(model_size) model.train() trainer_config = get_trainer_config() optimizer = build_optimizer(model, trainer_config) trainer = TrainerTrainingLoopMock( num_data_size, max_updates, max_epochs, config=trainer_config, optimizer=optimizer, on_update_end_fn=on_update_end_fn, fp16=fp16, scheduler_config=scheduler_config, grad_clipping_config=grad_clipping_config, ) model.to(trainer.device) trainer.model = model return trainer
def get_mmf_trainer( model_size=1, num_data_size=100, max_updates=5, max_epochs=None, on_update_end_fn=None, fp16=False, scheduler_config=None, grad_clipping_config=None, evaluation_interval=4, log_interval=1, batch_size=1, tensorboard=False, ): torch.random.manual_seed(2) model = SimpleModel({"in_dim": model_size}) model.build() model.train() trainer_config = get_trainer_config() trainer_config.training.evaluation_interval = evaluation_interval trainer_config.training.log_interval = log_interval optimizer = build_optimizer(model, trainer_config) trainer = TrainerTrainingLoopMock( num_data_size, max_updates, max_epochs, config=trainer_config, optimizer=optimizer, fp16=fp16, on_update_end_fn=on_update_end_fn, scheduler_config=scheduler_config, grad_clipping_config=grad_clipping_config, batch_size=batch_size, tensorboard=tensorboard, ) trainer.load_datasets() model.to(trainer.device) trainer.model = model return trainer