Esempio n. 1
0
    def load_net(self, net):
        if self.configer.get('gpu') is not None:
            net = self._make_parallel(net)

        net = net.to(torch.device('cpu' if self.configer.get('gpu') is None else 'cuda'))
        net.float()
        if self.configer.get('network', 'resume') is not None:
            Log.info('Loading checkpoint from {}...'.format(self.configer.get('network', 'resume')))
            resume_dict = torch.load(self.configer.get('network', 'resume'))
            if 'state_dict' in resume_dict:
                checkpoint_dict = resume_dict['state_dict']

            elif 'model' in resume_dict:
                checkpoint_dict = resume_dict['model']

            elif isinstance(resume_dict, OrderedDict):
                checkpoint_dict = resume_dict

            else:
                raise RuntimeError(
                    'No state_dict found in checkpoint file {}'.format(self.configer.get('network', 'resume')))

            if list(checkpoint_dict.keys())[0].startswith('module.'):
                checkpoint_dict = {k[7:]: v for k, v in checkpoint_dict.items()}

            # load state_dict
            if hasattr(net, 'module'):
                self.load_state_dict(net.module, checkpoint_dict, self.configer.get('network', 'resume_strict'))
            else:
                self.load_state_dict(net, checkpoint_dict, self.configer.get('network', 'resume_strict'))

            if self.configer.get('network', 'resume_continue'):
                self.configer.resume(resume_dict['config_dict'])

        return net
    def save_net(runner, net, performance=None, val_loss=None, iters=None, epoch=None, postfix='latest'):
        state = {
            'config_dict': runner.configer.to_dict(),
            'state_dict': net.state_dict(),
            'runner_state': runner.runner_state
        }
        checkpoints_dir = os.path.join(runner.configer.get('project_dir'),
                                       runner.configer.get('network', 'checkpoints_dir'))

        if not os.path.exists(checkpoints_dir):
            os.makedirs(checkpoints_dir)

        latest_name = '{}_{}.pth'.format(runner.configer.get('network', 'checkpoints_name'), postfix)
        torch.save(state, os.path.join(checkpoints_dir, latest_name))
        Log.info('save model {}'.format(os.path.join(checkpoints_dir, latest_name)))
        if performance is not None:
            if performance > runner.runner_state['max_performance']:
                latest_name = '{}_max_performance.pth'.format(runner.configer.get('network', 'checkpoints_name'))
                torch.save(state, os.path.join(checkpoints_dir, latest_name))
                runner.runner_state['max_performance'] = performance

        if val_loss is not None:
            if val_loss < runner.runner_state['min_val_loss']:
                latest_name = '{}_min_loss.pth'.format(runner.configer.get('network', 'checkpoints_name'))
                torch.save(state, os.path.join(checkpoints_dir, latest_name))
                runner.runner_state['min_val_loss'] = val_loss

        if iters is not None:
            latest_name = '{}_iters{}.pth'.format(runner.configer.get('network', 'checkpoints_name'), iters)
            torch.save(state, os.path.join(checkpoints_dir, latest_name))

        if epoch is not None:
            latest_name = '{}_epoch{}.pth'.format(runner.configer.get('network', 'checkpoints_name'), epoch)
            torch.save(state, os.path.join(checkpoints_dir, latest_name))
    def load_net(runner, net, model_path=None):
        if model_path is not None or runner.configer.get('network', 'resume') is not None:
            resume_path = runner.configer.get('network', 'resume')
            resume_path = model_path if model_path is not None else resume_path

            if not os.path.exists(resume_path):
                Log.warn('Resume path: {} not exists...'.format(resume_path))
                return net

            Log.info('Resuming from {}'.format(resume_path))
            resume_dict = torch.load(resume_path, map_location="cpu")
            if 'state_dict' in resume_dict:
                checkpoint_dict = resume_dict['state_dict']

            elif 'model' in resume_dict:
                checkpoint_dict = resume_dict['model']

            elif isinstance(resume_dict, OrderedDict):
                checkpoint_dict = resume_dict

            else:
                raise RuntimeError(
                    'No state_dict found in checkpoint file {}'.format(runner.configer.get('network', 'resume')))

            # load state_dict
            if hasattr(net, 'module'):
                RunnerHelper.load_state_dict(net.module, checkpoint_dict,
                                             runner.configer.get('network', 'resume_strict'))
            else:
                RunnerHelper.load_state_dict(net, checkpoint_dict, runner.configer.get('network', 'resume_strict'))

            if runner.configer.get('network', 'resume_continue'):
                runner.runner_state = resume_dict['runner_state']

        return net
Esempio n. 4
0
    def get_testloader(self, dataset=None):
            dataset = 'test' if dataset is None else dataset
            if self.configer.exists('data', 'use_sw_offset') or self.configer.exists('data', 'pred_sw_offset'):
                Log.info('use sliding window based offset loader for test ...')
                test_loader = data.DataLoader(
                    SWOffsetTestLoader(root_dir=self.configer.get('data', 'data_dir'), dataset=dataset,
                                       img_transform=self.img_transform,
                                       configer=self.configer),
                    batch_size=self.configer.get('test', 'batch_size'), pin_memory=True,
                    num_workers=self.configer.get('data', 'workers'), shuffle=False,
                    collate_fn=lambda *args: collate(
                        *args, trans_dict=self.configer.get('test', 'data_transformer')
                    )
                )
                return test_loader

            elif self.configer.get('method') == 'fcn_segmentor':
                Log.info('use CSDataTestLoader for test ...')
                test_loader = data.DataLoader(
                    CSDataTestLoader(root_dir=self.configer.get('data', 'data_dir'), dataset=dataset,
                                     img_transform=self.img_transform,
                                     configer=self.configer),
                    batch_size=self.configer.get('test', 'batch_size'), pin_memory=True,
                    num_workers=self.configer.get('data', 'workers'), shuffle=False,
                    collate_fn=lambda *args: collate(
                        *args, trans_dict=self.configer.get('test', 'data_transformer')
                    )
                )
                return test_loader
Esempio n. 5
0
    def __init__(self, args_parser=None, configs=None, config_dict=None):
        if config_dict is not None:
            self.params_root = config_dict

        elif configs is not None:
            if not os.path.exists(configs):
                Log.error('Json Path:{} not exists!'.format(configs))
                exit(0)

            json_stream = open(configs, 'r')
            self.params_root = json.load(json_stream)
            json_stream.close()

        elif args_parser is not None:
            self.args_dict = args_parser.__dict__
            self.params_root = None

            if not os.path.exists(args_parser.configs):
                print('Json Path:{} not exists!'.format(args_parser.configs))
                exit(1)

            json_stream = open(args_parser.configs, 'r')
            self.params_root = json.load(json_stream)
            json_stream.close()

            for key, value in self.args_dict.items():
                if not self.exists(*key.split(':')):
                    self.add(key.split(':'), value)
                elif value is not None:
                    self.update(key.split(':'), value)
Esempio n. 6
0
    def __init__(self, configer):
        self.configer = configer

        if self.configer.get('data', 'image_tool') == 'pil':
            self.aug_train_transform = pil_aug_trans.PILAugCompose(
                self.configer, split='train')
            self.aug_val_transform = pil_aug_trans.PILAugCompose(self.configer,
                                                                 split='val')
        elif self.configer.get('data', 'image_tool') == 'cv2':
            self.aug_train_transform = cv2_aug_trans.CV2AugCompose(
                self.configer, split='train')
            self.aug_val_transform = cv2_aug_trans.CV2AugCompose(self.configer,
                                                                 split='val')
        else:
            Log.error('Not support {} image tool.'.format(
                self.configer.get('data', 'image_tool')))
            exit(1)

        self.img_transform = trans.Compose([
            trans.ToTensor(),
            trans.Normalize(div_value=self.configer.get(
                'normalize', 'div_value'),
                            mean=self.configer.get('normalize', 'mean'),
                            std=self.configer.get('normalize', 'std')),
        ])

        self.label_transform = trans.Compose([
            trans.ToLabel(),
            trans.ReLabel(255, -1),
        ])
Esempio n. 7
0
    def __getitem__(self, index):
        img = None
        valid = True
        while img is None:
            try:
                img = ImageHelper.read_image(self.item_list[index][0],
                                             tool=self.configer.get('data', 'image_tool'),
                                             mode=self.configer.get('data', 'input_mode'))
                assert isinstance(img, np.ndarray) or isinstance(img, Image.Image)
            except:
                Log.warn('Invalid image path: {}'.format(self.item_list[index][0]))
                img = None
                valid = False
                index = (index + 1) % len(self.item_list)

        ori_img_size = ImageHelper.get_size(img)
        if self.aug_transform is not None:
            img = self.aug_transform(img)

        border_hw = ImageHelper.get_size(img)[::-1]
        if self.img_transform is not None:
            img = self.img_transform(img)

        meta = dict(
            valid=valid,
            ori_img_size=ori_img_size,
            border_hw=border_hw,
            img_path=self.item_list[index][0],
            filename=self.item_list[index][1],
            label=self.item_list[index][2]
        )
        return dict(
            img=DataContainer(img, stack=True),
            meta=DataContainer(meta, stack=False, cpu_only=True)
        )
Esempio n. 8
0
    def get_valloader(self, loader_type=None, data_dir=None, batch_size=None):
        loader_type = self.configer.get(
            'val', 'loader') if loader_type is None else loader_type
        data_dir = self.configer.get(
            'data', 'data_dir') if data_dir is None else data_dir
        batch_size = self.configer.get(
            'val', 'batch_size') if batch_size is None else batch_size
        if loader_type is None or loader_type == 'default':
            valloader = data.DataLoader(
                DefaultDataset(data_dir=data_dir,
                               dataset='val',
                               aug_transform=self.aug_val_transform,
                               img_transform=self.img_transform,
                               configer=self.configer),
                batch_size=batch_size,
                shuffle=False,
                num_workers=self.configer.get('data', 'workers'),
                pin_memory=True,
                collate_fn=collate)

            return valloader

        else:
            Log.error('{} val loader is invalid.'.format(
                self.configer.get('val', 'loader')))
            exit(1)
Esempio n. 9
0
    def get_valloader(self, dataset=None):
        dataset = 'val' if dataset is None else dataset

        if self.configer.get('method') == 'fcn_segmentor':
            """
            default manner:
            load the ground-truth label.
            """
            Log.info('use DefaultLoader for val ...')
            valloader = data.DataLoader(
                DefaultLoader(root_dir=self.configer.get('data', 'data_dir'),
                              dataset=dataset,
                              aug_transform=self.aug_val_transform,
                              img_transform=self.img_transform,
                              label_transform=self.label_transform,
                              configer=self.configer),
                batch_size=self.configer.get('val', 'batch_size'),
                pin_memory=True,
                num_workers=self.configer.get('data', 'workers'),
                shuffle=False,
                collate_fn=lambda *args: collate(
                    *args,
                    trans_dict=self.configer.get('val', 'data_transformer')))
            return valloader

        else:
            Log.error('Method: {} loader is invalid.'.format(
                self.configer.get('method')))
            return None
Esempio n. 10
0
    def _hard_anchor_sampling(self, X, y_hat, y):
        batch_size, feat_dim = X.shape[0], X.shape[-1]

        classes = []
        total_classes = 0
        for ii in range(batch_size):
            this_y = y_hat[ii]
            this_classes = torch.unique(this_y)
            this_classes = [x for x in this_classes if x > 0 and x != self.ignore_label]
            this_classes = [x for x in this_classes if (this_y == x).nonzero().shape[0] > self.max_views]

            classes.append(this_classes)
            total_classes += len(this_classes)

        if total_classes == 0:
            return None, None

        n_view = self.max_samples // total_classes
        n_view = min(n_view, self.max_views)

        X_ = torch.zeros((total_classes, n_view, feat_dim), dtype=torch.float).cuda()
        y_ = torch.zeros(total_classes, dtype=torch.float).cuda()

        X_ptr = 0
        for ii in range(batch_size):
            this_y_hat = y_hat[ii]
            this_y = y[ii]
            this_classes = classes[ii]

            for cls_id in this_classes:
                hard_indices = ((this_y_hat == cls_id) & (this_y != cls_id)).nonzero()
                easy_indices = ((this_y_hat == cls_id) & (this_y == cls_id)).nonzero()

                num_hard = hard_indices.shape[0]
                num_easy = easy_indices.shape[0]

                if num_hard >= n_view / 2 and num_easy >= n_view / 2:
                    num_hard_keep = n_view // 2
                    num_easy_keep = n_view - num_hard_keep
                elif num_hard >= n_view / 2:
                    num_easy_keep = num_easy
                    num_hard_keep = n_view - num_easy_keep
                elif num_easy >= n_view / 2:
                    num_hard_keep = num_hard
                    num_easy_keep = n_view - num_hard_keep
                else:
                    Log.info('this shoud be never touched! {} {} {}'.format(num_hard, num_easy, n_view))
                    raise Exception

                perm = torch.randperm(num_hard)
                hard_indices = hard_indices[perm[:num_hard_keep]]
                perm = torch.randperm(num_easy)
                easy_indices = easy_indices[perm[:num_easy_keep]]
                indices = torch.cat((hard_indices, easy_indices), dim=0)

                X_[X_ptr, :, :] = X[ii, indices, :].squeeze(1)
                y_[X_ptr] = cls_id
                X_ptr += 1

        return X_, y_
    def Linear(linear_type):
        if linear_type == 'default':
            return Linear

        if linear_type == 'nobias':
            return functools.partial(Linear, bias=False)

        elif 'arc' in linear_type:
            #example arc0.5_64  arc0.32_64 easyarc0.5_64
            margin_scale = linear_type.split('arc')[1]
            margin = float(margin_scale.split('_')[0])
            scale = float(margin_scale.split('_')[1])
            easy = True if 'easy' in linear_type else False
            return functools.partial(ArcLinear,
                                     s=scale,
                                     m=margin,
                                     easy_margin=easy)

        elif linear_type == 'cos0.4_30':
            return functools.partial(CosineLinear, s=30, m=0.5)

        elif linear_type == 'cos0.4_64':
            return functools.partial(CosineLinear, s=64, m=0.5)

        elif linear_type == 'sphere4':
            return functools.partial(SphereLinear, m=4)

        else:
            Log.error('Not support linear type: {}.'.format(linear_type))
            exit(1)
Esempio n. 12
0
    def __list_dirs(self, root_dir, dataset):
        img_list = list()
        offset_h_list = list()
        offset_w_list = list()
        name_list = list()
        image_dir = os.path.join(root_dir, dataset, 'image')

        offset_h_dir = None
        offset_w_dir = None

        offset_type = self.configer.get('data', 'offset_type')
        assert (offset_type is not None)
        offset_h_dir = os.path.join(root_dir, dataset, offset_type, 'h')
        offset_w_dir = os.path.join(root_dir, dataset, offset_type, 'w')
        img_extension = os.listdir(image_dir)[0].split('.')[-1]

        for file_name in os.listdir(label_dir):
            image_name = '.'.join(file_name.split('.')[:-1])
            img_path = os.path.join(image_dir,
                                    '{}.{}'.format(image_name, img_extension))
            offset_h_path = os.path.join(offset_h_dir,
                                         self._replace_ext(file_name, 'mat'))
            offset_w_path = os.path.join(offset_w_dir,
                                         self._replace_ext(file_name, 'mat'))

            if not os.path.exists(label_path) or not os.path.exists(img_path):
                Log.error('Label Path: {} not exists.'.format(label_path))
                continue
            img_list.append(img_path)
            offset_h_list.append(offset_h_path)
            offset_w_list.append(offset_w_path)
            name_list.append(image_name)

        return img_list, offset_h_list, offset_w_list, name_list
    def Linear(linear_type):
        if linear_type == 'default':
            return Linear

        if linear_type == 'nobias':
            return functools.partial(Linear, bias=False)

        elif linear_type == 'arc0.5_30':
            return functools.partial(ArcLinear, s=30, m=0.5, easy_margin=False)

        elif linear_type == 'arc0.5_64':
            return functools.partial(ArcLinear, s=64, m=0.5, easy_margin=False)

        elif linear_type == 'easyarc0.5_30':
            return functools.partial(ArcLinear, s=30, m=0.5, easy_margin=True)

        elif linear_type == 'easyarc0.5_64':
            return functools.partial(ArcLinear, s=64, m=0.5, easy_margin=True)

        elif linear_type == 'cos0.4_30':
            return functools.partial(CosineLinear, s=30, m=0.5)

        elif linear_type == 'cos0.4_64':
            return functools.partial(CosineLinear, s=64, m=0.5)

        elif linear_type == 'sphere4':
            return functools.partial(SphereLinear, m=4)

        else:
            Log.error('Not support linear type: {}.'.format(linear_type))
            exit(1)
Esempio n. 14
0
    def prepare_data(self, data_dict, want_reverse=False):

        input_keys, target_keys = self.input_keys(), self.target_keys()

        if self.conditions.use_ground_truth:
            input_keys += target_keys

        Log.info_once('Input keys: {}'.format(input_keys))
        Log.info_once('Target keys: {}'.format(target_keys))

        inputs = [data_dict[k] for k in input_keys]
        batch_size = len(inputs[0])
        targets = [data_dict[k] for k in target_keys]

        sequences = [
            self._prepare_sequence(inputs, force_list=True),
            self._prepare_sequence(targets, force_list=False)
        ]
        if want_reverse:
            rev_data_dict = self._reverse_data_dict(data_dict)
            sequences.extend([
                self._prepare_sequence([rev_data_dict[k] for k in input_keys],
                                       force_list=True),
                self._prepare_sequence([rev_data_dict[k] for k in target_keys],
                                       force_list=False)
            ])

        return sequences, batch_size
Esempio n. 15
0
    def update(self, key, value, append=False):
        if key not in self.params_root:
            Log.error('{} Key: {} not existed!!!'.format(
                self._get_caller(), key))
            exit(1)

        self.params_root.put(key, value, append)
Esempio n. 16
0
    def __getitem__(self, index):
        img = None
        valid = True
        while img is None:
            try:
                img = ImageHelper.read_image(
                    self.img_list[index],
                    tool=self.configer.get('data', 'image_tool'),
                    mode=self.configer.get('data', 'input_mode'))
                assert isinstance(img, np.ndarray) or isinstance(
                    img, Image.Image)
            except:
                Log.warn('Invalid image path: {}'.format(self.img_list[index]))
                img = None
                valid = False
                index = (index + 1) % len(self.img_list)

        label = torch.from_numpy(np.array(self.label_list[index]))
        if self.aug_transform is not None:
            img = self.aug_transform(img)

        if self.img_transform is not None:
            img = self.img_transform(img)

        return dict(valid=valid,
                    img=DataContainer(img, stack=True),
                    label=DataContainer(label, stack=True))
Esempio n. 17
0
    def __read_file(self, data_dir, dataset):
        img_list = list()
        mlabel_list = list()
        img_dict = dict()
        all_img_list = []
        with open(self.configer.get('data.{}_label_path'.format(dataset)),
                  'r') as file_stream:
            all_img_list += file_stream.readlines()

        if dataset == 'train' and self.configer.get('data.include_val',
                                                    default=False):
            with open(self.configer.get('data.val_label_path'),
                      'r') as file_stream:
                all_img_list += file_stream.readlines()

        for line_cnt in range(len(all_img_list)):
            line_items = all_img_list[line_cnt].strip().split()
            if len(line_items) == 0:
                continue

            path = line_items[0]
            if not os.path.exists(os.path.join(
                    data_dir, path)) or not ImageHelper.is_img(path):
                Log.warn('Invalid Image Path: {}'.format(
                    os.path.join(data_dir, path)))
                continue

            img_list.append(os.path.join(data_dir, path))
            mlabel_list.append([int(item) for item in line_items[1:]])

        assert len(img_list) > 0
        Log.info('Length of {} imgs is {}...'.format(dataset, len(img_list)))
        return img_list, mlabel_list
    def _relabel(self):
        label_id = 0
        label_dict = dict()
        old_label_path = self.configer.get('data', 'label_path')
        new_label_path = '{}_new'.format(self.configer.get('data', 'label_path'))
        self.configer.update('data.label_path', new_label_path)
        fw = open(new_label_path, 'w')
        check_valid_dict = dict()
        with open(old_label_path, 'r') as fr:
            for line in fr.readlines():
                line_items = line.strip().split()
                if not os.path.exists(os.path.join(self.configer.get('data', 'data_dir'), line_items[0])):
                    continue

                if line_items[1] not in label_dict:
                    label_dict[line_items[1]] = label_id
                    label_id += 1

                if line_items[0] in check_valid_dict:
                    Log.error('Duplicate Error: {}'.format(line_items[0]))
                    exit()

                check_valid_dict[line_items[0]] = 1
                fw.write('{} {}\n'.format(line_items[0], label_dict[line_items[1]]))

        fw.close()
        shutil.copy(self.configer.get('data', 'label_path'),
                    os.path.join(self.configer.get('data', 'merge_dir'), 'ori_label.txt'))
        self.configer.update(('data.num_classes'), [label_id])
        Log.info('Num Classes is {}...'.format(self.configer.get('data', 'num_classes')))
Esempio n. 19
0
    def sscrop_test(self, inputs, crop_size, scale=1):
        '''
        Currently, sscrop_test does not support diverse_size testing
        '''
        n, c, ori_h, ori_w = inputs.size(0), inputs.size(1), inputs.size(2), inputs.size(3)
        scaled_inputs = F.interpolate(inputs, size=(int(ori_h*scale), int(ori_w*scale)), mode="bilinear", align_corners=True)
        n, c, h, w = scaled_inputs.size(0), scaled_inputs.size(1), scaled_inputs.size(2), scaled_inputs.size(3)
        full_probs = torch.cuda.FloatTensor(n, self.configer.get('data', 'num_classes'), h, w).fill_(0)
        count_predictions = torch.cuda.FloatTensor(n, self.configer.get('data', 'num_classes'), h, w).fill_(0)

        crop_counter = 0

        height_starts = self._decide_intersection(h, crop_size[0])
        width_starts = self._decide_intersection(w, crop_size[1])

        for height in height_starts:
            for width in width_starts:
                crop_inputs = scaled_inputs[:, :, height:height+crop_size[0], width:width + crop_size[1]]
                prediction = self.ss_test(crop_inputs)
                count_predictions[:, :, height:height+crop_size[0], width:width + crop_size[1]] += 1
                full_probs[:, :, height:height+crop_size[0], width:width + crop_size[1]] += prediction 
                crop_counter += 1
                Log.info('predicting {:d}-th crop'.format(crop_counter))

        full_probs /= count_predictions
        full_probs = F.interpolate(full_probs, size=(ori_h, ori_w), mode='bilinear', align_corners=True)
        return full_probs
Esempio n. 20
0
    def __list_dirs(self, root_dir, dataset):
        img_list = list()
        name_list = list()
        image_dir = os.path.join(root_dir, dataset)
        img_extension = os.listdir(image_dir)[0].split('.')[-1]

        if self.configer.get('dataset') == 'cityscapes':
            for item in os.listdir(image_dir):
                sub_image_dir = os.path.join(image_dir, item)
                for file_name in os.listdir(sub_image_dir):
                    image_name = file_name.split('.')[0]
                    img_path = os.path.join(sub_image_dir, file_name)
                    if not os.path.exists(img_path):
                        Log.error(
                            'Image Path: {} not exists.'.format(img_path))
                        continue
                    img_list.append(img_path)
                    name_list.append(image_name)
        else:
            for file_name in os.listdir(image_dir):
                image_name = file_name.split('.')[0]
                img_path = os.path.join(image_dir, file_name)
                if not os.path.exists(img_path):
                    Log.error('Image Path: {} not exists.'.format(img_path))
                    continue
                img_list.append(img_path)
                name_list.append(image_name)

        return img_list, name_list
Esempio n. 21
0
 def read_image(image_path, tool='pil', mode='RGB'):
     if tool == 'pil':
         return ImageHelper.pil_read_image(image_path, mode=mode)
     elif tool == 'cv2':
         return ImageHelper.cv2_read_image(image_path, mode=mode)
     else:
         Log.error('Not support mode {}'.format(mode))
         exit(1)
Esempio n. 22
0
 def get_seg_loss(self, loss_type=None):
     key = self.configer.get('loss', 'loss_type') if loss_type is None else loss_type
     if key not in SEG_LOSS_DICT:
         Log.error('Loss: {} not valid!'.format(key))
         exit(1)
     Log.info('use loss: {}.'.format(key))
     loss = SEG_LOSS_DICT[key](self.configer)
     return self._parallel(loss)
Esempio n. 23
0
    def save_file(json_dict, json_file):
        dir_name = os.path.dirname(json_file)
        if not os.path.exists(dir_name):
            Log.info('Json Dir: {} not exists.'.format(dir_name))
            os.makedirs(dir_name)

        with open(json_file, 'w') as write_stream:
            write_stream.write(json.dumps(json_dict))
Esempio n. 24
0
 def __init__(self, max_degree, rotate_ratio=0.5, mean=(104, 117, 123)):
     assert isinstance(max_degree, int)
     self.max_degree = max_degree
     self.ratio = rotate_ratio
     self.mean = mean
     Log.warn(
         'Currently `RandomRotate` is only implemented for `img`, `labelmap` and `maskmap`.'
     )
    def test(self, data_loader=None):
        """
          Validation function during the train phase.
        """
        self.seg_net.eval()
        start_time = time.time()
        image_id = 0

        Log.info('save dir {}'.format(self.save_dir))
        FileHelper.make_dirs(self.save_dir, is_file=False)

        print('Total batches', len(self.test_loader))
        for j, data_dict in enumerate(self.test_loader):
            inputs = [data_dict['img']]
            names = data_dict['name']
            metas = data_dict['meta']

            dest_dir = self.save_dir

            with torch.no_grad():
                offsets, logits = self.extract_offset(inputs)
                print([x.shape for x in logits])
                for k in range(len(inputs[0])):
                    image_id += 1
                    ori_img_size = metas[k]['ori_img_size']
                    border_size = metas[k]['border_size']
                    offset = offsets[k].squeeze().cpu().numpy()
                    offset = cv2.resize(
                        offset[:border_size[1], :border_size[0]],
                        tuple(ori_img_size),
                        interpolation=cv2.INTER_NEAREST)
                    print(image_id)

                    os.makedirs(dest_dir, exist_ok=True)

                    if names[k].rpartition('.')[0]:
                        dest_name = names[k].rpartition('.')[0] + '.mat'
                    else:
                        dest_name = names[k] + '.mat'
                    dest_name = os.path.join(dest_dir, dest_name)
                    print('Shape:', offset.shape, 'Saving to', dest_name)

                    data_dict = {'mat': offset}

                    scipy.io.savemat(dest_name, data_dict, do_compression=True)
                    try:
                        scipy.io.loadmat(dest_name)
                    except Exception as e:
                        print(e)
                        scipy.io.savemat(dest_name,
                                         data_dict,
                                         do_compression=False)

            self.batch_time.update(time.time() - start_time)
            start_time = time.time()

        Log.info('Test Time {batch_time.sum:.3f}s'.format(
            batch_time=self.batch_time))
Esempio n. 26
0
    def json2xml(json_file, xml_file):
        if not os.path.exists(json_file):
            Log.error('Json file: {} not exists.'.format(json_file))
            exit(1)

        xml_dir_name = os.path.dirname(xml_file)
        if not os.path.exists(xml_dir_name):
            Log.info('Xml Dir: {} not exists.'.format(xml_dir_name))
            os.makedirs(xml_dir_name)
Esempio n. 27
0
    def xml2json(xml_file, json_file):
        if not os.path.exists(xml_file):
            Log.error('Xml file: {} not exists.'.format(xml_file))
            exit(1)

        json_dir_name = os.path.dirname(json_file)
        if not os.path.exists(json_dir_name):
            Log.info('Json Dir: {} not exists.'.format(json_dir_name))
            os.makedirs(json_dir_name)
Esempio n. 28
0
def get_evaluator(configer, trainer, name=None):
    name = os.environ.get('evaluator', 'standard')

    if not name in evaluators:
        raise RuntimeError('Unknown evaluator name: {}'.format(name))
    klass = evaluators[name]
    Log.info('Using evaluator: {}'.format(klass.__name__))

    return klass(configer, trainer)
Esempio n. 29
0
    def load_file(json_file):
        if not os.path.exists(json_file):
            Log.error('Json file: {} not exists.'.format(json_file))
            exit(1)

        with open(json_file, 'r') as read_stream:
            json_dict = json.load(read_stream)

        return json_dict
Esempio n. 30
0
    def forward(self, inputs, targets, **kwargs):

        from lib.utils.helpers.offset_helper import DTOffsetHelper

        pred_mask, pred_direction = inputs

        seg_label_map, distance_map, angle_map = targets[0], targets[
            1], targets[2]
        gt_mask = DTOffsetHelper.distance_to_mask_label(distance_map,
                                                        seg_label_map,
                                                        return_tensor=True)

        gt_size = gt_mask.shape[1:]
        mask_weights = self.calc_weights(gt_mask, 2)

        pred_direction = F.interpolate(pred_direction,
                                       size=gt_size,
                                       mode="bilinear",
                                       align_corners=True)
        pred_mask = F.interpolate(pred_mask,
                                  size=gt_size,
                                  mode="bilinear",
                                  align_corners=True)
        mask_loss = F.cross_entropy(pred_mask,
                                    gt_mask,
                                    weight=mask_weights,
                                    ignore_index=-1)

        mask_threshold = float(os.environ.get('mask_threshold', 0.5))
        binary_pred_mask = torch.softmax(pred_mask,
                                         dim=1)[:, 1, :, :] > mask_threshold

        gt_direction = DTOffsetHelper.angle_to_direction_label(
            angle_map,
            seg_label_map=seg_label_map,
            extra_ignore_mask=(binary_pred_mask == 0),
            return_tensor=True)

        direction_loss_mask = gt_direction != -1
        direction_weights = self.calc_weights(
            gt_direction[direction_loss_mask], pred_direction.size(1))
        direction_loss = F.cross_entropy(pred_direction,
                                         gt_direction,
                                         weight=direction_weights,
                                         ignore_index=-1)

        if self.training \
           and self.configer.get('iters') % self.configer.get('solver', 'display_iter') == 0 \
           and torch.cuda.current_device() == 0:
            Log.info('mask loss: {} direction loss: {}.'.format(
                mask_loss, direction_loss))

        mask_weight = float(os.environ.get('mask_weight', 1))
        direction_weight = float(os.environ.get('direction_weight', 1))

        return mask_weight * mask_loss + direction_weight * direction_loss