Пример #1
0
def main():
    dataset = 'datasets'
    model_name = 'FUNIT'
    os.makedirs(os.path.join('experiments', model_name, 'checkpoints'),
                exist_ok=True)
    log_dir = os.path.join('logs', model_name)
    os.makedirs(log_dir, exist_ok=True)

    validation_output_dir = 'sample'
    os.makedirs(validation_output_dir, exist_ok=True)

    classes = os.listdir(dataset)
    num_classes = len(classes)
    print(num_classes)
    img_size = 128
    btGen = BatchGenerator(img_size=img_size,
                           imgdir=dataset,
                           num_classes=num_classes)

    num_iterations = 100000
    mini_batch_size = 8
    generator_learning_rate = 0.00010
    discriminator_learning_rate = 0.00010
    lambda_fm = 1
    lambda_rec = 0.1

    model = FUNIT(img_size=img_size,
                  num_classes=num_classes,
                  batch_size=mini_batch_size,
                  rec_weight=lambda_rec,
                  feature_weight=lambda_fm,
                  log_dir=log_dir)

    ckpt = tf.train.get_checkpoint_state(
        os.path.join('experiments', model_name, 'checkpoints'))
    if ckpt:
        #last_model = ckpt.all_model_checkpoint_paths[1]
        last_model = ckpt.model_checkpoint_path
        print("loading {}".format(last_model))
        model.load(filepath=last_model)
    else:
        print("checkpoints are not found")

    iteration = 1
    while iteration <= num_iterations:
        generator_learning_rate *= 0.99999
        discriminator_learning_rate *= 0.99999

        cont_img, cont_label, cls_img, cls_label = btGen.getBatch(
            mini_batch_size)

        # to One-hot
        cont_labels = np.zeros([mini_batch_size, num_classes])
        cls_labels = np.zeros([mini_batch_size, num_classes])
        for b in range(mini_batch_size):
            cont_labels[b] = np.identity(num_classes)[cont_label[b]]
            cls_labels[b] = np.identity(num_classes)[cls_label[b]]

        gen_loss, dis_loss = model.train(
            content_image=cont_img,
            class_image=cls_img,
            content_label=cont_labels,
            class_label=cls_labels,
            discriminator_learning_rate=discriminator_learning_rate,
            generator_learning_rate=generator_learning_rate)

        print(
            'Iteration: {:07d}, Generator Loss : {:.3f}, Discriminator Loss : {:.3f}'
            .format(iteration, gen_loss, dis_loss))

        if iteration % 5000 == 0:
            print('Checkpointing...')
            model.save(directory=os.path.join('experiments', model_name,
                                              'checkpoints'),
                       filename='{}_{}.ckpt'.format(model_name, iteration))

        if iteration % 100 == 0 or iteration == 1:
            cont_img, cont_label, cls_img, cls_label = btGen.getBatch(
                mini_batch_size)
            for b in range(mini_batch_size):
                cont_labels[b] = np.identity(num_classes)[cont_label[b]]
                cls_labels[b] = np.identity(num_classes)[cls_label[b]]
            gen_img = model.test(cont_img, cls_img)
            gen_img = np.array(gen_img)
            gen_img = np.squeeze(gen_img)
            print(gen_img.shape)
            contTiled = tileImage(cont_img)
            clsTiled = tileImage(cls_img)
            genTiled = tileImage(gen_img)

            out = np.concatenate([contTiled, clsTiled, genTiled], axis=1)
            out = (out + 1) * 127.5
            print(out.shape)
            cv2.imwrite(
                "{}/{:07}.png".format(validation_output_dir, iteration), out)

        iteration += 1
def main():
    if not os.path.exists(SAVE_DIR):
        os.mkdir(SAVE_DIR)

    if not os.path.exists(SVIM_DIR):
        os.mkdir(SVIM_DIR)

    img_size = 128
    bs = 32
    z_dim = 64
    critic = 3
    lmd = 10

    datalen = foloderLength(DATASET_DIR)

    # loading images on training
    batch = BatchGenerator(img_size=img_size, imgdir=DATASET_DIR)

    id = np.random.choice(range(datalen), bs)

    IN_ = batch.getBatch(bs, id)[:4]
    IN_ = (IN_ + 1) * 127.5
    IN_ = tileImage(IN_)

    cv2.imwrite("{}/input.png".format(SVIM_DIR), IN_)

    z = tf.placeholder(tf.float32, [bs, z_dim])
    X_real = tf.placeholder(tf.float32, [bs, img_size, img_size, 3])

    X_fake = buildGenerator(z, z_dim=z_dim, img_size=img_size, nBatch=bs)
    fake_y = buildDiscriminator(y=X_fake, nBatch=bs, isTraining=True)
    real_y = buildDiscriminator(y=X_real,
                                nBatch=bs,
                                reuse=True,
                                isTraining=True)
    d_loss_real = -tf.reduce_mean(real_y)
    d_loss_fake = tf.reduce_mean(fake_y)
    g_loss = -tf.reduce_mean(fake_y)

    epsilon = tf.random_uniform(shape=[bs, 1, 1, 1], minval=0., maxval=1.)
    X_hat = X_real + epsilon * (X_fake - X_real)
    D_X_hat = buildDiscriminator(X_hat,
                                 nBatch=bs,
                                 reuse=True,
                                 isTraining=False)
    grad_D_X_hat = tf.gradients(D_X_hat, [X_hat])[0]

    slopes = tf.sqrt(tf.reduce_sum(tf.square(grad_D_X_hat), axis=[1, 2, 3]))
    gradient_penalty = tf.reduce_mean((slopes - 1.)**2)

    wd_g = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES,
                             scope="Generator")
    wd_d = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES,
                             scope="Discriminator")

    wd_g = tf.reduce_sum(wd_g)
    wd_d = tf.reduce_sum(wd_d)

    d_loss = d_loss_real + d_loss_fake + lmd * gradient_penalty + wd_d
    d_loss += 0.001 * tf.reduce_mean(tf.square(d_loss_real - 0.0))
    g_loss = g_loss + wd_g

    g_opt = tf.train.AdamOptimizer(2e-4, beta1=0.5).minimize(
        g_loss,
        var_list=[
            x for x in tf.trainable_variables() if "Generator" in x.name
        ])
    d_opt = tf.train.AdamOptimizer(2e-4, beta1=0.5).minimize(
        d_loss,
        var_list=[
            x for x in tf.trainable_variables() if "Discriminator" in x.name
        ])

    printParam(scope="Generator")
    printParam(scope="Discriminator")

    start = time.time()

    config = tf.ConfigProto(gpu_options=tf.GPUOptions(
        per_process_gpu_memory_fraction=0.66))

    sess = tf.Session()
    sess.run(tf.global_variables_initializer())

    saver = tf.train.Saver()
    summary = tf.summary.merge_all()

    ckpt = tf.train.get_checkpoint_state(SAVE_DIR)

    if ckpt:  # checkpointがある場合
        last_model = ckpt.model_checkpoint_path  # 最後に保存したmodelへのパス
        print("load " + last_model)
        saver.restore(sess, last_model)  # 変数データの読み込み
        print("succeed restore model")
    else:
        print("models were not found")
        init = tf.global_variables_initializer()
        sess.run(init)

    print("%.4e sec took initializing" % (time.time() - start))
    g_hist = []
    d_hist = []

    start = time.time()
    stable = np.random.uniform(-1., +1., [bs, z_dim]).astype(np.float32)
    for i in range(100001):
        # loading images on training

        for c in range(critic):
            id = np.random.choice(range(datalen), bs)
            batch_images = batch.getBatch(bs, id)
            batch_z = np.random.uniform(-1., +1.,
                                        [bs, z_dim]).astype(np.float32)
            _, dis_loss = sess.run([d_opt, d_loss],
                                   feed_dict={
                                       z: batch_z,
                                       X_real: batch_images
                                   })

        id = np.random.choice(range(datalen), bs)
        batch_images_x = batch.getBatch(bs, id)
        batch_z = np.random.uniform(-1., +1., [bs, z_dim]).astype(np.float32)
        _, gen_loss = sess.run([g_opt, g_loss],
                               feed_dict={
                                   z: batch_z,
                                   X_real: batch_images
                               })

        print("in step %s, dis_loss = %.4e, gen_loss = %.4e" %
              (i, dis_loss, gen_loss))
        g_hist.append(gen_loss)
        d_hist.append(dis_loss)

        if i % 100 == 0:
            batch_z = np.random.uniform(-1., +1.,
                                        [bs, z_dim]).astype(np.float32)
            g_image = sess.run(X_fake, feed_dict={z: batch_z})
            cv2.imwrite(os.path.join(SVIM_DIR, "img_%d_fake.png" % i),
                        tileImage(g_image) * 127. + 127.5)
            g_image = sess.run(X_fake, feed_dict={z: stable})
            cv2.imwrite(os.path.join(SVIM_DIR, "imgst_%d_fake.png" % i),
                        tileImage(g_image) * 127. + 127.5)

            fig = plt.figure(figsize=(8, 6), dpi=128)
            ax = fig.add_subplot(111)
            plt.title("Loss")
            plt.grid(which="both")
            #plt.yscale("log")
            ax.plot(g_hist, label="gen_loss", linewidth=0.25)
            ax.plot(d_hist, label="dis_loss", linewidth=0.25)
            plt.xlabel('step', fontsize=16)
            plt.ylabel('loss', fontsize=16)
            plt.legend(loc='upper right')
            plt.savefig("hist.png")
            plt.close()

            print("%.4e sec took 100steps" % (time.time() - start))
            start = time.time()

        if i % 1000 == 0:
            saver.save(sess, os.path.join(SAVE_DIR, "model.ckpt"), i)
Пример #3
0
def main():
    img_size = 96
    bs = 4
    val_size = 4
    trans_lr = 1e-4

    start = time.time()

    batchgen = BatchGenerator(img_size=img_size,
                              LRDir=TRAIN_LR_DIR,
                              HRDir=TRAIN_HR_DIR,
                              aug=True)
    valgen = BatchGenerator(img_size=img_size,
                            LRDir=VAL_LR_DIR,
                            HRDir=VAL_HR_DIR,
                            aug=False)

    IN_, OUT_ = batchgen.getBatch(4)
    IN_ = tileImage(IN_)
    IN_ = cv2.resize(IN_, (img_size * 2 * 4, img_size * 2 * 4),
                     interpolation=cv2.INTER_CUBIC)
    IN_ = (IN_ + 1) * 127.5
    OUT_ = tileImage(OUT_)
    OUT_ = cv2.resize(OUT_, (img_size * 4 * 2, img_size * 4 * 2))
    OUT_ = (OUT_ + 1) * 127.5

    Z_ = np.concatenate((IN_, OUT_), axis=1)
    cv2.imwrite("input.png", Z_)
    print("%s sec took sampling" % (time.time() - start))

    start = time.time()

    x = tf.placeholder(tf.float32, [bs, img_size, img_size, 3])
    t = tf.placeholder(tf.float32, [bs, img_size * 4, img_size * 4, 3])
    lr = tf.placeholder(tf.float32)

    y = buildSRGAN_g(x)
    test_y = buildSRGAN_g(x, reuse=True, isTraining=False)
    fake_y = buildSRGAN_d(y)
    real_y = buildSRGAN_d(t, reuse=True)

    vgg_y1, vgg_y2, vgg_y3, vgg_y4, vgg_y5 = vgg19(y)

    vgg_t1, vgg_t2, vgg_t3, vgg_t4, vgg_t5 = vgg19(t, reuse=True)

    d_loss_real = tf.log((real_y) + 1e-10)
    d_loss_fake = tf.log(1 - (fake_y) + 1e-10)
    g_loss_fake = tf.reduce_mean(-tf.log((fake_y) + 1e-10)) * 2e-3

    wd_g = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES,
                             scope="Generator")
    wd_d = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES,
                             scope="Discriminator")

    wd_g = tf.reduce_sum(wd_g)
    wd_d = tf.reduce_sum(wd_d)

    L1_loss = tf.reduce_mean(tf.square(y - t))
    e_1 = tf.reduce_mean(tf.square(vgg_y1 - vgg_t1)) * 2.8
    e_2 = tf.reduce_mean(tf.square(vgg_y2 - vgg_t2)) * 0.2
    e_3 = tf.reduce_mean(tf.square(vgg_y3 - vgg_t3)) * 0.08
    e_4 = tf.reduce_mean(tf.square(vgg_y4 - vgg_t4)) * 0.2
    e_5 = tf.reduce_mean(tf.square(vgg_y5 - vgg_t5)) * 75.0
    vgg_loss = (e_1 + e_2 + e_3 + e_4 + e_5) * 2e-7

    pre_loss = L1_loss + vgg_loss + wd_g
    g_loss = L1_loss + vgg_loss + g_loss_fake + wd_g
    d_loss = tf.reduce_mean(-(d_loss_fake + d_loss_real)) + wd_d

    g_pre = tf.train.AdamOptimizer(1e-4, beta1=0.5).minimize(
        pre_loss,
        var_list=[x for x in tf.trainable_variables() if "SRGAN_g" in x.name])
    g_opt = tf.train.AdamOptimizer(lr, beta1=0.5).minimize(
        g_loss,
        var_list=[x for x in tf.trainable_variables() if "SRGAN_g" in x.name])
    d_opt = tf.train.AdamOptimizer(lr / 2, beta1=0.5).minimize(
        d_loss,
        var_list=[x for x in tf.trainable_variables() if "SRGAN_d" in x.name])

    print("%.4f sec took building" % (time.time() - start))
    printParam(scope="SRGAN_g")
    printParam(scope="SRGAN_d")
    printParam(scope="vgg19")

    g_vars = [x for x in tf.trainable_variables() if "SRGAN_g" in x.name]
    d_vars = [x for x in tf.trainable_variables() if "SRGAN_d" in x.name]
    vgg_vars = [x for x in tf.trainable_variables() if "vgg19" in x.name]

    saver = tf.train.Saver()
    saver_vgg = tf.train.Saver(vgg_vars)

    sess = tf.Session()
    sess.run(tf.global_variables_initializer())

    ckpt = tf.train.get_checkpoint_state(SAVE_DIR)

    if ckpt:  # is checkpoint exist
        last_model = ckpt.model_checkpoint_path
        #last_model = ckpt.all_model_checkpoint_paths[0]
        print("load " + last_model)
        saver.restore(sess, last_model)  # read variable data
        print("succeed restore model")
    else:
        init = tf.global_variables_initializer()
        sess.run(init)

    ckpt_vgg = tf.train.get_checkpoint_state('modelvgg')
    last_model = ckpt_vgg.model_checkpoint_path
    saver_vgg.restore(sess, last_model)

    print("%.4e sec took initializing" % (time.time() - start))

    hist = []
    hist_g = []
    hist_d = []

    start = time.time()
    print("start pretrain")
    for p in range(50001):
        batch_images_x, batch_images_t = batchgen.getBatch(bs)
        tmp, gen_loss, L1, vgg = sess.run([g_pre, pre_loss, L1_loss, vgg_loss],
                                          feed_dict={
                                              x: batch_images_x,
                                              t: batch_images_t
                                          })

        hist.append(gen_loss)
        print("in step %s, pre_loss =%.4e, L1_loss=%.4e, vgg_loss=%.4e" %
              (p, gen_loss, L1, vgg))

        if p % 100 == 0:
            batch_images_x, batch_images_t = batchgen.getBatch(bs)

            out = sess.run(test_y, feed_dict={x: batch_images_x})
            X_ = tileImage(batch_images_x[:4])
            Y_ = tileImage(out[:4])
            Z_ = tileImage(batch_images_t[:4])

            X_ = cv2.resize(X_, (img_size * 2 * 4, img_size * 2 * 4),
                            interpolation=cv2.INTER_CUBIC)

            X_ = (X_ + 1) * 127.5
            Y_ = (Y_ + 1) * 127.5
            Z_ = (Z_ + 1) * 127.5
            Z_ = np.concatenate((X_, Y_, Z_), axis=1)

            cv2.imwrite("{}/pre_{}.png".format(SAVEIM_DIR, p), Z_)

            fig = plt.figure(figsize=(8, 6), dpi=128)
            ax = fig.add_subplot(111)
            plt.title("Loss")
            plt.grid(which="both")
            plt.yscale("log")
            ax.plot(hist, label="gen_loss", linewidth=0.25)
            plt.xlabel('step', fontsize=16)
            plt.ylabel('loss', fontsize=16)
            plt.legend(loc='upper right')
            plt.savefig("hist_pre.png")
            plt.close()

            print("%.4e sec took 100steps" % (time.time() - start))
            start = time.time()
        if p % 5000 == 0 and p != 0:
            saver.save(sess, os.path.join(SAVEPRE_DIR, "model.ckpt"), p)

    print("start Discriminator")
    for d in range(0):
        batch_images_x, batch_images_t = batchgen.getBatch(bs)

        tmp, dis_loss = sess.run([
            d_opt,
            d_loss,
        ],
                                 feed_dict={
                                     x: batch_images_x,
                                     t: batch_images_t,
                                     lr: 1e-4,
                                 })

        print("in step %s, dis_loss = %.4e" % (d, dis_loss))

    print("start GAN")
    for i in range(100001):
        batch_images_x, batch_images_t = batchgen.getBatch(bs)

        tmp, gen_loss, L1, adv, vgg, = sess.run(
            [g_opt, g_loss, L1_loss, g_loss_fake, vgg_loss],
            feed_dict={
                x: batch_images_x,
                t: batch_images_t,
                lr: trans_lr,
            })

        batch_images_x, batch_images_t = batchgen.getBatch(bs)

        tmp, dis_loss = sess.run([
            d_opt,
            d_loss,
        ],
                                 feed_dict={
                                     x: batch_images_x,
                                     t: batch_images_t,
                                     lr: trans_lr,
                                 })

        batch_images_x, batch_images_t = batchgen.getBatch(bs)

        tmp, gen_loss, L1, adv, vgg, = sess.run(
            [g_opt, g_loss, L1_loss, g_loss_fake, vgg_loss],
            feed_dict={
                x: batch_images_x,
                t: batch_images_t,
                lr: trans_lr,
            })

        if trans_lr > 1e-5:
            trans_lr = trans_lr * 0.99998

        print("in step %s, dis_loss = %.4e, gen_loss = %.4e" %
              (i, dis_loss, gen_loss))
        print("L1_loss=%.4e, adv_loss=%.4e, vgg_loss=%.4e" % (L1, adv, vgg))

        hist_g.append(gen_loss)
        hist_d.append(dis_loss)

        if i % 100 == 0:
            batch_images_x, batch_images_t = batchgen.getBatch(bs)

            out = sess.run(test_y, feed_dict={x: batch_images_x})
            X_ = tileImage(batch_images_x[:4])
            Y_ = tileImage(out[:4])
            Z_ = tileImage(batch_images_t[:4])

            X_ = (X_ + 1) * 127.5
            X_ = cv2.resize(X_, (img_size * 4 * 2, img_size * 4 * 2),
                            interpolation=cv2.INTER_CUBIC)
            Y_ = (Y_ + 1) * 127.5
            Z_ = (Z_ + 1) * 127.5
            Z_ = np.concatenate((X_, Y_, Z_), axis=1)
            cv2.imwrite("{}/{}.png".format(SAVEIM_DIR, i), Z_)

            fig = plt.figure(figsize=(8, 6), dpi=128)
            ax = fig.add_subplot(111)
            plt.title("Loss")
            plt.grid(which="both")
            plt.yscale("log")
            ax.plot(hist_g, label="gen_loss", linewidth=0.25)
            ax.plot(hist_d, label="dis_loss", linewidth=0.25)
            plt.xlabel('step', fontsize=16)
            plt.ylabel('loss', fontsize=16)
            plt.legend(loc='upper right')
            plt.savefig("hist.png")
            plt.close()

            print("%.4f sec took per 100steps, lr = %.4e" %
                  (time.time() - start, trans_lr))
            start = time.time()

        if i % 5000 == 0 and i != 0:
            saver.save(sess, os.path.join(SAVE_DIR, "model.ckpt"), i)
Пример #4
0
def main():
    if not os.path.exists(SAVE_DIR):
        os.mkdir(SAVE_DIR)

    if not os.path.exists(SVIM_DIR):
        os.mkdir(SVIM_DIR)

    img_size = [2**(i + 2) for i in range(9)]
    bs = [64, 64, 32, 32, 32, 16, 8, 4, 4]
    steps = [8000, 10000, 20000, 40000, 50000, 60000, 80000, 90000, 100000]
    z_dim = 512
    lmd = 10

    batch = BatchGenerator(img_size=256, datadir=DATASET_DIR)
    IN_ = batch.getBatch(4)
    IN_ = (IN_ + 1) * 127.5
    IN_ = tileImage(IN_)
    cv2.imwrite("{}/input.png".format(SVIM_DIR), IN_)

    z = tf.placeholder(tf.float32, [None, 1, 1, z_dim])
    X_real = [tf.placeholder(tf.float32, [None, r, r, 3]) for r in img_size]
    alpha = tf.placeholder(tf.float32, [])
    X_fake = [buildGenerator(z, alpha, stage=i + 1) for i in range(9)]
    fake_y = [
        buildDiscriminator(x, alpha, stage=i + 1, reuse=False)
        for i, x in enumerate(X_fake)
    ]
    real_y = [
        buildDiscriminator(x, alpha, stage=i + 1, reuse=True)
        for i, x in enumerate(X_real)
    ]

    #WGAN-GP
    xhats = []
    d_xhats = []
    for i, (real, fake) in enumerate(zip(X_real, X_fake)):
        epsilon = tf.random_uniform(shape=[tf.shape(real)[0], 1, 1, 1],
                                    minval=0.0,
                                    maxval=1.0)
        inter = real * epsilon + fake * (1 - epsilon)
        d_xhat = buildDiscriminator(inter, alpha, stage=i + 1, reuse=True)
        xhats.append(inter)
        d_xhats.append(d_xhat)

    g_losses, d_losses = calc_losses(real_y, fake_y, xhats, d_xhats)

    g_var = [x for x in tf.trainable_variables() if "Generator" in x.name]
    d_var = [x for x in tf.trainable_variables() if "Discriminator" in x.name]
    opt = tf.train.AdamOptimizer(learning_rate=1e-3,
                                 beta1=0.0,
                                 beta2=0.99,
                                 epsilon=1e-8)

    g_opt = [opt.minimize(g_loss, var_list=g_var) for g_loss in g_losses]
    d_opt = [opt.minimize(d_loss, var_list=d_var) for d_loss in d_losses]

    printParam(scope="Generator")
    printParam(scope="Discriminator")

    start = time.time()

    config = tf.ConfigProto(gpu_options=tf.GPUOptions(
        per_process_gpu_memory_fraction=0.75))

    sess = tf.Session(config=config)
    sess.run(tf.global_variables_initializer())

    saver = tf.train.Saver()

    ckpt = tf.train.get_checkpoint_state(SAVE_DIR)

    if ckpt:  # checkpointがある場合
        last_model = ckpt.model_checkpoint_path  # 最後に保存したmodelへのパス
        print("load " + last_model)
        saver.restore(sess, last_model)  # 変数データの読み込み
        print("succeed restore model")
    else:
        print("models were not found")
        init = tf.global_variables_initializer()
        sess.run(init)

    print("%.4e sec took initializing" % (time.time() - start))

    start = time.time()
    for stage in range(0, 9):
        batch = BatchGenerator(img_size=img_size[stage], datadir=DATASET_DIR)

        x_batch = batch.getBatch(bs[stage], alpha=1.0)
        out = tileImage(x_batch)
        out = np.array((out + 1) * 127.5, dtype=np.uint8)
        outdir = os.path.join(SVIM_DIR, 'stage{}'.format(stage + 1))
        os.makedirs(outdir, exist_ok=True)
        dst = os.path.join(outdir, 'sample.png')
        cv2.imwrite(dst, out)
        g_hist = []
        d_hist = []
        print("starting stage{}".format(stage + 1))
        for i in range(steps[stage] + 1):
            delta = 4 * i / (steps[stage])
            if stage == 0:
                alp = 1.0
            else:
                alp = min(delta, 1.0)

            x_batch = batch.getBatch(bs[stage], alpha=alp)

            z_batch = np.random.normal(0, 0.5, [bs[stage], 1, 1, 512])

            _, dis_loss = sess.run([d_opt[stage], d_losses[stage]],
                                   feed_dict={
                                       X_real[stage]: x_batch,
                                       z: z_batch,
                                       alpha: alp
                                   })

            z_batch = np.random.normal(0, 0.5, [bs[stage], 1, 1, 512])
            _, gen_loss = sess.run([g_opt[stage], g_losses[stage]],
                                   feed_dict={
                                       z: z_batch,
                                       alpha: alp
                                   })

            g_hist.append(gen_loss)
            d_hist.append(dis_loss)

            print("in step %s, dis_loss = %.4e, gen_loss = %.4e" %
                  (i, dis_loss, gen_loss))

            if i % 100 == 0:
                # save sample image
                z_batch = np.random.normal(0, 0.5, [bs[stage], 1, 1, 512])
                out = X_fake[stage].eval(feed_dict={
                    z: z_batch,
                    alpha: alp
                },
                                         session=sess)
                out = tileImage(out)
                out = np.array((out + 1) * 127.5, dtype=np.uint8)
                outdir = os.path.join(SVIM_DIR, 'stage{}'.format(stage + 1))
                os.makedirs(outdir, exist_ok=True)
                dst = os.path.join(outdir,
                                   '{}.png'.format('{0:09d}'.format(i)))
                cv2.imwrite(dst, out)

                # save loss graph
                fig = plt.figure(figsize=(8, 6), dpi=128)
                ax = fig.add_subplot(111)
                plt.title("Loss")
                plt.grid(which="both")
                ax.plot(g_hist, label="gen_loss", linewidth=0.25)
                ax.plot(d_hist, label="dis_loss", linewidth=0.25)
                plt.xlabel('step', fontsize=16)
                plt.ylabel('loss', fontsize=16)
                plt.legend(loc='upper right')
                plt.savefig(os.path.join(outdir, "hist.png"))
                plt.close()

            if i % 5000 == 0 and i != 0:
                saver.save(sess, os.path.join(SAVE_DIR, "model.ckpt"), i)
Пример #5
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--gpu',
                        type=str,
                        default='0',
                        help='Which GPU to use')
    args = parser.parse_args()
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu

    img_size = 64
    bs = 4
    trans_lr = 1e-4

    start = time.time()

    batchgen = BatchGenerator(img_size=img_size,
                              LRDir=TRAIN_LR_DIR,
                              HRDir=TRAIN_HR_DIR,
                              aug=True)
    valgen = BatchGenerator(img_size=img_size,
                            LRDir=VAL_LR_DIR,
                            HRDir=VAL_HR_DIR,
                            aug=False)

    #save samples
    IN_, OUT_ = batchgen.getBatch(4)[:4]
    print(IN_.shape)
    IN_ = tileImage(IN_)
    IN_ = cv2.resize(IN_, (img_size * 2 * 4, img_size * 2 * 4),
                     interpolation=cv2.INTER_CUBIC)
    IN_ = (IN_ + 1) * 127.5
    OUT_ = tileImage(OUT_)
    OUT_ = cv2.resize(OUT_, (img_size * 4 * 2, img_size * 4 * 2))
    OUT_ = (OUT_ + 1) * 127.5
    Z_ = np.concatenate((IN_, OUT_), axis=1)
    cv2.imwrite("input.png", Z_)
    print("%s sec took sampling" % (time.time() - start))

    start = time.time()

    x = tf.placeholder(tf.float32, [bs, img_size, img_size, 3])
    t = tf.placeholder(tf.float32, [bs, img_size * 4, img_size * 4, 3])
    lr = tf.placeholder(tf.float32)

    generator = Generator()

    y = generator.ThermalSR(x)
    test_y = generator.ThermalSR(x, reuse=True, isTraining=False)

    # L1 loss function
    L1_loss = tf.losses.absolute_difference(y, t)

    # Contextual loss function
    #vgg_real = build_vgg19(y)
    #vgg_fake = build_vgg19(t)
    # CX_loss_content_list = [w * CX_loss_helper(vgg_real[layer], vgg_fake[layer], config.CX)
    #for layer, w in config.CX.feat_content_layers.items()]
    #CX_content_loss = tf.reduce_sum(CX_loss_content_list)
    #CX_content_loss *= config.W.CX_content

    # ssim loss function
    ssim_ = tf.reduce_mean(tf.image.ssim(y, t, 2.0))
    ssim_loss = 1 - ssim_

    # Total loss function
    Total_loss = L1_loss + ssim_loss

    g_loss = tf.train.AdamOptimizer(1e-4, beta1=0.5).minimize(
        Total_loss,
        var_list=[
            x for x in tf.trainable_variables() if "ThermalSR" in x.name
        ])

    print("%.4f sec took building" % (time.time() - start))
    printParam(scope="ThermalSR")

    g_vars = [x for x in tf.trainable_variables() if "ThermalSR" in x.name]

    saver = tf.train.Saver(max_to_keep=15)

    sess = tf.Session()
    sess.run(tf.global_variables_initializer())

    ckpt = tf.train.get_checkpoint_state(SAVEPRE_DIR)

    if ckpt:  # is checkpoint exist
        last_model = ckpt.model_checkpoint_path
        #last_model = ckpt.all_model_checkpoint_paths[0]
        print("load " + last_model)
        saver.restore(sess, last_model)  # read variable data
        print("succeed restore model")
    else:
        init = tf.global_variables_initializer()
        sess.run(init)

    print("%.4e sec took initializing" % (time.time() - start))

    hist = []
    hist_g = []

    start = time.time()
    print("start pretrain")
    for p in range(50001):
        batch_images_x, batch_images_t = batchgen.getBatch(bs)
        tmp, gen_loss, l1, ssim = sess.run(
            [g_loss, Total_loss, L1_loss, ssim_loss],
            feed_dict={
                x: batch_images_x,
                t: batch_images_t
            })

        hist.append(gen_loss)
        print("in step %s, pre_loss =%.4e, l1_loss=%.4e, ssim_loss=%.4e" %
              (p, gen_loss, l1, ssim))

        if p % 100 == 0:
            batch_images_x, batch_images_t = valgen.getBatch(bs)
            out = sess.run(test_y, feed_dict={x: batch_images_x})
            #            out1 = (out + 1)*127.5
            #            target1 = (batch_images_t + 1)*127.5

            #            p, s = evaluate.test_images(out1, target1)

            X_ = tileImage(batch_images_x[:4])
            Y_ = tileImage(out[:4])
            Z_ = tileImage(batch_images_t[:4])

            X_ = cv2.resize(X_, (img_size * 2 * 4, img_size * 2 * 4),
                            interpolation=cv2.INTER_CUBIC)

            X_ = (X_ + 1) * 127.5
            Y_ = (Y_ + 1) * 127.5
            Z_ = (Z_ + 1) * 127.5
            ZZ_ = np.concatenate((X_, Y_, Z_), axis=1)

            cv2.imwrite("{0}/pre_{1:06d}.png".format(SAVEIM_DIR, int(p)), ZZ_)

            print("%.4e sec took 100steps" % (time.time() - start))
            start = time.time()

        if p % 1000 == 0:
            fig = plt.figure(figsize=(8, 6), dpi=128)
            ax = fig.add_subplot(111)
            plt.title("Loss")
            plt.grid(which="both")
            plt.yscale("log")
            ax.plot(hist, label="gen_loss", linewidth=0.25)
            plt.xlabel('step', fontsize=16)
            plt.ylabel('loss', fontsize=16)
            plt.legend(loc='upper right')
            plt.savefig("hist_pre.png")
            plt.close()

        if p % 5000 == 0 and p != 0:
            # batch_images_x1, batch_images_t1 = valgen.getBatch(50)
            # out1 = sess.run(test_y, feed_dict={x:batch_images_x1})
            # batch_images_t1 = (batch_images_t1 + 1)*127.5
            # out1 = (out1 + 1)*127.5
            # p1, s1 = evaluate.test_images(batch_images_t1, out1)
            # print('PSNR: %.2f, SSIM: %.4f' %(p1, s1))

            saver.save(sess, os.path.join(SAVEPRE_DIR, "model.ckpt"), p)
Пример #6
0
def main():
    if not os.path.exists(SAVE_DIR):
        os.mkdir(SAVE_DIR)

    if not os.path.exists(SVIM_DIR):
        os.mkdir(SVIM_DIR)

    img_size = [2**(i+2) for i in range(9)]
    #bs = [64, 48, 32, 24, 16, 12, 8, 4, 4] # PC has enough VRAM
    #bs = [48, 32, 24, 16, 12, 8, 4, 4, 4]
    bs = [16, 16, 16, 16, 12, 8, 4, 3, 2]
    #steps = [16000,24000,40000,64000,96000,128000,160000,200000,240000]
    steps = [1,16000,24000,40000,64000,96000,128000,192000,320000]
    #steps = [12000,28000,60000,120000,240000,360000,600000,960000,2160000]

    z_dim = 512

    # save sample images
    batch = BatchGenerator(img_size=512,datadir=DATASET_DIR)
    IN_ = batch.getBatch(4)
    IN_ = (IN_ + 1)*127.5
    IN_ =tileImage(IN_)
    cv2.imwrite("{}/input.png".format(SVIM_DIR),IN_)

    z = tf.placeholder(tf.float32, [None, z_dim])
    X_real =  [tf.placeholder(tf.float32, [None, r, r, 3]) for r in img_size]
    alpha = tf.placeholder(tf.float32, [])
    X_fake = [buildGenerator(z, alpha, stage=i+1) for i in range(9)]
    fake_y = [buildDiscriminator(x, alpha, stage=i+1, reuse=False) for i, x in enumerate(X_fake)]
    real_y = [buildDiscriminator(x, alpha, stage=i+1, reuse=True) for i, x in enumerate(X_real)]
    lr = tf.placeholder(tf.float32, [])
    
    """
    #WGAN-gp
    xhats = []
    d_xhats = []
    for i, (real, fake) in enumerate(zip(X_real, X_fake)):
        epsilon = tf.random_uniform(shape=[tf.shape(real)[0], 1, 1, 1], minval=0.0, maxval=1.0)
        inter = real * epsilon + fake * (1 - epsilon)
        d_xhat = buildDiscriminator(inter, alpha, stage=i+1, reuse=True)
        xhats.append(inter)
        d_xhats.append(d_xhat)

    g_losses, d_losses = calc_losses(real_y, fake_y, xhats, d_xhats)
    """

    # softplus
    g_losses = []
    d_losses = []
    for i, (real_images, real_logit, fake_logit) in enumerate(zip(X_real, real_y, fake_y)):
        r1_gamma = 10.0

        # discriminator loss: gradient penalty
        d_loss_gan = tf.nn.softplus(fake_logit) + tf.nn.softplus(-real_logit)
        real_loss = tf.reduce_sum(real_logit)
        real_grads = tf.gradients(real_loss, [real_images])[0]
        r1_penalty = tf.reduce_sum(tf.square(real_grads), axis=[1, 2, 3])
        d_loss = d_loss_gan + r1_penalty * (r1_gamma * 0.5)
        d_loss = tf.reduce_mean(d_loss)

        # generator loss: logistic nonsaturating
        g_loss = tf.nn.softplus(-fake_logit)
        g_loss = tf.reduce_mean(g_loss)
        g_losses.append(g_loss)
        d_losses.append(d_loss)

    g_var = [x for x in tf.trainable_variables() if "Generator"     in x.name]
    d_var = [x for x in tf.trainable_variables() if "Discriminator" in x.name]
    opt = tf.train.AdamOptimizer(learning_rate=lr, beta1=0.0, beta2=0.99, epsilon=1e-8)

    g_opt = [opt.minimize(g_loss, var_list=g_var) for g_loss in g_losses]
    d_opt = [opt.minimize(d_loss, var_list=d_var) for d_loss in d_losses]

    printParam(scope="Generator")
    printParam(scope="Discriminator")

    start = time.time()

    config = tf.ConfigProto(gpu_options=tf.GPUOptions(per_process_gpu_memory_fraction=0.75))

    sess =tf.Session(config=config)
    sess.run(tf.global_variables_initializer())

    saver = tf.train.Saver()

    ckpt = tf.train.get_checkpoint_state(SAVE_DIR)

    if ckpt: # checkpointがある場合
        last_model = ckpt.model_checkpoint_path # 最後に保存したmodelへのパス
        print ("load " + last_model)
        saver.restore(sess, last_model) # 変数データの読み込み
        print("succeed restore model")
    else:
        print("models were not found")
        init = tf.global_variables_initializer()
        sess.run(init)

    print("%.4f sec took initializing"%(time.time()-start))

    start = time.time()
    for stage in range(0,9):
        #batch =  BatchGenerator(img_size=img_size[stage],datadir=DATASET_DIR)
        if stage<6:
            batch =  BatchGenerator(img_size=img_size[stage],datadir="ffhq_dataset128")
        else:
            batch =  BatchGenerator(img_size=img_size[stage],datadir="ffhq_dataset")
        #save samples
        x_batch = batch.getBatch(bs[stage],alpha=1.0)
        out = tileImage(x_batch)
        out = np.array((out + 1) * 127.5, dtype=np.uint8)
        outdir = os.path.join(SVIM_DIR, 'stage{}'.format(stage+1))
        os.makedirs(outdir, exist_ok=True)
        dst = os.path.join(outdir, 'sample.png')
        cv2.imwrite(dst, out)

        trans_lr = 1e-3
        g_hist = []
        d_hist = []
        print("starting stage{}".format(stage+1))
        for i in range(steps[stage]+1):
            delta = 4*i/(steps[stage]+1)
            # First stage does not require interpolation
            if stage == 1 or stage == 2:
                alp = 1.0
            else:
                alp = min(delta, 1.0)

            x_batch = batch.getBatch(bs[stage],alpha=alp)

            z_batch = np.random.normal(0, 0.5, [bs[stage], z_dim])

            _, dis_loss = sess.run([d_opt[stage], d_losses[stage]],
                                 feed_dict={X_real[stage]: x_batch, z: z_batch, alpha: alp, lr:trans_lr})

            z_batch = np.random.normal(0, 0.5, [bs[stage], z_dim])
            _, gen_loss = sess.run([g_opt[stage], g_losses[stage]], feed_dict={z: z_batch, alpha: alp, lr:trans_lr})

            g_hist.append(gen_loss)
            d_hist.append(dis_loss)

            print("stage:[%d], in step %s, dis_loss = %.3e, gen_loss = %.3e, alpha = %.3f, lr = %.3e"
                    %(stage+1, i,dis_loss, gen_loss, alp, trans_lr))

            if alp==1.0:
                #decaying learning rate
                trans_lr *= (1 - 2 / steps[stage])

            if i%100 == 0:
                z_batch = np.random.normal(0, 0.5, [bs[stage], z_dim])
                out = X_fake[stage].eval(feed_dict={z: z_batch, alpha: alp}, session=sess)
                out = tileImage(out)
                out = np.array((out + 1) * 127.5, dtype=np.uint8)
                outdir = os.path.join(SVIM_DIR, 'stage{}'.format(stage+1))
                os.makedirs(outdir, exist_ok=True)
                dst = os.path.join(outdir, '{}_alp.png'.format('{0:09d}'.format(i)))
                cv2.imwrite(dst, out)

                fig = plt.figure(figsize=(8,6), dpi=128)
                ax = fig.add_subplot(111)
                plt.title("Loss")
                plt.grid(which="both")
                plt.yscale("log")
                ax.plot(g_hist,label="gen_loss", linewidth = 0.25)
                ax.plot(d_hist,label="dis_loss", linewidth = 0.25)
                plt.xlabel('step', fontsize = 16)
                plt.ylabel('loss', fontsize = 16)
                plt.legend(loc = 'upper right')
                plt.savefig(os.path.join(outdir,"hist.png"))
                plt.close()

            if i % 8000 == 0 and i!=0:
                saver.save(sess,os.path.join(SAVE_DIR,"model.ckpt"),i)
Пример #7
0
def main():
    if not os.path.exists(SAVE_DIR):
        os.makedirs(SAVE_DIR)
    if not os.path.exists(SVIM_DIR):
        os.makedirs(SVIM_DIR)

    img_size = 256
    bs = 16

    dir = DATASET_DIR
    val = VAL_DIR
    datalen = foloderLength(DATASET_DIR)
    vallen = foloderLength(VAL_DIR)

    # loading images on training
    batch = BatchGenerator(img_size=img_size, datadir=dir)
    val = BatchGenerator(img_size=img_size, datadir=val)

    id = np.random.choice(range(datalen), bs)
    IN_ = tileImage(batch.getBatch(bs, id)[:4])

    IN_ = (IN_ + 1) * 127.5
    cv2.imwrite("input.png", IN_)

    start = time.time()

    x = tf.placeholder(tf.float32, [bs, img_size, img_size, 3])
    t = tf.placeholder(tf.float32, [bs, img_size, img_size, 3])

    y = buildGenerator(x, nBatch=bs)

    loss = loss_g(y, t)
    printParam(scope="generator")

    train_step = training(loss)

    init = tf.global_variables_initializer()
    sess = tf.Session()
    sess.run(init)
    saver = tf.train.Saver()
    summary = tf.summary.merge_all()

    ckpt = tf.train.get_checkpoint_state(SAVE_DIR)

    if ckpt:  # checkpointがある場合
        last_model = ckpt.model_checkpoint_path  # 最後に保存したmodelへのパス
        print("load " + last_model)
        saver.restore(sess, last_model)  # 変数データの読み込み
        print("succeed restore model")
    else:
        print("models were not found")
        init = tf.global_variables_initializer()
        sess.run(init)

    print("%.4e sec took initializing" % (time.time() - start))

    hist = []

    start = time.time()
    for i in range(100000):
        # loading images on training
        id = np.random.choice(range(datalen), bs)
        batch_images_x = batch.getBatch(bs, id, ocp=0.5)
        batch_images_t = batch.getBatch(bs, id, ocp=0.5)

        tmp, yloss = sess.run([train_step, loss],
                              feed_dict={
                                  x: batch_images_x,
                                  t: batch_images_t
                              })

        print("in step %s loss = %.4e" % (i, yloss))
        hist.append(yloss)

        if i % 100 == 0:
            id = np.random.choice(range(vallen), bs)
            batch_images_x = val.getBatch(bs, id, ocp=0.5)
            out = sess.run(y, feed_dict={x: batch_images_x})
            X_ = tileImage(batch_images_x[:4])
            Y_ = tileImage(out[:4])

            X_ = (X_ + 1) * 127.5
            Y_ = (Y_ + 1) * 127.5
            Z_ = np.concatenate((X_, Y_), axis=1)
            #print(np.max(X_))
            cv2.imwrite("{}/{}.png".format(SVIM_DIR, i), Z_)

            fig = plt.figure()
            ax = fig.add_subplot(111)
            plt.title("Loss")
            plt.grid(which="both")
            plt.yscale("log")
            ax.plot(hist, label="test", linewidth=0.5)
            plt.savefig("hist.png")
            plt.close()

            print("%.4e sec took per 100steps" % (time.time() - start))
            start = time.time()

        if i % 1000 == 0:
            if i > 1900:
                loss_1k_old = np.mean(hist[-2000:-1000])
                loss_1k_new = np.mean(hist[-1000:])
                print("old loss=%.4e , new loss=%.4e" %
                      (loss_1k_old, loss_1k_new))
                if loss_1k_old * 2 < loss_1k_new:
                    break

            saver.save(sess, os.path.join(SAVE_DIR, "model.ckpt"), i)
def main():
    if not os.path.exists(SAVE_DIR):
        os.makedirs(SAVE_DIR)
    if not os.path.exists(SVIM_DIR):
        os.makedirs(SVIM_DIR)
    img_size = 256
    bs = 4
    lr = tf.placeholder(tf.float32)
    lmd = tf.placeholder(tf.float32)

    trans_lr = 2e-4
    trans_lmd = 10
    max_step = 100000

    datalen = foloderLength(DATASET_DIR)
    vallen = foloderLength(VAL_DIR)

    # loading images on training
    batch = BatchGenerator(img_size=img_size, datadir=DATASET_DIR)
    val = BatchGenerator(img_size=img_size, datadir=VAL_DIR)
    id = np.random.choice(range(datalen), bs)

    IN_, OUT_ = batch.getBatch(bs, id)[:4]
    IN_ = (IN_ + 1) * 127.5
    IN_ = tileImage(IN_)
    OUT_ = (OUT_ + 1) * 127.5
    OUT_ = tileImage(OUT_)
    Z_ = np.concatenate([IN_, OUT_], axis=1)
    cv2.imwrite("input.png", Z_)

    x = tf.placeholder(tf.float32, [bs, img_size, img_size, 3])
    t = tf.placeholder(tf.float32, [bs, img_size, img_size, 3])

    y = buildGenerator(x)
    fake_y = buildDiscriminator(x, y, isTraining=True, nBatch=bs)
    real_y = buildDiscriminator(x, t, reuse=True, isTraining=True, nBatch=bs)

    # sce gan
    d_loss_real = tf.reduce_mean(
        tf.nn.sigmoid_cross_entropy_with_logits(logits=real_y,
                                                labels=tf.ones_like(real_y)))
    d_loss_fake = tf.reduce_mean(
        tf.nn.sigmoid_cross_entropy_with_logits(logits=fake_y,
                                                labels=tf.zeros_like(fake_y)))
    g_loss = tf.reduce_mean(
        tf.nn.sigmoid_cross_entropy_with_logits(logits=fake_y,
                                                labels=tf.ones_like(fake_y)))

    # ls gan
    #d_loss_real = tf.reduce_mean((real_y-tf.ones_like (real_y))**2)
    #d_loss_fake = tf.reduce_mean((fake_y-tf.zeros_like (fake_y))**2)
    #g_loss  = tf.reduce_mean((fake_y-tf.ones_like (fake_y))**2)

    #variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='g')
    wd_g = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES,
                             scope="Generator")
    wd_d = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES,
                             scope="Discriminator")

    wd_g = tf.reduce_sum(wd_g)
    wd_d = tf.reduce_sum(wd_d)

    L1_loss = tf.reduce_mean(tf.abs(y - t))

    d_loss = d_loss_real + d_loss_fake + wd_d
    g_loss = g_loss + lmd * L1_loss + wd_g

    #L2_loss = tf.nn.l2_loss(y-t)
    pre_loss = lmd * L1_loss + wd_g
    #g_pre = tf.train.AdamOptimizer(1e-3,beta1=0.5).minimize(pre_loss, var_list=[x for x in tf.trainable_variables() if "generator"     in x.name])
    g_opt = tf.train.AdamOptimizer(lr, beta1=0.5).minimize(
        g_loss,
        var_list=[
            x for x in tf.trainable_variables() if "Generator" in x.name
        ])
    d_opt = tf.train.AdamOptimizer(lr / 5, beta1=0.5).minimize(
        d_loss,
        var_list=[
            x for x in tf.trainable_variables() if "Discriminator" in x.name
        ])

    total_parameters = 0
    printParam(scope="Generator")
    printParam(scope="Discriminator")

    start = time.time()

    config = tf.ConfigProto(gpu_options=tf.GPUOptions(
        per_process_gpu_memory_fraction=0.66))

    sess = tf.Session()
    sess.run(tf.global_variables_initializer())

    saver = tf.train.Saver()
    summary = tf.summary.merge_all()

    ckpt = tf.train.get_checkpoint_state('model')

    if ckpt:  # checkpointがある場合
        last_model = ckpt.model_checkpoint_path  # 最後に保存したmodelへのパス
        print("load " + last_model)
        saver.restore(sess, last_model)  # 変数データの読み込み
        print("succeed restore model")
    else:
        print("models were not found")
        init = tf.global_variables_initializer()
        sess.run(init)

    print("%.4e sec took initializing" % (time.time() - start))

    hist = []
    g_hist = []
    d_hist = []

    start = time.time()
    """
    for p in range(10000):
        id = np.random.choice(range(datalen),bs)
        batch_images_x, batch_images_t  = batch.getBatch(bs,id)
        tmp, gen_loss = sess.run([g_pre,pre_loss], feed_dict={
            x: batch_images_x,
            t: batch_images_t
        })
        hist.append(gen_loss)
        print("in step %s, pre_loss =%.4e" %(p, gen_loss))

        if p % 100 == 0:
            out = sess.run(y,feed_dict={
                x:batch_images_x})
            X_ = tileImage(batch_images_x[:4])
            Y_ = tileImage(out[:4])
            Z_ = tileImage(batch_images_t[:4])

            X_ = (X_ + 1)*127.5
            Y_ = (Y_ + 1)*127.5
            Z_ = (Z_ + 1)*127.5
            Z_ = np.concatenate((X_,Y_,Z_), axis=1)
            #print(np.max(X_))
            cv2.imwrite("pre{}.png".format(p),Z_)

            fig = plt.figure()
            ax = fig.add_subplot(111)
            plt.title("Loss")
            plt.grid(which="both")
            plt.yscale("log")
            ax.plot(hist,label="gen_loss")
            plt.xlabel('x{} step'.format(100), fontsize = 16)
            plt.ylabel('loss', fontsize = 16)
            plt.legend(loc = 'upper right')
            plt.savefig("histL2.png")
            plt.close()

    print("%.4e sec took 1000steps" %(time.time()-start))
    """

    for i in range(100001):
        # loading images on training
        id = np.random.choice(range(datalen), bs)
        batch_images_x, batch_images_t = batch.getBatch(bs, id)

        tmp, dis_loss = sess.run([
            d_opt,
            d_loss,
        ],
                                 feed_dict={
                                     x: batch_images_x,
                                     t: batch_images_t,
                                     lr: trans_lr,
                                     lmd: trans_lmd
                                 })

        tmp, gen_loss, l1 = sess.run([g_opt, g_loss, L1_loss],
                                     feed_dict={
                                         x: batch_images_x,
                                         t: batch_images_t,
                                         lr: trans_lr,
                                         lmd: trans_lmd
                                     })
        """
        id = np.random.choice(range(datalen),bs)
        batch_images_x, batch_images_t  = batch.getBatch(bs,id,ocp=0.1)
        tmp, gen_loss, l1 = sess.run([g_opt,g_loss, L1_loss], feed_dict={
            x: batch_images_x,
            t: batch_images_t,
            lr:trans_lr,
            lmd:trans_lmd
        })
        """
        if trans_lr > 5e-5:
            trans_lr = trans_lr * 0.99998
        if trans_lmd > 5:
            trans_lmd = trans_lmd * 0.9998

        print("in step %s, dis_loss = %.4e, gen_loss = %.4e, l1_loss= %.4e" %
              (i, dis_loss, gen_loss, l1 * trans_lmd))
        g_hist.append(gen_loss)
        d_hist.append(dis_loss)

        if i % 100 == 0:
            id = np.random.choice(range(vallen), bs)
            batch_images_x, batch_images_t = val.getBatch(bs, id)
            out = sess.run(y, feed_dict={x: batch_images_x})
            X_ = tileImage(batch_images_x[:4])
            Y_ = tileImage(out[:4])
            Z_ = tileImage(batch_images_t[:4])

            X_ = (X_ + 1) * 127.5
            Y_ = (Y_ + 1) * 127.5
            Z_ = (Z_ + 1) * 127.5
            Z_ = np.concatenate((X_, Y_, Z_), axis=1)
            #print(np.max(X_))
            cv2.imwrite("{}/{}.png".format(SVIM_DIR, i), Z_)

            fig = plt.figure(figsize=(8, 6), dpi=128)
            ax = fig.add_subplot(111)
            plt.title("Loss")
            plt.grid(which="both")
            plt.yscale("log")
            ax.plot(g_hist, label="gen_loss", linewidth=0.25)
            ax.plot(d_hist, label="dis_loss", linewidth=0.25)
            plt.xlabel('step', fontsize=16)
            plt.ylabel('loss', fontsize=16)
            plt.legend(loc='upper right')
            plt.savefig("hist.png")
            plt.close()

            print("%.4f sec took per 100steps lmd = %.4e, lr = %.4e" %
                  (time.time() - start, trans_lmd, trans_lr))
            start = time.time()

        if i % 5000 == 0:
            if i > 10000:
                loss_1k_old = np.mean(g_hist[-2000:-1000])
                loss_1k_new = np.mean(g_hist[-1000:])
                print("old loss=%.4e , new loss=%.4e" %
                      (loss_1k_old, loss_1k_new))
                if loss_1k_old * 2 < loss_1k_new:
                    break

            saver.save(sess, os.path.join(SAVE_DIR, "model.ckpt"), i)
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--gpu',
                        type=str,
                        default='0',
                        help='Which GPU to use')
    args = parser.parse_args()
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu

    img_size = 64
    bs = 4
    trans_lr = 1e-4

    start = time.time()

    batchgen = BatchGenerator(img_size=img_size,
                              LRDir=TRAIN_LR_DIR,
                              HRDir=TRAIN_HR_DIR,
                              aug=True)
    valgen = BatchGenerator(img_size=img_size,
                            LRDir=VAL_LR_DIR,
                            HRDir=VAL_HR_DIR,
                            aug=False)

    start = time.time()

    x = tf.placeholder(tf.float32, [bs, img_size, img_size, 3])
    t = tf.placeholder(tf.float32, [bs, img_size * 4, img_size * 4, 3])
    lr = tf.placeholder(tf.float32)

    generator = Generator()

    y = generator.ThermalSR(x)
    test_y = generator.ThermalSR(x, reuse=True, isTraining=False)

    # Contextual loss function
    vgg_real34, vgg_real54 = build_vgg19(t)
    vgg_fake34, vgg_fake54 = build_vgg19(y)
    #vgg_loss = 0.006*(tf.reduce_mean(tf.reduce_mean(tf.square(vgg_real54 - vgg_fake54))))

    CX_loss_content_list = CX_loss_helper(vgg_real34, vgg_fake34, config.CX)
    CX_content_loss = tf.reduce_sum(CX_loss_content_list)
    CX_content_loss *= config.W.CX_content

    L1_loss = tf.losses.absolute_difference(y, t)
    ssim_loss = tf.reduce_mean(tf.image.ssim(y, t, 2.0))

    ssim_loss1 = 1 - ssim_loss
    Total_loss = 10 * L1_loss + 10 * ssim_loss1 + 0.1 * CX_content_loss

    g_loss = tf.train.AdamOptimizer(1e-4, beta1=0.9).minimize(
        Total_loss,
        var_list=[
            x for x in tf.trainable_variables() if "ThermalSR" in x.name
        ])

    print("%.4f sec took building" % (time.time() - start))
    printParam(scope="ThermalSR")

    g_vars = [x for x in tf.trainable_variables() if "ThermalSR" in x.name]

    saver = tf.train.Saver()

    sess = tf.Session()
    sess.run(tf.global_variables_initializer())

    ckpt = tf.train.get_checkpoint_state(SAVEPRE_DIR)

    if ckpt:  # is checkpoint exist
        last_model = ckpt.model_checkpoint_path
        #last_model = ckpt.all_model_checkpoint_paths[0]
        print("load " + last_model)
        saver.restore(sess, last_model)  # read variable data
        print("succeed restore model")
    else:
        init = tf.global_variables_initializer()
        sess.run(init)

    print("%.4e sec took initializing" % (time.time() - start))

    hist = []
    hist_g = []

    start = time.time()
    print("start pretrain")
    for p in range(50001):
        batch_images_x, batch_images_t = batchgen.getBatch(bs)
        tmp, gen_loss, l1, ssim, cx = sess.run(
            [g_loss, Total_loss, L1_loss, ssim_loss, CX_content_loss],
            feed_dict={
                x: batch_images_x,
                t: batch_images_t
            })

        hist.append(gen_loss)
        print(
            "in step %s, pre_loss =%.4e, l1_loss=%.4e, ssim_loss=%.4e, cx_loss=%.4e"
            % (p, gen_loss, l1, ssim, cx))

        if p % 100 == 0:
            batch_images_x, batch_images_t = valgen.getBatch(bs)

            out = sess.run(test_y, feed_dict={x: batch_images_x})
            X_ = tileImage(batch_images_x[:4])
            Y_ = tileImage(out[:4])
            Z_ = tileImage(batch_images_t[:4])

            X_ = cv2.resize(X_, (img_size * 2 * 4, img_size * 2 * 4),
                            interpolation=cv2.INTER_CUBIC)

            X_ = (X_ + 1) * 127.5
            Y_ = (Y_ + 1) * 127.5
            Z_ = (Z_ + 1) * 127.5
            ZZ_ = np.concatenate((X_, Y_, Z_), axis=1)

            #cv2.imwrite("{0}/pre_{1:06d}.png".format(SAVEIM_DIR_lr,int(p)),X_)
            #cv2.imwrite("{0}/pre_{1:06d}.png".format(SAVEIM_DIR_sr,int(p)),Y_)
            #cv2.imwrite("{0}/pre_{1:06d}.png".format(SAVEIM_DIR_hr,int(p)),Z_)
            cv2.imwrite("{0}/pre_{1:06d}.png".format(SAVEIM_DIR, int(p)), ZZ_)

            print("%.4e sec took 100steps" % (time.time() - start))
            start = time.time()

        if p % 1000 == 0:
            fig = plt.figure(figsize=(8, 6), dpi=128)
            ax = fig.add_subplot(111)
            plt.title("Loss")
            plt.grid(which="both")
            plt.yscale("log")
            ax.plot(hist, label="gen_loss", linewidth=0.25)
            plt.xlabel('step', fontsize=16)
            plt.ylabel('loss', fontsize=16)
            plt.legend(loc='upper right')
            plt.savefig("hist_pre_ThermalSR_Axis.png")
            plt.close()

        if p % 5000 == 0 and p != 0:
            saver.save(sess, os.path.join(SAVEPRE_DIR, "model.ckpt"), p)