def cal_fid(real_samples, G_sample):
    """
    Args:
        real_samples: Samples of per batch.
        G_sample: Tensorflow operation.
    Return:
        : [D_G fid, D_D fid]
    """
    if real_samples.shape[0] < 4:
        return 0, 0
    n_fid = [[], []]
    l = len(real_samples) // 2
    # print(type(real_samples))
    for i in range(5):
        mu_f, sigma_f = fid.calculate_statistics(G_sample[:l])
        shuffling(G_sample)
        mu_r1, sigma_r1 = fid.calculate_statistics(real_samples[:l])
        mu_r2, sigma_r2 = fid.calculate_statistics(real_samples[l:])
        shuffling(real_samples)
        # print('The %d times samples:' % i)
        # print(real_samples)
        n_fid[0].append(
            fid.calculate_frechet_distance(mu_r1, sigma_r1, mu_f, sigma_f))
        n_fid[1].append(
            fid.calculate_frechet_distance(mu_r1, sigma_r1, mu_r2, sigma_r2))
    # print(n_fid[1])
    return np.mean(n_fid, axis=1)
Beispiel #2
0
def main():
    # Paths
    pic_path = os.path.join('./out/c/', 'checkpoints',
                            'dummy.samples%d.npy' % (NUM_POINTS))
    image_path = 'results_celeba/generated'  # set path to some generated images
    stats_path = 'fid_stats_celeba.npz'  # training set statistics
    inception_path = fid.check_or_download_inception(
        None)  # download inception network

    # load precalculated training set statistics
    f = np.load(stats_path)
    mu_real, sigma_real = f['mu'][:], f['sigma'][:]
    f.close()

    #image_list = glob.glob(os.path.join(image_path, '*.png'))
    #images = np.array([imread(str(fn)).astype(np.float32) for fn in image_list])
    images = np.load(pic_path)

    images_t = images / 2.0 + 0.5
    images_t = 255.0 * images_t

    from PIL import Image
    img = Image.fromarray(np.uint8(images_t[0]), 'RGB')
    img.save('my.png')

    fid.create_inception_graph(
        inception_path)  # load the graph into the current TF graph
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        mu_gen, sigma_gen = fid.calculate_activation_statistics(images, sess)

    fid_value = fid.calculate_frechet_distance(mu_gen, sigma_gen, mu_real,
                                               sigma_real)
    print("FID: %s" % fid_value)
Beispiel #3
0
    def train(self, num_batches=200000):
        self.sess.run(tf.global_variables_initializer())
        mean_list = []
        fid_list = []
        start_time = time.time()
        for t in range(0, num_batches):
            for _ in range(0, self.d_iters):
                bx = self.x_sampler(self.batch_size)
                bz = self.z_sampler(self.batch_size, self.z_dim)
                self.sess.run(self.d_adam, feed_dict={self.x: bx, self.z: bz})

            bx = self.x_sampler(self.batch_size)
            bz = self.z_sampler(self.batch_size, self.z_dim)
            self.sess.run([self.g_adam], feed_dict={self.z: bz, self.x: bx})

            if t % 1000 == 0:
                bx = self.x_sampler(self.batch_size)
                bz = self.z_sampler(self.batch_size, self.z_dim)
                dl, gl, gp, x_ = self.sess.run(
                    [self.d_loss, self.g_loss, self.gp_loss, self.x_],
                    feed_dict={
                        self.x: bx,
                        self.z: bz
                    })

                print('Iter [%8d] Time [%.4f] dl [%.4f] gl [%.4f] gp [%.4f]' %
                      (t, time.time() - start_time, dl, gl, gp))

                x_ = self.x_sampler.data2img(x_)
                x_ = grid_transform(x_, self.x_sampler.shape)
                imsave(self.log_dir + '/wos/{}.png'.format(int(t)), x_)

            if t % 10000 == 0 and t > 0:
                in_list = []
                for _ in range(int(50000 / self.batch_size)):
                    bz = self.z_sampler(self.batch_size, self.z_dim)
                    x_ = self.sess.run(self.x_, feed_dict={self.z: bz})
                    x_ = self.x_sampler.data2img(x_)
                    bx_list = np.split(x_, self.batch_size)
                    in_list = in_list + [np.squeeze(x) for x in bx_list]
                mean, std = self.inception.get_inception_score(in_list,
                                                               splits=10)
                mean_list.append(mean)
                np.save(self.log_dir + '/inception_score_wgan_gp.npy',
                        np.asarray(mean_list))
                print('inception score [%.4f]' % (mean))

            if t % 10000 == 0 and t > 0 and args.fid:
                f = np.load(self.stats_path)
                mu_real, sigma_real = f['mu'][:], f['sigma'][:]
                f.close()

                mu_gen, sigma_gen = fid.calculate_activation_statistics(
                    np.array(in_list[:10000]), self.sess, batch_size=100)
                fid_value = fid.calculate_frechet_distance(
                    mu_gen, sigma_gen, mu_real, sigma_real)
                print("FID: %s" % fid_value)
                fid_list.append(fid_value)
                np.save(self.log_dir + '/fid_score_wgan_gp.npy',
                        np.asarray(fid_list))
Beispiel #4
0
 def compute(images):
     m, s = fid.calculate_activation_statistics(transform_for_fid(images),
                                                inception_sess,
                                                args.batch_size,
                                                verbose=True,
                                                model='lenet')
     return fid.calculate_frechet_distance(m, s, mu0, sig0)
Beispiel #5
0
def calculate_fid(generator, fid_net, mu_fid, sigma_fid, n_samples, batch_size,
                  latentsize, gpu_id):
    with torch.no_grad():
        mean = torch.Tensor([0.485, 0.456, 0.406])[None, :, None, None]
        std = torch.Tensor([0.229, 0.224, 0.225])[None, :, None, None]
        if gpu_id != -1:
            mean = mean.cuda()
            std = std.cuda()
        mean, std = Variable(mean), Variable(std)

        n_iter = (n_samples // batch_size) + 1
        n = n_iter * batch_size
        zz = torch.FloatTensor(batch_size, latentsize)
        act = torch.FloatTensor(n, 2048)
        for i in range(n_iter):
            zz.normal_()
            z = Variable(zz)
            if gpu_id != -1:
                z = z.cuda()
            x = generator(z)
            x = (torch.clamp(x, -1.0, +1.0) + 1.0) / 2.0
            x -= mean
            x /= std
            x = F.interpolate(x, 299, mode='bilinear')
            a = fid_net(x)
            act[(i * batch_size):(i + 1) * batch_size] = a.data.cpu()
        act = act.numpy()
        mu = np.mean(act, axis=0, dtype=np.float64)
        sigma = np.cov(act, rowvar=False)
        return fid.calculate_frechet_distance(mu, sigma, mu_fid, sigma_fid)
def real_fid(spectral_data):
    n_fid = []
    for i in range(20):
        shuffle(spectral_data)
        l = len(spectral_data) // 2
        real_1, real_2 = spectral_data[: l], spectral_data[l :]
        mu_1, sigma_1 = fid.calculate_statistics(real_1)
        mu_2, sigma_2 = fid.calculate_statistics(real_2)
        n_fid.append(fid.calculate_frechet_distance(mu_1, sigma_1, mu_2, sigma_2))
    return np.mean(n_fid), mu_1, sigma_1
Beispiel #7
0
    def validation_epoch_end(self, outputs):
        def get_mu_sig(x):
            return np.mean(x, axis=0), np.cov(x, rowvar=False)

        mu1, sig1 = get_mu_sig(
            torch.cat([x[:, 0].cpu() for x in outputs]).numpy())
        mu2, sig2 = get_mu_sig(
            torch.cat([x[:, 1].cpu() for x in outputs]).numpy())

        fid_ = fid.calculate_frechet_distance(mu1, sig1, mu2, sig2)
        return OrderedDict({'val_loss': fid_, 'log': {'fid_loss': fid_}})
Beispiel #8
0
def main(model, data_source, noise_method):
    # load precalculated training set statistics
    if data_source == 'MNIST':
        f = np.load("fid_stats_mnist_1w.npz")
    else:
        f = np.load("fid_stats_fashion_1w.npz")
    mu_real, sigma_real = f['mu'][:], f['sigma'][:]
    f.close()
    
    lambdas = [1,5,10,15,20,25,50,70,100,250]
    if noise_method == 'sp':
        noise_factors = [round(i*0.01,2) for i in range(1,52,2)]
    else:
        noise_factors = [round(i*0.1,1) for i in range(1,10)]
    
    fid_scores = []
    if model == 'RVAE':
        for l in lambdas:
            ls = []
            for nr in noise_factors:
                path = "fid_precalc/"+model+"_"+data_source+"_"+noise_method+"/"+'fid_stats_lambda_'+str(l)+'noise_'+str(nr)+".npz"
                f_g = np.load(path)
                mu_gen, sigma_gen = f_g['mu'][:], f_g['sigma'][:]
                fid_value = fid.calculate_frechet_distance(mu_gen, sigma_gen, mu_real, sigma_real)
                print("FID: %s" % fid_value)
                ls.append(fid_value)
            fid_scores.append(ls)
        fid_scores = np.array(fid_scores)
        np.save("fid_scores_"+model+"_"+data_source+"_"+noise_method+".npy",fid_scores)
        
    else:
        for nr in noise_factors:
            path = "fid_precalc/"+model+"_"+data_source+"_"+noise_method+"/"+'fid_stats_noise_'+str(nr)+".npz"
            f_g = np.load(path)
            mu_gen, sigma_gen = f_g['mu'][:], f_g['sigma'][:]
            fid_value = fid.calculate_frechet_distance(mu_gen, sigma_gen, mu_real, sigma_real)
            print("FID: %s" % fid_value)
            fid_scores.append(fid_value)
        fid_scores = np.array(fid_scores)
        np.save("fid_scores_"+model+"_"+data_source+"_"+noise_method+".npy",fid_scores)
def cal_fid(real_samples, G_sample):
    """
    Args:
        real_samples: Samples of per batch.
        G_sample: Tensorflow operation.
    Return:
        : [D_G fid, D_D fid]
    """
    n_fid = [[], []]
    for i in range(5):
        l = len(real_samples) // 2
        mu_f, sigma_f = fid.calculate_statistics(G_sample[:l])
        shuffle(G_sample)
        mu_r1, sigma_r1 = fid.calculate_statistics(real_samples[:l])
        mu_r2, sigma_r2 = fid.calculate_statistics(real_samples[l:])
        shuffle(real_samples)
        n_fid[0].append(
            fid.calculate_frechet_distance(mu_r1, sigma_r1, mu_f, sigma_f))
        n_fid[1].append(
            fid.calculate_frechet_distance(mu_r1, sigma_r1, mu_r2, sigma_r2))
    # print(n_fid[1])
    return np.mean(n_fid, axis=1)
Beispiel #10
0
    def compute_fid_score(self, generator, timestamp):
        """
        Computes FID of generator using fixed noise dataset;
        appends the current score to the list of computed scores;
        and overwrites the json file that logs the fid scores.

        :param generator: [nn.Module]
        :param timestamp: [int]
        :return: None
        """
        generator.eval()
        fake_samples = np.empty(
            (self.sample_size_fid, self.imsize, self.imsize, 3))
        for j, noise in enumerate(self.fid_noise_loader):
            noise = noise.cuda()
            i1 = j * 200  # batch_size = 200
            i2 = i1 + noise.size(0)
            samples = generator(noise).cpu().data.add(1).mul(255 / 2.0)
            fake_samples[i1:i2] = samples.permute(0, 2, 3, 1).numpy()
        generator.train()
        mu_g, sigma_g = fid.calculate_activation_statistics(fake_samples,
                                                            self.fid_session,
                                                            batch_size=100)
        fid_score = fid.calculate_frechet_distance(mu_g, sigma_g, self.mu_real,
                                                   self.sigma_real)
        _result = {
            'entry': len(self.fid_scores),
            'iter': timestamp,
            'fid': fid_score
        }

        # if best update the checkpoint in self.best_path
        new_best = True
        for prev_fid in self.fid_scores:
            if prev_fid['fid'] < fid_score:
                new_best = False
                break
        if new_best:
            self.backup(timestamp, dir=self.best_path)

        self.fid_scores.append(_result)
        with open(self.fid_json_file, 'w') as _f_fid:
            json.dump(self.fid_scores,
                      _f_fid,
                      sort_keys=True,
                      indent=4,
                      separators=(',', ': '))
Beispiel #11
0
def bootstrap_sample(act1, act2, iterations=10, same_index=False):
    l = len(act1)
    output = []
    if l != len(act2):
        raise RuntimeError('only works when length of both arrays is identical')
    for i in range (iterations):
        iter_array = np.random.randint(low=0, high=l, size=l)
        temp1 = act1[iter_array,:]
        if same_index is False:
            iter_array = np.random.randint(low=0, high=l, size=l)
        temp2 = act2[iter_array,:]
        m1, s1 = get_moments(temp1)
        m2, s2 = get_moments(temp2)

        output.append(fid.calculate_frechet_distance(m1,s1, m2,s2))
    output=np.array(output)
    return output
def caculate_fid(mu_r1, sigma_r1, Z_dim, y_mb, G_sample):
    """
    Args:
        real_samples: Samples of per batch.
        Z_dim: Dimension of noise.
        y_mb: Onehot label of per batch.
        G_sample: Tensorflow operation.
    Return:
        : [D_G fid, D_D fid]
    """
    n_fid = []
    for i in range(20):
        Z_sample = sample_Z(y_mb.shape[0], Z_dim)
        g_sample = sess.run(G_sample, feed_dict={Z: Z_sample, y: y_mb})
        mu_f, sigma_f = fid.calculate_statistics(g_sample)
        n_fid.append(fid.calculate_frechet_distance(mu_r1, sigma_r1, mu_f, sigma_f))
    return np.mean(n_fid)
Beispiel #13
0
def bootstrap_sample(act1, act2, iterations=10, same_index=False):
    l = len(act1)
    output = []
    size = len(act2)
    if l > size:
        raise RuntimeError('subsampling length is bigger than sample')

    m1, s1 = get_moments(act1)
    iter_array = np.arange(size)
    for i in range(iterations):
        iter_array = np.random.shuffle(iter_array)
        iters, iter_array = iter_array[:l], iter_array[l:]

        temp2 = act2[iters, :]
        m2, s2 = get_moments(temp2)
        output.append(fid.calculate_frechet_distance(m1, s1, m2, s2))
        if len(iter_array) < l:
            iter_array = np.arange(size)
    output = np.array(output)
    return output
Beispiel #14
0
def main(paths, iterations, same_index):
    check_paths(paths)

    act1 = get_activation(paths[0])
    act2 = get_activation(paths[1])
    #print(act1.shape)
    #print(act2.shape)
    print((paths[0].split('/')[-1]))
    print('Same index:', same_index)
    print('FID score:')
    m1, s1 = get_moments(act1)
    m2, s2 = get_moments(act2)
    print(fid.calculate_frechet_distance(m1,s1, m2,s2))

    fid_scores = bootstrap_sample(act1, act2, iterations=iterations, same_index=same_index)
    print(fid_scores)
    print('Mean')
    print(fid_scores.mean())
    print('std')
    print(fid_scores.std())
Beispiel #15
0
def _run_fid_calculation(sess, inception_sess, placeholders, batch_size,
                         iteration, generator, mu, sigma, epoch, image_shape,
                         z_input_shape, y_input_shape):
    f = 0.0
    for _ in range(iteration):
        z = util.gen_random_noise(batch_size, z_input_shape)
        y = util.gen_random_label(batch_size, y_input_shape[0])

        images = sess.run(
            tf.reshape(generator, (-1, ) + image_shape), {
                placeholders['z']: z,
                placeholders['y']: y,
                placeholders['mode']: False,
            })
        images = _maybe_grayscale_to_rgb(images)
        images = (images + 1.0) / 2.0 * 255.0

        mu_gen, sigma_gen = fid.calculate_activation_statistics(
            images, inception_sess)
        f += fid.calculate_frechet_distance(mu, sigma, mu_gen, sigma_gen)
    return f / iteration
Beispiel #16
0
  def train(self, config):
    """Train DCGAN"""

    assert len(self.paths) > 0, 'no data loaded, was model not built?'
    print("load train stats.. ", end="", flush=True)
    # load precalculated training set statistics
    f = np.load(self.stats_path)
    mu_real, sigma_real = f['mu'][:], f['sigma'][:]
    f.close()
    print("ok")

    if self.load_checkpoint:
      if self.load(self.checkpoint_dir):
        print(" [*] Load SUCCESS")
      else:
        print(" [!] Load failed...")

    # Batch preparing
    batch_nums = min(len(self.paths), config.train_size) // config.batch_size

    counter = self.counter_start
    errD_fake = 0.
    errD_real = 0.
    errG = 0.
    errG_count = 0
    penD_gradient = 0.
    penD_lipschitz = 0.
    esti_slope = 0.
    lipschitz_estimate = 0.

    start_time = time.time()

    try:
      # Loop over epochs
      for epoch in range(config.epoch):

        # Assign learning rates for d and g
        lrate =  config.learning_rate_d # * (config.lr_decay_rate_d ** epoch)
        self.sess.run(tf.assign(self.learning_rate_d, lrate))
        lrate =  config.learning_rate_g # * (config.lr_decay_rate_g ** epoch)
        self.sess.run(tf.assign(self.learning_rate_g, lrate))

        # Loop over batches
        for batch_idx in range(batch_nums):
          # Update D network
          _, errD_fake_, errD_real_, summary_str, penD_gradient_, penD_lipschitz_, esti_slope_, lipschitz_estimate_ = self.sess.run(
              [self.d_optim, self.d_loss_fake, self.d_loss_real, self.d_sum,
              self.d_gradient_penalty_loss, self.d_lipschitz_penalty_loss, self.d_mean_slope_target, self.d_lipschitz_estimate])
          for i in range(self.num_discriminator_updates - 1):
            self.sess.run([self.d_optim, self.d_loss_fake, self.d_loss_real, self.d_sum,
                           self.d_gradient_penalty_loss, self.d_lipschitz_penalty_loss])
          if np.mod(counter, 20) == 0:
            self.writer.add_summary(summary_str, counter)

          # Update G network
          if config.learning_rate_g > 0.: # and (np.mod(counter, 100) == 0 or lipschitz_estimate_ > 1 / (20 * self.lipschitz_penalty)):
            _, errG_, summary_str = self.sess.run([self.g_optim, self.g_loss, self.g_sum])
            if np.mod(counter, 20) == 0:
              self.writer.add_summary(summary_str, counter)
            errG += errG_
            errG_count += 1

          errD_fake += errD_fake_
          errD_real += errD_real_
          penD_gradient += penD_gradient_
          penD_lipschitz += penD_lipschitz_
          esti_slope += esti_slope_
          lipschitz_estimate += lipschitz_estimate_

          # Print
          if np.mod(counter, 100) == 0:
            print("Epoch: [%2d] [%4d/%4d] time: %4.4f, d_loss: %.8f, lip_pen: %.8f, gradient_pen: %.8f, g_loss: %.8f, d_tgt_slope: %.6f, d_avg_lip: %.6f, g_updates: %3d"  \
              % (epoch, batch_idx, batch_nums, time.time() - start_time, (errD_fake+errD_real) / 100.,
                 penD_lipschitz / 100., penD_gradient / 100., errG / 100., esti_slope / 100., lipschitz_estimate / 100., errG_count))
            errD_fake = 0.
            errD_real = 0.
            errG = 0.
            errG_count = 0
            penD_gradient = 0.
            penD_lipschitz = 0.
            esti_slope = 0.
            lipschitz_estimate = 0.

          # Save generated samples and FID
          if np.mod(counter, config.fid_eval_steps) == 0:

            # Save
            try:
              samples, d_loss, g_loss = self.sess.run(
                  [self.sampler, self.d_loss, self.g_loss])
              save_images(samples, [8, 8], '{}/train_{:02d}_{:04d}.png'.format(config.sample_dir, epoch, batch_idx))
              print("[Sample] d_loss: %.8f, g_loss: %.8f" % (d_loss, g_loss))
            except Exception as e:
              print(e)
              print("sample image error!")

            # FID
            print("samples for incept", end="", flush=True)

            samples = np.zeros((self.fid_n_samples, self.output_height, self.output_width, 3))
            n_batches = self.fid_n_samples // self.fid_sample_batchsize
            lo = 0
            for btch in range(n_batches):
              print("\rsamples for incept %d/%d" % (btch + 1, n_batches), end=" ", flush=True)
              samples[lo:(lo+self.fid_sample_batchsize)] = self.sess.run(self.sampler_fid)
              lo += self.fid_sample_batchsize

            samples = (samples + 1.) * 127.5
            print("ok")

            mu_gen, sigma_gen = fid.calculate_activation_statistics(samples,
                                                             self.sess,
                                                             batch_size=self.fid_batch_size,
                                                             verbose=self.fid_verbose)

            print("calculate FID:", end=" ", flush=True)
            try:
                FID = fid.calculate_frechet_distance(mu_gen, sigma_gen, mu_real, sigma_real)
            except Exception as e:
                print(e)
                FID=500

            print(FID)

            # Update event log with FID
            self.sess.run(tf.assign(self.fid, FID))
            summary_str = self.sess.run(self.fid_sum)
            self.writer.add_summary(summary_str, counter)

          # Save checkpoint
          if (counter != 0) and (np.mod(counter, 2000) == 0):
            self.save(config.checkpoint_dir, counter)

          counter += 1
    except KeyboardInterrupt as e:
      self.save(config.checkpoint_dir, counter)
    except Exception as e:
      print(e)
    finally:
      # When done, ask the threads to stop.
      self.coord.request_stop()
      self.coord.join(self.threads)
Beispiel #17
0
    def run(self, model_path, ckpt_ids):
        '''
            This is the main function for evaluation
            Input:
                model_path: the folder containing all checkpoints
                ckpt_ids: the list of integers indicating checkpoint ids
                that we want to evaluate
        '''
        config = tf.ConfigProto(allow_soft_placement=True,
                                log_device_placement=True)
        config.gpu_options.allow_growth = True
        model = self.model
        with tf.Session(config=config) as sess:
            # start evaluating every checkpoint
            for id in ckpt_ids:
                ckpt = os.path.join(model_path, 'saved-model-' + str(id))
                saver = tf.train.Saver()
                saver.restore(sess, ckpt)

                if flags.DATASET in ['grid']:
                    #---------------------------------------------------------------#
                    # visualization real and random generated samples               #
                    #---------------------------------------------------------------#
                    print('\nVisualizing generated samples at iter %d' % (id))
                    gen_data = model.decode(noise_dist.sample(1000))
                    self.visualize(gen_data, id)

                if flags.DATASET in ['color_mnist']:
                    #---------------------------------------------------------------#
                    # visualization real and random generated samples               #
                    #---------------------------------------------------------------#
                    np.random.seed(SEED)
                    print('\nVisualizing generated samples at iter %d' % (id))
                    real_data, _ = data_dist.sample(10)
                    gen_data = model.decode(noise_dist.sample(
                        model.batch_size))
                    self.visualize(gen_data, id, real_data)

                if flags.DATASET == 'cifar_100':
                    #---------------------------------------------------------------#
                    # visualization real and random generated samples               #
                    #---------------------------------------------------------------#
                    np.random.seed(SEED)
                    print('\nVisualizing generated samples at iter %d' % (id))
                    gen_data = model.decode(noise_dist.sample(
                        model.batch_size))
                    self.visualize(gen_data, id)

                if flags.DATASET in [
                        'color_mnist', 'cifar_100'
                ] and args.method not in ['cougan', 'wgan']:
                    #---------------------------------------------------------------#
                    # visualization real and nearest generated samples              #
                    #---------------------------------------------------------------#
                    np.random.seed(SEED)
                    print('\nVisualizing nearest neighbors at iter %d' %
                          (model.global_step.eval() + 1))

                    real_data, _ = data_dist.sample(10)

                    true_input = model.encode(real_data)
                    if args.method == 'vae':
                        [mean, std] = np.split(true_input, 2, axis=0)
                        epsilon = noise_dist.sample(10)
                        z_sampled = mean + std * epsilon
                    else:
                        z_sampled = true_input

                    gen_data = np.zeros((real_data.shape[0], 90))
                    for i in range(10):
                        perturbed_input = np.random.multivariate_normal(
                            z_sampled[:, i],
                            NOISE_PERTURB * np.eye(z_sampled.shape[0]), 9)
                        gen_data[:, i * 9:i * 9 + 9] = model.decode(
                            perturbed_input.T)
                    self.visualize_closest(real_data, gen_data, id)

                log = {}
                if flags.DATASET in ['low_dim_embed']:
                    #---------------------------------------------------------------#
                    # counting the number of modes using pretrained MLP net          #
                    #---------------------------------------------------------------#
                    print('\nClassifying generated data at iter %d' % (id))
                    gen_data = model.decode(noise_dist.sample(10000))
                    predictions = self.classify_low_dim_embed(
                        os.path.join('classification', 'low_dim_embed',
                                     'model'), gen_data)
                    hist, _ = np.histogram(predictions)
                    log['hist'] = hist
                    if not os.path.exists(
                            os.path.join(args.working_dir, 'others')):
                        os.makedirs(os.path.join(args.working_dir, 'others'))
                    path = os.path.join(args.working_dir, 'others',
                                        'hist_modes.pkl')
                    with open(path, 'wb') as f:
                        pickle.dump(log, f)

                if flags.DATASET in ['color_mnist']:
                    #---------------------------------------------------------------#
                    # counting the number of modes using pretrained CNN          #
                    #---------------------------------------------------------------#
                    print('\nClassifying generated data at iter %d' % (id))
                    hist = 0
                    for i in range(100):
                        gen_data = model.decode(noise_dist.sample(10000))
                        predictions = self.classify_cmnist(
                            os.path.join('classification', 'color_mnist',
                                         'model'), gen_data)
                        h, _ = np.histogram(predictions, bins=1000)
                        hist = hist + h
                    print('number of class having 0 samples %d' %
                          np.sum(hist == 0))
                    print('mean #samples per class %d' % np.mean(hist))
                    print('std #samples per class %d' % np.std(hist))
                    log['hist'] = hist
                    if not os.path.exists(
                            os.path.join(args.working_dir, 'others')):
                        os.makedirs(os.path.join(args.working_dir, 'others'))
                    path = os.path.join(args.working_dir, 'others',
                                        'hist_modes.pkl')
                    with open(path, 'wb') as f:
                        pickle.dump(log, f)

                if flags.DATASET in ['grid', 'low_dim_embed']:
                    #---------------------------------------------------------------#
                    # compute log-likelihood                                        #
                    #---------------------------------------------------------------#
                    val_gen_data = model.decode(noise_dist.sample(10000))
                    gen_data = model.decode(noise_dist.sample(100000))
                    print('\nComputing log-likelihood at iter %d'%\
                        (id))
                    self.compute_log_likelihood(val_gen_data, gen_data, id)

                if flags.DATASET in ['color_mnist']:
                    #---------------------------------------------------------------#
                    # compute FID                                                   #
                    #---------------------------------------------------------------#
                    print('\nComputing FID at iter %d' % (id))
                    path = os.path.join('utils', 'inception_statistics',
                                        'color_mnist')
                    if not os.path.isdir(path):
                        os.mkdir(path)
                    f = os.path.join(path, 'real_statistics')
                    if not os.path.isdir(f + '.npz'):
                        x, _ = data_dist.sample(data_dist.num_train_samples)
                        data = x
                        x, _ = data_dist.test_sample()
                        data = np.hstack((data, x))
                        x, _ = data_dist.val_sample()
                        data = np.hstack((data, x))
                        # reshape data before passing to inception net
                        images = self.reshape_cmnist(data)
                        images = np.round(images * 255)
                        real_mu, real_sigma = self.compute_fid(
                            images,
                            os.path.join('utils', 'inception-2015-12-05',
                                         'classify_image_graph_def.pb'))
                        np.savez(f, mu=real_mu, sigma=real_sigma)
                    else:
                        npzfile = np.load(f)
                        real_mu, real_sigma = npzfile['mu'], npzfile['sigma']

                    gen_data = model.decode(noise_dist.sample(10000))
                    # reshape data before passing to inception net
                    images = self.reshape_cmnist(gen_data)
                    images = np.round(images * 255)
                    gen_mu, gen_sigma = self.compute_fid(
                        images,
                        os.path.join('utils', 'inception-2015-12-05',
                                     'classify_image_graph_def.pb'))
                    fid_value = fid.calculate_frechet_distance(
                        gen_mu, gen_sigma, real_mu, real_sigma)
                    print("FID: %s" % fid_value)

                if flags.DATASET in ['cifar_100']:
                    #---------------------------------------------------------------#
                    # compute FID                                                   #
                    #---------------------------------------------------------------#
                    print('\nComputing FID at iter %d' % (id))
                    path = os.path.join('utils', 'inception_statistics',
                                        'cifar_100')
                    if not os.path.isdir(path):
                        os.mkdir(path)
                    f = os.path.join(path, 'real_statistics')
                    if not os.path.isdir(f + '.npz'):
                        x, _ = data_dist.sample(data_dist.num_train_samples)
                        data = x
                        x, _ = data_dist.test_sample()
                        data = np.hstack((data, x))
                        # reshape data before passing to inception net
                        images = self.reshape_cifar(data)
                        images = np.round(images * 255)
                        real_mu, real_sigma = self.compute_fid(
                            images,
                            os.path.join('utils', 'inception-2015-12-05',
                                         'classify_image_graph_def.pb'))
                        np.savez(f, mu=real_mu, sigma=real_sigma)
                    else:
                        npzfile = np.load(f)
                        real_mu, real_sigma = npzfile['mu'], npzfile['sigma']

                    gen_data = model.decode(noise_dist.sample(10000))
                    # reshape data before passing to inception net
                    images = self.reshape_cifar(gen_data)
                    images = np.round(images * 255)
                    gen_mu, gen_sigma = self.compute_fid(
                        images,
                        os.path.join('utils', 'inception-2015-12-05',
                                     'classify_image_graph_def.pb'))
                    fid_value = fid.calculate_frechet_distance(
                        gen_mu, gen_sigma, real_mu, real_sigma)
                    print("FID: %s" % fid_value)
    def train(self, config):
        """Train DCGAN"""

        print("load train stats.. ", end="", flush=True)
        # load precalculated training set statistics
        f = np.load(self.stats_path)
        mu_real, sigma_real = f['mu'][:], f['sigma'][:]
        f.close()
        print("ok")

        if config.dataset == 'mnist':
            print("scan files", end=" ", flush=True)
            data_X, data_y = self.load_mnist()
        else:
            if (config.dataset == "celebA") or (config.dataset == "cifar10"):
                print("scan files", end=" ", flush=True)
                data = glob(
                    os.path.join(self.data_path, self.input_fname_pattern))
            else:
                if config.dataset == "lsun":
                    print("scan files")
                    data = []
                    for i in range(304):
                        print("\r%d" % i, end="", flush=True)
                        data += glob(
                            os.path.join(self.data_path, str(i),
                                         self.input_fname_pattern))
                else:
                    print(
                        "Please specify dataset in run.sh [mnist, celebA, lsun, cifar10]"
                    )
                    raise SystemExit()

        print()
        print("%d images found" % len(data))

        # Z sample
        #sample_z = np.random.normal(0, 1.0, size=(self.sample_num , self.z_dim))
        sample_z = np.random.uniform(-1.0,
                                     1.0,
                                     size=(self.sample_num, self.z_dim))

        # Input samples
        sample_files = data[0:self.sample_num]
        sample = [
            get_image(sample_file,
                      input_height=self.input_height,
                      input_width=self.input_width,
                      resize_height=self.output_height,
                      resize_width=self.output_width,
                      is_crop=self.is_crop,
                      is_grayscale=self.is_grayscale)
            for sample_file in sample_files
        ]
        if (self.is_grayscale):
            sample_inputs = np.array(sample).astype(np.float32)[:, :, :, None]
        else:
            sample_inputs = np.array(sample).astype(np.float32)

        if self.load_checkpoint:
            if self.load(self.checkpoint_dir):
                print(" [*] Load SUCCESS")
            else:
                print(" [!] Load failed...")

        # Batch preparing
        batch_nums = min(len(data), config.train_size) // config.batch_size
        data_idx = list(range(len(data)))

        counter = self.counter_start

        start_time = time.time()

        # Loop over epochs
        for epoch in range(config.epoch):

            # Assign learning rates for d and g
            lrate = config.learning_rate_d  # * (config.lr_decay_rate_d ** epoch)
            self.sess.run(tf.assign(self.learning_rate_d, lrate))
            lrate = config.learning_rate_g  # * (config.lr_decay_rate_g ** epoch)
            self.sess.run(tf.assign(self.learning_rate_g, lrate))

            # Shuffle the data indices
            np.random.shuffle(data_idx)

            # Loop over batches
            for batch_idx in range(batch_nums):

                # Prepare batch
                idx = data_idx[batch_idx * config.batch_size:(batch_idx + 1) *
                               config.batch_size]
                batch = [
                    get_image(data[i],
                              input_height=self.input_height,
                              input_width=self.input_width,
                              resize_height=self.output_height,
                              resize_width=self.output_width,
                              is_crop=self.is_crop,
                              is_grayscale=self.is_grayscale) for i in idx
                ]
                if (self.is_grayscale):
                    batch_images = np.array(batch).astype(np.float32)[:, :, :,
                                                                      None]
                else:
                    batch_images = np.array(batch).astype(np.float32)

                #batch_z = np.random.normal(0, 1.0, size=(config.batch_size , self.z_dim)).astype(np.float32)
                batch_z = np.random.uniform(
                    -1.0, 1.0,
                    size=(config.batch_size, self.z_dim)).astype(np.float32)

                # Update D network
                _, summary_str = self.sess.run([self.d_optim, self.d_sum],
                                               feed_dict={
                                                   self.inputs: batch_images,
                                                   self.z: batch_z
                                               })
                if np.mod(counter, 20) == 0:
                    self.writer.add_summary(summary_str, counter)

                # Update G network
                _, summary_str = self.sess.run([self.g_optim, self.g_sum],
                                               feed_dict={self.z: batch_z})
                if np.mod(counter, 20) == 0:
                    self.writer.add_summary(summary_str, counter)

                errD_fake = self.d_loss_fake.eval({self.z: batch_z})
                errD_real = self.d_loss_real.eval({self.inputs: batch_images})
                errG = self.g_loss.eval({self.z: batch_z})

                # Print
                if np.mod(counter, 100) == 0:
                    print("Epoch: [%2d] [%4d/%4d] time: %4.4f, d_loss: %.8f, g_loss: %.8f" \
                      % (epoch, batch_idx, batch_nums, time.time() - start_time, errD_fake+errD_real, errG))

                # Save generated samples and FID
                if np.mod(counter, config.fid_eval_steps) == 0:

                    # Save
                    try:
                        samples, d_loss, g_loss = self.sess.run(
                            [self.sampler, self.d_loss, self.g_loss],
                            feed_dict={
                                self.z: sample_z,
                                self.inputs: sample_inputs
                            })
                        save_images(
                            samples, [8, 8],
                            '{}/train_{:02d}_{:04d}.png'.format(
                                config.sample_dir, epoch, batch_idx))
                        print("[Sample] d_loss: %.8f, g_loss: %.8f" %
                              (d_loss, g_loss))
                    except Exception as e:
                        print(e)
                        print("sample image error!")

                    # FID
                    print("samples for incept", end="", flush=True)

                    samples = np.zeros((self.fid_n_samples, self.output_height,
                                        self.output_width, 3))
                    n_batches = self.fid_n_samples // self.fid_sample_batchsize
                    lo = 0
                    for btch in range(n_batches):
                        print("\rsamples for incept %d/%d" %
                              (btch + 1, n_batches),
                              end=" ",
                              flush=True)
                        #sample_z_fid = np.random.normal(0, 1.0, size=(self.fid_sample_batchsize, self.z_dim))
                        sample_z_fid = np.random.uniform(
                            -1.0,
                            1.0,
                            size=(self.fid_sample_batchsize, self.z_dim))
                        samples[lo:(
                            lo + self.fid_sample_batchsize)] = self.sess.run(
                                self.sampler_fid,
                                feed_dict={self.z_fid: sample_z_fid})
                        lo += self.fid_sample_batchsize

                    samples = (samples + 1.) * 127.5
                    print("ok")

                    mu_gen, sigma_gen = fid.calculate_activation_statistics(
                        samples,
                        self.sess,
                        batch_size=self.fid_batch_size,
                        verbose=self.fid_verbose)

                    print("calculate FID:", end=" ", flush=True)
                    try:
                        FID = fid.calculate_frechet_distance(
                            mu_gen, sigma_gen, mu_real, sigma_real)
                    except Exception as e:
                        print(e)
                        FID = 500

                    print(FID)

                    # Update event log with FID
                    self.sess.run(tf.assign(self.fid, FID))
                    summary_str = self.sess.run(self.fid_sum)
                    self.writer.add_summary(summary_str, counter)

                # Save checkpoint
                if (counter != 0) and (np.mod(counter, 2000) == 0):
                    self.save(config.checkpoint_dir, counter)

                counter += 1
Beispiel #19
0
def run(config):
    # Prepare state dict, which holds things like epoch # and itr #
    state_dict = {
        'itr': 0,
        'epoch': 0,
        'save_num': 0,
        'save_best_num': 0,
        'best_IS': 0,
        'best_FID': 999999,
        'config': config
    }

    # Optionally, get the configuration from the state dict. This allows for
    # recovery of the config provided only a state dict and experiment name,
    # and can be convenient for writing less verbose sample shell scripts.
    if config['config_from_name']:
        utils.load_weights(None,
                           None,
                           state_dict,
                           config['weights_root'],
                           config['experiment_name'],
                           config['load_weights'],
                           None,
                           strict=False,
                           load_optim=False)
        # Ignore items which we might want to overwrite from the command line
        for item in state_dict['config']:
            if item not in [
                    'z_var', 'base_root', 'batch_size', 'G_batch_size',
                    'use_ema', 'G_eval_mode'
            ]:
                config[item] = state_dict['config'][item]

    # update config (see train.py for explanation)
    config['resolution'] = utils.imsize_dict[config['dataset']]
    config['n_classes'] = utils.nclass_dict[config['dataset']]
    config['n_channels'] = utils.nchannels_dict[config['dataset']]
    config['G_activation'] = utils.activation_dict[config['G_nl']]
    config['D_activation'] = utils.activation_dict[config['D_nl']]
    config = utils.update_config_roots(config)
    config['skip_init'] = True
    config['no_optim'] = True
    device = 'cuda'

    # Seed RNG
    # utils.seed_rng(config['seed'])

    # Setup cudnn.benchmark for free speed
    torch.backends.cudnn.benchmark = True

    # Import the model--this line allows us to dynamically select different files.
    model = __import__(config['model'])
    experiment_name = (config['experiment_name'] if config['experiment_name']
                       else utils.name_from_config(config))
    print('Experiment name is %s' % experiment_name)

    G = model.Generator(**config).cuda()
    utils.count_parameters(G)

    # In some cases we need to load D
    if True or config['get_test_error'] or config['get_train_error'] or config[
            'get_self_error'] or config['get_generator_error']:
        disc_config = config.copy()
        if config['mh_csc_loss'] or config['mh_loss']:
            disc_config['output_dim'] = disc_config['n_classes'] + 1
        D = model.Discriminator(**disc_config).to(device)

        def get_n_correct_from_D(x, y):
            """Gets the "classifications" from D.
      
      y: the correct labels
      
      In the case of projection discrimination we have to pass in all the labels
      as conditionings to get the class specific affinity.
      """
            x = x.to(device)
            if config['model'] == 'BigGAN':  # projection discrimination case
                if not config['get_self_error']:
                    y = y.to(device)
                yhat = D(x, y)
                for i in range(1, config['n_classes']):
                    yhat_ = D(x, ((y + i) % config['n_classes']))
                    yhat = torch.cat([yhat, yhat_], 1)
                preds_ = yhat.data.max(1)[1].cpu()
                return preds_.eq(0).cpu().sum()
            else:  # the mh gan case
                if not config['get_self_error']:
                    y = y.to(device)
                yhat = D(x)
                preds_ = yhat[:, :config['n_classes']].data.max(1)[1]
                return preds_.eq(y.data).cpu().sum()

    # Load weights
    print('Loading weights...')
    # Here is where we deal with the ema--load ema weights or load normal weights
    utils.load_weights(G if not (config['use_ema']) else None,
                       D,
                       state_dict,
                       config['weights_root'],
                       experiment_name,
                       config['load_weights'],
                       G if config['ema'] and config['use_ema'] else None,
                       strict=False,
                       load_optim=False)
    # Update batch size setting used for G
    G_batch_size = max(config['G_batch_size'], config['batch_size'])
    z_, y_ = utils.prepare_z_y(G_batch_size,
                               G.dim_z,
                               config['n_classes'],
                               device=device,
                               fp16=config['G_fp16'],
                               z_var=config['z_var'])

    if config['G_eval_mode']:
        print('Putting G in eval mode..')
        G.eval()
    else:
        print('G is in %s mode...' % ('training' if G.training else 'eval'))

    sample = functools.partial(utils.sample, G=G, z_=z_, y_=y_, config=config)
    brief_expt_name = config['experiment_name'][-30:]

    # load results dict always
    HIST_FNAME = 'scoring_hist.npy'

    def load_or_make_hist(d):
        """make/load history files in each
    """
        if not os.path.isdir(d):
            raise Exception('%s is not a valid directory' % d)
        f = os.path.join(d, HIST_FNAME)
        if os.path.isfile(f):
            return np.load(f, allow_pickle=True).item()
        else:
            return defaultdict(dict)

    hist_dir = os.path.join(config['weights_root'], config['experiment_name'])
    hist = load_or_make_hist(hist_dir)

    if config['get_test_error'] or config['get_train_error']:
        loaders = utils.get_data_loaders(
            **{
                **config, 'batch_size': config['batch_size'],
                'start_itr': state_dict['itr'],
                'use_test_set': config['get_test_error']
            })
        acc_type = 'Test' if config['get_test_error'] else 'Train'

        pbar = tqdm(loaders[0])
        loader_total = len(loaders[0]) * config['batch_size']
        sample_todo = min(config['sample_num_error'], loader_total)
        print('Getting %s error accross %i examples' % (acc_type, sample_todo))
        correct = 0
        total = 0

        with torch.no_grad():
            for i, (x, y) in enumerate(pbar):
                correct += get_n_correct_from_D(x, y)
                total += config['batch_size']
                if loader_total > total and total >= config['sample_num_error']:
                    print('Quitting early...')
                    break

        accuracy = float(correct) / float(total)
        hist = load_or_make_hist(hist_dir)
        hist[state_dict['itr']][acc_type] = accuracy
        np.save(os.path.join(hist_dir, HIST_FNAME), hist)

        print('[%s][%06d] %s accuracy: %f.' %
              (brief_expt_name, state_dict['itr'], acc_type, accuracy * 100))

    if config['get_self_error']:
        n_used_imgs = config['sample_num_error']
        correct = 0
        imageSize = config['resolution']
        x = np.empty((n_used_imgs, imageSize, imageSize, 3), dtype=np.uint8)
        for l in tqdm(range(n_used_imgs // G_batch_size),
                      desc='Generating [%s][%06d]' %
                      (brief_expt_name, state_dict['itr'])):
            with torch.no_grad():
                images, y = sample()
                correct += get_n_correct_from_D(images, y)

        accuracy = float(correct) / float(n_used_imgs)
        print('[%s][%06d] %s accuracy: %f.' %
              (brief_expt_name, state_dict['itr'], 'Self', accuracy * 100))
        hist = load_or_make_hist(hist_dir)
        hist[state_dict['itr']]['Self'] = accuracy
        np.save(os.path.join(hist_dir, HIST_FNAME), hist)

    if config['get_generator_error']:

        if config['dataset'] == 'C10':
            from classification.models.densenet import DenseNet121
            from torchvision import transforms
            compnet = DenseNet121()
            compnet = torch.nn.DataParallel(compnet)
            #checkpoint = torch.load(os.path.join('/scratch0/ilya/locDoc/classifiers/densenet121','ckpt_47.t7'))
            checkpoint = torch.load(
                os.path.join(
                    '/fs/vulcan-scratch/ilyak/locDoc/experiments/classifiers/cifar/densenet121',
                    'ckpt_47.t7'))
            compnet.load_state_dict(checkpoint['net'])
            compnet = compnet.to(device)
            compnet.eval()
            minimal_trans = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.4914, 0.4822, 0.4465),
                                     (0.2023, 0.1994, 0.2010)),
            ])
        elif config['dataset'] == 'C100':
            from classification.models.densenet import DenseNet121
            from torchvision import transforms
            compnet = DenseNet121(num_classes=100)
            compnet = torch.nn.DataParallel(compnet)
            checkpoint = torch.load(
                os.path.join(
                    '/scratch0/ilya/locDoc/classifiers/cifar100/densenet121',
                    'ckpt.copy.t7'))
            #checkpoint = torch.load(os.path.join('/fs/vulcan-scratch/ilyak/locDoc/experiments/classifiers/cifar100/densenet121','ckpt.copy.t7'))
            compnet.load_state_dict(checkpoint['net'])
            compnet = compnet.to(device)
            compnet.eval()
            minimal_trans = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.507, 0.487, 0.441),
                                     (0.267, 0.256, 0.276)),
            ])
        elif config['dataset'] == 'STL48':
            from classification.models.wideresnet import WideResNet48
            from torchvision import transforms
            checkpoint = torch.load(
                os.path.join(
                    '/fs/vulcan-scratch/ilyak/locDoc/experiments/classifiers/stl/mixmatch_48',
                    'model_best.pth.tar'))
            compnet = WideResNet48(num_classes=10)
            compnet = compnet.to(device)
            for param in compnet.parameters():
                param.detach_()
            compnet.load_state_dict(checkpoint['ema_state_dict'])
            compnet.eval()
            minimal_trans = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.4914, 0.4822, 0.4465),
                                     (0.2023, 0.1994, 0.2010)),
            ])
        else:
            raise ValueError('Dataset %s has no comparison network.' %
                             config['dataset'])

        n_used_imgs = 10000
        correct = 0
        mean_label = np.zeros(config['n_classes'])
        imageSize = config['resolution']
        x = np.empty((n_used_imgs, imageSize, imageSize, 3), dtype=np.uint8)
        for l in tqdm(range(n_used_imgs // G_batch_size),
                      desc='Generating [%s][%06d]' %
                      (brief_expt_name, state_dict['itr'])):
            with torch.no_grad():
                images, y = sample()
                fake = images.data.cpu().numpy()
                fake = np.floor((fake + 1) * 255 / 2.0).astype(np.uint8)
                fake_input = np.zeros(fake.shape)
                for bi in range(fake.shape[0]):
                    fake_input[bi] = minimal_trans(np.moveaxis(
                        fake[bi], 0, -1))
                images.data.copy_(torch.from_numpy(fake_input))
                lab = compnet(images).max(1)[1]
                mean_label += np.bincount(lab.data.cpu(),
                                          minlength=config['n_classes'])
                correct += int((lab == y).sum().cpu())

        accuracy = float(correct) / float(n_used_imgs)
        mean_label_normalized = mean_label / float(n_used_imgs)

        print(
            '[%s][%06d] %s accuracy: %f.' %
            (brief_expt_name, state_dict['itr'], 'Generator', accuracy * 100))
        hist = load_or_make_hist(hist_dir)
        hist[state_dict['itr']]['Generator'] = accuracy
        hist[state_dict['itr']]['Mean_Label'] = mean_label_normalized
        np.save(os.path.join(hist_dir, HIST_FNAME), hist)

    if config['accumulate_stats']:
        print('Accumulating standing stats across %d accumulations...' %
              config['num_standing_accumulations'])
        utils.accumulate_standing_stats(G, z_, y_, config['n_classes'],
                                        config['num_standing_accumulations'])

    # Sample a number of images and save them to an NPZ, for use with TF-Inception
    if config['sample_npz']:
        # Lists to hold images and labels for images
        x, y = [], []
        print('Sampling %d images and saving them to npz...' %
              config['sample_num_npz'])
        for i in trange(
                int(np.ceil(config['sample_num_npz'] / float(G_batch_size)))):
            with torch.no_grad():
                images, labels = sample()
            x += [np.uint8(255 * (images.cpu().numpy() + 1) / 2.)]
            y += [labels.cpu().numpy()]
        x = np.concatenate(x, 0)[:config['sample_num_npz']]
        y = np.concatenate(y, 0)[:config['sample_num_npz']]
        print('Images shape: %s, Labels shape: %s' % (x.shape, y.shape))
        npz_filename = '%s/%s/samples.npz' % (config['samples_root'],
                                              experiment_name)
        print('Saving npz to %s...' % npz_filename)
        np.savez(npz_filename, **{'x': x, 'y': y})

    if config['official_FID']:
        f = np.load(config['dataset_is_fid'])
        # this is for using the downloaded one from
        # https://github.com/bioinf-jku/TTUR
        #mdata, sdata = f['mu'][:], f['sigma'][:]

        # this one is for my format files
        mdata, sdata = f['mfid'], f['sfid']

    # Sample a number of images and stick them in memory, for use with TF-Inception official_IS and official_FID
    data_gen_necessary = False
    if config['sample_np_mem']:
        is_saved = int('IS' in hist[state_dict['itr']])
        is_todo = int(config['official_IS'])
        fid_saved = int('FID' in hist[state_dict['itr']])
        fid_todo = int(config['official_FID'])
        data_gen_necessary = config['overwrite'] or (is_todo > is_saved) or (
            fid_todo > fid_saved)
    if config['sample_np_mem'] and data_gen_necessary:
        n_used_imgs = 50000
        imageSize = config['resolution']
        x = np.empty((n_used_imgs, imageSize, imageSize, 3), dtype=np.uint8)
        for l in tqdm(range(n_used_imgs // G_batch_size),
                      desc='Generating [%s][%06d]' %
                      (brief_expt_name, state_dict['itr'])):
            start = l * G_batch_size
            end = start + G_batch_size

            with torch.no_grad():
                images, labels = sample()
            fake = np.uint8(255 * (images.cpu().numpy() + 1) / 2.)
            x[start:end] = np.moveaxis(fake, 1, -1)
            #y += [labels.cpu().numpy()]

    if config['official_IS']:
        if (not ('IS' in hist[state_dict['itr']])) or config['overwrite']:
            mis, sis = iscore.get_inception_score(x)
            print('[%s][%06d] IS mu: %f. IS sigma: %f.' %
                  (brief_expt_name, state_dict['itr'], mis, sis))
            hist = load_or_make_hist(hist_dir)
            hist[state_dict['itr']]['IS'] = [mis, sis]
            np.save(os.path.join(hist_dir, HIST_FNAME), hist)
        else:
            mis, sis = hist[state_dict['itr']]['IS']
            print(
                '[%s][%06d] Already done (skipping...): IS mu: %f. IS sigma: %f.'
                % (brief_expt_name, state_dict['itr'], mis, sis))

    if config['official_FID']:
        import tensorflow as tf

        def fid_ms_for_imgs(images, mem_fraction=0.5):
            gpu_options = tf.GPUOptions(
                per_process_gpu_memory_fraction=mem_fraction)
            inception_path = fid.check_or_download_inception(None)
            fid.create_inception_graph(
                inception_path)  # load the graph into the current TF graph
            with tf.Session(config=tf.ConfigProto(
                    gpu_options=gpu_options)) as sess:
                sess.run(tf.global_variables_initializer())
                mu_gen, sigma_gen = fid.calculate_activation_statistics(
                    images, sess, batch_size=100)
            return mu_gen, sigma_gen

        if (not ('FID' in hist[state_dict['itr']])) or config['overwrite']:
            m1, s1 = fid_ms_for_imgs(x)
            fid_value = fid.calculate_frechet_distance(m1, s1, mdata, sdata)
            print('[%s][%06d] FID: %f' %
                  (brief_expt_name, state_dict['itr'], fid_value))
            hist = load_or_make_hist(hist_dir)
            hist[state_dict['itr']]['FID'] = fid_value
            np.save(os.path.join(hist_dir, HIST_FNAME), hist)
        else:
            fid_value = hist[state_dict['itr']]['FID']
            print('[%s][%06d] Already done (skipping...): FID: %f' %
                  (brief_expt_name, state_dict['itr'], fid_value))

    # Prepare sample sheets
    if config['sample_sheets']:
        print('Preparing conditional sample sheets...')
        folder_number = config['sample_sheet_folder_num']
        if folder_number == -1:
            folder_number = config['load_weights']
        utils.sample_sheet(
            G,
            classes_per_sheet=utils.classes_per_sheet_dict[config['dataset']],
            num_classes=config['n_classes'],
            samples_per_class=10,
            parallel=config['parallel'],
            samples_root=config['samples_root'],
            experiment_name=experiment_name,
            folder_number=folder_number,
            z_=z_,
        )
    # Sample interp sheets
    if config['sample_interps']:
        print('Preparing interp sheets...')
        folder_number = config['sample_sheet_folder_num']
        if folder_number == -1:
            folder_number = config['load_weights']
        for fix_z, fix_y in zip([False, False, True], [False, True, False]):
            utils.interp_sheet(G,
                               num_per_sheet=16,
                               num_midpoints=8,
                               num_classes=config['n_classes'],
                               parallel=config['parallel'],
                               samples_root=config['samples_root'],
                               experiment_name=experiment_name,
                               folder_number=int(folder_number),
                               sheet_number=0,
                               fix_z=fix_z,
                               fix_y=fix_y,
                               device='cuda')
    # Sample random sheet
    if config['sample_random']:
        print('Preparing random sample sheet...')
        images, labels = sample()
        torchvision.utils.save_image(
            images.float(),
            '%s/%s/%s.jpg' %
            (config['samples_root'], experiment_name, config['load_weights']),
            nrow=int(G_batch_size**0.5),
            normalize=True)

    # Prepare a simple function get metrics that we use for trunc curves
    def get_metrics():
        # Get Inception Score and FID
        get_inception_metrics = inception_utils.prepare_inception_metrics(
            config['dataset'], config['parallel'], config['no_fid'])
        sample = functools.partial(utils.sample,
                                   G=G,
                                   z_=z_,
                                   y_=y_,
                                   config=config)
        IS_mean, IS_std, FID = get_inception_metrics(
            sample,
            config['num_inception_images'],
            num_splits=10,
            prints=False)
        # Prepare output string
        outstring = 'Using %s weights ' % ('ema'
                                           if config['use_ema'] else 'non-ema')
        outstring += 'in %s mode, ' % ('eval' if config['G_eval_mode'] else
                                       'training')
        outstring += 'with noise variance %3.3f, ' % z_.var
        outstring += 'over %d images, ' % config['num_inception_images']
        if config['accumulate_stats'] or not config['G_eval_mode']:
            outstring += 'with batch size %d, ' % G_batch_size
        if config['accumulate_stats']:
            outstring += 'using %d standing stat accumulations, ' % config[
                'num_standing_accumulations']
        outstring += 'Itr %d: PYTORCH UNOFFICIAL Inception Score is %3.3f +/- %3.3f, PYTORCH UNOFFICIAL FID is %5.4f' % (
            state_dict['itr'], IS_mean, IS_std, FID)
        print(outstring)

    if config['sample_inception_metrics']:
        print('Calculating Inception metrics...')
        get_metrics()

    # Sample truncation curve stuff. This is basically the same as the inception metrics code
    if config['sample_trunc_curves']:
        start, step, end = [
            float(item) for item in config['sample_trunc_curves'].split('_')
        ]
        print(
            'Getting truncation values for variance in range (%3.3f:%3.3f:%3.3f)...'
            % (start, step, end))
        for var in np.arange(start, end + step, step):
            z_.var = var
            # Optionally comment this out if you want to run with standing stats
            # accumulated at one z variance setting
            if config['accumulate_stats']:
                utils.accumulate_standing_stats(
                    G, z_, y_, config['n_classes'],
                    config['num_standing_accumulations'])
            get_metrics()
    def calculate_fid(self):
        import fid
        import tensorflow as tf

        num_of_step = 500
        bs = 100

        sigmas = np.exp(
            np.linspace(np.log(self.config.model.sigma_begin),
                        np.log(self.config.model.sigma_end),
                        self.config.model.num_classes))
        stats_path = 'fid_stats_cifar10_train.npz'  # training set statistics
        inception_path = fid.check_or_download_inception(
            None)  # download inception network

        print('Load checkpoint from' + self.args.log)
        #for epochs in range(140000, 200001, 1000):
        for epochs in [149000]:
            states = torch.load(os.path.join(
                self.args.log, 'checkpoint_' + str(epochs) + '.pth'),
                                map_location=self.config.device)
            #states = torch.load(os.path.join(self.args.log, 'checkpoint.pth'), map_location=self.config.device)
            score = CondRefineNetDilated(self.config).to(self.config.device)
            score = torch.nn.DataParallel(score)

            score.load_state_dict(states[0])

            score.eval()

            if self.config.data.dataset == 'MNIST':
                print("Begin epochs", epochs)
                samples = torch.rand(bs, 1, 28, 28, device=self.config.device)
                all_samples = self.anneal_Langevin_dynamics_GenerateImages(
                    samples, score, sigmas, 100, 0.00002)
                images = all_samples.mul_(255).add_(0.5).clamp_(
                    0, 255).permute(0, 2, 3, 1).to('cpu').numpy()
                for j in range(num_of_step - 1):
                    samples = torch.rand(bs,
                                         3,
                                         32,
                                         32,
                                         device=self.config.device)
                    all_samples = self.anneal_Langevin_dynamics_GenerateImages(
                        samples, score, sigmas, 100, 0.00002)
                    images_new = all_samples.mul_(255).add_(0.5).clamp_(
                        0, 255).permute(0, 2, 3, 1).to('cpu').numpy()
                    images = np.concatenate((images, images_new), axis=0)

            else:
                print("Begin epochs", epochs)
                samples = torch.rand(bs, 3, 32, 32, device=self.config.device)
                all_samples = self.anneal_Langevin_dynamics_GenerateImages(
                    samples, score, sigmas, 100, 0.00002)
                images = all_samples.mul_(255).add_(0.5).clamp_(
                    0, 255).permute(0, 2, 3, 1).to('cpu').numpy()
                for j in range(num_of_step - 1):
                    samples = torch.rand(bs,
                                         3,
                                         32,
                                         32,
                                         device=self.config.device)
                    all_samples = self.anneal_Langevin_dynamics_GenerateImages(
                        samples, score, sigmas, 100, 0.00002)
                    images_new = all_samples.mul_(255).add_(0.5).clamp_(
                        0, 255).permute(0, 2, 3, 1).to('cpu').numpy()
                    images = np.concatenate((images, images_new), axis=0)

            # load precalculated training set statistics
            f = np.load(stats_path)
            mu_real, sigma_real = f['mu'][:], f['sigma'][:]
            f.close()

            fid.create_inception_graph(
                inception_path)  # load the graph into the current TF graph
            with tf.Session() as sess:
                sess.run(tf.global_variables_initializer())
                mu_gen, sigma_gen = fid.calculate_activation_statistics(
                    images, sess, batch_size=100)

            fid_value = fid.calculate_frechet_distance(mu_gen, sigma_gen,
                                                       mu_real, sigma_real)
            print("FID: %s" % fid_value)
Beispiel #21
0
def run(dataset,
        generator_type,
        discriminator_type,
        latentsize,
        kernel_dimension,
        epsilon,
        learning_rate,
        batch_size,
        options,
        logdir_base='/tmp'):
    if dataset in ['billion_word']:
        dataset_type = 'text'
    else:
        dataset_type = 'image'
    tf.reset_default_graph()
    dtype = tf.float32

    run_name = '_'.join([
        '%s' % get_timestamp(),
        'g%s' % generator_type,
        'd%s' % discriminator_type,
        'z%d' % latentsize,
        'l%1.0e' % learning_rate,
        'l2p%1.0e' % options.l2_penalty,
        #'d%d' % kernel_dimension,
        #'eps%3.2f' % epsilon,
        'lds%1.e' % options.discriminator_lr_scale,
    ])
    run_name += ("_l2pscale%1.e" %
                 options.gen_l2p_scale) if options.gen_l2p_scale != 1.0 else ''
    run_name += "_M" if options.remember_previous else ''
    run_name += ("_dl%s" %
                 options.disc_loss) if options.disc_loss != 'l2' else ''
    run_name += ("_%s" %
                 options.logdir_suffix) if options.logdir_suffix else ''
    run_name = run_name.replace('+', '')

    if options.verbosity == 0:
        tf.logging.set_verbosity(tf.logging.ERROR)

    subdir = "%s_%s" % (get_timestamp('%y%m%d'), dataset)
    logdir = Path(logdir_base) / subdir / run_name
    print_info("\nLogdir: %s\n" % logdir, options.verbosity > 0)
    if __name__ == "__main__" and options.sample_images is None:
        startup_bookkeeping(logdir, __file__)
        trainlog = open(str(logdir / 'logfile.csv'), 'w')
    else:
        trainlog = None

    dataset_pattern, n_samples, img_shape = get_dataset_path(dataset)
    z = tf.random_normal([batch_size, latentsize], dtype=dtype, name="z")
    if dataset_type == 'text':
        n_samples = options.num_examples
        y, lines_as_ints, charmap, inv_charmap = load_text_dataset(
            dataset_pattern,
            batch_size,
            options.sequence_length,
            options.num_examples,
            options.max_vocab_size,
            shuffle=True,
            num_epochs=None)
        img_shape = [options.sequence_length, len(charmap)]
        true_ngram_model = ngram_language_model.NgramLanguageModel(
            lines_as_ints, options.ngrams, len(charmap))
    else:
        y = load_image_dataset(dataset_pattern,
                               batch_size,
                               img_shape,
                               n_threads=options.threads)

    x = create_generator(z, img_shape,
                         options.l2_penalty * options.gen_l2p_scale,
                         generator_type, batch_size)
    assert x.get_shape().as_list()[1:] == y.get_shape().as_list(
    )[1:], "X and Y have different shapes: %s vs %s" % (
        x.get_shape().as_list(), y.get_shape().as_list())

    disc_x = create_discriminator(x, discriminator_type, options.l2_penalty,
                                  False)
    disc_y = create_discriminator(y, discriminator_type, options.l2_penalty,
                                  True)

    with tf.name_scope('loss'):
        disc_x = tf.reshape(disc_x, [-1])
        disc_y = tf.reshape(disc_y, [-1])
        pot_x, pot_y = get_potentials(x, y, kernel_dimension, epsilon)

        if options.disc_loss == 'l2':
            disc_loss_fn = tf.losses.mean_squared_error
        elif options.disc_loss == 'l1':
            disc_loss_fn = tf.losses.absolute_difference
        else:
            assert False, "Unknown Discriminator Loss: %s" % options.disc_loss

        loss_d_x = disc_loss_fn(pot_x, disc_x)
        loss_d_y = disc_loss_fn(pot_y, disc_y)
        loss_d = loss_d_x + loss_d_y
        loss_g = tf.reduce_mean(disc_x)

        if options.remember_previous:
            x_old = tf.get_variable("x_old",
                                    shape=x.shape,
                                    initializer=tf.zeros_initializer(),
                                    trainable=False)
            disc_x_old = create_discriminator(x_old, discriminator_type,
                                              options.l2_penalty, True)
            disc_x_old = tf.reshape(disc_x_old, [-1])
            pot_x_old = calculate_potential(x, y, x_old, kernel_dimension,
                                            epsilon)
            loss_d_x_old = disc_loss_fn(pot_x_old, disc_x_old)
            loss_d += loss_d_x_old

    vars_d = [
        v for v in tf.global_variables() if v.name.startswith('discriminator')
    ]
    vars_g = [
        v for v in tf.global_variables() if v.name.startswith('generator')
    ]
    optim_d = tf.train.AdamOptimizer(learning_rate *
                                     options.discriminator_lr_scale,
                                     beta1=options.discriminator_beta1,
                                     beta2=options.discriminator_beta2)
    optim_g = tf.train.AdamOptimizer(learning_rate,
                                     beta1=options.generator_beta1,
                                     beta2=options.generator_beta2)

    # we can sum all regularizers in one term, the var-list argument to minimize
    # should make sure each optimizer only regularizes "its own" variables
    regularizers = tf.reduce_sum(
        tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
    train_op_d = optim_d.minimize(loss_d + regularizers, var_list=vars_d)
    train_op_g = optim_g.minimize(loss_g + regularizers, var_list=vars_g)
    train_op = tf.group(train_op_d, train_op_g)

    if options.remember_previous:
        with tf.control_dependencies([train_op]):
            assign_x_op = tf.assign(x_old, x)
        train_op = tf.group(train_op, assign_x_op)

    # Tensorboard summaries
    if dataset_type == 'image':
        x_img = (tf.clip_by_value(x, -1.0, 1.0) + 1) / 2.0
        y_img = tf.clip_by_value((y + 1) / 2, 0.0, 1.0)
    with tf.name_scope('potential'):
        tf.summary.histogram('x', pot_x)
        tf.summary.histogram('y', pot_y)
        if options.remember_previous:
            tf.summary.histogram('x_old', pot_x_old)
    if options.create_summaries:
        if dataset_type == 'image':
            with tf.name_scope("distances"):
                tf.summary.histogram("xx", generate_all_distances(x, x))
                tf.summary.histogram("xy", generate_all_distances(x, y))
                tf.summary.histogram("yy", generate_all_distances(y, y))
        with tf.name_scope('discriminator_stats'):
            tf.summary.histogram('output_x', disc_x)
            tf.summary.histogram('output_y', disc_y)
            tf.summary.histogram('pred_error_y', pot_y - disc_y)
            tf.summary.histogram('pred_error_x', pot_x - disc_x)
        if dataset_type == 'image':
            img_smry = tf.summary.image("out_img", x_img, 2)
            img_smry = tf.summary.image("in_img", y_img, 2)
        with tf.name_scope("losses"):
            tf.summary.scalar('loss_d_x', loss_d_x)
            tf.summary.scalar('loss_d_y', loss_d_y)
            tf.summary.scalar('loss_d', loss_d)
            tf.summary.scalar('loss_g', loss_g)

        with tf.name_scope('weightnorm'):
            for v in tf.global_variables():
                if not v.name.endswith('kernel:0'):
                    continue
                tf.summary.scalar("wn_" + v.name[:-8], tf.norm(v))
        with tf.name_scope('mean_activations'):
            for op in tf.get_default_graph().get_operations():
                if not op.name.endswith('Tanh'):
                    continue
                tf.summary.scalar("act_" + op.name,
                                  tf.reduce_mean(op.outputs[0]))
    merged_smry = tf.summary.merge_all()

    if dataset_type == 'image':
        fid_stats_file = options.fid_stats % dataset.lower()
        assert Path(fid_stats_file).exists(
        ), "Can't find training set statistics for FID (%s)" % fid_stats_file
        f = np.load(fid_stats_file)
        mu_fid, sigma_fid = f['mu'][:], f['sigma'][:]
        f.close()
        inception_path = fid.check_or_download_inception(
            options.inception_path)
        fid.create_inception_graph(inception_path)

    maxv = 0.05
    cmap = plt.cm.ScalarMappable(mpl.colors.Normalize(-maxv, maxv),
                                 cmap=plt.cm.RdBu)
    config = tf.ConfigProto(intra_op_parallelism_threads=2,
                            inter_op_parallelism_threads=2,
                            use_per_session_threads=True,
                            gpu_options=tf.GPUOptions(allow_growth=True))

    save_vars = [
        v for v in tf.global_variables() if v.name.startswith('generator')
    ]
    save_vars += [
        v for v in tf.global_variables() if v.name.startswith('discriminator')
    ]

    with tf.Session(config=config) as sess:
        log = tf.summary.FileWriter(str(logdir), sess.graph)
        sess.run(tf.global_variables_initializer())
        if options.resume_checkpoint:
            loader = tf.train.Saver(save_vars)
            loader.restore(sess, options.resume_checkpoint)
        sess.run(tf.local_variables_initializer())
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)
        fd = {}

        if options.sample_images is not None:
            n_batches = (1 + options.sample_images_size // batch_size)
            sample_images(x_img, sess, n_batches, path=options.sample_images)
            coord.request_stop()
            coord.join(threads)
            return

        saver = tf.train.Saver(save_vars, max_to_keep=50)
        max_iter = int(options.iterations * 1000)

        n_epochs = max_iter / (n_samples / batch_size)
        print_info(
            "total iterations: %d (= %3.2f epochs)" % (max_iter, n_epochs),
            options.verbosity > 0)
        t0 = time.time()

        try:
            for cur_iter in range(
                    max_iter + 1
            ):  # +1 so we are more likely to get a model/stats line at the end
                sess.run(train_op)
                if (cur_iter > 0) and (cur_iter % options.checkpoint_every
                                       == 0):
                    saver.save(sess,
                               str(logdir / 'model'),
                               global_step=cur_iter)

                if cur_iter % options.stats_every == 0:
                    if dataset_type == 'image':
                        smry, xx_img = sess.run([merged_smry, x_img])
                        log.add_summary(smry, cur_iter)
                        images = sample_images(
                            x_img, sess,
                            n_batches=5 * 1024 // batch_size) * 255
                        mu_gen, sigma_gen = fid.calculate_activation_statistics(
                            images, sess, batch_size=128)
                        quality_measure = fid.calculate_frechet_distance(
                            mu_gen, sigma_gen, mu_fid, sigma_fid)
                        fig = plot_tiles(xx_img,
                                         10,
                                         10,
                                         local_norm="none",
                                         figsize=(6.6, 6.6))
                        fig.savefig(str(logdir / ('%09d.png' % cur_iter)))
                        plt.close(fig)
                    elif dataset_type == 'text':
                        smry = sess.run(merged_smry)
                        # Note: to compare with WGAN-GP, we can only take 5 samples since our batch size is 2x theirs
                        # and JSD improves a lot with larger samples size
                        sample_text_ = sample_text(x, sess, 5, inv_charmap,
                                                   logdir / 'samples',
                                                   cur_iter)
                        gen_ngram_model = ngram_language_model.NgramLanguageModel(
                            sample_text_, options.ngrams, len(charmap))
                        js = []
                        for i in range(options.ngrams):
                            js.append(
                                true_ngram_model.js_with(
                                    gen_ngram_model, i + 1))
                            #print('js%d' % (i+1), quality_measure[i])
                        quality_measure = js[3] if options.ngrams < 6 else (
                            str(js[3]) + '/' + str(js[5]))

                    s = (cur_iter, quality_measure, time.time() - t0, dataset,
                         run_name)
                    print_info("%9d  %s -- %3.2fs %s %s" % s,
                               options.verbosity > 0)
                    if trainlog:
                        print(', '.join([str(ss) for ss in s]),
                              file=trainlog,
                              flush=True)
                    log.add_summary(smry, cur_iter)

        except KeyboardInterrupt:
            saver.save(sess, str(logdir / 'model'), global_step=cur_iter)
        finally:
            if trainlog:
                trainlog.close()
            coord.request_stop()
            coord.join(threads)
        return
    def train(self):

        print("load train stats..", end="")
        # load precalculated training set statistics
        f = np.load(self.train_stats_file)
        mu_trn, sigma_trn = f['mu'][:], f['sigma'][:]
        f.close()
        print("ok")

        z_fixed = np.random.uniform(-1, 1, size=(self.batch_size, self.z_num))

        x_fixed = self.get_image_from_loader()
        save_image(x_fixed, '{}/x_fixed.png'.format(self.model_dir))

        prev_measure = 1
        measure_history = deque([0]*self.lr_update_step, self.lr_update_step)

        # load inference model
        fid.create_inception_graph("inception-2015-12-05/classify_image_graph_def.pb")

        #query_tensor = fid.get_Fid_query_tensor(self.sess)

        if self.load_checkpoint:
          if self.load(self.model_dir):
            print(" [*] Load SUCCESS")
          else:
            print(" [!] Load failed...")

        # Precallocate prediction array for kl/fid inception score
        #print("preallocate %.3f GB for prediction array.." % (self.eval_num_samples * 2048 / (1024**3)), end=" ", flush=True)
        inception_activations = np.ones([self.eval_num_samples, 2048])
        #print("ok")

        for step in trange(self.start_step, self.max_step):

            # Optimize
            self.sess.run([self.d_optim, self.g_optim])

            # Feed dict
            fetch_dict = {"measure": self.measure}

            if self.update_k:
              fetch_dict.update({"k_update": self.k_update})

            if step % self.log_step == 0:
                fetch_dict.update({
                    "summary": self.summary_op,
                    "g_loss": self.g_loss,
                    "d_loss": self.d_loss,
                    "k_t": self.k_t,
                })

            # Get summaries
            result = self.sess.run(fetch_dict)

            measure = result['measure']
            measure_history.append(measure)

            if step % self.log_step == 0:
                self.summary_writer.add_summary(result['summary'], step)
                self.summary_writer.flush()

                g_loss = result['g_loss']
                d_loss = result['d_loss']
                k_t = result['k_t']

                print("[{}/{}] Loss_D: {:.6f} Loss_G: {:.6f} measure: {:.4f}, k_t: {:.4f}". \
                      format(step, self.max_step, d_loss, g_loss, measure, k_t))

            if step % (self.log_step * 10) == 0:
                x_fake = self.generate(z_fixed, self.model_dir, idx=step)
                self.autoencode(x_fixed, self.model_dir, idx=step, x_fake=x_fake)

            if step % self.lr_update_step == self.lr_update_step - 1:
                self.sess.run([self.g_lr_update, self.d_lr_update])

            # FID
            if step % self.eval_step == 0:

              eval_batches_num = self.eval_num_samples // self.eval_batch_size

              for eval_batch in range(eval_batches_num):

                print("\rFID batch %d/%d" % (eval_batch + 1, eval_batches_num), end="", flush=True)

                sample_z_eval = np.random.uniform(-1, 1, size=(self.eval_batch_size, self.z_num))
                samples_eval = self.generate(sample_z_eval, self.model_dir, save=False)

                activations_batch = fid.get_activations(samples_eval,
                                                self.sess,
                                                batch_size=self.eval_batch_size,
                                                verbose=False)

                frm = eval_batch * self.eval_batch_size
                to = frm + self.eval_batch_size
                inception_activations[frm:to,:] = activations_batch

              print()

              # calculate FID
              print("FID:", end=" ", flush=True)
              try:
                mu_eval = np.mean(inception_activations, axis=0)
                sigma_eval = np.cov(inception_activations, rowvar=False)
                FID = fid.calculate_frechet_distance(mu_eval, sigma_eval, mu_trn, sigma_trn)
              except Exception as e:
                print(e)
                FID = 500
              print(FID)

              self.sess.run(tf.assign(self.fid, FID))
              summary_str = self.sess.run(self.fid_sum)
              self.summary_writer.add_summary(summary_str, step)
Beispiel #23
0
        mu, sigma = fid.calculate_activation_statistics(images,
                                                        sess,
                                                        batch_size=100)
        np.savez_compressed(args.stats_path, mu=mu, sigma=sigma)
    print("finished")
else:
    image_list = glob.glob(os.path.join(args.image_path, '*.jpg'))
    images = np.array(
        [imread(str(fn)).astype(np.float32) for fn in image_list])

    f = np.load(args.stats_path)
    mu_real, sigma_real = f['mu'][:], f['sigma'][:]
    f.close()

    fid.create_inception_graph(inception_path)
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        mu_gen, sigma_gen = fid.calculate_activation_statistics(images,
                                                                sess,
                                                                batch_size=100)

    fid_value = fid.calculate_frechet_distance(mu_gen, sigma_gen, mu_real,
                                               sigma_real)
    print("FID: %s" % fid_value)
    with open(args.output_file, "a") as f:
        print("\n",
              datetime.datetime.now().isoformat(),
              fid_value,
              end="\n ",
              file=f)
Beispiel #24
0
def fid_imgs(cfg):
    print("CALCULATING FID/KID scores")
    rnd_seed = 12345
    random.seed(rnd_seed)
    np.random.seed(rnd_seed)
    tf.compat.v2.random.set_seed(rnd_seed)
    tf.random.set_random_seed(rnd_seed)
    inception_path = fid.check_or_download_inception(
        None)  # download inception network

    # load precalculated training set statistics
    print("Loading stats from:", cfg.stats_filename, '  ...', end='')
    f = np.load(cfg.stats_filename)
    mu_real, sigma_real = f['mu'][:], f['sigma'][:]

    activations_ref = None
    if 'activations' in f:
        activations_ref = f['activations']
        print(" reference activations #:", activations_ref.shape[0])

    f.close()
    print("done")

    fid_epoch = 0
    epoch_info_file = cfg.exp_path + '/fid-epoch.txt'
    if os.path.isfile(epoch_info_file):
        fid_epoch = open(epoch_info_file, 'rt').read()
    else:
        print("ERROR: couldnot find file:", epoch_info_file)

    best_fid_file = cfg.exp_path + '/fid-best.txt'
    best_fid = 1e10
    if os.path.isfile(best_fid_file):
        best_fid = float(open(best_fid_file, 'rt').read())
        print("Best FID: " + str(best_fid))

    pr = None
    pr_file = cfg.exp_path + '/pr.txt'
    if os.path.isfile(pr_file):
        pr = open(pr_file).read()
        print("PR: " + str(pr))

    rec = []
    rec.append(fid_epoch)
    rec.append('nref:' + str(activations_ref.shape[0]))

    fid.create_inception_graph(
        inception_path)  # load the graph into the current TF graph
    dirs = cfg.image_path.split(',')
    first_fid = None
    for dir in dirs:
        print("Working on:", dir)
        test_name = dir.split('/')[-1]
        rec.append(test_name)
        # loads all images into memory (this might require a lot of RAM!)
        image_list = glob.glob(os.path.join(dir, '*.jpg'))
        image_list = image_list + glob.glob(os.path.join(dir, '*.png'))
        image_list.sort()
        print("Loading images:", len(image_list), '  ...', end='')
        images = np.array([
            imageio.imread(str(fn), as_gray=False,
                           pilmode="RGB").astype(np.float32)
            for fn in image_list
        ])
        print("done")

        print("Extracting features ", end='')
        os.environ['CUDA_VISIBLE_DEVICES'] = '1'
        with tf.compat.v1.Session() as sess:
            sess.run(tf.compat.v1.global_variables_initializer())
            mu_gen, sigma_gen, activations = fid.calculate_activation_statistics(
                images, sess, batch_size=BATCH_SIZE)
        print("Extracted activations:", activations.shape[0])
        rec.append('ntest:' + str(activations.shape[0]))

        if cfg.fid:
            # Calculate FID
            print("Calculating FID.....")
            fid_value = fid.calculate_frechet_distance(mu_gen, sigma_gen,
                                                       mu_real, sigma_real)
            rec.append('fid:' + str(fid_value))
            if first_fid is None:
                first_fid = fid_value

            if best_fid > first_fid and fid_epoch != 0:
                epoch = int(fid_epoch.split(' ')[0].split(':')[1])
                print("Storing best FID model. Epoch: " + str(epoch) +
                      "  Current FID: " + str(best_fid) + " new: " +
                      str(first_fid))
                best_fid = first_fid
                # Store best fid & weights
                with open(best_fid_file, 'wt') as f:
                    f.write(str(first_fid))
                model_file = cfg.exp_path + '/models/weights-' + str(
                    epoch) + '.cp'
                backup_model_file = cfg.exp_path + '/models/' + str(
                    epoch) + '.cp'
                os.system('cp ' + model_file + '  ' + backup_model_file)

        if cfg.kid:
            # Calculate KID
            # Parameters:
            print("Calculating KID...")
            mmd_degree = 3
            mmd_gamma = None
            mmd_coef0 = 1
            mmd_var = False
            mmd_subsets = 100
            mmd_subset_size = 1000

            ret = polynomial_mmd_averages(activations,
                                          activations_ref,
                                          degree=mmd_degree,
                                          gamma=mmd_gamma,
                                          coef0=mmd_coef0,
                                          ret_var=mmd_var,
                                          n_subsets=mmd_subsets,
                                          subset_size=mmd_subset_size)

            if mmd_var:
                mmd2s, vars = ret
            else:
                mmd2s = ret

            kid_value = mmd2s.mean()
            kid_value_std = mmd2s.std()
            rec.append('kid_mean:' + str(kid_value))
            rec.append('kid_std:' + str(kid_value_std))

        if cfg.psnr and test_name == 'reco':
            image_list = glob.glob(os.path.join(cfg.stats_path, '*.jpg'))
            image_list.sort()
            if len(image_list) == 0:
                print("No images in directory ", cfg.stats_path)
                return

            images_gt = np.array([
                imageio.imread(str(fn), as_gray=False,
                               pilmode="RGB").astype(np.float32)
                for fn in image_list
            ])
            print("%d images found and loaded" % len(images_gt))
            print("Calculating PSNR...")
            psnr_val = psnr(images_gt, images)
            print("Calculating SSIM...")
            ssim_val = ssim(images_gt, images)

            print('PSNR:', psnr_val, 'SSIM:', ssim_val)
            rec.append('psnr:' + str(psnr_val))
            rec.append('ssim:' + str(ssim_val))

        print(' '.join(rec))

    if pr is not None:
        rec.append(pr)

    print(' '.join(rec))

    # Write out results
    with open(cfg.exp_path + '/results.txt', 'a+') as f:
        f.write(' '.join(rec) + '\n')

    return first_fid
Beispiel #25
0
                  pc2pix_filename, "Elapsed :", elapsed_time)
            print(np.array(gt).shape)

    gt = np.array(gt)
    bl = np.array(bl)
    pc = np.array(pc)

    fid.create_inception_graph(
        inception_path)  # load the graph into the current TF graph
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        mu_gt, sigma_gt = fid.calculate_activation_statistics(gt, sess)
        mu_bl, sigma_bl = fid.calculate_activation_statistics(bl, sess)
        mu_pc, sigma_pc = fid.calculate_activation_statistics(pc, sess)

    fid_value = fid.calculate_frechet_distance(mu_bl, sigma_bl, mu_gt,
                                               sigma_gt)
    filename = "fid.log"
    fd = open(filename, "a+")
    fd.write("---| ")
    fd.write(args.split_file)
    fd.write(" |---\n")
    print("Surface FID: %s" % fid_value)
    fd.write("Surface FID: %s\n" % fid_value)
    fid_value = fid.calculate_frechet_distance(mu_pc, sigma_pc, mu_gt,
                                               sigma_gt)
    print("PC2PIX FID: %s" % fid_value)
    fd.write("PC2PIX FID: %s\n" % fid_value)
    fd.write("---\n")
    fd.close()
Beispiel #26
0
            q_score.append(
                np.mean([
                    np.any(item == dist_list[idx])
                    for item in np.concatenate(tlist)
                ]))

        acc_list.append([np.round(v * 100, 2) for v in q_score])
        np.save(log_dir + '/acc_score.npy', np.vstack(acc_list))
        print(acc_list[-1])

    if ((i) % args.cal_every == 0) and (i != 0) and fid_score_calc:

        fid_score = []
        for idx in range(args.num_agent):
            list_ = []
            for k in range(100):
                #for batch in in_list[idx]:
                list_.append(sess.run(feat, {sample: in_list[idx][k]}))
            list_ = np.reshape(np.concatenate(list_, 0), (10000, -1))
            mu_gen = np.mean(list_, axis=0)
            sigma_gen = np.cov(list_, rowvar=False)
            fid_score.append(
                np.round(
                    fid.calculate_frechet_distance(mu_gen, sigma_gen,
                                                   mu_real[idx],
                                                   sigma_real[idx]), 2))
        print(fid_score)
        in_list = []
        FID_list.append(fid_score)
        np.save(log_dir + '/fid_score.npy', np.vstack(FID_list))
def calculate_fid_from_stats(stats_A, stats_B):
    mu_A, sigma_A = stats_A
    mu_B, sigma_B = stats_B
    return fid.calculate_frechet_distance(mu_A, sigma_A, mu_B, sigma_B)
Beispiel #28
0
    def train(self, config):
        """Train DCGAN"""

        print("load train stats")
        # load precalculated training set statistics
        f = np.load(self.stats_path)
        mu_real, sigma_real = f['mu'][:], f['sigma'][:]
        f.close()

        # get querry tensor
        #query_tensor = self.querry_tensor #fid.get_Fid_query_tensor(self.sess)

        print("scan files", end=" ", flush=True)
        if config.dataset == 'mnist':
            data_X, data_y = self.load_mnist()
        else:
            data = glob(os.path.join(self.data_path, "*"))
        print("%d images found" % len(data))

        print("build model", end=" ", flush=True)
        # Train optimizers
        opt_d = tf.train.AdamOptimizer(config.learning_rate_d,
                                       beta1=config.beta1)
        opt_g = tf.train.AdamOptimizer(config.learning_rate_g,
                                       beta1=config.beta1)

        # Discriminator
        grads_and_vars = opt_d.compute_gradients(self.d_loss,
                                                 var_list=self.d_vars)
        grads = []
        d_optim = opt_d.apply_gradients(grads_and_vars)

        # Gradient summaries discriminator
        sum_grad_d = []
        for i, (grad, vars_) in enumerate(grads_and_vars):
            grad_l2 = tf.sqrt(tf.reduce_sum(tf.square(grad)))
            sum_grad_d.append(
                tf.summary.scalar("grad_l2_d_%d_%s" % (i, vars_.name),
                                  grad_l2))

        # Generator
        grads_and_vars = opt_g.compute_gradients(self.g_loss,
                                                 var_list=self.g_vars)
        g_optim = opt_g.apply_gradients(grads_and_vars)

        # Gradient summaries generator
        sum_grad_g = []
        for i, (grad, vars_) in enumerate(grads_and_vars):
            grad_l2 = tf.sqrt(tf.reduce_sum(tf.square(grad)))
            sum_grad_g.append(
                tf.summary.scalar("grad_l2_g_%d_%s" % (i, vars_.name),
                                  grad_l2))

        # Init:
        tf.global_variables_initializer().run()

        # Summaries
        self.g_sum = tf.summary.merge([
            self.z_sum, self.d_fake_sum, self.G_sum, self.d_loss_fake_sum,
            self.g_loss_sum, self.lrate_sum_g
        ] + sum_grad_g)
        self.d_sum = tf.summary.merge([
            self.z_sum, self.d_real_sum, self.d_loss_real_sum, self.d_loss_sum,
            self.lrate_sum_d
        ] + sum_grad_d)
        self.writer = tf.summary.FileWriter(self.log_dir, self.sess.graph)

        # Z sample
        sample_z = np.random.normal(0, 1.0, size=(self.sample_num, self.z_dim))

        # Input samples
        sample_files = data[0:self.sample_num]
        sample = [
            get_image(sample_file,
                      input_height=self.input_height,
                      input_width=self.input_width,
                      resize_height=self.output_height,
                      resize_width=self.output_width,
                      is_crop=self.is_crop,
                      is_grayscale=self.is_grayscale)
            for sample_file in sample_files
        ]
        if (self.is_grayscale):
            sample_inputs = np.array(sample).astype(np.float32)[:, :, :, None]
        else:
            sample_inputs = np.array(sample).astype(np.float32)

        print("ok")

        start_time = time.time()

        if self.load_checkpoint:
            if self.load(self.checkpoint_dir):
                print(" [*] Load SUCCESS")
            else:
                print(" [!] Load failed...")

        # Batch preparing
        batch_nums = min(len(data), config.train_size) // config.batch_size
        data_idx = list(range(len(data)))

        counter = self.counter_start

        # Loop over epochs
        for epoch in xrange(config.epoch):

            # Assign learning rates for d and g
            lrate = config.learning_rate_d * (config.lr_decay_rate_d**epoch)
            self.sess.run(tf.assign(self.learning_rate_d, lrate))
            lrate = config.learning_rate_g * (config.lr_decay_rate_g**epoch)
            self.sess.run(tf.assign(self.learning_rate_g, lrate))

            # Shuffle the data indices
            np.random.shuffle(data_idx)

            # Loop over batches
            for batch_idx in range(batch_nums):

                # Prepare batch
                idx = data_idx[batch_idx * config.batch_size:(batch_idx + 1) *
                               config.batch_size]
                batch = [
                    get_image(data[i],
                              input_height=self.input_height,
                              input_width=self.input_width,
                              resize_height=self.output_height,
                              resize_width=self.output_width,
                              is_crop=self.is_crop,
                              is_grayscale=self.is_grayscale) for i in idx
                ]
                if (self.is_grayscale):
                    batch_images = np.array(batch).astype(np.float32)[:, :, :,
                                                                      None]
                else:
                    batch_images = np.array(batch).astype(np.float32)

                batch_z = np.random.normal(
                    0, 1.0,
                    size=(config.batch_size, self.z_dim)).astype(np.float32)

                # Update D network
                _, summary_str = self.sess.run([d_optim, self.d_sum],
                                               feed_dict={
                                                   self.inputs: batch_images,
                                                   self.z: batch_z
                                               })
                if np.mod(counter, 20) == 0:
                    self.writer.add_summary(summary_str, counter)

                # Update G network
                _, summary_str = self.sess.run([g_optim, self.g_sum],
                                               feed_dict={self.z: batch_z})
                if np.mod(counter, 20) == 0:
                    self.writer.add_summary(summary_str, counter)

                errD_fake = self.d_loss_fake.eval({self.z: batch_z})
                errD_real = self.d_loss_real.eval({self.inputs: batch_images})
                errG = self.g_loss.eval({self.z: batch_z})

                if np.mod(counter, 100) == 0:
                    print("Epoch: [%2d] [%4d/%4d] time: %4.4f, d_loss: %.8f, g_loss: %.8f" \
                      % (epoch, batch_idx, batch_nums, time.time() - start_time, errD_fake+errD_real, errG))

                if np.mod(counter, 1000) == 0:

                    try:
                        samples, d_loss, g_loss = self.sess.run(
                            [self.sampler, self.d_loss, self.g_loss],
                            feed_dict={
                                self.z: sample_z,
                                self.inputs: sample_inputs
                            })
                        save_images(
                            samples, [8, 8],
                            '{}/train_{:02d}_{:04d}.png'.format(
                                config.sample_dir, epoch, batch_idx))
                        print("[Sample] d_loss: %.8f, g_loss: %.8f" %
                              (d_loss, g_loss))
                    except Exception as e:
                        print(e)
                        print("sample image error!")

                    # FID
                    print("samples for incept")

                    #self.fid_sample_batchsize=fid_sample_batchsize
                    samples = np.zeros((self.fid_n_samples, 64, 64, 3))
                    n_batches = self.fid_n_samples / self.fid_sample_batchsize
                    lo = 0
                    for btch in range(int(n_batches)):
                        sample_z_dist = np.random.normal(
                            0,
                            1.0,
                            size=(self.fid_sample_batchsize, self.z_dim))
                        samples[lo:(
                            lo + self.fid_sample_batchsize)] = self.sess.run(
                                self.sampler_dist,
                                feed_dict={self.z_dist: sample_z_dist})
                        lo += self.fid_sample_batchsize

                    samples = (samples + 1.) * 127.5
                    #predictions = fid.get_predictions( samples,
                    #                                 query_tensor,
                    #                                 self.sess,
                    #                                 batch_size=self.fid_batch_size,
                    #                                 verbose=self.fid_verbose)
                    #FID=None
                    mu_gen, sigma_gen = fid.calculate_activation_statistics(
                        samples,
                        self.sess,
                        batch_size=self.fid_batch_size,
                        verbose=self.fid_verbose)
                    try:
                        #FID,_,_ = fid.FID(mu_trn, sigma_trn, self.sess)
                        FID = fid.calculate_frechet_distance(
                            mu_gen, sigma_gen, mu_real, sigma_real)
                        print("FID = " + str(FID))
                    except Exception as e:
                        print("Exception: " + str(e) + " FID is set to 500.")
                        FID = 500

                    self.sess.run(tf.assign(self.fid, FID))
                    summary_str = self.sess.run(self.fid_sum)
                    self.writer.add_summary(summary_str, counter)

                # Save checkpoint
                if (counter != 0) and (np.mod(counter, 2000) == 0):
                    self.save(config.checkpoint_dir, counter)

                counter += 1
Beispiel #29
0
 def compute(images):
     m, s = fid.calculate_activation_statistics(
         np.array(images), inception_sess, args.batch_size, verbose=True)
     return fid.calculate_frechet_distance(m, s, mu0, sig0)
    def calculate_fid(self):
        import fid, pickle
        import tensorflow as tf

        stats_path = "fid_stats_cifar10_train.npz"  # training set statistics
        inception_path = fid.check_or_download_inception(
            "./tmp/"
        )  # download inception network

        score = get_model(self.config)
        score = torch.nn.DataParallel(score)

        sigmas_th = get_sigmas(self.config)
        sigmas = sigmas_th.cpu().numpy()

        fids = {}
        for ckpt in tqdm.tqdm(
            range(
                self.config.fast_fid.begin_ckpt, self.config.fast_fid.end_ckpt + 1, 5000
            ),
            desc="processing ckpt",
        ):
            states = torch.load(
                os.path.join(self.args.log_path, f"checkpoint_{ckpt}.pth"),
                map_location=self.config.device,
            )

            if self.config.model.ema:
                ema_helper = EMAHelper(mu=self.config.model.ema_rate)
                ema_helper.register(score)
                ema_helper.load_state_dict(states[-1])
                ema_helper.ema(score)
            else:
                score.load_state_dict(states[0])

            score.eval()

            num_iters = (
                self.config.fast_fid.num_samples // self.config.fast_fid.batch_size
            )
            output_path = os.path.join(self.args.image_folder, "ckpt_{}".format(ckpt))
            os.makedirs(output_path, exist_ok=True)
            for i in range(num_iters):
                init_samples = torch.rand(
                    self.config.fast_fid.batch_size,
                    self.config.data.channels,
                    self.config.data.image_size,
                    self.config.data.image_size,
                    device=self.config.device,
                )
                init_samples = data_transform(self.config, init_samples)

                all_samples = anneal_Langevin_dynamics(
                    init_samples,
                    score,
                    sigmas,
                    self.config.fast_fid.n_steps_each,
                    self.config.fast_fid.step_lr,
                    verbose=self.config.fast_fid.verbose,
                )

                final_samples = all_samples[-1]
                for id, sample in enumerate(final_samples):
                    sample = sample.view(
                        self.config.data.channels,
                        self.config.data.image_size,
                        self.config.data.image_size,
                    )

                    sample = inverse_data_transform(self.config, sample)

                    save_image(
                        sample, os.path.join(output_path, "sample_{}.png".format(id))
                    )

            # load precalculated training set statistics
            f = np.load(stats_path)
            mu_real, sigma_real = f["mu"][:], f["sigma"][:]
            f.close()

            fid.create_inception_graph(
                inception_path
            )  # load the graph into the current TF graph
            final_samples = (
                (final_samples - final_samples.min())
                / (final_samples.max() - final_samples.min()).data.cpu().numpy()
                * 255
            )
            final_samples = np.transpose(final_samples, [0, 2, 3, 1])
            with tf.Session() as sess:
                sess.run(tf.global_variables_initializer())
                mu_gen, sigma_gen = fid.calculate_activation_statistics(
                    final_samples, sess, batch_size=100
                )

            fid_value = fid.calculate_frechet_distance(
                mu_gen, sigma_gen, mu_real, sigma_real
            )
            print("FID: %s" % fid_value)

        with open(os.path.join(self.args.image_folder, "fids.pickle"), "wb") as handle:
            pickle.dump(fids, handle, protocol=pickle.HIGHEST_PROTOCOL)