コード例 #1
0
    def __list_dirs(self, root_dir, dataset):
        imgA_list = list()
        imgB_list = list()

        imageA_dir = os.path.join(root_dir, dataset, 'imageA')
        imageB_dir = os.path.join(root_dir, dataset, 'imageB')

        for file_name in os.listdir(imageA_dir):
            image_name = '.'.join(file_name.split('.')[:-1])
            imgA_path = ImageHelper.imgpath(imageA_dir, image_name)
            imgB_path = ImageHelper.imgpath(imageB_dir, image_name)
            if not os.path.exists(imgA_path) or not os.path.exists(imgB_path):
                Log.warn('Img Path: {} not exists.'.format(imgA_path))
                continue

            imgA_list.append(imgA_path)
            imgB_list.append(imgB_path)

        if dataset == 'train' and self.configer.get('data', 'include_val'):
            imageA_dir = os.path.join(root_dir, 'val/imageA')
            imageB_dir = os.path.join(root_dir, 'val/imageB')
            for file_name in os.listdir(imageA_dir):
                image_name = '.'.join(file_name.split('.')[:-1])
                imgA_path = ImageHelper.imgpath(imageA_dir, image_name)
                imgB_path = ImageHelper.imgpath(imageB_dir, image_name)
                if not os.path.exists(imgA_path) or not os.path.exists(
                        imgB_path):
                    Log.warn('Img Path: {} not exists.'.format(imgA_path))
                    continue

                imgA_list.append(imgA_path)
                imgB_list.append(imgB_path)

        return imgA_list, imgB_list
コード例 #2
0
ファイル: fasterrcnn_loader.py プロジェクト: yobcmst/torchcv
    def __list_dirs(self, root_dir, dataset):
        img_list = list()
        json_list = list()
        image_dir = os.path.join(root_dir, dataset, 'image')
        json_dir = os.path.join(root_dir, dataset, 'json')

        for file_name in os.listdir(json_dir):
            image_name = '.'.join(file_name.split('.')[:-1])
            img_path = ImageHelper.imgpath(image_dir, image_name)
            json_path = os.path.join(json_dir, file_name)
            if not os.path.exists(json_path) or img_path is None:
                Log.warn('Json Path: {} not exists.'.format(json_path))
                continue

            json_list.append(json_path)
            img_list.append(img_path)

        if dataset == 'train' and self.configer.get('data', 'include_val'):
            image_dir = os.path.join(root_dir, 'val/image')
            json_dir = os.path.join(root_dir, 'val/json')
            for file_name in os.listdir(json_dir):
                image_name = '.'.join(file_name.split('.')[:-1])
                img_path = ImageHelper.imgpath(image_dir, image_name)
                json_path = os.path.join(json_dir, file_name)
                if not os.path.exists(json_path) or img_path is None:
                    Log.warn('Json Path: {} not exists.'.format(json_path))
                    continue

                json_list.append(json_path)
                img_list.append(img_path)

        return img_list, json_list
コード例 #3
0
    def load_model(model, pretrained=None, all_match=True):
        if pretrained is None:
            return model

        if not os.path.exists(pretrained):
            Log.warn('{} not exists.'.format(pretrained))
            return model

        Log.info('Loading pretrained model:{}'.format(pretrained))
        if all_match:
            pretrained_dict = torch.load(pretrained)
            model_dict = model.state_dict()
            load_dict = dict()
            for k, v in pretrained_dict.items():
                if 'prefix.{}'.format(k) in model_dict:
                    load_dict['prefix.{}'.format(k)] = v
                else:
                    load_dict[k] = v

            # load_dict = {k: v for k, v in pretrained_dict.items() if 'resinit.{}'.format(k) not in model_dict}
            model.load_state_dict(load_dict)

        else:
            pretrained_dict = torch.load(pretrained)
            model_dict = model.state_dict()
            load_dict = {
                k: v
                for k, v in pretrained_dict.items() if k in model_dict
            }
            Log.info('Matched Keys: {}'.format(load_dict.keys()))
            model_dict.update(load_dict)
            model.load_state_dict(model_dict)

        return model
コード例 #4
0
    def __read_json_file(self, root_dir, dataset):
        img_list = list()
        label_list = list()

        with open(os.path.join(root_dir, dataset, 'label.json'), 'r') as file_stream:
            items = json.load(file_stream)
            for item in items:
                img_path = os.path.join(root_dir, dataset, item['image_path'])
                if not os.path.exists(img_path):
                    Log.warn('Image Path: {} not exists.'.format(img_path))
                    continue

                img_list.append(img_path)
                label_list.append(item['label'])

        if dataset == 'train' and self.configer.get('data', 'include_val'):
            with open(os.path.join(root_dir, 'val', 'label.json'), 'r') as file_stream:
                items = json.load(file_stream)
                for item in items:
                    img_path = os.path.join(root_dir, 'val', item['image_path'])
                    if not os.path.exists(img_path):
                        Log.warn('Image Path: {} not exists.'.format(img_path))
                        continue

                    img_list.append(img_path)
                    label_list.append(item['label'])

        return img_list, label_list
コード例 #5
0
ファイル: runner_helper.py プロジェクト: zy0851/TorchCV
    def load_state_dict(module, state_dict, strict=False):
        """Load state_dict to a module.
        This method is modified from :meth:`torch.nn.Module.load_state_dict`.
        Default value for ``strict`` is set to ``False`` and the message for
        param mismatch will be shown even if strict is False.
        Args:
            module (Module): Module that receives the state_dict.
            state_dict (OrderedDict): Weights.
            strict (bool): whether to strictly enforce that the keys
                in :attr:`state_dict` match the keys returned by this module's
                :meth:`~torch.nn.Module.state_dict` function. Default: ``False``.
        """

        if list(state_dict.keys())[0].startswith('module.'):
            state_dict = {k[7:]: v for k, v in state_dict.items()}

        unexpected_keys = []
        unmatched_keys = []
        own_state = module.state_dict()
        for name, param in state_dict.items():
            if name not in own_state:
                unexpected_keys.append(name)
                continue
            if isinstance(param, torch.nn.Parameter):
                # backwards compatibility for serialized parameters
                param = param.data

            try:
                own_state[name].copy_(param)
            except Exception:
                if strict:
                    raise RuntimeError(
                        'While copying the parameter named {}, '
                        'whose dimensions in the model are {} and '
                        'whose dimensions in the checkpoint are {}.'.format(
                            name, own_state[name].size(), param.size()))
                else:
                    unmatched_keys.append(name)

        missing_keys = set(own_state.keys()) - set(state_dict.keys())

        err_msg = []
        if unexpected_keys:
            err_msg.append('unexpected key in source state_dict: {}\n'.format(
                ', '.join(unexpected_keys)))
        if missing_keys:
            err_msg.append('missing keys in source state_dict: {}\n'.format(
                ', '.join(missing_keys)))
        if unexpected_keys:
            err_msg.append('unmatched keys in source state_dict: {}\n'.format(
                ', '.join(unmatched_keys)))
        err_msg = '\n'.join(err_msg)
        if err_msg:
            if strict:
                raise RuntimeError(err_msg)
            else:
                Log.warn(err_msg)
コード例 #6
0
    def __list_dirs(self, root_dir, dataset):
        img_list = list()
        json_list = list()
        mask_list = list()
        image_dir = os.path.join(root_dir, dataset, 'image')
        json_dir = os.path.join(root_dir, dataset, 'json')
        mask_dir = os.path.join(root_dir, dataset, 'mask')

        img_extension = os.listdir(image_dir)[0].split('.')[-1]

        for file_name in os.listdir(json_dir):
            image_name = '.'.join(file_name.split('.')[:-1])
            img_path = os.path.join(image_dir,
                                    '{}.{}'.format(image_name, img_extension))
            mask_path = os.path.join(mask_dir, '{}.png'.format(image_name))
            json_path = os.path.join(json_dir, file_name)
            if not os.path.exists(json_path) or not os.path.exists(img_path):
                Log.warn('Json Path: {} not exists.'.format(json_path))
                continue

            json_list.append(json_path)
            mask_list.append(mask_path)
            img_list.append(img_path)

        if dataset == 'train' and self.configer.get('data', 'include_val'):
            image_dir = os.path.join(root_dir, 'val/image')
            json_dir = os.path.join(root_dir, 'val/json')
            mask_dir = os.path.join(root_dir, 'val/mask')
            for file_name in os.listdir(json_dir):
                image_name = '.'.join(file_name.split('.')[:-1])
                img_path = os.path.join(
                    image_dir, '{}.{}'.format(image_name, img_extension))
                mask_path = os.path.join(mask_dir, '{}.png'.format(image_name))
                json_path = os.path.join(json_dir, file_name)
                if not os.path.exists(json_path) or not os.path.exists(
                        img_path):
                    Log.warn('Json Path: {} not exists.'.format(json_path))
                    continue

                json_list.append(json_path)
                mask_list.append(mask_path)
                img_list.append(img_path)

        return img_list, json_list, mask_list
コード例 #7
0
ファイル: default_loader.py プロジェクト: shmm91/torchcv
    def __list_dirs(self, root_dir, dataset):
        img_list = list()
        label_list = list()
        image_dir = os.path.join(root_dir, dataset, 'image')
        label_dir = os.path.join(root_dir, dataset, 'label')
        img_extension = os.listdir(image_dir)[0].split('.')[-1]

        for file_name in os.listdir(label_dir):
            image_name = '.'.join(file_name.split('.')[:-1])
            img_path = os.path.join(image_dir,
                                    '{}.{}'.format(image_name, img_extension))
            label_path = os.path.join(label_dir, file_name)
            if not os.path.exists(label_path) or not os.path.exists(img_path):
                Log.warn('Label Path: {} not exists.'.format(label_path))
                continue

            img_list.append(img_path)
            label_list.append(label_path)

        if dataset == 'train' and self.configer.get('data', 'include_val'):
            image_dir = os.path.join(root_dir, 'val/image')
            label_dir = os.path.join(root_dir, 'val/label')
            for file_name in os.listdir(label_dir):
                image_name = '.'.join(file_name.split('.')[:-1])
                img_path = os.path.join(
                    image_dir, '{}.{}'.format(image_name, img_extension))
                label_path = os.path.join(label_dir, file_name)
                if not os.path.exists(label_path) or not os.path.exists(
                        img_path):
                    Log.warn('Label Path: {} not exists.'.format(label_path))
                    continue

                img_list.append(img_path)
                label_list.append(label_path)

        return img_list, label_list