Exemplo n.º 1
0
def train_model():

    # Setup session
    sess = tu.setup_session()

    # Placeholder for data and Mnist iterator
    mnist = input_data.read_data_sets(FLAGS.raw_dir, one_hot=True)
    assert FLAGS.data_format == "NCHW", "Scattering only implemented in NCHW"
    X_tensor = tf.placeholder(tf.float32, shape=[FLAGS.batch_size, 1, 28, 28])
    y_tensor = tf.placeholder(tf.int64, shape=[FLAGS.batch_size, 10])

    with tf.device('/cpu:0'):
        X_train = mnist.train.images.astype(np.float32)
        y_train = mnist.train.labels.astype(np.int64)

        X_validation = mnist.validation.images.astype(np.float32)
        y_validation = mnist.validation.labels.astype(np.int64)

        X_train = (X_train - 0.5) / 0.5
        X_train = X_train.reshape((-1, 1, 28, 28))

        X_validation = (X_validation - 0.5) / 0.5
        X_validation = X_validation.reshape((-1, 1, 28, 28))

    # Build model
    class HybridCNN(models.Model):
        def __call__(self, x, reuse=False):
            with tf.variable_scope(self.name) as scope:

                if reuse:
                    scope.reuse_variables()

                M, N = x.get_shape().as_list()[-2:]
                x = scattering.Scattering(M=M, N=N, J=2)(x)
                x = tf.contrib.layers.batch_norm(x,
                                                 data_format=FLAGS.data_format,
                                                 fused=True,
                                                 scope="scat_bn")
                x = layers.conv2d_block("CONV2D",
                                        x,
                                        64,
                                        1,
                                        1,
                                        p="SAME",
                                        data_format=FLAGS.data_format,
                                        bias=True,
                                        bn=False,
                                        activation_fn=tf.nn.relu)

                target_shape = (-1, 64 * 7 * 7)
                x = layers.reshape(x, target_shape)
                x = layers.linear(x, 512, name="dense1")
                x = tf.nn.relu(x)
                x = layers.linear(x, 10, name="dense2")

                return x

    HCNN = HybridCNN("HCNN")
    y_pred = HCNN(X_tensor)

    ###########################
    # Instantiate optimizers
    ###########################
    opt = tf.train.AdamOptimizer(learning_rate=FLAGS.learning_rate,
                                 name='opt',
                                 beta1=0.5)

    loss = tf.reduce_mean(
        tf.nn.softmax_cross_entropy_with_logits(labels=y_tensor,
                                                logits=y_pred))
    correct_prediction = tf.equal(tf.argmax(y_pred, 1), tf.argmax(y_tensor, 1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

    ###########################
    # Compute gradient updates
    ###########################

    dict_vars = HCNN.get_trainable_variables()
    all_vars = [dict_vars[k] for k in dict_vars.keys()]

    gradvar = opt.compute_gradients(loss,
                                    var_list=all_vars,
                                    colocate_gradients_with_ops=True)
    update = opt.apply_gradients(gradvar, name='loss_minimize')

    ##########################
    # Group training ops
    ##########################
    train_ops = [update]
    loss_ops = [loss, accuracy]

    ##########################
    # Summary ops
    ##########################

    # Add summary for gradients
    tu.add_gradient_summary(gradvar)

    # Add scalar symmaries
    tf.summary.scalar("loss", loss)

    summary_op = tf.summary.merge_all()

    ############################
    # Start training
    ############################

    # Initialize session
    tu.initialize_session(sess)

    # Start queues
    tu.manage_queues(sess)

    # Summaries
    writer = tu.manage_summaries(sess)

    for e in tqdm(range(FLAGS.nb_epoch), desc="Training progress"):

        t = tqdm(range(FLAGS.nb_batch_per_epoch),
                 desc="Epoch %i" % e,
                 mininterval=0.5)
        for batch_counter in t:

            # Get training data
            X_train_batch, y_train_batch = du.sample_batch(
                X_train, y_train, FLAGS.batch_size)

            # Run update and get loss
            output = sess.run(train_ops + loss_ops + [summary_op],
                              feed_dict={
                                  X_tensor: X_train_batch,
                                  y_tensor: y_train_batch
                              })
            train_loss = output[1]
            train_acc = output[2]

            # Write summaries
            if batch_counter % (FLAGS.nb_batch_per_epoch // 20) == 0:
                writer.add_summary(
                    output[-1], e * FLAGS.nb_batch_per_epoch + batch_counter)

            # Get validation data
            X_validation_batch, y_validation_batch = du.sample_batch(
                X_validation, y_validation, FLAGS.batch_size)

            # Run update and get loss
            output = sess.run(loss_ops,
                              feed_dict={
                                  X_tensor: X_validation_batch,
                                  y_tensor: y_validation_batch
                              })
            validation_loss = output[0]
            validation_acc = output[1]

            t.set_description(
                'Epoch %i: - train loss: %.2f val loss: %.2f - train acc: %.2f val acc: %.2f'
                % (e, train_loss, validation_loss, train_acc, validation_acc))

    print('Finished training!')
def train_model():

    # Setup session
    sess = tu.setup_session()

    # Placeholder for data and Mnist iterator
    mnist = input_data.read_data_sets(FLAGS.raw_dir, one_hot=True)
    # images = tf.constant(mnist.train.images)
    if FLAGS.data_format == "NHWC":
        X_real = tf.placeholder(tf.float32, shape=[FLAGS.batch_size, 28, 28, 1])
    else:
        X_real = tf.placeholder(tf.float32, shape=[FLAGS.batch_size, 1, 28, 28])

    with tf.device('/cpu:0'):
        imgs = mnist.train.images.astype(np.float32)
        npts = imgs.shape[0]
        if FLAGS.data_format == "NHWC":
            imgs = imgs.reshape((npts, 28, 28, 1))
        else:
            imgs = imgs.reshape((npts, 1, 28, 28))
        imgs = (imgs - 0.5) / 0.5
        # input_images = tf.constant(imgs)

    #     image = tf.train.slice_input_producer([input_images], num_epochs=FLAGS.nb_epoch)
    #     X_real = tf.train.batch(image, batch_size=FLAGS.batch_size, num_threads=8)

    #######################
    # Instantiate generator
    #######################
    list_filters = [256, 1]
    list_strides = [2] * len(list_filters)
    list_kernel_size = [3] * len(list_filters)
    list_padding = ["SAME"] * len(list_filters)
    output_shape = X_real.get_shape().as_list()[1:]
    G = models.Generator(list_filters, list_kernel_size, list_strides, list_padding, output_shape,
                         batch_size=FLAGS.batch_size, dset="mnist", data_format=FLAGS.data_format)

    ###########################
    # Instantiate discriminator
    ###########################
    list_filters = [32, 64]
    list_strides = [2] * len(list_filters)
    list_kernel_size = [3] * len(list_filters)
    list_padding = ["SAME"] * len(list_filters)
    D = models.Discriminator(list_filters, list_kernel_size, list_strides, list_padding,
                             FLAGS.batch_size, data_format=FLAGS.data_format)

    ###########################
    # Instantiate optimizers
    ###########################
    G_opt = tf.train.AdamOptimizer(learning_rate=FLAGS.learning_rate, name='G_opt', beta1=0.5)
    D_opt = tf.train.AdamOptimizer(learning_rate=FLAGS.learning_rate, name='D_opt', beta1=0.5)

    ###########################
    # Instantiate model outputs
    ###########################

    # noise_input = tf.random_normal((FLAGS.batch_size, FLAGS.noise_dim,), stddev=0.1)
    noise_input = tf.random_uniform((FLAGS.batch_size, FLAGS.noise_dim,), minval=-1, maxval=1)
    X_fake = G(noise_input)

    # output images
    X_G_output = du.unnormalize_image(X_fake)
    X_real_output = du.unnormalize_image(X_real)

    D_real = D(X_real)
    D_fake = D(X_fake, reuse=True)

    ###########################
    # Instantiate losses
    ###########################

    G_loss = objectives.binary_cross_entropy_with_logits(D_fake, tf.ones_like(D_fake))
    D_loss_real = objectives.binary_cross_entropy_with_logits(D_real, tf.ones_like(D_real))
    D_loss_fake = objectives.binary_cross_entropy_with_logits(D_fake, tf.zeros_like(D_fake))

    D_loss = D_loss_real + D_loss_fake

    # ######################################################################
    # # Some parameters need to be updated (e.g. BN moving average/variance)
    # ######################################################################
    # from tensorflow.python.ops import control_flow_ops
    # update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    # with tf.control_dependencies(update_ops):
    #     barrier = tf.no_op(name='update_barrier')
    # D_loss = control_flow_ops.with_dependencies([barrier], D_loss)
    # G_loss = control_flow_ops.with_dependencies([barrier], G_loss)

    ###########################
    # Compute gradient updates
    ###########################

    dict_G_vars = G.get_trainable_variables()
    G_vars = [dict_G_vars[k] for k in dict_G_vars.keys()]

    dict_D_vars = D.get_trainable_variables()
    D_vars = [dict_D_vars[k] for k in dict_D_vars.keys()]

    G_gradvar = G_opt.compute_gradients(G_loss, var_list=G_vars)
    G_update = G_opt.apply_gradients(G_gradvar, name='G_loss_minimize')

    D_gradvar = D_opt.compute_gradients(D_loss, var_list=D_vars)
    D_update = D_opt.apply_gradients(D_gradvar, name='D_loss_minimize')

    ##########################
    # Group training ops
    ##########################
    train_ops = [G_update, D_update]
    loss_ops = [G_loss, D_loss, D_loss_real, D_loss_fake]

    ##########################
    # Summary ops
    ##########################

    # Add summary for gradients
    tu.add_gradient_summary(G_gradvar)
    tu.add_gradient_summary(D_gradvar)

    # Add scalar symmaries
    tf.summary.scalar("G loss", G_loss)
    tf.summary.scalar("D loss real", D_loss_real)
    tf.summary.scalar("D loss fake", D_loss_fake)

    summary_op = tf.summary.merge_all()

    ############################
    # Start training
    ############################

    # Initialize session
    saver = tu.initialize_session(sess)

    # Start queues
    coord = tu.manage_queues(sess)

    # Summaries
    writer = tu.manage_summaries(sess)

    for e in tqdm(range(FLAGS.nb_epoch), desc="\nTraining progress"):

        t = tqdm(range(FLAGS.nb_batch_per_epoch), desc="Epoch %i" % e, mininterval=0.5)
        for batch_counter in t:

            # X_batch, _ = mnist.train.next_batch(FLAGS.batch_size)
            # if FLAGS.data_format == "NHWC":
            #     X_batch = np.reshape(X_batch, [-1, 28, 28, 1])
            # else:
            #     X_batch = np.reshape(X_batch, [-1, 1, 28, 28])
            # X_batch = (X_batch - 0.5) / 0.5

            X_batch = du.sample_batch(imgs, FLAGS.batch_size)
            output = sess.run(train_ops + loss_ops + [summary_op], feed_dict={X_real: X_batch})

            if batch_counter % (FLAGS.nb_batch_per_epoch // 20) == 0:
                writer.add_summary(output[-1], e * FLAGS.nb_batch_per_epoch + batch_counter)
            lossG, lossDreal, lossDfake = [output[2], output[4], output[5]]

            t.set_description('Epoch %i: - G loss: %.2f D loss real: %.2f Dloss fake: %.2f' %
                              (e, lossG, lossDreal, lossDfake))

            # variables = [v for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)]
            # bmean = [v for v in variables if v.name == "generator/conv2D_1_1/BatchNorm/moving_mean:0"]
            # print sess.run(bmean)
            # raw_input()

        # Plot some generated images
        # output = sess.run([X_G_output, X_real_output])
        output = sess.run([X_G_output, X_real_output], feed_dict={X_real: X_batch})
        vu.save_image(output, FLAGS.data_format, e)

        # Save session
        saver.save(sess, os.path.join(FLAGS.model_dir, "model"), global_step=e)

    print('Finished training!')
Exemplo n.º 3
0
def train_model():

    # Setup session
    sess = tu.setup_session()

    # Setup async input queue of real images
    #X_input = du.input_data(sess)
    #X_real, X_fake_in = X_input[0], X_input[1]
    X_real, X_fake_in = du.input_data(sess)
    X_real_name = X_real[1]
    X_real = X_real[0]
    X_fake_name = X_fake_in[1]
    X_fake_in = X_fake_in[0]
    #X, Y = create_datasets(no_dir, yes_dir)
    #X_fake_in = tf.placeholder(tf.float32, (batch_size, 96, 96, 3))
    #X_real = tf.placeholder(tf.float32, (batch_size, 96, 96, 3))

    #######################
    # Instantiate generator
    #######################
    list_filters = [256, 128, 64, 3]
    list_strides = [2] * len(list_filters)
    list_kernel_size = [3] * len(list_filters)
    list_padding = ["SAME"] * len(list_filters)
    output_shape = X_real.get_shape().as_list()[1:]
    G = models.Generator(list_filters,
                         list_kernel_size,
                         list_strides,
                         list_padding,
                         output_shape,
                         batch_size=FLAGS.batch_size,
                         data_format=FLAGS.data_format)

    ###########################
    # Instantiate discriminator
    ###########################
    list_filters = [32, 64, 128, 256]
    list_strides = [2] * len(list_filters)
    list_kernel_size = [3] * len(list_filters)
    list_padding = ["SAME"] * len(list_filters)
    D = models.Discriminator(list_filters,
                             list_kernel_size,
                             list_strides,
                             list_padding,
                             FLAGS.batch_size,
                             data_format=FLAGS.data_format)

    ###########################
    # Instantiate optimizers
    ###########################
    G_opt = tf.train.AdamOptimizer(learning_rate=1E-4,
                                   name='G_opt',
                                   beta1=0.5,
                                   beta2=0.9)
    D_opt = tf.train.AdamOptimizer(learning_rate=1E-4,
                                   name='D_opt',
                                   beta1=0.5,
                                   beta2=0.9)

    ###########################
    # Instantiate model outputs
    ###########################

    # noise_input = tf.random_normal((FLAGS.batch_size, FLAGS.noise_dim,), stddev=0.1)
    noise_input = tf.random_uniform((
        FLAGS.batch_size,
        FLAGS.noise_dim,
    ),
                                    minval=-1,
                                    maxval=1)
    X_fake = G(X_fake_in)
    #X_fake = G(noise_input)

    # output images
    #X_G_input = du.unnormalize_image(X_fake_in)
    X_G_output = du.unnormalize_image(X_fake)
    X_real_output = du.unnormalize_image(X_real)

    D_real = D(X_real)
    D_fake = D(X_fake, reuse=True)

    ###########################
    # Instantiate losses
    ###########################

    G_loss = -tf.reduce_mean(D_fake) + (tf.reduce_mean(abs(X_fake - X_real)))
    D_loss = tf.reduce_mean(D_fake) - tf.reduce_mean(D_real)

    epsilon = tf.random_uniform(shape=[FLAGS.batch_size, 1, 1, 1],
                                minval=0.,
                                maxval=1.)
    X_hat = X_real + epsilon * (X_fake - X_real)
    D_X_hat = D(X_hat, reuse=True)
    grad_D_X_hat = tf.gradients(D_X_hat, [X_hat])[0]
    if FLAGS.data_format == "NCHW":
        red_idx = [1]
    else:
        red_idx = [-1]
    slopes = tf.sqrt(
        tf.reduce_sum(tf.square(grad_D_X_hat), reduction_indices=red_idx))
    gradient_penalty = tf.reduce_mean((slopes - 1.)**2)
    D_loss += 10 * gradient_penalty

    ###########################
    # Compute gradient updates
    ###########################

    dict_G_vars = G.get_trainable_variables()
    G_vars = [dict_G_vars[k] for k in dict_G_vars.keys()]

    dict_D_vars = D.get_trainable_variables()
    D_vars = [dict_D_vars[k] for k in dict_D_vars.keys()]

    G_gradvar = G_opt.compute_gradients(G_loss,
                                        var_list=G_vars,
                                        colocate_gradients_with_ops=True)
    G_update = G_opt.apply_gradients(G_gradvar, name='G_loss_minimize')

    D_gradvar = D_opt.compute_gradients(D_loss,
                                        var_list=D_vars,
                                        colocate_gradients_with_ops=True)
    D_update = D_opt.apply_gradients(D_gradvar, name='D_loss_minimize')

    ##########################
    # Group training ops
    ##########################
    loss_ops = [G_loss, D_loss]

    ##########################
    # Summary ops
    ##########################

    # Add summary for gradients
    tu.add_gradient_summary(G_gradvar)
    tu.add_gradient_summary(D_gradvar)

    # Add scalar symmaries
    tf.summary.scalar("G loss", G_loss)
    tf.summary.scalar("D loss", D_loss)
    tf.summary.scalar("gradient_penalty", gradient_penalty)

    summary_op = tf.summary.merge_all()

    ############################
    # Start training
    ############################

    # Initialize session
    saver = tu.initialize_session(sess)

    # Start queues
    tu.manage_queues(sess)

    # Summaries
    writer = tu.manage_summaries(sess)
    txtfile = open("/home/ubuntu/makeup_removal/WGAN-GP/testfile.txt", "w")
    g_loss_list = []
    d_loss_list = []
    for e in tqdm(range(FLAGS.nb_epoch), desc="Training progress"):

        t = tqdm(range(FLAGS.nb_batch_per_epoch),
                 desc="Epoch %i" % e,
                 mininterval=0.5)
        num = 0
        for batch_counter in t:
            #with_makeup_batch, without_makeup_batch = nextBatch(X, Y, num, batch_size)
            num += 1
            g_loss_total = 0
            d_loss_total = 0
            for di in range(5):
                sess.run([D_update])
            #    #output = sess.run([G_update] + loss_ops + [summary_op])

            #sess.run([D_update])
            output = sess.run([G_update] + loss_ops + [summary_op])
            g_loss, d_loss = sess.run([G_loss, D_loss])
            g_loss_total += g_loss
            d_loss_total += d_loss
            #print (g_loss, d_loss)

            if batch_counter % (FLAGS.nb_batch_per_epoch // 20) == 0:
                writer.add_summary(
                    output[-1], e * FLAGS.nb_batch_per_epoch + batch_counter)

            t.set_description('Epoch %i' % e)
        g_loss_list.append(g_loss_total)
        d_loss_list.append(d_loss_total)
        # Plot some generated images
        #output = sess.run([X_G_output, X_G_input, X_real_output])
        output = sess.run(
            [X_G_output, X_real_output, X_fake_name, X_real_name])
        vu.save_image(output[:2], FLAGS.data_format, e)
        name = output[2:]
        print(name)
        with open('/home/ubuntu/makeup_removal/WGAN-GP/testfile.txt',
                  'a') as f:
            f.write(str(e))
            f.write('\n')
            for array in name:
                for n in array:
                    f.write(str(n))
                    f.write('\n')
            f.write('\n\n')
        #txtfile.close()
        #print (name)

        # Save session
        saver.save(sess, os.path.join(FLAGS.model_dir, "model"), global_step=e)

    print('Finished training!')
    plt.plot(range(FLAGS.nb_epoch), g_loss_list, label='g_loss')
    plt.plot(range(FLAGS.nb_epoch), d_loss_list, label='d_loss')
    plt.legend()
    plt.title('G:1E-4, D:1E-4')
    plt.xlabel('epoch')
    plt.ylabel('loss')
    plt.savefig('/home/ubuntu/makeup_removal/WGAN-GP/G_4_D_4.png')
Exemplo n.º 4
0
def train_model():

    # Setup session
    sess = tu.setup_session()

    # Placeholder for data and Mnist iterator
    mnist = input_data.read_data_sets(FLAGS.raw_dir, one_hot=True)
    # images = tf.constant(mnist.train.images)
    if FLAGS.data_format == "NHWC":
        X_real = tf.placeholder(tf.float32,
                                shape=[FLAGS.batch_size, 28, 28, 1])
    else:
        X_real = tf.placeholder(tf.float32,
                                shape=[FLAGS.batch_size, 1, 28, 28])

    with tf.device('/cpu:0'):
        imgs = mnist.train.images.astype(np.float32)
        npts = imgs.shape[0]
        if FLAGS.data_format == "NHWC":
            imgs = imgs.reshape((npts, 28, 28, 1))
        else:
            imgs = imgs.reshape((npts, 1, 28, 28))
        imgs = (imgs - 0.5) / 0.5
        # input_images = tf.constant(imgs)

    #     image = tf.train.slice_input_producer([input_images], num_epochs=FLAGS.nb_epoch)
    #     X_real = tf.train.batch(image, batch_size=FLAGS.batch_size, num_threads=8)

    #######################
    # Instantiate generator
    #######################
    list_filters = [256, 1]
    list_strides = [2] * len(list_filters)
    list_kernel_size = [3] * len(list_filters)
    list_padding = ["SAME"] * len(list_filters)
    output_shape = X_real.get_shape().as_list()[1:]
    G = models.Generator(list_filters,
                         list_kernel_size,
                         list_strides,
                         list_padding,
                         output_shape,
                         batch_size=FLAGS.batch_size,
                         dset="mnist",
                         data_format=FLAGS.data_format)

    ###########################
    # Instantiate discriminator
    ###########################
    list_filters = [32, 64]
    list_strides = [2] * len(list_filters)
    list_kernel_size = [3] * len(list_filters)
    list_padding = ["SAME"] * len(list_filters)
    D = models.Discriminator(list_filters,
                             list_kernel_size,
                             list_strides,
                             list_padding,
                             FLAGS.batch_size,
                             data_format=FLAGS.data_format)

    ###########################
    # Instantiate optimizers
    ###########################
    G_opt = tf.train.AdamOptimizer(learning_rate=FLAGS.learning_rate,
                                   name='G_opt',
                                   beta1=0.5)
    D_opt = tf.train.AdamOptimizer(learning_rate=FLAGS.learning_rate,
                                   name='D_opt',
                                   beta1=0.5)

    ###########################
    # Instantiate model outputs
    ###########################

    # noise_input = tf.random_normal((FLAGS.batch_size, FLAGS.noise_dim,), stddev=0.1)
    noise_input = tf.random_uniform((
        FLAGS.batch_size,
        FLAGS.noise_dim,
    ),
                                    minval=-1,
                                    maxval=1)
    X_fake = G(noise_input)

    # output images
    X_G_output = du.unnormalize_image(X_fake)
    X_real_output = du.unnormalize_image(X_real)

    D_real = D(X_real)
    D_fake = D(X_fake, reuse=True)

    ###########################
    # Instantiate losses
    ###########################

    G_loss = objectives.binary_cross_entropy_with_logits(
        D_fake, tf.ones_like(D_fake))
    D_loss_real = objectives.binary_cross_entropy_with_logits(
        D_real, tf.ones_like(D_real))
    D_loss_fake = objectives.binary_cross_entropy_with_logits(
        D_fake, tf.zeros_like(D_fake))

    D_loss = D_loss_real + D_loss_fake

    # ######################################################################
    # # Some parameters need to be updated (e.g. BN moving average/variance)
    # ######################################################################
    # from tensorflow.python.ops import control_flow_ops
    # update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    # with tf.control_dependencies(update_ops):
    #     barrier = tf.no_op(name='update_barrier')
    # D_loss = control_flow_ops.with_dependencies([barrier], D_loss)
    # G_loss = control_flow_ops.with_dependencies([barrier], G_loss)

    ###########################
    # Compute gradient updates
    ###########################

    dict_G_vars = G.get_trainable_variables()
    G_vars = [dict_G_vars[k] for k in dict_G_vars.keys()]

    dict_D_vars = D.get_trainable_variables()
    D_vars = [dict_D_vars[k] for k in dict_D_vars.keys()]

    G_gradvar = G_opt.compute_gradients(G_loss, var_list=G_vars)
    G_update = G_opt.apply_gradients(G_gradvar, name='G_loss_minimize')

    D_gradvar = D_opt.compute_gradients(D_loss, var_list=D_vars)
    D_update = D_opt.apply_gradients(D_gradvar, name='D_loss_minimize')

    ##########################
    # Group training ops
    ##########################
    train_ops = [G_update, D_update]
    loss_ops = [G_loss, D_loss, D_loss_real, D_loss_fake]

    ##########################
    # Summary ops
    ##########################

    # Add summary for gradients
    tu.add_gradient_summary(G_gradvar)
    tu.add_gradient_summary(D_gradvar)

    # Add scalar symmaries
    tf.summary.scalar("G loss", G_loss)
    tf.summary.scalar("D loss real", D_loss_real)
    tf.summary.scalar("D loss fake", D_loss_fake)

    summary_op = tf.summary.merge_all()

    ############################
    # Start training
    ############################

    # Initialize session
    saver = tu.initialize_session(sess)

    # Start queues
    coord = tu.manage_queues(sess)

    # Summaries
    writer = tu.manage_summaries(sess)

    for e in tqdm(range(FLAGS.nb_epoch), desc="\nTraining progress"):

        t = tqdm(range(FLAGS.nb_batch_per_epoch),
                 desc="Epoch %i" % e,
                 mininterval=0.5)
        for batch_counter in t:

            # X_batch, _ = mnist.train.next_batch(FLAGS.batch_size)
            # if FLAGS.data_format == "NHWC":
            #     X_batch = np.reshape(X_batch, [-1, 28, 28, 1])
            # else:
            #     X_batch = np.reshape(X_batch, [-1, 1, 28, 28])
            # X_batch = (X_batch - 0.5) / 0.5

            X_batch = du.sample_batch(imgs, FLAGS.batch_size)
            output = sess.run(train_ops + loss_ops + [summary_op],
                              feed_dict={X_real: X_batch})

            if batch_counter % (FLAGS.nb_batch_per_epoch // 20) == 0:
                writer.add_summary(
                    output[-1], e * FLAGS.nb_batch_per_epoch + batch_counter)
            lossG, lossDreal, lossDfake = [output[2], output[4], output[5]]

            t.set_description(
                'Epoch %i: - G loss: %.2f D loss real: %.2f Dloss fake: %.2f' %
                (e, lossG, lossDreal, lossDfake))

            # variables = [v for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)]
            # bmean = [v for v in variables if v.name == "generator/conv2D_1_1/BatchNorm/moving_mean:0"]
            # print sess.run(bmean)
            # raw_input()

        # Plot some generated images
        # output = sess.run([X_G_output, X_real_output])
        output = sess.run([X_G_output, X_real_output],
                          feed_dict={X_real: X_batch})
        vu.save_image(output, FLAGS.data_format, e)

        # Save session
        saver.save(sess, os.path.join(FLAGS.model_dir, "model"), global_step=e)

    print('Finished training!')
Exemplo n.º 5
0
def train_model():

    # Setup session
    sess = tu.setup_session()

    # Setup async input queue of real images
    X_real = du.input_data(sess)

    #######################
    # Instantiate generator
    #######################
    list_filters = [256, 128, 64, 3]
    list_strides = [2] * len(list_filters)
    list_kernel_size = [3] * len(list_filters)
    list_padding = ["SAME"] * len(list_filters)
    output_shape = X_real.get_shape().as_list()[1:]
    G = models.Generator(list_filters,
                         list_kernel_size,
                         list_strides,
                         list_padding,
                         output_shape,
                         batch_size=FLAGS.batch_size,
                         data_format=FLAGS.data_format)

    ###########################
    # Instantiate discriminator
    ###########################
    list_filters = [32, 64, 128, 256]
    list_strides = [2] * len(list_filters)
    list_kernel_size = [3] * len(list_filters)
    list_padding = ["SAME"] * len(list_filters)
    D = models.Discriminator(list_filters,
                             list_kernel_size,
                             list_strides,
                             list_padding,
                             FLAGS.batch_size,
                             data_format=FLAGS.data_format)

    ###########################
    # Instantiate optimizers
    ###########################
    G_opt = tf.train.AdamOptimizer(learning_rate=FLAGS.learning_rate,
                                   name='G_opt',
                                   beta1=0.5)
    D_opt = tf.train.AdamOptimizer(learning_rate=FLAGS.learning_rate,
                                   name='D_opt',
                                   beta1=0.5)

    ###########################
    # Instantiate model outputs
    ###########################

    # noise_input = tf.random_normal((FLAGS.batch_size, FLAGS.noise_dim,), stddev=0.1, name="noise_input")
    noise_input = tf.random_uniform((
        FLAGS.batch_size,
        FLAGS.noise_dim,
    ),
                                    minval=-1,
                                    maxval=1,
                                    name="noise_input")
    X_fake = G(noise_input)

    # output images
    X_G_output = du.unnormalize_image(X_fake, name="X_G_output")
    X_real_output = du.unnormalize_image(X_real, name="X_real_output")

    D_real = D(X_real)
    D_fake = D(X_fake, reuse=True)

    ###########################
    # Instantiate losses
    ###########################

    G_loss = objectives.binary_cross_entropy_with_logits(
        D_fake, tf.ones_like(D_fake))
    D_loss_real = objectives.binary_cross_entropy_with_logits(
        D_real, tf.ones_like(D_real))
    D_loss_fake = objectives.binary_cross_entropy_with_logits(
        D_fake, tf.zeros_like(D_fake))

    # G_loss = objectives.wasserstein(D_fake, -tf.ones_like(D_fake))
    # D_loss_real = objectives.wasserstein(D_real, -tf.ones_like(D_real))
    # D_loss_fake = objectives.wasserstein(D_fake, tf.ones_like(D_fake))

    D_loss = D_loss_real + D_loss_fake

    ###########################
    # Compute gradient updates
    ###########################

    dict_G_vars = G.get_trainable_variables()
    G_vars = [dict_G_vars[k] for k in dict_G_vars.keys()]

    dict_D_vars = D.get_trainable_variables()
    D_vars = [dict_D_vars[k] for k in dict_D_vars.keys()]

    G_gradvar = G_opt.compute_gradients(G_loss, var_list=G_vars)
    G_update = G_opt.apply_gradients(G_gradvar, name='G_loss_minimize')

    D_gradvar = D_opt.compute_gradients(D_loss, var_list=D_vars)
    D_update = D_opt.apply_gradients(D_gradvar, name='D_loss_minimize')

    # clip_op = [var.assign(tf.clip_by_value(var, -0.01, 0.01)) for var in D_vars]

    G_update = G_opt.minimize(G_loss, var_list=G_vars, name='G_loss_minimize')
    D_update = D_opt.minimize(D_loss, var_list=D_vars, name='D_loss_minimize')

    ##########################
    # Group training ops
    ##########################
    train_ops = [G_update, D_update]
    loss_ops = [G_loss, D_loss, D_loss_real, D_loss_fake]

    ##########################
    # Summary ops
    ##########################

    # Add summary for gradients
    tu.add_gradient_summary(G_gradvar)
    tu.add_gradient_summary(D_gradvar)

    # Add scalar symmaries
    tf.summary.scalar("G loss", G_loss)
    tf.summary.scalar("D loss real", D_loss_real)
    tf.summary.scalar("D loss fake", D_loss_fake)

    summary_op = tf.summary.merge_all()

    ############################
    # Start training
    ############################

    # Initialize session
    saver = tu.initialize_session(sess)

    # Start queues
    coord = tu.manage_queues(sess)

    # Summaries
    writer = tu.manage_summaries(sess)

    # Run checks on data dimensions
    list_data = [noise_input, X_real, X_fake, X_G_output, X_real_output]
    output = sess.run([noise_input, X_real, X_fake, X_G_output, X_real_output])
    tu.check_data(output, list_data)

    for e in tqdm(range(FLAGS.nb_epoch), desc="Training progress"):

        list_G_loss = []
        list_D_loss_real = []
        list_D_loss_fake = []

        t = tqdm(range(FLAGS.nb_batch_per_epoch),
                 desc="Epoch %i" % e,
                 mininterval=0.5)
        for batch_counter in t:

            o_D = sess.run([D_update, D_loss_real, D_loss_fake])
            sess.run([G_update, G_loss])
            o_G = sess.run([G_update, G_loss])
            output = sess.run([summary_op])

            list_G_loss.append(o_G[-1])
            list_D_loss_real.append(o_D[-2])
            list_D_loss_fake.append(o_D[-1])

            # output = sess.run(train_ops + loss_ops + [summary_op])
            # list_G_loss.append(output[2])
            # list_D_loss_real.append(output[4])
            # list_D_loss_fake.append(output[5])

            if batch_counter % (FLAGS.nb_batch_per_epoch //
                                (int(0.5 * FLAGS.nb_batch_per_epoch))) == 0:
                writer.add_summary(
                    output[-1], e * FLAGS.nb_batch_per_epoch + batch_counter)

        t.set_description(
            'Epoch %i: - G loss: %.3f D loss real: %.3f Dloss fake: %.3f' %
            (e, np.mean(list_G_loss), np.mean(list_D_loss_real),
             np.mean(list_D_loss_fake)))

        # Plot some generated images
        output = sess.run([X_G_output, X_real_output])
        vu.save_image(output, FLAGS.data_format, e)

        # Save session
        saver.save(sess, os.path.join(FLAGS.model_dir, "model"), global_step=e)

        if e == 0:
            print(len(list_data))
            output = sess.run(
                [noise_input, X_real, X_fake, X_G_output, X_real_output])
            tu.check_data(output, list_data)

    print('Finished training!')
def train_model():

    # Setup session
    sess = tu.setup_session()

    # Placeholder for data and Mnist iterator
    mnist = input_data.read_data_sets(FLAGS.raw_dir, one_hot=True)
    assert FLAGS.data_format == "NCHW", "Scattering only implemented in NCHW"
    X_tensor = tf.placeholder(tf.float32, shape=[FLAGS.batch_size, 1, 28, 28])
    y_tensor = tf.placeholder(tf.int64, shape=[FLAGS.batch_size, 10])

    with tf.device('/cpu:0'):
        X_train = mnist.train.images.astype(np.float32)
        y_train = mnist.train.labels.astype(np.int64)

        X_validation = mnist.validation.images.astype(np.float32)
        y_validation = mnist.validation.labels.astype(np.int64)

        X_train = (X_train - 0.5) / 0.5
        X_train = X_train.reshape((-1, 1, 28, 28))

        X_validation = (X_validation - 0.5) / 0.5
        X_validation = X_validation.reshape((-1, 1, 28, 28))

    # Build model
    class HybridCNN(models.Model):

        def __call__(self, x, reuse=False):
            with tf.variable_scope(self.name) as scope:

                if reuse:
                    scope.reuse_variables()

                M, N = x.get_shape().as_list()[-2:]
                x = scattering.Scattering(M=M, N=N, J=2)(x)
                x = tf.contrib.layers.batch_norm(x, data_format=FLAGS.data_format, fused=True, scope="scat_bn")
                x = layers.conv2d_block("CONV2D", x, 64, 1, 1, p="SAME", data_format=FLAGS.data_format, bias=True, bn=False, activation_fn=tf.nn.relu)

                target_shape = (-1, 64 * 7 * 7)
                x = layers.reshape(x, target_shape)
                x = layers.linear(x, 512, name="dense1")
                x = tf.nn.relu(x)
                x = layers.linear(x, 10, name="dense2")

                return x

    HCNN = HybridCNN("HCNN")
    y_pred = HCNN(X_tensor)

    ###########################
    # Instantiate optimizers
    ###########################
    opt = tf.train.AdamOptimizer(learning_rate=FLAGS.learning_rate, name='opt', beta1=0.5)

    loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_tensor, logits=y_pred))
    correct_prediction = tf.equal(tf.argmax(y_pred, 1), tf.argmax(y_tensor, 1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

    ###########################
    # Compute gradient updates
    ###########################

    dict_vars = HCNN.get_trainable_variables()
    all_vars = [dict_vars[k] for k in dict_vars.keys()]

    gradvar = opt.compute_gradients(loss, var_list=all_vars, colocate_gradients_with_ops=True)
    update = opt.apply_gradients(gradvar, name='loss_minimize')

    ##########################
    # Group training ops
    ##########################
    train_ops = [update]
    loss_ops = [loss, accuracy]

    ##########################
    # Summary ops
    ##########################

    # Add summary for gradients
    tu.add_gradient_summary(gradvar)

    # Add scalar symmaries
    tf.summary.scalar("loss", loss)

    summary_op = tf.summary.merge_all()

    ############################
    # Start training
    ############################

    # Initialize session
    tu.initialize_session(sess)

    # Start queues
    tu.manage_queues(sess)

    # Summaries
    writer = tu.manage_summaries(sess)

    for e in tqdm(range(FLAGS.nb_epoch), desc="Training progress"):

        t = tqdm(range(FLAGS.nb_batch_per_epoch), desc="Epoch %i" % e, mininterval=0.5)
        for batch_counter in t:

            # Get training data
            X_train_batch, y_train_batch = du.sample_batch(X_train, y_train, FLAGS.batch_size)

            # Run update and get loss
            output = sess.run(train_ops + loss_ops + [summary_op], feed_dict={X_tensor: X_train_batch,
                                                                              y_tensor: y_train_batch})
            train_loss = output[1]
            train_acc = output[2]

            # Write summaries
            if batch_counter % (FLAGS.nb_batch_per_epoch // 20) == 0:
                writer.add_summary(output[-1], e * FLAGS.nb_batch_per_epoch + batch_counter)

            # Get validation data
            X_validation_batch, y_validation_batch = du.sample_batch(X_validation, y_validation, FLAGS.batch_size)

            # Run update and get loss
            output = sess.run(loss_ops, feed_dict={X_tensor: X_validation_batch,
                                                   y_tensor: y_validation_batch})
            validation_loss = output[0]
            validation_acc = output[1]

            t.set_description('Epoch %i: - train loss: %.2f val loss: %.2f - train acc: %.2f val acc: %.2f' %
                              (e, train_loss, validation_loss, train_acc, validation_acc))

    print('Finished training!')
Exemplo n.º 7
0
def train_model():

    # Setup session
    sess = tu.setup_session()

    # Setup async input queue of real images
    X_real = du.input_data(sess)

    #######################
    # Instantiate generator
    #######################
    list_filters = [256, 128, 64, 3]
    list_strides = [2] * len(list_filters)
    list_kernel_size = [3] * len(list_filters)
    list_padding = ["SAME"] * len(list_filters)
    output_shape = X_real.get_shape().as_list()[1:]
    G = models.Generator(list_filters,
                         list_kernel_size,
                         list_strides,
                         list_padding,
                         output_shape,
                         batch_size=FLAGS.batch_size,
                         data_format=FLAGS.data_format)

    ###########################
    # Instantiate discriminator
    ###########################
    list_filters = [32, 64, 128, 256]
    list_strides = [2] * len(list_filters)
    list_kernel_size = [3] * len(list_filters)
    list_padding = ["SAME"] * len(list_filters)
    D = models.Discriminator(list_filters,
                             list_kernel_size,
                             list_strides,
                             list_padding,
                             FLAGS.batch_size,
                             data_format=FLAGS.data_format)

    ###########################
    # Instantiate optimizers
    ###########################
    G_opt = tf.train.AdamOptimizer(learning_rate=1E-4,
                                   name='G_opt',
                                   beta1=0.5,
                                   beta2=0.9)
    D_opt = tf.train.AdamOptimizer(learning_rate=1E-4,
                                   name='D_opt',
                                   beta1=0.5,
                                   beta2=0.9)

    ###########################
    # Instantiate model outputs
    ###########################

    # noise_input = tf.random_normal((FLAGS.batch_size, FLAGS.noise_dim,), stddev=0.1)
    noise_input = tf.random_uniform((
        FLAGS.batch_size,
        FLAGS.noise_dim,
    ),
                                    minval=-1,
                                    maxval=1)
    X_fake = G(noise_input)

    # output images
    X_G_output = du.unnormalize_image(X_fake)
    X_real_output = du.unnormalize_image(X_real)

    D_real = D(X_real)
    D_fake = D(X_fake, reuse=True)

    ###########################
    # Instantiate losses
    ###########################

    G_loss = -tf.reduce_mean(D_fake)
    D_loss = tf.reduce_mean(D_fake) - tf.reduce_mean(D_real)

    epsilon = tf.random_uniform(shape=[FLAGS.batch_size, 1, 1, 1],
                                minval=0.,
                                maxval=1.)
    X_hat = X_real + epsilon * (X_fake - X_real)
    D_X_hat = D(X_hat, reuse=True)
    grad_D_X_hat = tf.gradients(D_X_hat, [X_hat])[0]
    if FLAGS.data_format == "NCHW":
        red_idx = [1]
    else:
        red_idx = [-1]
    slopes = tf.sqrt(
        tf.reduce_sum(tf.square(grad_D_X_hat), reduction_indices=red_idx))
    gradient_penalty = tf.reduce_mean((slopes - 1.)**2)
    D_loss += 10 * gradient_penalty

    ###########################
    # Compute gradient updates
    ###########################

    dict_G_vars = G.get_trainable_variables()
    G_vars = [dict_G_vars[k] for k in dict_G_vars.keys()]

    dict_D_vars = D.get_trainable_variables()
    D_vars = [dict_D_vars[k] for k in dict_D_vars.keys()]

    G_gradvar = G_opt.compute_gradients(G_loss,
                                        var_list=G_vars,
                                        colocate_gradients_with_ops=True)
    G_update = G_opt.apply_gradients(G_gradvar, name='G_loss_minimize')

    D_gradvar = D_opt.compute_gradients(D_loss,
                                        var_list=D_vars,
                                        colocate_gradients_with_ops=True)
    D_update = D_opt.apply_gradients(D_gradvar, name='D_loss_minimize')

    ##########################
    # Group training ops
    ##########################
    loss_ops = [G_loss, D_loss]

    ##########################
    # Summary ops
    ##########################

    # Add summary for gradients
    tu.add_gradient_summary(G_gradvar)
    tu.add_gradient_summary(D_gradvar)

    # Add scalar symmaries
    tf.summary.scalar("G loss", G_loss)
    tf.summary.scalar("D loss", D_loss)
    tf.summary.scalar("gradient_penalty", gradient_penalty)

    summary_op = tf.summary.merge_all()

    ############################
    # Start training
    ############################

    # Initialize session
    saver = tu.initialize_session(sess)

    # Start queues
    tu.manage_queues(sess)

    # Summaries
    writer = tu.manage_summaries(sess)

    for e in tqdm(range(FLAGS.nb_epoch), desc="Training progress"):

        t = tqdm(range(FLAGS.nb_batch_per_epoch),
                 desc="Epoch %i" % e,
                 mininterval=0.5)
        for batch_counter in t:

            for di in range(5):
                sess.run([D_update])

            output = sess.run([G_update] + loss_ops + [summary_op])

            if batch_counter % (FLAGS.nb_batch_per_epoch // 20) == 0:
                writer.add_summary(
                    output[-1], e * FLAGS.nb_batch_per_epoch + batch_counter)

            t.set_description('Epoch %i' % e)

        # Plot some generated images
        output = sess.run([X_G_output, X_real_output])
        vu.save_image(output, FLAGS.data_format, e)

        # Save session
        saver.save(sess, os.path.join(FLAGS.model_dir, "model"), global_step=e)

    print('Finished training!')
def train_model():

    # Setup session
    sess = tu.setup_session()

    # Setup async input queue of real images
    X_real = du.input_data(sess)

    #######################
    # Instantiate generator
    #######################
    list_filters = [256, 128, 64, 3]
    list_strides = [2] * len(list_filters)
    list_kernel_size = [3] * len(list_filters)
    list_padding = ["SAME"] * len(list_filters)
    output_shape = X_real.get_shape().as_list()[1:]
    G = models.Generator(list_filters, list_kernel_size, list_strides, list_padding, output_shape,
                         batch_size=FLAGS.batch_size, data_format=FLAGS.data_format)

    ###########################
    # Instantiate discriminator
    ###########################
    list_filters = [32, 64, 128, 256]
    list_strides = [2] * len(list_filters)
    list_kernel_size = [3] * len(list_filters)
    list_padding = ["SAME"] * len(list_filters)
    D = models.Discriminator(list_filters, list_kernel_size, list_strides, list_padding,
                             FLAGS.batch_size, data_format=FLAGS.data_format)

    ###########################
    # Instantiate optimizers
    ###########################
    G_opt = tf.train.AdamOptimizer(learning_rate=FLAGS.learning_rate, name='G_opt', beta1=0.5)
    D_opt = tf.train.AdamOptimizer(learning_rate=FLAGS.learning_rate, name='D_opt', beta1=0.5)

    ###########################
    # Instantiate model outputs
    ###########################

    # noise_input = tf.random_normal((FLAGS.batch_size, FLAGS.noise_dim,), stddev=0.1, name="noise_input")
    noise_input = tf.random_uniform((FLAGS.batch_size, FLAGS.noise_dim,), minval=-1, maxval=1, name="noise_input")
    X_fake = G(noise_input)

    # output images
    X_G_output = du.unnormalize_image(X_fake, name="X_G_output")
    X_real_output = du.unnormalize_image(X_real, name="X_real_output")

    D_real = D(X_real)
    D_fake = D(X_fake, reuse=True)

    ###########################
    # Instantiate losses
    ###########################

    G_loss = objectives.binary_cross_entropy_with_logits(D_fake, tf.ones_like(D_fake))
    D_loss_real = objectives.binary_cross_entropy_with_logits(D_real, tf.ones_like(D_real))
    D_loss_fake = objectives.binary_cross_entropy_with_logits(D_fake, tf.zeros_like(D_fake))

    # G_loss = objectives.wasserstein(D_fake, -tf.ones_like(D_fake))
    # D_loss_real = objectives.wasserstein(D_real, -tf.ones_like(D_real))
    # D_loss_fake = objectives.wasserstein(D_fake, tf.ones_like(D_fake))

    D_loss = D_loss_real + D_loss_fake

    ###########################
    # Compute gradient updates
    ###########################

    dict_G_vars = G.get_trainable_variables()
    G_vars = [dict_G_vars[k] for k in dict_G_vars.keys()]

    dict_D_vars = D.get_trainable_variables()
    D_vars = [dict_D_vars[k] for k in dict_D_vars.keys()]

    G_gradvar = G_opt.compute_gradients(G_loss, var_list=G_vars)
    G_update = G_opt.apply_gradients(G_gradvar, name='G_loss_minimize')

    D_gradvar = D_opt.compute_gradients(D_loss, var_list=D_vars)
    D_update = D_opt.apply_gradients(D_gradvar, name='D_loss_minimize')

    # clip_op = [var.assign(tf.clip_by_value(var, -0.01, 0.01)) for var in D_vars]

    G_update = G_opt.minimize(G_loss, var_list=G_vars, name='G_loss_minimize')
    D_update = D_opt.minimize(D_loss, var_list=D_vars, name='D_loss_minimize')

    ##########################
    # Group training ops
    ##########################
    train_ops = [G_update, D_update]
    loss_ops = [G_loss, D_loss, D_loss_real, D_loss_fake]

    ##########################
    # Summary ops
    ##########################

    # Add summary for gradients
    tu.add_gradient_summary(G_gradvar)
    tu.add_gradient_summary(D_gradvar)

    # Add scalar symmaries
    tf.summary.scalar("G loss", G_loss)
    tf.summary.scalar("D loss real", D_loss_real)
    tf.summary.scalar("D loss fake", D_loss_fake)

    summary_op = tf.summary.merge_all()

    ############################
    # Start training
    ############################

    # Initialize session
    saver = tu.initialize_session(sess)

    # Start queues
    coord = tu.manage_queues(sess)

    # Summaries
    writer = tu.manage_summaries(sess)

    # Run checks on data dimensions
    list_data = [noise_input, X_real, X_fake, X_G_output, X_real_output]
    output = sess.run([noise_input, X_real, X_fake, X_G_output, X_real_output])
    tu.check_data(output, list_data)

    for e in tqdm(range(FLAGS.nb_epoch), desc="Training progress"):

        list_G_loss = []
        list_D_loss_real = []
        list_D_loss_fake = []

        t = tqdm(range(FLAGS.nb_batch_per_epoch), desc="Epoch %i" % e, mininterval=0.5)
        for batch_counter in t:

            o_D = sess.run([D_update, D_loss_real, D_loss_fake])
            sess.run([G_update, G_loss])
            o_G = sess.run([G_update, G_loss])
            output = sess.run([summary_op])

            list_G_loss.append(o_G[-1])
            list_D_loss_real.append(o_D[-2])
            list_D_loss_fake.append(o_D[-1])

            # output = sess.run(train_ops + loss_ops + [summary_op])
            # list_G_loss.append(output[2])
            # list_D_loss_real.append(output[4])
            # list_D_loss_fake.append(output[5])

            if batch_counter % (FLAGS.nb_batch_per_epoch // (int(0.5 * FLAGS.nb_batch_per_epoch))) == 0:
                writer.add_summary(output[-1], e * FLAGS.nb_batch_per_epoch + batch_counter)

            t.set_description('Epoch %i: - G loss: %.3f D loss real: %.3f Dloss fake: %.3f' %
                              (e, np.mean(list_G_loss), np.mean(list_D_loss_real), np.mean(list_D_loss_fake)))

        # Plot some generated images
        output = sess.run([X_G_output, X_real_output])
        vu.save_image(output, FLAGS.data_format, e)

        # Save session
        saver.save(sess, os.path.join(FLAGS.model_dir, "model"), global_step=e)

        if e == 0:
            print len(list_data)
            output = sess.run([noise_input, X_real, X_fake, X_G_output, X_real_output])
            tu.check_data(output, list_data)

    print('Finished training!')
def train_model():

    # Setup session
    sess = tu.setup_session()

    # Setup async input queue of real images
    X_real = du.input_data(sess)

    #######################
    # Instantiate generator
    #######################
    list_filters = [256, 128, 64, 3]
    list_strides = [2] * len(list_filters)
    list_kernel_size = [3] * len(list_filters)
    list_padding = ["SAME"] * len(list_filters)
    output_shape = X_real.get_shape().as_list()[1:]
    G = models.Generator(list_filters, list_kernel_size, list_strides, list_padding, output_shape,
                         batch_size=FLAGS.batch_size, data_format=FLAGS.data_format)

    ###########################
    # Instantiate discriminator
    ###########################
    list_filters = [32, 64, 128, 256]
    list_strides = [2] * len(list_filters)
    list_kernel_size = [3] * len(list_filters)
    list_padding = ["SAME"] * len(list_filters)
    D = models.Discriminator(list_filters, list_kernel_size, list_strides, list_padding,
                             FLAGS.batch_size, data_format=FLAGS.data_format)

    ###########################
    # Instantiate optimizers
    ###########################
    G_opt = tf.train.AdamOptimizer(learning_rate=1E-4, name='G_opt', beta1=0.5, beta2=0.9)
    D_opt = tf.train.AdamOptimizer(learning_rate=1E-4, name='D_opt', beta1=0.5, beta2=0.9)

    ###########################
    # Instantiate model outputs
    ###########################

    # noise_input = tf.random_normal((FLAGS.batch_size, FLAGS.noise_dim,), stddev=0.1)
    noise_input = tf.random_uniform((FLAGS.batch_size, FLAGS.noise_dim,), minval=-1, maxval=1)
    X_fake = G(noise_input)

    # output images
    X_G_output = du.unnormalize_image(X_fake)
    X_real_output = du.unnormalize_image(X_real)

    D_real = D(X_real)
    D_fake = D(X_fake, reuse=True)

    ###########################
    # Instantiate losses
    ###########################

    G_loss = -tf.reduce_mean(D_fake)
    D_loss = tf.reduce_mean(D_fake) - tf.reduce_mean(D_real)

    epsilon = tf.random_uniform(
        shape=[FLAGS.batch_size, 1, 1, 1],
        minval=0.,
        maxval=1.
    )
    X_hat = X_real + epsilon * (X_fake - X_real)
    D_X_hat = D(X_hat, reuse=True)
    grad_D_X_hat = tf.gradients(D_X_hat, [X_hat])[0]
    if FLAGS.data_format == "NCHW":
        red_idx = [1]
    else:
        red_idx = [-1]
    slopes = tf.sqrt(tf.reduce_sum(tf.square(grad_D_X_hat), reduction_indices=red_idx))
    gradient_penalty = tf.reduce_mean((slopes - 1.)**2)
    D_loss += 10 * gradient_penalty

    ###########################
    # Compute gradient updates
    ###########################

    dict_G_vars = G.get_trainable_variables()
    G_vars = [dict_G_vars[k] for k in dict_G_vars.keys()]

    dict_D_vars = D.get_trainable_variables()
    D_vars = [dict_D_vars[k] for k in dict_D_vars.keys()]

    G_gradvar = G_opt.compute_gradients(G_loss, var_list=G_vars, colocate_gradients_with_ops=True)
    G_update = G_opt.apply_gradients(G_gradvar, name='G_loss_minimize')

    D_gradvar = D_opt.compute_gradients(D_loss, var_list=D_vars, colocate_gradients_with_ops=True)
    D_update = D_opt.apply_gradients(D_gradvar, name='D_loss_minimize')

    ##########################
    # Group training ops
    ##########################
    loss_ops = [G_loss, D_loss]

    ##########################
    # Summary ops
    ##########################

    # Add summary for gradients
    tu.add_gradient_summary(G_gradvar)
    tu.add_gradient_summary(D_gradvar)

    # Add scalar symmaries
    tf.summary.scalar("G loss", G_loss)
    tf.summary.scalar("D loss", D_loss)
    tf.summary.scalar("gradient_penalty", gradient_penalty)

    summary_op = tf.summary.merge_all()

    ############################
    # Start training
    ############################

    # Initialize session
    saver = tu.initialize_session(sess)

    # Start queues
    tu.manage_queues(sess)

    # Summaries
    writer = tu.manage_summaries(sess)

    for e in tqdm(range(FLAGS.nb_epoch), desc="Training progress"):

        t = tqdm(range(FLAGS.nb_batch_per_epoch), desc="Epoch %i" % e, mininterval=0.5)
        for batch_counter in t:

            for di in range(5):
                sess.run([D_update])

            output = sess.run([G_update] + loss_ops + [summary_op])

            if batch_counter % (FLAGS.nb_batch_per_epoch // 20) == 0:
                writer.add_summary(output[-1], e * FLAGS.nb_batch_per_epoch + batch_counter)

            t.set_description('Epoch %i' % e)

        # Plot some generated images
        output = sess.run([X_G_output, X_real_output])
        vu.save_image(output, FLAGS.data_format, e)

        # Save session
        saver.save(sess, os.path.join(FLAGS.model_dir, "model"), global_step=e)

    print('Finished training!')