def test(epoch, netG, test_data): f = open(opt.FID, 'a') if not os.path.exists('%s/%d' % (opt.output, epoch)): os.mkdir('%s/%d' % (opt.output, epoch)) print('test_data len = %d' % len(test_data)) for i, image in enumerate(test_data): imgA = image['A'] A_path = image['A_path'] real_A = imgA.cuda() fake_B = netG(real_A) name = A_path[0].split('/') vutils.save_image(fake_B[:, :, 3:253, 28:228], '%s/%d/%s' % (opt.output, epoch, name[-1]), normalize=True, scale_each=True) fid_value = calculate_fid_given_paths( '%s/%d' % (opt.output, epoch), '/home/lixiang/dataset/photosketch/sketch/test', 8, opt.gpuid != '', 2048) print(str(epoch) + " saved") str1 = str(epoch) + ':' + str(fid_value) + '\n' f.write(str1) f.close() print('calculate FID:%.4f' % fid_value)
def evaluate_model_samples(args, simulator, x_gen): """ Evaluate model samples and save results """ logger.info("Calculating likelihood of generated samples") try: if simulator.parameter_dim() is None: log_likelihood_gen = simulator.log_density(x_gen) else: params = simulator.default_parameters(true_param_id=args.trueparam) params = np.asarray([params for _ in range(args.generate)]) log_likelihood_gen = simulator.log_density(x_gen, parameters=params) log_likelihood_gen[np.isnan(log_likelihood_gen)] = -1.0e-12 np.save(create_filename("results", "samples_likelihood", args), log_likelihood_gen) except IntractableLikelihoodError: logger.info("True simulator likelihood is intractable for dataset %s", args.dataset) logger.info("Calculating distance from manifold of generated samples") try: distances_gen = simulator.distance_from_manifold(x_gen) np.save(create_filename("results", "samples_manifold_distance", args), distances_gen) except NotImplementedError: logger.info("Cannot calculate distance from manifold for dataset %s", args.dataset) if simulator.is_image(): if calculate_fid_given_paths is None: logger.warning( "Cannot compute FID score, did not find FID implementation") return logger.info("Calculating FID score of generated samples") # The FID script needs an image folder with tempfile.TemporaryDirectory() as gen_dir: logger.debug( f"Storing generated images in temporary folder {gen_dir}") array_to_image_folder(x_gen, gen_dir) true_dir = create_filename("dataset", None, args) + "/test" os.makedirs(os.path.dirname(true_dir), exist_ok=True) if not os.path.exists(f"{true_dir}/0.jpg"): array_to_image_folder( simulator.load_dataset(train=False, numpy=True, dataset_dir=create_filename( "dataset", None, args), true_param_id=args.trueparam)[0], true_dir) logger.debug("Beginning FID calculation with batchsize 50") fid = calculate_fid_given_paths([gen_dir, true_dir], 50, "", 2048) logger.info(f"FID = {fid}") np.save(create_filename("results", "samples_fid", args), [fid])
def mytest(self, step): num = 200 z = tensor2var(torch.randn(num, self.z_dim)) # 32*128 fake_images, gf1, gf2 = self.G(z) # 1*3*64*64 fake_images = fake_images.data # inception_score(fake_images) # fake_images = fake_images.resize_((1, 3, 218, 178)) for n in range(num): save_image(fake_images[n], '/home/xinzi/dataset_k40/test_celeba/%d.jpg' % n) str0 = '/home/xinzi/dataset_k40/test_celeba/' + str(n) + '.jpg' im = Image.open(str0) im = im.resize((178, 218)) im.save(str0) ''' path = '/home/jingjie/xinzi/dataset/test' a = os.listdir(path) #a.sort() for file in a: file_path = os.path.join(path, file) if os.path.splitext(file_path)[1] == '.png': im = Image.open(file_path) ''' #args = parser.parse_args() #os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu # fid_value = calculate_fid_given_paths(self.path1, self.path2, 1, self.gpu != '', self.dims) # print('FID: ', fid_value) # a = 0 str1 = '/home/xinzi/dataset_k40/celebAtemp' times = 5 sum = 0 for i in range(times): shutil.rmtree(str1) os.mkdir(str1) random_copyfile(self.path1, str1, 40000) # args.path0 = str1 """ dims = 64 """ fid_value = calculate_fid_given_paths(str1, self.path2, 100, self.gpu != '', self.dims) sum = sum + fid_value print('FID: ', fid_value) print(float(sum / times)) f = open('FID.txt', 'a') f.write('\n') f.write(str(float(sum / times))) f.close()
def validate(test_dataloader, netG_A2B, netG_B2A, dataset): now = datetime.now(dateutil.tz.tzlocal()) timestamp = now.strftime('%Y_%m_%d_%H_%M_%S') temp_buffer = "fid_buffer_{}".format(timestamp) if not os.path.exists(temp_buffer): os.mkdir(temp_buffer) os.mkdir(os.path.join(temp_buffer, 'A')) os.mkdir(os.path.join(temp_buffer, 'B')) for i, batch in enumerate(test_dataloader): # Set model input real_A = batch['A'].cuda() real_B = batch['B'].cuda() # Generate output fake_B = 0.5*(netG_A2B(real_A).data + 1.0) fake_A = 0.5*(netG_B2A(real_B).data + 1.0) # Save image files save_image( fake_A, os.path.join(temp_buffer, 'A', '%04d.png' % (i+1))) save_image( fake_A, os.path.join(temp_buffer, 'B', '%04d.png' % (i+1))) fid_value_A = calculate_fid_given_paths(['datasets/{}/train/A'.format(dataset), os.path.join(temp_buffer, 'A')] , 1, True, 2048) fid_value_B = calculate_fid_given_paths(['datasets/{}/train/B'.format(dataset), os.path.join(temp_buffer, 'B')] , 1, True, 2048) os.system('rm -rf {}'.format(temp_buffer)) print("FID A: {}, FID B: {}".format(fid_value_A, fid_value_B)) return fid_value_A, fid_value_B
def run(config): config['resolution'] = utils.imsize_dict[config['dataset']] config['n_classes'] = utils.nclass_dict[config['dataset']] config['G_activation'] = utils.activation_dict[config['G_nl']] config['D_activation'] = utils.activation_dict[config['D_nl']] if config['resume']: config['skip_init'] = True config = utils.update_config_roots(config) device = 'cuda' utils.seed_rng(config['seed']) utils.prepare_root(config) torch.backends.cudnn.benchmark = True model = __import__(config['model']) experiment_name = (config['experiment_name'] if config['experiment_name'] else utils.name_from_config(config)) G = model.Generator(**config).to(device) D = model.Discriminator(**config).to(device) G3 = model.Generator(**config).to(device) D3 = model.Discriminator(**config).to(device) if config['ema']: G_ema = model.Generator(**{**config, 'skip_init': True, 'no_optim': True}).to(device) ema = utils.ema(G, G_ema, config['ema_decay'], config['ema_start']) else: G_ema, ema = None, None if config['G_fp16']: G = G.half() if config['ema']: G_ema = G_ema.half() if config['D_fp16']: D = D.half() GD = model.G_D(G, D, config['conditional']) GD3 = model.G_D(G3, D3, config['conditional']) state_dict = {'itr': 0, 'epoch': 0, 'save_num': 0, 'save_best_num': 0, 'best_IS': 0, 'best_FID': 999999, 'config': config} if config['resume']: utils.load_weights(G, D, state_dict, config['weights_root'], experiment_name, config['load_weights'] if config['load_weights'] else None, G_ema if config['ema'] else None) #utils.load_weights(G, D, state_dict, '../Task3_CIFAR_MNIST_KLWGAN_Simulation_Experiment/weights', 'C10Ukl5', 'best0', G_ema if config['ema'] else None) #utils.load_weights(G, D, state_dict, '../Task1_CIFAR_MNIST_KLWGAN_Simulation_Experiment/weights', 'C10Ukl5', 'best0', G_ema if config['ema'] else None) #utils.load_weights(G3, D3, state_dict, '../Task2_CIFAR_MNIST_KLWGAN_Simulation_Experiment/weights', 'C10Ukl5', 'last0', G_ema if config['ema'] else None) #utils.load_weights(G3, D3, state_dict, '../Task2_CIFAR_MNIST_KLWGAN_Simulation_Experiment/weights', 'C10Ukl5', 'best0', G_ema if config['ema'] else None) utils.load_weights(G3, D3, state_dict, '../Task2_CIFAR_MNIST_KLWGAN_Simulation_Experiment/weights', 'C10Ukl5', 'last0', G_ema if config['ema'] else None) utils.load_weights(G, D, state_dict, '../Task3_CIFAR_MNIST_KLWGAN_Simulation_Experiment/weights', 'C10Ukl5', 'best0', G_ema if config['ema'] else None) if config['parallel']: GD = nn.DataParallel(GD) if config['cross_replica']: patch_replication_callback(GD) test_metrics_fname = '%s/%s_log.jsonl' % (config['logs_root'], experiment_name) train_metrics_fname = '%s/%s' % (config['logs_root'], experiment_name) test_log = utils.MetricsLogger(test_metrics_fname, reinitialize=(not config['resume'])) train_log = utils.MyLogger(train_metrics_fname, reinitialize=(not config['resume']), logstyle=config['logstyle']) utils.write_metadata(config['logs_root'], experiment_name, config, state_dict) D_batch_size = (config['batch_size'] * config['num_D_steps'] * config['num_D_accumulations']) # Use: config['abnormal_class'] #print(config['abnormal_class']) abnormal_class = config['abnormal_class'] select_dataset = config['select_dataset'] #print(config['select_dataset']) #print(select_dataset) loaders = utils.get_data_loaders(**{**config, 'batch_size': D_batch_size, 'start_itr': state_dict['itr'], 'abnormal_class': abnormal_class, 'select_dataset': select_dataset}) # Usage: --select_dataset cifar10 --abnormal_class 0 --shuffle --batch_size 64 --parallel --num_G_accumulations 1 --num_D_accumulations 1 --num_epochs 500 --num_D_steps 4 --G_lr 2e-4 --D_lr 2e-4 --dataset C10 --data_root ../Task2_CIFAR_MNIST_KLWGAN_Simulation_Experiment/data/ --G_ortho 0.0 --G_attn 0 --D_attn 0 --G_init N02 --D_init N02 --ema --use_ema --ema_start 1000 --start_eval 50 --test_every 5000 --save_every 2000 --num_best_copies 5 --num_save_copies 2 --loss_type kl_5 --seed 2 --which_best FID --model BigGAN --experiment_name C10Ukl5 # Use: --select_dataset mnist --abnormal_class 1 --shuffle --batch_size 64 --parallel --num_G_accumulations 1 --num_D_accumulations 1 --num_epochs 500 --num_D_steps 4 --G_lr 2e-4 --D_lr 2e-4 --dataset C10 --data_root ../Task2_CIFAR_MNIST_KLWGAN_Simulation_Experiment/data/ --G_ortho 0.0 --G_attn 0 --D_attn 0 --G_init N02 --D_init N02 --ema --use_ema --ema_start 1000 --start_eval 50 --test_every 5000 --save_every 2000 --num_best_copies 5 --num_save_copies 2 --loss_type kl_5 --seed 2 --which_best FID --model BigGAN --experiment_name C10Ukl5 G_batch_size = max(config['G_batch_size'], config['batch_size']) z_, y_ = utils.prepare_z_y(G_batch_size, G.dim_z, config['n_classes'], device=device, fp16=config['G_fp16']) fixed_z, fixed_y = utils.prepare_z_y(G_batch_size, G.dim_z, config['n_classes'], device=device, fp16=config['G_fp16']) fixed_z.sample_() fixed_y.sample_() if not config['conditional']: fixed_y.zero_() y_.zero_() if config['which_train_fn'] == 'GAN': train = train_fns.GAN_training_function(G3, D3, GD3, G3, D3, GD3, G, D, GD, z_, y_, ema, state_dict, config) else: train = train_fns.dummy_training_function() sample = functools.partial(utils.sample, G=(G_ema if config['ema'] and config['use_ema'] else G), z_=z_, y_=y_, config=config) if config['dataset'] == 'C10U' or config['dataset'] == 'C10': data_moments = 'fid_stats_cifar10_train.npz' #'../Task1_CIFAR_MNIST_KLWGAN_Simulation_Experiment/fid_stats_cifar10_train.npz' #data_moments = '../Task1_CIFAR_MNIST_KLWGAN_Simulation_Experiment/fid_stats_cifar10_train.npz' else: print("Cannot find the data set.") sys.exit() for epoch in range(state_dict['epoch'], config['num_epochs']): if config['pbar'] == 'mine': pbar = utils.progress(loaders[0], displaytype='s1k' if config['use_multiepoch_sampler'] else 'eta') else: pbar = tqdm(loaders[0]) for i, (x, y) in enumerate(pbar): state_dict['itr'] += 1 G.eval() D.train() if config['ema']: G_ema.train() if config['D_fp16']: x, y = x.to(device).half(), y.to(device) else: x, y = x.to(device), y.to(device) print('') # Random seed #print(config['seed']) if epoch==0 and i==0: print(config['seed']) metrics = train(x, y) # We double the learning rate if we double the batch size. train_log.log(itr=int(state_dict['itr']), **metrics) if (config['sv_log_interval'] > 0) and (not (state_dict['itr'] % config['sv_log_interval'])): train_log.log(itr=int(state_dict['itr']), **{**utils.get_SVs(G, 'G'), **utils.get_SVs(D, 'D')}) if config['pbar'] == 'mine': print(', '.join(['itr: %d' % state_dict['itr']] + ['%s : %+4.3f' % (key, metrics[key]) for key in metrics]), end=' ') if not (state_dict['itr'] % config['save_every']): if config['G_eval_mode']: G.eval() if config['ema']: G_ema.eval() train_fns.save_and_sample(G, D, G_ema, z_, y_, fixed_z, fixed_y, state_dict, config, experiment_name) experiment_name = (config['experiment_name'] if config['experiment_name'] else utils.name_from_config(config)) if (not (state_dict['itr'] % config['test_every'])) and (epoch >= config['start_eval']): if config['G_eval_mode']: G.eval() if config['ema']: G_ema.eval() utils.sample_inception( G_ema if config['ema'] and config['use_ema'] else G, config, str(epoch)) folder_number = str(epoch) sample_moments = '%s/%s/%s/samples.npz' % (config['samples_root'], experiment_name, folder_number) FID = fid_score.calculate_fid_given_paths([data_moments, sample_moments], batch_size=50, cuda=True, dims=2048) train_fns.update_FID(G, D, G_ema, state_dict, config, FID, experiment_name, test_log) state_dict['epoch'] += 1 #utils.save_weights(G, D, state_dict, config['weights_root'], experiment_name, 'be01Bes01Best%d' % state_dict['save_best_num'], G_ema if config['ema'] else None) utils.save_weights(G, D, state_dict, config['weights_root'], experiment_name, 'last%d' % 0, G_ema if config['ema'] else None)
with torch.no_grad(): Ae = netE(Xa) Xer, Ae = diffRender.render(**Ae) Ai = deep_copy(Ae) Ai['azimuths'] = - torch.empty((Xa.shape[0]), dtype=torch.float32).uniform_(-opt.azi_scope/2, opt.azi_scope/2).cuda() Xir, Ai = diffRender.render(**Ai) for i in range(len(paths)): path = paths[i] image_name = os.path.basename(path) rec_path = os.path.join(rec_dir, image_name) output_Xer = to_pil_image(Xer[i, :3].detach().cpu()) output_Xer.save(rec_path, 'JPEG', quality=100) inter_path = os.path.join(inter_dir, image_name) output_Xir = to_pil_image(Xir[i, :3].detach().cpu()) output_Xir.save(inter_path, 'JPEG', quality=100) ori_path = os.path.join(ori_dir, image_name) output_Xa = to_pil_image(Xa[i, :3].detach().cpu()) output_Xa.save(ori_path, 'JPEG', quality=100) fid_recon = calculate_fid_given_paths([ori_dir, rec_dir], 32, True) print('Test recon fid: %0.2f' % fid_recon) summary_writer.add_scalar('Test/fid_recon', fid_recon, epoch) fid_inter = calculate_fid_given_paths([ori_dir, inter_dir], 32, True) print('Test rotation fid: %0.2f' % fid_inter) summary_writer.add_scalar('Test/fid_inter', fid_inter, epoch) netE.train()
plt.subplot(1, 2, 2) plt.axis("off") plt.title("Fake Images") plt.imshow(np.transpose(img_list[-1], (1, 2, 0))) plt.show() if not os.path.exists('./checkpoint'): os.mkdir('./checkpoint') if not os.path.exists('./dataset'): os.mkdir('./dataset') if not os.path.exists('./img'): os.mkdir('./img') if not os.path.exists('./img/real'): os.mkdir('./img/real') if not os.path.exists('./img/fake'): os.mkdir('./img/fake') fake_dataset = netG(fixed_noise) fake_image_path_list = save_image_list(fake_dataset, False) real_dataset = real_batch[0].to(device)[:1000] real_image_path_list = save_image_list(real_dataset, True) fid_value = calculate_fid_given_paths( [real_image_path_list, fake_image_path_list], 50, True, 2048) print(f'FID scoreeeeeeeeeeeee: {fid_value}')
m1, s1 = fid_score.getRealm1s1(paths, batch_size, cuda, dims, model) print('finish prepare FID') # %% calc FID per dir # aug_type_vec = ['translationX'] # lambda_vec = np.arange(0.1, 1.1, 0.1).round(2) matplotlib.use('Agg') for aug_type in aug_type_vec: fid_vec = [] for lam in lambda_vec: print(aug_type, lam) folder_name = aug_type + '_' + str(lam) # .replace('.','p') paths = ['output/real_samples', 'weights/'+folder_name+'/fake_samples/'] fid = fid_score.calculate_fid_given_paths(paths, batch_size, cuda, dims, model, m1, s1) file = open("output/"+folder_name+"/result_FID.txt", "w") file.write(str(fid)) file.close() fid_vec.append(fid) p = plt.plot(lambda_vec, np.array(fid_vec)) plt.title(aug_type) plt.xlabel("lambda_aug") plt.ylabel("fid score") plt.savefig("output/"+aug_type+"_result_FID.png") print("FINISH calc FID per dir") # %% read FID per dir # matplotlib.use('Agg')
torch.save(netG.state_dict(), f"./checkpoint/netG_{version}_epoch_{epoch+1}.pth") torch.save(netD.state_dict(), f"./checkpoint/netD_{version}_epoch_{epoch+1}.pth") print(f"Save checkpoint at epoch {epoch+1}") # f"./checkpoint/netG_{version}_epoch_{epoch+1}.pth" # TEST print("Start Test") TB = 1000 dataloader = torch.utils.data.DataLoader(dataset, batch_size=TB, shuffle=True, num_workers=workers) for i, (data, _) in enumerate(dataloader): real_dataset = data break noise = torch.randn(TB, nz, 1, 1).cuda() fake_dataset = netG(noise) print("Generate real/fake images") real_image_path_list = save_image_list(real_dataset, True) fake_image_path_list = save_image_list(fake_dataset, False) print("Evaluate FID score") fid_value = calculate_fid_given_paths( [real_image_path_list, fake_image_path_list], 1000, False, 2048) print(f'FID score: {fid_value}')
def fid_score(self): return calculate_fid_given_paths([self._data_path, self._output_path], 32, self._cuda, 2048)
def get_fid(self): """Calculate the FID score""" # Load the trained generator. self.restore_model(self.test_iters) data_loader = self.labelled_loader if self.dataset == 'CelebA': attr_list = self.selected_attrs elif self.dataset == 'RaFD': attr_list = self.selected_emots real_dir = os.path.join(self.result_dir, 'real') if not os.path.isdir(real_dir): os.mkdir(real_dir) fake_dir = os.path.join(self.result_dir, 'fake_all') if not os.path.isdir(fake_dir): os.mkdir(fake_dir) with torch.no_grad(): r_count = 0 f_count = 0 for i, (x_real, c_org) in enumerate(data_loader): # Prepare input images and target domain labels. x_real = x_real.to(self.device) c_trg_list = self.create_labels(c_org, self.c_dim, self.dataset, self.selected_attrs) if self.dataset == 'RaFD': c_org = self.label2onehot(c_org, self.c_dim) # Save real images. for r in range(x_real.size(0)): r_count += 1 c_o = ''.join( c_org[r].numpy().astype('int8').astype('str')) real_path = os.path.join( real_dir, 'r_{}_{}.jpg'.format(r_count, c_o)) save_image(self.denorm(x_real[r].data.cpu()), real_path) # Translate real images into fake ones and save them. f_count_old = f_count for c_trg, attr_name in zip(c_trg_list, attr_list): x_fake = self.G(x_real, c_trg) if f_count > f_count_old: f_count = f_count_old for f in range(x_fake.size(0)): f_count += 1 c_t = ''.join(c_trg[f].cpu().numpy().astype( 'int8').astype('str')) fake_path = os.path.join( fake_dir, attr_name + '_{}_{}.jpg'.format(f_count, c_t)) save_image(self.denorm(x_fake[f].data.cpu()), fake_path) fids = [] fid_path = os.path.join(self.result_dir, 'fid_scores.txt') fake_paths_list = [] for attr_name in attr_list: fake_paths_list.append( list(pathlib.Path(fake_dir).glob('{}*'.format(attr_name)))) fake_paths_list.append(fake_dir) fids = calculate_fid_given_paths([real_dir, fake_paths_list], 128, True, 2048, device=self.device_num) with open(fid_path, 'w') as f: for attr_name, fid in zip(attr_list, fids[:-1]): f.write('{}: {:.4f}\n'.format(attr_name, fid)) f.write('Average FID: {:.4f}\n'.format(np.mean(fids[:-1]))) f.write('Overall FID: {:.4f}\n'.format(fids[-1])) print('Overall FID: {:.4f}\n'.format(fids[-1]))
p_small_G = "MNIST_DCGAN_results/RGB/small_G_" + str(small_size) if not os.path.isdir(p_small_G): os.mkdir(p_small_G) p_mnist = "MNIST_DCGAN_results/RGB/mnist" if os.path.exists(p_mnist+".npz"): p_mnist = p_mnist+".npz" #path = [p_mnist, p_small_G] for i in range(0,1000): src=skimage.transform.resize(images_small_numpy[i][0], (64, 64)) imageio.imwrite("temp.jpg",skimage.img_as_ubyte(src)) src = cv2.imread("temp.jpg", 0) src_RGB = cv2.cvtColor(src, cv2.COLOR_GRAY2RGB) cv2.imwrite(p_small_G + '/'+str(i)+".jpg", src_RGB) fid = fid_score.calculate_fid_given_paths([p_mnist, p_small_G], 50, True, 2048) if fid < best_FID: best_FID = fid print("best_FID:" + str(fid)) state = {'small_G': small_G.state_dict(), 'G_optimizer': G_optimizer.state_dict(),'train_hist' : train_hist, 'epoch': epoch+1, 'total_ptime' : total_ptime, 'best_FID' : best_FID} torch.save(state, "MNIST_DCGAN_results/FID/state_small_"+ str(small_size)+ ".pkl") if fid < best_FID: best_FID = fid print("best_FID:" + str(fid)) state = {'small_G': small_G.state_dict(), 'G_optimizer': G_optimizer.state_dict(),'train_hist' : train_hist, 'epoch': epoch+1, 'total_ptime' : total_ptime, 'best_FID' : best_FID} torch.save(state, "MNIST_DCGAN_results/FID/best_state_small_"+ str(small_size)+ ".pkl") state = {'small_G': small_G.state_dict(), 'G_optimizer': G_optimizer.state_dict(),'train_hist' : train_hist, 'epoch': epoch+1, 'total_ptime' : total_ptime, 'best_FID' : best_FID} torch.save(state, "MNIST_DCGAN_results/FID/state_small_"+ str(small_size)+ ".pkl") # end_time = time.time() # total_ptime = end_time - start_time
import torch from fid_score import calculate_fid_given_paths # The code is downloaded from github fid_value = calculate_fid_given_paths(['./img/real', './img/fake'], 100, True, 2048) print(f'FID score: {fid_value}')
if not os.path.exists(save_dir): os.mkdir(save_dir) for i, image in enumerate(test_loader): imgA = image[0] # imgB = image['A'] real_A = Variable(imgA.cuda()) # real_B = Variable(imgB.cuda()) fake_B = net_G(real_A) # fakeB[i, :, :, :] = fake_B.data # realB[i, :, :, :] = real_B.data print("%d.jpg generate completed" % (i+88)) # vutils.save_image(fake_B[:, :, 3:253, 28:228], # '%s/fakeB/%d.png' % (opt.output, i), # normalize=True, # scale_each=True) vutils.save_image(fake_B[:, :, 3:253, 28:228],'%s/%d.png' % (save_dir, i),normalize=True,scale_each=True) fid_value = calculate_fid_given_paths(save_dir, opt.dataroot + '/CUHK/Testing Images/sketches', 8, opt.gpuid != '',2048) print(fid_value) # vutils.save_image(realB, # '%s/realB_8.png' % (opt.outf), # normalize=True, # scale_each=True)
def compress(big_size, small_size): # training parameters batch_size = 128 lr = 0.0002 train_epoch = 20 # data_loader img_size = 64 transform = transforms.Compose([ transforms.Resize(img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]) ]) train_loader = torch.utils.data.DataLoader( datasets.MNIST('data', train=True, download=True, transform=transform), batch_size=batch_size, shuffle=True) G = generator(big_size) small_G = generator(small_size) # D = discriminator(128) G.weight_init(mean=0.0, std=0.02) small_G.weight_init(mean=0.0, std=0.02) # D.weight_init(mean=0.0, std=0.02) G.cuda() small_G.cuda() # D.cuda() # Binary Cross Entropy loss # BCE_loss = nn.BCELoss() L1_loss = nn.L1Loss() # Adam optimizer G_optimizer = optim.Adam(small_G.parameters(), lr=lr, betas=(0.5, 0.999)) # D_optimizer = optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.999)) # results save folder if not os.path.isdir('MNIST_DCGAN_results'): os.mkdir('MNIST_DCGAN_results') # if not os.path.isdir('MNIST_DCGAN_results/Random_results'): # os.mkdir('MNIST_DCGAN_results/Random_results') # if not os.path.isdir('MNIST_DCGAN_results/Fixed_results'): # os.mkdir('MNIST_DCGAN_results/Fixed_results') if not os.path.isdir('MNIST_DCGAN_results/Compress/'): os.mkdir('MNIST_DCGAN_results/Compress/') if not os.path.isdir('MNIST_DCGAN_results/Compress/'+ str(big_size) + '_' + str(small_size)): os.mkdir('MNIST_DCGAN_results/Compress/'+ str(big_size) + '_' + str(small_size)) if not os.path.isdir('MNIST_DCGAN_results/Compress/'+str(big_size) + '_' + str(small_size)+ '/Training'): os.mkdir('MNIST_DCGAN_results/Compress/'+str(big_size) + '_' + str(small_size)+ '/Training') if not os.path.isdir('MNIST_DCGAN_results/Compress/'+ str(big_size) + '_' + str(small_size)+ '/Validation'): os.mkdir('MNIST_DCGAN_results/Compress/'+ str(big_size) + '_' + str(small_size)+ '/Validation') if not os.path.isdir('MNIST_DCGAN_results/FID/'): os.mkdir('MNIST_DCGAN_results/FID/') if os.path.exists("MNIST_DCGAN_results/iter/best_state_"+ str(big_size)+ ".pkl"): print("load") checkpoint = torch.load("MNIST_DCGAN_results/iter/best_state_"+ str(big_size)+ ".pkl") G.load_state_dict(checkpoint['G']) if os.path.exists("MNIST_DCGAN_results/iter/state_"+ str(small_size)+ ".pkl"): checkpoint = torch.load("MNIST_DCGAN_results/iter/state_" + str(small_size)+ ".pkl") small_G.load_state_dict(checkpoint['G']) # D.load_state_dict(checkpoint['D']) G_optimizer.load_state_dict(checkpoint['G_optimizer']) # D_optimizer.load_state_dict(checkpoint['D_optimizer']) train_hist = checkpoint['train_hist'] total_ptime = checkpoint['total_ptime'] start_epoch = checkpoint['epoch'] best_FID = checkpoint['best_FID'] print("start from epoch" + str(start_epoch)) #num_iter = start_epoch else: start_epoch = 0 total_ptime = 0 train_hist = {} best_FID = sys.maxsize #train_hist['D_losses'] = [] train_hist['G_losses'] = [] train_hist['per_epoch_ptimes'] = [] train_hist['total_ptime'] = [] #num_iter = start_epoch print('training start!') # start_time = time.time() for epoch in range(start_epoch, train_epoch): # D_losses = [] G_losses = [] epoch_start_time = time.time() num_iter = 0 for x_, _ in train_loader: mini_batch = x_.size()[0] G.zero_grad() z_ = torch.randn((mini_batch, 100)).view(-1, 100, 1, 1) z_ = Variable(z_.cuda()) G_result = G(z_) small_G_result = small_G(z_) G_train_loss = L1_loss( small_G_result, G_result.detach()) G_train_loss.backward() G_optimizer.step() G_losses.append(G_train_loss.item()) num_iter += 1 epoch_end_time = time.time() per_epoch_ptime = epoch_end_time - epoch_start_time print(time.strftime("%Y-%m-%d %H:%M:%S",time.localtime())) print('[%d/%d] - ptime: %.2f, loss_g: %.3f' % ((epoch + 1), train_epoch, per_epoch_ptime, torch.mean(torch.FloatTensor(G_losses)))) #p = 'MNIST_DCGAN_results/Compress/Random_results/MNIST_DCGAN_' + str(epoch + 1) fixed_p = 'MNIST_DCGAN_results/Compress/'+ str(big_size) + '_' + str(small_size)+ '/Validation/MNIST_DCGAN_' + str(epoch + 1) #show_result((epoch+1), save=True, path=p, isFix=False) show_result(G, small_G, (epoch+1), save=True, path=fixed_p, isFix=True) #train_hist['D_losses'].append(torch.mean(torch.FloatTensor(D_losses))) train_hist['G_losses'].append(torch.mean(torch.FloatTensor(G_losses))) train_hist['per_epoch_ptimes'].append(per_epoch_ptime) total_ptime += per_epoch_ptime torch.cuda.empty_cache() for h in range(1): #export jpg z_ = torch.randn((10000, 100)).view(-1, 100, 1, 1) with torch.no_grad(): z_ = Variable(z_.cuda()) test_images_small = small_G(z_) images_small_numpy = test_images_small.cpu().data.numpy() torch.cuda.empty_cache() p_small_G = "MNIST_DCGAN_results/RGB/" +str(big_size) + "_"+ str(small_size) if not os.path.isdir(p_small_G): os.mkdir(p_small_G) p_mnist = "MNIST_DCGAN_results/RGB/mnist" if os.path.exists(p_mnist+".npz"): p_mnist = p_mnist+".npz" #path = [p_mnist, p_small_G] for i in range(0,10000): src=skimage.transform.resize(images_small_numpy[i][0], (64, 64)) imageio.imwrite("temp.jpg",skimage.img_as_ubyte(src)) src = cv2.imread("temp.jpg", 0) src_RGB = cv2.cvtColor(src, cv2.COLOR_GRAY2RGB) cv2.imwrite(p_small_G + '/'+str(h*1 + i)+".jpg", src_RGB) fid = fid_score.calculate_fid_given_paths([p_mnist, p_small_G], 50, True, 2048) #shutil.rmtree(p_small_G) if fid < best_FID: best_FID = fid print("FID update:" + str(fid)) state = {'G': small_G.state_dict(), 'G_optimizer': G_optimizer.state_dict(),'train_hist' : train_hist, 'epoch': epoch+1, 'total_ptime' : total_ptime, 'best_FID' : best_FID} torch.save(state, "MNIST_DCGAN_results/iter/best_state_"+ str(small_size)+ ".pkl") else: print("FID:" + str(fid)) state = {'G': small_G.state_dict(), 'G_optimizer': G_optimizer.state_dict(),'train_hist' : train_hist, 'epoch': epoch+1, 'total_ptime' : total_ptime, 'best_FID' : best_FID} torch.save(state, "MNIST_DCGAN_results/iter/state_"+ str(small_size)+ ".pkl") torch.cuda.empty_cache() # end_time = time.time() # total_ptime = end_time - start_time print("G_" + str(big_size) + '_' + str(small_size)+"_best_FID:" + str(best_FID)) train_hist['total_ptime'].append(total_ptime) print("Avg per epoch ptime: %.2f, total %d epochs ptime: %.2f" % (torch.mean(torch.FloatTensor(train_hist['per_epoch_ptimes'])), train_epoch, total_ptime)) print("Training finish!... save training results")
'%s/netG_epoch_%d.pth' % (opt.outf, epoch)) torch.save(netD.state_dict(), '%s/netD_epoch_%d.pth' % (opt.outf, epoch)) if True: batch_size = opt.batchSize netG.load_state_dict( torch.load('%s/netG_epoch_%d.pth' % (opt.outf, opt.niter - 1))) # test netG, and calculate FID score netG.eval() for i in range(fake_images_number): noise = torch.randn(batch_size, nz, 1, 1).cuda() fake = netG(noise) for j in range(fake.shape[0]): vutils.save_image(fake.detach()[j, ...], fake_folder + '/{}.png'.format(j + i * batch_size), normalize=True) netG.train() # calculate FID score fid_value = calculate_fid_given_paths([real_folder, fake_folder], opt.batchSize // 2, True) FIDs.append(fid_value) print('FID: {}'.format(fid_value)) df = pd.DataFrame(FIDs) df.to_csv('FID_{}.csv'.format(opt.outf))
def perform_evaluation(run_name, image_type): out_dir = os.path.join(os.getcwd(), '..', 'output', run_name) checkpoint_dir = os.path.join(out_dir, 'chkpts') checkpoints = sorted(glob.glob(os.path.join(checkpoint_dir, '*'))) evaluation_dict = {} for point in checkpoints: if not int( point.split('/')[-1].split('_')[1].split('.')[0]) % 10000 == 0: continue iter_num = int(point.split('/')[-1].split('_')[1].split('.')[0]) model_file = point.split('/')[-1] config = load_config('../configs/fr_default.yaml', None) is_cuda = (torch.cuda.is_available()) checkpoint_io = CheckpointIO(checkpoint_dir=checkpoint_dir) device = torch.device("cuda:0" if is_cuda else "cpu") generator, discriminator = build_models(config) # Put models on gpu if needed generator = generator.to(device) discriminator = discriminator.to(device) # Use multiple GPUs if possible generator = nn.DataParallel(generator) discriminator = nn.DataParallel(discriminator) generator_test_9 = copy.deepcopy(generator) generator_test_99 = copy.deepcopy(generator) generator_test_999 = copy.deepcopy(generator) generator_test_9999 = copy.deepcopy(generator) # Register modules to checkpoint checkpoint_io.register_modules( generator=generator, generator_test_9=generator_test_9, generator_test_99=generator_test_99, generator_test_999=generator_test_999, generator_test_9999=generator_test_9999, discriminator=discriminator, ) # Load checkpoint load_dict = checkpoint_io.load(model_file) # Distributions ydist = get_ydist(config['data']['nlabels'], device=device) zdist = get_zdist(config['z_dist']['type'], config['z_dist']['dim'], device=device) z_sample = torch.Tensor(np.load('z_data.npy')).to(device) #for name, model in zip(['0_', '09_', '099_', '0999_', '09999_'], [generator, generator_test_9, generator_test_99, generator_test_999, generator_test_9999]): for name, model in zip( ['099_', '0999_', '09999_'], [generator_test_99, generator_test_999, generator_test_9999]): # Evaluator evaluator = Evaluator(model, zdist, ydist, device=device) x_sample = [] for i in range(10): x = evaluator.create_samples(z_sample[i * 1000:(i + 1) * 1000]) x_sample.append(x) x_sample = torch.cat(x_sample) x_sample = x_sample / 2 + 0.5 if not os.path.exists('fake_data'): os.makedirs('fake_data') for i in range(10000): torchvision.utils.save_image(x_sample[i, :, :, :], 'fake_data/{}.png'.format(i)) fid_score = calculate_fid_given_paths( ['fake_data', image_type + '_real'], 50, True, 2048) print(iter_num, name, fid_score) os.system("rm -rf " + "fake_data") evaluation_dict[(iter_num, name[:-1])] = {'FID': fid_score} if not os.path.exists('evaluation_data/' + run_name): os.makedirs('evaluation_data/' + run_name) pickle.dump( evaluation_dict, open('evaluation_data/' + run_name + '/eval_fid.p', 'wb'))
def run(config): config['resolution'] = imsize_dict[config['dataset']] config['n_classes'] = nclass_dict[config['dataset']] config['G_activation'] = activation_dict[config['G_nl']] config['D_activation'] = activation_dict[config['D_nl']] if config['resume']: config['skip_init'] = True config = update_config_roots(config) device = 'cuda' utils_Task1_KLWGAN_Simulation_Experiment.seed_rng(config['seed']) utils_Task1_KLWGAN_Simulation_Experiment.prepare_root(config) torch.backends.cudnn.benchmark = True model = __import__(config['model']) experiment_name = (config['experiment_name'] if config['experiment_name'] else utils_Task1_KLWGAN_Simulation_Experiment.name_from_config(config)) G = model.Generator(**config).to(device) D = model.Discriminator(**config).to(device) if config['ema']: G_ema = model.Generator(**{**config, 'skip_init': True, 'no_optim': True}).to(device) ema = utils_Task1_KLWGAN_Simulation_Experiment.ema(G, G_ema, config['ema_decay'], config['ema_start']) else: G_ema, ema = None, None if config['G_fp16']: G = G.half() if config['ema']: G_ema = G_ema.half() if config['D_fp16']: D = D.half() GD = model.G_D(G, D, config['conditional']) state_dict = {'itr': 0, 'epoch': 0, 'save_num': 0, 'save_best_num': 0, 'best_IS': 0, 'best_FID': 999999, 'config': config} if config['resume']: utils_Task1_KLWGAN_Simulation_Experiment.load_weights(G, D, state_dict, config['weights_root'], experiment_name, config['load_weights'] if config['load_weights'] else None, G_ema if config['ema'] else None) if config['parallel']: GD = nn.DataParallel(GD) if config['cross_replica']: patch_replication_callback(GD) test_metrics_fname = '%s/%s_log.jsonl' % (config['logs_root'], experiment_name) train_metrics_fname = '%s/%s' % (config['logs_root'], experiment_name) test_log = utils_Task1_KLWGAN_Simulation_Experiment.MetricsLogger(test_metrics_fname, reinitialize=(not config['resume'])) train_log = utils_Task1_KLWGAN_Simulation_Experiment.MyLogger(train_metrics_fname, reinitialize=(not config['resume']), logstyle=config['logstyle']) utils_Task1_KLWGAN_Simulation_Experiment.write_metadata(config['logs_root'], experiment_name, config, state_dict) D_batch_size = (config['batch_size'] * config['num_D_steps'] * config['num_D_accumulations']) # Use: config['abnormal_class'] #print(config['abnormal_class']) abnormal_class = config['abnormal_class'] select_dataset = config['select_dataset'] #print(config['select_dataset']) #print(select_dataset) loaders = utils_Task1_KLWGAN_Simulation_Experiment.get_data_loaders(**{**config, 'batch_size': D_batch_size, 'start_itr': state_dict['itr'], 'abnormal_class': abnormal_class, 'select_dataset': select_dataset}) G_batch_size = max(config['G_batch_size'], config['batch_size']) z_, y_ = utils_Task1_KLWGAN_Simulation_Experiment.prepare_z_y(G_batch_size, G.dim_z, config['n_classes'], device=device, fp16=config['G_fp16']) fixed_z, fixed_y = utils_Task1_KLWGAN_Simulation_Experiment.prepare_z_y(G_batch_size, G.dim_z, config['n_classes'], device=device, fp16=config['G_fp16']) fixed_z.sample_() fixed_y.sample_() if not config['conditional']: fixed_y.zero_() y_.zero_() if config['which_train_fn'] == 'GAN': train = train_fns.GAN_training_function(G, D, GD, z_, y_, ema, state_dict, config) else: train = train_fns.dummy_training_function() sample = functools.partial(utils_Task1_KLWGAN_Simulation_Experiment.sample, G=(G_ema if config['ema'] and config['use_ema'] else G), z_=z_, y_=y_, config=config) if config['dataset'] == 'C10U' or config['dataset'] == 'C10': data_moments = 'fid_stats_cifar10_train.npz' #'../Task1_CIFAR_MNIST_KLWGAN_Simulation_Experiment/fid_stats_cifar10_train.npz' #data_moments = '../Task1_CIFAR_MNIST_KLWGAN_Simulation_Experiment/fid_stats_cifar10_train.npz' else: print("Cannot find the dataset.") sys.exit() for epoch in range(state_dict['epoch'], config['num_epochs']): if config['pbar'] == 'mine': pbar = utils_Task1_KLWGAN_Simulation_Experiment.progress(loaders[0], displaytype='s1k' if config['use_multiepoch_sampler'] else 'eta') else: pbar = tqdm(loaders[0]) for i, (x, y) in enumerate(pbar): state_dict['itr'] += 1 G.train() D.train() if config['ema']: G_ema.train() if config['D_fp16']: x, y = x.to(device).half(), y.to(device) else: x, y = x.to(device), y.to(device) #metrics = train(x, y) print('') # Random seed #print(config['seed']) if epoch==0 and i==0: print(config['seed']) metrics = train(x, y) # We double the learning rate if we double the batch size. train_log.log(itr=int(state_dict['itr']), **metrics) if (config['sv_log_interval'] > 0) and (not (state_dict['itr'] % config['sv_log_interval'])): train_log.log(itr=int(state_dict['itr']), **{**utils_Task1_KLWGAN_Simulation_Experiment.get_SVs(G, 'G'), **utils_Task1_KLWGAN_Simulation_Experiment.get_SVs(D, 'D')}) if config['pbar'] == 'mine': print(', '.join(['itr: %d' % state_dict['itr']] + ['%s : %+4.3f' % (key, metrics[key]) for key in metrics]), end=' ') if not (state_dict['itr'] % config['save_every']): if config['G_eval_mode']: G.eval() if config['ema']: G_ema.eval() train_fns.save_and_sample(G, D, G_ema, z_, y_, fixed_z, fixed_y, state_dict, config, experiment_name) experiment_name = (config['experiment_name'] if config['experiment_name'] else utils_Task1_KLWGAN_Simulation_Experiment.name_from_config(config)) if (not (state_dict['itr'] % config['test_every'])) and (epoch >= config['start_eval']): if config['G_eval_mode']: G.eval() if config['ema']: G_ema.eval() utils_Task1_KLWGAN_Simulation_Experiment.sample_inception(G_ema if config['ema'] and config['use_ema'] else G, config, str(epoch)) folder_number = str(epoch) sample_moments = '%s/%s/%s/samples.npz' % (config['samples_root'], experiment_name, folder_number) #FID = fid_score.calculate_fid_given_paths([data_moments, sample_moments], batch_size=50, cuda=True, dims=2048) #train_fns.update_FID(G, D, G_ema, state_dict, config, FID, experiment_name, test_log) # Use the files train_fns.py and utils_Task1_KLWGAN_Simulation_Experiment.py # Use the functions update_FID() and save_weights() # Save the lowest FID score FID = fid_score.calculate_fid_given_paths([data_moments, sample_moments], batch_size=50, cuda=True, dims=2048) train_fns.update_FID(G, D, G_ema, state_dict, config, FID, experiment_name, test_log) # FID also from: https://github.com/DarthSid95/RumiGANs/blob/main/gan_metrics.py # Implicit generative models and GANs generate sharp, low-FID, realistic, and high-quality images. # We use implicit generative models and GANs for the challenging task of anomaly detection in high-dimensional spaces. state_dict['epoch'] += 1 # Save the last model utils_Task1_KLWGAN_Simulation_Experiment.save_weights(G, D, state_dict, config['weights_root'], experiment_name, 'last%d' % 0, G_ema if config['ema'] else None)
def run(config): # Update the config dict as necessary # This is for convenience, to add settings derived from the user-specified # configuration into the config-dict (e.g. inferring the number of classes # and size of the images from the dataset, passing in a pytorch object # for the activation specified as a string) config['resolution'] = utils.imsize_dict[config['dataset']] config['n_classes'] = utils.nclass_dict[config['dataset']] config['G_activation'] = utils.activation_dict[config['G_nl']] config['D_activation'] = utils.activation_dict[config['D_nl']] # By default, skip init if resuming training. if config['resume']: print('Skipping initialization for training resumption...') config['skip_init'] = True config = utils.update_config_roots(config) device = 'cuda' # Seed RNG utils.seed_rng(config['seed']) # Prepare root folders if necessary utils.prepare_root(config) # Setup cudnn.benchmark for free speed torch.backends.cudnn.benchmark = True # Import the model--this line allows us to dynamically select different files. model = __import__(config['model']) experiment_name = (config['experiment_name'] if config['experiment_name'] else utils.name_from_config(config)) print('Experiment name is %s' % experiment_name) # Next, build the model G = model.Generator(**config).to(device) D = model.Discriminator(**config).to(device) # If using EMA, prepare it if config['ema']: print('Preparing EMA for G with decay of {}'.format( config['ema_decay'])) G_ema = model.Generator(**{**config, 'skip_init': True, 'no_optim': True}).to(device) ema = utils.ema(G, G_ema, config['ema_decay'], config['ema_start']) else: G_ema, ema = None, None # FP16? if config['G_fp16']: print('Casting G to float16...') G = G.half() if config['ema']: G_ema = G_ema.half() if config['D_fp16']: print('Casting D to fp16...') D = D.half() # Consider automatically reducing SN_eps? GD = model.G_D(G, D, config['conditional']) # setting conditional to false print(G) print(D) print('Number of params in G: {} D: {}'.format( *[sum([p.data.nelement() for p in net.parameters()]) for net in [G, D]])) # Prepare state dict, which holds things like epoch # and itr # state_dict = {'itr': 0, 'epoch': 0, 'save_num': 0, 'save_best_num': 0, 'best_IS': 0, 'best_FID': 999999, 'config': config} # If loading from a pre-trained model, load weights if config['resume']: print('Loading weights...') utils.load_weights(G, D, state_dict, config['weights_root'], experiment_name, config['load_weights'] if config['load_weights'] else None, G_ema if config['ema'] else None) # If parallel, parallelize the GD module if config['parallel']: GD = nn.DataParallel(GD) if config['cross_replica']: patch_replication_callback(GD) # Prepare loggers for stats; metrics holds test metrics, # lmetrics holds any desired training metrics. test_metrics_fname = '%s/%s_log.jsonl' % (config['logs_root'], experiment_name) train_metrics_fname = '%s/%s' % (config['logs_root'], experiment_name) print('Inception Metrics will be saved to {}'.format(test_metrics_fname)) test_log = utils.MetricsLogger(test_metrics_fname, reinitialize=(not config['resume'])) print('Training Metrics will be saved to {}'.format(train_metrics_fname)) train_log = utils.MyLogger(train_metrics_fname, reinitialize=(not config['resume']), logstyle=config['logstyle']) # Write metadata utils.write_metadata(config['logs_root'], experiment_name, config, state_dict) # Prepare data; the Discriminator's batch size is all that needs to be passed # to the dataloader, as G doesn't require dataloading. # Note that at every loader iteration we pass in enough data to complete # a full D iteration (regardless of number of D steps and accumulations) D_batch_size = (config['batch_size'] * config['num_D_steps'] * config['num_D_accumulations']) loaders = utils.get_data_loaders(**{**config, 'batch_size': D_batch_size, 'start_itr': state_dict['itr']}) # Prepare noise and randomly sampled label arrays # Allow for different batch sizes in G G_batch_size = max(config['G_batch_size'], config['batch_size']) z_, y_ = utils.prepare_z_y(G_batch_size, G.dim_z, config['n_classes'], device=device, fp16=config['G_fp16']) # Prepare a fixed z & y to see individual sample evolution throghout training fixed_z, fixed_y = utils.prepare_z_y(G_batch_size, G.dim_z, config['n_classes'], device=device, fp16=config['G_fp16']) fixed_z.sample_() fixed_y.sample_() if not config['conditional']: fixed_y.zero_() y_.zero_() # Loaders are loaded, prepare the training function if config['which_train_fn'] == 'GAN': train = train_fns.GAN_training_function(G, D, GD, z_, y_, ema, state_dict, config) # Else, assume debugging and use the dummy train fn else: train = train_fns.dummy_training_function() # Prepare Sample function for use with inception metrics sample = functools.partial(utils.sample, G=(G_ema if config['ema'] and config['use_ema'] else G), z_=z_, y_=y_, config=config) print('Beginning training at epoch %d...' % state_dict['epoch']) print('Total training epochs ', config['num_epochs']) print("the dataset is ", config['dataset'], ) if config['dataset'] == 'C10U' or config['dataset'] == 'C10': data_moments = 'fid_stats_cifar10_train.npz' else: print("cannot find the dataset") sys.exit() print("the data moments is ", data_moments) # Train for specified number of epochs, although we mostly track G iterations. for epoch in range(state_dict['epoch'], config['num_epochs']): # Which progressbar to use? TQDM or my own? if config['pbar'] == 'mine': pbar = utils.progress( loaders[0], displaytype='s1k' if config['use_multiepoch_sampler'] else 'eta') else: pbar = tqdm(loaders[0]) for i, (x, y) in enumerate(pbar): # Increment the iteration counter state_dict['itr'] += 1 # Make sure G and D are in training mode, just in case they got set to eval # For D, which typically doesn't have BN, this shouldn't matter much. G.train() D.train() if config['ema']: G_ema.train() if config['D_fp16']: x, y = x.to(device).half(), y.to(device) else: x, y = x.to(device), y.to(device) metrics = train(x, y) train_log.log(itr=int(state_dict['itr']), **metrics) # Every sv_log_interval, log singular values if (config['sv_log_interval'] > 0) and (not (state_dict['itr'] % config['sv_log_interval'])): train_log.log(itr=int(state_dict['itr']), **{**utils.get_SVs(G, 'G'), **utils.get_SVs(D, 'D')}) # If using my progbar, print metrics. if config['pbar'] == 'mine': print(', '.join(['itr: %d' % state_dict['itr']] + ['%s : %+4.3f' % (key, metrics[key]) for key in metrics]), end=' ') # Save weights and copies as configured at specified interval if not (state_dict['itr'] % config['save_every']): if config['G_eval_mode']: print('Switchin G to eval mode...') G.eval() if config['ema']: G_ema.eval() train_fns.save_and_sample(G, D, G_ema, z_, y_, fixed_z, fixed_y, state_dict, config, experiment_name) # Test every specified interval # First load celeba moments experiment_name = (config['experiment_name'] if config['experiment_name'] else utils.name_from_config(config)) if (not (state_dict['itr'] % config['test_every'])) and (epoch >= config['start_eval']): if config['G_eval_mode']: print('Switchin G to eval mode...') G.eval() if config['ema']: G_ema.eval() # sampling images and saving to samples/experiments/epoch utils.sample_inception( G_ema if config['ema'] and config['use_ema'] else G, config, str(epoch)) # Get saved sample path folder_number = str(epoch) sample_moments = '%s/%s/%s/samples.npz' % ( config['samples_root'], experiment_name, folder_number) # Calculate FID FID = fid_score.calculate_fid_given_paths([data_moments, sample_moments], batch_size=50, cuda=True, dims=2048) print("FID calculated") train_fns.update_FID( G, D, G_ema, state_dict, config, FID, experiment_name, test_log) # train_fns.test(G, D, G_ema, z_, y_, state_dict, config, sample, # get_inception_metrics, experiment_name, test_log) # Increment epoch counter at end of epoch state_dict['epoch'] += 1