示例#1
0
 def _train_with_condition(
     self,
     num_train_data,
     max_updates,
     max_epochs,
     update_frequency,
     batch_size,
     on_update_end_fn=None,
 ):
     torch.random.manual_seed(2)
     model = SimpleModel({"in_dim": 1})
     model.build()
     opt = torch.optim.SGD(model.parameters(), lr=0.01)
     trainer = TrainerTrainingLoopMock(
         num_train_data,
         max_updates,
         max_epochs,
         optimizer=opt,
         update_frequency=update_frequency,
         batch_size=batch_size,
         on_update_end_fn=on_update_end_fn,
     )
     model.to(trainer.device)
     trainer.model = model
     trainer.training_loop()
     return trainer
示例#2
0
    def __init__(self,
                 num_train_data,
                 max_updates,
                 max_epochs,
                 device="cuda",
                 fp16_model=False):
        config = get_config_with_defaults({
            "training": {
                "max_updates": max_updates,
                "max_epochs": max_epochs,
                "evaluation_interval": 10000,
                "fp16": True,
            },
            "run_type": "train",
        })
        super().__init__(num_train_data, config=config)
        if fp16_model:
            assert (torch.cuda.is_available()
                    ), "MMFTrainerMock fp16 requires cuda enabled"
            model = SimpleModelWithFp16Assert({"in_dim": 1})
            model.build()
            model = model.cuda()
        else:
            model = SimpleModel({"in_dim": 1})
            model.build()
            model.train()
            model.to(self.device)

        self.model = model
        self.optimizer = torch.optim.SGD(self.model.parameters(), lr=1e-3)
示例#3
0
def get_mmf_trainer(
    model_size=1,
    num_data_size=100,
    max_updates=5,
    max_epochs=None,
    on_update_end_fn=None,
    fp16=False,
    scheduler_config=None,
    grad_clipping_config=None,
):
    torch.random.manual_seed(2)
    model = SimpleModel(model_size)
    model.train()
    trainer_config = get_trainer_config()
    optimizer = build_optimizer(model, trainer_config)
    trainer = TrainerTrainingLoopMock(
        num_data_size,
        max_updates,
        max_epochs,
        config=trainer_config,
        optimizer=optimizer,
        on_update_end_fn=on_update_end_fn,
        fp16=fp16,
        scheduler_config=scheduler_config,
        grad_clipping_config=grad_clipping_config,
    )
    model.to(trainer.device)
    trainer.model = model
    return trainer
示例#4
0
def get_mmf_trainer(
    model_size=1,
    num_data_size=100,
    max_updates=5,
    max_epochs=None,
    on_update_end_fn=None,
    fp16=False,
    scheduler_config=None,
    grad_clipping_config=None,
    evaluation_interval=4,
    log_interval=1,
    batch_size=1,
    tensorboard=False,
):
    torch.random.manual_seed(2)
    model = SimpleModel({"in_dim": model_size})
    model.build()
    model.train()
    trainer_config = get_trainer_config()
    trainer_config.training.evaluation_interval = evaluation_interval
    trainer_config.training.log_interval = log_interval
    optimizer = build_optimizer(model, trainer_config)
    trainer = TrainerTrainingLoopMock(
        num_data_size,
        max_updates,
        max_epochs,
        config=trainer_config,
        optimizer=optimizer,
        fp16=fp16,
        on_update_end_fn=on_update_end_fn,
        scheduler_config=scheduler_config,
        grad_clipping_config=grad_clipping_config,
        batch_size=batch_size,
        tensorboard=tensorboard,
    )
    trainer.load_datasets()
    model.to(trainer.device)
    trainer.model = model
    return trainer