示例#1
0
def lowlight_train(lowlight_enhance):
    if not os.path.exists(args.ckpt_dir):
        os.makedirs(args.ckpt_dir)
    if not os.path.exists(args.sample_dir):
        os.makedirs(args.sample_dir)

    lr = args.start_lr * np.ones([args.epoch])
    lr[20:] = lr[0] / 10.0

    train_low_data = []
    train_high_data = []

    train_low_data_names = glob('./data/our485/low/*.png') + glob(
        './data/syn/low/*.png')
    train_low_data_names.sort()
    train_high_data_names = glob('./data/our485/high/*.png') + glob(
        './data/syn/high/*.png')
    train_high_data_names.sort()
    assert len(train_low_data_names) == len(train_high_data_names)
    print('[*] Number of training data: %d' % len(train_low_data_names))

    for idx in range(len(train_low_data_names)):
        low_im = load_images(train_low_data_names[idx])
        train_low_data.append(low_im)
        high_im = load_images(train_high_data_names[idx])
        train_high_data.append(high_im)

    eval_low_data = []
    eval_high_data = []

    eval_low_data_name = glob('./data/eval/low/*.*')

    for idx in range(len(eval_low_data_name)):
        eval_low_im = load_images(eval_low_data_name[idx])
        eval_low_data.append(eval_low_im)

    lowlight_enhance.train(train_low_data,
                           train_high_data,
                           eval_low_data,
                           batch_size=args.batch_size,
                           patch_size=args.patch_size,
                           epoch=args.epoch,
                           lr=lr,
                           sample_dir=args.sample_dir,
                           ckpt_dir=os.path.join(args.ckpt_dir, 'Decom'),
                           eval_every_epoch=args.eval_every_epoch,
                           train_phase="Decom")

    lowlight_enhance.train(train_low_data,
                           train_high_data,
                           eval_low_data,
                           batch_size=args.batch_size,
                           patch_size=args.patch_size,
                           epoch=args.epoch,
                           lr=lr,
                           sample_dir=args.sample_dir,
                           ckpt_dir=os.path.join(args.ckpt_dir, 'Relight'),
                           eval_every_epoch=args.eval_every_epoch,
                           train_phase="Relight")
示例#2
0
def lowlight_train(lowlight_enhance):
    if not os.path.exists(args.ckpt_dir):
        os.makedirs(args.ckpt_dir)
    if not os.path.exists(args.sample_dir):
        os.makedirs(args.sample_dir)

    train_low_data = []
    train_high_data = []

    train_low_data_names = glob(
        '/mnt/hdd/wangwenjing/FGtraining/low/*.png')  #./data/train/low/*.png')
    train_low_data_names.sort()
    train_high_data_names = glob('/mnt/hdd/wangwenjing/FGtraining/normal/*.png'
                                 )  #./data/train/normal/*.png')
    train_high_data_names.sort()
    assert len(train_low_data_names) == len(train_high_data_names)
    print('[*] Number of training data: %d' % len(train_low_data_names))

    for idx in range(len(train_low_data_names)):
        if (idx + 1) % 1000 == 0:
            print(idx + 1)
        low_im = load_images(train_low_data_names[idx])
        train_low_data.append(low_im)
        high_im = load_images(train_high_data_names[idx])
        train_high_data.append(high_im)

    eval_low_data = []
    eval_high_data = []

    eval_low_data_name = glob('./data/eval/low/*.*')

    for idx in range(len(eval_low_data_name)):
        eval_low_im = load_images(eval_low_data_name[idx])
        eval_low_data.append(eval_low_im)

    lowlight_enhance.train(train_low_data,
                           train_high_data,
                           eval_low_data,
                           batch_size=args.batch_size,
                           patch_size=args.patch_size,
                           epoch=args.epoch,
                           sample_dir=args.sample_dir,
                           ckpt_dir=args.ckpt_dir,
                           eval_every_epoch=args.eval_every_epoch)
示例#3
0
文件: main.py 项目: ZRSlayee/IP-UIE
def lowlight_train(lowlight_enhance):
    if not os.path.exists(args.ckpt_dir):
        os.makedirs(args.ckpt_dir)
    if not os.path.exists(args.sample_dir):
        os.makedirs(args.sample_dir)

    global_step = tf.Variable(0, trainable=False)
    #     lr = args.start_lr * np.ones([args.epoch])
    lr_Decom = args.start_lr * np.ones([args.decom_epoch])
    lr_Relight = args.start_lr / 10.0 * np.ones([args.relight_epoch])
    #     lr_Relight = tf.train.exponential_decay(args.start_lr, global_step=global_step, decay_steps=400, decay_rate=0.9)
    # 学习率衰减
    lr_Decom[40:] = lr_Decom[0] / 10.0

    train_low_data = []
    train_high_data = []

    train_low_data_names = glob('../DATA/uw_ll/syn/low/*.jpg')
    train_low_data_names.sort()
    train_high_data_names = glob('../DATA/uw_ll/syn/normal/*.jpg')
    train_high_data_names.sort()
    assert len(train_low_data_names) == len(train_high_data_names)
    print('[*] Number of training data: %d' % len(train_low_data_names))

    for idx in range(len(train_low_data_names)):
        low_im = load_images(train_low_data_names[idx])
        train_low_data.append(low_im)
        high_im = load_images(train_high_data_names[idx])
        train_high_data.append(high_im)

    eval_low_data = []
    eval_high_data = []

    eval_low_data_name = glob('../DATA/uw_ll/val/low/*.*')

    for idx in range(len(eval_low_data_name)):
        eval_low_im = load_images(eval_low_data_name[idx])
        eval_low_data.append(eval_low_im)

    lowlight_enhance.train(train_low_data,
                           train_high_data,
                           eval_low_data,
                           batch_size=args.batch_size,
                           patch_size=args.patch_size,
                           epoch=int(args.decom_epoch),
                           lr=lr_Decom,
                           sample_dir=args.sample_dir,
                           ckpt_dir=os.path.join(args.ckpt_dir, 'Decom'),
                           eval_every_epoch=int(args.eval_every_epoch),
                           train_phase="Decom")

    lowlight_enhance.train(train_low_data,
                           train_high_data,
                           eval_low_data,
                           batch_size=args.batch_size,
                           patch_size=args.patch_size,
                           epoch=int(args.relight_epoch),
                           lr=lr_Relight,
                           sample_dir=args.sample_dir,
                           ckpt_dir=os.path.join(args.ckpt_dir, 'Relight'),
                           eval_every_epoch=int(args.eval_every_epoch),
                           train_phase="Relight")
def lowlight_train(lowlight_enhance):
    if not os.path.exists(args.ckpt_dir):
        os.makedirs(args.ckpt_dir)
    if not os.path.exists(args.sample_dir):
        os.makedirs(args.sample_dir)

    lr = args.start_lr * np.ones([args.epoch])
    lr[40:] = lr[0] / 10.0

    train_low_data = []
    train_high_data = []
    train_low_data_eq = []
    train_low_data_clahe = []
    train_high_data_eq = []

    train_low_data_eq_guide = []
    train_high_data_eq_guide = []
    train_low_data_eq_guide_weight = []
    train_low_data_eq_clahe_weight = []
    train_high_data_eq_guide_weight = []

    train_low_data_names = glob('./data/our485/low/*.png')
    train_low_data_names.sort()
    train_high_data_names = glob('./data/our485/low/*.png')
    train_high_data_names.sort()
    assert len(train_low_data_names) == len(train_high_data_names)
    print('[*] Number of training data: %d' % len(train_low_data_names))

    for idx in range(len(train_low_data_names)):
        low_im = load_images(train_low_data_names[idx])
        #low_im = white_world(low_im)
        train_low_data.append(low_im)
        high_im = load_images(train_high_data_names[idx])
        # high_im = white_world(high_im)
        train_high_data.append(high_im)
        # train_low_data_max_chan = np.max(meanFilter(low_im,winSize=(5,5)),axis=2,keepdims=True)
        train_low_data_max_chan = np.max(high_im, axis=2, keepdims=True)

        weight_eq_clahe = 0  #sigmoid(5*(meanFilter(train_low_data_max_chan,(20,20))-0.5))
        train_low_data_max_channel = (1 - weight_eq_clahe) * histeq(
            train_low_data_max_chan) + weight_eq_clahe * adapthisteq(
                train_low_data_max_chan)
        # train_low_data_max_channel = histeq(low_im[:,:,1])

        train_low_data_eq.append(train_low_data_max_channel[:, :, :])

    eval_low_data = []
    eval_high_data = []

    eval_low_data_name = glob('./data/eval15/low/*.*')

    for idx in range(len(eval_low_data_name)):
        eval_low_im = load_images(eval_low_data_name[idx])
        eval_low_data.append(eval_low_im)

    lowlight_enhance.train(train_low_data,
                           train_low_data_eq,
                           eval_low_data,
                           train_high_data,
                           batch_size=args.batch_size,
                           patch_size=args.patch_size,
                           epoch=args.epoch,
                           lr=lr,
                           sample_dir=args.sample_dir,
                           ckpt_dir=os.path.join(args.ckpt_dir, 'Decom'),
                           eval_every_epoch=args.eval_every_epoch,
                           train_phase="Decom")