예제 #1
0
def get_ious():
    conf_dir = os.path.join(task_dir, 'conf_map_{}'.format(model_name))
    conf_files = sorted(glob(os.path.join(conf_dir, '*.npy')))

    cm = collectionMaker.read_collection('aemo_pad')
    truth_files = cm.load_files(field_name='aus50',
                                field_id='',
                                field_ext='.*gt_d255')
    truth_files = [f[0] for f in truth_files[:2]]
    '''uniq_vals = []
    for conf, truth in zip(conf_files, truth_files):
        c = ersa_utils.load_file(conf)

        uniq_vals.append(np.unique(c.flatten()))
    uniq_vals = np.sort(np.unique(np.concatenate(uniq_vals)))

    ious_a = np.zeros(len(uniq_vals))
    ious_b = np.zeros(len(uniq_vals))'''

    uniq_vals = np.linspace(0, 1, 1000)
    ious_a = np.zeros(len(uniq_vals))
    ious_b = np.zeros(len(uniq_vals))

    for conf, truth in zip(conf_files, truth_files):
        c = ersa_utils.load_file(conf)
        t = ersa_utils.load_file(truth)

        for cnt, th in enumerate(tqdm(uniq_vals)):
            c_th = (c > th).astype(np.int)

            a, b = nn_utils.iou_metric(c_th, t, truth_val=1, divide_flag=True)
            ious_a[cnt] = a
            ious_b[cnt] = b
    return np.stack([uniq_vals, ious_a, ious_b], axis=0)
예제 #2
0
def data_reader(file_list, chan_mean, th=1e-2):
    for cnt, (rgb_file, gt_file) in enumerate(file_list):
        rgb = crop_center(ersa_utils.load_file(rgb_file), 224, 224)
        gt = crop_center(ersa_utils.load_file(gt_file), 224, 224)

        if np.sum(gt) / (224 * 224) > th:
            yield rgb - chan_mean, rgb_file
def make_dataset(rgb_files, info_dir, store_dir, tf_dir, city_name=''):
    writer_train = tf.python_io.TFRecordWriter(
        os.path.join(tf_dir, 'train_v2_{}.record'.format(city_name)))
    writer_valid = tf.python_io.TFRecordWriter(
        os.path.join(tf_dir, 'valid_v2_{}.record'.format(city_name)))

    for rgb_file_name in rgb_files:
        file_name = os.path.basename(rgb_file_name[:-4])
        if 'NZ' not in file_name:
            city_id = int(file_name.split('_')[-1])
        else:
            city_id = int(file_name.split('_')[-2])
        if city_id <= 3:
            print('Processing file {} in validation set'.format(file_name))
            is_val = True
        else:
            print('Processing file {} in training set'.format(file_name))
            is_val = False

        rgb = ersa_utils.load_file(rgb_file_name)
        npy_file_name = os.path.join(
            info_dir, os.path.basename(rgb_file_name[:-4] + '.npy'))
        coords = ersa_utils.load_file(npy_file_name)

        patch_cnt = 0
        for line in coords:
            for cell in line:
                patch_cnt += 1
                save_name = os.path.join(
                    store_dir,
                    os.path.basename(rgb_file_name[:-4] +
                                     '_{}.jpg'.format(patch_cnt)))
                img = rgb[cell['h']:cell['h'] + PATCH_SIZE[0],
                          cell['w']:cell['w'] + PATCH_SIZE[1], :3]
                label = cell['label']
                # assert np.unique(label) == ['DT'] or label == []
                box = cell['box']
                '''import matplotlib.pyplot as plt
                import matplotlib.patches as patches
                if len(label) > 0:
                    fig, ax = plt.subplots(1)
                    ax.imshow(img)
                    for l, b in zip(label, box):
                        rect = patches.Rectangle((b[1], b[0]), b[3]-b[1], b[2]-b[0], linewidth=1, edgecolor='r', facecolor='none')
                        ax.add_patch(rect)
                    plt.show()'''

                ersa_utils.save_file(save_name, img)

                tf_example = create_tf_example(save_name, label, box)
                if is_val:
                    writer_valid.write(tf_example.SerializeToString())
                else:
                    writer_train.write(tf_example.SerializeToString())

    writer_train.close()
    writer_valid.close()
예제 #4
0
def get_spcastats():
    print('Extracting panel pixels in aemo...')
    aemo_stats = np.zeros((3, 255))
    for rgb_file in tqdm(aemo_hist_files):
        gt_file = os.path.join(aemo_dir, os.path.basename(rgb_file[:-7]) + 'gt_d255.tif')
        rgb = ersa_utils.load_file(rgb_file)
        gt = ersa_utils.load_file(gt_file)

        for c in range(3):
            cnt, _ = np.histogram(rgb[:, :, c] * gt, bins=np.arange(256))
            aemo_stats[c, :] += cnt / np.sum(gt)
        aemo_stats = aemo_stats / len(aemo_files)
    return aemo_stats
예제 #5
0
def load_data(dirs, model_name, city_id, tile_id, merge_range=100):
    conf_dict = {0: 2, 1: 1, 2: 0, 3: 3}
    pred_file_name = os.path.join(dirs['task'], model_name + '_v2', 'NZ_{}_{}_resize.txt'.format(city_list[city_id], tile_id))
    preds = ersa_utils.load_file(pred_file_name)
    raw_rgb = ersa_utils.load_file(os.path.join(dirs['raw'], 'NZ_{}_{}_resize.tif'.format(city_list[city_id], tile_id)))
    conf_img = ersa_utils.load_file(os.path.join(dirs['conf'],
                                                 '{}{}.png'.format(city_list[city_id].replace(' ', ''), tile_id)))
    line_gt = ersa_utils.load_file(os.path.join(dirs['line'], '{}{}_GT.png'.format(city_list[city_id].replace(' ', ''),
                                                                                   tile_id)))
    tower_gt = get_tower_truth_pred(dirs, city_id, tile_id)
    tower_pred, tower_conf, _ = local_maxima_suppression(preds, th=merge_range)
    conf_img = scipy.misc.imresize(conf_img, line_gt.shape)
    return preds, raw_rgb, conf_img, line_gt, tower_gt, tower_pred, tower_conf
예제 #6
0
def make_dataset(rgb_files, info_dir, store_dir, city_name=''):
    writer_train = open(
        os.path.join(store_dir, 'data', 'train_{}_T.txt'.format(city_name)),
        'w+')
    writer_valid = open(
        os.path.join(store_dir, 'data', 'test_{}_T.txt'.format(city_name)),
        'w+')

    for rgb_file_name in rgb_files:
        file_name = os.path.basename(rgb_file_name[:-4])
        city_id = int(file_name.split('_')[-1])
        if city_id <= 3:
            print('Processing file {} in validation set'.format(file_name))
            is_val = True
        else:
            print('Processing file {} in training set'.format(file_name))
            is_val = False

        rgb = ersa_utils.load_file(rgb_file_name)
        npy_file_name = os.path.join(
            info_dir, os.path.basename(rgb_file_name[:-4] + '.npy'))
        coords = ersa_utils.load_file(npy_file_name)

        patch_cnt = 0
        for line in coords:
            for cell in line:
                patch_cnt += 1
                patch_file_name = os.path.basename(rgb_file_name[:-4] +
                                                   '_{}.jpg'.format(patch_cnt))
                img_name = os.path.join(store_dir,
                                        'build/darknet/x64/data/obj',
                                        patch_file_name)
                lbl_name = os.path.join(
                    store_dir, 'build/darknet/x64/data/obj',
                    os.path.basename(rgb_file_name[:-4] +
                                     '_{}.txt'.format(patch_cnt)))
                img = rgb[cell['h']:cell['h'] + PATCH_SIZE[0],
                          cell['w']:cell['w'] + PATCH_SIZE[1], :]
                label = cell['label']
                # assert np.unique(label) == ['DT'] or label == []
                box = cell['box']
                ersa_utils.save_file(img_name, img)
                write_lbl(lbl_name, label, box, PATCH_SIZE)
                if is_val:
                    writer_valid.write('{}\n'.format(img_name))
                else:
                    writer_train.write('{}\n'.format(img_name))

    writer_train.close()
    writer_valid.close()
예제 #7
0
def get_spcastats():
    print('Extacting panel pixels in spca...')
    spca_stats = np.zeros((3, 255))
    for rgb_file in tqdm(spca_files):
        gt_file = rgb_file[:-7] + 'GT.png'
        rgb = ersa_utils.load_file(rgb_file)
        gt = ersa_utils.load_file(gt_file)

        for c in range(3):
            cnt, _ = np.histogram(rgb[:, :, c] * gt, bins=np.arange(256))
            if np.sum(gt) > 0:
                spca_stats[c, :] += cnt / np.sum(gt)
    spca_stats = spca_stats / len(spca_files)
    return spca_stats
def make_dataset_all(rgb_files, info_dir, store_dir, tf_dir):
    writer_train = tf.python_io.TFRecordWriter(
        os.path.join(tf_dir, 'train_v2_xcity.record'))
    writer_valid = tf.python_io.TFRecordWriter(
        os.path.join(tf_dir, 'valid_v2_xcity.record'))

    for rgb_file_name in rgb_files:
        file_name = os.path.basename(rgb_file_name[:-4])
        if 'NZ' not in file_name:
            city_id = int(file_name.split('_')[-1])
        else:
            city_id = int(file_name.split('_')[-2])
        if city_id <= 3:
            print('Processing file {} in validation set'.format(file_name))
            is_val = True
        else:
            print('Processing file {} in training set'.format(file_name))
            is_val = False

        rgb = ersa_utils.load_file(rgb_file_name)
        npy_file_name = os.path.join(
            info_dir, os.path.basename(rgb_file_name[:-4] + '.npy'))
        coords = ersa_utils.load_file(npy_file_name)

        patch_cnt = 0
        for line in coords:
            for cell in line:
                patch_cnt += 1
                save_name = os.path.join(
                    store_dir,
                    os.path.basename(rgb_file_name[:-4] +
                                     '_{}.jpg'.format(patch_cnt)))
                img = rgb[cell['h']:cell['h'] + PATCH_SIZE[0],
                          cell['w']:cell['w'] + PATCH_SIZE[1], :3]
                label = cell['label']
                # assert np.unique(label) == ['DT'] or label == []
                box = cell['box']

                ersa_utils.save_file(save_name, img)

                tf_example = create_tf_example(save_name, label, box)
                if is_val:
                    writer_valid.write(tf_example.SerializeToString())
                else:
                    writer_train.write(tf_example.SerializeToString())

    writer_train.close()
    writer_valid.close()
예제 #9
0
def compare_lines_rnn(task_dir):
    city_list = ['AZ_Tucson', 'KS_Colwich_Maize', 'NC_Clyde', 'NC_Wilmington']
    tp_all, p_d_all, r_d_all = 0, 0, 0
    for city_id in range(4):
        tp_city = 0
        pd_city = 0
        rd_city = 0
        for tile_id in [1, 2, 3]:
            cp_file_name = os.path.join(task_dir, '{}_{}_cp.npy'.format(city_list[city_id], city_id))
            connected_pairs = ersa_utils.load_file(cp_file_name)

            conns_name = os.path.join(task_dir, 'graph_rnn_{}_{}_gt.npy'.format(city_id, tile_id))
            conns = np.load(os.path.join(task_dir, conns_name))

            _, tp, p_d, r_d = grid_score(conns, connected_pairs)
            tp_city += tp
            pd_city += p_d
            rd_city += r_d
        p = tp_city / pd_city
        r = tp_city / rd_city
        f1 = 2 * p * r / (p + r)
        print('{}: f1={:.2f}'.format(city_list[city_id], f1))

        tp_all += tp_city
        p_d_all += pd_city
        r_d_all += rd_city
    p = tp_all / p_d_all
    r = tp_all / r_d_all
    f1 = 2 * p * r / (p + r)
    print('overall f1={:.2f}'.format(f1))
예제 #10
0
 def process(self, **kwargs):
     """
     process to make the new field
     :param kwargs:
         file_list: the list of the files, if not given, use all the files with selected field extension
         file_ext: the new file extension, if not given, use the same as the old one
         d_type: the new data type, if not given, use the same as the old one
     :return:
     """
     if 'file_list' not in kwargs:
         file_list = self.clc.load_files(','.join(self.clc.field_name), ','.join(self.clc.field_id),
                                         self.field_ext_pair[0])
     else:
         file_list = kwargs['file_list']
     assert len(file_list[0]) == 1
     pbar = tqdm(file_list)
     for img_file in pbar:
         save_name = img_file[0].replace(''.join([a for a in self.field_ext_pair[0] if a != '.' and a != '*']),
                                         self.field_ext_pair[1])
         if 'file_ext' in kwargs:
             # user specified a new file extension
             save_name = save_name.replace(save_name.split('.')[-1], kwargs['file_ext'])
         save_name = os.path.join(self.path, os.path.basename(save_name))
         pbar.set_description('Making {}'.format(os.path.basename(save_name)))
         img = ersa_utils.load_file(img_file[0])
         for old_val, new_val in self.switch_dict.items():
             img[np.where(img == old_val)] = new_val
         if 'd_type' in kwargs:
             img = img.astype(kwargs['d_type'])
         ersa_utils.save_file(save_name, img)
         self.files.append(save_name)
예제 #11
0
 def process(self, **kwargs):
     """
     Extract the patches
     :param kwargs:
         file_list: list of lists of the files, can be generated by using collectionMaker.load_files()
         file_exts: extensions of the new files
     :return:
     """
     assert len(kwargs['file_exts']) == len(kwargs['file_list'][0])
     grid_list = None
     if self.tile_size is not None:
         grid_list = make_grid(self.tile_size + 2*self.pad, self.patch_size, self.overlap)
     pbar = tqdm(kwargs['file_list'])
     record_file = open(os.path.join(self.path, 'file_list.txt'), 'w')
     for files in pbar:
         pbar.set_description('Extracting {}'.format(os.path.basename(files[0])))
         patch_list = []
         for f, ext in zip(files, kwargs['file_exts']):
             patch_list_ext = []
             img = ersa_utils.load_file(f)
             if self.tile_size is None:
                 grid_list = make_grid(np.array(img.shape[:2])+2*self.pad, self.patch_size, self.overlap)
             # extract images
             for patch, y, x in patch_block(img, self.pad, grid_list, self.patch_size, return_coord=True):
                 patch_name = '{}_y{}x{}.{}'.format(os.path.basename(f).split('.')[0], int(y), int(x), ext)
                 patch_name = os.path.join(self.path, patch_name)
                 ersa_utils.save_file(patch_name, patch.astype(np.uint8))
                 patch_list_ext.append(patch_name)
             patch_list.append(patch_list_ext)
         patch_list = ersa_utils.rotate_list(patch_list)
         for items in patch_list:
             record_file.write('{}\n'.format(' '.join(items)))
     record_file.close()
예제 #12
0
def data_reader(save_dir, chan_mean):
    file_list = os.path.join(save_dir, 'file_list.txt')
    with open(file_list, 'r') as f:
        files = f.readlines()
    for f in files:
        patch = ersa_utils.load_file(f.strip())
        yield patch - chan_mean, os.path.basename(f.strip())
예제 #13
0
def get_pretrained_weights(weight_dir, model_dir):
    save_name = os.path.join(weight_dir, 'weight.pkl')

    if not os.path.exists(save_name):
        X = tf.placeholder(tf.float32, shape=[None, input_size[0], input_size[1], 3], name='X')
        y = tf.placeholder(tf.int32, shape=[None, input_size[0], input_size[1], 1], name='y')
        mode = tf.placeholder(tf.bool, name='mode')
        model = uabMakeNetwork_UNet.UnetModelCrop({'X': X, 'Y': y},
                                                  trainable=mode,
                                                  input_size=input_size,
                                                  start_filter_num=32)
        model.create_graph('X', class_num=2)
        train_vars = [v for v in tf.trainable_variables()]

        weight_dict = dict()

        with tf.Session() as sess:
            model.load(model_dir, sess, epoch=95)
            for v in train_vars:
                theta = sess.run(v)
                weight_dict[v.name] = theta
        ersa_utils.save_file(save_name, weight_dict)
    else:
        weight_dict = ersa_utils.load_file(save_name)

    tf.reset_default_graph()
    return weight_dict
예제 #14
0
 def data_reader_helper(self, files):
     """
     Helper function of data reader, reads list of lists files
     :param files: list of lists, each element is a patch name, each row is corresponds to one file
     :return: feature and label, or only feature
     """
     data_block = []
     for f in files:
         data_block.append(ersa_utils.load_file(f))
     data_block = np.dstack(data_block)
     for aug_func in self.global_func:
         data_block = aug_func(data_block)
     for aug_func in self.aug_func:
         data_block = aug_func(data_block)
     if self.has_gt:
         ftr_block = data_block[:, :, :-self.gt_dim]
         ftr_block = ftr_block - self.chan_mean
         lbl_block = data_block[:, :, -self.gt_dim:]
         if self.include_gt:
             return ftr_block, lbl_block
         else:
             return ftr_block
     else:
         data_block = data_block - self.chan_mean
         return data_block
예제 #15
0
def get_pretrained_weights(flags):
    save_name = os.path.join(flags.weight_dir, 'weight.pkl')

    if not os.path.exists(save_name):
        X = tf.placeholder(tf.float32, shape=[None, flags.input_size[0], flags.input_size[1], 3], name='X')
        y = tf.placeholder(tf.int32, shape=[None, flags.input_size[0], flags.input_size[1], 1], name='y')
        mode = tf.placeholder(tf.bool, name='mode')
        model = uabMakeNetwork_UNet.UnetModelCrop({'X': X, 'Y': y},
                                                  trainable=mode,
                                                  model_name=flags.model_name,
                                                  input_size=flags.input_size,
                                                  batch_size=flags.batch_size,
                                                  learn_rate=flags.learning_rate,
                                                  decay_step=flags.decay_step,
                                                  decay_rate=flags.decay_rate,
                                                  epochs=flags.epochs,
                                                  start_filter_num=flags.sfn)
        model.create_graph('X', class_num=flags.num_classes)
        train_vars = [v for v in tf.global_variables() if 'global_step' not in v.name]

        weight_dict = dict()

        with tf.Session() as sess:
            model.load(flags.model_dir, sess, epoch=95)
            for v in train_vars:
                theta = sess.run(v)
                weight_dict[v.name] = theta
        ersa_utils.save_file(save_name, weight_dict)
    else:
        weight_dict = ersa_utils.load_file(save_name)

    tf.reset_default_graph()
    return weight_dict
예제 #16
0
 def read_meta_data(self):
     """
     Read meta data of the collection
     :return:
     """
     meta_data = ersa_utils.load_file(os.path.join(self.clc_dir,
                                                   'meta.pkl'))
     return meta_data
예제 #17
0
def align_files(data_dir, save_dir, source_dist, target_dist):
    rgb_files = glob(os.path.join(data_dir, '*.tif'))
    for file in rgb_files:
        print('aligning {}'.format(file))
        im_s = ersa_utils.load_file(file)

        im_res = cust_hist_match(target_dist, source_dist, im_s)
        ersa_utils.save_file(os.path.join(save_dir, os.path.basename(file)), im_res)
예제 #18
0
def plot_pred_stats(gt_files, pred_files, model_name):
    size_all = []
    acc_all = []
    for gt_file, pred_file in zip(gt_files, pred_files):
        gt = ersa_utils.load_file(gt_file)
        pred = ersa_utils.load_file(pred_file)

        sizes, accuracy = pred_stats(pred, gt)
        size_all.append(sizes)
        acc_all.append(accuracy)
    plt.scatter(np.concatenate(size_all), np.concatenate(acc_all), s=8)
    plt.xlabel('Building Size')
    plt.ylabel('Accuracy')
    plt.title(model_name)
    plt.tight_layout()
    plt.savefig(
        os.path.join(img_dir, 'panel_size_vs_acc_{}.png'.format(model_name)))
    plt.close()
예제 #19
0
def write_dataset(rgb_files, info_dir, img_dir, csv_dir, patch_size,
                  city_name):
    df = pd.DataFrame(columns=[
        'filenames', 'width', 'height', 'class', 'xmin', 'ymin', 'xmax',
        'ymax', 'train_test'
    ])

    for rgb_file_name in rgb_files:
        file_name = os.path.basename(rgb_file_name[:-4])
        city_id = int(file_name.split('_')[-1])
        if city_id <= 3:
            print('Processing file {} in validation set'.format(file_name))
            train_test = 'test'
        else:
            print('Processing file {} in training set'.format(file_name))
            train_test = 'train'
        rgb = ersa_utils.load_file(rgb_file_name)
        npy_file_name = os.path.join(
            info_dir, os.path.basename(rgb_file_name[:-4] + '.npy'))
        coords = ersa_utils.load_file(npy_file_name)

        patch_cnt = 0
        record_cnt = 0
        for line in coords:
            for cell in line:
                patch_cnt += 1
                img_name = os.path.basename(rgb_file_name[:-4] +
                                            '_{}.png'.format(patch_cnt))
                save_name = os.path.join(img_dir, img_name)
                img = rgb[cell['h']:cell['h'] + patch_size[0],
                          cell['w']:cell['w'] + patch_size[1], :]
                label = cell['label']
                box = cell['box']
                ersa_utils.save_file(save_name, img)

                if len(box) > 0:
                    for lbl, bbox in zip(label, box):
                        df.loc[patch_cnt] = [
                            img_name, patch_size[0], patch_size[1], lbl,
                            bbox[1], bbox[0], bbox[3], bbox[2], train_test
                        ]
                        record_cnt += 1
    df.to_csv(os.path.join(csv_dir, 'labels_{}.csv'.format(city_name)),
              index=False)
예제 #20
0
def plot_across_model(link_r=20,
                      model_names=('faster_rcnn', 'faster_rcnn_res101',
                                   'faster_rcnn_res50')):
    plt.figure(figsize=(10, 8))

    city_list = ['AZ_Tucson', 'KS_Colwich_Maize', 'NC_Clyde', 'NC_Wilmington']
    for city_id in range(4):
        plt.subplot(221 + city_id)
        for model_name in model_names:
            pred_list_all = []
            gt_list_all = []
            cf_list_all = []
            for tile_id in [1, 2, 3]:
                # load data
                pred_file_name = os.path.join(
                    task_dir, model_name,
                    'USA_{}_{}.txt'.format(city_list[city_id], tile_id))
                preds = ersa_utils.load_file(pred_file_name)
                csv_file_name = os.path.join(
                    raw_dir, 'USA_{}_{}.csv'.format(city_list[city_id],
                                                    tile_id))
                pred_list = []
                gt_list = []
                cf_list = []

                center_list, conf_list, _ = local_maxima_suppression(preds)
                for center, conf in zip(center_list, conf_list):
                    pred_list.append(center.tolist())
                    cf_list.append(conf)

                for label, bbox in read_polygon_csv_data(csv_file_name):
                    y, x = get_center_point(*bbox)
                    gt_list.append([y, x])

                pred_list_all.extend(pred_list)
                gt_list_all.extend(gt_list)
                cf_list_all.extend(cf_list)

            f1, y_true, y_score = radius_scoring(pred_list_all, gt_list_all,
                                                 cf_list_all, link_r)
            ap = average_precision_score(y_true, y_score)
            precision, recall, _ = precision_recall_curve(y_true, y_score)
            plt.step(recall[1:],
                     precision[1:],
                     alpha=1,
                     where='post',
                     label='{}, AP={:.2f}'.format(model_name, ap))
            plt.xlabel('Recall')
            plt.ylabel('Precision')
            plt.ylim([0.0, 1.05])
            plt.xlim([0.0, 1.0])
            plt.title('{} Performance Comparison'.format(city_list[city_id]))
            plt.legend(loc='lower left')
    plt.tight_layout()
    plt.savefig(os.path.join(img_dir, 'cmp_tile_pr.png'))
    plt.show()
예제 #21
0
def get_spcastats():
    print('Extracting panel pixels in aemo...')
    aemo_stats = np.zeros((3, 255))
    for rgb_file in tqdm(aemo_hist_files):
        rgb = ersa_utils.load_file(rgb_file)

        for c in range(3):
            cnt, _ = np.histogram(rgb[:, :, c], bins=np.arange(256))
            aemo_stats[c, :] += cnt
        aemo_stats = aemo_stats / len(aemo_hist_files)
    return aemo_stats
예제 #22
0
def get_aemo_stats(data_dir, suffix='*.tif'):
    rgb_files = glob(os.path.join(data_dir, suffix))
    dist = np.zeros((3, 255))
    for rgb_file in rgb_files:
        rgb = ersa_utils.load_file(rgb_file)
        for c in range(3):
            rgb_cnt, _ = np.histogram(rgb[:, :, c], bins=np.arange(256))
            dist[c, :] += rgb_cnt
    dist[:, :2] = 0
    dist[:, -1] = dist[:, -2]
    return dist
예제 #23
0
def process_files(save_dir, file_list, code_list):
    for f, c in zip(file_list, code_list):
        print('processing: {} with code {}'.format(f,c))
        sub_dir = os.path.join(save_dir, '/'.join(f.split('/')[5:-1]))
        ersa_utils.make_dir_if_not_exist(sub_dir)
        save_name = os.path.join(sub_dir, os.path.basename(f))

        rgb = ersa_utils.load_file(f)
        rgb_new = makeup_aemo_img(rgb, c)
        
        ersa_utils.save_file(save_name, rgb_new)
예제 #24
0
def load_data(dirs, model_name, city_id, tile_id, merge_range=100):
    conf_dict = {0: 2, 1: 1, 2: 0, 3: 3}
    pred_file_name = os.path.join(
        dirs['task'], model_name,
        'USA_{}_{}.txt'.format(city_list[city_id], tile_id))
    preds = ersa_utils.load_file(pred_file_name)
    raw_rgb = ersa_utils.load_file(
        os.path.join(dirs['raw'],
                     'USA_{}_{}.tif'.format(city_list[city_id], tile_id)))
    conf_img = ersa_utils.load_file(
        os.path.join(
            dirs['conf'].format(conf_dict[city_id]),
            '{}{}.png'.format(city_list[city_id].split('_')[1], tile_id)))
    line_gt = ersa_utils.load_file(
        os.path.join(
            dirs['line'],
            '{}{}_GT.png'.format(city_list[city_id].split('_')[1], tile_id)))
    tower_gt = get_tower_truth_pred(dirs, city_id, tile_id)
    tower_pred, tower_conf, _ = local_maxima_suppression(preds, th=merge_range)
    return preds, raw_rgb, conf_img, line_gt, tower_gt, tower_pred, tower_conf
예제 #25
0
def get_spcastats():
    spca_dir = r'/media/ei-edl01/data/uab_datasets/spca/data/Original_Tiles'
    spca_files = glob(os.path.join(spca_dir, '*_RGB.jpg'))
    idx = np.random.permutation(len(spca_files))
    spca = np.zeros((3, 255))
    for i in tqdm(idx[:100]):
        rgb = ersa_utils.load_file(spca_files[i])
        for c in range(3):
            cnt, _ = np.histogram(rgb[:, :, c], bins=np.arange(256))
            spca[c, :] += cnt
    spca = spca / 100
    return spca
예제 #26
0
def plot_gt_panel_stats(gt_files):
    for gt_file in gt_files:
        tile_name = os.path.basename(gt_file)[:-9]
        gt = ersa_utils.load_file(gt_file)
        sizes, coords = get_objects(gt)
        plt.hist(sizes, bins=np.arange(2500))
        plt.xlabel('Panel Size')
        plt.ylabel('Cnts')
        plt.title(tile_name)
        plt.tight_layout()
        plt.savefig(
            os.path.join(img_dir, '{}_panel_size_stats.png'.format(tile_name)))
        plt.close()
예제 #27
0
def read_collection(clc_name=None,
                    clc_dir=None,
                    raw_data_path=None,
                    field_name=None,
                    field_id=None,
                    rgb_ext=None,
                    gt_ext=None,
                    file_ext=None,
                    files=None,
                    force_run=False):
    """
    Read and initialize a collection from a directory, try to create one if it does not exists
    :param clc_name: name of the collection
    :param clc_dir: directory to the collection
    :return: the collection object, assertion error if no process hasn't completed

    :param raw_data_path: path to where the data are stored
    :param field_name: could be name of the cities, or other prefix of the images
    :param field_id: could be id of the tiles, or other suffix of the images
    :param rgb_ext: name extensions that indicates the images are not ground truth, use ',' to separate if you have
                    multiple extensions
    :param gt_ext: name extensions that indicates the images are ground truth, you can only have at most one ground
                   truth extension
    :param file_ext: extension of the files, use ',' to separate if you have multiple extensions, if all the files
                     have the same extension, you only need to specify one
    :param files: files in the raw_data_path, can be specified by user to exclude some of the raw files, if it is
                  None, all files will be found automatically
    :param force_run: force run the collection maker even if it already exists
    """
    if clc_dir is None:
        assert clc_name is not None
        clc_dir = ersa_utils.get_block_dir('data', ['collection', clc_name])
    # check if finish
    if processBlock.BasicProcess('collection_maker', clc_dir).check_finish():
        # read metadata
        meta_data = ersa_utils.load_file(os.path.join(clc_dir, 'meta.pkl'))
        # create collection
        cm = CollectionMaker(meta_data['raw_data_path'],
                             meta_data['field_name'],
                             meta_data['field_id'],
                             meta_data['rgb_ext'],
                             meta_data['gt_ext'],
                             meta_data['file_ext'],
                             meta_data['files'],
                             meta_data['clc_name'],
                             force_run=force_run)
        return cm
    else:
        # try to create the collection
        return CollectionMaker(raw_data_path, field_name, field_id, rgb_ext,
                               gt_ext, file_ext, files, clc_name, force_run)
예제 #28
0
def check_stats(check_fig=False):
    root_dir = r'/home/lab/Documents/bohao/data/aemo_all/align/0584270470{}0_01'
    data_dir = r'/media/ei-edl01/data/aemo/TILES/'
    for fl, p_dir in get_file_list(root_dir):
        rgb_file_dir = os.path.join(data_dir, '/'.join(p_dir.split('/')[8:]))
        pred_save_dir = os.path.join(task_dir, 'aemo_all',
                                     '/'.join(p_dir.split('/')[7:]))
        pred_files = sorted(glob(os.path.join(pred_save_dir, '*.png')))
        conf_files = sorted(glob(os.path.join(pred_save_dir, '*.npy')))
        for file_pred, file_conf in zip(pred_files, conf_files):
            print('Processing file {}'.format(file_pred))
            pred = ersa_utils.load_file(file_pred)
            conf = ersa_utils.load_file(file_conf)
            rgb = ersa_utils.load_file(
                os.path.join(rgb_file_dir,
                             os.path.basename(file_pred)[5:-3]) + 'tif')
            bm = 1 - get_blank_regions(rgb)
            pred = bm * pred
            conf = bm * conf
            if check_fig:
                visualize_utils.compare_three_figure(rgb, pred, conf)
            ersa_utils.save_file(file_pred, pred)
            ersa_utils.save_file(file_conf, conf)
예제 #29
0
def make_patches(files, patch_size, save_dir, overlap=0):
    for f in tqdm(files):
        tile_name = '_'.join(os.path.basename(f).split('_')[:2])
        rgb = ersa_utils.load_file(f)
        h, w, _ = rgb.shape
        grid = patchExtractor.make_grid((h, w), patch_size, overlap)
        file_list = os.path.join(save_dir, 'file_list.txt')
        with open(file_list, 'w+') as f:
            for cnt, patch in enumerate(
                    patchExtractor.patch_block(rgb, overlap // 2, grid,
                                               patch_size)):
                file_name = '{}_{:04d}.jpg'.format(tile_name, cnt)
                ersa_utils.save_file(os.path.join(save_dir, file_name), patch)
                f.write('{}\n'.format(os.path.join(save_dir, file_name)))
예제 #30
0
def read_data(data_dir, city_name):
    gt_files = get_gt_files2(data_dir, city_name)
    tile_num = len(gt_files)
    n_tower = 0
    n_line = 0
    print('City: {}, #Images: {}'.format(city_name, tile_num))
    for gt in gt_files:
        tower_gt = read_tower_truth(gt)
        line_gt = read_lines_truth(gt, tower_gt)
        n_tower += len(tower_gt)
        n_line += len(line_gt)
    print('\t #Tower:{}, #Line:{}'.format(n_tower, n_line))
    city_area = 0
    for rgb in get_rgb_files(data_dir, city_name):
        img = ersa_utils.load_file(rgb)
        city_area += (img.shape[0] * img.shape[1]) * 0.3 * 1e-3 * 0.3 * 1e-3
    print(city_area)