예제 #1
0
    def get_valloader(self, dataset=None):
        dataset = 'val' if dataset is None else dataset
        if self.configer.get('dataset') == 'default_pix2pix':
            dataset = DefaultPix2pixDataset(root_dir=self.configer.get('data', 'data_dir'), dataset=dataset,
                                            aug_transform=self.aug_val_transform,
                                            img_transform=self.img_transform,
                                            configer=self.configer)

        elif self.configer.get('dataset') == 'default_cyclegan':
            dataset = DefaultCycleGANDataset(root_dir=self.configer.get('data', 'data_dir'), dataset=dataset,
                                             aug_transform=self.aug_val_transform,
                                             img_transform=self.img_transform,
                                             configer=self.configer)

        elif self.configer.get('dataset') == 'default_facegan':
            dataset = DefaultFaceGANDataset(root_dir=self.configer.get('data', 'data_dir'),
                                            dataset=dataset, tag=self.configer.get('data', 'tag'),
                                            aug_transform=self.aug_val_transform,
                                            img_transform=self.img_transform,
                                            configer=self.configer)

        else:
            Log.error('{} val loader is invalid.'.format(self.configer.get('val', 'loader')))
            exit(1)

        valloader = data.DataLoader(
            dataset,
            batch_size=self.configer.get('val', 'batch_size'), shuffle=False,
            num_workers=self.configer.get('data', 'workers'), pin_memory=True,
            collate_fn=lambda *args: collate(
                *args, trans_dict=self.configer.get('val', 'data_transformer')
            )
        )

        return valloader
예제 #2
0
    def get_trainloader(self):
        if self.configer.get('dataset', default=None) in [None, 'default']:
            dataset = DefaultDataset(root_dir=self.configer.get(
                'data', 'data_dir'),
                                     dataset='train',
                                     aug_transform=self.aug_train_transform,
                                     img_transform=self.img_transform,
                                     configer=self.configer)

        else:
            Log.error('{} dataset is invalid.'.format(
                self.configer.get('dataset')))
            exit(1)

        sampler = None
        if self.configer.get('network.distributed'):
            sampler = torch.utils.data.distributed.DistributedSampler(dataset)

        trainloader = data.DataLoader(
            dataset,
            sampler=sampler,
            batch_size=self.configer.get('train', 'batch_size'),
            shuffle=(sampler is None),
            num_workers=self.configer.get('data', 'workers'),
            pin_memory=True,
            drop_last=self.configer.get('data', 'drop_last'),
            collate_fn=lambda *args: collate(*args,
                                             trans_dict=self.configer.get(
                                                 'train', 'data_transformer')))

        return trainloader
예제 #3
0
    def get_valloader(self):
        if self.configer.get('dataset', default=None) in [None, 'default']:
            dataset = DefaultDataset(root_dir=self.configer.get(
                'data', 'data_dir'),
                                     dataset='val',
                                     aug_transform=self.aug_val_transform,
                                     img_transform=self.img_transform,
                                     label_transform=self.label_transform,
                                     configer=self.configer)

        elif self.configer.get('dataset', default=None) == 'cityscapes':
            dataset = CityscapesDataset(root_dir=self.configer.get(
                'data', 'data_dir'),
                                        dataset='val',
                                        aug_transform=self.aug_val_transform,
                                        img_transform=self.img_transform,
                                        label_transform=self.label_transform,
                                        configer=self.configer)

        else:
            Log.error('{} dataset is invalid.'.format(
                self.configer.get('dataset')))
            exit(1)

        valloader = data.DataLoader(
            dataset,
            batch_size=self.configer.get('val', 'batch_size'),
            shuffle=False,
            num_workers=self.configer.get('data', 'workers'),
            pin_memory=True,
            collate_fn=lambda *args: collate(*args,
                                             trans_dict=self.configer.get(
                                                 'val', 'data_transformer')))

        return valloader
예제 #4
0
    def update(self, key, value, append=False):
        if key not in self.params_root:
            Log.error('{} Key: {} not existed!!!'.format(
                self._get_caller(), key))
            exit(1)

        self.params_root.put(key, value, append)
예제 #5
0
    def get_trainloader(self):
        if self.configer.get('dataset', default=None) == 'default_pix2pix':
            dataset = DefaultPix2pixDataset(root_dir=self.configer.get('data', 'data_dir'), dataset='train',
                                            aug_transform=self.aug_train_transform,
                                            img_transform=self.img_transform,
                                            configer=self.configer)

        elif self.configer.get('dataset') == 'default_cyclegan':
            dataset = DefaultCycleGANDataset(root_dir=self.configer.get('data', 'data_dir'), dataset='train',
                                             aug_transform=self.aug_train_transform,
                                             img_transform=self.img_transform,
                                             configer=self.configer)

        elif self.configer.get('dataset') == 'default_facegan':
            dataset = DefaultFaceGANDataset(root_dir=self.configer.get('data', 'data_dir'),
                                            dataset='train', tag=self.configer.get('data', 'tag'),
                                            aug_transform=self.aug_train_transform,
                                            img_transform=self.img_transform,
                                            configer=self.configer)

        else:
            Log.error('{} train loader is invalid.'.format(self.configer.get('train', 'loader')))
            exit(1)

        trainloader = data.DataLoader(
            dataset,
            batch_size=self.configer.get('train', 'batch_size'), shuffle=True,
            num_workers=self.configer.get('data', 'workers'), pin_memory=True,
            drop_last=self.configer.get('data', 'drop_last'),
            collate_fn=lambda *args: collate(
                *args, trans_dict=self.configer.get('train', 'data_transformer')
            )
        )

        return trainloader
예제 #6
0
    def vis_rois(self, inputs, indices_and_rois, rois_labels=None, name='default', sub_dir='rois'):
        base_dir = os.path.join(self.configer.get('project_dir'), DET_DIR, sub_dir)

        if not os.path.exists(base_dir):
            log.error('Dir:{} not exists!'.format(base_dir))
            os.makedirs(base_dir)

        for i in range(inputs.size(0)):
            rois = indices_and_rois[indices_and_rois[:, 0] == i][:, 1:]
            ori_img = DeNormalize(div_value=self.configer.get('normalize', 'div_value'),
                                  mean=self.configer.get('normalize', 'mean'),
                                  std=self.configer.get('normalize', 'std'))(inputs[i])
            ori_img = ori_img.data.cpu().squeeze().numpy().transpose(1, 2, 0).astype(np.uint8)
            ori_img = cv2.cvtColor(ori_img, cv2.COLOR_RGB2BGR)
            color_num = len(self.configer.get('details', 'color_list'))

            for j in range(len(rois)):
                label = 1 if rois_labels is None else rois_labels[j]
                if label == 0:
                    continue

                class_name = self.configer.get('details', 'name_seq')[label - 1]
                cv2.rectangle(ori_img,
                              (int(rois[j][0]), int(rois[j][1])),
                              (int(rois[j][2]), int(rois[j][3])),
                              color=self.configer.get('details', 'color_list')[(label - 1) % color_num], thickness=3)
                cv2.putText(ori_img, class_name,
                            (int(rois[j][0]) + 5, int(rois[j][3]) - 5),
                            cv2.FONT_HERSHEY_SIMPLEX, fontScale=0.5,
                            color=self.configer.get('details', 'color_list')[(label - 1) % color_num], thickness=2)

            img_path = os.path.join(base_dir, '{}_{}_{}.jpg'.format(name, i, time.time()))

            cv2.imwrite(img_path, ori_img)
    def vis_default_bboxes(self,
                           ori_img_in,
                           default_bboxes,
                           labels,
                           name='default',
                           sub_dir='encode'):
        base_dir = os.path.join(self.configer.get('project_dir'), DET_DIR,
                                sub_dir)

        if not os.path.exists(base_dir):
            log.error('Dir:{} not exists!'.format(base_dir))
            os.makedirs(base_dir)

        if not isinstance(ori_img_in, np.ndarray):
            ori_img = DeNormalize(
                div_value=self.configer.get('normalize', 'div_value'),
                mean=self.configer.get('normalize', 'mean'),
                std=self.configer.get('normalize', 'std'))(ori_img_in.clone())
            ori_img = ori_img.data.cpu().squeeze().numpy().transpose(
                1, 2, 0).astype(np.uint8)
            ori_img = cv2.cvtColor(ori_img, cv2.COLOR_RGB2BGR)
        else:
            ori_img = ori_img_in.copy()

        assert labels.size(0) == default_bboxes.size(0)

        bboxes = torch.cat([
            default_bboxes[:, :2] - default_bboxes[:, 2:] / 2,
            default_bboxes[:, :2] + default_bboxes[:, 2:] / 2
        ], 1)
        height, width, _ = ori_img.shape
        for i in range(labels.size(0)):
            if labels[i] == 0:
                continue

            class_name = self.configer.get('details',
                                           'name_seq')[labels[i] - 1]
            color_num = len(self.configer.get('details', 'color_list'))

            cv2.rectangle(
                ori_img,
                (int(bboxes[i][0] * width), int(bboxes[i][1] * height)),
                (int(bboxes[i][2] * width), int(bboxes[i][3] * height)),
                color=self.configer.get(
                    'details', 'color_list')[(labels[i] - 1) % color_num],
                thickness=3)

            cv2.putText(ori_img,
                        class_name, (int(bboxes[i][0] * width) + 5,
                                     int(bboxes[i][3] * height) - 5),
                        cv2.FONT_HERSHEY_SIMPLEX,
                        fontScale=0.5,
                        color=self.configer.get('details',
                                                'color_list')[(labels[i] - 1) %
                                                              color_num],
                        thickness=2)

        img_path = os.path.join(base_dir, '{}.jpg'.format(name))

        cv2.imwrite(img_path, ori_img)
예제 #8
0
    def __init__(self, configer):
        self.configer = configer

        if self.configer.get('data', 'image_tool') == 'pil':
            self.aug_train_transform = pil_aug_trans.PILAugCompose(
                self.configer, split='train')
        elif self.configer.get('data', 'image_tool') == 'cv2':
            self.aug_train_transform = cv2_aug_trans.CV2AugCompose(
                self.configer, split='train')
        else:
            Log.error('Not support {} image tool.'.format(
                self.configer.get('data', 'image_tool')))
            exit(1)

        if self.configer.get('data', 'image_tool') == 'pil':
            self.aug_val_transform = pil_aug_trans.PILAugCompose(self.configer,
                                                                 split='val')
        elif self.configer.get('data', 'image_tool') == 'cv2':
            self.aug_val_transform = cv2_aug_trans.CV2AugCompose(self.configer,
                                                                 split='val')
        else:
            Log.error('Not support {} image tool.'.format(
                self.configer.get('data', 'image_tool')))
            exit(1)

        self.img_transform = trans.Compose([
            trans.ToTensor(),
            trans.Normalize(**self.configer.get('data', 'normalize')),
        ])

        self.label_transform = trans.Compose([
            trans.ToLabel(),
            trans.ReLabel(255, -1),
        ])
    def vis_bboxes(self,
                   image_in,
                   bboxes_list,
                   name='default',
                   sub_dir='bbox'):
        """
          Show the diff bbox of individuals.
        """
        base_dir = os.path.join(self.configer.get('project_dir'), DET_DIR,
                                sub_dir)

        if isinstance(image_in, Image.Image):
            image = ImageHelper.rgb2bgr(ImageHelper.to_np(image_in))

        else:
            image = image_in.copy()

        if not os.path.exists(base_dir):
            log.error('Dir:{} not exists!'.format(base_dir))
            os.makedirs(base_dir)

        img_path = os.path.join(
            base_dir,
            name if ImageHelper.is_img(name) else '{}.jpg'.format(name))

        for bbox in bboxes_list:
            image = cv2.rectangle(image, (bbox[0], bbox[1]),
                                  (bbox[2], bbox[3]), (0, 255, 0), 2)

        cv2.imwrite(img_path, image)
예제 #10
0
    def vis_peaks(self, heatmap_in, ori_img_in, name='default', sub_dir='peaks'):
        base_dir = os.path.join(self.configer.get('project_dir'), POSE_DIR, sub_dir)
        if not os.path.exists(base_dir):
            Log.error('Dir:{} not exists!'.format(base_dir))
            os.makedirs(base_dir)

        if not isinstance(heatmap_in, np.ndarray):
            if len(heatmap_in.size()) != 3:
                Log.error('Heatmap size is not valid.')
                exit(1)

            heatmap = heatmap_in.clone().data.cpu().numpy().transpose(1, 2, 0)
        else:
            heatmap = heatmap_in.copy()

        if not isinstance(ori_img_in, np.ndarray):
            ori_img = DeNormalize(div_value=self.configer.get('normalize', 'div_value'),
                                  mean=self.configer.get('normalize', 'mean'),
                                  std=self.configer.get('normalize', 'std'))(ori_img_in.clone())
            ori_img = ori_img.data.cpu().squeeze().numpy().transpose(1, 2, 0).astype(np.uint8)
            ori_img = cv2.cvtColor(ori_img, cv2.COLOR_RGB2BGR)
        else:
            ori_img = ori_img_in.copy()

        for j in range(self.configer.get('data', 'num_kpts')):
            peaks = self.__get_peaks(heatmap[:, :, j])

            for peak in peaks:
                ori_img = cv2.circle(ori_img, (peak[0], peak[1]),
                                     self.configer.get('vis', 'circle_radius'),
                                     self.configer.get('details', 'color_list')[j], thickness=-1)

            cv2.imwrite(os.path.join(base_dir, '{}_{}.jpg'.format(name, j)), ori_img)
예제 #11
0
    def parse_dir_pose(self, image_dir, json_dir, mask_dir=None):
        if image_dir is None or not os.path.exists(image_dir):
            Log.error('Image Dir: {} not existed.'.format(image_dir))
            return

        if json_dir is None or not os.path.exists(json_dir):
            Log.error('Json Dir: {} not existed.'.format(json_dir))
            return

        for image_file in os.listdir(image_dir):
            shotname, extension = os.path.splitext(image_file)
            Log.info(image_file)
            image_canvas = cv2.imread(os.path.join(
                image_dir, image_file))  # B, G, R order.
            with open(os.path.join(json_dir, '{}.json'.format(shotname)),
                      'r') as json_stream:
                info_tree = json.load(json_stream)
                image_canvas = self.draw_points(image_canvas, info_tree)
                if self.configer.exists('details', 'limb_seq'):
                    image_canvas = self.link_points(image_canvas, info_tree)

            if mask_dir is not None:
                mask_file = os.path.join(mask_dir,
                                         '{}_vis.png'.format(shotname))
                mask_canvas = cv2.imread(mask_file)
                image_canvas = cv2.addWeighted(image_canvas, 0.6, mask_canvas,
                                               0.4, 0)

            cv2.imshow('main', image_canvas)
            cv2.waitKey()
예제 #12
0
    def get_valloader(self, dataset=None):
        dataset = 'val' if dataset is None else dataset
        if self.configer.get('dataset', default=None) == 'default_cpm':
            dataset = DefaultCPMDataset(root_dir=self.configer.get('data', 'data_dir'), dataset=dataset,
                                     aug_transform=self.aug_val_transform,
                                     img_transform=self.img_transform,
                                     configer=self.configer)

        elif self.configer.get('dataset', default=None) == 'default_openpose':
            dataset = DefaultOpenPoseDataset(root_dir=self.configer.get('data', 'data_dir'), dataset=dataset,
                                             aug_transform=self.aug_val_transform,
                                             img_transform=self.img_transform,
                                             configer=self.configer),

        else:
            Log.error('{} dataset is invalid.'.format(self.configer.get('dataset')))
            exit(1)

        valloader = data.DataLoader(
            dataset,
            batch_size=self.configer.get('val', 'batch_size'), shuffle=False,
            num_workers=self.configer.get('data', 'workers'), pin_memory=True,
            collate_fn=lambda *args: collate(
                *args, trans_dict=self.configer.get('val', 'data_transformer')
            )
        )
        return valloader
예제 #13
0
    def get(self, *key, **kwargs):
        key = '.'.join(key)
        if key in self.params_root or 'default' in kwargs:
            return self.params_root.get(key, **kwargs)

        else:
            Log.error('{} KeyError: {}.'.format(self._get_caller(), key))
            exit(1)
예제 #14
0
 def read_image(image_path, tool='pil', mode='RGB'):
     if tool == 'pil':
         return ImageHelper.pil_read_image(image_path, mode=mode)
     elif tool == 'cv2':
         return ImageHelper.cv2_read_image(image_path, mode=mode)
     else:
         Log.error('Not support mode {}'.format(mode))
         exit(1)
예제 #15
0
    def json2xml(json_file, xml_file):
        if not os.path.exists(json_file):
            Log.error('Json file: {} not exists.'.format(json_file))
            exit(1)

        xml_dir_name = os.path.dirname(xml_file)
        if not os.path.exists(xml_dir_name):
            Log.info('Xml Dir: {} not exists.'.format(xml_dir_name))
            os.makedirs(xml_dir_name)
예제 #16
0
    def get_testloader(self, test_dir=None, list_path=None, json_path=None):
        if self.configer.get('test.dataset',
                             default=None) in [None, 'default']:
            test_dir = test_dir if test_dir is not None else self.configer.get(
                'test', 'test_dir')
            dataset = DefaultDataset(test_dir=test_dir,
                                     aug_transform=self.aug_test_transform,
                                     img_transform=self.img_transform,
                                     configer=self.configer)

        elif self.configer.get('test.dataset') == 'list':
            list_path = list_path if list_path is not None else self.configer.get(
                'test', 'list_path')
            dataset = ListDataset(root_dir=self.configer.get(
                'test', 'root_dir'),
                                  list_path=list_path,
                                  aug_transform=self.aug_test_transform,
                                  img_transform=self.img_transform,
                                  configer=self.configer)

        elif self.configer.get('test.dataset') == 'json':
            json_path = json_path if json_path is not None else self.configer.get(
                'test', 'json_path')
            dataset = JsonDataset(root_dir=self.configer.get(
                'test', 'root_dir'),
                                  json_path=json_path,
                                  aug_transform=self.aug_test_transform,
                                  img_transform=self.img_transform,
                                  configer=self.configer)

        elif self.configer.get('test.dataset') == 'facegan':
            json_path = json_path if json_path is not None else self.configer.get(
                'test', 'json_path')
            dataset = FaceGANDataset(root_dir=self.configer.get(
                'test', 'root_dir'),
                                     json_path=json_path,
                                     aug_transform=self.aug_test_transform,
                                     img_transform=self.img_transform,
                                     configer=self.configer)

        else:
            Log.error('{} test dataset is invalid.'.format(
                self.configer.get('test.dataset')))
            exit(1)

        testloader = data.DataLoader(
            dataset,
            batch_size=self.configer.get('test.batch_size',
                                         default=torch.cuda.device_count()),
            shuffle=False,
            num_workers=self.configer.get('data', 'workers'),
            pin_memory=True,
            collate_fn=lambda *args: collate(*args,
                                             trans_dict=self.configer.get(
                                                 'test', 'data_transformer')))

        return testloader
예제 #17
0
    def xml2json(xml_file, json_file):
        if not os.path.exists(xml_file):
            Log.error('Xml file: {} not exists.'.format(xml_file))
            exit(1)

        json_dir_name = os.path.dirname(json_file)
        if not os.path.exists(json_dir_name):
            Log.info('Json Dir: {} not exists.'.format(json_dir_name))
            os.makedirs(json_dir_name)
예제 #18
0
    def get_cls_model(self):
        model_name = self.configer.get('network', 'model_name')

        if model_name not in CLS_MODEL_DICT:
            Log.error('Model: {} not valid!'.format(model_name))
            exit(1)

        model = CLS_MODEL_DICT[model_name](self.configer)
        return model
예제 #19
0
    def load_file(json_file):
        if not os.path.exists(json_file):
            Log.error('Json file: {} not exists.'.format(json_file))
            exit(1)

        with open(json_file, 'r') as read_stream:
            json_dict = json.load(read_stream)

        return json_dict
예제 #20
0
    def object_detector(self):
        model_name = self.configer.get('network', 'model_name')

        if model_name not in DET_MODEL_DICT:
            Log.error('Model: {} not valid!'.format(model_name))
            exit(1)

        model = DET_MODEL_DICT[model_name](self.configer)

        return model
예제 #21
0
    def seg_runner(self):
        key = self.configer.get('method')
        if key not in SEG_METHOD_DICT or key not in SEG_TEST_DICT:
            Log.error('Det Method: {} is not valid.'.format(key))
            exit(1)

        if self.configer.get('phase') == 'train':
            return SEG_METHOD_DICT[key](self.configer)
        else:
            return SEG_TEST_DICT[key](self.configer)
    def test(self, test_dir, out_dir):
        for _, data_dict in enumerate(
                self.test_loader.get_testloader(test_dir=test_dir)):
            total_logits = None
            if self.configer.get('test', 'mode') == 'ss_test':
                total_logits = self.ss_test(data_dict)

            elif self.configer.get('test', 'mode') == 'sscrop_test':
                total_logits = self.sscrop_test(data_dict,
                                                params_dict=self.configer.get(
                                                    'test', 'sscrop_test'))

            elif self.configer.get('test', 'mode') == 'ms_test':
                total_logits = self.ms_test(data_dict,
                                            params_dict=self.configer.get(
                                                'test', 'ms_test'))

            elif self.configer.get('test', 'mode') == 'mscrop_test':
                total_logits = self.mscrop_test(data_dict,
                                                params_dict=self.configer.get(
                                                    'test', 'mscrop_test'))

            else:
                Log.error('Invalid test mode:{}'.format(
                    self.configer.get('test', 'mode')))
                exit(1)

            meta_list = DCHelper.tolist(data_dict['meta'])
            for i in range(len(meta_list)):
                label_map = np.argmax(total_logits[i], axis=-1)
                label_img = np.array(label_map, dtype=np.uint8)
                ori_img_bgr = ImageHelper.read_image(meta_list[i]['img_path'],
                                                     tool='cv2',
                                                     mode='BGR')
                image_canvas = self.seg_parser.colorize(
                    label_img, image_canvas=ori_img_bgr)
                ImageHelper.save(image_canvas,
                                 save_path=os.path.join(
                                     out_dir, 'vis/{}.png'.format(
                                         meta_list[i]['filename'])))

                if self.configer.get('data.label_list',
                                     default=None) is not None:
                    label_img = self.__relabel(label_img)

                if self.configer.get('data.reduce_zero_label', default=False):
                    label_img = label_img + 1
                    label_img = label_img.astype(np.uint8)

                label_img = Image.fromarray(label_img, 'P')
                label_path = os.path.join(
                    out_dir, 'label/{}.png'.format(meta_list[i]['filename']))
                Log.info('Label Path: {}'.format(label_path))
                ImageHelper.save(label_img, label_path)
예제 #23
0
    def save(img, save_path):
        FileHelper.make_dirs(save_path, is_file=True)
        if isinstance(img, Image.Image):
            img.save(save_path)

        elif isinstance(img, np.ndarray):
            cv2.imwrite(save_path, img)

        else:
            Log.error('Image type is invalid.')
            exit(1)
예제 #24
0
    def get_size(img):
        if isinstance(img, Image.Image):
            return img.size

        elif isinstance(img, np.ndarray):
            height, width = img.shape[:2]
            return [width, height]

        else:
            Log.error('Image type is invalid.')
            exit(1)
예제 #25
0
    def cv2_read_image(image_path, mode='RGB'):
        img_bgr = cv2.imread(image_path, cv2.IMREAD_COLOR)
        if mode in ['BGR', 'RGB', 'GRAY']:
            return img_bgr

        elif mode == 'P':
            return ImageHelper.to_np(Image.open(image_path).convert('P'))

        else:
            Log.error('Not support mode {}'.format(mode))
            exit(1)
예제 #26
0
    def __read_json(self, root_dir, json_path):
        item_list = []
        for item in JsonHelper.load_file(json_path):
            img_path = os.path.join(root_dir, item['image_path'])
            if not os.path.exists(img_path) or not ImageHelper.is_img(img_path):
                Log.error('Image Path: {} is Invalid.'.format(img_path))
                exit(1)

            item_list.append((img_path, '.'.join(item['image_path'].split('.')[:-1])))

        Log.info('There are {} images..'.format(len(item_list)))
        return item_list
예제 #27
0
    def pil_read_image(image_path, mode='RGB'):
        with open(image_path, 'rb') as f:
            img = Image.open(f)
            if mode == 'RGB':
                return img.convert('RGB')

            elif mode == 'P':
                return img.convert('P')

            else:
                Log.error('Not support mode {}'.format(mode))
                exit(1)
    def __read_list(self, root_dir, list_path):
        item_list = []
        with open(list_path, 'r') as f:
            for line in f.readlines()[0:]:
                filename = line.strip().split()[0]
                img_path = os.path.join(root_dir, filename)
                if not os.path.exists(img_path) or not ImageHelper.is_img(img_path):
                    Log.error('Image Path: {} is Invalid.'.format(img_path))
                    exit(1)

                item_list.append((img_path, '.'.join(filename.split('.')[:-1])))

        Log.info('There are {} images..'.format(len(item_list)))
        return item_list
    def __init__(self, configer):
        self.configer = configer

        if self.configer.get('data', 'image_tool') == 'pil':
            self.aug_test_transform = pil_aug_trans.PILAugCompose(self.configer, split='test')
        elif self.configer.get('data', 'image_tool') == 'cv2':
            self.aug_test_transform = cv2_aug_trans.CV2AugCompose(self.configer, split='test')
        else:
            Log.error('Not support {} image tool.'.format(self.configer.get('data', 'image_tool')))
            exit(1)

        self.img_transform = Compose([
            ToTensor(),
            Normalize(**self.configer.get('data', 'normalize')), ])
예제 #30
0
 def BNReLU(num_features, norm_type=None, **kwargs):
     if norm_type == 'batchnorm':
         return nn.Sequential(nn.BatchNorm2d(num_features, **kwargs),
                              nn.ReLU())
     elif norm_type == 'encsync_batchnorm':
         from encoding.nn import SyncBatchNorm
         return nn.Sequential(SyncBatchNorm(num_features, **kwargs),
                              nn.ReLU())
     elif norm_type == 'instancenorm':
         return nn.Sequential(nn.InstanceNorm2d(num_features, **kwargs),
                              nn.ReLU())
     else:
         Log.error('Not support BN type: {}.'.format(norm_type))
         exit(1)