Beispiel #1
0
    def predict_deraining(self):
        # Deraining
        self.DNet.eval()

        current_data_list = [self.test_data, self.test_data_semi
                             ] if self.train_path_semi else [
                                 self.test_data,
                             ]
        for kk, currrent_data in enumerate(current_data_list):
            num_frame = currrent_data.shape[1]
            test_data_derain = torch.zeros(
                currrent_data.shape)  # c x n x p x p
            for ii in range(ceil(num_frame / self.truncate_test)):
                start_ind = ii * self.truncate_test
                end_ind = min((ii + 1) * self.truncate_test, num_frame)
                inputs = currrent_data[:, start_ind:end_ind, ].cuda(
                )  # c x truncate x p x p
                with torch.set_grad_enabled(False):
                    out = self.DNet(inputs.unsqueeze(0)).clamp_(0.0,
                                                                1.0).squeeze(0)
                test_data_derain[:, start_ind:end_ind, ] = out.cpu()

                if len(current_data_list) == 2 and kk == 1:
                    x1 = vutils.make_grid(inputs.permute([1, 0, 2, 3]),
                                          normalize=True,
                                          scale_each=True)
                    self.writer.add_image('Test Rainy Image', x1,
                                          self.log_im_step['test'])
                    x2 = vutils.make_grid(out.permute([1, 0, 2, 3]),
                                          normalize=True,
                                          scale_each=True)
                    self.writer.add_image('Test Deained Image', x2,
                                          self.log_im_step['test'])
                    self.log_im_step['test'] += 1
                else:
                    if random.randint(1, 10) == 1:
                        x1 = vutils.make_grid(inputs.permute([1, 0, 2, 3]),
                                              normalize=True,
                                              scale_each=True)
                        self.writer.add_image('Test Rainy Image', x1,
                                              self.log_im_step['test'])
                        x2 = vutils.make_grid(out.permute([1, 0, 2, 3]),
                                              normalize=True,
                                              scale_each=True)
                        self.writer.add_image('Test Deained Image', x2,
                                              self.log_im_step['test'])
                        self.log_im_step['test'] += 1

            if kk == 0:
                self.psnrm = batch_PSNR(
                    test_data_derain[:, 2:-2, ].permute([1, 0, 2, 3]),
                    self.test_gt[:, 2:-2, ].permute([1, 0, 2, 3]),
                    ycbcr=False)
                self.ssimm = batch_SSIM(
                    test_data_derain[:, 2:-2, ].permute([1, 0, 2, 3]),
                    self.test_gt[:, 2:-2, ].permute([1, 0, 2, 3]),
                    ycbcr=False)
Beispiel #2
0
def train_model(net, datasets, optimizer, lr_scheduler, criterion):
    clip_grad_D = args.clip_grad_D
    clip_grad_S = args.clip_grad_S
    batch_size = {'train':args.batch_size, 'test_SIDD':4}
    data_loader = {phase:torch.utils.data.DataLoader(datasets[phase], batch_size=batch_size[phase],
         shuffle=True, num_workers=args.num_workers, pin_memory=True) for phase in _modes}
    num_data = {phase:len(datasets[phase]) for phase in _modes}
    num_iter_epoch = {phase: ceil(num_data[phase] / batch_size[phase]) for phase in _modes}
    if args.resume:
        step = args.step
        step_img = args.step_img
    else:
        step = 0
        step_img = {x:0 for x in _modes}
    param_D = [x for name, x in net.named_parameters() if 'dnet' in name.lower()]
    param_S = [x for name, x in net.named_parameters() if 'snet' in name.lower()]
    writer = SummaryWriter(args.log_dir)
    for epoch in range(args.epoch_start, args.epochs):
        loss_per_epoch = {x:0 for x in ['Loss', 'lh', 'KLG', 'KLIG']}
        mse_per_epoch = {x:0 for x in _modes}
        grad_norm_D = grad_norm_S = 0
        tic = time.time()
        # train stage
        net.train()
        lr = optimizer.param_groups[0]['lr']
        if lr < _lr_min:
            sys.exit('Reach the minimal learning rate')
        phase = 'train'
        for ii, data in enumerate(data_loader[phase]):
            im_noisy, im_gt, sigmaMap, eps2 = [x.cuda() for x in data]
            optimizer.zero_grad()
            phi_Z, phi_sigma = net(im_noisy, 'train')
            loss, g_lh, kl_g, kl_Igam = criterion(phi_Z, phi_sigma, im_noisy, im_gt, sigmaMap,
                                                                           eps2, radius=args.radius)

            loss.backward()
            # clip the gradient norm of D-Net
            total_norm_D = nn.utils.clip_grad_norm_(param_D, clip_grad_D)
            grad_norm_D = (grad_norm_D*(ii/(ii+1)) + total_norm_D/(ii+1))
            # clip the gradient norm of S-Net
            total_norm_S = nn.utils.clip_grad_norm_(param_S, clip_grad_S)
            grad_norm_S = (grad_norm_S*(ii/(ii+1)) + total_norm_S/(ii+1))
            optimizer.step()

            loss_per_epoch['Loss'] += loss.item() / num_iter_epoch[phase]
            loss_per_epoch['lh'] += g_lh.item() / num_iter_epoch[phase]
            loss_per_epoch['KLG'] += kl_g.item() / num_iter_epoch[phase]
            loss_per_epoch['KLIG'] += kl_Igam.item() / num_iter_epoch[phase]
            im_denoise = im_noisy-phi_Z[:, :_C, ].detach().data
            im_denoise.clamp_(0.0, 1.0)
            mse = F.mse_loss(im_denoise, im_gt)
            mse_per_epoch[phase] += mse
            if (ii+1) % args.print_freq == 0:
                log_str = '[Epoch:{:>2d}/{:<2d}] {:s}:{:0>4d}/{:0>4d}, lh={:+4.2f}, ' + \
                                 'KLG={:+>7.2f}, KLIG={:+>6.2f}, mse={:.2e}, GD:{:.1e}/{:.1e}, ' + \
                                                                       'GS:{:.1e}/{:.1e}, lr={:.1e}'
                print(log_str.format(epoch+1, args.epochs, phase, ii+1, num_iter_epoch[phase],
                                         g_lh.item(), kl_g.item(), kl_Igam.item(), mse, clip_grad_D,
                                                       total_norm_D, clip_grad_S, total_norm_S, lr))
                writer.add_scalar('Train Loss Iter', loss.item(), step)
                writer.add_scalar('Train MSE Iter', mse, step)
                step += 1
            if (ii+1) % (20*args.print_freq) == 0:
                alpha = torch.exp(phi_sigma[:, :_C,])
                beta = torch.exp(phi_sigma[:, _C:,])
                sigmaMap_pred = beta / (alpha-1)
                x1 = vutils.make_grid(im_denoise, normalize=True, scale_each=True)
                writer.add_image(phase+' Denoised images', x1, step_img[phase])
                x2 = vutils.make_grid(im_gt, normalize=True, scale_each=True)
                writer.add_image(phase+' GroundTruth', x2, step_img[phase])
                x3 = vutils.make_grid(sigmaMap_pred, normalize=True, scale_each=True)
                writer.add_image(phase+' Predict Sigma', x3, step_img[phase])
                x4 = vutils.make_grid(sigmaMap, normalize=True, scale_each=True)
                writer.add_image(phase+' Groundtruth Sigma', x4, step_img[phase])
                x5 = vutils.make_grid(im_noisy, normalize=True, scale_each=True)
                writer.add_image(phase+' Noisy Image', x5, step_img[phase])
                step_img[phase] += 1

        mse_per_epoch[phase] /= (ii+1)
        log_str ='{:s}: Loss={:+.2e}, lh={:+.2e}, KL_Guass={:+.2e}, KLIG={:+.2e}, mse={:.3e}, ' + \
                                                      'GNorm_D={:.1e}/{:.1e}, GNorm_S={:.1e}/{:.1e}'
        print(log_str.format(phase, loss_per_epoch['Loss'], loss_per_epoch['lh'],
                                loss_per_epoch['KLG'], loss_per_epoch['KLIG'], mse_per_epoch[phase],
                                                clip_grad_D, grad_norm_D, clip_grad_S, grad_norm_S))
        writer.add_scalar('Loss_epoch', loss_per_epoch['Loss'], epoch)
        clip_grad_D = min(clip_grad_D, grad_norm_D)
        clip_grad_S = min(clip_grad_S, grad_norm_S)
        print('-'*150)

        # test stage
        net.eval()
        psnr_per_epoch = {x:0 for x in _modes[1:]}
        ssim_per_epoch = {x:0 for x in _modes[1:]}
        for phase in _modes[1:]:
            for ii, data in enumerate(data_loader[phase]):
                im_noisy, im_gt = [x.cuda() for x in data]
                with torch.set_grad_enabled(False):
                    phi_Z, phi_sigma = net(im_noisy, 'train')

                im_denoise = im_noisy-phi_Z[:, :_C, ].data
                im_denoise.clamp_(0.0, 1.0)
                mse = F.mse_loss(im_denoise, im_gt)
                mse_per_epoch[phase] += mse
                psnr_iter = batch_PSNR(im_denoise, im_gt)
                ssim_iter = batch_SSIM(im_denoise, im_gt)
                psnr_per_epoch[phase] += psnr_iter
                ssim_per_epoch[phase] += ssim_iter
                # print statistics every log_interval mini_batches
                if (ii+1) % 20 == 0:
                    log_str = '[Epoch:{:>2d}/{:<2d}] {:s}:{:0>3d}/{:0>3d}, mse={:.2e}, ' + \
                                                                        'psnr={:4.2f}, ssim={:5.4f}'
                    print(log_str.format(epoch+1, args.epochs, phase, ii+1, num_iter_epoch[phase],
                                                                         mse, psnr_iter, ssim_iter))
                    # tensorboardX summary
                    alpha = torch.exp(phi_sigma[:, :_C,])
                    beta = torch.exp(phi_sigma[:, _C:,])
                    sigmaMap_pred = beta / (alpha-1)
                    x1 = vutils.make_grid(im_denoise, normalize=True, scale_each=True)
                    writer.add_image(phase+' Denoised images', x1, step_img[phase])
                    x2 = vutils.make_grid(im_gt, normalize=True, scale_each=True)
                    writer.add_image(phase+' GroundTruth', x2, step_img[phase])
                    x3 = vutils.make_grid(sigmaMap_pred, normalize=True, scale_each=True)
                    writer.add_image(phase+' Predict Sigma', x3, step_img[phase])
                    x5 = vutils.make_grid(im_noisy, normalize=True, scale_each=True)
                    writer.add_image(phase+' Noisy Image', x5, step_img[phase])
                    step_img[phase] += 1

            psnr_per_epoch[phase] /= (ii+1)
            ssim_per_epoch[phase] /= (ii+1)
            mse_per_epoch[phase] /= (ii+1)
            log_str ='{:s}: mse={:.3e}, PSNR={:4.2f}, SSIM={:5.4f}'
            print(log_str.format(phase, mse_per_epoch[phase], psnr_per_epoch[phase],
                                                                            ssim_per_epoch[phase]))
            print('-'*90)

        # adjust the learning rate
        lr_scheduler.step()
        # save model
        if (epoch+1) % args.save_model_freq == 0 or epoch+1==args.epochs:
            model_prefix = 'model_'
            save_path_model = os.path.join(args.model_dir, model_prefix+str(epoch+1))
            torch.save({
                'epoch': epoch+1,
                'step': step+1,
                'step_img': {x:step_img[x] for x in _modes},
                'grad_norm_D': clip_grad_D,
                'grad_norm_S': clip_grad_S,
                'model_state_dict': net.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'lr_scheduler_state_dict': lr_scheduler.state_dict()
                }, save_path_model)
            model_state_prefix = 'model_state_'
            save_path_model_state = os.path.join(args.model_dir, model_state_prefix+str(epoch+1))
            torch.save(net.state_dict(), save_path_model_state)

        writer.add_scalars('MSE_epoch', mse_per_epoch, epoch)
        writer.add_scalars('PSNR_epoch_test', psnr_per_epoch, epoch)
        writer.add_scalars('SSIM_epoch_test', ssim_per_epoch, epoch)
        toc = time.time()
        print('This epoch take time {:.2f}'.format(toc-tic))
    writer.close()
    print('Reach the maximal epochs! Finish training')
Beispiel #3
0
                x[ii]
                for x in [inds_ext_start, inds_ext_end, inds_start, inds_end]
            ]
            inputs = rain_data[:, :, start_ext:end_ext, :, :].cuda()
            out_temp = model(inputs)
            if ii == 0:
                derain_data[0, :,
                            start:end, ] = out_temp[:, :, :-2, ].cpu().clamp_(
                                0.0, 1.0)
            elif (ii + 1) == len(inds_start):
                derain_data[0, :, start:end, ] = out_temp[:, :,
                                                          2:, ].cpu().clamp_(
                                                              0.0, 1.0)
            else:
                derain_data[0, :, start:end, ] = out_temp[:, :,
                                                          2:-2, ].cpu().clamp_(
                                                              0.0, 1.0)

    derain_data = derain_data[:, :, 2:-2, ].squeeze(0).permute([1, 0, 2, 3])
    gt_data = gt_data[2:-2, ]
    psnrm_y = batch_PSNR(derain_data, gt_data, ycbcr=True)
    psnr_all_y.append(psnrm_y)
    ssimm_y = batch_SSIM(derain_data, gt_data, ycbcr=True)
    ssim_all_y.append(ssimm_y)
    print('Type:{:s}, PSNR:{:5.2f}, SSIM:{:6.4f}'.format(
        current_type, psnrm_y, ssimm_y))

mean_psnr_y = sum(psnr_all_y) / len(rain_types)
mean_ssim_y = sum(ssim_all_y) / len(rain_types)
print('MPSNR:{:5.2f}, MSSIM:{:6.4f}'.format(mean_psnr_y, mean_ssim_y))
Beispiel #4
0
def train_model(net, netG, datasets, optimizer, lr_scheduler, args):
    NReal = ceil(args['batch_size'] / (1 + args['fake_ratio']))
    batch_size = {'train': args['batch_size'], 'val': 4}
    data_loader = {
        phase: uData.DataLoader(datasets[phase],
                                batch_size=batch_size[phase],
                                shuffle=True,
                                num_workers=args['num_workers'],
                                pin_memory=True)
        for phase in _modes
    }
    num_data = {phase: len(datasets[phase]) for phase in _modes}
    num_iter_epoch = {
        phase: ceil(num_data[phase] / batch_size[phase])
        for phase in _modes
    }
    step = args['step'] if args['resume'] else 0
    step_img = args['step_img'] if args['resume'] else {x: 0 for x in _modes}
    writer = SummaryWriter(str(Path(args['log_dir'])))
    clip_grad = args['clip_normD']
    for epoch in range(args['epoch_start'], args['epochs']):
        mae_per_epoch = {x: 0 for x in _modes}
        tic = time.time()
        # train stage
        net.train()
        lr = optimizer.param_groups[0]['lr']
        grad_mean = 0
        phase = 'train'
        for ii, data in enumerate(data_loader[phase]):
            im_noisy, im_gt = [x.cuda() for x in data]
            with torch.autograd.no_grad():
                im_noisy[NReal:, ] = sample_generator(netG, im_gt[NReal:, ])
                im_noisy[NReal:, ].clamp_(0.0, 1.0)
            optimizer.zero_grad()
            im_denoise = im_noisy - net(im_noisy)
            loss = F.l1_loss(im_denoise, im_gt,
                             reduction='sum') / im_gt.shape[0]

            # backpropagation
            loss.backward()
            # clip the grad
            total_grad = nn.utils.clip_grad_norm_(net.parameters(), clip_grad)
            grad_mean = grad_mean * ii / (ii + 1) + total_grad / (ii + 1)
            optimizer.step()

            mae_iter = loss.item() / (im_gt.shape[1] * im_gt.shape[2] *
                                      im_gt.shape[3])
            mae_per_epoch[phase] += mae_iter
            if (ii + 1) % args['print_freq'] == 0:
                template = '[Epoch:{:>2d}/{:<2d}] {:s}:{:0>5d}/{:0>5d}, Loss={:5.2e}, ' + \
                                                                     'Grad:{:.2e}/{:.2e}, lr={:.2e}'
                print(
                    template.format(epoch + 1, args['epochs'], phase, ii + 1,
                                    num_iter_epoch[phase], mae_iter, clip_grad,
                                    total_grad, lr))
                writer.add_scalar('Train Loss Iter', mae_iter, step)
                step += 1
            if (ii + 1) % (20 * args['print_freq']) == 0:
                x1 = vutils.make_grid(im_denoise,
                                      normalize=True,
                                      scale_each=True)
                writer.add_image(phase + ' Denoised images', x1,
                                 step_img[phase])
                x2 = vutils.make_grid(im_gt, normalize=True, scale_each=True)
                writer.add_image(phase + ' GroundTruth', x2, step_img[phase])
                x3 = vutils.make_grid(im_noisy,
                                      normalize=True,
                                      scale_each=True)
                writer.add_image(phase + ' Noisy Image', x3, step_img[phase])
                step_img[phase] += 1

        mae_per_epoch[phase] /= (ii + 1)
        clip_grad = min(grad_mean, clip_grad)
        print('{:s}: Loss={:+.2e}, grad_mean={:.2e}'.format(
            phase, mae_per_epoch[phase], grad_mean))
        print('-' * 100)

        # test stage
        net.eval()
        psnr_per_epoch = ssim_per_epoch = 0
        phase = 'val'
        for ii, data in enumerate(data_loader[phase]):
            im_noisy, im_gt = [x.cuda() for x in data]
            with torch.set_grad_enabled(False):
                im_denoise = im_noisy - net(im_noisy)

            im_denoise.clamp_(0.0, 1.0)
            mae_iter = F.l1_loss(im_denoise, im_gt)
            mae_per_epoch[phase] += mae_iter
            psnr_iter = batch_PSNR(im_denoise, im_gt)
            psnr_per_epoch += psnr_iter
            ssim_iter = batch_SSIM(im_denoise, im_gt)
            ssim_per_epoch += ssim_iter
            # print statistics every log_interval mini_batches
            if (ii + 1) % 50 == 0:
                log_str = '[Epoch:{:>2d}/{:<2d}] {:s}:{:0>3d}/{:0>3d}, mae={:.2e}, ' + \
                                                                    'psnr={:4.2f}, ssim={:5.4f}'
                print(
                    log_str.format(epoch + 1, args['epochs'], phase, ii + 1,
                                   num_iter_epoch[phase], mae_iter, psnr_iter,
                                   ssim_iter))
                # tensorboard summary
                x1 = vutils.make_grid(im_denoise,
                                      normalize=True,
                                      scale_each=True)
                writer.add_image(phase + ' Denoised images', x1,
                                 step_img[phase])
                x2 = vutils.make_grid(im_gt, normalize=True, scale_each=True)
                writer.add_image(phase + ' GroundTruth', x2, step_img[phase])
                x5 = vutils.make_grid(im_noisy,
                                      normalize=True,
                                      scale_each=True)
                writer.add_image(phase + ' Noisy Image', x5, step_img[phase])
                step_img[phase] += 1

        psnr_per_epoch /= (ii + 1)
        ssim_per_epoch /= (ii + 1)
        mae_per_epoch[phase] /= (ii + 1)
        print('{:s}: mse={:.3e}, PSNR={:4.2f}, SSIM={:5.4f}'.format(
            phase, mae_per_epoch[phase], psnr_per_epoch, ssim_per_epoch))
        print('-' * 100)

        # adjust the learning rate
        lr_scheduler.step()
        # save model
        save_path_model = str(
            Path(args['model_dir']) / ('model_' + str(epoch + 1)))
        torch.save(
            {
                'epoch': epoch + 1,
                'step': step + 1,
                'step_img': {x: step_img[x] + 1
                             for x in _modes},
                'clip_grad': clip_grad,
                'model_state_dict': net.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'lr_scheduler_state_dict': lr_scheduler.state_dict()
            }, save_path_model)
        save_path_model = str(
            Path(args['model_dir']) /
            ('model_state_' + str(epoch + 1) + '.pt'))
        torch.save(net.state_dict(), save_path_model)

        writer.add_scalars('MAE_epoch', mae_per_epoch, epoch)
        writer.add_scalar('Val PSNR epoch', psnr_per_epoch, epoch)
        writer.add_scalar('Val SSIM epoch', ssim_per_epoch, epoch)
        toc = time.time()
        print('This epoch take time {:.2f}'.format(toc - tic))
    writer.close()
    print('Reach the maximal epochs! Finish training')