class SAGAN_test(object): def __init__(self, data_loader): # Data loader self.data_loader = data_loader self.labels_dict = { 0: 'airplane', 1: 'automobile', 2: 'bird', 3: 'cat', 4: 'deer', 5: 'dog', 6: 'frog', 7: 'horse', 8: 'ship', 9: 'truck' } # exact model and loss self.model = model self.adv_loss = adv_loss # Model hyper-parameters self.imsize = imsize self.g_num = g_num self.z_dim = z_dim self.g_conv_dim = g_conv_dim self.d_conv_dim = d_conv_dim self.parallel = parallel self.d_iters = d_iters self.batch_size = batch_size self.num_workers = num_workers self.pretrained_model = pretrained_model self.dataset = dataset self.model_save_path = model_save_path self.test_path = test_path device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") self.fid_model = FID("./log_path", device) # as a var log path not a string changed self.build_model() # Start with trained model if self.pretrained_model: self.load_pretrained_model() def test(self): device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # Data iterator data_iter = iter(self.data_loader) # Fixed input for debugging fixed_z = tensor2var(torch.randn(self.batch_size, 90)) #self.z_dim self.G.eval() fid_scores = [] n_batches = 2 for i in range(n_batches): real_images, labels = next(iter(self.data_loader)) if i == n_batches - 1: if self.batch_size <= 10: for l in labels: print(self.labels_dict[l.item()]) else: print( "Avoiding to print labels since batch size greater than 10" ) # Compute loss with real images real_images = tensor2var(real_images) labels = tensor2var(encode(labels)) z = tensor2var(torch.randn(real_images.size(0), 90)) fake_images, gf1, gf2 = self.G(z, labels) fid_score = self.fid_model.compute_fid(real_images, fake_images) fid_scores.append(fid_score) fid_score = self.fid_model.compute_fid(real_images, fake_images) save_image(denorm(fake_images.data), 'SAGAN_test.png') print("Image saved as SAGAN_test.png") avg_fid_score = sum(fid_scores) / len(fid_scores) print("Average FID_score for SA GAN, for ", n_batches, " is:", avg_fid_score) def build_model(self): device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") self.G = Generator(self.batch_size, self.imsize, self.z_dim, self.g_conv_dim).to(device) def build_tensorboard(self): return #from logger import Logger #self.logger = Logger(self.log_path) def load_pretrained_model(self): self.G.load_state_dict( torch.load( os.path.join(self.model_save_path, '{}_G.pth'.format(self.pretrained_model)))) print('loaded trained models (step: {})..!'.format( self.pretrained_model)) def save_sample(self, data_iter): real_images, _ = next(data_iter) save_image(denorm(real_images), os.path.join(self.sample_path, 'real.png'))
fixed_p = output_dir + str(epoch + 1) + '.png' vutils.save_image(G(save_noise).detach(), fixed_p, normalize=True) num_info = { 'Discriminator loss': torch.mean(torch.FloatTensor(D_losses)), 'Generator loss': torch.mean(torch.FloatTensor(G_losses)) } fake_to_show = G(save_noise).detach() #tensorboard logging writer.add_scalars('Loss', num_info, epoch) writer.add_image('Fake Samples', fake_to_show[0].cpu()) train_hist['per_epoch_ptimes'].append(per_epoch_ptime) if epoch % 30 == 0: fid_score = fid_model.compute_fid(real_image, G_result) print("FID score", fid_score) writer.add_scalar('FID Score', fid_score, epoch) end_time = time.time() total_ptime = end_time - start_time train_hist['total_ptime'].append(total_ptime) print("Avg one epoch ptime: %.2f, total %d epochs ptime: %.2f" % (torch.mean(torch.FloatTensor( train_hist['per_epoch_ptimes'])), train_epoch, total_ptime)) writer.close() with open(report_dir + 'train_hist.pkl', 'wb') as f: pickle.dump(train_hist, f) show_train_hist(train_hist, save=True, path=report_dir + 'train_hist.png')