Ejemplo n.º 1
0
def __evaluate(ds, eval_out_path, filenames=None):
	G = get_G([1, None, None, 3])
	G.load_weights(os.path.join(checkpoint_dir, 'g.h5'))
	G.eval()
	sample_folders = ['lr', 'hr', 'gen', 'bicubic', 'combined']
	for sample_folder in sample_folders:
		tl.files.exists_or_mkdir(os.path.join(eval_out_path, sample_folder))

	# for i,(filename, valid_lr_img) in enumerate(zip(filenames, ds)):
	for i,(valid_lr_img, valid_hr_img) in enumerate(ds):
		valid_lr_img = valid_lr_img.numpy()

		valid_lr_img = np.asarray(valid_lr_img, dtype=np.float32)
		valid_lr_img = valid_lr_img[np.newaxis,:,:,:]
		size = [valid_lr_img.shape[1], valid_lr_img.shape[2]]

		out = G(valid_lr_img).numpy()

		print("LR size: %s /  generated HR size: %s" % (size, out.shape))  # LR size: (339, 510, 3) /  gen HR size: (1, 1356, 2040, 3)
		print("[*] save images")
		if filenames is None:
			tl.vis.save_image(out[0], os.path.join(eval_out_path, 'gen', f'valid_gen_{i}.jpg'))
			tl.vis.save_image(valid_lr_img[0], os.path.join(eval_out_path, 'lr', f'valid_lr_{i}.jpg'))
			tl.vis.save_image(valid_hr_img, os.path.join(eval_out_path, 'hr', f'valid_hr_{i}.jpg'))

			out_bicu = scipy.misc.imresize(valid_lr_img[0], [size[0] * 4, size[1] * 4], interp='bicubic', mode=None)
			tl.vis.save_image(out_bicu, os.path.join(eval_out_path, 'bicubic', f'valid_bicu_{i}.jpg'))
			# tl.vis.save_images(np.array([valid_lr_img[0], np.array(out_bicu), out[0]]), [1,3], os.path.join(eval_out_path, 'combined', f'valid_bicu_{i}.jpg'))
		else:
			tl.vis.save_image(out[0], os.path.join(eval_out_path, 'gen', filename))
			tl.vis.save_image(valid_lr_img[0], os.path.join(eval_out_path, 'lr', filename))
			# tl.vis.save_image(valid_hr_img, os.path.join(eval_out_path, 'hr', filename))

			out_bicu = scipy.misc.imresize(valid_lr_img[0], [size[0] * 4, size[1] * 4], interp='bicubic', mode=None)
			tl.vis.save_image(out_bicu, os.path.join(eval_out_path, 'bicubic', filename))
Ejemplo n.º 2
0
def evaluate():
    ###====================== PRE-LOAD DATA ===========================###
    valid_hr_img_list = sorted(tl.files.load_file_list(path=config.VALID.hr_img_path, regx='.*.png', printable=False))
    valid_lr_img_list = sorted(tl.files.load_file_list(path=config.VALID.lr_img_path, regx='.*.png', printable=False))

    valid_lr_imgs = tl.vis.read_images(valid_lr_img_list, path=config.VALID.lr_img_path, n_threads=32)
    valid_hr_imgs = tl.vis.read_images(valid_hr_img_list, path=config.VALID.hr_img_path, n_threads=32)

    ###========================== DEFINE MODEL ============================###
    G = get_G([1, None, None, 3])
    G.load_weights(os.path.join(checkpoint_dir, 'g.h5'))
    G.eval()

    for imid in range(len(valid_lr_img_list)):
        valid_lr_img = valid_lr_imgs[imid]
        valid_hr_img = valid_hr_imgs[imid]
        valid_lr_img = (valid_lr_img / 127.5) - 1

        valid_lr_img = np.asarray(valid_lr_img, dtype=np.float32)
        valid_lr_img = valid_lr_img[np.newaxis, :, :, :]
        size = [valid_lr_img.shape[1], valid_lr_img.shape[2]]

        out = G(valid_lr_img).numpy()

        print("LR size: %s /  generated HR size: %s" % (size, out.shape))
        print("[*] save images")
        tl.vis.save_image(out[0], os.path.join(save_dir, str(imid + 1) + '_valid_gen.png'))
        tl.vis.save_image(valid_hr_img, os.path.join(save_dir, str(imid + 1) + '_valid_hr.png'))

        out_bicu = scipy.misc.imresize(valid_lr_img[0], [size[0] * 4, size[1] * 4], interp='bicubic', mode=None)
        tl.vis.save_image(out_bicu, os.path.join(save_dir, str(imid + 1) + '_valid_bicubic.png'))
Ejemplo n.º 3
0
def train():
    G = get_G((batch_size, 96, 96, 3))
    D = get_D((batch_size, 384, 384, 3))
    VGG = tl.models.vgg19(pretrained=True, end_with='pool4', mode='static')

    lr_v = tf.Variable(lr_init)
    g_optimizer_init = tf.optimizers.Adam(lr_v, beta_1=beta1)
    g_optimizer = tf.optimizers.Adam(lr_v, beta_1=beta1)
    d_optimizer = tf.optimizers.Adam(lr_v, beta_1=beta1)

    G.train()
    D.train()
    VGG.train()

    train_ds = get_train_data()

    ## initialize learning (G)
    n_step_epoch = round(n_epoch_init // batch_size)
    for epoch in range(n_epoch_init):
        for step, (lr_patchs, hr_patchs) in enumerate(train_ds):
            if lr_patchs.shape[0] != batch_size: # if the remaining data in this epoch < batch_size
                break
            step_time = time.time()
            with tf.GradientTape() as tape:
                fake_hr_patchs = G(lr_patchs)
                mse_loss = tl.cost.mean_squared_error(fake_hr_patchs, hr_patchs, is_mean=True)
            grad = tape.gradient(mse_loss, G.trainable_weights)
            g_optimizer_init.apply_gradients(zip(grad, G.trainable_weights))
            print("Epoch: [{}/{}] step: [{}/{}] time: {:.3f}s, mse: {:.3f} ".format(
                epoch, n_epoch_init, step, n_step_epoch, time.time() - step_time, mse_loss))
			if (epoch == n_epoch_init):
				tl.vis.save_images(fake_hr_patchs.numpy(), [2, 4], os.path.join(save_dir1, 'train_g_init_{}.png'.format(epoch)))
Ejemplo n.º 4
0
def evaluate():
    ## create folders to save result images
    save_dir = "samples/{}".format(tl.global_flag['mode'])
    tl.files.exists_or_mkdir(save_dir)
    checkpoint_dir = "checkpoint"

    ###====================== PRE-LOAD DATA ===========================###
    # train_hr_img_list = sorted(tl.files.load_file_list(path=config.TRAIN.hr_img_path, regx='.*.png', printable=False))
    # train_lr_img_list = sorted(tl.files.load_file_list(path=config.TRAIN.lr_img_path, regx='.*.png', printable=False))
    valid_hr_img_list = sorted(
        tl.files.load_file_list(path=config.VALID.hr_img_path,
                                regx='.*.png',
                                printable=False))
    valid_lr_img_list = sorted(
        tl.files.load_file_list(path=config.VALID.lr_img_path,
                                regx='.*.png',
                                printable=False))

    ## If your machine have enough memory, please pre-load the whole train set.
    # train_hr_imgs = tl.vis.read_images(train_hr_img_list, path=config.TRAIN.hr_img_path, n_threads=32)
    # for im in train_hr_imgs:
    #     print(im.shape)
    valid_lr_imgs = tl.vis.read_images(valid_lr_img_list,
                                       path=config.VALID.lr_img_path,
                                       n_threads=32)
    # for im in valid_lr_imgs:
    #     print(im.shape)
    valid_hr_imgs = tl.vis.read_images(valid_hr_img_list,
                                       path=config.VALID.hr_img_path,
                                       n_threads=32)
    # for im in valid_hr_imgs:
    #     print(im.shape)

    ###========================== DEFINE MODEL ============================###
    imid = 64  # 0: 企鹅  81: 蝴蝶 53: 鸟  64: 古堡
    valid_lr_img = valid_lr_imgs[imid]
    valid_hr_img = valid_hr_imgs[imid]
    # valid_lr_img = get_imgs_fn('test.png', 'data2017/')  # if you want to test your own image
    valid_lr_img = (valid_lr_img / 127.5) - 1  # rescale to [-1, 1]
    # print(valid_lr_img.min(), valid_lr_img.max())

    G = get_G([1, None, None, 3])
    G.load_weights(checkpoint_dir + '/g_{}.h5'.format(tl.global_flag['mode']))
    G.eval()

    valid_lr_img = np.asarray(valid_lr_img, dtype=np.float32)

    out = G(valid_lr_img).numpy()

    print("LR size: %s /  generated HR size: %s" % (size, out.shape)
          )  # LR size: (339, 510, 3) /  gen HR size: (1, 1356, 2040, 3)
    print("[*] save images")
    tl.vis.save_image(out[0], save_dir + '/valid_gen.png')
    tl.vis.save_image(valid_lr_img, save_dir + '/valid_lr.png')
    tl.vis.save_image(valid_hr_img, save_dir + '/valid_hr.png')

    out_bicu = scipy.misc.imresize(valid_lr_img, [size[0] * 4, size[1] * 4],
                                   interp='bicubic',
                                   mode=None)
    tl.vis.save_image(out_bicu, save_dir + '/valid_bicubic.png')
Ejemplo n.º 5
0
def evaluate():
    ###====================== PRE-LOAD DATA ===========================###
    # train_hr_img_list = sorted(tl.files.load_file_list(path=config.TRAIN.hr_img_path, regx='.*.png', printable=False))
    # train_lr_img_list = sorted(tl.files.load_file_list(path=config.TRAIN.lr_img_path, regx='.*.png', printable=False))
    valid_hr_img_list = sorted(
        tl.files.load_file_list(path=config.VALID.hr_img_path,
                                regx='.*.png',
                                printable=False))
    valid_lr_img_list = sorted(
        tl.files.load_file_list(path=config.VALID.lr_img_path,
                                regx='.*.png',
                                printable=False))

    ## if your machine have enough memory, please pre-load the whole train set.
    # train_hr_imgs = tl.vis.read_images(train_hr_img_list, path=config.TRAIN.hr_img_path, n_threads=32)
    # for im in train_hr_imgs:
    #     print(im.shape)
    valid_lr_imgs = tl.vis.read_images(valid_lr_img_list,
                                       path=config.VALID.lr_img_path,
                                       n_threads=32)
    # for im in valid_lr_imgs:
    #     print(im.shape)
    valid_hr_imgs = tl.vis.read_images(valid_hr_img_list,
                                       path=config.VALID.hr_img_path,
                                       n_threads=32)
    # for im in valid_hr_imgs:
    #     print(im.shape)

    ###========================== DEFINE MODEL ============================###
    imid = 64  # 0: 企鹅  81: 蝴蝶 53: 鸟  64: 古堡
    valid_lr_img = valid_lr_imgs[imid]
    valid_hr_img = valid_hr_imgs[imid]
    # valid_lr_img = get_imgs_fn('test.png', 'data2017/')  # if you want to test your own image
    valid_lr_img = (valid_lr_img / 127.5) - 1  # rescale to [-1, 1]
    # print(valid_lr_img.min(), valid_lr_img.max())

    G = get_G([1, None, None, 3])
    G.load_weights(os.path.join(checkpoint_dir, 'g.h5'))
    G.eval()

    valid_lr_img = np.asarray(valid_lr_img, dtype=np.float32)
    valid_lr_img = valid_lr_img[np.newaxis, :, :, :]
    size = [valid_lr_img.shape[1], valid_lr_img.shape[2]]

    out = G(valid_lr_img).numpy()

    print("LR size: %s /  generated HR size: %s" % (size, out.shape)
          )  # LR size: (339, 510, 3) /  gen HR size: (1, 1356, 2040, 3)
    print("[*] save images")
    tl.vis.save_image(out[0], os.path.join(save_dir, 'valid_gen.png'))
    tl.vis.save_image(valid_lr_img[0], os.path.join(save_dir, 'valid_lr.png'))
    tl.vis.save_image(valid_hr_img, os.path.join(save_dir, 'valid_hr.png'))

    out_bicu = scipy.misc.imresize(valid_lr_img[0], [size[0] * 4, size[1] * 4],
                                   interp='bicubic',
                                   mode=None)
    tl.vis.save_image(out_bicu, os.path.join(save_dir, 'valid_bicubic.png'))
Ejemplo n.º 6
0
def train_L1():
    '''使用L1 Loss训练生成器
    step:每隔多少张图片取一张加入训练集中,step=1表示使用全部图片
    start:从第几张图片开始
    '''
    train_loss=[]
    val_loss=[]
    G = model.get_G((config.batch_size_init, 56, 56, 3))
    #载入权重G
    if os.path.exists(os.path.join(config.path_model, 'g_init.h5')):
        G.load_weights(os.path.join(config.path_model, 'g_init.h5'))
    #训练数据集    
    train_ds, train_len = config.get_train_data(config.batch_size_init, step = config.train_step_init, start = 0)
    #验证数据集
    val_ds = config.get_valid_data(batch_size = config.batch_size_init, valid_size = config.batch_size_init, step = 10)
    print('训练集一共有{}张图片'.format(train_len)) 
    g_optimizer_init = tf.optimizers.Adam(learning_rate=config.lr_init)
    for epoch in range(config.n_epoch_init):
        step_time = time.time()
        epoch_loss=0
        i = 0 # 计数器
        G.train()
        for step, (lr_patchs, hr_patchs) in enumerate(train_ds):
            with tf.GradientTape() as tape:
                fake_hr_patchs = G(lr_patchs)
                l1_loss = tl.cost.absolute_difference_error(fake_hr_patchs, hr_patchs, is_mean=True)
            grad = tape.gradient(l1_loss, G.trainable_weights)
            g_optimizer_init.apply_gradients(zip(grad, G.trainable_weights))
            
            epoch_loss += l1_loss
            i += 1
        G.eval()
        v_loss = 0
        j=0
        for _, (lr_patchs, hr_patchs) in enumerate(val_ds):
            val_lr = lr_patchs
            val_hr = hr_patchs
            val_fr = G(val_lr)
            loss = tl.cost.absolute_difference_error(val_fr, val_hr, is_mean=True)
            v_loss += loss
            j+=1
        
        train_loss.append(epoch_loss/i)
        val_loss.append(v_loss/j)
        print('Epoch: [{}/{}] time: {:.3f}s : mean train loss of  is {:.5f}, mean valid loss is {:.5f}'.
              format(epoch, config.n_epoch_init,time.time() - step_time, epoch_loss/i,v_loss/j))
        train_loss_file = os.path.join(config.path_loss, 'train_loss_L1.txt') 
        np.savetxt(train_loss_file,train_loss)   
        valid_loss_file = os.path.join(config.path_loss, 'valid_loss_L1.txt')   
        np.savetxt(valid_loss_file,val_loss)
        G.save_weights(os.path.join(config.path_model, 'g_init_epoch_{}.h5'.format(epoch)))   
    G.save_weights(os.path.join(config.path_model, 'g_init.h5')) 
def evaluate():

    valid_hr_img_list = sorted(
        tl.files.load_file_list(path=config.VALID.hr_img_path,
                                regx='.*.png',
                                printable=False))
    valid_lr_img_list = sorted(
        tl.files.load_file_list(path=config.VALID.lr_img_path,
                                regx='.*.png',
                                printable=False))

    valid_lr_imgs = tl.vis.read_images(valid_lr_img_list,
                                       path=config.VALID.lr_img_path,
                                       n_threads=32)
    # for im in valid_lr_imgs:
    #     print(im.shape)
    valid_hr_imgs = tl.vis.read_images(valid_hr_img_list,
                                       path=config.VALID.hr_img_path,
                                       n_threads=32)

    imid = 64
    valid_lr_img = valid_lr_imgs[imid]
    valid_hr_img = valid_hr_imgs[imid]

    valid_lr_img = (valid_lr_img / 127.5) - 1

    G = get_G([1, None, None, 3])
    G.load_weights(os.path.join(checkpoint_dir, 'g.h5'))
    G.eval()

    valid_lr_img = np.asarray(valid_lr_img, dtype=np.float32)
    valid_lr_img = valid_lr_img[np.newaxis, :, :, :]
    size = [valid_lr_img.shape[1], valid_lr_img.shape[2]]

    out = G(valid_lr_img).numpy()

    print("LR size: %s /  generated HR size: %s" % (size, out.shape)
          )  # LR size: (339, 510, 3) /  gen HR size: (1, 1356, 2040, 3)
    print("[*] save images")
    tl.vis.save_image(out[0], os.path.join(save_dir, 'valid_gen.png'))
    tl.vis.save_image(valid_lr_img[0], os.path.join(save_dir, 'valid_lr.png'))
    tl.vis.save_image(valid_hr_img, os.path.join(save_dir, 'valid_hr.png'))

    out_bicu = scipy.misc.imresize(valid_lr_img[0], [size[0] * 4, size[1] * 4],
                                   interp='bicubic',
                                   mode=None)
    tl.vis.save_image(out_bicu, os.path.join(save_dir, 'valid_bicubic.png'))
Ejemplo n.º 8
0
def evaluate():
    valid_hr_img_list = sorted(
        tl.files.load_file_list(path=config.path_valid_HR_orin,
                                regx='.*.png',
                                printable=False))[:]
    valid_lr_img_list = sorted(
        tl.files.load_file_list(path=config.path_valid_LR_orin,
                                regx='.*.png',
                                printable=False))[:]

    valid_lr_imgs = tl.vis.read_images(valid_lr_img_list,
                                       path=config.path_valid_LR_orin,
                                       n_threads=8)
    valid_hr_imgs = tl.vis.read_images(valid_hr_img_list,
                                       path=config.path_valid_HR_orin,
                                       n_threads=8)

    imid = 0  # 0: 企鹅  81: 蝴蝶 53: 鸟  64: 古堡
    valid_lr_img = valid_lr_imgs[imid]
    #print(valid_lr_img.shape)
    valid_hr_img = valid_hr_imgs[imid]
    valid_lr_img = (valid_lr_img / 127.5) - 1  # rescale to [-1, 1]
    valid_lr_img = np.asarray(valid_lr_img, dtype=np.float32)
    valid_lr_img = valid_lr_img[np.newaxis, :, :, :]
    W, H = valid_hr_img.shape[0], valid_hr_img.shape[1]

    G = model.get_G([1, None, None, 3])
    G.load_weights(os.path.join(config.path_model, 'g_gan.h5'))
    G.eval()
    #网络输出图像
    gen_img = G(valid_lr_img).numpy()

    #插值放大的图像
    out_bicu = config.resize_img(valid_lr_img, (W, H))

    tl.vis.save_image(gen_img[0], os.path.join(config.path_pic, 'fh.png'))
    tl.vis.save_image(valid_lr_img[0], os.path.join(config.path_pic, 'rl.png'))
    tl.vis.save_image(valid_hr_img, os.path.join(config.path_pic, 'hr.png'))
    tl.vis.save_image(out_bicu[0], os.path.join(config.path_pic, 'bh.png'))

    print('验证图像已保存在{}文件夹中'.format(config.path_pic))
Ejemplo n.º 9
0
def train():
    images, images_path = get_Chairs(
        flags.output_size, flags.n_epoch, flags.batch_size)
    G = get_G([None, flags.z_dim])
    D = get_D([None, flags.output_size, flags.output_size, flags.n_channel])
    Q = get_Q([None, 1024])

    G.train()
    D.train()
    Q.train()

    g_optimizer = tf.optimizers.Adam(
        learning_rate=flags.G_learning_rate, beta_1=0.5)
    d_optimizer = tf.optimizers.Adam(
        learning_rate=flags.D_learning_rate, beta_1=0.5)

    n_step_epoch = int(len(images_path) // flags.batch_size)
    his_g_loss = []
    his_d_loss = []
    his_mutual = []
    count = 0

    for epoch in range(flags.n_epoch):
        for step, batch_images in enumerate(images):
            count += 1
            if batch_images.shape[0] != flags.batch_size:
                break
            step_time = time.time()
            with tf.GradientTape(persistent=True) as tape:
                noise, cat1, cat2, cat3, con = gen_noise()
                fake_logits, mid = D(G(noise))
                real_logits, _ = D(batch_images)
                f_cat1, f_cat2, f_cat3, f_mu = Q(mid)

                # base = tf.random.normal(shape=f_mu.shape)
                # f_con = f_mu + base * tf.exp(f_sigma)
                d_loss_fake = tl.cost.sigmoid_cross_entropy(
                    output=fake_logits, target=tf.zeros_like(fake_logits), name='d_loss_fake')
                d_loss_real = tl.cost.sigmoid_cross_entropy(
                    output=real_logits, target=tf.ones_like(real_logits), name='d_loss_real')
                d_loss = d_loss_fake + d_loss_real

                g_loss_tmp = tl.cost.sigmoid_cross_entropy(
                    output=fake_logits, target=tf.ones_like(fake_logits), name='g_loss_fake')

                mutual_disc = calc_disc_mutual(
                    f_cat1, f_cat2, f_cat3, cat1, cat2, cat3)
                mutual_cont = calc_cont_mutual(f_mu, con)
                mutual = (flags.disc_lambda*mutual_disc +
                          flags.cont_lambda*mutual_cont)
                g_loss = mutual + g_loss_tmp
                d_tr = d_loss + mutual

            grads = tape.gradient(
                g_loss, G.trainable_weights + Q.trainable_weights)  # 一定要可求导
            g_optimizer.apply_gradients(
                zip(grads, G.trainable_weights + Q.trainable_weights))
            grads = tape.gradient(
                d_tr, D.trainable_weights)
            d_optimizer.apply_gradients(
                zip(grads, D.trainable_weights))
            del tape

            print("Epoch: [{}/{}] [{}/{}] took: {}, d_loss: {:.5f}, g_loss: {:.5f}, mutual: {:.5f}".format(
                epoch, flags.n_epoch, step, n_step_epoch, time.time()-step_time, d_loss, g_loss, mutual))

            if count % flags.save_every_it == 1:
                his_g_loss.append(g_loss)
                his_d_loss.append(d_loss)
                his_mutual.append(mutual)

        plt.plot(his_d_loss)
        plt.plot(his_g_loss)
        plt.plot(his_mutual)
        plt.legend(['D_Loss', 'G_Loss', 'Mutual_Info'])
        plt.xlabel(f'Iterations / {flags.save_every_it}')
        plt.ylabel('Loss')
        plt.savefig(f'{flags.result_dir}/loss.jpg')
        plt.clf()
        plt.close()

        G.save_weights(f'{flags.checkpoint_dir}/G.npz', format='npz')
        D.save_weights(f'{flags.checkpoint_dir}/D.npz', format='npz')
        G.eval()
        for k in range(flags.n_samples):
            z = gen_eval_noise(flags.save_every_epoch, flags.n_samples)
            result = G(z)
            tl.visualize.save_images(result.numpy(), [
                                     flags.save_every_epoch, flags.n_samples], f'result/train_{epoch}_{k}.png')
        G.train()
Ejemplo n.º 10
0
def evaluate():
    ###====================== PRE-LOAD DATA ===========================###
    # train_hr_img_list = sorted(tl.files.load_file_list(path=config.TRAIN.hr_img_path, regx='.*.png', printable=False))
    # train_lr_img_list = sorted(tl.files.load_file_list(path=config.TRAIN.lr_img_path, regx='.*.png', printable=False))
    valid_hr_img_list = sorted(
        tl.files.load_file_list(path=config.VALID.hr_img_path,
                                regx='.*.png',
                                printable=False))
    valid_lr_img_list = sorted(
        tl.files.load_file_list(path=config.VALID.lr_img_path,
                                regx='.*.png',
                                printable=False))

    ## if your machine have enough memory, please pre-load the whole train set.
    # train_hr_imgs = tl.vis.read_images(train_hr_img_list, path=config.TRAIN.hr_img_path, n_threads=32)
    # for im in train_hr_imgs:
    #     print(im.shape)
    valid_lr_imgs = tl.vis.read_images(valid_lr_img_list,
                                       path=config.VALID.lr_img_path,
                                       n_threads=32)
    # for im in valid_lr_imgs:
    #     print(im.shape)
    valid_hr_imgs = tl.vis.read_images(valid_hr_img_list,
                                       path=config.VALID.hr_img_path,
                                       n_threads=32)
    # for im in valid_hr_imgs:
    #     print(im.shape)

    ###========================== DEFINE MODEL ============================###
    imid = 64  # 0: 企鹅  81: 蝴蝶 53: 鸟  64: 古堡
    # valid_lr_img = valid_lr_imgs[imid]
    # valid_hr_img = valid_hr_imgs[imid]

    G = get_G([1, None, None, 3])
    G.load_weights(os.path.join(checkpoint_dir, 'g.h5'))
    G.eval()

    imgs = {}
    valid_hr_img = Image.open("face3.jpg")
    valid_lr_img = valid_hr_img.resize(
        (int(valid_hr_img.size[0] / 4), int(valid_hr_img.size[1] / 4)))
    valid_lr_img = valid_hr_img
    valid_lr_img = np.array(valid_lr_img, dtype=np.float32)
    valid_lr_img = valid_lr_img / 255.0
    # print(valid_lr_img.min(), valid_lr_img.max())

    valid_hr_img = np.array(valid_hr_img, dtype=np.float32)
    valid_hr_img = valid_hr_img / 255.0
    imgs['hr'] = valid_hr_img
    imgs['lr'] = valid_lr_img

    size = [imgs['lr'].shape[0], imgs['lr'].shape[1]]

    input_img = imgs['lr'] * 2 - 1  # rescale to [-1, 1]
    input_img = input_img[np.newaxis, :, :, :]
    out = G(input_img)

    out = out.numpy().squeeze()
    imgs['srgan'] = (out + 1) / 2

    imgs['bicubic'] = scipy.misc.imresize(imgs['lr'],
                                          [size[0] * 4, size[1] * 4],
                                          interp='bicubic',
                                          mode=None)

    print("LR size: %s /  generated HR size: %s" % (size, imgs['hr'].shape[:])
          )  # LR size: (339, 510, 3) /  gen HR size: (1, 1356, 2040, 3)
    print("[*] save images")
    for name, picture in imgs.items():
        tl.vis.save_image(picture,
                          os.path.join(save_dir, 'valid_{}.png'.format(name)))

    fig, axes = plt.subplots(
        2,
        2,
        figsize=(9, 5),
        dpi=500,
        squeeze=False,
    )
    axes = list(chain(*axes))
    for ax, name, picture in zip(axes, imgs.keys(), imgs.values()):
        try:
            ax.imshow(picture)
            ax.set_title(name)
            ax.set_xticks([])
            ax.set_yticks([])
            ax.set_axis_off()
        except:
            print(name)

    # 取消坐标轴
    plt.axis('off')
    plt.show()
    plt.close(fig)
Ejemplo n.º 11
0
def main(config):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    data_loader = get_dataloader(config)
    G = get_G()
    D = get_D()
    G.to(device)
    D.to(device)
    G = nn.DataParallel(G)
    D = nn.DataParallel(D)
    optimizer_G = torch.optim.Adam(G.parameters(),
                                   lr=config.lr,
                                   betas=(0.5, 0.999))
    optimizer_D = torch.optim.Adam(D.parameters(),
                                   lr=config.lr,
                                   betas=(0.5, 0.999))

    criterionGAN = GANLoss(use_lsgan=True).to(device)
    criterionL1 = nn.L1Loss()

    make_dir(config.result_dir)
    make_dir(config.sample_dir)
    make_dir(config.model_dir)
    make_dir(config.log_dir)
    total_steps = 0

    for epoch in range(config.epoch_count,
                       config.niter + config.niter_decay + 1):
        SAVE_IMAGE_DIR = "{}/{}".format(config.sample_dir, epoch)
        make_dir(SAVE_IMAGE_DIR)

        for i, (real_A, real_B) in enumerate(
                DataLoader(data_loader,
                           batch_size=config.batch_size,
                           shuffle=True)):

            real_A = real_A.to(device)
            real_B = real_B.to(device)

            ### Making fake B image
            fake_B = G(real_A)

            ### Update D
            ## Set gradients
            set_requires_grad(D, True)
            ## Optimizer D
            optimizer_D.zero_grad()
            ## Backward
            # Fake
            pred_fake = D(fake_B.detach())
            loss_D_fake = criterionGAN(pred_fake, False)
            # Real
            pred_real = D(real_B.detach())
            loss_D_real = criterionGAN(pred_real, True)
            # Conbined loss
            loss_D = (loss_D_fake + loss_D_real) * 0.5
            loss_D.backward()
            ## Optimizer step
            optimizer_D.step()

            ### Update G
            ## Set gradients
            set_requires_grad(D, False)
            ## Optimizer G
            optimizer_G.zero_grad()
            ## Backward
            pred_fake = D(fake_B)
            loss_G_GAN = criterionGAN(pred_fake, True)
            loss_G_L1 = criterionL1(fake_B, real_B) * config.lambda_L1
            loss_G = loss_G_GAN + loss_G_L1
            loss_G.backward()
            ## Optimizer step
            optimizer_G.step()

            if total_steps % config.print_freq == 0:
                # if total_steps % 1 == 0:
                # Print
                print(
                    "Loss D Fake:{:.4f}, D Real:{:.4f}, D Total:{:.4f}, G GAN:{:.4f}, G L1:{:.4f}. G Total:{:.4f}"
                    .format(loss_D_fake, loss_D_real, loss_D, loss_G_GAN,
                            loss_G_L1, loss_G))
                # Save image
                save_image(fake_B, "{}/{}.png".format(SAVE_IMAGE_DIR, i))
            total_steps += 1

        if epoch % config.save_epoch_freq == 0:
            # Save model
            print("Save models in {} epochs".format(epoch))
            save_checkpoint("{}/D_{}.pth".format(config.model_dir, epoch), D,
                            optimizer_D)
            save_checkpoint("{}/G_{}.pth".format(config.model_dir, epoch), G,
                            optimizer_G)
Ejemplo n.º 12
0
def train_adv():
     #with tf.device('/cpu:0'):
    
        
    '''initialize model'''
    G = model.get_G((config.batch_size_adv, 56, 56, 3))
    D = model.get_D((config.batch_size_adv, 224, 224, 3))
    vgg22 = model.VGG16_22((config.batch_size_adv, 224, 224, 3))
    
    G.load_weights(os.path.join(config.path_model, 'g_init.h5'))
    '''optimizer'''
    #g_optimizer_init = tf.optimizers.Adam(learning_rate=0.001)
    g_optimizer = tf.optimizers.Adam(learning_rate=0.0001)
    d_optimizer = tf.optimizers.Adam(learning_rate=0.0001)

    G.train()
    D.train()
    vgg22.train()
    train_ds, train_len = config.get_train_data(config.batch_size_adv, step = config.train_step_adv, start = 0)
    print('训练集一共有{}张图片'.format(train_len))
    '''initialize generator with L1 loss in pixel spase'''
    
    '''train with GAN and vgg16-22 loss'''
    n_step_epoch = round(train_len // config.batch_size_adv)
    for epoch in range(config.n_epoch_adv):
        #一个epoch累计损失,初始化为0
        mse_ls=0; vgg_ls=0; gan_ls=0; d_ls=0
        #计数
        i=0
        for step, (lr_patchs, hr_patchs) in enumerate(train_ds):
            step_time = time.time()
            with tf.GradientTape(persistent=True) as tape:
                  
                fake_patchs = G(lr_patchs)
                feature22_fake = vgg22(fake_patchs) # the pre-trained VGG uses the input range of [0, 1]
                feature22_real = vgg22(hr_patchs)
                logits_fake = D(fake_patchs)
                logits_real = D(hr_patchs)
                
                #g_vgg_loss = 2e-3 * tl.cost.mean_squared_error(feature22_fake, feature22_real, is_mean=True)
                #g_gan_loss = -tf.reduce_mean(logits_fake)  
                d_loss1 = tl.cost.sigmoid_cross_entropy(logits_real, tf.ones_like(logits_real))
                d_loss2 = tl.cost.sigmoid_cross_entropy(logits_fake, tf.zeros_like(logits_fake))
                d_loss = d_loss1 + d_loss2
                
                g_gan_loss = 1e-3 * tl.cost.sigmoid_cross_entropy(logits_fake, tf.ones_like(logits_fake))
                mse_loss = tl.cost.mean_squared_error(fake_patchs, hr_patchs, is_mean=True)
                vgg_loss = 1e-4 * tl.cost.mean_squared_error(feature22_fake, feature22_real, is_mean=True)
                g_loss = mse_loss + vgg_loss + g_gan_loss
                
                mse_ls+=mse_loss
                vgg_ls+=vgg_loss
                gan_ls+=g_gan_loss
                d_ls+=d_loss
                i+=1
                
                ''' WGAN-gp 未完成
                d_loss = tf.reduce_mean(logits_fake) - tf.reduce_mean(logits_real)
                g_loss = g_vgg_loss + g_gan_loss 
                eps = tf.random.uniform([batch_size, 1, 1, 1], minval=0., maxval=1.)
                interpolates  = eps*hr_patchs + (1. - eps)*fake_patchs
                grad = tape.gradient(D(interpolates), interpolates)
                slopes = tf.sqrt(tf.reduce_sum(tf.square(grad), axis=[1,2,3]))
                gradient_penalty = 0.1*tf.reduce_mean((slopes-1.)**2)
                
                d_loss += gradient_penalty
                '''
            grad = tape.gradient(g_loss, G.trainable_weights)
            g_optimizer.apply_gradients(zip(grad, G.trainable_weights))
            grad = tape.gradient(d_loss, D.trainable_weights)
            d_optimizer.apply_gradients(zip(grad, D.trainable_weights))
            del(tape)
            print("Epoch: [{}/{}] step: [{}/{}] time: {:.3f}s, g_loss(mse:{:.5f}, vgg:{:.5f}, adv:{:.5f}), d_loss: {:.5f}".format(
                        epoch, config.n_epoch_adv, step, n_step_epoch, time.time() - step_time, mse_loss, vgg_loss, g_gan_loss, d_loss))
        print('~~~~~~~~~~~~Epoch {}平均损失~~~~~~~~~~~~~~~~'.format(epoch))
        print("Epoch: [{}/{}] time: {:.3f}s, g_loss(mse:{:.5f}, vgg:{:.5f}, adv:{:.5f}), d_loss: {:.5f}".format(
                        epoch, config.n_epoch_adv, time.time() - step_time, mse_ls/i, vgg_ls/i, gan_ls/i, d_ls/i))
        G.save_weights(os.path.join(config.path_model, 'g_adv.h5')) 
        print('\n')
def train():
    # create folders to save result images and trained model
    save_dir_ginit = "samples/{}_ginit".format(tl.global_flag['mode'])
    save_dir_gan = "samples/{}_gan".format(tl.global_flag['mode'])
    tl.files.exists_or_mkdir(save_dir_ginit)
    tl.files.exists_or_mkdir(save_dir_gan)
    checkpoint_dir = "checkpoint"  # checkpoint_resize_conv
    tl.files.exists_or_mkdir(checkpoint_dir)

    # load dataset
    train_hr_img_list = sorted(
        tl.files.load_file_list(path=config.TRAIN.hr_img_path,
                                regx='.*.png',
                                printable=False))
    # train_lr_img_list = sorted(tl.files.load_file_list(path=config.TRAIN.lr_img_path, regx='.*.png', printable=False))
    # valid_hr_img_list = sorted(tl.files.load_file_list(path=config.VALID.hr_img_path, regx='.*.png', printable=False))
    # valid_lr_img_list = sorted(tl.files.load_file_list(path=config.VALID.lr_img_path, regx='.*.png', printable=False))

    ## If your machine have enough memory, please pre-load the whole train set.
    train_hr_imgs = tl.vis.read_images(train_hr_img_list,
                                       path=config.TRAIN.hr_img_path,
                                       n_threads=32)

    # dataset API and augmentation
    def generator_train():
        for img in train_hr_imgs:
            yield img

    def _map_fn_train(img):
        hr_patch = tf.image.random_crop(img, [384, 384, 3])
        hr_patch = hr_patch / (255. / 2.)
        hr_patch = hr_patch - 1.
        hr_patch = tf.image.random_flip_left_right(hr_patch)
        lr_patch = tf.image.resize(hr_patch, size=[96, 96])
        return lr_patch, hr_patch

    train_ds = tf.data.Dataset.from_generator(generator_train,
                                              output_types=(tf.float32))
    train_ds = train_ds.map(_map_fn_train,
                            num_parallel_calls=multiprocessing.cpu_count())
    train_ds = train_ds.repeat(n_epoch_init + n_epoch)
    train_ds = train_ds.shuffle(shuffle_buffer_size)
    train_ds = train_ds.prefetch(buffer_size=4096)
    train_ds = train_ds.batch(batch_size)
    # value = train_ds.make_one_shot_iterator().get_next()

    # obtain models
    G = get_G((batch_size, None, None, 3))  # (None, 96, 96, 3)
    D = get_D((batch_size, None, None, 3))  # (None, 384, 384, 3)
    VGG = tl.models.vgg19(pretrained=True, end_with='pool4', mode='static')

    print(G)
    print(D)
    print(VGG)

    # G.load_weights(checkpoint_dir + '/g_{}.h5'.format(tl.global_flag['mode'])) # in case you want to restore a training
    # D.load_weights(checkpoint_dir + '/d_{}.h5'.format(tl.global_flag['mode']))

    lr_v = tf.Variable(lr_init)
    g_optimizer_init = tf.optimizers.Adam(
        lr_v, beta_1=beta1)  #.minimize(mse_loss, var_list=g_vars)
    g_optimizer = tf.optimizers.Adam(
        lr_v, beta_1=beta1)  #.minimize(g_loss, var_list=g_vars)
    d_optimizer = tf.optimizers.Adam(
        lr_v, beta_1=beta1)  #.minimize(d_loss, var_list=d_vars)

    G.train()
    D.train()
    VGG.train()

    # initialize learning (G)
    n_step_epoch = round(n_epoch_init // batch_size)
    for step, (lr_patchs, hr_patchs) in enumerate(train_ds):
        step_time = time.time()
        with tf.GradientTape() as tape:
            fake_hr_patchs = G(lr_patchs)
            mse_loss = tl.cost.mean_squared_error(fake_hr_patchs,
                                                  hr_patchs,
                                                  is_mean=True)
        grad = tape.gradient(mse_loss, G.weights)
        g_optimizer_init.apply_gradients(zip(grad, G.weights))
        step += 1
        epoch = step // n_step_epoch
        print("Epoch: [{}/{}] step: [{}/{}] time: {}s, mse: {} ".format(
            epoch, n_epoch_init, step, n_step_epoch,
            time.time() - step_time, mse_loss))
        if (epoch != 0) and (epoch % 10 == 0):
            tl.vis.save_images(
                fake_hr_patchs.numpy(), [ni, ni],
                save_dir_gan + '/train_g_init_{}.png'.format(epoch))

    # adversarial learning (G, D)
    n_step_epoch = round(n_epoch // batch_size)
    for step, (lr_patchs, hr_patchs) in train_ds:
        with tf.GradientTape() as tape:
            fake_patchs = G(lr_patchs)
            logits_fake = D(fake_patchs)
            logits_real = D(hr_patchs)
            feature_fake = VGG((fake_patchs + 1) / 2.)
            feature_real = VGG((hr_patchs + 1) / 2.)
            d_loss1 = tl.cost.sigmoid_cross_entropy(logits_real,
                                                    tf.ones_like(logits_real))
            d_loss2 = tl.cost.sigmoid_cross_entropy(logits_fake,
                                                    tf.zeros_like(logits_fake))
            d_loss = d_loss1 + d_loss2
            g_gan_loss = 1e-3 * tl.cost.sigmoid_cross_entropy(
                logits_fake, tf.ones_like(logits_fake))
            mse_loss = tl.cost.mean_squared_error(fake_patchs,
                                                  hr_patchs,
                                                  is_mean=True)
            vgg_loss = 2e-6 * tl.cost.mean_squared_error(
                feature_fake, feature_real, is_mean=True)
            g_loss = mse_loss + vgg_loss + g_gan_loss
        grad = tape.gradient(g_loss, G.trainable_weights)
        g_optimizer.apply_gradients(zip(grad, G.trainable_weights))
        grad = tape.gradient(d_loss, D.weights)
        d_optimizer.apply_gradients(zip(grad, D.trainable_weights))
        step += 1
        epoch = step // n_step_epoch
        print(
            "Epoch: [{}/{}] step: [{}/{}] time: {}s, g_loss(mse:{}, vgg:{}, adv:{}) d_loss: {}"
            .format(epoch, n_epoch_init, step, n_step_epoch,
                    time.time() - step_time, mse_loss, vgg_loss, g_gan_loss,
                    d_loss))

        # update learning rate
        if epoch != 0 and (epoch % decay_every == 0):
            new_lr_decay = lr_decay**(epoch // decay_every)
            lr_v.assign(lr_init * new_lr_decay)
            log = " ** new learning rate: %f (for GAN)" % (lr_init *
                                                           new_lr_decay)
            print(log)

        if (epoch != 0) and (epoch % 10 == 0):
            tl.vis.save_images(fake_hr_patchs.numpy(), [ni, ni],
                               save_dir_gan + '/train_g_{}.png'.format(epoch))
            G.save_weights(checkpoint_dir +
                           '/g_{}.h5'.format(tl.global_flag['mode']))
            D.save_weights(checkpoint_dir +
                           '/d_{}.h5'.format(tl.global_flag['mode']))
Ejemplo n.º 14
0
def train():
	size = [1080, 1920]
	aspect_ratio = size[1] / size[0]

	G = get_G((batch_size, 96, 96, 3))
	D = get_D((batch_size, 384, 384, 3))
	VGG = tl.models.vgg19(pretrained=True, end_with='pool4', mode='static')

	lr_v = tf.Variable(lr_init)
	g_optimizer_init = tf.optimizers.Adam(lr_v, beta_1=beta1)
	g_optimizer = tf.optimizers.Adam(lr_v, beta_1=beta1)
	d_optimizer = tf.optimizers.Adam(lr_v, beta_1=beta1)

	G.train()
	D.train()
	VGG.train()

	train_ds, test_ds, sample_ds = get_train_data()

	sample_folders = ['train_lr', 'train_hr', 'train_gen', 'test_lr', 'test_hr', 'test_gen', 'sample_lr', 'sample_gen']
	for sample_folder in sample_folders:
		tl.files.exists_or_mkdir(os.path.join(save_dir, sample_folder))
	
	# only take a certain amount of images to save
	test_lr_patchs, test_hr_patchs = next(iter(test_ds))
	valid_lr_imgs = []
	for i,lr_patchs in enumerate(sample_ds):
		valid_lr_img = lr_patchs.numpy()
		valid_lr_img = np.asarray(valid_lr_img, dtype=np.float32)
		valid_lr_img = valid_lr_img[np.newaxis,:,:,:]
		valid_lr_imgs.append(valid_lr_img)
		tl.vis.save_images(valid_lr_img, [1,1], os.path.join(save_dir, 'sample_lr', 'sample_lr_img_{}.jpg'.format(i)))

	tl.vis.save_images(test_lr_patchs.numpy(), [2, 4], os.path.join(save_dir, 'test_lr', 'test_lr.jpg'))
	tl.vis.save_images(test_hr_patchs.numpy(), [2, 4], os.path.join(save_dir, 'test_hr', 'test_hr.jpg'))

	# initialize learning (G)
	n_step_epoch = round(iteration_size // batch_size)
	for epoch in range(n_epoch_init):
		for step, (lr_patchs, hr_patchs) in enumerate(train_ds):
			if lr_patchs.shape[0] != batch_size: # if the remaining data in this epoch < batch_size
				break
			step_time = time.time()
			with tf.GradientTape() as tape:
				fake_hr_patchs = G(lr_patchs)
				mse_loss = tl.cost.mean_squared_error(fake_hr_patchs, hr_patchs, is_mean=True)
			grad = tape.gradient(mse_loss, G.trainable_weights)
			g_optimizer_init.apply_gradients(zip(grad, G.trainable_weights))
			print("Epoch: [{}/{}] step: [{}/{}] time: {:.3f}s, mse: {:.3f} ".format(
				epoch, n_epoch_init, step, n_step_epoch, time.time() - step_time, mse_loss))
		if (epoch != 0) and (epoch % 10 == 0):
			# save training result examples
			tl.vis.save_images(lr_patchs.numpy(), [2, 4], os.path.join(save_dir, 'train_lr', 'train_lr_init_{}.jpg'.format(epoch)))
			tl.vis.save_images(hr_patchs.numpy(), [2, 4], os.path.join(save_dir, 'train_hr', 'train_hr_init_{}.jpg'.format(epoch)))
			tl.vis.save_images(fake_hr_patchs.numpy(), [2, 4], os.path.join(save_dir, 'train_gen', 'train_gen_init_{}.jpg'.format(epoch)))
			# save test results (only save generated, since it's always the same images. Inputs are saved before the training loop)
			fake_hr_patchs = G(test_lr_patchs)
			tl.vis.save_images(fake_hr_patchs.numpy(), [2, 4], os.path.join(save_dir, 'test_gen', 'test_gen_init_{}.jpg'.format(epoch)))
			# save sample results (only save generated, since it's always the same images. Inputs are saved before the training loop)
			for i,lr_patchs in enumerate(valid_lr_imgs):
				fake_hr_patchs = G(lr_patchs)
				tl.vis.save_images(fake_hr_patchs.numpy(), [1,1], os.path.join(save_dir, 'sample_gen', 'sample_gen_init_{}_img_{}.jpg'.format(epoch, i)))

	## adversarial learning (G, D)
	n_step_epoch = round(iteration_size // batch_size)
	for epoch in range(n_epoch):
		for step, (lr_patchs, hr_patchs) in enumerate(train_ds):
			if lr_patchs.shape[0] != batch_size: # if the remaining data in this epoch < batch_size
				break
			step_time = time.time()
			with tf.GradientTape(persistent=True) as tape:
				fake_patchs = G(lr_patchs)
				logits_fake = D(fake_patchs)
				logits_real = D(hr_patchs)
				feature_fake = VGG((fake_patchs+1)/2.) # the pre-trained VGG uses the input range of [0, 1]
				feature_real = VGG((hr_patchs+1)/2.)
				d_loss1 = tl.cost.sigmoid_cross_entropy(logits_real, tf.ones_like(logits_real))
				d_loss2 = tl.cost.sigmoid_cross_entropy(logits_fake, tf.zeros_like(logits_fake))
				d_loss = d_loss1 + d_loss2
				g_gan_loss = 1e-3 * tl.cost.sigmoid_cross_entropy(logits_fake, tf.ones_like(logits_fake))
				mse_loss = tl.cost.mean_squared_error(fake_patchs, hr_patchs, is_mean=True)
				vgg_loss = 2e-6 * tl.cost.mean_squared_error(feature_fake, feature_real, is_mean=True)
				g_loss = mse_loss + vgg_loss + g_gan_loss
			grad = tape.gradient(g_loss, G.trainable_weights)
			g_optimizer.apply_gradients(zip(grad, G.trainable_weights))
			grad = tape.gradient(d_loss, D.trainable_weights)
			d_optimizer.apply_gradients(zip(grad, D.trainable_weights))
			print("Epoch: [{}/{}] step: [{}/{}] time: {:.3f}s, g_loss(mse:{:.3f}, vgg:{:.3f}, adv:{:.3f}) d_loss: {:.3f}".format(
				epoch, n_epoch, step, n_step_epoch, time.time() - step_time, mse_loss, vgg_loss, g_gan_loss, d_loss))

		# update the learning rate
		if epoch != 0 and (epoch % decay_every == 0):
			new_lr_decay = lr_decay**(epoch // decay_every)
			lr_v.assign(lr_init * new_lr_decay)
			log = " ** new learning rate: %f (for GAN)" % (lr_init * new_lr_decay)
			print(log)

		if (epoch != 0) and (epoch % 10 == 0):
			# save training result examples
			tl.vis.save_images(lr_patchs.numpy(), [2, 4], os.path.join(save_dir, 'train_lr', 'train_lr_{}.jpg'.format(epoch)))
			tl.vis.save_images(hr_patchs.numpy(), [2, 4], os.path.join(save_dir, 'train_hr', 'train_hr_{}.jpg'.format(epoch)))
			tl.vis.save_images(fake_patchs.numpy(), [2, 4], os.path.join(save_dir, 'train_gen', 'train_gen_{}.jpg'.format(epoch)))
			# save test results (only save generated, since it's always the same images. Inputs are saved before the training loop)
			fake_hr_patchs = G(test_lr_patchs)
			tl.vis.save_images(fake_hr_patchs.numpy(), [2, 4], os.path.join(save_dir, 'test_gen', 'test_gen_{}.jpg'.format(epoch)))
			# save sample results (only save generated, since it's always the same images. Inputs are saved before the training loop)
			# for i,lr_patchs in enumerate(valid_lr_imgs):
			# 	fake_hr_patchs = G(lr_patchs)
			# 	tl.vis.save_images(fake_hr_patchs.numpy(), [1,1], os.path.join(save_dir, 'sample_gen', 'sample_gen_init_{}_img_{}.jpg'.format(epoch, i)))


			G.save_weights(os.path.join(checkpoint_dir, f'g_epoch_{epoch}.h5'))
			D.save_weights(os.path.join(checkpoint_dir, f'd_epoch_{epoch}.h5'))

			G.save_weights(os.path.join(checkpoint_dir, 'g.h5'))
			D.save_weights(os.path.join(checkpoint_dir, 'd.h5'))
Ejemplo n.º 15
0
def train():
    # create folders to save result images and trained model
    save_dir_ginit = "samples/{}_ginit".format(tl.global_flag['mode'])
    save_dir_gan = "samples/{}_gan".format(tl.global_flag['mode'])
    tl.files.exists_or_mkdir(save_dir_ginit)
    tl.files.exists_or_mkdir(save_dir_gan)
    checkpoint_dir = "checkpoint"  # checkpoint_resize_conv
    tl.files.exists_or_mkdir(checkpoint_dir)

    # load dataset
    train_hr_img_list = glob.glob(os.path.join(config.TRAIN.hr_img_path, "*"))
    train_hr_img_list = [
        f for f in train_hr_img_list if os.path.splitext(f)[-1] == '.png'
    ]

    ## If your machine have enough memory, please pre-load the whole train set.
    #train_hr_imgs = tl.vis.read_images(train_hr_img_list, path=config.TRAIN.hr_img_path, n_threads=32)
    train_hr_imgs = []
    #use tensorflow to load image in form of RGB
    #    for f in train_hr_img_list:
    #        img=tf.io.read_file(f)
    #        img=tf.image.decode_png(img,channels=3,dtype=tf.uint16)
    #        max_value=np.max(img)
    #        print(max_value)
    #        train_hr_imgs.append(img)
    #use opencv to load image in form of BGR (Blue,Green,Red), we decide to use this way because we can use cv2.imshow() which support 16bit image
    for f in train_hr_img_list:
        img = cv2.imread(f, cv2.IMREAD_UNCHANGED)
        max_value = np.max(img)
        if max_value == 0:
            continue
        #print(max_value)
        train_hr_imgs.append(img)

    # dataset API and augmentation
    def generator_train():
        for img in train_hr_imgs:
            yield img

    def _map_fn_train(img):
        hr_patch = tf.image.random_crop(img, [96, 96, 3])
        # Randomly flip the image horizontally.
        hr_patch = tf.image.random_flip_left_right(hr_patch)
        hr_patch = tf.image.random_flip_up_down(hr_patch)
        # Randomly adjust hue, contrast and saturation.
        hr_patch = tf.image.random_hue(hr_patch, max_delta=0.05)
        hr_patch = tf.image.random_contrast(hr_patch, lower=0.3, upper=1.0)
        hr_patch = tf.image.random_brightness(hr_patch, max_delta=0.2)
        hr_patch = tf.image.random_saturation(hr_patch, lower=0.0, upper=2.0)
        hr_patch = tf.image.rot90(hr_patch, np.random.randint(1, 4))

        lr_patch = tf.floor(
            hr_patch / 4096.0
        )  #compress 16bit image to 4bit. caution!!! do not divide a int number,or you will get all zero result
        lr_patch = lr_patch * 4096.0  #padding 4bit with zeros to 8bit
        #make the value of pixel lies between [-1,1]
        hr_patch = hr_patch / (65535. / 2.)
        hr_patch = hr_patch - 1.

        lr_patch = lr_patch / (65535. / 2.)
        lr_patch = lr_patch - 1.
        return lr_patch, hr_patch

    train_ds = tf.data.Dataset.from_generator(generator_train,
                                              output_types=(tf.float32))
    train_ds = train_ds.map(_map_fn_train,
                            num_parallel_calls=multiprocessing.cpu_count())
    #train_ds = train_ds.repeat(n_epoch_init + n_epoch)
    #train_ds = train_ds.repeat(2)
    train_ds = train_ds.shuffle(shuffle_buffer_size)
    train_ds = train_ds.prefetch(buffer_size=4096)
    train_ds = train_ds.batch(batch_size)
    # value = train_ds.make_one_shot_iterator().get_next()

    # obtain models
    G = get_G((batch_size, None, None, 3))  # (None, 96, 96, 3)
    print('load VGG')
    VGG = tl.models.vgg19(pretrained=True, end_with='pool4', mode='static')
    '''print(G)
    print(VGG)'''

    # G.load_weights(checkpoint_dir + '/g_{}.h5'.format(tl.global_flag['mode'])) # in case you want to restore a training
    # D.load_weights(checkpoint_dir + '/d_{}.h5'.format(tl.global_flag['mode']))

    lr_v = tf.Variable(lr_init)
    g_optimizer_init = tf.optimizers.Adam(
        lr_v, beta_1=beta1)  #.minimize(mse_loss, var_list=g_vars)
    g_optimizer = tf.optimizers.Adam(
        lr_v, beta_1=beta1)  #.minimize(g_loss, var_list=g_vars)

    G.train()
    VGG.train()

    n_step_epoch = round(n_epoch_init // batch_size)
    for step, (lr_patchs, hr_patchs) in enumerate(train_ds):
        step_time = time.time()
        with tf.GradientTape() as tape:
            fake_hr_patchs = G(lr_patchs)
            mse_loss = tl.cost.mean_squared_error(fake_hr_patchs,
                                                  hr_patchs,
                                                  is_mean=True)
        grad = tape.gradient(mse_loss, G.trainable_weights)
        g_optimizer_init.apply_gradients(zip(grad, G.trainable_weights))
        step += 1
        epoch = step // n_step_epoch
        print("Epoch: [{}/{}] step: [{}/{}] time: {}s, mse: {} ".format(
            epoch, n_epoch_init, step, n_step_epoch,
            time.time() - step_time, mse_loss))
        #if (epoch != 0) and (epoch % 10 == 0):
        #tl.vis.save_images(fake_hr_patchs.numpy(), [ni, ni], save_dir_gan + '/train_g_init_{}.png'.format(epoch))

    # initialize learning (G)
    for epoch in range(120):
        train_ds = tf.data.Dataset.from_generator(generator_train,
                                                  output_types=(tf.float32))
        train_ds = train_ds.map(_map_fn_train,
                                num_parallel_calls=multiprocessing.cpu_count())
        #train_ds = train_ds.repeat(n_epoch_init + n_epoch)
        #train_ds = train_ds.repeat(1)
        train_ds = train_ds.shuffle(shuffle_buffer_size)
        train_ds = train_ds.prefetch(buffer_size=4096)
        train_ds = train_ds.batch(batch_size)
        n_step_epoch = round(n_epoch_init // batch_size)
        for step, (lr_patchs, hr_patchs) in enumerate(train_ds):
            step_time = time.time()
            with tf.GradientTape() as tape:
                fake_hr_patchs = G(lr_patchs)
                mse_loss = tl.cost.mean_squared_error(fake_hr_patchs,
                                                      hr_patchs,
                                                      is_mean=True)
                feature_fake = VGG((fake_hr_patchs + 1) / 2.)
                feature_real = VGG((hr_patchs + 1) / 2.)
                vgg_loss = 0.6 * tl.cost.mean_squared_error(
                    feature_fake, feature_real, is_mean=True)
                diff = fake_hr_patchs - lr_patchs
                range_loss = tf.reduce_sum(
                    tf.where((tf.greater(diff, range_step) | tf.less(diff, 0)),
                             tf.abs(diff) * out_weight, diff * in_weight))
                #range_loss=tf.reduce_sum(tf.where((tf.greater(diff,range_step) | tf.less(diff,0) ),out_weight,in_weight))
                g_loss = 0.000005 * range_loss + vgg_loss
            grad = tape.gradient(g_loss, G.trainable_weights)
            g_optimizer.apply_gradients(zip(grad, G.trainable_weights))

            #epoch = step//n_step_epoch
            if step % 100 == 0:
                print(
                    "Epoch: [{}/{}] step: [{}/{}] time: {}s, mse: {}  vgg_loss: {}"
                    .format(epoch, n_epoch_init, step, n_step_epoch,
                            time.time() - step_time, mse_loss, vgg_loss))
            step += 1
            if epoch != 0 and (epoch % decay_every == 0):
                new_lr_decay = lr_decay**(epoch // decay_every)
                lr_v.assign(lr_init * new_lr_decay)
                log = " ** new learning rate: %f (for GAN)" % (lr_init *
                                                               new_lr_decay)
                print(log)

            if (epoch != 0) and ((epoch + 1) % 5 == 0):
                #tl.vis.save_images(fake_hr_patchs.numpy(), [ni, ni], save_dir_gan + '/train_g_{}.png'.format(epoch))
                G.save_weights(checkpoint_dir +
                               '/g_{}.h5'.format(tl.global_flag['mode']))
Ejemplo n.º 16
0
def evaluate():
    ## create folders to save result images
    save_dir = "samples/{}".format(tl.global_flag['mode'])
    tl.files.exists_or_mkdir(save_dir)
    checkpoint_dir = "checkpoint"

    ###====================== PRE-LOAD DATA ===========================###

    valid_hr_img_list = glob.glob(os.path.join(config.VALID.hr_img_path, "*"))
    valid_hr_img_list = [
        f for f in valid_hr_img_list if os.path.splitext(f)[-1] == '.png'
    ]
    #valid_hr_imgs = tl.vis.read_images(valid_hr_img_list, path=config.VALID.hr_img_path, n_threads=32)
    valid_hr_imgs = []
    for f in valid_hr_img_list:
        #print(f)
        img = cv2.imread(f, cv2.IMREAD_UNCHANGED)
        #print(img[1:10,1:10,0])
        max_value = np.max(img)
        print('max_value:{}'.format(max_value))
        valid_hr_imgs.append(img)

    ###========================== DEFINE MODEL ============================###
    G = get_G([1, None, None, 3])
    #G.load_weights(checkpoint_dir + '/g_{}.h5'.format(tl.global_flag['mode']))
    G.load_weights(checkpoint_dir + '/g_{}.h5'.format('srgan'))
    G.eval()
    for imid in range(8):
        #imid =7

        valid_hr_img = valid_hr_imgs[imid]
        valid_hr_img = cv2.resize(valid_hr_img, (500, 218),
                                  interpolation=cv2.INTER_CUBIC)
        valid_lr_img = tf.floor(valid_hr_img / 4096.0)
        valid_lr_img = valid_lr_img * 4096.0

        valid_lr_img = (valid_lr_img / 32767.5) - 1  # rescale to [-1, 1]
        valid_lr_img = tf.cast(valid_lr_img, dtype=tf.float32)
        valid_lr_img_input = tf.reshape(valid_lr_img, [
            1, valid_lr_img.shape[0], valid_lr_img.shape[1],
            valid_lr_img.shape[2]
        ])
        # print(valid_lr_img.min(), valid_lr_img.max())
        #        print(valid_lr_img.shape)

        #        G = get_G([1, None, None, 3])
        #        #G.load_weights(checkpoint_dir + '/g_{}.h5'.format(tl.global_flag['mode']))
        #        G.load_weights(checkpoint_dir + '/g_{}.h5'.format('srgan'))
        #        G.eval()

        out = G(valid_lr_img_input).numpy()
        print(np.min(out), np.max(out))
        #        cv2.imshow('out',out[0])
        #        cv2.waitKey(0)

        #        print("[*] save images")
        #        tl.vis.save_image(out[0], save_dir + '/valid_gen_{}.png'.format(str(imid)))
        #        tl.vis.save_image(valid_lr_img, save_dir + '/valid_lr_{}.png'.format(str(imid)))
        #        tl.vis.save_image(valid_hr_img, save_dir + '/valid_hr_{}.png'.format(str(imid)))
        '''do not use tl.vis.save_image,for which do not support 16bit mode'''
        out = tf.cast((out[0] + 1) * 32767.5, dtype=tf.uint16).numpy()
        #out=out.eval()
        valid_lr_img = tf.cast((valid_lr_img + 1) * 32767.5,
                               dtype=tf.uint16).numpy()
        valid_hr_img = tf.cast(valid_hr_img, dtype=tf.uint16).numpy()
        #        cv2.imshow('out',out)
        #        cv2.waitKey(0)
        #        print(np.max(out))
        #        print(np.min(out))
        psnr_gen = psnr(np.float32(out), np.float32(valid_hr_img), 2)
        psnr_zp = psnr(np.float32(valid_lr_img), np.float32(valid_hr_img), 2)
        print('psnr_gen:{}   psnr_zp:{}'.format(psnr_gen, psnr_zp))
        cv2.imwrite(save_dir + '/valid_gen_{}.png'.format(str(imid)), out,
                    [int(cv2.IMWRITE_PNG_COMPRESSION), 0])
        cv2.imwrite(save_dir + '/valid_lr_{}.png'.format(str(imid)),
                    valid_lr_img, [int(cv2.IMWRITE_PNG_COMPRESSION), 0])
        cv2.imwrite(save_dir + '/valid_hr_{}.png'.format(str(imid)),
                    valid_hr_img, [int(cv2.IMWRITE_PNG_COMPRESSION), 0])
Ejemplo n.º 17
0
def train(end):
    G = get_G((batch_size, 28, 28, 3))
    D = get_D_Conditional((batch_size, 224, 224, 3))
    VGG = tl.models.VGG16(pretrained=True, end_with=end)

    lr_v = tf.Variable(lr_init)
    g_optimizer_init = tf.optimizers.Adam(lr_v, beta_1=beta1)
    g_optimizer = tf.optimizers.Adam(lr_v, beta_1=beta1)
    d_optimizer = tf.optimizers.Adam(lr_v, beta_1=beta1)

    G.train()
    D.train()
    VGG.train()
    train_ds, num_images, valid_ds, num_valid_images, maxes, mins = get_train_data(
    )

    ## initialize learning (G)
    print("Initialize Generator")

    mseloss = []
    PSNR = []
    SSIM = []
    mseloss_valid = []
    PSNR_valid = []
    SSIM_valid = []

    n_step_epoch = round(num_images // batch_size)
    for epoch in range(n_epoch_init):
        mse_avg = 0
        PSNR_avg = 0
        SSIM_avg = 0
        for step, (lr_patchs, hr_patchs) in enumerate(train_ds):

            if lr_patchs.shape[
                    0] != batch_size:  # if the remaining data in this epoch < batch_size
                break
            step_time = time.time()
            with tf.GradientTape() as tape:
                fake_hr_patchs = G(lr_patchs)
                mse_loss = tl.cost.mean_squared_error(fake_hr_patchs,
                                                      hr_patchs,
                                                      is_mean=True)
                mse_avg += mse_loss

                psnr = np.mean(
                    np.array(
                        tf.image.psnr(fake_hr_patchs, hr_patchs, max_val=1)))
                PSNR_avg += psnr

                ssim = np.mean(
                    np.array(
                        tf.image.ssim(fake_hr_patchs, hr_patchs, max_val=1)))
                SSIM_avg += ssim

            grad = tape.gradient(mse_loss, G.trainable_weights)
            g_optimizer_init.apply_gradients(zip(grad, G.trainable_weights))

            print(
                "Epoch: [{}/{}] step: [{}/{}] time: {:.3f}s, mse: {:.3f}, psnr: {:.3f}, ssim: {:.3f} "
                .format(epoch + 1, n_epoch_init, step + 1, n_step_epoch,
                        time.time() - step_time, mse_loss, psnr, ssim))
        if (epoch != 0) and (epoch % 10 == 0):
            #restore to original values before normalization
            '''fake_hr_patchs = fake_hr_patchs.numpy()
            hr_patchs = hr_patchs.numpy()
            fake_hr_patchs = np.add(np.multiply(fake_hr_patchs, np.subtract(maxes, mins)), mins);
            hr_patchs = np.add(np.multiply(hr_patchs, np.subtract(maxes, mins)), mins)'''

            save = np.concatenate([fake_hr_patchs, hr_patchs], axis=0)
            tl.vis.save_images(
                save, [4, 4],
                os.path.join(save_dir, 'train_g_init{}.png'.format(epoch)))

        mseloss.append(mse_avg / (step + 1))
        PSNR.append(PSNR_avg / (step + 1))
        SSIM.append(SSIM_avg / (step + 1))

        #validate
        mse_valid_loss = 0
        psnr_valid = 0
        ssim_valid = 0
        for step, (lr_patchs, hr_patchs) in enumerate(valid_ds):
            fake_hr_patchs = G(lr_patchs)
            mse_valid_loss += tl.cost.mean_squared_error(fake_hr_patchs,
                                                         hr_patchs,
                                                         is_mean=True)
            psnr_valid += np.mean(
                np.array(tf.image.psnr(fake_hr_patchs, hr_patchs, max_val=1)))
            ssim_valid += np.mean(
                np.array(tf.image.ssim(fake_hr_patchs, hr_patchs, max_val=1)))

        if (epoch != 0) and (epoch % 10 == 0):
            #restore to original values before normalization
            '''fake_hr_patchs = fake_hr_patchs.numpy()
          hr_patchs = hr_patchs.numpy()
          fake_hr_patchs = np.add(np.multiply(fake_hr_patchs, np.subtract(maxes, mins)), mins);
          hr_patchs = np.add(np.multiply(hr_patchs, np.subtract(maxes, mins)), mins);'''
            save = np.concatenate([fake_hr_patchs, hr_patchs], axis=0)
            tl.vis.save_images(
                save, [4, 4],
                os.path.join(save_dir, 'valid_g_init{}.png'.format(epoch)))

        mse_valid_loss /= (step + 1)
        mseloss_valid.append(mse_valid_loss)
        psnr_valid /= (step + 1)
        PSNR_valid.append(psnr_valid)
        ssim_valid /= (step + 1)
        SSIM_valid.append(ssim_valid)

        print("Validation MSE: ", mse_valid_loss.numpy(), "Validation PSNR: ",
              psnr_valid, "Validation SSIM: ", ssim_valid)
    '''plot stuff'''
    '''plt.figure()
    epochs = np.linspace(1, n_epoch_init*n_step_epoch, num = n_epoch_init*n_step_epoch);
    plt.title("Generator Initialization")
    plt.xlabel("Epoch")
    plt.ylabel("Metric")
    plt.plot(epochs, np.array(mseloss))
    plt.plot(epochs, np.array(PSNR))
    #plt.plot(epochs, SSIM)
    plt.legend(("MSE", " PSNR"))
    plt.show()'''
    np.save(save_dir + '/mse_init_train.npy', mseloss)
    np.save(save_dir + '/psnr_init_train.npy', PSNR)
    np.save(save_dir + '/ssim_init_train.npy', SSIM)
    np.save(save_dir + '/mse_init_valid.npy', mseloss_valid)
    np.save(save_dir + '/psnr_init_valid.npy', PSNR_valid)
    np.save(save_dir + '/ssim_init_valid.npy', SSIM_valid)

    ## adversarial learning (G, D)
    print("Adversarial Learning")

    mseloss = []
    mseloss_valid = []
    PSNR = []
    PSNR_valid = []
    SSIM = []
    SSIM_valid = []
    vggloss = []
    vggloss_valid = []
    advloss = []
    dloss = []
    min_val_mse = 9999
    max_val_psnr = 0
    max_val_ssim = 0
    n_step_epoch = round(num_images // batch_size)
    for epoch in range(n_epoch):
        mse_avg = 0
        PSNR_avg = 0
        vgg_avg = 0
        SSIM_avg = 0
        for step, (lr_patchs, hr_patchs) in enumerate(train_ds):
            if lr_patchs.shape[
                    0] != batch_size:  # if the remaining data in this epoch < batch_size
                break
            step_time = time.time()
            with tf.GradientTape(persistent=True) as tape:
                fake_patchs = G(lr_patchs)
                logits_fake = D([lr_patchs, fake_patchs])
                logits_real = D([lr_patchs, hr_patchs])

                d_acc = (
                    np.count_nonzero(tf.nn.sigmoid(logits_fake) < 0.5) +
                    np.count_nonzero(tf.nn.sigmoid(logits_real) > 0.5)) / 16
                feature_fake = VGG(
                    fake_patchs
                )  # the pre-trained VGG uses the input range of [0, 1]
                feature_real = VGG(hr_patchs)
                d_loss1 = tl.cost.sigmoid_cross_entropy(
                    logits_real, tf.ones_like(logits_real))
                d_loss2 = tl.cost.sigmoid_cross_entropy(
                    logits_fake, tf.zeros_like(logits_fake))
                d_loss = d_loss1 + d_loss2
                g_gan_loss = tl.cost.sigmoid_cross_entropy(
                    logits_fake, tf.ones_like(logits_fake))
                mse_loss = tl.cost.mean_squared_error(fake_patchs,
                                                      hr_patchs,
                                                      is_mean=True)
                vgg_loss = 6e-5 * tl.cost.mean_squared_error(
                    feature_fake, feature_real, is_mean=True)
                vgg_avg += vgg_loss
                g_loss = mse_loss + vgg_loss + g_gan_loss

                mse_avg += mse_loss
                advloss.append(g_gan_loss)
                dloss.append(d_loss)
                psnr = np.mean(
                    np.array(tf.image.psnr(fake_patchs, hr_patchs, max_val=1)))
                PSNR_avg += psnr
                ssim = np.mean(
                    np.array(tf.image.ssim(fake_patchs, hr_patchs, max_val=1)))
                SSIM_avg += ssim
            grad = tape.gradient(g_loss, G.trainable_weights)
            g_optimizer.apply_gradients(zip(grad, G.trainable_weights))
            grad = tape.gradient(d_loss, D.trainable_weights)
            d_optimizer.apply_gradients(zip(grad, D.trainable_weights))
            print(
                "Epoch: [{}/{}] step: [{}/{}] time: {:.3f}s, g_loss(mse:{:.3f}, vgg:{:.3f}, adv:{:.3f}) d_loss: {:.3f}, d_acc: {:.3f}, psnr: {:.3f}, ssim: {:.3f}"
                .format(epoch + 1, n_epoch, step + 1, n_step_epoch,
                        time.time() - step_time, mse_loss, vgg_loss,
                        g_gan_loss, d_loss, d_acc, psnr, ssim))

        if (epoch != 0) and (epoch % 10 == 0):
            #restore to original values before normalization
            '''fake_patchs = fake_patchs.numpy()
            hr_patchs = hr_patchs.numpy()
            fake_patchs = np.add(np.multiply(fake_patchs, np.subtract(maxes - mins)), mins);
            hr_patchs = np.add(np.multiply(hr_patchs, np.subtract(maxes - mins)), mins);'''
            save = np.concatenate([fake_patchs, hr_patchs], axis=0)
            tl.vis.save_images(
                save, [4, 4],
                os.path.join(save_dir, 'train_g{}.png'.format(epoch)))

        mseloss.append(mse_avg / (step + 1))
        PSNR.append(PSNR_avg / (step + 1))
        vggloss.append(vgg_avg / (step + 1))
        SSIM.append(SSIM_avg / (step + 1))

        #validate
        mse_valid_loss = 0
        vgg_valid_loss = 0
        psnr_valid = 0
        ssim_valid = 0
        d_acc = 0

        for step, (lr_patchs, hr_patchs) in enumerate(valid_ds):
            if lr_patchs.shape[
                    0] != batch_size:  # if the remaining data in this epoch < batch_size
                break
            fake_patchs = G(lr_patchs)
            logits_fake = D([lr_patchs, fake_patchs])
            logits_real = D([lr_patchs, hr_patchs])
            d_acc += ((np.count_nonzero(tf.nn.sigmoid(logits_fake) < 0.5) +
                       np.count_nonzero(tf.nn.sigmoid(logits_real) > 0.5)))
            feature_fake = VGG(
                fake_patchs
            )  # the pre-trained VGG uses the input range of [0, 1]
            feature_real = VGG(hr_patchs)
            d_loss1 = tl.cost.sigmoid_cross_entropy(logits_real,
                                                    tf.ones_like(logits_real))
            d_loss2 = tl.cost.sigmoid_cross_entropy(logits_fake,
                                                    tf.zeros_like(logits_fake))
            d_loss = d_loss1 + d_loss2
            g_gan_loss = tl.cost.sigmoid_cross_entropy(
                logits_fake, tf.ones_like(logits_fake))
            mse_valid_loss += tl.cost.mean_squared_error(fake_patchs,
                                                         hr_patchs,
                                                         is_mean=True)
            vgg_valid_loss += 6e-5 * tl.cost.mean_squared_error(
                feature_fake, feature_real, is_mean=True)

            g_loss = mse_loss + vgg_valid_loss + g_gan_loss
            psnr_valid += np.mean(
                np.array(tf.image.psnr(fake_patchs, hr_patchs, max_val=1)))
            ssim_valid += np.mean(
                np.array(tf.image.ssim(fake_patchs, hr_patchs, max_val=1)))

        mse_valid_loss /= (step + 1)
        mseloss_valid.append(mse_valid_loss)
        vgg_valid_loss /= (step + 1)
        vggloss_valid.append(vgg_valid_loss)
        psnr_valid /= (step + 1)
        PSNR_valid.append(psnr_valid)
        ssim_valid /= (step + 1)
        SSIM_valid.append(ssim_valid)

        d_acc /= (num_valid_images * 2)
        print("Validation MSE: ", mse_valid_loss.numpy(), "Validation PSNR: ",
              psnr_valid, "Validation SSIM", ssim_valid,
              "Validation Disc Accuracy: ", d_acc)

        if (epoch != 0) and (epoch % 10 == 0):
            #restore to original values before normalization
            '''fake_patchs = fake_patchs.numpy()
            hr_patchs = hr_patchs.numpy()
            fake_patchs = np.add(np.multiply(fake_patchs, np.subtract(maxes, mins)), mins);
            hr_patchs = np.add(np.multiply(hr_patchs, np.subtract(maxes, mins)), mins);'''
            save = np.concatenate([fake_patchs, hr_patchs], axis=0)
            tl.vis.save_images(
                save, [4, 4],
                os.path.join(save_dir, 'valid_g{}.png'.format(epoch)))

        #save models if metrics improve
        if (mse_valid_loss <= min_val_mse):
            print("Val loss improved from ", np.array(min_val_mse), " to ",
                  np.array(mse_valid_loss))
            min_val_mse = mse_valid_loss
            G.save_weights(os.path.join(checkpoint_dir, "g_mse_val.h5"))
            D.save_weights(os.path.join(checkpoint_dir, "d_mse_val.h5"))
        else:
            print("Val loss did not improve from ", np.array(min_val_mse))

        if (psnr_valid >= max_val_psnr):
            print("Val PSNR improved from ", max_val_psnr, " to ", psnr_valid)
            max_val_psnr = psnr_valid
            G.save_weights(os.path.join(checkpoint_dir, "g_psnr_val.h5"))
            D.save_weights(os.path.join(checkpoint_dir, "d_psnr_val.h5"))
        else:
            print("Val PSNR did not improve from ", max_val_psnr)

        if (ssim_valid >= max_val_ssim):
            print("Val SSIM improved from ", max_val_ssim, " to ", ssim_valid)
            max_val_ssim = ssim_valid
            G.save_weights(os.path.join(checkpoint_dir, "g_ssim_val.h5"))
            D.save_weights(os.path.join(checkpoint_dir, "d_ssim_val.h5"))
        else:
            print("Val SSIM did not improve from ", max_val_ssim)

        # update the learning rate
        if epoch != 0 and (epoch % decay_every == 0):
            new_lr_decay = lr_decay**(epoch // decay_every)
            lr_v.assign(lr_init * new_lr_decay)
            log = " ** new learning rate: %f (for GAN)" % (lr_init *
                                                           new_lr_decay)
            print(log)

    np.save(save_dir + '/mse_train.npy', mseloss)
    np.save(save_dir + '/psnr_train.npy', PSNR)
    np.save(save_dir + '/ssim_train.npy', SSIM)
    np.save(save_dir + '/advloss_train.npy', advloss)
    np.save(save_dir + '/discloss_train.npy', dloss)
    np.save(save_dir + '/vgg_train.npy', vggloss)
    np.save(save_dir + '/mse_valid.npy', mseloss_valid)
    np.save(save_dir + '/psnr_valid.npy', PSNR_valid)
    np.save(save_dir + '/ssim_valid.npy', SSIM_valid)
    np.save(save_dir + '/vgg_valid.npy', vggloss_valid)

    return checkpoint_dir, save_dir
Ejemplo n.º 18
0
def evaluate(timestamp, model):

    ###====================== PRE-LOAD DATA ===========================###
    # train_hr_img_list = sorted(tl.files.load_file_list(path=config.TRAIN.hr_img_path, regx='.*.png', printable=False))
    # train_lr_img_list = sorted(tl.files.load_file_list(path=config.TRAIN.lr_img_path, regx='.*.png', printable=False))
    # load dataset

    # load dataset

    train_path_1979_2 = config.TRAIN.hr_img_path + "2/CycleGAN_Data/ExtremeWeather/"
    train_path_1981_2 = "/content/drive/My Drive/ProjectX 2020/Data/1981/2/"
    train_path_1984_2 = "/content/drive/My Drive/ProjectX 2020/Data/1984/2/"

    train_path_1979_6 = config.TRAIN.hr_img_path + "6/"
    train_path_1981_6 = "/content/drive/My Drive/ProjectX 2020/Data/1981/6/"
    train_path_1984_6 = "/content/drive/My Drive/ProjectX 2020/Data/1984/6/"

    train_path_1979_8 = config.TRAIN.hr_img_path + "8/"
    train_path_1981_8 = "/content/drive/My Drive/ProjectX 2020/Data/1981/8/"
    train_path_1984_8 = "/content/drive/My Drive/ProjectX 2020/Data/1984/8/"

    print("Getting File Paths")
    #hr_img_list_1979_2 = (tl.files.load_file_list(train_path_1979_2, regx='.*.npy', printable=False))[0:500]
    #hr_img_list_1981_2 = (tl.files.load_file_list(train_path_1981_2, regx='.*.npy', printable=False))[0:500]
    hr_img_list_1984_2 = (tl.files.load_file_list(train_path_1984_2,
                                                  regx='.*.npy',
                                                  printable=False))[200:500]

    #hr_img_list_1979_6 = (tl.files.load_file_list(train_path_1979_6, regx='.*.npy', printable=False))
    #hr_img_list_1981_6 = (tl.files.load_file_list(train_path_1981_6, regx='.*.npy', printable=False))
    hr_img_list_1984_6 = (tl.files.load_file_list(train_path_1984_6,
                                                  regx='.*.npy',
                                                  printable=False))

    #hr_img_list_1979_8 = (tl.files.load_file_list(train_path_1979_8, regx='.*.npy', printable=False))
    #hr_img_list_1981_8 = (tl.files.load_file_list(train_path_1981_8, regx='.*.npy', printable=False))
    hr_img_list_1984_8 = (tl.files.load_file_list(train_path_1984_8,
                                                  regx='.*.npy',
                                                  printable=False))

    print("Loading Images")
    hr_imgs = []
    min0 = 999999999
    min1 = min0
    min2 = min0

    max0 = -999999999
    max1 = max0
    max2 = max0

    print("1984")
    for i in tqdm(range(len(hr_img_list_1984_2))):
        im = np.concatenate([
            np.load(train_path_1984_2 + hr_img_list_1984_2[i]),
            np.load(train_path_1984_6 + hr_img_list_1984_6[i]),
            np.load(train_path_1984_8 + hr_img_list_1984_8[i])
        ],
                            axis=2)
        min0 = min(min0, np.min(im[:, :, 0]))
        min1 = min(min1, np.min(im[:, :, 1]))
        min2 = min(min2, np.min(im[:, :, 2]))
        max0 = max(max0, np.max(im[:, :, 0]))
        max1 = max(max1, np.max(im[:, :, 1]))
        max2 = max(max2, np.max(im[:, :, 2]))
        hr_imgs.append(im)

    maxes = np.array([max0, max1, max2])
    mins = np.array([min0, min1, min2])
    print(maxes)
    print(mins)

    valid_imgs = hr_imgs[0:]
    v = len(valid_imgs)
    print("Number of Validation Images: ", v)

    def generator_valid():
        for im in valid_imgs:
            yield im

    def _map_fn_train(img):
        #hr_patch = tf.image.random_crop(img, [224, 224, 3])
        hr_patch = img
        hr_patch = tf.divide(tf.subtract(hr_patch, mins),
                             tf.cast(tf.subtract(maxes, mins), tf.float32))
        #min-max normalization
        hr_patch = tf.image.random_flip_left_right(hr_patch)
        hr_patch = tf.image.random_flip_up_down(hr_patch)
        lr_patch = tf.image.resize(hr_patch, size=[96, 144])
        return lr_patch, hr_patch

    valid_ds = tf.data.Dataset.from_generator(generator_valid,
                                              output_types=(tf.float32))
    valid_ds = valid_ds.map(_map_fn_train,
                            num_parallel_calls=multiprocessing.cpu_count())
    valid_ds = valid_ds.prefetch(buffer_size=2)
    valid_ds = valid_ds.batch(1)

    ###========================== DEFINE MODEL ============================###
    G = get_G([1, 96, 144, 3])
    G.load_weights(os.path.join(timestamp, "models/g_" + model + ".h5"))
    G.eval()
    PSNR = []
    #SSIM = [];
    MSE = []
    pred_psnr = 0
    pred_ssim = 0
    pred_vgg = 0
    bl_psnr = 0
    bl_ssim = 0
    bl_vgg = 0
    '''imid = 10;
    h, w = valid_imgs[imid].shape[0], valid_imgs[imid].shape[1]
    valid_img  = tf.divide(tf.subtract(valid_imgs[imid], mins[0]), tf.cast(tf.subtract(maxes[0], mins[0]), tf.float32));         #min-max normalization
    valid_lr_img = tf.image.resize(valid_img, [int(h/8), int(w/8)]);
    valid_sr_img = G(tf.expand_dims(valid_lr_img, 0))
    valid_bicu_img = np.expand_dims(cv2.resize(np.float32(valid_lr_img), (w, h), interpolation= cv2.INTER_CUBIC), 2)

    
    print("LR size: %s /  generated HR size: %s" % ((h/8, w/8), valid_sr_img.shape))
      
    pred_psnr += np.mean(np.array(tf.image.psnr(valid_sr_img, valid_img, max_val = 1)))
    pred_ssim += np.mean(np.array(tf.image.ssim(valid_sr_img, valid_img, max_val = 1)))
    bl_psnr += np.mean(np.array(tf.image.psnr(valid_bicu_img, valid_img, max_val = 1)))
    bl_ssim += np.mean(np.array(tf.image.ssim(tf.convert_to_tensor(valid_bicu_img), valid_img, max_val = 1)))

    print(valid_lr_img.shape)
    print(valid_bicu_img.shape)
    print(valid_img.shape)
    print(valid_sr_img.shape)

    valid_lr_img = valid_lr_img[0:29, 0:29]
    valid_bicu_img = valid_bicu_img[0:225, 0:225]
    valid_img = valid_img[0:225, 0:225]
    valid_sr_img = valid_sr_img[0, 0:225, 0:225, :]
    #tl.vis.save_image(valid_sr_img[0].numpy(), os.path.join(timestamp, 'samples/valid_gen' + str(imid) + '.png'))
    tl.vis.save_image(valid_lr_img.numpy(), os.path.join(timestamp, 'samples/8x_LR.png'))
    tl.vis.save_image(valid_bicu_img, os.path.join(timestamp, 'samples/8x_bicubic.png'))
    tl.vis.save_image(valid_img.numpy(), os.path.join(timestamp, 'samples/GT.png'))
    tl.vis.save_image(valid_sr_img.numpy(), os.path.join(timestamp, 'samples/8x_HL.png'))'''
    VGG = tl.models.VGG16(pretrained=True, end_with='conv2_1')
    VGG.train()

    for step, (lr_patchs, hr_patchs) in enumerate(valid_ds):
        print(step, " out of ", len(valid_imgs))
        h, w = hr_patchs.shape[1], hr_patchs.shape[2]
        #print(w, h)
        valid_sr_img = G(lr_patchs)
        #print(lr_patchs.shape)
        valid_bicu_img = cv2.resize(np.float32(lr_patchs[0]), (w, h),
                                    interpolation=cv2.INTER_CUBIC)

        #print("LR size: %s /  generated HR size: %s" % ((h/8, w/8), valid_sr_img.shape))

        pred_psnr += np.mean(
            np.array(tf.image.psnr(valid_sr_img, hr_patchs, max_val=1)))
        pred_ssim += np.mean(
            np.array(tf.image.ssim(valid_sr_img, hr_patchs, max_val=1)))
        pred_vgg += tl.cost.mean_squared_error(VGG(hr_patchs),
                                               VGG(valid_sr_img),
                                               is_mean=True)

        bl_psnr += np.mean(
            np.array(tf.image.psnr(valid_bicu_img, hr_patchs, max_val=1)))
        bl_ssim += np.mean(
            np.array(
                tf.image.ssim(tf.convert_to_tensor(valid_bicu_img),
                              hr_patchs,
                              max_val=1)))
        bl_vgg += tl.cost.mean_squared_error(VGG(hr_patchs),
                                             VGG(valid_bicu_img),
                                             is_mean=True)

        triplet = np.concatenate([
            valid_bicu_img,
            np.zeros((768, 1, 3)), valid_sr_img[0].numpy(),
            np.zeros((768, 1, 3)), hr_patchs[0].numpy()
        ],
                                 axis=1)
        #tl.vis.save_image(valid_sr_img[0].numpy(), os.path.join(timestamp, 'samples/valid_gen' + str(imid) + '.png'))
        #tl.vis.save_image(valid_lr_img.numpy(), os.path.join(timestamp, 'samples/valid_lr' + str(imid) + '.png'))
        #tl.vis.save_image(valid_img.numpy(), os.path.join(timestamp, 'samples/valid_hr' + str(imid) + '.png'))
        tl.vis.save_image(
            triplet,
            os.path.join(timestamp, 'samples/triplet' + str(step) + '.png'))
    print(pred_psnr / v)
    print(pred_ssim / v)
    print(pred_vgg / v)
    print(bl_psnr / v)
    print(bl_ssim / v)
    print(bl_vgg / v)
Ejemplo n.º 19
0
    parser.add_argument('--model',
                        type=str,
                        default='g110.h5',
                        help='Model name')
    parser.add_argument('--data',
                        type=str,
                        default='Set14',
                        help='Test dataset')
    parser.add_argument('--loc',
                        type=str,
                        default="resized",
                        help='Location on data')
    args = parser.parse_args()

    if (args.model == 'g_srgan.npz'):
        G = get_G([1, None, None, 3])
        G.load_weights(os.path.join(checkpoint_dir, 'g_srgan.npz'))
    else:
        G = get_espcn([1, None, None, 3])
        G.load_weights(os.path.join(checkpoint_dir, args.model))
    G.eval()
    total_time = 0
    if (args.type == 'multi'):
        path = os.path.join(output_dir,
                            args.data + '_' + args.model.split('.')[0])
        os.mkdir(path)
        data_HR = glob.glob(os.path.join('test', args.data + '/*'))
        c = len(args.data) + 6
        for i in data_HR:
            fileName = i[c::]
            outputName = fileName.split('.')[0] + '.png'
Ejemplo n.º 20
0
def train():
    G = get_G((batch_size, 96, 96, 3))
    D = get_D((batch_size, 384, 384, 3))
    VGG = tl.models.vgg19(pretrained=True, end_with='pool4', mode='static')

    lr_v = tf.Variable(lr_init)
    g_optimizer_init = tf.optimizers.Adam(lr_v, beta_1=beta1)
    g_optimizer = tf.optimizers.Adam(lr_v, beta_1=beta1)
    d_optimizer = tf.optimizers.Adam(lr_v, beta_1=beta1)

    G.train()
    D.train()
    VGG.train()

    train_ds = get_train_data()

    ## initialize learning (G)
    n_step_epoch = round(n_epoch_init // batch_size)
    for epoch in range(n_epoch_init):
        for step, (lr_patchs, hr_patchs) in enumerate(train_ds):
            if lr_patchs.shape[0] != batch_size: # if the remaining data in this epoch < batch_size
                break
            step_time = time.time()
            with tf.GradientTape() as tape:
                fake_hr_patchs = G(lr_patchs)
                mse_loss = tl.cost.mean_squared_error(fake_hr_patchs, hr_patchs, is_mean=True)
            grad = tape.gradient(mse_loss, G.trainable_weights)
            g_optimizer_init.apply_gradients(zip(grad, G.trainable_weights))
            print("Epoch: [{}/{}] step: [{}/{}] time: {:.3f}s, mse: {:.3f} ".format(
                epoch, n_epoch_init, step, n_step_epoch, time.time() - step_time, mse_loss))
        if (epoch != 0) and (epoch % 10 == 0):
            tl.vis.save_images(fake_hr_patchs.numpy(), [2, 4], os.path.join(save_dir, 'train_g_init_{}.png'.format(epoch)))

    ## adversarial learning (G, D)
    n_step_epoch = round(n_epoch // batch_size)
    for epoch in range(n_epoch):
        for step, (lr_patchs, hr_patchs) in enumerate(train_ds):
            if lr_patchs.shape[0] != batch_size: # if the remaining data in this epoch < batch_size
                break
            step_time = time.time()
            with tf.GradientTape(persistent=True) as tape:
                fake_patchs = G(lr_patchs)
                logits_fake = D(fake_patchs)
                logits_real = D(hr_patchs)
                feature_fake = VGG((fake_patchs+1)/2.) # the pre-trained VGG uses the input range of [0, 1]
                feature_real = VGG((hr_patchs+1)/2.)
                d_loss1 = tl.cost.sigmoid_cross_entropy(logits_real, tf.ones_like(logits_real))
                d_loss2 = tl.cost.sigmoid_cross_entropy(logits_fake, tf.zeros_like(logits_fake))
                d_loss = d_loss1 + d_loss2
                g_gan_loss = 1e-3 * tl.cost.sigmoid_cross_entropy(logits_fake, tf.ones_like(logits_fake))
                mse_loss = tl.cost.mean_squared_error(fake_patchs, hr_patchs, is_mean=True)
                vgg_loss = 2e-6 * tl.cost.mean_squared_error(feature_fake, feature_real, is_mean=True)
                g_loss = mse_loss + vgg_loss + g_gan_loss
            grad = tape.gradient(g_loss, G.trainable_weights)
            g_optimizer.apply_gradients(zip(grad, G.trainable_weights))
            grad = tape.gradient(d_loss, D.trainable_weights)
            d_optimizer.apply_gradients(zip(grad, D.trainable_weights))
            print("Epoch: [{}/{}] step: [{}/{}] time: {:.3f}s, g_loss(mse:{:.3f}, vgg:{:.3f}, adv:{:.3f}) d_loss: {:.3f}".format(
                epoch, n_epoch_init, step, n_step_epoch, time.time() - step_time, mse_loss, vgg_loss, g_gan_loss, d_loss))

        # update the learning rate
        if epoch != 0 and (epoch % decay_every == 0):
            new_lr_decay = lr_decay**(epoch // decay_every)
            lr_v.assign(lr_init * new_lr_decay)
            log = " ** new learning rate: %f (for GAN)" % (lr_init * new_lr_decay)
            print(log)

        if (epoch != 0) and (epoch % 10 == 0):
            tl.vis.save_images(fake_patchs.numpy(), [2, 4], os.path.join(save_dir, 'train_g_{}.png'.format(epoch)))
            G.save_weights(os.path.join(checkpoint_dir, 'g.h5'))
            D.save_weights(os.path.join(checkpoint_dir, 'd.h5'))
Ejemplo n.º 21
0
def train():
    images, images_path = get_celebA(flags.output_size, flags.n_epoch,
                                     flags.batch_size)
    G = get_G([None, flags.dim_z])
    Base = get_base(
        [None, flags.output_size, flags.output_size, flags.n_channel])
    D = get_D([None, 4096])
    Q = get_Q([None, 4096])

    G.train()
    Base.train()
    D.train()
    Q.train()

    g_optimizer = tf.optimizers.Adam(learning_rate=flags.G_learning_rate,
                                     beta_1=flags.beta_1)
    d_optimizer = tf.optimizers.Adam(learning_rate=flags.D_learning_rate,
                                     beta_1=flags.beta_1)

    n_step_epoch = int(len(images_path) // flags.batch_size)
    his_g_loss = []
    his_d_loss = []
    his_mutual = []
    count = 0

    for epoch in range(flags.n_epoch):
        for step, batch_images in enumerate(images):
            count += 1
            if batch_images.shape[0] != flags.batch_size:
                break
            step_time = time.time()
            with tf.GradientTape(persistent=True) as tape:
                z, c = gen_noise()
                fake = Base(G(z))
                fake_logits = D(fake)
                fake_cat = Q(fake)
                real_logits = D(Base(batch_images))

                d_loss_fake = tl.cost.sigmoid_cross_entropy(
                    output=fake_logits,
                    target=tf.zeros_like(fake_logits),
                    name='d_loss_fake')
                d_loss_real = tl.cost.sigmoid_cross_entropy(
                    output=real_logits,
                    target=tf.ones_like(real_logits),
                    name='d_loss_real')
                d_loss = d_loss_fake + d_loss_real

                g_loss = tl.cost.sigmoid_cross_entropy(
                    output=fake_logits,
                    target=tf.ones_like(fake_logits),
                    name='g_loss_fake')

                mutual = calc_mutual(fake_cat, c)
                g_loss += mutual

            grad = tape.gradient(g_loss,
                                 G.trainable_weights + Q.trainable_weights)
            g_optimizer.apply_gradients(
                zip(grad, G.trainable_weights + Q.trainable_weights))
            grad = tape.gradient(d_loss,
                                 D.trainable_weights + Base.trainable_weights)
            d_optimizer.apply_gradients(
                zip(grad, D.trainable_weights + Base.trainable_weights))
            del tape
            print(
                f"Epoch: [{epoch}/{flags.n_epoch}] [{step}/{n_step_epoch}] took: {time.time()-step_time:.3f}, d_loss: {d_loss:.5f}, g_loss: {g_loss:.5f}, mutual: {mutual:.5f}"
            )

            if count % flags.save_every_it == 1:
                his_g_loss.append(g_loss)
                his_d_loss.append(d_loss)
                his_mutual.append(mutual)

        plt.plot(his_d_loss)
        plt.plot(his_g_loss)
        plt.plot(his_mutual)
        plt.legend(['D_Loss', 'G_Loss', 'Mutual_Info'])
        plt.xlabel(f'Iterations / {flags.save_every_it}')
        plt.ylabel('Loss')
        plt.savefig(f'{flags.result_dir}/loss.jpg')
        plt.clf()
        plt.close()

        G.save_weights(f'{flags.checkpoint_dir}/G.npz', format='npz')
        D.save_weights(f'{flags.checkpoint_dir}/D.npz', format='npz')
        G.eval()
        for k in range(flags.n_categorical):
            z = gen_eval_noise(k, flags.n_sample)
            result = G(z)
            tl.visualize.save_images(convert(result.numpy()),
                                     [flags.n_sample, flags.dim_categorical],
                                     f'result/train_{epoch}_{k}.png')
        G.train()