Ejemplo n.º 1
0
def train():
    G = get_G((batch_size, 96, 96, 3))
    D = get_D((batch_size, 384, 384, 3))
    VGG = tl.models.vgg19(pretrained=True, end_with='pool4', mode='static')

    lr_v = tf.Variable(lr_init)
    g_optimizer_init = tf.optimizers.Adam(lr_v, beta_1=beta1)
    g_optimizer = tf.optimizers.Adam(lr_v, beta_1=beta1)
    d_optimizer = tf.optimizers.Adam(lr_v, beta_1=beta1)

    G.train()
    D.train()
    VGG.train()

    train_ds = get_train_data()

    ## initialize learning (G)
    n_step_epoch = round(n_epoch_init // batch_size)
    for epoch in range(n_epoch_init):
        for step, (lr_patchs, hr_patchs) in enumerate(train_ds):
            if lr_patchs.shape[0] != batch_size: # if the remaining data in this epoch < batch_size
                break
            step_time = time.time()
            with tf.GradientTape() as tape:
                fake_hr_patchs = G(lr_patchs)
                mse_loss = tl.cost.mean_squared_error(fake_hr_patchs, hr_patchs, is_mean=True)
            grad = tape.gradient(mse_loss, G.trainable_weights)
            g_optimizer_init.apply_gradients(zip(grad, G.trainable_weights))
            print("Epoch: [{}/{}] step: [{}/{}] time: {:.3f}s, mse: {:.3f} ".format(
                epoch, n_epoch_init, step, n_step_epoch, time.time() - step_time, mse_loss))
			if (epoch == n_epoch_init):
				tl.vis.save_images(fake_hr_patchs.numpy(), [2, 4], os.path.join(save_dir1, 'train_g_init_{}.png'.format(epoch)))
Ejemplo n.º 2
0
def train():
    G = get_G((batch_size, 96, 96, 3))
    D = get_D((batch_size, 384, 384, 3))
    VGG = tl.models.vgg19(pretrained=True, end_with='pool4', mode='static')

    lr_v = tf.Variable(lr_init)
    g_optimizer_init = tf.optimizers.Adam(lr_v, beta_1=beta1)
    g_optimizer = tf.optimizers.Adam(lr_v, beta_1=beta1)
    d_optimizer = tf.optimizers.Adam(lr_v, beta_1=beta1)

    G.train()
    D.train()
    VGG.train()

    train_ds = get_train_data()

    ## initialize learning (G)
    n_step_epoch = round(n_epoch_init // batch_size)
    for epoch in range(n_epoch_init):
        for step, (lr_patchs, hr_patchs) in enumerate(train_ds):
            if lr_patchs.shape[0] != batch_size: # if the remaining data in this epoch < batch_size
                break
            step_time = time.time()
            with tf.GradientTape() as tape:
                fake_hr_patchs = G(lr_patchs)
                mse_loss = tl.cost.mean_squared_error(fake_hr_patchs, hr_patchs, is_mean=True)
            grad = tape.gradient(mse_loss, G.trainable_weights)
            g_optimizer_init.apply_gradients(zip(grad, G.trainable_weights))
            print("Epoch: [{}/{}] step: [{}/{}] time: {:.3f}s, mse: {:.3f} ".format(
                epoch, n_epoch_init, step, n_step_epoch, time.time() - step_time, mse_loss))
        if (epoch != 0) and (epoch % 10 == 0):
            tl.vis.save_images(fake_hr_patchs.numpy(), [2, 4], os.path.join(save_dir, 'train_g_init_{}.png'.format(epoch)))

    ## adversarial learning (G, D)
    n_step_epoch = round(n_epoch // batch_size)
    for epoch in range(n_epoch):
        for step, (lr_patchs, hr_patchs) in enumerate(train_ds):
            if lr_patchs.shape[0] != batch_size: # if the remaining data in this epoch < batch_size
                break
            step_time = time.time()
            with tf.GradientTape(persistent=True) as tape:
                fake_patchs = G(lr_patchs)
                logits_fake = D(fake_patchs)
                logits_real = D(hr_patchs)
                feature_fake = VGG((fake_patchs+1)/2.) # the pre-trained VGG uses the input range of [0, 1]
                feature_real = VGG((hr_patchs+1)/2.)
                d_loss1 = tl.cost.sigmoid_cross_entropy(logits_real, tf.ones_like(logits_real))
                d_loss2 = tl.cost.sigmoid_cross_entropy(logits_fake, tf.zeros_like(logits_fake))
                d_loss = d_loss1 + d_loss2
                g_gan_loss = 1e-3 * tl.cost.sigmoid_cross_entropy(logits_fake, tf.ones_like(logits_fake))
                mse_loss = tl.cost.mean_squared_error(fake_patchs, hr_patchs, is_mean=True)
                vgg_loss = 2e-6 * tl.cost.mean_squared_error(feature_fake, feature_real, is_mean=True)
                g_loss = mse_loss + vgg_loss + g_gan_loss
            grad = tape.gradient(g_loss, G.trainable_weights)
            g_optimizer.apply_gradients(zip(grad, G.trainable_weights))
            grad = tape.gradient(d_loss, D.trainable_weights)
            d_optimizer.apply_gradients(zip(grad, D.trainable_weights))
            print("Epoch: [{}/{}] step: [{}/{}] time: {:.3f}s, g_loss(mse:{:.3f}, vgg:{:.3f}, adv:{:.3f}) d_loss: {:.3f}".format(
                epoch, n_epoch_init, step, n_step_epoch, time.time() - step_time, mse_loss, vgg_loss, g_gan_loss, d_loss))

        # update the learning rate
        if epoch != 0 and (epoch % decay_every == 0):
            new_lr_decay = lr_decay**(epoch // decay_every)
            lr_v.assign(lr_init * new_lr_decay)
            log = " ** new learning rate: %f (for GAN)" % (lr_init * new_lr_decay)
            print(log)

        if (epoch != 0) and (epoch % 10 == 0):
            tl.vis.save_images(fake_patchs.numpy(), [2, 4], os.path.join(save_dir, 'train_g_{}.png'.format(epoch)))
            G.save_weights(os.path.join(checkpoint_dir, 'g.h5'))
            D.save_weights(os.path.join(checkpoint_dir, 'd.h5'))
def train():
    # create folders to save result images and trained model
    save_dir_ginit = "samples/{}_ginit".format(tl.global_flag['mode'])
    save_dir_gan = "samples/{}_gan".format(tl.global_flag['mode'])
    tl.files.exists_or_mkdir(save_dir_ginit)
    tl.files.exists_or_mkdir(save_dir_gan)
    checkpoint_dir = "checkpoint"  # checkpoint_resize_conv
    tl.files.exists_or_mkdir(checkpoint_dir)

    # load dataset
    train_hr_img_list = sorted(
        tl.files.load_file_list(path=config.TRAIN.hr_img_path,
                                regx='.*.png',
                                printable=False))
    # train_lr_img_list = sorted(tl.files.load_file_list(path=config.TRAIN.lr_img_path, regx='.*.png', printable=False))
    # valid_hr_img_list = sorted(tl.files.load_file_list(path=config.VALID.hr_img_path, regx='.*.png', printable=False))
    # valid_lr_img_list = sorted(tl.files.load_file_list(path=config.VALID.lr_img_path, regx='.*.png', printable=False))

    ## If your machine have enough memory, please pre-load the whole train set.
    train_hr_imgs = tl.vis.read_images(train_hr_img_list,
                                       path=config.TRAIN.hr_img_path,
                                       n_threads=32)

    # dataset API and augmentation
    def generator_train():
        for img in train_hr_imgs:
            yield img

    def _map_fn_train(img):
        hr_patch = tf.image.random_crop(img, [384, 384, 3])
        hr_patch = hr_patch / (255. / 2.)
        hr_patch = hr_patch - 1.
        hr_patch = tf.image.random_flip_left_right(hr_patch)
        lr_patch = tf.image.resize(hr_patch, size=[96, 96])
        return lr_patch, hr_patch

    train_ds = tf.data.Dataset.from_generator(generator_train,
                                              output_types=(tf.float32))
    train_ds = train_ds.map(_map_fn_train,
                            num_parallel_calls=multiprocessing.cpu_count())
    train_ds = train_ds.repeat(n_epoch_init + n_epoch)
    train_ds = train_ds.shuffle(shuffle_buffer_size)
    train_ds = train_ds.prefetch(buffer_size=4096)
    train_ds = train_ds.batch(batch_size)
    # value = train_ds.make_one_shot_iterator().get_next()

    # obtain models
    G = get_G((batch_size, None, None, 3))  # (None, 96, 96, 3)
    D = get_D((batch_size, None, None, 3))  # (None, 384, 384, 3)
    VGG = tl.models.vgg19(pretrained=True, end_with='pool4', mode='static')

    print(G)
    print(D)
    print(VGG)

    # G.load_weights(checkpoint_dir + '/g_{}.h5'.format(tl.global_flag['mode'])) # in case you want to restore a training
    # D.load_weights(checkpoint_dir + '/d_{}.h5'.format(tl.global_flag['mode']))

    lr_v = tf.Variable(lr_init)
    g_optimizer_init = tf.optimizers.Adam(
        lr_v, beta_1=beta1)  #.minimize(mse_loss, var_list=g_vars)
    g_optimizer = tf.optimizers.Adam(
        lr_v, beta_1=beta1)  #.minimize(g_loss, var_list=g_vars)
    d_optimizer = tf.optimizers.Adam(
        lr_v, beta_1=beta1)  #.minimize(d_loss, var_list=d_vars)

    G.train()
    D.train()
    VGG.train()

    # initialize learning (G)
    n_step_epoch = round(n_epoch_init // batch_size)
    for step, (lr_patchs, hr_patchs) in enumerate(train_ds):
        step_time = time.time()
        with tf.GradientTape() as tape:
            fake_hr_patchs = G(lr_patchs)
            mse_loss = tl.cost.mean_squared_error(fake_hr_patchs,
                                                  hr_patchs,
                                                  is_mean=True)
        grad = tape.gradient(mse_loss, G.weights)
        g_optimizer_init.apply_gradients(zip(grad, G.weights))
        step += 1
        epoch = step // n_step_epoch
        print("Epoch: [{}/{}] step: [{}/{}] time: {}s, mse: {} ".format(
            epoch, n_epoch_init, step, n_step_epoch,
            time.time() - step_time, mse_loss))
        if (epoch != 0) and (epoch % 10 == 0):
            tl.vis.save_images(
                fake_hr_patchs.numpy(), [ni, ni],
                save_dir_gan + '/train_g_init_{}.png'.format(epoch))

    # adversarial learning (G, D)
    n_step_epoch = round(n_epoch // batch_size)
    for step, (lr_patchs, hr_patchs) in train_ds:
        with tf.GradientTape() as tape:
            fake_patchs = G(lr_patchs)
            logits_fake = D(fake_patchs)
            logits_real = D(hr_patchs)
            feature_fake = VGG((fake_patchs + 1) / 2.)
            feature_real = VGG((hr_patchs + 1) / 2.)
            d_loss1 = tl.cost.sigmoid_cross_entropy(logits_real,
                                                    tf.ones_like(logits_real))
            d_loss2 = tl.cost.sigmoid_cross_entropy(logits_fake,
                                                    tf.zeros_like(logits_fake))
            d_loss = d_loss1 + d_loss2
            g_gan_loss = 1e-3 * tl.cost.sigmoid_cross_entropy(
                logits_fake, tf.ones_like(logits_fake))
            mse_loss = tl.cost.mean_squared_error(fake_patchs,
                                                  hr_patchs,
                                                  is_mean=True)
            vgg_loss = 2e-6 * tl.cost.mean_squared_error(
                feature_fake, feature_real, is_mean=True)
            g_loss = mse_loss + vgg_loss + g_gan_loss
        grad = tape.gradient(g_loss, G.trainable_weights)
        g_optimizer.apply_gradients(zip(grad, G.trainable_weights))
        grad = tape.gradient(d_loss, D.weights)
        d_optimizer.apply_gradients(zip(grad, D.trainable_weights))
        step += 1
        epoch = step // n_step_epoch
        print(
            "Epoch: [{}/{}] step: [{}/{}] time: {}s, g_loss(mse:{}, vgg:{}, adv:{}) d_loss: {}"
            .format(epoch, n_epoch_init, step, n_step_epoch,
                    time.time() - step_time, mse_loss, vgg_loss, g_gan_loss,
                    d_loss))

        # update learning rate
        if epoch != 0 and (epoch % decay_every == 0):
            new_lr_decay = lr_decay**(epoch // decay_every)
            lr_v.assign(lr_init * new_lr_decay)
            log = " ** new learning rate: %f (for GAN)" % (lr_init *
                                                           new_lr_decay)
            print(log)

        if (epoch != 0) and (epoch % 10 == 0):
            tl.vis.save_images(fake_hr_patchs.numpy(), [ni, ni],
                               save_dir_gan + '/train_g_{}.png'.format(epoch))
            G.save_weights(checkpoint_dir +
                           '/g_{}.h5'.format(tl.global_flag['mode']))
            D.save_weights(checkpoint_dir +
                           '/d_{}.h5'.format(tl.global_flag['mode']))
Ejemplo n.º 4
0
def train():
    images, images_path = get_Chairs(
        flags.output_size, flags.n_epoch, flags.batch_size)
    G = get_G([None, flags.z_dim])
    D = get_D([None, flags.output_size, flags.output_size, flags.n_channel])
    Q = get_Q([None, 1024])

    G.train()
    D.train()
    Q.train()

    g_optimizer = tf.optimizers.Adam(
        learning_rate=flags.G_learning_rate, beta_1=0.5)
    d_optimizer = tf.optimizers.Adam(
        learning_rate=flags.D_learning_rate, beta_1=0.5)

    n_step_epoch = int(len(images_path) // flags.batch_size)
    his_g_loss = []
    his_d_loss = []
    his_mutual = []
    count = 0

    for epoch in range(flags.n_epoch):
        for step, batch_images in enumerate(images):
            count += 1
            if batch_images.shape[0] != flags.batch_size:
                break
            step_time = time.time()
            with tf.GradientTape(persistent=True) as tape:
                noise, cat1, cat2, cat3, con = gen_noise()
                fake_logits, mid = D(G(noise))
                real_logits, _ = D(batch_images)
                f_cat1, f_cat2, f_cat3, f_mu = Q(mid)

                # base = tf.random.normal(shape=f_mu.shape)
                # f_con = f_mu + base * tf.exp(f_sigma)
                d_loss_fake = tl.cost.sigmoid_cross_entropy(
                    output=fake_logits, target=tf.zeros_like(fake_logits), name='d_loss_fake')
                d_loss_real = tl.cost.sigmoid_cross_entropy(
                    output=real_logits, target=tf.ones_like(real_logits), name='d_loss_real')
                d_loss = d_loss_fake + d_loss_real

                g_loss_tmp = tl.cost.sigmoid_cross_entropy(
                    output=fake_logits, target=tf.ones_like(fake_logits), name='g_loss_fake')

                mutual_disc = calc_disc_mutual(
                    f_cat1, f_cat2, f_cat3, cat1, cat2, cat3)
                mutual_cont = calc_cont_mutual(f_mu, con)
                mutual = (flags.disc_lambda*mutual_disc +
                          flags.cont_lambda*mutual_cont)
                g_loss = mutual + g_loss_tmp
                d_tr = d_loss + mutual

            grads = tape.gradient(
                g_loss, G.trainable_weights + Q.trainable_weights)  # 一定要可求导
            g_optimizer.apply_gradients(
                zip(grads, G.trainable_weights + Q.trainable_weights))
            grads = tape.gradient(
                d_tr, D.trainable_weights)
            d_optimizer.apply_gradients(
                zip(grads, D.trainable_weights))
            del tape

            print("Epoch: [{}/{}] [{}/{}] took: {}, d_loss: {:.5f}, g_loss: {:.5f}, mutual: {:.5f}".format(
                epoch, flags.n_epoch, step, n_step_epoch, time.time()-step_time, d_loss, g_loss, mutual))

            if count % flags.save_every_it == 1:
                his_g_loss.append(g_loss)
                his_d_loss.append(d_loss)
                his_mutual.append(mutual)

        plt.plot(his_d_loss)
        plt.plot(his_g_loss)
        plt.plot(his_mutual)
        plt.legend(['D_Loss', 'G_Loss', 'Mutual_Info'])
        plt.xlabel(f'Iterations / {flags.save_every_it}')
        plt.ylabel('Loss')
        plt.savefig(f'{flags.result_dir}/loss.jpg')
        plt.clf()
        plt.close()

        G.save_weights(f'{flags.checkpoint_dir}/G.npz', format='npz')
        D.save_weights(f'{flags.checkpoint_dir}/D.npz', format='npz')
        G.eval()
        for k in range(flags.n_samples):
            z = gen_eval_noise(flags.save_every_epoch, flags.n_samples)
            result = G(z)
            tl.visualize.save_images(result.numpy(), [
                                     flags.save_every_epoch, flags.n_samples], f'result/train_{epoch}_{k}.png')
        G.train()
Ejemplo n.º 5
0
def main(config):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    data_loader = get_dataloader(config)
    G = get_G()
    D = get_D()
    G.to(device)
    D.to(device)
    G = nn.DataParallel(G)
    D = nn.DataParallel(D)
    optimizer_G = torch.optim.Adam(G.parameters(),
                                   lr=config.lr,
                                   betas=(0.5, 0.999))
    optimizer_D = torch.optim.Adam(D.parameters(),
                                   lr=config.lr,
                                   betas=(0.5, 0.999))

    criterionGAN = GANLoss(use_lsgan=True).to(device)
    criterionL1 = nn.L1Loss()

    make_dir(config.result_dir)
    make_dir(config.sample_dir)
    make_dir(config.model_dir)
    make_dir(config.log_dir)
    total_steps = 0

    for epoch in range(config.epoch_count,
                       config.niter + config.niter_decay + 1):
        SAVE_IMAGE_DIR = "{}/{}".format(config.sample_dir, epoch)
        make_dir(SAVE_IMAGE_DIR)

        for i, (real_A, real_B) in enumerate(
                DataLoader(data_loader,
                           batch_size=config.batch_size,
                           shuffle=True)):

            real_A = real_A.to(device)
            real_B = real_B.to(device)

            ### Making fake B image
            fake_B = G(real_A)

            ### Update D
            ## Set gradients
            set_requires_grad(D, True)
            ## Optimizer D
            optimizer_D.zero_grad()
            ## Backward
            # Fake
            pred_fake = D(fake_B.detach())
            loss_D_fake = criterionGAN(pred_fake, False)
            # Real
            pred_real = D(real_B.detach())
            loss_D_real = criterionGAN(pred_real, True)
            # Conbined loss
            loss_D = (loss_D_fake + loss_D_real) * 0.5
            loss_D.backward()
            ## Optimizer step
            optimizer_D.step()

            ### Update G
            ## Set gradients
            set_requires_grad(D, False)
            ## Optimizer G
            optimizer_G.zero_grad()
            ## Backward
            pred_fake = D(fake_B)
            loss_G_GAN = criterionGAN(pred_fake, True)
            loss_G_L1 = criterionL1(fake_B, real_B) * config.lambda_L1
            loss_G = loss_G_GAN + loss_G_L1
            loss_G.backward()
            ## Optimizer step
            optimizer_G.step()

            if total_steps % config.print_freq == 0:
                # if total_steps % 1 == 0:
                # Print
                print(
                    "Loss D Fake:{:.4f}, D Real:{:.4f}, D Total:{:.4f}, G GAN:{:.4f}, G L1:{:.4f}. G Total:{:.4f}"
                    .format(loss_D_fake, loss_D_real, loss_D, loss_G_GAN,
                            loss_G_L1, loss_G))
                # Save image
                save_image(fake_B, "{}/{}.png".format(SAVE_IMAGE_DIR, i))
            total_steps += 1

        if epoch % config.save_epoch_freq == 0:
            # Save model
            print("Save models in {} epochs".format(epoch))
            save_checkpoint("{}/D_{}.pth".format(config.model_dir, epoch), D,
                            optimizer_D)
            save_checkpoint("{}/G_{}.pth".format(config.model_dir, epoch), G,
                            optimizer_G)
Ejemplo n.º 6
0
def train_adv():
     #with tf.device('/cpu:0'):
    
        
    '''initialize model'''
    G = model.get_G((config.batch_size_adv, 56, 56, 3))
    D = model.get_D((config.batch_size_adv, 224, 224, 3))
    vgg22 = model.VGG16_22((config.batch_size_adv, 224, 224, 3))
    
    G.load_weights(os.path.join(config.path_model, 'g_init.h5'))
    '''optimizer'''
    #g_optimizer_init = tf.optimizers.Adam(learning_rate=0.001)
    g_optimizer = tf.optimizers.Adam(learning_rate=0.0001)
    d_optimizer = tf.optimizers.Adam(learning_rate=0.0001)

    G.train()
    D.train()
    vgg22.train()
    train_ds, train_len = config.get_train_data(config.batch_size_adv, step = config.train_step_adv, start = 0)
    print('训练集一共有{}张图片'.format(train_len))
    '''initialize generator with L1 loss in pixel spase'''
    
    '''train with GAN and vgg16-22 loss'''
    n_step_epoch = round(train_len // config.batch_size_adv)
    for epoch in range(config.n_epoch_adv):
        #一个epoch累计损失,初始化为0
        mse_ls=0; vgg_ls=0; gan_ls=0; d_ls=0
        #计数
        i=0
        for step, (lr_patchs, hr_patchs) in enumerate(train_ds):
            step_time = time.time()
            with tf.GradientTape(persistent=True) as tape:
                  
                fake_patchs = G(lr_patchs)
                feature22_fake = vgg22(fake_patchs) # the pre-trained VGG uses the input range of [0, 1]
                feature22_real = vgg22(hr_patchs)
                logits_fake = D(fake_patchs)
                logits_real = D(hr_patchs)
                
                #g_vgg_loss = 2e-3 * tl.cost.mean_squared_error(feature22_fake, feature22_real, is_mean=True)
                #g_gan_loss = -tf.reduce_mean(logits_fake)  
                d_loss1 = tl.cost.sigmoid_cross_entropy(logits_real, tf.ones_like(logits_real))
                d_loss2 = tl.cost.sigmoid_cross_entropy(logits_fake, tf.zeros_like(logits_fake))
                d_loss = d_loss1 + d_loss2
                
                g_gan_loss = 1e-3 * tl.cost.sigmoid_cross_entropy(logits_fake, tf.ones_like(logits_fake))
                mse_loss = tl.cost.mean_squared_error(fake_patchs, hr_patchs, is_mean=True)
                vgg_loss = 1e-4 * tl.cost.mean_squared_error(feature22_fake, feature22_real, is_mean=True)
                g_loss = mse_loss + vgg_loss + g_gan_loss
                
                mse_ls+=mse_loss
                vgg_ls+=vgg_loss
                gan_ls+=g_gan_loss
                d_ls+=d_loss
                i+=1
                
                ''' WGAN-gp 未完成
                d_loss = tf.reduce_mean(logits_fake) - tf.reduce_mean(logits_real)
                g_loss = g_vgg_loss + g_gan_loss 
                eps = tf.random.uniform([batch_size, 1, 1, 1], minval=0., maxval=1.)
                interpolates  = eps*hr_patchs + (1. - eps)*fake_patchs
                grad = tape.gradient(D(interpolates), interpolates)
                slopes = tf.sqrt(tf.reduce_sum(tf.square(grad), axis=[1,2,3]))
                gradient_penalty = 0.1*tf.reduce_mean((slopes-1.)**2)
                
                d_loss += gradient_penalty
                '''
            grad = tape.gradient(g_loss, G.trainable_weights)
            g_optimizer.apply_gradients(zip(grad, G.trainable_weights))
            grad = tape.gradient(d_loss, D.trainable_weights)
            d_optimizer.apply_gradients(zip(grad, D.trainable_weights))
            del(tape)
            print("Epoch: [{}/{}] step: [{}/{}] time: {:.3f}s, g_loss(mse:{:.5f}, vgg:{:.5f}, adv:{:.5f}), d_loss: {:.5f}".format(
                        epoch, config.n_epoch_adv, step, n_step_epoch, time.time() - step_time, mse_loss, vgg_loss, g_gan_loss, d_loss))
        print('~~~~~~~~~~~~Epoch {}平均损失~~~~~~~~~~~~~~~~'.format(epoch))
        print("Epoch: [{}/{}] time: {:.3f}s, g_loss(mse:{:.5f}, vgg:{:.5f}, adv:{:.5f}), d_loss: {:.5f}".format(
                        epoch, config.n_epoch_adv, time.time() - step_time, mse_ls/i, vgg_ls/i, gan_ls/i, d_ls/i))
        G.save_weights(os.path.join(config.path_model, 'g_adv.h5')) 
        print('\n')
Ejemplo n.º 7
0
def train():
    images, images_path = get_celebA(flags.output_size, flags.n_epoch,
                                     flags.batch_size)
    G = get_G([None, flags.dim_z])
    Base = get_base(
        [None, flags.output_size, flags.output_size, flags.n_channel])
    D = get_D([None, 4096])
    Q = get_Q([None, 4096])

    G.train()
    Base.train()
    D.train()
    Q.train()

    g_optimizer = tf.optimizers.Adam(learning_rate=flags.G_learning_rate,
                                     beta_1=flags.beta_1)
    d_optimizer = tf.optimizers.Adam(learning_rate=flags.D_learning_rate,
                                     beta_1=flags.beta_1)

    n_step_epoch = int(len(images_path) // flags.batch_size)
    his_g_loss = []
    his_d_loss = []
    his_mutual = []
    count = 0

    for epoch in range(flags.n_epoch):
        for step, batch_images in enumerate(images):
            count += 1
            if batch_images.shape[0] != flags.batch_size:
                break
            step_time = time.time()
            with tf.GradientTape(persistent=True) as tape:
                z, c = gen_noise()
                fake = Base(G(z))
                fake_logits = D(fake)
                fake_cat = Q(fake)
                real_logits = D(Base(batch_images))

                d_loss_fake = tl.cost.sigmoid_cross_entropy(
                    output=fake_logits,
                    target=tf.zeros_like(fake_logits),
                    name='d_loss_fake')
                d_loss_real = tl.cost.sigmoid_cross_entropy(
                    output=real_logits,
                    target=tf.ones_like(real_logits),
                    name='d_loss_real')
                d_loss = d_loss_fake + d_loss_real

                g_loss = tl.cost.sigmoid_cross_entropy(
                    output=fake_logits,
                    target=tf.ones_like(fake_logits),
                    name='g_loss_fake')

                mutual = calc_mutual(fake_cat, c)
                g_loss += mutual

            grad = tape.gradient(g_loss,
                                 G.trainable_weights + Q.trainable_weights)
            g_optimizer.apply_gradients(
                zip(grad, G.trainable_weights + Q.trainable_weights))
            grad = tape.gradient(d_loss,
                                 D.trainable_weights + Base.trainable_weights)
            d_optimizer.apply_gradients(
                zip(grad, D.trainable_weights + Base.trainable_weights))
            del tape
            print(
                f"Epoch: [{epoch}/{flags.n_epoch}] [{step}/{n_step_epoch}] took: {time.time()-step_time:.3f}, d_loss: {d_loss:.5f}, g_loss: {g_loss:.5f}, mutual: {mutual:.5f}"
            )

            if count % flags.save_every_it == 1:
                his_g_loss.append(g_loss)
                his_d_loss.append(d_loss)
                his_mutual.append(mutual)

        plt.plot(his_d_loss)
        plt.plot(his_g_loss)
        plt.plot(his_mutual)
        plt.legend(['D_Loss', 'G_Loss', 'Mutual_Info'])
        plt.xlabel(f'Iterations / {flags.save_every_it}')
        plt.ylabel('Loss')
        plt.savefig(f'{flags.result_dir}/loss.jpg')
        plt.clf()
        plt.close()

        G.save_weights(f'{flags.checkpoint_dir}/G.npz', format='npz')
        D.save_weights(f'{flags.checkpoint_dir}/D.npz', format='npz')
        G.eval()
        for k in range(flags.n_categorical):
            z = gen_eval_noise(k, flags.n_sample)
            result = G(z)
            tl.visualize.save_images(convert(result.numpy()),
                                     [flags.n_sample, flags.dim_categorical],
                                     f'result/train_{epoch}_{k}.png')
        G.train()
Ejemplo n.º 8
0
def train():
	size = [1080, 1920]
	aspect_ratio = size[1] / size[0]

	G = get_G((batch_size, 96, 96, 3))
	D = get_D((batch_size, 384, 384, 3))
	VGG = tl.models.vgg19(pretrained=True, end_with='pool4', mode='static')

	lr_v = tf.Variable(lr_init)
	g_optimizer_init = tf.optimizers.Adam(lr_v, beta_1=beta1)
	g_optimizer = tf.optimizers.Adam(lr_v, beta_1=beta1)
	d_optimizer = tf.optimizers.Adam(lr_v, beta_1=beta1)

	G.train()
	D.train()
	VGG.train()

	train_ds, test_ds, sample_ds = get_train_data()

	sample_folders = ['train_lr', 'train_hr', 'train_gen', 'test_lr', 'test_hr', 'test_gen', 'sample_lr', 'sample_gen']
	for sample_folder in sample_folders:
		tl.files.exists_or_mkdir(os.path.join(save_dir, sample_folder))
	
	# only take a certain amount of images to save
	test_lr_patchs, test_hr_patchs = next(iter(test_ds))
	valid_lr_imgs = []
	for i,lr_patchs in enumerate(sample_ds):
		valid_lr_img = lr_patchs.numpy()
		valid_lr_img = np.asarray(valid_lr_img, dtype=np.float32)
		valid_lr_img = valid_lr_img[np.newaxis,:,:,:]
		valid_lr_imgs.append(valid_lr_img)
		tl.vis.save_images(valid_lr_img, [1,1], os.path.join(save_dir, 'sample_lr', 'sample_lr_img_{}.jpg'.format(i)))

	tl.vis.save_images(test_lr_patchs.numpy(), [2, 4], os.path.join(save_dir, 'test_lr', 'test_lr.jpg'))
	tl.vis.save_images(test_hr_patchs.numpy(), [2, 4], os.path.join(save_dir, 'test_hr', 'test_hr.jpg'))

	# initialize learning (G)
	n_step_epoch = round(iteration_size // batch_size)
	for epoch in range(n_epoch_init):
		for step, (lr_patchs, hr_patchs) in enumerate(train_ds):
			if lr_patchs.shape[0] != batch_size: # if the remaining data in this epoch < batch_size
				break
			step_time = time.time()
			with tf.GradientTape() as tape:
				fake_hr_patchs = G(lr_patchs)
				mse_loss = tl.cost.mean_squared_error(fake_hr_patchs, hr_patchs, is_mean=True)
			grad = tape.gradient(mse_loss, G.trainable_weights)
			g_optimizer_init.apply_gradients(zip(grad, G.trainable_weights))
			print("Epoch: [{}/{}] step: [{}/{}] time: {:.3f}s, mse: {:.3f} ".format(
				epoch, n_epoch_init, step, n_step_epoch, time.time() - step_time, mse_loss))
		if (epoch != 0) and (epoch % 10 == 0):
			# save training result examples
			tl.vis.save_images(lr_patchs.numpy(), [2, 4], os.path.join(save_dir, 'train_lr', 'train_lr_init_{}.jpg'.format(epoch)))
			tl.vis.save_images(hr_patchs.numpy(), [2, 4], os.path.join(save_dir, 'train_hr', 'train_hr_init_{}.jpg'.format(epoch)))
			tl.vis.save_images(fake_hr_patchs.numpy(), [2, 4], os.path.join(save_dir, 'train_gen', 'train_gen_init_{}.jpg'.format(epoch)))
			# save test results (only save generated, since it's always the same images. Inputs are saved before the training loop)
			fake_hr_patchs = G(test_lr_patchs)
			tl.vis.save_images(fake_hr_patchs.numpy(), [2, 4], os.path.join(save_dir, 'test_gen', 'test_gen_init_{}.jpg'.format(epoch)))
			# save sample results (only save generated, since it's always the same images. Inputs are saved before the training loop)
			for i,lr_patchs in enumerate(valid_lr_imgs):
				fake_hr_patchs = G(lr_patchs)
				tl.vis.save_images(fake_hr_patchs.numpy(), [1,1], os.path.join(save_dir, 'sample_gen', 'sample_gen_init_{}_img_{}.jpg'.format(epoch, i)))

	## adversarial learning (G, D)
	n_step_epoch = round(iteration_size // batch_size)
	for epoch in range(n_epoch):
		for step, (lr_patchs, hr_patchs) in enumerate(train_ds):
			if lr_patchs.shape[0] != batch_size: # if the remaining data in this epoch < batch_size
				break
			step_time = time.time()
			with tf.GradientTape(persistent=True) as tape:
				fake_patchs = G(lr_patchs)
				logits_fake = D(fake_patchs)
				logits_real = D(hr_patchs)
				feature_fake = VGG((fake_patchs+1)/2.) # the pre-trained VGG uses the input range of [0, 1]
				feature_real = VGG((hr_patchs+1)/2.)
				d_loss1 = tl.cost.sigmoid_cross_entropy(logits_real, tf.ones_like(logits_real))
				d_loss2 = tl.cost.sigmoid_cross_entropy(logits_fake, tf.zeros_like(logits_fake))
				d_loss = d_loss1 + d_loss2
				g_gan_loss = 1e-3 * tl.cost.sigmoid_cross_entropy(logits_fake, tf.ones_like(logits_fake))
				mse_loss = tl.cost.mean_squared_error(fake_patchs, hr_patchs, is_mean=True)
				vgg_loss = 2e-6 * tl.cost.mean_squared_error(feature_fake, feature_real, is_mean=True)
				g_loss = mse_loss + vgg_loss + g_gan_loss
			grad = tape.gradient(g_loss, G.trainable_weights)
			g_optimizer.apply_gradients(zip(grad, G.trainable_weights))
			grad = tape.gradient(d_loss, D.trainable_weights)
			d_optimizer.apply_gradients(zip(grad, D.trainable_weights))
			print("Epoch: [{}/{}] step: [{}/{}] time: {:.3f}s, g_loss(mse:{:.3f}, vgg:{:.3f}, adv:{:.3f}) d_loss: {:.3f}".format(
				epoch, n_epoch, step, n_step_epoch, time.time() - step_time, mse_loss, vgg_loss, g_gan_loss, d_loss))

		# update the learning rate
		if epoch != 0 and (epoch % decay_every == 0):
			new_lr_decay = lr_decay**(epoch // decay_every)
			lr_v.assign(lr_init * new_lr_decay)
			log = " ** new learning rate: %f (for GAN)" % (lr_init * new_lr_decay)
			print(log)

		if (epoch != 0) and (epoch % 10 == 0):
			# save training result examples
			tl.vis.save_images(lr_patchs.numpy(), [2, 4], os.path.join(save_dir, 'train_lr', 'train_lr_{}.jpg'.format(epoch)))
			tl.vis.save_images(hr_patchs.numpy(), [2, 4], os.path.join(save_dir, 'train_hr', 'train_hr_{}.jpg'.format(epoch)))
			tl.vis.save_images(fake_patchs.numpy(), [2, 4], os.path.join(save_dir, 'train_gen', 'train_gen_{}.jpg'.format(epoch)))
			# save test results (only save generated, since it's always the same images. Inputs are saved before the training loop)
			fake_hr_patchs = G(test_lr_patchs)
			tl.vis.save_images(fake_hr_patchs.numpy(), [2, 4], os.path.join(save_dir, 'test_gen', 'test_gen_{}.jpg'.format(epoch)))
			# save sample results (only save generated, since it's always the same images. Inputs are saved before the training loop)
			# for i,lr_patchs in enumerate(valid_lr_imgs):
			# 	fake_hr_patchs = G(lr_patchs)
			# 	tl.vis.save_images(fake_hr_patchs.numpy(), [1,1], os.path.join(save_dir, 'sample_gen', 'sample_gen_init_{}_img_{}.jpg'.format(epoch, i)))


			G.save_weights(os.path.join(checkpoint_dir, f'g_epoch_{epoch}.h5'))
			D.save_weights(os.path.join(checkpoint_dir, f'd_epoch_{epoch}.h5'))

			G.save_weights(os.path.join(checkpoint_dir, 'g.h5'))
			D.save_weights(os.path.join(checkpoint_dir, 'd.h5'))