Exemple #1
0
def calculate_fid(path, stats_path, inception_path, use_unbatched):
    if not os.path.exists(stats_path):
        raise RuntimeError("Invalid inception-statistics file")
    inception_path = check_or_download_inception(inception_path)

    path = pathlib.Path(path)
    files = list(path.glob('*.jpg')) + list(path.glob('*.png'))
    x = np.array(
        [scipy.misc.imread(str(fn)).astype(np.float32) for fn in files])

    fid.create_incpetion_graph(str(inception_path))
    sigma, mu = fid.load_stats(stats_path)
    jpeg_tuple = fid.get_jpeg_encoder_tuple()
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        query_tensor = fid.get_Fid_query_tensor(sess)
        if use_unbatched:
            fid_value = fid.FID_unbatched(x, query_tensor, mu, sigma,
                                          jpeg_tuple, sess)
        else:
            pred_array = fid.get_predictions(x,
                                             query_tensor,
                                             sess,
                                             batch_size=128)
            fid_value, _, _ = fid.FID(pred_array, mu, sigma, sess)
        return fid_value
sess = tf.Session()
with sess.as_default():
    sess.run(init)
    query_tensor = fid.get_Fid_query_tensor(sess)

    #
    # caĺculate statistics for batch version
    #
    sigma_b, mu_b = fid.precalc_stats_batched( X.get_data().reshape(-1,64,64,3),
                                               query_tensor,
                                               sess,
                                               batch_size=batch_size,
                                               verbouse=True)
    # save statistics of batch version
    fid.save_stats(sigma_b, mu_b, "stats_b.pkl.gz")
    # load saved statistics
    (sigma_b_loaded, mu_b_loaded) = fid.load_stats("stats_b.pkl.gz")


    #
    # calculate statistic for unbatched version
    #
    sigma_u, mu_u = fid.precalc_stats_unbatched( X.get_data().reshape(-1,64,64,3),
                                                 query_tensor,
                                                 jpeg_tuple,
                                                 sess)
    # save statistics of unbatched version
    fid.save_stats(sigma_u, mu_u, "stats_u.pkl.gz")
    # load statistics of unbatched version
    (sigma_u_loaded, mu_u_loaded) = fid.load_stats("stats_u.pkl.gz")
                    input_height=64,
                    input_width=64,
                    resize_height=64,
                    resize_width=64,
                    is_crop=False,
                    is_grayscale=False)
    X._data[i,:] = img.flatten()
print("done")



# load inference model
fid.create_incpetion_graph(MODEL_PATH)

# load precalculated statistics
sigma_trn, mu_trn = fid.load_stats(STATS_PATH)

# get jpeg encoder
jpeg_tuple = fid.get_jpeg_encoder_tuple()

n_rect = 5
alphas = [ 0.75, 0.5, 0.25, 0.0]
init = tf.global_variables_initializer()
sess = tf.Session()
with sess.as_default():
    sess.run(init)
    query_tensor = fid.get_Fid_query_tensor(sess)
    for i,a in enumerate(alphas):
        # disturbe images with implanted black rectangles
        X.apply_mult_rect(n_rect, 64, 64, 3, share=a, val=X._data.min())
        # propagate disturbed images through imagnet and calculate FID
Exemple #4
0
    def train(self):

        print("load train stats..", end="")
        sigma_trn, mu_trn = fid.load_stats(self.train_stats_file)
        print("ok")

        z_fixed = np.random.uniform(-1, 1, size=(self.batch_size, self.z_num))

        x_fixed = self.get_image_from_loader()
        save_image(x_fixed, '{}/x_fixed.png'.format(self.model_dir))

        prev_measure = 1
        measure_history = deque([0] * self.lr_update_step, self.lr_update_step)

        # load inference model
        fid.create_incpetion_graph(
            "inception-2015-12-05/classify_image_graph_def.pb")

        query_tensor = fid.get_Fid_query_tensor(self.sess)

        if self.load_checkpoint:
            if self.load(self.model_dir):
                print(" [*] Load SUCCESS")
            else:
                print(" [!] Load failed...")

        # Precallocate prediction array for kl/fid inception score
        #print("preallocate %.3f GB for prediction array.." % (self.batch_size_eval * 2048 / (1024**3)), end=" ", flush=True)
        pred_arr = np.ones([self.batch_size_eval, 2048])
        #print("ok")

        for step in trange(self.start_step, self.max_step):

            # Optimize
            self.sess.run([self.d_optim, self.g_optim])

            # Feed dict
            fetch_dict = {"measure": self.measure}

            if self.update_k:
                fetch_dict.update({"k_update": self.k_update})

            if step % self.log_step == 0:
                fetch_dict.update({
                    "summary": self.summary_op,
                    "g_loss": self.g_loss,
                    "d_loss": self.d_loss,
                    "k_t": self.k_t,
                })

            # Get summaries
            result = self.sess.run(fetch_dict)

            measure = result['measure']
            measure_history.append(measure)

            if step % self.log_step == 0:
                self.summary_writer.add_summary(result['summary'], step)
                self.summary_writer.flush()

                g_loss = result['g_loss']
                d_loss = result['d_loss']
                k_t = result['k_t']

                print("[{}/{}] Loss_D: {:.6f} Loss_G: {:.6f} measure: {:.4f}, k_t: {:.4f}". \
                      format(step, self.max_step, d_loss, g_loss, measure, k_t))

            if step % (self.log_step * 10) == 0:
                x_fake = self.generate(z_fixed, self.model_dir, idx=step)
                self.autoencode(x_fixed,
                                self.model_dir,
                                idx=step,
                                x_fake=x_fake)

            if step % self.lr_update_step == self.lr_update_step - 1:
                self.sess.run([self.g_lr_update, self.d_lr_update])

            # FID
            if step % self.eval_step == 0:

                eval_batches_num = self.batch_size_eval // self.eval_batch_size

                for eval_batch in range(eval_batches_num):

                    print("\rFID batch %d/%d" %
                          (eval_batch + 1, eval_batches_num),
                          end="",
                          flush=True)

                    sample_z_eval = np.random.uniform(
                        -1, 1, size=(self.eval_batch_size, self.z_num))
                    samples_eval = self.generate(sample_z_eval,
                                                 self.model_dir,
                                                 save=False)

                    pred_array = fid.get_predictions(
                        samples_eval,
                        query_tensor,
                        self.sess,
                        batch_size=self.eval_batch_size,
                        verbose=False)

                    frm = eval_batch * self.eval_batch_size
                    to = frm + self.eval_batch_size
                    pred_arr[frm:to, :] = pred_array

                print()

                # calculate FID
                print("FID", end=" ", flush=True)
                FID, _, _ = fid.FID(pred_arr, mu_trn, sigma_trn, self.sess)

                if FID is None:
                    FID = 500  # Something went wrong
                print(FID)
                self.sess.run(tf.assign(self.fid, FID))
                summary_str = self.sess.run(self.fid_sum)
                self.summary_writer.add_summary(summary_str, step)
Exemple #5
0
    def train(self, config):
        """Train DCGAN"""

        print("load train stats")
        sigma_trn, mu_trn = fid.load_stats(self.stats_path)

        # get querry tensor
        query_tensor = self.querry_tensor  #fid.get_Fid_query_tensor(self.sess)

        print("scan files", end=" ", flush=True)
        if config.dataset == 'mnist':
            data_X, data_y = self.load_mnist()
        else:
            data = glob(os.path.join(self.data_path, "*"))
        print("%d images found" % len(data))

        print("build model", end=" ", flush=True)
        # Train optimizers
        opt_d = tf.train.AdamOptimizer(config.learning_rate_d,
                                       beta1=config.beta1)
        opt_g = tf.train.AdamOptimizer(config.learning_rate_g,
                                       beta1=config.beta1)

        # Discriminator
        grads_and_vars = opt_d.compute_gradients(self.d_loss,
                                                 var_list=self.d_vars)
        grads = []
        d_optim = opt_d.apply_gradients(grads_and_vars)

        # Gradient summaries discriminator
        sum_grad_d = []
        for i, (grad, vars_) in enumerate(grads_and_vars):
            grad_l2 = tf.sqrt(tf.reduce_sum(tf.square(grad)))
            sum_grad_d.append(
                tf.summary.scalar("grad_l2_d_%d_%s" % (i, vars_.name),
                                  grad_l2))

        # Generator
        grads_and_vars = opt_g.compute_gradients(self.g_loss,
                                                 var_list=self.g_vars)
        g_optim = opt_g.apply_gradients(grads_and_vars)

        # Gradient summaries generator
        sum_grad_g = []
        for i, (grad, vars_) in enumerate(grads_and_vars):
            grad_l2 = tf.sqrt(tf.reduce_sum(tf.square(grad)))
            sum_grad_g.append(
                tf.summary.scalar("grad_l2_g_%d_%s" % (i, vars_.name),
                                  grad_l2))

        # Init:
        tf.global_variables_initializer().run()

        # Summaries
        self.g_sum = tf.summary.merge([
            self.z_sum, self.d_fake_sum, self.G_sum, self.d_loss_fake_sum,
            self.g_loss_sum, self.lrate_sum_g
        ] + sum_grad_g)
        self.d_sum = tf.summary.merge([
            self.z_sum, self.d_real_sum, self.d_loss_real_sum, self.d_loss_sum,
            self.lrate_sum_d
        ] + sum_grad_d)
        self.writer = tf.summary.FileWriter(self.log_dir, self.sess.graph)

        # Z sample
        sample_z = np.random.normal(0, 1.0, size=(self.sample_num, self.z_dim))

        # Input samples
        sample_files = data[0:self.sample_num]
        sample = [
            get_image(sample_file,
                      input_height=self.input_height,
                      input_width=self.input_width,
                      resize_height=self.output_height,
                      resize_width=self.output_width,
                      is_crop=self.is_crop,
                      is_grayscale=self.is_grayscale)
            for sample_file in sample_files
        ]
        if (self.is_grayscale):
            sample_inputs = np.array(sample).astype(np.float32)[:, :, :, None]
        else:
            sample_inputs = np.array(sample).astype(np.float32)

        print("ok")

        start_time = time.time()

        if self.load_checkpoint:
            if self.load(self.checkpoint_dir):
                print(" [*] Load SUCCESS")
            else:
                print(" [!] Load failed...")

        # Batch preparing
        batch_nums = min(len(data), config.train_size) // config.batch_size
        data_idx = list(range(len(data)))

        counter = self.counter_start

        # Loop over epochs
        for epoch in xrange(config.epoch):

            # Assign learning rates for d and g
            lrate = config.learning_rate_d * (config.lr_decay_rate_d**epoch)
            self.sess.run(tf.assign(self.learning_rate_d, lrate))
            lrate = config.learning_rate_g * (config.lr_decay_rate_g**epoch)
            self.sess.run(tf.assign(self.learning_rate_g, lrate))

            # Shuffle the data indices
            np.random.shuffle(data_idx)

            # Loop over batches
            for batch_idx in range(batch_nums):

                # Prepare batch
                idx = data_idx[batch_idx * config.batch_size:(batch_idx + 1) *
                               config.batch_size]
                batch = [
                    get_image(data[i],
                              input_height=self.input_height,
                              input_width=self.input_width,
                              resize_height=self.output_height,
                              resize_width=self.output_width,
                              is_crop=self.is_crop,
                              is_grayscale=self.is_grayscale) for i in idx
                ]
                if (self.is_grayscale):
                    batch_images = np.array(batch).astype(np.float32)[:, :, :,
                                                                      None]
                else:
                    batch_images = np.array(batch).astype(np.float32)

                batch_z = np.random.normal(
                    0, 1.0,
                    size=(config.batch_size, self.z_dim)).astype(np.float32)

                # Update D network
                _, summary_str = self.sess.run([d_optim, self.d_sum],
                                               feed_dict={
                                                   self.inputs: batch_images,
                                                   self.z: batch_z
                                               })
                if np.mod(counter, 20) == 0:
                    self.writer.add_summary(summary_str, counter)

                # Update G network
                _, summary_str = self.sess.run([g_optim, self.g_sum],
                                               feed_dict={self.z: batch_z})
                if np.mod(counter, 20) == 0:
                    self.writer.add_summary(summary_str, counter)

                errD_fake = self.d_loss_fake.eval({self.z: batch_z})
                errD_real = self.d_loss_real.eval({self.inputs: batch_images})
                errG = self.g_loss.eval({self.z: batch_z})

                if np.mod(counter, 100) == 0:
                    print("Epoch: [%2d] [%4d/%4d] time: %4.4f, d_loss: %.8f, g_loss: %.8f" \
                      % (epoch, batch_idx, batch_nums, time.time() - start_time, errD_fake+errD_real, errG))

                if np.mod(counter, 1000) == 0:

                    try:
                        samples, d_loss, g_loss = self.sess.run(
                            [self.sampler, self.d_loss, self.g_loss],
                            feed_dict={
                                self.z: sample_z,
                                self.inputs: sample_inputs
                            })
                        save_images(
                            samples, [8, 8],
                            '{}/train_{:02d}_{:04d}.png'.format(
                                config.sample_dir, epoch, batch_idx))
                        print("[Sample] d_loss: %.8f, g_loss: %.8f" %
                              (d_loss, g_loss))
                    except Exception as e:
                        print(e)
                        print("sample image error!")

                    # FID
                    print("samples for incept")
                    sample_z_dist = np.random.normal(
                        0, 1.0, size=(self.batch_size_dist, self.z_dim))
                    samples = self.sess.run(
                        self.sampler_dist,
                        feed_dict={self.z_dist: sample_z_dist})

                    samples = (samples + 1.) * 127.5
                    predictions = fid.get_predictions(
                        samples,
                        query_tensor,
                        self.sess,
                        batch_size=self.fid_batch_size,
                        verbose=self.fid_verbose)
                    FID = None
                    try:
                        FID, _, _ = fid.FID(predictions, mu_trn, sigma_trn,
                                            self.sess)
                        print("FID = " + str(FID))
                    except Exception as e:
                        print(e)
                        FID = 500

                    self.sess.run(tf.assign(self.fid, FID))
                    summary_str = self.sess.run(self.fid_sum)
                    self.writer.add_summary(summary_str, counter)

                # Save checkpoint
                if (counter != 0) and (np.mod(counter, 2000) == 0):
                    self.save(config.checkpoint_dir, counter)

                counter += 1