Exemplo n.º 1
0
def main(flags):
    nn_utils.set_gpu(GPU)

    # define network
    model = unet.UNet(flags.num_classes, flags.patch_size, suffix=flags.model_suffix, learn_rate=flags.learning_rate,
                          decay_step=flags.decay_step, decay_rate=flags.decay_rate, epochs=flags.epochs,
                          batch_size=flags.batch_size)
    overlap = model.get_overlap()

    cm = collectionMaker.read_collection(raw_data_path=flags.data_dir,
                                         field_name='austin,chicago,kitsap,tyrol-w,vienna',
                                         field_id=','.join(str(i) for i in range(37)),
                                         rgb_ext='RGB',
                                         gt_ext='GT',
                                         file_ext='tif',
                                         force_run=False,
                                         clc_name=flags.ds_name)
    gt_d255 = collectionEditor.SingleChanMult(cm.clc_dir, 1 / 255, ['GT', 'gt_d255']). \
        run(force_run=False, file_ext='png', d_type=np.uint8, )
    cm.replace_channel(gt_d255.files, True, ['GT', 'gt_d255'])
    cm.print_meta_data()
    file_list_train = cm.load_files(field_id=','.join(str(i) for i in range(6, 37)), field_ext='RGB,gt_d255')
    file_list_valid = cm.load_files(field_id=','.join(str(i) for i in range(6)), field_ext='RGB,gt_d255')
    chan_mean = cm.meta_data['chan_mean'][:3]

    patch_list_train = patchExtractor.PatchExtractor(flags.patch_size, flags.tile_size, flags.ds_name + '_train',
                                                     overlap, overlap // 2). \
        run(file_list=file_list_train, file_exts=['jpg', 'png'], force_run=False).get_filelist()
    patch_list_valid = patchExtractor.PatchExtractor(flags.patch_size, flags.tile_size, flags.ds_name + '_valid',
                                                     overlap, overlap // 2). \
        run(file_list=file_list_valid, file_exts=['jpg', 'png'], force_run=False).get_filelist()

    train_init_op, valid_init_op, reader_op = \
        dataReaderSegmentation.DataReaderSegmentationTrainValid(
            flags.patch_size, patch_list_train, patch_list_valid, batch_size=flags.batch_size, chan_mean=chan_mean,
            aug_func=[reader_utils.image_flipping, reader_utils.image_rotating],
            random=True, has_gt=True, gt_dim=1, include_gt=True, valid_mult=flags.val_mult).read_op()
    feature, label = reader_op

    model.create_graph(feature)
    model.compile(feature, label, flags.n_train, flags.n_valid, flags.patch_size, ersaPath.PATH['model'],
                  par_dir=flags.model_par_dir, val_mult=flags.val_mult, loss_type='xent')
    train_hook = hook.ValueSummaryHook(flags.verb_step, [model.loss, model.lr_op],
                                       value_names=['train_loss', 'learning_rate'], print_val=[0])
    model_save_hook = hook.ModelSaveHook(model.get_epoch_step()*flags.save_epoch, model.ckdir)
    valid_loss_hook = hook.ValueSummaryHook(model.get_epoch_step(), [model.loss, model.loss_iou],
                                            value_names=['valid_loss', 'valid_mIoU'], log_time=True,
                                            run_time=model.n_valid, iou_pos=1)
    image_hook = hook.ImageValidSummaryHook(model.input_size, model.get_epoch_step(), feature, label, model.output,
                                            nn_utils.image_summary, img_mean=cm.meta_data['chan_mean'])
    start_time = time.time()
    model.train(train_hooks=[train_hook, model_save_hook], valid_hooks=[valid_loss_hook, image_hook],
                train_init=train_init_op, valid_init=valid_init_op)
    print('Duration: {:.3f}'.format((time.time() - start_time)/3600))
Exemplo n.º 2
0
nn_utils.set_gpu(gpu)

# define network
unet = pspnet.PSPNet(class_num, patch_size, suffix=suffix, learn_rate=lr, decay_step=ds, decay_rate=dr,
                     epochs=epochs, batch_size=bs, weight_decay=1e-3)
overlap = unet.get_overlap()

cm = collectionMaker.read_collection(raw_data_path=r'/home/lab/Documents/bohao/data/aemo',
                                     field_name='aus10,aus30,aus50',
                                     field_id='',
                                     rgb_ext='.*rgb',
                                     gt_ext='.*gt',
                                     file_ext='tif',
                                     force_run=False,
                                     clc_name='aemo')
gt_d255 = collectionEditor.SingleChanMult(cm.clc_dir, 1/255, ['.*gt', 'gt_d255']).\
    run(force_run=False, file_ext='tif', d_type=np.uint8,)
cm.replace_channel(gt_d255.files, True, ['gt', 'gt_d255'])
# hist matching
ref_file = r'/media/ei-edl01/data/uab_datasets/spca/data/Original_Tiles/Fresno1_RGB.jpg'
ga = histMatching.HistMatching(ref_file, color_space='RGB', ds_name=suffix)
file_list = [f[0] for f in cm.meta_data['rgb_files']]
hist_match = ga.run(force_run=False, file_list=file_list)
cm.add_channel(hist_match.get_files(), '.*rgb_hist')
cm.print_meta_data()

file_list_train = cm.load_files(field_name='aus10,aus30', field_id='', field_ext='.*rgb_hist,.*gt_d255')
file_list_valid = cm.load_files(field_name='aus50', field_id='', field_ext='.*rgb_hist,.*gt_d255')
chan_mean = cm.meta_data['chan_mean'][-3:]

patch_list_train = patchExtractor.PatchExtractor(patch_size, tile_size, ds_name+'_train', overlap, overlap//2).\
    run(file_list=file_list_train, file_exts=['jpg', 'png'], force_run=False).get_filelist()