class Session: def __init__(self): self.log_dir = settings.log_dir self.model_dir = settings.model_dir ensure_dir(settings.log_dir) ensure_dir(settings.model_dir) logger.info('set log dir as %s' % settings.log_dir) logger.info('set model dir as %s' % settings.model_dir) # self.net = RESCAN().cuda() if len(settings.device_id) > 1: self.net = nn.DataParallel(RESCAN()).cuda() else: self.net = RESCAN().cuda() self.crit = MSELoss().cuda() self.ssim = SSIM().cuda() self.dataloaders = {} def get_dataloader(self, dataset_name): dataset = TestDataset(dataset_name) if not dataset_name in self.dataloaders: self.dataloaders[dataset_name] = \ DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1, drop_last=False) return self.dataloaders[dataset_name] def load_checkpoints(self, name): ckp_path = os.path.join(self.model_dir, name) try: obj = torch.load(ckp_path) logger.info('Load checkpoint %s' % ckp_path) except FileNotFoundError: logger.info('No checkpoint %s!!' % ckp_path) return self.net.load_state_dict(obj['net']) def inf_batch(self, name, batch): O, B = batch['O'].cuda(), batch['B'].cuda() O, B = Variable(O, requires_grad=False), Variable(B, requires_grad=False) R = O - B with torch.no_grad(): O_Rs = self.net(O) loss_list = [self.crit(O_R, B) for O_R in O_Rs] ssim_list = [self.ssim(O_R, B) for O_R in O_Rs] psnr = PSNR(O_Rs[0].data.cpu().numpy() * 255, B.data.cpu().numpy() * 255) losses = { 'loss%d' % i: loss.item() for i, loss in enumerate(loss_list) } ssimes = { 'ssim%d' % i: ssim.item() for i, ssim in enumerate(ssim_list) } losses.update(ssimes) return losses, psnr
class Session: def __init__(self): self.show_dir = settings.show_dir self.model_dir = settings.model_dir ensure_dir(settings.show_dir) ensure_dir(settings.model_dir) logger.info('set show dir as %s' % settings.show_dir) logger.info('set model dir as %s' % settings.model_dir) # self.net = RESCAN().cuda() if len(settings.device_id) >1: self.net = nn.DataParallel(RESCAN()).cuda() else: self.net = RESCAN().cuda() self.dataloaders = {} self.ssim=SSIM().cuda() def get_dataloader(self, dataset_name): dataset = ShowDataset(dataset_name) self.dataloaders[dataset_name] = \ DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1) return self.dataloaders[dataset_name] def load_checkpoints(self, name): ckp_path = os.path.join(self.model_dir, name) try: obj = torch.load(ckp_path) logger.info('Load checkpoint %s' % ckp_path) except FileNotFoundError: logger.info('No checkpoint %s!!' % ckp_path) return self.net.load_state_dict(obj['net']) def inf_batch(self, name, batch): O, B = batch['O'].cuda(), batch['B'].cuda() O, B = Variable(O, requires_grad=False), Variable(B, requires_grad=False) with torch.no_grad(): O_Rs = self.net(O) #loss_list = [self.crit(O_R, B) for O_R in O_Rs] ssim_list = [self.ssim(O_R, B) for O_R in O_Rs] psnr=PSNR(O_Rs[0].data.cpu().numpy()*255, B.data.cpu().numpy()*255) print('psnr:%4f-------------ssim:%4f'%(psnr,ssim_list[0])) return O_Rs[-1],psnr,ssim_list[0] def save_image(self, No, imgs, name, psnr, ssim): for i, img in enumerate(imgs): img = (img.cpu().data * 255).numpy() img = np.clip(img, 0, 255) img = np.transpose(img, (1, 2, 0)) h, w, c = img.shape img_file = os.path.join(self.show_dir, '%s_%d_%d_%4f_%4f.png' % (name, No, i,psnr,ssim)) cv2.imwrite(img_file, img)
class Session: def __init__(self): self.show_dir = settings.show_dir self.model_dir = settings.model_dir ensure_dir(settings.show_dir) ensure_dir(settings.model_dir) logger.info('set show dir as %s' % settings.show_dir) logger.info('set model dir as %s' % settings.model_dir) self.net = RESCAN().cuda() self.dataset = None self.dataloader = None def get_dataloader(self, dataset_name): self.dataset = ShowDataset(dataset_name) self.dataloader = \ DataLoader(self.dataset, batch_size=1, shuffle=False, num_workers=1) return self.dataloader def load_checkpoints(self, name): ckp_path = os.path.join(self.model_dir, name) try: obj = torch.load(ckp_path) logger.info('Load checkpoint %s' % ckp_path) except FileNotFoundError: logger.info('No checkpoint %s!!' % ckp_path) return self.net.load_state_dict(obj['net']) def inf_batch(self, name, batch): O = batch['O'].cuda() O = Variable(O, requires_grad=False) with torch.no_grad(): O_Rs = self.net(O) O_Rs = [O - O_R for O_R in O_Rs] return O_Rs def save_image(self, No, imgs, ori): for i, img in enumerate(imgs): img = (img.cpu().data * 255).numpy() img = np.clip(img, 0, 255)[0] img = np.transpose(img, (1, 2, 0)) # h, w, c = img.shape # Add original img = np.hstack((ori, img)) img_file = os.path.join(self.show_dir, '%s_%d.png' % (No, i)) cv2.imwrite(img_file, img)
class Session: def __init__(self): self.log_dir = settings.log_dir self.model_dir = settings.model_dir ensure_dir(settings.log_dir) ensure_dir(settings.model_dir) logger.info('set log dir as %s' % settings.log_dir) logger.info('set model dir as %s' % settings.model_dir) self.net = RESCAN().cuda() self.crit = MSELoss().cuda() self.ssim = SSIM().cuda() self.step = 0 self.save_steps = settings.save_steps self.num_workers = settings.num_workers self.batch_size = settings.batch_size self.writers = {} self.dataloaders = {} self.opt = Adam(self.net.parameters(), lr=settings.lr) self.sche = MultiStepLR(self.opt, milestones=[15000, 17500], gamma=0.1) def tensorboard(self, name): self.writers[name] = SummaryWriter( os.path.join(self.log_dir, name + '.events')) return self.writers[name] def write(self, name, out): for k, v in out.items(): self.writers[name].add_scalar(k, v, self.step) out['lr'] = self.opt.param_groups[0]['lr'] out['step'] = self.step outputs = ["{}:{:.4g}".format(k, v) for k, v in out.items()] logger.info(name + '--' + ' '.join(outputs)) def get_dataloader(self, dataset_name): dataset = TrainValDataset(dataset_name) if not dataset_name in self.dataloaders: self.dataloaders[dataset_name] = \ DataLoader(dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers, drop_last=True) return iter(self.dataloaders[dataset_name]) def save_checkpoints(self, name): ckp_path = os.path.join(self.model_dir, name) obj = { 'net': self.net.state_dict(), 'clock': self.step, 'opt': self.opt.state_dict(), } torch.save(obj, ckp_path) def load_checkpoints(self, name): ckp_path = os.path.join(self.model_dir, name) try: obj = torch.load(ckp_path) logger.info('Load checkpoint %s' % ckp_path) except FileNotFoundError: logger.info('No checkpoint %s!!' % ckp_path) return self.net.load_state_dict(obj['net']) self.opt.load_state_dict(obj['opt']) self.step = obj['clock'] self.sche.last_epoch = self.step def inf_batch(self, name, batch): O, B = batch['O'].cuda(), batch['B'].cuda() O, B = Variable(O, requires_grad=False), Variable(B, requires_grad=False) R = O - B O_Rs = self.net(O) loss_list = [self.crit(O_R, R) for O_R in O_Rs] ssim_list = [self.ssim(O - O_R, O - R) for O_R in O_Rs] if name == 'train': self.net.zero_grad() sum(loss_list).backward() self.opt.step() losses = { 'loss%d' % i: loss.item() for i, loss in enumerate(loss_list) } ssimes = { 'ssim%d' % i: ssim.item() for i, ssim in enumerate(ssim_list) } losses.update(ssimes) self.write(name, losses) return O - O_Rs[-1] def save_image(self, name, img_lists): data, pred, label = img_lists pred = pred.cpu().data data, label, pred = data * 255, label * 255, pred * 255 pred = np.clip(pred, 0, 255) h, w = pred.shape[-2:] gen_num = (6, 2) img = np.zeros((gen_num[0] * h, gen_num[1] * 3 * w, 3)) for img_list in img_lists: for i in range(gen_num[0]): row = i * h for j in range(gen_num[1]): idx = i * gen_num[1] + j tmp_list = [data[idx], pred[idx], label[idx]] for k in range(3): col = (j * 3 + k) * w tmp = np.transpose(tmp_list[k], (1, 2, 0)) img[row:row + h, col:col + w] = tmp img_file = os.path.join(self.log_dir, '%d_%s.jpg' % (self.step, name)) cv2.imwrite(img_file, img)
class Session: def __init__(self): self.log_dir = settings.log_dir self.model_dir = settings.model_dir ensure_dir(settings.log_dir) ensure_dir(settings.model_dir) logger.info('set log dir as %s' % settings.log_dir) logger.info('set model dir as %s' % settings.model_dir) # self.net = RESCAN().cuda() if len(settings.device_id) >1: self.net = nn.DataParallel(RESCAN()).cuda() else: self.net = RESCAN().cuda() self.crit = MSELoss().cuda() self.ssim = SSIM().cuda() self.step = 0 self.save_steps = settings.save_steps self.num_workers = settings.num_workers self.batch_size = settings.batch_size self.writers = {} self.dataloaders = {} self.opt = Adam(self.net.parameters(), lr=settings.lr) self.sche = MultiStepLR(self.opt, milestones=[240000, 320000], gamma=0.1) def tensorboard(self, name): self.writers[name] = SummaryWriter(os.path.join(self.log_dir, name + '.events')) return self.writers[name] def write(self, name, out): for k, v in out.items(): self.writers[name].add_scalar(k, v, self.step) out['lr'] = self.opt.param_groups[0]['lr'] out['step'] = self.step outputs = [ "{}:{:.4g}".format(k, v) for k, v in out.items() ] logger.info(name + '--' + ' '.join(outputs)) def get_dataloader(self, dataset_name): dataset = TrainValDataset(dataset_name) if not dataset_name in self.dataloaders: self.dataloaders[dataset_name] = \ DataLoader(dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers, drop_last=True) return iter(self.dataloaders[dataset_name]) def save_checkpoints(self, name): ckp_path = os.path.join(self.model_dir, name) obj = { 'net': self.net.state_dict(), 'clock': self.step, 'opt': self.opt.state_dict(), } torch.save(obj, ckp_path) def load_checkpoints(self, name): ckp_path = os.path.join(self.model_dir, name) try: logger.info('Load checkpoint %s' % ckp_path) obj = torch.load(ckp_path) except FileNotFoundError: logger.info('No checkpoint %s!!' % ckp_path) return self.net.load_state_dict(obj['net'] ) self.opt.load_state_dict(obj['opt']) self.step = obj['clock'] self.sche.last_epoch = self.step def print_network(self, model): num_params = 0 for p in model.parameters(): num_params += p.numel() # 1. torch.numel() 返回一个tensor变量内所有元素个数,可以理解为矩阵内元素的个数 print(model) print("The number of parameters: {}".format(num_params)) def inf_batch(self, name, batch): if name == 'train': self.net.zero_grad() if self.step==0: self.print_network(self.net) O, B = batch['O'].cuda(), batch['B'].cuda() O, B = Variable(O, requires_grad=False), Variable(B, requires_grad=False) O_Rs = self.net(O) loss_list = [self.crit(O_Rs[0],B)] ssim_list = [self.ssim(O_Rs[0],B)] #ssim_list = [self.ssim(O - O_R, O - R) for O_R in O_Rs] #vgg_gt = self.vgg.forward(B) #path1_vgg_eval = self.vgg.forward(O_Rs[0]) #loss1 = [self.crit(path1_vgg_eval[m], vgg_gt[m]) for m in range(len(vgg_gt))] if name == 'train': loss = (loss_list[0]) loss.backward() self.opt.step() losses = { 'loss%d' % i: loss.item() for i, loss in enumerate(loss_list) } ssimes = { 'ssim%d' % i: ssim.item() for i, ssim in enumerate(ssim_list) } losses.update(ssimes) self.write(name, losses) return O_Rs[-1] def save_image(self, name, img_lists): data, pred, label = img_lists # data 有雨图,pred 网络学习的无雨图,label:clear gt pred = pred.cpu().data # pred 网络预测的无雨图 data, label, pred = data * 255, label * 255, pred * 255 pred = np.clip(pred, 0, 255) h, w = pred.shape[-2:] # c, h, w #gen_num = (6, 2) gen_num = (3, 1) img = np.zeros((gen_num[0] * h, gen_num[1] * 3 * w, 3)) for img_list in img_lists: for i in range(gen_num[0]): row = i * h for j in range(gen_num[1]): idx = i * gen_num[1] + j print(idx) tmp_list = [data[idx], pred[idx], label[idx]] for k in range(3): col = (j * 3 + k) * w tmp = np.transpose(tmp_list[k], (1, 2, 0)) img[row: row+h, col: col+w] = tmp img_file = os.path.join(self.log_dir, '%d_%s.jpg' % (self.step, name)) cv2.imwrite(img_file, img)