def main(): # Paths pic_path = os.path.join('./out/c/', 'checkpoints', 'dummy.samples%d.npy' % (NUM_POINTS)) image_path = 'results_celeba/generated' # set path to some generated images stats_path = 'fid_stats_celeba.npz' # training set statistics inception_path = fid.check_or_download_inception( None) # download inception network # load precalculated training set statistics f = np.load(stats_path) mu_real, sigma_real = f['mu'][:], f['sigma'][:] f.close() #image_list = glob.glob(os.path.join(image_path, '*.png')) #images = np.array([imread(str(fn)).astype(np.float32) for fn in image_list]) images = np.load(pic_path) images_t = images / 2.0 + 0.5 images_t = 255.0 * images_t from PIL import Image img = Image.fromarray(np.uint8(images_t[0]), 'RGB') img.save('my.png') fid.create_inception_graph( inception_path) # load the graph into the current TF graph with tf.Session() as sess: sess.run(tf.global_variables_initializer()) mu_gen, sigma_gen = fid.calculate_activation_statistics(images, sess) fid_value = fid.calculate_frechet_distance(mu_gen, sigma_gen, mu_real, sigma_real) print("FID: %s" % fid_value)
def compute(images): m, s = fid.calculate_activation_statistics(transform_for_fid(images), inception_sess, args.batch_size, verbose=True, model='lenet') return fid.calculate_frechet_distance(m, s, mu0, sig0)
def main(): inception_path = None print("check for inception model..", end=" ", flush=True) inception_path = fid.check_or_download_inception( inception_path) # download inception if necessary print("ok") # loads all images into memory (this might require a lot of RAM!) print("load images..", end=" ", flush=True) data_files = glob.glob(os.path.join("./img_align_celeba", "*.jpg")) data_files = sorted(data_files)[:10000] data_files = np.array(data_files) images = np.array([get_image(data_file, 148) for data_file in data_files]).astype(np.float32) images = images * 255 output_name = 'fid_stats_face' print("create inception graph..", end=" ", flush=True) fid.create_inception_graph( inception_path) # load the graph into the current TF graph print("ok") print("calculte FID stats..", end=" ", flush=True) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) mu, sigma = fid.calculate_activation_statistics(images, sess, batch_size=100) np.savez_compressed(output_name, mu=mu, sigma=sigma) print("finished")
def train(self, num_batches=200000): self.sess.run(tf.global_variables_initializer()) mean_list = [] fid_list = [] start_time = time.time() for t in range(0, num_batches): for _ in range(0, self.d_iters): bx = self.x_sampler(self.batch_size) bz = self.z_sampler(self.batch_size, self.z_dim) self.sess.run(self.d_adam, feed_dict={self.x: bx, self.z: bz}) bx = self.x_sampler(self.batch_size) bz = self.z_sampler(self.batch_size, self.z_dim) self.sess.run([self.g_adam], feed_dict={self.z: bz, self.x: bx}) if t % 1000 == 0: bx = self.x_sampler(self.batch_size) bz = self.z_sampler(self.batch_size, self.z_dim) dl, gl, gp, x_ = self.sess.run( [self.d_loss, self.g_loss, self.gp_loss, self.x_], feed_dict={ self.x: bx, self.z: bz }) print('Iter [%8d] Time [%.4f] dl [%.4f] gl [%.4f] gp [%.4f]' % (t, time.time() - start_time, dl, gl, gp)) x_ = self.x_sampler.data2img(x_) x_ = grid_transform(x_, self.x_sampler.shape) imsave(self.log_dir + '/wos/{}.png'.format(int(t)), x_) if t % 10000 == 0 and t > 0: in_list = [] for _ in range(int(50000 / self.batch_size)): bz = self.z_sampler(self.batch_size, self.z_dim) x_ = self.sess.run(self.x_, feed_dict={self.z: bz}) x_ = self.x_sampler.data2img(x_) bx_list = np.split(x_, self.batch_size) in_list = in_list + [np.squeeze(x) for x in bx_list] mean, std = self.inception.get_inception_score(in_list, splits=10) mean_list.append(mean) np.save(self.log_dir + '/inception_score_wgan_gp.npy', np.asarray(mean_list)) print('inception score [%.4f]' % (mean)) if t % 10000 == 0 and t > 0 and args.fid: f = np.load(self.stats_path) mu_real, sigma_real = f['mu'][:], f['sigma'][:] f.close() mu_gen, sigma_gen = fid.calculate_activation_statistics( np.array(in_list[:10000]), self.sess, batch_size=100) fid_value = fid.calculate_frechet_distance( mu_gen, sigma_gen, mu_real, sigma_real) print("FID: %s" % fid_value) fid_list.append(fid_value) np.save(self.log_dir + '/fid_score_wgan_gp.npy', np.asarray(fid_list))
def calculate_stats(imageset, batch_size=default_batchsize, printTime=False): ### # inception-net expects shape (n,w,h,3) ### # if shape == (n,w,h), reshape to (n,w,h,1) if len(imageset.shape) < 4: # need shape: (n, height, width, 3) # reshape to (n,h,w,1) shape = list(imageset.shape) shape.append(1) imageset = np.reshape(imageset, shape) ### # if shape == (n,w,h,1), duplicate channel: -> (n,w,h,3) if imageset.shape[3] == 1: # repeat channel imageset = np.repeat(imageset, 3, axis=-1) starttime = time() with tf_v1.Session() as sess: sess.run(tf_v1.global_variables_initializer()) mu, sigma = fid.calculate_activation_statistics(imageset, sess, batch_size=batch_size) if printTime: print("calculating rlts took %f seconds" % (time() - starttime)) return (mu, sigma)
def load_fid(mnist_test_images, args, binarize=True): import fid def transform_for_fid(im): assert len(im.shape) == 2 and im.dtype == np.float32 if binarize: im = (im > np.random.random(size=im.shape)).astype(np.float32) a = np.array(im) - 0.5 return a.reshape((-1, 28, 28, 1)) inception_path = os.path.expanduser('~/lenet/savedmodel') inception_graph = tf.Graph() config = tf.ConfigProto() config.gpu_options.allow_growth = True inception_sess = tf.Session(config=config, graph=inception_graph) with inception_graph.as_default(): tf.saved_model.loader.load(inception_sess, [tf.saved_model.tag_constants.TRAINING], inception_path) mu0, sig0 = fid.calculate_activation_statistics( transform_for_fid(mnist_test_images), inception_sess, args.batch_size, verbose=True, model='lenet') def compute(images): m, s = fid.calculate_activation_statistics(transform_for_fid(images), inception_sess, args.batch_size, verbose=True, model='lenet') return fid.calculate_frechet_distance(m, s, mu0, sig0) return compute, locals()
def compute_fid(self, images, inception_path, batch_size=100): g = tf.Graph() with g.as_default(): fid.create_inception_graph( inception_path) # load the graph into the current TF graph sess = tf.Session(graph=g) mu, sigma = fid.calculate_activation_statistics( images, sess, batch_size=batch_size) sess.close() return mu, sigma
def fid_ms_for_imgs(images, mem_fraction=0.5): gpu_options = tf.GPUOptions( per_process_gpu_memory_fraction=mem_fraction) inception_path = fid.check_or_download_inception(None) fid.create_inception_graph( inception_path) # load the graph into the current TF graph with tf.Session(config=tf.ConfigProto( gpu_options=gpu_options)) as sess: sess.run(tf.global_variables_initializer()) mu_gen, sigma_gen = fid.calculate_activation_statistics( images, sess, batch_size=100) return mu_gen, sigma_gen
def _get_statistics(stat_root, data, image_shape, inception_sess): os.makedirs(stat_root, exist_ok=True) mu_path = os.path.join(stat_root, 'ac_mu.npy') sigma_path = os.path.join(stat_root, 'ac_sigma.npy') if os.path.exists(mu_path) and os.path.exists(sigma_path): print('Using cached activation statistics') mu = np.load(mu_path) sigma = np.load(sigma_path) else: image = _maybe_grayscale_to_rgb(np.reshape(data, (-1, ) + image_shape)) image = (image + 1.0) / 2.0 * 255.0 mu, sigma = fid.calculate_activation_statistics(image, inception_sess) np.save(mu_path, mu) np.save(sigma_path, sigma) return mu, sigma
def main(model, data_source, noise_method, noise_factors, lambdas): """ model: RVAE or VAE data_source: data set of training. Either 'MNIST' or 'FASHION' noise_method: method of adding noise. Either 'sp' (represents salt-and-pepper) or 'gs' (represents Gaussian) noise_factors: noise factors lambdas: lambda """ input_path = "../output/"+model+"_"+data_source+"_"+noise_method+"/" inception_path = None print("check for inception model..", end=" ", flush=True) inception_path = fid.check_or_download_inception(inception_path) # download inception if necessary print("ok") # loads all images into memory (this might require a lot of RAM!) print("load images..", end=" " , flush=True) output_path = "fid_precalc/" if not os.path.exists(output_path): os.mkdir(output_path) output_path = output_path+model+"_"+data_source+"_"+noise_method+"/" if not os.path.exists(output_path): os.mkdir(output_path) for l in lambdas: for nr in noise_factors: if model == 'RVAE': data_path = input_path+'lambda_'+str(l)+'/noise_'+str(nr)+'/generation_fid.npy' output_name = 'fid_stats_lambda_'+str(l)+'noise_'+str(nr) else: data_path = input_path+str(nr)+'/generation_fid.npy' output_name = 'fid_stats_noise_'+str(nr) images = np.load(data_path)[:10000] images = np.stack((((images*255)).reshape(-1,28,28),)*3,axis=-1) print("create inception graph..", end=" ", flush=True) fid.create_inception_graph(inception_path) # load the graph into the current TF graph print("ok") print("calculte FID stats..", end=" ", flush=True) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) mu, sigma = fid.calculate_activation_statistics(images, sess, batch_size=100) np.savez_compressed(output_path+output_name, mu=mu, sigma=sigma) print("finished")
def compute_fid_score(self, generator, timestamp): """ Computes FID of generator using fixed noise dataset; appends the current score to the list of computed scores; and overwrites the json file that logs the fid scores. :param generator: [nn.Module] :param timestamp: [int] :return: None """ generator.eval() fake_samples = np.empty( (self.sample_size_fid, self.imsize, self.imsize, 3)) for j, noise in enumerate(self.fid_noise_loader): noise = noise.cuda() i1 = j * 200 # batch_size = 200 i2 = i1 + noise.size(0) samples = generator(noise).cpu().data.add(1).mul(255 / 2.0) fake_samples[i1:i2] = samples.permute(0, 2, 3, 1).numpy() generator.train() mu_g, sigma_g = fid.calculate_activation_statistics(fake_samples, self.fid_session, batch_size=100) fid_score = fid.calculate_frechet_distance(mu_g, sigma_g, self.mu_real, self.sigma_real) _result = { 'entry': len(self.fid_scores), 'iter': timestamp, 'fid': fid_score } # if best update the checkpoint in self.best_path new_best = True for prev_fid in self.fid_scores: if prev_fid['fid'] < fid_score: new_best = False break if new_best: self.backup(timestamp, dir=self.best_path) self.fid_scores.append(_result) with open(self.fid_json_file, 'w') as _f_fid: json.dump(self.fid_scores, _f_fid, sort_keys=True, indent=4, separators=(',', ': '))
def generate(args): config = tf.ConfigProto() config.gpu_options.allow_growth = True if not os.path.exists(CIFAR_STATS_PATH): print('Generating FID statistics for test set...') print('Building Inception graph') with tf.Session(config=config) as sess: inception_path = fid.check_or_download_inception(INCEPTION_PATH) fid.create_inception_graph(str(inception_path)) ds = datasets.load_cifar10(True) all_test_set = (ds.test.images + 1) * 128 print(all_test_set.shape) m, s = fid.calculate_activation_statistics( all_test_set, sess, args.batch_size, verbose=True) np.savez(CIFAR_STATS_PATH, mu=m, sigma=s) print('Done') root_dir = os.path.dirname(args.dir) args_json = json.load(open(os.path.join(root_dir, 'hps.txt'))) ckpt_dir = args.dir vars(args).update(args_json) model_graph = tf.Graph() with model_graph.as_default(): x_ph, is_training_ph, model, optimizer, batch_size_sym, z_sample_sym, x_sample_sym = build_graph(args) saver = tf.compat.v1.train.Saver(keep_checkpoint_every_n_hours=3, max_to_keep=6) model_sess = tf.Session(config=config, graph=model_graph) print('RESTORING MODEL FROM', ckpt_dir) saver.restore(model_sess, ckpt_dir) compute_fid, _ = load_fid(args) images = [] for j in range(100): x_samples = model_sess.run(x_sample_sym, {batch_size_sym: 100, is_training_ph: False}) x_samples = (np.clip(x_samples, -1, 1) + 1) / 2 * 256 images.extend(x_samples) fscore = compute_fid(images) print('FID score = {}'.format(fscore)) dest = os.path.join(root_dir, 'generated') if not os.path.exists(dest): os.makedirs(dest) for j, im in enumerate(images): plt.imsave(os.path.join(dest, '{}.png'.format(j)), im/256)
def main(model, noise_factors, lambdas): """ model: RVAE or VAE noise_factors: noise factors lambdas: lambda """ input_path = model inception_path = None print("check for inception model..", end=" ", flush=True) inception_path = fid.check_or_download_inception( inception_path) # download inception if necessary print("ok") # loads all images into memory (this might require a lot of RAM!) print("load images..", end=" ", flush=True) output_path = "fid_precalc/" if not os.path.exists(output_path): os.mkdir(output_path) for l in lambdas: for nr in noise_factors: data_path = input_path + 'lambda_' + str(l) + '/noise_' + str( nr) + '/generation_fid.npy' output_name = 'fid_stats_lambda_' + str(l) + 'noise_' + str(nr) images = np.load(data_path) images = np.transpose(images * 255, (0, 2, 3, 1)) #images = np.stack((((images*255)).reshape(-1,28,28),)*3,axis=-1) print("create inception graph..", end=" ", flush=True) fid.create_inception_graph( inception_path) # load the graph into the current TF graph print("ok") print("calculte FID stats..", end=" ", flush=True) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) mu, sigma = fid.calculate_activation_statistics(images, sess, batch_size=100) np.savez_compressed(output_path + output_name, mu=mu, sigma=sigma) print("finished")
def main(args): device = 'cuda' print('Loading ResNext101 model...') model = nn.DataParallel(resnet101(sample_duration=16).cuda()) model.load_state_dict(torch.load('resnext-101-kinetics.pth')['state_dict']) print('Loading video paths...') if args.dataset == 'uva': files = glob.glob(args.data_path + '/*.mp4') data_type = 'video' else: raise NotImplementedError mu, sigma = fid.calculate_activation_statistics(files, data_type, model, args.batch_size, args.size, args.length, args.dims, device) np.savez_compressed('./stats/'+args.dataset+'.npz', mu=mu, sigma=sigma) print('finished')
def load_fid(dtest, args): import fid def transform_for_fid(im): assert len(im.shape) == 4 and im.dtype == np.float32 if im.shape[-1] == 1: assert im.shape[-2] == 28 im = np.tile(im, [1, 1, 1, 3]) if not (im.std() < 1. and im.min() > -1.): print('WARNING: abnormal image range', im.std(), im.min()) return (im + 1) * 128 inception_path = fid.check_or_download_inception(INCEPTION_PATH) inception_graph = tf.Graph() with inception_graph.as_default(): fid.create_inception_graph(str(inception_path)) config = tf.ConfigProto() config.gpu_options.allow_growth = True inception_sess = tf.Session(config=config, graph=inception_graph) stats_path = os.path.join(INCEPTION_PATH, f'{args.dataset}-stats.npz') if not os.path.exists(stats_path): mu0, sig0 = fid.calculate_activation_statistics( transform_for_fid(dtest), inception_sess, args.batch_size, verbose=True) np.savez(stats_path, mu0=mu0, sig0=sig0) else: sdict = np.load(stats_path) mu0, sig0 = sdict['mu0'], sdict['sig0'] def compute(images): m, s = fid.calculate_activation_statistics(transform_for_fid(images), inception_sess, args.batch_size, verbose=True) return fid.calculate_frechet_distance(m, s, mu0, sig0) return compute, locals()
def _run_fid_calculation(sess, inception_sess, placeholders, batch_size, iteration, generator, mu, sigma, epoch, image_shape, z_input_shape, y_input_shape): f = 0.0 for _ in range(iteration): z = util.gen_random_noise(batch_size, z_input_shape) y = util.gen_random_label(batch_size, y_input_shape[0]) images = sess.run( tf.reshape(generator, (-1, ) + image_shape), { placeholders['z']: z, placeholders['y']: y, placeholders['mode']: False, }) images = _maybe_grayscale_to_rgb(images) images = (images + 1.0) / 2.0 * 255.0 mu_gen, sigma_gen = fid.calculate_activation_statistics( images, inception_sess) f += fid.calculate_frechet_distance(mu, sigma, mu_gen, sigma_gen) return f / iteration
def precalc(data_path, output_path): print("CALCULATING THE GT STATS....") # data_path = 'reconstructed_test/eval' # set path to training set images # output_path = data_path+'/fid_stats.npz' # path for where to store the statistics # if you have downloaded and extracted # http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz # set this path to the directory where the extracted files are, otherwise # just set it to None and the script will later download the files for you inception_path = None print("check for inception model..", end=" ", flush=True) inception_path = fid.check_or_download_inception( inception_path) # download inception if necessary print("ok") # loads all images into memory (this might require a lot of RAM!) print("load images..", end=" ", flush=True) image_list = glob.glob(os.path.join(data_path, '*.jpg')) if len(image_list) == 0: print("No images in directory ", data_path) return images = np.array([ imageio.imread(str(fn), as_gray=False, pilmode="RGB").astype(np.float32) for fn in image_list ]) print("%d images found and loaded" % len(images)) print("create inception graph..", end=" ", flush=True) fid.create_inception_graph( inception_path) # load the graph into the current TF graph print("ok") print("calculte FID stats..", end=" ", flush=True) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) mu, sigma, acts = fid.calculate_activation_statistics( images, sess, batch_size=BATCH_SIZE) np.savez_compressed(output_path, mu=mu, sigma=sigma, activations=acts) print("finished")
t += 1 elapsed_time = datetime.datetime.now() - start_time print(str(t), "/", test_len, ": ", blender_filename, pc2pix_filename, "Elapsed :", elapsed_time) print(np.array(gt).shape) gt = np.array(gt) bl = np.array(bl) pc = np.array(pc) fid.create_inception_graph( inception_path) # load the graph into the current TF graph with tf.Session() as sess: sess.run(tf.global_variables_initializer()) mu_gt, sigma_gt = fid.calculate_activation_statistics(gt, sess) mu_bl, sigma_bl = fid.calculate_activation_statistics(bl, sess) mu_pc, sigma_pc = fid.calculate_activation_statistics(pc, sess) fid_value = fid.calculate_frechet_distance(mu_bl, sigma_bl, mu_gt, sigma_gt) filename = "fid.log" fd = open(filename, "a+") fd.write("---| ") fd.write(args.split_file) fd.write(" |---\n") print("Surface FID: %s" % fid_value) fd.write("Surface FID: %s\n" % fid_value) fid_value = fid.calculate_frechet_distance(mu_pc, sigma_pc, mu_gt, sigma_gt) print("PC2PIX FID: %s" % fid_value)
# '%s/Epoch_(%d)_(%dof%d)_img_rec.png' % (save_dir, ep, it_in_epoch, it_per_epoch)) im.imwrite( im.immerge(img_intp_opt_sample, n_col=1, padding=0), '%s/Epoch_(%d)_(%dof%d)_img_intp.png' % (save_dir, ep, it_in_epoch, it_per_epoch)) im.imwrite( im.immerge(img_opt_sample), '%s/Epoch_(%d)_(%dof%d)_img_sample.png' % (save_dir, ep, it_in_epoch, it_per_epoch)) if fid_stats_path: try: mu_gen, sigma_gen = fid.calculate_activation_statistics( im.im2uint( np.concatenate([ sess.run(fid_sample).squeeze() for _ in range(5) ], 0)), sess, batch_size=100) fid_value = fid.calculate_frechet_distance( mu_gen, sigma_gen, mu_real, sigma_real) except: fid_value = -1. fid_summary = tf.Summary() fid_summary.value.add(tag='FID', simple_value=fid_value) summary_writer.add_summary(fid_summary, it) print("FID: %s" % fid_value) save_path = saver.save(sess, '%s/Epoch_%d.ckpt' % (ckpt_dir, ep)) print('Model is saved in file: %s' % save_path) except:
image_list = h5py.File(os.path.join(data_path, 'test_x.h5'), 'r', swmr=True)['x'] # print(image_list[10]) # exit(0) # images = np.array([files[index, ...].astype(np.float32) for index in range(len(files))]) output_path = 'fid_stats_pcam.npz' # path for where to store the statistics else: print("Unsupported dataset") sys.exit(1) print("%d images found" % len(image_list)) # These are not fully read into memory though print("create inception graph..", end=" ", flush=True) fid.create_inception_graph( inception_path) # load the graph into the current TF graph print("ok") print("calculte FID stats..", end=" ", flush=True) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) if config.dataset == "CelebA": mu, sigma = fid.calculate_activation_statistics(image_list, sess, batch_size=100) elif config.dataset == "PCam": mu, sigma = fid.calculate_activation_statistics(image_list, sess, batch_size=100) np.savez_compressed(output_path, mu=mu, sigma=sigma) print("finished")
end="", flush=True) frm = i * FID_SAMPLE_BATCH_SIZE to = frm + FID_SAMPLE_BATCH_SIZE samples[frm:to] = session.run(Generator(FID_SAMPLE_BATCH_SIZE)) # Cast, reshape and transpose (BCHW -> BHWC) samples = ((samples + 1.0) * 127.5).astype('uint8') samples = samples.reshape(FID_EVAL_SIZE, 3, DIM, DIM) samples = samples.transpose(0, 2, 3, 1) print("ok") mu_gen, sigma_gen = fid.calculate_activation_statistics( samples, session, batch_size=FID_BATCH_SIZE, verbose=True) print("calculate FID:", end=" ", flush=True) try: FID = fid.calculate_frechet_distance(mu_gen, sigma_gen, mu_real, sigma_real) except Exception as e: print(e) FID = 500 print(FID) session.run(tf.assign(fid_tfvar, FID)) summary_str = session.run(fid_sum) writer.add_summary(summary_str, iteration)
def calculate_fid(self): import fid import tensorflow as tf num_of_step = 500 bs = 100 sigmas = np.exp( np.linspace(np.log(self.config.model.sigma_begin), np.log(self.config.model.sigma_end), self.config.model.num_classes)) stats_path = 'fid_stats_cifar10_train.npz' # training set statistics inception_path = fid.check_or_download_inception( None) # download inception network print('Load checkpoint from' + self.args.log) #for epochs in range(140000, 200001, 1000): for epochs in [149000]: states = torch.load(os.path.join( self.args.log, 'checkpoint_' + str(epochs) + '.pth'), map_location=self.config.device) #states = torch.load(os.path.join(self.args.log, 'checkpoint.pth'), map_location=self.config.device) score = CondRefineNetDilated(self.config).to(self.config.device) score = torch.nn.DataParallel(score) score.load_state_dict(states[0]) score.eval() if self.config.data.dataset == 'MNIST': print("Begin epochs", epochs) samples = torch.rand(bs, 1, 28, 28, device=self.config.device) all_samples = self.anneal_Langevin_dynamics_GenerateImages( samples, score, sigmas, 100, 0.00002) images = all_samples.mul_(255).add_(0.5).clamp_( 0, 255).permute(0, 2, 3, 1).to('cpu').numpy() for j in range(num_of_step - 1): samples = torch.rand(bs, 3, 32, 32, device=self.config.device) all_samples = self.anneal_Langevin_dynamics_GenerateImages( samples, score, sigmas, 100, 0.00002) images_new = all_samples.mul_(255).add_(0.5).clamp_( 0, 255).permute(0, 2, 3, 1).to('cpu').numpy() images = np.concatenate((images, images_new), axis=0) else: print("Begin epochs", epochs) samples = torch.rand(bs, 3, 32, 32, device=self.config.device) all_samples = self.anneal_Langevin_dynamics_GenerateImages( samples, score, sigmas, 100, 0.00002) images = all_samples.mul_(255).add_(0.5).clamp_( 0, 255).permute(0, 2, 3, 1).to('cpu').numpy() for j in range(num_of_step - 1): samples = torch.rand(bs, 3, 32, 32, device=self.config.device) all_samples = self.anneal_Langevin_dynamics_GenerateImages( samples, score, sigmas, 100, 0.00002) images_new = all_samples.mul_(255).add_(0.5).clamp_( 0, 255).permute(0, 2, 3, 1).to('cpu').numpy() images = np.concatenate((images, images_new), axis=0) # load precalculated training set statistics f = np.load(stats_path) mu_real, sigma_real = f['mu'][:], f['sigma'][:] f.close() fid.create_inception_graph( inception_path) # load the graph into the current TF graph with tf.Session() as sess: sess.run(tf.global_variables_initializer()) mu_gen, sigma_gen = fid.calculate_activation_statistics( images, sess, batch_size=100) fid_value = fid.calculate_frechet_distance(mu_gen, sigma_gen, mu_real, sigma_real) print("FID: %s" % fid_value)
# loads all images into memory (this might require a lot of RAM!) print("load images..", end=" ", flush=True) if args.images_png_path is not None: image_list = [] for path in args.images_png_path: image_list.extend(glob.glob(os.path.join(args.data_path, '*.png'))) images = np.array( [imread(str(fn)).astype(np.float32) for fn in image_list]) elif args.images_npy_path is not None: images = [] for path in args.images_npy_path: with open(path, 'rb') as f: images.append(load_nist_images(np.load(f))) images = np.vstack(images) print(images.shape) print("%d images found and loaded" % len(images)) print("create inception graph..", end=" ", flush=True) fid.create_inception_graph( inception_path) # load the graph into the current TF graph print("ok") print("calculte FID stats..", end=" ", flush=True) with tf.compat.v1.Session() as sess: sess.run(tf.compat.v1.global_variables_initializer()) mu, sigma = fid.calculate_activation_statistics(images, sess, batch_size=256) np.savez_compressed(args.output_path, mu=mu, sigma=sigma) print("finished")
#stats_path = '/home/minje/dev/dataset/stl/fid_stats_stl10.npz' # training set statistics (maybe pre-calculated) inception_path = fid.check_or_download_inception(None) # download inception network # precalculate training set statistics # #image_files = glob.glob(os.path.join('/home/minje/dev/dataset/cifar/cifar-10-images', '*.jpg')) # image_files = glob.glob(os.path.join('/home/minje/dev/dataset/stl/images', '*.jpg')) # fid.create_inception_graph(inception_path) # with tf.Session() as sess: # sess.run(tf.global_variables_initializer()) # mu_real, sigma_real = fid.calculate_activation_statistics_from_files(image_files, sess, # batch_size=100, verbose=True) # np.savez(stats_path, mu=mu_real, sigma=sigma_real) # exit(0) # loads all images into memory (this might require a lot of RAM!) image_files = glob.glob(os.path.join(image_path, '*.jpg')) images = np.array([imread(str(fn)).astype(np.float32) for fn in image_files]) # load precalculated training set statistics f = np.load(stats_path) mu_real, sigma_real = f['mu'][:], f['sigma'][:] f.close() fid.create_inception_graph(inception_path) # load the graph into the current TF graph with tf.Session() as sess: sess.run(tf.global_variables_initializer()) mu_gen, sigma_gen = fid.calculate_activation_statistics(images, sess, batch_size=100, verbose=True) fid_value = fid.calculate_frechet_distance(mu_gen, sigma_gen, mu_real, sigma_real) print("FID: %s" % fid_value)
def train(self, config): """Train DCGAN""" print("load train stats.. ", end="", flush=True) # load precalculated training set statistics f = np.load(self.stats_path) mu_real, sigma_real = f['mu'][:], f['sigma'][:] f.close() print("ok") if config.dataset == 'mnist': print("scan files", end=" ", flush=True) data_X, data_y = self.load_mnist() else: if (config.dataset == "celebA") or (config.dataset == "cifar10"): print("scan files", end=" ", flush=True) data = glob( os.path.join(self.data_path, self.input_fname_pattern)) else: if config.dataset == "lsun": print("scan files") data = [] for i in range(304): print("\r%d" % i, end="", flush=True) data += glob( os.path.join(self.data_path, str(i), self.input_fname_pattern)) else: print( "Please specify dataset in run.sh [mnist, celebA, lsun, cifar10]" ) raise SystemExit() print() print("%d images found" % len(data)) # Z sample #sample_z = np.random.normal(0, 1.0, size=(self.sample_num , self.z_dim)) sample_z = np.random.uniform(-1.0, 1.0, size=(self.sample_num, self.z_dim)) # Input samples sample_files = data[0:self.sample_num] sample = [ get_image(sample_file, input_height=self.input_height, input_width=self.input_width, resize_height=self.output_height, resize_width=self.output_width, is_crop=self.is_crop, is_grayscale=self.is_grayscale) for sample_file in sample_files ] if (self.is_grayscale): sample_inputs = np.array(sample).astype(np.float32)[:, :, :, None] else: sample_inputs = np.array(sample).astype(np.float32) if self.load_checkpoint: if self.load(self.checkpoint_dir): print(" [*] Load SUCCESS") else: print(" [!] Load failed...") # Batch preparing batch_nums = min(len(data), config.train_size) // config.batch_size data_idx = list(range(len(data))) counter = self.counter_start start_time = time.time() # Loop over epochs for epoch in range(config.epoch): # Assign learning rates for d and g lrate = config.learning_rate_d # * (config.lr_decay_rate_d ** epoch) self.sess.run(tf.assign(self.learning_rate_d, lrate)) lrate = config.learning_rate_g # * (config.lr_decay_rate_g ** epoch) self.sess.run(tf.assign(self.learning_rate_g, lrate)) # Shuffle the data indices np.random.shuffle(data_idx) # Loop over batches for batch_idx in range(batch_nums): # Prepare batch idx = data_idx[batch_idx * config.batch_size:(batch_idx + 1) * config.batch_size] batch = [ get_image(data[i], input_height=self.input_height, input_width=self.input_width, resize_height=self.output_height, resize_width=self.output_width, is_crop=self.is_crop, is_grayscale=self.is_grayscale) for i in idx ] if (self.is_grayscale): batch_images = np.array(batch).astype(np.float32)[:, :, :, None] else: batch_images = np.array(batch).astype(np.float32) #batch_z = np.random.normal(0, 1.0, size=(config.batch_size , self.z_dim)).astype(np.float32) batch_z = np.random.uniform( -1.0, 1.0, size=(config.batch_size, self.z_dim)).astype(np.float32) # Update D network _, summary_str = self.sess.run([self.d_optim, self.d_sum], feed_dict={ self.inputs: batch_images, self.z: batch_z }) if np.mod(counter, 20) == 0: self.writer.add_summary(summary_str, counter) # Update G network _, summary_str = self.sess.run([self.g_optim, self.g_sum], feed_dict={self.z: batch_z}) if np.mod(counter, 20) == 0: self.writer.add_summary(summary_str, counter) errD_fake = self.d_loss_fake.eval({self.z: batch_z}) errD_real = self.d_loss_real.eval({self.inputs: batch_images}) errG = self.g_loss.eval({self.z: batch_z}) # Print if np.mod(counter, 100) == 0: print("Epoch: [%2d] [%4d/%4d] time: %4.4f, d_loss: %.8f, g_loss: %.8f" \ % (epoch, batch_idx, batch_nums, time.time() - start_time, errD_fake+errD_real, errG)) # Save generated samples and FID if np.mod(counter, config.fid_eval_steps) == 0: # Save try: samples, d_loss, g_loss = self.sess.run( [self.sampler, self.d_loss, self.g_loss], feed_dict={ self.z: sample_z, self.inputs: sample_inputs }) save_images( samples, [8, 8], '{}/train_{:02d}_{:04d}.png'.format( config.sample_dir, epoch, batch_idx)) print("[Sample] d_loss: %.8f, g_loss: %.8f" % (d_loss, g_loss)) except Exception as e: print(e) print("sample image error!") # FID print("samples for incept", end="", flush=True) samples = np.zeros((self.fid_n_samples, self.output_height, self.output_width, 3)) n_batches = self.fid_n_samples // self.fid_sample_batchsize lo = 0 for btch in range(n_batches): print("\rsamples for incept %d/%d" % (btch + 1, n_batches), end=" ", flush=True) #sample_z_fid = np.random.normal(0, 1.0, size=(self.fid_sample_batchsize, self.z_dim)) sample_z_fid = np.random.uniform( -1.0, 1.0, size=(self.fid_sample_batchsize, self.z_dim)) samples[lo:( lo + self.fid_sample_batchsize)] = self.sess.run( self.sampler_fid, feed_dict={self.z_fid: sample_z_fid}) lo += self.fid_sample_batchsize samples = (samples + 1.) * 127.5 print("ok") mu_gen, sigma_gen = fid.calculate_activation_statistics( samples, self.sess, batch_size=self.fid_batch_size, verbose=self.fid_verbose) print("calculate FID:", end=" ", flush=True) try: FID = fid.calculate_frechet_distance( mu_gen, sigma_gen, mu_real, sigma_real) except Exception as e: print(e) FID = 500 print(FID) # Update event log with FID self.sess.run(tf.assign(self.fid, FID)) summary_str = self.sess.run(self.fid_sum) self.writer.add_summary(summary_str, counter) # Save checkpoint if (counter != 0) and (np.mod(counter, 2000) == 0): self.save(config.checkpoint_dir, counter) counter += 1
print("Check for inception model..", end=" ", flush=True) inception_path = fid.check_or_download_inception( inception_path) # download inception if necessary print("OK") # loads all images into memory (this might require a lot of RAM!) print("Load images..", end=" ", flush=True) (x_train, _), (x_test, _) = tf.keras.datasets.cifar10.load_data() x_train = x_train.astype(np.float32) x_test = x_test.astype(np.float32) print("%d/%d images found and loaded" % (len(x_train), len(x_test))) print("Create inception graph..", end=" ", flush=True) fid.create_inception_graph( inception_path) # load the graph into the current TF graph print("OK") print("Calculte FID stats..", end=" ", flush=True) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) mu_train, sigma_train = fid.calculate_activation_statistics(x_train, sess, batch_size=100) mu_test, sigma_test = fid.calculate_activation_statistics(x_test, sess, batch_size=100) np.savez_compressed(out_path_train, mu=mu_train, sigma=sigma_train) np.savez_compressed(out_path_test, mu=mu_test, sigma=sigma_test) print("Finished")
if args.mode == "pre-calculate": print("load images..") image_list = glob.glob(os.path.join(args.image_path, '*.jpg')) images = np.array( [imread(image).astype(np.float32) for image in image_list]) print("%d images found and loaded" % len(images)) print("create inception graph..", end=" ", flush=True) fid.create_inception_graph(inception_path) print("ok") print("calculate FID stats..", end=" ", flush=True) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) mu, sigma = fid.calculate_activation_statistics(images, sess, batch_size=100) np.savez_compressed(args.stats_path, mu=mu, sigma=sigma) print("finished") else: image_list = glob.glob(os.path.join(args.image_path, '*.jpg')) images = np.array( [imread(str(fn)).astype(np.float32) for fn in image_list]) f = np.load(args.stats_path) mu_real, sigma_real = f['mu'][:], f['sigma'][:] f.close() fid.create_inception_graph(inception_path) with tf.Session() as sess: sess.run(tf.global_variables_initializer())
def train(self, config): """Train DCGAN""" assert len(self.paths) > 0, 'no data loaded, was model not built?' print("load train stats.. ", end="", flush=True) # load precalculated training set statistics f = np.load(self.stats_path) mu_real, sigma_real = f['mu'][:], f['sigma'][:] f.close() print("ok") if self.load_checkpoint: if self.load(self.checkpoint_dir): print(" [*] Load SUCCESS") else: print(" [!] Load failed...") # Batch preparing batch_nums = min(len(self.paths), config.train_size) // config.batch_size counter = self.counter_start errD_fake = 0. errD_real = 0. errG = 0. errG_count = 0 penD_gradient = 0. penD_lipschitz = 0. esti_slope = 0. lipschitz_estimate = 0. start_time = time.time() try: # Loop over epochs for epoch in range(config.epoch): # Assign learning rates for d and g lrate = config.learning_rate_d # * (config.lr_decay_rate_d ** epoch) self.sess.run(tf.assign(self.learning_rate_d, lrate)) lrate = config.learning_rate_g # * (config.lr_decay_rate_g ** epoch) self.sess.run(tf.assign(self.learning_rate_g, lrate)) # Loop over batches for batch_idx in range(batch_nums): # Update D network _, errD_fake_, errD_real_, summary_str, penD_gradient_, penD_lipschitz_, esti_slope_, lipschitz_estimate_ = self.sess.run( [self.d_optim, self.d_loss_fake, self.d_loss_real, self.d_sum, self.d_gradient_penalty_loss, self.d_lipschitz_penalty_loss, self.d_mean_slope_target, self.d_lipschitz_estimate]) for i in range(self.num_discriminator_updates - 1): self.sess.run([self.d_optim, self.d_loss_fake, self.d_loss_real, self.d_sum, self.d_gradient_penalty_loss, self.d_lipschitz_penalty_loss]) if np.mod(counter, 20) == 0: self.writer.add_summary(summary_str, counter) # Update G network if config.learning_rate_g > 0.: # and (np.mod(counter, 100) == 0 or lipschitz_estimate_ > 1 / (20 * self.lipschitz_penalty)): _, errG_, summary_str = self.sess.run([self.g_optim, self.g_loss, self.g_sum]) if np.mod(counter, 20) == 0: self.writer.add_summary(summary_str, counter) errG += errG_ errG_count += 1 errD_fake += errD_fake_ errD_real += errD_real_ penD_gradient += penD_gradient_ penD_lipschitz += penD_lipschitz_ esti_slope += esti_slope_ lipschitz_estimate += lipschitz_estimate_ # Print if np.mod(counter, 100) == 0: print("Epoch: [%2d] [%4d/%4d] time: %4.4f, d_loss: %.8f, lip_pen: %.8f, gradient_pen: %.8f, g_loss: %.8f, d_tgt_slope: %.6f, d_avg_lip: %.6f, g_updates: %3d" \ % (epoch, batch_idx, batch_nums, time.time() - start_time, (errD_fake+errD_real) / 100., penD_lipschitz / 100., penD_gradient / 100., errG / 100., esti_slope / 100., lipschitz_estimate / 100., errG_count)) errD_fake = 0. errD_real = 0. errG = 0. errG_count = 0 penD_gradient = 0. penD_lipschitz = 0. esti_slope = 0. lipschitz_estimate = 0. # Save generated samples and FID if np.mod(counter, config.fid_eval_steps) == 0: # Save try: samples, d_loss, g_loss = self.sess.run( [self.sampler, self.d_loss, self.g_loss]) save_images(samples, [8, 8], '{}/train_{:02d}_{:04d}.png'.format(config.sample_dir, epoch, batch_idx)) print("[Sample] d_loss: %.8f, g_loss: %.8f" % (d_loss, g_loss)) except Exception as e: print(e) print("sample image error!") # FID print("samples for incept", end="", flush=True) samples = np.zeros((self.fid_n_samples, self.output_height, self.output_width, 3)) n_batches = self.fid_n_samples // self.fid_sample_batchsize lo = 0 for btch in range(n_batches): print("\rsamples for incept %d/%d" % (btch + 1, n_batches), end=" ", flush=True) samples[lo:(lo+self.fid_sample_batchsize)] = self.sess.run(self.sampler_fid) lo += self.fid_sample_batchsize samples = (samples + 1.) * 127.5 print("ok") mu_gen, sigma_gen = fid.calculate_activation_statistics(samples, self.sess, batch_size=self.fid_batch_size, verbose=self.fid_verbose) print("calculate FID:", end=" ", flush=True) try: FID = fid.calculate_frechet_distance(mu_gen, sigma_gen, mu_real, sigma_real) except Exception as e: print(e) FID=500 print(FID) # Update event log with FID self.sess.run(tf.assign(self.fid, FID)) summary_str = self.sess.run(self.fid_sum) self.writer.add_summary(summary_str, counter) # Save checkpoint if (counter != 0) and (np.mod(counter, 2000) == 0): self.save(config.checkpoint_dir, counter) counter += 1 except KeyboardInterrupt as e: self.save(config.checkpoint_dir, counter) except Exception as e: print(e) finally: # When done, ask the threads to stop. self.coord.request_stop() self.coord.join(self.threads)
def compute(images): m, s = fid.calculate_activation_statistics( np.array(images), inception_sess, args.batch_size, verbose=True) return fid.calculate_frechet_distance(m, s, mu0, sig0)
def calculate_fid(self): import fid, pickle import tensorflow as tf stats_path = "fid_stats_cifar10_train.npz" # training set statistics inception_path = fid.check_or_download_inception( "./tmp/" ) # download inception network score = get_model(self.config) score = torch.nn.DataParallel(score) sigmas_th = get_sigmas(self.config) sigmas = sigmas_th.cpu().numpy() fids = {} for ckpt in tqdm.tqdm( range( self.config.fast_fid.begin_ckpt, self.config.fast_fid.end_ckpt + 1, 5000 ), desc="processing ckpt", ): states = torch.load( os.path.join(self.args.log_path, f"checkpoint_{ckpt}.pth"), map_location=self.config.device, ) if self.config.model.ema: ema_helper = EMAHelper(mu=self.config.model.ema_rate) ema_helper.register(score) ema_helper.load_state_dict(states[-1]) ema_helper.ema(score) else: score.load_state_dict(states[0]) score.eval() num_iters = ( self.config.fast_fid.num_samples // self.config.fast_fid.batch_size ) output_path = os.path.join(self.args.image_folder, "ckpt_{}".format(ckpt)) os.makedirs(output_path, exist_ok=True) for i in range(num_iters): init_samples = torch.rand( self.config.fast_fid.batch_size, self.config.data.channels, self.config.data.image_size, self.config.data.image_size, device=self.config.device, ) init_samples = data_transform(self.config, init_samples) all_samples = anneal_Langevin_dynamics( init_samples, score, sigmas, self.config.fast_fid.n_steps_each, self.config.fast_fid.step_lr, verbose=self.config.fast_fid.verbose, ) final_samples = all_samples[-1] for id, sample in enumerate(final_samples): sample = sample.view( self.config.data.channels, self.config.data.image_size, self.config.data.image_size, ) sample = inverse_data_transform(self.config, sample) save_image( sample, os.path.join(output_path, "sample_{}.png".format(id)) ) # load precalculated training set statistics f = np.load(stats_path) mu_real, sigma_real = f["mu"][:], f["sigma"][:] f.close() fid.create_inception_graph( inception_path ) # load the graph into the current TF graph final_samples = ( (final_samples - final_samples.min()) / (final_samples.max() - final_samples.min()).data.cpu().numpy() * 255 ) final_samples = np.transpose(final_samples, [0, 2, 3, 1]) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) mu_gen, sigma_gen = fid.calculate_activation_statistics( final_samples, sess, batch_size=100 ) fid_value = fid.calculate_frechet_distance( mu_gen, sigma_gen, mu_real, sigma_real ) print("FID: %s" % fid_value) with open(os.path.join(self.args.image_folder, "fids.pickle"), "wb") as handle: pickle.dump(fids, handle, protocol=pickle.HIGHEST_PROTOCOL)