def test_mnist():
    """Train an autoencoder on MNIST.
    This function will train an autoencoder on MNIST and also
    save many image files during the training process, demonstrating
    the latent space of the inner most dimension of the encoder,
    as well as reconstructions of the decoder.
    """

    # load MNIST
    n_code = 2
    mnist = MNIST(split=[0.8, 0.1, 0.1])
    ae = VAE(input_shape=[None, 784],
             n_filters=[512, 256],
             n_hidden=64,
             n_code=n_code,
             activation=tf.nn.sigmoid,
             convolutional=False,
             variational=True)

    n_examples = 100
    zs = np.random.uniform(-1.0, 1.0, [4, n_code]).astype(np.float32)
    zs = utils.make_latent_manifold(zs, n_examples)

    learning_rate = 0.02
    optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(
        ae['cost'])

    # We create a session to use the graph
    sess = tf.Session()
    sess.run(tf.global_variables_initializer())

    # Fit all training data
    t_i = 0
    batch_i = 0
    batch_size = 200
    n_epochs = 10
    test_xs = mnist.test.images[:n_examples]
    utils.montage(test_xs.reshape((-1, 28, 28)), 'test_xs.png')
    for epoch_i in range(n_epochs):
        train_i = 0
        train_cost = 0
        for batch_xs, _ in mnist.train.next_batch(batch_size):
            train_cost += sess.run([ae['cost'], optimizer],
                                   feed_dict={
                                       ae['x']: batch_xs,
                                       ae['train']: True,
                                       ae['keep_prob']: 1.0
                                   })[0]
            train_i += 1
            if batch_i % 10 == 0:
                # Plot example reconstructions from latent layer
                recon = sess.run(ae['y'],
                                 feed_dict={
                                     ae['z']: zs,
                                     ae['train']: False,
                                     ae['keep_prob']: 1.0
                                 })
                utils.montage(recon.reshape((-1, 28, 28)),
                              'manifold_%08d.png' % t_i)
                # Plot example reconstructions
                recon = sess.run(ae['y'],
                                 feed_dict={
                                     ae['x']: test_xs,
                                     ae['train']: False,
                                     ae['keep_prob']: 1.0
                                 })
                utils.montage(recon.reshape((-1, 28, 28)),
                              'reconstruction_%08d.png' % t_i)
                t_i += 1
            batch_i += 1

        valid_i = 0
        valid_cost = 0
        for batch_xs, _ in mnist.valid.next_batch(batch_size):
            valid_cost += sess.run([ae['cost']],
                                   feed_dict={
                                       ae['x']: batch_xs,
                                       ae['train']: False,
                                       ae['keep_prob']: 1.0
                                   })[0]
            valid_i += 1
        print('train:', train_cost / train_i, 'valid:', valid_cost / valid_i)
Exemple #2
0
def train_vae(files,
              input_shape=[None, 784],
              output_shape=[None, 784],
              learning_rate=0.0001,
              batch_size=128,
              n_epochs=50,
              crop_shape=[64, 64],
              crop_factor=0.8,
              n_filters=[100, 100, 100, 100],
              n_hidden=256,
              n_code=50,
              denoising=True,
              convolutional=True,
              variational=True,
              softmax=False,
              classifier='alexnet_v2',
              filter_sizes=[3, 3, 3, 3],
              dropout=True,
              keep_prob=0.8,
              activation=tf.nn.relu,
              img_step=1000,
              save_step=2500,
              output_path="result",
              ckpt_name="vae.ckpt"):
    """General purpose training of a (Variational) (Convolutional) Autoencoder.

    Supply a list of file paths to images, and this will do everything else.

    Parameters
    ----------
    files : list of strings
    List of paths to images.
    input_shape : list
    Must define what the input image's shape is.
    learning_rate : float, optional
    Learning rate.
    batch_size : int, optional
    Batch size.
    n_epochs : int, optional
    Number of epochs.
    n_examples : int, optional
    Number of example to use while demonstrating the current training
    iteration's reconstruction.  Creates a square montage, so make
    sure int(sqrt(n_examples))**2 = n_examples, e.g. 16, 25, 36, ... 100.
    crop_shape : list, optional
    Size to centrally crop the image to.
    crop_factor : float, optional
    Resize factor to apply before cropping.
    n_filters : list, optional
    Same as VAE's n_filters.
    n_hidden : int, optional
    Same as VAE's n_hidden.
    n_code : int, optional
    Same as VAE's n_code.
    convolutional : bool, optional
    Use convolution or not.
    variational : bool, optional
    Use variational layer or not.
    softmax : bool, optional
    Use the classification network or not.
    classifier : str, optional
    Network for classification.
    filter_sizes : list, optional
    Same as VAE's filter_sizes.
    dropout : bool, optional
    Use dropout or not
    keep_prob : float, optional
    Percent of keep for dropout.
    activation : function, optional
    Which activation function to use.
    img_step : int, optional
    How often to save training images showing the manifold and
    reconstruction.
    save_step : int, optional
    How often to save checkpoints.
    ckpt_name : str, optional
    Checkpoints will be named as this, e.g. 'model.ckpt'
    """

    batch_train = create_input_pipeline(files=files,
                                        batch_size=batch_size,
                                        n_epochs=n_epochs,
                                        crop_shape=crop_shape,
                                        crop_factor=crop_factor,
                                        input_shape=input_shape,
                                        output_shape=output_shape)

    if softmax:
        batch_imagenet = create_input_pipeline(
            files="./list_annotated_imagenet.csv",
            batch_size=batch_size,
            n_epochs=n_epochs,
            crop_shape=crop_shape,
            crop_factor=crop_factor,
            input_shape=input_shape,
            output_shape=output_shape)
        batch_pascal = create_input_pipeline(
            files="./list_annotated_pascal.csv",
            batch_size=batch_size,
            n_epochs=n_epochs,
            crop_shape=crop_shape,
            crop_factor=crop_factor,
            input_shape=input_shape,
            output_shape=output_shape)
        batch_shapenet = create_input_pipeline(
            files="./list_annotated_img_test.csv",
            batch_size=batch_size,
            n_epochs=n_epochs,
            crop_shape=crop_shape,
            crop_factor=crop_factor,
            input_shape=input_shape,
            output_shape=output_shape)

    ae = VAE(input_shape=[None] + crop_shape + [input_shape[-1]],
             output_shape=[None] + crop_shape + [output_shape[-1]],
             denoising=denoising,
             convolutional=convolutional,
             variational=variational,
             softmax=softmax,
             n_filters=n_filters,
             n_hidden=n_hidden,
             n_code=n_code,
             dropout=dropout,
             filter_sizes=filter_sizes,
             activation=activation,
             classifier=classifier)

    with open(files, "r") as f:
        reader = csv.reader(f, delimiter=",")
        data = list(reader)
        n_files = len(data)

    # Create a manifold of our inner most layer to show
    # example reconstructions.  This is one way to see
    # what the "embedding" or "latent space" of the encoder
    # is capable of encoding, though note that this is just
    # a random hyperplane within the latent space, and does not
    # encompass all possible embeddings.
    np.random.seed(1)
    zs = np.random.uniform(-1.0, 1.0, [4, n_code]).astype(np.float32)
    zs = utils.make_latent_manifold(zs, 6)

    optimizer_vae = tf.train.AdamOptimizer(
        learning_rate=learning_rate).minimize(ae['cost_vae'])
    if softmax:
        # AlexNet for 0.01,
        # Iception v1 for 0.01
        # SqueezeNet for 0.01
        if classifier == 'inception_v3':
            lr = tf.train.exponential_decay(0.1,
                                            0,
                                            n_files / batch_size * 20,
                                            0.16,
                                            staircase=True)
            optimizer_softmax = tf.train.RMSPropOptimizer(
                lr, decay=0.9, momentum=0.9,
                epsilon=0.1).minimize(ae['cost_s'])
        elif classifier == 'inception_v2':
            optimizer_softmax = tf.train.AdamOptimizer(
                learning_rate=0.01).minimize(ae['cost_s'])
        elif classifier == 'inception_v1':
            optimizer_softmax = tf.train.GradientDescentOptimizer(
                learning_rate=0.01).minimize(ae['cost_s'])
        elif (classifier == 'squeezenet') or (classifier == 'zigzagnet'):
            optimizer_softmax = tf.train.RMSPropOptimizer(
                0.04, decay=0.9, momentum=0.9,
                epsilon=0.1).minimize(ae['cost_s'])
        elif classifier == 'alexnet_v2':
            optimizer_softmax = tf.train.GradientDescentOptimizer(
                learning_rate=0.01).minimize(ae['cost_s'])
        else:
            optimizer_softmax = tf.train.GradientDescentOptimizer(
                learning_rate=0.001).minimize(ae['cost_s'])

    # We create a session to use the graph together with a GPU declaration.
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    # config.gpu_options.per_process_gpu_memory_fraction = 0.4
    sess = tf.Session(config=config)
    saver = tf.train.Saver()
    sess.run(tf.global_variables_initializer())
    train_writer = tf.summary.FileWriter('./summary', sess.graph)

    # This will handle our threaded image pipeline
    coord = tf.train.Coordinator()

    # Ensure no more changes to graph
    tf.get_default_graph().finalize()

    # Start up the queues for handling the image pipeline
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    if (os.path.exists(output_path + '/' + ckpt_name + '.index')
            or os.path.exists(ckpt_name)):
        saver.restore(sess, output_path + '/' + ckpt_name)
        print("Model restored")

    # Fit all training data
    t_i = 0
    step_i = 0
    batch_i = 0
    epoch_i = 0
    summary_i = 0
    cost = 0
    # Test samples of training data from ShapeNet
    test_xs_img, test_xs_obj, test_xs_label = sess.run(batch_train)
    test_xs_img /= 255.0
    test_xs_obj /= 255.0
    utils.montage(test_xs_img, output_path + '/train_img.png')
    utils.montage(test_xs_obj, output_path + '/train_obj.png')

    # Test samples of testing data from ImageNet
    test_imagenet_img, _, test_imagenet_label = sess.run(batch_imagenet)
    test_imagenet_img /= 255.0
    utils.montage(test_imagenet_img, output_path + '/test_imagenet_img.png')

    # Test samples of testing data from PASCAL 2012
    test_pascal_img, _, test_pascal_label = sess.run(batch_pascal)
    test_pascal_img /= 255.0
    utils.montage(test_pascal_img, output_path + '/test_pascal_img.png')

    # Test samples of testing data from ShapeNet test data
    test_shapenet_img, _, test_shapenet_label = sess.run(batch_shapenet)
    test_shapenet_img /= 255.0
    utils.montage(test_shapenet_img, output_path + '/test_shapenet_img.png')
    try:
        while not coord.should_stop():
            batch_i += 1
            step_i += 1
            batch_xs_img, batch_xs_obj, batch_xs_label = sess.run(batch_train)
            batch_xs_img /= 255.0
            batch_xs_obj /= 255.0

            # Here we must set corrupt_rec and corrupt_cls as 0 to find a
            # proper ratio of variance to feed for variable var_prob.
            # We use tanh as non-linear function for ratio of Vars from
            # the reconstructed channels and original channels
            var_prob = sess.run(ae['var_prob'],
                                feed_dict={
                                    ae['x']: test_xs_img,
                                    ae['label']: test_xs_label[:, 0],
                                    ae['train']: True,
                                    ae['keep_prob']: 1.0,
                                    ae['corrupt_rec']: 0,
                                    ae['corrupt_cls']: 0
                                })

            # Here is a fast training process
            corrupt_rec = np.tanh(0.25 * var_prob)
            corrupt_cls = np.tanh(1 - np.tanh(2 * var_prob))

            # Optimizing reconstruction network
            cost_vae = sess.run(
                [ae['cost_vae'], optimizer_vae],
                feed_dict={
                    ae['x']: batch_xs_img,
                    ae['t']: batch_xs_obj,
                    ae['label']: batch_xs_label[:, 0],
                    ae['train']: True,
                    ae['keep_prob']: keep_prob,
                    ae['corrupt_rec']: corrupt_rec,
                    ae['corrupt_cls']: corrupt_cls
                })[0]
            cost += cost_vae
            if softmax:

                # Optimizing classification network
                cost_s = sess.run(
                    [ae['cost_s'], optimizer_softmax],
                    feed_dict={
                        ae['x']: batch_xs_img,
                        ae['t']: batch_xs_obj,
                        ae['label']: batch_xs_label[:, 0],
                        ae['train']: True,
                        ae['keep_prob']: keep_prob,
                        ae['corrupt_rec']: corrupt_rec,
                        ae['corrupt_cls']: corrupt_cls
                    })[0]
                cost += cost_s

            if step_i % img_step == 0:
                if variational:
                    # Plot example reconstructions from latent layer
                    recon = sess.run(ae['y'],
                                     feed_dict={
                                         ae['z']: zs,
                                         ae['train']: False,
                                         ae['keep_prob']: 1.0,
                                         ae['corrupt_rec']: 0,
                                         ae['corrupt_cls']: 0
                                     })
                    utils.montage(recon.reshape([-1] + crop_shape),
                                  output_path + '/manifold_%08d.png' % t_i)

                # Plot example reconstructions
                recon = sess.run(ae['y'],
                                 feed_dict={
                                     ae['x']: test_xs_img,
                                     ae['train']: False,
                                     ae['keep_prob']: 1.0,
                                     ae['corrupt_rec']: 0,
                                     ae['corrupt_cls']: 0
                                 })
                utils.montage(recon.reshape([-1] + crop_shape),
                              output_path + '/recon_%08d.png' % t_i)
                """
                filters = sess.run(
                  ae['Ws'], feed_dict={
                              ae['x']: test_xs_img,
                              ae['train']: False,
                              ae['keep_prob']: 1.0,
                              ae['corrupt_rec']: 0,
                              ae['corrupt_cls']: 0})
                #for filter_element in filters:
                utils.montage_filters(filters[-1],
                                output_path + '/filter_%08d.png' % t_i)
                """

                # Test on ImageNet samples
                with open('./list_annotated_imagenet.csv', 'r') as csvfile:
                    spamreader = csv.reader(csvfile)
                    rows = list(spamreader)
                    totalrows = len(rows)
                num_batches = np.int_(np.floor(totalrows / batch_size))
                accumulated_acc = 0
                for index_batch in range(1, num_batches + 1):
                    test_image, _, test_label = sess.run(batch_imagenet)
                    test_image /= 255.0
                    acc, z_codes, sm_codes = sess.run(
                        [ae['acc'], ae['z'], ae['predictions']],
                        feed_dict={
                            ae['x']: test_image,
                            ae['label']: test_label[:, 0],
                            ae['train']: False,
                            ae['keep_prob']: 1.0,
                            ae['corrupt_rec']: 0,
                            ae['corrupt_cls']: 0
                        })
                    accumulated_acc += acc.tolist().count(True) / acc.size
                    if index_batch == 1:
                        z_imagenet = z_codes
                        sm_imagenet = sm_codes
                        labels_imagenet = test_label
                        # Plot example reconstructions
                        recon = sess.run(ae['y'],
                                         feed_dict={
                                             ae['x']: test_imagenet_img,
                                             ae['train']: False,
                                             ae['keep_prob']: 1.0,
                                             ae['corrupt_rec']: 0,
                                             ae['corrupt_cls']: 0
                                         })
                        utils.montage(
                            recon.reshape([-1] + crop_shape),
                            output_path + '/recon_imagenet_%08d.png' % t_i)
                    else:
                        z_imagenet = np.append(z_imagenet, z_codes, axis=0)
                        sm_imagenet = np.append(sm_imagenet, sm_codes, axis=0)
                        labels_imagenet = np.append(labels_imagenet,
                                                    test_label,
                                                    axis=0)
                accumulated_acc /= num_batches
                print("Accuracy of ImageNet images= %.3f" % (accumulated_acc))

                fig = plt.figure()
                z_viz, V = pca(z_imagenet, dim_remain=2)
                ax = fig.add_subplot(121)
                # ax.set_aspect('equal')
                ax.scatter(z_viz[:, 0],
                           z_viz[:, 1],
                           c=labels_imagenet[:, 0],
                           alpha=0.4,
                           cmap='gist_rainbow')
                sm_viz, V = pca(sm_imagenet, dim_remain=2)
                ax = fig.add_subplot(122)
                # ax.set_aspect('equal')
                ax.scatter(sm_viz[:, 0],
                           sm_viz[:, 1],
                           c=labels_imagenet[:, 0],
                           alpha=0.4,
                           cmap='gist_rainbow')

                fig.savefig(output_path + '/z_feat_imagenet.png',
                            transparent=True)
                plt.clf()

                # Test on PASCAL 2012 samples
                with open('./list_annotated_pascal.csv', 'r') as csvfile:
                    spamreader = csv.reader(csvfile)
                    rows = list(spamreader)
                    totalrows = len(rows)
                num_batches = np.int_(np.floor(totalrows / batch_size))
                accumulated_acc = 0
                for index_batch in range(1, num_batches + 1):
                    test_image, _, test_label = sess.run(batch_pascal)
                    test_image /= 255.0
                    acc, z_codes, sm_codes = sess.run(
                        [ae['acc'], ae['z'], ae['predictions']],
                        feed_dict={
                            ae['x']: test_image,
                            ae['label']: test_label[:, 0],
                            ae['train']: False,
                            ae['keep_prob']: 1.0,
                            ae['corrupt_rec']: 0,
                            ae['corrupt_cls']: 0
                        })
                    accumulated_acc += acc.tolist().count(True) / acc.size
                    if index_batch == 1:
                        z_pascal = z_codes
                        sm_pascal = sm_codes
                        labels_pascal = test_label
                        # Plot example reconstructions
                        recon = sess.run(ae['y'],
                                         feed_dict={
                                             ae['x']: test_pascal_img,
                                             ae['train']: False,
                                             ae['keep_prob']: 1.0,
                                             ae['corrupt_rec']: 0,
                                             ae['corrupt_cls']: 0
                                         })
                        utils.montage(
                            recon.reshape([-1] + crop_shape),
                            output_path + '/recon_pascal_%08d.png' % t_i)
                    else:
                        z_pascal = np.append(z_pascal, z_codes, axis=0)
                        sm_pascal = np.append(sm_pascal, sm_codes, axis=0)
                        labels_pascal = np.append(labels_pascal,
                                                  test_label,
                                                  axis=0)
                accumulated_acc /= num_batches
                print("Accuracy of PASCAL images= %.3f" % (accumulated_acc))

                fig = plt.figure()
                z_viz, V = pca(z_pascal, dim_remain=2)
                ax = fig.add_subplot(121)
                # ax.set_aspect('equal')
                ax.scatter(z_viz[:, 0],
                           z_viz[:, 1],
                           c=labels_pascal[:, 0],
                           alpha=0.4,
                           cmap='gist_rainbow')
                sm_viz, V = pca(sm_pascal, dim_remain=2)
                ax = fig.add_subplot(122)
                # ax.set_aspect('equal')
                ax.scatter(sm_viz[:, 0],
                           sm_viz[:, 1],
                           c=labels_pascal[:, 0],
                           alpha=0.4,
                           cmap='gist_rainbow')

                fig.savefig(output_path + '/z_feat_pascal.png',
                            transparent=True)
                plt.clf()

                # Test on ShapeNet test samples
                with open('./list_annotated_img_test.csv', 'r') as csvfile:
                    spamreader = csv.reader(csvfile)
                    rows = list(spamreader)
                    totalrows = len(rows)
                num_batches = np.int_(np.floor(totalrows / batch_size))
                accumulated_acc = 0
                for index_batch in range(1, num_batches + 1):
                    test_image, _, test_label = sess.run(batch_shapenet)
                    test_image /= 255.0
                    acc, z_codes, sm_codes = sess.run(
                        [ae['acc'], ae['z'], ae['predictions']],
                        feed_dict={
                            ae['x']: test_image,
                            ae['label']: test_label[:, 0],
                            ae['train']: False,
                            ae['keep_prob']: 1.0,
                            ae['corrupt_rec']: 0,
                            ae['corrupt_cls']: 0
                        })
                    accumulated_acc += acc.tolist().count(True) / acc.size
                    if index_batch == 1:
                        z_shapenet = z_codes
                        sm_shapenet = sm_codes
                        labels_shapenet = test_label
                        # Plot example reconstructions
                        recon = sess.run(ae['y'],
                                         feed_dict={
                                             ae['x']: test_shapenet_img,
                                             ae['train']: False,
                                             ae['keep_prob']: 1.0,
                                             ae['corrupt_rec']: 0,
                                             ae['corrupt_cls']: 0
                                         })
                        utils.montage(
                            recon.reshape([-1] + crop_shape),
                            output_path + '/recon_shapenet_%08d.png' % t_i)
                    else:
                        z_shapenet = np.append(z_shapenet, z_codes, axis=0)
                        sm_shapenet = np.append(sm_shapenet, sm_codes, axis=0)
                        labels_shapenet = np.append(labels_shapenet,
                                                    test_label,
                                                    axis=0)
                accumulated_acc /= num_batches
                print("Accuracy of ShapeNet images= %.3f" % (accumulated_acc))

                fig = plt.figure()
                z_viz, V = pca(z_shapenet, dim_remain=2)
                ax = fig.add_subplot(121)
                # ax.set_aspect('equal')
                ax.scatter(z_viz[:, 0],
                           z_viz[:, 1],
                           c=labels_shapenet[:, 0],
                           alpha=0.4,
                           cmap='gist_rainbow')
                sm_viz, V = pca(sm_shapenet, dim_remain=2)
                ax = fig.add_subplot(122)
                # ax.set_aspect('equal')
                ax.scatter(sm_viz[:, 0],
                           sm_viz[:, 1],
                           c=labels_shapenet[:, 0],
                           alpha=0.4,
                           cmap='gist_rainbow')

                fig.savefig(output_path + '/z_feat_shapenet.png',
                            transparent=True)
                plt.clf()

                t_i += 1

            if step_i % save_step == 0:

                # Save the variables to disk.
                # We should set global_step=batch_i if we want several ckpt
                saver.save(sess,
                           output_path + "/" + ckpt_name,
                           global_step=None,
                           write_meta_graph=False)
                if softmax:
                    acc = sess.run(ae['acc'],
                                   feed_dict={
                                       ae['x']: test_xs_img,
                                       ae['label']: test_xs_label[:, 0],
                                       ae['train']: False,
                                       ae['keep_prob']: 1.0,
                                       ae['corrupt_rec']: 0,
                                       ae['corrupt_cls']: 0
                                   })

                    print(
                        "epoch %d: VAE = %d, SM = %.3f, Acc = %.3f, R_Var = %.3f, Cpt_R = %.3f, Cpt_C = %.3f"
                        %
                        (epoch_i, cost_vae, cost_s, acc.tolist().count(True) /
                         acc.size, var_prob, corrupt_rec, corrupt_cls))

                    # Summary recording to Tensorboard
                    summary = sess.run(ae['merged'],
                                       feed_dict={
                                           ae['x']: batch_xs_img,
                                           ae['t']: batch_xs_obj,
                                           ae['label']: batch_xs_label[:, 0],
                                           ae['train']: False,
                                           ae['keep_prob']: keep_prob,
                                           ae['corrupt_rec']: corrupt_rec,
                                           ae['corrupt_cls']: corrupt_cls
                                       })

                    summary_i += 1
                    train_writer.add_summary(summary, summary_i)
                else:
                    print("VAE loss = %d" % cost_vae)

            if batch_i > (n_files / batch_size):
                batch_i = 0
                epoch_i += 1

    except tf.errors.OutOfRangeError:
        print('Done.')
    finally:
        # One of the threads has issued an exception.  So let's tell all the
        # threads to shutdown.
        coord.request_stop()

    # Wait until all threads have finished.
    coord.join(threads)

    # Clean up the session.
    sess.close()
def train_vae(files,
              input_shape,
              learning_rate=0.0001,
              batch_size=100,
              n_epochs=50,
              n_examples=10,
              crop_shape=[64, 64, 3],
              crop_factor=0.8,
              n_filters=[100, 100, 100, 100],
              n_hidden=256,
              n_code=50,
              convolutional=True,
              variational=True,
              filter_sizes=[3, 3, 3, 3],
              dropout=True,
              keep_prob=0.8,
              activation=tf.nn.relu,
              img_step=100,
              save_step=100,
              ckpt_name="vae.ckpt"):
    """General purpose training of a (Variational) (Convolutional) Autoencoder.
    Supply a list of file paths to images, and this will do everything else.
    Parameters
    ----------
    files : list of strings
        List of paths to images.
    input_shape : list
        Must define what the input image's shape is.
    learning_rate : float, optional
        Learning rate.
    batch_size : int, optional
        Batch size.
    n_epochs : int, optional
        Number of epochs.
    n_examples : int, optional
        Number of example to use while demonstrating the current training
        iteration's reconstruction.  Creates a square montage, so make
        sure int(sqrt(n_examples))**2 = n_examples, e.g. 16, 25, 36, ... 100.
    crop_shape : list, optional
        Size to centrally crop the image to.
    crop_factor : float, optional
        Resize factor to apply before cropping.
    n_filters : list, optional
        Same as VAE's n_filters.
    n_hidden : int, optional
        Same as VAE's n_hidden.
    n_code : int, optional
        Same as VAE's n_code.
    convolutional : bool, optional
        Use convolution or not.
    variational : bool, optional
        Use variational layer or not.
    filter_sizes : list, optional
        Same as VAE's filter_sizes.
    dropout : bool, optional
        Use dropout or not
    keep_prob : float, optional
        Percent of keep for dropout.
    activation : function, optional
        Which activation function to use.
    img_step : int, optional
        How often to save training images showing the manifold and
        reconstruction.
    save_step : int, optional
        How often to save checkpoints.
    ckpt_name : str, optional
        Checkpoints will be named as this, e.g. 'model.ckpt'
    """
    batch = create_input_pipeline(files=files,
                                  batch_size=batch_size,
                                  n_epochs=n_epochs,
                                  crop_shape=crop_shape,
                                  crop_factor=crop_factor,
                                  shape=input_shape)

    ae = VAE(input_shape=[None] + crop_shape,
             convolutional=convolutional,
             variational=variational,
             n_filters=n_filters,
             n_hidden=n_hidden,
             n_code=n_code,
             dropout=dropout,
             filter_sizes=filter_sizes,
             activation=activation)

    # Create a manifold of our inner most layer to show
    # example reconstructions.  This is one way to see
    # what the "embedding" or "latent space" of the encoder
    # is capable of encoding, though note that this is just
    # a random hyperplane within the latent space, and does not
    # encompass all possible embeddings.
    zs = np.random.uniform(-1.0, 1.0, [4, n_code]).astype(np.float32)
    zs = utils.make_latent_manifold(zs, n_examples)

    optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(
        ae['cost'])

    # We create a session to use the graph
    sess = tf.Session()
    saver = tf.train.Saver()
    sess.run(tf.global_variables_initializer())

    # This will handle our threaded image pipeline
    coord = tf.train.Coordinator()

    # Ensure no more changes to graph
    tf.get_default_graph().finalize()

    # Start up the queues for handling the image pipeline
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    if os.path.exists(ckpt_name + '.index') or os.path.exists(ckpt_name):
        saver.restore(sess, ckpt_name)

    # Fit all training data
    t_i = 0
    batch_i = 0
    epoch_i = 0
    cost = 0
    n_files = len(files)
    test_xs = sess.run(batch) / 255.0
    utils.montage(test_xs, 'test_xs.png')
    try:
        while not coord.should_stop() and epoch_i < n_epochs:
            batch_i += 1
            batch_xs = sess.run(batch) / 255.0
            train_cost = sess.run([ae['cost'], optimizer],
                                  feed_dict={
                                      ae['x']: batch_xs,
                                      ae['train']: True,
                                      ae['keep_prob']: keep_prob
                                  })[0]
            print(batch_i, train_cost)
            cost += train_cost
            if batch_i % n_files == 0:
                print('epoch:', epoch_i)
                print('average cost:', cost / batch_i)
                cost = 0
                batch_i = 0
                epoch_i += 1

            if batch_i % img_step == 0:
                # Plot example reconstructions from latent layer
                recon = sess.run(ae['y'],
                                 feed_dict={
                                     ae['z']: zs,
                                     ae['train']: False,
                                     ae['keep_prob']: 1.0
                                 })
                utils.montage(recon.reshape([-1] + crop_shape),
                              'manifold_%08d.png' % t_i)

                # Plot example reconstructions
                recon = sess.run(ae['y'],
                                 feed_dict={
                                     ae['x']: test_xs,
                                     ae['train']: False,
                                     ae['keep_prob']: 1.0
                                 })
                print('reconstruction (min, max, mean):', recon.min(),
                      recon.max(), recon.mean())
                utils.montage(recon.reshape([-1] + crop_shape),
                              'reconstruction_%08d.png' % t_i)
                t_i += 1

            if batch_i % save_step == 0:
                # Save the variables to disk.
                saver.save(sess,
                           "./" + ckpt_name,
                           global_step=batch_i,
                           write_meta_graph=False)
    except tf.errors.OutOfRangeError:
        print('Done.')
    finally:
        # One of the threads has issued an exception.  So let's tell all the
        # threads to shutdown.
        coord.request_stop()

    # Wait until all threads have finished.
    coord.join(threads)

    # Clean up the session.
    sess.close()
Exemple #4
0
def train_vaegan(files,
                 learning_rate=0.00001,
                 batch_size=64,
                 n_epochs=250,
                 n_examples=10,
                 input_shape=[218, 178, 3],
                 crop_shape=[64, 64, 3],
                 crop_factor=0.8,
                 n_filters=[100, 100, 100, 100],
                 n_hidden=None,
                 n_code=128,
                 convolutional=True,
                 variational=True,
                 filter_sizes=[3, 3, 3, 3],
                 activation=tf.nn.elu,
                 ckpt_name="vaegan.ckpt"):
    """Summary

    Parameters
    ----------
    files : TYPE
        Description
    learning_rate : float, optional
        Description
    batch_size : int, optional
        Description
    n_epochs : int, optional
        Description
    n_examples : int, optional
        Description
    input_shape : list, optional
        Description
    crop_shape : list, optional
        Description
    crop_factor : float, optional
        Description
    n_filters : list, optional
        Description
    n_hidden : int, optional
        Description
    n_code : int, optional
        Description
    convolutional : bool, optional
        Description
    variational : bool, optional
        Description
    filter_sizes : list, optional
        Description
    activation : TYPE, optional
        Description
    ckpt_name : str, optional
        Description

    No Longer Returned
    ------------------
    name : TYPE
        Description
    """

    ae = VAEGAN(input_shape=[None] + crop_shape,
                convolutional=convolutional,
                variational=variational,
                n_filters=n_filters,
                n_hidden=n_hidden,
                n_code=n_code,
                filter_sizes=filter_sizes,
                activation=activation)

    batch = create_input_pipeline(files=files,
                                  batch_size=batch_size,
                                  n_epochs=n_epochs,
                                  crop_shape=crop_shape,
                                  crop_factor=crop_factor,
                                  shape=input_shape)

    zs = np.random.randn(4, n_code).astype(np.float32)
    zs = utils.make_latent_manifold(zs, n_examples)

    opt_enc = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(
        ae['loss_enc'],
        var_list=[
            var_i for var_i in tf.trainable_variables()
            if var_i.name.startswith('encoder')
        ])

    opt_gen = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(
        ae['loss_gen'],
        var_list=[
            var_i for var_i in tf.trainable_variables()
            if var_i.name.startswith('generator')
        ])

    opt_dis = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(
        ae['loss_dis'],
        var_list=[
            var_i for var_i in tf.trainable_variables()
            if var_i.name.startswith('discriminator')
        ])

    sess = tf.Session()
    saver = tf.train.Saver()
    init_op = tf.group(tf.global_variables_initializer(),
                       tf.local_variables_initializer())
    sess.run(init_op)
    coord = tf.train.Coordinator()
    tf.get_default_graph().finalize()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    if os.path.exists(ckpt_name + '.index') or os.path.exists(ckpt_name):
        saver.restore(sess, ckpt_name)
        print("VAE model restored.")

    t_i = 0
    batch_i = 0
    epoch_i = 0

    equilibrium = 0.693
    margin = 0.4

    n_files = len(files)
    test_xs = sess.run(batch) / 255.0
    utils.montage(test_xs, 'test_xs.png')
    try:
        while not coord.should_stop() and epoch_i < n_epochs:
            if batch_i % (n_files // batch_size) == 0:
                batch_i = 0
                epoch_i += 1
                print('---------- EPOCH:', epoch_i)

            batch_i += 1
            batch_xs = sess.run(batch) / 255.0
            batch_zs = np.random.randn(batch_size, n_code).astype(np.float32)
            real_cost, fake_cost, _ = sess.run(
                [ae['loss_real'], ae['loss_fake'], opt_enc],
                feed_dict={
                    ae['x']: batch_xs,
                    ae['gamma']: 0.5
                })
            real_cost = -np.mean(real_cost)
            fake_cost = -np.mean(fake_cost)
            print('real:', real_cost, '/ fake:', fake_cost)

            gen_update = True
            dis_update = True

            if real_cost > (equilibrium + margin) or \
               fake_cost > (equilibrium + margin):
                gen_update = False

            if real_cost < (equilibrium - margin) or \
               fake_cost < (equilibrium - margin):
                dis_update = False

            if not (gen_update or dis_update):
                gen_update = True
                dis_update = True

            if gen_update:
                sess.run(opt_gen,
                         feed_dict={
                             ae['x']: batch_xs,
                             ae['z_samp']: batch_zs,
                             ae['gamma']: 0.5
                         })
            if dis_update:
                sess.run(opt_dis,
                         feed_dict={
                             ae['x']: batch_xs,
                             ae['z_samp']: batch_zs,
                             ae['gamma']: 0.5
                         })

            if batch_i % 50 == 0:

                # Plot example reconstructions from latent layer
                recon = sess.run(ae['x_tilde'], feed_dict={ae['z']: zs})
                print('recon:', recon.min(), recon.max())
                recon = np.clip(recon / recon.max(), 0, 1)
                utils.montage(recon.reshape([-1] + crop_shape),
                              'imgs/manifold_%08d.png' % t_i)

                # Plot example reconstructions
                recon = sess.run(ae['x_tilde'], feed_dict={ae['x']: test_xs})
                print('recon:', recon.min(), recon.max())
                recon = np.clip(recon / recon.max(), 0, 1)
                utils.montage(recon.reshape([-1] + crop_shape),
                              'imgs/reconstruction_%08d.png' % t_i)
                t_i += 1

            if batch_i % 100 == 0:
                # Save the variables to disk.
                save_path = saver.save(sess,
                                       ckpt_name,
                                       global_step=batch_i,
                                       write_meta_graph=False)
                print("Model saved in file: %s" % save_path)
    except tf.errors.OutOfRangeError:
        print('Done training -- epoch limit reached')
    finally:
        # One of the threads has issued an exception.  So let's tell all the
        # threads to shutdown.
        coord.request_stop()

    # Wait until all threads have finished.
    coord.join(threads)

    # Clean up the session.
    sess.close()
Exemple #5
0
def test_mnist(n_epochs=10):
    """Train an autoencoder on MNIST.

    This function will train an autoencoder on MNIST and also
    save many image files during the training process, demonstrating
    the latent space of the inner most dimension of the encoder,
    as well as reconstructions of the decoder.
    """

    # load MNIST
    n_code = 2
    mnist = MNIST(split=[0.8, 0.1, 0.1])
    ae = VAE(input_shape=[None, 784], n_filters=[512, 256],
             n_hidden=64, n_code=n_code, activation=tf.nn.sigmoid,
             convolutional=False, variational=True)

    n_examples = 100
    zs = np.random.uniform(
        -1.0, 1.0, [4, n_code]).astype(np.float32)
    zs = utils.make_latent_manifold(zs, n_examples)

    learning_rate = 0.02
    optimizer = tf.train.AdamOptimizer(
        learning_rate=learning_rate).minimize(ae['cost'])

    # We create a session to use the graph
    sess = tf.Session()
    sess.run(tf.global_variables_initializer())

    # Fit all training data
    t_i = 0
    batch_i = 0
    batch_size = 200
    test_xs = mnist.test.images[:n_examples]
    utils.montage(test_xs.reshape((-1, 28, 28)), 'test_xs.png')
    for epoch_i in range(n_epochs):
        train_i = 0
        train_cost = 0
        for batch_xs, _ in mnist.train.next_batch(batch_size):
            train_cost += sess.run([ae['cost'], optimizer], feed_dict={
                ae['x']: batch_xs, ae['train']: True, ae['keep_prob']: 1.0})[0]
            train_i += 1
            if batch_i % 10 == 0:
                # Plot example reconstructions from latent layer
                recon = sess.run(
                    ae['y'], feed_dict={
                        ae['z']: zs,
                        ae['train']: False,
                        ae['keep_prob']: 1.0})
                m = utils.montage(recon.reshape((-1, 28, 28)),
                    'manifold_%08d.png' % t_i)
                # Plot example reconstructions
                recon = sess.run(
                    ae['y'], feed_dict={ae['x']: test_xs,
                                        ae['train']: False,
                                        ae['keep_prob']: 1.0})
                m = utils.montage(recon.reshape(
                    (-1, 28, 28)), 'reconstruction_%08d.png' % t_i)
                t_i += 1
            batch_i += 1

        valid_i = 0
        valid_cost = 0
        for batch_xs, _ in mnist.valid.next_batch(batch_size):
            valid_cost += sess.run([ae['cost']], feed_dict={
                ae['x']: batch_xs, ae['train']: False, ae['keep_prob']: 1.0})[0]
            valid_i += 1
        print('train:', train_cost / train_i, 'valid:', valid_cost / valid_i)
Exemple #6
0
def train_vae(files,
              input_shape,
              learning_rate=0.0001,
              batch_size=100,
              n_epochs=50,
              n_examples=10,
              crop_shape=[64, 64, 3],
              crop_factor=0.8,
              n_filters=[100, 100, 100, 100],
              n_hidden=256,
              n_code=50,
              convolutional=True,
              variational=True,
              filter_sizes=[3, 3, 3, 3],
              dropout=True,
              keep_prob=0.8,
              activation=tf.nn.relu,
              img_step=100,
              save_step=100,
              ckpt_name="vae.ckpt"):
    """General purpose training of a (Variational) (Convolutional) Autoencoder.

    Supply a list of file paths to images, and this will do everything else.

    Parameters
    ----------
    files : list of strings
        List of paths to images.
    input_shape : list
        Must define what the input image's shape is.
    learning_rate : float, optional
        Learning rate.
    batch_size : int, optional
        Batch size.
    n_epochs : int, optional
        Number of epochs.
    n_examples : int, optional
        Number of example to use while demonstrating the current training
        iteration's reconstruction.  Creates a square montage, so make
        sure int(sqrt(n_examples))**2 = n_examples, e.g. 16, 25, 36, ... 100.
    crop_shape : list, optional
        Size to centrally crop the image to.
    crop_factor : float, optional
        Resize factor to apply before cropping.
    n_filters : list, optional
        Same as VAE's n_filters.
    n_hidden : int, optional
        Same as VAE's n_hidden.
    n_code : int, optional
        Same as VAE's n_code.
    convolutional : bool, optional
        Use convolution or not.
    variational : bool, optional
        Use variational layer or not.
    filter_sizes : list, optional
        Same as VAE's filter_sizes.
    dropout : bool, optional
        Use dropout or not
    keep_prob : float, optional
        Percent of keep for dropout.
    activation : function, optional
        Which activation function to use.
    img_step : int, optional
        How often to save training images showing the manifold and
        reconstruction.
    save_step : int, optional
        How often to save checkpoints.
    ckpt_name : str, optional
        Checkpoints will be named as this, e.g. 'model.ckpt'
    """
    batch = create_input_pipeline(
        files=files,
        batch_size=batch_size,
        n_epochs=n_epochs,
        crop_shape=crop_shape,
        crop_factor=crop_factor,
        shape=input_shape)

    ae = VAE(input_shape=[None] + crop_shape,
             convolutional=convolutional,
             variational=variational,
             n_filters=n_filters,
             n_hidden=n_hidden,
             n_code=n_code,
             dropout=dropout,
             filter_sizes=filter_sizes,
             activation=activation)

    # Create a manifold of our inner most layer to show
    # example reconstructions.  This is one way to see
    # what the "embedding" or "latent space" of the encoder
    # is capable of encoding, though note that this is just
    # a random hyperplane within the latent space, and does not
    # encompass all possible embeddings.
    zs = np.random.uniform(
        -1.0, 1.0, [4, n_code]).astype(np.float32)
    zs = utils.make_latent_manifold(zs, n_examples)

    optimizer = tf.train.AdamOptimizer(
        learning_rate=learning_rate).minimize(ae['cost'])

    # We create a session to use the graph
    sess = tf.Session()
    saver = tf.train.Saver()
    sess.run(tf.global_variables_initializer())

    # This will handle our threaded image pipeline
    coord = tf.train.Coordinator()

    # Ensure no more changes to graph
    tf.get_default_graph().finalize()

    # Start up the queues for handling the image pipeline
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    if os.path.exists(ckpt_name + '.index') or os.path.exists(ckpt_name):
        saver.restore(sess, ckpt_name)

    # Fit all training data
    t_i = 0
    batch_i = 0
    epoch_i = 0
    cost = 0
    n_files = len(files)
    test_xs = sess.run(batch) / 255.0
    utils.montage(test_xs, 'test_xs.png')
    try:
        while not coord.should_stop() and epoch_i < n_epochs:
            batch_i += 1
            batch_xs = sess.run(batch) / 255.0
            train_cost = sess.run([ae['cost'], optimizer], feed_dict={
                ae['x']: batch_xs, ae['train']: True,
                ae['keep_prob']: keep_prob})[0]
            print(batch_i, train_cost)
            cost += train_cost
            if batch_i % n_files == 0:
                print('epoch:', epoch_i)
                print('average cost:', cost / batch_i)
                cost = 0
                batch_i = 0
                epoch_i += 1

            if batch_i % img_step == 0:
                # Plot example reconstructions from latent layer
                recon = sess.run(
                    ae['y'], feed_dict={
                        ae['z']: zs,
                        ae['train']: False,
                        ae['keep_prob']: 1.0})
                utils.montage(recon.reshape([-1] + crop_shape),
                              'manifold_%08d.png' % t_i)

                # Plot example reconstructions
                recon = sess.run(
                    ae['y'], feed_dict={ae['x']: test_xs,
                                        ae['train']: False,
                                        ae['keep_prob']: 1.0})
                print('reconstruction (min, max, mean):',
                    recon.min(), recon.max(), recon.mean())
                utils.montage(recon.reshape([-1] + crop_shape),
                              'reconstruction_%08d.png' % t_i)
                t_i += 1

            if batch_i % save_step == 0:
                # Save the variables to disk.
                saver.save(sess, "./" + ckpt_name,
                           global_step=batch_i,
                           write_meta_graph=False)
    except tf.errors.OutOfRangeError:
        print('Done.')
    finally:
        # One of the threads has issued an exception.  So let's tell all the
        # threads to shutdown.
        coord.request_stop()

    # Wait until all threads have finished.
    coord.join(threads)

    # Clean up the session.
    sess.close()
Exemple #7
0
def train_vaegan(files,
                 learning_rate=0.00001,
                 batch_size=64,
                 n_epochs=250,
                 n_examples=10,
                 input_shape=[218, 178, 3],
                 crop_shape=[64, 64, 3],
                 crop_factor=0.8,
                 n_filters=[100, 100, 100, 100],
                 n_hidden=None,
                 n_code=128,
                 convolutional=True,
                 variational=True,
                 filter_sizes=[3, 3, 3, 3],
                 activation=tf.nn.elu,
                 ckpt_name="vaegan.ckpt"):
    """Summary

    Parameters
    ----------
    files : TYPE
        Description
    learning_rate : float, optional
        Description
    batch_size : int, optional
        Description
    n_epochs : int, optional
        Description
    n_examples : int, optional
        Description
    input_shape : list, optional
        Description
    crop_shape : list, optional
        Description
    crop_factor : float, optional
        Description
    n_filters : list, optional
        Description
    n_hidden : int, optional
        Description
    n_code : int, optional
        Description
    convolutional : bool, optional
        Description
    variational : bool, optional
        Description
    filter_sizes : list, optional
        Description
    activation : TYPE, optional
        Description
    ckpt_name : str, optional
        Description

    No Longer Returned
    ------------------
    name : TYPE
        Description
    """

    ae = VAEGAN(
        input_shape=[None] + crop_shape,
        convolutional=convolutional,
        variational=variational,
        n_filters=n_filters,
        n_hidden=n_hidden,
        n_code=n_code,
        filter_sizes=filter_sizes,
        activation=activation)

    batch = create_input_pipeline(
        files=files,
        batch_size=batch_size,
        n_epochs=n_epochs,
        crop_shape=crop_shape,
        crop_factor=crop_factor,
        shape=input_shape)

    zs = np.random.randn(4, n_code).astype(np.float32)
    zs = utils.make_latent_manifold(zs, n_examples)

    opt_enc = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(
        ae['loss_enc'],
        var_list=[
            var_i for var_i in tf.trainable_variables()
            if var_i.name.startswith('encoder')
        ])

    opt_gen = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(
        ae['loss_gen'],
        var_list=[
            var_i for var_i in tf.trainable_variables()
            if var_i.name.startswith('generator')
        ])

    opt_dis = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(
        ae['loss_dis'],
        var_list=[
            var_i for var_i in tf.trainable_variables()
            if var_i.name.startswith('discriminator')
        ])

    sess = tf.Session()
    saver = tf.train.Saver()
    init_op = tf.group(tf.global_variables_initializer(),
                       tf.local_variables_initializer())
    sess.run(init_op)
    coord = tf.train.Coordinator()
    tf.get_default_graph().finalize()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    if os.path.exists(ckpt_name + '.index') or os.path.exists(ckpt_name):
        saver.restore(sess, ckpt_name)
        print("VAE model restored.")

    t_i = 0
    batch_i = 0
    epoch_i = 0

    equilibrium = 0.693
    margin = 0.4

    n_files = len(files)
    test_xs = sess.run(batch) / 255.0
    utils.montage(test_xs, 'test_xs.png')
    try:
        while not coord.should_stop() and epoch_i < n_epochs:
            if batch_i % (n_files // batch_size) == 0:
                batch_i = 0
                epoch_i += 1
                print('---------- EPOCH:', epoch_i)

            batch_i += 1
            batch_xs = sess.run(batch) / 255.0
            batch_zs = np.random.randn(batch_size, n_code).astype(np.float32)
            real_cost, fake_cost, _ = sess.run(
                [ae['loss_real'], ae['loss_fake'], opt_enc],
                feed_dict={ae['x']: batch_xs,
                           ae['gamma']: 0.5})
            real_cost = -np.mean(real_cost)
            fake_cost = -np.mean(fake_cost)
            print('real:', real_cost, '/ fake:', fake_cost)

            gen_update = True
            dis_update = True

            if real_cost > (equilibrium + margin) or \
               fake_cost > (equilibrium + margin):
                gen_update = False

            if real_cost < (equilibrium - margin) or \
               fake_cost < (equilibrium - margin):
                dis_update = False

            if not (gen_update or dis_update):
                gen_update = True
                dis_update = True

            if gen_update:
                sess.run(
                    opt_gen,
                    feed_dict={
                        ae['x']: batch_xs,
                        ae['z_samp']: batch_zs,
                        ae['gamma']: 0.5
                    })
            if dis_update:
                sess.run(
                    opt_dis,
                    feed_dict={
                        ae['x']: batch_xs,
                        ae['z_samp']: batch_zs,
                        ae['gamma']: 0.5
                    })

            if batch_i % 50 == 0:

                # Plot example reconstructions from latent layer
                recon = sess.run(ae['x_tilde'], feed_dict={ae['z']: zs})
                print('recon:', recon.min(), recon.max())
                recon = np.clip(recon / recon.max(), 0, 1)
                utils.montage(
                    recon.reshape([-1] + crop_shape),
                    'imgs/manifold_%08d.png' % t_i)

                # Plot example reconstructions
                recon = sess.run(ae['x_tilde'], feed_dict={ae['x']: test_xs})
                print('recon:', recon.min(), recon.max())
                recon = np.clip(recon / recon.max(), 0, 1)
                utils.montage(
                    recon.reshape([-1] + crop_shape),
                    'imgs/reconstruction_%08d.png' % t_i)
                t_i += 1

            if batch_i % 100 == 0:
                # Save the variables to disk.
                save_path = saver.save(
                    sess,
                    ckpt_name,
                    global_step=batch_i,
                    write_meta_graph=False)
                print("Model saved in file: %s" % save_path)
    except tf.errors.OutOfRangeError:
        print('Done training -- epoch limit reached')
    finally:
        # One of the threads has issued an exception.  So let's tell all the
        # threads to shutdown.
        coord.request_stop()

    # Wait until all threads have finished.
    coord.join(threads)

    # Clean up the session.
    sess.close()