Exemplo n.º 1
0
def train_vaegan(files,
                 learning_rate=0.00001,
                 batch_size=64,
                 n_epochs=250,
                 n_examples=10,
                 input_shape=[218, 178, 3],
                 crop_shape=[64, 64, 3],
                 crop_factor=0.8,
                 n_filters=[100, 100, 100, 100],
                 n_hidden=None,
                 n_code=128,
                 convolutional=True,
                 variational=True,
                 filter_sizes=[3, 3, 3, 3],
                 activation=tf.nn.elu,
                 ckpt_name="vaegan.ckpt"):
    """Summary

    Parameters
    ----------
    files : TYPE
        Description
    learning_rate : float, optional
        Description
    batch_size : int, optional
        Description
    n_epochs : int, optional
        Description
    n_examples : int, optional
        Description
    input_shape : list, optional
        Description
    crop_shape : list, optional
        Description
    crop_factor : float, optional
        Description
    n_filters : list, optional
        Description
    n_hidden : int, optional
        Description
    n_code : int, optional
        Description
    convolutional : bool, optional
        Description
    variational : bool, optional
        Description
    filter_sizes : list, optional
        Description
    activation : TYPE, optional
        Description
    ckpt_name : str, optional
        Description

    No Longer Returned
    ------------------
    name : TYPE
        Description
    """

    ae = VAEGAN(input_shape=[None] + crop_shape,
                convolutional=convolutional,
                variational=variational,
                n_filters=n_filters,
                n_hidden=n_hidden,
                n_code=n_code,
                filter_sizes=filter_sizes,
                activation=activation)

    batch = create_input_pipeline(files=files,
                                  batch_size=batch_size,
                                  n_epochs=n_epochs,
                                  crop_shape=crop_shape,
                                  crop_factor=crop_factor,
                                  shape=input_shape)

    zs = np.random.randn(4, n_code).astype(np.float32)
    zs = utils.make_latent_manifold(zs, n_examples)

    opt_enc = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(
        ae['loss_enc'],
        var_list=[
            var_i for var_i in tf.trainable_variables()
            if var_i.name.startswith('encoder')
        ])

    opt_gen = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(
        ae['loss_gen'],
        var_list=[
            var_i for var_i in tf.trainable_variables()
            if var_i.name.startswith('generator')
        ])

    opt_dis = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(
        ae['loss_dis'],
        var_list=[
            var_i for var_i in tf.trainable_variables()
            if var_i.name.startswith('discriminator')
        ])

    sess = tf.Session()
    saver = tf.train.Saver()
    init_op = tf.group(tf.global_variables_initializer(),
                       tf.local_variables_initializer())
    sess.run(init_op)
    coord = tf.train.Coordinator()
    tf.get_default_graph().finalize()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    if os.path.exists(ckpt_name + '.index') or os.path.exists(ckpt_name):
        saver.restore(sess, ckpt_name)
        print("VAE model restored.")

    t_i = 0
    batch_i = 0
    epoch_i = 0

    equilibrium = 0.693
    margin = 0.4

    n_files = len(files)
    test_xs = sess.run(batch) / 255.0
    utils.montage(test_xs, 'test_xs.png')
    try:
        while not coord.should_stop() and epoch_i < n_epochs:
            if batch_i % (n_files // batch_size) == 0:
                batch_i = 0
                epoch_i += 1
                print('---------- EPOCH:', epoch_i)

            batch_i += 1
            batch_xs = sess.run(batch) / 255.0
            batch_zs = np.random.randn(batch_size, n_code).astype(np.float32)
            real_cost, fake_cost, _ = sess.run(
                [ae['loss_real'], ae['loss_fake'], opt_enc],
                feed_dict={
                    ae['x']: batch_xs,
                    ae['gamma']: 0.5
                })
            real_cost = -np.mean(real_cost)
            fake_cost = -np.mean(fake_cost)
            print('real:', real_cost, '/ fake:', fake_cost)

            gen_update = True
            dis_update = True

            if real_cost > (equilibrium + margin) or \
               fake_cost > (equilibrium + margin):
                gen_update = False

            if real_cost < (equilibrium - margin) or \
               fake_cost < (equilibrium - margin):
                dis_update = False

            if not (gen_update or dis_update):
                gen_update = True
                dis_update = True

            if gen_update:
                sess.run(opt_gen,
                         feed_dict={
                             ae['x']: batch_xs,
                             ae['z_samp']: batch_zs,
                             ae['gamma']: 0.5
                         })
            if dis_update:
                sess.run(opt_dis,
                         feed_dict={
                             ae['x']: batch_xs,
                             ae['z_samp']: batch_zs,
                             ae['gamma']: 0.5
                         })

            if batch_i % 50 == 0:

                # Plot example reconstructions from latent layer
                recon = sess.run(ae['x_tilde'], feed_dict={ae['z']: zs})
                print('recon:', recon.min(), recon.max())
                recon = np.clip(recon / recon.max(), 0, 1)
                utils.montage(recon.reshape([-1] + crop_shape),
                              'imgs/manifold_%08d.png' % t_i)

                # Plot example reconstructions
                recon = sess.run(ae['x_tilde'], feed_dict={ae['x']: test_xs})
                print('recon:', recon.min(), recon.max())
                recon = np.clip(recon / recon.max(), 0, 1)
                utils.montage(recon.reshape([-1] + crop_shape),
                              'imgs/reconstruction_%08d.png' % t_i)
                t_i += 1

            if batch_i % 100 == 0:
                # Save the variables to disk.
                save_path = saver.save(sess,
                                       ckpt_name,
                                       global_step=batch_i,
                                       write_meta_graph=False)
                print("Model saved in file: %s" % save_path)
    except tf.errors.OutOfRangeError:
        print('Done training -- epoch limit reached')
    finally:
        # One of the threads has issued an exception.  So let's tell all the
        # threads to shutdown.
        coord.request_stop()

    # Wait until all threads have finished.
    coord.join(threads)

    # Clean up the session.
    sess.close()
Exemplo n.º 2
0
def train_vae(files,
              input_shape,
              learning_rate=0.0001,
              batch_size=100,
              n_epochs=50,
              n_examples=10,
              crop_shape=[64, 64, 3],
              crop_factor=0.8,
              n_filters=[100, 100, 100, 100],
              n_hidden=256,
              n_code=50,
              convolutional=True,
              variational=True,
              filter_sizes=[3, 3, 3, 3],
              dropout=True,
              keep_prob=0.8,
              activation=tf.nn.relu,
              img_step=100,
              save_step=100,
              ckpt_name="vae.ckpt"):
    """General purpose training of a (Variational) (Convolutional) Autoencoder.

    Supply a list of file paths to images, and this will do everything else.

    Parameters
    ----------
    files : list of strings
        List of paths to images.
    input_shape : list
        Must define what the input image's shape is.
    learning_rate : float, optional
        Learning rate.
    batch_size : int, optional
        Batch size.
    n_epochs : int, optional
        Number of epochs.
    n_examples : int, optional
        Number of example to use while demonstrating the current training
        iteration's reconstruction.  Creates a square montage, so make
        sure int(sqrt(n_examples))**2 = n_examples, e.g. 16, 25, 36, ... 100.
    crop_shape : list, optional
        Size to centrally crop the image to.
    crop_factor : float, optional
        Resize factor to apply before cropping.
    n_filters : list, optional
        Same as VAE's n_filters.
    n_hidden : int, optional
        Same as VAE's n_hidden.
    n_code : int, optional
        Same as VAE's n_code.
    convolutional : bool, optional
        Use convolution or not.
    variational : bool, optional
        Use variational layer or not.
    filter_sizes : list, optional
        Same as VAE's filter_sizes.
    dropout : bool, optional
        Use dropout or not
    keep_prob : float, optional
        Percent of keep for dropout.
    activation : function, optional
        Which activation function to use.
    img_step : int, optional
        How often to save training images showing the manifold and
        reconstruction.
    save_step : int, optional
        How often to save checkpoints.
    ckpt_name : str, optional
        Checkpoints will be named as this, e.g. 'model.ckpt'
    """
    batch = create_input_pipeline(
        files=files,
        batch_size=batch_size,
        n_epochs=n_epochs,
        crop_shape=crop_shape,
        crop_factor=crop_factor,
        shape=input_shape)

    ae = VAE(
        input_shape=[None] + crop_shape,
        convolutional=convolutional,
        variational=variational,
        n_filters=n_filters,
        n_hidden=n_hidden,
        n_code=n_code,
        dropout=dropout,
        filter_sizes=filter_sizes,
        activation=activation)

    # Create a manifold of our inner most layer to show
    # example reconstructions.  This is one way to see
    # what the "embedding" or "latent space" of the encoder
    # is capable of encoding, though note that this is just
    # a random hyperplane within the latent space, and does not
    # encompass all possible embeddings.
    zs = np.random.uniform(-1.0, 1.0, [4, n_code]).astype(np.float32)
    zs = utils.make_latent_manifold(zs, n_examples)

    optimizer = tf.train.AdamOptimizer(
        learning_rate=learning_rate).minimize(ae['cost'])

    # We create a session to use the graph
    sess = tf.Session()
    saver = tf.train.Saver()
    init_op = tf.group(tf.global_variables_initializer(),
                       tf.local_variables_initializer())
    sess.run(init_op)

    # This will handle our threaded image pipeline
    coord = tf.train.Coordinator()

    # Ensure no more changes to graph
    tf.get_default_graph().finalize()

    # Start up the queues for handling the image pipeline
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    if os.path.exists(ckpt_name + '.index') or os.path.exists(ckpt_name):
        saver.restore(sess, ckpt_name)

    # Fit all training data
    t_i = 0
    batch_i = 0
    epoch_i = 0
    cost = 0
    n_files = len(files)
    test_xs = sess.run(batch) / 255.0
    utils.montage(test_xs, 'test_xs.png')
    try:
        while not coord.should_stop() and epoch_i < n_epochs:
            batch_i += 1
            batch_xs = sess.run(batch) / 255.0
            train_cost = sess.run(
                [ae['cost'], optimizer],
                feed_dict={
                    ae['x']: batch_xs,
                    ae['train']: True,
                    ae['keep_prob']: keep_prob
                })[0]
            print(batch_i, train_cost)
            cost += train_cost
            if batch_i % n_files == 0:
                print('epoch:', epoch_i)
                print('average cost:', cost / batch_i)
                cost = 0
                batch_i = 0
                epoch_i += 1

            if batch_i % img_step == 0:
                # Plot example reconstructions from latent layer
                recon = sess.run(
                    ae['y'],
                    feed_dict={
                        ae['z']: zs,
                        ae['train']: False,
                        ae['keep_prob']: 1.0
                    })
                utils.montage(
                    recon.reshape([-1] + crop_shape), 'manifold_%08d.png' % t_i)

                # Plot example reconstructions
                recon = sess.run(
                    ae['y'],
                    feed_dict={
                        ae['x']: test_xs,
                        ae['train']: False,
                        ae['keep_prob']: 1.0
                    })
                print('reconstruction (min, max, mean):',
                      recon.min(), recon.max(), recon.mean())
                utils.montage(
                    recon.reshape([-1] + crop_shape),
                    'reconstruction_%08d.png' % t_i)
                t_i += 1

            if batch_i % save_step == 0:
                # Save the variables to disk.
                saver.save(
                    sess,
                    ckpt_name,
                    global_step=batch_i,
                    write_meta_graph=False)
    except tf.errors.OutOfRangeError:
        print('Done.')
    finally:
        # One of the threads has issued an exception.  So let's tell all the
        # threads to shutdown.
        coord.request_stop()

    # Wait until all threads have finished.
    coord.join(threads)

    # Clean up the session.
    sess.close()
Exemplo n.º 3
0
def test_mnist():
    """Train an autoencoder on MNIST.

    This function will train an autoencoder on MNIST and also
    save many image files during the training process, demonstrating
    the latent space of the inner most dimension of the encoder,
    as well as reconstructions of the decoder.
    """

    # load MNIST
    n_code = 2
    mnist = MNIST(split=[0.8, 0.1, 0.1])
    ae = VAE(
        input_shape=[None, 784],
        n_filters=[512, 256],
        n_hidden=64,
        n_code=n_code,
        activation=tf.nn.sigmoid,
        convolutional=False,
        variational=True)

    n_examples = 100
    zs = np.random.uniform(-1.0, 1.0, [4, n_code]).astype(np.float32)
    zs = utils.make_latent_manifold(zs, n_examples)

    learning_rate = 0.02
    optimizer = tf.train.AdamOptimizer(
        learning_rate=learning_rate).minimize(ae['cost'])

    # We create a session to use the graph
    sess = tf.Session()
    init_op = tf.group(tf.global_variables_initializer(),
                       tf.local_variables_initializer())
    sess.run(init_op)

    # Fit all training data
    t_i = 0
    batch_i = 0
    batch_size = 200
    n_epochs = 10
    test_xs = mnist.test.images[:n_examples]
    utils.montage(test_xs.reshape((-1, 28, 28)), 'test_xs.png')
    for epoch_i in range(n_epochs):
        train_i = 0
        train_cost = 0
        for batch_xs, _ in mnist.train.next_batch(batch_size):
            train_cost += sess.run(
                [ae['cost'], optimizer],
                feed_dict={
                    ae['x']: batch_xs,
                    ae['train']: True,
                    ae['keep_prob']: 1.0
                })[0]
            train_i += 1
            if batch_i % 10 == 0:
                # Plot example reconstructions from latent layer
                recon = sess.run(
                    ae['y'],
                    feed_dict={
                        ae['z']: zs,
                        ae['train']: False,
                        ae['keep_prob']: 1.0
                    })
                utils.montage(
                    recon.reshape((-1, 28, 28)), 'manifold_%08d.png' % t_i)
                # Plot example reconstructions
                recon = sess.run(
                    ae['y'],
                    feed_dict={
                        ae['x']: test_xs,
                        ae['train']: False,
                        ae['keep_prob']: 1.0
                    })
                utils.montage(
                    recon.reshape((-1, 28, 28)),
                    'reconstruction_%08d.png' % t_i)
                t_i += 1
            batch_i += 1

        valid_i = 0
        valid_cost = 0
        for batch_xs, _ in mnist.valid.next_batch(batch_size):
            valid_cost += sess.run(
                [ae['cost']],
                feed_dict={
                    ae['x']: batch_xs,
                    ae['train']: False,
                    ae['keep_prob']: 1.0
                })[0]
            valid_i += 1
        print('train:', train_cost / train_i, 'valid:', valid_cost / valid_i)
Exemplo n.º 4
0
def train_vae(files,
              input_shape,
              learning_rate=0.0001,
              batch_size=100,
              n_epochs=50,
              n_examples=10,
              crop_shape=[64, 64, 3],
              crop_factor=0.8,
              n_filters=[100, 100, 100, 100],
              n_hidden=256,
              n_code=50,
              convolutional=True,
              variational=True,
              filter_sizes=[3, 3, 3, 3],
              dropout=True,
              keep_prob=0.8,
              activation=tf.nn.relu,
              img_step=100,
              save_step=100,
              ckpt_name="vae.ckpt"):
    """General purpose training of a (Variational) (Convolutional) Autoencoder.

    Supply a list of file paths to images, and this will do everything else.

    Parameters
    ----------
    files : list of strings
        List of paths to images.
    input_shape : list
        Must define what the input image's shape is.
    learning_rate : float, optional
        Learning rate.
    batch_size : int, optional
        Batch size.
    n_epochs : int, optional
        Number of epochs.
    n_examples : int, optional
        Number of example to use while demonstrating the current training
        iteration's reconstruction.  Creates a square montage, so make
        sure int(sqrt(n_examples))**2 = n_examples, e.g. 16, 25, 36, ... 100.
    crop_shape : list, optional
        Size to centrally crop the image to.
    crop_factor : float, optional
        Resize factor to apply before cropping.
    n_filters : list, optional
        Same as VAE's n_filters.
    n_hidden : int, optional
        Same as VAE's n_hidden.
    n_code : int, optional
        Same as VAE's n_code.
    convolutional : bool, optional
        Use convolution or not.
    variational : bool, optional
        Use variational layer or not.
    filter_sizes : list, optional
        Same as VAE's filter_sizes.
    dropout : bool, optional
        Use dropout or not
    keep_prob : float, optional
        Percent of keep for dropout.
    activation : function, optional
        Which activation function to use.
    img_step : int, optional
        How often to save training images showing the manifold and
        reconstruction.
    save_step : int, optional
        How often to save checkpoints.
    ckpt_name : str, optional
        Checkpoints will be named as this, e.g. 'model.ckpt'
    """
    batch = create_input_pipeline(files=files,
                                  batch_size=batch_size,
                                  n_epochs=n_epochs,
                                  crop_shape=crop_shape,
                                  crop_factor=crop_factor,
                                  shape=input_shape)

    ae = VAE(input_shape=[None] + crop_shape,
             convolutional=convolutional,
             variational=variational,
             n_filters=n_filters,
             n_hidden=n_hidden,
             n_code=n_code,
             dropout=dropout,
             filter_sizes=filter_sizes,
             activation=activation)

    # Create a manifold of our inner most layer to show
    # example reconstructions.  This is one way to see
    # what the "embedding" or "latent space" of the encoder
    # is capable of encoding, though note that this is just
    # a random hyperplane within the latent space, and does not
    # encompass all possible embeddings.
    zs = np.random.uniform(-1.0, 1.0, [4, n_code]).astype(np.float32)
    zs = utils.make_latent_manifold(zs, n_examples)

    optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(
        ae['cost'])

    # We create a session to use the graph
    sess = tf.Session()
    saver = tf.train.Saver()
    init_op = tf.group(tf.global_variables_initializer(),
                       tf.local_variables_initializer())
    sess.run(init_op)

    # This will handle our threaded image pipeline
    coord = tf.train.Coordinator()

    # Ensure no more changes to graph
    tf.get_default_graph().finalize()

    # Start up the queues for handling the image pipeline
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    if os.path.exists(ckpt_name + '.index') or os.path.exists(ckpt_name):
        saver.restore(sess, ckpt_name)

    # Fit all training data
    t_i = 0
    batch_i = 0
    epoch_i = 0
    cost = 0
    n_files = len(files)
    test_xs = sess.run(batch) / 255.0
    utils.montage(test_xs, 'test_xs.png')
    try:
        while not coord.should_stop() and epoch_i < n_epochs:
            batch_i += 1
            batch_xs = sess.run(batch) / 255.0
            train_cost = sess.run([ae['cost'], optimizer],
                                  feed_dict={
                                      ae['x']: batch_xs,
                                      ae['train']: True,
                                      ae['keep_prob']: keep_prob
                                  })[0]
            print(batch_i, train_cost)
            cost += train_cost
            if batch_i % n_files == 0:
                print('epoch:', epoch_i)
                print('average cost:', cost / batch_i)
                cost = 0
                batch_i = 0
                epoch_i += 1

            if batch_i % img_step == 0:
                # Plot example reconstructions from latent layer
                recon = sess.run(ae['y'],
                                 feed_dict={
                                     ae['z']: zs,
                                     ae['train']: False,
                                     ae['keep_prob']: 1.0
                                 })
                utils.montage(recon.reshape([-1] + crop_shape),
                              'manifold_%08d.png' % t_i)

                # Plot example reconstructions
                recon = sess.run(ae['y'],
                                 feed_dict={
                                     ae['x']: test_xs,
                                     ae['train']: False,
                                     ae['keep_prob']: 1.0
                                 })
                print('reconstruction (min, max, mean):', recon.min(),
                      recon.max(), recon.mean())
                utils.montage(recon.reshape([-1] + crop_shape),
                              'reconstruction_%08d.png' % t_i)
                t_i += 1

            if batch_i % save_step == 0:
                # Save the variables to disk.
                saver.save(sess,
                           ckpt_name,
                           global_step=batch_i,
                           write_meta_graph=False)
    except tf.errors.OutOfRangeError:
        print('Done.')
    finally:
        # One of the threads has issued an exception.  So let's tell all the
        # threads to shutdown.
        coord.request_stop()

    # Wait until all threads have finished.
    coord.join(threads)

    # Clean up the session.
    sess.close()
Exemplo n.º 5
0
def test_mnist():
    """Train an autoencoder on MNIST.

    This function will train an autoencoder on MNIST and also
    save many image files during the training process, demonstrating
    the latent space of the inner most dimension of the encoder,
    as well as reconstructions of the decoder.
    """

    # load MNIST
    n_code = 2
    mnist = MNIST(split=[0.8, 0.1, 0.1])
    ae = VAE(input_shape=[None, 784],
             n_filters=[512, 256],
             n_hidden=64,
             n_code=n_code,
             activation=tf.nn.sigmoid,
             convolutional=False,
             variational=True)

    n_examples = 100
    zs = np.random.uniform(-1.0, 1.0, [4, n_code]).astype(np.float32)
    zs = utils.make_latent_manifold(zs, n_examples)

    learning_rate = 0.02
    optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(
        ae['cost'])

    # We create a session to use the graph
    sess = tf.Session()
    init_op = tf.group(tf.global_variables_initializer(),
                       tf.local_variables_initializer())
    sess.run(init_op)

    # Fit all training data
    t_i = 0
    batch_i = 0
    batch_size = 200
    n_epochs = 10
    test_xs = mnist.test.images[:n_examples]
    utils.montage(test_xs.reshape((-1, 28, 28)), 'test_xs.png')
    for epoch_i in range(n_epochs):
        train_i = 0
        train_cost = 0
        for batch_xs, _ in mnist.train.next_batch(batch_size):
            train_cost += sess.run([ae['cost'], optimizer],
                                   feed_dict={
                                       ae['x']: batch_xs,
                                       ae['train']: True,
                                       ae['keep_prob']: 1.0
                                   })[0]
            train_i += 1
            if batch_i % 10 == 0:
                # Plot example reconstructions from latent layer
                recon = sess.run(ae['y'],
                                 feed_dict={
                                     ae['z']: zs,
                                     ae['train']: False,
                                     ae['keep_prob']: 1.0
                                 })
                utils.montage(recon.reshape((-1, 28, 28)),
                              'manifold_%08d.png' % t_i)
                # Plot example reconstructions
                recon = sess.run(ae['y'],
                                 feed_dict={
                                     ae['x']: test_xs,
                                     ae['train']: False,
                                     ae['keep_prob']: 1.0
                                 })
                utils.montage(recon.reshape((-1, 28, 28)),
                              'reconstruction_%08d.png' % t_i)
                t_i += 1
            batch_i += 1

        valid_i = 0
        valid_cost = 0
        for batch_xs, _ in mnist.valid.next_batch(batch_size):
            valid_cost += sess.run([ae['cost']],
                                   feed_dict={
                                       ae['x']: batch_xs,
                                       ae['train']: False,
                                       ae['keep_prob']: 1.0
                                   })[0]
            valid_i += 1
        print('train:', train_cost / train_i, 'valid:', valid_cost / valid_i)
Exemplo n.º 6
0
def train_dataset(ds,
                  A,
                  B,
                  C,
                  T=20,
                  n_enc=512,
                  n_z=200,
                  n_dec=512,
                  read_n=12,
                  write_n=12,
                  batch_size=100,
                  n_epochs=100):

    if ds is None:
        ds = CIFAR10(split=[0.8, 0.1, 0.1])
        A, B, C = (32, 32, 3)

    n_examples = batch_size
    zs = np.random.uniform(-1.0, 1.0, [4, n_z]).astype(np.float32)
    zs = utils.make_latent_manifold(zs, n_examples)

    # We create a session to use the graph
    g = tf.Graph()
    with tf.Session(graph=g) as sess:
        draw = create_model(
            A=A,
            B=B,
            C=C,
            T=T,
            batch_size=batch_size,
            n_enc=n_enc,
            n_z=n_z,
            n_dec=n_dec,
            read_n=read_n,
            write_n=write_n)
        opt = tf.train.AdamOptimizer(learning_rate=0.0001)

        # Clip gradients
        grads = opt.compute_gradients(draw['cost'])
        for i, (g, v) in enumerate(grads):
            if g is not None:
                grads[i] = (tf.clip_by_norm(g, 5), v)
        train_op = opt.apply_gradients(grads)

        # Add summary variables
        tf.summary.scalar(name='cost', tensor=draw['cost'])
        tf.summary.scalar(name='loss_z', tensor=draw['loss_z'])
        tf.summary.scalar(name='loss_x', tensor=draw['loss_x'])
        tf.summary.histogram(
            name='recon_t0_histogram', values=draw['canvas'][0])
        tf.summary.histogram(
            name='recon_t-1_histogram', values=draw['canvas'][-1])
        tf.summary.image(
            name='recon_t0_image',
            tensor=tf.reshape(draw['canvas'][0], (-1, A, B, C)),
            max_outputs=2)
        tf.summary.image(
            name='recon_t-1_image',
            tensor=tf.reshape(draw['canvas'][-1], (-1, A, B, C)),
            max_outputs=2)
        sums = tf.summary.merge_all()
        train_writer = tf.summary.FileWriter(logdir='draw/train')
        valid_writer = tf.summary.FileWriter(logdir='draw/valid')

        # Init
        init_op = tf.group(tf.global_variables_initializer(),
                           tf.local_variables_initializer())
        sess.run(init_op)
        saver = tf.train.Saver()

        # Fit all training data
        batch_i = 0
        test_xs = ds.test.images[:n_examples] / 255.0
        utils.montage(test_xs.reshape((-1, A, B, C)), 'draw/test_xs.png')
        for epoch_i in range(n_epochs):
            for batch_xs, _ in ds.train.next_batch(batch_size):
                noise = np.random.randn(batch_size, n_z)
                cost, summary = sess.run(
                    [draw['cost'], sums, train_op],
                    feed_dict={
                        draw['x']: batch_xs / 255.0,
                        draw['noise']: noise
                    })[0:2]
                train_writer.add_summary(summary, batch_i)
                print('train cost:', cost)

                if batch_i % 1000 == 0:
                    # Plot example reconstructions
                    recon = sess.run(
                        draw['canvas'],
                        feed_dict={draw['x']: test_xs,
                                   draw['noise']: noise})
                    recon = [
                        utils.montage(r.reshape(-1, A, B, C)) for r in recon
                    ]
                    gif.build_gif(
                        recon,
                        cmap='gray',
                        saveto='draw/manifold_%08d.gif' % batch_i)
                    saver.save(sess, './draw/draw.ckpt', global_step=batch_i)
                batch_i += 1

            # Run validation
            if batch_i % 1000 == 0:
                for batch_xs, _ in ds.valid.next_batch(batch_size):
                    noise = np.random.randn(batch_size, n_z)
                    cost, summary = sess.run(
                        [draw['cost'], sums],
                        feed_dict={
                            draw['x']: batch_xs / 255.0,
                            draw['noise']: noise
                        })[0:2]
                    valid_writer.add_summary(summary, batch_i)
                    print('valid cost:', cost)
                    batch_i += 1
Exemplo n.º 7
0
def train_input_pipeline(files,
                         init_lr_g=1e-4,
                         init_lr_d=1e-4,
                         n_features=10,
                         n_latent=100,
                         n_epochs=1000000,
                         batch_size=200,
                         n_samples=15,
                         input_shape=[218, 178, 3],
                         crop_shape=[64, 64, 3],
                         crop_factor=0.8):
    """Summary

    Parameters
    ----------
    files : TYPE
        Description
    init_lr_g : float, optional
        Description
    init_lr_d : float, optional
        Description
    n_features : int, optional
        Description
    n_latent : int, optional
        Description
    n_epochs : int, optional
        Description
    batch_size : int, optional
        Description
    n_samples : int, optional
        Description
    input_shape : list, optional
        Description
    crop_shape : list, optional
        Description
    crop_factor : float, optional
        Description

    No Longer Returned
    ------------------
    name : TYPE
        Description
    """

    with tf.Graph().as_default(), tf.Session() as sess:
        batch = create_input_pipeline(
            files=files,
            batch_size=batch_size,
            n_epochs=n_epochs,
            crop_shape=crop_shape,
            crop_factor=crop_factor,
            shape=input_shape)

        gan = GAN(
            input_shape=[None] + crop_shape,
            n_features=n_features,
            n_latent=n_latent,
            rgb=True,
            debug=False)

        vars_d = [
            v for v in tf.trainable_variables()
            if v.name.startswith('discriminator')
        ]
        print('Training discriminator variables:')
        [
            print(v.name) for v in tf.trainable_variables()
            if v.name.startswith('discriminator')
        ]

        vars_g = [
            v for v in tf.trainable_variables()
            if v.name.startswith('generator')
        ]
        print('Training generator variables:')
        [
            print(v.name) for v in tf.trainable_variables()
            if v.name.startswith('generator')
        ]
        zs = np.random.uniform(-1.0, 1.0, [4, n_latent]).astype(np.float32)
        zs = utils.make_latent_manifold(zs, n_samples)

        lr_g = tf.placeholder(tf.float32, shape=[], name='learning_rate_g')
        lr_d = tf.placeholder(tf.float32, shape=[], name='learning_rate_d')

        try:
            from tf.contrib.layers import apply_regularization
            d_reg = apply_regularization(
                tf.contrib.layers.l2_regularizer(1e-6), vars_d)
            g_reg = apply_regularization(
                tf.contrib.layers.l2_regularizer(1e-6), vars_g)
        except:
            d_reg, g_reg = 0, 0

        opt_g = tf.train.AdamOptimizer(
            lr_g, name='Adam_g').minimize(
                gan['loss_G'] + g_reg, var_list=vars_g)
        opt_d = tf.train.AdamOptimizer(
            lr_d, name='Adam_d').minimize(
                gan['loss_D'] + d_reg, var_list=vars_d)

        # %%
        # We create a session to use the graph

        saver = tf.train.Saver()
        sums = gan['sums']
        G_sum_op = tf.summary.merge([
            sums['G'], sums['loss_G'], sums['z'], sums['loss_D_fake'],
            sums['D_fake']
        ])
        D_sum_op = tf.summary.merge([
            sums['loss_D'], sums['loss_D_real'], sums['loss_D_fake'], sums['z'],
            sums['x'], sums['D_real'], sums['D_fake']
        ])
        writer = tf.summary.FileWriter("./logs", sess.graph_def)

        init_op = tf.group(tf.global_variables_initializer(),
                           tf.local_variables_initializer())
        sess.run(init_op)
        coord = tf.train.Coordinator()
        tf.get_default_graph().finalize()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        # g = tf.get_default_graph()
        # [print(op.name) for op in g.get_operations()]

        if os.path.exists("gan.ckpt"):
            saver.restore(sess, "gan.ckpt")
            print("GAN model restored.")

        fig, ax = plt.subplots(1, 1, figsize=(10, 10))
        step_i, t_i = 0, 0
        loss_d = 1
        loss_g = 1
        n_loss_d, total_loss_d = 1, 1
        n_loss_g, total_loss_g = 1, 1
        try:
            while not coord.should_stop():
                batch_xs = sess.run(batch)
                step_i += 1
                batch_zs = np.random.uniform(
                    -1.0, 1.0, [batch_size, n_latent]).astype(np.float32)

                this_lr_g = min(1e-2,
                                max(1e-6, init_lr_g * (loss_g / loss_d)**2))
                this_lr_d = min(1e-2,
                                max(1e-6, init_lr_d * (loss_d / loss_g)**2))
                # this_lr_d *= ((1.0 - (step_i / 100000)) ** 2)
                # this_lr_g *= ((1.0 - (step_i / 100000)) ** 2)

                # if np.random.random() > (loss_g / (loss_d + loss_g)):
                if step_i % 3 == 1:
                    loss_d, _, sum_d = sess.run(
                        [gan['loss_D'], opt_d, D_sum_op],
                        feed_dict={
                            gan['x']: batch_xs,
                            gan['z']: batch_zs,
                            lr_d: this_lr_d
                        })
                    total_loss_d += loss_d
                    n_loss_d += 1
                    writer.add_summary(sum_d, step_i)
                    print('%04d d* = lr: %0.08f, loss: %08.06f, \t' % (
                        step_i, this_lr_d, loss_d
                    ) + 'g  = lr: %0.08f, loss: %08.06f' % (this_lr_g, loss_g))
                else:
                    loss_g, _, sum_g = sess.run(
                        [gan['loss_G'], opt_g, G_sum_op],
                        feed_dict={gan['z']: batch_zs,
                                   lr_g: this_lr_g})
                    total_loss_g += loss_g
                    n_loss_g += 1
                    writer.add_summary(sum_g, step_i)
                    print('%04d d  = lr: %0.08f, loss: %08.06f, \t' % (
                        step_i, this_lr_d, loss_d
                    ) + 'g* = lr: %0.08f, loss: %08.06f' % (this_lr_g, loss_g))

                if step_i % 100 == 0:
                    samples = sess.run(gan['G'], feed_dict={gan['z']: zs})
                    utils.montage(
                        np.clip((samples + 1) * 127.5, 0, 255).astype(np.uint8),
                        'imgs/gan_%08d.png' % t_i)
                    t_i += 1

                    print('generator loss:', total_loss_g / n_loss_g)
                    print('discriminator loss:', total_loss_d / n_loss_d)

                    # Save the variables to disk.
                    save_path = saver.save(
                        sess,
                        "./gan.ckpt",
                        global_step=step_i,
                        write_meta_graph=False)
                    print("Model saved in file: %s" % save_path)
        except tf.errors.OutOfRangeError:
            print('Done training -- epoch limit reached')
        finally:
            # One of the threads has issued an exception.  So let's tell all the
            # threads to shutdown.
            coord.request_stop()

        # Wait until all threads have finished.
        coord.join(threads)
Exemplo n.º 8
0
def test_mnist():
    A = 28  # img_h
    B = 28  # img_w
    C = 1
    T = 10
    n_enc = 256
    n_z = 100
    n_dec = 256
    read_n = 5
    write_n = 5
    batch_size = 64
    mnist = MNIST(split=[0.8, 0.1, 0.1])

    n_examples = batch_size
    zs = np.random.uniform(-1.0, 1.0, [4, n_z]).astype(np.float32)
    zs = utils.make_latent_manifold(zs, n_examples)

    # We create a session to use the graph
    g = tf.Graph()
    with tf.Session(graph=g) as sess:
        draw = create_model(
            A=A,
            B=B,
            C=C,
            T=T,
            batch_size=batch_size,
            n_enc=n_enc,
            n_z=n_z,
            n_dec=n_dec,
            read_n=read_n,
            write_n=write_n)
        opt = tf.train.AdamOptimizer(learning_rate=0.0001)
        grads = opt.compute_gradients(draw['cost'])
        for i, (g, v) in enumerate(grads):
            if g is not None:
                grads[i] = (tf.clip_by_norm(g, 5), v)
        train_op = opt.apply_gradients(grads)
        init_op = tf.group(tf.global_variables_initializer(),
                           tf.local_variables_initializer())
        sess.run(init_op)
        saver = tf.train.Saver()

        # Fit all training data
        batch_i = 0
        n_epochs = 100
        test_xs = mnist.test.images[:n_examples]
        utils.montage(test_xs.reshape((-1, A, B)), 'test_xs.png')
        for epoch_i in range(n_epochs):
            for batch_xs, _ in mnist.train.next_batch(batch_size):
                noise = np.random.randn(batch_size, n_z)
                lx, lz = sess.run(
                    [draw['loss_x'], draw['loss_z'], train_op],
                    feed_dict={draw['x']: batch_xs,
                               draw['noise']: noise})[0:2]
                print('x:', lx, 'z:', lz)
                if batch_i % 1000 == 0:
                    # Plot example reconstructions
                    recon = sess.run(
                        draw['canvas'],
                        feed_dict={draw['x']: test_xs,
                                   draw['noise']: noise})
                    recon = [utils.montage(r.reshape(-1, A, B)) for r in recon]
                    gif.build_gif(
                        recon,
                        cmap='gray',
                        saveto='manifold_%08d.gif' % batch_i)

                    saver.save(sess, './draw.ckpt', global_step=batch_i)

                batch_i += 1
Exemplo n.º 9
0
Arquivo: gan.py Projeto: zhf459/pycadl
def train_input_pipeline(files,
                         init_lr_g=1e-4,
                         init_lr_d=1e-4,
                         n_features=10,
                         n_latent=100,
                         n_epochs=1000000,
                         batch_size=200,
                         n_samples=15,
                         input_shape=[218, 178, 3],
                         crop_shape=[64, 64, 3],
                         crop_factor=0.8):
    """Summary

    Parameters
    ----------
    files : TYPE
        Description
    init_lr_g : float, optional
        Description
    init_lr_d : float, optional
        Description
    n_features : int, optional
        Description
    n_latent : int, optional
        Description
    n_epochs : int, optional
        Description
    batch_size : int, optional
        Description
    n_samples : int, optional
        Description
    input_shape : list, optional
        Description
    crop_shape : list, optional
        Description
    crop_factor : float, optional
        Description

    No Longer Returned
    ------------------
    name : TYPE
        Description
    """

    with tf.Graph().as_default(), tf.Session() as sess:
        batch = create_input_pipeline(files=files,
                                      batch_size=batch_size,
                                      n_epochs=n_epochs,
                                      crop_shape=crop_shape,
                                      crop_factor=crop_factor,
                                      shape=input_shape)

        gan = GAN(input_shape=[None] + crop_shape,
                  n_features=n_features,
                  n_latent=n_latent,
                  rgb=True,
                  debug=False)

        vars_d = [
            v for v in tf.trainable_variables()
            if v.name.startswith('discriminator')
        ]
        print('Training discriminator variables:')
        [
            print(v.name) for v in tf.trainable_variables()
            if v.name.startswith('discriminator')
        ]

        vars_g = [
            v for v in tf.trainable_variables()
            if v.name.startswith('generator')
        ]
        print('Training generator variables:')
        [
            print(v.name) for v in tf.trainable_variables()
            if v.name.startswith('generator')
        ]
        zs = np.random.uniform(-1.0, 1.0, [4, n_latent]).astype(np.float32)
        zs = utils.make_latent_manifold(zs, n_samples)

        lr_g = tf.placeholder(tf.float32, shape=[], name='learning_rate_g')
        lr_d = tf.placeholder(tf.float32, shape=[], name='learning_rate_d')

        try:
            from tf.contrib.layers import apply_regularization
            d_reg = apply_regularization(
                tf.contrib.layers.l2_regularizer(1e-6), vars_d)
            g_reg = apply_regularization(
                tf.contrib.layers.l2_regularizer(1e-6), vars_g)
        except:
            d_reg, g_reg = 0, 0

        opt_g = tf.train.AdamOptimizer(lr_g, name='Adam_g').minimize(
            gan['loss_G'] + g_reg, var_list=vars_g)
        opt_d = tf.train.AdamOptimizer(lr_d, name='Adam_d').minimize(
            gan['loss_D'] + d_reg, var_list=vars_d)

        # %%
        # We create a session to use the graph

        saver = tf.train.Saver()
        sums = gan['sums']
        G_sum_op = tf.summary.merge([
            sums['G'], sums['loss_G'], sums['z'], sums['loss_D_fake'],
            sums['D_fake']
        ])
        D_sum_op = tf.summary.merge([
            sums['loss_D'], sums['loss_D_real'], sums['loss_D_fake'],
            sums['z'], sums['x'], sums['D_real'], sums['D_fake']
        ])
        writer = tf.summary.FileWriter("./logs", sess.graph_def)

        init_op = tf.group(tf.global_variables_initializer(),
                           tf.local_variables_initializer())
        sess.run(init_op)
        coord = tf.train.Coordinator()
        tf.get_default_graph().finalize()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        # g = tf.get_default_graph()
        # [print(op.name) for op in g.get_operations()]

        if os.path.exists("gan.ckpt"):
            saver.restore(sess, "gan.ckpt")
            print("GAN model restored.")

        fig, ax = plt.subplots(1, 1, figsize=(10, 10))
        step_i, t_i = 0, 0
        loss_d = 1
        loss_g = 1
        n_loss_d, total_loss_d = 1, 1
        n_loss_g, total_loss_g = 1, 1
        try:
            while not coord.should_stop():
                batch_xs = sess.run(batch)
                step_i += 1
                batch_zs = np.random.uniform(
                    -1.0, 1.0, [batch_size, n_latent]).astype(np.float32)

                this_lr_g = min(1e-2,
                                max(1e-6,
                                    init_lr_g * (loss_g / loss_d)**2))
                this_lr_d = min(1e-2,
                                max(1e-6,
                                    init_lr_d * (loss_d / loss_g)**2))
                # this_lr_d *= ((1.0 - (step_i / 100000)) ** 2)
                # this_lr_g *= ((1.0 - (step_i / 100000)) ** 2)

                # if np.random.random() > (loss_g / (loss_d + loss_g)):
                if step_i % 3 == 1:
                    loss_d, _, sum_d = sess.run(
                        [gan['loss_D'], opt_d, D_sum_op],
                        feed_dict={
                            gan['x']: batch_xs,
                            gan['z']: batch_zs,
                            lr_d: this_lr_d
                        })
                    total_loss_d += loss_d
                    n_loss_d += 1
                    writer.add_summary(sum_d, step_i)
                    print('%04d d* = lr: %0.08f, loss: %08.06f, \t' %
                          (step_i, this_lr_d, loss_d) +
                          'g  = lr: %0.08f, loss: %08.06f' %
                          (this_lr_g, loss_g))
                else:
                    loss_g, _, sum_g = sess.run(
                        [gan['loss_G'], opt_g, G_sum_op],
                        feed_dict={
                            gan['z']: batch_zs,
                            lr_g: this_lr_g
                        })
                    total_loss_g += loss_g
                    n_loss_g += 1
                    writer.add_summary(sum_g, step_i)
                    print('%04d d  = lr: %0.08f, loss: %08.06f, \t' %
                          (step_i, this_lr_d, loss_d) +
                          'g* = lr: %0.08f, loss: %08.06f' %
                          (this_lr_g, loss_g))

                if step_i % 100 == 0:
                    samples = sess.run(gan['G'], feed_dict={gan['z']: zs})
                    utils.montage(
                        np.clip((samples + 1) * 127.5, 0,
                                255).astype(np.uint8),
                        'imgs/gan_%08d.png' % t_i)
                    t_i += 1

                    print('generator loss:', total_loss_g / n_loss_g)
                    print('discriminator loss:', total_loss_d / n_loss_d)

                    # Save the variables to disk.
                    save_path = saver.save(sess,
                                           "./gan.ckpt",
                                           global_step=step_i,
                                           write_meta_graph=False)
                    print("Model saved in file: %s" % save_path)
        except tf.errors.OutOfRangeError:
            print('Done training -- epoch limit reached')
        finally:
            # One of the threads has issued an exception.  So let's tell all the
            # threads to shutdown.
            coord.request_stop()

        # Wait until all threads have finished.
        coord.join(threads)
Exemplo n.º 10
0
def train_dataset(ds,
                  A,
                  B,
                  C,
                  T=20,
                  n_enc=512,
                  n_z=200,
                  n_dec=512,
                  read_n=12,
                  write_n=12,
                  batch_size=100,
                  n_epochs=100):

    if ds is None:
        ds = CIFAR10(split=[0.8, 0.1, 0.1])
        A, B, C = (32, 32, 3)

    n_examples = batch_size
    zs = np.random.uniform(-1.0, 1.0, [4, n_z]).astype(np.float32)
    zs = utils.make_latent_manifold(zs, n_examples)

    # We create a session to use the graph
    g = tf.Graph()
    with tf.Session(graph=g) as sess:
        draw = create_model(
            A=A,
            B=B,
            C=C,
            T=T,
            batch_size=batch_size,
            n_enc=n_enc,
            n_z=n_z,
            n_dec=n_dec,
            read_n=read_n,
            write_n=write_n)
        opt = tf.train.AdamOptimizer(learning_rate=0.0001)

        # Clip gradients
        grads = opt.compute_gradients(draw['cost'])
        for i, (g, v) in enumerate(grads):
            if g is not None:
                grads[i] = (tf.clip_by_norm(g, 5), v)
        train_op = opt.apply_gradients(grads)

        # Add summary variables
        tf.summary.scalar(name='cost', tensor=draw['cost'])
        tf.summary.scalar(name='loss_z', tensor=draw['loss_z'])
        tf.summary.scalar(name='loss_x', tensor=draw['loss_x'])
        tf.summary.histogram(
            name='recon_t0_histogram', values=draw['canvas'][0])
        tf.summary.histogram(
            name='recon_t-1_histogram', values=draw['canvas'][-1])
        tf.summary.image(
            name='recon_t0_image',
            tensor=tf.reshape(draw['canvas'][0], (-1, A, B, C)),
            max_outputs=2)
        tf.summary.image(
            name='recon_t-1_image',
            tensor=tf.reshape(draw['canvas'][-1], (-1, A, B, C)),
            max_outputs=2)
        sums = tf.summary.merge_all()
        train_writer = tf.summary.FileWriter(logdir='draw/train')
        valid_writer = tf.summary.FileWriter(logdir='draw/valid')

        # Init
        init_op = tf.group(tf.global_variables_initializer(),
                           tf.local_variables_initializer())
        sess.run(init_op)
        saver = tf.train.Saver()

        # Fit all training data
        batch_i = 0
        test_xs = ds.test.images[:n_examples] / 255.0
        utils.montage(test_xs.reshape((-1, A, B, C)), 'draw/test_xs.png')
        for epoch_i in range(n_epochs):
            for batch_xs, _ in ds.train.next_batch(batch_size):
                noise = np.random.randn(batch_size, n_z)
                cost, summary = sess.run(
                    [draw['cost'], sums, train_op],
                    feed_dict={
                        draw['x']: batch_xs / 255.0,
                        draw['noise']: noise
                    })[0:2]
                train_writer.add_summary(summary, batch_i)
                print('train cost:', cost)

                if batch_i % 1000 == 0:
                    # Plot example reconstructions
                    recon = sess.run(
                        draw['canvas'],
                        feed_dict={draw['x']: test_xs,
                                   draw['noise']: noise})
                    recon = [
                        utils.montage(r.reshape(-1, A, B, C)) for r in recon
                    ]
                    gif.build_gif(
                        recon,
                        cmap='gray',
                        saveto='draw/manifold_%08d.gif' % batch_i)
                    saver.save(sess, './draw/draw.ckpt', global_step=batch_i)
                batch_i += 1

            # Run validation
            if batch_i % 1000 == 0:
                for batch_xs, _ in ds.valid.next_batch(batch_size):
                    noise = np.random.randn(batch_size, n_z)
                    cost, summary = sess.run(
                        [draw['cost'], sums],
                        feed_dict={
                            draw['x']: batch_xs / 255.0,
                            draw['noise']: noise
                        })[0:2]
                    valid_writer.add_summary(summary, batch_i)
                    print('valid cost:', cost)
                    batch_i += 1
Exemplo n.º 11
0
def test_mnist():
    A = 28  # img_h
    B = 28  # img_w
    C = 1
    T = 10
    n_enc = 256
    n_z = 100
    n_dec = 256
    read_n = 5
    write_n = 5
    batch_size = 64
    mnist = MNIST(split=[0.8, 0.1, 0.1])

    n_examples = batch_size
    zs = np.random.uniform(-1.0, 1.0, [4, n_z]).astype(np.float32)
    zs = utils.make_latent_manifold(zs, n_examples)

    # We create a session to use the graph
    g = tf.Graph()
    with tf.Session(graph=g) as sess:
        draw = create_model(
            A=A,
            B=B,
            C=C,
            T=T,
            batch_size=batch_size,
            n_enc=n_enc,
            n_z=n_z,
            n_dec=n_dec,
            read_n=read_n,
            write_n=write_n)
        opt = tf.train.AdamOptimizer(learning_rate=0.0001)
        grads = opt.compute_gradients(draw['cost'])
        for i, (g, v) in enumerate(grads):
            if g is not None:
                grads[i] = (tf.clip_by_norm(g, 5), v)
        train_op = opt.apply_gradients(grads)
        init_op = tf.group(tf.global_variables_initializer(),
                           tf.local_variables_initializer())
        sess.run(init_op)
        saver = tf.train.Saver()

        # Fit all training data
        batch_i = 0
        n_epochs = 100
        test_xs = mnist.test.images[:n_examples]
        utils.montage(test_xs.reshape((-1, A, B)), 'test_xs.png')
        for epoch_i in range(n_epochs):
            for batch_xs, _ in mnist.train.next_batch(batch_size):
                noise = np.random.randn(batch_size, n_z)
                lx, lz = sess.run(
                    [draw['loss_x'], draw['loss_z'], train_op],
                    feed_dict={draw['x']: batch_xs,
                               draw['noise']: noise})[0:2]
                print('x:', lx, 'z:', lz)
                if batch_i % 1000 == 0:
                    # Plot example reconstructions
                    recon = sess.run(
                        draw['canvas'],
                        feed_dict={draw['x']: test_xs,
                                   draw['noise']: noise})
                    recon = [utils.montage(r.reshape(-1, A, B)) for r in recon]
                    gif.build_gif(
                        recon,
                        cmap='gray',
                        saveto='manifold_%08d.gif' % batch_i)

                    saver.save(sess, './draw.ckpt', global_step=batch_i)

                batch_i += 1
Exemplo n.º 12
0
def train_vaegan(files,
                 learning_rate=0.00001,
                 batch_size=64,
                 n_epochs=250,
                 n_examples=10,
                 input_shape=[218, 178, 3],
                 crop_shape=[64, 64, 3],
                 crop_factor=0.8,
                 n_filters=[100, 100, 100, 100],
                 n_hidden=None,
                 n_code=128,
                 convolutional=True,
                 variational=True,
                 filter_sizes=[3, 3, 3, 3],
                 activation=tf.nn.elu,
                 ckpt_name="vaegan.ckpt"):
    """Summary

    Parameters
    ----------
    files : TYPE
        Description
    learning_rate : float, optional
        Description
    batch_size : int, optional
        Description
    n_epochs : int, optional
        Description
    n_examples : int, optional
        Description
    input_shape : list, optional
        Description
    crop_shape : list, optional
        Description
    crop_factor : float, optional
        Description
    n_filters : list, optional
        Description
    n_hidden : int, optional
        Description
    n_code : int, optional
        Description
    convolutional : bool, optional
        Description
    variational : bool, optional
        Description
    filter_sizes : list, optional
        Description
    activation : TYPE, optional
        Description
    ckpt_name : str, optional
        Description

    No Longer Returned
    ------------------
    name : TYPE
        Description
    """

    ae = VAEGAN(
        input_shape=[None] + crop_shape,
        convolutional=convolutional,
        variational=variational,
        n_filters=n_filters,
        n_hidden=n_hidden,
        n_code=n_code,
        filter_sizes=filter_sizes,
        activation=activation)

    batch = create_input_pipeline(
        files=files,
        batch_size=batch_size,
        n_epochs=n_epochs,
        crop_shape=crop_shape,
        crop_factor=crop_factor,
        shape=input_shape)

    zs = np.random.randn(4, n_code).astype(np.float32)
    zs = utils.make_latent_manifold(zs, n_examples)

    opt_enc = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(
        ae['loss_enc'],
        var_list=[
            var_i for var_i in tf.trainable_variables()
            if var_i.name.startswith('encoder')
        ])

    opt_gen = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(
        ae['loss_gen'],
        var_list=[
            var_i for var_i in tf.trainable_variables()
            if var_i.name.startswith('generator')
        ])

    opt_dis = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(
        ae['loss_dis'],
        var_list=[
            var_i for var_i in tf.trainable_variables()
            if var_i.name.startswith('discriminator')
        ])

    sess = tf.Session()
    saver = tf.train.Saver()
    init_op = tf.group(tf.global_variables_initializer(),
                       tf.local_variables_initializer())
    sess.run(init_op)
    coord = tf.train.Coordinator()
    tf.get_default_graph().finalize()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    if os.path.exists(ckpt_name + '.index') or os.path.exists(ckpt_name):
        saver.restore(sess, ckpt_name)
        print("VAE model restored.")

    t_i = 0
    batch_i = 0
    epoch_i = 0

    equilibrium = 0.693
    margin = 0.4

    n_files = len(files)
    test_xs = sess.run(batch) / 255.0
    utils.montage(test_xs, 'test_xs.png')
    try:
        while not coord.should_stop() and epoch_i < n_epochs:
            if batch_i % (n_files // batch_size) == 0:
                batch_i = 0
                epoch_i += 1
                print('---------- EPOCH:', epoch_i)

            batch_i += 1
            batch_xs = sess.run(batch) / 255.0
            batch_zs = np.random.randn(batch_size, n_code).astype(np.float32)
            real_cost, fake_cost, _ = sess.run(
                [ae['loss_real'], ae['loss_fake'], opt_enc],
                feed_dict={ae['x']: batch_xs,
                           ae['gamma']: 0.5})
            real_cost = -np.mean(real_cost)
            fake_cost = -np.mean(fake_cost)
            print('real:', real_cost, '/ fake:', fake_cost)

            gen_update = True
            dis_update = True

            if real_cost > (equilibrium + margin) or \
               fake_cost > (equilibrium + margin):
                gen_update = False

            if real_cost < (equilibrium - margin) or \
               fake_cost < (equilibrium - margin):
                dis_update = False

            if not (gen_update or dis_update):
                gen_update = True
                dis_update = True

            if gen_update:
                sess.run(
                    opt_gen,
                    feed_dict={
                        ae['x']: batch_xs,
                        ae['z_samp']: batch_zs,
                        ae['gamma']: 0.5
                    })
            if dis_update:
                sess.run(
                    opt_dis,
                    feed_dict={
                        ae['x']: batch_xs,
                        ae['z_samp']: batch_zs,
                        ae['gamma']: 0.5
                    })

            if batch_i % 50 == 0:

                # Plot example reconstructions from latent layer
                recon = sess.run(ae['x_tilde'], feed_dict={ae['z']: zs})
                print('recon:', recon.min(), recon.max())
                recon = np.clip(recon / recon.max(), 0, 1)
                utils.montage(
                    recon.reshape([-1] + crop_shape),
                    'imgs/manifold_%08d.png' % t_i)

                # Plot example reconstructions
                recon = sess.run(ae['x_tilde'], feed_dict={ae['x']: test_xs})
                print('recon:', recon.min(), recon.max())
                recon = np.clip(recon / recon.max(), 0, 1)
                utils.montage(
                    recon.reshape([-1] + crop_shape),
                    'imgs/reconstruction_%08d.png' % t_i)
                t_i += 1

            if batch_i % 100 == 0:
                # Save the variables to disk.
                save_path = saver.save(
                    sess,
                    ckpt_name,
                    global_step=batch_i,
                    write_meta_graph=False)
                print("Model saved in file: %s" % save_path)
    except tf.errors.OutOfRangeError:
        print('Done training -- epoch limit reached')
    finally:
        # One of the threads has issued an exception.  So let's tell all the
        # threads to shutdown.
        coord.request_stop()

    # Wait until all threads have finished.
    coord.join(threads)

    # Clean up the session.
    sess.close()