Exemplo n.º 1
0
def main(flags):
    nn_utils.set_gpu(GPU)

    # define network
    model = unet.UNet(flags.num_classes, flags.patch_size, suffix=flags.model_suffix, learn_rate=flags.learning_rate,
                          decay_step=flags.decay_step, decay_rate=flags.decay_rate, epochs=flags.epochs,
                          batch_size=flags.batch_size)
    overlap = model.get_overlap()

    cm = collectionMaker.read_collection(raw_data_path=flags.data_dir,
                                         field_name='austin,chicago,kitsap,tyrol-w,vienna',
                                         field_id=','.join(str(i) for i in range(37)),
                                         rgb_ext='RGB',
                                         gt_ext='GT',
                                         file_ext='tif',
                                         force_run=False,
                                         clc_name=flags.ds_name)
    gt_d255 = collectionEditor.SingleChanMult(cm.clc_dir, 1 / 255, ['GT', 'gt_d255']). \
        run(force_run=False, file_ext='png', d_type=np.uint8, )
    cm.replace_channel(gt_d255.files, True, ['GT', 'gt_d255'])
    cm.print_meta_data()
    file_list_train = cm.load_files(field_id=','.join(str(i) for i in range(6, 37)), field_ext='RGB,gt_d255')
    file_list_valid = cm.load_files(field_id=','.join(str(i) for i in range(6)), field_ext='RGB,gt_d255')
    chan_mean = cm.meta_data['chan_mean'][:3]

    patch_list_train = patchExtractor.PatchExtractor(flags.patch_size, flags.tile_size, flags.ds_name + '_train',
                                                     overlap, overlap // 2). \
        run(file_list=file_list_train, file_exts=['jpg', 'png'], force_run=False).get_filelist()
    patch_list_valid = patchExtractor.PatchExtractor(flags.patch_size, flags.tile_size, flags.ds_name + '_valid',
                                                     overlap, overlap // 2). \
        run(file_list=file_list_valid, file_exts=['jpg', 'png'], force_run=False).get_filelist()

    train_init_op, valid_init_op, reader_op = \
        dataReaderSegmentation.DataReaderSegmentationTrainValid(
            flags.patch_size, patch_list_train, patch_list_valid, batch_size=flags.batch_size, chan_mean=chan_mean,
            aug_func=[reader_utils.image_flipping, reader_utils.image_rotating],
            random=True, has_gt=True, gt_dim=1, include_gt=True, valid_mult=flags.val_mult).read_op()
    feature, label = reader_op

    model.create_graph(feature)
    model.compile(feature, label, flags.n_train, flags.n_valid, flags.patch_size, ersaPath.PATH['model'],
                  par_dir=flags.model_par_dir, val_mult=flags.val_mult, loss_type='xent')
    train_hook = hook.ValueSummaryHook(flags.verb_step, [model.loss, model.lr_op],
                                       value_names=['train_loss', 'learning_rate'], print_val=[0])
    model_save_hook = hook.ModelSaveHook(model.get_epoch_step()*flags.save_epoch, model.ckdir)
    valid_loss_hook = hook.ValueSummaryHook(model.get_epoch_step(), [model.loss, model.loss_iou],
                                            value_names=['valid_loss', 'valid_mIoU'], log_time=True,
                                            run_time=model.n_valid, iou_pos=1)
    image_hook = hook.ImageValidSummaryHook(model.input_size, model.get_epoch_step(), feature, label, model.output,
                                            nn_utils.image_summary, img_mean=cm.meta_data['chan_mean'])
    start_time = time.time()
    model.train(train_hooks=[train_hook, model_save_hook], valid_hooks=[valid_loss_hook, image_hook],
                train_init=train_init_op, valid_init=valid_init_op)
    print('Duration: {:.3f}'.format((time.time() - start_time)/3600))
Exemplo n.º 2
0
def main(flags):
    nn_utils.set_gpu(GPU)

    # define network
    model = pspnet.PSPNet(flags.num_classes, flags.tile_size, suffix=flags.model_suffix, learn_rate=flags.learning_rate,
                          decay_step=flags.decay_step, decay_rate=flags.decay_rate, epochs=flags.epochs,
                          batch_size=flags.batch_size, weight_decay=flags.weight_decay, momentum=flags.momentum)

    cm_train = cityscapes_reader.CollectionMakerCityscapes(flags.data_dir, flags.rgb_type, flags.gt_type, 'train',
                                                           flags.rgb_ext, flags.gt_ext, ['png', 'png'],
                                                           clc_name='{}_train'.format(flags.ds_name),
                                                           force_run=flags.force_run)
    cm_valid = cityscapes_reader.CollectionMakerCityscapes(flags.data_dir, flags.rgb_type, flags.gt_type, 'val',
                                                           flags.rgb_ext, flags.gt_ext, ['png', 'png'],
                                                           clc_name='{}_valid'.format(flags.ds_name),
                                                           force_run=flags.force_run)
    cm_train.print_meta_data()
    cm_valid.print_meta_data()

    resize_func = lambda img: resize_image(img, flags.tile_size)
    train_init_op, valid_init_op, reader_op = dataReaderSegmentation.DataReaderSegmentationTrainValid(
            flags.tile_size, cm_train.meta_data['file_list'], cm_valid.meta_data['file_list'],
            flags.batch_size, cm_train.meta_data['chan_mean'], aug_func=[reader_utils.image_flipping_hori,
                                                                         reader_utils.image_scaling_with_label],
            random=True, has_gt=True, gt_dim=1, include_gt=True, valid_mult=flags.val_mult, global_func=resize_func)\
        .read_op()
    feature, label = reader_op

    model.create_graph(feature)
    model.load_resnet(flags.res_dir)
    model.compile(feature, label, flags.n_train, flags.n_valid, flags.tile_size, ersaPath.PATH['model'],
                  par_dir=flags.model_par_dir, val_mult=flags.val_mult, loss_type='xent')
    train_hook = hook.ValueSummaryHook(flags.verb_step, [model.loss, model.lr_op],
                                       value_names=['train_loss', 'learning_rate'], print_val=[0])
    model_save_hook = hook.ModelSaveHook(model.get_epoch_step()*flags.save_epoch, model.ckdir)
    valid_loss_hook = hook.ValueSummaryHook(model.get_epoch_step(), [model.loss, model.loss_iou],
                                            value_names=['valid_loss', 'valid_mIoU'], log_time=True,
                                            run_time=model.n_valid, iou_pos=1)
    image_hook = hook.ImageValidSummaryHook(model.input_size, model.get_epoch_step(), feature, label, model.output,
                                            cityscapes_labels.image_summary, img_mean=cm_train.meta_data['chan_mean'])
    start_time = time.time()
    model.train(train_hooks=[train_hook, model_save_hook], valid_hooks=[valid_loss_hook, image_hook],
                train_init=train_init_op, valid_init=valid_init_op)
    print('Duration: {:.3f}'.format((time.time() - start_time)/3600))
Exemplo n.º 3
0
file_list_valid = cm.load_files(field_name='aus50', field_id='', field_ext='.*rgb_hist,.*gt_d255')
chan_mean = cm.meta_data['chan_mean'][-3:]

patch_list_train = patchExtractor.PatchExtractor(patch_size, tile_size, ds_name+'_train', overlap, overlap//2).\
    run(file_list=file_list_train, file_exts=['jpg', 'png'], force_run=False).get_filelist()
patch_list_valid = patchExtractor.PatchExtractor(patch_size, tile_size, ds_name+'_valid', overlap, overlap//2).\
    run(file_list=file_list_valid, file_exts=['jpg', 'png'], force_run=False).get_filelist()

train_init_op, valid_init_op, reader_op = \
    dataReaderSegmentation.DataReaderSegmentationTrainValid(
        patch_size, patch_list_train, patch_list_valid, batch_size=bs, chan_mean=chan_mean,
        aug_func=[reader_utils.image_flipping, reader_utils.image_rotating],
        random=True, has_gt=True, gt_dim=1, include_gt=True, valid_mult=valid_mult).read_op()
feature, label = reader_op

unet.create_graph(feature)
unet.compile(feature, label, n_train, n_valid, patch_size, ersaPath.PATH['model'], par_dir=ds_name, loss_type='xent')
train_hook = hook.ValueSummaryHook(verb_step, [unet.loss, unet.lr_op], value_names=['train_loss', 'learning_rate'],
                                   print_val=[0])
model_save_hook = hook.ModelSaveHook(unet.get_epoch_step()*save_epoch, unet.ckdir)
valid_loss_hook = hook.ValueSummaryHook(unet.get_epoch_step(), [unet.loss, unet.loss_iou],
                                        value_names=['valid_loss', 'IoU'], log_time=True, run_time=unet.n_valid,
                                        iou_pos=1)
image_hook = hook.ImageValidSummaryHook(unet.input_size, unet.get_epoch_step(), feature, label, unet.pred,
                                        nn_utils.image_summary, img_mean=chan_mean)
start_time = time.time()
unet.load(model_dir)
unet.train(train_hooks=[train_hook, model_save_hook], valid_hooks=[valid_loss_hook, image_hook],
           train_init=train_init_op, valid_init=valid_init_op)
print('Duration: {:.3f}'.format((time.time() - start_time)/3600))
Exemplo n.º 4
0
                   label,
                   n_train,
                   n_valid,
                   patch_size,
                   ersaPath.PATH['model'],
                   par_dir=par_dir,
                   loss_type='xent',
                   train_var_filter=[
                       'layerup{}'.format(i)
                       for i in range(start_layer, 10)
                   ])
 train_hook = hook.ValueSummaryHook(
     verb_step, [model.loss, model.lr_op],
     value_names=['train_loss', 'learning_rate'],
     print_val=[0])
 model_save_hook = hook.ModelSaveHook(
     model.get_epoch_step() * save_epoch, model.ckdir)
 valid_loss_hook = hook.ValueSummaryHookIters(
     model.get_epoch_step(), [model.loss_xent, model.loss_iou],
     value_names=['valid_loss', 'IoU'],
     log_time=True,
     run_time=model.n_valid)
 image_hook = hook.ImageValidSummaryHook(model.input_size,
                                         model.get_epoch_step(),
                                         feature,
                                         label,
                                         model.pred,
                                         nn_utils.image_summary,
                                         img_mean=chan_mean)
 start_time = time.time()
 if not from_scratch:
     model.load(model_dir)
Exemplo n.º 5
0
def main(flags):
    nn_utils.set_gpu(GPU)

    # define network
    model = pspnet.PSPNet(flags.num_classes,
                          flags.patch_size,
                          suffix=flags.model_suffix,
                          learn_rate=flags.learning_rate,
                          decay_step=flags.decay_step,
                          decay_rate=flags.decay_rate,
                          epochs=flags.epochs,
                          batch_size=flags.batch_size,
                          weight_decay=flags.weight_decay,
                          momentum=flags.momentum)
    overlap = model.get_overlap()

    cm = collectionMaker.read_collection(raw_data_path=flags.data_dir,
                                         field_name='Fresno,Modesto,Stockton',
                                         field_id=','.join(
                                             str(i) for i in range(663)),
                                         rgb_ext='RGB',
                                         gt_ext='GT',
                                         file_ext='jpg,png',
                                         force_run=False,
                                         clc_name=flags.ds_name)
    cm.print_meta_data()
    file_list_train = cm.load_files(field_id=','.join(
        str(i) for i in range(0, 250)),
                                    field_ext='RGB,GT')
    file_list_valid = cm.load_files(field_id=','.join(
        str(i) for i in range(250, 500)),
                                    field_ext='RGB,GT')
    chan_mean = cm.meta_data['chan_mean'][:3]

    patch_list_train = patchExtractor.PatchExtractor(flags.patch_size, flags.tile_size, flags.ds_name + '_train',
                                                     overlap, overlap // 2). \
        run(file_list=file_list_train, file_exts=['jpg', 'png'], force_run=False).get_filelist()
    patch_list_valid = patchExtractor.PatchExtractor(flags.patch_size, flags.tile_size, flags.ds_name + '_valid',
                                                     overlap, overlap // 2). \
        run(file_list=file_list_valid, file_exts=['jpg', 'png'], force_run=False).get_filelist()

    train_init_op, valid_init_op, reader_op = \
        dataReaderSegmentation.DataReaderSegmentationTrainValid(
            flags.patch_size, patch_list_train, patch_list_valid, batch_size=flags.batch_size, chan_mean=chan_mean,
            aug_func=[reader_utils.image_flipping, reader_utils.image_rotating],
            random=True, has_gt=True, gt_dim=1, include_gt=True, valid_mult=flags.val_mult).read_op()
    feature, label = reader_op

    model.create_graph(feature)
    model.load_resnet(flags.res_dir, keep_last=False)
    model.compile(feature,
                  label,
                  flags.n_train,
                  flags.n_valid,
                  flags.patch_size,
                  ersaPath.PATH['model'],
                  par_dir=flags.model_par_dir,
                  val_mult=flags.val_mult,
                  loss_type='xent')
    train_hook = hook.ValueSummaryHook(
        flags.verb_step, [model.loss, model.lr_op],
        value_names=['train_loss', 'learning_rate'],
        print_val=[0])
    model_save_hook = hook.ModelSaveHook(
        model.get_epoch_step() * flags.save_epoch, model.ckdir)
    valid_loss_hook = hook.ValueSummaryHookIters(
        model.get_epoch_step(), [model.loss_xent, model.loss_iou],
        value_names=['valid_loss', 'valid_mIoU'],
        log_time=True,
        run_time=model.n_valid)
    image_hook = hook.ImageValidSummaryHook(model.input_size,
                                            model.get_epoch_step(),
                                            feature,
                                            label,
                                            model.output,
                                            nn_utils.image_summary,
                                            img_mean=cm.meta_data['chan_mean'])
    start_time = time.time()
    model.train(train_hooks=[train_hook, model_save_hook],
                valid_hooks=[valid_loss_hook, image_hook],
                train_init=train_init_op,
                valid_init=valid_init_op)
    print('Duration: {:.3f}'.format((time.time() - start_time) / 3600))
Exemplo n.º 6
0
def main(flags):
    nn_utils.set_gpu(flags.GPU)
    np.random.seed(flags.run_id)
    tf.set_random_seed(flags.run_id)

    # define network
    model = unet.UNet(flags.num_classes,
                      flags.patch_size,
                      suffix=flags.suffix,
                      learn_rate=flags.learn_rate,
                      decay_step=flags.decay_step,
                      decay_rate=flags.decay_rate,
                      epochs=flags.epochs,
                      batch_size=flags.batch_size)
    overlap = model.get_overlap()

    cm = collectionMaker.read_collection(
        raw_data_path=flags.data_dir,
        field_name='Tucson,Colwich,Clyde,Wilmington',
        field_id=','.join(str(i) for i in range(1, 16)),
        rgb_ext='RGB',
        gt_ext='GT',
        file_ext='tif,png',
        force_run=False,
        clc_name=flags.ds_name)
    gt_d255 = collectionEditor.SingleChanSwitch(cm.clc_dir, {
        2: 0,
        3: 1,
        4: 0,
        5: 0,
        6: 0,
        7: 0
    }, ['GT', 'GT_switch'], 'tower_only').run(
        force_run=False,
        file_ext='png',
        d_type=np.uint8,
    )
    cm.replace_channel(gt_d255.files, True, ['GT', 'GT_switch'])
    cm.print_meta_data()

    file_list_train = cm.load_files(
        field_name='Tucson,Colwich,Clyde,Wilmington',
        field_id=','.join(str(i) for i in range(4, 16)),
        field_ext='RGB,GT_switch')
    file_list_valid = cm.load_files(
        field_name='Tucson,Colwich,Clyde,Wilmington',
        field_id='1,2,3',
        field_ext='RGB,GT_switch')

    patch_list_train = patchExtractor.PatchExtractor(flags.patch_size,
                                                     ds_name=flags.ds_name + '_tower_only',
                                                     overlap=overlap, pad=overlap // 2). \
        run(file_list=file_list_train, file_exts=['jpg', 'png'], force_run=False).get_filelist()
    patch_list_valid = patchExtractor.PatchExtractor(flags.patch_size,
                                                     ds_name=flags.ds_name + '_tower_only',
                                                     overlap=overlap, pad=overlap // 2). \
        run(file_list=file_list_valid, file_exts=['jpg', 'png'], force_run=False).get_filelist()
    chan_mean = cm.meta_data['chan_mean']

    train_init_op, valid_init_op, reader_op = \
        dataReaderSegmentation.DataReaderSegmentationTrainValid(
            flags.patch_size, patch_list_train, patch_list_valid, batch_size=flags.batch_size,
            chan_mean=chan_mean,
            aug_func=[reader_utils.image_flipping, reader_utils.image_rotating],
            random=True, has_gt=True, gt_dim=1, include_gt=True, valid_mult=flags.val_mult).read_op()
    feature, label = reader_op

    model.create_graph(feature)
    model.compile(feature,
                  label,
                  flags.n_train,
                  flags.n_valid,
                  flags.patch_size,
                  ersaPath.PATH['model'],
                  par_dir=flags.par_dir,
                  loss_type='xent',
                  pos_weight=flags.pos_weight)
    train_hook = hook.ValueSummaryHook(
        flags.verb_step, [model.loss, model.lr_op],
        value_names=['train_loss', 'learning_rate'],
        print_val=[0])
    model_save_hook = hook.ModelSaveHook(
        model.get_epoch_step() * flags.save_epoch, model.ckdir)
    valid_loss_hook = hook.ValueSummaryHookIters(
        model.get_epoch_step(), [model.loss_iou, model.loss_xent],
        value_names=['IoU', 'valid_loss'],
        log_time=True,
        run_time=model.n_valid)
    image_hook = hook.ImageValidSummaryHook(model.input_size,
                                            model.get_epoch_step(),
                                            feature,
                                            label,
                                            model.pred,
                                            partial(
                                                nn_utils.image_summary,
                                                label_num=flags.num_classes),
                                            img_mean=chan_mean)
    start_time = time.time()
    model.train(train_hooks=[train_hook, model_save_hook],
                valid_hooks=[valid_loss_hook, image_hook],
                train_init=train_init_op,
                valid_init=valid_init_op)
    print('Duration: {:.3f}'.format((time.time() - start_time) / 3600))
Exemplo n.º 7
0
def main(flags):
    nn_utils.set_gpu(GPU)

    # define network
    model = stn.STN(flags.num_classes,
                    flags.tile_size,
                    suffix=flags.model_suffix,
                    learn_rate=flags.learning_rate,
                    decay_step=flags.decay_step,
                    decay_rate=flags.decay_rate,
                    epochs=flags.epochs,
                    batch_size=flags.batch_size)

    cm_train = cityscapes_reader.CollectionMakerCityscapes(
        flags.data_dir,
        flags.rgb_type,
        flags.gt_type,
        'train',
        flags.rgb_ext,
        flags.gt_ext, ['png', 'png'],
        clc_name='{}_train'.format(flags.ds_name),
        force_run=flags.force_run)

    cm_valid = cityscapes_reader.CollectionMakerCityscapes(
        flags.data_dir,
        flags.rgb_type,
        flags.gt_type,
        'val',
        flags.rgb_ext,
        flags.gt_ext, ['png', 'png'],
        clc_name='{}_valid'.format(flags.ds_name),
        force_run=flags.force_run)
    cm_train.print_meta_data()
    cm_valid.print_meta_data()

    train_pred_dir = r'/home/lab/Documents/bohao/data/deeplab_model/vis_train/raw_segmentation_results'
    valid_pred_dir = r'/home/lab/Documents/bohao/data/deeplab_model/vis/raw_segmentation_results'
    file_list_train = get_image_list(flags.data_dir, train_pred_dir, 'train')
    file_list_valid = get_image_list(flags.data_dir, valid_pred_dir, 'val')

    resize_func = lambda img: resize_image(img, flags.tile_size)
    train_init_op, valid_init_op, reader_op = DataReaderSegmentationTrainValid(
            flags.tile_size, file_list_train, file_list_valid,
            flags.batch_size, cm_train.meta_data['chan_mean'], aug_func=[reader_utils.image_flipping_hori],
            random=True, has_gt=True, gt_dim=1, include_gt=True, valid_mult=flags.val_mult, global_func=resize_func)\
        .read_op()
    feature, label, pred = reader_op
    train_init_op_valid, _, reader_op = DataReaderSegmentationTrainValid(
        flags.tile_size, file_list_valid, file_list_train,
        flags.batch_size, cm_valid.meta_data['chan_mean'], aug_func=[reader_utils.image_flipping_hori],
        random=True, has_gt=True, gt_dim=1, include_gt=True, valid_mult=flags.val_mult, global_func=resize_func) \
        .read_op()
    _, _, pred_valid = reader_op

    model.create_graph(pred, feature_valid=pred_valid, rgb=feature)

    model.compile(pred,
                  label,
                  flags.n_train,
                  flags.n_valid,
                  flags.tile_size,
                  ersaPath.PATH['model'],
                  par_dir=flags.model_par_dir,
                  val_mult=flags.val_mult,
                  loss_type='xent')
    train_hook = hook.ValueSummaryHook(
        flags.verb_step, [
            model.loss, model.g_loss, model.d_loss, model.lr_op[0],
            model.lr_op[1], model.lr_op[2]
        ],
        value_names=['seg_loss', 'g_loss', 'd_loss', 'lr_seg', 'lr_g', 'lr_d'],
        print_val=[0, 1, 2])
    model_save_hook = hook.ModelSaveHook(
        model.get_epoch_step() * flags.save_epoch, model.ckdir)
    valid_loss_hook = hook.ValueSummaryHookIters(
        model.get_epoch_step(), [model.loss_xent, model.loss_iou],
        value_names=['valid_loss', 'valid_mIoU'],
        log_time=True,
        run_time=model.n_valid)
    image_hook = hook.ImageValidSummaryHook(
        model.input_size,
        model.get_epoch_step(),
        feature,
        label,
        model.refine,
        cityscapes_labels.image_summary,
        img_mean=cm_train.meta_data['chan_mean'])
    start_time = time.time()
    model.train(train_hooks=[train_hook, model_save_hook],
                valid_hooks=[valid_loss_hook, image_hook],
                train_init=[train_init_op, train_init_op_valid],
                valid_init=valid_init_op)
    print('Duration: {:.3f}'.format((time.time() - start_time) / 3600))
Exemplo n.º 8
0
def main(flags):
    nn_utils.set_gpu(flags.GPU)
    for start_layer in flags.start_layer:
        if start_layer >= 10:
            suffix_base = 'aemo_newloss'
        else:
            suffix_base = 'aemo_newloss_up{}'.format(start_layer)
        if flags.from_scratch:
            suffix_base += '_scratch'
        for lr in flags.learn_rate:
            for run_id in range(4):
                suffix = '{}_{}'.format(suffix_base, run_id)
                tf.reset_default_graph()

                np.random.seed(run_id)
                tf.set_random_seed(run_id)

                # define network
                model = unet.UNet(flags.num_classes, flags.patch_size, suffix=suffix, learn_rate=lr,
                                  decay_step=flags.decay_step, decay_rate=flags.decay_rate,
                                  epochs=flags.epochs, batch_size=flags.batch_size)
                overlap = model.get_overlap()

                cm = collectionMaker.read_collection(raw_data_path=flags.data_dir,
                                                     field_name='aus10,aus30,aus50',
                                                     field_id='',
                                                     rgb_ext='.*rgb',
                                                     gt_ext='.*gt',
                                                     file_ext='tif',
                                                     force_run=False,
                                                     clc_name=flags.ds_name)
                cm.print_meta_data()

                file_list_train = cm.load_files(field_name='aus10,aus30', field_id='', field_ext='.*rgb,.*gt')
                file_list_valid = cm.load_files(field_name='aus50', field_id='', field_ext='.*rgb,.*gt')

                patch_list_train = patchExtractor.PatchExtractor(flags.patch_size, flags.tile_size,
                                                                 flags.ds_name + '_train_hist',
                                                                 overlap, overlap // 2). \
                    run(file_list=file_list_train, file_exts=['jpg', 'png'], force_run=False).get_filelist()
                patch_list_valid = patchExtractor.PatchExtractor(flags.patch_size, flags.tile_size,
                                                                 flags.ds_name + '_valid_hist',
                                                                 overlap, overlap // 2). \
                    run(file_list=file_list_valid, file_exts=['jpg', 'png'], force_run=False).get_filelist()
                chan_mean = cm.meta_data['chan_mean']

                train_init_op, valid_init_op, reader_op = \
                    dataReaderSegmentation.DataReaderSegmentationTrainValid(
                        flags.patch_size, patch_list_train, patch_list_valid, batch_size=flags.batch_size,
                        chan_mean=chan_mean,
                        aug_func=[reader_utils.image_flipping, reader_utils.image_rotating],
                        random=True, has_gt=True, gt_dim=1, include_gt=True, valid_mult=flags.val_mult).read_op()
                feature, label = reader_op

                model.create_graph(feature)
                if start_layer >= 10:
                    model.compile(feature, label, flags.n_train, flags.n_valid, flags.patch_size, ersaPath.PATH['model'],
                                  par_dir=flags.par_dir, loss_type='xent')
                else:
                    model.compile(feature, label, flags.n_train, flags.n_valid, flags.patch_size, ersaPath.PATH['model'],
                                  par_dir=flags.par_dir, loss_type='xent',
                                  train_var_filter=['layerup{}'.format(i) for i in range(start_layer, 10)])
                train_hook = hook.ValueSummaryHook(flags.verb_step, [model.loss, model.lr_op],
                                                   value_names=['train_loss', 'learning_rate'],
                                                   print_val=[0])
                model_save_hook = hook.ModelSaveHook(model.get_epoch_step() * flags.save_epoch, model.ckdir)
                valid_loss_hook = hook.ValueSummaryHookIters(model.get_epoch_step(), [model.loss_xent, model.loss_iou],
                                                             value_names=['valid_loss', 'IoU'], log_time=True,
                                                             run_time=model.n_valid)
                image_hook = hook.ImageValidSummaryHook(model.input_size, model.get_epoch_step(), feature, label,
                                                        model.pred,
                                                        nn_utils.image_summary, img_mean=chan_mean)
                start_time = time.time()
                if not flags.from_scratch:
                    model.load(flags.model_dir)
                model.train(train_hooks=[train_hook, model_save_hook], valid_hooks=[valid_loss_hook, image_hook],
                            train_init=train_init_op, valid_init=valid_init_op)
                print('Duration: {:.3f}'.format((time.time() - start_time) / 3600))
Exemplo n.º 9
0
def main(flags):
    nn_utils.set_gpu(flags.GPU)
    for start_layer in flags.start_layer:
        if start_layer >= 10:
            suffix_base = 'aemo'
        else:
            suffix_base = 'aemo_up{}'.format(start_layer)
        if flags.from_scratch:
            suffix_base += '_scratch'
        for lr in flags.learn_rate:
            for run_id in range(1):
                suffix = '{}_{}'.format(suffix_base, run_id)
                tf.reset_default_graph()

                np.random.seed(run_id)
                tf.set_random_seed(run_id)

                # define network
                model = unet.UNet(flags.num_classes,
                                  flags.patch_size,
                                  suffix=suffix,
                                  learn_rate=lr,
                                  decay_step=flags.decay_step,
                                  decay_rate=flags.decay_rate,
                                  epochs=flags.epochs,
                                  batch_size=flags.batch_size)

                file_list = os.path.join(flags.data_dir, 'file_list.txt')
                lines = ersa_utils.load_file(file_list)

                patch_list_train = []
                patch_list_valid = []
                train_tile_names = ['aus10', 'aus30']
                valid_tile_names = ['aus50']

                for line in lines:
                    tile_name = os.path.basename(
                        line.split(' ')[0]).split('_')[0].strip()
                    if tile_name in train_tile_names:
                        patch_list_train.append(line.strip().split(' '))
                    elif tile_name in valid_tile_names:
                        patch_list_valid.append(line.strip().split(' '))
                    else:
                        raise ValueError

                cm = collectionMaker.read_collection('aemo_align')
                chan_mean = cm.meta_data['chan_mean']

                train_init_op, valid_init_op, reader_op = \
                    dataReaderSegmentation.DataReaderSegmentationTrainValid(
                        flags.patch_size, patch_list_train, patch_list_valid, batch_size=flags.batch_size,
                        chan_mean=chan_mean,
                        aug_func=[reader_utils.image_flipping, reader_utils.image_rotating],
                        random=True, has_gt=True, gt_dim=1, include_gt=True, valid_mult=flags.val_mult).read_op()
                feature, label = reader_op

                model.create_graph(feature)
                if start_layer >= 10:
                    model.compile(feature,
                                  label,
                                  flags.n_train,
                                  flags.n_valid,
                                  flags.patch_size,
                                  ersaPath.PATH['model'],
                                  par_dir=flags.par_dir,
                                  loss_type='xent')
                else:
                    model.compile(feature,
                                  label,
                                  flags.n_train,
                                  flags.n_valid,
                                  flags.patch_size,
                                  ersaPath.PATH['model'],
                                  par_dir=flags.par_dir,
                                  loss_type='xent',
                                  train_var_filter=[
                                      'layerup{}'.format(i)
                                      for i in range(start_layer, 10)
                                  ])
                train_hook = hook.ValueSummaryHook(
                    flags.verb_step, [model.loss, model.lr_op],
                    value_names=['train_loss', 'learning_rate'],
                    print_val=[0])
                model_save_hook = hook.ModelSaveHook(
                    model.get_epoch_step() * flags.save_epoch, model.ckdir)
                valid_loss_hook = hook.ValueSummaryHookIters(
                    model.get_epoch_step(), [model.loss_xent, model.loss_iou],
                    value_names=['valid_loss', 'IoU'],
                    log_time=True,
                    run_time=model.n_valid)
                image_hook = hook.ImageValidSummaryHook(model.input_size,
                                                        model.get_epoch_step(),
                                                        feature,
                                                        label,
                                                        model.pred,
                                                        nn_utils.image_summary,
                                                        img_mean=chan_mean)
                start_time = time.time()
                if not flags.from_scratch:
                    model.load(flags.model_dir)
                model.train(train_hooks=[train_hook, model_save_hook],
                            valid_hooks=[valid_loss_hook, image_hook],
                            train_init=train_init_op,
                            valid_init=valid_init_op)
                print('Duration: {:.3f}'.format(
                    (time.time() - start_time) / 3600))