Ejemplo n.º 1
0
def get_ious():
    conf_dir = os.path.join(task_dir, 'conf_map_{}'.format(model_name))
    conf_files = sorted(glob(os.path.join(conf_dir, '*.npy')))

    cm = collectionMaker.read_collection('aemo_pad')
    truth_files = cm.load_files(field_name='aus50',
                                field_id='',
                                field_ext='.*gt_d255')
    truth_files = [f[0] for f in truth_files[:2]]
    '''uniq_vals = []
    for conf, truth in zip(conf_files, truth_files):
        c = ersa_utils.load_file(conf)

        uniq_vals.append(np.unique(c.flatten()))
    uniq_vals = np.sort(np.unique(np.concatenate(uniq_vals)))

    ious_a = np.zeros(len(uniq_vals))
    ious_b = np.zeros(len(uniq_vals))'''

    uniq_vals = np.linspace(0, 1, 1000)
    ious_a = np.zeros(len(uniq_vals))
    ious_b = np.zeros(len(uniq_vals))

    for conf, truth in zip(conf_files, truth_files):
        c = ersa_utils.load_file(conf)
        t = ersa_utils.load_file(truth)

        for cnt, th in enumerate(tqdm(uniq_vals)):
            c_th = (c > th).astype(np.int)

            a, b = nn_utils.iou_metric(c_th, t, truth_val=1, divide_flag=True)
            ious_a[cnt] = a
            ious_b[cnt] = b
    return np.stack([uniq_vals, ious_a, ious_b], axis=0)
Ejemplo n.º 2
0
def main(flags):
    nn_utils.set_gpu(GPU)

    # define network
    model = unet.UNet(flags.num_classes, flags.patch_size, suffix=flags.model_suffix, learn_rate=flags.learning_rate,
                          decay_step=flags.decay_step, decay_rate=flags.decay_rate, epochs=flags.epochs,
                          batch_size=flags.batch_size)
    overlap = model.get_overlap()

    cm = collectionMaker.read_collection(raw_data_path=flags.data_dir,
                                         field_name='austin,chicago,kitsap,tyrol-w,vienna',
                                         field_id=','.join(str(i) for i in range(37)),
                                         rgb_ext='RGB',
                                         gt_ext='GT',
                                         file_ext='tif',
                                         force_run=False,
                                         clc_name=flags.ds_name)
    gt_d255 = collectionEditor.SingleChanMult(cm.clc_dir, 1 / 255, ['GT', 'gt_d255']). \
        run(force_run=False, file_ext='png', d_type=np.uint8, )
    cm.replace_channel(gt_d255.files, True, ['GT', 'gt_d255'])
    cm.print_meta_data()
    file_list_train = cm.load_files(field_id=','.join(str(i) for i in range(6, 37)), field_ext='RGB,gt_d255')
    file_list_valid = cm.load_files(field_id=','.join(str(i) for i in range(6)), field_ext='RGB,gt_d255')
    chan_mean = cm.meta_data['chan_mean'][:3]

    patch_list_train = patchExtractor.PatchExtractor(flags.patch_size, flags.tile_size, flags.ds_name + '_train',
                                                     overlap, overlap // 2). \
        run(file_list=file_list_train, file_exts=['jpg', 'png'], force_run=False).get_filelist()
    patch_list_valid = patchExtractor.PatchExtractor(flags.patch_size, flags.tile_size, flags.ds_name + '_valid',
                                                     overlap, overlap // 2). \
        run(file_list=file_list_valid, file_exts=['jpg', 'png'], force_run=False).get_filelist()

    train_init_op, valid_init_op, reader_op = \
        dataReaderSegmentation.DataReaderSegmentationTrainValid(
            flags.patch_size, patch_list_train, patch_list_valid, batch_size=flags.batch_size, chan_mean=chan_mean,
            aug_func=[reader_utils.image_flipping, reader_utils.image_rotating],
            random=True, has_gt=True, gt_dim=1, include_gt=True, valid_mult=flags.val_mult).read_op()
    feature, label = reader_op

    model.create_graph(feature)
    model.compile(feature, label, flags.n_train, flags.n_valid, flags.patch_size, ersaPath.PATH['model'],
                  par_dir=flags.model_par_dir, val_mult=flags.val_mult, loss_type='xent')
    train_hook = hook.ValueSummaryHook(flags.verb_step, [model.loss, model.lr_op],
                                       value_names=['train_loss', 'learning_rate'], print_val=[0])
    model_save_hook = hook.ModelSaveHook(model.get_epoch_step()*flags.save_epoch, model.ckdir)
    valid_loss_hook = hook.ValueSummaryHook(model.get_epoch_step(), [model.loss, model.loss_iou],
                                            value_names=['valid_loss', 'valid_mIoU'], log_time=True,
                                            run_time=model.n_valid, iou_pos=1)
    image_hook = hook.ImageValidSummaryHook(model.input_size, model.get_epoch_step(), feature, label, model.output,
                                            nn_utils.image_summary, img_mean=cm.meta_data['chan_mean'])
    start_time = time.time()
    model.train(train_hooks=[train_hook, model_save_hook], valid_hooks=[valid_loss_hook, image_hook],
                train_init=train_init_op, valid_init=valid_init_op)
    print('Duration: {:.3f}'.format((time.time() - start_time)/3600))
Ejemplo n.º 3
0
 def __init__(self, path, switch_dict, field_ext_pair, name):
     """
     :param path: directory of the collection
     :param mult_factor: constant number
     :param field_ext_pair: a list where the first element is the field extension to be operated on and the second
                            field extension is the name of the new field
     """
     func = self.process
     self.switch_dict = switch_dict
     self.field_ext_pair = field_ext_pair
     self.clc = collectionMaker.read_collection(clc_dir=path)
     path = ersa_utils.get_block_dir('data', ['preprocess', os.path.basename(path), name])
     self.files = []
     super().__init__(name, path, func)
Ejemplo n.º 4
0
 def __init__(self, path, mult_factor, field_ext_pair):
     """
     :param path: directory of the collection
     :param mult_factor: constant number
     :param field_ext_pair: a list where the first element is the field extension to be operated on and the second
                            field extension is the name of the new field
     """
     if mult_factor >= 1:
         name = 'chan_mult_{:.5f}'.format(mult_factor).replace('.', 'p')
     else:
         name = 'chan_mult_{}'.format(str(Fraction(mult_factor).limit_denominator()).replace('/', 'd'))
     func = self.process
     self.mult_factor = mult_factor
     self.field_ext_pair = field_ext_pair
     self.clc = collectionMaker.read_collection(clc_dir=path)
     path = ersa_utils.get_block_dir('data', ['preprocess', os.path.basename(path), name])
     self.files = []
     super().__init__(name, path, func)
import uab_collectionFunctions
from nn import unet, nn_utils
from collection import collectionMaker

class_num = 2
patch_size = (572, 572)
tile_size = (5000, 5000)
batch_size = 1
gpu = 0
model_dir = r'/hdd6/Models/Inria_decay/UnetCrop_inria_decay_0_PS(572, 572)_BS5_EP100_LR0.0001_DS60.0_DR0.1_SFN32'

cm = collectionMaker.read_collection('inria')
cm.print_meta_data()

file_list_valid = cm.load_files(field_id=','.join(str(i) for i in range(5)),
                                field_ext='RGB,gt_d255')
chan_mean = cm.meta_data['chan_mean']

blCol = uab_collectionFunctions.uabCollection('inria')
blCol.readMetadata()
#chan_mean = blCol.getChannelMeans([0, 1, 2])

nn_utils.tf_warn_level(3)
model = unet.UNet(class_num, patch_size)

model.evaluate(file_list_valid,
               patch_size,
               tile_size,
               batch_size,
               chan_mean,
               model_dir,
Ejemplo n.º 6
0
        return tile_dict, field_dict, overall


# settings
class_num = 2
patch_size = (572, 572)
tile_size = (5000, 5000)
suffix = 'aemo_hist'
bs = 5
gpu = 0

# define network
unet = unet.UNet(class_num, patch_size, suffix=suffix, batch_size=bs)
overlap = unet.get_overlap()

cm = collectionMaker.read_collection('aemo_hist2')
cm.print_meta_data()

file_list_train = cm.load_files(field_name='aus10,aus30',
                                field_id='',
                                field_ext='.*rgb,.*gt')
file_list_valid = cm.load_files(field_name='aus50',
                                field_id='',
                                field_ext='.*rgb,.*gt')
chan_mean = cm.meta_data['chan_mean']

nn_utils.tf_warn_level(3)
estimator = myNNEstimatorSegment(unet,
                                 file_list_valid,
                                 patch_size,
                                 tile_size,
Ejemplo n.º 7
0
import cv2
import imageio
import numpy as np
import matplotlib.pyplot as plt
from collection import collectionMaker as cm


def get_contour(image, contour_length=5):
    _, contours, _ = cv2.findContours(image, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
    mask = np.zeros(image.shape, np.uint8)
    cv2.drawContours(mask, contours, -1, 1, contour_length)
    return mask


if __name__ == '__main__':
    ds = cm.read_collection('Inria')
    ds.print_meta_data()
    gt_file = ds.load_files(field_name='austin', field_id='1', field_ext='gt_d255')

    gt = imageio.imread(gt_file)
    #plt.imshow(gt)
    #plt.show()

    mask = get_contour(gt)
    plt.imshow(mask)
    plt.show()
Ejemplo n.º 8
0
            block_size = 100
            tf.reset_default_graph()
            record_matrix = []

            img_dir, task_dir = sis_utils.get_task_img_folder()
            save_file_name = os.path.join(
                task_dir,
                'corr_{}{}_ps{}_bs{}.npy'.format(field_name, field_id,
                                                 patch_size, block_size))

            blCol = uab_collectionFunctions.uabCollection('inria')
            img_mean = blCol.getChannelMeans([0, 1, 2])

            # load data
            if not os.path.exists(save_file_name) or force_run:
                clc = cm.read_collection(clc_name='Inria')
                clc.print_meta_data()
                files = clc.load_files(field_name=field_name,
                                       field_id=field_id,
                                       field_ext='RGB,gt_d255')
                rgb = ersa_utils.load_file(files[0][0])
                gt = ersa_utils.load_file(files[0][1])
                reader = overlap_reader(rgb, gt, y, x, patch_size, stride,
                                        block_size)

                # make the model
                # define place holder
                X = tf.placeholder(tf.float32,
                                   shape=[None, patch_size, patch_size, 3],
                                   name='X')
                y = tf.placeholder(tf.int32,
Ejemplo n.º 9
0
                '=', 50, 'Overall={:.2f}'.format(overall * 100)))
        return tile_dict, field_dict, overall


# define network
patch_size = (572, 572)
model = unet.UNet(class_num, patch_size, suffix=suffix, batch_size=bs)
overlap = model.get_overlap()
data_dir = r'/media/ei-edl01/data/uab_datasets/infrastructure/data/Original_Tiles'
model_dir = r'/hdd6/Models/infrastructure/unet_5objs_weight100_PS(572, 572)_BS5_EP100_LR0.0001_DS60_DR0.1'

cm = collectionMaker.read_collection(
    raw_data_path=data_dir,
    field_name='Tucson,Colwich,Clyde,Wilmington',
    field_id=','.join(str(i) for i in range(1, 16)),
    rgb_ext='RGB',
    gt_ext='GT',
    file_ext='tif,png',
    force_run=False,
    clc_name=suffix)
gt_d255 = collectionEditor.SingleChanSwitch(cm.clc_dir, {
    2: 0,
    3: 1,
    4: 0,
    5: 0,
    6: 0,
    7: 0
}, ['GT', 'GT_switch'], 'tower_only').run(
    force_run=False,
    file_ext='png',
    d_type=np.uint8,
Ejemplo n.º 10
0
        cdfsrc = im_hist_s.cumsum()
        cdfsrc = (255 * cdfsrc / cdfsrc[-1]).astype(np.uint8)
        cdftint = im_hist_t.cumsum()
        cdftint = (255 * cdftint / cdftint[-1]).astype(np.uint8)

        im2 = np.interp(img_s[:, :, d].flatten(), bins[:-1], cdfsrc)
        im3 = np.interp(im2, cdftint, bins[:-1])
        im_res[:, :, d] = im3.reshape((img_s.shape[0], img_s.shape[1]))
    return im_res


if __name__ == '__main__':
    img_dir, task_dir = sis_utils.get_task_img_folder()

    # get aemo stats
    cm = collectionMaker.read_collection('aemo_pad')
    cm.print_meta_data()

    aemo_files = cm.load_files(field_name='aus10,aus30,aus50', field_id='', field_ext='.*rgb')
    rgb_files = aemo_files[:6]

    aemo = np.zeros((3, 255))
    for rgb_file in rgb_files:
        rgb = ersa_utils.load_file(rgb_file[0])
        for c in range(3):
            rgb_cnt, _ = np.histogram(rgb[:, :, c], bins=np.arange(256))
            aemo[c, :] += rgb_cnt
    aemo = aemo / len(rgb_files)
    aemo[:, :2] = 0
    aemo[:, -1] = aemo[:, -2]
Ejemplo n.º 11
0
]

# define network
for model_dir in model_dirs:
    tf.reset_default_graph()
    model = pspnet.PSPNet(class_num,
                          patch_size,
                          suffix=suffix,
                          batch_size=batch_size)
    overlap = model.get_overlap()

    cm = collectionMaker.read_collection(raw_data_path=data_dir,
                                         field_name='Fresno,Modesto,Stockton',
                                         field_id=','.join(
                                             str(i) for i in range(663)),
                                         rgb_ext='RGB',
                                         gt_ext='GT',
                                         file_ext='jpg,png',
                                         force_run=False,
                                         clc_name=ds_name)
    cm.print_meta_data()
    file_list_train = cm.load_files(field_id=','.join(
        str(i) for i in range(0, 250)),
                                    field_ext='RGB,GT')
    file_list_valid = cm.load_files(field_id=','.join(
        str(i) for i in range(250, 500)),
                                    field_ext='RGB,GT')
    chan_mean = cm.meta_data['chan_mean'][:3]

    nn_utils.tf_warn_level(3)
    model.evaluate(file_list_valid,
Ejemplo n.º 12
0
# define network
unet = deeplab.DeepLab(class_num,
                       patch_size,
                       suffix=suffix,
                       learn_rate=lr,
                       decay_step=ds,
                       decay_rate=dr,
                       epochs=epochs,
                       batch_size=bs)
overlap = unet.get_overlap()

cm = collectionMaker.read_collection(
    raw_data_path=
    r'/media/ei-edl01/data/uab_datasets/inria/data/Original_Tiles',
    field_name='austin,chicago,kitsap,tyrol-w,vienna',
    field_id=','.join(str(i) for i in range(37)),
    rgb_ext='RGB',
    gt_ext='GT',
    file_ext='tif',
    force_run=False,
    clc_name=ds_name)
gt_d255 = collectionEditor.SingleChanMult(cm.clc_dir, 1/255, ['GT', 'gt_d255']).\
    run(force_run=False, file_ext='png', d_type=np.uint8,)
cm.replace_channel(gt_d255.files, True, ['GT', 'gt_d255'])
cm.print_meta_data()
file_list_train = cm.load_files(field_id=','.join(
    str(i) for i in range(6, 37)),
                                field_ext='RGB,gt_d255')
file_list_valid = cm.load_files(field_id=','.join(str(i) for i in range(6)),
                                field_ext='RGB,gt_d255')
chan_mean = cm.meta_data['chan_mean'][:3]
Ejemplo n.º 13
0
lr = 1e-4

nn_utils.set_gpu(gpu)

# define network
model = unet.UNet(class_num,
                  patch_size,
                  suffix=suffix,
                  learn_rate=lr,
                  decay_step=ds,
                  decay_rate=dr,
                  epochs=epochs,
                  batch_size=bs)
overlap = model.get_overlap()

cm = collectionMaker.read_collection(ds_name)
cm.print_meta_data()

file_list_train = cm.load_files(field_id=','.join(
    str(i) for i in range(6, 37)),
                                field_ext='RGB,gt_d255')
file_list_valid = cm.load_files(field_id=','.join(str(i) for i in range(5)),
                                field_ext='RGB,gt_d255')

patch_list_train = patchExtractor.PatchExtractor(patch_size, tile_size, ds_name+'_train',
                                                 overlap, overlap//2).\
    run(file_list=file_list_train, file_exts=['jpg', 'png'], force_run=False).get_filelist()
patch_list_valid = patchExtractor.PatchExtractor(patch_size, tile_size, ds_name+'_valid',
                                                 overlap, overlap//2).\
    run(file_list=file_list_valid, file_exts=['jpg', 'png'], force_run=False).get_filelist()
chan_mean = cm.meta_data['chan_mean']
Ejemplo n.º 14
0
def main(flags):
    nn_utils.set_gpu(flags.GPU)
    for start_layer in flags.start_layer:
        if start_layer >= 10:
            suffix_base = 'aemo_newloss'
        else:
            suffix_base = 'aemo_newloss_up{}'.format(start_layer)
        if flags.from_scratch:
            suffix_base += '_scratch'
        for lr in flags.learn_rate:
            for run_id in range(4):
                suffix = '{}_{}'.format(suffix_base, run_id)
                tf.reset_default_graph()

                np.random.seed(run_id)
                tf.set_random_seed(run_id)

                # define network
                model = unet.UNet(flags.num_classes, flags.patch_size, suffix=suffix, learn_rate=lr,
                                  decay_step=flags.decay_step, decay_rate=flags.decay_rate,
                                  epochs=flags.epochs, batch_size=flags.batch_size)
                overlap = model.get_overlap()

                cm = collectionMaker.read_collection(raw_data_path=flags.data_dir,
                                                     field_name='aus10,aus30,aus50',
                                                     field_id='',
                                                     rgb_ext='.*rgb',
                                                     gt_ext='.*gt',
                                                     file_ext='tif',
                                                     force_run=False,
                                                     clc_name=flags.ds_name)
                cm.print_meta_data()

                file_list_train = cm.load_files(field_name='aus10,aus30', field_id='', field_ext='.*rgb,.*gt')
                file_list_valid = cm.load_files(field_name='aus50', field_id='', field_ext='.*rgb,.*gt')

                patch_list_train = patchExtractor.PatchExtractor(flags.patch_size, flags.tile_size,
                                                                 flags.ds_name + '_train_hist',
                                                                 overlap, overlap // 2). \
                    run(file_list=file_list_train, file_exts=['jpg', 'png'], force_run=False).get_filelist()
                patch_list_valid = patchExtractor.PatchExtractor(flags.patch_size, flags.tile_size,
                                                                 flags.ds_name + '_valid_hist',
                                                                 overlap, overlap // 2). \
                    run(file_list=file_list_valid, file_exts=['jpg', 'png'], force_run=False).get_filelist()
                chan_mean = cm.meta_data['chan_mean']

                train_init_op, valid_init_op, reader_op = \
                    dataReaderSegmentation.DataReaderSegmentationTrainValid(
                        flags.patch_size, patch_list_train, patch_list_valid, batch_size=flags.batch_size,
                        chan_mean=chan_mean,
                        aug_func=[reader_utils.image_flipping, reader_utils.image_rotating],
                        random=True, has_gt=True, gt_dim=1, include_gt=True, valid_mult=flags.val_mult).read_op()
                feature, label = reader_op

                model.create_graph(feature)
                if start_layer >= 10:
                    model.compile(feature, label, flags.n_train, flags.n_valid, flags.patch_size, ersaPath.PATH['model'],
                                  par_dir=flags.par_dir, loss_type='xent')
                else:
                    model.compile(feature, label, flags.n_train, flags.n_valid, flags.patch_size, ersaPath.PATH['model'],
                                  par_dir=flags.par_dir, loss_type='xent',
                                  train_var_filter=['layerup{}'.format(i) for i in range(start_layer, 10)])
                train_hook = hook.ValueSummaryHook(flags.verb_step, [model.loss, model.lr_op],
                                                   value_names=['train_loss', 'learning_rate'],
                                                   print_val=[0])
                model_save_hook = hook.ModelSaveHook(model.get_epoch_step() * flags.save_epoch, model.ckdir)
                valid_loss_hook = hook.ValueSummaryHookIters(model.get_epoch_step(), [model.loss_xent, model.loss_iou],
                                                             value_names=['valid_loss', 'IoU'], log_time=True,
                                                             run_time=model.n_valid)
                image_hook = hook.ImageValidSummaryHook(model.input_size, model.get_epoch_step(), feature, label,
                                                        model.pred,
                                                        nn_utils.image_summary, img_mean=chan_mean)
                start_time = time.time()
                if not flags.from_scratch:
                    model.load(flags.model_dir)
                model.train(train_hooks=[train_hook, model_save_hook], valid_hooks=[valid_loss_hook, image_hook],
                            train_init=train_init_op, valid_init=valid_init_op)
                print('Duration: {:.3f}'.format((time.time() - start_time) / 3600))
Ejemplo n.º 15
0
                patch = np.expand_dims(patch, axis=0)
                fc1000 = fc2048.predict(patch).reshape((-1, )).tolist()
                writer = csv.writer(f, lineterminator='\n')
                writer.writerow(['{}'.format(x) for x in fc1000])
                f2.write('{}\n'.format(patch_name))


if __name__ == '__main__':
    # settings
    patch_size = (572, 572)
    tile_size = (5000, 5000)
    np.random.seed(1004)
    gpu = 0
    img_dir, task_dir = sis_utils.get_task_img_folder()
    use_hist = False
    cm = collectionMaker.read_collection('spca')
    cm.print_meta_data()
    chan_mean = cm.meta_data['chan_mean'][:3]

    file_list = cm.load_files(field_id=','.join(str(i) for i in range(0, 663)),
                              field_ext='RGB,GT')
    patch_list_train = patchExtractor.PatchExtractor(patch_size, tile_size, 'spca_all',
                                                     184, 184 // 2). \
        run(file_list=file_list, file_exts=['jpg', 'png'], force_run=False).get_filelist()

    feature_dir = os.path.join(task_dir, 'spca_patches')
    dr = data_reader(patch_list_train, chan_mean)
    ersa_utils.make_dir_if_not_exist(feature_dir)
    processBlock.BasicProcess(
        'make_feature',
        feature_dir,
Ejemplo n.º 16
0
def main(flags):
    nn_utils.set_gpu(flags.GPU)
    np.random.seed(flags.run_id)
    tf.set_random_seed(flags.run_id)

    # define network
    model = unet.UNet(flags.num_classes,
                      flags.patch_size,
                      suffix=flags.suffix,
                      learn_rate=flags.learn_rate,
                      decay_step=flags.decay_step,
                      decay_rate=flags.decay_rate,
                      epochs=flags.epochs,
                      batch_size=flags.batch_size)
    overlap = model.get_overlap()

    cm = collectionMaker.read_collection(
        raw_data_path=flags.data_dir,
        field_name='Tucson,Colwich,Clyde,Wilmington',
        field_id=','.join(str(i) for i in range(1, 16)),
        rgb_ext='RGB',
        gt_ext='GT',
        file_ext='tif,png',
        force_run=False,
        clc_name=flags.ds_name)
    gt_d255 = collectionEditor.SingleChanSwitch(cm.clc_dir, {
        2: 0,
        3: 1,
        4: 0,
        5: 0,
        6: 0,
        7: 0
    }, ['GT', 'GT_switch'], 'tower_only').run(
        force_run=False,
        file_ext='png',
        d_type=np.uint8,
    )
    cm.replace_channel(gt_d255.files, True, ['GT', 'GT_switch'])
    cm.print_meta_data()

    file_list_train = cm.load_files(
        field_name='Tucson,Colwich,Clyde,Wilmington',
        field_id=','.join(str(i) for i in range(4, 16)),
        field_ext='RGB,GT_switch')
    file_list_valid = cm.load_files(
        field_name='Tucson,Colwich,Clyde,Wilmington',
        field_id='1,2,3',
        field_ext='RGB,GT_switch')

    patch_list_train = patchExtractor.PatchExtractor(flags.patch_size,
                                                     ds_name=flags.ds_name + '_tower_only',
                                                     overlap=overlap, pad=overlap // 2). \
        run(file_list=file_list_train, file_exts=['jpg', 'png'], force_run=False).get_filelist()
    patch_list_valid = patchExtractor.PatchExtractor(flags.patch_size,
                                                     ds_name=flags.ds_name + '_tower_only',
                                                     overlap=overlap, pad=overlap // 2). \
        run(file_list=file_list_valid, file_exts=['jpg', 'png'], force_run=False).get_filelist()
    chan_mean = cm.meta_data['chan_mean']

    train_init_op, valid_init_op, reader_op = \
        dataReaderSegmentation.DataReaderSegmentationTrainValid(
            flags.patch_size, patch_list_train, patch_list_valid, batch_size=flags.batch_size,
            chan_mean=chan_mean,
            aug_func=[reader_utils.image_flipping, reader_utils.image_rotating],
            random=True, has_gt=True, gt_dim=1, include_gt=True, valid_mult=flags.val_mult).read_op()
    feature, label = reader_op

    model.create_graph(feature)
    model.compile(feature,
                  label,
                  flags.n_train,
                  flags.n_valid,
                  flags.patch_size,
                  ersaPath.PATH['model'],
                  par_dir=flags.par_dir,
                  loss_type='xent',
                  pos_weight=flags.pos_weight)
    train_hook = hook.ValueSummaryHook(
        flags.verb_step, [model.loss, model.lr_op],
        value_names=['train_loss', 'learning_rate'],
        print_val=[0])
    model_save_hook = hook.ModelSaveHook(
        model.get_epoch_step() * flags.save_epoch, model.ckdir)
    valid_loss_hook = hook.ValueSummaryHookIters(
        model.get_epoch_step(), [model.loss_iou, model.loss_xent],
        value_names=['IoU', 'valid_loss'],
        log_time=True,
        run_time=model.n_valid)
    image_hook = hook.ImageValidSummaryHook(model.input_size,
                                            model.get_epoch_step(),
                                            feature,
                                            label,
                                            model.pred,
                                            partial(
                                                nn_utils.image_summary,
                                                label_num=flags.num_classes),
                                            img_mean=chan_mean)
    start_time = time.time()
    model.train(train_hooks=[train_hook, model_save_hook],
                valid_hooks=[valid_loss_hook, image_hook],
                train_init=train_init_op,
                valid_init=valid_init_op)
    print('Duration: {:.3f}'.format((time.time() - start_time) / 3600))
Ejemplo n.º 17
0
def main(flags):
    nn_utils.set_gpu(GPU)

    # define network
    model = pspnet.PSPNet(flags.num_classes,
                          flags.patch_size,
                          suffix=flags.model_suffix,
                          learn_rate=flags.learning_rate,
                          decay_step=flags.decay_step,
                          decay_rate=flags.decay_rate,
                          epochs=flags.epochs,
                          batch_size=flags.batch_size,
                          weight_decay=flags.weight_decay,
                          momentum=flags.momentum)
    overlap = model.get_overlap()

    cm = collectionMaker.read_collection(raw_data_path=flags.data_dir,
                                         field_name='Fresno,Modesto,Stockton',
                                         field_id=','.join(
                                             str(i) for i in range(663)),
                                         rgb_ext='RGB',
                                         gt_ext='GT',
                                         file_ext='jpg,png',
                                         force_run=False,
                                         clc_name=flags.ds_name)
    cm.print_meta_data()
    file_list_train = cm.load_files(field_id=','.join(
        str(i) for i in range(0, 250)),
                                    field_ext='RGB,GT')
    file_list_valid = cm.load_files(field_id=','.join(
        str(i) for i in range(250, 500)),
                                    field_ext='RGB,GT')
    chan_mean = cm.meta_data['chan_mean'][:3]

    patch_list_train = patchExtractor.PatchExtractor(flags.patch_size, flags.tile_size, flags.ds_name + '_train',
                                                     overlap, overlap // 2). \
        run(file_list=file_list_train, file_exts=['jpg', 'png'], force_run=False).get_filelist()
    patch_list_valid = patchExtractor.PatchExtractor(flags.patch_size, flags.tile_size, flags.ds_name + '_valid',
                                                     overlap, overlap // 2). \
        run(file_list=file_list_valid, file_exts=['jpg', 'png'], force_run=False).get_filelist()

    train_init_op, valid_init_op, reader_op = \
        dataReaderSegmentation.DataReaderSegmentationTrainValid(
            flags.patch_size, patch_list_train, patch_list_valid, batch_size=flags.batch_size, chan_mean=chan_mean,
            aug_func=[reader_utils.image_flipping, reader_utils.image_rotating],
            random=True, has_gt=True, gt_dim=1, include_gt=True, valid_mult=flags.val_mult).read_op()
    feature, label = reader_op

    model.create_graph(feature)
    model.load_resnet(flags.res_dir, keep_last=False)
    model.compile(feature,
                  label,
                  flags.n_train,
                  flags.n_valid,
                  flags.patch_size,
                  ersaPath.PATH['model'],
                  par_dir=flags.model_par_dir,
                  val_mult=flags.val_mult,
                  loss_type='xent')
    train_hook = hook.ValueSummaryHook(
        flags.verb_step, [model.loss, model.lr_op],
        value_names=['train_loss', 'learning_rate'],
        print_val=[0])
    model_save_hook = hook.ModelSaveHook(
        model.get_epoch_step() * flags.save_epoch, model.ckdir)
    valid_loss_hook = hook.ValueSummaryHookIters(
        model.get_epoch_step(), [model.loss_xent, model.loss_iou],
        value_names=['valid_loss', 'valid_mIoU'],
        log_time=True,
        run_time=model.n_valid)
    image_hook = hook.ImageValidSummaryHook(model.input_size,
                                            model.get_epoch_step(),
                                            feature,
                                            label,
                                            model.output,
                                            nn_utils.image_summary,
                                            img_mean=cm.meta_data['chan_mean'])
    start_time = time.time()
    model.train(train_hooks=[train_hook, model_save_hook],
                valid_hooks=[valid_loss_hook, image_hook],
                train_init=train_init_op,
                valid_init=valid_init_op)
    print('Duration: {:.3f}'.format((time.time() - start_time) / 3600))
Ejemplo n.º 18
0
from nn import unet, nn_utils
from collection import collectionMaker

# settings
class_num = 2
tile_size = (5000, 5000)
suffix = 'aemo_hist'
bs = 5
gpu = 0
temp_name = r'/work/bh163/misc/aemo_temp'

cm = collectionMaker.read_collection(
    raw_data_path=r'/work/bh163/data/aemo_hist',
    field_name='aus10,aus30,aus50',
    field_id='',
    rgb_ext='.*rgb',
    gt_ext='.*gt',
    file_ext='tif',
    force_run=False,
    clc_name=suffix)
cm.print_meta_data()

file_list_train = cm.load_files(field_name='aus10,aus30',
                                field_id='',
                                field_ext='.*rgb,.*gt')
file_list_valid = cm.load_files(field_name='aus50',
                                field_id='',
                                field_ext='.*rgb,.*gt')
chan_mean = cm.meta_data['chan_mean']

nn_utils.tf_warn_level(3)
Ejemplo n.º 19
0
n_valid = 500
verb_step = 50
save_epoch = 5
model_dir = r'/hdd6/Models/spca/psp101/pspnet_spca_PS(384, 384)_BS5_EP100_LR0.001_DS40_DR0.1'

nn_utils.set_gpu(gpu)

# define network
unet = pspnet.PSPNet(class_num, patch_size, suffix=suffix, learn_rate=lr, decay_step=ds, decay_rate=dr,
                     epochs=epochs, batch_size=bs, weight_decay=1e-3)
overlap = unet.get_overlap()

cm = collectionMaker.read_collection(raw_data_path=r'/home/lab/Documents/bohao/data/aemo',
                                     field_name='aus10,aus30,aus50',
                                     field_id='',
                                     rgb_ext='.*rgb',
                                     gt_ext='.*gt',
                                     file_ext='tif',
                                     force_run=False,
                                     clc_name='aemo')
gt_d255 = collectionEditor.SingleChanMult(cm.clc_dir, 1/255, ['.*gt', 'gt_d255']).\
    run(force_run=False, file_ext='tif', d_type=np.uint8,)
cm.replace_channel(gt_d255.files, True, ['gt', 'gt_d255'])
# hist matching
ref_file = r'/media/ei-edl01/data/uab_datasets/spca/data/Original_Tiles/Fresno1_RGB.jpg'
ga = histMatching.HistMatching(ref_file, color_space='RGB', ds_name=suffix)
file_list = [f[0] for f in cm.meta_data['rgb_files']]
hist_match = ga.run(force_run=False, file_list=file_list)
cm.add_channel(hist_match.get_files(), '.*rgb_hist')
cm.print_meta_data()

file_list_train = cm.load_files(field_name='aus10,aus30', field_id='', field_ext='.*rgb_hist,.*gt_d255')
Ejemplo n.º 20
0
def main(flags):
    nn_utils.set_gpu(flags.GPU)
    for start_layer in flags.start_layer:
        if start_layer >= 10:
            suffix_base = 'aemo'
        else:
            suffix_base = 'aemo_up{}'.format(start_layer)
        if flags.from_scratch:
            suffix_base += '_scratch'
        for lr in flags.learn_rate:
            for run_id in range(1):
                suffix = '{}_{}'.format(suffix_base, run_id)
                tf.reset_default_graph()

                np.random.seed(run_id)
                tf.set_random_seed(run_id)

                # define network
                model = unet.UNet(flags.num_classes,
                                  flags.patch_size,
                                  suffix=suffix,
                                  learn_rate=lr,
                                  decay_step=flags.decay_step,
                                  decay_rate=flags.decay_rate,
                                  epochs=flags.epochs,
                                  batch_size=flags.batch_size)

                file_list = os.path.join(flags.data_dir, 'file_list.txt')
                lines = ersa_utils.load_file(file_list)

                patch_list_train = []
                patch_list_valid = []
                train_tile_names = ['aus10', 'aus30']
                valid_tile_names = ['aus50']

                for line in lines:
                    tile_name = os.path.basename(
                        line.split(' ')[0]).split('_')[0].strip()
                    if tile_name in train_tile_names:
                        patch_list_train.append(line.strip().split(' '))
                    elif tile_name in valid_tile_names:
                        patch_list_valid.append(line.strip().split(' '))
                    else:
                        raise ValueError

                cm = collectionMaker.read_collection('aemo_align')
                chan_mean = cm.meta_data['chan_mean']

                train_init_op, valid_init_op, reader_op = \
                    dataReaderSegmentation.DataReaderSegmentationTrainValid(
                        flags.patch_size, patch_list_train, patch_list_valid, batch_size=flags.batch_size,
                        chan_mean=chan_mean,
                        aug_func=[reader_utils.image_flipping, reader_utils.image_rotating],
                        random=True, has_gt=True, gt_dim=1, include_gt=True, valid_mult=flags.val_mult).read_op()
                feature, label = reader_op

                model.create_graph(feature)
                if start_layer >= 10:
                    model.compile(feature,
                                  label,
                                  flags.n_train,
                                  flags.n_valid,
                                  flags.patch_size,
                                  ersaPath.PATH['model'],
                                  par_dir=flags.par_dir,
                                  loss_type='xent')
                else:
                    model.compile(feature,
                                  label,
                                  flags.n_train,
                                  flags.n_valid,
                                  flags.patch_size,
                                  ersaPath.PATH['model'],
                                  par_dir=flags.par_dir,
                                  loss_type='xent',
                                  train_var_filter=[
                                      'layerup{}'.format(i)
                                      for i in range(start_layer, 10)
                                  ])
                train_hook = hook.ValueSummaryHook(
                    flags.verb_step, [model.loss, model.lr_op],
                    value_names=['train_loss', 'learning_rate'],
                    print_val=[0])
                model_save_hook = hook.ModelSaveHook(
                    model.get_epoch_step() * flags.save_epoch, model.ckdir)
                valid_loss_hook = hook.ValueSummaryHookIters(
                    model.get_epoch_step(), [model.loss_xent, model.loss_iou],
                    value_names=['valid_loss', 'IoU'],
                    log_time=True,
                    run_time=model.n_valid)
                image_hook = hook.ImageValidSummaryHook(model.input_size,
                                                        model.get_epoch_step(),
                                                        feature,
                                                        label,
                                                        model.pred,
                                                        nn_utils.image_summary,
                                                        img_mean=chan_mean)
                start_time = time.time()
                if not flags.from_scratch:
                    model.load(flags.model_dir)
                model.train(train_hooks=[train_hook, model_save_hook],
                            valid_hooks=[valid_loss_hook, image_hook],
                            train_init=train_init_op,
                            valid_init=valid_init_op)
                print('Duration: {:.3f}'.format(
                    (time.time() - start_time) / 3600))