Ejemplo n.º 1
0
def main(flags):
    nn_utils.set_gpu(GPU)

    # define network
    model = deeplab.DeepLab(flags.num_classes,
                            flags.patch_size,
                            suffix=flags.model_suffix,
                            learn_rate=flags.learning_rate,
                            decay_step=flags.decay_step,
                            decay_rate=flags.decay_rate,
                            epochs=flags.epochs,
                            batch_size=flags.batch_size)
    overlap = model.get_overlap()

    cm = collectionMaker.read_collection(
        raw_data_path=flags.data_dir,
        field_name='austin,chicago,kitsap,tyrol-w,vienna',
        field_id=','.join(str(i) for i in range(37)),
        rgb_ext='RGB',
        gt_ext='GT',
        file_ext='tif',
        force_run=False,
        clc_name=flags.ds_name)
    gt_d255 = collectionEditor.SingleChanMult(cm.clc_dir, 1 / 255, ['GT', 'gt_d255']). \
        run(force_run=False, file_ext='png', d_type=np.uint8, )
    cm.replace_channel(gt_d255.files, True, ['GT', 'gt_d255'])
    cm.print_meta_data()
    file_list_train = cm.load_files(field_id=','.join(
        str(i) for i in range(6, 37)),
                                    field_ext='RGB,gt_d255')
    file_list_valid = cm.load_files(field_id=','.join(
        str(i) for i in range(6)),
                                    field_ext='RGB,gt_d255')
    chan_mean = cm.meta_data['chan_mean'][:3]

    patch_list_train = patchExtractor.PatchExtractor(flags.patch_size, flags.tile_size, flags.ds_name + '_train',
                                                     overlap, overlap // 2). \
        run(file_list=file_list_train, file_exts=['jpg', 'png'], force_run=False).get_filelist()
    patch_list_valid = patchExtractor.PatchExtractor(flags.patch_size, flags.tile_size, flags.ds_name + '_valid',
                                                     overlap, overlap // 2). \
        run(file_list=file_list_valid, file_exts=['jpg', 'png'], force_run=False).get_filelist()

    train_init_op, valid_init_op, reader_op = \
        dataReaderSegmentation.DataReaderSegmentationTrainValid(
            flags.patch_size, patch_list_train, patch_list_valid, batch_size=flags.batch_size, chan_mean=chan_mean,
            aug_func=[reader_utils.image_flipping, reader_utils.image_rotating],
            random=True, has_gt=True, gt_dim=1, include_gt=True, valid_mult=flags.val_mult).read_op()
    feature, label = reader_op

    model.create_graph(feature)
    model.load_resnet(flags.res_dir)
    model.compile(feature,
                  label,
                  flags.n_train,
                  flags.n_valid,
                  flags.patch_size,
                  ersaPath.PATH['model'],
                  par_dir=flags.model_par_dir,
                  val_mult=flags.val_mult,
                  loss_type='xent')
    train_hook = hook.ValueSummaryHook(
        flags.verb_step, [model.loss, model.lr_op],
        value_names=['train_loss', 'learning_rate'],
        print_val=[0])
    model_save_hook = hook.ModelSaveHook(
        model.get_epoch_step() * flags.save_epoch, model.ckdir)
    valid_loss_hook = hook.ValueSummaryHook(
        model.get_epoch_step(), [model.loss, model.loss_iou],
        value_names=['valid_loss', 'valid_mIoU'],
        log_time=True,
        run_time=model.n_valid,
        iou_pos=1)
    image_hook = hook.ImageValidSummaryHook(model.input_size,
                                            model.get_epoch_step(),
                                            feature,
                                            label,
                                            model.output,
                                            nn_utils.image_summary,
                                            img_mean=cm.meta_data['chan_mean'])
    start_time = time.time()
    model.train(train_hooks=[train_hook, model_save_hook],
                valid_hooks=[valid_loss_hook, image_hook],
                train_init=train_init_op,
                valid_init=valid_init_op)
    print('Duration: {:.3f}'.format((time.time() - start_time) / 3600))
Ejemplo n.º 2
0
def main(flags):
    nn_utils.set_gpu(GPU)

    # define network
    model = deeplab.DeepLab(flags.num_classes,
                            flags.tile_size,
                            suffix=flags.model_suffix,
                            learn_rate=flags.learning_rate,
                            decay_step=flags.decay_step,
                            decay_rate=flags.decay_rate,
                            epochs=flags.epochs,
                            batch_size=flags.batch_size)

    cm_train = cityscapes_reader.CollectionMakerCityscapes(
        flags.data_dir,
        flags.rgb_type,
        flags.gt_type,
        'train',
        flags.rgb_ext,
        flags.gt_ext, ['png', 'png'],
        clc_name='{}_train'.format(flags.ds_name),
        force_run=flags.force_run)
    cm_valid = cityscapes_reader.CollectionMakerCityscapes(
        flags.data_dir,
        flags.rgb_type,
        flags.gt_type,
        'val',
        flags.rgb_ext,
        flags.gt_ext, ['png', 'png'],
        clc_name='{}_valid'.format(flags.ds_name),
        force_run=flags.force_run)
    cm_train.print_meta_data()
    cm_valid.print_meta_data()

    resize_func = lambda img: resize_image(img, flags.tile_size)
    train_init_op, valid_init_op, reader_op = dataReaderSegmentation.DataReaderSegmentationTrainValid(
            flags.tile_size, cm_train.meta_data['file_list'], cm_valid.meta_data['file_list'],
            flags.batch_size, cm_train.meta_data['chan_mean'], aug_func=[reader_utils.image_flipping_hori,
                                                                         reader_utils.image_scaling_with_label],
            random=True, has_gt=True, gt_dim=1, include_gt=True, valid_mult=flags.val_mult, global_func=resize_func)\
        .read_op()
    feature, label = reader_op

    model.create_graph(feature)
    model.load_resnet(flags.res_dir)
    model.compile(feature,
                  label,
                  flags.n_train,
                  flags.n_valid,
                  flags.tile_size,
                  ersaPath.PATH['model'],
                  par_dir=flags.model_par_dir,
                  val_mult=flags.val_mult,
                  loss_type='xent')
    train_hook = hook.ValueSummaryHook(
        flags.verb_step, [model.loss, model.lr_op],
        value_names=['train_loss', 'learning_rate'],
        print_val=[0])
    model_save_hook = hook.ModelSaveHook(
        model.get_epoch_step() * flags.save_epoch, model.ckdir)
    valid_loss_hook = hook.ValueSummaryHook(
        model.get_epoch_step(), [model.loss, model.loss_iou],
        value_names=['valid_loss', 'valid_mIoU'],
        log_time=True,
        run_time=model.n_valid,
        iou_pos=1)
    image_hook = hook.ImageValidSummaryHook(
        model.input_size,
        model.get_epoch_step(),
        feature,
        label,
        model.pred,
        cityscapes_labels.image_summary,
        img_mean=cm_train.meta_data['chan_mean'])
    start_time = time.time()
    model.train(train_hooks=[train_hook, model_save_hook],
                valid_hooks=[valid_loss_hook, image_hook],
                train_init=train_init_op,
                valid_init=valid_init_op)
    print('Duration: {:.3f}'.format((time.time() - start_time) / 3600))
Ejemplo n.º 3
0
# settings
class_num = 2
tile_size = (5000, 5000)
suffix = 'aemo_hist'
bs = 5
gpu = 0
model_name = 'unet'

# define network
if model_name == 'unet':
    patch_size = (572, 572)
    unet = unet.UNet(class_num, patch_size, suffix=suffix, batch_size=bs)
else:
    patch_size = (321, 321)
    unet = deeplab.DeepLab(class_num, patch_size, suffix=suffix, batch_size=bs)
overlap = unet.get_overlap()

cm = collectionMaker.read_collection(
    raw_data_path=r'/home/lab/Documents/bohao/data/aemo/aemo_hist',
    field_name='aus10,aus30,aus50',
    field_id='',
    rgb_ext='.*rgb',
    gt_ext='.*gt',
    file_ext='tif',
    force_run=False,
    clc_name=suffix)
cm.print_meta_data()

file_list_train = cm.load_files(field_name='aus10,aus30',
                                field_id='',
Ejemplo n.º 4
0
def main(flags):
    nn_utils.set_gpu(GPU)

    # define network
    model = deeplab.DeepLab(flags.num_classes,
                            flags.tile_size,
                            batch_size=flags.batch_size)

    cm_train = cityscapes_reader.CollectionMakerCityscapes(
        flags.data_dir,
        flags.rgb_type,
        flags.gt_type,
        'train',
        flags.rgb_ext,
        flags.gt_ext, ['png', 'png'],
        clc_name='{}_train'.format(flags.ds_name),
        force_run=False)
    cm_test = cityscapes_reader.CollectionMakerCityscapes(
        flags.data_dir,
        flags.rgb_type,
        flags.gt_type,
        'val',
        flags.rgb_ext,
        flags.gt_ext, ['png', 'png'],
        clc_name='{}_valid'.format(flags.ds_name),
        force_run=False)
    cm_test.print_meta_data()
    resize_func_train = lambda img: skimage.transform.resize(
        img, flags.tile_size, mode='reflect')
    resize_func_test = lambda img: skimage.transform.resize(
        img,
        cm_test.meta_data['tile_dim'],
        order=0,
        preserve_range=True,
        mode='reflect')

    init_op, reader_op = dataReaderSegmentation.DataReaderSegmentation(
        flags.tile_size,
        cm_test.meta_data['file_list'],
        batch_size=flags.batch_size,
        random=False,
        chan_mean=cm_train.meta_data['chan_mean'],
        is_train=False,
        has_gt=True,
        gt_dim=1,
        include_gt=True,
        global_func=resize_func_train).read_op()
    estimator = nn_processor.NNEstimatorSegmentScene(
        model,
        cm_test.meta_data['file_list'],
        flags.res_dir,
        init_op,
        reader_op,
        ds_name='city_scapes',
        save_result_parent_dir='Cityscapes',
        gpu=flags.GPU,
        score_result=True,
        split_char='.',
        post_func=resize_func_test,
        save_func=make_general_id_map,
        ignore_label=(-1, 255))
    estimator.run(force_run=flags.force_run)
Ejemplo n.º 5
0
verb_step = 50
save_epoch = 50
model_dir = r'/hdd6/Models/DeepLab_rand_grid/DeeplabV3_spca_aug_grid_0_PS(321, 321)_BS5_EP60_LR0.0001_DS40_DR0.1_SFN32'

nn_utils.set_gpu(gpu)
np.random.seed(1004)
tf.set_random_seed(1004)

if use_hist:
    suffix += '_hist'

# define network
unet = deeplab.DeepLab(class_num,
                       patch_size,
                       suffix=suffix,
                       learn_rate=lr,
                       decay_step=ds,
                       decay_rate=dr,
                       epochs=epochs,
                       batch_size=bs)
overlap = unet.get_overlap()

cm = collectionMaker.read_collection(
    raw_data_path=r'/home/lab/Documents/bohao/data/aemo/aemo_pad',
    field_name='aus10,aus30,aus50',
    field_id='',
    rgb_ext='.*rgb',
    gt_ext='.*gt',
    file_ext='tif',
    force_run=False,
    clc_name=ds_name)
gt_d255 = collectionEditor.SingleChanMult(cm.clc_dir, 1/255, ['.*gt', 'gt_d255']).\