def main(args): device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') dataset = get_dataset(args.dataset_type, args.input, *args.hw, args.factor, pre=args.pre_factor, threshold=args.E_thres, N=args.n_hardest) dataloader = DataLoader( dataset, batch_size=args.batch_size, shuffle=not args.no_shuffle, num_workers=0 ) generator = GeneratorRRDB(1, filters=64, num_res_blocks=args.residual_blocks, num_upsample=int(np.log2(args.factor)), power=args.scaling_power, res_scale=args.res_scale, use_transposed_conv=args.use_transposed_conv, fully_tconv_upsample=args.fully_transposed_conv, num_final_layer_res=args.num_final_res_blocks).to(device).eval() generator.thres = args.threshold generator.load_state_dict(torch.load(args.model, map_location=device)) criterion = torch.nn.L1Loss() mse = torch.nn.MSELoss() sumpool = SumPool2d(args.factor) if args.manual_image is not None: global manual_image for ii in manual_image: print("Proccessing image at location: " + str(ii)) truth_imgs = dataset.__getitem__(ii) truth_hr = truth_imgs["hr"].numpy() truth_lr = truth_imgs["lr"].unsqueeze(0) gen_hr = generator(truth_lr).detach() gen_lr = sumpool(gen_hr).squeeze(0).numpy() gen_hr = gen_hr.squeeze(0).numpy() truth_lr = truth_lr.squeeze(0).numpy() print(truth_hr.shape, truth_lr.shape, gen_hr.shape, gen_lr.shape) np.save(args.output+"image_"+str(ii)+"_truth_hr.npy", truth_hr) np.save(args.output+"image_"+str(ii)+"_truth_lr.npy", truth_lr) np.save(args.output+"image_"+str(ii)+"_gen_hr.npy", gen_hr) np.save(args.output+"image_"+str(ii)+"_gen_lr.npy", gen_lr) return for i, imgs in enumerate(dataloader): # Configure model input imgs_lr = imgs["lr"].to(device) imgs_hr = imgs["hr"].to(device) # Generate a high resolution image from low resolution input gen_hr = generator(imgs_lr).detach() with torch.no_grad(): gen_lr = sumpool(gen_hr).detach() gen_nnz = gen_hr[gen_hr > 0].view(-1) en_loss = 0 if len(gen_nnz) > 0: real_nnz = imgs_hr[imgs_hr > 0].view(-1) e_min = torch.min(torch.cat((gen_nnz, real_nnz), 0)).item() e_max = torch.max(torch.cat((gen_nnz, real_nnz), 0)).item() gen_hist = torch.histc(gen_nnz, 10, min=e_min, max=e_max).float() real_hist = torch.histc(real_nnz, 10, min=e_min, max=e_max).float() en_loss = mse(gen_hist, real_hist) print("HR L1Loss: %.3e, LR L1Loss: %.3e, Energy distribution loss %.3e" % (criterion(gen_hr, imgs_hr).item(), criterion(gen_lr, imgs_lr).item(), en_loss)) show(imgs_lr, imgs_hr, gen_hr, gen_lr)
def _set_model(self, device, hr_shape): # Initialize generator and discriminator self.generator = GeneratorRRDB( opt.channels, filters=64, num_res_blocks=opt.residual_blocks).to(device) self.discriminator = Discriminator(input_shape=(opt.channels, *hr_shape)).to(device) self.feature_extractor = FeatureExtractor().to(device) # Set feature extractor to inference mode self.feature_extractor.eval() # Losses self.criterion_GAN = torch.nn.BCEWithLogitsLoss().to(device) self.criterion_content = torch.nn.L1Loss().to(device) self.criterion_pixel = torch.nn.L1Loss().to(device)
def SuperResolution(f_name, ori): pth = "./generator.pth" channels = 3 residual_blocks = 23 device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu") # Define model and load model checkpoint generator = GeneratorRRDB(channels, filters=64, num_res_blocks=residual_blocks).to(device) generator.load_state_dict(torch.load(pth)) generator.eval() transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize(mean, std)]) # Prepare input image_tensor = Variable(transform(ori)).to(device).unsqueeze(0) # Upsample image with torch.no_grad(): sr_image = denormalize(generator(image_tensor)).cpu() # Save image path = os.path.join("./data/sr_img/", f_name) oripath = os.path.join("./data/preprocessed_img/", f_name) save_image(sr_image, path) result = OCR(path) global ocr_result ocr_result = result mse, psnr = PSNR(oripath, path) save_result(f_name, result, mse, psnr)
def upsample_empty(args): device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') generator = GeneratorRRDB(1, filters=64, num_res_blocks=args.residual_blocks, num_upsample=int(np.log2(args.factor)), power=args.scaling_power, res_scale=args.res_scale, use_transposed_conv=args.use_transposed_conv, fully_tconv_upsample=args.fully_transposed_conv, num_final_layer_res=args.num_final_res_blocks).to(device).eval() generator.thres = args.threshold generator.load_state_dict(torch.load(args.model, map_location=device)) sumpool = SumPool2d(args.factor) empty_hr = torch.zeros([1,1,*args.hw]) empty_lr = sumpool(empty_hr) noise = torch.abs(torch.randn(empty_hr.shape)) noise = noise / (torch.max(noise).item()) indices = np.random.choice(np.arange(noise.numpy().flatten().size), replace=False, size=int(noise.numpy().flatten().size)-150) #choose indices randomly noise[np.unravel_index(indices, noise.shape)] = 0 #and set them to zero noise_hr=empty_hr+noise noise_lr = sumpool(noise_hr) empty_sr = generator(empty_lr).detach() noise_sr = generator(noise_lr).detach() print(empty_hr.shape,empty_lr.shape,empty_sr.shape) #delete nnz = len([val for val in empty_sr.numpy().squeeze().flatten() if val > args.threshold]) noisennz = len([val for val in noise_sr.numpy().squeeze().flatten() if val > args.threshold]) hrnoisennz = len([val for val in noise_hr.numpy().squeeze().flatten() if val > args.threshold]) print("upsampled empty picture nnz: {}".format(nnz)) print("upsampled soft noise picture nnz: {}".format(noisennz)) print("hr soft noise picture nnz: {}".format(hrnoisennz)) global colors, vmax plt.figure() plt.subplot(221) plt.title("empty hr image") plt.imshow(toArray(empty_hr).squeeze(), cmap='gray', vmax=vmax) plt.subplot(222) plt.title("empty sr image") plt.imshow(toArray(empty_sr).squeeze(), cmap='gray', vmax=vmax) plt.subplot(223) plt.title("soft noise hr image") plt.imshow(toArray(noise_hr).squeeze(), cmap='gray', vmax=vmax) plt.subplot(224) plt.title("soft noise sr image") plt.imshow(toArray(noise_sr).squeeze(), cmap='gray', vmax=vmax) plt.show()
def network_initializers(self, hr_shape, use_LeakyReLU_Mish=False): generator = GeneratorRRDB(self.opt.channels, filters=64, num_res_blocks=self.opt.residual_blocks, use_LeakyReLU_Mish=use_LeakyReLU_Mish).to( self.device, non_blocking=True) discriminator = Discriminator( input_shape=(self.opt.channels, *hr_shape), use_LeakyReLU_Mish=use_LeakyReLU_Mish).to(self.device, non_blocking=True) feature_extractor = FeatureExtractor().to(self.device, non_blocking=True) # Set feature extractor to inference mode feature_extractor.eval() return discriminator, feature_extractor, generator
default=3, help="Number of image channels") parser.add_argument("--residual_blocks", type=int, default=23, help="Number of residual blocks in G") opt = parser.parse_args() print(opt) os.makedirs("images/outputs", exist_ok=True) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Define model and load model checkpoint generator = GeneratorRRDB(opt.channels, filters=64, num_res_blocks=opt.residual_blocks).to(device) generator.load_state_dict(torch.load(opt.checkpoint_model)) generator.eval() transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize(mean, std)]) # Prepare input image_tensor = Variable(transform(Image.open( opt.image_path))).to(device).unsqueeze(0) # Upsample image with torch.no_grad(): sr_image = denormalize(generator(image_tensor)).cpu()
# - os.makedirs(demo_out_dir, exist_ok=True) weight_path = "/workspace/output/cat_face/weight/generator_3900.pth" # # data device = torch.device("cuda" if torch.cuda.is_available() else "cpu") hr_shape = (opt.hr_height, opt.hr_width) # + # Initialize generator and discriminator generator = GeneratorRRDB(opt.channels, filters=64, num_res_blocks=opt.residual_blocks).to(device) generator.load_state_dict(torch.load(weight_path)) Tensor = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.Tensor # - demo_dataloader = DataLoader( DemoImageDataset(demo_in_dir), batch_size=opt.batch_size, shuffle=False, num_workers=opt.n_cpu, ) # # generate hr image
# # main # + # def main(opt): # - opt = Opt() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") hr_shape = (opt.hr_height, opt.hr_width) # Initialize generator and discriminator generator = GeneratorRRDB(opt.channels, filters=64, num_res_blocks=opt.residual_blocks).to(device) discriminator = Discriminator(input_shape=(opt.channels, *hr_shape)).to(device) feature_extractor = FeatureExtractor().to(device) # Set feature extractor to inference mode feature_extractor.eval() # Losses criterion_GAN = torch.nn.BCEWithLogitsLoss().to(device) criterion_content = torch.nn.L1Loss().to(device) criterion_pixel = torch.nn.L1Loss().to(device) # Optimizers optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr,
class ESRGAN(): def __init__(self, opt): self.opt = opt device = torch.device("cuda" if torch.cuda.is_available() else "cpu") hr_shape = (self.opt.hr_height, self.opt.hr_width) self._set_model(device, hr_shape) def _set_model(self, device, hr_shape): # Initialize generator and discriminator self.generator = GeneratorRRDB( opt.channels, filters=64, num_res_blocks=opt.residual_blocks).to(device) self.discriminator = Discriminator(input_shape=(opt.channels, *hr_shape)).to(device) self.feature_extractor = FeatureExtractor().to(device) # Set feature extractor to inference mode self.feature_extractor.eval() # Losses self.criterion_GAN = torch.nn.BCEWithLogitsLoss().to(device) self.criterion_content = torch.nn.L1Loss().to(device) self.criterion_pixel = torch.nn.L1Loss().to(device) def _set_param(self): for key, value in vars(opt).items(): mlflow.log_param(key, value) def _load_weigth(self): if opt.epoch != 0: # Load pretrained models load_g_weight_path = osp.join(weight_save_dir, "generator_%d.pth" % opt.epoch) load_d_weight_path = osp.join(weight_save_dir, "discriminator_%d.pth" % opt.epoch) self.generator.load_state_dict(torch.load(load_g_weight_path)) self.discriminator.load_state_dict(torch.load(load_d_weight_path)) # Optimizers self.optimizer_G = torch.optim.Adam(self.generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)) self.optimizer_D = torch.optim.Adam(self.discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)) # ---------- # Training # ---------- def train(self, dataloader, opt): for epoch in range(opt.epoch + 1, opt.n_epochs + 1): for batch_num, imgs in enumerate(dataloader): Tensor = torch.cuda.FloatTensor if torch.cuda.is_available( ) else torch.Tensor batches_done = (epoch - 1) * len(dataloader) + batch_num # Configure model input imgs_lr = Variable(imgs["lr"].type(Tensor)) imgs_hr = Variable(imgs["hr"].type(Tensor)) # Adversarial ground truths valid = Variable(Tensor( np.ones((imgs_lr.size(0), *discriminator.output_shape))), requires_grad=False) fake = Variable(Tensor( np.zeros((imgs_lr.size(0), *discriminator.output_shape))), requires_grad=False) # ------------------ # Train Generators # ------------------ optimizer_G.zero_grad() # Generate a high resolution image from low resolution input gen_hr = generator(imgs_lr) # Measure pixel-wise loss against ground truth loss_pixel = criterion_pixel(gen_hr, imgs_hr) # Warm-up (pixel-wise loss only) if batches_done <= opt.warmup_batches: loss_pixel.backward() optimizer_G.step() log_info = "[Epoch {}/{}] [Batch {}/{}] [G pixel: {}]".format( epoch, opt.n_epochs, batch_num, len(dataloader), loss_pixel.item()) sys.stdout.write("\r{}".format(log_info)) sys.stdout.flush() mlflow.log_metric('train_{}'.format('loss_pixel'), loss_pixel.item(), step=batches_done) else: # Extract validity predictions from discriminator pred_real = discriminator(imgs_hr).detach() pred_fake = discriminator(gen_hr) # Adversarial loss (relativistic average GAN) loss_GAN = criterion_GAN( pred_fake - pred_real.mean(0, keepdim=True), valid) # Content loss gen_features = feature_extractor(gen_hr) real_features = feature_extractor(imgs_hr).detach() loss_content = criterion_content(gen_features, real_features) # Total generator loss loss_G = loss_content + opt.lambda_adv * loss_GAN + opt.lambda_pixel * loss_pixel loss_G.backward() optimizer_G.step() # --------------------- # Train Discriminator # --------------------- optimizer_D.zero_grad() pred_real = discriminator(imgs_hr) pred_fake = discriminator(gen_hr.detach()) # Adversarial loss for real and fake images (relativistic average GAN) loss_real = criterion_GAN( pred_real - pred_fake.mean(0, keepdim=True), valid) loss_fake = criterion_GAN( pred_fake - pred_real.mean(0, keepdim=True), fake) # Total loss loss_D = (loss_real + loss_fake) / 2 loss_D.backward() optimizer_D.step() # -------------- # Log Progress # -------------- log_info = "[Epoch {}/{}] [Batch {}/{}] [D loss: {}] [G loss: {}, content: {}, adv: {}, pixel: {}]".format( epoch, opt.n_epochs, batch_num, len(dataloader), loss_D.item(), loss_G.item(), loss_content.item(), loss_GAN.item(), loss_pixel.item(), ) if batch_num == 1: sys.stdout.write("\n{}".format(log_info)) else: sys.stdout.write("\r{}".format(log_info)) sys.stdout.flush() # import pdb; pdb.set_trace() if batches_done % opt.sample_interval == 0: # Save image grid with upsampled inputs and ESRGAN outputs imgs_lr = nn.functional.interpolate(imgs_lr, scale_factor=4) img_grid = denormalize(torch.cat((imgs_lr, gen_hr), -1)) image_batch_save_dir = osp.join( image_train_save_dir, '{:07}'.format(batches_done)) os.makedirs(osp.join(image_batch_save_dir, "hr_image"), exist_ok=True) save_image(img_grid, osp.join(image_batch_save_dir, "hr_image", "%d.png" % batches_done), nrow=1, normalize=False) if batches_done % opt.checkpoint_interval == 0: # Save model checkpoints torch.save( generator.state_dict(), osp.join(weight_save_dir, "generator_%d.pth" % epoch)) torch.save( discriminator.state_dict(), osp.join(weight_save_dir, "discriminator_%d.pth" % epoch)) mlflow.log_metric('train_{}'.format('loss_D'), loss_D.item(), step=batches_done) mlflow.log_metric('train_{}'.format('loss_G'), loss_G.item(), step=batches_done) mlflow.log_metric('train_{}'.format('loss_content'), loss_content.item(), step=batches_done) mlflow.log_metric('train_{}'.format('loss_GAN'), loss_GAN.item(), step=batches_done) mlflow.log_metric('train_{}'.format('loss_pixel'), loss_pixel.item(), step=batches_done)