def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): version = local_metadata.get('version', None) # version = None if version is None or version < 2: # the key is different in early versions # In version < 2, ModulatedDeformConvPack # loads previous benchmark models. if (prefix + 'conv_offset.weight' not in state_dict and prefix[:-1] + '_offset.weight' in state_dict): state_dict[prefix + 'conv_offset.weight'] = state_dict.pop( prefix[:-1] + '_offset.weight') if (prefix + 'conv_offset.bias' not in state_dict and prefix[:-1] + '_offset.bias' in state_dict): state_dict[prefix + 'conv_offset.bias'] = state_dict.pop(prefix[:-1] + '_offset.bias') if version is not None and version > 1: Logging.getLogger().info( 'ModulatedDeformConvPack {} is upgraded to version 2.'.format( prefix.rstrip('.'))) super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
def train_detector(model, dataset, cfg, validate=False, timestamp=None, meta=None): logger = Logging.getLogger() # prepare data loaders dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset] data_loaders = [build_dataloader(ds, data=cfg.data) for ds in dataset] if torch.cuda.is_available(): model = model.cuda(cfg.gpu_ids[0]) model.device = cfg.gpu_ids[0] if torch.cuda.device_count() > 1: model = DataParallel(model, device_ids=cfg.gpu_ids) else: model.device = 'cpu' # build runner optimizer = cfg.optimizer if 'ema' in cfg: ema = cfg.ema else: ema = None runner = Runner(model, batch_processor, optimizer, cfg.work_dir, logger=logger, meta=meta, ema=ema) # an ugly walkaround to make the .log and .log.json filenames the same runner.timestamp = timestamp # register eval hooks 需要放在日志前面,不然打印不出日志。 if validate: cfg.data.val.train = False val_dataset = build_from_dict(cfg.data.val, DATASET) val_dataloader = build_dataloader(val_dataset, shuffle=False, data=cfg.data) eval_cfg = cfg.get('evaluation', {}) from yolodet.models.hooks.eval_hook import EvalHook runner.register_hook(EvalHook(val_dataloader, **eval_cfg)) # register hooks # runner.register_training_hooks(cfg.lr_config, cfg.optimizer_config,cfg.checkpoint_config) runner.register_training_hooks(cfg.lr_config, cfg.optimizer_config, cfg.checkpoint_config, cfg.log_config) if cfg.resume_from: runner.resume(cfg.resume_from) elif cfg.load_from: runner.load_checkpoint(cfg.load_from) runner.run(data_loaders, cfg.workflow, cfg.total_epochs)
def __init__(self, model, batch_processor, optimizer=None, work_dir=None, logger=None, meta=None, ema=None): assert callable(batch_processor) self.model = model if meta is not None: assert isinstance(meta, dict), '"meta" must be a dict or None' self.meta = meta self.mode = None self._hooks = [] self._epoch = 0 self._iter = 0 self._inner_iter = 0 self._max_epochs = 0 self._max_iters = 0 self._warmup_max_iters = 0 self._momentum = 0 self.batch_processor = batch_processor # create work_dir if isinstance(work_dir, str): self.work_dir = osp.abspath(work_dir) file_utils.mkdir_or_exist(self.work_dir) elif work_dir is None: self.work_dir = None else: raise TypeError('"work_dir" must be a str or None') # get model name from the model class if hasattr(self.model, 'module'): self._model_name = self.model.module.__class__.__name__ else: self._model_name = self.model.__class__.__name__ if logger is None: self.logger = Logging.getLogger() else: self.logger = logger self.log_buffer = LogBuffer() if optimizer is not None: self.optimizer = self.init_optimizer(optimizer) else: self.optimizer = None if ema is not None: self.ema = self.init_ema(ema) else: self.ema = None
def load_checkpoint(model, filename, strict=False, map_location='cpu'): logger = Logging.getLogger() logger.info('load checkpoint from %s', filename) if not osp.isfile(filename): raise IOError('{} is not a checkpoint file'.format(filename)) checkpoint = torch.load(filename, map_location=map_location) # get state_dict from checkpoint if isinstance(checkpoint, OrderedDict): state_dict = checkpoint elif isinstance(checkpoint, dict) and 'model' in checkpoint: state_dict = checkpoint['model'].state_dict() else: raise RuntimeError( 'No state_dict found in checkpoint file {}'.format(filename)) # strip prefix of state_dict if list(state_dict.keys())[0].startswith('module.'): state_dict = {k[7:]: v for k, v in checkpoint['state_dict'].items()} # load state_dict if hasattr(model, 'module'): load_state_dict(model.module, state_dict, strict) else: load_state_dict(model, state_dict, strict) return checkpoint
def main(): args = parse_args() cfg = Config.fromfile(args.config) # set cudnn_benchmark if cfg.get('cudnn_benchmark', False): torch.backends.cudnn.benchmark = True # update configs according to CLI args if args.work_dir is not None: cfg.work_dir = args.work_dir if args.resume_from is not None: cfg.resume_from = args.resume_from if args.device is not None: cfg.device = args.device else: cfg.device = None device = select_device(cfg.device) if args.autoscale_lr: # apply the linear scaling rule (https://arxiv.org/abs/1706.02677) cfg.optimizer['lr'] = cfg.optimizer['lr'] * len(cfg.gpu_ids) / 8 # create work_dir file_utils.mkdir_or_exist(osp.abspath(cfg.work_dir)) # init the logger before other steps # timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) # log_file = osp.join(cfg.work_dir, '{}.log'.format(timestamp)) logger = Logging.getLogger() # init the meta dict to record some important information such as # environment info and seed, which will be logged meta = dict() # log env info env_info_dict = collect_env() env_info = '\n'.join([('{}: {}'.format(k, v)) for k, v in env_info_dict.items()]) dash_line = '-' * 60 + '\n' logger.info('Environment info:\n' + dash_line + env_info + '\n' + dash_line) meta['env_info'] = env_info meta['batch_size'] = cfg.data.batch_size meta['subdivisions'] = cfg.data.subdivisions meta['multi_scale'] = args.multi_scale # log some basic info logger.info('Config:\n{}'.format(cfg.text)) # set random seeds if args.seed is not None: logger.info('Set random seed to {}, deterministic: {}'.format( args.seed, args.deterministic)) set_random_seed(args.seed, deterministic=args.deterministic) cfg.seed = args.seed meta['seed'] = args.seed model = build_from_dict(cfg.model, DETECTORS) model = model.cuda(device) # model.device = device if device.type != 'cpu' and torch.cuda.device_count() > 1: model = torch.nn.DataParallel(model) model.device = device datasets = [build_from_dict(cfg.data.train, DATASET)] if len(cfg.workflow) == 2: val_dataset = copy.deepcopy(cfg.data.val) val_dataset.pipeline = cfg.data.train.pipeline datasets.append(build_from_dict(val_dataset, DATASET)) if cfg.checkpoint_config is not None: # save mmdet version, config file content and class names in # checkpoints as meta data cfg.checkpoint_config.meta = dict(config=cfg.text, CLASSES=datasets[0].CLASSES) # add an attribute for visualization convenience model.CLASSES = datasets[0].CLASSES timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) train_detector(model, datasets, cfg, validate=args.validate, timestamp=timestamp, meta=meta)
def load_state_dict(module, state_dict, strict=False): """Load state_dict to a module. This method is modified from :meth:`torch.nn.Module.load_state_dict`. Default value for ``strict`` is set to ``False`` and the message for param mismatch will be shown even if strict is False. Args: module (Module): Module that receives the state_dict. state_dict (OrderedDict): Weights. strict (bool): whether to strictly enforce that the keys in :attr:`state_dict` match the keys returned by this module's :meth:`~torch.nn.Module.state_dict` function. Default: ``False``. logger (:obj:`logging.Logger`, optional): Logger to log the error message. If not specified, print function will be used. """ logger = Logging.getLogger() unexpected_keys = [] all_missing_keys = [] err_msg = [] metadata = getattr(state_dict, '_metadata', None) state_dict = state_dict.copy() if metadata is not None: state_dict._metadata = metadata # use _load_from_state_dict to enable checkpoint version control def load(module, prefix=''): local_metadata = {} if metadata is None else metadata.get( prefix[:-1], {}) module._load_from_state_dict(state_dict, prefix, local_metadata, True, all_missing_keys, unexpected_keys, err_msg) for name, child in module._modules.items(): if child is not None: load(child, prefix + name + '.') load(module) load = None # break load->load reference cycle # ignore "num_batches_tracked" of BN layers missing_keys = [ key for key in all_missing_keys if 'num_batches_tracked' not in key ] if unexpected_keys: err_msg.append('unexpected key in source state_dict: {}\n'.format( ', '.join(unexpected_keys))) if missing_keys: err_msg.append('missing keys in source state_dict: {}\n'.format( ', '.join(missing_keys))) if len(err_msg) > 0: err_msg.insert( 0, 'The model and loaded state dict do not match exactly\n') err_msg = '\n'.join(err_msg) if strict: raise RuntimeError(err_msg) elif logger is not None: logger.warning(err_msg) else: print(err_msg)