def load_checkpoint(filename, model=None, map_location=None, strict=False, logger=None): """Load checkpoint from a file or URI. Args: model (Module): Module to load checkpoint. filename (str): Either a filepath or URL or modelzoo://xxxxxxx. map_location (str): Same as :func:`torch.load`. strict (bool): Whether to allow different params for the model and checkpoint. logger (:mod:`logging.Logger` or None): The logger for error message. Returns: dict or OrderedDict: The loaded checkpoint. """ if logger is None: logger = logging.getLogger() # load checkpoint from modelzoo or file or url logger.info('Start loading the model from ' + filename) if filename.startswith(('http://', 'https://')): url = filename filename = '../' + url.split('/')[-1] if get_dist_info()[0] == 0: if osp.isfile(filename): os.system('rm ' + filename) os.system('wget -N -q -P ../ ' + url) dist.barrier() elif filename.startswith(('hdfs://', )): url = filename filename = '../' + url.split('/')[-1] if get_dist_info()[0] == 0: if osp.isfile(filename): os.system('rm ' + filename) os.system('hdfs dfs -get ' + url + ' ../') dist.barrier() else: if not osp.isfile(filename): raise IOError('{} is not a checkpoint file'.format(filename)) checkpoint = torch.load(filename, map_location=map_location) # get state_dict from checkpoint if isinstance(checkpoint, OrderedDict) or isinstance(checkpoint, dict): state_dict = checkpoint else: raise RuntimeError( 'No state_dict found in checkpoint file {}'.format(filename)) # strip prefix of state_dict if list(state_dict.keys())[0].startswith('module.'): state_dict = {k[7:]: v for k, v in state_dict.items()} # load state_dict if model is not None: if hasattr(model, 'module'): model.module.load_state_dict(state_dict, strict=strict) else: model.load_state_dict(state_dict, strict=strict) logger.info('Loading the model finished!') return state_dict
def __init__(self, dataset, samples_per_gpu=1, num_replicas=None, rank=None): _rank, _num_replicas = get_dist_info() if num_replicas is None: num_replicas = _num_replicas if rank is None: rank = _rank self.dataset = dataset self.samples_per_gpu = samples_per_gpu self.num_replicas = num_replicas self.rank = rank self.epoch = 0 assert hasattr(self.dataset, 'flag') self.flag = self.dataset.flag self.group_sizes = np.bincount(self.flag) self.num_samples = 0 for i, j in enumerate(self.group_sizes): self.num_samples += int( math.ceil(self.group_sizes[i] * 1.0 / self.samples_per_gpu / self.num_replicas)) * self.samples_per_gpu self.total_size = self.num_samples * self.num_replicas
def __init__(self, dataset, samples_per_gpu=1, num_replicas=None, rank=None, split=1000, mode=None): _rank, _num_replicas = get_dist_info() if num_replicas is None: num_replicas = _num_replicas if rank is None: rank = _rank self.dataset = dataset self.samples_per_gpu = samples_per_gpu self.num_replicas = num_replicas self.rank = rank self.epoch = 0 self.split = int(np.floor(split / samples_per_gpu / self.num_replicas)) * samples_per_gpu self.mode = mode assert hasattr(self.dataset, 'flag') self.flag = self.dataset.flag self.group_sizes = np.bincount(self.flag) self.num_samples = 0 for i, j in enumerate(self.group_sizes): self.num_samples += int( math.ceil(self.group_sizes[i] * 1.0 / self.samples_per_gpu / self.num_replicas)) * self.samples_per_gpu if self.mode == 'train': self.num_samples = self.split elif self.mode == 'val': self.num_samples = self.num_samples - self.split self.total_size = self.num_samples * self.num_replicas self.split *= self.num_replicas
def load_url_dist(url): """ In distributed setting, this function only download checkpoint at local rank 0 """ rank, world_size = get_dist_info() rank = int(os.environ.get('LOCAL_RANK', rank)) if rank == 0: checkpoint = model_zoo.load_url(url) if world_size > 1: torch.distributed.barrier() if rank > 0: checkpoint = model_zoo.load_url(url) return checkpoint
def load_checkpoint(self, filename, map_location='cpu', strict=True): self.logger.info('load checkpoint from %s', filename) if filename.startswith(('http://', 'https://')): url = filename filename = '../' + url.split('/')[-1] if get_dist_info()[0] == 0: if osp.isfile(filename): os.system('rm ' + filename) os.system('wget -N -q -P ../ ' + url) dist.barrier() return load_checkpoint(self.model, filename, map_location, strict, self.logger)
def __init__(self, model, batch_processor, optimizer=None, work_dir=None, log_level=logging.INFO, logger=None, mean_teacher=False): assert callable(batch_processor) self.model = model if optimizer is not None: self.optimizer = self.init_optimizer(optimizer) else: self.optimizer = None self.batch_processor = batch_processor self.teacher_dict = {} self.mean_teacher = mean_teacher # create work_dir if mmcv.is_str(work_dir): self.work_dir = osp.abspath(work_dir) mmcv.mkdir_or_exist(self.work_dir) elif work_dir is None: self.work_dir = None else: raise TypeError('"work_dir" must be a str or None') # get model name from the model class if hasattr(self.model, 'module'): self._model_name = self.model.module.__class__.__name__ else: self._model_name = self.model.__class__.__name__ self._rank, self._world_size = get_dist_info() self.timestamp = get_time_str() if logger is None: self.logger = self.init_logger(work_dir, log_level) else: self.logger = logger self.log_buffer = LogBuffer() self.mode = None self._hooks = [] self._epoch = 0 self._iter = 0 self._inner_iter = 0 self._max_epochs = 0 self._max_iters = 0
def get_root_logger(log_dir=None, log_level=logging.INFO): logger = logging.getLogger() if not logger.hasHandlers(): logging.basicConfig(format='%(asctime)s - %(message)s', level=log_level, datefmt='%m/%d %I:%M:%S %p') rank, _ = get_dist_info() if rank != 0: logger.setLevel('ERROR') if log_dir and rank == 0: filename = '{}.log'.format( time.strftime('%Y%m%d_%H%M%S', time.localtime())) log_file = osp.join(log_dir, filename) _add_file_handler(logger, log_file, level=log_level) return logger
def init_logger(log_dir=None, level=logging.INFO): """Init the logger. Args: log_dir(str, optional): Log file directory. If not specified, no log file will be used. level (int or str): See the built-in python logging module. Returns: :obj:`~logging.Logger`: Python logger. """ rank, _ = get_dist_info() logging.basicConfig(format='%(asctime)s - %(message)s', level=level) logger = logging.getLogger(__name__) if log_dir and rank == 0: filename = '{}.log'.format( time.strftime('%Y%m%d_%H%M%S', time.localtime())) log_file = osp.join(log_dir, filename) _add_file_handler(logger, log_file, level=level) return logger
def __init__(self, dataset, samples_per_gpu=1, repeat_t=0.001, num_replicas=None, rank=None): _rank, _num_replicas = get_dist_info() # _rank, _num_replicas = 0, 8 if num_replicas is None: num_replicas = _num_replicas if rank is None: rank = _rank self.dataset = dataset self.samples_per_gpu = samples_per_gpu self.num_replicas = num_replicas self.rank = rank self.epoch = 0 self.num_to_sample_out = [6000, 17000] assert hasattr(self.dataset, 'flag') self.flag = self.dataset.flag self.group_sizes = np.bincount(self.flag) self.dataset_class_image_info = pickle.load( open('./data/lvis/class_to_imageid_and_inscount.pt', 'rb')) img_id_to_new_img_info_id = { info['id']: idx for idx, info in enumerate(self.dataset.img_infos) } for idx, (k, v) in enumerate(self.dataset_class_image_info.items()): v_new = v.copy() v_new['image_info_id'] = [ img_id_to_new_img_info_id[id] for id in v_new['img_id'] ] self.dataset_class_image_info[k] = v_new # self.dataset_abundant_class_image_info = [self.dataset_class_image_info[cls_idx] # for cls_idx in range(len(self.dataset_class_image_info)) # if self.dataset_class_image_info[cls_idx]['isntance_count'] > 1000] self.dataset_abundant_class_image_info = [ self.dataset_class_image_info[cls_idx] for cls_idx in range(len(self.dataset_class_image_info)) if self.dataset_class_image_info[cls_idx]['isntance_count'] > 0 ] self.dataset_abundant_class_ids = [ item['category_id'] for item in self.dataset_abundant_class_image_info ] ## calculate repeating num. per image repeat_per_img = {} total_img_num = len(self.dataset.img_infos) clses_to_repeat = [] for cls in range(1, 1231): fc = len(self.dataset_class_image_info[cls] ['image_info_id']) / float(total_img_num) repeat_this_cls = max(1., np.sqrt(repeat_t / fc)) if repeat_this_cls > 1: clses_to_repeat.append(cls) for img_info_id in self.dataset_class_image_info[cls][ 'image_info_id']: if img_info_id not in repeat_per_img: repeat_per_img[img_info_id] = repeat_this_cls else: if repeat_per_img[img_info_id] < repeat_this_cls: repeat_per_img[img_info_id] = repeat_this_cls else: pass repeat_per_img = { k: math.ceil(v) for i, (k, v) in enumerate(repeat_per_img.items()) } ## ceiling assert len(repeat_per_img.keys()) == total_img_num img_info_ids_to_repeat = { k: v for i, (k, v) in enumerate(repeat_per_img.items()) if v > 1 } ## repeat larget than 1 imgs self.img_info_ids_to_repeat = img_info_ids_to_repeat ## calculate new group size infomation self.group_sizes_new = self.group_sizes.copy() for i, size in enumerate(self.group_sizes): indice = np.where(self.flag == i)[0] assert len(indice) == size for idx, (img_info_id, re_count) in enumerate( self.img_info_ids_to_repeat.items()): if img_info_id in indice: self.group_sizes_new[i] += re_count self.num_samples_new = 0 for i, j in enumerate(self.group_sizes_new): self.num_samples_new += int( math.ceil(self.group_sizes_new[i] * 1.0 / self.samples_per_gpu / self.num_replicas)) * self.samples_per_gpu self.total_size = self.num_samples_new * self.num_replicas
def load_state_dict(module, state_dict, strict=False, logger=None, force_matching=False, show_converted=False, ignores=None): rank, _ = get_dist_info() unexpected_keys = [] converted_pairs = [] shape_mismatch_pairs = [] shape_casted_pairs = [] own_state = module.state_dict() for name, param in state_dict.items(): if name not in own_state: unexpected_keys.append(name) continue if isinstance(param, torch.nn.Parameter): param = param.data src_shape = param.size() trg_shape = own_state[name].size() if src_shape != trg_shape: is_valid = False if force_matching: is_valid = len(src_shape) == len(trg_shape) for i in range(len(src_shape)): is_valid &= src_shape[i] >= trg_shape[i] if is_valid: ind = [slice(0, d) for d in list(trg_shape)] own_state[name].copy_(param[ind]) shape_casted_pairs.append([name, list(own_state[name].size()), list(param.size())]) else: shape_mismatch_pairs.append([name, list(own_state[name].size()), list(param.size())]) elif ignores is None or not name.endswith(ignores): own_state[name].copy_(param) if show_converted: converted_pairs.append([name, list(own_state[name].size())]) all_missing_keys = set(own_state.keys()) - set(state_dict.keys()) missing_keys = [key for key in all_missing_keys if 'num_batches_tracked' not in key] err_msg = [] if unexpected_keys: err_msg.append('unexpected key in source state_dict: {}\n'.format(', '.join(unexpected_keys))) if missing_keys: err_msg.append('missing keys in source state_dict: {}\n'.format(', '.join(missing_keys))) if shape_mismatch_pairs: casted_info = 'these keys have mismatched shape:\n' header = ['key', 'expected shape', 'loaded shape'] table_data = [header] + shape_mismatch_pairs table = AsciiTable(table_data) err_msg.append(casted_info + table.table) if len(err_msg) > 0 and rank == 0: err_msg.insert(0, 'The model and loaded state dict do not match exactly\n') err_msg = '\n'.join(err_msg) if strict: raise RuntimeError(err_msg) elif logger is not None: logger.warning(err_msg) else: print(err_msg) ok_message = [] if converted_pairs: converted_info = 'These keys have been converted correctly:\n' header = ['key', 'shape'] table_data = [header] + converted_pairs table = AsciiTable(table_data) ok_message.append(converted_info + table.table) if len(ok_message) > 0 and rank == 0: warning_msg = '\n'.join(ok_message) if logger is not None: logger.warning(warning_msg) else: print(warning_msg) warning_msg = [] if shape_casted_pairs: casted_info = 'these keys have been shape casted:\n' header = ['key', 'expected shape', 'loaded shape'] table_data = [header] + shape_casted_pairs table = AsciiTable(table_data) warning_msg.append(casted_info + table.table) if len(warning_msg) > 0 and rank == 0: warning_msg.insert(0, 'The model and loaded state dict do not match exactly\n') warning_msg = '\n'.join(warning_msg) if logger is not None: logger.warning(warning_msg) else: print(warning_msg)