Ejemplo n.º 1
0
    def infer(self, model, pred_dir, patch_size, overlap, ext='_mask', file_ext='png', visualize=False,
              densecrf=False, crf_params=None):
        if isinstance(model, list) or isinstance(model, tuple):
            lbl_margin = model[0].lbl_margin
        else:
            lbl_margin = model.lbl_margin
        if crf_params is None and densecrf:
            crf_params = {'sxy': 3, 'srgb': 3, 'compat': 5}

        misc_utils.make_dir_if_not_exist(pred_dir)
        pbar = tqdm(self.rgb_files)
        for rgb_file in pbar:
            file_name = os.path.splitext(os.path.basename(rgb_file))[0].split('_')[0]
            pbar.set_description('Inferring {}'.format(file_name))
            # read data
            rgb = misc_utils.load_file(rgb_file)[:, :, :3]

            # evaluate on tiles
            tile_dim = rgb.shape[:2]
            tile_dim_pad = [tile_dim[0] + 2 * lbl_margin, tile_dim[1] + 2 * lbl_margin]
            grid_list = patch_extractor.make_grid(tile_dim_pad, patch_size, overlap)

            if isinstance(model, list) or isinstance(model, tuple):
                tile_preds = 0
                for m in model:
                    tile_preds = tile_preds + self.infer_tile(m, rgb, grid_list, patch_size, tile_dim, tile_dim_pad,
                                                              lbl_margin)
            else:
                tile_preds = self.infer_tile(model, rgb, grid_list, patch_size, tile_dim, tile_dim_pad, lbl_margin)

            if densecrf:
                d = dcrf.DenseCRF2D(*tile_preds.shape)
                U = unary_from_softmax(np.ascontiguousarray(
                    data_utils.change_channel_order(tile_preds, False)))
                d.setUnaryEnergy(U)
                d.addPairwiseBilateral(rgbim=rgb, **crf_params)
                Q = d.inference(5)
                tile_preds = np.argmax(Q, axis=0).reshape(*tile_preds.shape[:2])
            else:
                tile_preds = np.argmax(tile_preds, -1)

            if self.encode_func:
                pred_img = self.encode_func(tile_preds)
            else:
                pred_img = tile_preds

            if visualize:
                vis_utils.compare_figures([rgb, pred_img], (1, 2), fig_size=(12, 5))

            misc_utils.save_file(os.path.join(pred_dir, '{}{}.{}'.format(file_name, ext, file_ext)), pred_img)
Ejemplo n.º 2
0
def create_dataset(data_dir,
                   save_dir,
                   patch_size,
                   pad,
                   overlap,
                   valid_percent=0.2,
                   visualize=False):
    # create folders and files
    patch_dir = os.path.join(save_dir, 'patches')
    misc_utils.make_dir_if_not_exist(patch_dir)
    record_file_train = open(
        os.path.join(save_dir, 'file_list_train_{}_2.txt').format(
            misc_utils.float2str(valid_percent)), 'w+')
    record_file_valid = open(
        os.path.join(save_dir, 'file_list_valid_{}_2.txt').format(
            misc_utils.float2str(valid_percent)), 'w+')
    train_files, valid_files = get_images(data_dir, valid_percent, split=True)

    for img_file, lbl_file in tqdm(train_files):
        city_name = os.path.splitext(
            os.path.basename(img_file))[0].split('_')[0]
        for rgb_patch, gt_patch, y, x in data_utils.patch_tile(
                img_file, lbl_file, patch_size, pad, overlap):
            if visualize:
                from mrs_utils import vis_utils
                vis_utils.compare_figures([rgb_patch, gt_patch], (1, 2),
                                          fig_size=(12, 5))
            img_patchname = '{}_y{}x{}.jpg'.format(city_name, int(y), int(x))
            lbl_patchname = '{}_y{}x{}.png'.format(city_name, int(y), int(x))
            # misc_utils.save_file(os.path.join(patch_dir, img_patchname), rgb_patch.astype(np.uint8))
            # misc_utils.save_file(os.path.join(patch_dir, lbl_patchname), gt_patch.astype(np.uint8))
            record_file_train.write('{} {}\n'.format(img_patchname,
                                                     lbl_patchname))

    for img_file, lbl_file in tqdm(valid_files):
        city_name = os.path.splitext(
            os.path.basename(img_file))[0].split('_')[0]
        for rgb_patch, gt_patch, y, x in data_utils.patch_tile(
                img_file, lbl_file, patch_size, pad, overlap):
            if visualize:
                from mrs_utils import vis_utils
                vis_utils.compare_figures([rgb_patch, gt_patch], (1, 2),
                                          fig_size=(12, 5))
            img_patchname = '{}_y{}x{}.jpg'.format(city_name, int(y), int(x))
            lbl_patchname = '{}_y{}x{}.png'.format(city_name, int(y), int(x))
            # misc_utils.save_file(os.path.join(patch_dir, img_patchname), rgb_patch.astype(np.uint8))
            # misc_utils.save_file(os.path.join(patch_dir, lbl_patchname), gt_patch.astype(np.uint8))
            record_file_valid.write('{} {}\n'.format(img_patchname,
                                                     lbl_patchname))
Ejemplo n.º 3
0
def display_group(reg_groups, size, img=None, need_return=False):
    """
    Visualize grouped connected components
    :param reg_groups: grouped connected components, can get this by calling ObjectScorer._group_pairs
    :param size: the size of the image or gt
    :param img: if given, the image will be displayed together with the visualization
    :param need_return: if True, the rendered image will be returned, otherwise the image will be displayed
    :return:
    """
    group_map = np.zeros(size, dtype=np.int)
    for cnt, group in enumerate(reg_groups):
        for g in group:
            coords = np.array(g.coords)
            group_map[coords[:, 0], coords[:, 1]] = cnt
    if need_return:
        return group_map
    else:
        if img:
            vis_utils.compare_figures([img, group_map], (1, 2), fig_size=(12, 5))
        else:
            vis_utils.compare_figures([group_map], (1, 1), fig_size=(8, 6))
Ejemplo n.º 4
0
        return self.ds_len[0]

    def __iter__(self):
        rand_idx = [np.random.permutation(np.arange(x)) for x in self.ds_len]
        for cnt in range(self.ds_len[0] // self.batch_size):
            for ds_cnt, n_sample in enumerate(self.ratio):
                for curr_cnt in range(n_sample):
                    yield self.offset[ds_cnt] + rand_idx[ds_cnt][
                        (cnt * n_sample + curr_cnt) % self.ds_len[ds_cnt]]


if __name__ == '__main__':
    from data import data_utils
    import albumentations as A
    from albumentations.pytorch import ToTensorV2
    tsfms = [A.RandomCrop(512, 512), ToTensorV2()]
    ds1 = RSDataLoader(r'/hdd/mrs/inria/ps512_pd0_ol0/patches',
                       r'/hdd/mrs/inria/ps512_pd0_ol0/file_list_train_kt.txt',
                       transforms=tsfms,
                       n_class=2)

    for cnt, data_dict in enumerate(ds1):
        from mrs_utils import vis_utils
        rgb, gt, cls = data_dict['image'], data_dict['mask'], data_dict['cls']
        print(cls)
        vis_utils.compare_figures([
            data_utils.change_channel_order(rgb.cpu().numpy()),
            gt.cpu().numpy()
        ], (1, 2),
                                  fig_size=(12, 5))
Ejemplo n.º 5
0
    def evaluate(self,
                 model,
                 patch_size,
                 overlap,
                 pred_dir=None,
                 report_dir=None,
                 save_conf=False,
                 delta=1e-6,
                 eval_class=(1, ),
                 visualize=False,
                 densecrf=False,
                 crf_params=None,
                 verbose=True):
        if isinstance(model, list) or isinstance(model, tuple):
            lbl_margin = model[0].lbl_margin
        else:
            lbl_margin = model.lbl_margin
        if crf_params is None and densecrf:
            crf_params = {'sxy': 3, 'srgb': 3, 'compat': 5}

        iou_a, iou_b = np.zeros(len(eval_class)), np.zeros(len(eval_class))
        report = []
        if pred_dir:
            misc_utils.make_dir_if_not_exist(pred_dir)
        for rgb_file, lbl_file in zip(self.rgb_files, self.lbl_files):
            file_name = os.path.splitext(os.path.basename(lbl_file))[0]

            # read data
            rgb = misc_utils.load_file(rgb_file)[:, :, :3]
            lbl = misc_utils.load_file(lbl_file)
            if self.decode_func:
                lbl = self.decode_func(lbl)

            # evaluate on tiles
            tile_dim = rgb.shape[:2]
            tile_dim_pad = [
                tile_dim[0] + 2 * lbl_margin, tile_dim[1] + 2 * lbl_margin
            ]
            grid_list = patch_extractor.make_grid(tile_dim_pad, patch_size,
                                                  overlap)

            if isinstance(model, list) or isinstance(model, tuple):
                tile_preds = 0
                for m in model:
                    tile_preds = tile_preds + self.infer_tile(
                        m, rgb, grid_list, patch_size, tile_dim, tile_dim_pad,
                        lbl_margin)
            else:
                tile_preds = self.infer_tile(model, rgb, grid_list, patch_size,
                                             tile_dim, tile_dim_pad,
                                             lbl_margin)

            if save_conf:
                misc_utils.save_file(
                    os.path.join(pred_dir, '{}.npy'.format(file_name)),
                    scipy.special.softmax(tile_preds, axis=-1)[:, :, 1])

            if densecrf:
                d = dcrf.DenseCRF2D(*tile_preds.shape)
                U = unary_from_softmax(
                    np.ascontiguousarray(
                        data_utils.change_channel_order(
                            scipy.special.softmax(tile_preds, axis=-1),
                            False)))
                d.setUnaryEnergy(U)
                d.addPairwiseBilateral(rgbim=rgb, **crf_params)
                Q = d.inference(5)
                tile_preds = np.argmax(Q,
                                       axis=0).reshape(*tile_preds.shape[:2])
            else:
                tile_preds = np.argmax(tile_preds, -1)
            iou_score = metric_utils.iou_metric(lbl / self.truth_val,
                                                tile_preds,
                                                eval_class=eval_class)
            pstr, rstr = self.get_result_strings(file_name, iou_score, delta)
            tm.misc_utils.verb_print(pstr, verbose)
            report.append(rstr)
            iou_a += iou_score[0, :]
            iou_b += iou_score[1, :]
            if visualize:
                if self.encode_func:
                    vis_utils.compare_figures([
                        rgb,
                        self.encode_func(lbl),
                        self.encode_func(tile_preds)
                    ], (1, 3),
                                              fig_size=(15, 5))
                else:
                    vis_utils.compare_figures([rgb, lbl, tile_preds], (1, 3),
                                              fig_size=(15, 5))
            if pred_dir:
                if self.encode_func:
                    misc_utils.save_file(
                        os.path.join(pred_dir, '{}.png'.format(file_name)),
                        self.encode_func(tile_preds))
                else:
                    misc_utils.save_file(
                        os.path.join(pred_dir, '{}.png'.format(file_name)),
                        tile_preds)
        pstr, rstr = self.get_result_strings('Overall',
                                             np.stack([iou_a, iou_b], axis=0),
                                             delta)
        tm.misc_utils.verb_print(pstr, verbose)
        report.append(rstr)
        if report_dir:
            misc_utils.make_dir_if_not_exist(report_dir)
            misc_utils.save_file(os.path.join(report_dir, 'result.txt'),
                                 report)
        return np.mean(iou_a / (iou_b + delta)) * 100
Ejemplo n.º 6
0
        if self.use_max:
            pred = np.max(np.stack(fuse_tps, axis=0), axis=0)
        else:
            pred = np.mean(np.stack(fuse_tps, axis=0), axis=0)
        return np.expand_dims(data_utils.change_channel_order(
            pred, to_channel_last=False),
                              axis=0)


if __name__ == '__main__':
    rgb_file = r'/media/ei-edl01/data/remote_sensing_data/inria/images/austin1.tif'
    lbl_file = r'/media/ei-edl01/data/remote_sensing_data/inria/gt/austin1.tif'
    conf_file = r'/hdd/Results/mrs/inria/ecresnet50_dcunet_dsinria_lre1e-04_lrd1e-04_ep50_bs7_ds50_dr0p1/austin1.npy'
    rgb = misc_utils.load_file(rgb_file)
    lbl_img, conf_img = misc_utils.load_file(
        lbl_file) / 255, misc_utils.load_file(conf_file)

    osc = ObjectScorer(min_region=5, min_th=0.5, link_r=10, eps=2)
    lbl_groups = osc.get_object_groups(lbl_img)
    conf_groups = osc.get_object_groups(conf_img)
    print(len(lbl_groups), len(conf_groups))
    lbl_group_img = display_group(lbl_groups,
                                  lbl_img.shape[:2],
                                  need_return=True)
    conf_group_img = display_group(conf_groups,
                                   conf_img.shape[:2],
                                   need_return=True)
    vis_utils.compare_figures([rgb, lbl_group_img, conf_group_img], (1, 3),
                              fig_size=(15, 5))