예제 #1
0
 def __init__(self, config: Config, model: torch.nn.Module):
     optimizer: torch.optim.Optimizer = create_optimizer(
         config.optimizer, model)
     self.scheduler: torch.optim.lr_scheduler = (create_scheduler(
         config.scheduler, optimizer) if config.scheduler else Scheduler())
     self.sparsifier: Sparsifier = (create_sparsifier(config.sparsifier)
                                    if config.sparsifier else Sparsifier())
     model, self.optimizer = precision.initialize(model, optimizer)
     self.config = config
예제 #2
0
 def __init__(self, config: Config, model: torch.nn.Module):
     if config.early_stop_after > 0:
         assert config.do_eval, "can't do early stopping when not running evalution"
     optimizer: torch.optim.Optimizer = create_optimizer(
         config.optimizer, model)
     self.scheduler: torch.optim.lr_scheduler = (create_scheduler(
         config.scheduler, optimizer) if config.scheduler else Scheduler())
     self.sparsifier: Sparsifier = (create_sparsifier(config.sparsifier)
                                    if config.sparsifier else Sparsifier())
     model, self.optimizer = precision.initialize(model, optimizer)
     self.config = config
예제 #3
0
    def __init__(self, config: Config, model: torch.nn.Module):
        if config.early_stop_after > 0:
            assert config.do_eval, "can't do early stopping when not running evalution"

        if precision.FP16_ENABLED:
            self.optimizer: torch.optim.Optimizer = create_optimizer(
                config.fp16_args,
                model,
                config.optimizer,
                config.num_accumulated_batches,
            )
        else:
            self.optimizer: torch.optim.Optimizer = create_optimizer(
                config.optimizer, model)

        self.scheduler: torch.optim.lr_scheduler = (create_scheduler(
            config.scheduler, self.optimizer) if config.scheduler else
                                                    Scheduler())
        self.sparsifier: Sparsifier = (create_sparsifier(config.sparsifier)
                                       if config.sparsifier else Sparsifier())
        self.config = config