示例#1
0
def patch_tile(rgb_file, gt_file, patch_size, pad, overlap):
    """
    Extract the given rgb and gt tiles into patches
    :param rgb_file: path to the rgb file
    :param gt_file: path to the gt file
    :param patch_size: size of the patches, should be a tuple of (h, w)
    :param pad: #pixels to be padded around each tile, should be either one element or four elements
    :param overlap: #overlapping pixels between two patches in both vertical and horizontal direction
    :return: rgb and gt patches as well as coordinates
    """
    rgb = misc_utils.load_file(rgb_file)
    gt_mask = misc_utils.load_file(gt_file)
    gt = decode_map(gt_mask, DECODER)

    np.testing.assert_array_equal(rgb.shape[:2], gt.shape[:2])
    grid_list = data_utils.make_grid(
        np.array(rgb.shape[:2]) + 2 * pad, patch_size, overlap)
    if pad > 0:
        rgb = data_utils.pad_image(rgb, pad)
        gt = data_utils.pad_image(gt, pad)
    for y, x in grid_list:
        rgb_patch = data_utils.crop_image(rgb, y, x, patch_size[0],
                                          patch_size[1])
        gt_patch = data_utils.crop_image(gt, y, x, patch_size[0],
                                         patch_size[1])

        yield rgb_patch, gt_patch, y, x
示例#2
0
文件: eval_utils.py 项目: chbinb/mrs
def batch_score(pred_files, lbl_files, min_region=5, min_th=0.5, link_r=20, eps=2, iou_th=0.5):
    conf, true = [], []
    for pred_file, lbl_file in tqdm(zip(pred_files, lbl_files), total=len(pred_files)):
        pred, lbl = misc_utils.load_file(pred_file), misc_utils.load_file(lbl_file)
        conf_, true_ = score(pred, lbl, min_region, min_th, link_r, eps, iou_th)
        conf.extend(conf_)
        true.extend(true_)
    return conf, true
示例#3
0
 def extract_(file_list, file_exts, patch_size, pad, overlap, save_path):
     assert len(file_exts) == len(file_list[0])
     pbar = tqdm(file_list)
     record_file = open(os.path.join(save_path, 'file_list.txt'), 'w')
     for files in pbar:
         pbar.set_description('Extracting {}'.format(
             os.path.basename(files[0])))
         patch_list = []
         for f, ext in zip(files, file_exts):
             patch_list_ext = []
             img = misc_utils.load_file(f)
             grid_list = make_grid(
                 np.array(img.shape[:2]) + 2 * pad, patch_size, overlap)
             # extract images
             for patch, y, x in patch_block(img,
                                            pad,
                                            grid_list,
                                            patch_size,
                                            return_coord=True):
                 patch_name = '{}_y{}x{}.{}'.format(
                     os.path.basename(f).split('.')[0], int(y), int(x), ext)
                 patch_name = os.path.join(save_path, patch_name)
                 misc_utils.save_file(patch_name, patch.astype(np.uint8))
                 patch_list_ext.append(patch_name)
             patch_list.append(patch_list_ext)
         patch_list = misc_utils.rotate_list(patch_list)
         for items in patch_list:
             record_file.write('{}\n'.format(' '.join(items)))
     record_file.close()
示例#4
0
def get_images(data_dir):
    record_file_valid = os.path.join(data_dir, 'file_list_valid.txt')
    file_list = misc_utils.load_file(record_file_valid)
    rgb_files, gt_files = [], []
    for line in file_list:
        rgb_file, gt_file = line.strip().split(' ')
        rgb_files.append(os.path.join(data_dir, 'patches', rgb_file))
        gt_files.append(os.path.join(data_dir, 'patches', gt_file))
    return rgb_files, gt_files
示例#5
0
def get_dataset_stats(ds_name,
                      img_dir,
                      load_func=None,
                      file_list=None,
                      mean_val=([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])):
    if ds_name == 'inria':
        from data.inria import preprocess
        val = preprocess.get_stats_pb(img_dir)[0]
        print('Use {} mean std stats: {}'.format(ds_name, val))
    elif ds_name == 'deepglobe':
        from data.deepglobe import preprocess
        val = preprocess.get_stats_pb(img_dir)[0]
        print('Use {} mean std stats: {}'.format(ds_name, val))
    elif ds_name == 'deepgloberoad':
        from data.deepgloberoad import preprocess
        val = preprocess.get_stats_pb(img_dir)[0]
        print('Use {} mean std stats: {}'.format(ds_name, val))
    elif ds_name == 'deepglobeland':
        from data.deepglobeland import preprocess
        val = preprocess.get_stats_pb(img_dir)[0]
        print('Use {} mean std stats: {}'.format(ds_name, val))
    elif ds_name == 'mnih':
        from data.mnih import preprocess
        val = preprocess.get_stats_pb(img_dir)[0]
        print('Use {} mean std stats: {}'.format(ds_name, val))
    elif ds_name == 'spca':
        from data.spca import preprocess
        val = preprocess.get_stats_pb(img_dir)[0]
        print('Use {} mean std stats: {}'.format(ds_name, val))
    elif load_func:
        try:
            val = process_block.ValueComputeProcess(
                ds_name, os.path.join(os.path.dirname(__file__), '../data/stats/custom'),
                os.path.join(os.path.dirname(__file__), '../data/stats/custom/{}.npy'.format(ds_name)),
                func=load_func). \
                run(img_dir=img_dir, file_list=file_list).val
            print('Use {} mean std stats: {}'.format(ds_name, val))
        except ValueError:
            print(
                'Dataset {} is not supported, use default mean stats instead'.
                format(ds_name))
            return np.array(mean_val)
    else:
        try:
            val = misc_utils.load_file(
                os.path.join(
                    os.path.realpath(
                        os.path.join(os.path.dirname(__file__), '..')),
                    'data/stats/custom/{}.npy'.format(ds_name)))
            print('Use {} mean std stats: {}'.format(ds_name, val))
        except (FileNotFoundError, OSError):
            print(
                'Dataset {} is not supported, use default mean stats instead {}'
                .format(ds_name, mean_val))
            return np.array(mean_val)
    return val[0, :], val[1, :]
示例#6
0
 def __getitem__(self, index):
     output_dict = dict()
     output_dict['image'] = misc_utils.load_file(self.img_list[index])
     if self.with_label:
         output_dict['mask'] = misc_utils.load_file(self.lbl_list[index])
     if self.transforms:
         for tsfm in self.transforms:
             tsfm_image = tsfm(**output_dict)
             for key, val in tsfm_image.items():
                 output_dict[key] = val
     if self.with_aux:
         if len(output_dict['mask'].shape) == 2:
             cls = int(
                 torch.mean(output_dict['mask'].type(torch.float)) > 0)
             cls = one_hot(self.n_class, cls).type(torch.float)
         else:
             cls = (torch.sum(output_dict['mask'], dim=-1) > 0).type(
                 torch.float)
         output_dict['cls'] = cls
     return output_dict
示例#7
0
def load_config(model_dir):
    """
    Load definition arguments in the config file, the dictionary will be load into the argument class defined in
    config.py
    :param model_dir: the directory to the model, this directory should be created by train.py and has a config.json
                      file
    :return: the parsed arguments
    """
    config_file = os.path.join(model_dir, 'config.json')
    args = misc_utils.load_file(config_file)
    return misc_utils.historical_process_flag(args)
示例#8
0
    def vis_transform_pair(self, target_img_files):
        """
        Visualize a pair of sample
        :param target_img_files: list of target image files, a random of them will be chosen to display
        """
        def plot_hist(hist, smooth=False):
            import scipy.signal
            color_list = ['r', 'g', 'b']
            for c in range(3):
                if smooth:
                    plt.plot(scipy.signal.savgol_filter(hist[c, :], 11, 2),
                             color_list[c])
                else:
                    plt.plot(hist[c, :], color_list[c])

        rand_img = misc_utils.load_file(
            np.random.choice(self.source_imgs, 1)[0])
        target_img = misc_utils.load_file(
            np.random.choice(target_img_files, 1)[0])
        target_hist = self.get_histogram(target_img_files)
        match_img = self.match_image(self.source_hist, target_hist, target_img)

        plt.figure(figsize=(15, 8))
        plt.subplot(231)
        plt.imshow(rand_img)
        plt.axis('off')
        plt.subplot(234)
        plot_hist(self.source_hist)
        plt.subplot(232)
        plt.imshow(target_img)
        plt.axis('off')
        plt.subplot(235)
        plot_hist(target_hist)
        plt.subplot(233)
        plt.imshow(match_img)
        plt.axis('off')
        plt.subplot(236)
        plot_hist(self.get_histogram([match_img]), smooth=True)
        plt.tight_layout()
        plt.show()
示例#9
0
 def __init__(self,
              parent_path,
              file_list,
              transforms=None,
              n_class=2,
              with_label=True,
              with_aux=False):
     """
     A data reader for the remote sensing dataset
     The dataset storage structure should be like
     /parent_path
         /patches
             img0.png
             img1.png
         file_list.txt
     Normally the downloaded remote sensing dataset needs to be preprocessed
     :param parent_path: path to a preprocessed remote sensing dataset
     :param file_list: a text file where each row contains rgb and gt files separated by space
     :param transforms: albumentation transforms
     :param n_class: if greater than 0, will yield a #classes dimension vector where 1 indicates corresponding class exist
     :param with_label: if True, label files will be read, otherwise label files will be ignored
     :param with_aux: if True, auxiliary classification label will be returned
     """
     self.with_label = with_label
     try:
         file_list = misc_utils.load_file(file_list)
         self.img_list, self.lbl_list = get_file_paths(
             parent_path, file_list, self.with_label)
     except OSError:
         file_list = eval(file_list)
         parent_path = eval(parent_path)
         self.img_list, self.lbl_list = [], []
         for fl, pp in zip(file_list, parent_path):
             img_list, lbl_list = get_file_paths(pp,
                                                 misc_utils.load_file(fl))
             self.img_list.extend(img_list)
             self.lbl_list.extend(lbl_list)
     self.transforms = transforms
     self.n_class = n_class
     self.with_aux = with_aux
示例#10
0
文件: eval_utils.py 项目: chbinb/mrs
    def infer(self, model, pred_dir, patch_size, overlap, ext='_mask', file_ext='png', visualize=False,
              densecrf=False, crf_params=None):
        if isinstance(model, list) or isinstance(model, tuple):
            lbl_margin = model[0].lbl_margin
        else:
            lbl_margin = model.lbl_margin
        if crf_params is None and densecrf:
            crf_params = {'sxy': 3, 'srgb': 3, 'compat': 5}

        misc_utils.make_dir_if_not_exist(pred_dir)
        pbar = tqdm(self.rgb_files)
        for rgb_file in pbar:
            file_name = os.path.splitext(os.path.basename(rgb_file))[0].split('_')[0]
            pbar.set_description('Inferring {}'.format(file_name))
            # read data
            rgb = misc_utils.load_file(rgb_file)[:, :, :3]

            # evaluate on tiles
            tile_dim = rgb.shape[:2]
            tile_dim_pad = [tile_dim[0] + 2 * lbl_margin, tile_dim[1] + 2 * lbl_margin]
            grid_list = patch_extractor.make_grid(tile_dim_pad, patch_size, overlap)

            if isinstance(model, list) or isinstance(model, tuple):
                tile_preds = 0
                for m in model:
                    tile_preds = tile_preds + self.infer_tile(m, rgb, grid_list, patch_size, tile_dim, tile_dim_pad,
                                                              lbl_margin)
            else:
                tile_preds = self.infer_tile(model, rgb, grid_list, patch_size, tile_dim, tile_dim_pad, lbl_margin)

            if densecrf:
                d = dcrf.DenseCRF2D(*tile_preds.shape)
                U = unary_from_softmax(np.ascontiguousarray(
                    data_utils.change_channel_order(tile_preds, False)))
                d.setUnaryEnergy(U)
                d.addPairwiseBilateral(rgbim=rgb, **crf_params)
                Q = d.inference(5)
                tile_preds = np.argmax(Q, axis=0).reshape(*tile_preds.shape[:2])
            else:
                tile_preds = np.argmax(tile_preds, -1)

            if self.encode_func:
                pred_img = self.encode_func(tile_preds)
            else:
                pred_img = tile_preds

            if visualize:
                vis_utils.compare_figures([rgb, pred_img], (1, 2), fig_size=(12, 5))

            misc_utils.save_file(os.path.join(pred_dir, '{}{}.{}'.format(file_name, ext, file_ext)), pred_img)
示例#11
0
def get_class_distribution(img_dir):
    dirs = [
        'land-train/land-train',
    ]
    gt_imgs = []
    cnt = np.zeros(7)
    for dir_ in dirs:
        gt_imgs.extend([
            a[1] for a in data_utils.get_img_lbl(os.path.join(img_dir, dir_),
                                                 'sat.jpg', 'mask.png')
        ])
    for gt_img in tqdm(gt_imgs):
        gt = misc_utils.load_file(gt_img)
        gt = decode_map(gt)
        hist, _ = np.histogram(gt, bins=np.arange(8))
        cnt += hist
    return cnt
示例#12
0
 def get_histogram(img_files, progress=False):
     """
     Get the histogram of given list of images
     :param img_files: list of images, could be file names or numpy arrays
     :param progress: if True, will show a progress bar
     :return: a numpy array of size (3, 256) where each row represents histogram of certain color channel
     """
     hist = np.zeros((3, 256))
     if progress:
         pbar = tqdm(img_files)
     else:
         pbar = img_files
     for img_file in pbar:
         if isinstance(img_file, str):
             img = misc_utils.load_file(img_file)
         else:
             img = img_file
         for channel in range(3):
             img_hist, _ = np.histogram(img[:, :, channel].flatten(),
                                        bins=np.arange(0, 257))
             hist[channel, :] += img_hist
     return hist
示例#13
0
    def match_target_images(self, target_imgs, individual=False):
        """
        Match the given list of target images
        :param target_imgs: list of image files, could be file names or numpy arrays
        :param individual: if True, compute histogram of each target image respectively
        :return: a generator that yields adjusted image one each time
        """
        if not individual:
            target_hist = self.get_histogram(target_imgs)
        for target_img_file in target_imgs:
            if individual:
                target_hist = self.get_histogram([target_img_file])
                '''import scipy.signal
                color_list = ['r', 'g', 'b']
                for c in range(3):
                    plt.plot(target_hist[c, :], color_list[c])
                plt.show()'''

            if isinstance(target_img_file, str):
                img = misc_utils.load_file(target_img_file)
            else:
                img = target_img_file
            yield self.match_image(self.source_hist, target_hist, img)
示例#14
0
    def run(self, force_run=False, **kwargs):
        """
        Run the process
        :param force_run: if True, then the process will run no matter it has completed before
        :param kwargs:
        :return:
        """
        # check if state file exists
        state_exist = os.path.exists(self.state_file)
        # run the function if force run or haven't run before
        if force_run or state_exist == 0:
            print(('Start running {}'.format(self.name)))
            # write state log as incomplete
            with open(self.state_file, 'w') as f:
                f.write('Incomplete\n')

            # run the process
            self.val = self.func(**kwargs)

            # write state log as complete
            with open(self.state_file, 'w') as f:
                f.write('Finished\n')
            misc_utils.save_file(self.save_path, self.val)
        else:
            # if haven't run before, run the process
            if not self.check_finish():
                self.val = self.func(**kwargs)
                misc_utils.save_file(self.save_path, self.val)

            # if already exists, load the file
            self.val = misc_utils.load_file(self.save_path)

            # write state log as complete
            with open(self.state_file, 'w') as f:
                f.write('Finished\n')
        return self
示例#15
0
文件: eval_utils.py 项目: xyt556/mrs
    def evaluate(self,
                 model,
                 patch_size,
                 overlap,
                 pred_dir=None,
                 report_dir=None,
                 save_conf=False,
                 delta=1e-6,
                 eval_class=(1, ),
                 visualize=False,
                 densecrf=False,
                 crf_params=None,
                 verbose=True):
        if isinstance(model, list) or isinstance(model, tuple):
            lbl_margin = model[0].lbl_margin
        else:
            lbl_margin = model.lbl_margin
        if crf_params is None and densecrf:
            crf_params = {'sxy': 3, 'srgb': 3, 'compat': 5}

        iou_a, iou_b = np.zeros(len(eval_class)), np.zeros(len(eval_class))
        report = []
        if pred_dir:
            misc_utils.make_dir_if_not_exist(pred_dir)
        for rgb_file, lbl_file in zip(self.rgb_files, self.lbl_files):
            file_name = os.path.splitext(os.path.basename(lbl_file))[0]

            # read data
            rgb = misc_utils.load_file(rgb_file)[:, :, :3]
            lbl = misc_utils.load_file(lbl_file)
            if self.decode_func:
                lbl = self.decode_func(lbl)

            # evaluate on tiles
            tile_dim = rgb.shape[:2]
            tile_dim_pad = [
                tile_dim[0] + 2 * lbl_margin, tile_dim[1] + 2 * lbl_margin
            ]
            grid_list = patch_extractor.make_grid(tile_dim_pad, patch_size,
                                                  overlap)

            if isinstance(model, list) or isinstance(model, tuple):
                tile_preds = 0
                for m in model:
                    tile_preds = tile_preds + self.infer_tile(
                        m, rgb, grid_list, patch_size, tile_dim, tile_dim_pad,
                        lbl_margin)
            else:
                tile_preds = self.infer_tile(model, rgb, grid_list, patch_size,
                                             tile_dim, tile_dim_pad,
                                             lbl_margin)

            if save_conf:
                misc_utils.save_file(
                    os.path.join(pred_dir, '{}.npy'.format(file_name)),
                    scipy.special.softmax(tile_preds, axis=-1)[:, :, 1])

            if densecrf:
                d = dcrf.DenseCRF2D(*tile_preds.shape)
                U = unary_from_softmax(
                    np.ascontiguousarray(
                        data_utils.change_channel_order(
                            scipy.special.softmax(tile_preds, axis=-1),
                            False)))
                d.setUnaryEnergy(U)
                d.addPairwiseBilateral(rgbim=rgb, **crf_params)
                Q = d.inference(5)
                tile_preds = np.argmax(Q,
                                       axis=0).reshape(*tile_preds.shape[:2])
            else:
                tile_preds = np.argmax(tile_preds, -1)
            iou_score = metric_utils.iou_metric(lbl / self.truth_val,
                                                tile_preds,
                                                eval_class=eval_class)
            pstr, rstr = self.get_result_strings(file_name, iou_score, delta)
            tm.misc_utils.verb_print(pstr, verbose)
            report.append(rstr)
            iou_a += iou_score[0, :]
            iou_b += iou_score[1, :]
            if visualize:
                if self.encode_func:
                    vis_utils.compare_figures([
                        rgb,
                        self.encode_func(lbl),
                        self.encode_func(tile_preds)
                    ], (1, 3),
                                              fig_size=(15, 5))
                else:
                    vis_utils.compare_figures([rgb, lbl, tile_preds], (1, 3),
                                              fig_size=(15, 5))
            if pred_dir:
                if self.encode_func:
                    misc_utils.save_file(
                        os.path.join(pred_dir, '{}.png'.format(file_name)),
                        self.encode_func(tile_preds))
                else:
                    misc_utils.save_file(
                        os.path.join(pred_dir, '{}.png'.format(file_name)),
                        tile_preds)
        pstr, rstr = self.get_result_strings('Overall',
                                             np.stack([iou_a, iou_b], axis=0),
                                             delta)
        tm.misc_utils.verb_print(pstr, verbose)
        report.append(rstr)
        if report_dir:
            misc_utils.make_dir_if_not_exist(report_dir)
            misc_utils.save_file(os.path.join(report_dir, 'result.txt'),
                                 report)
        return np.mean(iou_a / (iou_b + delta)) * 100
示例#16
0
def make_dataset(ds_train, ds_valid, save_dir, th=0.5):
    import solaris as sol

    # create folders and files
    patch_dir = os.path.join(save_dir, 'patches')
    misc_utils.make_dir_if_not_exist(patch_dir)
    record_file_train = open(os.path.join(save_dir, 'file_list_train.txt'),
                             'w+')
    record_file_valid = open(os.path.join(save_dir, 'file_list_valid.txt'),
                             'w+')

    # remove counting
    remove_train_cnt = 0
    remove_valid_cnt = 0

    # make dataset
    ds_dict = {
        'train': {
            'ds': ds_train,
            'record': record_file_train,
            'remove_cnt': remove_train_cnt
        },
        'valid': {
            'ds': ds_valid,
            'record': record_file_valid,
            'remove_cnt': remove_valid_cnt
        }
    }

    # valid ds
    for phase in ['valid', 'train']:
        for rgb_file, gt_file in tqdm(ds_dict[phase]['ds']):
            img_save_name = os.path.join(
                patch_dir, '{}.jpg'.format(
                    os.path.splitext(os.path.basename(rgb_file))[0]))
            lbl_save_name = os.path.join(
                patch_dir, '{}.png'.format(
                    os.path.splitext(os.path.basename(rgb_file))[0]))
            convert_gtif_to_8bit(rgb_file, img_save_name)
            img = misc_utils.load_file(img_save_name)
            lbl = sol.vector.mask.footprint_mask(df=gt_file,
                                                 reference_im=rgb_file)

            # from mrs_utils import vis_utils
            # vis_utils.compare_figures([img, lbl], (1, 2), fig_size=(12, 5))

            blank_region = check_blank_region(img)
            if blank_region > th:
                ds_dict[phase]['remove_cnt'] += 1
                os.remove(img_save_name)
            else:
                if img.shape[0] != lbl.shape[0] or img.shape[1] != lbl.shape[1]:
                    assert np.unique(lbl) == np.array([0])
                    lbl = lbl[:img.shape[0], :img.shape[1]]
                misc_utils.save_file(os.path.join(patch_dir, lbl_save_name),
                                     (lbl / 255).astype(np.uint8))
                ds_dict[phase]['record'].write('{} {}\n'.format(
                    os.path.basename(img_save_name),
                    os.path.basename(lbl_save_name)))
        ds_dict[phase]['record'].close()
        print('{} set: {:.2f}% data removed with threshold of {}'.format(
            phase, ds_dict[phase]['remove_cnt'] / len(ds_dict[phase]['ds']),
            th))
        print('\t kept patches: {}'.format(
            len(ds_dict[phase]['ds']) - ds_dict[phase]['remove_cnt']))

        files_remove = glob(os.path.join(patch_dir, '*.aux.xml'))
        for f in files_remove:
            os.remove(f)
    :param conf_img:
    :return:
    """
    coords = []
    for g in reg_group:
        coords.extend(g.coords)
    coords = np.array(coords)
    return coords


city = "austin"
city_name = "austin3"
targ_rgb_dir = r"/hdd/inria/train/images/" + city_name + ".tif"
targ_gt_dir = r"/hdd/inria/train/gt/" + city_name + ".tif"

targ_rgb = misc_utils.load_file(targ_rgb_dir)
targ_lbl = misc_utils.load_file(targ_gt_dir) / 255

osc = eval_utils.ObjectScorer(min_th=0.5, link_r=10, eps=2)

#Get object groups
print("Extracting the building as objects...")
lbl_groups = osc.get_object_groups(targ_lbl)
rgb_groups = osc.get_object_groups(targ_lbl)

print("Number of building 'groups': ", len(lbl_groups))

# Get the colors for everything in every building
grp_id = 0
size_dict = dict()
def implant_textures(rgb_input_path, gt_input_path, stl, mtl, ltl,
                     small_threshold, large_threshold):
    targ_rgb = cv2.imread(rgb_input_path)
    targ_gt = cv2.imread(gt_input_path)

    #load in file
    targ_rgb = misc_utils.load_file(rgb_input_path)
    cv2.imwrite(
        r"/hdd/2019-bass-connections-aatb/mrs_new/object_exctraction/Aneesh/implantation_with_size/original_satellite.png",
        targ_rgb)
    # print("RGB Shape", rgb.shape)
    targ_lbl = misc_utils.load_file(gt_input_path) / 255

    #Prepare to extract objects (Using Bohao's class)
    osc = eval_utils.ObjectScorer(min_th=0.5, link_r=10, eps=2)

    #Get object groups
    print("Extracting the building as objects...")
    lbl_groups = osc.get_object_groups(targ_lbl)

    print("Number of building 'groups': ", len(lbl_groups))
    # print("\n Replacing roofs for eligible buildings...")
    counter = 0
    # Get the colors for everything in every building
    i = 0
    size_dict = dict()
    for g_lbl in tqdm(lbl_groups):
        i += 1
        # coords_lbl is the list of coords, [[x1 y1] [x2 y2] ....]
        coords_lbl = get_stats_from_group(g_lbl)
        #area = num of pixels in roof
        group_area = sum([k.area for k in g_lbl])
        grp_id = i

        if group_area <= small_threshold:
            roof_tex = random.choice(stl)
        elif group_area >= large_threshold:
            roof_tex = random.choice(ltl)
        else:
            roof_tex = random.choice(mtl)
        texture_area = roof_tex.shape[0] * roof_tex.shape[1]
        #getting top left of the target roof
        xi, yi = coords_lbl[0]

        # This is the 'logic' is used.
        # I just match up top left of target roof with top left of texture.
        # Not entirely correct, here I just say if target building area is less than texture area
        # But this fails for complex geometries, where target area is less than textures, but
        # still falls out of the texture box
        if (group_area <= texture_area):
            counter += 1
            for k in coords_lbl:
                # k is a single coordinate, ie, [x y]
                # matching top left to top left.
                try:
                    targ_rgb[k[0]][k[1]] = roof_tex[k[0] - xi][k[1] - yi]
                except:
                    print("an exception has occured")
        size_dict[i] = group_area
    print(counter, " building roofs changed...")
    cv2.imwrite(
        r"/hdd/2019-bass-connections-aatb/mrs_new/object_exctraction/Aneesh/implantation_with_size/austin_implanted.png",
        targ_rgb)
示例#19
0
文件: eval_utils.py 项目: xyt556/mrs
def read_results(result_name,
                 regex=None,
                 sum_results=False,
                 delta=1e-6,
                 class_names=None):
    """
    Read and parse evaluated results text file
    :param result_name: path to the results file
    :param regex: if given, it will be applied to select lines that match the name
    :param sum_results: if True, return the IoU of the overall dataset
    :param delta: a small value to prevent divided by zero
    :param class_names: list of strings for class names, if None, they will be class_i
    :return:
    """
    def update_results(res, n, i_res, c_names):
        if c_names is not None:
            assert len(i_res) == 2 * len(c_names)
        else:
            c_names = ['class_{}'.format(i) for i in range(len(i_res) // 2)]
        for cnt, c_name in enumerate(c_names):
            res[n][c_name + '_a'] = i_res[cnt * 2]
            res[n][c_name + '_b'] = i_res[cnt * 2 + 1]
        return c_names

    def combine_results(res, i_res):
        if res is None:
            res = dict()
            for k, v in i_res.items():
                if k != 'iou':
                    res.update({k: v})
        else:
            for k, v in i_res.items():
                if k != 'iou':
                    if k in res:
                        res[k] += v
                    else:
                        res.update({k: v})
        return res

    def summarize_results(res):
        sum_res = dict()
        for c_name in class_names + ['iou']:
            res[c_name] = (float(res[c_name + '_a']) /
                           float(res[c_name + '_b']) + delta) * 100
        if 'iou' in res:
            sum_res['iou'] = res['iou']
        else:
            overall_iou = []
            for c_name in class_names:
                overall_iou.append(sum_res[c_name])
            sum_res['iou'] = np.mean(overall_iou)
        return sum_res

    results = {}
    result_lines = misc_utils.load_file(result_name)

    for line in result_lines:
        if len(line) <= 1:
            continue
        name, iou_a, iou_b, *ious, iou = line.strip().split(',')
        iou_a, iou_b, iou = float(iou_a), float(iou_b), float(iou)
        results[name] = {'iou': iou, 'iou_a': iou_a, 'iou_b': iou_b}
        class_names = update_results(results, name, ious, class_names)
    if regex:
        comb_res = None
        for key, val in results.items():
            if re.search(regex, key):
                comb_res = combine_results(comb_res, val)
        return summarize_results(comb_res)
    elif sum_results:
        return summarize_results(results['Overall'])
    else:
        return results
    coords = []
    for g in reg_group:
        coords.extend(g.coords)
    coords = np.array(coords)
    return coords


for city in city_names:
    # Specify paths
    print("\nCity: ", city)
    rgb_file = r"/hdd/inria/train/images/" + city + r"1.tif"
    lbl_file = r"/hdd/inria/train/gt/" + city + r"1.tif"
    conf_file = r"/hdd/2019-bass-connections-aatb/mrs_new/results/ecresnet50_dcunet_dsinria_lre1e-03_lrd1e-02_ep25_bs5_ds50_dr0p1/" + city + r"1.npy"

    #load in file
    rgb = misc_utils.load_file(rgb_file)
    lbl_img, conf_img = misc_utils.load_file(
        lbl_file) / 255, misc_utils.load_file(conf_file)

    #Prepare to extracto objects (Using Bohao's class)
    osc = eval_utils.ObjectScorer(min_th=0.5, link_r=10, eps=2)

    #Get object groups
    print("Extracting the objects...")
    lbl_groups = osc.get_object_groups(lbl_img)
    conf_groups = osc.get_object_groups(conf_img)

    print("Number of object 'groups': ", len(lbl_groups))

    # Get the colors for everything in every building
    i = 0