def __init__(self): self.log_dir = settings.log_dir self.model_dir = settings.model_dir self.ssim_loss = settings.ssim_loss ensure_dir(settings.log_dir) ensure_dir(settings.model_dir) ensure_dir('../log_test') logger.info('set log dir as %s' % settings.log_dir) logger.info('set model dir as %s' % settings.model_dir) if len(settings.device_id) > 1: self.net = nn.DataParallel(ODE_DerainNet()).cuda() else: torch.cuda.set_device(settings.device_id[0]) self.net = ODE_DerainNet().cuda() self.l1 = nn.L1Loss().cuda() self.mse = nn.MSELoss().cuda() self.ssim = SSIM().cuda() self.step = 0 self.save_steps = settings.save_steps self.num_workers = settings.num_workers self.batch_size = settings.batch_size self.writers = {} self.dataloaders = {} self.opt_net = Adam(self.net.parameters(), lr=settings.lr) self.sche_net = MultiStepLR(self.opt_net, milestones=[settings.l1, settings.l2], gamma=0.1)
class Session: def __init__(self): self.log_dir = settings.log_dir self.model_dir = settings.model_dir ensure_dir(settings.log_dir) ensure_dir(settings.model_dir) logger.info('set log dir as %s' % settings.log_dir) logger.info('set model dir as %s' % settings.model_dir) if len(settings.device_id) > 1: self.net = nn.DataParallel(ODE_DerainNet()).cuda() else: torch.cuda.set_device(settings.device_id[0]) self.net = ODE_DerainNet().cuda() self.l2 = MSELoss().cuda() self.l1 = nn.L1Loss().cuda() self.ssim = SSIM().cuda() self.dataloaders = {} def get_dataloader(self, dataset_name): dataset = TestDataset(dataset_name) if not dataset_name in self.dataloaders: self.dataloaders[dataset_name] = \ DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1, drop_last=False) return self.dataloaders[dataset_name] def load_checkpoints(self, name): ckp_path = os.path.join(self.model_dir, name) try: obj = torch.load(ckp_path) logger.info('Load checkpoint %s' % ckp_path) except FileNotFoundError: logger.info('No checkpoint %s!!' % ckp_path) return self.net.load_state_dict(obj['net']) def loss_vgg(self, input, groundtruth): vgg_gt = self.vgg.forward(groundtruth) eval = self.vgg.forward(input) loss_vgg = [self.l1(eval[m], vgg_gt[m]) for m in range(len(vgg_gt))] loss = sum(loss_vgg) return loss def inf_batch(self, name, batch): O, B = batch['O'].cuda(), batch['B'].cuda() O, B = Variable(O, requires_grad=False), Variable(B, requires_grad=False) with torch.no_grad(): derain = self.net(O) l1_loss = self.l1(derain, B) ssim = self.ssim(derain, B) psnr = PSNR(derain.data.cpu().numpy() * 255, B.data.cpu().numpy() * 255) losses = {'L1 loss': l1_loss} ssimes = {'ssim': ssim} losses.update(ssimes) return losses, psnr
def __init__(self): self.log_dir = settings.log_dir self.model_dir = settings.model_dir ensure_dir(settings.log_dir) ensure_dir(settings.model_dir) logger.info('set log dir as %s' % settings.log_dir) logger.info('set model dir as %s' % settings.model_dir) if len(settings.device_id) > 1: self.net = nn.DataParallel(ODE_DerainNet()).cuda() else: torch.cuda.set_device(settings.device_id[0]) self.net = ODE_DerainNet().cuda() self.l2 = MSELoss().cuda() self.l1 = nn.L1Loss().cuda() self.ssim = SSIM().cuda() self.dataloaders = {}
def __init__(self): self.show_dir = settings.show_dir self.model_dir = settings.model_dir ensure_dir(settings.show_dir) ensure_dir(settings.model_dir) logger.info('set show dir as %s' % settings.show_dir) logger.info('set model dir as %s' % settings.model_dir) if len(settings.device_id) > 1: self.net = nn.DataParallel(ODE_DerainNet()).cuda() #self.l2 = nn.DataParallel(MSELoss(),settings.device_id) #self.l1 = nn.DataParallel(nn.L1Loss(),settings.device_id) #self.ssim = nn.DataParallel(SSIM(),settings.device_id) #self.vgg = nn.DataParallel(VGG(),settings.device_id) else: torch.cuda.set_device(settings.device_id[0]) self.net = ODE_DerainNet().cuda() self.ssim = SSIM().cuda() self.dataloaders = {} self.ssim = SSIM().cuda() self.a = 0 self.t = 0
class Session: def __init__(self): self.log_dir = settings.log_dir self.model_dir = settings.model_dir self.ssim_loss = settings.ssim_loss ensure_dir(settings.log_dir) ensure_dir(settings.model_dir) ensure_dir('../log_test') logger.info('set log dir as %s' % settings.log_dir) logger.info('set model dir as %s' % settings.model_dir) if len(settings.device_id) > 1: self.net = nn.DataParallel(ODE_DerainNet()).cuda() else: torch.cuda.set_device(settings.device_id[0]) self.net = ODE_DerainNet().cuda() self.l1 = nn.L1Loss().cuda() self.mse = nn.MSELoss().cuda() self.ssim = SSIM().cuda() self.step = 0 self.save_steps = settings.save_steps self.num_workers = settings.num_workers self.batch_size = settings.batch_size self.writers = {} self.dataloaders = {} self.opt_net = Adam(self.net.parameters(), lr=settings.lr) self.sche_net = MultiStepLR(self.opt_net, milestones=[settings.l1, settings.l2], gamma=0.1) def tensorboard(self, name): self.writers[name] = SummaryWriter( os.path.join(self.log_dir, name + '.events')) return self.writers[name] def write(self, name, out): for k, v in out.items(): self.writers[name].add_scalar(k, v, self.step) out['lr'] = self.opt_net.param_groups[0]['lr'] out['step'] = self.step outputs = ["{}:{:.4g}".format(k, v) for k, v in out.items()] logger.info(name + '--' + ' '.join(outputs)) def get_dataloader(self, dataset_name): dataset = TrainValDataset(dataset_name) if not dataset_name in self.dataloaders: self.dataloaders[dataset_name] = \ DataLoader(dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers, drop_last=True) return iter(self.dataloaders[dataset_name]) def get_test_dataloader(self, dataset_name): dataset = TestDataset(dataset_name) if not dataset_name in self.dataloaders: self.dataloaders[dataset_name] = \ DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1, drop_last=False) return self.dataloaders[dataset_name] def save_checkpoints_net(self, name): ckp_path = os.path.join(self.model_dir, name) obj = { 'net': self.net.state_dict(), 'clock_net': self.step, 'opt_net': self.opt_net.state_dict(), } torch.save(obj, ckp_path) def load_checkpoints_net(self, name): ckp_path = os.path.join(self.model_dir, name) try: logger.info('Load checkpoint %s' % ckp_path) obj = torch.load(ckp_path) except FileNotFoundError: logger.info('No checkpoint %s!!' % ckp_path) return self.net.load_state_dict(obj['net']) self.opt_net.load_state_dict(obj['opt_net']) self.step = obj['clock_net'] self.sche_net.last_epoch = self.step def print_network(self, model): num_params = 0 for p in model.parameters(): num_params += p.numel() print(model) print("The number of parameters: {}".format(num_params)) def inf_batch(self, name, batch): if name == 'train': self.net.zero_grad() if self.step == 0: self.print_network(self.net) O, B = batch['O'].cuda(), batch['B'].cuda() O, B = Variable(O, requires_grad=False), Variable(B, requires_grad=False) derain = self.net(O) l1_loss = self.l1(derain, B) mse_loss = self.mse(derain, B) ssim = self.ssim(derain, B) if self.ssim_loss == True: loss = mse_loss else: loss = mse_loss if name == 'train': loss.backward() self.opt_net.step() losses = {'L1loss': l1_loss} ssimes = {'ssim': ssim} losses.update(ssimes) self.write(name, losses) return derain def save_image(self, name, img_lists): data, pred, label = img_lists pred = pred.cpu().data data, label, pred = data * 255, label * 255, pred * 255 pred = np.clip(pred, 0, 255) h, w = pred.shape[-2:] gen_num = (1, 1) img = np.zeros((gen_num[0] * h, gen_num[1] * 3 * w, 3)) for img_list in img_lists: for i in range(gen_num[0]): row = i * h for j in range(gen_num[1]): idx = i * gen_num[1] + j tmp_list = [data[idx], pred[idx], label[idx]] for k in range(3): col = (j * 3 + k) * w tmp = np.transpose(tmp_list[k], (1, 2, 0)) img[row:row + h, col:col + w] = tmp img_file = os.path.join(self.log_dir, '%d_%s.jpg' % (self.step, name)) cv2.imwrite(img_file, img) def inf_batch_test(self, name, batch): O, B = batch['O'].cuda(), batch['B'].cuda() O, B = Variable(O, requires_grad=False), Variable(B, requires_grad=False) with torch.no_grad(): derain = self.net(O) l1_loss = self.l1(derain, B) ssim = self.ssim(derain, B) psnr = PSNR(derain.data.cpu().numpy() * 255, B.data.cpu().numpy() * 255) losses = {'L1 loss': l1_loss} ssimes = {'ssim': ssim} losses.update(ssimes) return l1_loss.data.cpu().numpy(), ssim.data.cpu().numpy(), psnr
class Session: def __init__(self): self.show_dir = settings.show_dir self.model_dir = settings.model_dir ensure_dir(settings.show_dir) ensure_dir(settings.model_dir) logger.info('set show dir as %s' % settings.show_dir) logger.info('set model dir as %s' % settings.model_dir) if len(settings.device_id) > 1: self.net = nn.DataParallel(ODE_DerainNet()).cuda() #self.l2 = nn.DataParallel(MSELoss(),settings.device_id) #self.l1 = nn.DataParallel(nn.L1Loss(),settings.device_id) #self.ssim = nn.DataParallel(SSIM(),settings.device_id) #self.vgg = nn.DataParallel(VGG(),settings.device_id) else: torch.cuda.set_device(settings.device_id[0]) self.net = ODE_DerainNet().cuda() self.ssim = SSIM().cuda() self.dataloaders = {} self.ssim = SSIM().cuda() self.a = 0 self.t = 0 def get_dataloader(self, dataset_name): dataset = ShowDataset(dataset_name) self.dataloaders[dataset_name] = \ DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1) return self.dataloaders[dataset_name] def load_checkpoints(self, name): ckp_path = os.path.join(self.model_dir, name) try: obj = torch.load(ckp_path) logger.info('Load checkpoint %s' % ckp_path) except FileNotFoundError: logger.info('No checkpoint %s!!' % ckp_path) return self.net.load_state_dict(obj['net']) def inf_batch(self, name, batch, i): O, B, file_name = batch['O'].cuda(), batch['B'].cuda( ), batch['file_name'] file_name = str(file_name[0]) O, B = Variable(O, requires_grad=False), Variable(B, requires_grad=False) with torch.no_grad(): import time t0 = time.time() derain = self.net(O) t1 = time.time() comput_time = t1 - t0 print(comput_time) ssim = self.ssim(derain, B).data.cpu().numpy() psnr = PSNR(derain.data.cpu().numpy() * 255, B.data.cpu().numpy() * 255) print('psnr:%4f-------------ssim:%4f' % (psnr, ssim)) return derain, psnr, ssim, file_name def save_image(self, No, imgs, name, psnr, ssim, file_name): for i, img in enumerate(imgs): img = (img.cpu().data * 255).numpy() img = np.clip(img, 0, 255) img = np.transpose(img, (1, 2, 0)) h, w, c = img.shape img_file = os.path.join(self.show_dir, '%s.png' % (file_name)) print(img_file) cv2.imwrite(img_file, img)