Example #1
0
def main():
    # Parameters
    train_data = './datasets/facades/train/'
    display_data = './datasets/facades/val/'
    start = 0
    stop = 400
    save_samples = False
    shuffle_ = True
    use_h5py = 0
    batchSize = 4
    loadSize = 286
    fineSize = 256
    flip = True
    ngf = 64
    ndf = 64
    input_nc = 3
    output_nc = 3
    num_epoch = 1001
    training_method = 'adam'
    lr_G = 0.0002
    lr_D = 0.0002
    beta1 = 0.5

    task = 'facades'
    name = 'pan'
    which_direction = 'BtoA'
    preprocess = 'regular'
    begin_save = 700
    save_freq = 100
    show_freq = 20
    continue_train = 0
    use_PercepGAN = 1
    use_Pix = 'No'
    which_netG = 'unet_nodrop'
    which_netD = 'basic'
    lam_pix = 25.
    lam_p1 = 5.
    lam_p2 = 1.5
    lam_p3 = 1.5
    lam_p4 = 1.
    lam_gan_d = 1.
    lam_gan_g = 1.
    m = 3.0
    test_deterministic = True

    kD = 1
    kG = 1
    save_model_D = False
    # Load the dataset
    print("Loading data...")
    if which_direction == 'AtoB':
        tra_input, tra_output, _ = pix2pix(
            data_path=train_data,
            img_shape=[input_nc, loadSize, loadSize],
            save=save_samples,
            start=start,
            stop=stop)
        dis_input, dis_output, _ = pix2pix(
            data_path=display_data,
            img_shape=[input_nc, fineSize, fineSize],
            save=False,
            start=0,
            stop=4)
        dis_input = processing_img(dis_input,
                                   center=True,
                                   scale=True,
                                   convert=False)
    elif which_direction == 'BtoA':
        tra_output, tra_input, _ = pix2pix(
            data_path=train_data,
            img_shape=[input_nc, loadSize, loadSize],
            save=save_samples,
            start=start,
            stop=stop)
        dis_output, dis_input, _ = pix2pix(
            data_path=display_data,
            img_shape=[input_nc, fineSize, fineSize],
            save=False,
            start=0,
            stop=4)
        dis_input = processing_img(dis_input,
                                   center=True,
                                   scale=True,
                                   convert=False)
    ids = range(0, stop - start)

    ntrain = len(ids)

    # Prepare Theano variables for inputs and targets
    input_x = T.tensor4('input_x')
    input_y = T.tensor4('input_y')

    # Create neural network model
    print("Building model and compiling functions...")
    if which_netG == 'unet':
        generator = models.build_generator_unet(input_x, ngf=ngf)
    elif which_netG == 'unet_nodrop':
        generator = models.build_generator_unet_nodrop(input_x, ngf=ngf)
    elif which_netG == 'unet_1.0':
        generator = models.build_generator_unet_1(input_x, ngf=ngf)
    elif which_netG == 'unet_facades':
        generator = models.build_generator_facades(input_x, ngf=ngf)
    else:
        print('waiting to fill')

    if use_PercepGAN == 1:
        if which_netD == 'basic':
            discriminator = models.build_discriminator(ndf=ndf)
        else:
            print('waiting to fill')

    # Create expression for passing generator
    gen_imgs = lasagne.layers.get_output(generator)

    if use_PercepGAN == 1:
        # Create expression for passing real data through the discriminator
        dis1_f, dis2_f, dis3_f, dis4_f, disout_f = lasagne.layers.get_output(
            discriminator, input_y)
        # Create expression for passing fake data through the discriminator
        dis1_ff, dis2_ff, dis3_ff, dis4_ff, disout_ff = lasagne.layers.get_output(
            discriminator, gen_imgs)

        p1 = lam_p1 * T.mean(T.abs_(dis1_f - dis1_ff))
        p2 = lam_p2 * T.mean(T.abs_(dis2_f - dis2_ff))
        p3 = lam_p3 * T.mean(T.abs_(dis3_f - dis3_ff))
        p4 = lam_p4 * T.mean(T.abs_(dis4_f - dis4_ff))

        l2_norm = p1 + p2 + p3 + p4

        percepgan_dis_loss = lam_gan_d * (
            lasagne.objectives.binary_crossentropy(disout_f, 0.9) + lasagne.
            objectives.binary_crossentropy(disout_ff, 0)).mean() + T.maximum(
                (T.constant(m) - l2_norm), T.constant(0.))
        percepgan_gen_loss = -lam_gan_g * (
            lasagne.objectives.binary_crossentropy(disout_ff,
                                                   0)).mean() + l2_norm
    else:
        l2_norm = T.constant(0)
        percepgan_dis_loss = T.constant(0)
        percepgan_gen_loss = T.constant(0)

    if use_Pix == 'L1':
        pixel_loss = lam_pix * T.mean(abs(gen_imgs - input_y))
    elif use_Pix == 'L2':
        pixel_loss = lam_pix * T.mean(T.sqr(gen_imgs - input_y))
    else:
        pixel_loss = T.constant(0)

    # Create loss expressions
    generator_loss = percepgan_gen_loss + pixel_loss
    discriminator_loss = percepgan_dis_loss

    # Create update expressions for training
    generator_params = lasagne.layers.get_all_params(generator, trainable=True)
    if training_method == 'adam':
        g_updates = lasagne.updates.adam(generator_loss,
                                         generator_params,
                                         learning_rate=lr_G,
                                         beta1=beta1)
    elif training_method == 'nm':
        g_updates = lasagne.updates.nesterov_momentum(generator_loss,
                                                      generator_params,
                                                      learning_rate=lr_G,
                                                      momentum=beta1)

    # Compile a function performing a training step on a mini-batch (by giving
    # the updates dictionary) and returning the corresponding training loss:
    train_g = theano.function(
        [input_x, input_y],
        [p1, p2, p3, p4, l2_norm, generator_loss, pixel_loss],
        updates=g_updates)

    if use_PercepGAN == 1:
        discriminator_params = lasagne.layers.get_all_params(discriminator,
                                                             trainable=True)
        if training_method == 'adam':
            d_updates = lasagne.updates.adam(discriminator_loss,
                                             discriminator_params,
                                             learning_rate=lr_D,
                                             beta1=beta1)
        elif training_method == 'nm':
            d_updates = lasagne.updates.nesterov_momentum(discriminator_loss,
                                                          discriminator_params,
                                                          learning_rate=lr_D,
                                                          momentum=beta1)
        train_d = theano.function([input_x, input_y],
                                  [l2_norm, discriminator_loss],
                                  updates=d_updates)
        dis_fn = theano.function([input_x, input_y], [(disout_f > .5).mean(),
                                                      (disout_ff < .5).mean()])
    # Compile another function generating some data
    gen_fn = theano.function([input_x],
                             lasagne.layers.get_output(
                                 generator, deterministic=test_deterministic))

    # Finally, launch the training loop.
    print("Starting training...")

    desc = task + '_' + name
    print desc

    f_log = open('logs/%s.ndjson' % desc, 'wb')
    log_fields = [
        'NE',
        'sec',
        'px',
        '1',
        '2',
        '3',
        '4',
        'pd',
        'cd',
        'pg',
        'cg',
        'fr',
        'tr',
    ]

    if not os.path.isdir('generated_imgs/' + desc):
        os.mkdir(os.path.join('generated_imgs/', desc))
    if not os.path.isdir('models/' + desc):
        os.mkdir(os.path.join('models/', desc))

    t = time()
    # We iterate over epochs:
    for epoch in range(num_epoch):
        if shuffle_ is True:
            ids = shuffle_data(ids)
        n_updates_g = 0
        n_updates_d = 0
        percep_d = 0
        percep_g = 0
        cost_g = 0
        cost_d = 0
        pixel = 0
        train_batches = 0
        k = 0
        p1 = 0
        p2 = 0
        p3 = 0
        p4 = 0
        for index_ in iter_data(ids, size=batchSize):
            index = sorted(index_)
            xmb = tra_input[index, :, :, :]
            ymb = tra_output[index, :, :, :]

            if preprocess == 'regular':
                xmb, ymb = pix2pixBatch(xmb,
                                        ymb,
                                        fineSize,
                                        input_nc,
                                        flip=flip)
            elif task == 'inpainting':
                print('waiting to fill')
            elif task == 'cartoon':
                print('waiting to fill')

            if n_updates_g == 0:
                imsave('other/%s_input' % desc,
                       convert_img_back(xmb[0, :, :, :]),
                       format='png')
                imsave('other/%s_GT' % desc,
                       convert_img_back(ymb[0, :, :, :]),
                       format='png')

            xmb = processing_img(xmb, center=True, scale=True, convert=False)
            ymb = processing_img(ymb, center=True, scale=True, convert=False)

            if use_PercepGAN == 1:
                if k < kD:
                    percep, cost = train_d(xmb, ymb)
                    percep_d += percep
                    cost_d += cost
                    n_updates_d += 1
                    k += 1
                elif k < kD + kG:
                    pp1, pp2, pp3, pp4, percep, cost, pix = train_g(xmb, ymb)
                    p1 += pp1
                    p2 += pp2
                    p3 += pp3
                    p4 += pp4
                    percep_g += percep
                    cost_g += cost
                    pixel += pix
                    n_updates_g += 1
                    k += 1
                elif k == kD + kG:
                    percep, cost = train_d(xmb, ymb)
                    percep_d += percep
                    cost_d += cost
                    n_updates_d += 1
                    pp1, pp2, pp3, pp4, percep, cost, pix = train_g(xmb, ymb)
                    p1 += pp1
                    p2 += pp2
                    p3 += pp3
                    p4 += pp4
                    percep_g += percep
                    cost_g += cost
                    pixel += pix
                    n_updates_g += 1
                if k == kD + kG:
                    k = 0
            else:
                pp1, pp2, pp3, pp4, percep, cost, pix = train_g(xmb, ymb)
                p1 += pp1
                p2 += pp2
                p3 += pp3
                p4 += pp4
                percep_g += percep
                cost_g += cost
                pixel += pix
                n_updates_g += 1

        if epoch % show_freq == 0:
            p1 = p1 / n_updates_g
            p2 = p2 / n_updates_g
            p3 = p3 / n_updates_g
            p4 = p4 / n_updates_g

            percep_g = percep_g / n_updates_g
            percep_d = percep_d / (n_updates_d + 0.0001)

            cost_g = cost_g / n_updates_g
            cost_d = cost_d / (n_updates_d + 0.0001)

            pixel = pixel / n_updates_g

            true_rate = -1
            fake_rate = -1
            if use_PercepGAN == 1:
                true_rate, fake_rate = dis_fn(xmb, ymb)

            log = [
                epoch,
                round(time() - t, 2),
                round(pixel, 2),
                round(p1, 2),
                round(p2, 2),
                round(p3, 2),
                round(p4, 2),
                round(percep_d, 2),
                round(cost_d, 2),
                round(percep_g, 2),
                round(cost_g, 2),
                round(float(fake_rate), 2),
                round(float(true_rate), 2)
            ]
            print '%.0f %.2f %.2f %.2f %.2f %.2f% .2f %.2f %.2f %.2f% .2f %.2f' % (
                epoch, p1, p2, p3, p4, percep_d, cost_d, pixel, percep_g,
                cost_g, fake_rate, true_rate)

            t = time()
            f_log.write(json.dumps(dict(zip(log_fields, log))) + '\n')
            f_log.flush()

            gen_imgs = gen_fn(dis_input)

            blank_image = Image.new("RGB",
                                    (fineSize * 4 + 5, fineSize * 2 + 3))
            pc = 0
            for i in range(2):
                for ii in range(4):
                    if i == 0:
                        img = dis_input[ii, :, :, :]
                        img = ImgRescale(img,
                                         center=True,
                                         scale=True,
                                         convert_back=True)
                        blank_image.paste(Image.fromarray(img),
                                          (ii * fineSize + ii + 1, 1))
                    elif i == 1:
                        img = gen_imgs[ii, :, :, :]
                        img = ImgRescale(img,
                                         center=True,
                                         scale=True,
                                         convert_back=True)
                        blank_image.paste(
                            Image.fromarray(img),
                            (ii * fineSize + ii + 1, 2 + fineSize))
            blank_image.save('generated_imgs/%s/%s_%d.png' %
                             (desc, desc, epoch))

        #pv = PatchViewer(grid_shape=(2, 4),
        #                 patch_shape=(256,256), is_color=True)
        #for i in range(2):
        #    for ii in range(4):
        #        if i == 0:
        #            img = dis_input[ii,:,:,:]
        #        elif i == 1:
        #            img = gen_imgs[ii,:,:,:]
        #        img = convert_img_back(img)
        #        pv.add_patch(img, rescale=False, activation=0)

        #pv.save('generated_imgs/%s/%s_%d.png'%(desc,desc,epoch))

        if (epoch) % save_freq == 0 and epoch > begin_save - 1:
            # Optionally, you could now dump the network weights to a file like this:
            np.savez('models/%s/gen_%d.npz' % (desc, epoch),
                     *lasagne.layers.get_all_param_values(generator))
            if use_PercepGAN == 1 and save_model_D is True:
                np.savez('models/%s/dis_%d.npz' % (desc, epoch),
                         *lasagne.layers.get_all_param_values(discriminator))
Example #2
0
def train(data_filepath='data/flowers.hdf5',
          ndf=64,
          ngf=128,
          z_dim=128,
          emb_dim=128,
          lr_d=5e-5,
          lr_g=5e-5,
          n_iterations=int(1e6),
          batch_size=64,
          iters_per_checkpoint=100,
          n_checkpoint_samples=16,
          out_dir='wgan_gp_lr5e-5'):
    global BATCH_SIZE

    BATCH_SIZE = batch_size
    logger = SummaryWriter(out_dir)
    logger.add_scalar('d_lr', lr_d, 0)
    logger.add_scalar('g_lr', lr_g, 0)
    train_data = get_data(data_filepath, 'train')
    val_data = get_data(data_filepath, 'valid')
    data_iterator = iterate_minibatches(train_data, batch_size)
    val_data_iterator = iterate_minibatches(val_data, n_checkpoint_samples)
    val_data = next(val_data_iterator)
    img_fixed = images_from_bytes(val_data[0])
    emb_fixed = val_data[1]
    txt_fixed = val_data[2]

    img_shape = img_fixed[0].shape
    emb_shape = emb_fixed[0].shape
    print("emb shape {}".format(img_shape))
    print("img shape {}".format(emb_shape))
    z_shape = (z_dim, )

    # plot real text for reference
    log_images(img_fixed, 'real', '0', logger)
    log_text(txt_fixed, 'real', '0', logger)

    # build models
    D = build_discriminator(img_shape, emb_shape, emb_dim, ndf)
    G = build_generator(z_shape, emb_shape, emb_dim, ngf)

    # build model outputs
    real_inputs = Input(shape=img_shape)
    txt_inputs = Input(shape=emb_shape)
    z_inputs = Input(shape=(z_dim, ))

    fake_samples = G([z_inputs, txt_inputs])
    averaged_samples = RandomWeightedAverage()([real_inputs, fake_samples])
    D_real = D([real_inputs, txt_inputs])
    D_fake = D([fake_samples, txt_inputs])
    D_averaged = D([averaged_samples, txt_inputs])

    # The gradient penalty loss function requires the input averaged samples to
    # get gradients. However, Keras loss functions can only have two arguments,
    # y_true and y_pred. We get around this by making a partial() of the
    # function with the averaged samples here.
    loss_gp = partial(loss_gradient_penalty,
                      averaged_samples=averaged_samples,
                      gradient_penalty_weight=GRADIENT_PENALTY_WEIGHT)
    # Functions need names or Keras will throw an error
    loss_gp.__name__ = 'loss_gradient_penalty'

    # define D graph and optimizer
    G.trainable = False
    D.trainable = True
    D_model = Model(inputs=[real_inputs, txt_inputs, z_inputs],
                    outputs=[D_real, D_fake, D_averaged])
    D_model.compile(optimizer=Adam(lr_d, beta_1=0.5, beta_2=0.9),
                    loss=[loss_wasserstein, loss_wasserstein, loss_gp])

    # define D(G(z)) graph and optimizer
    G.trainable = True
    D.trainable = False
    G_model = Model(inputs=[z_inputs, txt_inputs], outputs=D_fake)
    G_model.compile(Adam(lr=lr_g, beta_1=0.5, beta_2=0.9),
                    loss=loss_wasserstein)

    ones = np.ones((batch_size, 1), dtype=np.float32)
    minus_ones = -ones
    dummy = np.zeros((batch_size, 1), dtype=np.float32)

    # fix a z vector for training evaluation
    z_fixed = np.random.uniform(-1, 1, size=(n_checkpoint_samples, z_dim))

    for i in range(n_iterations):
        D.trainable = True
        G.trainable = False
        for j in range(N_CRITIC_ITERS):
            z = np.random.normal(0, 1, size=(batch_size, z_dim))
            real_batch = next(data_iterator)
            losses_d = D_model.train_on_batch(
                [images_from_bytes(real_batch[0]), real_batch[1], z],
                [ones, minus_ones, dummy])

        D.trainable = False
        G.trainable = True
        z = np.random.normal(0, 1, size=(batch_size, z_dim))
        real_batch = next(data_iterator)
        loss_g = G_model.train_on_batch([z, real_batch[1]], ones)

        print("iter", i)
        if (i % iters_per_checkpoint) == 0:
            G.trainable = False
            fake_image = G.predict([z_fixed, emb_fixed])
            log_images(fake_image, 'val_fake', i, logger)
            log_images(img_fixed, 'val_real', i, logger)
            log_text(txt_fixed, 'val_fake', i, logger)

        log_losses(losses_d, loss_g, i, logger)
Example #3
0
BUFFER_SIZE = 60000

print(f"Will generate {GENERATE_SQUARE}px square images.")

print(f"Images being loaded from {TRAINING_DATA_PATH}")

train_dataset = get_dataset(TRAINING_DATA_PATH, BUFFER_SIZE, BATCH_SIZE)
print(f"Images loaded from {TRAINING_DATA_PATH}")

# Checks if you want to continue training model from disk or start a new

if (INITIAL_TRAINING):
    print("Initializing Generator and Discriminator")
    generator = build_generator(image_shape=(GENERATE_SQUARE, GENERATE_SQUARE,
                                             1))
    discriminator = build_discriminator(image_shape=(GENERATE_SQUARE,
                                                     GENERATE_SQUARE, 2))
    print("Generator and Discriminator initialized")
else:
    print("Loading model from memory")
    if os.path.isfile(GENERATOR_PATH_PRE):
        generator = tf.keras.models.load_model(GENERATOR_PATH_PRE)
        print("Generator loaded")
    else:
        print("No generator file found")
    if os.path.isfile(DISCRIMINATOR_PATH_PRE):

        discriminator = tf.keras.models.load_model(DISCRIMINATOR_PATH_PRE)
        print("Discriminator loaded")
    else:
        print("No discriminator file found")
def train(data_folderpath='data/edges2shoes', image_size=256, ndf=64, ngf=64,
          lr_d=2e-4, lr_g=2e-4, n_iterations=int(1e6),
          batch_size=64, iters_per_checkpoint=100, n_checkpoint_samples=16,
          reconstruction_weight=100, out_dir='gan'):

    logger = SummaryWriter(out_dir)
    logger.add_scalar('d_lr', lr_d, 0)
    logger.add_scalar('g_lr', lr_g, 0)

    data_iterator = iterate_minibatches(
        data_folderpath + "/train/*.jpg", batch_size, image_size)
    val_data_iterator = iterate_minibatches(
        data_folderpath + "/val/*.jpg", n_checkpoint_samples, image_size)
    img_ab_fixed, _ = next(val_data_iterator)
    img_a_fixed, img_b_fixed = img_ab_fixed[:, 0], img_ab_fixed[:, 1]

    img_a_shape = img_a_fixed.shape[1:]
    img_b_shape = img_b_fixed.shape[1:]
    patch = int(img_a_shape[0] / 2**4)  # n_layers
    disc_patch = (patch, patch, 1)
    print("img a shape ", img_a_shape)
    print("img b shape ", img_b_shape)
    print("disc_patch ", disc_patch)

    # plot real text for reference
    log_images(img_a_fixed, 'real_a', '0', logger)
    log_images(img_b_fixed, 'real_b', '0', logger)

    # build models
    D = build_discriminator(
        img_a_shape, img_b_shape, ndf, activation='sigmoid')
    G = build_generator(img_a_shape, ngf)

    # build model outputs
    img_a_input = Input(shape=img_a_shape)
    img_b_input = Input(shape=img_b_shape)

    fake_samples = G(img_a_input)
    D_real = D([img_a_input, img_b_input])
    D_fake = D([img_a_input, fake_samples])

    loss_reconstruction = partial(mean_absolute_error,
                                  real_samples=img_b_input,
                                  fake_samples=fake_samples)
    loss_reconstruction.__name__ = 'loss_reconstruction'

    # define D graph and optimizer
    G.trainable = False
    D.trainable = True
    D_model = Model(inputs=[img_a_input, img_b_input],
                    outputs=[D_real, D_fake])
    D_model.compile(optimizer=Adam(lr_d, beta_1=0.5, beta_2=0.999),
                    loss='binary_crossentropy')

    # define D(G(z)) graph and optimizer
    G.trainable = True
    D.trainable = False
    G_model = Model(inputs=[img_a_input, img_b_input],
                    outputs=[D_fake, fake_samples])
    G_model.compile(Adam(lr=lr_g, beta_1=0.5, beta_2=0.999),
                    loss=['binary_crossentropy', loss_reconstruction],
                    loss_weights=[1, reconstruction_weight])

    ones = np.ones((batch_size, ) + disc_patch, dtype=np.float32)
    zeros = np.zeros((batch_size, ) + disc_patch, dtype=np.float32)
    dummy = zeros

    for i in range(n_iterations):
        D.trainable = True
        G.trainable = False

        image_ab_batch, _ = next(data_iterator)
        loss_d = D_model.train_on_batch(
            [image_ab_batch[:, 0], image_ab_batch[:, 1]],
            [ones, zeros])

        D.trainable = False
        G.trainable = True
        image_ab_batch, _ = next(data_iterator)
        loss_g = G_model.train_on_batch(
            [image_ab_batch[:, 0], image_ab_batch[:, 1]],
            [ones, dummy])

        print("iter", i)
        if (i % iters_per_checkpoint) == 0:
            G.trainable = False
            fake_image = G.predict(img_a_fixed)
            log_images(fake_image, 'val_fake', i, logger)
            save_model(G, out_dir)

        log_losses(loss_d, loss_g, i, logger)
Example #5
0
def train(data_filepath='data/flowers.hdf5',
          ndf=64,
          ngf=128,
          z_dim=128,
          emb_dim=128,
          lr_d=2e-4,
          lr_g=2e-4,
          n_iterations=int(1e6),
          batch_size=64,
          iters_per_checkpoint=500,
          n_checkpoint_samples=16,
          out_dir='gan'):

    logger = SummaryWriter(out_dir)
    logger.add_scalar('d_lr', lr_d, 0)
    logger.add_scalar('g_lr', lr_g, 0)
    train_data = get_data(data_filepath, 'train')
    val_data = get_data(data_filepath, 'valid')
    data_iterator = iterate_minibatches(train_data, batch_size)
    val_data_iterator = iterate_minibatches(val_data, n_checkpoint_samples)
    val_data = next(val_data_iterator)
    img_fixed = images_from_bytes(val_data[0])
    emb_fixed = val_data[1]
    txt_fixed = val_data[2]

    img_shape = img_fixed[0].shape
    emb_shape = emb_fixed[0].shape
    print("emb shape {}".format(img_shape))
    print("img shape {}".format(emb_shape))
    z_shape = (z_dim, )

    # plot real text for reference
    log_images(img_fixed, 'real', '0', logger)
    log_text(txt_fixed, 'real', '0', logger)

    # build models
    D = build_discriminator(img_shape,
                            emb_shape,
                            emb_dim,
                            ndf,
                            activation='sigmoid')
    G = build_generator(z_shape, emb_shape, emb_dim, ngf)

    # build model outputs
    real_inputs = Input(shape=img_shape)
    txt_inputs = Input(shape=emb_shape)
    txt_shuf_inputs = Input(shape=emb_shape)
    z_inputs = Input(shape=(z_dim, ))

    fake_samples = G([z_inputs, txt_inputs])
    D_real = D([real_inputs, txt_inputs])
    D_wrong = D([real_inputs, txt_shuf_inputs])
    D_fake = D([fake_samples, txt_inputs])

    # define D graph and optimizer
    G.trainable = False
    D.trainable = True
    D_model = Model(
        inputs=[real_inputs, txt_inputs, txt_shuf_inputs, z_inputs],
        outputs=[D_real, D_wrong, D_fake])
    D_model.compile(optimizer=Adam(lr_d, beta_1=0.5, beta_2=0.9),
                    loss='binary_crossentropy',
                    loss_weights=[1, 0.5, 0.5])

    # define D(G(z)) graph and optimizer
    G.trainable = True
    D.trainable = False
    G_model = Model(inputs=[z_inputs, txt_inputs], outputs=D_fake)
    G_model.compile(Adam(lr=lr_g, beta_1=0.5, beta_2=0.9),
                    loss='binary_crossentropy')

    ones = np.ones((batch_size, 1, 1, 1), dtype=np.float32)
    zeros = np.zeros((batch_size, 1, 1, 1), dtype=np.float32)

    # fix a z vector for training evaluation
    z_fixed = np.random.uniform(-1, 1, size=(n_checkpoint_samples, z_dim))

    for i in range(n_iterations):
        start = clock()
        D.trainable = True
        G.trainable = False
        z = np.random.normal(0, 1, size=(batch_size, z_dim))
        real_batch = next(data_iterator)
        images_batch = images_from_bytes(real_batch[0])
        emb_text_batch = real_batch[1]
        ids = np.arange(len(emb_text_batch))
        np.random.shuffle(ids)
        emb_text_batch_shuffle = emb_text_batch[ids]
        loss_d = D_model.train_on_batch(
            [images_batch, emb_text_batch, emb_text_batch_shuffle, z],
            [ones, zeros, zeros])

        D.trainable = False
        G.trainable = True
        z = np.random.normal(0, 1, size=(batch_size, z_dim))
        real_batch = next(data_iterator)
        loss_g = G_model.train_on_batch([z, real_batch[1]], ones)

        print("iter", i, "time", clock() - start)
        if (i % iters_per_checkpoint) == 0:
            G.trainable = False
            fake_image = G.predict([z_fixed, emb_fixed])
            log_images(fake_image, 'val_fake', i, logger)
            save_model(G, 'gan')

        log_losses(loss_d, loss_g, i, logger)
Example #6
0
    def test_discriminator(self):
        discriminator = build_discriminator()

        self.assertIsNotNone(discriminator)
        self.assertEqual((None, 1), discriminator.output_shape)
Example #7
0
File: train.py Project: rahhul/GANs
        d2_hist.append(d_loss2)
        g_hist.append(g_loss)
        # evaluate
        if (i+1) % (batch_per_epoch * 1) == 0:
            log_performance(i, g_model, latent_dim)
    # plot
    plot_history(d1_hist, d2_hist, g_hist)



# EXAMPLE

latent_dim = 100

# discriminator model
discriminator = build_discriminator(in_shape=(28, 28, 1))

# generator model
generator = build_generator(latent_dim=latent_dim)

# gan model
gan_model = build_gan(generator, discriminator)

# image dataset
dataset = load_mnist()
print(dataset.shape)

# train

train(generator, discriminator, gan_model, dataset, latent_dim)
    loss_func = utils.modified_binary_crossentropy if args.network_type == 'wgan' else 'binary_crossentropy'
    fake_label = -1 if args.network_type == 'wgan' else 0
    output_activation = 'linear' if args.network_type == 'wgan' else 'sigmoid'

    # Load the data
    (X_train,
     y_train), (X_test,
                y_test) = utils.load_data(join(args.datadir, 'embed/CV0'))
    if not args.normalized:
        print('Now normalize the input data')
        X_train = (X_train.astype(np.float32) - 0.5) / 0.5
        X_test = (X_test.astype(np.float32) - 0.5) / 0.5
    num_train, num_test = X_train.shape[0], X_test.shape[0]

    # build the discriminator
    discriminator = models.build_discriminator(
        args.seqlen, args.nchannel, output_activation=output_activation)

    # build the generator
    generator = models.build_generator(latent_size, args.seqlen, args.nchannel)

    # we only want to be able to train generation for the combined model
    latent = Input(shape=(latent_size, ))
    utils.set_trainability(discriminator, False)
    fake = generator(latent)
    fake = discriminator(fake)
    combined = Model(latent, fake)
    combined.compile(optimizer=g_optim, loss=loss_func)

    # The actual discriminator model
    utils.set_trainability(discriminator, True)
    real_samples = Input(shape=X_train.shape[1:])
Example #9
0
def train():
    # Set main parameters
    start_time = time.time()
    dataset_dir = "data/*.*"
    batch_size = 64
    z_shape = 100
    epochs = 10000
    dis_learning_rate = 0.005
    gen_learning_rate = 0.005
    dis_momentum = 0.5
    gen_momentum = 0.5
    dis_nesterov = True
    gen_nesterov = True

    # Define optimizers (can change to Adam later)
    #dis_optimizer = SGD(lr=dis_learning_rate, momentum=dis_momentum, nesterov=dis_nesterov)
    #gen_optimizer = SGD(lr=gen_learning_rate, momentum=gen_momentum, nesterov=gen_nesterov)
    dis_optimizer = Adam()
    gen_optimizer = Adam()

    # Load images
    all_images = []
    for index, filename in enumerate(glob.glob(dataset_dir)):
        all_images.append(imread(filename, flatten=False, mode='RGB'))

    # Compile images into array and normailze them
    X = np.array(all_images)
    X = normalize(X)
    X = X.astype(np.float32)

    # Build the GAN models
    dis_model = build_discriminator()
    dis_model.compile(loss='binary_crossentropy', optimizer=dis_optimizer)

    gen_model = build_generator()
    gen_model.compile(loss='mse', optimizer=gen_optimizer)

    adversarial_model = build_adversarial_model(gen_model, dis_model)
    adversarial_model.compile(loss='binary_crossentropy',
                              optimizer=gen_optimizer)

    # Record training data to the tensorboard
    tensorboard = TensorBoard(log_dir="results/logs/{}".format(time.time()),
                              write_images=True,
                              write_grads=True,
                              write_graph=True)
    tensorboard.set_model(gen_model)
    tensorboard.set_model(dis_model)

    for epoch in range(epochs):
        print("--------------------------")
        print("Epoch:{}".format(epoch))

        dis_losses = []
        gen_losses = []

        num_batches = int(X.shape[0] / batch_size)

        print("Number of batches:{}".format(num_batches))
        for index in range(num_batches):
            print("Batch:{}".format(index))

            z_noise = np.random.normal(0, 1, size=(batch_size, z_shape))
            # z_noise = np.random.uniform(-1, 1, size=(batch_size, 100))

            generated_images = gen_model.predict_on_batch(z_noise)

            # visualize_rgb(generated_images[0])
            """
            Train the discriminator model
            """

            dis_model.trainable = True

            image_batch = X[index * batch_size:(index + 1) * batch_size]

            # Label switching every three epochs
            if epoch % 3 == 0:
                # Use label smoothing to avoid discriminator approaching zero loss quickly
                y_fake = np.random.uniform(low=0.7,
                                           high=1.2,
                                           size=(batch_size, ))
                y_real = np.random.uniform(low=0,
                                           high=0.3,
                                           size=(batch_size, ))
            else:
                y_real = np.random.uniform(low=0.7,
                                           high=1.2,
                                           size=(batch_size, ))
                y_fake = np.random.uniform(low=0,
                                           high=0.3,
                                           size=(batch_size, ))

            # Real labels to train generator
            y_real_gen = np.random.uniform(low=0.7,
                                           high=1.0,
                                           size=(batch_size, ))

            dis_loss_real = dis_model.train_on_batch(image_batch, y_real)
            dis_loss_fake = dis_model.train_on_batch(generated_images, y_fake)

            d_loss = (dis_loss_real + dis_loss_fake) / 2
            print("d_loss:", d_loss)

            dis_model.trainable = False
            """
            Train the generator model(adversarial model)
            """
            z_noise = np.random.normal(0, 1, size=(batch_size, z_shape))
            # z_noise = np.random.uniform(-1, 1, size=(batch_size, 100))

            g_loss = adversarial_model.train_on_batch(z_noise, y_real_gen)
            print("g_loss:", g_loss)

            dis_losses.append(d_loss)
            gen_losses.append(g_loss)
        """
        Sample some images and save them
        """
        # Sample images every one hundred epochs
        if epoch % 20 == 0:
            z_noise = np.random.normal(0, 1, size=(batch_size, z_shape))
            gen_images1 = gen_model.predict_on_batch(z_noise)

            for img in gen_images1[:2]:
                save_rgb_img(denormalize(img),
                             "results/img/gen_{}.png".format(epoch))

        print("Epoch:{}, dis_loss:{}".format(epoch, np.mean(dis_losses)))
        print("Epoch:{}, gen_loss: {}".format(epoch, np.mean(gen_losses)))
        """
        Save losses to Tensorboard after each epoch
        """
        write_log(tensorboard, 'discriminator_loss', np.mean(dis_losses),
                  epoch)
        write_log(tensorboard, 'generator_loss', np.mean(gen_losses), epoch)
    """
    Save models
    """
    gen_model.save("results/models/generator_model.h5")
    dis_model.save("results/models/discriminator_model.h5")

    print("Time:", (time.time() - start_time))
def train(data_filepath='data/flowers.hdf5',
          ndf=64,
          ngf=128,
          z_dim=128,
          emb_dim=128,
          lr_d=1e-4,
          lr_g=1e-4,
          n_iterations=int(1e6),
          batch_size=64,
          iters_per_checkpoint=100,
          n_checkpoint_samples=16,
          out_dir='rgan'):

    logger = SummaryWriter(out_dir)
    logger.add_scalar('d_lr', lr_d, 0)
    logger.add_scalar('g_lr', lr_g, 0)
    train_data = get_data(data_filepath, 'train')
    val_data = get_data(data_filepath, 'valid')
    data_iterator = iterate_minibatches(train_data, batch_size)
    val_data_iterator = iterate_minibatches(val_data, n_checkpoint_samples)
    val_data = next(val_data_iterator)
    img_fixed = images_from_bytes(val_data[0])
    emb_fixed = val_data[1]
    txt_fixed = val_data[2]

    img_shape = img_fixed[0].shape
    emb_shape = emb_fixed[0].shape
    print("emb shape {}".format(img_shape))
    print("img shape {}".format(emb_shape))
    z_shape = (z_dim, )

    # plot real text for reference
    log_images(img_fixed, 'real', '0', logger)
    log_text(txt_fixed, 'real', '0', logger)

    # build models
    D = build_discriminator(img_shape, emb_shape, emb_dim, ndf)
    G = build_generator(z_shape, emb_shape, emb_dim, ngf)

    # build model outputs
    real_inputs = Input(shape=img_shape)
    txt_inputs = Input(shape=emb_shape)
    z_inputs = Input(shape=(z_dim, ))
    fake_samples = G([z_inputs, txt_inputs])
    D_real = D([real_inputs, txt_inputs])
    D_fake = D([fake_samples, txt_inputs])

    # build losses
    loss_d_fn = partial(rel_disc_loss, disc_r=D_real, disc_f=D_fake)
    loss_g_fn = partial(rel_gen_loss, disc_r=D_real, disc_f=D_fake)

    # define D graph and optimizer
    G.trainable = False
    D.trainable = True
    D_model = Model(inputs=[real_inputs, txt_inputs, z_inputs],
                    outputs=[D_real, D_fake])
    D_model.compile(optimizer=Adam(lr_d, beta_1=0.5, beta_2=0.999),
                    loss=[loss_d_fn, None])

    # define G graph and optimizer
    G.trainable = True
    D.trainable = False
    G_model = Model(inputs=[real_inputs, z_inputs, txt_inputs],
                    outputs=[D_real, D_fake])
    G_model.compile(Adam(lr=lr_g, beta_1=0.5, beta_2=0.999),
                    loss=[loss_g_fn, None])

    # dummy loss
    dummy_y = np.zeros((batch_size, 1), dtype=np.float32)

    # fix a z vector for training evaluation
    z_fixed = np.random.uniform(-1, 1, size=(n_checkpoint_samples, z_dim))

    for i in range(n_iterations):
        D.trainable = True
        G.trainable = False
        z = np.random.normal(0, 1, size=(batch_size, z_dim))
        real_batch = next(data_iterator)
        loss_d = D_model.train_on_batch(
            [images_from_bytes(real_batch[0]), real_batch[1], z], dummy_y)[0]

        D.trainable = False
        G.trainable = True
        z = np.random.normal(0, 1, size=(batch_size, z_dim))
        real_batch = next(data_iterator)
        loss_g = G_model.train_on_batch(
            [images_from_bytes(real_batch[0]), z, real_batch[1]], dummy_y)[0]

        print("iter", i)
        if (i % iters_per_checkpoint) == 0:
            G.trainable = False
            fake_image = G.predict([z_fixed, emb_fixed])
            log_images(fake_image, 'val_fake', i, logger)
            log_images(img_fixed, 'val_real', i, logger)
            log_text(txt_fixed, 'val_fake', i, logger)

        log_losses(loss_d, loss_g, i, logger)
Example #11
0
# Optimizer algorithm

from optimizers import get_optimizer
optimizer = get_optimizer(args)

##############################################################################################################################
# Building model

from models import build_generator, build_discriminator
import keras.backend as K

logger.info('  Building model')

generator = build_generator(args, overall_maxlen, vocab)

discriminator_C1 = build_discriminator(args, name='classifier1')
discriminator_C2 = build_discriminator(args, name='classifier2')

z1 = Input(shape=(overall_maxlen, ))
feature_g1 = generator(z1)
prob1 = discriminator_C1(feature_g1)
combined_g_c1 = Model(z1, prob1)
combined_g_c1.compile(optimizer=optimizer,
                      loss='categorical_crossentropy',
                      metrics=['categorical_accuracy'])
combined_g_c1.summary()

z2 = Input(shape=(overall_maxlen, ))
feature_g2 = generator(z2)
prob2 = discriminator_C2(feature_g2)
combined_g_c2 = Model(z2, prob2)