def update(self, key, value, append=False):
        if key not in self.params_root:
            Log.error('{} Key: {} not existed!!!'.format(
                self._get_caller(), key))
            exit(1)

        self.params_root.put(key, value, append)
Exemple #2
0
    def get_valloader(self, dataset=None):
        dataset = 'val' if dataset is None else dataset

        if self.configer.exists('data', 'use_dt_offset') or self.configer.exists('data', 'pred_dt_offset'):
            """
            dt-offset manner:
            load both the ground-truth label and offset (based on distance transform).
            """   
            Log.info('use distance transform based offset loader for val ...')
            klass = DTOffsetLoader

        elif self.configer.get('method') == 'fcn_segmentor':
            """
            default manner:
            load the ground-truth label.
            """   
            Log.info('use DefaultLoader for val ...')
            klass = DefaultLoader
        else:
            Log.error('Method: {} loader is invalid.'.format(self.configer.get('method')))
            return None

        loader, sampler = self.get_dataloader_sampler(klass, 'val', dataset)
        valloader = data.DataLoader(
            loader,
            sampler=sampler,
            batch_size=self.configer.get('val', 'batch_size') // get_world_size(), pin_memory=True,
            num_workers=self.configer.get('data', 'workers'), shuffle=False,
            collate_fn=lambda *args: collate(
                *args, trans_dict=self.configer.get('val', 'data_transformer')
            )
        )
        return valloader
    def __list_dirs(self, root_dir, dataset):
        img_list = list()
        name_list = list()
        image_dir = os.path.join(root_dir, dataset)
        img_extension = os.listdir(image_dir)[0].split('.')[-1]

        if self.configer.get('dataset') == 'cityscapes':
            for item in os.listdir(image_dir):
                sub_image_dir = os.path.join(image_dir, item)
                for file_name in os.listdir(sub_image_dir):
                    image_name = file_name.split('.')[0]
                    img_path = os.path.join(sub_image_dir, file_name)
                    if not os.path.exists(img_path):
                        Log.error(
                            'Image Path: {} not exists.'.format(img_path))
                        continue
                    img_list.append(img_path)
                    name_list.append(image_name)
        else:
            for file_name in os.listdir(image_dir):
                image_name = file_name.split('.')[0]
                img_path = os.path.join(image_dir, file_name)
                if not os.path.exists(img_path):
                    Log.error('Image Path: {} not exists.'.format(img_path))
                    continue
                img_list.append(img_path)
                name_list.append(image_name)

        return img_list, name_list
Exemple #4
0
    def __init__(self, configer):
        self.configer = configer

        if self.configer.get('data', 'image_tool') == 'pil':
            self.aug_train_transform = pil_aug_trans.PILAugCompose(
                self.configer, split='train')
            self.aug_val_transform = pil_aug_trans.PILAugCompose(self.configer,
                                                                 split='val')
        elif self.configer.get('data', 'image_tool') == 'cv2':
            self.aug_train_transform = cv2_aug_trans.CV2AugCompose(
                self.configer, split='train')
            self.aug_val_transform = cv2_aug_trans.CV2AugCompose(self.configer,
                                                                 split='val')
        else:
            Log.error('Not support {} image tool.'.format(
                self.configer.get('data', 'image_tool')))
            exit(1)

        self.img_transform = trans.Compose([
            trans.ToTensor(),
            trans.Normalize(div_value=self.configer.get(
                'normalize', 'div_value'),
                            mean=self.configer.get('normalize', 'mean'),
                            std=self.configer.get('normalize', 'std')),
        ])

        self.label_transform = trans.Compose([
            trans.ToLabel(),
            trans.ReLabel(255, -1),
        ])
Exemple #5
0
    def get_valloader(self, loader_type=None, data_dir=None, batch_size=None):
        loader_type = self.configer.get(
            'val', 'loader') if loader_type is None else loader_type
        data_dir = self.configer.get(
            'data', 'data_dir') if data_dir is None else data_dir
        batch_size = self.configer.get(
            'val', 'batch_size') if batch_size is None else batch_size
        if loader_type is None or loader_type == 'default':
            valloader = data.DataLoader(
                DefaultDataset(data_dir=data_dir,
                               dataset='val',
                               aug_transform=self.aug_val_transform,
                               img_transform=self.img_transform,
                               configer=self.configer),
                batch_size=batch_size,
                shuffle=False,
                num_workers=self.configer.get('data', 'workers'),
                pin_memory=True,
                collate_fn=collate)

            return valloader

        else:
            Log.error('{} val loader is invalid.'.format(
                self.configer.get('val', 'loader')))
            exit(1)
Exemple #6
0
    def get_valloader(self, dataset=None):
        dataset = 'val' if dataset is None else dataset

        if self.configer.get('method') == 'fcn_segmentor':
            """
            default manner:
            load the ground-truth label.
            """
            Log.info('use DefaultLoader for val ...')
            valloader = data.DataLoader(
                DefaultLoader(root_dir=self.configer.get('data', 'data_dir'),
                              dataset=dataset,
                              aug_transform=self.aug_val_transform,
                              img_transform=self.img_transform,
                              label_transform=self.label_transform,
                              configer=self.configer),
                batch_size=self.configer.get('val', 'batch_size'),
                pin_memory=True,
                num_workers=self.configer.get('data', 'workers'),
                shuffle=False,
                collate_fn=lambda *args: collate(
                    *args,
                    trans_dict=self.configer.get('val', 'data_transformer')))
            return valloader

        else:
            Log.error('Method: {} loader is invalid.'.format(
                self.configer.get('method')))
            return None
    def Linear(linear_type):
        if linear_type == 'default':
            return Linear

        if linear_type == 'nobias':
            return functools.partial(Linear, bias=False)

        elif 'arc' in linear_type:
            #example arc0.5_64  arc0.32_64 easyarc0.5_64
            margin_scale = linear_type.split('arc')[1]
            margin = float(margin_scale.split('_')[0])
            scale = float(margin_scale.split('_')[1])
            easy = True if 'easy' in linear_type else False
            return functools.partial(ArcLinear,
                                     s=scale,
                                     m=margin,
                                     easy_margin=easy)

        elif linear_type == 'cos0.4_30':
            return functools.partial(CosineLinear, s=30, m=0.5)

        elif linear_type == 'cos0.4_64':
            return functools.partial(CosineLinear, s=64, m=0.5)

        elif linear_type == 'sphere4':
            return functools.partial(SphereLinear, m=4)

        else:
            Log.error('Not support linear type: {}.'.format(linear_type))
            exit(1)
Exemple #8
0
    def __init__(self, args_parser=None, configs=None, config_dict=None):
        if config_dict is not None:
            self.params_root = config_dict

        elif configs is not None:
            if not os.path.exists(configs):
                Log.error('Json Path:{} not exists!'.format(configs))
                exit(0)

            json_stream = open(configs, 'r')
            self.params_root = json.load(json_stream)
            json_stream.close()

        elif args_parser is not None:
            self.args_dict = args_parser.__dict__
            self.params_root = None

            if not os.path.exists(args_parser.configs):
                print('Json Path:{} not exists!'.format(args_parser.configs))
                exit(1)

            json_stream = open(args_parser.configs, 'r')
            self.params_root = json.load(json_stream)
            json_stream.close()

            for key, value in self.args_dict.items():
                if not self.exists(*key.split(':')):
                    self.add(key.split(':'), value)
                elif value is not None:
                    self.update(key.split(':'), value)
    def _relabel(self):
        label_id = 0
        label_dict = dict()
        old_label_path = self.configer.get('data', 'label_path')
        new_label_path = '{}_new'.format(self.configer.get('data', 'label_path'))
        self.configer.update('data.label_path', new_label_path)
        fw = open(new_label_path, 'w')
        check_valid_dict = dict()
        with open(old_label_path, 'r') as fr:
            for line in fr.readlines():
                line_items = line.strip().split()
                if not os.path.exists(os.path.join(self.configer.get('data', 'data_dir'), line_items[0])):
                    continue

                if line_items[1] not in label_dict:
                    label_dict[line_items[1]] = label_id
                    label_id += 1

                if line_items[0] in check_valid_dict:
                    Log.error('Duplicate Error: {}'.format(line_items[0]))
                    exit()

                check_valid_dict[line_items[0]] = 1
                fw.write('{} {}\n'.format(line_items[0], label_dict[line_items[1]]))

        fw.close()
        shutil.copy(self.configer.get('data', 'label_path'),
                    os.path.join(self.configer.get('data', 'merge_dir'), 'ori_label.txt'))
        self.configer.update(('data.num_classes'), [label_id])
        Log.info('Num Classes is {}...'.format(self.configer.get('data', 'num_classes')))
Exemple #10
0
    def __list_dirs(self, root_dir, dataset):
        img_list = list()
        offset_h_list = list()
        offset_w_list = list()
        name_list = list()
        image_dir = os.path.join(root_dir, dataset, 'image')

        offset_h_dir = None
        offset_w_dir = None

        offset_type = self.configer.get('data', 'offset_type')
        assert (offset_type is not None)
        offset_h_dir = os.path.join(root_dir, dataset, offset_type, 'h')
        offset_w_dir = os.path.join(root_dir, dataset, offset_type, 'w')
        img_extension = os.listdir(image_dir)[0].split('.')[-1]

        for file_name in os.listdir(label_dir):
            image_name = '.'.join(file_name.split('.')[:-1])
            img_path = os.path.join(image_dir,
                                    '{}.{}'.format(image_name, img_extension))
            offset_h_path = os.path.join(offset_h_dir,
                                         self._replace_ext(file_name, 'mat'))
            offset_w_path = os.path.join(offset_w_dir,
                                         self._replace_ext(file_name, 'mat'))

            if not os.path.exists(label_path) or not os.path.exists(img_path):
                Log.error('Label Path: {} not exists.'.format(label_path))
                continue
            img_list.append(img_path)
            offset_h_list.append(offset_h_path)
            offset_w_list.append(offset_w_path)
            name_list.append(image_name)

        return img_list, offset_h_list, offset_w_list, name_list
    def Linear(linear_type):
        if linear_type == 'default':
            return Linear

        if linear_type == 'nobias':
            return functools.partial(Linear, bias=False)

        elif linear_type == 'arc0.5_30':
            return functools.partial(ArcLinear, s=30, m=0.5, easy_margin=False)

        elif linear_type == 'arc0.5_64':
            return functools.partial(ArcLinear, s=64, m=0.5, easy_margin=False)

        elif linear_type == 'easyarc0.5_30':
            return functools.partial(ArcLinear, s=30, m=0.5, easy_margin=True)

        elif linear_type == 'easyarc0.5_64':
            return functools.partial(ArcLinear, s=64, m=0.5, easy_margin=True)

        elif linear_type == 'cos0.4_30':
            return functools.partial(CosineLinear, s=30, m=0.5)

        elif linear_type == 'cos0.4_64':
            return functools.partial(CosineLinear, s=64, m=0.5)

        elif linear_type == 'sphere4':
            return functools.partial(SphereLinear, m=4)

        else:
            Log.error('Not support linear type: {}.'.format(linear_type))
            exit(1)
 def get_seg_loss(self, loss_type=None):
     key = self.configer.get('loss', 'loss_type') if loss_type is None else loss_type
     if key not in SEG_LOSS_DICT:
         Log.error('Loss: {} not valid!'.format(key))
         exit(1)
     Log.info('use loss: {}.'.format(key))
     loss = SEG_LOSS_DICT[key](self.configer)
     return self._parallel(loss)
Exemple #13
0
 def read_image(image_path, tool='pil', mode='RGB'):
     if tool == 'pil':
         return ImageHelper.pil_read_image(image_path, mode=mode)
     elif tool == 'cv2':
         return ImageHelper.cv2_read_image(image_path, mode=mode)
     else:
         Log.error('Not support mode {}'.format(mode))
         exit(1)
    def xml2json(xml_file, json_file):
        if not os.path.exists(xml_file):
            Log.error('Xml file: {} not exists.'.format(xml_file))
            exit(1)

        json_dir_name = os.path.dirname(json_file)
        if not os.path.exists(json_dir_name):
            Log.info('Json Dir: {} not exists.'.format(json_dir_name))
            os.makedirs(json_dir_name)
    def json2xml(json_file, xml_file):
        if not os.path.exists(json_file):
            Log.error('Json file: {} not exists.'.format(json_file))
            exit(1)

        xml_dir_name = os.path.dirname(xml_file)
        if not os.path.exists(xml_dir_name):
            Log.info('Xml Dir: {} not exists.'.format(xml_dir_name))
            os.makedirs(xml_dir_name)
    def load_file(json_file):
        if not os.path.exists(json_file):
            Log.error('Json file: {} not exists.'.format(json_file))
            exit(1)

        with open(json_file, 'r') as read_stream:
            json_dict = json.load(read_stream)

        return json_dict
    def save(img, save_path):
        if isinstance(img, Image.Image):
            img.save(save_path)

        elif isinstance(img, np.ndarray):
            cv2.imwrite(save_path, img)

        else:
            Log.error('Image type is invalid.')
            exit(1)
    def semantic_segmentor(self):
        model_name = self.configer.get('network', 'model_name')

        if model_name not in SEG_MODEL_DICT:
            Log.error('Model: {} not valid!'.format(model_name))
            exit(1)

        model = SEG_MODEL_DICT[model_name](self.configer)

        return model
Exemple #19
0
    def get_model(self):
        model_name = self.configer.get('network', 'model_name')

        if model_name not in CLS_MODEL_DICT:
            Log.error('Model: {} not valid!'.format(model_name))
            exit(1)

        model = CLS_MODEL_DICT[model_name](self.configer)

        return model
Exemple #20
0
        def _set_value(key, value):
            """
            We directly operate on `params_root`.
            """
            remained_parts = key.split('.')
            consumed_parts = []

            parent_dict = self.params_root
            while len(remained_parts) > 1:
                cur_key = remained_parts.pop(0)
                consumed_parts.append(cur_key)

                if cur_key not in parent_dict:
                    parent_dict[cur_key] = dict()
                    Log.info('{} not exists, set as `dict()`.'.format(
                        '.'.join(consumed_parts)))
                elif not isinstance(parent_dict[cur_key], dict):
                    Log.error(
                        'Cannot set {child_name} on {root_name}, as {root_name} is `{root_type}`.'
                        .format(root_name='.'.join(consumed_parts),
                                child_name='.'.join(remained_parts),
                                root_type=type(parent_dict[cur_key])))
                    sys.exit(1)

                parent_dict = parent_dict[cur_key]

            cur_key = remained_parts.pop(0)
            consumed_parts.append(cur_key)

            if cur_key.endswith('+'):
                cur_key = cur_key[:-1]
                target = parent_dict.get(cur_key)

                if not isinstance(target, list):
                    Log.error(
                        'Cannot append to {key}, as its type is {target_type}.'
                        .format(key=key[:-1], target_type=type(target)))
                    sys.exit(1)

                target.append(value)
                Log.info('Append {value} to {key}. Current: {target}.'.format(
                    key=key[:-1],
                    value=value,
                    target=target,
                ))
                return

            existing_value = parent_dict.get(cur_key)
            if existing_value is not None:
                Log.warn(
                    'Override {key} using {value}. Previous value: {old_value}.'
                    .format(key=key, value=value, old_value=existing_value))
            else:
                Log.info('Set {key} as {value}.'.format(key=key, value=value))
            parent_dict[cur_key] = value
Exemple #21
0
    def save(img, save_path):
        FileHelper.make_dirs(save_path, is_file=True)
        if isinstance(img, Image.Image):
            img.save(save_path)

        elif isinstance(img, np.ndarray):
            cv2.imwrite(save_path, img)

        else:
            Log.error('Image type is invalid.')
            exit(1)
Exemple #22
0
    def get_size(img):
        if isinstance(img, Image.Image):
            return img.size

        elif isinstance(img, np.ndarray):
            height, width = img.shape[:2]
            return [width, height]

        else:
            Log.error('Image type is invalid.')
            exit(1)
    def get_deploy_model(self, model_type=None):
        model_name = self.configer.get(
            'network', 'model_name') if model_type is None else model_type

        if model_name not in DEPLOY_MODEL_DICT:
            Log.error('Model: {} not valid!'.format(model_name))
            exit(1)

        model = DEPLOY_MODEL_DICT[model_name](self.configer)

        return model
    def get(self, *key, **kwargs):
        key = '.'.join(key)
        if key not in self.params_root:
            Log.warn('Key: {} not exists'.format(key))

        if key in self.params_root or 'default' in kwargs:
            return self.params_root.get(key, **kwargs)

        else:
            Log.error('{} KeyError: {}.'.format(self._get_caller(), key))
            exit(1)
Exemple #25
0
    def save_net(self, net, save_mode='iters'):
        if is_distributed() and get_rank() != 0:
            return

        state = {
            'config_dict': self.configer.to_dict(),
            'state_dict': net.state_dict(),
        }
        if self.configer.get('checkpoints', 'checkpoints_root') is None:
            checkpoints_dir = os.path.join(self.configer.get('project_dir'),
                                           self.configer.get('checkpoints', 'checkpoints_dir'))
        else:
            checkpoints_dir = os.path.join(self.configer.get('checkpoints', 'checkpoints_root'),
                                           self.configer.get('checkpoints', 'checkpoints_dir'))

        if not os.path.exists(checkpoints_dir):
            os.makedirs(checkpoints_dir)

        latest_name = '{}_latest.pth'.format(self.configer.get('checkpoints', 'checkpoints_name'))
        torch.save(state, os.path.join(checkpoints_dir, latest_name))
        if save_mode == 'performance':
            if self.configer.get('performance') > self.configer.get('max_performance'):
                latest_name = '{}_max_performance.pth'.format(self.configer.get('checkpoints', 'checkpoints_name'))
                torch.save(state, os.path.join(checkpoints_dir, latest_name))
                self.configer.update(['max_performance'], self.configer.get('performance'))

        elif save_mode == 'val_loss':
            if self.configer.get('val_loss') < self.configer.get('min_val_loss'):
                latest_name = '{}_min_loss.pth'.format(self.configer.get('checkpoints', 'checkpoints_name'))
                torch.save(state, os.path.join(checkpoints_dir, latest_name))
                self.configer.update(['min_val_loss'], self.configer.get('val_loss'))

        elif save_mode == 'iters':
            if self.configer.get('iters') - self.configer.get('last_iters') >= \
                    self.configer.get('checkpoints', 'save_iters'):
                latest_name = '{}_iters{}.pth'.format(self.configer.get('checkpoints', 'checkpoints_name'),
                                                 self.configer.get('iters'))
                torch.save(state, os.path.join(checkpoints_dir, latest_name))
                self.configer.update(['last_iters'], self.configer.get('iters'))

        elif save_mode == 'epoch':
            if self.configer.get('epoch') - self.configer.get('last_epoch') >= \
                    self.configer.get('checkpoints', 'save_epoch'):
                latest_name = '{}_epoch{}.pth'.format(self.configer.get('checkpoints', 'checkpoints_name'),
                                                 self.configer.get('epoch'))
                torch.save(state, os.path.join(checkpoints_dir, latest_name))
                self.configer.update(['last_epoch'], self.configer.get('epoch'))

        else:
            Log.error('Metric: {} is invalid.'.format(save_mode))
            exit(1)
    def plus_one(self, *key):
        if not self.exists(*key):
            Log.error('{} Key: {} not existed!!!'.format(self._get_caller(), key))
            exit(1)

        if len(key) == 1 and not isinstance(self.params_root[key[0]], dict):
            self.params_root[key[0]] += 1

        elif len(key) == 2:
            self.params_root[key[0]][key[1]] += 1

        else:
            Log.error('{} KeyError: {} !!!'.format(self._get_caller(), key))
            exit(1)
Exemple #27
0
    def cv2_read_image(image_path, mode='RGB'):
        img_bgr = cv2.imread(image_path, cv2.IMREAD_COLOR)
        if mode == 'RGB':
            return ImageHelper.bgr2rgb(img_bgr)

        elif mode == 'BGR':
            return img_bgr

        elif mode == 'P':
            return ImageHelper.img2np(Image.open(image_path).convert('P'))

        else:
            Log.error('Not support mode {}'.format(mode))
            exit(1)
Exemple #28
0
    def resize(img, target_size, interpolation=None):
        assert isinstance(target_size, (list, tuple))
        assert isinstance(interpolation, str)

        target_size = tuple(target_size)
        if isinstance(img, Image.Image):
            return ImageHelper.pil_resize(img, target_size, interpolation=PIL_INTER_DICT[interpolation])

        elif isinstance(img, np.ndarray):
            return ImageHelper.cv2_resize(img, target_size, interpolation=CV2_INTER_DICT[interpolation])

        else:
            Log.error('Image type is invalid.')
            exit(1)
    def get_scale(self, img_size):
        if self.method == 'random':
            scale_ratio = random.uniform(self.scale_range[0], self.scale_range[1])
            return scale_ratio

        elif self.method == 'bound':
            scale1 = self.resize_bound[0] / min(img_size)
            scale2 = self.resize_bound[1] / max(img_size)
            scale = min(scale1, scale2)
            return scale

        else:
            Log.error('Resize method {} is invalid.'.format(self.method))
            exit(1)
    def update(self, key_tuple, value):
        if not self.exists(*key_tuple):
            Log.error('{} Key: {} not existed!!!'.format(self._get_caller(), key_tuple))
            exit(1)

        if len(key_tuple) == 1 and not isinstance(self.params_root[key_tuple[0]], dict):
            self.params_root[key_tuple[0]] = value

        elif len(key_tuple) == 2:
            self.params_root[key_tuple[0]][key_tuple[1]] = value

        else:
            Log.error('{} Key: {} not existed!!!'.format(self._get_caller(), key_tuple))
            exit(1)