def forward_loss(self, gt, hazy, args):

        results_forward = self.forward(gt, hazy)    
        rec_gt, rec_hazy_free = results_forward["rec_gt"], results_forward["rec_hazy_free"]


        losses = dict()

        teacher_recons_loss = self.teacher_l1loss(rec_gt, gt)
        
        student_recons_loss = self.student_l1loss(rec_hazy_free, gt)

        gt_perceptual_features = self.vgg19(gt)
        reconstructed_perceptual_features = self.vgg19(rec_hazy_free)

        perceptual_loss = 0.0

        # Sum up perceptual loss taken from different layers of VGG19 
        for idx, (gt_feat, rec_feat) in enumerate(zip(gt_perceptual_features, reconstructed_perceptual_features)):
            
            perceptual_loss += self.perceptual_loss(rec_feat, gt_feat)


        # TODO ADD MIMICKING LOSS
        dehazing_loss = student_recons_loss + args.lambda_p * perceptual_loss

        
        # Scale between 0 - 1
        rec_hazy_free = rec_hazy_free + 1
        gt = gt + 1

        psnr_loss = psnr(rec_hazy_free, gt)
        ssim_loss = ssim(rec_hazy_free, gt)



        losses["teacher_rec_loss"] = teacher_recons_loss
        losses["student_rec_loss"] = student_recons_loss

        losses["perceptual_loss"] = perceptual_loss
        losses["dehazing_loss"] = dehazing_loss

        losses["loss_psnr"] = psnr_loss
        losses["loss_ssim"] = ssim_loss



        self.teacher_scheduler.step(teacher_recons_loss)
        self.student_scheduler.step(student_recons_loss)

        return losses
Beispiel #2
0
def train_epoch(device, model, data_loader, optimizer, loss_fn, epoch):
    model.train()
    tq = tqdm.tqdm(total=len(data_loader) * args.batch_size)
    tq.set_description(
        f'Train: Epoch {epoch:4}, LR: {optimizer.param_groups[0]["lr"]:0.6f}')
    train_loss, train_ssim, train_psnr = 0, 0, 0
    for batch_idx, (data, target) in enumerate(data_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        prediction = model(data)
        loss = loss_fn(prediction, target)
        if args.amp:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()
        torch.nn.utils.clip_grad_value_(model.parameters(), args.clip)
        torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
        optimizer.step()
        with torch.no_grad():
            train_loss += loss.item() * (1 / len(data_loader))
            if 'temp' in args.type:
                prediction, target = prediction[:, prediction.size(1) //
                                                2].squeeze(
                                                    1
                                                ), target[:,
                                                          target.size(1) //
                                                          2].squeeze(1)
            train_ssim += losses.ssim(prediction,
                                      target).item() * (1 / len(data_loader))
            train_psnr += losses.psnr(prediction,
                                      target).item() * (1 / len(data_loader))
        tq.update(args.batch_size)
        tq.set_postfix(
            loss=f'{train_loss*len(data_loader)/(batch_idx+1):4.6f}',
            ssim=f'{train_ssim*len(data_loader)/(batch_idx+1):.4f}',
            psnr=f'{train_psnr*len(data_loader)/(batch_idx+1):4.4f}')
    tq.close()
    writer.add_scalar('Loss/train', train_loss, epoch)
    writer.add_scalar('SSIM/train', train_ssim, epoch)
    writer.add_scalar('PSNR/train', train_psnr, epoch)
Beispiel #3
0
def eval_epoch(device, model, data_loader, loss_fn, epoch):
    model.eval()
    tq = tqdm.tqdm(total=len(data_loader))
    tq.set_description(f'Test:  Epoch {epoch:4}')
    eval_loss, eval_ssim, eval_psnr = 0, 0, 0
    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(data_loader):
            data, target = data.to(device), target.to(device)
            prediction = model(data)
            eval_loss += loss_fn(prediction,
                                 target).item() * (1 / len(data_loader))
            if 'temp' in args.type:
                prediction, target = prediction[:, prediction.size(1) //
                                                2].squeeze(
                                                    1
                                                ), target[:,
                                                          target.size(1) //
                                                          2].squeeze(1)
            eval_ssim += losses.ssim(prediction,
                                     target).item() * (1 / len(data_loader))
            eval_psnr += losses.psnr(prediction,
                                     target).item() * (1 / len(data_loader))
            tq.update()
            tq.set_postfix(
                loss=f'{eval_loss*len(data_loader)/(batch_idx+1):4.6f}',
                ssim=f'{eval_ssim*len(data_loader)/(batch_idx+1):.4f}',
                psnr=f'{eval_psnr*len(data_loader)/(batch_idx+1):4.4f}')
    tq.close()
    writer.add_scalar('Loss/test', eval_loss, epoch)
    writer.add_scalar('SSIM/test', eval_ssim, epoch)
    writer.add_scalar('PSNR/test', eval_psnr, epoch)
    if epoch % 10 == 0:
        if 'temp' in args.type: data = data[:, data.size(1) // 2].squeeze(1)
        writer.add_image(f'Prediction/test',
                         torch.clamp(
                             torch.cat(
                                 (data[-1, 0:3], prediction[-1], target[-1]),
                                 dim=-1), 0, 1),
                         epoch,
                         dataformats='CHW')
    return eval_loss
    def backward(self, gt, hazy, args):

        results_forward = self.forward(gt, hazy)
        rec_gt, rec_hazy_free = results_forward["rec_gt"], results_forward["rec_hazy_free"]


        losses = dict()

        teacher_recons_loss = self.teacher_l1loss(rec_gt, gt)

        self.teacher_optimizer.zero_grad()
        teacher_recons_loss.backward()
        self.teacher_optimizer.step()

        
        student_recons_loss = self.student_l1loss(rec_hazy_free, gt)

        gt_perceptual_features = self.vgg19(gt)
        reconstructed_perceptual_features = self.vgg19(rec_hazy_free)

        perceptual_loss = 0.0

        # Sum up perceptual loss taken from different layers of VGG19 
        for idx, (gt_feat, rec_feat) in enumerate(zip(gt_perceptual_features, reconstructed_perceptual_features)):
            
            perceptual_loss += self.perceptual_loss(rec_feat, gt_feat)

        mimicking_loss = 0.0

        for idx, (gt_mimicking, rec_mimicking) in enumerate(zip(self.teacher.forward_mimicking_features(gt), self.student.forward_mimicking_features(hazy))):

            mimicking_loss += self.mimicking_loss(gt_mimicking, rec_mimicking)



        self.student_optimizer.zero_grad()
        dehazing_loss = student_recons_loss + args.lambda_p * perceptual_loss + args.lambda_rm * mimicking_loss
        dehazing_loss.backward()
        self.student_optimizer.step()
        
        # Scale between 0 - 1
        rec_hazy_free = rec_hazy_free + 1
        gt = gt + 1

        psnr_loss = psnr(rec_hazy_free, gt)
        ssim_loss = ssim(rec_hazy_free, gt)



        losses["teacher_rec_loss"] = teacher_recons_loss
        losses["student_rec_loss"] = student_recons_loss

        losses["perceptual_loss"] = perceptual_loss
        losses["dehazing_loss"] = dehazing_loss

        losses["loss_psnr"] = psnr_loss
        losses["loss_ssim"] = ssim_loss


        self.teacher_scheduler.step(teacher_recons_loss)
        self.student_scheduler.step(student_recons_loss)

        return losses
Beispiel #5
0
import numpy as np
from keras import backend as K
from keras.losses import mean_absolute_error

from losses import ssim, photometric_consistency_loss

x = np.ones((1, 10, 10, 1))
y = np.ones((1, 10, 10, 1))

x_img1 = K.variable(x)
y_img1 = K.variable(y)

ssim1 = ssim(x_img1, y_img1)

assert K.eval(ssim1).all() == np.zeros((1, 10, 10)).all()

x_img2 = K.variable(255 * x)
y_img2 = K.variable(-255 * y)

ssim2 = ssim(x_img2, y_img2)

assert K.eval(ssim2).all() == np.ones((1, 10, 10)).all()

pcl = photometric_consistency_loss(x_img1, y_img1)

assert K.eval(pcl).all() == np.zeros((1, 10, 10)).all()
Beispiel #6
0
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            prediction = parallel_model(data)
            loss = loss_fn(prediction, target)
            if args.amp:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()
            torch.nn.utils.clip_grad_value_(parallel_model.parameters(), args.clip)
            torch.nn.utils.clip_grad_norm_(parallel_model.parameters(), args.clip)
            optimizer.step()
            with torch.no_grad():
                train_loss += loss.item() * (1/len(train_loader))
                train_ssim += losses.ssim(prediction, target).item() * (1/len(train_loader))
                train_psnr += losses.psnr(prediction, target).item() * (1/len(train_loader))
            tq.update(args.batch_size)
            tq.set_postfix(loss=f'{train_loss*len(train_loader)/(batch_idx+1):4.6f}',
                    ssim=f'{train_ssim*len(train_loader)/(batch_idx+1):.4f}',
                    psnr=f'{train_psnr*len(train_loader)/(batch_idx+1):4.4f}')
        tq.close()
        writer.add_scalar('Loss/train', train_loss, epoch)
        writer.add_scalar('SSIM/train', train_ssim, epoch)
        writer.add_scalar('PSNR/train', train_psnr, epoch)
        scheduler.step(train_loss)

        # -----------------------------------------
        # save checkpoint for best loss

        if train_loss < best_loss:
def benchmark(pred, true, check_all=False, check=["mse", "magn", "imcon"]):
    '''
    Benchmark parcours to evaluate quality of predictions.
    Checks include MSE, MAE and SSIM applied directly, using phase-correlation and cross-correllation.
    Sharpness and the magnitude error are tested
    and image constraints are checked.
    Method expects numpy arrays
    '''    
    pred_signal = np.real(pred)
    true_signal = np.real(true)
    
    checks = [e.lower() for e in check]
    
    pred_croco = register_croco(pred_signal, true_signal)
    pred_phaco = register_phaco(pred_signal, true_signal)
    
    markdown = ""
    
    print("Signal error:")
    if "mse" in checks or check_all:
        _mse = mse(pred_signal, true_signal)
        markdown = markdown + " {:.{}f} |".format(_mse[0], 4 + math.floor(-math.log10(_mse[0])))
        print("  MSE:        {}, std: {}".format(*_mse))
    if "mae" in checks or check_all:
        _mae = mae(pred_signal, true_signal)
        markdown = markdown + " {:.{}f} |".format(_mae[0], 4 + math.floor(-math.log10(_mae[0])))
        print("  MAE:        {}, std: {}".format(*_mae))
    if "ssim" in checks or check_all:
        _ssim = ssim(pred_signal, true_signal)
        markdown = markdown + " {:.{}f} |".format(_ssim[0], 4 + math.floor(-math.log10(_ssim[0])))
        print("  SSIM:       {}, std: {}".format(*_ssim))
    if "sharpness" in checks or check_all:
        _sharpness = sharp_dist(pred_croco, true_signal)
        markdown = markdown + " {:.{}f} |".format(_sharpness[0], 4 + math.floor(-math.log10(_sharpness[0])))
        print("  Sharpness:  {}, std: {}".format(*_sharpness))
    if "phaco" in checks or check_all:
        print("=============================PHACO=============================")
        _fasimse = mse(pred_phaco, true_signal)
        markdown = markdown + " {:.{}f} |".format(_fasimse[0], 4 + math.floor(-math.log10(_fasimse[0])))
        print("  PhaCo-MSE:  {}, std: {}".format(*_fasimse))
        _fasimae = mae(pred_phaco, true_signal)
        markdown = markdown + " {:.{}f} |".format(_fasimae[0], 4 + math.floor(-math.log10(_fasimae[0])))
        print("  PhaCo-MAE:  {}, std: {}".format(*_fasimae))
        _fasissim = ssim(pred_phaco, true_signal)
        markdown = markdown + " {:.{}f} |".format(_fasissim[0], 4 + math.floor(-math.log10(_fasissim[0])))
        print("  PhaCo-SSIM: {}, std: {}".format(*_fasissim))
    if "croco" in checks or check_all:
        print("=============================CROCO=============================")
        _crocomse = mse(pred_croco, true_signal)
        markdown = markdown + " {:.{}f} |".format(_crocomse[0], 4 + math.floor(-math.log10(_crocomse[0])))
        print("  CroCo-MSE:  {}, std: {}".format(*_crocomse))
        _crocomae = mae(pred_croco, true_signal)
        markdown = markdown + " {:.{}f} |".format(_crocomae[0], 4 + math.floor(-math.log10(_crocomae[0])))
        print("  CroCo-MAE:  {}, std: {}".format(*_crocomae))
        _crocossim = ssim(pred_croco, true_signal)
        markdown = markdown + " {:.{}f} |".format(_crocossim[0], 4 + math.floor(-math.log10(_crocossim[0])))
        print("  CroCo-SSIM: {}, std: {}".format(*_crocossim))
    if "magn" in checks or check_all:
        _magn = magn_mse(pred, true)
        markdown = markdown + " {:.{}f} |".format(_magn[0], 4 + math.floor(-math.log10(_magn[0])))
        print()
        print("Magnitude error:")
        print("  MSE Magnitude: {}, std: {}".format(*_magn))
    if "imcon" in checks or check_all:
        print()
        print("Image constraints:")
        print("  Imag part =", np.mean(np.imag(pred)), "- should be very close to 0")
        print("  Real part is in [{0:.2f}, {1:.2f}]".format(np.min(np.real(pred)), np.max(np.real(pred))),
              "- should be in [0, 1]")
    
    print()
    print("Markdown table values:")
    print(markdown)
Beispiel #8
0
    def loop(dataloader, epoch, loss_meter, back=True):
        for i, batch in enumerate(dataloader):
            step = epoch * len(dataloader) + i

            if back:
                optimizer.zero_grad()

            lr, hr = batch
            lr, hr = lr.to(device), hr.to(device)

            if using_mask:
                with torch.no_grad():
                    if config.over_upscale:
                        factor = 4
                    else:
                        factor = 1
                    upscaled, mask_in = image_mask(lr,
                                                   config.up_factor * factor)
                pred = model(upscaled.to(device), mask_in.to(device))
            elif config.unsupervised:
                pred = model(lr)
            elif config.pre_upscale:
                with torch.no_grad():
                    upscaled = transforms.functional.resize(
                        lr, (hr_size, hr_size))
                pred = model(upscaled)
            else:
                pred = model(lr)

            if config.loss == "VGG16Partial":
                loss, _, _ = loss_func(pred, hr)  # VGG style loss
            elif config.loss == "DISTS":
                loss = loss_func(pred,
                                 hr,
                                 require_grad=True,
                                 batch_average=True)
            else:
                loss = loss_func(pred, hr)

            if back:
                loss.backward()
                optimizer.step()

            loss_meter.update(loss.item(), writer, step, name=config.loss)

            if config.metrics:
                with torch.no_grad():
                    for metric in config.metrics:
                        tag = loss_meter.name + "/" + metric
                        if metric == "PSNR":
                            writer.add_scalar(tag, losses.psnr(pred, hr), step)
                        elif metric == "SSIM":
                            writer.add_scalar(tag, losses.ssim(pred, hr), step)
                        elif metric == "consistency":
                            downscaled_pred = transforms.functional.resize(
                                pred, (config.lr_size, config.lr_size))
                            writer.add_scalar(
                                tag,
                                torch.nn.functional.mse_loss(
                                    downscaled_pred, lr).item(),
                                step,
                            )
                        elif metric == "lr":
                            writer.add_scalar(tag,
                                              lr_scheduler.get_last_lr()[0],
                                              step)
                        elif metric == "sample":
                            model.eval()
                            if step % config.sample_step == 0:
                                writer.add_image("sample/hr",
                                                 hr[0],
                                                 global_step=step)
                                writer.add_image("sample/lr",
                                                 lr[0],
                                                 global_step=step)
                                writer.add_image("sample/bicubic",
                                                 upscaled[0],
                                                 global_step=step)
                                writer.add_image("sample/pred",
                                                 pred[0],
                                                 global_step=step)
                            model.train()
                        elif metric == "VGG16Partial":
                            val, _, _ = vgg(pred, hr)
                            writer.add_scalar(tag, val.item(), step)