def train():
    ## create folders to save result images and trained model
    save_dir_gan = "samples/{}_gan".format(tl.global_flag['mode'])  #srresnet
    tl.files.exists_or_mkdir(save_dir_gan)
    checkpoint_dir = "checkpoint"
    tl.files.exists_or_mkdir(checkpoint_dir)

    ###====================== PRE-LOAD DATA ===========================###
    train_hr_img_list = sorted(
        tl.files.load_file_list(path=config.TRAIN.hr_img_path,
                                regx='.*.png',
                                printable=False))
    train_lr_img_list = sorted(
        tl.files.load_file_list(path=config.TRAIN.lr_img_path,
                                regx='.*.png',
                                printable=False))
    valid_hr_img_list = sorted(
        tl.files.load_file_list(path=config.VALID.hr_img_path,
                                regx='.*.png',
                                printable=False))
    valid_lr_img_list = sorted(
        tl.files.load_file_list(path=config.VALID.lr_img_path,
                                regx='.*.png',
                                printable=False))

    ## If your machine has enough memory, please pre-load the whole train set.
    print("reading images")

    train_hr_imgs = []

    for img__ in train_hr_img_list:
        image_loaded = scipy.misc.imread(os.path.join(config.TRAIN.hr_img_path,
                                                      img__),
                                         mode='L')
        image_loaded = image_loaded.reshape(
            (image_loaded.shape[0], image_loaded.shape[1], 1))
        train_hr_imgs.append(image_loaded)

    print(type(train_hr_imgs), len(train_hr_img_list))

    ###========================== DEFINE MODEL ============================###
    ## train inference

    t_image = tf.placeholder('float32', [batch_size, 56, 56, 1],
                             name='t_image_input_to_SRGAN_generator')
    t_target_image = tf.placeholder('float32', [batch_size, 224, 224, 1],
                                    name='t_target_image')

    print("t_image:", tf.shape(t_image))
    print("t_target_image:", tf.shape(t_target_image))

    net_g = SRGAN_g(t_image, is_train=True,
                    reuse=False)  #SRGAN_g is the SRResNet portion of the GAN

    print("net_g.outputs:", tf.shape(net_g.outputs))

    net_g.print_params(False)
    net_g.print_layers()

    ## vgg inference. 0, 1, 2, 3 BILINEAR NEAREST BICUBIC AREA
    t_target_image_224 = tf.image.resize_images(
        t_target_image, size=[224, 224], method=0, align_corners=False
    )  # resize_target_image_for_vgg # http://tensorlayer.readthedocs.io/en/latest/_modules/tensorlayer/layers.html#UpSampling2dLayer
    t_predict_image_224 = tf.image.resize_images(
        net_g.outputs, size=[224, 224], method=0,
        align_corners=False)  # resize_generate_image_for_vgg

    ## Added as VGG works for RGB and expects 3 channels.
    t_target_image_224 = tf.image.grayscale_to_rgb(t_target_image_224)
    t_predict_image_224 = tf.image.grayscale_to_rgb(t_predict_image_224)

    print("net_g.outputs:", tf.shape(net_g.outputs))
    print("t_predict_image_224:", tf.shape(t_predict_image_224))

    net_vgg, vgg_target_emb = Vgg19_simple_api((t_target_image_224 + 1) / 2,
                                               reuse=False)
    _, vgg_predict_emb = Vgg19_simple_api((t_predict_image_224 + 1) / 2,
                                          reuse=True)

    ## test inference
    net_g_test = SRGAN_g(t_image, is_train=False, reuse=True)

    # ###========================== DEFINE TRAIN OPS ==========================###
    mse_loss = tl.cost.mean_squared_error(net_g.outputs,
                                          t_target_image,
                                          is_mean=True)
    vgg_loss = 2e-6 * tl.cost.mean_squared_error(
        vgg_predict_emb.outputs, vgg_target_emb.outputs, is_mean=True)

    g_loss = mse_loss + vgg_loss

    mse_loss_summary = tf.summary.scalar('Generator MSE loss', mse_loss)
    vgg_loss_summary = tf.summary.scalar('Generator VGG loss', vgg_loss)
    g_loss_summary = tf.summary.scalar('Generator total loss', g_loss)

    g_vars = tl.layers.get_variables_with_name('SRGAN_g', True, True)

    with tf.variable_scope('learning_rate'):
        lr_v = tf.Variable(lr_init, trainable=False)
    ## SRResNet
    g_optim = tf.train.AdamOptimizer(lr_v,
                                     beta1=beta1).minimize(g_loss,
                                                           var_list=g_vars)

    ###========================== RESTORE MODEL =============================###
    sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
                                            log_device_placement=False))
    tl.layers.initialize_global_variables(sess)
    tl.files.load_and_assign_npz(sess=sess,
                                 name=checkpoint_dir +
                                 '/g_{}.npz'.format(tl.global_flag['mode']),
                                 network=net_g)

    ###============================= LOAD VGG ===============================###
    vgg19_npy_path = "vgg19.npy"
    if not os.path.isfile(vgg19_npy_path):
        print(
            "Please download vgg19.npz from : https://github.com/machrisaa/tensorflow-vgg"
        )
        exit()
    npz = np.load(vgg19_npy_path, encoding='latin1').item()

    params = []
    for val in sorted(npz.items()):
        W = np.asarray(val[1][0])
        b = np.asarray(val[1][1])
        print("  Loading %s: %s, %s" % (val[0], W.shape, b.shape))
        params.extend([W, b])
    tl.files.assign_params(sess, params, net_vgg)

    ###============================= TRAINING ===============================###
    ## use first `batch_size` of train set to have a quick test during training
    sample_imgs = train_hr_imgs[0:batch_size]

    print("sample_imgs size:", len(sample_imgs), sample_imgs[0].shape)

    sample_imgs_224 = tl.prepro.threading_data(sample_imgs,
                                               fn=crop_sub_imgs_fn,
                                               is_random=False)
    print('sample HR sub-image:', sample_imgs_224.shape, sample_imgs_224.min(),
          sample_imgs_224.max())
    sample_imgs_56 = tl.prepro.threading_data(sample_imgs_224,
                                              fn=downsample_fn)
    print('sample LR sub-image:', sample_imgs_56.shape, sample_imgs_56.min(),
          sample_imgs_56.max())
    tl.vis.save_images(sample_imgs_56, [ni, ni],
                       save_dir_gan + '/_train_sample_56.png')
    tl.vis.save_images(sample_imgs_224, [ni, ni],
                       save_dir_gan + '/_train_sample_224.png')
    #tl.vis.save_image(sample_imgs_96[0],  save_dir_gan + '/_train_sample_96.png')
    #tl.vis.save_image(sample_imgs_384[0],save_dir_gan + '/_train_sample_384.png')

    ###========================= train SRResNet  =========================###

    merged_summary_generator = tf.summary.merge(
        [mse_loss_summary, vgg_loss_summary,
         g_loss_summary])  #g_gan_loss_summary
    summary_generator_writer = tf.summary.FileWriter("./log/train/generator")

    learning_rate_writer = tf.summary.FileWriter("./log/train/learning_rate")

    count = 0
    for epoch in range(0, n_epoch + 1):
        ## update learning rate
        if epoch != 0 and (epoch % decay_every == 0):
            new_lr_decay = lr_decay**(epoch // decay_every)
            sess.run(tf.assign(lr_v, lr_init * new_lr_decay))
            log = " ** new learning rate: %f (for GAN)" % (lr_init *
                                                           new_lr_decay)
            print(log)

            learning_rate_writer.add_summary(
                tf.Summary(value=[
                    tf.Summary.Value(tag="Learning_rate per epoch",
                                     simple_value=(lr_init * new_lr_decay)),
                ]), (epoch))

        elif epoch == 0:
            sess.run(tf.assign(lr_v, lr_init))
            log = " ** init lr: %f  decay_every_init: %d, lr_decay: %f (for GAN)" % (
                lr_init, decay_every, lr_decay)
            print(log)

            learning_rate_writer.add_summary(
                tf.Summary(value=[
                    tf.Summary.Value(tag="Learning_rate per epoch",
                                     simple_value=lr_init),
                ]), (epoch))

        epoch_time = time.time()
        total_g_loss, n_iter = 0, 0

        ## If your machine cannot load all images into memory, you should use
        ## this one to load batch of images while training.
        # random.shuffle(train_hr_img_list)
        # for idx in range(0, len(train_hr_img_list), batch_size):
        #     step_time = time.time()
        #     b_imgs_list = train_hr_img_list[idx : idx + batch_size]
        #     b_imgs = tl.prepro.threading_data(b_imgs_list, fn=get_imgs_fn, path=config.TRAIN.hr_img_path)
        #     b_imgs_384 = tl.prepro.threading_data(b_imgs, fn=crop_sub_imgs_fn, is_random=True)
        #     b_imgs_96 = tl.prepro.threading_data(b_imgs_384, fn=downsample_fn)

        ## If your machine have enough memory, please pre-load the whole train set.

        loss_per_batch = []

        mse_loss_summary_per_epoch = []
        vgg_loss_summary_per_epoch = []
        g_loss_summary_per_epoch = []

        for idx in range(0, len(train_hr_imgs), batch_size):
            step_time = time.time()
            b_imgs_224 = tl.prepro.threading_data(train_hr_imgs[idx:idx +
                                                                batch_size],
                                                  fn=crop_sub_imgs_fn,
                                                  is_random=True)
            b_imgs_56 = tl.prepro.threading_data(b_imgs_224, fn=downsample_fn)

            summary_pb = tf.summary.Summary()

            ## update G
            errG, errM, errV, _, generator_summary = sess.run(
                [
                    g_loss, mse_loss, vgg_loss, g_optim,
                    merged_summary_generator
                ], {
                    t_image: b_imgs_56,
                    t_target_image: b_imgs_224
                })  #g_ga_loss

            summary_pb = tf.summary.Summary()
            summary_pb.ParseFromString(generator_summary)

            generator_summaries = {}
            for val in summary_pb.value:
                # Assuming all summaries are scalars.
                generator_summaries[val.tag] = val.simple_value

            mse_loss_summary_per_epoch.append(
                generator_summaries['Generator_MSE_loss'])
            vgg_loss_summary_per_epoch.append(
                generator_summaries['Generator_VGG_loss'])
            g_loss_summary_per_epoch.append(
                generator_summaries['Generator_total_loss'])

            print(
                "Epoch [%2d/%2d] %4d time: %4.4fs, g_loss: %.8f (mse: %.6f vgg: %.6f)"
                % (epoch, n_epoch, n_iter, time.time() - step_time, errG, errM,
                   errV))
            total_g_loss += errG
            n_iter += 1

        log = "[*] Epoch: [%2d/%2d] time: %4.4fs, g_loss: %.8f" % (
            epoch, n_epoch, time.time() - epoch_time / n_iter,
            total_g_loss / n_iter)
        print(log)

        #####
        #
        # logging generator summary
        #
        ######

        summary_generator_writer.add_summary(
            tf.Summary(value=[
                tf.Summary.Value(tag="Generator_MSE_loss per epoch",
                                 simple_value=np.mean(
                                     mse_loss_summary_per_epoch)),
            ]), (epoch))

        summary_generator_writer.add_summary(
            tf.Summary(value=[
                tf.Summary.Value(tag="Generator_VGG_loss per epoch",
                                 simple_value=np.mean(
                                     vgg_loss_summary_per_epoch)),
            ]), (epoch))

        summary_generator_writer.add_summary(
            tf.Summary(value=[
                tf.Summary.Value(tag="Generator_total_loss per epoch",
                                 simple_value=np.mean(
                                     g_loss_summary_per_epoch)),
            ]), (epoch))

        out = sess.run(net_g_test.outputs, {t_image: sample_imgs_56})
        print("[*] save images")
        tl.vis.save_image(out[0], save_dir_gan + '/train_%d.png' % epoch)

        ## save model
        if (epoch != 0) and (epoch % 3 == 0):

            tl.files.save_npz(
                net_g.all_params,
                name=checkpoint_dir +
                '/g_{}_{}.npz'.format(tl.global_flag['mode'], epoch),
                sess=sess)
示例#2
0
def train():
    ## create folders to save result images and trained model
    save_dir_ginit = "samples/{}_ginit".format(tl.global_flag['mode'])
    save_dir_gan = "samples/{}_gan".format(tl.global_flag['mode'])
    tl.files.exists_or_mkdir(save_dir_ginit)
    tl.files.exists_or_mkdir(save_dir_gan)
    checkpoint_dir = "checkpoint"  # checkpoint_resize_conv
    tl.files.exists_or_mkdir(checkpoint_dir)

    ###====================== PRE-LOAD DATA ===========================###
    train_hr_img_list = sorted(
        tl.files.load_file_list(path=config.TRAIN.hr_img_path,
                                regx='.*.png',
                                printable=False))
    train_lr_img_list = sorted(
        tl.files.load_file_list(path=config.TRAIN.lr_img_path,
                                regx='.*.png',
                                printable=False))
    valid_hr_img_list = sorted(
        tl.files.load_file_list(path=config.VALID.hr_img_path,
                                regx='.*.png',
                                printable=False))
    valid_lr_img_list = sorted(
        tl.files.load_file_list(path=config.VALID.lr_img_path,
                                regx='.*.png',
                                printable=False))

    ## If your machine have enough memory, please pre-load the whole train set.
    train_hr_imgs = tl.vis.read_images(train_hr_img_list,
                                       path=config.TRAIN.hr_img_path,
                                       n_threads=32)
    # for im in train_hr_imgs:
    #     print(im.shape)
    # valid_lr_imgs = tl.vis.read_images(valid_lr_img_list, path=config.VALID.lr_img_path, n_threads=32)
    # for im in valid_lr_imgs:
    #     print(im.shape)
    # valid_hr_imgs = tl.vis.read_images(valid_hr_img_list, path=config.VALID.hr_img_path, n_threads=32)
    # for im in valid_hr_imgs:
    #     print(im.shape)
    # exit()

    ###========================== DEFINE MODEL ============================###
    ## train inference
    t_image = tf.placeholder('float32', [batch_size, 96, 96, 3],
                             name='t_image_input_to_SRGAN_generator')
    t_target_image = tf.placeholder('float32', [batch_size, 384, 384, 3],
                                    name='t_target_image')

    net_g = SRGAN_g(t_image, is_train=True, reuse=False)
    net_d, logits_real = SRGAN_d(t_target_image, is_train=True, reuse=False)
    _, logits_fake = SRGAN_d(net_g.outputs, is_train=True, reuse=True)

    net_g.print_params(False)
    net_g.print_layers()
    net_d.print_params(False)
    net_d.print_layers()

    ## vgg inference. 0, 1, 2, 3 BILINEAR NEAREST BICUBIC AREA
    t_target_image_224 = tf.image.resize_images(
        t_target_image, size=[224, 224], method=0, align_corners=False
    )  # resize_target_image_for_vgg # http://tensorlayer.readthedocs.io/en/latest/_modules/tensorlayer/layers.html#UpSampling2dLayer
    t_predict_image_224 = tf.image.resize_images(
        net_g.outputs, size=[224, 224], method=0,
        align_corners=False)  # resize_generate_image_for_vgg

    net_vgg, vgg_target_emb = Vgg19_simple_api((t_target_image_224 + 1) / 2,
                                               reuse=False)
    _, vgg_predict_emb = Vgg19_simple_api((t_predict_image_224 + 1) / 2,
                                          reuse=True)

    ## test inference
    net_g_test = SRGAN_g(t_image, is_train=False, reuse=True)

    # ###========================== DEFINE TRAIN OPS ==========================###
    d_loss1 = tl.cost.sigmoid_cross_entropy(logits_real,
                                            tf.ones_like(logits_real),
                                            name='d1')
    d_loss2 = tl.cost.sigmoid_cross_entropy(logits_fake,
                                            tf.zeros_like(logits_fake),
                                            name='d2')
    d_loss = d_loss1 + d_loss2

    g_gan_loss = 1e-3 * tl.cost.sigmoid_cross_entropy(
        logits_fake, tf.ones_like(logits_fake), name='g')
    mse_loss = tl.cost.mean_squared_error(net_g.outputs,
                                          t_target_image,
                                          is_mean=True)
    vgg_loss = 2e-6 * tl.cost.mean_squared_error(
        vgg_predict_emb.outputs, vgg_target_emb.outputs, is_mean=True)

    g_loss = mse_loss + vgg_loss + g_gan_loss

    g_vars = tl.layers.get_variables_with_name('SRGAN_g', True, True)
    d_vars = tl.layers.get_variables_with_name('SRGAN_d', True, True)

    with tf.variable_scope('learning_rate'):
        lr_v = tf.Variable(lr_init, trainable=False)
    ## Pretrain
    g_optim_init = tf.train.AdamOptimizer(lr_v, beta1=beta1).minimize(
        mse_loss, var_list=g_vars)
    ## SRGAN
    g_optim = tf.train.AdamOptimizer(lr_v,
                                     beta1=beta1).minimize(g_loss,
                                                           var_list=g_vars)
    d_optim = tf.train.AdamOptimizer(lr_v,
                                     beta1=beta1).minimize(d_loss,
                                                           var_list=d_vars)

    ###========================== RESTORE MODEL =============================###
    sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
                                            log_device_placement=False))
    tl.layers.initialize_global_variables(sess)
    if tl.files.load_and_assign_npz(
            sess=sess,
            name=checkpoint_dir + '/g_{}.npz'.format(tl.global_flag['mode']),
            network=net_g) is None:
        tl.files.load_and_assign_npz(
            sess=sess,
            name=checkpoint_dir +
            '/g_{}_init.npz'.format(tl.global_flag['mode']),
            network=net_g)
    tl.files.load_and_assign_npz(sess=sess,
                                 name=checkpoint_dir +
                                 '/d_{}.npz'.format(tl.global_flag['mode']),
                                 network=net_d)

    ###============================= LOAD VGG ===============================###
    vgg19_npy_path = "vgg19.npy"
    if not os.path.isfile(vgg19_npy_path):
        print(
            "Please download vgg19.npz from : https://github.com/machrisaa/tensorflow-vgg"
        )
        exit()
    npz = np.load(vgg19_npy_path, encoding='latin1').item()

    params = []
    for val in sorted(npz.items()):
        W = np.asarray(val[1][0])
        b = np.asarray(val[1][1])
        print("  Loading %s: %s, %s" % (val[0], W.shape, b.shape))
        params.extend([W, b])
    tl.files.assign_params(sess, params, net_vgg)
    # net_vgg.print_params(False)
    # net_vgg.print_layers()

    ###============================= TRAINING ===============================###
    ## use first `batch_size` of train set to have a quick test during training
    sample_imgs = train_hr_imgs[0:batch_size]
    # sample_imgs = tl.vis.read_images(train_hr_img_list[0:batch_size], path=config.TRAIN.hr_img_path, n_threads=32) # if no pre-load train set
    sample_imgs_384 = tl.prepro.threading_data(sample_imgs,
                                               fn=crop_sub_imgs_fn,
                                               is_random=False)
    print('sample HR sub-image:', sample_imgs_384.shape, sample_imgs_384.min(),
          sample_imgs_384.max())
    sample_imgs_96 = tl.prepro.threading_data(sample_imgs_384,
                                              fn=downsample_fn)
    print('sample LR sub-image:', sample_imgs_96.shape, sample_imgs_96.min(),
          sample_imgs_96.max())
    tl.vis.save_images(sample_imgs_96, [ni, ni],
                       save_dir_ginit + '/_train_sample_96.png')
    tl.vis.save_images(sample_imgs_384, [ni, ni],
                       save_dir_ginit + '/_train_sample_384.png')
    tl.vis.save_images(sample_imgs_96, [ni, ni],
                       save_dir_gan + '/_train_sample_96.png')
    tl.vis.save_images(sample_imgs_384, [ni, ni],
                       save_dir_gan + '/_train_sample_384.png')

    ###========================= initialize G ====================###
    ## fixed learning rate
    sess.run(tf.assign(lr_v, lr_init))
    print(" ** fixed learning rate: %f (for init G)" % lr_init)
    for epoch in range(0, n_epoch_init + 1):
        epoch_time = time.time()
        total_mse_loss, n_iter = 0, 0

        ## If your machine cannot load all images into memory, you should use
        ## this one to load batch of images while training.
        # random.shuffle(train_hr_img_list)
        # for idx in range(0, len(train_hr_img_list), batch_size):
        #     step_time = time.time()
        #     b_imgs_list = train_hr_img_list[idx : idx + batch_size]
        #     b_imgs = tl.prepro.threading_data(b_imgs_list, fn=get_imgs_fn, path=config.TRAIN.hr_img_path)
        #     b_imgs_384 = tl.prepro.threading_data(b_imgs, fn=crop_sub_imgs_fn, is_random=True)
        #     b_imgs_96 = tl.prepro.threading_data(b_imgs_384, fn=downsample_fn)

        ## If your machine have enough memory, please pre-load the whole train set.
        for idx in range(0, len(train_hr_imgs), batch_size):
            step_time = time.time()
            b_imgs_384 = tl.prepro.threading_data(train_hr_imgs[idx:idx +
                                                                batch_size],
                                                  fn=crop_sub_imgs_fn,
                                                  is_random=True)
            b_imgs_96 = tl.prepro.threading_data(b_imgs_384, fn=downsample_fn)
            ## update G
            errM, _ = sess.run([mse_loss, g_optim_init], {
                t_image: b_imgs_96,
                t_target_image: b_imgs_384
            })
            print("Epoch [%2d/%2d] %4d time: %4.4fs, mse: %.8f " %
                  (epoch, n_epoch_init, n_iter, time.time() - step_time, errM))
            total_mse_loss += errM
            n_iter += 1
        log = "[*] Epoch: [%2d/%2d] time: %4.4fs, mse: %.8f" % (
            epoch, n_epoch_init, time.time() - epoch_time,
            total_mse_loss / n_iter)
        print(log)

        ## quick evaluation on train set
        if (epoch != 0) and (epoch % 10 == 0):
            out = sess.run(net_g_test.outputs, {
                t_image: sample_imgs_96
            })  #; print('gen sub-image:', out.shape, out.min(), out.max())
            print("[*] save images")
            tl.vis.save_images(out, [ni, ni],
                               save_dir_ginit + '/train_%d.png' % epoch)

        ## save model
        if (epoch != 0) and (epoch % 10 == 0):
            tl.files.save_npz(net_g.all_params,
                              name=checkpoint_dir +
                              '/g_{}_init.npz'.format(tl.global_flag['mode']),
                              sess=sess)

    ###========================= train GAN (SRGAN) =========================###
    for epoch in range(0, n_epoch + 1):
        ## update learning rate
        if epoch != 0 and (epoch % decay_every == 0):
            new_lr_decay = lr_decay**(epoch // decay_every)
            sess.run(tf.assign(lr_v, lr_init * new_lr_decay))
            log = " ** new learning rate: %f (for GAN)" % (lr_init *
                                                           new_lr_decay)
            print(log)
        elif epoch == 0:
            sess.run(tf.assign(lr_v, lr_init))
            log = " ** init lr: %f  decay_every_init: %d, lr_decay: %f (for GAN)" % (
                lr_init, decay_every, lr_decay)
            print(log)

        epoch_time = time.time()
        total_d_loss, total_g_loss, n_iter = 0, 0, 0

        ## If your machine cannot load all images into memory, you should use
        ## this one to load batch of images while training.
        # random.shuffle(train_hr_img_list)
        # for idx in range(0, len(train_hr_img_list), batch_size):
        #     step_time = time.time()
        #     b_imgs_list = train_hr_img_list[idx : idx + batch_size]
        #     b_imgs = tl.prepro.threading_data(b_imgs_list, fn=get_imgs_fn, path=config.TRAIN.hr_img_path)
        #     b_imgs_384 = tl.prepro.threading_data(b_imgs, fn=crop_sub_imgs_fn, is_random=True)
        #     b_imgs_96 = tl.prepro.threading_data(b_imgs_384, fn=downsample_fn)

        ## If your machine have enough memory, please pre-load the whole train set.
        for idx in range(0, len(train_hr_imgs), batch_size):
            step_time = time.time()
            b_imgs_384 = tl.prepro.threading_data(train_hr_imgs[idx:idx +
                                                                batch_size],
                                                  fn=crop_sub_imgs_fn,
                                                  is_random=True)
            b_imgs_96 = tl.prepro.threading_data(b_imgs_384, fn=downsample_fn)
            ## update D
            errD, _ = sess.run([d_loss, d_optim], {
                t_image: b_imgs_96,
                t_target_image: b_imgs_384
            })
            ## update G
            errG, errM, errV, errA, _ = sess.run(
                [g_loss, mse_loss, vgg_loss, g_gan_loss, g_optim], {
                    t_image: b_imgs_96,
                    t_target_image: b_imgs_384
                })
            print(
                "Epoch [%2d/%2d] %4d time: %4.4fs, d_loss: %.8f g_loss: %.8f (mse: %.6f vgg: %.6f adv: %.6f)"
                % (epoch, n_epoch, n_iter, time.time() - step_time, errD, errG,
                   errM, errV, errA))
            total_d_loss += errD
            total_g_loss += errG
            n_iter += 1

        log = "[*] Epoch: [%2d/%2d] time: %4.4fs, d_loss: %.8f g_loss: %.8f" % (
            epoch, n_epoch, time.time() - epoch_time, total_d_loss / n_iter,
            total_g_loss / n_iter)
        print(log)

        ## quick evaluation on train set
        if (epoch != 0) and (epoch % 10 == 0):
            out = sess.run(net_g_test.outputs, {
                t_image: sample_imgs_96
            })  #; print('gen sub-image:', out.shape, out.min(), out.max())
            print("[*] save images")
            tl.vis.save_images(out, [ni, ni],
                               save_dir_gan + '/train_%d.png' % epoch)

        ## save model
        if (epoch != 0) and (epoch % 10 == 0):
            tl.files.save_npz(net_g.all_params,
                              name=checkpoint_dir +
                              '/g_{}.npz'.format(tl.global_flag['mode']),
                              sess=sess)
            tl.files.save_npz(net_d.all_params,
                              name=checkpoint_dir +
                              '/d_{}.npz'.format(tl.global_flag['mode']),
                              sess=sess)
示例#3
0
def train():
    ## create folders to save result images and trained model
    save_dir_gan = samples_path + "gan"
    tl.files.exists_or_mkdir(save_dir_gan)
    tl.files.exists_or_mkdir(checkpoint_path)

    ###====================== PRE-LOAD DATA ===========================###
    valid_hr_img_list = sorted(
        tl.files.load_file_list(path=valid_hr_img_path,
                                regx='.*\.(bmp|png|webp|jpg)',
                                printable=False))

    ###========================== DEFINE MODEL ============================###
    ## train inference
    sample_t_image = tf.compat.v1.placeholder(
        'float32', [sample_batch_size, 96, 96, 3],
        name='sample_t_image_input_to_SRGAN_generator')
    t_image = tf.compat.v1.placeholder('float32', [batch_size, 96, 96, 3],
                                       name='t_image_input_to_SRGAN_generator')
    t_target_image = tf.compat.v1.placeholder('float32',
                                              [batch_size, 384, 384, 3],
                                              name='t_target_image')

    net_g = SRGAN_g(t_image, is_train=True, reuse=False)
    net_d, logits_real = SRGAN_d(t_target_image, is_train=True, reuse=False)
    _, logits_fake = SRGAN_d(net_g.outputs, is_train=True, reuse=True)

    net_g.print_params(False)
    net_g.print_layers()
    net_d.print_params(False)
    net_d.print_layers()

    ## test inference
    net_g_test = SRGAN_g(sample_t_image, is_train=False, reuse=True)

    # ###========================== DEFINE TRAIN OPS ==========================###

    # MAE Loss
    mae_loss = tf.reduce_mean(tf.map_fn(tf.abs,
                                        t_target_image - net_g.outputs))

    # GAN Loss
    d_loss = 0.5 * (
        tf.reduce_mean(
            tf.square(logits_real - tf.reduce_mean(logits_fake) - 1)) +
        tf.reduce_mean(
            tf.square(logits_fake - tf.reduce_mean(logits_real) + 1)))
    g_gan_loss = 0.5 * (
        tf.reduce_mean(
            tf.square(logits_real - tf.reduce_mean(logits_fake) + 1)) +
        tf.reduce_mean(
            tf.square(logits_fake - tf.reduce_mean(logits_real) - 1)))

    g_loss = 1e-1 * g_gan_loss + mae_loss

    d_real = tf.reduce_mean(logits_real)
    d_fake = tf.reduce_mean(logits_fake)

    with tf.variable_scope('learning_rate'):
        learning_rate_var = tf.Variable(learning_rate, trainable=False)

    g_vars = tl.layers.get_variables_with_name('SRGAN_g', True, True)
    d_vars = tl.layers.get_variables_with_name('SRGAN_d', True, True)

    ## SRGAN
    g_optim = tf.compat.v1.train.AdamOptimizer(
        learning_rate=learning_rate_var).minimize(g_loss, var_list=g_vars)
    d_optim = tf.compat.v1.train.AdamOptimizer(
        learning_rate=learning_rate_var).minimize(d_loss, var_list=d_vars)

    ###========================== RESTORE MODEL =============================###
    sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
                                            log_device_placement=False))
    sess.run(tf.variables_initializer(tf.global_variables()))
    tl.files.load_and_assign_npz(sess=sess,
                                 name=checkpoint_path + 'g.npz',
                                 network=net_g)
    tl.files.load_and_assign_npz(sess=sess,
                                 name=checkpoint_path + 'd.npz',
                                 network=net_d)

    ###============================= TRAINING ===============================###
    sample_imgs = tl.prepro.threading_data(
        valid_hr_img_list[0:sample_batch_size],
        fn=get_imgs_fn,
        path=valid_hr_img_path)
    sample_imgs_384 = tl.prepro.threading_data(sample_imgs,
                                               fn=crop_sub_imgs_fn,
                                               is_random=False)
    print('sample HR sub-image:', sample_imgs_384.shape, sample_imgs_384.min(),
          sample_imgs_384.max())
    sample_imgs_96 = tl.prepro.threading_data(sample_imgs_384,
                                              fn=downsample_fn)
    print('sample LR sub-image:', sample_imgs_96.shape, sample_imgs_96.min(),
          sample_imgs_96.max())
    save_images(sample_imgs_96, [ni, ni], save_file_format,
                save_dir_gan + '/_train_sample_96')
    save_images(sample_imgs_384, [ni, ni], save_file_format,
                save_dir_gan + '/_train_sample_384')

    ###========================= train GAN =========================###
    sess.run(tf.assign(learning_rate_var, learning_rate))
    for epoch in range(0, n_epoch_gan + 1):
        epoch_time = time.time()
        total_d_loss, total_g_loss_mae, total_g_loss_gan, n_iter = 0, 0, 0, 0

        train_hr_img_list = load_deep_file_list(path=train_hr_img_path,
                                                regx='.*\.(bmp|png|webp|jpg)',
                                                recursive=True,
                                                printable=False)
        random.shuffle(train_hr_img_list)

        list_length = len(train_hr_img_list)
        print("Number of images: %d" % (list_length))

        if list_length % batch_size != 0:
            train_hr_img_list += train_hr_img_list[0:batch_size -
                                                   list_length % batch_size:1]

        list_length = len(train_hr_img_list)
        print("Length of list: %d" % (list_length))

        for idx in range(0, list_length, batch_size):
            step_time = time.time()
            b_imgs_list = train_hr_img_list[idx:idx + batch_size]
            b_imgs = tl.prepro.threading_data(b_imgs_list,
                                              fn=get_imgs_fn,
                                              path=train_hr_img_path)
            b_imgs_384 = tl.prepro.threading_data(b_imgs,
                                                  fn=crop_data_augment_fn,
                                                  is_random=True)
            b_imgs_96 = tl.prepro.threading_data(b_imgs_384, fn=downsample_fn)
            b_imgs_384 = tl.prepro.threading_data(b_imgs_384, fn=rescale_m1p1)

            ## update D
            errD, d_r, d_f, _ = sess.run([d_loss, d_real, d_fake, d_optim], {
                t_image: b_imgs_96,
                t_target_image: b_imgs_384
            })
            ## update G
            errM, errA, _, _ = sess.run(
                [mae_loss, g_gan_loss, g_loss, g_optim], {
                    t_image: b_imgs_96,
                    t_target_image: b_imgs_384
                })
            print(
                "Epoch[%2d/%2d] %4d time: %4.2fs d_loss: %.8f g_loss_mae: %.8f g_loss_gan: %.8f d_r: %.8f d_f: %.8f"
                % (epoch, n_epoch_gan, n_iter, time.time() - step_time, errD,
                   errM, errA, d_r, d_f))
            total_d_loss += errD
            total_g_loss_mae += errM
            total_g_loss_gan += errA
            n_iter += 1

        log = (
            "[*] Epoch[%2d/%2d] time: %4.2fs d_loss: %.8f g_loss_mae: %.8f g_loss_gan: %.8f"
            % (epoch, n_epoch_gan, time.time() - epoch_time, total_d_loss /
               n_iter, total_g_loss_mae / n_iter, total_g_loss_gan / n_iter))
        print(log)

        ## quick evaluation on train set
        out = sess.run(net_g_test.outputs, {sample_t_image: sample_imgs_96})
        print("[*] save images")
        save_images(out, [ni, ni], save_file_format,
                    save_dir_gan + '/train_%d' % epoch)

        ## save model
        tl.files.save_npz(net_g.all_params,
                          name=checkpoint_path + 'g.npz',
                          sess=sess)
        tl.files.save_npz(net_d.all_params,
                          name=checkpoint_path + 'd.npz',
                          sess=sess)
示例#4
0
文件: main.py 项目: TowardSun/srgan
def train():
    ## create folders to save result images and trained model
    save_dir_ginit = "samples/{}_ginit".format(tl.global_flag['mode'])
    save_dir_gan = "samples/{}_gan".format(tl.global_flag['mode'])
    tl.files.exists_or_mkdir(save_dir_ginit)
    tl.files.exists_or_mkdir(save_dir_gan)
    checkpoint_dir = "checkpoint"  # checkpoint_resize_conv
    tl.files.exists_or_mkdir(checkpoint_dir)

    ###====================== PRE-LOAD DATA ===========================###
    train_hr_img_list = sorted(tl.files.load_file_list(path=config.TRAIN.hr_img_path, regx='.*.png', printable=False))
    train_lr_img_list = sorted(tl.files.load_file_list(path=config.TRAIN.lr_img_path, regx='.*.png', printable=False))
    valid_hr_img_list = sorted(tl.files.load_file_list(path=config.VALID.hr_img_path, regx='.*.png', printable=False))
    valid_lr_img_list = sorted(tl.files.load_file_list(path=config.VALID.lr_img_path, regx='.*.png', printable=False))

    ## If your machine have enough memory, please pre-load the whole train set.
    train_hr_imgs = tl.vis.read_images(train_hr_img_list, path=config.TRAIN.hr_img_path, n_threads=32)
    # for im in train_hr_imgs:
    #     print(im.shape)
    # valid_lr_imgs = tl.vis.read_images(valid_lr_img_list, path=config.VALID.lr_img_path, n_threads=32)
    # for im in valid_lr_imgs:
    #     print(im.shape)
    # valid_hr_imgs = tl.vis.read_images(valid_hr_img_list, path=config.VALID.hr_img_path, n_threads=32)
    # for im in valid_hr_imgs:
    #     print(im.shape)
    # exit()

    ###========================== DEFINE MODEL ============================###
    ## train inference
    t_image = tf.placeholder('float32', [batch_size, 96, 96, 3], name='t_image_input_to_SRGAN_generator')
    t_target_image = tf.placeholder('float32', [batch_size, 384, 384, 3], name='t_target_image')

    net_g = SRGAN_g(t_image, is_train=True, reuse=False)
    net_d, logits_real = SRGAN_d(t_target_image, is_train=True, reuse=False)
    _, logits_fake = SRGAN_d(net_g.outputs, is_train=True, reuse=True)

    net_g.print_params(False)
    net_g.print_layers()
    net_d.print_params(False)
    net_d.print_layers()

    ## vgg inference. 0, 1, 2, 3 BILINEAR NEAREST BICUBIC AREA
    t_target_image_224 = tf.image.resize_images(
        t_target_image, size=[224, 224], method=0,
        align_corners=False)  # resize_target_image_for_vgg # http://tensorlayer.readthedocs.io/en/latest/_modules/tensorlayer/layers.html#UpSampling2dLayer
    t_predict_image_224 = tf.image.resize_images(net_g.outputs, size=[224, 224], method=0, align_corners=False)  # resize_generate_image_for_vgg

    net_vgg, vgg_target_emb = Vgg19_simple_api((t_target_image_224 + 1) / 2, reuse=False)
    _, vgg_predict_emb = Vgg19_simple_api((t_predict_image_224 + 1) / 2, reuse=True)

    ## test inference
    net_g_test = SRGAN_g(t_image, is_train=False, reuse=True)

    # ###========================== DEFINE TRAIN OPS ==========================###
    d_loss1 = tl.cost.sigmoid_cross_entropy(logits_real, tf.ones_like(logits_real), name='d1')
    d_loss2 = tl.cost.sigmoid_cross_entropy(logits_fake, tf.zeros_like(logits_fake), name='d2')
    d_loss = d_loss1 + d_loss2

    g_gan_loss = 1e-3 * tl.cost.sigmoid_cross_entropy(logits_fake, tf.ones_like(logits_fake), name='g')
    mse_loss = tl.cost.mean_squared_error(net_g.outputs, t_target_image, is_mean=True)
    vgg_loss = 2e-6 * tl.cost.mean_squared_error(vgg_predict_emb.outputs, vgg_target_emb.outputs, is_mean=True)

    g_loss = mse_loss + vgg_loss + g_gan_loss

    g_vars = tl.layers.get_variables_with_name('SRGAN_g', True, True)
    d_vars = tl.layers.get_variables_with_name('SRGAN_d', True, True)

    with tf.variable_scope('learning_rate'):
        lr_v = tf.Variable(lr_init, trainable=False)
    ## Pretrain
    g_optim_init = tf.train.AdamOptimizer(lr_v, beta1=beta1).minimize(mse_loss, var_list=g_vars)
    ## SRGAN
    g_optim = tf.train.AdamOptimizer(lr_v, beta1=beta1).minimize(g_loss, var_list=g_vars)
    d_optim = tf.train.AdamOptimizer(lr_v, beta1=beta1).minimize(d_loss, var_list=d_vars)

    ###========================== RESTORE MODEL =============================###
    sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False))
    tl.layers.initialize_global_variables(sess)
    if tl.files.load_and_assign_npz(sess=sess, name=checkpoint_dir + '/g_{}.npz'.format(tl.global_flag['mode']), network=net_g) is False:
        tl.files.load_and_assign_npz(sess=sess, name=checkpoint_dir + '/g_{}_init.npz'.format(tl.global_flag['mode']), network=net_g)
    tl.files.load_and_assign_npz(sess=sess, name=checkpoint_dir + '/d_{}.npz'.format(tl.global_flag['mode']), network=net_d)

    ###============================= LOAD VGG ===============================###
    vgg19_npy_path = "vgg19.npy"
    if not os.path.isfile(vgg19_npy_path):
        print("Please download vgg19.npz from : https://github.com/machrisaa/tensorflow-vgg")
        exit()
    npz = np.load(vgg19_npy_path, encoding='latin1').item()

    params = []
    for val in sorted(npz.items()):
        W = np.asarray(val[1][0])
        b = np.asarray(val[1][1])
        print("  Loading %s: %s, %s" % (val[0], W.shape, b.shape))
        params.extend([W, b])
    tl.files.assign_params(sess, params, net_vgg)
    # net_vgg.print_params(False)
    # net_vgg.print_layers()

    ###============================= TRAINING ===============================###
    ## use first `batch_size` of train set to have a quick test during training
    sample_imgs = train_hr_imgs[0:batch_size]
    # sample_imgs = tl.vis.read_images(train_hr_img_list[0:batch_size], path=config.TRAIN.hr_img_path, n_threads=32) # if no pre-load train set
    sample_imgs_384 = tl.prepro.threading_data(sample_imgs, fn=crop_sub_imgs_fn, is_random=False)
    print('sample HR sub-image:', sample_imgs_384.shape, sample_imgs_384.min(), sample_imgs_384.max())
    sample_imgs_96 = tl.prepro.threading_data(sample_imgs_384, fn=downsample_fn)
    print('sample LR sub-image:', sample_imgs_96.shape, sample_imgs_96.min(), sample_imgs_96.max())
    tl.vis.save_images(sample_imgs_96, [ni, ni], save_dir_ginit + '/_train_sample_96.png')
    tl.vis.save_images(sample_imgs_384, [ni, ni], save_dir_ginit + '/_train_sample_384.png')
    tl.vis.save_images(sample_imgs_96, [ni, ni], save_dir_gan + '/_train_sample_96.png')
    tl.vis.save_images(sample_imgs_384, [ni, ni], save_dir_gan + '/_train_sample_384.png')

    ###========================= initialize G ====================###
    ## fixed learning rate
    sess.run(tf.assign(lr_v, lr_init))
    print(" ** fixed learning rate: %f (for init G)" % lr_init)
    for epoch in range(0, n_epoch_init + 1):
        epoch_time = time.time()
        total_mse_loss, n_iter = 0, 0

        ## If your machine cannot load all images into memory, you should use
        ## this one to load batch of images while training.
        # random.shuffle(train_hr_img_list)
        # for idx in range(0, len(train_hr_img_list), batch_size):
        #     step_time = time.time()
        #     b_imgs_list = train_hr_img_list[idx : idx + batch_size]
        #     b_imgs = tl.prepro.threading_data(b_imgs_list, fn=get_imgs_fn, path=config.TRAIN.hr_img_path)
        #     b_imgs_384 = tl.prepro.threading_data(b_imgs, fn=crop_sub_imgs_fn, is_random=True)
        #     b_imgs_96 = tl.prepro.threading_data(b_imgs_384, fn=downsample_fn)

        ## If your machine have enough memory, please pre-load the whole train set.
        for idx in range(0, len(train_hr_imgs), batch_size):
            step_time = time.time()
            b_imgs_384 = tl.prepro.threading_data(train_hr_imgs[idx:idx + batch_size], fn=crop_sub_imgs_fn, is_random=True)
            b_imgs_96 = tl.prepro.threading_data(b_imgs_384, fn=downsample_fn)
            ## update G
            errM, _ = sess.run([mse_loss, g_optim_init], {t_image: b_imgs_96, t_target_image: b_imgs_384})
            print("Epoch [%2d/%2d] %4d time: %4.4fs, mse: %.8f " % (epoch, n_epoch_init, n_iter, time.time() - step_time, errM))
            total_mse_loss += errM
            n_iter += 1
        log = "[*] Epoch: [%2d/%2d] time: %4.4fs, mse: %.8f" % (epoch, n_epoch_init, time.time() - epoch_time, total_mse_loss / n_iter)
        print(log)

        ## quick evaluation on train set
        if (epoch != 0) and (epoch % 10 == 0):
            out = sess.run(net_g_test.outputs, {t_image: sample_imgs_96})  #; print('gen sub-image:', out.shape, out.min(), out.max())
            print("[*] save images")
            tl.vis.save_images(out, [ni, ni], save_dir_ginit + '/train_%d.png' % epoch)

        ## save model
        if (epoch != 0) and (epoch % 10 == 0):
            tl.files.save_npz(net_g.all_params, name=checkpoint_dir + '/g_{}_init.npz'.format(tl.global_flag['mode']), sess=sess)

    ###========================= train GAN (SRGAN) =========================###
    for epoch in range(0, n_epoch + 1):
        ## update learning rate
        if epoch != 0 and (epoch % decay_every == 0):
            new_lr_decay = lr_decay**(epoch // decay_every)
            sess.run(tf.assign(lr_v, lr_init * new_lr_decay))
            log = " ** new learning rate: %f (for GAN)" % (lr_init * new_lr_decay)
            print(log)
        elif epoch == 0:
            sess.run(tf.assign(lr_v, lr_init))
            log = " ** init lr: %f  decay_every_init: %d, lr_decay: %f (for GAN)" % (lr_init, decay_every, lr_decay)
            print(log)

        epoch_time = time.time()
        total_d_loss, total_g_loss, n_iter = 0, 0, 0

        ## If your machine cannot load all images into memory, you should use
        ## this one to load batch of images while training.
        # random.shuffle(train_hr_img_list)
        # for idx in range(0, len(train_hr_img_list), batch_size):
        #     step_time = time.time()
        #     b_imgs_list = train_hr_img_list[idx : idx + batch_size]
        #     b_imgs = tl.prepro.threading_data(b_imgs_list, fn=get_imgs_fn, path=config.TRAIN.hr_img_path)
        #     b_imgs_384 = tl.prepro.threading_data(b_imgs, fn=crop_sub_imgs_fn, is_random=True)
        #     b_imgs_96 = tl.prepro.threading_data(b_imgs_384, fn=downsample_fn)

        ## If your machine have enough memory, please pre-load the whole train set.
        for idx in range(0, len(train_hr_imgs), batch_size):
            step_time = time.time()
            b_imgs_384 = tl.prepro.threading_data(train_hr_imgs[idx:idx + batch_size], fn=crop_sub_imgs_fn, is_random=True)
            b_imgs_96 = tl.prepro.threading_data(b_imgs_384, fn=downsample_fn)
            ## update D
            errD, _ = sess.run([d_loss, d_optim], {t_image: b_imgs_96, t_target_image: b_imgs_384})
            ## update G
            errG, errM, errV, errA, _ = sess.run([g_loss, mse_loss, vgg_loss, g_gan_loss, g_optim], {t_image: b_imgs_96, t_target_image: b_imgs_384})
            print("Epoch [%2d/%2d] %4d time: %4.4fs, d_loss: %.8f g_loss: %.8f (mse: %.6f vgg: %.6f adv: %.6f)" %
                  (epoch, n_epoch, n_iter, time.time() - step_time, errD, errG, errM, errV, errA))
            total_d_loss += errD
            total_g_loss += errG
            n_iter += 1

        log = "[*] Epoch: [%2d/%2d] time: %4.4fs, d_loss: %.8f g_loss: %.8f" % (epoch, n_epoch, time.time() - epoch_time, total_d_loss / n_iter,
                                                                                total_g_loss / n_iter)
        print(log)

        ## quick evaluation on train set
        if (epoch != 0) and (epoch % 10 == 0):
            out = sess.run(net_g_test.outputs, {t_image: sample_imgs_96})  #; print('gen sub-image:', out.shape, out.min(), out.max())
            print("[*] save images")
            tl.vis.save_images(out, [ni, ni], save_dir_gan + '/train_%d.png' % epoch)

        ## save model
        if (epoch != 0) and (epoch % 10 == 0):
            tl.files.save_npz(net_g.all_params, name=checkpoint_dir + '/g_{}.npz'.format(tl.global_flag['mode']), sess=sess)
            tl.files.save_npz(net_d.all_params, name=checkpoint_dir + '/d_{}.npz'.format(tl.global_flag['mode']), sess=sess)
示例#5
0
def train():
    ## create folders to save result images and trained model
    save_dir_ginit = "samples/{}_ginit".format(tl.global_flag['mode'])
    save_dir_gan = "samples/{}_gan".format(tl.global_flag['mode'])
    tl.files.exists_or_mkdir(save_dir_ginit)
    tl.files.exists_or_mkdir(save_dir_gan)
    checkpoint_dir = "checkpoint"  # checkpoint_resize_conv
    tl.files.exists_or_mkdir(checkpoint_dir)

    ###====================== PRE-LOAD DATA ===========================###
    train_hr_img_list = sorted(
        tl.files.load_file_list(path=config.TRAIN.hr_img_path,
                                regx='.*.png',
                                printable=False))
    train_lr_img_list = sorted(
        tl.files.load_file_list(path=config.TRAIN.lr_img_path,
                                regx='.*.png',
                                printable=False))
    valid_hr_img_list = sorted(
        tl.files.load_file_list(path=config.VALID.hr_img_path,
                                regx='.*.png',
                                printable=False))
    valid_lr_img_list = sorted(
        tl.files.load_file_list(path=config.VALID.lr_img_path,
                                regx='.*.png',
                                printable=False))

    ## If your machine have enough memory, please pre-load the whole train set.

    print("reading images")
    train_hr_imgs = []  #[None] * len(train_hr_img_list)

    #sess = tf.Session()
    for img__ in train_hr_img_list:

        image_loaded = scipy.misc.imread(os.path.join(config.TRAIN.hr_img_path,
                                                      img__),
                                         mode='L')
        image_loaded = image_loaded.reshape(
            (image_loaded.shape[0], image_loaded.shape[1], 1))

        train_hr_imgs.append(image_loaded)

    print(type(train_hr_imgs), len(train_hr_img_list))

    ###========================== DEFINE MODEL ============================###
    ## train inference

    #t_image = tf.placeholder('float32', [batch_size, 96, 96, 3], name='t_image_input_to_SRGAN_generator')
    #t_target_image = tf.placeholder('float32', [batch_size, 384, 384, 3], name='t_target_image')

    t_image = tf.placeholder('float32', [batch_size, 28, 224, 1],
                             name='t_image_input_to_SRGAN_generator')
    t_target_image = tf.placeholder(
        'float32', [batch_size, 224, 224, 1], name='t_target_image'
    )  # may have to convert 224x224x1 into 224x224x3, with channel 1 & 2 as 0. May have to have separate place-holder ?

    print("t_image:", tf.shape(t_image))
    print("t_target_image:", tf.shape(t_target_image))

    net_g = SRGAN_g(t_image, is_train=True, reuse=False)
    print("net_g.outputs:", tf.shape(net_g.outputs))

    net_d, logits_real = SRGAN_d(t_target_image, is_train=True, reuse=False)
    _, logits_fake = SRGAN_d(net_g.outputs, is_train=True, reuse=True)

    net_g.print_params(False)
    net_g.print_layers()
    net_d.print_params(False)
    net_d.print_layers()

    ## vgg inference. 0, 1, 2, 3 BILINEAR NEAREST BICUBIC AREA
    t_target_image_224 = tf.image.resize_images(
        t_target_image, size=[224, 224], method=0, align_corners=False
    )  # resize_target_image_for_vgg # http://tensorlayer.readthedocs.io/en/latest/_modules/tensorlayer/layers.html#UpSampling2dLayer
    t_predict_image_224 = tf.image.resize_images(
        net_g.outputs, size=[224, 224], method=0,
        align_corners=False)  # resize_generate_image_for_vgg

    ## Added as VGG works for RGB and expects 3 channels.
    t_target_image_224 = tf.image.grayscale_to_rgb(t_target_image_224)
    t_predict_image_224 = tf.image.grayscale_to_rgb(t_predict_image_224)

    print("net_g.outputs:", tf.shape(net_g.outputs))
    print("t_predict_image_224:", tf.shape(t_predict_image_224))

    net_vgg, vgg_target_emb = Vgg19_simple_api((t_target_image_224 + 1) / 2,
                                               reuse=False)
    _, vgg_predict_emb = Vgg19_simple_api((t_predict_image_224 + 1) / 2,
                                          reuse=True)

    ## test inference
    net_g_test = SRGAN_g(t_image, is_train=False, reuse=True)

    # ###========================== DEFINE TRAIN OPS ==========================###
    d_loss1 = tl.cost.sigmoid_cross_entropy(logits_real,
                                            tf.ones_like(logits_real),
                                            name='d1')
    d_loss2 = tl.cost.sigmoid_cross_entropy(logits_fake,
                                            tf.zeros_like(logits_fake),
                                            name='d2')
    d_loss = d_loss1 + d_loss2

    g_gan_loss = 1e-3 * tl.cost.sigmoid_cross_entropy(
        logits_fake, tf.ones_like(logits_fake), name='g')
    mse_loss = tl.cost.mean_squared_error(net_g.outputs,
                                          t_target_image,
                                          is_mean=True)
    vgg_loss = 2e-6 * tl.cost.mean_squared_error(
        vgg_predict_emb.outputs, vgg_target_emb.outputs, is_mean=True)

    g_loss = mse_loss + vgg_loss + g_gan_loss

    d_loss1_summary = tf.summary.scalar('Disciminator logits_real loss',
                                        d_loss1)
    d_loss2_summary = tf.summary.scalar('Disciminator logits_fake loss',
                                        d_loss2)
    d_loss_summary = tf.summary.scalar('Disciminator total loss', d_loss)

    g_gan_loss_summary = tf.summary.scalar('Generator GAN loss', g_gan_loss)
    mse_loss_summary = tf.summary.scalar('Generator MSE loss', mse_loss)
    vgg_loss_summary = tf.summary.scalar('Generator VGG loss', vgg_loss)
    g_loss_summary = tf.summary.scalar('Generator total loss', g_loss)

    g_vars = tl.layers.get_variables_with_name('SRGAN_g', True, True)
    d_vars = tl.layers.get_variables_with_name('SRGAN_d', True, True)

    with tf.variable_scope('learning_rate'):
        lr_v = tf.Variable(lr_init, trainable=False)
    ## Pretrain
    #	UNCOMMENT THE LINE BELOW!!!
    #g_optim_init = tf.train.AdamOptimizer(lr_v, beta1=beta1).minimize(mse_loss, var_list=g_vars)
    ## SRGAN
    g_optim = tf.train.AdamOptimizer(lr_v,
                                     beta1=beta1).minimize(g_loss,
                                                           var_list=g_vars)
    d_optim = tf.train.AdamOptimizer(lr_v,
                                     beta1=beta1).minimize(d_loss,
                                                           var_list=d_vars)

    ###========================== RESTORE MODEL =============================###
    sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
                                            log_device_placement=False))
    tl.layers.initialize_global_variables(sess)
    #if tl.files.load_and_assign_npz(sess=sess, name=checkpoint_dir + '/g_{}.npz'.format(tl.global_flag['mode']), network=net_g) is False:
    #   tl.fites.load_and_assign_npz(sess=sess, name=checkpoint_dir + '/g_{}_init.npz'.format(tl.global_flag['mode']), network=net_g)
    #tl.files.load_and_assign_npz(sess=sess, name=checkpoint_dir + '/d_{}.npz'.format(tl.global_flag['mode']), network=net_d)

    ###============================= LOAD VGG ===============================###
    vgg19_npy_path = "vgg19.npy"
    if not os.path.isfile(vgg19_npy_path):
        print(
            "Please download vgg19.npz from : https://github.com/machrisaa/tensorflow-vgg"
        )
        exit()
    npz = np.load(vgg19_npy_path, encoding='latin1').item()

    params = []
    for val in sorted(npz.items()):
        W = np.asarray(val[1][0])
        b = np.asarray(val[1][1])
        print("  Loading %s: %s, %s" % (val[0], W.shape, b.shape))
        params.extend([W, b])
    tl.files.assign_params(sess, params, net_vgg)
    # net_vgg.print_params(False)
    # net_vgg.print_layers()

    ###============================= TRAINING ===============================###
    ## use first `batch_size` of train set to have a quick test during training
    sample_imgs = train_hr_imgs[0:batch_size]
    # sample_imgs = tl.vis.read_images(train_hr_img_list[0:batch_size], path=config.TRAIN.hr_img_path, n_threads=32) # if no pre-load train set

    print("sample_imgs size:", len(sample_imgs), sample_imgs[0].shape)

    sample_imgs_384 = tl.prepro.threading_data(sample_imgs,
                                               fn=crop_sub_imgs_fn,
                                               is_random=False)
    print('sample HR sub-image:', sample_imgs_384.shape, sample_imgs_384.min(),
          sample_imgs_384.max())
    sample_imgs_96 = tl.prepro.threading_data(sample_imgs_384,
                                              fn=downsample_fn_mod)
    print('sample LR sub-image:', sample_imgs_96.shape, sample_imgs_96.min(),
          sample_imgs_96.max())
    #tl.vis.save_images(sample_imgs_96, [ni, ni], save_dir_ginit + '/_train_sample_96.png')
    #tl.vis.save_images(sample_imgs_384, [ni, ni], save_dir_ginit + '/_train_sample_384.png')
    #tl.vis.save_images(sample_imgs_96, [ni, ni], save_dir_gan + '/_train_sample_96.png')
    #tl.vis.save_images(sample_imgs_384, [ni, ni], save_dir_gan + '/_train_sample_384.png')
    '''
    ###========================= initialize G ====================###
    
    merged_summary_initial_G = tf.summary.merge([mse_loss_summary])
    summary_intial_G_writer = tf.summary.FileWriter("./log/train/initial_G")
    
    

    ## fixed learning rate
    sess.run(tf.assign(lr_v, lr_init))
    print(" ** fixed learning rate: %f (for init G)" % lr_init)
    count = 0
    for epoch in range(0, n_epoch_init + 1):
        epoch_time = time.time()
        total_mse_loss, n_iter = 0, 0

        ## If your machine cannot load all images into memory, you should use
        ## this one to load batch of images while training.
        # random.shuffle(train_hr_img_list)
        # for idx in range(0, len(train_hr_img_list), batch_size):
        #     step_time = time.time()
        #     b_imgs_list = train_hr_img_list[idx : idx + batch_size]
        #     b_imgs = tl.prepro.threading_data(b_imgs_list, fn=get_imgs_fn, path=config.TRAIN.hr_img_path)
        #     b_imgs_384 = tl.prepro.threading_data(b_imgs, fn=crop_sub_imgs_fn, is_random=True)
        #     b_imgs_96 = tl.prepro.threading_data(b_imgs_384, fn=downsample_fn)

        ## If your machine have enough memory, please pre-load the whole train set.
        
        
        intial_MSE_G_summary_per_epoch = []
        
        
        for idx in range(0, len(train_hr_imgs), batch_size):
            step_time = time.time()
            b_imgs_384 = tl.prepro.threading_data(train_hr_imgs[idx:idx + batch_size], fn=crop_sub_imgs_fn, is_random=True)
            b_imgs_96 = tl.prepro.threading_data(b_imgs_384, fn=downsample_fn_mod)
            ## update G
            errM, _, mse_summary_initial_G = sess.run([mse_loss, g_optim_init, merged_summary_initial_G], {t_image: b_imgs_96, t_target_image: b_imgs_384})
            print("Epoch [%2d/%2d] %4d time: %4.4fs, mse: %.8f " % (epoch, n_epoch_init, n_iter, time.time() - step_time, errM))

            
            summary_pb = tf.summary.Summary()
            summary_pb.ParseFromString(mse_summary_initial_G)
            
            intial_G_summaries = {}
            for val in summary_pb.value:
            # Assuming all summaries are scalars.
                intial_G_summaries[val.tag] = val.simple_value
            #print("intial_G_summaries:", intial_G_summaries)
            
            
            intial_MSE_G_summary_per_epoch.append(intial_G_summaries['Generator_MSE_loss'])
            
            
            #summary_intial_G_writer.add_summary(mse_summary_initial_G, (count + 1)) #(epoch + 1)*(n_iter+1))
            #count += 1


            total_mse_loss += errM
            n_iter += 1
        log = "[*] Epoch: [%2d/%2d] time: %4.4fs, mse: %.8f" % (epoch, n_epoch_init, time.time() - epoch_time, total_mse_loss / n_iter)
        print(log)

        
        summary_intial_G_writer.add_summary(tf.Summary(value=[tf.Summary.Value(tag="Generator_Initial_MSE_loss per epoch", simple_value=np.mean(intial_MSE_G_summary_per_epoch)),]), (epoch))


        ## quick evaluation on train set
        #if (epoch != 0) and (epoch % 10 == 0):
        out = sess.run(net_g_test.outputs, {t_image: sample_imgs_96})  #; print('gen sub-image:', out.shape, out.min(), out.max())
        print("[*] save images")
        for im in range(len(out)):
            if(im%4==0 or im==1197):
                tl.vis.save_image(out[im], save_dir_ginit + '/train_%d_%d.png' % (epoch,im))

        ## save model
        saver=tf.train.Saver()
        if (epoch%10==0 and epoch!=0):
            saver.save(sess, 'checkpoint/init_'+str(epoch)+'.ckpt')      

   #if (epoch != 0) and (epoch % 10 == 0):
        #tl.files.save_npz(net_g.all_params, name=checkpoint_dir + '/g_{}_{}_init.npz'.format(tl.global_flag['mode'], epoch), sess=sess)
    '''
    ###========================= train GAN (SRGAN) =========================###
    saver = tf.train.Saver()
    saver.restore(sess, 'checkpoint/main_10.ckpt')
    print('Restored main_10, begin 11/50')
    merged_summary_discriminator = tf.summary.merge(
        [d_loss1_summary, d_loss2_summary, d_loss_summary])
    summary_discriminator_writer = tf.summary.FileWriter(
        "./log/train/discriminator")

    merged_summary_generator = tf.summary.merge([
        g_gan_loss_summary, mse_loss_summary, vgg_loss_summary, g_loss_summary
    ])
    summary_generator_writer = tf.summary.FileWriter("./log/train/generator")

    learning_rate_writer = tf.summary.FileWriter("./log/train/learning_rate")

    count = 0
    for epoch in range(11, n_epoch + 11):
        ## update learning rate
        if epoch != 0 and (epoch % decay_every == 0):
            new_lr_decay = lr_decay**(epoch // decay_every)
            sess.run(tf.assign(lr_v, lr_init * new_lr_decay))
            log = " ** new learning rate: %f (for GAN)" % (lr_init *
                                                           new_lr_decay)
            print(log)

            learning_rate_writer.add_summary(
                tf.Summary(value=[
                    tf.Summary.Value(tag="Learning_rate per epoch",
                                     simple_value=(lr_init * new_lr_decay)),
                ]), (epoch))

        elif epoch == 0:
            sess.run(tf.assign(lr_v, lr_init))
            log = " ** init lr: %f  decay_every_init: %d, lr_decay: %f (for GAN)" % (
                lr_init, decay_every, lr_decay)
            print(log)

            learning_rate_writer.add_summary(
                tf.Summary(value=[
                    tf.Summary.Value(tag="Learning_rate per epoch",
                                     simple_value=lr_init),
                ]), (epoch))

        epoch_time = time.time()
        total_d_loss, total_g_loss, n_iter = 0, 0, 0

        ## If your machine cannot load all images into memory, you should use
        ## this one to load batch of images while training.
        # random.shuffle(train_hr_img_list)
        # for idx in range(0, len(train_hr_img_list), batch_size):
        #     step_time = time.time()
        #     b_imgs_list = train_hr_img_list[idx : idx + batch_size]
        #     b_imgs = tl.prepro.threading_data(b_imgs_list, fn=get_imgs_fn, path=config.TRAIN.hr_img_path)
        #     b_imgs_384 = tl.prepro.threading_data(b_imgs, fn=crop_sub_imgs_fn, is_random=True)
        #     b_imgs_96 = tl.prepro.threading_data(b_imgs_384, fn=downsample_fn)

        ## If your machine have enough memory, please pre-load the whole train set.

        loss_per_batch = []

        d_loss1_summary_per_epoch = []
        d_loss2_summary_per_epoch = []
        d_loss_summary_per_epoch = []

        g_gan_loss_summary_per_epoch = []
        mse_loss_summary_per_epoch = []
        vgg_loss_summary_per_epoch = []
        g_loss_summary_per_epoch = []

        for idx in range(0, len(train_hr_imgs), batch_size):
            step_time = time.time()
            b_imgs_384 = tl.prepro.threading_data(train_hr_imgs[idx:idx +
                                                                batch_size],
                                                  fn=crop_sub_imgs_fn,
                                                  is_random=True)
            b_imgs_96 = tl.prepro.threading_data(b_imgs_384,
                                                 fn=downsample_fn_mod)
            ## update D
            errD, _, discriminator_summary = sess.run(
                [d_loss, d_optim, merged_summary_discriminator], {
                    t_image: b_imgs_96,
                    t_target_image: b_imgs_384
                })

            summary_pb = tf.summary.Summary()
            summary_pb.ParseFromString(discriminator_summary)
            #print("discriminator_summary", summary_pb, type(summary_pb))

            discriminator_summaries = {}
            for val in summary_pb.value:
                # Assuming all summaries are scalars.
                discriminator_summaries[val.tag] = val.simple_value

            d_loss1_summary_per_epoch.append(
                discriminator_summaries['Disciminator_logits_real_loss'])
            d_loss2_summary_per_epoch.append(
                discriminator_summaries['Disciminator_logits_fake_loss'])
            d_loss_summary_per_epoch.append(
                discriminator_summaries['Disciminator_total_loss'])

            ## update G
            errG, errM, errV, errA, _, generator_summary = sess.run(
                [
                    g_loss, mse_loss, vgg_loss, g_gan_loss, g_optim,
                    merged_summary_generator
                ], {
                    t_image: b_imgs_96,
                    t_target_image: b_imgs_384
                })

            summary_pb = tf.summary.Summary()
            summary_pb.ParseFromString(generator_summary)
            #print("generator_summary", summary_pb, type(summary_pb))

            generator_summaries = {}
            for val in summary_pb.value:
                # Assuming all summaries are scalars.
                generator_summaries[val.tag] = val.simple_value

            #print("generator_summaries:", generator_summaries)

            g_gan_loss_summary_per_epoch.append(
                generator_summaries['Generator_GAN_loss'])
            mse_loss_summary_per_epoch.append(
                generator_summaries['Generator_MSE_loss'])
            vgg_loss_summary_per_epoch.append(
                generator_summaries['Generator_VGG_loss'])
            g_loss_summary_per_epoch.append(
                generator_summaries['Generator_total_loss'])

            #summary_generator_writer.add_summary(generator_summary, (count + 1))

            #summary_total = sess.run(summary_total_merged, {t_image: b_imgs_96, t_target_image: b_imgs_384})
            #summary_total_merged_writer.add_summary(summary_total, (count + 1))

            #count += 1

            tot_epoch = n_epoch + 10
            print(
                "Epoch [%2d/%2d] %4d time: %4.4fs, d_loss: %.8f g_loss: %.8f (mse: %.6f vgg: %.6f adv: %.6f)"
                % (epoch, tot_epoch, n_iter, time.time() - step_time, errD,
                   errG, errM, errV, errA))
            total_d_loss += errD
            total_g_loss += errG
            n_iter += 1
            #remove this for normal running:

        log = "[*] Epoch: [%2d/%2d] time: %4.4fs, d_loss: %.8f g_loss: %.8f" % (
            epoch, tot_epoch, time.time() - epoch_time, total_d_loss / n_iter,
            total_g_loss / n_iter)
        print(log)

        #####
        #
        # logging discriminator summary
        #
        ######

        # logging per epcoch summary of logit_real_loss per epoch. Value logged is averaged across batches used per epoch.
        summary_discriminator_writer.add_summary(
            tf.Summary(value=[
                tf.Summary.Value(tag="Disciminator_logits_real_loss per epoch",
                                 simple_value=np.mean(
                                     d_loss1_summary_per_epoch)),
            ]), (epoch))

        # logging per epcoch summary of logit_fake_loss per epoch. Value logged is averaged across batches used per epoch.
        summary_discriminator_writer.add_summary(
            tf.Summary(value=[
                tf.Summary.Value(tag="Disciminator_logits_fake_loss per epoch",
                                 simple_value=np.mean(
                                     d_loss2_summary_per_epoch)),
            ]), (epoch))

        # logging per epcoch summary of total_loss per epoch. Value logged is averaged across batches used per epoch.
        summary_discriminator_writer.add_summary(
            tf.Summary(value=[
                tf.Summary.Value(tag="Disciminator_total_loss per epoch",
                                 simple_value=np.mean(
                                     d_loss_summary_per_epoch)),
            ]), (epoch))

        #####
        #
        # logging generator summary
        #
        ######

        summary_generator_writer.add_summary(
            tf.Summary(value=[
                tf.Summary.Value(tag="Generator_GAN_loss per epoch",
                                 simple_value=np.mean(
                                     g_gan_loss_summary_per_epoch)),
            ]), (epoch))

        summary_generator_writer.add_summary(
            tf.Summary(value=[
                tf.Summary.Value(tag="Generator_MSE_loss per epoch",
                                 simple_value=np.mean(
                                     mse_loss_summary_per_epoch)),
            ]), (epoch))

        summary_generator_writer.add_summary(
            tf.Summary(value=[
                tf.Summary.Value(tag="Generator_VGG_loss per epoch",
                                 simple_value=np.mean(
                                     vgg_loss_summary_per_epoch)),
            ]), (epoch))

        summary_generator_writer.add_summary(
            tf.Summary(value=[
                tf.Summary.Value(tag="Generator_total_loss per epoch",
                                 simple_value=np.mean(
                                     g_loss_summary_per_epoch)),
            ]), (epoch))

        ## quick evaluation on train set
        #if (epoch != 0) and (epoch % 10 == 0):
        out = sess.run(
            net_g_test.outputs,
            {t_image: sample_imgs_96
             })  #; print('gen sub-image:', out.shape, out.min(), out.max())
        ## save model
        if (epoch % 10 == 0 and epoch != 0):
            saver.save(sess, 'checkpoint/main_' + str(epoch) + '.ckpt')

            print("[*] save images")
            for im in range(len(out)):
                tl.vis.save_image(
                    out[im], save_dir_gan + '/train_%d_%d.png' % (epoch, im))
示例#6
0
def train():
    ## create folders to save result images and trained model
    save_dir_ginit = "samples/{}_ginit".format(tl.global_flag['mode'])
    save_dir_gan = "samples/{}_gan".format(tl.global_flag['mode'])
    save_dir_valid = "samples/{}_valid".format(tl.global_flag['mode'])
    tl.files.exists_or_mkdir(save_dir_ginit)
    tl.files.exists_or_mkdir(save_dir_gan)
    tl.files.exists_or_mkdir(save_dir_valid)
    checkpoint_dir = "checkpoint"  # checkpoint_resize_conv
    tl.files.exists_or_mkdir(checkpoint_dir)

    train_hr_imgs = read_csv_data(config.TRAIN.hr_img_path,
                                  width=48,
                                  height=48,
                                  channel=1)
    valid_hr_imgs = read_csv_data(config.VALID.hr_img_path,
                                  width=48,
                                  height=48,
                                  channel=1)

    ###========================== DEFINE MODEL ============================###
    ## train inference  ## t = train
    t_image = tf.placeholder('float32', [None, 16, 16, 1],
                             name='t_image_input_to_SRGAN_generator')
    t_target_image = tf.placeholder('float32', [None, 48, 48, 1],
                                    name='t_target_image')

    net_g = SRGAN_g(t_image, is_train=True, reuse=False, nb_block=16)
    net_d, logits_real = SRGAN_d(t_target_image, is_train=True, reuse=False)
    _, logits_fake = SRGAN_d(net_g.outputs, is_train=True, reuse=True)

    net_g.print_params(False)
    net_g.print_layers()
    net_d.print_params(False)
    net_d.print_layers()

    ## test inference
    net_g_test = SRGAN_g(t_image, is_train=False, reuse=True)

    # ###========================== DEFINE TRAIN OPS ==========================###
    # d_loss: for discriminator
    d_loss1 = tl.cost.sigmoid_cross_entropy(logits_real,
                                            tf.ones_like(logits_real),
                                            name='d1')
    d_loss2 = tl.cost.sigmoid_cross_entropy(logits_fake,
                                            tf.zeros_like(logits_fake),
                                            name='d2')
    d_loss = d_loss1 + d_loss2

    # g_loss: for generator
    g_gan_loss = 1e-3 * tl.cost.sigmoid_cross_entropy(
        logits_fake, tf.ones_like(logits_fake), name='g')
    mse_loss = tl.cost.mean_squared_error(net_g.outputs,
                                          t_target_image,
                                          is_mean=True)
    # vgg_loss = 2e-6 * tl.cost.mean_squared_error(vgg_predict_emb.outputs, vgg_target_emb.outputs, is_mean=True)

    g_loss = mse_loss + g_gan_loss
    # g_loss = mse_loss + vgg_loss + g_gan_loss

    g_vars = tl.layers.get_variables_with_name('SRGAN_g', True, True)
    d_vars = tl.layers.get_variables_with_name('SRGAN_d', True, True)

    with tf.variable_scope('learning_rate'):
        lr_v = tf.Variable(lr_init, trainable=False)
    ## Pretrain
    g_optim_init = tf.train.AdamOptimizer(lr_v, beta1=beta1).minimize(
        mse_loss, var_list=g_vars)
    ## SRGAN
    g_optim = tf.train.AdamOptimizer(lr_v,
                                     beta1=beta1).minimize(g_loss,
                                                           var_list=g_vars)
    d_optim = tf.train.AdamOptimizer(lr_v,
                                     beta1=beta1).minimize(d_loss,
                                                           var_list=d_vars)

    ###============================= TRAINING ===============================###
    ## use first `batch_size` of train set to have a quick test during training
    sample_imgs = train_hr_imgs[0:9]
    valid_imgs = valid_hr_imgs[44:53]
    # sample_imgs = tl.vis.read_images(train_hr_img_list[0:batch_size], path=config.TRAIN.hr_img_path, n_threads=32) # if no pre-load train set
    sample_imgs_384 = tl.prepro.threading_data(sample_imgs,
                                               fn=crop_sub_imgs_fn,
                                               is_random=False)
    valid_imgs_48 = tl.prepro.threading_data(valid_imgs,
                                             fn=crop_sub_imgs_fn,
                                             is_random=False)
    print('sample HR sub-image:', sample_imgs_384.shape, sample_imgs_384.min(),
          sample_imgs_384.max())
    sample_imgs_96 = tl.prepro.threading_data(sample_imgs_384,
                                              fn=downsample_fn,
                                              down_rate=3)
    valid_imgs_16 = tl.prepro.threading_data(valid_imgs_48,
                                             fn=downsample_fn,
                                             down_rate=3)
    print('sample LR sub-image:', sample_imgs_96.shape, sample_imgs_96.min(),
          sample_imgs_96.max())
    tl.vis.save_images(sample_imgs_96, [ni, ni],
                       save_dir_ginit + '/_train_sample_16.png')
    tl.vis.save_images(sample_imgs_384, [ni, ni],
                       save_dir_ginit + '/_train_sample_48.png')
    tl.vis.save_images(sample_imgs_96, [ni, ni],
                       save_dir_gan + '/_train_sample_16.png')
    tl.vis.save_images(sample_imgs_384, [ni, ni],
                       save_dir_gan + '/_train_sample_48.png')
    tl.vis.save_images(valid_imgs_48, [ni, ni],
                       save_dir_valid + '/_valid_sample_48.png')
    tl.vis.save_images(valid_imgs_16, [ni, ni],
                       save_dir_valid + '/_valid_sample_16.png')
    sample_hr_imgs_bicubic = tl.prepro.threading_data(sample_imgs_96,
                                                      fn=upsample_fn,
                                                      up_rate=3)
    valid_hr_imgs_bicubic = tl.prepro.threading_data(valid_imgs_16,
                                                     fn=upsample_fn,
                                                     up_rate=3)
    tl.vis.save_images(sample_hr_imgs_bicubic, [ni, ni],
                       save_dir_ginit + '/_sample_bicubic_48.png')
    tl.vis.save_images(valid_hr_imgs_bicubic, [ni, ni],
                       save_dir_valid + '/_valid_sample_bicubic_48.png')

    ###========================= initialize G ====================###
    ## fixed learning rate
    sess.run(tf.assign(lr_v, lr_init))
    train_writer_path = "./log/train"
    tl.files.exists_or_mkdir(train_writer_path)
    train_writer = tf.summary.FileWriter(train_writer_path,
                                         graph=tf.get_default_graph())
    print(" ** fixed learning rate: %f (for init G)" % lr_init)
    for epoch in range(0, n_epoch_init + 1):
        epoch_time = time.time()
        total_mse_loss, n_iter = 0, 0

        ## If your machine have enough memory, please pre-load the whole train set.
        for idx in range(0, len(train_hr_imgs), batch_size):
            if idx + batch_size > len(train_hr_imgs):
                break

            step_time = time.time()
            b_imgs_384 = tl.prepro.threading_data(train_hr_imgs[idx:idx +
                                                                batch_size],
                                                  fn=crop_sub_imgs_fn,
                                                  is_random=True)
            b_imgs_96 = tl.prepro.threading_data(b_imgs_384, fn=downsample_fn)
            ## update G
            errM, _ = sess.run([mse_loss, g_optim_init], {
                t_image: b_imgs_96,
                t_target_image: b_imgs_384
            })
            sys.stdout.write(
                "Epoch [%2d/%2d] %4d time: %4.4fs, mse: %.8f \r" %
                (epoch, n_epoch_init, n_iter, time.time() - step_time, errM))
            sys.stdout.flush()
            total_mse_loss += errM
            n_iter += 1
        log = "\n[*] Epoch: [%2d/%2d] time: %4.4fs, mse: %.8f" % (
            epoch, n_epoch_init, time.time() - epoch_time,
            total_mse_loss / n_iter)
        print(log)

        ## quick evaluation on train set
        if (epoch != 0) and (epoch % 10 == 0):
            out = sess.run(net_g_test.outputs, {
                t_image: sample_imgs_96
            })  #; print('gen sub-image:', out.shape, out.min(), out.max())
            print("[*] save images")
            tl.vis.save_images(out, [ni, ni],
                               save_dir_ginit + '/train_%d.png' % epoch)

        ## quick evaluation on validation set
        if (epoch != 0) and (epoch % 10 == 0):
            out = sess.run(net_g_test.outputs, {
                t_image: valid_imgs_16
            })  #; print('gen sub-image:', out.shape, out.min(), out.max())
            tl.vis.save_images(out, [ni, ni],
                               save_dir_valid + '/valid_ganit_%d.png' % epoch)

        ## save model
        if (epoch != 0) and (epoch % 10 == 0):
            tl.files.save_npz(
                net_g.all_params,
                name=checkpoint_dir +
                '/g_{}_init_{}.npz'.format(tl.global_flag['mode'], epoch),
                sess=sess)

    ###========================= train GAN (SRGAN) =========================###
    for epoch in range(0, n_epoch + 1):
        ## update learning rate
        if epoch != 0 and (epoch % decay_every == 0):
            new_lr_decay = lr_decay**(epoch // decay_every)
            sess.run(tf.assign(lr_v, lr_init * new_lr_decay))
            log = " ** new learning rate: %f (for GAN)" % (lr_init *
                                                           new_lr_decay)
            print(log)
        elif epoch == 0:
            sess.run(tf.assign(lr_v, lr_init))
            log = " ** init lr: %f  decay_every_init: %d, lr_decay: %f (for GAN)" % (
                lr_init, decay_every, lr_decay)
            print(log)

        epoch_time = time.time()
        total_d_loss, total_g_loss, n_iter = 0, 0, 0

        ## If your machine have enough memory, please pre-load the whole train set.
        for idx in range(0, len(train_hr_imgs), batch_size):
            step_time = time.time()
            b_imgs_384 = tl.prepro.threading_data(train_hr_imgs[idx:idx +
                                                                batch_size],
                                                  fn=crop_sub_imgs_fn,
                                                  is_random=True)
            b_imgs_96 = tl.prepro.threading_data(b_imgs_384, fn=downsample_fn)
            ## update D
            errD, _ = sess.run([d_loss, d_optim], {
                t_image: b_imgs_96,
                t_target_image: b_imgs_384
            })
            ## update G
            # errG, errM, errV, errA, _ = sess.run([g_loss, mse_loss, vgg_loss, g_gan_loss, g_optim], {t_image: b_imgs_96, t_target_image: b_imgs_384})
            errG, errM, errA, _ = sess.run(
                [g_loss, mse_loss, g_gan_loss, g_optim], {
                    t_image: b_imgs_96,
                    t_target_image: b_imgs_384
                })
            # sys.stdout.write("Epoch [%2d/%2d] %4d time: %4.4fs, d_loss: %.8f g_loss: %.8f (mse: %.6f vgg: %.6f adv: %.6f)\n" %
            #       (epoch, n_epoch, n_iter, time.time() - step_time, errD, errG, errM, errV, errA))
            sys.stdout.write(
                "Epoch [%2d/%2d] %4d time: %4.4fs, d_loss: %.8f g_loss: %.8f (mse: %.6f adv: %.6f) \r"
                % (epoch, n_epoch, n_iter, time.time() - step_time, errD, errG,
                   errM, errA))
            sys.stdout.flush()
            total_d_loss += errD
            total_g_loss += errG
            n_iter += 1

        log = "[*] Epoch: [%2d/%2d] time: %4.4fs, d_loss: %.8f g_loss: %.8f" % (
            epoch, n_epoch, time.time() - epoch_time, total_d_loss / n_iter,
            total_g_loss / n_iter)
        print(log)

        ## quick evaluation on train set
        if (epoch != 0) and (epoch % 10 == 0):
            out = sess.run(net_g_test.outputs, {
                t_image: sample_imgs_96
            })  #; print('gen sub-image:', out.shape, out.min(), out.max())
            print("[*] save images")
            tl.vis.save_images(out, [ni, ni],
                               save_dir_gan + '/train_%d.png' % epoch)

        ## quick evaluation on validation set
        if (epoch != 0) and (epoch % 10 == 0):
            out = sess.run(net_g_test.outputs, {
                t_image: valid_imgs_16
            })  #; print('gen sub-image:', out.shape, out.min(), out.max())
            tl.vis.save_images(out, [ni, ni],
                               save_dir_valid + '/valid_gan_%d.png' % epoch)

        ## save model
        if (epoch != 0) and (epoch % 10 == 0):
            tl.files.save_npz(
                net_g.all_params,
                name=checkpoint_dir +
                '/g_{}_{}.npz'.format(tl.global_flag['mode'], epoch),
                sess=sess)
            tl.files.save_npz(
                net_d.all_params,
                name=checkpoint_dir +
                '/d_{}_{}.npz'.format(tl.global_flag['mode'], epoch),
                sess=sess)
示例#7
0
def train():
    # 创建一个文件夹保存训练好的模型
    save_dir_ginit = "samples/facenet_pgd_loss_ginit".format(
        tl.global_flag['mode'])
    save_dir_gan = "samples/facenet_pgd_loss_gan".format(
        tl.global_flag['mode'])
    tl.files.exists_or_mkdir(save_dir_ginit)
    tl.files.exists_or_mkdir(save_dir_gan)
    checkpoint_dir = "checkpoint/facenet_pgd_loss_12.19"  # checkpoint_resize_conv
    tl.files.exists_or_mkdir(checkpoint_dir)

    # 加载训练集数据
    train_hr_img_list = sorted(
        tl.files.load_file_list(path=config.TRAIN.hr_img_path,
                                regx='.*.png',
                                printable=False))
    train_lr_img_list = sorted(
        tl.files.load_file_list(path=config.TRAIN.lr_img_path,
                                regx='.*.png',
                                printable=False))
    # 加载facenet参考样本数据集
    # train_reference_img_list = sorted(
    #     tl.files.load_file_list(path=config.TRAIN.reference_img_path, regx='.*.png', printable=False))

    train_hr_imgs = tl.vis.read_images(train_hr_img_list,
                                       path=config.TRAIN.hr_img_path,
                                       n_threads=32)
    train_lr_imgs = tl.vis.read_images(train_lr_img_list,
                                       path=config.TRAIN.lr_img_path,
                                       n_threads=32)
    # train_reference_imgs = tl.vis.read_images(train_reference_img_list, path=config.TRAIN.reference_img_path,
    #                                           n_threads=32)
    t_image = tf.placeholder('float32', [batch_size, 160, 160, 3],
                             name='t_image_input_to_SRGAN_generator')
    t_target_image = tf.placeholder('float32', [batch_size, 160, 160, 3],
                                    name='t_target_image')
    phase_train_placeholder = tf.placeholder(tf.bool, name='phase_train')
    # softmax_output1 = tf.placeholder('float32')
    # softmax_output2 = tf.placeholder('float32')
    # 定义模型
    net_g = SRGAN_g(t_image, is_train=True, reuse=False)

    net_g.print_params(False)
    net_g.print_layers()
    # 加载vggface模型
    data1 = loadmat('vgg-face.mat')
    # # # resize成vggface可以接受的图像尺寸
    # t_target_image_224 = tf.image.resize_images(t_target_image, size=[224, 224], method=0, align_corners=False)
    #
    # t_predict_image_224 = tf.image.resize_images(net_g.outputs, size=[224, 224], method=0, align_corners=False)
    # out_160 = tf.image.resize_images(net_g.outputs, size=[160, 160], method=0, align_corners=False)
    # t_target_image_160 = tf.image.resize_images(t_target_image, size=[160, 160], method=0, align_corners=False)
    out_160 = prewhitenfacenet(net_g.outputs)
    t_target_image_160 = prewhitenfacenet(t_target_image)
    image_batch1 = tf.identity(out_160, 'input')
    image_batch2 = tf.identity(t_target_image_160, 'input')
    # facenet_target_emb2 = tf.get_default_graph().get_tensor_by_name("embeddings:0")
    # #facenet_reference_emb2 = tf.get_default_graph().get_tensor_by_name("embeddings:0")
    # facenet_predict_emb2 = tf.get_default_graph().get_tensor_by_name("embeddings:0")

    # net_vgg, vgg_target_emb, vgg_relu_emb = vgg_face_api(data1, (t_target_image_224 + 1) / 2)
    # _, vgg_predict_emb, vgg_predict_relu_emb = vgg_face_api(data1, (t_predict_image_224 + 1) / 2)
    # predicted_out = tf.nn.l2_normalize(vgg_predict_relu_emb, 1, 1e-10, name='embeddings')
    # print(predicted_out)
    # predicted_target_out = tf.nn.l2_normalize(vgg_relu_emb, 1, 1e-10, name='embeddings')
    # net_vgg, vgg_target_emb, vgg_target_emb2 = Vgg19_simple_api((t_target_image_224 + 1) / 2, reuse=False)
    # _, vgg_predict_emb, vgg_predict_emb2 = Vgg19_simple_api((t_predict_image_224 + 1) / 2, reuse=True)
    model_def = 'inception_resnet_v1_new'
    network = importlib.import_module(model_def)
    # print('Building training graph')

    # Build the inference graph
    prelogits1, _, texture_emb1 = network.inference(
        image_batch1,
        keep_probability,
        phase_train=phase_train_placeholder,
        bottleneck_layer_size=embedding_size,
        weight_decay=weight_decay,
        reuse=tf.AUTO_REUSE)
    prelogits2, _, texture_emb2 = network.inference(
        image_batch2,
        keep_probability,
        phase_train=phase_train_placeholder,
        bottleneck_layer_size=embedding_size,
        weight_decay=weight_decay,
        reuse=tf.AUTO_REUSE)
    # logits = slim.fully_connected(prelogits, len(train_set), activation_fn=None,
    #         weights_initializer=slim.initializers.xavier_initializer(),
    #         weights_regularizer=slim.l2_regularizer(args.weight_decay),
    #         scope='Logits', reuse=False)

    embeddings1 = tf.nn.l2_normalize(prelogits1, 1, 1e-10, name='embeddings')
    embeddings2 = tf.nn.l2_normalize(prelogits2, 1, 1e-10, name='embeddings')
    # test inference
    net_g_test = SRGAN_g(t_image, is_train=False, reuse=True)
    # distance1 = tf.reduce_sum(tf.square(facenet_target_emb2 - facenet_predict_emb2))
    # softmax_output1 = softmax(distance1)
    # distance2 = tf.reduce_sum(tf.square(facenet_target_emb2 - facenet_target_emb2))
    # softmax_output2 = softmax(distance2)
    #softmax2 = convert_to_softmax(facenet_reference_emb2, facenet_reference_emb2)
    #print(distance)
    # ###========================== DEFINE TRAIN OPS ==========================###
    mse_loss = tl.cost.mean_squared_error(net_g.outputs,
                                          t_target_image,
                                          is_mean=True)
    # mse_loss = tl.cost.mean_squared_error(predicted_out, predicted_target_out, is_mean=True)
    vgg_loss = 1e4 * tl.cost.mean_squared_error(
        texture_emb1, texture_emb2, is_mean=True)
    # distance1 = 1/batch_size * tf.reduce_sum(tf.square(embeddings1 - embeddings2))
    distance1 = tf.reduce_sum(tf.square(embeddings1 - embeddings2), axis=1)
    softmax_output_value1 = tf.transpose(softmax(distance1))
    # distance2 = 1/batch_size * tf.reduce_sum(tf.square(embeddings2 - embeddings2))
    distance2 = tf.reduce_sum(tf.square(embeddings2 - embeddings2), axis=1)
    softmax_output_value2 = tf.transpose(softmax(distance2))
    # softmax_loss = 1e-3 * tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=softmax_output_value2,
    #                                                                                        logits=softmax_output_value1))
    index = tf.arg_max(softmax_output_value2, 1)
    label_mask = tf.one_hot(index,
                            2,
                            on_value=1.0,
                            off_value=0.0,
                            dtype=tf.float32)
    softmax_loss = 1e2 * tf.reduce_mean(
        -tf.reduce_sum(label_mask * tf.log(softmax_output_value1), 1))
    # softmax_loss = 1e2 * tl.cost.mean_squared_error(embeddings1, embeddings2, is_mean=True)
    # 生成器损失
    g_loss = mse_loss + vgg_loss + softmax_loss

    g_vars = tl.layers.get_variables_with_name('SRGAN_g', True, True)

    with tf.variable_scope('learning_rate'):
        lr_v = tf.Variable(lr_init, trainable=False)
    inception_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                       scope='InceptionResnetV1')
    saver = tf.train.Saver(inception_vars, max_to_keep=3)
    # 前100轮的初始化只优化Mse损失
    g_optim_init = tf.train.AdamOptimizer(lr_v, beta1=beta1).minimize(
        mse_loss, var_list=g_vars)
    # SRGAN
    g_optim = tf.train.AdamOptimizer(lr_v,
                                     beta1=beta1).minimize(g_loss,
                                                           var_list=g_vars)
    # 模型恢复
    sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
                                            log_device_placement=False))
    tl.layers.initialize_global_variables(sess)
    pretrained_model = '/home/fan/facenet_adversarial_faces/models/facenet/20170512-110547/'
    if pretrained_model:
        print('Restoring pretrained model: %s' % pretrained_model)
        # facenet.load_model(pretrained_model)

        model_exp = os.path.expanduser(pretrained_model)
        print('Model directory: %s' % model_exp)
        _, ckpt_file = facenet.get_model_filenames(model_exp)

        # print('Metagraph file: %s' % meta_file)
        print('Checkpoint file: %s' % ckpt_file)
        saver.restore(sess, os.path.join(model_exp, ckpt_file))

    if tl.files.load_and_assign_npz(
            sess=sess,
            name=checkpoint_dir + '/g_{}.npz'.format(tl.global_flag['mode']),
            network=net_g) is False:
        tl.files.load_and_assign_npz(
            sess=sess,
            name=checkpoint_dir +
            '/g_{}_init.npz'.format(tl.global_flag['mode']),
            network=net_g)

    for var in tf.trainable_variables():
        print(var.name)
    # 开始训练
    sample_imgs_h = train_hr_imgs[0:batch_size]
    sample_imgs_l = train_lr_imgs[0:batch_size]
    sample_imgs_h = tl.prepro.threading_data(sample_imgs_h,
                                             fn=retain,
                                             is_random=False)
    print('sample HR sub-image:', sample_imgs_h.shape, sample_imgs_h.min(),
          sample_imgs_h.max())
    sample_imgs_l = tl.prepro.threading_data(sample_imgs_l,
                                             fn=retain,
                                             is_random=False)
    print('sample LR sub-image:', sample_imgs_l.shape, sample_imgs_l.min(),
          sample_imgs_l.max())
    tl.vis.save_images(sample_imgs_l, [ni, ni],
                       save_dir_ginit + '/_train_sample_l.png')
    tl.vis.save_images(sample_imgs_h, [ni, ni],
                       save_dir_ginit + '/_train_sample_h.png')
    tl.vis.save_images(sample_imgs_l, [ni, ni],
                       save_dir_gan + '/_train_sample_l.png')
    tl.vis.save_images(sample_imgs_h, [ni, ni],
                       save_dir_gan + '/_train_sample_h.png')

    # 初始化生成器
    # fixed learning rate
    sess.run(tf.assign(lr_v, lr_init))
    print(" ** fixed learning rate: %f (for init G)" % lr_init)
    for epoch in range(0, n_epoch_init + 1):
        epoch_time = time.time()
        total_mse_loss, n_iter = 0, 0

        for idx in range(0, len(train_hr_imgs), batch_size):
            step_time = time.time()
            x_imgs = [1] * batch_size
            for i in range(0, batch_size):
                x_imgs[i] = np.concatenate(
                    [train_hr_imgs[idx + i], train_lr_imgs[idx + i]], axis=2)
            b_imgs = tl.prepro.threading_data(x_imgs,
                                              fn=retain,
                                              is_random=True)
            b_imgs_h = b_imgs[:, :, :, 0:3]
            b_imgs_l = b_imgs[:, :, :, 3:6]

            # update G
            errM, _ = sess.run([mse_loss, g_optim_init], {
                t_image: b_imgs_l,
                t_target_image: b_imgs_h
            })
            print("Epoch [%2d/%2d] %4d time: %4.4fs, mse: %.8f " %
                  (epoch, n_epoch_init, n_iter, time.time() - step_time, errM))
            total_mse_loss += errM
            n_iter += 1
        log = "[*] Epoch: [%2d/%2d] time: %4.4fs, mse: %.8f\n" % (
            epoch, n_epoch_init, time.time() - epoch_time,
            total_mse_loss / n_iter)
        print(log)
        f = open('log_init.txt', 'a')
        f.write(log)
        f.close()

        # quick evaluation on train set
        if (epoch != 0) and (epoch % 10 == 0):
            out = sess.run(net_g_test.outputs, {t_image: sample_imgs_l})
            print("[*] save images")
            tl.vis.save_images(out, [ni, ni],
                               save_dir_ginit + '/train_%d.png' % epoch)

        # save model
        if (epoch != 0) and (epoch % 10 == 0):
            tl.files.save_npz(
                net_g.all_params,
                name=checkpoint_dir +
                '/g_{}_init_softmax.npz'.format(tl.global_flag['mode']),
                sess=sess)

    # 开始训练GAN网络
    for epoch in range(0, n_epoch + 1):
        # update learning rate
        if epoch != 0 and (epoch % decay_every == 0):
            new_lr_decay = lr_decay**(epoch // decay_every)
            sess.run(tf.assign(lr_v, lr_init * new_lr_decay))
            log = " ** new learning rate: %f (for GAN)" % (lr_init *
                                                           new_lr_decay)
            print(log)
        elif epoch == 0:
            sess.run(tf.assign(lr_v, lr_init))
            log = " ** init lr: %f  decay_every_init: %d, lr_decay: %f (for GAN)" % (
                lr_init, decay_every, lr_decay)
            print(log)

        epoch_time = time.time()
        total_d_loss, total_g_loss, total_mse_loss, total_vgg_loss, total_adv_loss, total_vgg_loss2, n_iter = 0, 0, 0, 0, 0, 0, 0

        for idx in range(0, len(train_hr_imgs), batch_size):
            step_time = time.time()
            x_imgs = [1] * batch_size
            for i in range(0, batch_size):
                x_imgs[i] = np.concatenate(
                    [train_hr_imgs[idx + i], train_lr_imgs[idx + i]], axis=2)
            b_imgs = tl.prepro.threading_data(x_imgs,
                                              fn=retain,
                                              is_random=True)
            b_imgs_h = b_imgs[:, :, :, 0:3]
            b_imgs_l = b_imgs[:, :, :, 3:6]

            # update G
            errG, errM, errV, errV2, _ = sess.run(
                [g_loss, mse_loss, vgg_loss, softmax_loss, g_optim], {
                    t_image: b_imgs_l,
                    t_target_image: b_imgs_h,
                    phase_train_placeholder: False
                })
            print(
                "Epoch [%2d/%2d] %4d time: %4.4fs, g_loss: %.8f (mse: %.6f vgg: %.6f facenet: %.6f)"
                % (epoch, n_epoch, n_iter, time.time() - step_time, errG, errM,
                   errV, errV2))
            total_g_loss += errG
            total_mse_loss += errM
            total_vgg_loss += errV
            total_vgg_loss2 += errV2
            n_iter += 1


        log = "[*] Epoch: [%2d/%2d] time: %4.4fs, g_loss: %.8f (mse: %.6f vgg: %.6f facenet: %.6f)\n" \
              % (epoch, n_epoch, time.time() - epoch_time, total_g_loss / n_iter,
                 total_mse_loss / n_iter, total_vgg_loss / n_iter, total_vgg_loss2 / n_iter)
        print(log)
        f = open('log.txt', 'a')
        f.write(log)
        f.close()

        # quick evaluation on train set
        if (epoch != 0) and (epoch % 10 == 0):
            out = sess.run(net_g_test.outputs, {t_image: sample_imgs_l})
            print("[*] save images")
            tl.vis.save_images(out, [ni, ni],
                               save_dir_gan + '/train_%d.png' % epoch)

        # save model
        if (epoch != 0) and (epoch % 50 == 0):
            tl.files.save_npz(net_g.all_params,
                              name=checkpoint_dir +
                              '/g_srgan_softmax%d.npz' % epoch,
                              sess=sess)

    sess.close()