示例#1
0
 def __init__(self):
     self.log_dir = settings.log_dir
     self.model_dir = settings.model_dir
     self.ssim_loss = settings.ssim_loss
     ensure_dir(settings.log_dir)
     ensure_dir(settings.model_dir)
     ensure_dir('../log_test')
     logger.info('set log dir as %s' % settings.log_dir)
     logger.info('set model dir as %s' % settings.model_dir)
     if len(settings.device_id) > 1:
         self.net = nn.DataParallel(ODE_DerainNet()).cuda()
     else:
         torch.cuda.set_device(settings.device_id[0])
         self.net = ODE_DerainNet().cuda()
     self.l1 = nn.L1Loss().cuda()
     self.mse = nn.MSELoss().cuda()
     self.ssim = SSIM().cuda()
     self.step = 0
     self.save_steps = settings.save_steps
     self.num_workers = settings.num_workers
     self.batch_size = settings.batch_size
     self.writers = {}
     self.dataloaders = {}
     self.opt_net = Adam(self.net.parameters(), lr=settings.lr)
     self.sche_net = MultiStepLR(self.opt_net,
                                 milestones=[settings.l1, settings.l2],
                                 gamma=0.1)
示例#2
0
class Session:
    def __init__(self):
        self.log_dir = settings.log_dir
        self.model_dir = settings.model_dir
        ensure_dir(settings.log_dir)
        ensure_dir(settings.model_dir)
        logger.info('set log dir as %s' % settings.log_dir)
        logger.info('set model dir as %s' % settings.model_dir)
        if len(settings.device_id) > 1:
            self.net = nn.DataParallel(ODE_DerainNet()).cuda()
        else:
            torch.cuda.set_device(settings.device_id[0])
            self.net = ODE_DerainNet().cuda()
        self.l2 = MSELoss().cuda()
        self.l1 = nn.L1Loss().cuda()
        self.ssim = SSIM().cuda()
        self.dataloaders = {}

    def get_dataloader(self, dataset_name):
        dataset = TestDataset(dataset_name)
        if not dataset_name in self.dataloaders:
            self.dataloaders[dataset_name] = \
                    DataLoader(dataset, batch_size=1,
                            shuffle=False, num_workers=1, drop_last=False)
        return self.dataloaders[dataset_name]

    def load_checkpoints(self, name):
        ckp_path = os.path.join(self.model_dir, name)
        try:
            obj = torch.load(ckp_path)
            logger.info('Load checkpoint %s' % ckp_path)
        except FileNotFoundError:
            logger.info('No checkpoint %s!!' % ckp_path)
            return
        self.net.load_state_dict(obj['net'])

    def loss_vgg(self, input, groundtruth):
        vgg_gt = self.vgg.forward(groundtruth)
        eval = self.vgg.forward(input)
        loss_vgg = [self.l1(eval[m], vgg_gt[m]) for m in range(len(vgg_gt))]
        loss = sum(loss_vgg)
        return loss

    def inf_batch(self, name, batch):
        O, B = batch['O'].cuda(), batch['B'].cuda()
        O, B = Variable(O, requires_grad=False), Variable(B,
                                                          requires_grad=False)

        with torch.no_grad():
            derain = self.net(O)

        l1_loss = self.l1(derain, B)
        ssim = self.ssim(derain, B)
        psnr = PSNR(derain.data.cpu().numpy() * 255,
                    B.data.cpu().numpy() * 255)
        losses = {'L1 loss': l1_loss}
        ssimes = {'ssim': ssim}
        losses.update(ssimes)

        return losses, psnr
示例#3
0
 def __init__(self):
     self.log_dir = settings.log_dir
     self.model_dir = settings.model_dir
     ensure_dir(settings.log_dir)
     ensure_dir(settings.model_dir)
     logger.info('set log dir as %s' % settings.log_dir)
     logger.info('set model dir as %s' % settings.model_dir)
     if len(settings.device_id) > 1:
         self.net = nn.DataParallel(ODE_DerainNet()).cuda()
     else:
         torch.cuda.set_device(settings.device_id[0])
         self.net = ODE_DerainNet().cuda()
     self.l2 = MSELoss().cuda()
     self.l1 = nn.L1Loss().cuda()
     self.ssim = SSIM().cuda()
     self.dataloaders = {}
示例#4
0
    def __init__(self):
        self.show_dir = settings.show_dir
        self.model_dir = settings.model_dir
        ensure_dir(settings.show_dir)
        ensure_dir(settings.model_dir)
        logger.info('set show dir as %s' % settings.show_dir)
        logger.info('set model dir as %s' % settings.model_dir)

        if len(settings.device_id) > 1:
            self.net = nn.DataParallel(ODE_DerainNet()).cuda()
            #self.l2 = nn.DataParallel(MSELoss(),settings.device_id)
            #self.l1 = nn.DataParallel(nn.L1Loss(),settings.device_id)
            #self.ssim = nn.DataParallel(SSIM(),settings.device_id)
            #self.vgg = nn.DataParallel(VGG(),settings.device_id)
        else:
            torch.cuda.set_device(settings.device_id[0])
            self.net = ODE_DerainNet().cuda()
        self.ssim = SSIM().cuda()
        self.dataloaders = {}
        self.ssim = SSIM().cuda()
        self.a = 0
        self.t = 0
示例#5
0
class Session:
    def __init__(self):
        self.log_dir = settings.log_dir
        self.model_dir = settings.model_dir
        self.ssim_loss = settings.ssim_loss
        ensure_dir(settings.log_dir)
        ensure_dir(settings.model_dir)
        ensure_dir('../log_test')
        logger.info('set log dir as %s' % settings.log_dir)
        logger.info('set model dir as %s' % settings.model_dir)
        if len(settings.device_id) > 1:
            self.net = nn.DataParallel(ODE_DerainNet()).cuda()
        else:
            torch.cuda.set_device(settings.device_id[0])
            self.net = ODE_DerainNet().cuda()
        self.l1 = nn.L1Loss().cuda()
        self.mse = nn.MSELoss().cuda()
        self.ssim = SSIM().cuda()
        self.step = 0
        self.save_steps = settings.save_steps
        self.num_workers = settings.num_workers
        self.batch_size = settings.batch_size
        self.writers = {}
        self.dataloaders = {}
        self.opt_net = Adam(self.net.parameters(), lr=settings.lr)
        self.sche_net = MultiStepLR(self.opt_net,
                                    milestones=[settings.l1, settings.l2],
                                    gamma=0.1)

    def tensorboard(self, name):
        self.writers[name] = SummaryWriter(
            os.path.join(self.log_dir, name + '.events'))
        return self.writers[name]

    def write(self, name, out):
        for k, v in out.items():
            self.writers[name].add_scalar(k, v, self.step)
        out['lr'] = self.opt_net.param_groups[0]['lr']
        out['step'] = self.step
        outputs = ["{}:{:.4g}".format(k, v) for k, v in out.items()]
        logger.info(name + '--' + ' '.join(outputs))

    def get_dataloader(self, dataset_name):
        dataset = TrainValDataset(dataset_name)
        if not dataset_name in self.dataloaders:
            self.dataloaders[dataset_name] = \
                    DataLoader(dataset, batch_size=self.batch_size,
                            shuffle=True, num_workers=self.num_workers, drop_last=True)
        return iter(self.dataloaders[dataset_name])

    def get_test_dataloader(self, dataset_name):
        dataset = TestDataset(dataset_name)
        if not dataset_name in self.dataloaders:
            self.dataloaders[dataset_name] = \
                    DataLoader(dataset, batch_size=1,
                            shuffle=False, num_workers=1, drop_last=False)
        return self.dataloaders[dataset_name]

    def save_checkpoints_net(self, name):
        ckp_path = os.path.join(self.model_dir, name)
        obj = {
            'net': self.net.state_dict(),
            'clock_net': self.step,
            'opt_net': self.opt_net.state_dict(),
        }
        torch.save(obj, ckp_path)

    def load_checkpoints_net(self, name):
        ckp_path = os.path.join(self.model_dir, name)
        try:
            logger.info('Load checkpoint %s' % ckp_path)
            obj = torch.load(ckp_path)
        except FileNotFoundError:
            logger.info('No checkpoint %s!!' % ckp_path)
            return
        self.net.load_state_dict(obj['net'])
        self.opt_net.load_state_dict(obj['opt_net'])
        self.step = obj['clock_net']
        self.sche_net.last_epoch = self.step

    def print_network(self, model):
        num_params = 0
        for p in model.parameters():
            num_params += p.numel()
        print(model)
        print("The number of parameters: {}".format(num_params))

    def inf_batch(self, name, batch):
        if name == 'train':
            self.net.zero_grad()
        if self.step == 0:
            self.print_network(self.net)

        O, B = batch['O'].cuda(), batch['B'].cuda()
        O, B = Variable(O, requires_grad=False), Variable(B,
                                                          requires_grad=False)
        derain = self.net(O)
        l1_loss = self.l1(derain, B)
        mse_loss = self.mse(derain, B)
        ssim = self.ssim(derain, B)
        if self.ssim_loss == True:
            loss = mse_loss
        else:
            loss = mse_loss
        if name == 'train':
            loss.backward()
            self.opt_net.step()
        losses = {'L1loss': l1_loss}
        ssimes = {'ssim': ssim}
        losses.update(ssimes)
        self.write(name, losses)

        return derain

    def save_image(self, name, img_lists):
        data, pred, label = img_lists
        pred = pred.cpu().data

        data, label, pred = data * 255, label * 255, pred * 255
        pred = np.clip(pred, 0, 255)
        h, w = pred.shape[-2:]
        gen_num = (1, 1)
        img = np.zeros((gen_num[0] * h, gen_num[1] * 3 * w, 3))
        for img_list in img_lists:
            for i in range(gen_num[0]):
                row = i * h
                for j in range(gen_num[1]):
                    idx = i * gen_num[1] + j
                    tmp_list = [data[idx], pred[idx], label[idx]]
                    for k in range(3):
                        col = (j * 3 + k) * w
                        tmp = np.transpose(tmp_list[k], (1, 2, 0))
                        img[row:row + h, col:col + w] = tmp
        img_file = os.path.join(self.log_dir, '%d_%s.jpg' % (self.step, name))
        cv2.imwrite(img_file, img)

    def inf_batch_test(self, name, batch):
        O, B = batch['O'].cuda(), batch['B'].cuda()
        O, B = Variable(O, requires_grad=False), Variable(B,
                                                          requires_grad=False)

        with torch.no_grad():
            derain = self.net(O)

        l1_loss = self.l1(derain, B)
        ssim = self.ssim(derain, B)
        psnr = PSNR(derain.data.cpu().numpy() * 255,
                    B.data.cpu().numpy() * 255)
        losses = {'L1 loss': l1_loss}
        ssimes = {'ssim': ssim}
        losses.update(ssimes)

        return l1_loss.data.cpu().numpy(), ssim.data.cpu().numpy(), psnr
示例#6
0
class Session:
    def __init__(self):
        self.show_dir = settings.show_dir
        self.model_dir = settings.model_dir
        ensure_dir(settings.show_dir)
        ensure_dir(settings.model_dir)
        logger.info('set show dir as %s' % settings.show_dir)
        logger.info('set model dir as %s' % settings.model_dir)

        if len(settings.device_id) > 1:
            self.net = nn.DataParallel(ODE_DerainNet()).cuda()
            #self.l2 = nn.DataParallel(MSELoss(),settings.device_id)
            #self.l1 = nn.DataParallel(nn.L1Loss(),settings.device_id)
            #self.ssim = nn.DataParallel(SSIM(),settings.device_id)
            #self.vgg = nn.DataParallel(VGG(),settings.device_id)
        else:
            torch.cuda.set_device(settings.device_id[0])
            self.net = ODE_DerainNet().cuda()
        self.ssim = SSIM().cuda()
        self.dataloaders = {}
        self.ssim = SSIM().cuda()
        self.a = 0
        self.t = 0

    def get_dataloader(self, dataset_name):
        dataset = ShowDataset(dataset_name)
        self.dataloaders[dataset_name] = \
                    DataLoader(dataset, batch_size=1,
                            shuffle=False, num_workers=1)
        return self.dataloaders[dataset_name]

    def load_checkpoints(self, name):
        ckp_path = os.path.join(self.model_dir, name)
        try:
            obj = torch.load(ckp_path)
            logger.info('Load checkpoint %s' % ckp_path)
        except FileNotFoundError:
            logger.info('No checkpoint %s!!' % ckp_path)
            return
        self.net.load_state_dict(obj['net'])

    def inf_batch(self, name, batch, i):
        O, B, file_name = batch['O'].cuda(), batch['B'].cuda(
        ), batch['file_name']
        file_name = str(file_name[0])
        O, B = Variable(O, requires_grad=False), Variable(B,
                                                          requires_grad=False)
        with torch.no_grad():
            import time
            t0 = time.time()
            derain = self.net(O)
            t1 = time.time()
            comput_time = t1 - t0
            print(comput_time)
            ssim = self.ssim(derain, B).data.cpu().numpy()
            psnr = PSNR(derain.data.cpu().numpy() * 255,
                        B.data.cpu().numpy() * 255)
            print('psnr:%4f-------------ssim:%4f' % (psnr, ssim))
            return derain, psnr, ssim, file_name

    def save_image(self, No, imgs, name, psnr, ssim, file_name):
        for i, img in enumerate(imgs):
            img = (img.cpu().data * 255).numpy()
            img = np.clip(img, 0, 255)
            img = np.transpose(img, (1, 2, 0))
            h, w, c = img.shape

            img_file = os.path.join(self.show_dir, '%s.png' % (file_name))
            print(img_file)
            cv2.imwrite(img_file, img)