Exemple #1
0
    def _run(self, images_path=None):
        idx = 0     # model index
        # print (imgs)
        sess = tf.Session()
        generator_x = sess.run(sample_noise(num_gen, dim))
        predict = self.model[idx][0].run(
            self.model[idx][4],
            feed_dict={self.model[idx][1]: generator_x
                       # self.model[idx][2]: imgs2,
                       # self.model[idx][3]: imgs3
                                  }
                                  )
        print ('predict:', predict)

        return predict.tolist()
Exemple #2
0
    def _run(self, images_path=None):
        idx = 0     # model index
        # print (imgs)
        sess = tf.Session()
        generator_x = sess.run(sample_noise(num_gen, dim))
        num_classes = config.num_classes
        def sample_label():
            num = num_gen
            label_vector = np.zeros((num , num_classes), dtype=np.float)
            for i in range(0 , num):
                label_vector[i , i%4] = 1.0
            return label_vector
        label = sample_label()
        predict = self.model[idx][0].run(
            self.model[idx][5],
            feed_dict={self.model[idx][1]: generator_x,
                       self.model[idx][2]: label
                       # self.model[idx][3]: imgs3
                                  }
                                  )
        print ('predict:', predict)

        return predict.tolist()
def train_GANs(train_data,
               train_label,
               valid_data,
               valid_label,
               train_dir,
               num_classes,
               batch_size,
               arch_model,
               learning_r_decay,
               learning_rate_base,
               decay_rate,
               dropout_prob,
               epoch,
               height,
               width,
               checkpoint_exclude_scopes,
               early_stop,
               EARLY_STOP_PATIENCE,
               fine_tune,
               train_all_layers,
               checkpoint_path,
               train_n,
               valid_n,
               g_parameter,
               dim=64):
    # ---------------------------------------------------------------------------------#
    G_X, G_Y, G_is_train, G_keep_prob_fc = generator_input_placeholder(
        dim, num_classes)
    G_net, _ = build_generator(G_X, num_classes, G_keep_prob_fc, G_is_train,
                               arch_model)

    D_X, D_Y, D_is_train, D_keep_prob_fc = discriminator_input_placeholder(
        height, width, num_classes)
    with tf.variable_scope("") as scope:
        logits_real, _ = build_discriminator(D_X, num_classes, D_keep_prob_fc,
                                             D_is_train, arch_model)
        scope.reuse_variables()
        logits_fake, _ = build_discriminator(G_net, num_classes,
                                             D_keep_prob_fc, D_is_train,
                                             arch_model)

    G_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'Generator')
    # print (G_vars)
    D_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                               'Discriminator')
    # print (G_vars)

    D_loss, G_loss = gan_loss(logits_real, logits_fake)
    # G_loss = cost(logits_fake)
    # D_loss = cost(logits_real) + cost(logits_fake)

    global_step = tf.Variable(0, trainable=False)
    if learning_r_decay:
        learning_rate = tf.train.exponential_decay(learning_rate_base,
                                                   global_step * batch_size,
                                                   train_n,
                                                   decay_rate,
                                                   staircase=True)
    else:
        learning_rate = learning_rate_base

    G_optimizer = train_op(learning_rate, G_loss, G_vars, global_step)
    D_optimizer = train_op(learning_rate * 0.1, D_loss, D_vars, global_step)

    #------------------------------------------------------------------------------------#
    sess = tf.Session()
    init = tf.global_variables_initializer()
    sess.run(init)
    saver2 = tf.train.Saver(G_vars)
    if not train_all_layers:
        saver_net = tf.train.Saver(G_vars)
        saver_net.restore(sess, checkpoint_path)

    if fine_tune:
        # saver2.restore(sess, fine_tune_dir)
        latest = tf.train.latest_checkpoint(train_dir)
        if not latest:
            print("No checkpoint to continue from in", train_dir)
            sys.exit(1)
        print("resume", latest)
        saver2.restore(sess, latest)

    # early stopping
    best_valid = np.inf
    best_valid_epoch = 0

    for epoch_i in range(epoch):
        for batch_i in range(int(train_n / batch_size)):
            dim = dim
            generator_x = sess.run(sample_noise(batch_size, dim))
            # images,_ = mnist.train.next_batch(batch_size)
            # images = preprocess_img(images)
            images = get_next_batch_from_path(train_data,
                                              train_label,
                                              batch_i,
                                              height,
                                              width,
                                              batch_size=batch_size,
                                              training=True)
            D_los, _ = sess.run(
                [D_loss, D_optimizer],
                feed_dict={
                    G_X: generator_x,
                    D_X: images,
                    D_is_train: True,
                    D_keep_prob_fc: dropout_prob
                })
            G_los, _ = sess.run([G_loss, G_optimizer],
                                feed_dict={
                                    G_X: generator_x,
                                    G_is_train: True,
                                    G_keep_prob_fc: dropout_prob
                                })
            print('D_los:', D_los)
            print('G_los:', G_los)
            checkpoint_path = os.path.join(train_dir, 'model.ckpt')
            saver2.save(sess,
                        checkpoint_path,
                        global_step=batch_i,
                        write_meta_graph=False)
            if batch_i % 20 == 0:
                D_loss_ = sess.run(D_loss,
                                   feed_dict={
                                       G_X: generator_x,
                                       D_X: images,
                                       D_is_train: False,
                                       D_keep_prob_fc: 1.0
                                   })
                G_loss_ = sess.run(G_loss,
                                   feed_dict={
                                       G_X: generator_x,
                                       G_is_train: False,
                                       G_keep_prob_fc: 1.0
                                   })
                print('Batch: {:>2}: D_training loss: {:>3.5f}'.format(
                    batch_i, D_loss_))
                print('Batch: {:>2}: G_training loss: {:>3.5f}'.format(
                    batch_i, G_loss_))

            if batch_i % 100 == 0:
                generator_x = sess.run(sample_noise(batch_size, dim))
                # images,_ = mnist.train.next_batch(batch_size)
                # images = preprocess_img(images)
                images = get_next_batch_from_path(valid_data,
                                                  valid_label,
                                                  batch_i %
                                                  (int(valid_n / batch_size)),
                                                  height,
                                                  width,
                                                  batch_size=batch_size,
                                                  training=False)
                D_ls = sess.run(D_loss,
                                feed_dict={
                                    G_X: generator_x,
                                    D_X: images,
                                    D_is_train: False,
                                    D_keep_prob_fc: 1.0
                                })
                G_ls = sess.run(G_loss,
                                feed_dict={
                                    G_X: generator_x,
                                    G_is_train: False,
                                    G_keep_prob_fc: 1.0
                                })
                print('Batch: {:>2}: D_validation loss: {:>3.5f}'.format(
                    batch_i, D_ls))
                print('Batch: {:>2}: G_validation loss: {:>3.5f}'.format(
                    batch_i, G_ls))

        print(
            'Epoch===================================>: {:>2}'.format(epoch_i))
        G_valid_ls = 0
        G_samples = 0
        for batch_i in range(int(valid_n / batch_size)):
            generator_x = sess.run(sample_noise(batch_size, dim))
            G_epoch_ls, G_samples = sess.run([G_loss, G_net],
                                             feed_dict={
                                                 G_X: generator_x,
                                                 G_keep_prob_fc: 1.0,
                                                 G_is_train: False
                                             })
            G_valid_ls = G_valid_ls + G_epoch_ls
        fig = show_images(G_samples[:16])
        plt.show()
        print('Epoch: {:>2}: G_validation loss: {:>3.5f}'.format(
            epoch_i, G_valid_ls / int(valid_n / batch_size)))
        # ---------------------------------------------------------------------------------#
        if early_stop:
            loss_valid = G_valid_ls / int(valid_n / batch_size)
            if loss_valid < best_valid:
                best_valid = loss_valid
                best_valid_epoch = epoch_i
            elif best_valid_epoch + EARLY_STOP_PATIENCE < epoch_i:
                print("Early stopping.")
                print("Best valid loss was {:.6f} at epoch {}.".format(
                    best_valid, best_valid_epoch))
                break
        train_data, train_label = shuffle_train_data(train_data, train_label)
    sess.close()