def getIOU(self, lr, hr): if self.model == 'ESRGAN': lr = lr / 255.0 hr = hr / 255.0 #FORMAT THE INPUT h1, w1, d1 = hr.shape LR, HR, info = util.getTrainingPatches(lr, hr, self.args, transform=False) h2, w2 = info LR = LR.to(self.device) HR = HR.to(self.device) #GET EACH SR RESULT choices = self.agent.model(LR) maxval, maxarg = choices.max(dim=1) minval, minarg = choices.min(dim=1) l1 = [] for i, sisr in enumerate(self.SRmodels): sr = sisr(LR) l1.append(torch.abs(sr - HR).mean(dim=1).mean(dim=1).mean(dim=1)) l1diff = torch.stack(l1, dim=1) _, optimal_idx = l1diff.min(dim=1) _, predicted_idx = choices.max(dim=1) return info
def getTrainingIndices(self): indices = list(range(len(self.TRAINING_HRPATH))) data = [] for idx in indices: HRpath = self.TRAINING_HRPATH[idx] LRpath = self.TRAINING_LRPATH[idx] LR = imageio.imread(LRpath) HR = imageio.imread(HRpath) LR, HR, _ = util.getTrainingPatches(LR, HR, args) data.append(range(len(LR))) return data
def getGroundTruthIOU(self, lr, hr, samplesize=0.01): if self.model == 'ESRGAN' or self.model == 'basic': lr = lr / 255.0 hr = hr / 255.0 #FORMAT THE INPUT lr, hr, info = util.getTrainingPatches(lr, hr, args, transform=False) # WE EVALUATE IOU ON ENTIRE IMAGE FED AS BATCH OF PATCHES #batchsize = int(len(lr) * samplesize) maxsize = int(len(lr) * samplesize) patch_ids = list(range(len(hr))) score = 0 batch_ids = torch.Tensor(random.sample(patch_ids, maxsize)).long() #for i in range(0,maxsize-1,self.batch_size): #batch_ids = torch.Tensor(patch_ids[i:i+self.batch_size]).long() LR = lr[batch_ids] HR = hr[batch_ids] LR = LR.to(self.device) HR = HR.to(self.device) #GET EACH SR RESULT choices = self.agent.model(LR) diff = [] for i, sisr in enumerate(self.SRmodels): error = (sisr(LR) - HR).pow(2).mean(dim=1).mean(dim=1).mean( dim=1) # MSE #error = (torch.abs(sisr(LR) - HR).mean(dim=1).mean(dim=1).mean(dim=1) # MAE diff.append(error) l1diff = torch.stack(diff, dim=1) _, optimal_idx = l1diff.min(dim=1) _, predicted_idx = choices.max(dim=1) score += torch.sum((optimal_idx == predicted_idx).float()).item() # output IOU IOU = score / maxsize return IOU
def evaluateBounds(self, lr, hr): if self.model == 'ESRGAN': lr = lr / 255.0 hr = hr / 255.0 #FORMAT THE INPUT h1, w1, d1 = hr.shape LR, HR, info = util.getTrainingPatches(lr, hr, self.args, transform=False) h2, w2 = info LR = LR.to(self.device) HR = HR.to(self.device) #GET EACH SR RESULT choices = self.agent.model(LR) maxval, maxarg = choices.max(dim=1) minval, minarg = choices.min(dim=1) sisrs = [] l1 = [] for i, sisr in enumerate(self.SRmodels): sr = sisr(LR) sisrs.append(sr) l1.append(torch.abs(sr - HR).mean(dim=1).mean(dim=1).mean(dim=1)) l1diff = torch.stack(l1, dim=1) sisrs = torch.stack(sisrs, dim=1) mask = torch.zeros(sisrs.shape).to(self.device) mask[:, 0, 0] += 255 mask[:, 1, 1] += 255 mask[:, 2, 2] += 255 _, optimal_idx = l1diff.min(dim=1) _, worst_idx = l1diff.max(dim=1) sr_opt = sisrs[torch.arange(sisrs.size(0)), optimal_idx] # GATHER SUPER RESOLUTION BASED ON CHOICES optimal = sisrs[torch.arange(sisrs.size(0)), optimal_idx] worst = sisrs[torch.arange(sisrs.size(0)), worst_idx] sr = sisrs[torch.arange(sisrs.size(0)), maxarg] minsisr = sisrs[torch.arange(sisrs.size(0)), minarg] weight = torch.zeros(sr.shape).to(self.device) for i, w in enumerate(choices): for j, c in enumerate(w): weight[i, j] = c # GET DECISION MASKS optimalchoice = mask[torch.arange(sisrs.size(0)), optimal_idx] worstchoice = mask[torch.arange(sisrs.size(0)), worst_idx] maxchoice = mask[torch.arange(sisrs.size(0)), maxarg] minchoice = mask[torch.arange(sisrs.size(0)), minarg] # RECOMBINE RESULTS optimal = util.recombine(optimal, h1, w1, h2, w2) worst = util.recombine(worst, h1, w1, h2, w2) sr = util.recombine(sr, h1, w1, h2, w2) minsisr = util.recombine(minsisr, h1, w1, h2, w2) optimalchoice = util.recombine(optimalchoice, h1, w1, h2, w2) worstchoice = util.recombine(worstchoice, h1, w1, h2, w2) maxchoice = util.recombine(maxchoice, h1, w1, h2, w2) minchoice = util.recombine(minchoice, h1, w1, h2, w2) mask = util.recombine(weight, h1, w1, h2, w2) #FORMAT THE OUTPUT if self.model == 'ESRGAN': optimal = optimal.clip(0, 1) * 255.0 worst = worst.clip(0, 1) * 255.0 sr = sr.clip(0, 1) * 255.0 minsisr = minsisr.clip(0, 1) * 255.0 else: optimal = optimal.clip(0, 255) worst = worst.clip(0, 255) sr = sr.clip(0, 255) minsisr = minsisr.clip(0, 255) maxchoice = maxchoice.clip(0, 255) minchoice = minchoice.clip(0, 255) worstchoice = worstchoice.clip(0, 255) optimalchoice = optimalchoice.clip(0, 255) mask = mask * 255.0 info = { 'HR': hr, 'min': minsisr, 'max': sr, 'worst': worst, 'optimal': optimal, 'optimalchoice': optimalchoice, 'worstchoice': worstchoice, 'maxchoice': maxchoice, 'minchoice': minchoice, 'choices': choices, 'mask': mask } return info
def train(self,maxepoch=100,start=.01,end=0.0001): #EACH EPISODE TAKE ONE LR/HR PAIR WITH CORRESPONDING PATCHES #AND ATTEMPT TO SUPER RESOLVE EACH PATCH #requires pytorch 1.1.0+ which is not possible on the server #scheduler = torch.optim.lr_scheduler.CyclicLR(self.agent.optimizer,base_lr=0.0001,max_lr=0.1) #QUICK CHECK ON EVERYTHING #with torch.no_grad(): # psnr,ssim,info = self.test.validateSet5(save=False,quick=False) #START TRAINING indices = list(range(len(self.TRAINING_HRPATH))) lossfn = torch.nn.L1Loss() lossMSE = torch.nn.MSELoss() lossCE = torch.nn.CrossEntropyLoss() softmaxfn = torch.nn.Softmax(dim=1) #random.shuffle(indices) for c in count(): #FOR EACH HIGH RESOLUTION IMAGE for n,idx in enumerate(indices): idx = random.sample(indices,1)[0] #idx = 0 #GET INPUT FROM CURRENT IMAGE HRpath = self.TRAINING_HRPATH[idx] LRpath = self.TRAINING_LRPATH[idx] LR = imageio.imread(LRpath) HR = imageio.imread(HRpath) LR,HR,_ = util.getTrainingPatches(LR,HR,args) # WE GO THROUGH PATCH IN RANDOM ORDER patch_ids = list(range(len(LR))) random.shuffle(patch_ids) P = [] for step in range(1): batch_ids = random.sample(patch_ids,self.batch_size) #batch_ids = patch_ids[:self.batch_size] labels = torch.Tensor(batch_ids).long() lrbatch = LR[labels,:,:,:] hrbatch = HR[labels,:,:,:] lrbatch = lrbatch.to(self.device) hrbatch = hrbatch.to(self.device) if args.model == 'ESRGAN': lrbatch = lrbatch / 255.0 hrbatch = hrbatch / 255.0 #GET SISR RESULTS FROM EACH MODEL loss_SISR = 0 sisrs = [] probs = self.agent.model(lrbatch) for sisr in self.SRmodels: hr_pred = sisr(lrbatch) sisrs.append(hr_pred) #UPDATE BOTH THE SISR MODELS AND THE SELECTION MODEL ACCORDING TO THEIR LOSS sisrloss = [] l1loss = [] maxarg = probs.max(dim=1)[1] #onehot_mask = torch.nn.functional.one_hot(maxarg,len(sisrs)).float() for j, sr in enumerate(sisrs): self.SRoptimizers[j].zero_grad() l1 = torch.abs(sr - hrbatch).sum(dim=1).sum(dim=1).sum(dim=1) / ((self.PATCH_SIZE * self.UPSIZE)**2 * 3) l1loss.append(l1) loss = torch.mean(l1 * probs[:,j]) #loss = torch.mean(l1 * onehot_mask[:,j]) sisrloss.append(loss) l1loss = torch.stack(l1loss,dim=1) self.agent.opt.zero_grad() sisrloss_total = sum(sisrloss) sisrloss_total.backward() [opt.step() for opt in self.SRoptimizers] self.agent.opt.step() #[sched.step() for sched in self.schedulers] #self.agent.scheduler.step() #CONSOLE OUTPUT FOR QUICK AND DIRTY DEBUGGING lr = self.SRoptimizers[-1].param_groups[0]['lr'] lr2 = self.agent.opt.param_groups[0]['lr'] _,maxarg = probs[0].max(0) sampleSR = sisrs[maxarg.item()][0] sampleHR = hrbatch[0] if args.model != 'ESRGAN': sampleSR = sampleSR / 255.0 sampleHR = sampleHR / 255.0 choice = probs.max(dim=1)[1] c1 = (choice == 0).float().mean() c2 = (choice == 1).float().mean() c3 = (choice == 2).float().mean() s1 = torch.mean(l1loss[:,0]).item() s2 = torch.mean(l1loss[:,1]).item() s3 = torch.mean(l1loss[:,2]).item() agentloss = torch.mean(l1loss.gather(1,choice.unsqueeze(1))) print('\rEpoch/img: {}/{} | LR sr/ag: {:.8f}/{:.8f} | Agent Loss: {:.4f}, SISR Loss: {:.4f} | s1: {:.4f} | s2: {:.4f} | s3: {:.4f}'\ .format(c,n,lr,lr2,agentloss.item(),sisrloss_total.item(),s1,s2,s3),end="\n") #LOG AND SAVE THE INFORMATION scalar_summaries = {'Loss/AgentLoss': agentloss, 'Loss/SISRLoss': sisrloss_total, "choice/c1": c1, "choice/c2": c2, "choice/c3": c3, "sisr/s1": s1, "sisr/s2": s2, "sisr/s3": s3} hist_summaries = {'actions': probs.view(-1), "choices": choice.view(-1)} img_summaries = {'sr/HR': sampleHR.clamp(0,1),'sr/SR': sampleSR.clamp(0,1)} self.logger.hist_summary(hist_summaries) self.logger.scalar_summary(scalar_summaries) self.logger.image_summary(img_summaries) if self.logger.step % 100 == 0: with torch.no_grad(): psnr,ssim,info = self.test.validateSet5(save=False,quick=False) if self.logger: self.logger.scalar_summary({'Testing_PSNR': psnr, 'Testing_SSIM': ssim}) weightedmask = torch.from_numpy(info['mask']).permute(2,0,1) / 255.0 mask = torch.from_numpy(info['maxchoice']).permute(2,0,1) / 255.0 optimal_mask = torch.from_numpy(info['optimalchoice']).permute(2,0,1) / 255.0 hrimg = torch.Tensor(info['HR']).permute(2,0,1) / 255.0 srimg = torch.from_numpy(info['max']).permute(2,0,1) / 255.0 self.logger.image_summary({'Testing/Test Assignment':mask[:3],'Testing/Weight': weightedmask[:3], 'Testing/SR':srimg, 'Testing/HR': hrimg, 'Testing/optimalmask': optimal_mask}) self.savemodels() self.agent.model.train() [model.train() for model in self.SRmodels] self.logger.incstep()
def optimize(self, data, iou_threshold=0.5): self.agent.model.train() [model.train() for model in self.SRmodels] agent_iou = deque(maxlen=100) # while the agent iou is not good enough for c in count(): # get an image idx = random.sample(data, 1)[0] hr_path = self.TRAINING_HRPATH[idx] lr_path = self.TRAINING_LRPATH[idx] lr = imageio.imread(lr_path) hr = imageio.imread(hr_path) lr, hr, _ = util.getTrainingPatches(lr, hr, args, transform=False) patch_ids = list(range(len(lr))) random.shuffle(patch_ids) # get the mini batch batch_ids = random.sample(patch_ids, self.batch_size) labels = torch.Tensor(batch_ids).long() lr_batch = lr[labels, :, :, :] hr_batch = hr[labels, :, :, :] lr_batch = lr_batch.to(self.device) hr_batch = hr_batch.to(self.device) if args.model == 'ESRGAN' or args.model == 'basic': lr_batch = lr_batch / 255.0 hr_batch = hr_batch / 255.0 # UPDATE THE SISR MODELS self.agent.opt.zero_grad() sr_result = torch.zeros( self.batch_size, 3, self.PATCH_SIZE * self.UPSIZE, self.PATCH_SIZE * self.UPSIZE).to(self.device) sr_result.requires_gard = False probs = self.agent.model(lr_batch) #sisr_loss = 0 sisrs = [] pred_diff = [] for j, sisr in enumerate(self.SRmodels): self.SRoptimizers[j].zero_grad() hr_pred = sisr(lr_batch) diff = (hr_pred - hr_batch).pow(2).mean(dim=1).mean( dim=1).mean(dim=1) #MEAN OF FROB NORM SQUARED ACROSS CxHxW #diff = torch.abs(hr_pred - hr_batch).sum(dim=1).sum(dim=1).sum(dim=1) / ((self.PATCH_SIZE * self.UPSIZE)**2 * 3) #MAE ACROSS CxHxW #sisr_loss += torch.mean(diff * probs[:,j]) pred_diff.append(diff) sisrs.append(hr_pred) pred_diff = torch.stack(pred_diff, dim=1) minval, optimalidx = pred_diff.min(dim=1) selectionloss = torch.mean( probs.gather(1, optimalidx.unsqueeze(1)).clamp(1e-16, 1).log()) * -1 sisrloss = minval.mean() sisr_loss_total = sisrloss + selectionloss sisr_loss_total.backward() [opt.step() for opt in self.SRoptimizers] self.agent.opt.step() # VISUALIZATION # CONSOLE OUTPUT FOR QUICK AND DIRTY DEBUGGING lr1 = self.SRoptimizers[-1].param_groups[0]['lr'] lr2 = self.agent.opt.param_groups[0]['lr'] _, maxarg = probs[0].max(0) sample_sr = sisrs[maxarg.item()][0] sample_hr = hr_batch[0] if args.model != 'ESRGAN' and args.model != 'basic': sample_sr = sample_sr / 255.0 sample_hr = sample_hr / 255.0 choice = probs.max(dim=1)[1] iou = (choice == optimalidx).float().sum() / (len(choice)) c1 = (choice == 0).float().mean() c2 = (choice == 1).float().mean() c3 = (choice == 2).float().mean() s1 = torch.mean(pred_diff[:, 0]).item() s2 = torch.mean(pred_diff[:, 1]).item() s3 = torch.mean(pred_diff[:, 2]).item() agent_iou.append(iou.item()) print('\rEpoch/img: {}/{} | LR: {:.8f} | Agent Loss: {:.4f}, SISR Loss: {:.4f} | IOU: {:.4f} | c1: {:.4f}, c2: {:.4f}, c3: {:.4f}'\ .format(c,self.logger.step,lr2,selectionloss.item(),sisrloss.item(), np.mean(agent_iou), c1.item(), c2.item(), c3.item()),end="\n") #print('\rEpoch/img: {}/{} | LR sr/ag: {:.8f}/{:.8f} | Agent Loss: {:.4f} | SISR Loss: {:.4f} | IOU : {:.4f} | s1: {:.4f} | s2: {:.4f} | s3: {:.4f}'\ # .format(c,self.logger.step,lr1,lr2,sisr_loss_total.item(),sisr_loss_total.item(),np.mean(agent_iou),s1,s2,s3),end="\n") #LOG AND SAVE THE INFORMATION scalar_summaries = { 'Loss/AgentLoss': sisr_loss_total, 'Loss/SISRLoss': sisr_loss_total, "Loss/IOU": np.mean(agent_iou), "choice/c1": c1, "choice/c2": c2, "choice/c3": c3, "sisr/s1": s1, "sisr/s2": s2, "sisr/s3": s3 } #hist_summaries = {'actions': probs.view(-1), "choices": choice.view(-1)} img_summaries = { 'sr/HR': sample_hr.clamp(0, 1), 'sr/SR': sample_sr.clamp(0, 1) } #self.logger.hist_summary(hist_summaries) self.logger.scalar_summary(scalar_summaries) self.logger.image_summary(img_summaries) if self.logger.step % 100 == 0: with torch.no_grad(): psnr, ssim, info = self.test.validateSet5(save=False, quick=False) self.agent.model.train() [model.train() for model in self.SRmodels] if self.logger: self.logger.scalar_summary({ 'Testing_PSNR': psnr, 'Testing_SSIM': ssim }) weightedmask = torch.from_numpy(info['mask']).permute( 2, 0, 1) / 255.0 mask = torch.from_numpy(info['maxchoice']).permute( 2, 0, 1) / 255.0 optimal_mask = torch.from_numpy( info['optimalchoice']).permute(2, 0, 1) / 255.0 hrimg = torch.Tensor(info["HR"]).permute(2, 0, 1) srimg = torch.from_numpy(info['max']).permute(2, 0, 1) / 255.0 self.logger.image_summary({ 'Testing/Test Assignment': mask[:3], 'Testing/Weight': weightedmask[:3], 'Testing/SR': srimg, 'Testing/HR': hrimg, 'Testing/optimalmask': optimal_mask[:3] }) #self.logger.image_summary({'Testing/Test Assignment':mask[:3], 'Testing/SR':srimg, 'Testing/HR': hrimg, 'Testing/upperboundmask': best_mask[:3]}) self.savemodels() self.logger.incstep() if np.mean(agent_iou) > iou_threshold or c + 1 % 10000 == 0: break
def optimize(self, data, iou_threshold=0.8): random.shuffle(data) for c, idx in enumerate(data): # increment the patches seen to all patches of that image agent_iou = deque(maxlen=100) # while the agent iou is not good enough while True: # get an image idx = random.sample(data, 1)[0] hr_path = self.TRAINING_HRPATH[idx] lr_path = self.TRAINING_LRPATH[idx] lr = imageio.imread(lr_path) hr = imageio.imread(hr_path) lr, hr = util.getTrainingPatches(lr, hr, args) patch_ids = list(range(len(lr))) # get the mini batch batch_ids = random.sample( patch_ids, self.batch_size) #TRAIN ON A SINGLE IMAGE labels = torch.Tensor(batch_ids).long() lr_batch = lr[labels, :, :, :] hr_batch = hr[labels, :, :, :] lr_batch = lr_batch.to(self.device) hr_batch = hr_batch.to(self.device) if self.model == 'ESRGAN': lr_batch = lr_batch / 255.0 hr_batch = hr_batch / 255.0 # GET SISR RESULTS FROM EACH MODEL sisrs = [] probs = self.agent.model(lr_batch) for j, sisr in enumerate(self.SRmodels): hr_pred = sisr(lr_batch) sisrs.append(hr_pred) # update the sisr model and the selection model sr_result = torch.zeros( self.batch_size, 3, self.PATCH_SIZE * self.UPSIZE, self.PATCH_SIZE * self.UPSIZE).to(self.device) sr_result.requires_gard = False l1diff = [] sisr_loss = 0 for j, sr in enumerate(sisrs): self.SRoptimizers[j].zero_grad() l1 = torch.abs(sr - hr_batch).mean(dim=1) sisr_loss += torch.mean(probs[:, j] * l1) l1diff.append(l1) pred = sr * probs[:, j].unsqueeze(1) sr_result += pred self.agent.opt.zero_grad() sisr_loss_total = sisr_loss sisr_loss_total.backward() [opt.step() for opt in self.SRoptimizers] self.agent.opt.step() # visualizations if self.model != 'ESRGAN': sr_result = sr_result / 255.0 maxval, maxidx = probs.max(dim=1) l1diff = torch.stack(l1diff, dim=1) diffmap = torch.nn.functional.softmax( -255 * (l1diff - torch.mean(l1diff)), dim=1) minval, minidx = l1diff.min(dim=1) target = torch.nn.functional.one_hot(minidx, len(sisrs)).permute( 0, 3, 1, 2) reward = (l1diff - l1diff.mean(1).unsqueeze(1)).detach() * -1 reward = (reward - reward.mean()) / reward.std() lr1 = self.SRoptimizers[-1].param_groups[0]['lr'] lr2 = self.agent.opt.param_groups[0]['lr'] choice = probs.max(dim=1)[1] iou = (choice == minidx).float().sum() / ( choice.shape[0] * choice.shape[1] * choice.shape[2]) c1 = (choice == 0).float().mean() c2 = (choice == 1).float().mean() c3 = (choice == 2).float().mean() s1 = torch.mean(l1diff[:, 0]).item() s2 = torch.mean(l1diff[:, 1]).item() s3 = torch.mean(l1diff[:, 2]).item() # SAVE THE IOU FOR THIS BATCH agent_iou.append(iou.item()) print('\rEpoch/img: {}/{} | LR sr/ag: {:.8f}/{:.8f} | Agent Loss {:.4f} | SISR Loss: {:.4f} | IOU: {:.4f} | s1: {:.4f}, s2: {:.4f}, s3: {:.4f}'\ .format(c,self.logger.step,lr1,lr2,sisr_loss_total.item(),sisr_loss_total.item(),np.mean(agent_iou), s1, s2, s3),end="\n") #LOG AND SAVE THE INFORMATION scalar_summaries = { 'Loss/IOU': np.mean(agent_iou), 'Loss/SISRLoss': sisr_loss_total, "choice/c1": c1, "choice/c2": c2, "choice/c3": c3 } hist_summaries = { 'actions': probs[0].view(-1), "choices": choice[0].view(-1) } img_summaries = { 'sr/mask': probs[0][:3], 'sr/sr': sr_result[0].clamp(0, 1), 'sr/targetmask': target[0][:3], 'sr/diffmap': diffmap[0][:3] } self.logger.scalar_summary(scalar_summaries) self.logger.hist_summary(hist_summaries) self.logger.image_summary(img_summaries) if self.logger.step % 100 == 0: with torch.no_grad(): psnr, ssim, info = self.test.validateSet5(save=False, quick=False) self.agent.model.train() [model.train() for model in self.SRmodels] if self.logger: self.logger.scalar_summary({ 'Testing_PSNR': psnr, 'Testing_SSIM': ssim }) mask = torch.from_numpy( info['choices']).float().permute(2, 0, 1) best_mask = info['upperboundmask'].squeeze() worst_mask = info['lowerboundmask'].squeeze() hrimg = info['HR'].squeeze() srimg = torch.from_numpy(info['weighted']).permute( 2, 0, 1).clamp(0, 1) advantage = info['advantage'] self.logger.image_summary({ 'Testing/Test Assignment': mask[:3], 'Testing/SR': srimg, 'Testing/HR': hrimg, 'Testing/upperboundmask': best_mask, 'Testing/advantage': advantage }) self.savemodels() self.logger.incstep() # breaking condition if np.mean(agent_iou) > iou_threshold: print(agent_iou) break