Пример #1
0
 def test_input_pipeline(self):
     Xs, Ys = dsu.tiny_imagenet_load()
     n_batches = 0
     batch_size = 10
     with tf.Graph().as_default(), tf.Session() as sess:
         batch_generator = dsu.create_input_pipeline(
             Xs[:100],
             batch_size=batch_size,
             n_epochs=1,
             shape=(64, 64, 3),
             crop_shape=(64, 64, 3))
         init_op = tf.group(tf.global_variables_initializer(),
                            tf.local_variables_initializer())
         sess.run(init_op)
         coord = tf.train.Coordinator()
         tf.get_default_graph().finalize()
         threads = tf.train.start_queue_runners(sess=sess, coord=coord)
         try:
             while not coord.should_stop():
                 batch = sess.run(batch_generator)
                 assert (batch.shape == (batch_size, 64, 64, 3))
                 n_batches += 1
         except tf.errors.OutOfRangeError:
             pass
         finally:
             coord.request_stop()
         coord.join(threads)
     assert (n_batches == 10)
Пример #2
0
def train_tiny_imagenet(ckpt_path='pixelcnn',
                        n_epochs=1000,
                        save_step=100,
                        write_step=25,
                        B=32,
                        H=64,
                        W=64,
                        C=3):
    """Summary

    Parameters
    ----------
    ckpt_path : str, optional
        Description
    n_epochs : int, optional
        Description
    save_step : int, optional
        Description
    write_step : int, optional
        Description
    B : int, optional
        Description
    H : int, optional
        Description
    W : int, optional
        Description
    C : int, optional
        Description
    """
    ckpt_name = os.path.join(ckpt_path, 'pixelcnn.ckpt')

    with tf.Graph().as_default(), tf.Session() as sess:
        # Not actually conditioning on anything here just using the gated cnn model
        net = build_conditional_pixel_cnn_model(B=B, H=H, W=W, C=C)

        # build the optimizer (this will take a while!)
        optimizer = tf.train.AdamOptimizer(
            learning_rate=0.001).minimize(net['cost'])

        # Load a list of files for tiny imagenet, downloading if necessary
        imagenet_files = dsu.tiny_imagenet_load()

        # Create a threaded image pipeline which will load/shuffle/crop/resize
        batch = dsu.create_input_pipeline(
            imagenet_files[0],
            batch_size=B,
            n_epochs=n_epochs,
            shape=[64, 64, 3],
            crop_shape=[H, W, C],
            crop_factor=1.0,
            n_threads=8)

        saver = tf.train.Saver()
        writer = tf.summary.FileWriter(ckpt_path)
        init_op = tf.group(tf.global_variables_initializer(),
                           tf.local_variables_initializer())
        sess.run(init_op)

        # This will handle our threaded image pipeline
        coord = tf.train.Coordinator()

        # Ensure no more changes to graph
        tf.get_default_graph().finalize()

        # Start up the queues for handling the image pipeline
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        if os.path.exists(ckpt_name + '.index') or os.path.exists(ckpt_name):
            saver.restore(sess, ckpt_name)
            saver.restore(sess, tf.train.latest_checkpoint(ckpt_path))

        epoch_i = 0
        batch_i = 0
        try:
            while not coord.should_stop() and epoch_i < n_epochs:
                batch_i += 1
                batch_xs = sess.run(batch)
                train_cost = sess.run(
                    [net['cost'], optimizer], feed_dict={net['X']: batch_xs})[0]

                print(batch_i, train_cost)
                if batch_i % write_step == 0:
                    summary = sess.run(
                        net['summaries'], feed_dict={net['X']: batch_xs})
                    writer.add_summary(summary, batch_i)

                if batch_i % save_step == 0:
                    # Save the variables to disk.  Don't write the meta graph
                    # since we can use the code to create it, and it takes a long
                    # time to create the graph since it is so deep
                    saver.save(
                        sess,
                        ckpt_name,
                        global_step=batch_i,
                        write_meta_graph=True)
        except tf.errors.OutOfRangeError:
            print('Done.')
        finally:
            # One of the threads has issued an exception.  So let's tell all the
            # threads to shutdown.
            coord.request_stop()

        # Wait until all threads have finished.
        coord.join(threads)
Пример #3
0
def train_vaegan(files,
                 learning_rate=0.00001,
                 batch_size=64,
                 n_epochs=250,
                 n_examples=10,
                 input_shape=[218, 178, 3],
                 crop_shape=[64, 64, 3],
                 crop_factor=0.8,
                 n_filters=[100, 100, 100, 100],
                 n_hidden=None,
                 n_code=128,
                 convolutional=True,
                 variational=True,
                 filter_sizes=[3, 3, 3, 3],
                 activation=tf.nn.elu,
                 ckpt_name="vaegan.ckpt"):
    """Summary

    Parameters
    ----------
    files : TYPE
        Description
    learning_rate : float, optional
        Description
    batch_size : int, optional
        Description
    n_epochs : int, optional
        Description
    n_examples : int, optional
        Description
    input_shape : list, optional
        Description
    crop_shape : list, optional
        Description
    crop_factor : float, optional
        Description
    n_filters : list, optional
        Description
    n_hidden : int, optional
        Description
    n_code : int, optional
        Description
    convolutional : bool, optional
        Description
    variational : bool, optional
        Description
    filter_sizes : list, optional
        Description
    activation : TYPE, optional
        Description
    ckpt_name : str, optional
        Description

    No Longer Returned
    ------------------
    name : TYPE
        Description
    """

    ae = VAEGAN(input_shape=[None] + crop_shape,
                convolutional=convolutional,
                variational=variational,
                n_filters=n_filters,
                n_hidden=n_hidden,
                n_code=n_code,
                filter_sizes=filter_sizes,
                activation=activation)

    batch = create_input_pipeline(files=files,
                                  batch_size=batch_size,
                                  n_epochs=n_epochs,
                                  crop_shape=crop_shape,
                                  crop_factor=crop_factor,
                                  shape=input_shape)

    zs = np.random.randn(4, n_code).astype(np.float32)
    zs = utils.make_latent_manifold(zs, n_examples)

    opt_enc = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(
        ae['loss_enc'],
        var_list=[
            var_i for var_i in tf.trainable_variables()
            if var_i.name.startswith('encoder')
        ])

    opt_gen = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(
        ae['loss_gen'],
        var_list=[
            var_i for var_i in tf.trainable_variables()
            if var_i.name.startswith('generator')
        ])

    opt_dis = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(
        ae['loss_dis'],
        var_list=[
            var_i for var_i in tf.trainable_variables()
            if var_i.name.startswith('discriminator')
        ])

    sess = tf.Session()
    saver = tf.train.Saver()
    init_op = tf.group(tf.global_variables_initializer(),
                       tf.local_variables_initializer())
    sess.run(init_op)
    coord = tf.train.Coordinator()
    tf.get_default_graph().finalize()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    if os.path.exists(ckpt_name + '.index') or os.path.exists(ckpt_name):
        saver.restore(sess, ckpt_name)
        print("VAE model restored.")

    t_i = 0
    batch_i = 0
    epoch_i = 0

    equilibrium = 0.693
    margin = 0.4

    n_files = len(files)
    test_xs = sess.run(batch) / 255.0
    utils.montage(test_xs, 'test_xs.png')
    try:
        while not coord.should_stop() and epoch_i < n_epochs:
            if batch_i % (n_files // batch_size) == 0:
                batch_i = 0
                epoch_i += 1
                print('---------- EPOCH:', epoch_i)

            batch_i += 1
            batch_xs = sess.run(batch) / 255.0
            batch_zs = np.random.randn(batch_size, n_code).astype(np.float32)
            real_cost, fake_cost, _ = sess.run(
                [ae['loss_real'], ae['loss_fake'], opt_enc],
                feed_dict={
                    ae['x']: batch_xs,
                    ae['gamma']: 0.5
                })
            real_cost = -np.mean(real_cost)
            fake_cost = -np.mean(fake_cost)
            print('real:', real_cost, '/ fake:', fake_cost)

            gen_update = True
            dis_update = True

            if real_cost > (equilibrium + margin) or \
               fake_cost > (equilibrium + margin):
                gen_update = False

            if real_cost < (equilibrium - margin) or \
               fake_cost < (equilibrium - margin):
                dis_update = False

            if not (gen_update or dis_update):
                gen_update = True
                dis_update = True

            if gen_update:
                sess.run(opt_gen,
                         feed_dict={
                             ae['x']: batch_xs,
                             ae['z_samp']: batch_zs,
                             ae['gamma']: 0.5
                         })
            if dis_update:
                sess.run(opt_dis,
                         feed_dict={
                             ae['x']: batch_xs,
                             ae['z_samp']: batch_zs,
                             ae['gamma']: 0.5
                         })

            if batch_i % 50 == 0:

                # Plot example reconstructions from latent layer
                recon = sess.run(ae['x_tilde'], feed_dict={ae['z']: zs})
                print('recon:', recon.min(), recon.max())
                recon = np.clip(recon / recon.max(), 0, 1)
                utils.montage(recon.reshape([-1] + crop_shape),
                              'imgs/manifold_%08d.png' % t_i)

                # Plot example reconstructions
                recon = sess.run(ae['x_tilde'], feed_dict={ae['x']: test_xs})
                print('recon:', recon.min(), recon.max())
                recon = np.clip(recon / recon.max(), 0, 1)
                utils.montage(recon.reshape([-1] + crop_shape),
                              'imgs/reconstruction_%08d.png' % t_i)
                t_i += 1

            if batch_i % 100 == 0:
                # Save the variables to disk.
                save_path = saver.save(sess,
                                       ckpt_name,
                                       global_step=batch_i,
                                       write_meta_graph=False)
                print("Model saved in file: %s" % save_path)
    except tf.errors.OutOfRangeError:
        print('Done training -- epoch limit reached')
    finally:
        # One of the threads has issued an exception.  So let's tell all the
        # threads to shutdown.
        coord.request_stop()

    # Wait until all threads have finished.
    coord.join(threads)

    # Clean up the session.
    sess.close()
Пример #4
0
def train_vae(files,
              input_shape,
              learning_rate=0.0001,
              batch_size=100,
              n_epochs=50,
              n_examples=10,
              crop_shape=[64, 64, 3],
              crop_factor=0.8,
              n_filters=[100, 100, 100, 100],
              n_hidden=256,
              n_code=50,
              convolutional=True,
              variational=True,
              filter_sizes=[3, 3, 3, 3],
              dropout=True,
              keep_prob=0.8,
              activation=tf.nn.relu,
              img_step=100,
              save_step=100,
              ckpt_name="vae.ckpt"):
    """General purpose training of a (Variational) (Convolutional) Autoencoder.

    Supply a list of file paths to images, and this will do everything else.

    Parameters
    ----------
    files : list of strings
        List of paths to images.
    input_shape : list
        Must define what the input image's shape is.
    learning_rate : float, optional
        Learning rate.
    batch_size : int, optional
        Batch size.
    n_epochs : int, optional
        Number of epochs.
    n_examples : int, optional
        Number of example to use while demonstrating the current training
        iteration's reconstruction.  Creates a square montage, so make
        sure int(sqrt(n_examples))**2 = n_examples, e.g. 16, 25, 36, ... 100.
    crop_shape : list, optional
        Size to centrally crop the image to.
    crop_factor : float, optional
        Resize factor to apply before cropping.
    n_filters : list, optional
        Same as VAE's n_filters.
    n_hidden : int, optional
        Same as VAE's n_hidden.
    n_code : int, optional
        Same as VAE's n_code.
    convolutional : bool, optional
        Use convolution or not.
    variational : bool, optional
        Use variational layer or not.
    filter_sizes : list, optional
        Same as VAE's filter_sizes.
    dropout : bool, optional
        Use dropout or not
    keep_prob : float, optional
        Percent of keep for dropout.
    activation : function, optional
        Which activation function to use.
    img_step : int, optional
        How often to save training images showing the manifold and
        reconstruction.
    save_step : int, optional
        How often to save checkpoints.
    ckpt_name : str, optional
        Checkpoints will be named as this, e.g. 'model.ckpt'
    """
    batch = create_input_pipeline(files=files,
                                  batch_size=batch_size,
                                  n_epochs=n_epochs,
                                  crop_shape=crop_shape,
                                  crop_factor=crop_factor,
                                  shape=input_shape)

    ae = VAE(input_shape=[None] + crop_shape,
             convolutional=convolutional,
             variational=variational,
             n_filters=n_filters,
             n_hidden=n_hidden,
             n_code=n_code,
             dropout=dropout,
             filter_sizes=filter_sizes,
             activation=activation)

    # Create a manifold of our inner most layer to show
    # example reconstructions.  This is one way to see
    # what the "embedding" or "latent space" of the encoder
    # is capable of encoding, though note that this is just
    # a random hyperplane within the latent space, and does not
    # encompass all possible embeddings.
    zs = np.random.uniform(-1.0, 1.0, [4, n_code]).astype(np.float32)
    zs = utils.make_latent_manifold(zs, n_examples)

    optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(
        ae['cost'])

    # We create a session to use the graph
    sess = tf.Session()
    saver = tf.train.Saver()
    init_op = tf.group(tf.global_variables_initializer(),
                       tf.local_variables_initializer())
    sess.run(init_op)

    # This will handle our threaded image pipeline
    coord = tf.train.Coordinator()

    # Ensure no more changes to graph
    tf.get_default_graph().finalize()

    # Start up the queues for handling the image pipeline
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    if os.path.exists(ckpt_name + '.index') or os.path.exists(ckpt_name):
        saver.restore(sess, ckpt_name)

    # Fit all training data
    t_i = 0
    batch_i = 0
    epoch_i = 0
    cost = 0
    n_files = len(files)
    test_xs = sess.run(batch) / 255.0
    utils.montage(test_xs, 'test_xs.png')
    try:
        while not coord.should_stop() and epoch_i < n_epochs:
            batch_i += 1
            batch_xs = sess.run(batch) / 255.0
            train_cost = sess.run([ae['cost'], optimizer],
                                  feed_dict={
                                      ae['x']: batch_xs,
                                      ae['train']: True,
                                      ae['keep_prob']: keep_prob
                                  })[0]
            print(batch_i, train_cost)
            cost += train_cost
            if batch_i % n_files == 0:
                print('epoch:', epoch_i)
                print('average cost:', cost / batch_i)
                cost = 0
                batch_i = 0
                epoch_i += 1

            if batch_i % img_step == 0:
                # Plot example reconstructions from latent layer
                recon = sess.run(ae['y'],
                                 feed_dict={
                                     ae['z']: zs,
                                     ae['train']: False,
                                     ae['keep_prob']: 1.0
                                 })
                utils.montage(recon.reshape([-1] + crop_shape),
                              'manifold_%08d.png' % t_i)

                # Plot example reconstructions
                recon = sess.run(ae['y'],
                                 feed_dict={
                                     ae['x']: test_xs,
                                     ae['train']: False,
                                     ae['keep_prob']: 1.0
                                 })
                print('reconstruction (min, max, mean):', recon.min(),
                      recon.max(), recon.mean())
                utils.montage(recon.reshape([-1] + crop_shape),
                              'reconstruction_%08d.png' % t_i)
                t_i += 1

            if batch_i % save_step == 0:
                # Save the variables to disk.
                saver.save(sess,
                           ckpt_name,
                           global_step=batch_i,
                           write_meta_graph=False)
    except tf.errors.OutOfRangeError:
        print('Done.')
    finally:
        # One of the threads has issued an exception.  So let's tell all the
        # threads to shutdown.
        coord.request_stop()

    # Wait until all threads have finished.
    coord.join(threads)

    # Clean up the session.
    sess.close()
Пример #5
0
def train_vae(files,
              input_shape,
              learning_rate=0.0001,
              batch_size=100,
              n_epochs=50,
              n_examples=10,
              crop_shape=[64, 64, 3],
              crop_factor=0.8,
              n_filters=[100, 100, 100, 100],
              n_hidden=256,
              n_code=50,
              convolutional=True,
              variational=True,
              filter_sizes=[3, 3, 3, 3],
              dropout=True,
              keep_prob=0.8,
              activation=tf.nn.relu,
              img_step=100,
              save_step=100,
              ckpt_name="vae.ckpt"):
    """General purpose training of a (Variational) (Convolutional) Autoencoder.

    Supply a list of file paths to images, and this will do everything else.

    Parameters
    ----------
    files : list of strings
        List of paths to images.
    input_shape : list
        Must define what the input image's shape is.
    learning_rate : float, optional
        Learning rate.
    batch_size : int, optional
        Batch size.
    n_epochs : int, optional
        Number of epochs.
    n_examples : int, optional
        Number of example to use while demonstrating the current training
        iteration's reconstruction.  Creates a square montage, so make
        sure int(sqrt(n_examples))**2 = n_examples, e.g. 16, 25, 36, ... 100.
    crop_shape : list, optional
        Size to centrally crop the image to.
    crop_factor : float, optional
        Resize factor to apply before cropping.
    n_filters : list, optional
        Same as VAE's n_filters.
    n_hidden : int, optional
        Same as VAE's n_hidden.
    n_code : int, optional
        Same as VAE's n_code.
    convolutional : bool, optional
        Use convolution or not.
    variational : bool, optional
        Use variational layer or not.
    filter_sizes : list, optional
        Same as VAE's filter_sizes.
    dropout : bool, optional
        Use dropout or not
    keep_prob : float, optional
        Percent of keep for dropout.
    activation : function, optional
        Which activation function to use.
    img_step : int, optional
        How often to save training images showing the manifold and
        reconstruction.
    save_step : int, optional
        How often to save checkpoints.
    ckpt_name : str, optional
        Checkpoints will be named as this, e.g. 'model.ckpt'
    """
    batch = create_input_pipeline(
        files=files,
        batch_size=batch_size,
        n_epochs=n_epochs,
        crop_shape=crop_shape,
        crop_factor=crop_factor,
        shape=input_shape)

    ae = VAE(
        input_shape=[None] + crop_shape,
        convolutional=convolutional,
        variational=variational,
        n_filters=n_filters,
        n_hidden=n_hidden,
        n_code=n_code,
        dropout=dropout,
        filter_sizes=filter_sizes,
        activation=activation)

    # Create a manifold of our inner most layer to show
    # example reconstructions.  This is one way to see
    # what the "embedding" or "latent space" of the encoder
    # is capable of encoding, though note that this is just
    # a random hyperplane within the latent space, and does not
    # encompass all possible embeddings.
    zs = np.random.uniform(-1.0, 1.0, [4, n_code]).astype(np.float32)
    zs = utils.make_latent_manifold(zs, n_examples)

    optimizer = tf.train.AdamOptimizer(
        learning_rate=learning_rate).minimize(ae['cost'])

    # We create a session to use the graph
    sess = tf.Session()
    saver = tf.train.Saver()
    init_op = tf.group(tf.global_variables_initializer(),
                       tf.local_variables_initializer())
    sess.run(init_op)

    # This will handle our threaded image pipeline
    coord = tf.train.Coordinator()

    # Ensure no more changes to graph
    tf.get_default_graph().finalize()

    # Start up the queues for handling the image pipeline
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    if os.path.exists(ckpt_name + '.index') or os.path.exists(ckpt_name):
        saver.restore(sess, ckpt_name)

    # Fit all training data
    t_i = 0
    batch_i = 0
    epoch_i = 0
    cost = 0
    n_files = len(files)
    test_xs = sess.run(batch) / 255.0
    utils.montage(test_xs, 'test_xs.png')
    try:
        while not coord.should_stop() and epoch_i < n_epochs:
            batch_i += 1
            batch_xs = sess.run(batch) / 255.0
            train_cost = sess.run(
                [ae['cost'], optimizer],
                feed_dict={
                    ae['x']: batch_xs,
                    ae['train']: True,
                    ae['keep_prob']: keep_prob
                })[0]
            print(batch_i, train_cost)
            cost += train_cost
            if batch_i % n_files == 0:
                print('epoch:', epoch_i)
                print('average cost:', cost / batch_i)
                cost = 0
                batch_i = 0
                epoch_i += 1

            if batch_i % img_step == 0:
                # Plot example reconstructions from latent layer
                recon = sess.run(
                    ae['y'],
                    feed_dict={
                        ae['z']: zs,
                        ae['train']: False,
                        ae['keep_prob']: 1.0
                    })
                utils.montage(
                    recon.reshape([-1] + crop_shape), 'manifold_%08d.png' % t_i)

                # Plot example reconstructions
                recon = sess.run(
                    ae['y'],
                    feed_dict={
                        ae['x']: test_xs,
                        ae['train']: False,
                        ae['keep_prob']: 1.0
                    })
                print('reconstruction (min, max, mean):',
                      recon.min(), recon.max(), recon.mean())
                utils.montage(
                    recon.reshape([-1] + crop_shape),
                    'reconstruction_%08d.png' % t_i)
                t_i += 1

            if batch_i % save_step == 0:
                # Save the variables to disk.
                saver.save(
                    sess,
                    ckpt_name,
                    global_step=batch_i,
                    write_meta_graph=False)
    except tf.errors.OutOfRangeError:
        print('Done.')
    finally:
        # One of the threads has issued an exception.  So let's tell all the
        # threads to shutdown.
        coord.request_stop()

    # Wait until all threads have finished.
    coord.join(threads)

    # Clean up the session.
    sess.close()
Пример #6
0
def train_input_pipeline(
        files,
        A,  # img_h
        B,  # img_w
        C,
        T=20,
        n_enc=512,
        n_z=256,
        n_dec=512,
        read_n=15,
        write_n=15,
        batch_size=64,
        n_epochs=1e9,
        input_shape=(64, 64, 3)):

    # We create a session to use the graph
    g = tf.Graph()
    with tf.Session(graph=g) as sess:
        batch = create_input_pipeline(
            files=files,
            batch_size=batch_size,
            n_epochs=n_epochs,
            crop_shape=(A, B, C),
            shape=input_shape)

        draw = create_model(
            A=A,
            B=B,
            C=C,
            T=T,
            batch_size=batch_size,
            n_enc=n_enc,
            n_z=n_z,
            n_dec=n_dec,
            read_n=read_n,
            write_n=write_n)
        opt = tf.train.AdamOptimizer(learning_rate=0.0001)
        grads = opt.compute_gradients(draw['cost'])
        for i, (g, v) in enumerate(grads):
            if g is not None:
                grads[i] = (tf.clip_by_norm(g, 5), v)
        train_op = opt.apply_gradients(grads)
        init_op = tf.group(tf.global_variables_initializer(),
                           tf.local_variables_initializer())
        sess.run(init_op)
        coord = tf.train.Coordinator()
        tf.get_default_graph().finalize()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        # Fit all training data
        batch_i = 0
        epoch_i = 0
        n_files = len(files)
        test_xs = sess.run(batch).reshape((-1, A * B * C)) / 255.0
        utils.montage(test_xs.reshape((-1, A, B, C)), 'test_xs.png')
        try:
            while not coord.should_stop() and epoch_i < n_epochs:
                batch_xs = sess.run(batch) / 255.0
                noise = np.random.randn(batch_size, n_z)
                lx, lz = sess.run(
                    [draw['loss_x'], draw['loss_z'], train_op],
                    feed_dict={
                        draw['x']: batch_xs.reshape((-1, A * B * C)) / 255.0,
                        draw['noise']: noise
                    })[0:2]
                print('x:', lx, 'z:', lz)
                if batch_i % n_files == 0:
                    batch_i = 0
                    epoch_i += 1
                if batch_i % 1000 == 0:
                    # Plot example reconstructions
                    recon = sess.run(
                        draw['canvas'],
                        feed_dict={draw['x']: test_xs,
                                   draw['noise']: noise})
                    recon = [
                        utils.montage(r.reshape(-1, A, B, C)) for r in recon
                    ]
                    gif.build_gif(recon, saveto='manifold_%08d.gif' % batch_i)
                    plt.close('all')
                batch_i += 1
        except tf.errors.OutOfRangeError:
            print('Done.')
        finally:
            # One of the threads has issued an exception.  So let's tell all the
            # threads to shutdown.
            coord.request_stop()

        # Wait until all threads have finished.
        coord.join(threads)

        # Clean up the session.
        sess.close()
Пример #7
0
def train_input_pipeline(files,
                         init_lr_g=1e-4,
                         init_lr_d=1e-4,
                         n_features=10,
                         n_latent=100,
                         n_epochs=1000000,
                         batch_size=200,
                         n_samples=15,
                         input_shape=[218, 178, 3],
                         crop_shape=[64, 64, 3],
                         crop_factor=0.8):
    """Summary

    Parameters
    ----------
    files : TYPE
        Description
    init_lr_g : float, optional
        Description
    init_lr_d : float, optional
        Description
    n_features : int, optional
        Description
    n_latent : int, optional
        Description
    n_epochs : int, optional
        Description
    batch_size : int, optional
        Description
    n_samples : int, optional
        Description
    input_shape : list, optional
        Description
    crop_shape : list, optional
        Description
    crop_factor : float, optional
        Description

    No Longer Returned
    ------------------
    name : TYPE
        Description
    """

    with tf.Graph().as_default(), tf.Session() as sess:
        batch = create_input_pipeline(
            files=files,
            batch_size=batch_size,
            n_epochs=n_epochs,
            crop_shape=crop_shape,
            crop_factor=crop_factor,
            shape=input_shape)

        gan = GAN(
            input_shape=[None] + crop_shape,
            n_features=n_features,
            n_latent=n_latent,
            rgb=True,
            debug=False)

        vars_d = [
            v for v in tf.trainable_variables()
            if v.name.startswith('discriminator')
        ]
        print('Training discriminator variables:')
        [
            print(v.name) for v in tf.trainable_variables()
            if v.name.startswith('discriminator')
        ]

        vars_g = [
            v for v in tf.trainable_variables()
            if v.name.startswith('generator')
        ]
        print('Training generator variables:')
        [
            print(v.name) for v in tf.trainable_variables()
            if v.name.startswith('generator')
        ]
        zs = np.random.uniform(-1.0, 1.0, [4, n_latent]).astype(np.float32)
        zs = utils.make_latent_manifold(zs, n_samples)

        lr_g = tf.placeholder(tf.float32, shape=[], name='learning_rate_g')
        lr_d = tf.placeholder(tf.float32, shape=[], name='learning_rate_d')

        try:
            from tf.contrib.layers import apply_regularization
            d_reg = apply_regularization(
                tf.contrib.layers.l2_regularizer(1e-6), vars_d)
            g_reg = apply_regularization(
                tf.contrib.layers.l2_regularizer(1e-6), vars_g)
        except:
            d_reg, g_reg = 0, 0

        opt_g = tf.train.AdamOptimizer(
            lr_g, name='Adam_g').minimize(
                gan['loss_G'] + g_reg, var_list=vars_g)
        opt_d = tf.train.AdamOptimizer(
            lr_d, name='Adam_d').minimize(
                gan['loss_D'] + d_reg, var_list=vars_d)

        # %%
        # We create a session to use the graph

        saver = tf.train.Saver()
        sums = gan['sums']
        G_sum_op = tf.summary.merge([
            sums['G'], sums['loss_G'], sums['z'], sums['loss_D_fake'],
            sums['D_fake']
        ])
        D_sum_op = tf.summary.merge([
            sums['loss_D'], sums['loss_D_real'], sums['loss_D_fake'], sums['z'],
            sums['x'], sums['D_real'], sums['D_fake']
        ])
        writer = tf.summary.FileWriter("./logs", sess.graph_def)

        init_op = tf.group(tf.global_variables_initializer(),
                           tf.local_variables_initializer())
        sess.run(init_op)
        coord = tf.train.Coordinator()
        tf.get_default_graph().finalize()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        # g = tf.get_default_graph()
        # [print(op.name) for op in g.get_operations()]

        if os.path.exists("gan.ckpt"):
            saver.restore(sess, "gan.ckpt")
            print("GAN model restored.")

        fig, ax = plt.subplots(1, 1, figsize=(10, 10))
        step_i, t_i = 0, 0
        loss_d = 1
        loss_g = 1
        n_loss_d, total_loss_d = 1, 1
        n_loss_g, total_loss_g = 1, 1
        try:
            while not coord.should_stop():
                batch_xs = sess.run(batch)
                step_i += 1
                batch_zs = np.random.uniform(
                    -1.0, 1.0, [batch_size, n_latent]).astype(np.float32)

                this_lr_g = min(1e-2,
                                max(1e-6, init_lr_g * (loss_g / loss_d)**2))
                this_lr_d = min(1e-2,
                                max(1e-6, init_lr_d * (loss_d / loss_g)**2))
                # this_lr_d *= ((1.0 - (step_i / 100000)) ** 2)
                # this_lr_g *= ((1.0 - (step_i / 100000)) ** 2)

                # if np.random.random() > (loss_g / (loss_d + loss_g)):
                if step_i % 3 == 1:
                    loss_d, _, sum_d = sess.run(
                        [gan['loss_D'], opt_d, D_sum_op],
                        feed_dict={
                            gan['x']: batch_xs,
                            gan['z']: batch_zs,
                            lr_d: this_lr_d
                        })
                    total_loss_d += loss_d
                    n_loss_d += 1
                    writer.add_summary(sum_d, step_i)
                    print('%04d d* = lr: %0.08f, loss: %08.06f, \t' % (
                        step_i, this_lr_d, loss_d
                    ) + 'g  = lr: %0.08f, loss: %08.06f' % (this_lr_g, loss_g))
                else:
                    loss_g, _, sum_g = sess.run(
                        [gan['loss_G'], opt_g, G_sum_op],
                        feed_dict={gan['z']: batch_zs,
                                   lr_g: this_lr_g})
                    total_loss_g += loss_g
                    n_loss_g += 1
                    writer.add_summary(sum_g, step_i)
                    print('%04d d  = lr: %0.08f, loss: %08.06f, \t' % (
                        step_i, this_lr_d, loss_d
                    ) + 'g* = lr: %0.08f, loss: %08.06f' % (this_lr_g, loss_g))

                if step_i % 100 == 0:
                    samples = sess.run(gan['G'], feed_dict={gan['z']: zs})
                    utils.montage(
                        np.clip((samples + 1) * 127.5, 0, 255).astype(np.uint8),
                        'imgs/gan_%08d.png' % t_i)
                    t_i += 1

                    print('generator loss:', total_loss_g / n_loss_g)
                    print('discriminator loss:', total_loss_d / n_loss_d)

                    # Save the variables to disk.
                    save_path = saver.save(
                        sess,
                        "./gan.ckpt",
                        global_step=step_i,
                        write_meta_graph=False)
                    print("Model saved in file: %s" % save_path)
        except tf.errors.OutOfRangeError:
            print('Done training -- epoch limit reached')
        finally:
            # One of the threads has issued an exception.  So let's tell all the
            # threads to shutdown.
            coord.request_stop()

        # Wait until all threads have finished.
        coord.join(threads)
Пример #8
0
def train_input_pipeline(files,
                         init_lr_g=1e-4,
                         init_lr_d=1e-4,
                         n_features=10,
                         n_latent=100,
                         n_epochs=1000000,
                         batch_size=200,
                         n_samples=15,
                         input_shape=[218, 178, 3],
                         crop_shape=[64, 64, 3],
                         crop_factor=0.8):
    """Summary

    Parameters
    ----------
    files : TYPE
        Description
    init_lr_g : float, optional
        Description
    init_lr_d : float, optional
        Description
    n_features : int, optional
        Description
    n_latent : int, optional
        Description
    n_epochs : int, optional
        Description
    batch_size : int, optional
        Description
    n_samples : int, optional
        Description
    input_shape : list, optional
        Description
    crop_shape : list, optional
        Description
    crop_factor : float, optional
        Description

    No Longer Returned
    ------------------
    name : TYPE
        Description
    """

    with tf.Graph().as_default(), tf.Session() as sess:
        batch = create_input_pipeline(files=files,
                                      batch_size=batch_size,
                                      n_epochs=n_epochs,
                                      crop_shape=crop_shape,
                                      crop_factor=crop_factor,
                                      shape=input_shape)

        gan = GAN(input_shape=[None] + crop_shape,
                  n_features=n_features,
                  n_latent=n_latent,
                  rgb=True,
                  debug=False)

        vars_d = [
            v for v in tf.trainable_variables()
            if v.name.startswith('discriminator')
        ]
        print('Training discriminator variables:')
        [
            print(v.name) for v in tf.trainable_variables()
            if v.name.startswith('discriminator')
        ]

        vars_g = [
            v for v in tf.trainable_variables()
            if v.name.startswith('generator')
        ]
        print('Training generator variables:')
        [
            print(v.name) for v in tf.trainable_variables()
            if v.name.startswith('generator')
        ]
        zs = np.random.uniform(-1.0, 1.0, [4, n_latent]).astype(np.float32)
        zs = utils.make_latent_manifold(zs, n_samples)

        lr_g = tf.placeholder(tf.float32, shape=[], name='learning_rate_g')
        lr_d = tf.placeholder(tf.float32, shape=[], name='learning_rate_d')

        try:
            from tf.contrib.layers import apply_regularization
            d_reg = apply_regularization(
                tf.contrib.layers.l2_regularizer(1e-6), vars_d)
            g_reg = apply_regularization(
                tf.contrib.layers.l2_regularizer(1e-6), vars_g)
        except:
            d_reg, g_reg = 0, 0

        opt_g = tf.train.AdamOptimizer(lr_g, name='Adam_g').minimize(
            gan['loss_G'] + g_reg, var_list=vars_g)
        opt_d = tf.train.AdamOptimizer(lr_d, name='Adam_d').minimize(
            gan['loss_D'] + d_reg, var_list=vars_d)

        # %%
        # We create a session to use the graph

        saver = tf.train.Saver()
        sums = gan['sums']
        G_sum_op = tf.summary.merge([
            sums['G'], sums['loss_G'], sums['z'], sums['loss_D_fake'],
            sums['D_fake']
        ])
        D_sum_op = tf.summary.merge([
            sums['loss_D'], sums['loss_D_real'], sums['loss_D_fake'],
            sums['z'], sums['x'], sums['D_real'], sums['D_fake']
        ])
        writer = tf.summary.FileWriter("./logs", sess.graph_def)

        init_op = tf.group(tf.global_variables_initializer(),
                           tf.local_variables_initializer())
        sess.run(init_op)
        coord = tf.train.Coordinator()
        tf.get_default_graph().finalize()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        # g = tf.get_default_graph()
        # [print(op.name) for op in g.get_operations()]

        if os.path.exists("gan.ckpt"):
            saver.restore(sess, "gan.ckpt")
            print("GAN model restored.")

        fig, ax = plt.subplots(1, 1, figsize=(10, 10))
        step_i, t_i = 0, 0
        loss_d = 1
        loss_g = 1
        n_loss_d, total_loss_d = 1, 1
        n_loss_g, total_loss_g = 1, 1
        try:
            while not coord.should_stop():
                batch_xs = sess.run(batch)
                step_i += 1
                batch_zs = np.random.uniform(
                    -1.0, 1.0, [batch_size, n_latent]).astype(np.float32)

                this_lr_g = min(1e-2,
                                max(1e-6,
                                    init_lr_g * (loss_g / loss_d)**2))
                this_lr_d = min(1e-2,
                                max(1e-6,
                                    init_lr_d * (loss_d / loss_g)**2))
                # this_lr_d *= ((1.0 - (step_i / 100000)) ** 2)
                # this_lr_g *= ((1.0 - (step_i / 100000)) ** 2)

                # if np.random.random() > (loss_g / (loss_d + loss_g)):
                if step_i % 3 == 1:
                    loss_d, _, sum_d = sess.run(
                        [gan['loss_D'], opt_d, D_sum_op],
                        feed_dict={
                            gan['x']: batch_xs,
                            gan['z']: batch_zs,
                            lr_d: this_lr_d
                        })
                    total_loss_d += loss_d
                    n_loss_d += 1
                    writer.add_summary(sum_d, step_i)
                    print('%04d d* = lr: %0.08f, loss: %08.06f, \t' %
                          (step_i, this_lr_d, loss_d) +
                          'g  = lr: %0.08f, loss: %08.06f' %
                          (this_lr_g, loss_g))
                else:
                    loss_g, _, sum_g = sess.run(
                        [gan['loss_G'], opt_g, G_sum_op],
                        feed_dict={
                            gan['z']: batch_zs,
                            lr_g: this_lr_g
                        })
                    total_loss_g += loss_g
                    n_loss_g += 1
                    writer.add_summary(sum_g, step_i)
                    print('%04d d  = lr: %0.08f, loss: %08.06f, \t' %
                          (step_i, this_lr_d, loss_d) +
                          'g* = lr: %0.08f, loss: %08.06f' %
                          (this_lr_g, loss_g))

                if step_i % 100 == 0:
                    samples = sess.run(gan['G'], feed_dict={gan['z']: zs})
                    utils.montage(
                        np.clip((samples + 1) * 127.5, 0,
                                255).astype(np.uint8),
                        'imgs/gan_%08d.png' % t_i)
                    t_i += 1

                    print('generator loss:', total_loss_g / n_loss_g)
                    print('discriminator loss:', total_loss_d / n_loss_d)

                    # Save the variables to disk.
                    save_path = saver.save(sess,
                                           "./gan.ckpt",
                                           global_step=step_i,
                                           write_meta_graph=False)
                    print("Model saved in file: %s" % save_path)
        except tf.errors.OutOfRangeError:
            print('Done training -- epoch limit reached')
        finally:
            # One of the threads has issued an exception.  So let's tell all the
            # threads to shutdown.
            coord.request_stop()

        # Wait until all threads have finished.
        coord.join(threads)
Пример #9
0
def train_tiny_imagenet():
    """Summary
    """
    net = build_pixel_rnn_basic_model()

    # build the optimizer (this will take a while!)
    optimizer = tf.train.AdamOptimizer(
        learning_rate=0.001).minimize(net['cost'])

    # Load a list of files for tiny imagenet, downloading if necessary
    imagenet_files = dsu.tiny_imagenet_load()

    # Create a threaded image pipeline which will load/shuffle/crop/resize
    batch = dsu.create_input_pipeline(
        imagenet_files,
        batch_size=B,
        n_epochs=n_epochs,
        shape=[64, 64, 3],
        crop_shape=[32, 32, 3],
        crop_factor=0.5,
        n_threads=8)

    sess = tf.Session()
    saver = tf.train.Saver()
    init_op = tf.group(tf.global_variables_initializer(),
                       tf.local_variables_initializer())
    sess.run(init_op)

    # This will handle our threaded image pipeline
    coord = tf.train.Coordinator()

    # Ensure no more changes to graph
    tf.get_default_graph().finalize()

    # Start up the queues for handling the image pipeline
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    if os.path.exists(ckpt_name + '.index') or os.path.exists(ckpt_name):
        saver.restore(sess, ckpt_name)
        saver.restore(sess, tf.train.latest_checkpoint('./'))

    epoch_i = 0
    batch_i = 0
    save_step = 100
    try:
        while not coord.should_stop() and epoch_i < n_epochs:
            batch_i += 1
            batch_xs = sess.run(batch)
            train_cost = sess.run(
                [net['cost'], optimizer], feed_dict={net['X']: batch_xs})[0]
            print(batch_i, train_cost)
            if batch_i % save_step == 0:
                # Save the variables to disk.  Don't write the meta graph
                # since we can use the code to create it, and it takes a long
                # time to create the graph since it is so deep
                saver.save(
                    sess,
                    ckpt_name,
                    global_step=batch_i,
                    write_meta_graph=False)
    except tf.errors.OutOfRangeError:
        print('Done.')
    finally:
        # One of the threads has issued an exception.  So let's tell all the
        # threads to shutdown.
        coord.request_stop()

    # Wait until all threads have finished.
    coord.join(threads)

    # Clean up the session.
    sess.close()
Пример #10
0
def train_input_pipeline(
        files,
        A,  # img_h
        B,  # img_w
        C,
        T=20,
        n_enc=512,
        n_z=256,
        n_dec=512,
        read_n=15,
        write_n=15,
        batch_size=64,
        n_epochs=1e9,
        input_shape=(64, 64, 3)):

    # We create a session to use the graph
    g = tf.Graph()
    with tf.Session(graph=g) as sess:
        batch = create_input_pipeline(
            files=files,
            batch_size=batch_size,
            n_epochs=n_epochs,
            crop_shape=(A, B, C),
            shape=input_shape)

        draw = create_model(
            A=A,
            B=B,
            C=C,
            T=T,
            batch_size=batch_size,
            n_enc=n_enc,
            n_z=n_z,
            n_dec=n_dec,
            read_n=read_n,
            write_n=write_n)
        opt = tf.train.AdamOptimizer(learning_rate=0.0001)
        grads = opt.compute_gradients(draw['cost'])
        for i, (g, v) in enumerate(grads):
            if g is not None:
                grads[i] = (tf.clip_by_norm(g, 5), v)
        train_op = opt.apply_gradients(grads)
        init_op = tf.group(tf.global_variables_initializer(),
                           tf.local_variables_initializer())
        sess.run(init_op)
        coord = tf.train.Coordinator()
        tf.get_default_graph().finalize()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        # Fit all training data
        batch_i = 0
        epoch_i = 0
        n_files = len(files)
        test_xs = sess.run(batch).reshape((-1, A * B * C)) / 255.0
        utils.montage(test_xs.reshape((-1, A, B, C)), 'test_xs.png')
        try:
            while not coord.should_stop() and epoch_i < n_epochs:
                batch_xs = sess.run(batch) / 255.0
                noise = np.random.randn(batch_size, n_z)
                lx, lz = sess.run(
                    [draw['loss_x'], draw['loss_z'], train_op],
                    feed_dict={
                        draw['x']: batch_xs.reshape((-1, A * B * C)) / 255.0,
                        draw['noise']: noise
                    })[0:2]
                print('x:', lx, 'z:', lz)
                if batch_i % n_files == 0:
                    batch_i = 0
                    epoch_i += 1
                if batch_i % 1000 == 0:
                    # Plot example reconstructions
                    recon = sess.run(
                        draw['canvas'],
                        feed_dict={draw['x']: test_xs,
                                   draw['noise']: noise})
                    recon = [
                        utils.montage(r.reshape(-1, A, B, C)) for r in recon
                    ]
                    gif.build_gif(recon, saveto='manifold_%08d.gif' % batch_i)
                    plt.close('all')
                batch_i += 1
        except tf.errors.OutOfRangeError:
            print('Done.')
        finally:
            # One of the threads has issued an exception.  So let's tell all the
            # threads to shutdown.
            coord.request_stop()

        # Wait until all threads have finished.
        coord.join(threads)

        # Clean up the session.
        sess.close()
Пример #11
0
def train_tiny_imagenet(ckpt_path='pixelcnn',
                        n_epochs=1000,
                        save_step=100,
                        write_step=25,
                        B=32,
                        H=64,
                        W=64,
                        C=3):
    """Summary

    Parameters
    ----------
    ckpt_path : str, optional
        Description
    n_epochs : int, optional
        Description
    save_step : int, optional
        Description
    write_step : int, optional
        Description
    B : int, optional
        Description
    H : int, optional
        Description
    W : int, optional
        Description
    C : int, optional
        Description
    """
    ckpt_name = os.path.join(ckpt_path, 'pixelcnn.ckpt')

    with tf.Graph().as_default(), tf.Session() as sess:
        # Not actually conditioning on anything here just using the gated cnn model
        net = build_conditional_pixel_cnn_model(B=B, H=H, W=W, C=C)

        # build the optimizer (this will take a while!)
        optimizer = tf.train.AdamOptimizer(learning_rate=0.001).minimize(
            net['cost'])

        # Load a list of files for tiny imagenet, downloading if necessary
        imagenet_files = dsu.tiny_imagenet_load()

        # Create a threaded image pipeline which will load/shuffle/crop/resize
        batch = dsu.create_input_pipeline(imagenet_files[0],
                                          batch_size=B,
                                          n_epochs=n_epochs,
                                          shape=[64, 64, 3],
                                          crop_shape=[H, W, C],
                                          crop_factor=1.0,
                                          n_threads=8)

        saver = tf.train.Saver()
        writer = tf.summary.FileWriter(ckpt_path)
        init_op = tf.group(tf.global_variables_initializer(),
                           tf.local_variables_initializer())
        sess.run(init_op)

        # This will handle our threaded image pipeline
        coord = tf.train.Coordinator()

        # Ensure no more changes to graph
        tf.get_default_graph().finalize()

        # Start up the queues for handling the image pipeline
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        if os.path.exists(ckpt_name + '.index') or os.path.exists(ckpt_name):
            saver.restore(sess, ckpt_name)
            saver.restore(sess, tf.train.latest_checkpoint(ckpt_path))

        epoch_i = 0
        batch_i = 0
        try:
            while not coord.should_stop() and epoch_i < n_epochs:
                batch_i += 1
                batch_xs = sess.run(batch)
                train_cost = sess.run([net['cost'], optimizer],
                                      feed_dict={net['X']: batch_xs})[0]

                print(batch_i, train_cost)
                if batch_i % write_step == 0:
                    summary = sess.run(net['summaries'],
                                       feed_dict={net['X']: batch_xs})
                    writer.add_summary(summary, batch_i)

                if batch_i % save_step == 0:
                    # Save the variables to disk.  Don't write the meta graph
                    # since we can use the code to create it, and it takes a long
                    # time to create the graph since it is so deep
                    saver.save(sess,
                               ckpt_name,
                               global_step=batch_i,
                               write_meta_graph=True)
        except tf.errors.OutOfRangeError:
            print('Done.')
        finally:
            # One of the threads has issued an exception.  So let's tell all the
            # threads to shutdown.
            coord.request_stop()

        # Wait until all threads have finished.
        coord.join(threads)
Пример #12
0
def train_vaegan(files,
                 learning_rate=0.00001,
                 batch_size=64,
                 n_epochs=250,
                 n_examples=10,
                 input_shape=[218, 178, 3],
                 crop_shape=[64, 64, 3],
                 crop_factor=0.8,
                 n_filters=[100, 100, 100, 100],
                 n_hidden=None,
                 n_code=128,
                 convolutional=True,
                 variational=True,
                 filter_sizes=[3, 3, 3, 3],
                 activation=tf.nn.elu,
                 ckpt_name="vaegan.ckpt"):
    """Summary

    Parameters
    ----------
    files : TYPE
        Description
    learning_rate : float, optional
        Description
    batch_size : int, optional
        Description
    n_epochs : int, optional
        Description
    n_examples : int, optional
        Description
    input_shape : list, optional
        Description
    crop_shape : list, optional
        Description
    crop_factor : float, optional
        Description
    n_filters : list, optional
        Description
    n_hidden : int, optional
        Description
    n_code : int, optional
        Description
    convolutional : bool, optional
        Description
    variational : bool, optional
        Description
    filter_sizes : list, optional
        Description
    activation : TYPE, optional
        Description
    ckpt_name : str, optional
        Description

    No Longer Returned
    ------------------
    name : TYPE
        Description
    """

    ae = VAEGAN(
        input_shape=[None] + crop_shape,
        convolutional=convolutional,
        variational=variational,
        n_filters=n_filters,
        n_hidden=n_hidden,
        n_code=n_code,
        filter_sizes=filter_sizes,
        activation=activation)

    batch = create_input_pipeline(
        files=files,
        batch_size=batch_size,
        n_epochs=n_epochs,
        crop_shape=crop_shape,
        crop_factor=crop_factor,
        shape=input_shape)

    zs = np.random.randn(4, n_code).astype(np.float32)
    zs = utils.make_latent_manifold(zs, n_examples)

    opt_enc = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(
        ae['loss_enc'],
        var_list=[
            var_i for var_i in tf.trainable_variables()
            if var_i.name.startswith('encoder')
        ])

    opt_gen = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(
        ae['loss_gen'],
        var_list=[
            var_i for var_i in tf.trainable_variables()
            if var_i.name.startswith('generator')
        ])

    opt_dis = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(
        ae['loss_dis'],
        var_list=[
            var_i for var_i in tf.trainable_variables()
            if var_i.name.startswith('discriminator')
        ])

    sess = tf.Session()
    saver = tf.train.Saver()
    init_op = tf.group(tf.global_variables_initializer(),
                       tf.local_variables_initializer())
    sess.run(init_op)
    coord = tf.train.Coordinator()
    tf.get_default_graph().finalize()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    if os.path.exists(ckpt_name + '.index') or os.path.exists(ckpt_name):
        saver.restore(sess, ckpt_name)
        print("VAE model restored.")

    t_i = 0
    batch_i = 0
    epoch_i = 0

    equilibrium = 0.693
    margin = 0.4

    n_files = len(files)
    test_xs = sess.run(batch) / 255.0
    utils.montage(test_xs, 'test_xs.png')
    try:
        while not coord.should_stop() and epoch_i < n_epochs:
            if batch_i % (n_files // batch_size) == 0:
                batch_i = 0
                epoch_i += 1
                print('---------- EPOCH:', epoch_i)

            batch_i += 1
            batch_xs = sess.run(batch) / 255.0
            batch_zs = np.random.randn(batch_size, n_code).astype(np.float32)
            real_cost, fake_cost, _ = sess.run(
                [ae['loss_real'], ae['loss_fake'], opt_enc],
                feed_dict={ae['x']: batch_xs,
                           ae['gamma']: 0.5})
            real_cost = -np.mean(real_cost)
            fake_cost = -np.mean(fake_cost)
            print('real:', real_cost, '/ fake:', fake_cost)

            gen_update = True
            dis_update = True

            if real_cost > (equilibrium + margin) or \
               fake_cost > (equilibrium + margin):
                gen_update = False

            if real_cost < (equilibrium - margin) or \
               fake_cost < (equilibrium - margin):
                dis_update = False

            if not (gen_update or dis_update):
                gen_update = True
                dis_update = True

            if gen_update:
                sess.run(
                    opt_gen,
                    feed_dict={
                        ae['x']: batch_xs,
                        ae['z_samp']: batch_zs,
                        ae['gamma']: 0.5
                    })
            if dis_update:
                sess.run(
                    opt_dis,
                    feed_dict={
                        ae['x']: batch_xs,
                        ae['z_samp']: batch_zs,
                        ae['gamma']: 0.5
                    })

            if batch_i % 50 == 0:

                # Plot example reconstructions from latent layer
                recon = sess.run(ae['x_tilde'], feed_dict={ae['z']: zs})
                print('recon:', recon.min(), recon.max())
                recon = np.clip(recon / recon.max(), 0, 1)
                utils.montage(
                    recon.reshape([-1] + crop_shape),
                    'imgs/manifold_%08d.png' % t_i)

                # Plot example reconstructions
                recon = sess.run(ae['x_tilde'], feed_dict={ae['x']: test_xs})
                print('recon:', recon.min(), recon.max())
                recon = np.clip(recon / recon.max(), 0, 1)
                utils.montage(
                    recon.reshape([-1] + crop_shape),
                    'imgs/reconstruction_%08d.png' % t_i)
                t_i += 1

            if batch_i % 100 == 0:
                # Save the variables to disk.
                save_path = saver.save(
                    sess,
                    ckpt_name,
                    global_step=batch_i,
                    write_meta_graph=False)
                print("Model saved in file: %s" % save_path)
    except tf.errors.OutOfRangeError:
        print('Done training -- epoch limit reached')
    finally:
        # One of the threads has issued an exception.  So let's tell all the
        # threads to shutdown.
        coord.request_stop()

    # Wait until all threads have finished.
    coord.join(threads)

    # Clean up the session.
    sess.close()