Exemple #1
0
def train_step_P(net, x, y, optimizerP, args):
    alpha = args['alpha']
    batch_size = x.shape[0]
    # zero the gradient
    net['P'].zero_grad()
    # raal data
    real_data = torch.cat([x, y], 1)
    real_loss = net['P'](real_data).mean()
    # generator fake data
    with torch.autograd.no_grad():
        fake_y = sample_generator(net['G'], x)
        fake_y_data = torch.cat([x, fake_y], 1)
    fake_y_loss = net['P'](fake_y_data.data).mean()
    grad_y_loss = gradient_penalty(real_data, fake_y_data, net['P'],
                                   args['lambda_gp'])
    loss_y = alpha * (fake_y_loss - real_loss)
    loss_yg = alpha * grad_y_loss
    # Denoiser fake data
    with torch.autograd.no_grad():
        fake_x = y - net['D'](y)
        fake_x_data = torch.cat([fake_x, y], 1)
    fake_x_loss = net['P'](fake_x_data.data).mean()
    grad_x_loss = gradient_penalty(real_data, fake_x_data, net['P'],
                                   args['lambda_gp'])
    loss_x = (1 - alpha) * (fake_x_loss - real_loss)
    loss_xg = (1 - alpha) * grad_x_loss
    loss = loss_x + loss_xg + loss_y + loss_yg
    # backward
    loss.backward()
    optimizerP.step()

    return loss, loss_x, loss_xg, loss_y, loss_yg
Exemple #2
0
def train_step_G(net, x, y, optimizerG, args):
    alpha = args['alpha']
    batch_size = x.shape[0]
    # zero the gradient
    net['G'].zero_grad()
    fake_y = sample_generator(net['G'], x)
    loss_mean = args['tau_G'] * mean_match(x, y, fake_y, kernel.to(x.device),
                                           _C)
    fake_y_data = torch.cat([x, fake_y], 1)
    fake_y_loss = net['P'](fake_y_data).mean()
    loss_y = -alpha * fake_y_loss
    loss = loss_y + loss_mean
    # backward
    loss.backward()
    optimizerG.step()

    return loss, loss_y, loss_mean, fake_y.data
Exemple #3
0
def train_step_G(net, x, y, optimizerG, args): # Noise generator - residual only
    alpha = args['alpha']
    batch_size = x.shape[0]
    # zero the gradient
    net['G'].zero_grad()
    fake_y = sample_generator(net['G'], x)
    ##################
    x_ = x[:, 1, :, :].unsqueeze(1)
    y_ = y[:, 1, :, :].unsqueeze(1)
    fake_y_ = fake_y[:, 1, :, :].unsqueeze(1)
    ##################
    loss_mean = args['tau_G'] * mean_match(x_.repeat(1, 3, 1, 1), y_.repeat(1, 3, 1, 1), fake_y_.repeat(1, 3, 1, 1),
                                           kernel.to(x_.repeat(1, 3, 1, 1).device), _C)
    fake_y_data = torch.cat([x_, fake_y_], 1)
    fake_y_loss = net['P'](fake_y_data).mean()
    loss_y = -alpha * fake_y_loss
    loss = loss_y + loss_mean
    # backward
    loss.backward()
    optimizerG.step()

    return loss, loss_y, loss_mean, fake_y.data
Exemple #4
0
def train_step_P(net, x, y, optimizerP, args): # Discriminator
    ##################
    x_ = x[:, 1, :, :].unsqueeze(1)
    y_ = y[:, 1, :, :].unsqueeze(1)
    ##################
    alpha = args['alpha']
    batch_size =x.shape[0]
    # zero the gradient
    net['P'].zero_grad()
    # raal data
    real_data = torch.cat([x_,y_], 1) ### x<-1, y<-1
    real_loss = net['P'](real_data).mean()
    # generator fake data
    with torch.autograd.no_grad():
        fake_y = sample_generator(net['G'], x) ### x<-3
        fake_y = fake_y[:, 1, :, :].unsqueeze(1)
        fake_y_data = torch.cat([x_, fake_y], 1)
    fake_y_loss = net['P'](fake_y_data.data).mean() ### <<<<<<<<<<
    grad_y_loss = gradient_penalty(real_data, fake_y_data, net['P'], args['lambda_gp'])
    loss_y = alpha * (fake_y_loss - real_loss)
    loss_yg = alpha * grad_y_loss
    # Denoiser fake data
    with torch.autograd.no_grad():
        fake_x = y - net['D'](y) ### <<<<<<<<<<
        fake_x = fake_x[:, 1, :, :].unsqueeze(1)
        fake_x_data = torch.cat([fake_x, y_], 1)
    fake_x_loss = net['P'](fake_x_data.data).mean() ### <<<<<<<<<<
    grad_x_loss = gradient_penalty(real_data, fake_x_data, net['P'], args['lambda_gp'])
    loss_x = (1-alpha) * (fake_x_loss - real_loss)
    loss_xg = (1-alpha) * grad_x_loss
    loss = loss_x + loss_xg + loss_y + loss_yg
    # backward
    loss.backward()
    optimizerP.step()

    return loss, loss_x, loss_xg, loss_y, loss_yg
Exemple #5
0
# load the pretrained model
net.load_state_dict(
    torch.load('./model_states/DANet.pt', map_location='cpu')['G'])

# read the images
im_noisy_real = loadmat('./test_data/SIDD/noisy.mat')['im_noisy']
im_gt = loadmat('./test_data/SIDD/gt.mat')['im_gt']

# denoising
inputs = torch.from_numpy(img_as_float32(im_gt).transpose(
    [2, 0, 1])).unsqueeze(0).cuda()
with torch.autograd.no_grad():
    padunet = PadUNet(inputs, dep_U=5)
    inputs_pad = padunet.pad()
    outputs_pad = sample_generator(net, inputs_pad)
    outputs = padunet.pad_inverse(outputs_pad)
    outputs.clamp_(0.0, 1.0)

im_noisy_fake = img_as_ubyte(outputs.cpu().numpy()[0, ].transpose([1, 2, 0]))

plt.subplot(1, 3, 1)
plt.imshow(im_gt)
plt.title('Gt Image')
plt.axis('off')
plt.subplot(1, 3, 2)
plt.imshow(im_noisy_real)
plt.title('Real Noisy Image')
plt.axis('off')
plt.subplot(1, 3, 3)
plt.imshow(im_noisy_fake)
Exemple #6
0
def train_model(net, netG, datasets, optimizer, lr_scheduler, args):
    NReal = ceil(args['batch_size'] / (1 + args['fake_ratio']))
    batch_size = {'train': args['batch_size'], 'val': 4}
    data_loader = {
        phase: uData.DataLoader(datasets[phase],
                                batch_size=batch_size[phase],
                                shuffle=True,
                                num_workers=args['num_workers'],
                                pin_memory=True)
        for phase in _modes
    }
    num_data = {phase: len(datasets[phase]) for phase in _modes}
    num_iter_epoch = {
        phase: ceil(num_data[phase] / batch_size[phase])
        for phase in _modes
    }
    step = args['step'] if args['resume'] else 0
    step_img = args['step_img'] if args['resume'] else {x: 0 for x in _modes}
    writer = SummaryWriter(str(Path(args['log_dir'])))
    clip_grad = args['clip_normD']
    for epoch in range(args['epoch_start'], args['epochs']):
        mae_per_epoch = {x: 0 for x in _modes}
        tic = time.time()
        # train stage
        net.train()
        lr = optimizer.param_groups[0]['lr']
        grad_mean = 0
        phase = 'train'
        for ii, data in enumerate(data_loader[phase]):
            im_noisy, im_gt = [x.cuda() for x in data]
            with torch.autograd.no_grad():
                im_noisy[NReal:, ] = sample_generator(netG, im_gt[NReal:, ])
                im_noisy[NReal:, ].clamp_(0.0, 1.0)
            optimizer.zero_grad()
            im_denoise = im_noisy - net(im_noisy)
            loss = F.l1_loss(im_denoise, im_gt,
                             reduction='sum') / im_gt.shape[0]

            # backpropagation
            loss.backward()
            # clip the grad
            total_grad = nn.utils.clip_grad_norm_(net.parameters(), clip_grad)
            grad_mean = grad_mean * ii / (ii + 1) + total_grad / (ii + 1)
            optimizer.step()

            mae_iter = loss.item() / (im_gt.shape[1] * im_gt.shape[2] *
                                      im_gt.shape[3])
            mae_per_epoch[phase] += mae_iter
            if (ii + 1) % args['print_freq'] == 0:
                template = '[Epoch:{:>2d}/{:<2d}] {:s}:{:0>5d}/{:0>5d}, Loss={:5.2e}, ' + \
                                                                     'Grad:{:.2e}/{:.2e}, lr={:.2e}'
                print(
                    template.format(epoch + 1, args['epochs'], phase, ii + 1,
                                    num_iter_epoch[phase], mae_iter, clip_grad,
                                    total_grad, lr))
                writer.add_scalar('Train Loss Iter', mae_iter, step)
                step += 1
            if (ii + 1) % (20 * args['print_freq']) == 0:
                x1 = vutils.make_grid(im_denoise,
                                      normalize=True,
                                      scale_each=True)
                writer.add_image(phase + ' Denoised images', x1,
                                 step_img[phase])
                x2 = vutils.make_grid(im_gt, normalize=True, scale_each=True)
                writer.add_image(phase + ' GroundTruth', x2, step_img[phase])
                x3 = vutils.make_grid(im_noisy,
                                      normalize=True,
                                      scale_each=True)
                writer.add_image(phase + ' Noisy Image', x3, step_img[phase])
                step_img[phase] += 1

        mae_per_epoch[phase] /= (ii + 1)
        clip_grad = min(grad_mean, clip_grad)
        print('{:s}: Loss={:+.2e}, grad_mean={:.2e}'.format(
            phase, mae_per_epoch[phase], grad_mean))
        print('-' * 100)

        # test stage
        net.eval()
        psnr_per_epoch = ssim_per_epoch = 0
        phase = 'val'
        for ii, data in enumerate(data_loader[phase]):
            im_noisy, im_gt = [x.cuda() for x in data]
            with torch.set_grad_enabled(False):
                im_denoise = im_noisy - net(im_noisy)

            im_denoise.clamp_(0.0, 1.0)
            mae_iter = F.l1_loss(im_denoise, im_gt)
            mae_per_epoch[phase] += mae_iter
            psnr_iter = batch_PSNR(im_denoise, im_gt)
            psnr_per_epoch += psnr_iter
            ssim_iter = batch_SSIM(im_denoise, im_gt)
            ssim_per_epoch += ssim_iter
            # print statistics every log_interval mini_batches
            if (ii + 1) % 50 == 0:
                log_str = '[Epoch:{:>2d}/{:<2d}] {:s}:{:0>3d}/{:0>3d}, mae={:.2e}, ' + \
                                                                    'psnr={:4.2f}, ssim={:5.4f}'
                print(
                    log_str.format(epoch + 1, args['epochs'], phase, ii + 1,
                                   num_iter_epoch[phase], mae_iter, psnr_iter,
                                   ssim_iter))
                # tensorboard summary
                x1 = vutils.make_grid(im_denoise,
                                      normalize=True,
                                      scale_each=True)
                writer.add_image(phase + ' Denoised images', x1,
                                 step_img[phase])
                x2 = vutils.make_grid(im_gt, normalize=True, scale_each=True)
                writer.add_image(phase + ' GroundTruth', x2, step_img[phase])
                x5 = vutils.make_grid(im_noisy,
                                      normalize=True,
                                      scale_each=True)
                writer.add_image(phase + ' Noisy Image', x5, step_img[phase])
                step_img[phase] += 1

        psnr_per_epoch /= (ii + 1)
        ssim_per_epoch /= (ii + 1)
        mae_per_epoch[phase] /= (ii + 1)
        print('{:s}: mse={:.3e}, PSNR={:4.2f}, SSIM={:5.4f}'.format(
            phase, mae_per_epoch[phase], psnr_per_epoch, ssim_per_epoch))
        print('-' * 100)

        # adjust the learning rate
        lr_scheduler.step()
        # save model
        save_path_model = str(
            Path(args['model_dir']) / ('model_' + str(epoch + 1)))
        torch.save(
            {
                'epoch': epoch + 1,
                'step': step + 1,
                'step_img': {x: step_img[x] + 1
                             for x in _modes},
                'clip_grad': clip_grad,
                'model_state_dict': net.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'lr_scheduler_state_dict': lr_scheduler.state_dict()
            }, save_path_model)
        save_path_model = str(
            Path(args['model_dir']) /
            ('model_state_' + str(epoch + 1) + '.pt'))
        torch.save(net.state_dict(), save_path_model)

        writer.add_scalars('MAE_epoch', mae_per_epoch, epoch)
        writer.add_scalar('Val PSNR epoch', psnr_per_epoch, epoch)
        writer.add_scalar('Val SSIM epoch', ssim_per_epoch, epoch)
        toc = time.time()
        print('This epoch take time {:.2f}'.format(toc - tic))
    writer.close()
    print('Reach the maximal epochs! Finish training')
Exemple #7
0
net = UNetG(3, wf=32, depth=5).cuda()

# load the pretrained model
net.load_state_dict(
    torch.load('./model_states/DANet.pt', map_location='cpu')['G'])

# read the images
im_noisy_real = loadmat('./test_data/SIDD/noisy.mat')['im_noisy']
im_gt = loadmat('./test_data/SIDD/gt.mat')['im_gt']

L = 50
AKLD = 0
im_noisy_real = torch.from_numpy(
    img_as_float32(im_noisy_real).transpose([2, 0, 1])).unsqueeze(0).cuda()
im_gt = torch.from_numpy(img_as_float32(im_gt).transpose(
    [2, 0, 1])).unsqueeze(0).cuda()
sigma_real = estimate_sigma_gauss(im_noisy_real, im_gt)
with torch.autograd.no_grad():
    padunet = PadUNet(im_gt, dep_U=5)
    im_gt_pad = padunet.pad()
    for _ in range(L):
        outputs_pad = sample_generator(net, im_gt_pad)
        im_noisy_fake = padunet.pad_inverse(outputs_pad)
        im_noisy_fake.clamp_(0.0, 1.0)
        sigma_fake = estimate_sigma_gauss(im_noisy_fake, im_gt)
        kl_dis = kl_gauss_zero_center(sigma_fake, sigma_real)
        AKLD += kl_dis

AKLD /= L
print("AKLD value: {:.3f}".format(AKLD))