def __init__(self, args_parser=None, config_file=None, config_dict=None, valid_flag=None):
        self.params_root = None
        if config_dict is not None:
            assert config_file is None
            self.params_root = ConfigFactory.from_dict(config_dict)

        elif config_file is not None:
            if not os.path.exists(config_file):
                Log.error('Json Path:{} not exists!'.format(config_file))
                exit(1)

            self.params_root = ConfigFactory.parse_file(config_file)

        elif 'config_file' in args_parser and args_parser.config_file is not None:
            if not os.path.exists(args_parser.config_file):
                Log.error('Json Path:{} not exists!'.format(args_parser.config_file))
                exit(1)

            self.params_root = ConfigFactory.parse_file(args_parser.config_file)

        else:
            Log.warn('Base settings not set!')
            self.params_root = ConfigFactory.from_dict({})

        if args_parser is not None:
            for key, value in args_parser.__dict__.items():
                if valid_flag is not None and key.split('.')[0] != valid_flag:
                    continue

                if key not in self.params_root:
                    self.add(key, value)
                elif value is not None:
                    self.update(key, value)
示例#2
0
    def vis_peaks(self, heatmap_in, ori_img_in, name='default', sub_dir='peaks'):
        base_dir = os.path.join(self.configer.get('project_dir'), POSE_DIR, sub_dir)
        if not os.path.exists(base_dir):
            Log.error('Dir:{} not exists!'.format(base_dir))
            os.makedirs(base_dir)

        if not isinstance(heatmap_in, np.ndarray):
            if len(heatmap_in.size()) != 3:
                Log.error('Heatmap size is not valid.')
                exit(1)

            heatmap = heatmap_in.clone().data.cpu().numpy().transpose(1, 2, 0)
        else:
            heatmap = heatmap_in.copy()

        if not isinstance(ori_img_in, np.ndarray):
            ori_img = DeNormalize(div_value=self.configer.get('normalize', 'div_value'),
                                  mean=self.configer.get('normalize', 'mean'),
                                  std=self.configer.get('normalize', 'std'))(ori_img_in.clone())
            ori_img = ori_img.data.cpu().squeeze().numpy().transpose(1, 2, 0).astype(np.uint8)
            ori_img = cv2.cvtColor(ori_img, cv2.COLOR_RGB2BGR)
        else:
            ori_img = ori_img_in.copy()

        for j in range(self.configer.get('data', 'num_kpts')):
            peaks = self.__get_peaks(heatmap[:, :, j])

            for peak in peaks:
                ori_img = cv2.circle(ori_img, (peak[0], peak[1]),
                                     self.configer.get('vis', 'circle_radius'),
                                     self.configer.get('details', 'color_list')[j], thickness=-1)

            cv2.imwrite(os.path.join(base_dir, '{}_{}.jpg'.format(name, j)), ori_img)
示例#3
0
    def vis_default_bboxes(self,
                           ori_img_in,
                           default_bboxes,
                           labels,
                           name='default',
                           sub_dir='encode'):
        base_dir = os.path.join(self.configer.get('project_dir'), DET_DIR,
                                sub_dir)

        if not os.path.exists(base_dir):
            log.error('Dir:{} not exists!'.format(base_dir))
            os.makedirs(base_dir)

        if not isinstance(ori_img_in, np.ndarray):
            ori_img = DeNormalize(
                div_value=self.configer.get('normalize', 'div_value'),
                mean=self.configer.get('normalize', 'mean'),
                std=self.configer.get('normalize', 'std'))(ori_img_in.clone())
            ori_img = ori_img.data.cpu().squeeze().numpy().transpose(
                1, 2, 0).astype(np.uint8)
            ori_img = cv2.cvtColor(ori_img, cv2.COLOR_RGB2BGR)
        else:
            ori_img = ori_img_in.copy()

        assert labels.size(0) == default_bboxes.size(0)

        bboxes = torch.cat([
            default_bboxes[:, :2] - default_bboxes[:, 2:] / 2,
            default_bboxes[:, :2] + default_bboxes[:, 2:] / 2
        ], 1)
        height, width, _ = ori_img.shape
        for i in range(labels.size(0)):
            if labels[i] == 0:
                continue

            class_name = self.configer.get('details',
                                           'name_seq')[labels[i] - 1]
            color_num = len(self.configer.get('details', 'color_list'))

            cv2.rectangle(
                ori_img,
                (int(bboxes[i][0] * width), int(bboxes[i][1] * height)),
                (int(bboxes[i][2] * width), int(bboxes[i][3] * height)),
                color=self.configer.get(
                    'details', 'color_list')[(labels[i] - 1) % color_num],
                thickness=3)

            cv2.putText(ori_img,
                        class_name, (int(bboxes[i][0] * width) + 5,
                                     int(bboxes[i][3] * height) - 5),
                        cv2.FONT_HERSHEY_SIMPLEX,
                        fontScale=0.5,
                        color=self.configer.get('details',
                                                'color_list')[(labels[i] - 1) %
                                                              color_num],
                        thickness=2)

        img_path = os.path.join(base_dir, '{}.jpg'.format(name))

        cv2.imwrite(img_path, ori_img)
示例#4
0
    def cv2_read_image(image_path, mode='RGB'):
        if ImageHelper.is_zip_path(image_path):
            if mode == 'RGB':
                return ImageHelper.bgr2rgb(ZipReader.imread(image_path, mode))

            elif mode == 'BGR':
                return ZipReader.imread(image_path, mode)

            elif mode == 'P':
                return ZipReader.imread(image_path, mode)

            else:
                Log.error('Not support mode {}'.format(mode))
                exit(1)

        else:
            img_bgr = cv2.imread(image_path, cv2.IMREAD_COLOR)
            if mode == 'RGB':
                return ImageHelper.bgr2rgb(img_bgr)

            elif mode == 'BGR':
                return img_bgr

            elif mode == 'P':
                return ImageHelper.to_np(Image.open(image_path).convert('P'))

            else:
                Log.error('Not support mode {}'.format(mode))
                exit(1)
    def get_trainloader(self):
        if self.configer.get('train.loader', default=None) in [None, 'default']:
            trainloader = data.DataLoader(
                DefaultLoader(root_dir=self.configer.get('data', 'data_dir'), dataset='train',
                              aug_transform=self.aug_train_transform,
                              img_transform=self.img_transform,
                              configer=self.configer),
                batch_size=self.configer.get('train', 'batch_size'), shuffle=True,
                num_workers=self.configer.get('data', 'workers'), pin_memory=True,
                drop_last=self.configer.get('data', 'drop_last'),
                collate_fn=lambda *args: collate(
                    *args, trans_dict=self.configer.get('train', 'data_transformer')
                )
            )

            return trainloader

        elif self.configer.get('train', 'loader') == 'fasterrcnn':
            trainloader = data.DataLoader(
                FasterRCNNLoader(root_dir=self.configer.get('data', 'data_dir'), dataset='train',
                                 aug_transform=self.aug_train_transform,
                                 img_transform=self.img_transform,
                                 configer=self.configer),
                batch_size=self.configer.get('train', 'batch_size'), shuffle=True,
                num_workers=self.configer.get('data', 'workers'), pin_memory=True,
                drop_last=self.configer.get('data', 'drop_last'),
                collate_fn=lambda *args: collate(
                    *args, trans_dict=self.configer.get('train', 'data_transformer')
                )
            )

            return trainloader
        else:
            Log.error('{} train loader is invalid.'.format(self.configer.get('train', 'loader')))
            exit(1)
示例#6
0
    def vis_bboxes(self,
                   image_in,
                   bboxes_list,
                   name='default',
                   sub_dir='bbox'):
        """
          Show the diff bbox of individuals.
        """
        base_dir = os.path.join(self.configer.get('project_dir'), DET_DIR,
                                sub_dir)

        if isinstance(image_in, Image.Image):
            image = ImageHelper.rgb2bgr(ImageHelper.to_np(image_in))

        else:
            image = image_in.copy()

        if not os.path.exists(base_dir):
            log.error('Dir:{} not exists!'.format(base_dir))
            os.makedirs(base_dir)

        img_path = os.path.join(
            base_dir,
            name if ImageHelper.is_img(name) else '{}.jpg'.format(name))

        for bbox in bboxes_list:
            image = cv2.rectangle(image, (bbox[0], bbox[1]),
                                  (bbox[2], bbox[3]), (0, 255, 0), 2)

        cv2.imwrite(img_path, image)
示例#7
0
    def get_trainloader(self):
        if self.configer.get('dataset', default=None) == 'default_cpm':
            dataset = DefaultCPMDataset(root_dir=self.configer.get(
                'data', 'data_dir'),
                                        dataset='train',
                                        aug_transform=self.aug_train_transform,
                                        img_transform=self.img_transform,
                                        configer=self.configer)

        elif self.configer.get('dataset', default=None) == 'default_openpose':
            dataset = DefaultOpenPoseDataset(
                root_dir=self.configer.get('data', 'data_dir'),
                dataset='train',
                aug_transform=self.aug_train_transform,
                img_transform=self.img_transform,
                configer=self.configer)

        else:
            Log.error('{} dataset is invalid.'.format(
                self.configer.get('dataset', default=None)))
            exit(1)

        trainloader = data.DataLoader(
            dataset,
            batch_size=self.configer.get('train', 'batch_size'),
            shuffle=True,
            num_workers=self.configer.get('data', 'workers'),
            pin_memory=True,
            drop_last=self.configer.get('data', 'drop_last'),
            collate_fn=lambda *args: collate(*args,
                                             trans_dict=self.configer.get(
                                                 'train', 'data_transformer')))
        return trainloader
    def get_valloader(self, dataset=None):
        dataset = 'val' if dataset is None else dataset
        if self.configer.get('val.loader', default=None) in [None, 'default']:
            dataset = DefaultLoader(root_dir=self.configer.get(
                'data', 'data_dir'),
                                    dataset=dataset,
                                    aug_transform=self.aug_val_transform,
                                    img_transform=self.img_transform,
                                    configer=self.configer)
            sampler = None
            if self.configer.get('network.distributed'):
                sampler = torch.utils.data.distributed.DistributedSampler(
                    dataset)

            valloader = data.DataLoader(
                dataset,
                sampler=sampler,
                batch_size=self.configer.get('val', 'batch_size'),
                shuffle=False,
                num_workers=self.configer.get('data', 'workers'),
                pin_memory=True,
                collate_fn=lambda *args: collate(
                    *args,
                    trans_dict=self.configer.get('val', 'data_transformer')))
            return valloader

        else:
            Log.error('{} val loader is invalid.'.format(
                self.configer.get('val', 'loader')))
            exit(1)
示例#9
0
    def __init__(self, configer):
        self.configer = configer

        if self.configer.get('data', 'image_tool') == 'pil':
            self.aug_train_transform = pil_aug_trans.PILAugCompose(
                self.configer, split='train')
        elif self.configer.get('data', 'image_tool') == 'cv2':
            self.aug_train_transform = cv2_aug_trans.CV2AugCompose(
                self.configer, split='train')
        else:
            Log.error('Not support {} image tool.'.format(
                self.configer.get('data', 'image_tool')))
            exit(1)

        if self.configer.get('data', 'image_tool') == 'pil':
            self.aug_val_transform = pil_aug_trans.PILAugCompose(self.configer,
                                                                 split='val')
        elif self.configer.get('data', 'image_tool') == 'cv2':
            self.aug_val_transform = cv2_aug_trans.CV2AugCompose(self.configer,
                                                                 split='val')
        else:
            Log.error('Not support {} image tool.'.format(
                self.configer.get('data', 'image_tool')))
            exit(1)

        self.img_transform = trans.Compose([
            trans.ToTensor(),
            trans.Normalize(**self.configer.get('data', 'normalize')),
        ])
示例#10
0
    def get_backbone(self, **params):
        backbone = self.configer.get('network', 'backbone')

        model = None
        if 'vgg' in backbone:
            model = VGGBackbone(self.configer)(**params)

        elif 'darknet' in backbone:
            model = DarkNetBackbone(self.configer)(**params)

        elif 'resnet' in backbone:
            model = ResNetBackbone(self.configer)(**params)

        elif 'mobilenet' in backbone:
            model = MobileNetBackbone(self.configer)(*params)

        elif 'densenet' in backbone:
            model = DenseNetBackbone(self.configer)(**params)

        elif 'squeezenet' in backbone:
            model = SqueezeNetBackbone(self.configer)(**params)

        elif 'shufflenet' in backbone:
            model = ShuffleNetv2Backbone(self.configer)(**params)

        else:
            Log.error('Backbone {} is invalid.'.format(backbone))
            exit(1)

        return model
示例#11
0
    def get_valloader(self, dataset=None):
        dataset = 'val' if dataset is None else dataset
        if self.configer.get('dataset', default=None) == 'default_cpm':
            dataset = DefaultDataset(root_dir=self.configer.get(
                'data', 'data_dir'),
                                     dataset=dataset,
                                     aug_transform=self.aug_val_transform,
                                     img_transform=self.img_transform,
                                     configer=self.configer)

        elif self.configer.get('dataset', default=None) == 'default_openpose':
            dataset = DefaultOpenPoseDataset(
                root_dir=self.configer.get('data', 'data_dir'),
                dataset=dataset,
                aug_transform=self.aug_val_transform,
                img_transform=self.img_transform,
                configer=self.configer),

        else:
            Log.error('{} dataset is invalid.'.format(
                self.configer.get('dataset')))
            exit(1)

        valloader = data.DataLoader(
            dataset,
            batch_size=self.configer.get('val', 'batch_size'),
            shuffle=False,
            num_workers=self.configer.get('data', 'workers'),
            pin_memory=True,
            collate_fn=lambda *args: collate(*args,
                                             trans_dict=self.configer.get(
                                                 'val', 'data_transformer')))
        return valloader
    def parse_dir_pose(self, image_dir, json_dir, mask_dir=None):
        if image_dir is None or not os.path.exists(image_dir):
            Log.error('Image Dir: {} not existed.'.format(image_dir))
            return

        if json_dir is None or not os.path.exists(json_dir):
            Log.error('Json Dir: {} not existed.'.format(json_dir))
            return

        for image_file in os.listdir(image_dir):
            shotname, extension = os.path.splitext(image_file)
            Log.info(image_file)
            image_canvas = cv2.imread(os.path.join(image_dir, image_file))  # B, G, R order.
            with open(os.path.join(json_dir, '{}.json'.format(shotname)), 'r') as json_stream:
                info_tree = json.load(json_stream)
                image_canvas = self.draw_points(image_canvas, info_tree)
                if self.configer.exists('details', 'limb_seq'):
                    image_canvas = self.link_points(image_canvas, info_tree)

            if mask_dir is not None:
                mask_file = os.path.join(mask_dir, '{}_vis.png'.format(shotname))
                mask_canvas = cv2.imread(mask_file)
                image_canvas = cv2.addWeighted(image_canvas, 0.6, mask_canvas, 0.4, 0)

            cv2.imshow('main', image_canvas)
            cv2.waitKey()
示例#13
0
    def get_valloader(self):
        if self.configer.get('dataset', default=None) in [None, 'default']:
            dataset = DefaultDataset(root_dir=self.configer.get(
                'data', 'data_dir'),
                                     dataset='val',
                                     aug_transform=self.aug_val_transform,
                                     img_transform=self.img_transform,
                                     label_transform=self.label_transform,
                                     configer=self.configer)

        elif self.configer.get('dataset', default=None) == 'cityscapes':
            dataset = CityscapesDataset(root_dir=self.configer.get(
                'data', 'data_dir'),
                                        dataset='val',
                                        aug_transform=self.aug_val_transform,
                                        img_transform=self.img_transform,
                                        label_transform=self.label_transform,
                                        configer=self.configer)

        else:
            Log.error('{} dataset is invalid.'.format(
                self.configer.get('dataset')))
            exit(1)

        valloader = data.DataLoader(
            dataset,
            batch_size=self.configer.get('val', 'batch_size'),
            shuffle=False,
            num_workers=self.configer.get('data', 'workers'),
            pin_memory=True,
            collate_fn=lambda *args: collate(*args,
                                             trans_dict=self.configer.get(
                                                 'val', 'data_transformer')))

        return valloader
    def get_valloader(self):
        if self.configer.get('val.loader', default=None) in [None, 'default']:
            Log.info('Get val dataloader start')
            dataset = DefaultLoader(root_dir=self.configer.get(
                'data', 'data_dir'),
                                    dataset='val',
                                    aug_transform=self.aug_val_transform,
                                    img_transform=self.img_transform,
                                    label_transform=self.label_transform,
                                    configer=self.configer)
            Log.info('Get dataloader')
            valloader = data.DataLoader(
                dataset,
                batch_size=self.configer.get('val', 'batch_size'),
                shuffle=False,
                num_workers=self.configer.get('data', 'workers'),
                pin_memory=False,
                collate_fn=lambda *args: collate(
                    *args,
                    trans_dict=self.configer.get('val', 'data_transformer')))
            Log.info('Get val dataloader end')
            return valloader

        else:
            Log.error('{} val loader is invalid.'.format(
                self.configer.get('val', 'loader')))
            exit(1)
    def get_trainloader(self):
        if self.configer.get('train.loader',
                             default=None) in [None, 'default']:
            dataset = DefaultLoader(root_dir=self.configer.get(
                'data', 'data_dir'),
                                    dataset='train',
                                    aug_transform=self.aug_train_transform,
                                    img_transform=self.img_transform,
                                    configer=self.configer)
            sampler = None
            if self.configer.get('network.distributed'):
                sampler = torch.utils.data.distributed.DistributedSampler(
                    dataset)

            trainloader = data.DataLoader(
                dataset,
                sampler=sampler,
                batch_size=self.configer.get('train', 'batch_size'),
                shuffle=(sampler is None),
                num_workers=self.configer.get('data', 'workers'),
                pin_memory=True,
                drop_last=self.configer.get('data', 'drop_last'),
                collate_fn=lambda *args: collate(
                    *args,
                    trans_dict=self.configer.get('train', 'data_transformer')))
            return trainloader

        else:
            Log.error('{} train loader is invalid.'.format(
                self.configer.get('train', 'loader')))
            exit(1)
示例#16
0
    def get_valloader(self, dataset=None):
        dataset = 'val' if dataset is None else dataset
        if self.configer.get('val.loader', default=None) in [None, 'default']:
            valloader = data.DataLoader(
                DefaultLoader(root_dir=self.configer.get('data', 'data_dir'), dataset=dataset,
                              aug_transform=self.aug_val_transform,
                              img_transform=self.img_transform,
                              configer=self.configer),
                batch_size=self.configer.get('val', 'batch_size'), shuffle=False,
                num_workers=self.configer.get('data', 'workers'), pin_memory=True,
                collate_fn=lambda *args: collate(
                    *args, trans_dict=self.configer.get('val', 'data_transformer')
                )
            )

            return valloader

        elif self.configer.get('val', 'loader') == 'fasterrcnn':
            valloader = data.DataLoader(
                FasterRCNNLoader(root_dir=self.configer.get('data', 'data_dir'), dataset=dataset,
                                 aug_transform=self.aug_val_transform,
                                 img_transform=self.img_transform,
                                 configer=self.configer),
                batch_size=self.configer.get('val', 'batch_size'), shuffle=False,
                num_workers=self.configer.get('data', 'workers'), pin_memory=True,
                collate_fn=lambda *args: collate(
                    *args, trans_dict=self.configer.get('val', 'data_transformer')
                )
            )

            return valloader

        else:
            Log.error('{} val loader is invalid.'.format(self.configer.get('val', 'loader')))
            exit(1)
示例#17
0
 def read_image(image_path, tool='pil', mode='RGB'):
     if tool == 'pil':
         return ImageHelper.pil_read_image(image_path, mode=mode)
     elif tool == 'cv2':
         return ImageHelper.cv2_read_image(image_path, mode=mode)
     else:
         Log.error('Not support mode {}'.format(mode))
         exit(1)
    def get(self, *key, **kwargs):
        key = '.'.join(key)
        if key in self.params_root or 'default' in kwargs:
            return self.params_root.get(key, **kwargs)

        else:
            Log.error('{} KeyError: {}.'.format(self._get_caller(), key))
            exit(1)
示例#19
0
文件: dpn.py 项目: wxwoods/torchcv
def get_densenet(name='dpn_26', num_classes=10):
    if name == 'dpn_26':
        return DPN26(num_classes)
    elif name == 'dpn_92':
        return DPN92(num_classes)
    else:
        Log.error('Model: {} not valid!'.format(name))
        exit(1)
    def get_valloader(self, dataset=None):
        dataset = 'val' if dataset is None else dataset
        if not self.configer.exists('val', 'loader') or self.configer.get(
                'val', 'loader') == 'default':
            valloader = data.DataLoader(
                DefaultLoader(root_dir=self.configer.get('data', 'data_dir'),
                              dataset=dataset,
                              aug_transform=self.aug_val_transform,
                              img_transform=self.img_transform,
                              configer=self.configer),
                batch_size=self.configer.get('val', 'batch_size'),
                shuffle=False,
                num_workers=self.configer.get('data', 'workers'),
                pin_memory=True,
                collate_fn=lambda *args: collate(
                    *args,
                    trans_dict=self.configer.get('val', 'data_transformer')))

            return valloader

        elif self.configer.get('val', 'loader') == 'cyclegan':
            valloader = data.DataLoader(
                CycleGANLoader(root_dir=self.configer.get('data', 'data_dir'),
                               dataset=dataset,
                               aug_transform=self.aug_val_transform,
                               img_transform=self.img_transform,
                               configer=self.configer),
                batch_size=self.configer.get('val', 'batch_size'),
                shuffle=False,
                num_workers=self.configer.get('data', 'workers'),
                pin_memory=True,
                collate_fn=lambda *args: collate(
                    *args,
                    trans_dict=self.configer.get('val', 'data_transformer')))

            return valloader

        elif self.configer.get('val', 'loader') == 'facegan':
            valloader = data.DataLoader(
                FaceGANLoader(root_dir=self.configer.get('data', 'data_dir'),
                              dataset=dataset,
                              tag=self.configer.get('data', 'tag'),
                              aug_transform=self.aug_val_transform,
                              img_transform=self.img_transform,
                              configer=self.configer),
                batch_size=self.configer.get('val', 'batch_size'),
                shuffle=False,
                num_workers=self.configer.get('data', 'workers'),
                pin_memory=True,
                collate_fn=lambda *args: collate(
                    *args,
                    trans_dict=self.configer.get('val', 'data_transformer')))

            return valloader
        else:
            Log.error('{} val loader is invalid.'.format(
                self.configer.get('val', 'loader')))
            exit(1)
示例#21
0
    def get_cls_model(self):
        model_name = self.configer.get('network', 'model_name')

        if model_name not in CLS_MODEL_DICT:
            Log.error('Model: {} not valid!'.format(model_name))
            exit(1)

        model = CLS_MODEL_DICT[model_name](self.configer)
        return model
示例#22
0
    def xml2json(xml_file, json_file):
        if not os.path.exists(xml_file):
            Log.error('Xml file: {} not exists.'.format(xml_file))
            exit(1)

        json_dir_name = os.path.dirname(json_file)
        if not os.path.exists(json_dir_name):
            Log.info('Json Dir: {} not exists.'.format(json_dir_name))
            os.makedirs(json_dir_name)
示例#23
0
    def json2xml(json_file, xml_file):
        if not os.path.exists(json_file):
            Log.error('Json file: {} not exists.'.format(json_file))
            exit(1)

        xml_dir_name = os.path.dirname(xml_file)
        if not os.path.exists(xml_dir_name):
            Log.info('Xml Dir: {} not exists.'.format(xml_dir_name))
            os.makedirs(xml_dir_name)
示例#24
0
    def load_file(json_file):
        if not os.path.exists(json_file):
            Log.error('Json file: {} not exists.'.format(json_file))
            exit(1)

        with open(json_file, 'r') as read_stream:
            json_dict = json.load(read_stream)

        return json_dict
示例#25
0
    def object_detector(self):
        model_name = self.configer.get('network', 'model_name')

        if model_name not in DET_MODEL_DICT:
            Log.error('Model: {} not valid!'.format(model_name))
            exit(1)

        model = DET_MODEL_DICT[model_name](self.configer)

        return model
示例#26
0
    def seg_runner(self):
        key = self.configer.get('method')
        if key not in SEG_METHOD_DICT or key not in SEG_TEST_DICT:
            Log.error('Det Method: {} is not valid.'.format(key))
            exit(1)

        if self.configer.get('phase') == 'train':
            return SEG_METHOD_DICT[key](self.configer)
        else:
            return SEG_TEST_DICT[key](self.configer)
    def get_single_pose_model(self):
        model_name = self.configer.get('network', 'model_name')

        if model_name not in SINGLE_POSE_MODEL_DICT:
            Log.error('Model: {} not valid!'.format(model_name))
            exit(1)

        model = SINGLE_POSE_MODEL_DICT[model_name](self.configer)

        return model
示例#28
0
    def test(self, test_dir, out_dir):
        for _, data_dict in enumerate(
                self.test_loader.get_testloader(test_dir=test_dir)):
            total_logits = None
            if self.configer.get('test', 'mode') == 'ss_test':
                total_logits = self.ss_test(data_dict)

            elif self.configer.get('test', 'mode') == 'sscrop_test':
                total_logits = self.sscrop_test(data_dict,
                                                params_dict=self.configer.get(
                                                    'test', 'sscrop_test'))

            elif self.configer.get('test', 'mode') == 'ms_test':
                total_logits = self.ms_test(data_dict,
                                            params_dict=self.configer.get(
                                                'test', 'ms_test'))

            elif self.configer.get('test', 'mode') == 'mscrop_test':
                total_logits = self.mscrop_test(data_dict,
                                                params_dict=self.configer.get(
                                                    'test', 'mscrop_test'))

            else:
                Log.error('Invalid test mode:{}'.format(
                    self.configer.get('test', 'mode')))
                exit(1)

            meta_list = DCHelper.tolist(data_dict['meta'])
            for i in range(len(meta_list)):
                label_map = np.argmax(total_logits[i], axis=-1)
                label_img = np.array(label_map, dtype=np.uint8)
                ori_img_bgr = ImageHelper.read_image(meta_list[i]['img_path'],
                                                     tool='cv2',
                                                     mode='BGR')
                image_canvas = self.seg_parser.colorize(
                    label_img, image_canvas=ori_img_bgr)
                ImageHelper.save(image_canvas,
                                 save_path=os.path.join(
                                     out_dir, 'vis/{}.png'.format(
                                         meta_list[i]['filename'])))

                if self.configer.get('data.label_list',
                                     default=None) is not None:
                    label_img = self.__relabel(label_img)

                if self.configer.get('data.reduce_zero_label', default=False):
                    label_img = label_img + 1
                    label_img = label_img.astype(np.uint8)

                label_img = Image.fromarray(label_img, 'P')
                label_path = os.path.join(
                    out_dir, 'label/{}.png'.format(meta_list[i]['filename']))
                Log.info('Label Path: {}'.format(label_path))
                ImageHelper.save(label_img, label_path)
示例#29
0
    def get_size(img):
        if isinstance(img, Image.Image):
            return img.size

        elif isinstance(img, np.ndarray):
            height, width = img.shape[:2]
            return [width, height]

        else:
            Log.error('Image type is invalid.')
            exit(1)
示例#30
0
    def save(img, save_path):
        FileHelper.make_dirs(save_path, is_file=True)
        if isinstance(img, Image.Image):
            img.save(save_path)

        elif isinstance(img, np.ndarray):
            cv2.imwrite(save_path, img)

        else:
            Log.error('Image type is invalid.')
            exit(1)