def train(n_channels=3,
          resolution=32,
          z_dim=128,
          n_labels=0,
          lr=1e-3,
          e_drift=1e-3,
          wgp_target=750,
          initial_resolution=4,
          total_kimg=25000,
          training_kimg=500,
          transition_kimg=500,
          iters_per_checkpoint=500,
          n_checkpoint_images=16,
          glob_str='cifar10',
          out_dir='cifar10'):

    # instantiate logger
    logger = SummaryWriter(out_dir)

    # load data
    batch_size = MINIBATCH_OVERWRITES[0]
    train_iterator = iterate_minibatches(glob_str, batch_size, resolution)

    # build models
    G = Generator(n_channels, resolution, z_dim, n_labels)
    D = Discriminator(n_channels, resolution, n_labels)

    G_train, D_train = GAN(G, D, z_dim, n_labels, resolution, n_channels)

    D_opt = Adam(lr=lr, beta_1=0.0, beta_2=0.99, epsilon=1e-8)
    G_opt = Adam(lr=lr, beta_1=0.0, beta_2=0.99, epsilon=1e-8)

    # define loss functions
    D_loss = [loss_mean, loss_gradient_penalty, 'mse']
    G_loss = [loss_wasserstein]

    # compile graphs used during training
    G.compile(G_opt, loss=loss_wasserstein)
    D.trainable = False
    G_train.compile(G_opt, loss=G_loss)
    D.trainable = True
    D_train.compile(D_opt, loss=D_loss, loss_weights=[1, GP_WEIGHT, e_drift])

    # for computing the loss
    ones = np.ones((batch_size, 1), dtype=np.float32)
    zeros = ones * 0.0

    # fix a z vector for training evaluation
    z_fixed = np.random.normal(0, 1, size=(n_checkpoint_images, z_dim))

    # vars
    resolution_log2 = int(np.log2(resolution))
    starting_block = resolution_log2
    starting_block -= np.floor(np.log2(initial_resolution))
    cur_block = starting_block
    cur_nimg = 0

    # compute duration of each phase and use proxy to update minibatch size
    phase_kdur = training_kimg + transition_kimg
    phase_idx_prev = 0

    # offset variable for transitioning between blocks
    offset = 0
    i = 0
    while cur_nimg < total_kimg * 1000:
        # block processing
        kimg = cur_nimg / 1000.0
        phase_idx = int(np.floor((kimg + transition_kimg) / phase_kdur))
        phase_idx = max(phase_idx, 0.0)
        phase_kimg = phase_idx * phase_kdur

        # update batch size and ones vector if we switched phases
        if phase_idx_prev < phase_idx:
            batch_size = MINIBATCH_OVERWRITES[phase_idx]
            train_iterator = iterate_minibatches(glob_str, batch_size)
            ones = np.ones((batch_size, 1), dtype=np.float32)
            zeros = ones * 0.0
            phase_idx_prev = phase_idx

        # possibly gradually update current level of detail
        if transition_kimg > 0 and phase_idx > 0:
            offset = (kimg + transition_kimg - phase_kimg) / transition_kimg
            offset = min(offset, 1.0)
            offset = offset + phase_idx - 1
            cur_block = max(starting_block - offset, 0.0)

        # update level of detail
        K.set_value(G_train.cur_block, np.float32(cur_block))
        K.set_value(D_train.cur_block, np.float32(cur_block))

        # train D
        for j in range(N_CRITIC_ITERS):
            z = np.random.normal(0, 1, size=(batch_size, z_dim))
            real_batch = next(train_iterator)
            fake_batch = G.predict_on_batch([z])
            interpolated_batch = get_interpolated_images(
                real_batch, fake_batch)
            losses_d = D_train.train_on_batch(
                [real_batch, fake_batch, interpolated_batch],
                [ones, ones * wgp_target, zeros])
            cur_nimg += batch_size

        # train G
        z = np.random.normal(0, 1, size=(batch_size, z_dim))
        loss_g = G_train.train_on_batch(z, -1 * ones)

        logger.add_scalar("cur_block", cur_block, i)
        logger.add_scalar("learning_rate", lr, i)
        logger.add_scalar("batch_size", z.shape[0], i)
        print("iter", i, "cur_block", cur_block, "lr", lr, "kimg", kimg,
              "losses_d", losses_d, "loss_g", loss_g)
        if (i % iters_per_checkpoint) == 0:
            G.trainable = False
            fake_images = G.predict(z_fixed)
            # log fake images
            log_images(fake_images, 'fake', i, logger, fake_images.shape[1],
                       fake_images.shape[2], int(np.sqrt(n_checkpoint_images)))

            # plot real images for reference
            log_images(real_batch[:n_checkpoint_images], 'real', i, logger,
                       real_batch.shape[1], real_batch.shape[2],
                       int(np.sqrt(n_checkpoint_images)))

            # save the model to eventually resume training or do inference
            save_model(G, out_dir + "/model.json", out_dir + "/model.h5")

        log_losses(losses_d, loss_g, i, logger)
        i += 1
Example #2
0
def train():
    parser = argparse.ArgumentParser(description="keras pix2pix")
    parser.add_argument('--batchsize', '-b', type=int, default=1)
    parser.add_argument('--patchsize', '-p', type=int, default=64)
    parser.add_argument('--epoch', '-e', type=int, default=500)
    parser.add_argument('--out', '-o', default='result')
    parser.add_argument('--lmd', '-l', type=int, default=100)
    parser.add_argument('--dark', '-d', type=float, default=0.01)
    parser.add_argument('--gpu', '-g', type=int, default=2)
    args = parser.parse_args()
    args = parser.parse_args()
    PATCH_SIZE = args.patchsize
    BATCH_SIZE = args.batchsize
    epoch = args.epoch
    lmd = args.lmd

    # set gpu environment
    os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    K.set_session(sess)

    # make directory to save results
    if not os.path.exists("./result"):
        os.mkdir("./result")
    resultDir = "./result/" + args.out
    modelDir = resultDir + "/model/"
    if not os.path.exists(resultDir):
        os.mkdir(resultDir)
    if not os.path.exists(modelDir):
        os.mkdir(modelDir)

    # make a logfile and add colnames
    o = open(resultDir + "/log.txt", "w")
    o.write("batch:" + str(BATCH_SIZE) + "  lambda:" + str(lmd) + "\n")
    o.write(
        "epoch,dis_loss,gan_mae,gan_entropy,vdis_loss,vgan_mae,vgan_entropy" +
        "\n")
    o.close()

    # load data
    ds1_first, ds1_last, num_ds1 = 1, 1145, 1145
    ds2_first, ds2_last, num_ds2 = 2000, 6749, 4750
    # ds1_first, ds1_last, num_ds1 = 1,    100, 100
    # ds2_first, ds2_last, num_ds2 = 101, 200, 100
    train_data_i = np.concatenate([
        np.arange(ds1_first, ds1_last + 1)[:int(num_ds1 * 0.7)],
        np.arange(ds2_first, ds2_last + 1)[:int(num_ds2 * 0.7)]
    ])
    test_data_i = np.concatenate([
        np.arange(ds1_first, ds1_last + 1)[int(num_ds1 * 0.7):],
        np.arange(ds2_first, ds2_last + 1)[int(num_ds2 * 0.7):]
    ])
    train_gt, _, train_night = load_dataset(data_range=train_data_i,
                                            dark=args.dark)
    test_gt, _, test_night = load_dataset(data_range=test_data_i,
                                          dark=args.dark)

    # Create optimizers
    opt_Gan = Adam(lr=1E-3)
    opt_Discriminator = Adam(lr=1E-3)
    opt_Generator = Adam(lr=1E-3)

    # set the loss of gan
    def dis_entropy(y_true, y_pred):
        return -K.log(K.abs((y_pred - y_true)) + 1e-07)

    gan_loss = ['mae', dis_entropy]
    gan_loss_weights = [lmd, 1]

    # make models
    Generator = generator()
    Generator.compile(loss='mae', optimizer=opt_Generator)
    Discriminator = discriminator()
    Discriminator.trainable = False
    Gan = GAN(Generator, Discriminator)
    Gan.compile(loss=gan_loss,
                loss_weights=gan_loss_weights,
                optimizer=opt_Gan)
    Discriminator.trainable = True
    Discriminator.compile(loss=dis_entropy, optimizer=opt_Discriminator)

    # start training
    n_train = train_gt.shape[0]
    n_test = test_gt.shape[0]
    print(n_train, n_test)
    p = ProgressBar()
    for epoch in p(range(epoch)):
        p.update(epoch + 1)
        out_file = open(resultDir + "/log.txt", "a")
        train_ind = np.random.permutation(n_train)
        test_ind = np.random.permutation(n_test)
        dis_losses = []
        gan_losses = []
        test_dis_losses = []
        test_gan_losses = []
        y_real = np.array([1] * BATCH_SIZE)
        y_fake = np.array([0] * BATCH_SIZE)
        y_gan = np.array([1] * BATCH_SIZE)

        # training
        for batch_i in range(int(n_train / BATCH_SIZE)):
            gt_batch = train_gt[train_ind[(batch_i *
                                           BATCH_SIZE):((batch_i + 1) *
                                                        BATCH_SIZE)], :, :, :]
            night_batch = train_night[train_ind[(
                batch_i * BATCH_SIZE):((batch_i + 1) * BATCH_SIZE)], :, :, :]
            generated_batch = Generator.predict(night_batch)
            # train Discriminator
            dis_real_loss = np.array(
                Discriminator.train_on_batch([night_batch, gt_batch], y_real))
            dis_fake_loss = np.array(
                Discriminator.train_on_batch([night_batch, generated_batch],
                                             y_fake))
            dis_loss_batch = (dis_real_loss + dis_fake_loss) / 2
            dis_losses.append(dis_loss_batch)
            gan_loss_batch = np.array(
                Gan.train_on_batch(night_batch, [gt_batch, y_gan]))
            gan_losses.append(gan_loss_batch)
        dis_loss = np.mean(np.array(dis_losses))
        gan_loss = np.mean(np.array(gan_losses), axis=0)

        # validation
        for batch_i in range(int(n_test / BATCH_SIZE)):
            gt_batch = test_gt[test_ind[(batch_i *
                                         BATCH_SIZE):((batch_i + 1) *
                                                      BATCH_SIZE)], :, :, :]
            night_batch = test_night[test_ind[(
                batch_i * BATCH_SIZE):((batch_i + 1) * BATCH_SIZE)], :, :, :]
            generated_batch = Generator.predict(night_batch)
            # train Discriminator
            dis_real_loss = np.array(
                Discriminator.test_on_batch([night_batch, gt_batch], y_real))
            dis_fake_loss = np.array(
                Discriminator.test_on_batch([night_batch, generated_batch],
                                            y_fake))
            test_dis_loss_batch = (dis_real_loss + dis_fake_loss) / 2
            test_dis_losses.append(test_dis_loss_batch)
            test_gan_loss_batch = np.array(
                Gan.test_on_batch(night_batch, [gt_batch, y_gan]))
            test_gan_losses.append(test_gan_loss_batch)
        test_dis_loss = np.mean(np.array(test_dis_losses))
        test_gan_loss = np.mean(np.array(gan_losses), axis=0)
        # write log of leaning
        out_file.write(
            str(epoch) + "," + str(dis_loss) + "," + str(gan_loss[1]) + "," +
            str(gan_loss[2]) + "," + str(test_dis_loss) + "," +
            str(test_gan_loss[1]) + "," + str(test_gan_loss[2]) + "\n")

        # visualize
        if epoch % 50 == 0:
            # for training data
            gt_batch = train_gt[train_ind[0:9], :, :, :]
            night_batch = train_night[train_ind[0:9], :, :, :]
            generated_batch = Generator.predict(night_batch)
            save_images(night_batch,
                        resultDir + "/label_" + str(epoch) + "epoch.png")
            save_images(gt_batch,
                        resultDir + "/gt_" + str(epoch) + "epoch.png")
            save_images(generated_batch,
                        resultDir + "/generated_" + str(epoch) + "epoch.png")
            # for validation data
            gt_batch = test_gt[test_ind[0:9], :, :, :]
            night_batch = test_night[test_ind[0:9], :, :, :]
            generated_batch = Generator.predict(night_batch)
            save_images(night_batch,
                        resultDir + "/vlabel_" + str(epoch) + "epoch.png")
            save_images(gt_batch,
                        resultDir + "/vgt_" + str(epoch) + "epoch.png")
            save_images(generated_batch,
                        resultDir + "/vgenerated_" + str(epoch) + "epoch.png")

            Gan.save_weights(modelDir + 'gan_weights' + "_lambda" + str(lmd) +
                             "_epoch" + str(epoch) + '.h5')

        out_file.close()
    out_file.close()