Exemple #1
0
    def __init__(self,
                 gnetA,
                 gnetB,
                 dnetA,
                 dnetB,
                 train_set_A,
                 train_set_B,
                 val_set_A,
                 val_set_B,
                 G_optimizer,
                 D_A_optimizer,
                 D_B_optimizer,
                 stats_manager,
                 output_dir=None,
                 batch_size=4,
                 perform_validation_during_training=False):

        # Define data loaders
        train_loader_A = td.DataLoader(train_set_A,
                                       batch_size=batch_size,
                                       shuffle=True,
                                       drop_last=True,
                                       pin_memory=True)
        train_loader_B = td.DataLoader(train_set_B,
                                       batch_size=batch_size,
                                       shuffle=True,
                                       drop_last=True,
                                       pin_memory=True)
        val_loader_A = td.DataLoader(val_set_A,
                                     batch_size=batch_size,
                                     shuffle=False,
                                     drop_last=True,
                                     pin_memory=True)
        val_loader_B = td.DataLoader(val_set_B,
                                     batch_size=batch_size,
                                     shuffle=False,
                                     drop_last=True,
                                     pin_memory=True)

        # Initialize history
        history = []

        # fake images store
        fakeA_store = util.ImagePool(50)
        fakeB_store = util.ImagePool(50)

        # Define checkpoint paths
        if output_dir is None:
            output_dir = 'experiment_{}'.format(time.time())
        os.makedirs(output_dir, exist_ok=True)
        checkpoint_path = os.path.join(output_dir, "checkpoint.pth.tar")
        config_path = os.path.join(output_dir, "config.txt")

        # Transfer all local arguments/variables into attributes
        locs = {k: v for k, v in locals().items() if k is not 'self'}
        self.__dict__.update(locs)

        # Load checkpoint and check compatibility
        if os.path.isfile(config_path):
            with open(config_path, 'r') as f:
                if f.read()[:-1] != repr(self):
                    raise ValueError(
                        "Cannot create this experiment: "
                        "I found a checkpoint conflicting with the current setting."
                    )
            self.load()
        else:
            self.save()
Exemple #2
0
# Adam optimizer
G_optimizer = optim.Adam(itertools.chain(G_A.parameters(), G_B.parameters()),
                         lr=opt.lrG,
                         betas=(opt.beta1, opt.beta2))
D_A_optimizer = optim.Adam(D_A.parameters(),
                           lr=opt.lrD,
                           betas=(opt.beta1, opt.beta2))
D_B_optimizer = optim.Adam(D_B.parameters(),
                           lr=opt.lrD,
                           betas=(opt.beta1, opt.beta2))
H_A_optimizer = optim.Adam(H_A.parameters(),
                           lr=opt.lrH,
                           betas=(opt.beta1, opt.beta2))

# image store
fakeA_store = util.ImagePool(50)
fakeB_store = util.ImagePool(50)

train_hist = {}
train_hist['D_A_losses'] = []
train_hist['D_B_losses'] = []
train_hist['G_A_losses'] = []
train_hist['G_B_losses'] = []
train_hist['A_cycle_losses'] = []
train_hist['B_cycle_losses'] = []
train_hist['H_A_Hash_losses'] = []
train_hist['H_B_Hash_losses'] = []
train_hist['A_Cons_losses'] = []
train_hist['per_epoch_ptimes'] = []
train_hist['total_ptime'] = []
def training_loop(dataloader_X, dataloader_Y, test_dataloader_X,
                  test_dataloader_Y, opts):

    #Initialize generators, discriminators, and optimizers
    G_XtoY, G_YtoX, D_X, D_Y, g_optimizer, dx_optimizer, dy_optimizer = load_checkpoint(
        opts)

    iter_X = iter(dataloader_X)
    iter_Y = iter(dataloader_Y)

    test_iter_X = iter(test_dataloader_X)
    test_iter_Y = iter(test_dataloader_Y)

    # Set fixed data from domains X and Y for sampling. These are images that are held constant throughout training, that allow us to inspect the model's performance.

    fixed_X = utils.to_var(test_iter_X.next()[0])
    fixed_Y = utils.to_var(test_iter_Y.next()[0])

    iter_per_epoch = min(len(iter_X), len(iter_Y))

    # loss terms
    MSE_loss = nn.MSELoss().cuda()
    L1_loss = nn.L1Loss().cuda()

    # image store (used to stabilize discriminator training)
    fake_X_store = util.ImagePool(50)
    fake_Y_store = util.ImagePool(50)

    for iteration in range(1, opts.train_iters + 1):
        # Reset data_iter for each epoch
        if iteration % iter_per_epoch == 0:
            iter_X = iter(dataloader_X)
            iter_Y = iter(dataloader_Y)

        images_X, labels_X = iter_X.next()
        images_X, labels_X = utils.to_var(images_X), utils.to_var(
            labels_X).long().squeeze()

        images_Y, labels_Y = iter_Y.next()
        images_Y, labels_Y = utils.to_var(images_Y), utils.to_var(
            labels_Y).long().squeeze()

        #### GENERATOR TRAINING ####
        g_optimizer.zero_grad()

        # 1. GAN loss term
        fake_X = G_YtoX(images_Y)
        fake_Y = G_XtoY(images_X)

        d_x_pred = D_X(fake_X)
        d_y_pred = D_Y(fake_Y)

        #want d_x_pred to be as close to 1(real) as possible
        gan_loss = MSE_loss(
            d_x_pred, Variable(torch.ones(d_x_pred.size()).cuda())) + MSE_loss(
                d_y_pred, Variable(torch.ones(d_y_pred.size()).cuda()))

        #2. Identity loss term
        identity_X = G_YtoX(images_X)
        identity_Y = G_XtoY(images_Y)

        identity_loss = L1_loss(images_X, identity_X) + L1_loss(
            images_Y, identity_Y)

        #3. Cycle consistency loss term
        reconstructed_Y = G_XtoY(fake_X)
        reconstructed_X = G_YtoX(fake_Y)

        cycle_consistency_loss = L1_loss(images_X, reconstructed_X) + L1_loss(
            images_Y, reconstructed_Y)

        #Final GAN Loss Term
        g_loss = gan_loss + opts.identity_lambda * identity_loss + opts.cycle_consistency_lambda * cycle_consistency_loss

        g_loss.backward()
        g_optimizer.step()

        #### DISCRIMINATOR TRAINING ####

        # 1. Compute the discriminator x loss
        dx_optimizer.zero_grad()

        d_x_pred = D_X(images_X)
        D_X_real_loss = MSE_loss(d_x_pred,
                                 Variable(torch.ones(d_x_pred.size()).cuda()))

        fake_X = fake_X_store.query(fake_X)
        d_x_pred = D_X(fake_X)
        D_X_fake_loss = MSE_loss(d_x_pred,
                                 Variable(torch.zeros(d_x_pred.size()).cuda()))

        D_X_loss = (D_X_real_loss + D_X_fake_loss) * .5
        D_X_loss.backward()
        dx_optimizer.step()

        #2. Compute the discriminator y loss
        dy_optimizer.zero_grad()

        d_y_pred = D_X(images_Y)
        D_Y_real_loss = MSE_loss(d_y_pred,
                                 Variable(torch.ones(d_y_pred.size()).cuda()))

        fake_Y = fake_Y_store.query(fake_Y)
        d_y_pred = D_Y(fake_Y)
        D_Y_fake_loss = MSE_loss(d_y_pred,
                                 Variable(torch.zeros(d_y_pred.size()).cuda()))

        D_Y_loss = (D_Y_real_loss + D_Y_fake_loss) * .5
        D_Y_loss.backward()
        dy_optimizer.step()

        # Print the log info
        if iteration % opts.log_step == 0:
            print(
                'Iteration [{:5d}/{:5d}] | d_Y_loss: {:6.4f} | d_X_loss: {:6.4f} | g_loss: {:6.4f}'
                .format(iteration, opts.train_iters, D_Y_loss.item(),
                        D_X_loss.item(), g_loss.item()))

        # Save the generated samples
        if iteration % opts.sample_every == 0:
            save_samples(iteration, fixed_Y, fixed_X, G_YtoX, G_XtoY, opts)

        # Save the model parameters
        if iteration % opts.checkpoint_every == 0:
            checkpoint(iteration, G_XtoY, G_YtoX, D_X, D_Y, g_optimizer,
                       dx_optimizer, dy_optimizer, opts)
Exemple #4
0
def train_one2one(model_path):
    print('Please wait. It takes several minutes. Do not quit!')

    G_scope = 'translator_X_to_Y'
    DX_scope = 'discriminator_X'
    DY_scope = 'discriminator_Y'

    with tf.device('/device:CPU:0'):
        X_IN = tf.placeholder(tf.float32, [batch_size, input_height, input_width, num_channel])
        Y_IN = tf.placeholder(tf.float32, [batch_size, input_height, input_width, num_channel])
        X_FAKE_IN = tf.placeholder(tf.float32, [batch_size, input_height, input_width, num_channel])
        Y_FAKE_IN = tf.placeholder(tf.float32, [batch_size, input_height, input_width, num_channel])
        LR = tf.placeholder(tf.float32, None)
        b_train = tf.placeholder(tf.bool)

    # Launch the graph in a session
    config = tf.ConfigProto(allow_soft_placement=True)
    config.gpu_options.allow_growth = True

    with tf.device('/device:GPU:1'):
        fake_Y = translator(X_IN, activation='swish', norm='instance', b_train=b_train, scope=G_scope,
                            use_upsample=False)
        recon_X = translator(fake_Y, activation='swish', norm='instance', b_train=b_train, scope=G_scope,
                             use_upsample=False)
        fake_X = translator(Y_IN, activation='swish', norm='instance', b_train=b_train, scope=G_scope,
                            use_upsample=False)
        recon_Y = translator(fake_X, activation='swish', norm='instance', b_train=b_train, scope=G_scope,
                             use_upsample=False)

        if use_identity_loss is True:
            id_X = translator(X_IN, activation='swish', norm='instance', b_train=b_train, scope=G_scope,
                              use_upsample=False)
            id_Y = translator(Y_IN, activation='swish', norm='instance', b_train=b_train, scope=G_scope,
                              use_upsample=False)

    with tf.device('/device:GPU:0'):
        _, X_FAKE_IN_logit = discriminator(X_FAKE_IN, activation='swish', norm='instance', b_train=b_train,
                                           scope=DX_scope, use_patch=True)
        _, Y_FAKE_IN_logit = discriminator(Y_FAKE_IN, activation='swish', norm='instance', b_train=b_train,
                                           scope=DY_scope, use_patch=True)
        _, real_X_logit = discriminator(X_IN, activation='swish', norm='instance', b_train=b_train,
                                        scope=DX_scope, use_patch=True)
        _, fake_X_logit = discriminator(fake_X, activation='swish', norm='instance', b_train=b_train,
                                        scope=DX_scope, use_patch=True)
        _, real_Y_logit = discriminator(Y_IN, activation='swish', norm='instance', b_train=b_train,
                                        scope=DY_scope, use_patch=True)
        _, fake_Y_logit = discriminator(fake_Y, activation='swish', norm='instance', b_train=b_train,
                                        scope=DY_scope, use_patch=True)

    reconstruction_loss_Y = get_residual_loss(Y_IN, recon_Y, type='l1') + get_gradient_loss(Y_IN, recon_Y)
    reconstruction_loss_X = get_residual_loss(X_IN, recon_X, type='l1') + get_gradient_loss(X_IN, recon_X)
    cyclic_loss = reconstruction_loss_Y + reconstruction_loss_X
    alpha = 10.0
    cyclic_loss = alpha * cyclic_loss

    if gan_mode == 'ls':
        # LS GAN
        trans_loss_X2Y = tf.reduce_mean((fake_Y_logit - tf.ones_like(fake_Y_logit)) ** 2) + alpha * reconstruction_loss_X
        trans_loss_Y2X = tf.reduce_mean((fake_X_logit - tf.ones_like(fake_X_logit)) ** 2) + alpha * reconstruction_loss_Y
        disc_loss_Y = get_discriminator_loss(real_Y_logit, tf.ones_like(real_Y_logit), type='ls') + \
                      get_discriminator_loss(Y_FAKE_IN_logit, tf.zeros_like(Y_FAKE_IN_logit), type='ls')
        disc_loss_X = get_discriminator_loss(real_X_logit, tf.ones_like(real_X_logit), type='ls') + \
                      get_discriminator_loss(X_FAKE_IN_logit, tf.zeros_like(X_FAKE_IN_logit), type='ls')
    else:
        # for WGAN
        trans_loss_X2Y = -tf.reduce_mean(fake_Y_logit)
        trans_loss_Y2X = -tf.reduce_mean(fake_X_logit)
        disc_loss_Y, _, _ = get_discriminator_loss(real_Y_logit, Y_FAKE_IN_logit, type='wgan')
        disc_loss_X, _, _ = get_discriminator_loss(real_X_logit, X_FAKE_IN_logit, type='wgan')

    if use_identity_loss is True:
        identity_loss_Y = alpha * (get_residual_loss(Y_IN, id_Y, type='l1') + get_gradient_loss(Y_IN, id_Y))
        identity_loss_X = alpha * (get_residual_loss(X_IN, id_X, type='l1') + get_gradient_loss(X_IN, id_X))
        identity_loss = 0.5 * (identity_loss_X + identity_loss_Y)
        total_trans_loss = trans_loss_X2Y + trans_loss_Y2X + cyclic_loss + identity_loss
    else:
        total_trans_loss = trans_loss_X2Y + trans_loss_Y2X + cyclic_loss

    beta = 0.5
    total_disc_loss = beta * disc_loss_Y + beta * disc_loss_X

    disc_Y_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=DY_scope)
    disc_X_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=DX_scope)
    disc_vars = disc_Y_vars + disc_X_vars

    trans_X2Y_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=G_scope)
    trans_vars = trans_X2Y_vars

    if gan_mode == 'wgan':
        # Alert: Clip range is critical to WGAN.
        disc_weight_clip = [p.assign(tf.clip_by_value(p, -0.01, 0.01)) for p in disc_vars]

    with tf.device('/device:GPU:0'):
        disc_optimizer_X = tf.train.AdamOptimizer(learning_rate=LR).minimize(disc_loss_X, var_list=disc_X_vars)
        disc_optimizer_Y = tf.train.AdamOptimizer(learning_rate=LR).minimize(disc_loss_Y, var_list=disc_Y_vars)

    with tf.device('/device:GPU:1'):
        trans_optimizer_X2Y = tf.train.AdamOptimizer(learning_rate=LR).minimize(trans_loss_X2Y, var_list=trans_vars)
        trans_optimizer_Y2X = tf.train.AdamOptimizer(learning_rate=LR).minimize(trans_loss_Y2X, var_list=trans_vars)

    # Launch the graph in a session
    with tf.Session(config=config) as sess:
        sess.run(tf.global_variables_initializer())

        try:
            saver = tf.train.Saver()
            saver.restore(sess, model_path)
            print('Model Restored')
        except:
            print('Start New Training. Wait ...')

        trX_dir = os.path.join(train_data, 'X').replace("\\", "/")
        trY_dir = os.path.join(train_data, 'Y').replace("\\", "/")
        trX = os.listdir(trX_dir)
        trY = os.listdir(trY_dir)
        total_input_size = min(len(trX), len(trY))

        num_augmentations = 1  # How many augmentations per 1 sample
        file_batch_size = batch_size // num_augmentations

        if file_batch_size == 0:
            file_batch_size = 1

        num_critic = 1
        if gan_mode == 'wgan':
            num_critic = 5

        image_pool = util.ImagePool(maxsize=50)
        learning_rate = 2e-4
        lr_decay_step = 100

        for e in range(num_epoch):
            trX = shuffle(trX)
            tfY = shuffle(trY)

            training_batch = zip(range(0, total_input_size, file_batch_size),
                                 range(file_batch_size, total_input_size + 1, file_batch_size))
            itr = 0
            if e > lr_decay_step:
                learning_rate = learning_rate * (num_epoch - e)/(num_epoch - lr_decay_step)

            for start, end in training_batch:
                imgs_X = load_images(trX[start:end], base_dir=trX_dir, use_augmentation=False)
                if len(imgs_X[0].shape) != 3:
                    imgs_X = np.expand_dims(imgs_X, axis=3)
                imgs_Y = load_images(trY[start:end], base_dir=trY_dir, use_augmentation=False)
                if len(imgs_Y[0].shape) != 3:
                    imgs_Y = np.expand_dims(imgs_Y, axis=3)

                trans_X2Y, trans_Y2X = sess.run([fake_Y, fake_X], feed_dict={X_IN: imgs_X, Y_IN: imgs_Y, b_train: True})
                pool_X2Y, pool_Y2X = image_pool([trans_X2Y, trans_Y2X])

                _, dx_loss = sess.run([disc_optimizer_X, disc_loss_X],
                                     feed_dict={X_IN: imgs_X, X_FAKE_IN: pool_Y2X, b_train: True, LR: learning_rate})
                _, dy_loss = sess.run([disc_optimizer_Y, disc_loss_Y],
                                     feed_dict={Y_IN: imgs_Y, Y_FAKE_IN: pool_X2Y, b_train: True, LR: learning_rate})

                if gan_mode == 'wgan':
                    _ = sess.run([disc_weight_clip])

                if itr % num_critic == 0:
                    _, tx2y_loss = sess.run([trans_optimizer_X2Y, trans_loss_X2Y],
                                         feed_dict={X_IN: imgs_X, b_train: True, LR: learning_rate})
                    _, ty2x_loss = sess.run([trans_optimizer_Y2X, trans_loss_Y2X],
                                         feed_dict={Y_IN: imgs_Y, b_train: True, LR: learning_rate})

                    print('epoch: ' + str(e) + ', d_loss: ' + str(dx_loss + dy_loss) +
                          ', t_loss: ' + str(tx2y_loss + ty2x_loss))
                    decoded_images_Y2X = np.squeeze(trans_Y2X)
                    decoded_images_X2Y = np.squeeze(trans_X2Y)
                    cv2.imwrite('imgs/Y2X_' + trY[start], (decoded_images_Y2X * 128.0) + 128.0)
                    cv2.imwrite('imgs/X2Y_' + trX[start], (decoded_images_X2Y * 128.0) + 128.0)
                itr += 1

                if itr % 200 == 0:
                    try:
                        print('Saving model...')
                        saver.save(sess, model_path)
                        print('Saved.')
                    except:
                        print('Save failed')
            try:
                print('Saving model...')
                saver.save(sess, model_path)
                print('Saved.')
            except:
                print('Save failed')