예제 #1
0
    def get_valloader(self, dataset=None):
        dataset = 'val' if dataset is None else dataset
        if not self.configer.exists('val', 'loader') or self.configer.get(
                'val', 'loader') == 'default':
            valloader = data.DataLoader(
                DefaultLoader(root_dir=self.configer.get('data', 'data_dir'),
                              dataset=dataset,
                              aug_transform=self.aug_val_transform,
                              img_transform=self.img_transform,
                              label_transform=self.label_transform,
                              configer=self.configer),
                batch_size=self.configer.get('val', 'batch_size'),
                shuffle=False,
                num_workers=self.configer.get('data', 'workers'),
                pin_memory=True,
                collate_fn=lambda *args: collate(
                    *args,
                    trans_dict=self.configer.get('val', 'data_transformer')))

            return valloader

        else:
            Log.error('{} val loader is invalid.'.format(
                self.configer.get('val', 'loader')))
            exit(1)
예제 #2
0
    def get_trainloader(self):
        if not self.configer.exists('train', 'loader') or self.configer.get(
                'train', 'loader') == 'default':
            trainloader = data.DataLoader(
                DefaultLoader(root_dir=self.configer.get('data', 'data_dir'),
                              dataset='train',
                              aug_transform=self.aug_train_transform,
                              img_transform=self.img_transform,
                              label_transform=self.label_transform,
                              configer=self.configer),
                batch_size=self.configer.get('train', 'batch_size'),
                shuffle=True,
                num_workers=self.configer.get('data', 'workers'),
                pin_memory=True,
                drop_last=self.configer.get('data', 'drop_last'),
                collate_fn=lambda *args: collate(
                    *args,
                    trans_dict=self.configer.get('train', 'data_transformer')))

            return trainloader

        else:
            Log.error('{} train loader is invalid.'.format(
                self.configer.get('train', 'loader')))
            exit(1)
예제 #3
0
    def vis_rois(self, inputs, indices_and_rois, rois_labels=None, name='default', sub_dir='rois'):
        base_dir = os.path.join(self.configer.get('project_dir'), DET_DIR, sub_dir)

        if not os.path.exists(base_dir):
            log.error('Dir:{} not exists!'.format(base_dir))
            os.makedirs(base_dir)

        for i in range(inputs.size(0)):
            rois = indices_and_rois[indices_and_rois[:, 0] == i][:, 1:]
            ori_img = DeNormalize(div_value=self.configer.get('normalize', 'div_value'),
                                  mean=self.configer.get('normalize', 'mean'),
                                  std=self.configer.get('normalize', 'std'))(inputs[i])
            ori_img = ori_img.data.cpu().squeeze().numpy().transpose(1, 2, 0).astype(np.uint8)
            ori_img = cv2.cvtColor(ori_img, cv2.COLOR_RGB2BGR)
            color_num = len(self.configer.get('details', 'color_list'))

            for j in range(len(rois)):
                label = 1 if rois_labels is None else rois_labels[j]
                if label == 0:
                    continue

                class_name = self.configer.get('details', 'name_seq')[label - 1]
                cv2.rectangle(ori_img,
                                (int(rois[j][0]), int(rois[j][1])),
                                (int(rois[j][2]), int(rois[j][3])),
                                color=self.configer.get('details', 'color_list')[(label - 1) % color_num], thickness=3)
                cv2.putText(ori_img, class_name,
                            (int(rois[j][0]) + 5, int(rois[j][3]) - 5),
                            cv2.FONT_HERSHEY_SIMPLEX, fontScale=0.5,
                            color=self.configer.get('details', 'color_list')[(label - 1) % color_num], thickness=2)

            img_path = os.path.join(base_dir, '{}_{}_{}.jpg'.format(name, i, time.time()))

            cv2.imwrite(img_path, ori_img)
예제 #4
0
    def get_center(self, img_size, bboxes):
        max_center = [img_size[0] / 2, img_size[1] / 2]
        max_index = 0
        if self.method == 'center':
            return max_center

        elif bboxes is None or len(bboxes) == 0 or self.method == 'random':
            x = random.randint(self.size[0] // 2, img_size[0] - self.size[0] // 2)
            y = random.randint(self.size[1] // 2, img_size[1] - self.size[1] // 2)
            return [x, y]

        elif self.method == 'focus':
            bboxes = np.array(bboxes)
            border = bboxes[:, 2:] - bboxes[:, 0:2]
            for i in range(len(border)):
                if border[i][0] * border[i][1] >= border[max_index][0] * border[max_index][1]:
                    max_index = i
                    max_center = [(bboxes[i][0] + bboxes[i][2]) / 2, (bboxes[i][1] + bboxes[i][3]) / 2]

            jitter = random.randint(-40, 40)
            max_center[0] += jitter
            max_center[1] += jitter

            return max_center

        elif self.method == 'grid':
            grid_x = random.randint(0, self.grid[0] - 1)
            grid_y = random.randint(0, self.grid[1] - 1)
            x = self.size[0] // 2 + grid_x * ((img_size[0] - self.size[0]) // (self.grid[0] - 1))
            y = self.size[1] // 2 + grid_y * ((img_size[1] - self.size[1]) // (self.grid[1] - 1))
            return [x, y]

        else:
            Log.error('Crop method {} is invalid.'.format(self.method))
            exit(1)
예제 #5
0
    def __init__(self, configer):
        self.configer = configer

        if self.configer.get('data', 'image_tool') == 'pil':
            self.aug_train_transform = pil_aug_trans.PILAugCompose(
                self.configer, split='train')
        elif self.configer.get('data', 'image_tool') == 'cv2':
            self.aug_train_transform = cv2_aug_trans.CV2AugCompose(
                self.configer, split='train')
        else:
            Log.error('Not support {} image tool.'.format(
                self.configer.get('data', 'image_tool')))
            exit(1)

        if self.configer.get('data', 'image_tool') == 'pil':
            self.aug_val_transform = pil_aug_trans.PILAugCompose(self.configer,
                                                                 split='val')
        elif self.configer.get('data', 'image_tool') == 'cv2':
            self.aug_val_transform = cv2_aug_trans.CV2AugCompose(self.configer,
                                                                 split='val')
        else:
            Log.error('Not support {} image tool.'.format(
                self.configer.get('data', 'image_tool')))
            exit(1)

        self.img_transform = trans.Compose([
            trans.ToTensor(),
            trans.Normalize(**self.configer.get('data', 'normalize')),
        ])

        self.label_transform = trans.Compose([
            trans.ToLabel(),
            trans.ReLabel(255, -1),
        ])
예제 #6
0
    def get_scale(self, img_size, bboxes):
        if self.method == 'random':
            scale_ratio = random.uniform(self.scale_range[0], self.scale_range[1])
            return scale_ratio

        elif self.method == 'focus':
            if self.input_size is not None and bboxes is not None and len(bboxes) > 0:
                bboxes = np.array(bboxes)
                border = bboxes[:, 2:] - bboxes[:, 0:2]
                scale = 0.6 / max(max(border[:, 0]) / self.input_size[0], max(border[:, 1]) / self.input_size[1])
                scale_ratio = random.uniform(self.scale_range[0], self.scale_range[1]) * scale
                return scale_ratio

            else:
                scale_ratio = random.uniform(self.scale_range[0], self.scale_range[1])
                return scale_ratio

        elif self.method == 'bound':
            scale1 = self.resize_bound[0] / min(img_size)
            scale2 = self.resize_bound[1] / max(img_size)
            scale = min(scale1, scale2)
            return scale

        else:
            Log.error('Resize method {} is invalid.'.format(self.method))
            exit(1)
예제 #7
0
    def get_valloader(self):
        if self.configer.get('method') == 'single_shot_detector':
            valloader = data.DataLoader(
                SSDDataLoader(root_dir=os.path.join(
                    self.configer.get('data', 'data_dir'), 'val'),
                              aug_transform=self.aug_val_transform,
                              img_transform=self.img_transform,
                              configer=self.configer),
                batch_size=self.configer.get('data', 'val_batch_size'),
                shuffle=False,
                num_workers=self.configer.get('data', 'workers'),
                pin_memory=True)

            return valloader

        elif self.configer.get('method') == 'faster_rcnn':
            valloader = data.DataLoader(
                FRDataLoader(root_dir=os.path.join(
                    self.configer.get('data', 'data_dir'), 'val'),
                             aug_transform=self.aug_val_transform,
                             img_transform=self.img_transform,
                             configer=self.configer),
                batch_size=self.configer.get('data', 'val_batch_size'),
                shuffle=False,
                num_workers=self.configer.get('data', 'workers'),
                pin_memory=True)

            return valloader

        else:
            Log.error('Method: {} loader is invalid.'.format(
                self.configer.get('method')))
            return None
예제 #8
0
    def __init__(self, configer):
        self.configer = configer

        if self.configer.get('data', 'image_tool') == 'pil':
            self.aug_train_transform = pil_aug_trans.PILAugCompose(
                self.configer, split='train')
        elif self.configer.get('data', 'image_tool') == 'cv2':
            self.aug_train_transform = cv2_aug_trans.CV2AugCompose(
                self.configer, split='train')
        else:
            Log.error('Not support {} image tool.'.format(
                self.configer.get('data', 'image_tool')))
            exit(1)

        if self.configer.get('data', 'image_tool') == 'pil':
            self.aug_val_transform = pil_aug_trans.PILAugCompose(self.configer,
                                                                 split='val')
        elif self.configer.get('data', 'image_tool') == 'cv2':
            self.aug_val_transform = cv2_aug_trans.CV2AugCompose(self.configer,
                                                                 split='val')
        else:
            Log.error('Not support {} image tool.'.format(
                self.configer.get('data', 'image_tool')))
            exit(1)

        self.img_transform = trans.Compose([
            trans.ToTensor(),
            trans.Normalize(div_value=self.configer.get(
                'normalize', 'div_value'),
                            mean=self.configer.get('normalize', 'mean'),
                            std=self.configer.get('normalize', 'std')),
        ])
예제 #9
0
파일: configer.py 프로젝트: zy0851/TorchCV
    def add(self, key_tuple, value):
        if self.exists(*key_tuple):
            Log.error('{} Key: {} existed!!!'.format(self._get_caller(),
                                                     key_tuple))
            exit(1)

        if len(key_tuple) == 1:
            self.params_root[key_tuple[0]] = value

        elif len(key_tuple) == 2:
            if key_tuple[0] not in self.params_root:
                self.params_root[key_tuple[0]] = dict()

            self.params_root[key_tuple[0]][key_tuple[1]] = value

        elif len(key_tuple) == 3:
            if key_tuple[0] not in self.params_root:
                self.params_root[key_tuple[0]] = dict()

            if key_tuple[1] not in self.params_root[key_tuple[0]]:
                self.params_root[key_tuple[0]][key_tuple[1]] = dict()

            self.params_root[key_tuple[0]][key_tuple[1]][key_tuple[2]] = value

        else:
            Log.error('{} KeyError: {}.'.format(self._get_caller(), key_tuple))
            exit(1)
예제 #10
0
    def get_valloader(self):
        if self.configer.get('method') == 'conv_pose_machine':
            valloader = data.DataLoader(
                CPMDataLoader(root_dir=os.path.join(
                    self.configer.get('data', 'data_dir'), 'val'),
                              aug_transform=self.aug_val_transform,
                              img_transform=self.img_transform,
                              configer=self.configer),
                batch_size=self.configer.get('val', 'batch_size'),
                shuffle=False,
                num_workers=self.configer.get('data', 'workers'),
                pin_memory=True,
                collate_fn=lambda *args: CollateFunctions.default_collate(
                    *args, data_keys=['img', 'heatmap']))

            return valloader

        elif self.configer.get('method') == 'open_pose':
            valloader = data.DataLoader(
                OPDataLoader(root_dir=os.path.join(
                    self.configer.get('data', 'data_dir'), 'val'),
                             aug_transform=self.aug_val_transform,
                             img_transform=self.img_transform,
                             configer=self.configer),
                batch_size=self.configer.get('val', 'batch_size'),
                shuffle=False,
                collate_fn=lambda *args: CollateFunctions.default_collate(
                    *args, data_keys=['img', 'maskmap', 'heatmap', 'vecmap']))

            return valloader

        else:
            Log.error('Method: {} loader is invalid.'.format(
                self.configer.get('method')))
            return None
예제 #11
0
    def vis_bboxes(self,
                   image_in,
                   bboxes_list,
                   name='default',
                   sub_dir='bbox'):
        """
          Show the diff bbox of individuals.
        """
        base_dir = os.path.join(self.configer.get('project_dir'), DET_DIR,
                                sub_dir)

        if isinstance(image_in, Image.Image):
            image = ImageHelper.rgb2bgr(ImageHelper.to_np(image_in))

        else:
            image = image_in.copy()

        if not os.path.exists(base_dir):
            log.error('Dir:{} not exists!'.format(base_dir))
            os.makedirs(base_dir)

        img_path = os.path.join(
            base_dir,
            name if ImageHelper.is_img(name) else '{}.jpg'.format(name))

        for bbox in bboxes_list:
            image = cv2.rectangle(image, (bbox[0], bbox[1]),
                                  (bbox[2], bbox[3]), (0, 255, 0), 2)

        cv2.imwrite(img_path, image)
예제 #12
0
    def get_backbone(self, **params):
        backbone = self.configer.get('network', 'backbone')

        model = None
        if 'vgg' in backbone:
            model = VGGBackbone(self.configer)(**params)

        elif 'darknet' in backbone:
            model = DarkNetBackbone(self.configer)(**params)

        elif 'resnet' in backbone:
            model = ResNetBackbone(self.configer)(**params)

        elif 'mobilenet' in backbone:
            model = MobileNetBackbone(self.configer)(*params)

        elif 'densenet' in backbone:
            model = DenseNetBackbone(self.configer)(**params)

        elif 'squeezenet' in backbone:
            model = SqueezeNetBackbone(self.configer)(**params)

        else:
            Log.error('Backbone {} is invalid.'.format(backbone))
            exit(1)

        return model
예제 #13
0
    def vis_bboxes(self,
                   image_in,
                   bboxes_list,
                   name='default',
                   vis_dir=BBOX_DIR,
                   scale_factor=1,
                   img_size=None):
        """
          Show the diff bbox of individuals.
        """
        base_dir = os.path.join(self.configer.get('project_dir'), vis_dir)

        image = image_in.copy()
        if not os.path.exists(vis_dir):
            log.error('Dir:{} not exists!'.format(vis_dir))
            os.makedirs(vis_dir)

        img_path = os.path.join(base_dir, '{}.jpg'.format(name))

        for bbox in bboxes_list:
            image = cv2.rectangle(image, (bbox[0], bbox[1]),
                                  (bbox[2], bbox[3]), (0, 255, 0), -1)

        image = self.scale_image(image, scale_factor, img_size)
        cv2.imwrite(img_path, image)
예제 #14
0
    def vis_default_bboxes(self,
                           ori_img_in,
                           default_bboxes,
                           labels,
                           name='default',
                           sub_dir='encode'):
        base_dir = os.path.join(self.configer.get('project_dir'), DET_DIR,
                                sub_dir)

        if not os.path.exists(base_dir):
            log.error('Dir:{} not exists!'.format(base_dir))
            os.makedirs(base_dir)

        if not isinstance(ori_img_in, np.ndarray):
            ori_img = DeNormalize(
                div_value=self.configer.get('normalize', 'div_value'),
                mean=self.configer.get('normalize', 'mean'),
                std=self.configer.get('normalize', 'std'))(ori_img_in.clone())
            ori_img = ori_img.data.cpu().squeeze().numpy().transpose(
                1, 2, 0).astype(np.uint8)
            ori_img = cv2.cvtColor(ori_img, cv2.COLOR_RGB2BGR)
        else:
            ori_img = ori_img_in.copy()

        assert labels.size(0) == default_bboxes.size(0)

        bboxes = torch.cat([
            default_bboxes[:, :2] - default_bboxes[:, 2:] / 2,
            default_bboxes[:, :2] + default_bboxes[:, 2:] / 2
        ], 1)
        height, width, _ = ori_img.shape
        for i in range(labels.size(0)):
            if labels[i] == 0:
                continue

            class_name = self.configer.get('details',
                                           'name_seq')[labels[i] - 1]
            color_num = len(self.configer.get('details', 'color_list'))

            cv2.rectangle(
                ori_img,
                (int(bboxes[i][0] * width), int(bboxes[i][1] * height)),
                (int(bboxes[i][2] * width), int(bboxes[i][3] * height)),
                color=self.configer.get(
                    'details', 'color_list')[(labels[i] - 1) % color_num],
                thickness=3)

            cv2.putText(ori_img,
                        class_name, (int(bboxes[i][0] * width) + 5,
                                     int(bboxes[i][3] * height) - 5),
                        cv2.FONT_HERSHEY_SIMPLEX,
                        fontScale=0.5,
                        color=self.configer.get('details',
                                                'color_list')[(labels[i] - 1) %
                                                              color_num],
                        thickness=2)

        img_path = os.path.join(base_dir, '{}.jpg'.format(name))

        cv2.imwrite(img_path, ori_img)
예제 #15
0
    def test(self):
        base_dir = os.path.join(self.configer.get('project_dir'),
                                'val/results/pose', self.configer.get('dataset'))

        test_img = self.configer.get('test_img')
        test_dir = self.configer.get('test_dir')
        if test_img is None and test_dir is None:
            Log.error('test_img & test_dir not exists.')
            exit(1)

        if test_img is not None and test_dir is not None:
            Log.error('Either test_img or test_dir.')
            exit(1)

        if test_img is not None:
            base_dir = os.path.join(base_dir, 'test_img')
            if not os.path.exists(base_dir):
                os.makedirs(base_dir)

            filename = test_img.rstrip().split('/')[-1]
            save_path = os.path.join(base_dir, filename)
            self.__test_img(test_img, save_path)

        else:
            base_dir = os.path.join(base_dir, 'test_dir',  test_dir.rstrip('/').split('/')[-1])
            if not os.path.exists(base_dir):
                os.makedirs(base_dir)

            for filename in FileHelper.list_dir(test_dir):
                image_path = os.path.join(test_dir, filename)
                save_path = os.path.join(base_dir, filename)
                if not os.path.exists(os.path.dirname(save_path)):
                    os.makedirs(os.path.dirname(save_path))

                self.__test_img(image_path, save_path)
예제 #16
0
    def __list_dirs(self, root_dir):
        img_list = list()
        label_list = list()
        json_list = list()
        mask_list = list()
        image_dir = os.path.join(root_dir, 'image')
        label_dir = os.path.join(root_dir, 'label')
        json_dir = os.path.join(root_dir, 'json')
        mask_dir = os.path.join(root_dir, 'mask')
        img_extension = os.listdir(image_dir)[0].split('.')[-1]

        for file_name in os.listdir(json_dir):
            image_name = '.'.join(file_name.split('.')[:-1])
            img_list.append(
                os.path.join(image_dir, '{}.{}'.format(image_name,
                                                       img_extension)))
            label_path = os.path.join(label_dir, '{}.png'.format(image_name))
            label_list.append(label_path)
            mask_path = os.path.join(mask_dir, '{}.png'.format(image_name))
            mask_list.append(mask_path)
            json_path = os.path.join(json_dir, file_name)
            json_list.append(json_path)
            if not os.path.exists(json_path):
                Log.error('Json Path: {} not exists.'.format(json_path))
                exit(1)

        return img_list, label_list, json_list, mask_list
    def __call__(self, img, labelmap=None, maskmap=None):

        if self.split == 'train':
            shuffle_trans_seq = []
            if self.configer.exists('train_trans', 'shuffle_trans_seq'):
                if isinstance(self.configer.get('train_trans', 'shuffle_trans_seq')[0], list):
                    shuffle_trans_seq_list = self.configer.get('train_trans', 'shuffle_trans_seq')
                    shuffle_trans_seq = shuffle_trans_seq_list[random.randint(0, len(shuffle_trans_seq_list))]
                else:
                    shuffle_trans_seq = self.configer.get('train_trans', 'shuffle_trans_seq')
                    random.shuffle(shuffle_trans_seq)

            for trans_key in (shuffle_trans_seq + self.configer.get('train_trans', 'trans_seq')):
                img, labelmap, maskmap = self.transforms[trans_key](img, labelmap, maskmap)

        else:
            for trans_key in self.configer.get('val_trans', 'trans_seq'):
                img, labelmap, maskmap = self.transforms[trans_key](img, labelmap, maskmap)

        if self.__check_none([labelmap, maskmap], ['n', 'n']):
            return img

        if self.__check_none([labelmap, maskmap], ['y', 'n']):
            return img, labelmap

        if self.__check_none([labelmap, maskmap], ['n', 'y']):
            return img, maskmap

        if self.__check_none([labelmap, maskmap], ['y', 'y']):
            return img, labelmap, maskmap

        Log.error('Params is not valid.')
        exit(1)
예제 #18
0
    def __init__(self, args_parser=None, hypes_file=None, config_dict=None):
        if config_dict is not None:
            self.params_root = config_dict

        elif hypes_file is not None:
            if not os.path.exists(hypes_file):
                Log.error('Json Path:{} not exists!'.format(hypes_file))
                exit(0)

            json_stream = open(hypes_file, 'r')
            self.params_root = json.load(json_stream)
            json_stream.close()

        elif args_parser is not None:
            self.args_dict = args_parser.__dict__
            self.params_root = None

            if not os.path.exists(args_parser.hypes):
                print('Json Path:{} not exists!'.format(args_parser.hypes))
                exit(1)

            json_stream = open(args_parser.hypes, 'r')
            self.params_root = json.load(json_stream)
            json_stream.close()

            for key, value in self.args_dict.items():
                if not self.exists(*key.split(':')):
                    self.add(key.split(':'), value)
                elif value is not None:
                    self.update(key.split(':'), value)
예제 #19
0
    def parse_dir_det(self, image_dir, json_dir, mask_dir=None):
        if image_dir is None or not os.path.exists(image_dir):
            Log.error('Image Dir: {} not existed.'.format(image_dir))
            return

        if json_dir is None or not os.path.exists(json_dir):
            Log.error('Json Dir: {} not existed.'.format(json_dir))
            return

        for image_file in os.listdir(image_dir):
            shotname, extension = os.path.splitext(image_file)
            Log.info(image_file)
            image_canvas = cv2.imread(os.path.join(
                image_dir, image_file))  # B, G, R order.
            with open(os.path.join(json_dir, '{}.json'.format(shotname)),
                      'r') as json_stream:
                info_tree = json.load(json_stream)
                image_canvas = self.draw_bboxes(image_canvas, info_tree)

            if mask_dir is not None:
                mask_file = os.path.join(mask_dir,
                                         '{}_vis.png'.format(shotname))
                mask_canvas = cv2.imread(mask_file)
                image_canvas = cv2.addWeighted(image_canvas, 0.6, mask_canvas,
                                               0.4, 0)

            cv2.imshow('main', image_canvas)
            cv2.waitKey()
예제 #20
0
    def test(self, test_dir, out_dir):
        for i, data_dict in enumerate(
                self.test_loader.get_testloader(test_dir=test_dir)):
            total_logits = None
            if self.configer.get('test', 'mode') == 'ss_test':
                total_logits = self.ss_test(data_dict)

            elif self.configer.get('test', 'mode') == 'sscrop_test':
                total_logits = self.sscrop_test(data_dict,
                                                params_dict=self.configer.get(
                                                    'test', 'sscrop_test'))

            elif self.configer.get('test', 'mode') == 'ms_test':
                total_logits = self.ms_test(data_dict,
                                            params_dict=self.configer.get(
                                                'test', 'ms_test'))

            elif self.configer.get('test', 'mode') == 'mscrop_test':
                total_logits = self.mscrop_test(data_dict,
                                                params_dict=self.configer.get(
                                                    'test', 'mscrop_test'))

            else:
                Log.error('Invalid test mode:{}'.format(
                    self.configer.get('test', 'mode')))
                exit(1)

            meta_list = DCHelper.tolist(data_dict['meta'])
            img_list = DCHelper.tolist(data_dict['img'])
            for i in range(len(meta_list)):
                filename = meta_list[i]['img_path'].split('/')[-1].split(
                    '.')[0]
                label_map = np.argmax(total_logits[i], axis=-1)
                label_img = np.array(label_map, dtype=np.uint8)
                ori_img_bgr = self.blob_helper.tensor2bgr(img_list[i][0])
                ori_img_bgr = ImageHelper.resize(
                    ori_img_bgr,
                    target_size=meta_list[i]['ori_img_size'],
                    interpolation='linear')
                image_canvas = self.seg_parser.colorize(
                    label_img, image_canvas=ori_img_bgr)
                ImageHelper.save(image_canvas,
                                 save_path=os.path.join(
                                     out_dir, 'vis/{}.png'.format(filename)))

                if self.configer.exists('data', 'label_list'):
                    label_img = self.__relabel(label_img)

                if self.configer.exists(
                        'data', 'reduce_zero_label') and self.configer.get(
                            'data', 'reduce_zero_label'):
                    label_img = label_img + 1
                    label_img = label_img.astype(np.uint8)

                label_img = Image.fromarray(label_img, 'P')
                label_path = os.path.join(out_dir,
                                          'label/{}.png'.format(filename))
                Log.info('Label Path: {}'.format(label_path))
                ImageHelper.save(label_img, label_path)
예제 #21
0
파일: dpn.py 프로젝트: zy0851/TorchCV
def get_densenet(name='dpn_26', num_classes=10):
    if name == 'dpn_26':
        return DPN26(num_classes)
    elif name == 'dpn_92':
        return DPN92(num_classes)
    else:
        Log.error('Model: {} not valid!'.format(name))
        exit(1)
예제 #22
0
    def get_cls_loss(self, key):
        if key not in CLS_LOSS_DICT:
            Log.error('Loss: {} not valid!'.format(key))
            exit(1)

        loss = CLS_LOSS_DICT[key](self.configer)

        return loss
예제 #23
0
 def read_image(image_path, tool='pil', mode='RGB'):
     if tool == 'pil':
         return ImageHelper.pil_read_image(image_path, mode=mode)
     elif tool == 'cv2':
         return ImageHelper.cv2_read_image(image_path, mode=mode)
     else:
         Log.error('Not support mode {}'.format(mode))
         exit(1)
예제 #24
0
    def save_net(self, net, save_mode='iters'):
        state = {
            'config_dict': self.configer.to_dict(),
            'state_dict': net.state_dict(),
        }
        if self.configer.get('checkpoints', 'checkpoints_root') is None:
            checkpoints_dir = os.path.join(
                self.configer.get('project_dir'),
                self.configer.get('checkpoints', 'checkpoints_dir'))
        else:
            checkpoints_dir = os.path.join(
                self.configer.get('checkpoints', 'checkpoints_root'),
                self.configer.get('checkpoints', 'checkpoints_dir'))

        if not os.path.exists(checkpoints_dir):
            os.makedirs(checkpoints_dir)

        if save_mode == 'performance':
            if self.configer.get('performance') > self.configer.get(
                    'max_performance'):
                latest_name = '{}_max_performance.pth'.format(
                    self.configer.get('checkpoints', 'checkpoints_name'))
                torch.save(state, os.path.join(checkpoints_dir, latest_name))
                self.configer.update_value(['max_performance'],
                                           self.configer.get('performance'))

        elif save_mode == 'val_loss':
            if self.configer.get('val_loss') < self.configer.get(
                    'min_val_loss'):
                latest_name = '{}_min_loss.pth'.format(
                    self.configer.get('checkpoints', 'checkpoints_name'))
                torch.save(state, os.path.join(checkpoints_dir, latest_name))
                self.configer.update_value(['min_val_loss'],
                                           self.configer.get('val_loss'))

        elif save_mode == 'iters':
            if self.configer.get('iters') - self.configer.get('last_iters') >= \
                    self.configer.get('checkpoints', 'save_iters'):
                latest_name = '{}_iters{}.pth'.format(
                    self.configer.get('checkpoints', 'checkpoints_name'),
                    self.configer.get('iters'))
                torch.save(state, os.path.join(checkpoints_dir, latest_name))
                self.configer.update_value(['last_iters'],
                                           self.configer.get('iters'))

        elif save_mode == 'epoch':
            if self.configer.get('epoch') - self.configer.get('last_epoch') >= \
                    self.configer.get('checkpoints', 'save_epoch'):
                latest_name = '{}_epoch{}.pth'.format(
                    self.configer.get('checkpoints', 'checkpoints_name'),
                    self.configer.get('epoch'))
                torch.save(state, os.path.join(checkpoints_dir, latest_name))
                self.configer.update_value(['last_epoch'],
                                           self.configer.get('epoch'))

        else:
            Log.error('Metric: {} is invalid.'.format(save_mode))
            exit(1)
예제 #25
0
    def get_pose_loss(self, loss_type=None):
        key = self.configer.get(
            'loss', 'loss_type') if loss_type is None else loss_type
        if key not in POSE_LOSS_DICT:
            Log.error('Loss: {} not valid!'.format(key))
            exit(1)

        loss = POSE_LOSS_DICT[key](self.configer)
        return self._parallel(loss)
예제 #26
0
    def load_file(json_file):
        if not os.path.exists(json_file):
            Log.error('Json file: {} not exists.'.format(json_file))
            exit(1)

        with open(json_file, 'r') as read_stream:
            json_dict = json.load(read_stream)

        return json_dict
예제 #27
0
    def xml2json(xml_file, json_file):
        if not os.path.exists(xml_file):
            Log.error('Xml file: {} not exists.'.format(xml_file))
            exit(1)

        json_dir_name = os.path.dirname(json_file)
        if not os.path.exists(json_dir_name):
            Log.info('Json Dir: {} not exists.'.format(json_dir_name))
            os.makedirs(json_dir_name)
예제 #28
0
    def json2xml(json_file, xml_file):
        if not os.path.exists(json_file):
            Log.error('Json file: {} not exists.'.format(json_file))
            exit(1)

        xml_dir_name = os.path.dirname(xml_file)
        if not os.path.exists(xml_dir_name):
            Log.info('Xml Dir: {} not exists.'.format(xml_dir_name))
            os.makedirs(xml_dir_name)
예제 #29
0
    def load_net(self, net):
        if self.configer.get('gpu') is not None:
            net = self._make_parallel(net)

        net = net.to(torch.device('cpu' if self.configer.get('gpu') is None else 'cuda'))

        if self.configer.get('network', 'resume') is not None:
            resume_dict = torch.load(self.configer.get('network', 'resume'))
            if 'state_dict' in resume_dict:
                checkpoint_dict = resume_dict['state_dict']

            elif 'model' in resume_dict:
                checkpoint_dict = resume_dict['model']

            else:
                checkpoint_dict = resume_dict

            net_dict = net.state_dict()

            not_match_list = list()
            for key, value in checkpoint_dict.items():
                if key.split('.')[0] == 'module':
                    module_key = key
                    norm_key = '.'.join(key.split('.')[1:])
                else:
                    module_key = 'module.{}'.format(key)
                    norm_key = key

                if self.configer.get('network', 'parallel'):
                    key = module_key
                else:
                    key = norm_key

                if net_dict[key].size() == value.size():
                    net_dict[key] = value
                else:
                    not_match_list.append(key)

            if self.configer.get('network', 'resume_level') == 'full':
                assert len(not_match_list) == 0

            elif self.configer.get('network', 'resume_level') == 'part':
                Log.info('Not Matched Keys: {}'.format(not_match_list))

            else:
                Log.error('Resume Level: {} is invalid.'.format(self.configer.get('network', 'resume_level')))
                exit(1)

            if self.configer.get('network', 'resume_continue'):
                self.configer.update_value(['epoch'], resume_dict['config_dict']['epoch'])
                self.configer.update_value(['iters'], resume_dict['config_dict']['iters'])
                self.configer.update_value(['performance'], resume_dict['config_dict']['performance'])
                self.configer.update_value(['val_loss'], resume_dict['config_dict']['val_loss'])

            net.load_state_dict(net_dict)

        return net
예제 #30
0
    def select_seg_method(self):
        key = self.configer.get('method')
        if key not in SEG_METHOD_DICT or key not in SEG_TEST_DICT:
            Log.error('Det Method: {} is not valid.'.format(key))
            exit(1)

        if self.configer.get('phase') == 'train':
            return SEG_METHOD_DICT[key](self.configer)
        else:
            return SEG_TEST_DICT[key](self.configer)