def _train_with_condition(
     self,
     num_train_data,
     max_updates,
     max_epochs,
     update_frequency,
     batch_size,
     on_update_end_fn=None,
 ):
     torch.random.manual_seed(2)
     model = SimpleModel({"in_dim": 1})
     model.build()
     opt = torch.optim.SGD(model.parameters(), lr=0.01)
     trainer = TrainerTrainingLoopMock(
         num_train_data,
         max_updates,
         max_epochs,
         optimizer=opt,
         update_frequency=update_frequency,
         batch_size=batch_size,
     )
     trainer.load_datasets()
     if on_update_end_fn:
         trainer.on_update_end = on_update_end_fn
     model.to(trainer.device)
     trainer.model = model
     trainer.training_loop()
     return trainer
Beispiel #2
0
def get_mmf_trainer(
    model_size=1,
    num_data_size=100,
    max_updates=5,
    max_epochs=None,
    on_update_end_fn=None,
    fp16=False,
    scheduler_config=None,
    grad_clipping_config=None,
):
    torch.random.manual_seed(2)
    model = SimpleModel({"in_dim": model_size})
    model.build()
    model.train()
    trainer_config = get_trainer_config()
    optimizer = build_optimizer(model, trainer_config)
    trainer = TrainerTrainingLoopMock(
        num_data_size,
        max_updates,
        max_epochs,
        config=trainer_config,
        optimizer=optimizer,
        on_update_end_fn=on_update_end_fn,
        fp16=fp16,
        scheduler_config=scheduler_config,
        grad_clipping_config=grad_clipping_config,
    )
    trainer.load_datasets()
    model.to(trainer.device)
    trainer.model = model
    return trainer
Beispiel #3
0
def get_mmf_trainer(
    model_size=1,
    num_data_size=100,
    max_updates=5,
    max_epochs=None,
    on_update_end_fn=None,
    fp16=False,
    scheduler_config=None,
    grad_clipping_config=None,
    evaluation_interval=4,
    log_interval=1,
    batch_size=1,
    tensorboard=False,
):
    torch.random.manual_seed(2)
    model = SimpleModel({"in_dim": model_size})
    model.build()
    model.train()
    trainer_config = get_trainer_config()
    trainer_config.training.evaluation_interval = evaluation_interval
    trainer_config.training.log_interval = log_interval
    optimizer = build_optimizer(model, trainer_config)
    trainer = TrainerTrainingLoopMock(
        num_data_size,
        max_updates,
        max_epochs,
        config=trainer_config,
        optimizer=optimizer,
        fp16=fp16,
        on_update_end_fn=on_update_end_fn,
        scheduler_config=scheduler_config,
        grad_clipping_config=grad_clipping_config,
        batch_size=batch_size,
        tensorboard=tensorboard,
    )
    trainer.load_datasets()
    model.to(trainer.device)
    trainer.model = model
    return trainer