Пример #1
0
def train_vae(files,
              input_shape=[None, 784],
              output_shape=[None, 784],
              learning_rate=0.0001,
              batch_size=128,
              n_epochs=50,
              crop_shape=[64, 64],
              crop_factor=0.8,
              n_filters=[100, 100, 100, 100],
              n_hidden=256,
              n_code=50,
              denoising=True,
              convolutional=True,
              variational=True,
              softmax=False,
              classifier='alexnet_v2',
              filter_sizes=[3, 3, 3, 3],
              dropout=True,
              keep_prob=0.8,
              activation=tf.nn.relu,
              img_step=1000,
              save_step=2500,
              output_path="result",
              ckpt_name="vae.ckpt"):
    """General purpose training of a (Variational) (Convolutional) Autoencoder.

    Supply a list of file paths to images, and this will do everything else.

    Parameters
    ----------
    files : list of strings
    List of paths to images.
    input_shape : list
    Must define what the input image's shape is.
    learning_rate : float, optional
    Learning rate.
    batch_size : int, optional
    Batch size.
    n_epochs : int, optional
    Number of epochs.
    n_examples : int, optional
    Number of example to use while demonstrating the current training
    iteration's reconstruction.  Creates a square montage, so make
    sure int(sqrt(n_examples))**2 = n_examples, e.g. 16, 25, 36, ... 100.
    crop_shape : list, optional
    Size to centrally crop the image to.
    crop_factor : float, optional
    Resize factor to apply before cropping.
    n_filters : list, optional
    Same as VAE's n_filters.
    n_hidden : int, optional
    Same as VAE's n_hidden.
    n_code : int, optional
    Same as VAE's n_code.
    convolutional : bool, optional
    Use convolution or not.
    variational : bool, optional
    Use variational layer or not.
    softmax : bool, optional
    Use the classification network or not.
    classifier : str, optional
    Network for classification.
    filter_sizes : list, optional
    Same as VAE's filter_sizes.
    dropout : bool, optional
    Use dropout or not
    keep_prob : float, optional
    Percent of keep for dropout.
    activation : function, optional
    Which activation function to use.
    img_step : int, optional
    How often to save training images showing the manifold and
    reconstruction.
    save_step : int, optional
    How often to save checkpoints.
    ckpt_name : str, optional
    Checkpoints will be named as this, e.g. 'model.ckpt'
    """

    batch_train = create_input_pipeline(files=files,
                                        batch_size=batch_size,
                                        n_epochs=n_epochs,
                                        crop_shape=crop_shape,
                                        crop_factor=crop_factor,
                                        input_shape=input_shape,
                                        output_shape=output_shape)

    if softmax:
        batch_imagenet = create_input_pipeline(
            files="./list_annotated_imagenet.csv",
            batch_size=batch_size,
            n_epochs=n_epochs,
            crop_shape=crop_shape,
            crop_factor=crop_factor,
            input_shape=input_shape,
            output_shape=output_shape)
        batch_pascal = create_input_pipeline(
            files="./list_annotated_pascal.csv",
            batch_size=batch_size,
            n_epochs=n_epochs,
            crop_shape=crop_shape,
            crop_factor=crop_factor,
            input_shape=input_shape,
            output_shape=output_shape)
        batch_shapenet = create_input_pipeline(
            files="./list_annotated_img_test.csv",
            batch_size=batch_size,
            n_epochs=n_epochs,
            crop_shape=crop_shape,
            crop_factor=crop_factor,
            input_shape=input_shape,
            output_shape=output_shape)

    ae = VAE(input_shape=[None] + crop_shape + [input_shape[-1]],
             output_shape=[None] + crop_shape + [output_shape[-1]],
             denoising=denoising,
             convolutional=convolutional,
             variational=variational,
             softmax=softmax,
             n_filters=n_filters,
             n_hidden=n_hidden,
             n_code=n_code,
             dropout=dropout,
             filter_sizes=filter_sizes,
             activation=activation,
             classifier=classifier)

    with open(files, "r") as f:
        reader = csv.reader(f, delimiter=",")
        data = list(reader)
        n_files = len(data)

    # Create a manifold of our inner most layer to show
    # example reconstructions.  This is one way to see
    # what the "embedding" or "latent space" of the encoder
    # is capable of encoding, though note that this is just
    # a random hyperplane within the latent space, and does not
    # encompass all possible embeddings.
    np.random.seed(1)
    zs = np.random.uniform(-1.0, 1.0, [4, n_code]).astype(np.float32)
    zs = utils.make_latent_manifold(zs, 6)

    optimizer_vae = tf.train.AdamOptimizer(
        learning_rate=learning_rate).minimize(ae['cost_vae'])
    if softmax:
        # AlexNet for 0.01,
        # Iception v1 for 0.01
        # SqueezeNet for 0.01
        if classifier == 'inception_v3':
            lr = tf.train.exponential_decay(0.1,
                                            0,
                                            n_files / batch_size * 20,
                                            0.16,
                                            staircase=True)
            optimizer_softmax = tf.train.RMSPropOptimizer(
                lr, decay=0.9, momentum=0.9,
                epsilon=0.1).minimize(ae['cost_s'])
        elif classifier == 'inception_v2':
            optimizer_softmax = tf.train.AdamOptimizer(
                learning_rate=0.01).minimize(ae['cost_s'])
        elif classifier == 'inception_v1':
            optimizer_softmax = tf.train.GradientDescentOptimizer(
                learning_rate=0.01).minimize(ae['cost_s'])
        elif (classifier == 'squeezenet') or (classifier == 'zigzagnet'):
            optimizer_softmax = tf.train.RMSPropOptimizer(
                0.04, decay=0.9, momentum=0.9,
                epsilon=0.1).minimize(ae['cost_s'])
        elif classifier == 'alexnet_v2':
            optimizer_softmax = tf.train.GradientDescentOptimizer(
                learning_rate=0.01).minimize(ae['cost_s'])
        else:
            optimizer_softmax = tf.train.GradientDescentOptimizer(
                learning_rate=0.001).minimize(ae['cost_s'])

    # We create a session to use the graph together with a GPU declaration.
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    # config.gpu_options.per_process_gpu_memory_fraction = 0.4
    sess = tf.Session(config=config)
    saver = tf.train.Saver()
    sess.run(tf.global_variables_initializer())
    train_writer = tf.summary.FileWriter('./summary', sess.graph)

    # This will handle our threaded image pipeline
    coord = tf.train.Coordinator()

    # Ensure no more changes to graph
    tf.get_default_graph().finalize()

    # Start up the queues for handling the image pipeline
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    if (os.path.exists(output_path + '/' + ckpt_name + '.index')
            or os.path.exists(ckpt_name)):
        saver.restore(sess, output_path + '/' + ckpt_name)
        print("Model restored")

    # Fit all training data
    t_i = 0
    step_i = 0
    batch_i = 0
    epoch_i = 0
    summary_i = 0
    cost = 0
    # Test samples of training data from ShapeNet
    test_xs_img, test_xs_obj, test_xs_label = sess.run(batch_train)
    test_xs_img /= 255.0
    test_xs_obj /= 255.0
    utils.montage(test_xs_img, output_path + '/train_img.png')
    utils.montage(test_xs_obj, output_path + '/train_obj.png')

    # Test samples of testing data from ImageNet
    test_imagenet_img, _, test_imagenet_label = sess.run(batch_imagenet)
    test_imagenet_img /= 255.0
    utils.montage(test_imagenet_img, output_path + '/test_imagenet_img.png')

    # Test samples of testing data from PASCAL 2012
    test_pascal_img, _, test_pascal_label = sess.run(batch_pascal)
    test_pascal_img /= 255.0
    utils.montage(test_pascal_img, output_path + '/test_pascal_img.png')

    # Test samples of testing data from ShapeNet test data
    test_shapenet_img, _, test_shapenet_label = sess.run(batch_shapenet)
    test_shapenet_img /= 255.0
    utils.montage(test_shapenet_img, output_path + '/test_shapenet_img.png')
    try:
        while not coord.should_stop():
            batch_i += 1
            step_i += 1
            batch_xs_img, batch_xs_obj, batch_xs_label = sess.run(batch_train)
            batch_xs_img /= 255.0
            batch_xs_obj /= 255.0

            # Here we must set corrupt_rec and corrupt_cls as 0 to find a
            # proper ratio of variance to feed for variable var_prob.
            # We use tanh as non-linear function for ratio of Vars from
            # the reconstructed channels and original channels
            var_prob = sess.run(ae['var_prob'],
                                feed_dict={
                                    ae['x']: test_xs_img,
                                    ae['label']: test_xs_label[:, 0],
                                    ae['train']: True,
                                    ae['keep_prob']: 1.0,
                                    ae['corrupt_rec']: 0,
                                    ae['corrupt_cls']: 0
                                })

            # Here is a fast training process
            corrupt_rec = np.tanh(0.25 * var_prob)
            corrupt_cls = np.tanh(1 - np.tanh(2 * var_prob))

            # Optimizing reconstruction network
            cost_vae = sess.run(
                [ae['cost_vae'], optimizer_vae],
                feed_dict={
                    ae['x']: batch_xs_img,
                    ae['t']: batch_xs_obj,
                    ae['label']: batch_xs_label[:, 0],
                    ae['train']: True,
                    ae['keep_prob']: keep_prob,
                    ae['corrupt_rec']: corrupt_rec,
                    ae['corrupt_cls']: corrupt_cls
                })[0]
            cost += cost_vae
            if softmax:

                # Optimizing classification network
                cost_s = sess.run(
                    [ae['cost_s'], optimizer_softmax],
                    feed_dict={
                        ae['x']: batch_xs_img,
                        ae['t']: batch_xs_obj,
                        ae['label']: batch_xs_label[:, 0],
                        ae['train']: True,
                        ae['keep_prob']: keep_prob,
                        ae['corrupt_rec']: corrupt_rec,
                        ae['corrupt_cls']: corrupt_cls
                    })[0]
                cost += cost_s

            if step_i % img_step == 0:
                if variational:
                    # Plot example reconstructions from latent layer
                    recon = sess.run(ae['y'],
                                     feed_dict={
                                         ae['z']: zs,
                                         ae['train']: False,
                                         ae['keep_prob']: 1.0,
                                         ae['corrupt_rec']: 0,
                                         ae['corrupt_cls']: 0
                                     })
                    utils.montage(recon.reshape([-1] + crop_shape),
                                  output_path + '/manifold_%08d.png' % t_i)

                # Plot example reconstructions
                recon = sess.run(ae['y'],
                                 feed_dict={
                                     ae['x']: test_xs_img,
                                     ae['train']: False,
                                     ae['keep_prob']: 1.0,
                                     ae['corrupt_rec']: 0,
                                     ae['corrupt_cls']: 0
                                 })
                utils.montage(recon.reshape([-1] + crop_shape),
                              output_path + '/recon_%08d.png' % t_i)
                """
                filters = sess.run(
                  ae['Ws'], feed_dict={
                              ae['x']: test_xs_img,
                              ae['train']: False,
                              ae['keep_prob']: 1.0,
                              ae['corrupt_rec']: 0,
                              ae['corrupt_cls']: 0})
                #for filter_element in filters:
                utils.montage_filters(filters[-1],
                                output_path + '/filter_%08d.png' % t_i)
                """

                # Test on ImageNet samples
                with open('./list_annotated_imagenet.csv', 'r') as csvfile:
                    spamreader = csv.reader(csvfile)
                    rows = list(spamreader)
                    totalrows = len(rows)
                num_batches = np.int_(np.floor(totalrows / batch_size))
                accumulated_acc = 0
                for index_batch in range(1, num_batches + 1):
                    test_image, _, test_label = sess.run(batch_imagenet)
                    test_image /= 255.0
                    acc, z_codes, sm_codes = sess.run(
                        [ae['acc'], ae['z'], ae['predictions']],
                        feed_dict={
                            ae['x']: test_image,
                            ae['label']: test_label[:, 0],
                            ae['train']: False,
                            ae['keep_prob']: 1.0,
                            ae['corrupt_rec']: 0,
                            ae['corrupt_cls']: 0
                        })
                    accumulated_acc += acc.tolist().count(True) / acc.size
                    if index_batch == 1:
                        z_imagenet = z_codes
                        sm_imagenet = sm_codes
                        labels_imagenet = test_label
                        # Plot example reconstructions
                        recon = sess.run(ae['y'],
                                         feed_dict={
                                             ae['x']: test_imagenet_img,
                                             ae['train']: False,
                                             ae['keep_prob']: 1.0,
                                             ae['corrupt_rec']: 0,
                                             ae['corrupt_cls']: 0
                                         })
                        utils.montage(
                            recon.reshape([-1] + crop_shape),
                            output_path + '/recon_imagenet_%08d.png' % t_i)
                    else:
                        z_imagenet = np.append(z_imagenet, z_codes, axis=0)
                        sm_imagenet = np.append(sm_imagenet, sm_codes, axis=0)
                        labels_imagenet = np.append(labels_imagenet,
                                                    test_label,
                                                    axis=0)
                accumulated_acc /= num_batches
                print("Accuracy of ImageNet images= %.3f" % (accumulated_acc))

                fig = plt.figure()
                z_viz, V = pca(z_imagenet, dim_remain=2)
                ax = fig.add_subplot(121)
                # ax.set_aspect('equal')
                ax.scatter(z_viz[:, 0],
                           z_viz[:, 1],
                           c=labels_imagenet[:, 0],
                           alpha=0.4,
                           cmap='gist_rainbow')
                sm_viz, V = pca(sm_imagenet, dim_remain=2)
                ax = fig.add_subplot(122)
                # ax.set_aspect('equal')
                ax.scatter(sm_viz[:, 0],
                           sm_viz[:, 1],
                           c=labels_imagenet[:, 0],
                           alpha=0.4,
                           cmap='gist_rainbow')

                fig.savefig(output_path + '/z_feat_imagenet.png',
                            transparent=True)
                plt.clf()

                # Test on PASCAL 2012 samples
                with open('./list_annotated_pascal.csv', 'r') as csvfile:
                    spamreader = csv.reader(csvfile)
                    rows = list(spamreader)
                    totalrows = len(rows)
                num_batches = np.int_(np.floor(totalrows / batch_size))
                accumulated_acc = 0
                for index_batch in range(1, num_batches + 1):
                    test_image, _, test_label = sess.run(batch_pascal)
                    test_image /= 255.0
                    acc, z_codes, sm_codes = sess.run(
                        [ae['acc'], ae['z'], ae['predictions']],
                        feed_dict={
                            ae['x']: test_image,
                            ae['label']: test_label[:, 0],
                            ae['train']: False,
                            ae['keep_prob']: 1.0,
                            ae['corrupt_rec']: 0,
                            ae['corrupt_cls']: 0
                        })
                    accumulated_acc += acc.tolist().count(True) / acc.size
                    if index_batch == 1:
                        z_pascal = z_codes
                        sm_pascal = sm_codes
                        labels_pascal = test_label
                        # Plot example reconstructions
                        recon = sess.run(ae['y'],
                                         feed_dict={
                                             ae['x']: test_pascal_img,
                                             ae['train']: False,
                                             ae['keep_prob']: 1.0,
                                             ae['corrupt_rec']: 0,
                                             ae['corrupt_cls']: 0
                                         })
                        utils.montage(
                            recon.reshape([-1] + crop_shape),
                            output_path + '/recon_pascal_%08d.png' % t_i)
                    else:
                        z_pascal = np.append(z_pascal, z_codes, axis=0)
                        sm_pascal = np.append(sm_pascal, sm_codes, axis=0)
                        labels_pascal = np.append(labels_pascal,
                                                  test_label,
                                                  axis=0)
                accumulated_acc /= num_batches
                print("Accuracy of PASCAL images= %.3f" % (accumulated_acc))

                fig = plt.figure()
                z_viz, V = pca(z_pascal, dim_remain=2)
                ax = fig.add_subplot(121)
                # ax.set_aspect('equal')
                ax.scatter(z_viz[:, 0],
                           z_viz[:, 1],
                           c=labels_pascal[:, 0],
                           alpha=0.4,
                           cmap='gist_rainbow')
                sm_viz, V = pca(sm_pascal, dim_remain=2)
                ax = fig.add_subplot(122)
                # ax.set_aspect('equal')
                ax.scatter(sm_viz[:, 0],
                           sm_viz[:, 1],
                           c=labels_pascal[:, 0],
                           alpha=0.4,
                           cmap='gist_rainbow')

                fig.savefig(output_path + '/z_feat_pascal.png',
                            transparent=True)
                plt.clf()

                # Test on ShapeNet test samples
                with open('./list_annotated_img_test.csv', 'r') as csvfile:
                    spamreader = csv.reader(csvfile)
                    rows = list(spamreader)
                    totalrows = len(rows)
                num_batches = np.int_(np.floor(totalrows / batch_size))
                accumulated_acc = 0
                for index_batch in range(1, num_batches + 1):
                    test_image, _, test_label = sess.run(batch_shapenet)
                    test_image /= 255.0
                    acc, z_codes, sm_codes = sess.run(
                        [ae['acc'], ae['z'], ae['predictions']],
                        feed_dict={
                            ae['x']: test_image,
                            ae['label']: test_label[:, 0],
                            ae['train']: False,
                            ae['keep_prob']: 1.0,
                            ae['corrupt_rec']: 0,
                            ae['corrupt_cls']: 0
                        })
                    accumulated_acc += acc.tolist().count(True) / acc.size
                    if index_batch == 1:
                        z_shapenet = z_codes
                        sm_shapenet = sm_codes
                        labels_shapenet = test_label
                        # Plot example reconstructions
                        recon = sess.run(ae['y'],
                                         feed_dict={
                                             ae['x']: test_shapenet_img,
                                             ae['train']: False,
                                             ae['keep_prob']: 1.0,
                                             ae['corrupt_rec']: 0,
                                             ae['corrupt_cls']: 0
                                         })
                        utils.montage(
                            recon.reshape([-1] + crop_shape),
                            output_path + '/recon_shapenet_%08d.png' % t_i)
                    else:
                        z_shapenet = np.append(z_shapenet, z_codes, axis=0)
                        sm_shapenet = np.append(sm_shapenet, sm_codes, axis=0)
                        labels_shapenet = np.append(labels_shapenet,
                                                    test_label,
                                                    axis=0)
                accumulated_acc /= num_batches
                print("Accuracy of ShapeNet images= %.3f" % (accumulated_acc))

                fig = plt.figure()
                z_viz, V = pca(z_shapenet, dim_remain=2)
                ax = fig.add_subplot(121)
                # ax.set_aspect('equal')
                ax.scatter(z_viz[:, 0],
                           z_viz[:, 1],
                           c=labels_shapenet[:, 0],
                           alpha=0.4,
                           cmap='gist_rainbow')
                sm_viz, V = pca(sm_shapenet, dim_remain=2)
                ax = fig.add_subplot(122)
                # ax.set_aspect('equal')
                ax.scatter(sm_viz[:, 0],
                           sm_viz[:, 1],
                           c=labels_shapenet[:, 0],
                           alpha=0.4,
                           cmap='gist_rainbow')

                fig.savefig(output_path + '/z_feat_shapenet.png',
                            transparent=True)
                plt.clf()

                t_i += 1

            if step_i % save_step == 0:

                # Save the variables to disk.
                # We should set global_step=batch_i if we want several ckpt
                saver.save(sess,
                           output_path + "/" + ckpt_name,
                           global_step=None,
                           write_meta_graph=False)
                if softmax:
                    acc = sess.run(ae['acc'],
                                   feed_dict={
                                       ae['x']: test_xs_img,
                                       ae['label']: test_xs_label[:, 0],
                                       ae['train']: False,
                                       ae['keep_prob']: 1.0,
                                       ae['corrupt_rec']: 0,
                                       ae['corrupt_cls']: 0
                                   })

                    print(
                        "epoch %d: VAE = %d, SM = %.3f, Acc = %.3f, R_Var = %.3f, Cpt_R = %.3f, Cpt_C = %.3f"
                        %
                        (epoch_i, cost_vae, cost_s, acc.tolist().count(True) /
                         acc.size, var_prob, corrupt_rec, corrupt_cls))

                    # Summary recording to Tensorboard
                    summary = sess.run(ae['merged'],
                                       feed_dict={
                                           ae['x']: batch_xs_img,
                                           ae['t']: batch_xs_obj,
                                           ae['label']: batch_xs_label[:, 0],
                                           ae['train']: False,
                                           ae['keep_prob']: keep_prob,
                                           ae['corrupt_rec']: corrupt_rec,
                                           ae['corrupt_cls']: corrupt_cls
                                       })

                    summary_i += 1
                    train_writer.add_summary(summary, summary_i)
                else:
                    print("VAE loss = %d" % cost_vae)

            if batch_i > (n_files / batch_size):
                batch_i = 0
                epoch_i += 1

    except tf.errors.OutOfRangeError:
        print('Done.')
    finally:
        # One of the threads has issued an exception.  So let's tell all the
        # threads to shutdown.
        coord.request_stop()

    # Wait until all threads have finished.
    coord.join(threads)

    # Clean up the session.
    sess.close()
Пример #2
0
def train_ds():

    init_lr_g = 1e-4  # learning rates
    init_lr_d = 1e-4

    n_latent = 100  # still need to dig into this idea of a latent variable

    n_epochs = 1000000
    batch_size = 200
    n_samples = 15

    # Image sizes, crop etc
    input_shape = [218, 178, 3]
    crop_shape = [64, 64, 3]
    crop_factor = 0.8

    from libs.dataset_utils import create_input_pipeline
    from libs.datasets import CELEB

    files = CELEB()

    # Feed the network batch by batch, tailor images
    batch = create_input_pipeline(files=files,
                                  batch_size=batch_size,
                                  n_epochs=n_epochs,
                                  crop_shape=crop_shape,
                                  crop_factor=crop_factor,
                                  shape=input_shape)

    # [None] + crop_shape: batch (number of samples) + shape of tailored images
    gan = GAN(input_shape=[None] + crop_shape,
              n_features=10,
              n_latent=n_latent,
              rgb=True,
              debug=False)

    # List all the variables
    # Discriminator
    vars_d = [
        v for v in tf.trainable_variables()
        if v.name.startswith('discriminator')
    ]
    print('Training discriminator variables:')
    [
        print(v.name) for v in tf.trainable_variables()
        if v.name.startswith('discriminator')
    ]

    # Generator
    vars_g = [
        v for v in tf.trainable_variables() if v.name.startswith('generator')
    ]
    print('Training generator variables:')
    [
        print(v.name) for v in tf.trainable_variables()
        if v.name.startswith('generator')
    ]

    #########

    zs = np.random.uniform(-1.0, 1.0, [4, n_latent]).astype(np.float32)
    zs = make_latent_manifold(zs, n_samples)

    # Even the learning rates will be learnt! Those will be passed
    # to the opt_g & d below, which use the Adam algorithm
    lr_g = tf.placeholder(tf.float32, shape=[], name='learning_rate_g')
    lr_d = tf.placeholder(tf.float32, shape=[], name='learning_rate_d')

    # Check regularization intros above (before the code).
    # Process applied to discriminator and generator variables
    try:
        from tf.contrib.layers import apply_regularization
        d_reg = apply_regularization(tf.contrib.layers.l2_regularizer(1e-6),
                                     vars_d)
        g_reg = apply_regularization(tf.contrib.layers.l2_regularizer(1e-6),
                                     vars_g)
    except:
        d_reg, g_reg = 0, 0

    # Those two are passed to the Generator & Discriminator
    # respectively through sess.run below
    # (Both networks are trained alternatively)
    opt_g = tf.train.AdamOptimizer(lr_g, name='Adam_g').minimize(
        gan['loss_G'] + g_reg, var_list=vars_g)
    opt_d = tf.train.AdamOptimizer(lr_d, name='Adam_d').minimize(
        gan['loss_D'] + d_reg, var_list=vars_d)

    #########

    sess = tf.Session()
    init_op = tf.global_variables_initializer()

    #########

    # More Tensorboard summaries
    saver = tf.train.Saver()
    sums = gan['sums']

    G_sum_op = tf.summary.merge([
        sums['G'], sums['loss_G'], sums['z'], sums['loss_D_fake'],
        sums['D_fake']
    ])
    D_sum_op = tf.summary.merge([
        sums['loss_D'], sums['loss_D_real'], sums['loss_D_fake'], sums['z'],
        sums['x'], sums['D_real'], sums['D_fake']
    ])

    writer = tf.summary.FileWriter("./logs", sess.graph_def)

    #########

    # Multithreading / parallel calculations (if with GPU)
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    sess.run(init_op)

    # g = tf.get_default_graph()
    # [print(op.name) for op in g.get_operations()]

    #########

    # Checkpoint savery
    if os.path.exists("gan.ckpt"):
        saver.restore(sess, "gan.ckpt")
        print("GAN model restored.")

    fig, ax = plt.subplots(1, 1, figsize=(10, 10))

    step_i, t_i = 0, 0
    loss_d = 1
    loss_g = 1
    n_loss_d, total_loss_d = 1, 1
    n_loss_g, total_loss_g = 1, 1

    try:
        while not coord.should_stop():

            batch_xs = sess.run(batch)
            step_i += 1
            batch_zs = np.random.uniform(
                -1.0, 1.0, [batch_size, n_latent]).astype(np.float32)

            this_lr_g = min(1e-2, max(1e-6, init_lr_g * (loss_g / loss_d)**2))
            this_lr_d = min(1e-2, max(1e-6, init_lr_d * (loss_d / loss_g)**2))

            # this_lr_d *= ((1.0 - (step_i / 100000)) ** 2)
            # this_lr_g *= ((1.0 - (step_i / 100000)) ** 2)

            # 2 out of 3 steps (or according to another random-based criterium,
            # cf. the commented if & equation), we train the Discriminator,
            # and the last one we train the Generator instead

            # if np.random.random() > (loss_g / (loss_d + loss_g)):
            if step_i % 3 == 1:
                loss_d, _, sum_d = sess.run(
                    [gan['loss_D'], opt_d, D_sum_op],
                    feed_dict={
                        gan['x']: batch_xs,
                        gan['z']: batch_zs,
                        gan['train']: True,
                        lr_d: this_lr_d
                    })
                total_loss_d += loss_d
                n_loss_d += 1

                writer.add_summary(sum_d, step_i)  # Tensorboard

                print('%04d d* = lr: %0.08f, loss: %08.06f, \t' %
                      (step_i, this_lr_d, loss_d) +
                      'g  = lr: %0.08f, loss: %08.06f' % (this_lr_g, loss_g))

            else:

                loss_g, _, sum_g = sess.run([gan['loss_G'], opt_g, G_sum_op],
                                            feed_dict={
                                                gan['z']: batch_zs,
                                                gan['train']: True,
                                                lr_g: this_lr_g
                                            })
                total_loss_g += loss_g
                n_loss_g += 1

                writer.add_summary(sum_g, step_i)  # Tensorboard

                print('%04d d  = lr: %0.08f, loss: %08.06f, \t' %
                      (step_i, this_lr_d, loss_d) +
                      'g* = lr: %0.08f, loss: %08.06f' % (this_lr_g, loss_g))

            if step_i % 100 == 0:

                samples = sess.run(gan['G'],
                                   feed_dict={
                                       gan['z']: zs,
                                       gan['train']: False
                                   })

                # Create a wall of images of the latent space (what the
                # network learns)
                montage(
                    np.clip((samples + 1) * 127.5, 0, 255).astype(np.uint8),
                    'imgs/gan_%08d.png' % t_i)
                t_i += 1

                print('generator loss:', total_loss_g / n_loss_g)
                print('discriminator loss:', total_loss_d / n_loss_d)

                # Save variable to disk
                save_path = saver.save(sess,
                                       "./gan.ckpt",
                                       global_step=step_i,
                                       write_meta_graph=False)
                print("Model saved in file: %s" % save_path)

    except tf.errors.OutOfRangeError:
        print('Done training -- epoch limit reached')
    finally:
        # One thread issued an exception, let them all shut down
        coord.request_stop()

    # Wait for them to finish
    coord.join(threads)

    # Clean up
    sess.close()
def train_vae(files,
              input_shape,
              learning_rate=0.0001,
              batch_size=100,
              n_epochs=50,
              n_examples=10,
              crop_shape=[64, 64, 3],
              crop_factor=0.8,
              n_filters=[100, 100, 100, 100],
              n_hidden=256,
              n_code=50,
              convolutional=True,
              variational=True,
              filter_sizes=[3, 3, 3, 3],
              dropout=True,
              keep_prob=0.8,
              activation=tf.nn.relu,
              img_step=100,
              save_step=100,
              ckpt_name="vae.ckpt"):
    """General purpose training of a (Variational) (Convolutional) Autoencoder.
    Supply a list of file paths to images, and this will do everything else.
    Parameters
    ----------
    files : list of strings
        List of paths to images.
    input_shape : list
        Must define what the input image's shape is.
    learning_rate : float, optional
        Learning rate.
    batch_size : int, optional
        Batch size.
    n_epochs : int, optional
        Number of epochs.
    n_examples : int, optional
        Number of example to use while demonstrating the current training
        iteration's reconstruction.  Creates a square montage, so make
        sure int(sqrt(n_examples))**2 = n_examples, e.g. 16, 25, 36, ... 100.
    crop_shape : list, optional
        Size to centrally crop the image to.
    crop_factor : float, optional
        Resize factor to apply before cropping.
    n_filters : list, optional
        Same as VAE's n_filters.
    n_hidden : int, optional
        Same as VAE's n_hidden.
    n_code : int, optional
        Same as VAE's n_code.
    convolutional : bool, optional
        Use convolution or not.
    variational : bool, optional
        Use variational layer or not.
    filter_sizes : list, optional
        Same as VAE's filter_sizes.
    dropout : bool, optional
        Use dropout or not
    keep_prob : float, optional
        Percent of keep for dropout.
    activation : function, optional
        Which activation function to use.
    img_step : int, optional
        How often to save training images showing the manifold and
        reconstruction.
    save_step : int, optional
        How often to save checkpoints.
    ckpt_name : str, optional
        Checkpoints will be named as this, e.g. 'model.ckpt'
    """
    batch = create_input_pipeline(files=files,
                                  batch_size=batch_size,
                                  n_epochs=n_epochs,
                                  crop_shape=crop_shape,
                                  crop_factor=crop_factor,
                                  shape=input_shape)

    ae = VAE(input_shape=[None] + crop_shape,
             convolutional=convolutional,
             variational=variational,
             n_filters=n_filters,
             n_hidden=n_hidden,
             n_code=n_code,
             dropout=dropout,
             filter_sizes=filter_sizes,
             activation=activation)

    # Create a manifold of our inner most layer to show
    # example reconstructions.  This is one way to see
    # what the "embedding" or "latent space" of the encoder
    # is capable of encoding, though note that this is just
    # a random hyperplane within the latent space, and does not
    # encompass all possible embeddings.
    zs = np.random.uniform(-1.0, 1.0, [4, n_code]).astype(np.float32)
    zs = utils.make_latent_manifold(zs, n_examples)

    optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(
        ae['cost'])

    # We create a session to use the graph
    sess = tf.Session()
    saver = tf.train.Saver()
    sess.run(tf.global_variables_initializer())

    # This will handle our threaded image pipeline
    coord = tf.train.Coordinator()

    # Ensure no more changes to graph
    tf.get_default_graph().finalize()

    # Start up the queues for handling the image pipeline
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    if os.path.exists(ckpt_name + '.index') or os.path.exists(ckpt_name):
        saver.restore(sess, ckpt_name)

    # Fit all training data
    t_i = 0
    batch_i = 0
    epoch_i = 0
    cost = 0
    n_files = len(files)
    test_xs = sess.run(batch) / 255.0
    utils.montage(test_xs, 'test_xs.png')
    try:
        while not coord.should_stop() and epoch_i < n_epochs:
            batch_i += 1
            batch_xs = sess.run(batch) / 255.0
            train_cost = sess.run([ae['cost'], optimizer],
                                  feed_dict={
                                      ae['x']: batch_xs,
                                      ae['train']: True,
                                      ae['keep_prob']: keep_prob
                                  })[0]
            print(batch_i, train_cost)
            cost += train_cost
            if batch_i % n_files == 0:
                print('epoch:', epoch_i)
                print('average cost:', cost / batch_i)
                cost = 0
                batch_i = 0
                epoch_i += 1

            if batch_i % img_step == 0:
                # Plot example reconstructions from latent layer
                recon = sess.run(ae['y'],
                                 feed_dict={
                                     ae['z']: zs,
                                     ae['train']: False,
                                     ae['keep_prob']: 1.0
                                 })
                utils.montage(recon.reshape([-1] + crop_shape),
                              'manifold_%08d.png' % t_i)

                # Plot example reconstructions
                recon = sess.run(ae['y'],
                                 feed_dict={
                                     ae['x']: test_xs,
                                     ae['train']: False,
                                     ae['keep_prob']: 1.0
                                 })
                print('reconstruction (min, max, mean):', recon.min(),
                      recon.max(), recon.mean())
                utils.montage(recon.reshape([-1] + crop_shape),
                              'reconstruction_%08d.png' % t_i)
                t_i += 1

            if batch_i % save_step == 0:
                # Save the variables to disk.
                saver.save(sess,
                           "./" + ckpt_name,
                           global_step=batch_i,
                           write_meta_graph=False)
    except tf.errors.OutOfRangeError:
        print('Done.')
    finally:
        # One of the threads has issued an exception.  So let's tell all the
        # threads to shutdown.
        coord.request_stop()

    # Wait until all threads have finished.
    coord.join(threads)

    # Clean up the session.
    sess.close()
Пример #4
0
def train_vaegan(files,
                 learning_rate=0.00001,
                 batch_size=64,
                 n_epochs=250,
                 n_examples=10,
                 input_shape=[218, 178, 3],
                 crop_shape=[64, 64, 3],
                 crop_factor=0.8,
                 n_filters=[100, 100, 100, 100],
                 n_hidden=None,
                 n_code=128,
                 convolutional=True,
                 variational=True,
                 filter_sizes=[3, 3, 3, 3],
                 activation=tf.nn.elu,
                 ckpt_name="vaegan.ckpt"):
    """Summary

    Parameters
    ----------
    files : TYPE
        Description
    learning_rate : float, optional
        Description
    batch_size : int, optional
        Description
    n_epochs : int, optional
        Description
    n_examples : int, optional
        Description
    input_shape : list, optional
        Description
    crop_shape : list, optional
        Description
    crop_factor : float, optional
        Description
    n_filters : list, optional
        Description
    n_hidden : int, optional
        Description
    n_code : int, optional
        Description
    convolutional : bool, optional
        Description
    variational : bool, optional
        Description
    filter_sizes : list, optional
        Description
    activation : TYPE, optional
        Description
    ckpt_name : str, optional
        Description

    Returns
    -------
    name : TYPE
        Description
    """

    ae = VAEGAN(input_shape=[None] + crop_shape,
                convolutional=convolutional,
                variational=variational,
                n_filters=n_filters,
                n_hidden=n_hidden,
                n_code=n_code,
                filter_sizes=filter_sizes,
                activation=activation)

    batch = create_input_pipeline(files=files,
                                  batch_size=batch_size,
                                  n_epochs=n_epochs,
                                  crop_shape=crop_shape,
                                  crop_factor=crop_factor,
                                  shape=input_shape)

    zs = np.random.randn(4, n_code).astype(np.float32)
    zs = make_latent_manifold(zs, n_examples)

    opt_enc = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(
        ae['loss_enc'],
        var_list=[
            var_i for var_i in tf.trainable_variables()
            if var_i.name.startswith('encoder')
        ])

    opt_gen = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(
        ae['loss_gen'],
        var_list=[
            var_i for var_i in tf.trainable_variables()
            if var_i.name.startswith('generator')
        ])

    opt_dis = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(
        ae['loss_dis'],
        var_list=[
            var_i for var_i in tf.trainable_variables()
            if var_i.name.startswith('discriminator')
        ])

    sess = tf.Session()
    saver = tf.train.Saver()
    sess.run(tf.global_variables_initializer())
    coord = tf.train.Coordinator()
    tf.get_default_graph().finalize()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    if os.path.exists(ckpt_name + '.index') or os.path.exists(ckpt_name):
        saver.restore(sess, ckpt_name)
        print("VAE model restored.")

    t_i = 0
    batch_i = 0
    epoch_i = 0

    equilibrium = 0.693
    margin = 0.4

    n_files = len(files)
    test_xs = sess.run(batch) / 255.0
    montage(test_xs, 'test_xs.png')
    try:
        while not coord.should_stop() and epoch_i < n_epochs:
            if batch_i % (n_files // batch_size) == 0:
                batch_i = 0
                epoch_i += 1
                print('---------- EPOCH:', epoch_i)

            batch_i += 1
            batch_xs = sess.run(batch) / 255.0
            batch_zs = np.random.randn(batch_size, n_code).astype(np.float32)
            real_cost, fake_cost, _ = sess.run(
                [ae['loss_real'], ae['loss_fake'], opt_enc],
                feed_dict={
                    ae['x']: batch_xs,
                    ae['gamma']: 0.5
                })
            real_cost = -np.mean(real_cost)
            fake_cost = -np.mean(fake_cost)
            print('real:', real_cost, '/ fake:', fake_cost)

            gen_update = True
            dis_update = True

            if real_cost > (equilibrium + margin) or \
                            fake_cost > (equilibrium + margin):
                gen_update = False

            if real_cost < (equilibrium - margin) or \
                            fake_cost < (equilibrium - margin):
                dis_update = False

            if not (gen_update or dis_update):
                gen_update = True
                dis_update = True

            if gen_update:
                sess.run(opt_gen,
                         feed_dict={
                             ae['x']: batch_xs,
                             ae['z_samp']: batch_zs,
                             ae['gamma']: 0.5
                         })
            if dis_update:
                sess.run(opt_dis,
                         feed_dict={
                             ae['x']: batch_xs,
                             ae['z_samp']: batch_zs,
                             ae['gamma']: 0.5
                         })

            if batch_i % 50 == 0:
                # Plot example reconstructions from latent layer
                recon = sess.run(ae['x_tilde'], feed_dict={ae['z']: zs})
                print('recon:', recon.min(), recon.max())
                recon = np.clip(recon / recon.max(), 0, 1)
                montage(recon.reshape([-1] + crop_shape),
                        'imgs/manifold_%08d.png' % t_i)

                # Plot example reconstructions
                recon = sess.run(ae['x_tilde'], feed_dict={ae['x']: test_xs})
                print('recon:', recon.min(), recon.max())
                recon = np.clip(recon / recon.max(), 0, 1)
                montage(recon.reshape([-1] + crop_shape),
                        'imgs/reconstruction_%08d.png' % t_i)
                t_i += 1

            if batch_i % 100 == 0:
                # Save the variables to disk.
                save_path = saver.save(sess,
                                       ckpt_name,
                                       global_step=batch_i,
                                       write_meta_graph=False)
                print("Model saved in file: %s" % save_path)
    except tf.errors.OutOfRangeError:
        print('Done training -- epoch limit reached')
    finally:
        # One of the threads has issued an exception.  So let's tell all the
        # threads to shutdown.
        coord.request_stop()

    # Wait until all threads have finished.
    coord.join(threads)

    # Clean up the session.
    sess.close()
def build_net(graph, training=True, validation=False):
    """Helper for creating a 2D convolution model.

    Parameters
    ----------
    graph : tf.Graph
        default graph to build model
    training : bool, optional
        if true, use training dataset
    validation : bool, optional
        if true, use validation dataset

    Returns
    -------
    batch : list
        list of images
    batch_labels : list
        list of labels for images
    batch_image_paths : list
        list of paths to image files
    init : tf.group
        initializer functions
    x :
        input image
    y :
        labels
    phase_train : tf.bool
        is training
    keep_prob : tf.float32
        keep probability for conv2d layers
    keep_prob_fc1 :  tf.float32
        keep probability for fully connected layer
    learning_rate : tf.float32
        learning rate
    h : 
        output of sigmoid
    loss : 
        loss
    optimizer : 
        optimizer
    saver : tf.train.Saver

    """

    with graph.as_default():
        x = tf.placeholder(tf.float32, [None] + resize_shape, 'x')
        # TODO: use len(labels_map)
        y = tf.placeholder(tf.int32, [None, 17], 'y')
        phase_train = tf.placeholder(tf.bool, name='phase_train')
        keep_prob = tf.placeholder(tf.float32, name='keep_prob')
        keep_prob_fc1 = tf.placeholder(tf.float32, name='keep_prob_fc1')
        learning_rate = tf.placeholder(tf.float32, name='learning_rate')

        # Create Input Pipeline for Train, Validation and Test Sets
        if training:
            batch, batch_labels, batch_image_paths = dsutils.create_input_pipeline(
                image_paths=image_paths[:index_split_train_val],
                labels=labels_onehot_list[:index_split_train_val],
                batch_size=batch_size,
                n_epochs=n_epochs,
                shape=input_shape,
                crop_factor=resize_factor,
                training=training,
                randomize=True)
        elif validation:
            batch, batch_labels, batch_image_paths = dsutils.create_input_pipeline(
                image_paths=image_paths[index_split_train_val:],
                labels=labels_onehot_list[index_split_train_val:],
                batch_size=batch_size,
                # only one epoch for test output
                n_epochs=1,
                shape=input_shape,
                crop_factor=resize_factor,
                training=training)
        else:
            batch, batch_labels, batch_image_paths = dsutils.create_input_pipeline(
                image_paths=test_image_paths,
                labels=test_onehot_list,
                batch_size=batch_size,
                # only one epoch for test output
                n_epochs=1,
                shape=input_shape,
                crop_factor=resize_factor,
                training=training)

        Ws = []

        current_input = x

        for layer_i, n_output in enumerate(n_filters):
            with tf.variable_scope('layer{}'.format(layer_i)):
                # 2D Convolutional Layer with batch normalization and relu
                h, W = utils.conv2d(x=current_input,
                                    n_output=n_output,
                                    k_h=filter_sizes[layer_i],
                                    k_w=filter_sizes[layer_i])
                h = tf.layers.batch_normalization(h, training=phase_train)
                h = tf.nn.relu(h, 'relu' + str(layer_i))

                # Apply Max Pooling Every 2nd Layer
                if layer_i % 2 == 0:
                    h = tf.nn.max_pool(value=h,
                                       ksize=[1, 2, 2, 1],
                                       strides=[1, 2, 2, 1],
                                       padding='SAME')

                # Apply Dropout Every 2nd Layer
                if layer_i % 2 == 0:
                    h = tf.nn.dropout(h, keep_prob)

                Ws.append(W)
                current_input = h

        h = utils.linear(current_input, fc_size, name='fc_t')[0]
        h = tf.layers.batch_normalization(h, training=phase_train)
        h = tf.nn.relu(h, name='fc_t/relu')
        h = tf.nn.dropout(h, keep_prob_fc1)

        logits = utils.linear(h, len(labels_map), name='fc_t2')[0]
        h = tf.nn.sigmoid(logits, 'fc_t2')

        # must be the same type as logits
        y_float = tf.cast(y, tf.float32)

        cross_entropy = tf.nn.sigmoid_cross_entropy_with_logits(logits=logits,
                                                                labels=y_float)
        loss = tf.reduce_mean(cross_entropy)

        if training:
            # update moving_mean and moving_variance so it will be available at inference time
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            with tf.control_dependencies(update_ops):
                optimizer = tf.train.AdamOptimizer(
                    learning_rate=learning_rate).minimize(loss)
        else:
            optimizer = tf.train.AdamOptimizer(
                learning_rate=learning_rate).minimize(loss)

        saver = tf.train.Saver()
        init = tf.group(tf.global_variables_initializer(),
                        tf.local_variables_initializer())
        return batch, batch_labels, batch_image_paths, init, x, y, phase_train, keep_prob, keep_prob_fc1, learning_rate, h, loss, optimizer, saver
Пример #6
0
def train_ds():
    """Summary
    
    Returns
    -------
    name : TYPE
        Description
    """
    init_lr_g = 1e-4
    init_lr_d = 1e-4
    n_latent = 100
    n_epochs = 1000000
    batch_size = 200
    n_samples = 15
    input_shape = [218, 178, 3]
    crop_shape = [64, 64, 3]
    crop_factor = 0.8

    from libs.dataset_utils import create_input_pipeline
    from libs.datasets import CELEB

    files = CELEB()
    batch = create_input_pipeline(
        files=files,
        batch_size=batch_size,
        n_epochs=n_epochs,
        crop_shape=crop_shape,
        crop_factor=crop_factor,
        shape=input_shape)

    gan = GAN(input_shape=[None] + crop_shape, n_features=10,
              n_latent=n_latent, rgb=True, debug=False)

    vars_d = [v for v in tf.trainable_variables()
              if v.name.startswith('discriminator')]
    print('Training discriminator variables:')
    [print(v.name) for v in tf.trainable_variables()
     if v.name.startswith('discriminator')]

    vars_g = [v for v in tf.trainable_variables()
              if v.name.startswith('generator')]
    print('Training generator variables:')
    [print(v.name) for v in tf.trainable_variables()
     if v.name.startswith('generator')]
    zs = np.random.uniform(
        -1.0, 1.0, [4, n_latent]).astype(np.float32)
    zs = make_latent_manifold(zs, n_samples)

    lr_g = tf.placeholder(tf.float32, shape=[], name='learning_rate_g')
    lr_d = tf.placeholder(tf.float32, shape=[], name='learning_rate_d')

    try:
        from tf.contrib.layers import apply_regularization
        d_reg = apply_regularization(
            tf.contrib.layers.l2_regularizer(1e-6), vars_d)
        g_reg = apply_regularization(
            tf.contrib.layers.l2_regularizer(1e-6), vars_g)
    except:
        d_reg, g_reg = 0, 0

    opt_g = tf.train.AdamOptimizer(lr_g, name='Adam_g').minimize(
        gan['loss_G'] + g_reg, var_list=vars_g)
    opt_d = tf.train.AdamOptimizer(lr_d, name='Adam_d').minimize(
        gan['loss_D'] + d_reg, var_list=vars_d)

    # %%
    # We create a session to use the graph
    config = tf.ConfigProto(device_count={'GPU': 0})
    sess = tf.Session(config=config)
    init_op = tf.initialize_all_variables()

    saver = tf.train.Saver()
    sums = gan['sums']
    G_sum_op = tf.merge_summary([
        sums['G'], sums['loss_G'], sums['z'],
        sums['loss_D_fake'], sums['D_fake']])
    D_sum_op = tf.merge_summary([
        sums['loss_D'], sums['loss_D_real'], sums['loss_D_fake'],
        sums['z'], sums['x'], sums['D_real'], sums['D_fake']])
    writer = tf.train.SummaryWriter("./logs", sess.graph)

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    sess.run(init_op)
    # g = tf.get_default_graph()
    # [print(op.name) for op in g.get_operations()]
    
    fig, ax = plt.subplots(1, 1, figsize=(10, 10))
    step_i, t_i = 0, 0
    loss_d = 1
    loss_g = 1
    n_loss_d, total_loss_d = 1, 1
    n_loss_g, total_loss_g = 1, 1
    
    ckpt = glob('gan.ckpt*')
    if ckpt:
        ckpt = sorted(ckpt)[-1]
        saver.restore(sess, ckpt)
        step_i = int(ckpt.split('-')[-1])
        t_i = int(sorted(glob('imgs/*.png'))[-1].split('_')[-1][:-4]) + 1
        print("GAN model restored.")

    try:
        while not coord.should_stop():
            batch_xs = sess.run(batch)
            step_i += 1
            batch_zs = np.random.uniform(
                -1.0, 1.0, [batch_size, n_latent]).astype(np.float32)

            this_lr_g = min(1e-2, max(1e-6, init_lr_g * (loss_g / loss_d)**2))
            this_lr_d = min(1e-2, max(1e-6, init_lr_d * (loss_d / loss_g)**2))
            # this_lr_d *= ((1.0 - (step_i / 100000)) ** 2)
            # this_lr_g *= ((1.0 - (step_i / 100000)) ** 2)

            # if np.random.random() > (loss_g / (loss_d + loss_g)):
            if step_i % 3 == 1:
                loss_d, _, sum_d = sess.run([gan['loss_D'], opt_d, D_sum_op],
                                            feed_dict={gan['x']: batch_xs,
                                                       gan['z']: batch_zs,
                                                       gan['train']: True,
                                                       lr_d: this_lr_d})
                total_loss_d += loss_d
                n_loss_d += 1
                writer.add_summary(sum_d, step_i)
                print('%04d d* = lr: %0.08f, loss: %08.06f, \t' %
                      (step_i, this_lr_d, loss_d) +
                      'g  = lr: %0.08f, loss: %08.06f' % (this_lr_g, loss_g))
            else:
                loss_g, _, sum_g = sess.run([gan['loss_G'], opt_g, G_sum_op],
                                            feed_dict={gan['z']: batch_zs,
                                                       gan['train']: True,
                                                       lr_g: this_lr_g})
                total_loss_g += loss_g
                n_loss_g += 1
                writer.add_summary(sum_g, step_i)
                print('%04d d  = lr: %0.08f, loss: %08.06f, \t' %
                      (step_i, this_lr_d, loss_d) +
                      'g* = lr: %0.08f, loss: %08.06f' % (this_lr_g, loss_g))

            if step_i % 100 == 0:
                samples = sess.run(gan['G'], feed_dict={
                    gan['z']: zs,
                    gan['train']: False})
                montage(np.clip((samples + 1) * 127.5, 0, 255).astype(np.uint8),
                        'imgs/gan_%08d.png' % t_i)
                t_i += 1

                print('generator loss:', total_loss_g / n_loss_g)
                print('discriminator loss:', total_loss_d / n_loss_d)

                # Save the variables to disk.
                save_path = saver.save(sess, "./gan.ckpt",
                                       global_step=step_i,
                                       write_meta_graph=False)
                print("Model saved in file: %s" % save_path)
    except tf.errors.OutOfRangeError:
        print('Done training -- epoch limit reached')
    finally:
        # One of the threads has issued an exception.  So let's tell all the
        # threads to shutdown.
        coord.request_stop()

    # Wait until all threads have finished.
    coord.join(threads)

    # Clean up the session.
    sess.close()
Пример #7
0
def train_ds():
    """Summary
    
    Returns
    -------
    name : TYPE
        Description
    """
    init_lr_g = 1e-4
    init_lr_d = 1e-4
    n_latent = 100
    n_epochs = 1000000
    batch_size = 200
    n_samples = 15
    input_shape = [218, 178, 3]
    crop_shape = [64, 64, 3]
    crop_factor = 0.8

    from libs.dataset_utils import create_input_pipeline
    from libs.datasets import CELEB

    files = CELEB()
    batch = create_input_pipeline(files=files,
                                  batch_size=batch_size,
                                  n_epochs=n_epochs,
                                  crop_shape=crop_shape,
                                  crop_factor=crop_factor,
                                  shape=input_shape)

    gan = GAN(input_shape=[None] + crop_shape,
              n_features=10,
              n_latent=n_latent,
              rgb=True,
              debug=False)

    vars_d = [
        v for v in tf.trainable_variables()
        if v.name.startswith('discriminator')
    ]
    print('Training discriminator variables:')
    [
        print(v.name) for v in tf.trainable_variables()
        if v.name.startswith('discriminator')
    ]

    vars_g = [
        v for v in tf.trainable_variables() if v.name.startswith('generator')
    ]
    print('Training generator variables:')
    [
        print(v.name) for v in tf.trainable_variables()
        if v.name.startswith('generator')
    ]
    zs = np.random.uniform(-1.0, 1.0, [4, n_latent]).astype(np.float32)
    zs = make_latent_manifold(zs, n_samples)

    lr_g = tf.placeholder(tf.float32, shape=[], name='learning_rate_g')
    lr_d = tf.placeholder(tf.float32, shape=[], name='learning_rate_d')

    try:
        from tf.contrib.layers import apply_regularization
        d_reg = apply_regularization(tf.contrib.layers.l2_regularizer(1e-6),
                                     vars_d)
        g_reg = apply_regularization(tf.contrib.layers.l2_regularizer(1e-6),
                                     vars_g)
    except:
        d_reg, g_reg = 0, 0

    opt_g = tf.train.AdamOptimizer(lr_g, name='Adam_g').minimize(
        gan['loss_G'] + g_reg, var_list=vars_g)
    opt_d = tf.train.AdamOptimizer(lr_d, name='Adam_d').minimize(
        gan['loss_D'] + d_reg, var_list=vars_d)

    # %%
    # We create a session to use the graph
    config = tf.ConfigProto(device_count={'GPU': 0})
    sess = tf.Session(config=config)
    init_op = tf.initialize_all_variables()

    saver = tf.train.Saver()
    sums = gan['sums']
    G_sum_op = tf.merge_summary([
        sums['G'], sums['loss_G'], sums['z'], sums['loss_D_fake'],
        sums['D_fake']
    ])
    D_sum_op = tf.merge_summary([
        sums['loss_D'], sums['loss_D_real'], sums['loss_D_fake'], sums['z'],
        sums['x'], sums['D_real'], sums['D_fake']
    ])
    writer = tf.train.SummaryWriter("./logs", sess.graph)

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    sess.run(init_op)
    # g = tf.get_default_graph()
    # [print(op.name) for op in g.get_operations()]

    fig, ax = plt.subplots(1, 1, figsize=(10, 10))
    step_i, t_i = 0, 0
    loss_d = 1
    loss_g = 1
    n_loss_d, total_loss_d = 1, 1
    n_loss_g, total_loss_g = 1, 1

    ckpt = glob('gan.ckpt*')
    if ckpt:
        ckpt = sorted(ckpt)[-1]
        saver.restore(sess, ckpt)
        step_i = int(ckpt.split('-')[-1])
        t_i = int(sorted(glob('imgs/*.png'))[-1].split('_')[-1][:-4]) + 1
        print("GAN model restored.")

    try:
        while not coord.should_stop():
            batch_xs = sess.run(batch)
            step_i += 1
            batch_zs = np.random.uniform(
                -1.0, 1.0, [batch_size, n_latent]).astype(np.float32)

            this_lr_g = min(1e-2, max(1e-6, init_lr_g * (loss_g / loss_d)**2))
            this_lr_d = min(1e-2, max(1e-6, init_lr_d * (loss_d / loss_g)**2))
            # this_lr_d *= ((1.0 - (step_i / 100000)) ** 2)
            # this_lr_g *= ((1.0 - (step_i / 100000)) ** 2)

            # if np.random.random() > (loss_g / (loss_d + loss_g)):
            if step_i % 3 == 1:
                loss_d, _, sum_d = sess.run(
                    [gan['loss_D'], opt_d, D_sum_op],
                    feed_dict={
                        gan['x']: batch_xs,
                        gan['z']: batch_zs,
                        gan['train']: True,
                        lr_d: this_lr_d
                    })
                total_loss_d += loss_d
                n_loss_d += 1
                writer.add_summary(sum_d, step_i)
                print('%04d d* = lr: %0.08f, loss: %08.06f, \t' %
                      (step_i, this_lr_d, loss_d) +
                      'g  = lr: %0.08f, loss: %08.06f' % (this_lr_g, loss_g))
            else:
                loss_g, _, sum_g = sess.run([gan['loss_G'], opt_g, G_sum_op],
                                            feed_dict={
                                                gan['z']: batch_zs,
                                                gan['train']: True,
                                                lr_g: this_lr_g
                                            })
                total_loss_g += loss_g
                n_loss_g += 1
                writer.add_summary(sum_g, step_i)
                print('%04d d  = lr: %0.08f, loss: %08.06f, \t' %
                      (step_i, this_lr_d, loss_d) +
                      'g* = lr: %0.08f, loss: %08.06f' % (this_lr_g, loss_g))

            if step_i % 100 == 0:
                samples = sess.run(gan['G'],
                                   feed_dict={
                                       gan['z']: zs,
                                       gan['train']: False
                                   })
                montage(
                    np.clip((samples + 1) * 127.5, 0, 255).astype(np.uint8),
                    'imgs/gan_%08d.png' % t_i)
                t_i += 1

                print('generator loss:', total_loss_g / n_loss_g)
                print('discriminator loss:', total_loss_d / n_loss_d)

                # Save the variables to disk.
                save_path = saver.save(sess,
                                       "./gan.ckpt",
                                       global_step=step_i,
                                       write_meta_graph=False)
                print("Model saved in file: %s" % save_path)
    except tf.errors.OutOfRangeError:
        print('Done training -- epoch limit reached')
    finally:
        # One of the threads has issued an exception.  So let's tell all the
        # threads to shutdown.
        coord.request_stop()

    # Wait until all threads have finished.
    coord.join(threads)

    # Clean up the session.
    sess.close()
Пример #8
0
def train_vae(files,
              input_shape,
              learning_rate=0.0001,
              batch_size=100,
              n_epochs=50,
              n_examples=10,
              crop_shape=[64, 64, 3],
              crop_factor=0.8,
              n_filters=[100, 100, 100, 100],
              n_hidden=256,
              n_code=50,
              convolutional=True,
              variational=True,
              filter_sizes=[3, 3, 3, 3],
              dropout=True,
              keep_prob=0.8,
              activation=tf.nn.relu,
              img_step=100,
              save_step=100,
              ckpt_name="vae.ckpt"):
    """General purpose training of a (Variational) (Convolutional) Autoencoder.

    Supply a list of file paths to images, and this will do everything else.

    Parameters
    ----------
    files : list of strings
        List of paths to images.
    input_shape : list
        Must define what the input image's shape is.
    learning_rate : float, optional
        Learning rate.
    batch_size : int, optional
        Batch size.
    n_epochs : int, optional
        Number of epochs.
    n_examples : int, optional
        Number of example to use while demonstrating the current training
        iteration's reconstruction.  Creates a square montage, so make
        sure int(sqrt(n_examples))**2 = n_examples, e.g. 16, 25, 36, ... 100.
    crop_shape : list, optional
        Size to centrally crop the image to.
    crop_factor : float, optional
        Resize factor to apply before cropping.
    n_filters : list, optional
        Same as VAE's n_filters.
    n_hidden : int, optional
        Same as VAE's n_hidden.
    n_code : int, optional
        Same as VAE's n_code.
    convolutional : bool, optional
        Use convolution or not.
    variational : bool, optional
        Use variational layer or not.
    filter_sizes : list, optional
        Same as VAE's filter_sizes.
    dropout : bool, optional
        Use dropout or not
    keep_prob : float, optional
        Percent of keep for dropout.
    activation : function, optional
        Which activation function to use.
    img_step : int, optional
        How often to save training images showing the manifold and
        reconstruction.
    save_step : int, optional
        How often to save checkpoints.
    ckpt_name : str, optional
        Checkpoints will be named as this, e.g. 'model.ckpt'
    """
    batch = create_input_pipeline(
        files=files,
        batch_size=batch_size,
        n_epochs=n_epochs,
        crop_shape=crop_shape,
        crop_factor=crop_factor,
        shape=input_shape)

    ae = VAE(input_shape=[None] + crop_shape,
             convolutional=convolutional,
             variational=variational,
             n_filters=n_filters,
             n_hidden=n_hidden,
             n_code=n_code,
             dropout=dropout,
             filter_sizes=filter_sizes,
             activation=activation)

    # Create a manifold of our inner most layer to show
    # example reconstructions.  This is one way to see
    # what the "embedding" or "latent space" of the encoder
    # is capable of encoding, though note that this is just
    # a random hyperplane within the latent space, and does not
    # encompass all possible embeddings.
    zs = np.random.uniform(
        -1.0, 1.0, [4, n_code]).astype(np.float32)
    zs = utils.make_latent_manifold(zs, n_examples)

    optimizer = tf.train.AdamOptimizer(
        learning_rate=learning_rate).minimize(ae['cost'])

    # We create a session to use the graph
    sess = tf.Session()
    saver = tf.train.Saver()
    sess.run(tf.global_variables_initializer())

    # This will handle our threaded image pipeline
    coord = tf.train.Coordinator()

    # Ensure no more changes to graph
    tf.get_default_graph().finalize()

    # Start up the queues for handling the image pipeline
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    if os.path.exists(ckpt_name + '.index') or os.path.exists(ckpt_name):
        saver.restore(sess, ckpt_name)

    # Fit all training data
    t_i = 0
    batch_i = 0
    epoch_i = 0
    cost = 0
    n_files = len(files)
    test_xs = sess.run(batch) / 255.0
    utils.montage(test_xs, 'test_xs.png')
    try:
        while not coord.should_stop() and epoch_i < n_epochs:
            batch_i += 1
            batch_xs = sess.run(batch) / 255.0
            train_cost = sess.run([ae['cost'], optimizer], feed_dict={
                ae['x']: batch_xs, ae['train']: True,
                ae['keep_prob']: keep_prob})[0]
            print(batch_i, train_cost)
            cost += train_cost
            if batch_i % n_files == 0:
                print('epoch:', epoch_i)
                print('average cost:', cost / batch_i)
                cost = 0
                batch_i = 0
                epoch_i += 1

            if batch_i % img_step == 0:
                # Plot example reconstructions from latent layer
                recon = sess.run(
                    ae['y'], feed_dict={
                        ae['z']: zs,
                        ae['train']: False,
                        ae['keep_prob']: 1.0})
                utils.montage(recon.reshape([-1] + crop_shape),
                              'manifold_%08d.png' % t_i)

                # Plot example reconstructions
                recon = sess.run(
                    ae['y'], feed_dict={ae['x']: test_xs,
                                        ae['train']: False,
                                        ae['keep_prob']: 1.0})
                print('reconstruction (min, max, mean):',
                    recon.min(), recon.max(), recon.mean())
                utils.montage(recon.reshape([-1] + crop_shape),
                              'reconstruction_%08d.png' % t_i)
                t_i += 1

            if batch_i % save_step == 0:
                # Save the variables to disk.
                saver.save(sess, "./" + ckpt_name,
                           global_step=batch_i,
                           write_meta_graph=False)
    except tf.errors.OutOfRangeError:
        print('Done.')
    finally:
        # One of the threads has issued an exception.  So let's tell all the
        # threads to shutdown.
        coord.request_stop()

    # Wait until all threads have finished.
    coord.join(threads)

    # Clean up the session.
    sess.close()
Пример #9
0
def train_vaegan(files,
                 learning_rate=0.00001,
                 batch_size=64,
                 n_epochs=250,
                 n_examples=10,
                 input_shape=[218, 178, 3],
                 crop_shape=[64, 64, 3],
                 crop_factor=0.8,
                 n_filters=[100, 100, 100, 100],
                 n_hidden=None,
                 n_code=128,
                 convolutional=True,
                 variational=True,
                 filter_sizes=[3, 3, 3, 3],
                 activation=tf.nn.elu,
                 ckpt_name="vaegan.ckpt"):
    """Summary

    Parameters
    ----------
    files : TYPE
        Description
    learning_rate : float, optional
        Description
    batch_size : int, optional
        Description
    n_epochs : int, optional
        Description
    n_examples : int, optional
        Description
    input_shape : list, optional
        Description
    crop_shape : list, optional
        Description
    crop_factor : float, optional
        Description
    n_filters : list, optional
        Description
    n_hidden : int, optional
        Description
    n_code : int, optional
        Description
    convolutional : bool, optional
        Description
    variational : bool, optional
        Description
    filter_sizes : list, optional
        Description
    activation : TYPE, optional
        Description
    ckpt_name : str, optional
        Description

    Returns
    -------
    name : TYPE
        Description
    """

    ae = VAEGAN(input_shape=[None] + crop_shape,
                convolutional=convolutional,
                variational=variational,
                n_filters=n_filters,
                n_hidden=n_hidden,
                n_code=n_code,
                filter_sizes=filter_sizes,
                activation=activation)

    batch = create_input_pipeline(
        files=files,
        batch_size=batch_size,
        n_epochs=n_epochs,
        crop_shape=crop_shape,
        crop_factor=crop_factor,
        shape=input_shape)

    zs = np.random.randn(4, n_code).astype(np.float32)
    zs = make_latent_manifold(zs, n_examples)

    opt_enc = tf.train.AdamOptimizer(
        learning_rate=learning_rate).minimize(
        ae['loss_enc'],
        var_list=[var_i for var_i in tf.trainable_variables()
                  if var_i.name.startswith('encoder')])

    opt_gen = tf.train.AdamOptimizer(
        learning_rate=learning_rate).minimize(
        ae['loss_gen'],
        var_list=[var_i for var_i in tf.trainable_variables()
                  if var_i.name.startswith('generator')])

    opt_dis = tf.train.AdamOptimizer(
        learning_rate=learning_rate).minimize(
        ae['loss_dis'],
        var_list=[var_i for var_i in tf.trainable_variables()
                  if var_i.name.startswith('discriminator')])

    sess = tf.Session()
    saver = tf.train.Saver()
    sess.run(tf.global_variables_initializer())
    coord = tf.train.Coordinator()
    tf.get_default_graph().finalize()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    if os.path.exists(ckpt_name + '.index') or os.path.exists(ckpt_name):
        saver.restore(sess, ckpt_name)
        print("VAE model restored.")

    t_i = 0
    batch_i = 0
    epoch_i = 0

    equilibrium = 0.693
    margin = 0.4

    n_files = len(files)
    test_xs = sess.run(batch) / 255.0
    montage(test_xs, 'test_xs.png')
    try:
        while not coord.should_stop() and epoch_i < n_epochs:
            if batch_i % (n_files // batch_size) == 0:
                batch_i = 0
                epoch_i += 1
                print('---------- EPOCH:', epoch_i)

            batch_i += 1
            batch_xs = sess.run(batch) / 255.0
            batch_zs = np.random.randn(batch_size, n_code).astype(np.float32)
            real_cost, fake_cost, _ = sess.run([
                ae['loss_real'], ae['loss_fake'], opt_enc],
                feed_dict={
                    ae['x']: batch_xs,
                    ae['gamma']: 0.5})
            real_cost = -np.mean(real_cost)
            fake_cost = -np.mean(fake_cost)
            print('real:', real_cost, '/ fake:', fake_cost)

            gen_update = True
            dis_update = True

            if real_cost > (equilibrium + margin) or \
               fake_cost > (equilibrium + margin):
                gen_update = False

            if real_cost < (equilibrium - margin) or \
               fake_cost < (equilibrium - margin):
                dis_update = False

            if not (gen_update or dis_update):
                gen_update = True
                dis_update = True

            if gen_update:
                sess.run(opt_gen, feed_dict={
                    ae['x']: batch_xs,
                    ae['z_samp']: batch_zs,
                    ae['gamma']: 0.5})
            if dis_update:
                sess.run(opt_dis, feed_dict={
                    ae['x']: batch_xs,
                    ae['z_samp']: batch_zs,
                    ae['gamma']: 0.5})

            if batch_i % 50 == 0:

                # Plot example reconstructions from latent layer
                recon = sess.run(
                    ae['x_tilde'], feed_dict={
                        ae['z']: zs})
                print('recon:', recon.min(), recon.max())
                recon = np.clip(recon / recon.max(), 0, 1)
                montage(recon.reshape([-1] + crop_shape),
                        'imgs/manifold_%08d.png' % t_i)

                # Plot example reconstructions
                recon = sess.run(
                    ae['x_tilde'], feed_dict={
                        ae['x']: test_xs})
                print('recon:', recon.min(), recon.max())
                recon = np.clip(recon / recon.max(), 0, 1)
                montage(recon.reshape([-1] + crop_shape),
                        'imgs/reconstruction_%08d.png' % t_i)
                t_i += 1

            if batch_i % 100 == 0:
                # Save the variables to disk.
                save_path = saver.save(sess, ckpt_name,
                                       global_step=batch_i,
                                       write_meta_graph=False)
                print("Model saved in file: %s" % save_path)
    except tf.errors.OutOfRangeError:
        print('Done training -- epoch limit reached')
    finally:
        # One of the threads has issued an exception.  So let's tell all the
        # threads to shutdown.
        coord.request_stop()

    # Wait until all threads have finished.
    coord.join(threads)

    # Clean up the session.
    sess.close()