def before_run(self, runner): """Construct the averaged model which will keep track of the running averages of the parameters of the model.""" model = runner.model self.model = AveragedModel(model) self.log_buffer = LogBuffer() self.meta = runner.meta if self.meta is None: self.meta = dict() self.meta.setdefault('hook_msgs', dict())
def train_model(model, optimizer, train_loader, lr_scheduler, optim_cfg, start_epoch, total_epochs, start_iter, rank, logger, ckpt_save_dir, lr_warmup_scheduler=None, ckpt_save_interval=1, max_ckpt_save_num=50, log_interval=20): accumulated_iter = start_iter log_buffer = LogBuffer() for cur_epoch in range(start_epoch, total_epochs): trained_epoch = cur_epoch + 1 accumulated_iter = train_one_epoch( model, optimizer, train_loader, lr_scheduler=lr_scheduler, lr_warmup_scheduler=lr_warmup_scheduler, accumulated_iter=accumulated_iter, train_epoch=trained_epoch, optim_cfg=optim_cfg, rank=rank, logger=logger, log_buffer=log_buffer, log_interval=log_interval) # save trained model if trained_epoch % ckpt_save_interval == 0 and rank == 0: ckpt_list = glob.glob( os.path.join(ckpt_save_dir, 'checkpoint_epoch_*.pth')) ckpt_list.sort(key=os.path.getmtime) if ckpt_list.__len__() >= max_ckpt_save_num: for cur_file_idx in range( 0, len(ckpt_list) - max_ckpt_save_num + 1): os.remove(ckpt_list[cur_file_idx]) ckpt_name = os.path.join(ckpt_save_dir, ('checkpoint_epoch_%d' % trained_epoch)) save_checkpoint( checkpoint_state(model, optimizer, trained_epoch, accumulated_iter), filename=ckpt_name, )
def __init__(self, model, batch_processor, optimizer=None, work_dir=None, log_level=logging.INFO, logger=None): self.optimizer = [] for opt in optimizer: self.optimizer.append(self.init_optimizer(opt)) self.model = model self.batch_processor = batch_processor # create work_dir if mmcv.is_str(work_dir): self.work_dir = osp.abspath(work_dir) mmcv.mkdir_or_exist(self.work_dir) elif work_dir is None: self.work_dir = None else: raise TypeError('"work_dir" must be a str or None') # get model name from the model class if hasattr(self.model, 'module'): self._model_name = self.model.module.__class__.__name__ else: self._model_name = self.model.__class__.__name__ self._rank, self._world_size = get_dist_info() self.timestamp = get_time_str() if logger is None: self.logger = self.init_logger(work_dir, log_level) else: self.logger = logger self.log_buffer = LogBuffer() self.mode = None self._hooks = [] self._epoch = 0 self._iter = 0 self._inner_iter = 0 self._max_epochs = 0 self._max_iters = 0
class MyRunner(runner.Runner): def __init__(self, model, batch_processor, optimizer=None, work_dir=None, log_level=logging.INFO, logger=None): self.optimizer = [] for opt in optimizer: self.optimizer.append(self.init_optimizer(opt)) self.model = model self.batch_processor = batch_processor # create work_dir if mmcv.is_str(work_dir): self.work_dir = osp.abspath(work_dir) mmcv.mkdir_or_exist(self.work_dir) elif work_dir is None: self.work_dir = None else: raise TypeError('"work_dir" must be a str or None') # get model name from the model class if hasattr(self.model, 'module'): self._model_name = self.model.module.__class__.__name__ else: self._model_name = self.model.__class__.__name__ self._rank, self._world_size = get_dist_info() self.timestamp = get_time_str() if logger is None: self.logger = self.init_logger(work_dir, log_level) else: self.logger = logger self.log_buffer = LogBuffer() self.mode = None self._hooks = [] self._epoch = 0 self._iter = 0 self._inner_iter = 0 self._max_epochs = 0 self._max_iters = 0 def current_lr(self): """Get current learning rates. Returns: list: Current learning rate of all param groups. """ if self.optimizer is None: raise RuntimeError( 'lr is not applicable because optimizer does not exist.') lr_list = [] for opt in self.optimizer: lr_list.append([group['lr'] for group in opt.param_groups]) return lr_list def register_logger_hooks(self, log_config): log_interval = log_config['interval'] for info in log_config['hooks']: logger_hook = obj_from_dict( info, hooks, default_args=dict(interval=log_interval)) self.register_hook(logger_hook, priority='VERY_LOW') def register_lr_hooks(self, lr_config): if isinstance(lr_config, MyLrUpdaterHook): self.register_hook(lr_config) elif isinstance(lr_config, dict): assert 'policy' in lr_config # from .hooks import lr_updater hook_name = lr_config['policy'].title() + 'LrUpdaterHook' if not hasattr(lr_updater, hook_name): raise ValueError('"{}" does not exist'.format(hook_name)) hook_cls = getattr(lr_updater, hook_name) self.register_hook(hook_cls(**lr_config)) else: raise TypeError('"lr_config" must be either a LrUpdaterHook object' ' or dict, not {}'.format(type(lr_config))) def register_training_hooks(self, lr_config, optimizer_config=None, checkpoint_config=None, log_config=None): """Register default hooks for training. Default hooks include: - LrUpdaterHook - OptimizerStepperHook - CheckpointSaverHook - IterTimerHook - LoggerHook(s) """ if optimizer_config is None: optimizer_config = {} if checkpoint_config is None: checkpoint_config = {} self.register_lr_hooks(lr_config) self.register_hook(self.build_hook(optimizer_config, MyOptimizerHook)) self.register_hook(self.build_hook(checkpoint_config, CheckpointHook)) self.register_hook(IterTimerHook()) if log_config is not None: self.register_logger_hooks(log_config) def save_checkpoint(self, out_dir, filename_tmpl='epoch_{}.pth', save_optimizer=True, meta=None): if meta is None: meta = dict(epoch=self.epoch + 1, iter=self.iter) else: meta.update(epoch=self.epoch + 1, iter=self.iter) filename = filename_tmpl.format(self.epoch + 1) filepath = osp.join(out_dir, filename) linkpath = osp.join(out_dir, 'latest.pth') optimizer = None save_checkpoint(self.model, filepath, optimizer=optimizer, meta=meta) # use relative symlink mmcv.symlink(filename, linkpath) def train(self, data_loader, **kwargs): self.model.train() self.mode = 'train' self.data_loader = data_loader self._max_iters = self._max_epochs * len(data_loader) self.call_hook('before_train_epoch') for i, data_batch in enumerate(data_loader): self._inner_iter = i self.call_hook('before_train_iter') outputs = self.batch_processor( self.model, data_batch, epoch=self.epoch, train_mode=True, **kwargs) if not isinstance(outputs, dict): raise TypeError('batch_processor() must return a dict') if 'log_vars' in outputs: self.log_buffer.update(outputs['log_vars'], outputs['num_samples']) self.outputs = outputs self.call_hook('after_train_iter') self._iter += 1 self.call_hook('after_train_epoch') self._epoch += 1 def val(self, data_loader, **kwargs): self.model.eval() self.mode = 'val' self.data_loader = data_loader self.call_hook('before_val_epoch') for i, data_batch in enumerate(data_loader): self._inner_iter = i self.call_hook('before_val_iter') with torch.no_grad(): outputs = self.batch_processor( self.model, data_batch, train_mode=False, **kwargs) if not isinstance(outputs, dict): raise TypeError('batch_processor() must return a dict') if 'log_vars' in outputs: self.log_buffer.update(outputs['log_vars'], outputs['num_samples']) self.outputs = outputs self.call_hook('after_val_iter') self.call_hook('after_val_epoch') def run(self, data_loaders, workflow, max_epochs, **kwargs): """Start running. Args: data_loaders (list[:obj:`DataLoader`]): Dataloaders for training and validation. workflow (list[tuple]): A list of (phase, epochs) to specify the running order and epochs. E.g, [('train', 2), ('val', 1)] means running 2 epochs for training and 1 epoch for validation, iteratively. max_epochs (int): Total training epochs. """ assert isinstance(data_loaders, list) assert mmcv.is_list_of(workflow, tuple) assert len(data_loaders) == len(workflow) self._max_epochs = max_epochs work_dir = self.work_dir if self.work_dir is not None else 'NONE' self.logger.info('Start running, host: %s, work_dir: %s', get_host_info(), work_dir) self.logger.info('workflow: %s, max: %d epochs', workflow, max_epochs) self.call_hook('before_run') while self.epoch < max_epochs: for i, flow in enumerate(workflow): mode, epochs = flow if isinstance(mode, str): # self.train() if not hasattr(self, mode): raise ValueError( 'runner has no method named "{}" to run an epoch'. format(mode)) epoch_runner = getattr(self, mode) elif callable(mode): # custom train() epoch_runner = mode else: raise TypeError('mode in workflow must be a str or ' 'callable function, not {}'.format( type(mode))) for _ in range(epochs): if mode == 'train' and self.epoch >= max_epochs: return epoch_runner(data_loaders[i], **kwargs) time.sleep(1) # wait for some hooks like loggers to finish self.call_hook('after_run')
class SWAHook(Hook): r"""SWA Object Detection Hook. This hook works together with SWA training config files to train SWA object detectors <https://arxiv.org/abs/2012.12645>. Args: swa_eval (bool): Whether to evaluate the swa model. Defaults to True. eval_hook (Hook): Hook class that contains evaluation functions. Defaults to None. swa_interval (int): The epoch interval to perform swa """ def __init__(self, swa_eval=True, eval_hook=None, swa_interval=1): if not isinstance(swa_eval, bool): raise TypeError('swa_eval must be a bool, but got' f'{type(swa_eval)}') if swa_eval: if not isinstance(eval_hook, EvalHook) and \ not isinstance(eval_hook, DistEvalHook): raise TypeError('eval_hook must be either a EvalHook or a ' 'DistEvalHook when swa_eval = True, but got' f'{type(eval_hook)}') self.swa_eval = swa_eval self.eval_hook = eval_hook self.swa_interval = swa_interval def before_run(self, runner): """Construct the averaged model which will keep track of the running averages of the parameters of the model.""" model = runner.model self.model = AveragedModel(model) self.meta = runner.meta if self.meta is None: self.meta = dict() self.meta.setdefault('hook_msgs', dict()) if isinstance(self.meta, dict) and 'hook_msgs' not in self.meta: self.meta.setdefault('hook_msgs', dict()) self.log_buffer = LogBuffer() def after_train_epoch(self, runner): """Update the parameters of the averaged model, save and evaluate the updated averaged model.""" model = runner.model # Whether to perform swa if (runner.epoch + 1) % self.swa_interval == 0: swa_flag = True else: swa_flag = False # update the parameters of the averaged model if swa_flag: self.model.update_parameters(model) # save the swa model runner.logger.info( f'Saving swa model at swa-training {runner.epoch + 1} epoch') filename = 'swa_model_{}.pth'.format(runner.epoch + 1) filepath = osp.join(runner.work_dir, filename) optimizer = runner.optimizer self.meta['hook_msgs']['last_ckpt'] = filepath save_checkpoint(self.model.module, filepath, optimizer=optimizer, meta=self.meta) # evaluate the swa model if self.swa_eval and swa_flag: self.work_dir = runner.work_dir self.rank = runner.rank self.epoch = runner.epoch self.logger = runner.logger self.meta['hook_msgs']['last_ckpt'] = filename self.eval_hook.after_train_epoch(self) for name, val in self.log_buffer.output.items(): name = 'swa_' + name runner.log_buffer.output[name] = val runner.log_buffer.ready = True self.log_buffer.clear() def after_run(self, runner): # since BN layers in the backbone are frozen, # we do not need to update the BN for the swa model pass def before_epoch(self, runner): pass
def __init__( self, model, batch_processor=None, optimizer=None, work_dir=None, logger=None, meta=None, max_iters=None, max_epochs=None, ): if batch_processor is not None: if not callable(batch_processor): raise TypeError("batch_processor must be callable, " f"but got {type(batch_processor)}") warnings.warn("batch_processor is deprecated, please implement " "train_step() and val_step() in the model instead.") # raise an error is `batch_processor` is not None and # `model.train_step()` exists. if is_module_wrapper(model): _model = model.module else: _model = model if hasattr(_model, "train_step") or hasattr(_model, "val_step"): raise RuntimeError( "batch_processor and model.train_step()/model.val_step() " "cannot be both available.") else: assert hasattr(model, "train_step") # check the type of `optimizer` if isinstance(optimizer, dict): for name, optim in optimizer.items(): if not isinstance(optim, Optimizer): raise TypeError( f"optimizer must be a dict of torch.optim.Optimizers, " f'but optimizer["{name}"] is a {type(optim)}') elif not isinstance(optimizer, Optimizer) and optimizer is not None: pass # raise TypeError( # f'optimizer must be a torch.optim.Optimizer object ' # f'or dict or None, but got {type(optimizer)}') # check the type of `logger` if not isinstance(logger, logging.Logger): raise TypeError(f"logger must be a logging.Logger object, " f"but got {type(logger)}") # check the type of `meta` if meta is not None and not isinstance(meta, dict): raise TypeError( f"meta must be a dict or None, but got {type(meta)}") self.model = model self.batch_processor = batch_processor self.optimizer = optimizer self.logger = logger self.meta = meta # create work_dir if mmcv.is_str(work_dir): self.work_dir = osp.abspath(work_dir) mmcv.mkdir_or_exist(self.work_dir) elif work_dir is None: self.work_dir = None else: raise TypeError('"work_dir" must be a str or None') # get model name from the model class if hasattr(self.model, "module"): self._model_name = self.model.module.__class__.__name__ else: self._model_name = self.model.__class__.__name__ self._rank, self._world_size = get_dist_info() self.timestamp = get_time_str() self.mode = None self._hooks = [] self._epoch = 0 self._iter = 0 self._inner_iter = 0 if max_epochs is not None and max_iters is not None: raise ValueError( "Only one of `max_epochs` or `max_iters` can be set.") self._max_epochs = max_epochs self._max_iters = max_iters # TODO: Redesign LogBuffer, it is not flexible and elegant enough self.log_buffer = LogBuffer()
def main(): args = parse_args() cfg = Config.fromfile(args.cfg) work_dir = cfg.work_dir os.environ["CUDA_VISIBLE_DEVICES"] = ",".join( str(device_id) for device_id in cfg.device_ids) log_dir = os.path.join(work_dir, 'logs') if not os.path.exists(log_dir): os.makedirs(log_dir) logger = init_logger(log_dir) seed = cfg.seed logger.info('Set random seed to {}'.format(seed)) set_random_seed(seed) train_dataset = get_dataset(cfg.data.train) train_data_loader = build_dataloader( train_dataset, cfg.data.imgs_per_gpu, cfg.data.workers_per_gpu, len(cfg.device_ids), dist=False, ) val_dataset = get_dataset(cfg.data.val) val_data_loader = build_dataloader(val_dataset, 1, cfg.data.workers_per_gpu, 1, dist=False, shuffle=False) model = build_detector(cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg) model = MMDataParallel(model).cuda() optimizer = obj_from_dict(cfg.optimizer, torch.optim, dict(params=model.parameters())) lr_scheduler = obj_from_dict(cfg.lr_scedule, LRschedule, dict(optimizer=optimizer)) checkpoint_dir = os.path.join(cfg.work_dir, 'checkpoint_dir') if not os.path.exists(checkpoint_dir): os.makedirs(checkpoint_dir) start_epoch = cfg.start_epoch if cfg.resume_from: checkpoint = load_checkpoint(model, cfg.resume_from) start_epoch = 0 logger.info('resumed epoch {}, from {}'.format(start_epoch, cfg.resume_from)) log_buffer = LogBuffer() for epoch in range(start_epoch, cfg.end_epoch): train(train_data_loader, model, optimizer, epoch, lr_scheduler, log_buffer, cfg, logger) tmp_checkpoint_file = os.path.join(checkpoint_dir, 'tmp_val.pth') meta_dict = cfg._cfg_dict logger.info('save tmp checkpoint to {}'.format(tmp_checkpoint_file)) save_checkpoint(model, tmp_checkpoint_file, optimizer, meta=meta_dict) if len(cfg.device_ids) == 1: sensitivity = val(val_data_loader, model, cfg, logger, epoch) else: model_args = cfg.model.copy() model_args.update(train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg) model_type = getattr(detectors, model_args.pop('type')) results = parallel_test( cfg, model_type, model_args, tmp_checkpoint_file, val_dataset, np.arange(len(cfg.device_ids)).tolist(), workers_per_gpu=1, ) sensitivity = evaluate_deep_lesion(results, val_dataset, cfg.cfg_3dce, logger) save_file = os.path.join( checkpoint_dir, 'epoch_{}_sens@4FP_{:.5f}_{}.pth'.format( epoch + 1, sensitivity, time.strftime('%m-%d-%H-%M', time.localtime(time.time())))) os.rename(tmp_checkpoint_file, save_file) logger.info('save checkpoint to {}'.format(save_file)) if epoch > cfg.lr_scedule.T_max: os.remove(save_file)