Exemplo n.º 1
0
def srgan_model(features, labels, mode, params):
    del params
    global load_flag

    if mode == tf.estimator.ModeKeys.PREDICT:
        net_g_test = SRGAN_g(features, is_train=False)

        predictions = {'generated_images': net_g_test.outputs}

        return tf.estimator.EstimatorSpec(mode, predictions=predictions)

    net_g = SRGAN_g(features, is_train=True)
    net_d, logits_real = SRGAN_d(labels, is_train=True)
    _, logits_fake = SRGAN_d(net_g.outputs, is_train=True)

    t_target_image_224 = tf.image.resize_images(labels,
                                                size=[224, 224],
                                                method=0,
                                                align_corners=False)
    t_predict_image_224 = tf.image.resize_images(
        net_g.outputs, size=[224, 224], method=0,
        align_corners=False)  # resize_generate_image_for_vgg

    net_vgg, vgg_target_emb = Vgg19_simple_api((t_target_image_224 + 1) / 2)
    _, vgg_predict_emb = Vgg19_simple_api((t_predict_image_224 + 1) / 2)

    d_loss1 = tl.cost.sigmoid_cross_entropy(logits_real,
                                            tf.ones_like(logits_real),
                                            name='d1')
    d_loss2 = tl.cost.sigmoid_cross_entropy(logits_fake,
                                            tf.zeros_like(logits_fake),
                                            name='d2')
    d_loss = d_loss1 + d_loss2

    g_gan_loss = 1e-3 * tl.cost.sigmoid_cross_entropy(
        logits_fake, tf.ones_like(logits_fake), name='g')
    mse_loss = tl.cost.mean_squared_error(net_g.outputs, labels, is_mean=True)
    vgg_loss = 2e-6 * tl.cost.mean_squared_error(
        vgg_predict_emb.outputs, vgg_target_emb.outputs, is_mean=True)

    g_loss = mse_loss + vgg_loss + g_gan_loss

    g_vars = tl.layers.get_variables_with_name('SRGAN_g', True, True)
    d_vars = tl.layers.get_variables_with_name('SRGAN_d', True, True)

    with tf.variable_scope('learning_rate'):
        lr_v = tf.Variable(config.TRAIN.lr_init, trainable=False)

    # SRGAN
    g_optim = tf.train.AdamOptimizer(lr_v, beta1=config.TRAIN.beta1) \
        .minimize(g_loss, var_list=g_vars, global_step=tf.train.get_global_step())
    d_optim = tf.train.AdamOptimizer(lr_v, beta1=config.TRAIN.beta1) \
        .minimize(d_loss, var_list=d_vars, global_step=tf.train.get_global_step())

    joint_op = tf.group([g_optim, d_optim])

    load_vgg(net_vgg)

    return tf.estimator.EstimatorSpec(mode, loss=g_loss, train_op=joint_op)
Exemplo n.º 2
0
def train():
    ## create folders to save result images and trained model
    save_dir_ginit = "samples/{}_ginit".format(tl.global_flag['mode'])
    save_dir_gan = "samples/{}_gan".format(tl.global_flag['mode'])
    tl.files.exists_or_mkdir(save_dir_ginit)
    tl.files.exists_or_mkdir(save_dir_gan)
    checkpoint_dir = "checkpoint"  # checkpoint_resize_conv
    tl.files.exists_or_mkdir(checkpoint_dir)

    ###====================== PRE-LOAD DATA ===========================###
    train_hr_img_list = sorted(
        tl.files.load_file_list(path=config.TRAIN.hr_img_path,
                                regx='.*.png',
                                printable=False))
    train_lr_img_list = sorted(
        tl.files.load_file_list(path=config.TRAIN.lr_img_path,
                                regx='.*.png',
                                printable=False))
    valid_hr_img_list = sorted(
        tl.files.load_file_list(path=config.VALID.hr_img_path,
                                regx='.*.png',
                                printable=False))
    valid_lr_img_list = sorted(
        tl.files.load_file_list(path=config.VALID.lr_img_path,
                                regx='.*.png',
                                printable=False))

    ## If your machine have enough memory, please pre-load the whole train set.
    train_hr_imgs = tl.vis.read_images(train_hr_img_list,
                                       path=config.TRAIN.hr_img_path,
                                       n_threads=32)
    # for im in train_hr_imgs:
    #     print(im.shape)
    # valid_lr_imgs = tl.vis.read_images(valid_lr_img_list, path=config.VALID.lr_img_path, n_threads=32)
    # for im in valid_lr_imgs:
    #     print(im.shape)
    # valid_hr_imgs = tl.vis.read_images(valid_hr_img_list, path=config.VALID.hr_img_path, n_threads=32)
    # for im in valid_hr_imgs:
    #     print(im.shape)
    # exit()

    ###========================== DEFINE MODEL ============================###
    ## train inference
    t_image = tf.placeholder('float32', [batch_size, 96, 96, 3],
                             name='t_image_input_to_SRGAN_generator')
    t_target_image = tf.placeholder('float32', [batch_size, 384, 384, 3],
                                    name='t_target_image')

    net_g = SRGAN_g(t_image, is_train=True, reuse=False)
    net_d, logits_real = SRGAN_d(t_target_image, is_train=True, reuse=False)
    _, logits_fake = SRGAN_d(net_g.outputs, is_train=True, reuse=True)

    net_g.print_params(False)
    net_g.print_layers()
    net_d.print_params(False)
    net_d.print_layers()

    ## vgg inference. 0, 1, 2, 3 BILINEAR NEAREST BICUBIC AREA
    t_target_image_224 = tf.image.resize_images(
        t_target_image, size=[224, 224], method=0, align_corners=False
    )  # resize_target_image_for_vgg # http://tensorlayer.readthedocs.io/en/latest/_modules/tensorlayer/layers.html#UpSampling2dLayer
    t_predict_image_224 = tf.image.resize_images(
        net_g.outputs, size=[224, 224], method=0,
        align_corners=False)  # resize_generate_image_for_vgg

    net_vgg, vgg_target_emb = Vgg19_simple_api((t_target_image_224 + 1) / 2,
                                               reuse=False)
    _, vgg_predict_emb = Vgg19_simple_api((t_predict_image_224 + 1) / 2,
                                          reuse=True)

    ## test inference
    net_g_test = SRGAN_g(t_image, is_train=False, reuse=True)

    # ###========================== DEFINE TRAIN OPS ==========================###
    d_loss1 = tl.cost.sigmoid_cross_entropy(logits_real,
                                            tf.ones_like(logits_real),
                                            name='d1')
    d_loss2 = tl.cost.sigmoid_cross_entropy(logits_fake,
                                            tf.zeros_like(logits_fake),
                                            name='d2')
    d_loss = d_loss1 + d_loss2

    g_gan_loss = 1e-3 * tl.cost.sigmoid_cross_entropy(
        logits_fake, tf.ones_like(logits_fake), name='g')
    mse_loss = tl.cost.mean_squared_error(net_g.outputs,
                                          t_target_image,
                                          is_mean=True)
    vgg_loss = 2e-6 * tl.cost.mean_squared_error(
        vgg_predict_emb.outputs, vgg_target_emb.outputs, is_mean=True)

    g_loss = mse_loss + vgg_loss + g_gan_loss

    g_vars = tl.layers.get_variables_with_name('SRGAN_g', True, True)
    d_vars = tl.layers.get_variables_with_name('SRGAN_d', True, True)

    with tf.variable_scope('learning_rate'):
        lr_v = tf.Variable(lr_init, trainable=False)
    ## Pretrain
    g_optim_init = tf.train.AdamOptimizer(lr_v, beta1=beta1).minimize(
        mse_loss, var_list=g_vars)
    ## SRGAN
    g_optim = tf.train.AdamOptimizer(lr_v,
                                     beta1=beta1).minimize(g_loss,
                                                           var_list=g_vars)
    d_optim = tf.train.AdamOptimizer(lr_v,
                                     beta1=beta1).minimize(d_loss,
                                                           var_list=d_vars)

    ###========================== RESTORE MODEL =============================###
    sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
                                            log_device_placement=False))
    tl.layers.initialize_global_variables(sess)
    if tl.files.load_and_assign_npz(
            sess=sess,
            name=checkpoint_dir + '/g_{}.npz'.format(tl.global_flag['mode']),
            network=net_g) is None:
        tl.files.load_and_assign_npz(
            sess=sess,
            name=checkpoint_dir +
            '/g_{}_init.npz'.format(tl.global_flag['mode']),
            network=net_g)
    tl.files.load_and_assign_npz(sess=sess,
                                 name=checkpoint_dir +
                                 '/d_{}.npz'.format(tl.global_flag['mode']),
                                 network=net_d)

    ###============================= LOAD VGG ===============================###
    vgg19_npy_path = "vgg19.npy"
    if not os.path.isfile(vgg19_npy_path):
        print(
            "Please download vgg19.npz from : https://github.com/machrisaa/tensorflow-vgg"
        )
        exit()
    npz = np.load(vgg19_npy_path, encoding='latin1').item()

    params = []
    for val in sorted(npz.items()):
        W = np.asarray(val[1][0])
        b = np.asarray(val[1][1])
        print("  Loading %s: %s, %s" % (val[0], W.shape, b.shape))
        params.extend([W, b])
    tl.files.assign_params(sess, params, net_vgg)
    # net_vgg.print_params(False)
    # net_vgg.print_layers()

    ###============================= TRAINING ===============================###
    ## use first `batch_size` of train set to have a quick test during training
    sample_imgs = train_hr_imgs[0:batch_size]
    # sample_imgs = tl.vis.read_images(train_hr_img_list[0:batch_size], path=config.TRAIN.hr_img_path, n_threads=32) # if no pre-load train set
    sample_imgs_384 = tl.prepro.threading_data(sample_imgs,
                                               fn=crop_sub_imgs_fn,
                                               is_random=False)
    print('sample HR sub-image:', sample_imgs_384.shape, sample_imgs_384.min(),
          sample_imgs_384.max())
    sample_imgs_96 = tl.prepro.threading_data(sample_imgs_384,
                                              fn=downsample_fn)
    print('sample LR sub-image:', sample_imgs_96.shape, sample_imgs_96.min(),
          sample_imgs_96.max())
    tl.vis.save_images(sample_imgs_96, [ni, ni],
                       save_dir_ginit + '/_train_sample_96.png')
    tl.vis.save_images(sample_imgs_384, [ni, ni],
                       save_dir_ginit + '/_train_sample_384.png')
    tl.vis.save_images(sample_imgs_96, [ni, ni],
                       save_dir_gan + '/_train_sample_96.png')
    tl.vis.save_images(sample_imgs_384, [ni, ni],
                       save_dir_gan + '/_train_sample_384.png')

    ###========================= initialize G ====================###
    ## fixed learning rate
    sess.run(tf.assign(lr_v, lr_init))
    print(" ** fixed learning rate: %f (for init G)" % lr_init)
    for epoch in range(0, n_epoch_init + 1):
        epoch_time = time.time()
        total_mse_loss, n_iter = 0, 0

        ## If your machine cannot load all images into memory, you should use
        ## this one to load batch of images while training.
        # random.shuffle(train_hr_img_list)
        # for idx in range(0, len(train_hr_img_list), batch_size):
        #     step_time = time.time()
        #     b_imgs_list = train_hr_img_list[idx : idx + batch_size]
        #     b_imgs = tl.prepro.threading_data(b_imgs_list, fn=get_imgs_fn, path=config.TRAIN.hr_img_path)
        #     b_imgs_384 = tl.prepro.threading_data(b_imgs, fn=crop_sub_imgs_fn, is_random=True)
        #     b_imgs_96 = tl.prepro.threading_data(b_imgs_384, fn=downsample_fn)

        ## If your machine have enough memory, please pre-load the whole train set.
        for idx in range(0, len(train_hr_imgs), batch_size):
            step_time = time.time()
            b_imgs_384 = tl.prepro.threading_data(train_hr_imgs[idx:idx +
                                                                batch_size],
                                                  fn=crop_sub_imgs_fn,
                                                  is_random=True)
            b_imgs_96 = tl.prepro.threading_data(b_imgs_384, fn=downsample_fn)
            ## update G
            errM, _ = sess.run([mse_loss, g_optim_init], {
                t_image: b_imgs_96,
                t_target_image: b_imgs_384
            })
            print("Epoch [%2d/%2d] %4d time: %4.4fs, mse: %.8f " %
                  (epoch, n_epoch_init, n_iter, time.time() - step_time, errM))
            total_mse_loss += errM
            n_iter += 1
        log = "[*] Epoch: [%2d/%2d] time: %4.4fs, mse: %.8f" % (
            epoch, n_epoch_init, time.time() - epoch_time,
            total_mse_loss / n_iter)
        print(log)

        ## quick evaluation on train set
        if (epoch != 0) and (epoch % 10 == 0):
            out = sess.run(net_g_test.outputs, {
                t_image: sample_imgs_96
            })  #; print('gen sub-image:', out.shape, out.min(), out.max())
            print("[*] save images")
            tl.vis.save_images(out, [ni, ni],
                               save_dir_ginit + '/train_%d.png' % epoch)

        ## save model
        if (epoch != 0) and (epoch % 10 == 0):
            tl.files.save_npz(net_g.all_params,
                              name=checkpoint_dir +
                              '/g_{}_init.npz'.format(tl.global_flag['mode']),
                              sess=sess)

    ###========================= train GAN (SRGAN) =========================###
    for epoch in range(0, n_epoch + 1):
        ## update learning rate
        if epoch != 0 and (epoch % decay_every == 0):
            new_lr_decay = lr_decay**(epoch // decay_every)
            sess.run(tf.assign(lr_v, lr_init * new_lr_decay))
            log = " ** new learning rate: %f (for GAN)" % (lr_init *
                                                           new_lr_decay)
            print(log)
        elif epoch == 0:
            sess.run(tf.assign(lr_v, lr_init))
            log = " ** init lr: %f  decay_every_init: %d, lr_decay: %f (for GAN)" % (
                lr_init, decay_every, lr_decay)
            print(log)

        epoch_time = time.time()
        total_d_loss, total_g_loss, n_iter = 0, 0, 0

        ## If your machine cannot load all images into memory, you should use
        ## this one to load batch of images while training.
        # random.shuffle(train_hr_img_list)
        # for idx in range(0, len(train_hr_img_list), batch_size):
        #     step_time = time.time()
        #     b_imgs_list = train_hr_img_list[idx : idx + batch_size]
        #     b_imgs = tl.prepro.threading_data(b_imgs_list, fn=get_imgs_fn, path=config.TRAIN.hr_img_path)
        #     b_imgs_384 = tl.prepro.threading_data(b_imgs, fn=crop_sub_imgs_fn, is_random=True)
        #     b_imgs_96 = tl.prepro.threading_data(b_imgs_384, fn=downsample_fn)

        ## If your machine have enough memory, please pre-load the whole train set.
        for idx in range(0, len(train_hr_imgs), batch_size):
            step_time = time.time()
            b_imgs_384 = tl.prepro.threading_data(train_hr_imgs[idx:idx +
                                                                batch_size],
                                                  fn=crop_sub_imgs_fn,
                                                  is_random=True)
            b_imgs_96 = tl.prepro.threading_data(b_imgs_384, fn=downsample_fn)
            ## update D
            errD, _ = sess.run([d_loss, d_optim], {
                t_image: b_imgs_96,
                t_target_image: b_imgs_384
            })
            ## update G
            errG, errM, errV, errA, _ = sess.run(
                [g_loss, mse_loss, vgg_loss, g_gan_loss, g_optim], {
                    t_image: b_imgs_96,
                    t_target_image: b_imgs_384
                })
            print(
                "Epoch [%2d/%2d] %4d time: %4.4fs, d_loss: %.8f g_loss: %.8f (mse: %.6f vgg: %.6f adv: %.6f)"
                % (epoch, n_epoch, n_iter, time.time() - step_time, errD, errG,
                   errM, errV, errA))
            total_d_loss += errD
            total_g_loss += errG
            n_iter += 1

        log = "[*] Epoch: [%2d/%2d] time: %4.4fs, d_loss: %.8f g_loss: %.8f" % (
            epoch, n_epoch, time.time() - epoch_time, total_d_loss / n_iter,
            total_g_loss / n_iter)
        print(log)

        ## quick evaluation on train set
        if (epoch != 0) and (epoch % 10 == 0):
            out = sess.run(net_g_test.outputs, {
                t_image: sample_imgs_96
            })  #; print('gen sub-image:', out.shape, out.min(), out.max())
            print("[*] save images")
            tl.vis.save_images(out, [ni, ni],
                               save_dir_gan + '/train_%d.png' % epoch)

        ## save model
        if (epoch != 0) and (epoch % 10 == 0):
            tl.files.save_npz(net_g.all_params,
                              name=checkpoint_dir +
                              '/g_{}.npz'.format(tl.global_flag['mode']),
                              sess=sess)
            tl.files.save_npz(net_d.all_params,
                              name=checkpoint_dir +
                              '/d_{}.npz'.format(tl.global_flag['mode']),
                              sess=sess)
Exemplo n.º 3
0
def train(train_lr_imgs, train_hr_imgs):
    ## create folders to save result images and trained model
    checkpoint_dir = "models_checkpoints"
    tl.files.exists_or_mkdir(checkpoint_dir)

    ###========================== DEFINE MODEL ============================###
    ## train inference
    t_image = tf.placeholder(dtype='float32',
                             shape=(batch_size, 512, 512, 1),
                             name='t_image_input_to_SRGAN_generator')
    t_target_image = tf.placeholder(dtype='float32',
                                    shape=(batch_size, 512, 512, 1),
                                    name='t_target_image')

    net_g = SRGAN_g(t_image, is_train=True, reuse=False)
    net_d, logits_real = SRGAN_d(t_target_image, is_train=True, reuse=False)
    _, logits_fake = SRGAN_d(net_g.outputs, is_train=True, reuse=True)

    ## vgg inference. 0, 1, 2, 3 BILINEAR NEAREST BICUBIC AREA
    t_target_image_224 = tf.image.resize_images(
        t_target_image, size=[224, 224], method=0, align_corners=False
    )  # resize_target_image_for_vgg # http://tensorlayer.readthedocs.io/en/latest/_modules/tensorlayer/layers.html#UpSampling2dLayer
    t_predict_image_224 = tf.image.resize_images(
        net_g.outputs, size=[224, 224], method=0,
        align_corners=False)  # resize_generate_image_for_vgg
    net_vgg, vgg_target_emb = Vgg19_simple_api(input=(t_target_image_224 + 1) /
                                               2,
                                               reuse=False)
    _, vgg_predict_emb = Vgg19_simple_api(input=(t_predict_image_224 + 1) / 2,
                                          reuse=True)

    # ###========================== DEFINE TRAIN OPS ==========================###
    d_loss1 = tl.cost.sigmoid_cross_entropy(logits_real,
                                            tf.ones_like(logits_real),
                                            name='d1')
    d_loss2 = tl.cost.sigmoid_cross_entropy(logits_fake,
                                            tf.zeros_like(logits_fake),
                                            name='d2')
    d_loss = d_loss1 + d_loss2

    g_gan_loss = 1e-3 * tl.cost.sigmoid_cross_entropy(
        logits_fake, tf.ones_like(logits_fake), name='g')
    mse_loss = tl.cost.mean_squared_error(net_g.outputs,
                                          t_target_image,
                                          is_mean=True)
    vgg_loss = 2e-6 * tl.cost.mean_squared_error(
        vgg_predict_emb.outputs, vgg_target_emb.outputs, is_mean=True)

    g_loss = mse_loss + vgg_loss + g_gan_loss

    g_vars = tl.layers.get_variables_with_name('SRGAN_g', True, True)
    d_vars = tl.layers.get_variables_with_name('SRGAN_d', True, True)

    with tf.variable_scope('learning_rate'):
        lr_v = tf.Variable(lr_init, trainable=False)
    ## Pretrain
    g_optim_init = tf.train.AdamOptimizer(lr_v, beta1=beta1).minimize(
        mse_loss, var_list=g_vars)
    ## SRGAN
    g_optim = tf.train.AdamOptimizer(lr_v,
                                     beta1=beta1).minimize(g_loss,
                                                           var_list=g_vars)
    d_optim = tf.train.AdamOptimizer(lr_v,
                                     beta1=beta1).minimize(d_loss,
                                                           var_list=d_vars)

    ###========================== RESTORE MODEL =============================###
    sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
                                            log_device_placement=False))
    tl.layers.initialize_global_variables(sess)
    if tl.files.load_and_assign_npz(
            sess=sess,
            name=checkpoint_dir + '/g_{}.npz'.format(tl.global_flag['mode']),
            network=net_g) is False:
        tl.files.load_and_assign_npz(
            sess=sess,
            name=checkpoint_dir +
            '/g_{}_init.npz'.format(tl.global_flag['mode']),
            network=net_g)
    tl.files.load_and_assign_npz(sess=sess,
                                 name=checkpoint_dir +
                                 '/d_{}.npz'.format(tl.global_flag['mode']),
                                 network=net_d)

    ###============================= LOAD VGG ===============================###
    vgg19_npy_path = "vgg19.npy"
    npz = np.load(vgg19_npy_path, encoding='latin1').item()

    params = []
    for val in sorted(npz.items()):
        W = np.asarray(val[1][0])
        b = np.asarray(val[1][1])
        if val[0] == 'conv1_1':
            W = np.mean(W, axis=2)
            W = W.reshape((3, 3, 1, 64))
        print("  Loading %s: %s, %s" % (val[0], W.shape, b.shape))
        params.extend([W, b])

    tl.files.assign_params(sess, params, net_vgg)

    ###============================= TRAINING ===============================###

    ###========================= initialize G ====================###
    ## fixed learning rate
    sess.run(tf.assign(lr_v, lr_init))
    print(" ** fixed learning rate: %f (for init G)" % lr_init)
    start_time = time.time()
    for epoch in range(0, n_epoch_init):
        epoch_time = time.time()
        total_mse_loss, n_iter = 0, 0

        step_time = None
        for idx in range(0, len(train_hr_imgs), batch_size):
            if idx % 1000 == 0: step_time = time.time()
            b_imgs_hr = train_hr_imgs[idx:idx + batch_size]
            b_imgs_lr = train_lr_imgs[idx:idx + batch_size]
            b_imgs_hr = np.asarray(b_imgs_hr).reshape(
                (batch_size, 512, 512, 1))
            b_imgs_lr = np.asarray(b_imgs_lr).reshape(
                (batch_size, 512, 512, 1))

            ## update G
            errM, _ = sess.run([mse_loss, g_optim_init], {
                t_image: b_imgs_lr,
                t_target_image: b_imgs_hr
            })

            if idx % 1000 == 0:
                print("Epoch [%2d/%2d] %4d time: %4.4fs, mse: %.8f " %
                      (epoch, n_epoch_init, n_iter, time.time() - step_time,
                       errM))
                tl.files.save_npz(
                    net_g.all_params,
                    name=checkpoint_dir +
                    '/g_{}_init.npz'.format(tl.global_flag['mode']),
                    sess=sess)

            total_mse_loss += errM
            n_iter += 1

        log = "[*] Epoch: [%2d/%2d] time: %4.4fs, mse: %.8f" % (
            epoch, n_epoch_init, time.time() - epoch_time,
            total_mse_loss / n_iter)
        print(log)

        ## save model
        tl.files.save_npz(net_g.all_params,
                          name=checkpoint_dir +
                          '/g_{}_init.npz'.format(tl.global_flag['mode']),
                          sess=sess)
    print("G init took: %4.4fs" % (time.time() - start_time))

    ###========================= train GAN (SRGAN) =========================###
    start_time = time.time()
    epoch_losses = defaultdict(list)
    iter_losses = defaultdict(list)

    for epoch in range(0, n_epoch):
        ## update learning rate
        if epoch != 0 and decay_every != 0 and (epoch % decay_every == 0):
            new_lr_decay = lr_decay**(epoch // decay_every)
            sess.run(tf.assign(lr_v, lr_init * new_lr_decay))
            log = " ** new learning rate: %f (for GAN)" % (lr_init *
                                                           new_lr_decay)
            print(log)
        elif epoch == 0:
            sess.run(tf.assign(lr_v, lr_init))
            log = " ** init lr: %f  decay_every_init: %d, lr_decay: %f (for GAN)" % (
                lr_init, decay_every, lr_decay)
            print(log)

        epoch_time = time.time()
        total_d_loss, total_g_loss, n_iter = 0, 0, 0

        step_time = None
        for idx in range(0, len(train_hr_imgs), batch_size):
            if idx % 1000 == 0: step_time = time.time()
            b_imgs_hr = train_hr_imgs[idx:idx + batch_size]
            b_imgs_lr = train_lr_imgs[idx:idx + batch_size]
            b_imgs_hr = np.asarray(b_imgs_hr).reshape(
                (batch_size, 512, 512, 1))
            b_imgs_lr = np.asarray(b_imgs_lr).reshape(
                (batch_size, 512, 512, 1))

            ## update D
            errD, _ = sess.run([d_loss, d_optim], {
                t_image: b_imgs_lr,
                t_target_image: b_imgs_hr
            })
            ## update G
            errG, errM, errV, errA, _ = sess.run(
                [g_loss, mse_loss, vgg_loss, g_gan_loss, g_optim], {
                    t_image: b_imgs_lr,
                    t_target_image: b_imgs_hr
                })

            if idx % 1000 == 0:
                print(
                    "Epoch [%2d/%2d] %4d time: %4.4fs, d_loss: %.8f g_loss: %.8f (mse: %.6f vgg: %.6f adv: %.6f)"
                    % (epoch, n_epoch, n_iter, time.time() - step_time, errD,
                       errG, errM, errV, errA))
                tl.files.save_npz(net_g.all_params,
                                  name=checkpoint_dir +
                                  '/g_{}.npz'.format(tl.global_flag['mode']),
                                  sess=sess)
                tl.files.save_npz(net_d.all_params,
                                  name=checkpoint_dir +
                                  '/d_{}.npz'.format(tl.global_flag['mode']),
                                  sess=sess)

            total_d_loss += errD
            total_g_loss += errG
            n_iter += 1

            iter_losses['d_loss'].append(errD)
            iter_losses['g_loss'].append(errG)
            iter_losses['mse_loss'].append(errM)
            iter_losses['vgg_loss'].append(errV)
            iter_losses['adv_loss'].append(errA)

        log = "[*] Epoch: [%2d/%2d] time: %4.4fs, d_loss: %.8f g_loss: %.8f" % (
            epoch, n_epoch, time.time() - epoch_time, total_d_loss / n_iter,
            total_g_loss / n_iter)
        print(log)
        epoch_losses['d_loss'].append(total_d_loss)
        epoch_losses['g_loss'].append(total_g_loss)

        ## save model
        tl.files.save_npz(net_g.all_params,
                          name=checkpoint_dir +
                          '/g_{}.npz'.format(tl.global_flag['mode']),
                          sess=sess)
        tl.files.save_npz(net_d.all_params,
                          name=checkpoint_dir +
                          '/d_{}.npz'.format(tl.global_flag['mode']),
                          sess=sess)
    print("G train took: %4.4fs" % (time.time() - start_time))

    ## create visualizations for losses from training
    plot_total_losses(epoch_losses)
    plot_iterative_losses(iter_losses)
    for loss, values in epoch_losses.items():
        np.save(checkpoint_dir + "/epoch_" + loss + '.npy', np.asarray(values))
    for loss, values in iter_losses.items():
        np.save(checkpoint_dir + "/iter_" + loss + '.npy', np.asarray(values))
    print("[*] saved losses")
Exemplo n.º 4
0
def network_new2(
        top_dir="sr_tanh/",
        svg_dir="dataset/Test/",  #test_data
        pxl_dir="dataset/Train/",  #train_data
        output_dir="pic_smooth/",
        test_output_dir='test_output/',
        checkpoint_dir="save_model",
        checkpoint_dir1="save_model",
        model_name="model4",
        big_loop=1,
        scale_num=2,
        epoch_init=5000,
        strides=20,
        batch_size=4,
        max_idx=92,
        data_size=92,
        lr_init=1e-3,
        learning_rate=1e-5,
        vgg_weight_list=[1, 1, 5e-1, 1e-1],
        use_vgg=False,
        use_L1_loss=False,
        wgan=False,
        init_g=True,
        init_d=True,
        init_b=False,
        method=0,
        lowest_resolution_log2=4,
        train_net=True,
        generate_pics=True,
        resume_network=False):

    logger = logging.getLogger(__name__)
    logger.setLevel(level=logging.INFO)
    handler = logging.FileHandler(top_dir + "log.txt")
    handler.setLevel(logging.INFO)
    formatter = logging.Formatter(
        '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    handler.setFormatter(formatter)
    logger.addHandler(handler)

    for idx, val in enumerate(network_new2.__defaults__):
        logger.info(
            str(network_new2.__code__.co_varnames[idx]) + ' == ' + str(val))

    output_dir = top_dir + output_dir
    test_output_dir = top_dir + test_output_dir
    checkpoint_dir = top_dir + checkpoint_dir
    checkpoint_dir1 = top_dir + checkpoint_dir1

    logger.info("start building the net")

    if use_vgg:
        print('use vgg')
    t_target_image_data, image_padding_nums = read_data(pxl_dir, data_size)
    resolution = t_target_image_data.shape[1] / scale_num
    target_resolution = t_target_image_data.shape[1]
    resolution_log2 = int(np.floor(np.log2(resolution)))
    target_resolution_log2 = int(np.floor(np.log2(target_resolution)))

    #image = tf.image.resize_images(t_image, size=[64, 64], method=2)
    t_image_target = tf.placeholder(
        'float32', [None, target_resolution, target_resolution, 3],
        name='t_image_target')
    t_image_ = tf.image.resize_images(
        t_image_target,
        size=[target_resolution // scale_num, target_resolution // scale_num],
        method=method)
    t_image = tf.image.resize_images(
        t_image_, size=[target_resolution, target_resolution], method=method)

    t_image_target_list = []
    t_image_list = []

    #generate list of pics from 2 ** 2 resolution to t_image_size resolution
    net_Gs, mix_rates = my_GAN_G2(t_image, is_train=True, reuse=False)
    print("init Gs")
    net_Gs[-1].print_params(False)
    net_g_test, _ = my_GAN_G2(t_image, is_train=False, reuse=True)
    print("init g_test")

    if use_vgg:
        t_target_image_224 = tf.placeholder('float32', [None, 224, 224, 3],
                                            name='t_image_224')
        t_predict_image_224 = tf.placeholder('float32', [None, 224, 224, 3],
                                             name='t_target_224')
        net_vgg, vgg_target_emb = Vgg19_simple_api(
            (t_target_image_224 + 1) / 2, reuse=False)

    #initialize the list to store different level net
    net_ds = []
    b_outputs = []
    logits_reals = []

    logits_fakes = []
    logits_fakes2 = []

    d_loss_list = []
    b_loss_list = []
    d_loss3_list = []
    mse_loss_list = []
    g_gan_loss_list = []
    g_loss_list = []

    g_init_optimizer_list = []
    d_init_optimizer_list = []

    g_optimizer_list = []
    d_optimizer_list = []
    b_optimizer_list = []

    w_clip_list = []

    with tf.variable_scope('learning_rate'):
        lr_v = tf.Variable(lr_init, trainable=False)

    print("init Ds")
    for i in range(lowest_resolution_log2, target_resolution_log2 + 1):
        idx = i - lowest_resolution_log2
        cur_resolution = 2**i
        size = [cur_resolution, cur_resolution]

        target_i = tf.image.resize_images(t_image_target,
                                          size=size,
                                          method=method)
        image_i = tf.image.resize_images(t_image, size=size, method=method)
        t_image_target_list += [target_i]
        t_image_list += [image_i]

        if use_vgg:
            t_target_image_224 = tf.image.resize_images(
                t_image_target, size=[224, 224], method=1, align_corners=False
            )  # resize_target_image_for_vgg # http://tensorlayer.readthedocs.io/en/latest/_modules/tensorlayer/layers.html#UpSampling2dLayer
            add_dimens = tf.zeros_like(t_target_image_224)
            print(add_dimens.dtype)
            print(t_target_image_224.dtype)
            t_predict_image_224 = tf.image.resize_images(
                net_Gs[idx].outputs,
                size=[224, 224],
                method=1,
                align_corners=False)  # resize_generate_image_for_vgg
            net_vgg, vgg_target_emb = Vgg19_simple_api(
                (t_target_image_224 + 1) / 2, reuse=True)
            _, vgg_predict_emb = Vgg19_simple_api(
                (t_predict_image_224 + 1) / 2, reuse=True)

        #initialize the D_reals and D_fake
        net_d, logits_real = my_GAN_D1(target_i,
                                       is_train=True,
                                       reuse=False,
                                       use_sigmoid=not wgan)
        _, logits_fake = my_GAN_D1(net_Gs[idx].outputs,
                                   is_train=True,
                                   reuse=True,
                                   use_sigmoid=not wgan)
        _, logits_fake2 = my_GAN_D1(image_i,
                                    is_train=True,
                                    reuse=True,
                                    use_sigmoid=not wgan)

        blend_output = net_CT_blend(image_i, net_Gs[idx].outputs)
        b_outputs += [blend_output]

        net_ds += [net_d]
        logits_reals += [logits_real]
        logits_fakes += [logits_fake]
        logits_fakes2 += [logits_fake2]

        mix_factors = np.random.uniform(size=[1, 1, 1, int(target_i.shape[3])])
        print(mix_factors.shape)
        mix_pic = net_Gs[idx].outputs * mix_factors + target_i * (1 -
                                                                  mix_factors)

        _, logits_mix = my_GAN_D1(mix_pic, is_train=True, reuse=True)

        d_loss1 = tl.cost.sigmoid_cross_entropy(logits_real,
                                                tf.ones_like(logits_real),
                                                name='d1_%d' % cur_resolution)

        d_loss2 = tl.cost.sigmoid_cross_entropy(logits_fake,
                                                tf.zeros_like(logits_fake),
                                                name='d2_%d' % cur_resolution)

        d_loss3 = tl.cost.sigmoid_cross_entropy(logits_fake2,
                                                tf.zeros_like(logits_fake2),
                                                name='d3_%d' % cur_resolution)

        d_loss4 = (tf.reduce_mean(logits_fake2)) - (
            tf.reduce_mean(logits_real))

        d_loss4 = tf.nn.sigmoid(d_loss4)  #make sure in [0, 1]

        d_loss = 1 * (d_loss1 + d_loss2)  #+ d_loss3  + d_loss4
        d_loss += 0.

        d_loss3 += d_loss1

        use_vgg22 = True

        vgg_loss = 0
        if use_vgg:
            for i, vgg_target in enumerate(vgg_target_emb):
                vgg_loss += vgg_weight_list[i] * tl.cost.mean_squared_error(
                    vgg_predict_emb[i].outputs,
                    vgg_target.outputs,
                    is_mean=True)
        g_gan_loss1 = tl.cost.sigmoid_cross_entropy(logits_fake,
                                                    tf.ones_like(logits_fake),
                                                    name='g_%d' %
                                                    cur_resolution)
        g_gan_loss2 = (tf.reduce_mean(logits_fake2)) - (
            tf.reduce_mean(logits_fake))
        g_gan_loss2 = tf.nn.sigmoid(g_gan_loss2)  #make sure in [0, 1]

        g_gan_loss = g_gan_loss1  # + g_gan_loss2

        mse_loss = tl.cost.mean_squared_error(net_Gs[idx].outputs,
                                              target_i,
                                              is_mean=True)
        if use_L1_loss:
            mes_loss = tf.reduce_mean(
                tf.reduce_mean(tf.abs(net_Gs[idx].outputs - target_i)))
        g_gan_loss_list += [g_gan_loss]
        mse_loss_list += [mse_loss]

        g_loss = 1e-3 * g_gan_loss + mse_loss

        L1_norm = tf.reduce_mean(tf.reduce_mean(net_Gs[idx].outputs))

        def TV_loss(x):
            loss1 = x[:, :, 1:, :] - x[:, :, :-1, :]**2
            loss2 = x[:, 1:, :, :] - x[:, :-1, :, :]**2
            return tf.reduce_sum(tf.reduce_sum(loss1)) + tf.reduce_sum(
                tf.reduce_sum(loss2))

        tV_loss = TV_loss(net_Gs[idx].outputs)

        b_loss = tl.cost.mean_squared_error(blend_output.outputs,
                                            target_i,
                                            is_mean=True)

        if i >= 7: g_loss += vgg_loss

        #g_loss += vgg_loss

        g_vars = tl.layers.get_variables_with_name('my_GAN_G', True, True)
        d_vars = tl.layers.get_variables_with_name(
            'my_GAN_D_%d' % cur_resolution, True, True)
        b_vars = tl.layers.get_variables_with_name(
            'my_CT_blend_%d' % cur_resolution, True, True)

        g_optim_init = tf.train.AdamOptimizer(lr_v,
                                              0.9).minimize(mse_loss,
                                                            var_list=g_vars)
        g_init_optimizer_list += [g_optim_init]

        d_optim_init = tf.train.AdamOptimizer(lr_v,
                                              0.9).minimize(d_loss3,
                                                            var_list=d_vars)

        g_optim = tf.train.AdamOptimizer(lr_v, 0.9).minimize(g_loss,
                                                             var_list=g_vars)
        d_optim = tf.train.AdamOptimizer(lr_v, 0.9).minimize(d_loss,
                                                             var_list=d_vars)

        b_optim = tf.train.AdamOptimizer(lr_v, 0.9).minimize(b_loss,
                                                             var_list=b_vars)

        #WGAN
        if wgan:
            print('mode is wgan')
            g_loss = -(tf.reduce_mean(logits_fake)) + vgg_loss
            d_loss = (tf.reduce_mean(logits_fake)) - (
                tf.reduce_mean(logits_real))
            d_loss3 = (tf.reduce_mean(logits_fake2)) - (
                tf.reduce_mean(logits_real))

            mix_grads = tf.gradients(tf.reduce_sum(logits_mix), mix_pic)
            mix_norms = tf.sqrt(
                tf.reduce_sum(tf.square(mix_grads), axis=[1, 2, 3]))

            addtion = tf.reduce_mean(tf.square(mix_norms - 1.)) * 5.0
            #d_loss = d_loss + d_loss3 + addtion + tl.cost.mean_squared_error(logits_real, tf.zeros_like(logits_real)) * 1e-3
            d_loss = d_loss + d_loss3

            g_optim = tf.train.RMSPropOptimizer(learning_rate).minimize(
                g_loss, var_list=g_vars)
            d_optim = tf.train.RMSPropOptimizer(learning_rate).minimize(
                d_loss, var_list=d_vars)

            d_optim_init = tf.train.RMSPropOptimizer(learning_rate).minimize(
                d_loss3, var_list=d_vars)
            clip_ops = []
            for var in d_vars:
                clip_bound = [-1.0, 1.0]
                clip_ops.append(
                    tf.assign(
                        var, tf.clip_by_value(var, clip_bound[0],
                                              clip_bound[1])))
            clip_disc_weights = tf.group(*clip_ops)
            w_clip_list += [clip_disc_weights]

        d_loss_list += [d_loss]
        d_loss3_list += [d_loss3]
        g_loss_list += [g_loss]
        b_loss_list += [b_loss]
        g_optimizer_list += [g_optim]
        d_optimizer_list += [d_optim]
        b_optimizer_list += [b_optim]
        d_init_optimizer_list += [d_optim_init]

        print("init Res : %d D" % cur_resolution)

    #Restore Model
    config = tf.ConfigProto(allow_soft_placement=True,
                            log_device_placement=False)
    config.gpu_options.per_process_gpu_memory_fraction = 0.8
    sess = tf.Session(config=config)
    tl.layers.initialize_global_variables(sess)

    #......code for restore model
    if use_vgg:
        vgg19_npy_path = "vgg19.npy"
        if not os.path.isfile(vgg19_npy_path):
            print(
                "Please download vgg19.npz from : https://github.com/machrisaa/tensorflow-vgg"
            )
            exit()
        npz = np.load(vgg19_npy_path, encoding='latin1').item()

        params = []
        for val in sorted(npz.items()):
            W = np.asarray(val[1][0])
            b = np.asarray(val[1][1])
            print("  Loading %s: %s, %s" % (val[0], W.shape, b.shape))
            params.extend([W, b])
            if (len(params) == len(net_vgg.all_params)): break
        tl.files.assign_params(sess, params, net_vgg)
        # net_vgg.print_params(False)
        # net_vgg.print_layers()

    #Read Data
    #t_image_data, t_target_image_data = split_pic(read_data(svg_dir, data_size))

    #initialize G

    for temp_i in range(big_loop):

        decay_every = epoch_init // 2
        lr_decay = 0.1

        logger.info("start training the net")

        for R in range(lowest_resolution_log2, target_resolution_log2 + 1):
            idx = R - lowest_resolution_log2
            if resume_network or not train_net:
                tl.files.load_and_assign_npz_dict(sess=sess,
                                                  name=checkpoint_dir1 +
                                                  '/g_%d_level_my_gan.npz' % R,
                                                  network=net_Gs[idx])
                tl.files.load_and_assign_npz_dict(sess=sess,
                                                  name=checkpoint_dir1 +
                                                  '/b_%d_level_my_gan.npz' % R,
                                                  network=b_outputs[idx])
                #tl.files.load_and_assign_npz_dict(sess = sess, name = checkpoint_dir1 + '/d_%d_level_my_gan.npz' % R, network = net_ds[idx])

            total_mse_loss = 0
            mse_loss = mse_loss_list[idx]
            g_optim_init = g_init_optimizer_list[idx]

            total_d3_loss = 0
            d_loss3 = d_loss3_list[idx]
            d_optim = d_optimizer_list[idx]
            d_loss = d_loss_list[idx]

            d_optim_init = d_init_optimizer_list[idx]

            ni = int(np.sqrt(batch_size))
            out_svg = sess.run(
                t_image_list[idx],
                {t_image_target: t_target_image_data[0:batch_size]})
            out_pxl = sess.run(
                t_image_target_list[idx],
                {t_image_target: t_target_image_data[0:batch_size]})
            print(out_pxl[0])
            print(out_pxl.dtype)
            tl.vis.save_images(out_svg, [ni, ni],
                               output_dir + "R_%d_svg.png" % (R))
            tl.vis.save_images(out_pxl, [ni, ni],
                               output_dir + "R_%d_pxl.png" % (R))

            f = open('log%d.txt' % R, 'w')
            pre_loss_list = []
            now_loss_list = []
            if init_g and train_net:
                #fix lr_v
                print('init g')
                sess.run(tf.assign(lr_v, lr_init))

                for epoch in range(epoch_init + 1):
                    iters, data, padding_nums = batch_data(
                        t_target_image_data, image_padding_nums, max_idx,
                        batch_size)
                    total_mse_loss = 0
                    total_pre_loss = np.zeros([2])
                    total_now_loss = np.zeros([2])
                    for i in range(iters):
                        errM, _ = sess.run([mse_loss, g_optim_init],
                                           {t_image_target: data[i]})
                        total_mse_loss += errM
                        if R == target_resolution_log2:  #final steps
                            lowR_pics, output_pics, GT_pics = sess.run(
                                [
                                    t_image_list[idx], net_g_test[idx].outputs,
                                    t_image_target_list[idx]
                                ], {t_image_target: data[i]})
                            pre_lowR_pics = clip_pics(lowR_pics,
                                                      padding_nums[i])
                            pre_output_pics = clip_pics(
                                output_pics, padding_nums[i])
                            pre_GT_pics = clip_pics(GT_pics, padding_nums[i])
                            for ii in range(data[i].shape[0]):
                                pre_loss = cal_loss(pre_lowR_pics[ii],
                                                    pre_GT_pics[ii])
                                now_loss = cal_loss(pre_output_pics[ii],
                                                    pre_GT_pics[ii])
                                total_pre_loss += pre_loss
                                total_now_loss += now_loss
                    pre_loss_list += [total_pre_loss / max_idx]
                    now_loss_list += [total_now_loss / max_idx]
                    print("[%d/%d] total_mse_loss = %f errM = %f" %
                          (epoch, epoch_init, total_mse_loss, errM))
                    ## save model
                    if (epoch % strides == 0):
                        print("save img %d" % R)
                        out, logits_real, logits_fake, logits_fake2 = sess.run(
                            [
                                net_g_test[idx].outputs,
                                tf.nn.sigmoid(logits_reals[idx]),
                                tf.nn.sigmoid(logits_fakes[idx]),
                                tf.nn.sigmoid(logits_fakes2[idx])
                            ], {
                                t_image_target:
                                t_target_image_data[0:batch_size]
                            })
                        print(out[0])
                        print(out.dtype)
                        tl.vis.save_images(
                            out, [ni, ni],
                            output_dir + "R_%d_init_%d.png" % (R, epoch))
                        if epoch % 10 == 0:
                            tl.files.save_npz_dict(
                                net_Gs[idx].all_params,
                                name=checkpoint_dir +
                                ('/g_%d_level_{}_init.npz' % R).format(
                                    tl.global_flag['mode']),
                                sess=sess)
                print("R %d total_mse_loss = %f" % (2**R, total_mse_loss))

                save_list(top_dir + 'init_g_pre', pre_loss_list)
                save_list(top_dir + 'init_g_now', now_loss_list)
                pre_loss_list = []
                now_loss_list = []

            if init_d and train_net:
                #fix lr_v
                print('init d')
                sess.run(tf.assign(lr_v, lr_init))

                for epoch in range(epoch_init + 1):
                    iters, data, padding_nums = batch_data(
                        t_target_image_data, image_padding_nums, max_idx,
                        batch_size)
                    for i in range(iters):
                        errD3, errD, _ = sess.run(
                            [d_loss3, d_loss, d_optim_init],
                            {t_image_target: data[i]})
                        total_d3_loss += errD3
                    print("[%d/%d] d_loss = %f, errD3 = %f" %
                          (epoch, epoch_init, errD, errD3))
                    ## save model
                    if (epoch != 0) and (epoch % 5 == 0):
                        tl.files.save_npz_dict(
                            net_ds[idx].all_params,
                            name=checkpoint_dir +
                            '/d_{}_init.npz'.format(tl.global_flag['mode']),
                            sess=sess)
                    if epoch % 10 == 0:
                        out, logits_real, logits_fake, logits_fake2 = sess.run(
                            [
                                net_g_test[idx].outputs,
                                tf.nn.sigmoid(logits_reals[idx]),
                                tf.nn.sigmoid(logits_fakes[idx]),
                                tf.nn.sigmoid(logits_fakes2[idx])
                            ], {
                                t_image_target:
                                t_target_image_data[0:batch_size]
                            })
                        print("logits_real", file=f)
                        print(logits_real, file=f)
                        print("logits_fake", file=f)
                        print(logits_fake, file=f)
                        print("logits_fake2", file=f)
                        print(logits_fake2, file=f)

                print("R %d total_d3_loss = %f" % (2**R, total_d3_loss))
                print("init g or d end", file=f)
            #train GAN
            g_optim = g_optimizer_list[idx]
            d_optim = d_optimizer_list[idx]
            d_loss = d_loss_list[idx]
            g_loss = g_loss_list[idx]
            mse_loss = mse_loss_list[idx]
            g_gan_loss = g_gan_loss_list[idx]
            mix_rate, pic_rate = mix_rates[idx]

            increas = 2. / epoch_init
            mix_rate_vals = np.arange(0., 1. + increas, increas)

            last_errD = 0.
            last_errG = 0.
            if train_net:
                for epoch in range(epoch_init + 1):
                    if epoch != 0 and (epoch % decay_every == 0):
                        new_lr_decay = lr_decay**(epoch // decay_every)
                        sess.run(tf.assign(lr_v, lr_init * new_lr_decay))
                    elif epoch == 0:
                        sess.run(tf.assign(lr_v, lr_init))
                        #mix_mat = np.zeros([t_image_list[idx].shape[i] for i in range(1, 4)], dtype = 'float32')
                        sess.run(tf.assign(mix_rate, 0))
                        sess.run(tf.assign(pic_rate, 0))

                    total_d_loss = 0
                    total_g_loss = 0
                    total_mse_loss = 0
                    iters, data, padding_nums = batch_data(
                        t_target_image_data, image_padding_nums, max_idx,
                        batch_size)
                    total_pre_loss = np.zeros([2])
                    total_now_loss = np.zeros([2])
                    for i in range(iters):
                        #update G
                        if wgan:
                            errG, errM, errA, _ = sess.run(
                                [g_loss, mse_loss, g_gan_loss, g_optim],
                                {t_image_target: data[i]})
                        #update D
                        if True:  #last_errG * 1e3 <= last_errD * 10: # D learning too fast
                            flag = 1
                            errD, _ = sess.run([d_loss, d_optim],
                                               {t_image_target: data[i]})
                        #print("[%d/%d] epoch %d times d_loss : %f" % (epoch, epoch_init, i, errD))
                        #update G
                        if not wgan:
                            #print("train G")
                            errG, errM, errA, _ = sess.run(
                                [g_loss, mse_loss, g_gan_loss, g_optim],
                                {t_image_target: data[i]})
                        #print("[%d/%d] epoch %d times, g_loss : %f, mse_loss : %f, g_gan_loss : %f"
                        #        % (epoch, epoch_init, i, errG, errM, errA))
                        #clip var_val
                        if wgan:
                            _ = sess.run(w_clip_list[idx])
                        last_errD = errD
                        last_errG = errA
                        total_d_loss += errD
                        total_g_loss += errG
                        total_mse_loss += errM
                        if R == target_resolution_log2:  #final steps
                            lowR_pics, output_pics, GT_pics = sess.run(
                                [
                                    t_image_list[idx], net_g_test[idx].outputs,
                                    t_image_target_list[idx]
                                ], {t_image_target: data[i]})
                            pre_lowR_pics = clip_pics(lowR_pics,
                                                      padding_nums[i])
                            pre_output_pics = clip_pics(
                                output_pics, padding_nums[i])
                            pre_GT_pics = clip_pics(GT_pics, padding_nums[i])
                            for ii in range(data[i].shape[0]):
                                pre_loss = cal_loss(pre_lowR_pics[ii],
                                                    pre_GT_pics[ii])
                                now_loss = cal_loss(pre_output_pics[ii],
                                                    pre_GT_pics[ii])
                                total_pre_loss += pre_loss
                                total_now_loss += now_loss
                    pre_loss_list += [total_pre_loss / max_idx]
                    now_loss_list += [total_now_loss / max_idx]

                    print("lastD = %f, lastG = %f" % (last_errD, last_errG))
                    print("[%d/%d] epoch %d times d_loss : %f" %
                          (epoch, epoch_init, i, errD))
                    print(
                        "[%d/%d] epoch %d times, errM = %f, mse_loss : %f, g_gan_loss : %f"
                        % (epoch, epoch_init, i, errM, total_mse_loss, errA))

                    #save genate pic
                    if (epoch % strides == 0):
                        print("save img %d" % R)
                        out, logits_real, logits_fake, logits_fake2 = sess.run(
                            [
                                net_g_test[idx].outputs,
                                tf.nn.sigmoid(logits_reals[idx]),
                                tf.nn.sigmoid(logits_fakes[idx]),
                                tf.nn.sigmoid(logits_fakes2[idx])
                            ], {
                                t_image_target:
                                t_target_image_data[0:batch_size]
                            })
                        print(out[0])
                        out = out.clip(0, 255)
                        print(out.dtype)
                        tl.vis.save_images(
                            out, [ni, ni],
                            output_dir + "R_%d_train_%d.png" % (R, epoch))
                        #increase the mix_rate from 0 to 1 linearly
                        mix_rate_val = tf.nn.sigmoid(mix_rate).eval(
                            session=sess)
                        mix_pic_val = tf.nn.sigmoid(pic_rate).eval(
                            session=sess)
                        print("logits_real")
                        print(logits_real)
                        print("logits_fake")
                        print(logits_fake)
                        print("logits_fake2")
                        print(logits_fake2)
                        print("logits_real", file=f)
                        print(logits_real, file=f)
                        print("logits_fake", file=f)
                        print(logits_fake, file=f)
                        print("logits_fake2", file=f)
                        print(logits_fake2, file=f)
                        if (logits_real == logits_fake).all():
                            print("optimize well")
                            print("optimize well", file=f)
                        print("mix_rate, pic_rate")
                        print(mix_rate_val, mix_pic_val)
                        print("mix_rate, pic_rate", file=f)
                        print(mix_rate_val, mix_pic_val, file=f)
                    ## save model
                    if (epoch != 0) and (epoch % 10 == 0):
                        tl.files.save_npz_dict(
                            net_Gs[idx].all_params,
                            name=checkpoint_dir +
                            ('/g_%d_level_{}.npz' % R).format(
                                tl.global_flag['mode']),
                            sess=sess)
                        tl.files.save_npz_dict(
                            net_d.all_params,
                            name=checkpoint_dir +
                            ('/d_%d_level_{}.npz' % R).format(
                                tl.global_flag['mode']),
                            sess=sess)
                save_list(top_dir + 'g_pre', pre_loss_list)
                save_list(top_dir + 'g_now', now_loss_list)
                pre_loss_list = []
                now_loss_list = []
                f.close()

                blend_output = b_outputs[idx]
                b_loss = b_loss_list[idx]
                b_optim = b_optimizer_list[idx]
                if not True:
                    #fix lr_v
                    sess.run(tf.assign(lr_v, lr_init))

                    for epoch in range(epoch_init * 3 + 1):
                        iters, data, padding_nums = batch_data(
                            t_target_image_data, image_padding_nums, max_idx,
                            batch_size)
                        for i in range(iters):
                            errM, _ = sess.run([b_loss, b_optim],
                                               {t_image_target: data[i]})
                            total_mse_loss += errM
                        print("[%d/%d] total_mse_loss = %f errM = %f" %
                              (epoch, epoch_init, total_mse_loss, errM))
                        ## save model
                        if (epoch % (strides * 3) == 0):
                            print("save img %d" % R)
                            out = sess.run(blend_output.outputs, {
                                t_image_target:
                                t_target_image_data[0:batch_size]
                            })
                            out = out.clip(0, 255)
                            #print(out[0])
                            print(out.dtype)
                            tl.vis.save_images(
                                out, [ni, ni],
                                output_dir + "b_%d_output_%d.png" % (R, epoch))
                            if epoch % 100 == 0:
                                tl.files.save_npz_dict(
                                    blend_output.all_params,
                                    name=checkpoint_dir +
                                    ('/b_%d_level_{}.npz' % R).format(
                                        tl.global_flag['mode']),
                                    sess=sess)

        logger.info("end training the net")

        if not train_net or generate_pics:
            if init_b:
                sess.run(tf.assign(lr_v, lr_init))
                for epoch in range(epoch_init * 3 + 1):
                    iters, data, padding_nums = batch_data(
                        t_target_image_data, image_padding_nums, max_idx,
                        batch_size)
                    for i in range(iters):
                        errM, _ = sess.run([b_loss, b_optim],
                                           {t_image_target: data[i]})
                        total_mse_loss += errM
                    print("[%d/%d] total_mse_loss = %f errM = %f" %
                          (epoch, epoch_init, total_mse_loss, errM))
                    ## save model
                    if (epoch % (strides * 3) == 0):
                        print("save img %d" % R)
                        out = sess.run(blend_output.outputs, {
                            t_image_target:
                            t_target_image_data[0:batch_size]
                        })
                        out = out.clip(0, 255)
                        #print(out[0])
                        print(out.dtype)
                        tl.vis.save_images(
                            out, [ni, ni],
                            output_dir + "b_%d_output_%d.png" % (R, epoch))
                        if epoch % 100 == 0:
                            tl.files.save_npz_dict(
                                blend_output.all_params,
                                name=checkpoint_dir +
                                ('/b_%d_level_{}.npz' % R).format(
                                    tl.global_flag['mode']),
                                sess=sess)

            logger.info("load params")

            tl.files.load_and_assign_npz_dict(sess=sess,
                                              name=checkpoint_dir1 +
                                              '/g_%d_level_my_gan.npz' % R,
                                              network=net_Gs[-1])
            tl.files.load_and_assign_npz_dict(sess=sess,
                                              name=checkpoint_dir1 +
                                              '/b_%d_level_my_gan.npz' % R,
                                              network=b_outputs[-1])

            logger.info("read pics")
            test_set_dir = ["Set5/", "Set14/"]
            test_no = [5, 13]
            for j in range(2):
                data_pxl, pic_pad_nums = read_data(svg_dir + test_set_dir[j],
                                                   num=test_no[j])
                iters = data_pxl.shape[0]
                data_pxl = np.split(data_pxl, iters)
                #iters, data = batch_data((t_image_data, t_target_image_data), 100, batch_size)
                logger.info('start evaluating pics')
                for i in range(iters):
                    print("save img %d" % R)
                    out = sess.run(net_g_test[idx].outputs,
                                   {t_image_target: data_pxl[i]})
                    out = out.clip(0, 255)
                    out = np.array([clip_pic(out[0], pic_pad_nums[i])])
                    tl.vis.save_images(
                        out, [1, 1], test_output_dir + test_set_dir[j] +
                        "g_%d_output_%d.png" % (R, i))
                    out = sess.run(b_outputs[idx].outputs,
                                   {t_image_target: data_pxl[i]})
                    out = out.clip(0, 255)
                    out = np.array([clip_pic(out[0], pic_pad_nums[i])])
                    tl.vis.save_images(
                        out, [1, 1], test_output_dir + test_set_dir[j] +
                        "b_%d_output_%d.png" % (R, i))
                    out = sess.run(t_image, {t_image_target: data_pxl[i]})
                    out = np.array([clip_pic(out[0], pic_pad_nums[i])])
                    tl.vis.save_images(
                        out, [1, 1], test_output_dir + test_set_dir[j] +
                        "svg_%d_%d.png" % (R, i))
                    out = sess.run(t_image_target,
                                   {t_image_target: data_pxl[i]})
                    out = np.array([clip_pic(out[0], pic_pad_nums[i])])
                    tl.vis.save_images(
                        out, [1, 1], test_output_dir + test_set_dir[j] +
                        "pxl_%d_%d.png" % (R, i))
            logger.info('end evaluating pics')
def train():
    ## create folders to save result images and trained model
    save_dir_gan = "samples/{}_gan".format(tl.global_flag['mode'])  #srresnet
    tl.files.exists_or_mkdir(save_dir_gan)
    checkpoint_dir = "checkpoint"
    tl.files.exists_or_mkdir(checkpoint_dir)

    ###====================== PRE-LOAD DATA ===========================###
    train_hr_img_list = sorted(
        tl.files.load_file_list(path=config.TRAIN.hr_img_path,
                                regx='.*.png',
                                printable=False))
    train_lr_img_list = sorted(
        tl.files.load_file_list(path=config.TRAIN.lr_img_path,
                                regx='.*.png',
                                printable=False))
    valid_hr_img_list = sorted(
        tl.files.load_file_list(path=config.VALID.hr_img_path,
                                regx='.*.png',
                                printable=False))
    valid_lr_img_list = sorted(
        tl.files.load_file_list(path=config.VALID.lr_img_path,
                                regx='.*.png',
                                printable=False))

    ## If your machine has enough memory, please pre-load the whole train set.
    print("reading images")

    train_hr_imgs = []

    for img__ in train_hr_img_list:
        image_loaded = scipy.misc.imread(os.path.join(config.TRAIN.hr_img_path,
                                                      img__),
                                         mode='L')
        image_loaded = image_loaded.reshape(
            (image_loaded.shape[0], image_loaded.shape[1], 1))
        train_hr_imgs.append(image_loaded)

    print(type(train_hr_imgs), len(train_hr_img_list))

    ###========================== DEFINE MODEL ============================###
    ## train inference

    t_image = tf.placeholder('float32', [batch_size, 56, 56, 1],
                             name='t_image_input_to_SRGAN_generator')
    t_target_image = tf.placeholder('float32', [batch_size, 224, 224, 1],
                                    name='t_target_image')

    print("t_image:", tf.shape(t_image))
    print("t_target_image:", tf.shape(t_target_image))

    net_g = SRGAN_g(t_image, is_train=True,
                    reuse=False)  #SRGAN_g is the SRResNet portion of the GAN

    print("net_g.outputs:", tf.shape(net_g.outputs))

    net_g.print_params(False)
    net_g.print_layers()

    ## vgg inference. 0, 1, 2, 3 BILINEAR NEAREST BICUBIC AREA
    t_target_image_224 = tf.image.resize_images(
        t_target_image, size=[224, 224], method=0, align_corners=False
    )  # resize_target_image_for_vgg # http://tensorlayer.readthedocs.io/en/latest/_modules/tensorlayer/layers.html#UpSampling2dLayer
    t_predict_image_224 = tf.image.resize_images(
        net_g.outputs, size=[224, 224], method=0,
        align_corners=False)  # resize_generate_image_for_vgg

    ## Added as VGG works for RGB and expects 3 channels.
    t_target_image_224 = tf.image.grayscale_to_rgb(t_target_image_224)
    t_predict_image_224 = tf.image.grayscale_to_rgb(t_predict_image_224)

    print("net_g.outputs:", tf.shape(net_g.outputs))
    print("t_predict_image_224:", tf.shape(t_predict_image_224))

    net_vgg, vgg_target_emb = Vgg19_simple_api((t_target_image_224 + 1) / 2,
                                               reuse=False)
    _, vgg_predict_emb = Vgg19_simple_api((t_predict_image_224 + 1) / 2,
                                          reuse=True)

    ## test inference
    net_g_test = SRGAN_g(t_image, is_train=False, reuse=True)

    # ###========================== DEFINE TRAIN OPS ==========================###
    mse_loss = tl.cost.mean_squared_error(net_g.outputs,
                                          t_target_image,
                                          is_mean=True)
    vgg_loss = 2e-6 * tl.cost.mean_squared_error(
        vgg_predict_emb.outputs, vgg_target_emb.outputs, is_mean=True)

    g_loss = mse_loss + vgg_loss

    mse_loss_summary = tf.summary.scalar('Generator MSE loss', mse_loss)
    vgg_loss_summary = tf.summary.scalar('Generator VGG loss', vgg_loss)
    g_loss_summary = tf.summary.scalar('Generator total loss', g_loss)

    g_vars = tl.layers.get_variables_with_name('SRGAN_g', True, True)

    with tf.variable_scope('learning_rate'):
        lr_v = tf.Variable(lr_init, trainable=False)
    ## SRResNet
    g_optim = tf.train.AdamOptimizer(lr_v,
                                     beta1=beta1).minimize(g_loss,
                                                           var_list=g_vars)

    ###========================== RESTORE MODEL =============================###
    sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
                                            log_device_placement=False))
    tl.layers.initialize_global_variables(sess)
    tl.files.load_and_assign_npz(sess=sess,
                                 name=checkpoint_dir +
                                 '/g_{}.npz'.format(tl.global_flag['mode']),
                                 network=net_g)

    ###============================= LOAD VGG ===============================###
    vgg19_npy_path = "vgg19.npy"
    if not os.path.isfile(vgg19_npy_path):
        print(
            "Please download vgg19.npz from : https://github.com/machrisaa/tensorflow-vgg"
        )
        exit()
    npz = np.load(vgg19_npy_path, encoding='latin1').item()

    params = []
    for val in sorted(npz.items()):
        W = np.asarray(val[1][0])
        b = np.asarray(val[1][1])
        print("  Loading %s: %s, %s" % (val[0], W.shape, b.shape))
        params.extend([W, b])
    tl.files.assign_params(sess, params, net_vgg)

    ###============================= TRAINING ===============================###
    ## use first `batch_size` of train set to have a quick test during training
    sample_imgs = train_hr_imgs[0:batch_size]

    print("sample_imgs size:", len(sample_imgs), sample_imgs[0].shape)

    sample_imgs_224 = tl.prepro.threading_data(sample_imgs,
                                               fn=crop_sub_imgs_fn,
                                               is_random=False)
    print('sample HR sub-image:', sample_imgs_224.shape, sample_imgs_224.min(),
          sample_imgs_224.max())
    sample_imgs_56 = tl.prepro.threading_data(sample_imgs_224,
                                              fn=downsample_fn)
    print('sample LR sub-image:', sample_imgs_56.shape, sample_imgs_56.min(),
          sample_imgs_56.max())
    tl.vis.save_images(sample_imgs_56, [ni, ni],
                       save_dir_gan + '/_train_sample_56.png')
    tl.vis.save_images(sample_imgs_224, [ni, ni],
                       save_dir_gan + '/_train_sample_224.png')
    #tl.vis.save_image(sample_imgs_96[0],  save_dir_gan + '/_train_sample_96.png')
    #tl.vis.save_image(sample_imgs_384[0],save_dir_gan + '/_train_sample_384.png')

    ###========================= train SRResNet  =========================###

    merged_summary_generator = tf.summary.merge(
        [mse_loss_summary, vgg_loss_summary,
         g_loss_summary])  #g_gan_loss_summary
    summary_generator_writer = tf.summary.FileWriter("./log/train/generator")

    learning_rate_writer = tf.summary.FileWriter("./log/train/learning_rate")

    count = 0
    for epoch in range(0, n_epoch + 1):
        ## update learning rate
        if epoch != 0 and (epoch % decay_every == 0):
            new_lr_decay = lr_decay**(epoch // decay_every)
            sess.run(tf.assign(lr_v, lr_init * new_lr_decay))
            log = " ** new learning rate: %f (for GAN)" % (lr_init *
                                                           new_lr_decay)
            print(log)

            learning_rate_writer.add_summary(
                tf.Summary(value=[
                    tf.Summary.Value(tag="Learning_rate per epoch",
                                     simple_value=(lr_init * new_lr_decay)),
                ]), (epoch))

        elif epoch == 0:
            sess.run(tf.assign(lr_v, lr_init))
            log = " ** init lr: %f  decay_every_init: %d, lr_decay: %f (for GAN)" % (
                lr_init, decay_every, lr_decay)
            print(log)

            learning_rate_writer.add_summary(
                tf.Summary(value=[
                    tf.Summary.Value(tag="Learning_rate per epoch",
                                     simple_value=lr_init),
                ]), (epoch))

        epoch_time = time.time()
        total_g_loss, n_iter = 0, 0

        ## If your machine cannot load all images into memory, you should use
        ## this one to load batch of images while training.
        # random.shuffle(train_hr_img_list)
        # for idx in range(0, len(train_hr_img_list), batch_size):
        #     step_time = time.time()
        #     b_imgs_list = train_hr_img_list[idx : idx + batch_size]
        #     b_imgs = tl.prepro.threading_data(b_imgs_list, fn=get_imgs_fn, path=config.TRAIN.hr_img_path)
        #     b_imgs_384 = tl.prepro.threading_data(b_imgs, fn=crop_sub_imgs_fn, is_random=True)
        #     b_imgs_96 = tl.prepro.threading_data(b_imgs_384, fn=downsample_fn)

        ## If your machine have enough memory, please pre-load the whole train set.

        loss_per_batch = []

        mse_loss_summary_per_epoch = []
        vgg_loss_summary_per_epoch = []
        g_loss_summary_per_epoch = []

        for idx in range(0, len(train_hr_imgs), batch_size):
            step_time = time.time()
            b_imgs_224 = tl.prepro.threading_data(train_hr_imgs[idx:idx +
                                                                batch_size],
                                                  fn=crop_sub_imgs_fn,
                                                  is_random=True)
            b_imgs_56 = tl.prepro.threading_data(b_imgs_224, fn=downsample_fn)

            summary_pb = tf.summary.Summary()

            ## update G
            errG, errM, errV, _, generator_summary = sess.run(
                [
                    g_loss, mse_loss, vgg_loss, g_optim,
                    merged_summary_generator
                ], {
                    t_image: b_imgs_56,
                    t_target_image: b_imgs_224
                })  #g_ga_loss

            summary_pb = tf.summary.Summary()
            summary_pb.ParseFromString(generator_summary)

            generator_summaries = {}
            for val in summary_pb.value:
                # Assuming all summaries are scalars.
                generator_summaries[val.tag] = val.simple_value

            mse_loss_summary_per_epoch.append(
                generator_summaries['Generator_MSE_loss'])
            vgg_loss_summary_per_epoch.append(
                generator_summaries['Generator_VGG_loss'])
            g_loss_summary_per_epoch.append(
                generator_summaries['Generator_total_loss'])

            print(
                "Epoch [%2d/%2d] %4d time: %4.4fs, g_loss: %.8f (mse: %.6f vgg: %.6f)"
                % (epoch, n_epoch, n_iter, time.time() - step_time, errG, errM,
                   errV))
            total_g_loss += errG
            n_iter += 1

        log = "[*] Epoch: [%2d/%2d] time: %4.4fs, g_loss: %.8f" % (
            epoch, n_epoch, time.time() - epoch_time / n_iter,
            total_g_loss / n_iter)
        print(log)

        #####
        #
        # logging generator summary
        #
        ######

        summary_generator_writer.add_summary(
            tf.Summary(value=[
                tf.Summary.Value(tag="Generator_MSE_loss per epoch",
                                 simple_value=np.mean(
                                     mse_loss_summary_per_epoch)),
            ]), (epoch))

        summary_generator_writer.add_summary(
            tf.Summary(value=[
                tf.Summary.Value(tag="Generator_VGG_loss per epoch",
                                 simple_value=np.mean(
                                     vgg_loss_summary_per_epoch)),
            ]), (epoch))

        summary_generator_writer.add_summary(
            tf.Summary(value=[
                tf.Summary.Value(tag="Generator_total_loss per epoch",
                                 simple_value=np.mean(
                                     g_loss_summary_per_epoch)),
            ]), (epoch))

        out = sess.run(net_g_test.outputs, {t_image: sample_imgs_56})
        print("[*] save images")
        tl.vis.save_image(out[0], save_dir_gan + '/train_%d.png' % epoch)

        ## save model
        if (epoch != 0) and (epoch % 3 == 0):

            tl.files.save_npz(
                net_g.all_params,
                name=checkpoint_dir +
                '/g_{}_{}.npz'.format(tl.global_flag['mode'], epoch),
                sess=sess)
Exemplo n.º 6
0
def train():
    import os, time
    ## create folders to save result images and trained model
    save_dir_ginit = "samples/{}_ginit".format(tl.global_flag['mode'])
    tl.files.exists_or_mkdir(save_dir_ginit)
    checkpoint_dir = "checkpoint"  # checkpoint_resize_conv
    tl.files.exists_or_mkdir(checkpoint_dir)
    num_sub_imgs = config.Size.num_sub_imgs

    input_img_path = config.TRAIN.input_img_path
    if config.TRAIN.input_type == "quan":
        input_img_path = config.TRAIN.input_img_path + "_quan"
    elif config.TRAIN.input_type == "clip":
        input_img_path = config.TRAIN.input_img_path + "_clip"
    else:
        print("input_type error")
        return

    label_img_path = config.TRAIN.label_img_path

    ###====================== PRE-LOAD DATA ===========================###
    train_input_img_list = sorted(
        tl.files.load_file_list(path=input_img_path,
                                regx='.*.png',
                                printable=False))
    train_label_img_list = sorted(
        tl.files.load_file_list(path=label_img_path,
                                regx='.*.png',
                                printable=False))
    print('train_input_img_list : ', train_input_img_list)
    print('train_label_img_list : ', train_label_img_list)

    ## If your machine have enough memory, please pre-load the whole train set.
    train_input_imgs = tl.vis.read_images(train_input_img_list,
                                          path=input_img_path,
                                          n_threads=32)
    train_label_imgs = tl.vis.read_images(train_label_img_list,
                                          path=label_img_path,
                                          n_threads=32)

    ###========================== DEFINE MODEL ============================###
    ## train inference
    t_image = tf.placeholder(
        'float32', [batch_size * num_sub_imgs, sub_img_size, sub_img_size, 3],
        name='t_image_input_to_generator')
    t_target_image = tf.placeholder(
        'float32', [batch_size * num_sub_imgs, sub_img_size, sub_img_size, 3],
        name='t_target_image')

    net_g = unet(t_image, reuse=False)

    net_g.print_params(False)
    net_g.print_layers()

    ## vgg inference. 0, 1, 2, 3 BILINEAR NEAREST BICUBIC AREA
    if with_vgg:
        t_target_image_224 = tf.image.resize_images(
            t_target_image, size=[224, 224], method=0, align_corners=False
        )  # resize_target_image_for_vgg # http://tensorlayer.readthedocs.io/en/latest/_modules/tensorlayer/layers.html#UpSampling2dLayer
        t_predict_image_224 = tf.image.resize_images(
            net_g.outputs, size=[224, 224], method=0,
            align_corners=False)  # resize_generate_image_for_vgg

        net_vgg, vgg_target_emb = Vgg19_simple_api(
            (t_target_image_224 + 1) / 2, reuse=False)
        _, vgg_predict_emb = Vgg19_simple_api((t_predict_image_224 + 1) / 2,
                                              reuse=True)

    ## test inference
    net_g_test = unet(t_image, reuse=True)

    # ###========================== DEFINE TRAIN OPS ==========================##
    if with_mse:
        mse_loss = tl.cost.mean_squared_error(net_g.outputs,
                                              t_target_image,
                                              is_mean=True)
        if with_vgg:
            vgg_loss = 5e-6 * tl.cost.mean_squared_error(
                vgg_predict_emb.outputs, vgg_target_emb.outputs, is_mean=True)
            mse_loss = mse_loss + vgg_loss
    else:
        if with_vgg:
            mse_loss = 2e-6 * tl.cost.mean_squared_error(
                vgg_predict_emb.outputs, vgg_target_emb.outputs, is_mean=True)

    g_vars = tl.layers.get_variables_with_name('SRGAN_g', True, True)
    d_vars = tl.layers.get_variables_with_name('SRGAN_d', True, True)

    with tf.variable_scope('learning_rate'):
        lr_v = tf.Variable(lr_init, trainable=False)
    ## Pretrain
    g_optim_init = tf.train.AdamOptimizer(lr_v, beta1=beta1).minimize(
        mse_loss, var_list=g_vars)

    ###========================== RESTORE MODEL =============================###
    sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
                                            log_device_placement=False))
    tl.layers.initialize_global_variables(sess)
    tl.files.load_and_assign_npz(sess=sess,
                                 name=checkpoint_dir + '/checkpoint.npz',
                                 network=net_g)

    ###============================= LOAD VGG ===============================###
    vgg19_npy_path = "vgg19.npy"

    if not os.path.isfile(vgg19_npy_path):
        print(
            "Please download vgg19.npz from : https://github.com/machrisaa/tensorflow-vgg"
        )
        exit()

    npz = np.load(vgg19_npy_path, encoding='latin1').item()

    params = []
    if with_vgg:
        for val in sorted(npz.items()):
            W = np.asarray(val[1][0])
            b = np.asarray(val[1][1])
            print("  Loading %s: %s, %s" % (val[0], W.shape, b.shape))
            params.extend([W, b])
        tl.files.assign_params(sess, params, net_vgg)
        net_vgg.print_params(False)
        net_vgg.print_layers()

    ###============================= TRAINING ===============================###
    ## use first `batch_size` of train set to have a quick test during training
    _sample_imgs_input = train_input_imgs[0:batch_size]
    _sample_imgs_label = train_label_imgs[0:batch_size]
    sample_imgs_input, sample_imgs_label = crop_sub_imgs(
        _sample_imgs_input, _sample_imgs_label)

    tl.vis.save_images(sample_imgs_input, [ni, ni],
                       save_dir_ginit + '/_train_sample_96.png')
    tl.vis.save_images(sample_imgs_label, [ni, ni],
                       save_dir_ginit + '/_train_sample_384.png')

    ###========================= initialize G ====================###
    ## fixed learning rate
    lr_g = config.TRAIN.lr_g

    sess.run(tf.assign(lr_v, lr_init))
    print(" ** fixed learning rate: %f (for init G)" % lr_init)

    with open("loss.txt", 'w') as f:
        f.write("")
    decay_g = config.TRAIN.decay_g
    init_time = time.time()
    for epoch in range(0, n_epoch_init + 1):
        if epoch != 0 and (epoch % decay_g == 0):
            new_lr_decay = lr_decay**(epoch // decay_g)
            sess.run(tf.assign(lr_v, lr_init * new_lr_decay))
            log = " ** new learning rate: %f (for GAN)" % (lr_init *
                                                           new_lr_decay)
            print(log)

        epoch_time = time.time()
        total_mse_loss, n_iter = 0, 0

        for idx in range(0, len(train_input_imgs), batch_size):
            step_time = time.time()
            b_imgs_input, b_imgs_label = list_sub_imgs(
                train_input_imgs[idx:idx + batch_size],
                train_label_imgs[idx:idx + batch_size])
            errM, _ = sess.run([mse_loss, g_optim_init], {
                t_image: b_imgs_input,
                t_target_image: b_imgs_label
            })
            print("Epoch [%2d/%2d] %4d time: %4.4fs, mse: %.8f " %
                  (epoch, n_epoch_init, n_iter, time.time() - step_time, errM))
            total_mse_loss += errM
            n_iter += 1

        log = "[*] Epoch: [%2d/%2d] time: %4.4fs, mse: %.8f" % (
            epoch, n_epoch_init, time.time() - epoch_time,
            total_mse_loss / n_iter)
        print(log)
        with open("loss.txt", 'a') as f:
            f.write(str(total_mse_loss / n_iter) + "\n")

        ## quick evaluation on train set
        if (epoch != 0) and (epoch % 10 == 0):
            out = sess.run(net_g_test.outputs, {
                t_image: sample_imgs_input
            })  # ; print('gen sub-image:', out.shape, out.min(), out.max())
            print("[*] save images")
            tl.vis.save_images(out, [ni, ni],
                               save_dir_ginit + '/train_%d.png' % epoch)

        ## save model
        if (epoch != 0) and (epoch % 10 == 0):
            tl.files.save_npz(net_g.all_params,
                              name=checkpoint_dir +
                              '/g_{}_init.npz'.format(tl.global_flag['mode']),
                              sess=sess)

    print("init complete %dsec" % (init_time - time.time()))
Exemplo n.º 7
0
def train():
    ## create folders to save result images and trained model
    save_dir_ginit = "samples/{}_ginit".format(tl.global_flag['mode'])
    save_dir_gan = "samples/{}_gan".format(tl.global_flag['mode'])
    tl.files.exists_or_mkdir(save_dir_ginit)
    tl.files.exists_or_mkdir(save_dir_gan)
    checkpoint_dir = "checkpoint"  # checkpoint_resize_conv
    tl.files.exists_or_mkdir(checkpoint_dir)

    ###====================== PRE-LOAD DATA ===========================###
    train_hr_img_list = sorted(
        tl.files.load_file_list(path=config.TRAIN.hr_img_path,
                                regx='.*.png',
                                printable=False))
    train_lr_img_list = sorted(
        tl.files.load_file_list(path=config.TRAIN.lr_img_path,
                                regx='.*.png',
                                printable=False))
    valid_hr_img_list = sorted(
        tl.files.load_file_list(path=config.VALID.hr_img_path,
                                regx='.*.png',
                                printable=False))
    valid_lr_img_list = sorted(
        tl.files.load_file_list(path=config.VALID.lr_img_path,
                                regx='.*.png',
                                printable=False))

    ## If your machine have enough memory, please pre-load the whole train set.

    print("reading images")
    train_hr_imgs = []  #[None] * len(train_hr_img_list)

    #sess = tf.Session()
    for img__ in train_hr_img_list:

        image_loaded = scipy.misc.imread(os.path.join(config.TRAIN.hr_img_path,
                                                      img__),
                                         mode='L')
        image_loaded = image_loaded.reshape(
            (image_loaded.shape[0], image_loaded.shape[1], 1))

        train_hr_imgs.append(image_loaded)

    print(type(train_hr_imgs), len(train_hr_img_list))

    ###========================== DEFINE MODEL ============================###
    ## train inference

    #t_image = tf.placeholder('float32', [batch_size, 96, 96, 3], name='t_image_input_to_SRGAN_generator')
    #t_target_image = tf.placeholder('float32', [batch_size, 384, 384, 3], name='t_target_image')

    t_image = tf.placeholder('float32', [batch_size, 28, 224, 1],
                             name='t_image_input_to_SRGAN_generator')
    t_target_image = tf.placeholder(
        'float32', [batch_size, 224, 224, 1], name='t_target_image'
    )  # may have to convert 224x224x1 into 224x224x3, with channel 1 & 2 as 0. May have to have separate place-holder ?

    print("t_image:", tf.shape(t_image))
    print("t_target_image:", tf.shape(t_target_image))

    net_g = SRGAN_g(t_image, is_train=True, reuse=False)
    print("net_g.outputs:", tf.shape(net_g.outputs))

    net_d, logits_real = SRGAN_d(t_target_image, is_train=True, reuse=False)
    _, logits_fake = SRGAN_d(net_g.outputs, is_train=True, reuse=True)

    net_g.print_params(False)
    net_g.print_layers()
    net_d.print_params(False)
    net_d.print_layers()

    ## vgg inference. 0, 1, 2, 3 BILINEAR NEAREST BICUBIC AREA
    t_target_image_224 = tf.image.resize_images(
        t_target_image, size=[224, 224], method=0, align_corners=False
    )  # resize_target_image_for_vgg # http://tensorlayer.readthedocs.io/en/latest/_modules/tensorlayer/layers.html#UpSampling2dLayer
    t_predict_image_224 = tf.image.resize_images(
        net_g.outputs, size=[224, 224], method=0,
        align_corners=False)  # resize_generate_image_for_vgg

    ## Added as VGG works for RGB and expects 3 channels.
    t_target_image_224 = tf.image.grayscale_to_rgb(t_target_image_224)
    t_predict_image_224 = tf.image.grayscale_to_rgb(t_predict_image_224)

    print("net_g.outputs:", tf.shape(net_g.outputs))
    print("t_predict_image_224:", tf.shape(t_predict_image_224))

    net_vgg, vgg_target_emb = Vgg19_simple_api((t_target_image_224 + 1) / 2,
                                               reuse=False)
    _, vgg_predict_emb = Vgg19_simple_api((t_predict_image_224 + 1) / 2,
                                          reuse=True)

    ## test inference
    net_g_test = SRGAN_g(t_image, is_train=False, reuse=True)

    # ###========================== DEFINE TRAIN OPS ==========================###
    d_loss1 = tl.cost.sigmoid_cross_entropy(logits_real,
                                            tf.ones_like(logits_real),
                                            name='d1')
    d_loss2 = tl.cost.sigmoid_cross_entropy(logits_fake,
                                            tf.zeros_like(logits_fake),
                                            name='d2')
    d_loss = d_loss1 + d_loss2

    g_gan_loss = 1e-3 * tl.cost.sigmoid_cross_entropy(
        logits_fake, tf.ones_like(logits_fake), name='g')
    mse_loss = tl.cost.mean_squared_error(net_g.outputs,
                                          t_target_image,
                                          is_mean=True)
    vgg_loss = 2e-6 * tl.cost.mean_squared_error(
        vgg_predict_emb.outputs, vgg_target_emb.outputs, is_mean=True)

    g_loss = mse_loss + vgg_loss + g_gan_loss

    d_loss1_summary = tf.summary.scalar('Disciminator logits_real loss',
                                        d_loss1)
    d_loss2_summary = tf.summary.scalar('Disciminator logits_fake loss',
                                        d_loss2)
    d_loss_summary = tf.summary.scalar('Disciminator total loss', d_loss)

    g_gan_loss_summary = tf.summary.scalar('Generator GAN loss', g_gan_loss)
    mse_loss_summary = tf.summary.scalar('Generator MSE loss', mse_loss)
    vgg_loss_summary = tf.summary.scalar('Generator VGG loss', vgg_loss)
    g_loss_summary = tf.summary.scalar('Generator total loss', g_loss)

    g_vars = tl.layers.get_variables_with_name('SRGAN_g', True, True)
    d_vars = tl.layers.get_variables_with_name('SRGAN_d', True, True)

    with tf.variable_scope('learning_rate'):
        lr_v = tf.Variable(lr_init, trainable=False)
    ## Pretrain
    #	UNCOMMENT THE LINE BELOW!!!
    #g_optim_init = tf.train.AdamOptimizer(lr_v, beta1=beta1).minimize(mse_loss, var_list=g_vars)
    ## SRGAN
    g_optim = tf.train.AdamOptimizer(lr_v,
                                     beta1=beta1).minimize(g_loss,
                                                           var_list=g_vars)
    d_optim = tf.train.AdamOptimizer(lr_v,
                                     beta1=beta1).minimize(d_loss,
                                                           var_list=d_vars)

    ###========================== RESTORE MODEL =============================###
    sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
                                            log_device_placement=False))
    tl.layers.initialize_global_variables(sess)
    #if tl.files.load_and_assign_npz(sess=sess, name=checkpoint_dir + '/g_{}.npz'.format(tl.global_flag['mode']), network=net_g) is False:
    #   tl.fites.load_and_assign_npz(sess=sess, name=checkpoint_dir + '/g_{}_init.npz'.format(tl.global_flag['mode']), network=net_g)
    #tl.files.load_and_assign_npz(sess=sess, name=checkpoint_dir + '/d_{}.npz'.format(tl.global_flag['mode']), network=net_d)

    ###============================= LOAD VGG ===============================###
    vgg19_npy_path = "vgg19.npy"
    if not os.path.isfile(vgg19_npy_path):
        print(
            "Please download vgg19.npz from : https://github.com/machrisaa/tensorflow-vgg"
        )
        exit()
    npz = np.load(vgg19_npy_path, encoding='latin1').item()

    params = []
    for val in sorted(npz.items()):
        W = np.asarray(val[1][0])
        b = np.asarray(val[1][1])
        print("  Loading %s: %s, %s" % (val[0], W.shape, b.shape))
        params.extend([W, b])
    tl.files.assign_params(sess, params, net_vgg)
    # net_vgg.print_params(False)
    # net_vgg.print_layers()

    ###============================= TRAINING ===============================###
    ## use first `batch_size` of train set to have a quick test during training
    sample_imgs = train_hr_imgs[0:batch_size]
    # sample_imgs = tl.vis.read_images(train_hr_img_list[0:batch_size], path=config.TRAIN.hr_img_path, n_threads=32) # if no pre-load train set

    print("sample_imgs size:", len(sample_imgs), sample_imgs[0].shape)

    sample_imgs_384 = tl.prepro.threading_data(sample_imgs,
                                               fn=crop_sub_imgs_fn,
                                               is_random=False)
    print('sample HR sub-image:', sample_imgs_384.shape, sample_imgs_384.min(),
          sample_imgs_384.max())
    sample_imgs_96 = tl.prepro.threading_data(sample_imgs_384,
                                              fn=downsample_fn_mod)
    print('sample LR sub-image:', sample_imgs_96.shape, sample_imgs_96.min(),
          sample_imgs_96.max())
    #tl.vis.save_images(sample_imgs_96, [ni, ni], save_dir_ginit + '/_train_sample_96.png')
    #tl.vis.save_images(sample_imgs_384, [ni, ni], save_dir_ginit + '/_train_sample_384.png')
    #tl.vis.save_images(sample_imgs_96, [ni, ni], save_dir_gan + '/_train_sample_96.png')
    #tl.vis.save_images(sample_imgs_384, [ni, ni], save_dir_gan + '/_train_sample_384.png')
    '''
    ###========================= initialize G ====================###
    
    merged_summary_initial_G = tf.summary.merge([mse_loss_summary])
    summary_intial_G_writer = tf.summary.FileWriter("./log/train/initial_G")
    
    

    ## fixed learning rate
    sess.run(tf.assign(lr_v, lr_init))
    print(" ** fixed learning rate: %f (for init G)" % lr_init)
    count = 0
    for epoch in range(0, n_epoch_init + 1):
        epoch_time = time.time()
        total_mse_loss, n_iter = 0, 0

        ## If your machine cannot load all images into memory, you should use
        ## this one to load batch of images while training.
        # random.shuffle(train_hr_img_list)
        # for idx in range(0, len(train_hr_img_list), batch_size):
        #     step_time = time.time()
        #     b_imgs_list = train_hr_img_list[idx : idx + batch_size]
        #     b_imgs = tl.prepro.threading_data(b_imgs_list, fn=get_imgs_fn, path=config.TRAIN.hr_img_path)
        #     b_imgs_384 = tl.prepro.threading_data(b_imgs, fn=crop_sub_imgs_fn, is_random=True)
        #     b_imgs_96 = tl.prepro.threading_data(b_imgs_384, fn=downsample_fn)

        ## If your machine have enough memory, please pre-load the whole train set.
        
        
        intial_MSE_G_summary_per_epoch = []
        
        
        for idx in range(0, len(train_hr_imgs), batch_size):
            step_time = time.time()
            b_imgs_384 = tl.prepro.threading_data(train_hr_imgs[idx:idx + batch_size], fn=crop_sub_imgs_fn, is_random=True)
            b_imgs_96 = tl.prepro.threading_data(b_imgs_384, fn=downsample_fn_mod)
            ## update G
            errM, _, mse_summary_initial_G = sess.run([mse_loss, g_optim_init, merged_summary_initial_G], {t_image: b_imgs_96, t_target_image: b_imgs_384})
            print("Epoch [%2d/%2d] %4d time: %4.4fs, mse: %.8f " % (epoch, n_epoch_init, n_iter, time.time() - step_time, errM))

            
            summary_pb = tf.summary.Summary()
            summary_pb.ParseFromString(mse_summary_initial_G)
            
            intial_G_summaries = {}
            for val in summary_pb.value:
            # Assuming all summaries are scalars.
                intial_G_summaries[val.tag] = val.simple_value
            #print("intial_G_summaries:", intial_G_summaries)
            
            
            intial_MSE_G_summary_per_epoch.append(intial_G_summaries['Generator_MSE_loss'])
            
            
            #summary_intial_G_writer.add_summary(mse_summary_initial_G, (count + 1)) #(epoch + 1)*(n_iter+1))
            #count += 1


            total_mse_loss += errM
            n_iter += 1
        log = "[*] Epoch: [%2d/%2d] time: %4.4fs, mse: %.8f" % (epoch, n_epoch_init, time.time() - epoch_time, total_mse_loss / n_iter)
        print(log)

        
        summary_intial_G_writer.add_summary(tf.Summary(value=[tf.Summary.Value(tag="Generator_Initial_MSE_loss per epoch", simple_value=np.mean(intial_MSE_G_summary_per_epoch)),]), (epoch))


        ## quick evaluation on train set
        #if (epoch != 0) and (epoch % 10 == 0):
        out = sess.run(net_g_test.outputs, {t_image: sample_imgs_96})  #; print('gen sub-image:', out.shape, out.min(), out.max())
        print("[*] save images")
        for im in range(len(out)):
            if(im%4==0 or im==1197):
                tl.vis.save_image(out[im], save_dir_ginit + '/train_%d_%d.png' % (epoch,im))

        ## save model
        saver=tf.train.Saver()
        if (epoch%10==0 and epoch!=0):
            saver.save(sess, 'checkpoint/init_'+str(epoch)+'.ckpt')      

   #if (epoch != 0) and (epoch % 10 == 0):
        #tl.files.save_npz(net_g.all_params, name=checkpoint_dir + '/g_{}_{}_init.npz'.format(tl.global_flag['mode'], epoch), sess=sess)
    '''
    ###========================= train GAN (SRGAN) =========================###
    saver = tf.train.Saver()
    saver.restore(sess, 'checkpoint/main_10.ckpt')
    print('Restored main_10, begin 11/50')
    merged_summary_discriminator = tf.summary.merge(
        [d_loss1_summary, d_loss2_summary, d_loss_summary])
    summary_discriminator_writer = tf.summary.FileWriter(
        "./log/train/discriminator")

    merged_summary_generator = tf.summary.merge([
        g_gan_loss_summary, mse_loss_summary, vgg_loss_summary, g_loss_summary
    ])
    summary_generator_writer = tf.summary.FileWriter("./log/train/generator")

    learning_rate_writer = tf.summary.FileWriter("./log/train/learning_rate")

    count = 0
    for epoch in range(11, n_epoch + 11):
        ## update learning rate
        if epoch != 0 and (epoch % decay_every == 0):
            new_lr_decay = lr_decay**(epoch // decay_every)
            sess.run(tf.assign(lr_v, lr_init * new_lr_decay))
            log = " ** new learning rate: %f (for GAN)" % (lr_init *
                                                           new_lr_decay)
            print(log)

            learning_rate_writer.add_summary(
                tf.Summary(value=[
                    tf.Summary.Value(tag="Learning_rate per epoch",
                                     simple_value=(lr_init * new_lr_decay)),
                ]), (epoch))

        elif epoch == 0:
            sess.run(tf.assign(lr_v, lr_init))
            log = " ** init lr: %f  decay_every_init: %d, lr_decay: %f (for GAN)" % (
                lr_init, decay_every, lr_decay)
            print(log)

            learning_rate_writer.add_summary(
                tf.Summary(value=[
                    tf.Summary.Value(tag="Learning_rate per epoch",
                                     simple_value=lr_init),
                ]), (epoch))

        epoch_time = time.time()
        total_d_loss, total_g_loss, n_iter = 0, 0, 0

        ## If your machine cannot load all images into memory, you should use
        ## this one to load batch of images while training.
        # random.shuffle(train_hr_img_list)
        # for idx in range(0, len(train_hr_img_list), batch_size):
        #     step_time = time.time()
        #     b_imgs_list = train_hr_img_list[idx : idx + batch_size]
        #     b_imgs = tl.prepro.threading_data(b_imgs_list, fn=get_imgs_fn, path=config.TRAIN.hr_img_path)
        #     b_imgs_384 = tl.prepro.threading_data(b_imgs, fn=crop_sub_imgs_fn, is_random=True)
        #     b_imgs_96 = tl.prepro.threading_data(b_imgs_384, fn=downsample_fn)

        ## If your machine have enough memory, please pre-load the whole train set.

        loss_per_batch = []

        d_loss1_summary_per_epoch = []
        d_loss2_summary_per_epoch = []
        d_loss_summary_per_epoch = []

        g_gan_loss_summary_per_epoch = []
        mse_loss_summary_per_epoch = []
        vgg_loss_summary_per_epoch = []
        g_loss_summary_per_epoch = []

        for idx in range(0, len(train_hr_imgs), batch_size):
            step_time = time.time()
            b_imgs_384 = tl.prepro.threading_data(train_hr_imgs[idx:idx +
                                                                batch_size],
                                                  fn=crop_sub_imgs_fn,
                                                  is_random=True)
            b_imgs_96 = tl.prepro.threading_data(b_imgs_384,
                                                 fn=downsample_fn_mod)
            ## update D
            errD, _, discriminator_summary = sess.run(
                [d_loss, d_optim, merged_summary_discriminator], {
                    t_image: b_imgs_96,
                    t_target_image: b_imgs_384
                })

            summary_pb = tf.summary.Summary()
            summary_pb.ParseFromString(discriminator_summary)
            #print("discriminator_summary", summary_pb, type(summary_pb))

            discriminator_summaries = {}
            for val in summary_pb.value:
                # Assuming all summaries are scalars.
                discriminator_summaries[val.tag] = val.simple_value

            d_loss1_summary_per_epoch.append(
                discriminator_summaries['Disciminator_logits_real_loss'])
            d_loss2_summary_per_epoch.append(
                discriminator_summaries['Disciminator_logits_fake_loss'])
            d_loss_summary_per_epoch.append(
                discriminator_summaries['Disciminator_total_loss'])

            ## update G
            errG, errM, errV, errA, _, generator_summary = sess.run(
                [
                    g_loss, mse_loss, vgg_loss, g_gan_loss, g_optim,
                    merged_summary_generator
                ], {
                    t_image: b_imgs_96,
                    t_target_image: b_imgs_384
                })

            summary_pb = tf.summary.Summary()
            summary_pb.ParseFromString(generator_summary)
            #print("generator_summary", summary_pb, type(summary_pb))

            generator_summaries = {}
            for val in summary_pb.value:
                # Assuming all summaries are scalars.
                generator_summaries[val.tag] = val.simple_value

            #print("generator_summaries:", generator_summaries)

            g_gan_loss_summary_per_epoch.append(
                generator_summaries['Generator_GAN_loss'])
            mse_loss_summary_per_epoch.append(
                generator_summaries['Generator_MSE_loss'])
            vgg_loss_summary_per_epoch.append(
                generator_summaries['Generator_VGG_loss'])
            g_loss_summary_per_epoch.append(
                generator_summaries['Generator_total_loss'])

            #summary_generator_writer.add_summary(generator_summary, (count + 1))

            #summary_total = sess.run(summary_total_merged, {t_image: b_imgs_96, t_target_image: b_imgs_384})
            #summary_total_merged_writer.add_summary(summary_total, (count + 1))

            #count += 1

            tot_epoch = n_epoch + 10
            print(
                "Epoch [%2d/%2d] %4d time: %4.4fs, d_loss: %.8f g_loss: %.8f (mse: %.6f vgg: %.6f adv: %.6f)"
                % (epoch, tot_epoch, n_iter, time.time() - step_time, errD,
                   errG, errM, errV, errA))
            total_d_loss += errD
            total_g_loss += errG
            n_iter += 1
            #remove this for normal running:

        log = "[*] Epoch: [%2d/%2d] time: %4.4fs, d_loss: %.8f g_loss: %.8f" % (
            epoch, tot_epoch, time.time() - epoch_time, total_d_loss / n_iter,
            total_g_loss / n_iter)
        print(log)

        #####
        #
        # logging discriminator summary
        #
        ######

        # logging per epcoch summary of logit_real_loss per epoch. Value logged is averaged across batches used per epoch.
        summary_discriminator_writer.add_summary(
            tf.Summary(value=[
                tf.Summary.Value(tag="Disciminator_logits_real_loss per epoch",
                                 simple_value=np.mean(
                                     d_loss1_summary_per_epoch)),
            ]), (epoch))

        # logging per epcoch summary of logit_fake_loss per epoch. Value logged is averaged across batches used per epoch.
        summary_discriminator_writer.add_summary(
            tf.Summary(value=[
                tf.Summary.Value(tag="Disciminator_logits_fake_loss per epoch",
                                 simple_value=np.mean(
                                     d_loss2_summary_per_epoch)),
            ]), (epoch))

        # logging per epcoch summary of total_loss per epoch. Value logged is averaged across batches used per epoch.
        summary_discriminator_writer.add_summary(
            tf.Summary(value=[
                tf.Summary.Value(tag="Disciminator_total_loss per epoch",
                                 simple_value=np.mean(
                                     d_loss_summary_per_epoch)),
            ]), (epoch))

        #####
        #
        # logging generator summary
        #
        ######

        summary_generator_writer.add_summary(
            tf.Summary(value=[
                tf.Summary.Value(tag="Generator_GAN_loss per epoch",
                                 simple_value=np.mean(
                                     g_gan_loss_summary_per_epoch)),
            ]), (epoch))

        summary_generator_writer.add_summary(
            tf.Summary(value=[
                tf.Summary.Value(tag="Generator_MSE_loss per epoch",
                                 simple_value=np.mean(
                                     mse_loss_summary_per_epoch)),
            ]), (epoch))

        summary_generator_writer.add_summary(
            tf.Summary(value=[
                tf.Summary.Value(tag="Generator_VGG_loss per epoch",
                                 simple_value=np.mean(
                                     vgg_loss_summary_per_epoch)),
            ]), (epoch))

        summary_generator_writer.add_summary(
            tf.Summary(value=[
                tf.Summary.Value(tag="Generator_total_loss per epoch",
                                 simple_value=np.mean(
                                     g_loss_summary_per_epoch)),
            ]), (epoch))

        ## quick evaluation on train set
        #if (epoch != 0) and (epoch % 10 == 0):
        out = sess.run(
            net_g_test.outputs,
            {t_image: sample_imgs_96
             })  #; print('gen sub-image:', out.shape, out.min(), out.max())
        ## save model
        if (epoch % 10 == 0 and epoch != 0):
            saver.save(sess, 'checkpoint/main_' + str(epoch) + '.ckpt')

            print("[*] save images")
            for im in range(len(out)):
                tl.vis.save_image(
                    out[im], save_dir_gan + '/train_%d_%d.png' % (epoch, im))
Exemplo n.º 8
0
def train_distil():
    ## create folders to save result images and trained model
    save_dir_ginit = "samples/student_ginit"
    save_dir_gan = "samples/student_gan"
    tl.files.exists_or_mkdir(save_dir_ginit)
    tl.files.exists_or_mkdir(save_dir_gan)
    checkpoint_dir = "checkpoint"  # checkpoint_resize_conv
    tl.files.exists_or_mkdir(checkpoint_dir)

    d_losses, g_losses, m_losses, v_losses, a_losses = [], [], [], [], []
    g0losses, d1losses, d2losses = [], [], []
    ###====================== PRE-LOAD IMAGE DATA ===========================###

    print("loading images")
    train_hr_img_list = sorted(
        tl.files.load_file_list(path=config.TRAIN.hr_img_path,
                                regx='.*.png',
                                printable=False))
    train_lr_img_list = sorted(
        tl.files.load_file_list(path=config.TRAIN.lr_img_path,
                                regx='.*.png',
                                printable=False))

    # train_hr_img_list = train_hr_img_list[0:16]
    # train_lr_img_list = train_lr_img_list[0:16]

    train_hr_imgs = tl.vis.read_images(train_hr_img_list,
                                       path=config.TRAIN.hr_img_path,
                                       n_threads=32)
    train_lr_imgs = tl.vis.read_images(train_lr_img_list,
                                       path=config.TRAIN.lr_img_path,
                                       n_threads=32)

    print("images loaded")

    ###========================== DEFINE MODEL ============================###
    ## train inference
    t_image = tf.placeholder('float32', [batch_size, 96, 96, 3],
                             name='t_image')
    t_target_image = tf.placeholder('float32', [batch_size, 384, 384, 3],
                                    name='t_target_image')
    # t_distilled_d = tf.placeholder('float32', [batch_size, 384, 384, 3], name='t_target_image')
    # t_distilled_g = tf.placeholder('float32', [batch_size, 384, 384, 3], name='t_target_image')

    #nets
    net_g_student, net_g_student_distil = SRGAN_g_student(t_image,
                                                          is_train=True,
                                                          reuse=False)
    net_d_student, logits_real_student, net_d_student_distil = SRGAN_d_student(
        t_target_image, is_train=True, reuse=False)
    net_d_student_fake, logits_fake_student, net_d_student_distil_fake = SRGAN_d_student(
        net_g_student.outputs, is_train=True, reuse=True)

    if small_techer is True and train_all_nine is False:
        net_g_teacher_distil = SRGAN_g_teacher_small(t_image,
                                                     is_train=False,
                                                     reuse=False)
        net_d_teacher_distil = SRGAN_d_teacher_small(t_target_image,
                                                     is_train=False,
                                                     reuse=False)
        net_d_teacher_distil_fake = SRGAN_d_teacher_small(
            net_g_student.outputs, is_train=False, reuse=True)
    else:
        net_g_teacher, net_g_teacher_distil = SRGAN_g_teacher(t_image,
                                                              is_train=False,
                                                              reuse=False)
        net_d_teacher, _, net_d_teacher_distil = SRGAN_d_teacher(
            t_target_image, is_train=False, reuse=False)
        net_d_teacher_fake, _, net_d_teacher_distil_fake = SRGAN_d_teacher(
            net_g_student.outputs, is_train=False, reuse=True)

    if not train_all_nine is True:
        net_g0_predict, _ = SRGAN_g0_predict(net_g_student_distil.outputs,
                                             is_train=True,
                                             reuse=False)
        net_d1_predict, _ = SRGAN_d1_predict(net_d_student_distil.outputs,
                                             is_train=True,
                                             reuse=False)
        net_d2_predict, _ = SRGAN_d2_predict(net_d_student_distil_fake.outputs,
                                             is_train=True,
                                             reuse=False)
    else:

        net_g0_predict, _ = SRGAN_g0_predict(net_g_student_distil.outputs,
                                             is_train=True,
                                             reuse=False)
        net_d1_predict, _ = SRGAN_d1_predict(net_d_student_distil.outputs,
                                             is_train=True,
                                             reuse=False)
        net_d2_predict, _ = SRGAN_d2_predict(net_d_student_distil_fake.outputs,
                                             is_train=True,
                                             reuse=False)

        net_d1d2_predict, _ = SRGAN_d1d2_predict(net_d_student_distil.outputs,
                                                 is_train=True,
                                                 reuse=False)
        net_d2d1_predict, _ = SRGAN_d2d1_predict(
            net_d_student_distil_fake.outputs, is_train=True, reuse=False)

        net_g0d1_predict, _ = SRGAN_g0d1_predict(net_g_student_distil.outputs,
                                                 is_train=True,
                                                 reuse=False)
        net_g0d2_predict, _ = SRGAN_g0d2_predict(net_g_student_distil.outputs,
                                                 is_train=True,
                                                 reuse=False)

        net_d1g0_predict, _ = SRGAN_d1g0_predict(net_d_student_distil.outputs,
                                                 is_train=True,
                                                 reuse=False)
        net_d2g0_predict, _ = SRGAN_d2g0_predict(net_d_student_distil.outputs,
                                                 is_train=True,
                                                 reuse=False)

    net_g_student.print_params(False)
    net_g_student.print_layers()
    net_d_student.print_params(False)
    net_d_student.print_layers()
    # net_g_student_distil_fake.print_params(False)
    # net_g_student_distil_fake.print_layers()
    # net_d_student_distil_fake.print_params(False)
    # net_d_student_distil_fake.print_layers()

    ## vgg inference. 0, 1, 2, 3 BILINEAR NEAREST BICUBIC AREA
    t_target_image_224 = tf.image.resize_images(
        t_target_image, size=[224, 224], method=0, align_corners=False
    )  # resize_target_image_for_vgg # http://tensorlayer.readthedocs.io/en/latest/_modules/tensorlayer/layers.html#UpSampling2dLayer
    t_predict_image_224 = tf.image.resize_images(
        net_g_student.outputs, size=[224, 224], method=0,
        align_corners=False)  # resize_generate_image_for_vgg

    net_vgg, vgg_target_emb = Vgg19_simple_api((t_target_image_224 + 1) / 2,
                                               reuse=False)
    _, vgg_predict_emb = Vgg19_simple_api((t_predict_image_224 + 1) / 2,
                                          reuse=True)
    ## test inference
    net_g_test, _ = SRGAN_g_student(t_image, is_train=True, reuse=True)

    # ###========================== DEFINE TRAIN OPS ==========================###
    d_loss1 = tl.cost.sigmoid_cross_entropy(logits_real_student,
                                            tf.ones_like(logits_real_student),
                                            name='d1')
    d_loss2 = tl.cost.sigmoid_cross_entropy(logits_fake_student,
                                            tf.zeros_like(logits_fake_student),
                                            name='d2')
    g_gan_loss = 1e-3 * tl.cost.sigmoid_cross_entropy(
        logits_fake_student, tf.ones_like(logits_fake_student), name='g')
    mse_loss = tl.cost.mean_squared_error(net_g_student.outputs,
                                          t_target_image,
                                          is_mean=True)
    vgg_loss = 2e-6 * tl.cost.mean_squared_error(
        vgg_predict_emb.outputs, vgg_target_emb.outputs, is_mean=True)

    if not train_all_nine is True:
        g0_loss = 4e-3 * tl.cost.mean_squared_error(
            net_g0_predict.outputs, net_g_teacher_distil.outputs, is_mean=True)
        d1_loss = 1 / 5 * tl.cost.mean_squared_error(
            net_d1_predict.outputs, net_d_teacher_distil.outputs, is_mean=True)
        d2_loss = 1 / 5 * tl.cost.mean_squared_error(
            net_d2_predict.outputs,
            net_d_teacher_distil_fake.outputs,
            is_mean=True)
    else:
        g0_loss = 4e-3 * tl.cost.mean_squared_error(
            (net_g0_predict.outputs + net_d1g0_predict.outputs +
             net_d2g0_predict.outputs) / 3,
            net_g_teacher_distil.outputs,
            is_mean=True)
        d1_loss = 1 / 5 * tl.cost.mean_squared_error(
            (net_g0d1_predict.outputs + net_d1_predict.outputs +
             net_d2d1_predict.outputs) / 3,
            net_d_teacher_distil.outputs,
            is_mean=True)
        d2_loss = 1 / 5 * tl.cost.mean_squared_error(
            (net_g0d2_predict.outputs + net_d1d2_predict.outputs +
             net_d2_predict.outputs) / 3,
            net_d_teacher_distil_fake.outputs,
            is_mean=True)

    d_loss = d_loss1 + d_loss2 + d1_loss + d2_loss
    g_loss = mse_loss + vgg_loss + g_gan_loss + g0_loss

    g_vars = tl.layers.get_variables_with_name('SRGAN_g_student', True, True)
    d_vars = tl.layers.get_variables_with_name('SRGAN_d_student', True, True)

    with tf.variable_scope('learning_rate'):
        lr_v = tf.Variable(lr_init, trainable=False)
    # ## Pretrain
    # g_optim_init = tf.train.AdamOptimizer(lr_v, beta1=beta1).minimize(mse_loss, var_list=g_vars)
    ## SRGAN
    g_optim = tf.train.AdamOptimizer(lr_v,
                                     beta1=beta1).minimize(g_loss,
                                                           var_list=g_vars)
    d_optim = tf.train.AdamOptimizer(lr_v,
                                     beta1=beta1).minimize(d_loss,
                                                           var_list=d_vars)

    ###========================== RESTORE MODEL =============================###
    sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
                                            log_device_placement=False))
    tl.layers.initialize_global_variables(sess)

    if tl.files.load_and_assign_npz(
            sess=sess,
            name=checkpoint_dir + '/g_srgan_student.npz',
            network=net_g_student) is False:
        pass
        # tl.files.load_and_assign_npz(sess=sess, name=checkpoint_dir + '/g_srgan_init.npz', network=net_g_student)
    tl.files.load_and_assign_npz(sess=sess,
                                 name=checkpoint_dir + '/d_srgan_student.npz',
                                 network=net_d_student)

    if small_teacher is True:
        tl.files.load_and_assign_npz(sess=sess,
                                     name=checkpoint_dir +
                                     '/g_small_teacher_bicube.npz',
                                     network=net_g_teacher_distil)
        tl.files.load_and_assign_npz(sess=sess,
                                     name=checkpoint_dir +
                                     '/d_small_teacher_bicube.npz',
                                     network=net_d_teacher_distil)
    else:
        tl.files.load_and_assign_npz(sess=sess,
                                     name=checkpoint_dir +
                                     '/g_srgan_teacher.npz',
                                     network=net_g_teacher)
        tl.files.load_and_assign_npz(sess=sess,
                                     name=checkpoint_dir +
                                     '/d_srgan_teacher.npz',
                                     network=net_d_teacher)
    ###============================= LOAD VGG ===============================###
    vgg19_npy_path = "vgg19.npy"
    if not os.path.isfile(vgg19_npy_path):
        print(
            "Please download vgg19.npz from : https://github.com/machrisaa/tensorflow-vgg"
        )
        exit()
    npz = np.load(vgg19_npy_path, encoding='latin1').item()

    params = []
    for val in sorted(npz.items()):
        W = np.asarray(val[1][0])
        b = np.asarray(val[1][1])
        print("  Loading %s: %s, %s" % (val[0], W.shape, b.shape))
        params.extend([W, b])
    tl.files.assign_params(sess, params, net_vgg)
    # net_vgg.print_params(False)
    # net_vgg.print_layers()

    ###============================= PreSample ===============================###
    ## use first `batch_size` of train set to have a quick test during training
    # sample_imgs = train_hr_imgs[0:batch_size]
    # sample_imgs = train_lr_imgs[0:batch_size]
    # # sample_imgs = tl.vis.read_images(train_hr_img_list[0:batch_size], path=config.TRAIN.hr_img_path, n_threads=32) # if no pre-load train set
    sample_imgs_384, sample_imgs_96 = threading_data_2(
        (train_hr_imgs[0:batch_size], train_lr_imgs[0:batch_size]),
        fn=crop2,
        is_random=True)
    # sample_imgs_384 = tl.prepro.threading_data(sample_imgs, fn=crop_sub_imgs_fn, is_random=True)
    print('sample HR sub-image:', sample_imgs_384.shape, sample_imgs_384.min(),
          sample_imgs_384.max())
    # sample_imgs_96 = tl.prepro.threading_data(sample_imgs_384, fn=downsample_fn)
    print('sample LR sub-image:', sample_imgs_96.shape, sample_imgs_96.min(),
          sample_imgs_96.max())
    tl.vis.save_images(sample_imgs_96, [ni, ni],
                       save_dir_ginit + '/_train_sample_96.png')
    tl.vis.save_images(sample_imgs_384, [ni, ni],
                       save_dir_ginit + '/_train_sample_384.png')
    tl.vis.save_images(sample_imgs_96, [ni, ni],
                       save_dir_gan + '/_train_sample_96.png')
    tl.vis.save_images(sample_imgs_384, [ni, ni],
                       save_dir_gan + '/_train_sample_384.png')
    ###========================= train GAN (SRGAN) =========================###
    print("starting")
    for epoch in range(0, n_epoch + 1):
        ## update learning rate
        if epoch != 0 and (epoch % decay_every == 0):
            new_lr_decay = lr_decay**(epoch // decay_every)
            sess.run(tf.assign(lr_v, lr_init * new_lr_decay))
            log = " ** new learning rate: %f (for GAN)" % (lr_init *
                                                           new_lr_decay)
            print(log)
        elif epoch == 0:
            sess.run(tf.assign(lr_v, lr_init))
            log = " ** init lr: %f  decay_every_init: %d, lr_decay: %f (for GAN)" % (
                lr_init, decay_every, lr_decay)
            print(log)

        epoch_time = time.time()
        total_d_loss, total_g_loss, n_iter = 0, 0, 0
        total_errM, total_errV, total_errA = 0, 0, 0
        total_erg0, total_erd1, total_erd2 = 0, 0, 0

        ## Actual training
        for idx in range(0, len(train_hr_imgs), batch_size):
            #using 4x lowresolution images
            b_imgs_384, b_imgs_96 = threading_data_2(
                (train_hr_imgs[idx:idx + batch_size],
                 train_lr_imgs[idx:idx + batch_size]),
                fn=crop2,
                is_random=True)
            #for 4x high resolution images
            # if n_iter==0 and epoch==0: tl.vis.save_images(b_imgs_384, [ni, ni], save_dir_gan + '/original_train_384_%d.png' % epoch)
            # if n_iter==0 and epoch==0: tl.vis.save_images(b_imgs_96, [ni, ni], save_dir_gan + '/original_train_96_%d.png' % epoch)
            step_time = time.time()
            ## update D
            errD, erg0, erd1, erd2, _ = sess.run(
                [d_loss, g0_loss, d1_loss, d2_loss, d_optim], {
                    t_image: b_imgs_96,
                    t_target_image: b_imgs_384
                })
            ## update G
            errG, errM, errV, errA, _ = sess.run(
                [g_loss, mse_loss, vgg_loss, g_gan_loss, g_optim], {
                    t_image: b_imgs_96,
                    t_target_image: b_imgs_384
                })

            total_d_loss += errD
            total_g_loss += errG
            total_errM += errM
            total_errV += errV
            total_errA += errA
            total_erg0 += erg0
            total_erd1 += erd1
            total_erd2 += erd2
            n_iter += 1
            print(
                "Epoch [%2d/%2d] %4d time: %4.4fs, d_loss: %.8f g_loss: %.8f (mse: %.6f vgg: %.6f adv: %.6f)(erg0: %.6f erd1: %.6f erd2: %.6f)"
                % (epoch, n_epoch, n_iter, time.time() - step_time, errD, errG,
                   errM, errV, errA, erg0, erd1, erd2))
        log = "[*] Epoch: [%2d/%2d] time: %4.4fs, d_loss: %.8f g_loss: %.8f" % (
            epoch, n_epoch, time.time() - epoch_time, total_d_loss / n_iter,
            total_g_loss / n_iter)
        print(log)

        ##write losses
        d_losses += [total_d_loss / n_iter]
        g_losses += [total_g_loss / n_iter]
        m_losses += [total_errM / n_iter]
        v_losses += [total_errV / n_iter]
        a_losses += [total_errA / n_iter]
        g0losses += [total_erg0 / n_iter]
        d1losses += [total_erd1 / n_iter]
        d2losses += [total_erd2 / n_iter]

        if epoch % 20 == 0:
            tl.files.save_npz(net_g_student.all_params,
                              name=checkpoint_dir +
                              '/g_srgan_student_%d.npz' % epoch,
                              sess=sess)
            tl.files.save_npz(net_d_student.all_params,
                              name=checkpoint_dir +
                              '/d_srgan_student_%d.npz' % epoch,
                              sess=sess)
            write_losses("d_losses", d_losses)
            write_losses("g_losses", g_losses)
            write_losses("m_losses", m_losses)
            write_losses("v_losses", v_losses)
            write_losses("a_losses", a_losses)
            write_losses("g0losses", g0losses)
            write_losses("d1losses", d1losses)
            write_losses("d2losses", d2losses)

        if epoch % 10 == 0:
            # out = sess.run(net_g_test.outputs, {t_image: sample_imgs_96})  #; print('gen sub-image:', out.shape, out.min(), out.max())
            tl.files.save_npz(net_g_student.all_params,
                              name=checkpoint_dir + '/g_srgan_student.npz',
                              sess=sess)
            tl.files.save_npz(net_d_student.all_params,
                              name=checkpoint_dir + '/d_srgan_student.npz',
                              sess=sess)
            if not small_teacher is True:
                tl.files.save_npz(net_g_teacher_distil.all_params,
                                  name=checkpoint_dir +
                                  '/g_small_teacher_bicube.npz',
                                  sess=sess)
                tl.files.save_npz(net_d_teacher_distil.all_params,
                                  name=checkpoint_dir +
                                  '/d_small_teacher_bicube.npz',
                                  sess=sess)
            evaluate()

        ## quick evaluation on train set
        if (epoch % 5 == 0):
            out = sess.run(net_g_test.outputs, {
                t_image: sample_imgs_96
            })  #; print('gen sub-image:', out.shape, out.min(), out.max())
            print("[*] save images")
            tl.vis.save_images(out, [ni, ni],
                               save_dir_gan + '/train_%d.png' % epoch)
Exemplo n.º 9
0
def train():
    ## create folders to save result images and trained model
    save_dir_ginit = "samples/{}_ginit".format(tl.global_flag['mode'])
    tl.files.exists_or_mkdir(save_dir_ginit)
    tl.files.exists_or_mkdir(save_dir_gan)
    checkpoint_dir = "checkpoint"  # checkpoint_resize_conv
    tl.files.exists_or_mkdir(checkpoint_dir)

    ###====================== PRE-LOAD DATA ===========================###
    train_hr_img_list = sorted(tl.files.load_file_list(path=config.TRAIN.hr_img_path, regx='.*.png', printable=False))
    train_lr_img_list = sorted(tl.files.load_file_list(path=config.TRAIN.lr_img_path, regx='.*.png', printable=False))
    valid_hr_img_list = sorted(tl.files.load_file_list(path=config.VALID.hr_img_path, regx='.*.png', printable=False))
    valid_lr_img_list = sorted(tl.files.load_file_list(path=config.VALID.lr_img_path, regx='.*.png', printable=False))

    ## pre-load the whole train set.
    train_hr_imgs = tl.vis.read_images(train_hr_img_list, path=config.TRAIN.hr_img_path, n_threads=32)

    ###========================== DEFINE MODEL ============================###
    ## train inference
    t_image = tf.placeholder('float32', [batch_size, 30, 30, 3], name='t_image_input_to_generator')
    t_target_image = tf.placeholder('float32', [batch_size, 120, 120, 3], name='t_target_image')

    net_g = UMSR_g(t_image, is_train=True, reuse=False)

    net_g.print_params(False)
    net_g.print_layers()

    ## vgg inference. 0, 1, 2, 3 BILINEAR NEAREST BICUBIC AREA
    t_target_image_224 = tf.image.resize_images(
        t_target_image, size=[224, 224], method=0,
        align_corners=False)  # resize_target_image_for_vgg # http://tensorlayer.readthedocs.io/en/latest/_modules/tensorlayer/layers.html#UpSampling2dLayer
    t_predict_image_224 = tf.image.resize_images(net_g.outputs, size=[224, 224], method=0, align_corners=False)  # resize_generate_image_for_vgg

    net_vgg, vgg_target_emb1, vgg_target_emb2, vgg_target_emb3, vgg_target_emb4, vgg_target_emb5 = Vgg19_simple_api((t_target_image_224 + 1) / 2, reuse=False)
    _, vgg_predict_emb1, vgg_predict_emb2, vgg_predict_emb3, vgg_predict_emb4, vgg_predict_emb5 = Vgg19_simple_api((t_predict_image_224 + 1) / 2, reuse=True)

    ## test inference
    net_g_test = UMSR_g(t_image, is_train=False, reuse=True)

    # ###========================== DEFINE TRAIN OPS ==========================###

    mse_loss = tl.cost.mean_squared_error(net_g.outputs, t_target_image, is_mean=True)
    vgg_loss = 2e-6 * tl.cost.mean_squared_error(vgg_predict_emb5.outputs, vgg_target_emb5.outputs, is_mean=True)
    gram_loss1 = 1e-6 * gram_scale_loss1(vgg_target_emb1.outputs,vgg_predict_emb1.outputs)
    gram_loss2 = 1e-6 * gram_scale_loss2(vgg_target_emb3.outputs,vgg_predict_emb3.outputs)
    gram_loss = gram_loss1 + gram_loss2
    #tf.summary.scalar('loss', mse_loss)
    g1_loss = mse_loss + vgg_loss + gram_loss

    g_vars = tl.layers.get_variables_with_name('UMSR_g', True, True)

    with tf.variable_scope('learning_rate'):
        lr_v = tf.Variable(lr_init, trainable=False)
    ## Pretrain
    g_optim_init = tf.train.AdamOptimizer(lr_v, beta1=beta1, beta2=beta2).minimize(g1_loss, var_list=g_vars)

    ###========================== RESTORE MODEL =============================###
    sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False))
    tl.layers.initialize_global_variables(sess)
    tl.files.load_and_assign_npz(sess=sess, name=checkpoint_dir + '/g_{}_init.npz'.format(tl.global_flag['mode']), network=net_g)

    ###============================= LOAD VGG ===============================###
    vgg19_npy_path = "vgg19.npy"
    if not os.path.isfile(vgg19_npy_path):
        print("Please download vgg19.npz from : https://github.com/machrisaa/tensorflow-vgg")
        exit()
    npz = np.load(vgg19_npy_path, encoding='latin1').item()

    params = []
    for val in sorted(npz.items()):
        W = np.asarray(val[1][0])
        b = np.asarray(val[1][1])
        print("  Loading %s: %s, %s" % (val[0], W.shape, b.shape))
        params.extend([W, b])
    tl.files.assign_params(sess, params, net_vgg)
    # net_vgg.print_params(False)
    # net_vgg.print_layers()

    ###============================= TRAINING ===============================###
    ## use first `batch_size` of train set to have a quick test during training
    #sample_imgs = train_hr_imgs[0:batch_size]
    sample_imgs = tl.vis.read_images(train_hr_img_list[0:batch_size], path=config.TRAIN.hr_img_path, n_threads=32) # if no pre-load train set
    sample_imgs_120 = tl.prepro.threading_data(sample_imgs, fn=crop_sub_imgs_fn, is_random=False)
    #print('sample HR sub-image:', sample_imgs_384.shape, sample_imgs_384.min(), sample_imgs_384.max())
    sample_imgs_30 = tl.prepro.threading_data(sample_imgs_120, fn=downsample_fn)
    #print('sample LR sub-image:', sample_imgs_96.shape, sample_imgs_96.min(), sample_imgs_96.max())
    tl.vis.save_images(sample_imgs_30, [ni, ni], save_dir_ginit + '/_train_sample_30.png')
    tl.vis.save_images(sample_imgs_120, [ni, ni], save_dir_ginit + '/_train_sample_120.png')

    ###========================= initialize G ====================###
    ## fixed learning rate

    #sess.run(tf.assign(lr_v, lr_init))
    for epoch in range(0, n_epoch_init + 1):
        if epoch != 0 and (epoch % decay_every_init == 0):
            new_lr_decay_init = lr_decay_init**(epoch // decay_every_init)
            sess.run(tf.assign(lr_v, lr_init * new_lr_decay_init))
            log = " ** new learning rate: %f (for Generator)" % (lr_init * new_lr_decay_init)
            print(log)
        elif epoch == 0:
            sess.run(tf.assign(lr_v, lr_init))
            log = " ** init lr: %f  decay_every_init: %d, lr_decay: %f (for Generator)" % (lr_init, decay_every_init, lr_decay_init)
            print(log)

        epoch_time = time.time()
        total_g1_loss, n_iter = 0, 0

        ## If your machine have enough memory, please pre-load the whole train set.
        for idx in range(0, len(train_hr_imgs), batch_size):
            step_time = time.time()
            b_imgs_120 = tl.prepro.threading_data(train_hr_imgs[idx:idx + batch_size], fn=crop_sub_imgs_fn, is_random=True)  #in order to get the fix size of inputs to be suitable for the network.
            b_imgs_30 = tl.prepro.threading_data(b_imgs_120, fn=downsample_fn)
            ## update G
            errG1, _ = sess.run([g1_loss, g_optim_init], {t_image: b_imgs_30, t_target_image: b_imgs_120})

            print("Epoch [%2d/%2d] %4d time: %4.4fs, g1: %.8f " % (epoch, n_epoch_init, n_iter, time.time() - step_time, errG1))
            total_g1_loss += errG1
            n_iter += 1
#            tf.summary.scalar('loss', mse_loss)
        log = "[*] Epoch: [%2d/%2d] time: %4.4fs, g1: %.8f" % (epoch, n_epoch_init, time.time() - epoch_time, total_g1_loss / n_iter)
        print(log)

        ## quick evaluation on train set
        if (epoch != 0) and (epoch % 50 == 0):
            out = sess.run(net_g_test.outputs, {t_image: sample_imgs_30})  #; print('gen sub-image:', out.shape, out.min(), out.max())
            print("[*] save images")
            tl.vis.save_images(out, [ni, ni], save_dir_ginit + '/train_%d.png' % epoch)

        if (epoch != 0) and (epoch % 20 == 0):
            average_lossG1 = total_g1_loss / n_iter
            f = open('testG1.text', 'a')
            f.write(str(average_lossG1) + '\n')
            f.close()


        ## save model
        if (epoch != 0) and (epoch % 500 == 0):
            tl.files.save_npz(net_g.all_params, name=checkpoint_dir + '/g_%{}_init.npz'.format(tl.global_flag['mode']) % epoch, sess=sess)
Exemplo n.º 10
0
def train(train_lr_path,
          train_hr_path,
          save_path,
          save_every_epoch=2,
          validation=True,
          ratio=0.9,
          batch_size=16,
          lr_init=1e-4,
          beta1=0.9,
          n_epoch_init=10,
          n_epoch=20,
          lr_decay=0.1):
    '''
    Parameters:
    data:
        train_lr_path/train_hr_path: path of data
        save_path: the parent folder to save model result
        validation: whether to split data into train set and validation set
        save_every_epoch: how frequent to save the checkpoints and sample images
    Adam: 
        batch_size
        lr_init
        beta1
    Generator Initialization
        n_epoch_init
    Adversarial Net
        n_epoch
        lr_decay
    '''

    ## Folders to save results
    save_dir_ginit = os.path.join(save_path, 'srgan_ginit')
    save_dir_gan = os.path.join(save_path, 'srgan_gan')
    checkpoint_dir = os.path.join(save_path, 'checkpoint')
    tl.files.exists_or_mkdir(save_dir_ginit)
    tl.files.exists_or_mkdir(save_dir_gan)
    tl.files.exists_or_mkdir(checkpoint_dir)

    ###======LOAD DATA======###
    train_lr_img_list = sorted(
        tl.files.load_file_list(path=train_lr_path,
                                regx='.*.jpg',
                                printable=False))
    train_hr_img_list = sorted(
        tl.files.load_file_list(path=train_hr_path,
                                regx='.*.jpg',
                                printable=False))

    if validation:
        idx = np.random.choice(len(train_lr_img_list),
                               int(len(train_lr_img_list) * ratio),
                               replace=False)
        valid_lr_img_list = sorted(
            [x for i, x in enumerate(train_lr_img_list) if i not in idx])
        valid_hr_img_list = sorted(
            [x for i, x in enumerate(train_hr_img_list) if i not in idx])
        train_lr_img_list = sorted(
            [x for i, x in enumerate(train_lr_img_list) if i in idx])
        train_hr_img_list = sorted(
            [x for i, x in enumerate(train_hr_img_list) if i in idx])

        valid_lr_imgs = tl.vis.read_images(valid_lr_img_list,
                                           path=train_lr_path,
                                           n_threads=32)
        valid_hr_imgs = tl.vis.read_images(valid_hr_img_list,
                                           path=train_hr_path,
                                           n_threads=32)

    train_lr_imgs = tl.vis.read_images(train_lr_img_list,
                                       path=train_lr_path,
                                       n_threads=32)
    train_hr_imgs = tl.vis.read_images(train_hr_img_list,
                                       path=train_hr_path,
                                       n_threads=32)

    ###======DEFINE MODEL======###
    ## train inference
    lr_image = tf.placeholder('float32', [None, 96, 96, 3], name='lr_image')
    hr_image = tf.placeholder('float32', [None, 192, 192, 3], name='hr_image')

    net_g = SRGAN_g(lr_image, is_train=True, reuse=False)
    net_d, logits_real = SRGAN_d(hr_image, is_train=True, reuse=False)
    _, logits_fake = SRGAN_d(net_g.outputs, is_train=True, reuse=True)

    # net_g.print_params(False)
    # net_g.print_layers()
    # net_d.print_params(False)
    # net_d.print_layers()

    ## resize original hr images for VGG19
    hr_image_224 = tf.image.resize_images(
        hr_image,
        size=[224, 224],
        method=0,  # BICUBIC
        align_corners=False)

    ## generated hr image for VGG19
    generated_image_224 = tf.image.resize_images(
        net_g.outputs,
        size=[224, 224],
        method=0,  #BICUBIC
        align_corners=False)

    ## scale image to [0,1] and get conv characteristics
    net_vgg, vgg_target_emb = Vgg19_simple_api((hr_image_224 + 1) / 2,
                                               reuse=False)
    _, vgg_predict_emb = Vgg19_simple_api((generated_image_224 + 1) / 2,
                                          reuse=True)

    ## test inference
    net_g_test = SRGAN_g(lr_image, is_train=False, reuse=True)

    ###======DEFINE TRAIN PROCESS======###
    d_loss1 = tl.cost.sigmoid_cross_entropy(logits_real,
                                            tf.ones_like(logits_real),
                                            name='d1')
    d_loss2 = tl.cost.sigmoid_cross_entropy(logits_fake,
                                            tf.zeros_like(logits_fake),
                                            name='d2')
    d_loss = d_loss1 + d_loss2

    prediction1 = tf.greater(logits_real, tf.fill(tf.shape(logits_real), 0.5))
    acc_metric1 = tf.reduce_mean(tf.cast(prediction1, tf.float32))
    prediction2 = tf.less(logits_fake, tf.fill(tf.shape(logits_fake), 0.5))
    acc_metric2 = tf.reduce_mean(tf.cast(prediction2, tf.float32))
    acc_metric = acc_metric1 + acc_metric2

    g_gan_loss = 1e-3 * tl.cost.sigmoid_cross_entropy(
        logits_fake, tf.ones_like(logits_fake), name='g')
    mse_loss = tl.cost.mean_squared_error(net_g.outputs,
                                          hr_image,
                                          is_mean=True)
    vgg_loss = 2e-6 * tl.cost.mean_squared_error(
        vgg_predict_emb.outputs, vgg_target_emb.outputs, is_mean=True)

    g_loss = mse_loss + g_gan_loss + vgg_loss

    psnr_metric = tf.image.psnr(net_g.outputs,
                                hr_image,
                                max_val=2.0,
                                name='psnr')

    g_vars = tl.layers.get_variables_with_name('SRGAN_g', True, True)
    d_vars = tl.layers.get_variables_with_name('SRGAN_d', True, True)

    with tf.variable_scope('learning_rate'):
        lr_v = tf.Variable(lr_init, trainable=False)

    ## Pretrain
    g_optim_init = tf.train.AdamOptimizer(lr_v, beta1=beta1).minimize(
        mse_loss, var_list=g_vars)

    ## SRGAN
    g_optim = tf.train.AdamOptimizer(lr_v,
                                     beta1=beta1).minimize(g_loss,
                                                           var_list=g_vars)
    d_optim = tf.train.AdamOptimizer(lr_v,
                                     beta1=beta1).minimize(d_loss,
                                                           var_list=d_vars)

    ###========================== RESTORE MODEL =============================###
    sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
                                            log_device_placement=False))
    sess.run(tf.global_variables_initializer())

    if tl.files.file_exists(os.path.join(checkpoint_dir, 'g_srgan.npz')):
        tl.files.load_and_assign_npz(sess=sess,
                                     name=os.path.join(checkpoint_dir,
                                                       'g_srgan.npz'),
                                     network=net_g)
    else:
        tl.files.load_and_assign_npz(sess=sess,
                                     name=os.path.join(checkpoint_dir,
                                                       'g_srgan_init.npz'),
                                     network=net_g)

    tl.files.load_and_assign_npz(sess=sess,
                                 name=os.path.join(checkpoint_dir,
                                                   'd_srgan.npz'),
                                 network=net_d)

    ###======LOAD VGG======###
    vgg19_npy_path = '../lib/SRGAN/vgg19.npy'
    npz = np.load(vgg19_npy_path, encoding='latin1').item()

    params = []
    for val in sorted(npz.items()):
        W = np.asarray(val[1][0])
        b = np.asarray(val[1][1])
        print("  Loading %s: %s, %s" % (val[0], W.shape, b.shape))
        params.extend([W, b])
    tl.files.assign_params(sess, params, net_vgg)
    # net_vgg.print_params(False)
    # net_vgg.print_layers()

    ###======TRAINING======###
    ## use train set to have a quick test during training
    ni = 4
    num_sample = ni * ni
    idx = np.random.choice(len(train_lr_imgs), num_sample, replace=False)
    sample_imgs_lr = tl.prepro.threading_data(
        [img for i, img in enumerate(train_lr_imgs) if i in idx],
        fn=crop_sub_imgs_fn,
        size=(96, 96),
        is_random=False)
    sample_imgs_hr = tl.prepro.threading_data(
        [img for i, img in enumerate(train_hr_imgs) if i in idx],
        fn=crop_sub_imgs_fn,
        size=(192, 192),
        is_random=False)

    print('sample LR sub-image:', sample_imgs_lr.shape, sample_imgs_lr.min(),
          sample_imgs_lr.max())
    print('sample HR sub-image:', sample_imgs_hr.shape, sample_imgs_hr.min(),
          sample_imgs_hr.max())

    ## save the images
    tl.vis.save_images(sample_imgs_lr, [ni, ni],
                       os.path.join(save_dir_ginit, '_train_sample_96.jpg'))
    tl.vis.save_images(sample_imgs_hr, [ni, ni],
                       os.path.join(save_dir_ginit, '_train_sample_192.jpg'))
    tl.vis.save_images(sample_imgs_lr, [ni, ni],
                       os.path.join(save_dir_gan, '_train_sample_96.jpg'))
    tl.vis.save_images(sample_imgs_hr, [ni, ni],
                       os.path.join(save_dir_gan, '_train_sample_192.jpg'))
    print('finish saving sample images')

    ###====== initialize G ======###
    ## fixed learning rate
    sess.run(tf.assign(lr_v, lr_init))
    print(" ** fixed learning rate: %f (for init G)" % lr_init)
    for epoch in range(0, n_epoch_init + 1):
        epoch_time = time.time()
        total_mse_loss, total_psnr, n_iter = 0, 0, 0

        # random shuffle the train set for each epoch
        random.shuffle(train_hr_imgs)

        for idx in range(0, len(train_lr_imgs), batch_size):
            step_time = time.time()
            b_imgs_192 = tl.prepro.threading_data(train_hr_imgs[idx:idx +
                                                                batch_size],
                                                  fn=crop_sub_imgs_fn,
                                                  size=(192, 192),
                                                  is_random=True)
            b_imgs_96 = tl.prepro.threading_data(b_imgs_192,
                                                 fn=downsample_fn,
                                                 size=(96, 96))
            ## update G
            errM, metricP, _ = sess.run([mse_loss, psnr_metric, g_optim_init],
                                        {
                                            lr_image: b_imgs_96,
                                            hr_image: b_imgs_192
                                        })
            print("Epoch [%2d/%2d] %4d time: %4.2fs, mse: %.4f, psnr: %.4f " %
                  (epoch, n_epoch_init, n_iter, time.time() - step_time, errM,
                   metricP.mean()))
            total_mse_loss += errM
            total_psnr += metricP.mean()
            n_iter += 1
        log = "[*] Epoch: [%2d/%2d] time: %4.2fs, mse: %.4f, psnr: %.4f" % (
            epoch, n_epoch_init, time.time() - epoch_time,
            total_mse_loss / n_iter, total_psnr / n_iter)
        print(log)
        if validation:
            b_imgs_192_V = tl.prepro.threading_data(valid_hr_imgs,
                                                    fn=crop_sub_imgs_fn,
                                                    size=(192, 192),
                                                    is_random=True)
            b_imgs_96_V = tl.prepro.threading_data(b_imgs_192_V,
                                                   fn=downsample_fn,
                                                   size=(96, 96))
            errM_V, metricP_V, _ = sess.run(
                [mse_loss, psnr_metric, g_optim_init], {
                    lr_image: b_imgs_96_V,
                    hr_image: b_imgs_192_V
                })
            print("Validation | mse: %.4f, psnr: %.4f" %
                  (errM_V, metricP_V.mean()))

        ## quick evaluation on train set
        if (epoch != 0) and (epoch % save_every_epoch == 0):
            out = sess.run(net_g_test.outputs, {lr_image: sample_imgs_lr})
            print("[*] save sample images")
            tl.vis.save_images(
                out, [ni, ni],
                os.path.join(save_dir_ginit, 'train_{}.jpg'.format(epoch)))

        ## save model
        if (epoch != 0) and (epoch % save_every_epoch == 0):
            tl.files.save_npz(net_g.all_params,
                              name=os.path.join(checkpoint_dir,
                                                'g_srgan_init.npz'),
                              sess=sess)

    ###========================= train GAN (SRGAN) =========================###
    ## Learning rate decay
    decay_every = int(n_epoch / 2) if int(n_epoch / 2) > 0 else 1

    for epoch in range(0, n_epoch + 1):

        # random shuffle the train set for each epoch
        random.shuffle(train_hr_imgs)

        ## update learning rate
        if epoch != 0 and (epoch % decay_every == 0):
            new_lr_decay = lr_decay**(epoch // decay_every)
            sess.run(tf.assign(lr_v, lr_init * new_lr_decay))
            log = " ** new learning rate: %f (for GAN)" % (lr_init *
                                                           new_lr_decay)
            print(log)
        elif epoch == 0:
            sess.run(tf.assign(lr_v, lr_init))
            log = " ** init lr: %f  decay_every_init: %d, lr_decay: %f (for GAN)" % (
                lr_init, decay_every, lr_decay)
            print(log)

        epoch_time = time.time()
        total_d_loss, total_g_loss, total_mse_loss, total_psnr, total_acc, n_iter = 0, 0, 0, 0, 0, 0

        for idx in range(0, len(train_lr_imgs), batch_size):
            step_time = time.time()
            b_imgs_192 = tl.prepro.threading_data(train_hr_imgs[idx:idx +
                                                                batch_size],
                                                  fn=crop_sub_imgs_fn,
                                                  size=(192, 192),
                                                  is_random=True)
            b_imgs_96 = tl.prepro.threading_data(b_imgs_192,
                                                 fn=downsample_fn,
                                                 size=(96, 96))
            ## update D
            errD, metricA, _ = sess.run([d_loss, acc_metric, d_optim], {
                lr_image: b_imgs_96,
                hr_image: b_imgs_192
            })
            ## update G
            errG, errM, metricP, _ = sess.run(
                [g_loss, mse_loss, psnr_metric, g_optim], {
                    lr_image: b_imgs_96,
                    hr_image: b_imgs_192
                })
            print(
                "Epoch [%2d/%2d] %4d time: %4.2fs, d_loss: %.4f g_loss: %.4f (mse: %.4f, psnr: %.4f, accuracy: %.4f)"
                % (epoch, n_epoch, n_iter, time.time() - step_time, errD, errG,
                   errM, metricP.mean(), metricA / 2))
            total_d_loss += errD
            total_g_loss += errG
            total_mse_loss += errM
            total_psnr += metricP.mean()
            total_acc += metricA / 2
            n_iter += 1

        log = "[*] Epoch: [%2d/%2d] time: %4.2fs, d_loss: %.4f g_loss: %.4f (mse: %4f, psnr: %.4f, accuracy: %.4f)" % (
            epoch, n_epoch, time.time() - epoch_time, total_d_loss / n_iter,
            total_g_loss / n_iter, total_mse_loss / n_iter,
            total_psnr / n_iter, total_acc / n_iter)
        print(log)

        if validation:
            b_imgs_192_V = tl.prepro.threading_data(valid_hr_imgs,
                                                    fn=crop_sub_imgs_fn,
                                                    size=(192, 192),
                                                    is_random=True)
            b_imgs_96_V = tl.prepro.threading_data(b_imgs_192_V,
                                                   fn=downsample_fn,
                                                   size=(96, 96))
            errM_V, metricP_V, _ = sess.run([mse_loss, psnr_metric, g_optim], {
                lr_image: b_imgs_96_V,
                hr_image: b_imgs_192_V
            })
            print("Validation | mse: %.4f, psnr: %.4f" %
                  (errM_V, metricP_V.mean()))

        ## quick evaluation on train set
        if (epoch != 0) and (epoch % save_every_epoch == 0):
            out = sess.run(net_g_test.outputs, {lr_image: sample_imgs_lr})
            print("[*] save images")
            tl.vis.save_images(
                out, [ni, ni],
                os.path.join(save_dir_gan, 'train_{}.jpg'.format(epoch)))

        ## save model
        if (epoch != 0) and (epoch % save_every_epoch == 0):
            tl.files.save_npz(net_g.all_params,
                              name=os.path.join(checkpoint_dir, 'g_srgan.npz'),
                              sess=sess)
            tl.files.save_npz(net_d.all_params,
                              name=os.path.join(checkpoint_dir, 'd_srgan.npz'),
                              sess=sess)