def init_components( self, model_fn=None, criterion_fn=None, optimizer_fn=None, scheduler_fn=None, ): """Inits the runs components.""" model = model_fn() model = self.sync_device(model) criterion = criterion_fn() criterion = self.sync_device(criterion) optimizer = optimizer_fn() optimizer = self.sync_device(optimizer) # model, optimizer = _wrap_into_data_parallel_with_apex( # model, # optimizer, # distributed_params=dict( # opt_level=self.opt_level, # keep_batchnorm_fp32=self.keep_batchnorm_fp32, # loss_scale=self.loss_scale, # ), # ) # model = APEX_DDP(model, delay_allreduce=self.delay_all_reduce) model, optimizer = amp.initialize( model, optimizer, opt_level=self.opt_level, keep_batchnorm_fp32=self.keep_batchnorm_fp32, loss_scale=self.loss_scale, ) model = ApexDistributedDataParallel(model, delay_allreduce=self.delay_all_reduce) scheduler = scheduler_fn() scheduler = self.sync_device(scheduler) return model, criterion, optimizer, scheduler
def init_components( self, model_fn=None, criterion_fn=None, optimizer_fn=None, scheduler_fn=None, ): """Inits the runs components.""" model = model_fn() model = self.sync_device(model) if self._sync_bn: model = nn.SyncBatchNorm.convert_sync_batchnorm(model) criterion = criterion_fn() criterion = self.sync_device(criterion) optimizer = optimizer_fn() optimizer = self.sync_device(optimizer) model, optimizer = amp.initialize(model, optimizer, **self.apex_kwargs) model = ApexDistributedDataParallel(model, **self.ddp_kwargs) scheduler = scheduler_fn() scheduler = self.sync_device(scheduler) return model, criterion, optimizer, scheduler
def init_components( self, model_fn=None, criterion_fn=None, optimizer_fn=None, scheduler_fn=None, ): """Inits the runs components.""" model = model_fn() model = self.sync_device(model) criterion = criterion_fn() criterion = self.sync_device(criterion) optimizer = optimizer_fn() optimizer = self.sync_device(optimizer) model, optimizer = amp.initialize( model, optimizer, opt_level=self.opt_level, keep_batchnorm_fp32=self.keep_batchnorm_fp32, loss_scale=self.loss_scale, ) model = ApexDistributedDataParallel( model, delay_allreduce=self.delay_all_reduce) scheduler = scheduler_fn() scheduler = self.sync_device(scheduler) return model, criterion, optimizer, scheduler
def __init__(self, cfg, build_model): """ Args: cfg (config dict): """ self.data_loader = self.build_train_loader(cfg) # Assume these objects must be constructed in this order. model = build_model(cfg) self.model = maybe_convert_module(model) logger.info(f"Model: \n{self.model}") # Assume these objects must be constructed in this order. self.optimizer = self.build_optimizer(cfg, self.model) if cfg.TRAINER.FP16.ENABLED: self.mixed_precision = True if cfg.TRAINER.FP16.TYPE == "APEX": from apex import amp self.model, self.optimizer = amp.initialize( self.model, self.optimizer, opt_level=cfg.TRAINER.FP16.OPTS.OPT_LEVEL) else: self.mixed_precision = False # For training, wrap with DDP. But don't need this for inference. if comm.get_world_size() > 1: torch.cuda.set_device(comm.get_local_rank()) if cfg.MODEL.DDP_BACKEND == "torch": self.model = DistributedDataParallel( self.model, device_ids=[comm.get_local_rank()], broadcast_buffers=False, find_unused_parameters=True) elif cfg.MODEL.DDP_BACKEND == "apex": from apex.parallel import DistributedDataParallel as ApexDistributedDataParallel self.model = ApexDistributedDataParallel(self.model) else: raise ValueError("non-supported DDP backend: {}".format( cfg.MODEL.DDP_BACKEND)) super().__init__( self.model, self.data_loader, self.optimizer, ) if not cfg.SOLVER.LR_SCHEDULER.get("EPOCH_WISE", False): epoch_iters = -1 else: epoch_iters = cfg.SOLVER.LR_SCHEDULER.get("EPOCH_ITERS") logger.warning(f"Setup LR Scheduler in EPOCH mode: {epoch_iters}") auto_scale_config(cfg, self.data_loader) self.scheduler = self.build_lr_scheduler(cfg, self.optimizer, epoch_iters=epoch_iters) # Assume no other objects need to be checkpointed. # We can later make it checkpoint the stateful hooks self.checkpointer = DefaultCheckpointer( # Assume you want to save checkpoints together with logs/statistics self.model, cfg.OUTPUT_DIR, optimizer=self.optimizer, scheduler=self.scheduler, ) self.start_iter = 0 self.start_epoch = 0 self.max_iter = cfg.SOLVER.LR_SCHEDULER.MAX_ITER self.max_epoch = cfg.SOLVER.LR_SCHEDULER.MAX_EPOCH self.window_size = cfg.TRAINER.WINDOW_SIZE self.cfg = cfg self.register_hooks(self.build_hooks())