Exemplo n.º 1
0
def calculate_fid(path, stats_path, inception_path, use_unbatched):
    if not os.path.exists(stats_path):
        raise RuntimeError("Invalid inception-statistics file")
    inception_path = check_or_download_inception(inception_path)

    path = pathlib.Path(path)
    files = list(path.glob('*.jpg')) + list(path.glob('*.png'))
    x = np.array(
        [scipy.misc.imread(str(fn)).astype(np.float32) for fn in files])

    fid.create_incpetion_graph(str(inception_path))
    sigma, mu = fid.load_stats(stats_path)
    jpeg_tuple = fid.get_jpeg_encoder_tuple()
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        query_tensor = fid.get_Fid_query_tensor(sess)
        if use_unbatched:
            fid_value = fid.FID_unbatched(x, query_tensor, mu, sigma,
                                          jpeg_tuple, sess)
        else:
            pred_array = fid.get_predictions(x,
                                             query_tensor,
                                             sess,
                                             batch_size=128)
            fid_value, _, _ = fid.FID(pred_array, mu, sigma, sess)
        return fid_value
Exemplo n.º 2
0


# load inference model
fid.create_incpetion_graph(MODEL_PATH)

# get tuple for jpeg encoding
jpeg_tuple = fid.get_jpeg_encoder_tuple()

# batch size for batched version
batch_size = 500
init = tf.global_variables_initializer()
sess = tf.Session()
with sess.as_default():
    sess.run(init)
    query_tensor = fid.get_Fid_query_tensor(sess)

    #
    # caĺculate statistics for batch version
    #
    sigma_b, mu_b = fid.precalc_stats_batched( X.get_data().reshape(-1,64,64,3),
                                               query_tensor,
                                               sess,
                                               batch_size=batch_size,
                                               verbouse=True)
    # save statistics of batch version
    fid.save_stats(sigma_b, mu_b, "stats_b.pkl.gz")
    # load saved statistics
    (sigma_b_loaded, mu_b_loaded) = fid.load_stats("stats_b.pkl.gz")

Exemplo n.º 3
0
    def train(self):

        print("load train stats..", end="")
        sigma_trn, mu_trn = fid.load_stats(self.train_stats_file)
        print("ok")

        z_fixed = np.random.uniform(-1, 1, size=(self.batch_size, self.z_num))

        x_fixed = self.get_image_from_loader()
        save_image(x_fixed, '{}/x_fixed.png'.format(self.model_dir))

        prev_measure = 1
        measure_history = deque([0] * self.lr_update_step, self.lr_update_step)

        # load inference model
        fid.create_incpetion_graph(
            "inception-2015-12-05/classify_image_graph_def.pb")

        query_tensor = fid.get_Fid_query_tensor(self.sess)

        if self.load_checkpoint:
            if self.load(self.model_dir):
                print(" [*] Load SUCCESS")
            else:
                print(" [!] Load failed...")

        # Precallocate prediction array for kl/fid inception score
        #print("preallocate %.3f GB for prediction array.." % (self.batch_size_eval * 2048 / (1024**3)), end=" ", flush=True)
        pred_arr = np.ones([self.batch_size_eval, 2048])
        #print("ok")

        for step in trange(self.start_step, self.max_step):

            # Optimize
            self.sess.run([self.d_optim, self.g_optim])

            # Feed dict
            fetch_dict = {"measure": self.measure}

            if self.update_k:
                fetch_dict.update({"k_update": self.k_update})

            if step % self.log_step == 0:
                fetch_dict.update({
                    "summary": self.summary_op,
                    "g_loss": self.g_loss,
                    "d_loss": self.d_loss,
                    "k_t": self.k_t,
                })

            # Get summaries
            result = self.sess.run(fetch_dict)

            measure = result['measure']
            measure_history.append(measure)

            if step % self.log_step == 0:
                self.summary_writer.add_summary(result['summary'], step)
                self.summary_writer.flush()

                g_loss = result['g_loss']
                d_loss = result['d_loss']
                k_t = result['k_t']

                print("[{}/{}] Loss_D: {:.6f} Loss_G: {:.6f} measure: {:.4f}, k_t: {:.4f}". \
                      format(step, self.max_step, d_loss, g_loss, measure, k_t))

            if step % (self.log_step * 10) == 0:
                x_fake = self.generate(z_fixed, self.model_dir, idx=step)
                self.autoencode(x_fixed,
                                self.model_dir,
                                idx=step,
                                x_fake=x_fake)

            if step % self.lr_update_step == self.lr_update_step - 1:
                self.sess.run([self.g_lr_update, self.d_lr_update])

            # FID
            if step % self.eval_step == 0:

                eval_batches_num = self.batch_size_eval // self.eval_batch_size

                for eval_batch in range(eval_batches_num):

                    print("\rFID batch %d/%d" %
                          (eval_batch + 1, eval_batches_num),
                          end="",
                          flush=True)

                    sample_z_eval = np.random.uniform(
                        -1, 1, size=(self.eval_batch_size, self.z_num))
                    samples_eval = self.generate(sample_z_eval,
                                                 self.model_dir,
                                                 save=False)

                    pred_array = fid.get_predictions(
                        samples_eval,
                        query_tensor,
                        self.sess,
                        batch_size=self.eval_batch_size,
                        verbose=False)

                    frm = eval_batch * self.eval_batch_size
                    to = frm + self.eval_batch_size
                    pred_arr[frm:to, :] = pred_array

                print()

                # calculate FID
                print("FID", end=" ", flush=True)
                FID, _, _ = fid.FID(pred_arr, mu_trn, sigma_trn, self.sess)

                if FID is None:
                    FID = 500  # Something went wrong
                print(FID)
                self.sess.run(tf.assign(self.fid, FID))
                summary_str = self.sess.run(self.fid_sum)
                self.summary_writer.add_summary(summary_str, step)
Exemplo n.º 4
0
    def build_model(self):

        # get pool_3 layer
        self.querry_tensor = fid.get_Fid_query_tensor(self.sess)

        # Learning rate
        self.learning_rate_d = tf.Variable(0.0, trainable=False)
        self.learning_rate_g = tf.Variable(0.0, trainable=False)

        # Placeholders

        if self.is_crop:
            image_dims = [self.output_height, self.output_width, self.c_dim]
        else:
            image_dims = [self.input_height, self.input_width, self.c_dim]

        self.inputs = tf.placeholder(tf.float32,
                                     [self.batch_size] + image_dims,
                                     name='real_images')
        self.sample_inputs = tf.placeholder(tf.float32,
                                            [self.sample_num] + image_dims,
                                            name='sample_inputs')

        self.z = tf.placeholder(tf.float32, [None, self.z_dim], name='z')
        self.z_sum = tf.summary.histogram("z", self.z)

        self.fid = tf.Variable(0.0, trainable=False)
        self.z_dist = tf.placeholder(tf.float32, [None, self.z_dim],
                                     name='z_dist')

        # Inputs
        inputs = self.inputs
        sample_inputs = self.sample_inputs

        # Discriminator and generator
        if self.y_dim:
            pass
        else:
            self.G = self.generator(self.z, batch_size=self.batch_size)
            self.D_real, self.D_logits_real = self.discriminator(inputs)

            self.sampler_dist = self.sampler_func(self.z_dist,
                                                  self.batch_size_dist)
            self.sampler = self.sampler_func(self.z, self.batch_size)
            self.D_fake, self.D_logits_fake = self.discriminator(self.G,
                                                                 reuse=True)

        # Summaries
        self.d_real_sum = tf.summary.histogram("d_real", self.D_real)
        self.d_fake_sum = tf.summary.histogram("d_fake", self.D_fake)
        self.G_sum = tf.summary.image("G", self.G)

        def sigmoid_cross_entropy_with_logits(x, y):
            try:
                return tf.nn.sigmoid_cross_entropy_with_logits(logits=x,
                                                               labels=y)
            except:
                return tf.nn.sigmoid_cross_entropy_with_logits(logits=x,
                                                               targets=y)

        # Discriminator Loss Real
        self.d_loss_real = tf.reduce_mean(
            sigmoid_cross_entropy_with_logits(self.D_logits_real,
                                              tf.ones_like(self.D_real)))
        # Discriminator Loss Fake
        self.d_loss_fake = tf.reduce_mean(
            sigmoid_cross_entropy_with_logits(self.D_logits_fake,
                                              tf.zeros_like(self.D_fake)))
        # Generator Loss
        self.g_loss = tf.reduce_mean(
            sigmoid_cross_entropy_with_logits(self.D_logits_fake,
                                              tf.ones_like(self.D_fake)))

        self.d_loss_real_sum = tf.summary.scalar("d_loss_real",
                                                 self.d_loss_real)
        self.d_loss_fake_sum = tf.summary.scalar("d_loss_fake",
                                                 self.d_loss_fake)

        # Discriminator Loss Combined
        self.d_loss = self.d_loss_real + self.d_loss_fake

        self.g_loss_sum = tf.summary.scalar("g_loss", self.g_loss)
        self.d_loss_sum = tf.summary.scalar("d_loss", self.d_loss)

        self.lrate_sum_d = tf.summary.scalar('learning rate d',
                                             self.learning_rate_d)
        self.lrate_sum_g = tf.summary.scalar('learning rate g',
                                             self.learning_rate_g)

        self.fid_sum = tf.summary.scalar("FID", self.fid)

        self.image_enc_data = tf.placeholder(tf.uint8, [64, 64, 3])
        self.encode_jpeg = tf.image.encode_jpeg(self.image_enc_data)

        # Variables
        t_vars = tf.trainable_variables()

        self.d_vars = [var for var in t_vars if 'd_' in var.name]
        self.g_vars = [var for var in t_vars if 'g_' in var.name]

        # Checkpoint saver
        self.saver = tf.train.Saver()