コード例 #1
0
    def __read_json_file(self, root_dir, dataset):
        img_list = list()
        label_list = list()

        with open(os.path.join(root_dir, dataset, 'label.json'), 'r') as file_stream:
            items = json.load(file_stream)
            for item in items:
                img_path = os.path.join(root_dir, dataset, item['image_path'])
                if not os.path.exists(img_path):
                    Log.warn('Image Path: {} not exists.'.format(img_path))
                    continue

                img_list.append(img_path)
                label_list.append(item['label'])

        if dataset == 'train' and self.configer.get('data', 'include_val'):
            with open(os.path.join(root_dir, 'val', 'label.json'), 'r') as file_stream:
                items = json.load(file_stream)
                for item in items:
                    img_path = os.path.join(root_dir, 'val', item['image_path'])
                    if not os.path.exists(img_path):
                        Log.warn('Image Path: {} not exists.'.format(img_path))
                        continue

                    img_list.append(img_path)
                    label_list.append(item['label'])

        return img_list, label_list
コード例 #2
0
    def __list_dirs(self, root_dir, dataset):
        img_list = list()
        label_list = list()
        image_dir = os.path.join(root_dir, dataset, 'image')
        label_dir = os.path.join(root_dir, dataset, 'label')

        for file_name in os.listdir(label_dir):
            image_name = '.'.join(file_name.split('.')[:-1])
            label_path = os.path.join(label_dir, file_name)
            img_path = ImageHelper.imgpath(image_dir, image_name)
            if not os.path.exists(label_path) or img_path is None:
                Log.warn('Label Path: {} not exists.'.format(label_path))
                continue

            img_list.append(img_path)
            label_list.append(label_path)

        if dataset == 'train' and self.configer.get('data', 'include_val'):
            image_dir = os.path.join(root_dir, 'val/image')
            label_dir = os.path.join(root_dir, 'val/label')
            for file_name in os.listdir(label_dir):
                image_name = '.'.join(file_name.split('.')[:-1])
                label_path = os.path.join(label_dir, file_name)
                img_path = ImageHelper.imgpath(image_dir, image_name)
                if not os.path.exists(label_path) or img_path is None:
                    Log.warn('Label Path: {} not exists.'.format(label_path))
                    continue

                img_list.append(img_path)
                label_list.append(label_path)

        return img_list, label_list
コード例 #3
0
    def __list_dirs(self, root_dir, dataset):
        imgA_list = list()
        imgB_list = list()

        imageA_dir = os.path.join(root_dir, dataset, 'imageA')
        imageB_dir = os.path.join(root_dir, dataset, 'imageB')

        for file_name in os.listdir(imageA_dir):
            image_name = '.'.join(file_name.split('.')[:-1])
            imgA_path = ImageHelper.imgpath(imageA_dir, image_name)
            imgB_path = ImageHelper.imgpath(imageB_dir, image_name)
            if not os.path.exists(imgA_path) or not os.path.exists(imgB_path):
                Log.warn('Img Path: {} not exists.'.format(imgA_path))
                continue

            imgA_list.append(imgA_path)
            imgB_list.append(imgB_path)

        if dataset == 'train' and self.configer.get('data', 'include_val'):
            imageA_dir = os.path.join(root_dir, 'val/imageA')
            imageB_dir = os.path.join(root_dir, 'val/imageB')
            for file_name in os.listdir(imageA_dir):
                image_name = '.'.join(file_name.split('.')[:-1])
                imgA_path = ImageHelper.imgpath(imageA_dir, image_name)
                imgB_path = ImageHelper.imgpath(imageB_dir, image_name)
                if not os.path.exists(imgA_path) or not os.path.exists(
                        imgB_path):
                    Log.warn('Img Path: {} not exists.'.format(imgA_path))
                    continue

                imgA_list.append(imgA_path)
                imgB_list.append(imgB_path)

        return imgA_list, imgB_list
コード例 #4
0
    def load_model(model, pretrained=None, all_match=True, map_location='cpu'):
        if pretrained is None:
            return model

        if not os.path.exists(pretrained):
            Log.warn('{} not exists.'.format(pretrained))
            return model

        Log.info('Loading pretrained model:{}'.format(pretrained))
        if all_match:
            pretrained_dict = torch.load(pretrained, map_location=map_location)
            model_dict = model.state_dict()
            load_dict = dict()
            for k, v in pretrained_dict.items():
                if 'prefix.{}'.format(k) in model_dict:
                    load_dict['prefix.{}'.format(k)] = v
                else:
                    load_dict[k] = v

            model.load_state_dict(load_dict)

        else:
            pretrained_dict = torch.load(pretrained)
            model_dict = model.state_dict()
            load_dict = {
                k: v
                for k, v in pretrained_dict.items() if k in model_dict
            }
            Log.info('Matched Keys: {}'.format(load_dict.keys()))
            model_dict.update(load_dict)
            model.load_state_dict(model_dict)

        return model
コード例 #5
0
    def __init__(self, args_parser=None, config_file=None, config_dict=None, valid_flag=None):
        self.params_root = None
        if config_dict is not None:
            assert config_file is None
            self.params_root = ConfigFactory.from_dict(config_dict)

        elif config_file is not None:
            if not os.path.exists(config_file):
                Log.error('Json Path:{} not exists!'.format(config_file))
                exit(1)

            self.params_root = ConfigFactory.parse_file(config_file)

        elif 'config_file' in args_parser and args_parser.config_file is not None:
            if not os.path.exists(args_parser.config_file):
                Log.error('Json Path:{} not exists!'.format(args_parser.config_file))
                exit(1)

            self.params_root = ConfigFactory.parse_file(args_parser.config_file)

        else:
            Log.warn('Base settings not set!')
            self.params_root = ConfigFactory.from_dict({})

        if args_parser is not None:
            for key, value in args_parser.__dict__.items():
                if valid_flag is not None and key.split('.')[0] != valid_flag:
                    continue

                if key not in self.params_root:
                    self.add(key, value)
                elif value is not None:
                    self.update(key, value)
コード例 #6
0
ファイル: runner_helper.py プロジェクト: zouwen198317/torchcv
    def load_state_dict(module, state_dict, strict=False):
        """Load state_dict to a module.
        This method is modified from :meth:`torch.nn.Module.load_state_dict`.
        Default value for ``strict`` is set to ``False`` and the message for
        param mismatch will be shown even if strict is False.
        Args:
            module (Module): Module that receives the state_dict.
            state_dict (OrderedDict): Weights.
            strict (bool): whether to strictly enforce that the keys
                in :attr:`state_dict` match the keys returned by this module's
                :meth:`~torch.nn.Module.state_dict` function. Default: ``False``.
        """

        if list(state_dict.keys())[0].startswith('module.'):
            state_dict = {k[7:]: v for k, v in state_dict.items()}

        unexpected_keys = []
        unmatched_keys = []
        own_state = module.state_dict()
        for name, param in state_dict.items():
            if name not in own_state:
                unexpected_keys.append(name)
                continue
            if isinstance(param, torch.nn.Parameter):
                # backwards compatibility for serialized parameters
                param = param.data

            try:
                own_state[name].copy_(param)
            except Exception:
                if strict:
                    raise RuntimeError(
                        'While copying the parameter named {}, '
                        'whose dimensions in the model are {} and '
                        'whose dimensions in the checkpoint are {}.'.format(
                            name, own_state[name].size(), param.size()))
                else:
                    unmatched_keys.append(name)

        missing_keys = set(own_state.keys()) - set(state_dict.keys())

        err_msg = []
        if unexpected_keys:
            err_msg.append('unexpected key in source state_dict: {}\n'.format(
                ', '.join(unexpected_keys)))
        if missing_keys:
            err_msg.append('missing keys in source state_dict: {}\n'.format(
                ', '.join(missing_keys)))
        if unexpected_keys:
            err_msg.append('unmatched keys in source state_dict: {}\n'.format(
                ', '.join(unmatched_keys)))
        err_msg = '\n'.join(err_msg)
        if err_msg:
            if strict:
                raise RuntimeError(err_msg)
            else:
                Log.warn(err_msg)
コード例 #7
0
ファイル: openpose_loader.py プロジェクト: wxwoods/torchcv
    def __list_dirs(self, root_dir, dataset):
        img_list = list()
        json_list = list()
        mask_list = list()
        image_dir = os.path.join(root_dir, dataset, 'image')
        json_dir = os.path.join(root_dir, dataset, 'json')
        mask_dir = os.path.join(root_dir, dataset, 'mask')

        for file_name in os.listdir(json_dir):
            image_name = '.'.join(file_name.split('.')[:-1])
            mask_path = os.path.join(mask_dir, '{}.png'.format(image_name))
            img_path = ImageHelper.imgpath(image_dir, image_name)
            json_path = os.path.join(json_dir, file_name)
            if not os.path.exists(json_path) or img_path is None:
                Log.warn('Json Path: {} not exists.'.format(json_path))
                continue

            json_list.append(json_path)
            mask_list.append(mask_path)
            img_list.append(img_path)

        if dataset == 'train' and self.configer.get('data', 'include_val'):
            image_dir = os.path.join(root_dir, 'val/image')
            json_dir = os.path.join(root_dir, 'val/json')
            mask_dir = os.path.join(root_dir, 'val/mask')
            for file_name in os.listdir(json_dir):
                image_name = '.'.join(file_name.split('.')[:-1])
                mask_path = os.path.join(mask_dir, '{}.png'.format(image_name))
                img_path = ImageHelper.imgpath(image_dir, image_name)
                json_path = os.path.join(json_dir, file_name)
                if not os.path.exists(json_path) or img_path is None:
                    Log.warn('Json Path: {} not exists.'.format(json_path))
                    continue

                json_list.append(json_path)
                mask_list.append(mask_path)
                img_list.append(img_path)

        return img_list, json_list, mask_list
コード例 #8
0
    def __list_dirs(self, root_dir, dataset):
        img_list = list()
        label_list = list()
        image_dir = os.path.join(root_dir, 'leftImg8bit', dataset)
        label_dir = os.path.join(root_dir, 'gtFine', dataset)

        for image_file in FileHelper.list_dir(image_dir):
            image_name = '_'.join(image_file.split('_')[:-1])
            label_file = '{}_gtFine_labelIds.png'.format(image_name)
            img_path = os.path.join(image_dir, image_file)
            label_path = os.path.join(label_dir, label_file)
            if not (os.path.exists(label_path) and os.path.exists(img_path)):
                Log.warn('Image/Label Path: {} not exists.'.format(image_name))
                continue

            img_list.append(img_path)
            label_list.append(label_path)

        if dataset == 'train' and self.configer.get('data', 'include_val'):
            image_dir = os.path.join(root_dir, 'leftImg8bit/val')
            label_dir = os.path.join(root_dir, 'gtFine/val')

            for image_file in FileHelper.list_dir(image_dir):
                image_name = '_'.join(image_file.split('_')[:-1])
                label_file = '{}_gtFine_labelIds.png'.format(image_name)
                img_path = os.path.join(image_dir, image_file)
                label_path = os.path.join(label_dir, label_file)
                if not (os.path.exists(label_path)
                        and os.path.exists(img_path)):
                    Log.warn(
                        'Image/Label Path: {} not exists.'.format(image_name))
                    continue

                img_list.append(img_path)
                label_list.append(label_path)

        return img_list, label_list
コード例 #9
0
    def pypass_imgpath(data_dir):
        if ImageHelper.dataset_ext == 0:

            def _inner_list_file(path):
                if ImageHelper.is_zip_path(path):
                    return ZipReader.list_files(path)
                else:
                    return os.listdir(path)

            exist_img_list = _inner_list_file(data_dir)
            tmp_dataset_ext = '.' + exist_img_list[0].split('.')[-1]
            for exist_img_file in exist_img_list:
                if '.' + exist_img_file.split('.')[-1] != tmp_dataset_ext:
                    ImageHelper.dataset_ext = -1
                    return False
            ImageHelper.dataset_ext = tmp_dataset_ext
            Log.warn(
                'Pypass img exist check, consistent ext {} in image folder'.
                format(ImageHelper.dataset_ext))
            return True
        elif ImageHelper.dataset_ext == -1:
            return False
        else:
            return True