コード例 #1
0
ファイル: utils.py プロジェクト: zhangrj91/FNA
def load_checkpoint(filename,
                    model=None,
                    map_location=None,
                    strict=False,
                    logger=None):
    """Load checkpoint from a file or URI.

    Args:
        model (Module): Module to load checkpoint.
        filename (str): Either a filepath or URL or modelzoo://xxxxxxx.
        map_location (str): Same as :func:`torch.load`.
        strict (bool): Whether to allow different params for the model and
            checkpoint.
        logger (:mod:`logging.Logger` or None): The logger for error message.

    Returns:
        dict or OrderedDict: The loaded checkpoint.
    """
    if logger is None:
        logger = logging.getLogger()
    # load checkpoint from modelzoo or file or url
    logger.info('Start loading the model from ' + filename)
    if filename.startswith(('http://', 'https://')):
        url = filename
        filename = '../' + url.split('/')[-1]
        if get_dist_info()[0] == 0:
            if osp.isfile(filename):
                os.system('rm ' + filename)
            os.system('wget -N -q -P ../ ' + url)
        dist.barrier()
    elif filename.startswith(('hdfs://', )):
        url = filename
        filename = '../' + url.split('/')[-1]
        if get_dist_info()[0] == 0:
            if osp.isfile(filename):
                os.system('rm ' + filename)
            os.system('hdfs dfs -get ' + url + ' ../')
        dist.barrier()
    else:
        if not osp.isfile(filename):
            raise IOError('{} is not a checkpoint file'.format(filename))
    checkpoint = torch.load(filename, map_location=map_location)
    # get state_dict from checkpoint
    if isinstance(checkpoint, OrderedDict) or isinstance(checkpoint, dict):
        state_dict = checkpoint
    else:
        raise RuntimeError(
            'No state_dict found in checkpoint file {}'.format(filename))
    # strip prefix of state_dict
    if list(state_dict.keys())[0].startswith('module.'):
        state_dict = {k[7:]: v for k, v in state_dict.items()}
    # load state_dict
    if model is not None:
        if hasattr(model, 'module'):
            model.module.load_state_dict(state_dict, strict=strict)
        else:
            model.load_state_dict(state_dict, strict=strict)
        logger.info('Loading the model finished!')
    return state_dict
コード例 #2
0
ファイル: sampler.py プロジェクト: CarmeloA/DL
    def __init__(self,
                 dataset,
                 samples_per_gpu=1,
                 num_replicas=None,
                 rank=None):
        _rank, _num_replicas = get_dist_info()
        if num_replicas is None:
            num_replicas = _num_replicas
        if rank is None:
            rank = _rank
        self.dataset = dataset
        self.samples_per_gpu = samples_per_gpu
        self.num_replicas = num_replicas
        self.rank = rank
        self.epoch = 0

        assert hasattr(self.dataset, 'flag')
        self.flag = self.dataset.flag
        self.group_sizes = np.bincount(self.flag)

        self.num_samples = 0
        for i, j in enumerate(self.group_sizes):
            self.num_samples += int(
                math.ceil(self.group_sizes[i] * 1.0 / self.samples_per_gpu /
                          self.num_replicas)) * self.samples_per_gpu
        self.total_size = self.num_samples * self.num_replicas
コード例 #3
0
    def __init__(self, dataset, samples_per_gpu=1, num_replicas=None, rank=None, split=1000, mode=None):
        _rank, _num_replicas = get_dist_info()
        if num_replicas is None:
            num_replicas = _num_replicas
        if rank is None:
            rank = _rank
        self.dataset = dataset
        self.samples_per_gpu = samples_per_gpu
        self.num_replicas = num_replicas
        self.rank = rank
        self.epoch = 0
        self.split = int(np.floor(split / samples_per_gpu / self.num_replicas)) * samples_per_gpu
        self.mode = mode

        assert hasattr(self.dataset, 'flag')
        self.flag = self.dataset.flag
        self.group_sizes = np.bincount(self.flag)

        self.num_samples = 0
        for i, j in enumerate(self.group_sizes):
            self.num_samples += int(
                math.ceil(self.group_sizes[i] * 1.0 / self.samples_per_gpu /
                          self.num_replicas)) * self.samples_per_gpu
        if self.mode == 'train':
            self.num_samples = self.split
        elif self.mode == 'val':
            self.num_samples = self.num_samples - self.split
        self.total_size = self.num_samples * self.num_replicas
        self.split *= self.num_replicas
コード例 #4
0
def load_url_dist(url):
    """ In distributed setting, this function only download checkpoint at
    local rank 0 """
    rank, world_size = get_dist_info()
    rank = int(os.environ.get('LOCAL_RANK', rank))
    if rank == 0:
        checkpoint = model_zoo.load_url(url)
    if world_size > 1:
        torch.distributed.barrier()
        if rank > 0:
            checkpoint = model_zoo.load_url(url)
    return checkpoint
コード例 #5
0
    def load_checkpoint(self, filename, map_location='cpu', strict=True):
        self.logger.info('load checkpoint from %s', filename)

        if filename.startswith(('http://', 'https://')):
            url = filename
            filename = '../' + url.split('/')[-1]
            if get_dist_info()[0] == 0:
                if osp.isfile(filename):
                    os.system('rm ' + filename)
                os.system('wget -N -q -P ../ ' + url)
            dist.barrier()

        return load_checkpoint(self.model, filename, map_location, strict,
                               self.logger)
コード例 #6
0
    def __init__(self,
                 model,
                 batch_processor,
                 optimizer=None,
                 work_dir=None,
                 log_level=logging.INFO,
                 logger=None,
                 mean_teacher=False):
        assert callable(batch_processor)
        self.model = model
        if optimizer is not None:
            self.optimizer = self.init_optimizer(optimizer)
        else:
            self.optimizer = None
        self.batch_processor = batch_processor
        self.teacher_dict = {}
        self.mean_teacher = mean_teacher

        # create work_dir
        if mmcv.is_str(work_dir):
            self.work_dir = osp.abspath(work_dir)
            mmcv.mkdir_or_exist(self.work_dir)
        elif work_dir is None:
            self.work_dir = None
        else:
            raise TypeError('"work_dir" must be a str or None')

        # get model name from the model class
        if hasattr(self.model, 'module'):
            self._model_name = self.model.module.__class__.__name__
        else:
            self._model_name = self.model.__class__.__name__

        self._rank, self._world_size = get_dist_info()
        self.timestamp = get_time_str()
        if logger is None:
            self.logger = self.init_logger(work_dir, log_level)
        else:
            self.logger = logger
        self.log_buffer = LogBuffer()

        self.mode = None
        self._hooks = []
        self._epoch = 0
        self._iter = 0
        self._inner_iter = 0
        self._max_epochs = 0
        self._max_iters = 0
コード例 #7
0
ファイル: utils.py プロジェクト: zhangrj91/FNA
def get_root_logger(log_dir=None, log_level=logging.INFO):
    logger = logging.getLogger()
    if not logger.hasHandlers():
        logging.basicConfig(format='%(asctime)s - %(message)s',
                            level=log_level,
                            datefmt='%m/%d %I:%M:%S %p')
    rank, _ = get_dist_info()
    if rank != 0:
        logger.setLevel('ERROR')

    if log_dir and rank == 0:
        filename = '{}.log'.format(
            time.strftime('%Y%m%d_%H%M%S', time.localtime()))
        log_file = osp.join(log_dir, filename)
        _add_file_handler(logger, log_file, level=log_level)
    return logger
コード例 #8
0
ファイル: utils.py プロジェクト: zhangrj91/FNA
def init_logger(log_dir=None, level=logging.INFO):
    """Init the logger.

    Args:
        log_dir(str, optional): Log file directory. If not specified, no
            log file will be used.
        level (int or str): See the built-in python logging module.

    Returns:
        :obj:`~logging.Logger`: Python logger.
    """
    rank, _ = get_dist_info()
    logging.basicConfig(format='%(asctime)s - %(message)s', level=level)
    logger = logging.getLogger(__name__)
    if log_dir and rank == 0:
        filename = '{}.log'.format(
            time.strftime('%Y%m%d_%H%M%S', time.localtime()))
        log_file = osp.join(log_dir, filename)
        _add_file_handler(logger, log_file, level=level)
    return logger
    def __init__(self,
                 dataset,
                 samples_per_gpu=1,
                 repeat_t=0.001,
                 num_replicas=None,
                 rank=None):
        _rank, _num_replicas = get_dist_info()
        # _rank, _num_replicas = 0, 8
        if num_replicas is None:
            num_replicas = _num_replicas
        if rank is None:
            rank = _rank
        self.dataset = dataset
        self.samples_per_gpu = samples_per_gpu
        self.num_replicas = num_replicas
        self.rank = rank
        self.epoch = 0

        self.num_to_sample_out = [6000, 17000]
        assert hasattr(self.dataset, 'flag')
        self.flag = self.dataset.flag
        self.group_sizes = np.bincount(self.flag)

        self.dataset_class_image_info = pickle.load(
            open('./data/lvis/class_to_imageid_and_inscount.pt', 'rb'))
        img_id_to_new_img_info_id = {
            info['id']: idx
            for idx, info in enumerate(self.dataset.img_infos)
        }
        for idx, (k, v) in enumerate(self.dataset_class_image_info.items()):
            v_new = v.copy()
            v_new['image_info_id'] = [
                img_id_to_new_img_info_id[id] for id in v_new['img_id']
            ]
            self.dataset_class_image_info[k] = v_new
        # self.dataset_abundant_class_image_info = [self.dataset_class_image_info[cls_idx]
        #                                           for cls_idx in range(len(self.dataset_class_image_info))
        #                                           if self.dataset_class_image_info[cls_idx]['isntance_count'] > 1000]
        self.dataset_abundant_class_image_info = [
            self.dataset_class_image_info[cls_idx]
            for cls_idx in range(len(self.dataset_class_image_info))
            if self.dataset_class_image_info[cls_idx]['isntance_count'] > 0
        ]
        self.dataset_abundant_class_ids = [
            item['category_id']
            for item in self.dataset_abundant_class_image_info
        ]

        ## calculate repeating num. per image
        repeat_per_img = {}
        total_img_num = len(self.dataset.img_infos)
        clses_to_repeat = []
        for cls in range(1, 1231):
            fc = len(self.dataset_class_image_info[cls]
                     ['image_info_id']) / float(total_img_num)
            repeat_this_cls = max(1., np.sqrt(repeat_t / fc))
            if repeat_this_cls > 1:
                clses_to_repeat.append(cls)
            for img_info_id in self.dataset_class_image_info[cls][
                    'image_info_id']:
                if img_info_id not in repeat_per_img:
                    repeat_per_img[img_info_id] = repeat_this_cls
                else:
                    if repeat_per_img[img_info_id] < repeat_this_cls:
                        repeat_per_img[img_info_id] = repeat_this_cls
                    else:
                        pass
        repeat_per_img = {
            k: math.ceil(v)
            for i, (k, v) in enumerate(repeat_per_img.items())
        }  ## ceiling
        assert len(repeat_per_img.keys()) == total_img_num
        img_info_ids_to_repeat = {
            k: v
            for i, (k, v) in enumerate(repeat_per_img.items()) if v > 1
        }  ## repeat larget than 1 imgs
        self.img_info_ids_to_repeat = img_info_ids_to_repeat

        ## calculate new group size infomation
        self.group_sizes_new = self.group_sizes.copy()
        for i, size in enumerate(self.group_sizes):
            indice = np.where(self.flag == i)[0]
            assert len(indice) == size

            for idx, (img_info_id, re_count) in enumerate(
                    self.img_info_ids_to_repeat.items()):
                if img_info_id in indice:
                    self.group_sizes_new[i] += re_count

        self.num_samples_new = 0
        for i, j in enumerate(self.group_sizes_new):
            self.num_samples_new += int(
                math.ceil(self.group_sizes_new[i] * 1.0 / self.samples_per_gpu
                          / self.num_replicas)) * self.samples_per_gpu
        self.total_size = self.num_samples_new * self.num_replicas
コード例 #10
0
def load_state_dict(module, state_dict, strict=False, logger=None, force_matching=False,
                    show_converted=False, ignores=None):
    rank, _ = get_dist_info()

    unexpected_keys = []
    converted_pairs = []
    shape_mismatch_pairs = []
    shape_casted_pairs = []

    own_state = module.state_dict()
    for name, param in state_dict.items():
        if name not in own_state:
            unexpected_keys.append(name)
            continue

        if isinstance(param, torch.nn.Parameter):
            param = param.data

        src_shape = param.size()
        trg_shape = own_state[name].size()
        if src_shape != trg_shape:
            is_valid = False
            if force_matching:
                is_valid = len(src_shape) == len(trg_shape)
                for i in range(len(src_shape)):
                    is_valid &= src_shape[i] >= trg_shape[i]

            if is_valid:
                ind = [slice(0, d) for d in list(trg_shape)]
                own_state[name].copy_(param[ind])

                shape_casted_pairs.append([name, list(own_state[name].size()), list(param.size())])
            else:
                shape_mismatch_pairs.append([name, list(own_state[name].size()), list(param.size())])
        elif ignores is None or not name.endswith(ignores):
            own_state[name].copy_(param)

            if show_converted:
                converted_pairs.append([name, list(own_state[name].size())])

    all_missing_keys = set(own_state.keys()) - set(state_dict.keys())

    missing_keys = [key for key in all_missing_keys if 'num_batches_tracked' not in key]

    err_msg = []
    if unexpected_keys:
        err_msg.append('unexpected key in source state_dict: {}\n'.format(', '.join(unexpected_keys)))
    if missing_keys:
        err_msg.append('missing keys in source state_dict: {}\n'.format(', '.join(missing_keys)))

    if shape_mismatch_pairs:
        casted_info = 'these keys have mismatched shape:\n'
        header = ['key', 'expected shape', 'loaded shape']
        table_data = [header] + shape_mismatch_pairs
        table = AsciiTable(table_data)
        err_msg.append(casted_info + table.table)

    if len(err_msg) > 0 and rank == 0:
        err_msg.insert(0, 'The model and loaded state dict do not match exactly\n')
        err_msg = '\n'.join(err_msg)
        if strict:
            raise RuntimeError(err_msg)
        elif logger is not None:
            logger.warning(err_msg)
        else:
            print(err_msg)

    ok_message = []
    if converted_pairs:
        converted_info = 'These keys have been converted correctly:\n'
        header = ['key', 'shape']
        table_data = [header] + converted_pairs
        table = AsciiTable(table_data)
        ok_message.append(converted_info + table.table)

    if len(ok_message) > 0 and rank == 0:
        warning_msg = '\n'.join(ok_message)
        if logger is not None:
            logger.warning(warning_msg)
        else:
            print(warning_msg)

    warning_msg = []
    if shape_casted_pairs:
        casted_info = 'these keys have been shape casted:\n'
        header = ['key', 'expected shape', 'loaded shape']
        table_data = [header] + shape_casted_pairs
        table = AsciiTable(table_data)
        warning_msg.append(casted_info + table.table)

    if len(warning_msg) > 0 and rank == 0:
        warning_msg.insert(0, 'The model and loaded state dict do not match exactly\n')
        warning_msg = '\n'.join(warning_msg)
        if logger is not None:
            logger.warning(warning_msg)
        else:
            print(warning_msg)