def forward_loss(self, gt, hazy, args):

        results_forward = self.forward(gt, hazy)    
        rec_gt, rec_hazy_free = results_forward["rec_gt"], results_forward["rec_hazy_free"]


        losses = dict()

        teacher_recons_loss = self.teacher_l1loss(rec_gt, gt)
        
        student_recons_loss = self.student_l1loss(rec_hazy_free, gt)

        gt_perceptual_features = self.vgg19(gt)
        reconstructed_perceptual_features = self.vgg19(rec_hazy_free)

        perceptual_loss = 0.0

        # Sum up perceptual loss taken from different layers of VGG19 
        for idx, (gt_feat, rec_feat) in enumerate(zip(gt_perceptual_features, reconstructed_perceptual_features)):
            
            perceptual_loss += self.perceptual_loss(rec_feat, gt_feat)


        # TODO ADD MIMICKING LOSS
        dehazing_loss = student_recons_loss + args.lambda_p * perceptual_loss

        
        # Scale between 0 - 1
        rec_hazy_free = rec_hazy_free + 1
        gt = gt + 1

        psnr_loss = psnr(rec_hazy_free, gt)
        ssim_loss = ssim(rec_hazy_free, gt)



        losses["teacher_rec_loss"] = teacher_recons_loss
        losses["student_rec_loss"] = student_recons_loss

        losses["perceptual_loss"] = perceptual_loss
        losses["dehazing_loss"] = dehazing_loss

        losses["loss_psnr"] = psnr_loss
        losses["loss_ssim"] = ssim_loss



        self.teacher_scheduler.step(teacher_recons_loss)
        self.student_scheduler.step(student_recons_loss)

        return losses
Ejemplo n.º 2
0
def train_epoch(device, model, data_loader, optimizer, loss_fn, epoch):
    model.train()
    tq = tqdm.tqdm(total=len(data_loader) * args.batch_size)
    tq.set_description(
        f'Train: Epoch {epoch:4}, LR: {optimizer.param_groups[0]["lr"]:0.6f}')
    train_loss, train_ssim, train_psnr = 0, 0, 0
    for batch_idx, (data, target) in enumerate(data_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        prediction = model(data)
        loss = loss_fn(prediction, target)
        if args.amp:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()
        torch.nn.utils.clip_grad_value_(model.parameters(), args.clip)
        torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
        optimizer.step()
        with torch.no_grad():
            train_loss += loss.item() * (1 / len(data_loader))
            if 'temp' in args.type:
                prediction, target = prediction[:, prediction.size(1) //
                                                2].squeeze(
                                                    1
                                                ), target[:,
                                                          target.size(1) //
                                                          2].squeeze(1)
            train_ssim += losses.ssim(prediction,
                                      target).item() * (1 / len(data_loader))
            train_psnr += losses.psnr(prediction,
                                      target).item() * (1 / len(data_loader))
        tq.update(args.batch_size)
        tq.set_postfix(
            loss=f'{train_loss*len(data_loader)/(batch_idx+1):4.6f}',
            ssim=f'{train_ssim*len(data_loader)/(batch_idx+1):.4f}',
            psnr=f'{train_psnr*len(data_loader)/(batch_idx+1):4.4f}')
    tq.close()
    writer.add_scalar('Loss/train', train_loss, epoch)
    writer.add_scalar('SSIM/train', train_ssim, epoch)
    writer.add_scalar('PSNR/train', train_psnr, epoch)
Ejemplo n.º 3
0
def eval_epoch(device, model, data_loader, loss_fn, epoch):
    model.eval()
    tq = tqdm.tqdm(total=len(data_loader))
    tq.set_description(f'Test:  Epoch {epoch:4}')
    eval_loss, eval_ssim, eval_psnr = 0, 0, 0
    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(data_loader):
            data, target = data.to(device), target.to(device)
            prediction = model(data)
            eval_loss += loss_fn(prediction,
                                 target).item() * (1 / len(data_loader))
            if 'temp' in args.type:
                prediction, target = prediction[:, prediction.size(1) //
                                                2].squeeze(
                                                    1
                                                ), target[:,
                                                          target.size(1) //
                                                          2].squeeze(1)
            eval_ssim += losses.ssim(prediction,
                                     target).item() * (1 / len(data_loader))
            eval_psnr += losses.psnr(prediction,
                                     target).item() * (1 / len(data_loader))
            tq.update()
            tq.set_postfix(
                loss=f'{eval_loss*len(data_loader)/(batch_idx+1):4.6f}',
                ssim=f'{eval_ssim*len(data_loader)/(batch_idx+1):.4f}',
                psnr=f'{eval_psnr*len(data_loader)/(batch_idx+1):4.4f}')
    tq.close()
    writer.add_scalar('Loss/test', eval_loss, epoch)
    writer.add_scalar('SSIM/test', eval_ssim, epoch)
    writer.add_scalar('PSNR/test', eval_psnr, epoch)
    if epoch % 10 == 0:
        if 'temp' in args.type: data = data[:, data.size(1) // 2].squeeze(1)
        writer.add_image(f'Prediction/test',
                         torch.clamp(
                             torch.cat(
                                 (data[-1, 0:3], prediction[-1], target[-1]),
                                 dim=-1), 0, 1),
                         epoch,
                         dataformats='CHW')
    return eval_loss
    def backward(self, gt, hazy, args):

        results_forward = self.forward(gt, hazy)
        rec_gt, rec_hazy_free = results_forward["rec_gt"], results_forward["rec_hazy_free"]


        losses = dict()

        teacher_recons_loss = self.teacher_l1loss(rec_gt, gt)

        self.teacher_optimizer.zero_grad()
        teacher_recons_loss.backward()
        self.teacher_optimizer.step()

        
        student_recons_loss = self.student_l1loss(rec_hazy_free, gt)

        gt_perceptual_features = self.vgg19(gt)
        reconstructed_perceptual_features = self.vgg19(rec_hazy_free)

        perceptual_loss = 0.0

        # Sum up perceptual loss taken from different layers of VGG19 
        for idx, (gt_feat, rec_feat) in enumerate(zip(gt_perceptual_features, reconstructed_perceptual_features)):
            
            perceptual_loss += self.perceptual_loss(rec_feat, gt_feat)

        mimicking_loss = 0.0

        for idx, (gt_mimicking, rec_mimicking) in enumerate(zip(self.teacher.forward_mimicking_features(gt), self.student.forward_mimicking_features(hazy))):

            mimicking_loss += self.mimicking_loss(gt_mimicking, rec_mimicking)



        self.student_optimizer.zero_grad()
        dehazing_loss = student_recons_loss + args.lambda_p * perceptual_loss + args.lambda_rm * mimicking_loss
        dehazing_loss.backward()
        self.student_optimizer.step()
        
        # Scale between 0 - 1
        rec_hazy_free = rec_hazy_free + 1
        gt = gt + 1

        psnr_loss = psnr(rec_hazy_free, gt)
        ssim_loss = ssim(rec_hazy_free, gt)



        losses["teacher_rec_loss"] = teacher_recons_loss
        losses["student_rec_loss"] = student_recons_loss

        losses["perceptual_loss"] = perceptual_loss
        losses["dehazing_loss"] = dehazing_loss

        losses["loss_psnr"] = psnr_loss
        losses["loss_ssim"] = ssim_loss


        self.teacher_scheduler.step(teacher_recons_loss)
        self.student_scheduler.step(student_recons_loss)

        return losses
Ejemplo n.º 5
0
def plot_test_images(model,
                     loader,
                     datapath_test,
                     test_output,
                     epoch,
                     name='SRGAN',
                     channels=3,
                     colorspace='RGB'):

    try:
        # Get the location of test images
        test_images = [
            os.path.join(datapath_test, f) for f in os.listdir(datapath_test)
            if any(filetype in f.lower()
                   for filetype in ['jpeg', 'mp4', '264', 'png', 'jpg'])
        ]

        # Load the images to perform test on images
        imgs_lr, imgs_hr = loader.load_batch(img_paths=test_images,
                                             training=False,
                                             bicubic=True)
        # Create super resolution and bicubic interpolation images
        imgs_sr = []
        imgs_bi = []
        srgan_psnr = []
        bi_psnr = []
        for i in range(len(test_images)):

            pre = np.squeeze(model.predict(np.expand_dims(imgs_lr[i], 0),
                                           batch_size=1),
                             axis=0)
            imgs_sr.append(pre)

        # Unscale colors values
        if channels == 1:
            imgs_lr = [
                loader.unscale_lr_imgs(img[:, :, 0]).astype(np.uint8)
                for img in imgs_lr
            ]
            imgs_hr = [
                loader.unscale_hr_imgs(img[:, :, 0]).astype(np.uint8)
                for img in imgs_hr
            ]
            imgs_sr = [
                loader.unscale_hr_imgs(img[:, :, 0]).astype(np.uint8)
                for img in imgs_sr
            ]
        else:
            if (colorspace == 'YCbCr'):
                imgs_lr = [
                    cv2.cvtColor(
                        loader.unscale_lr_imgs(img).astype(np.uint8),
                        cv2.COLOR_YCrCb2BGR) for img in imgs_lr
                ]
                imgs_hr = [
                    cv2.cvtColor(
                        loader.unscale_hr_imgs(img).astype(np.uint8),
                        cv2.COLOR_YCrCb2BGR) for img in imgs_hr
                ]
                imgs_sr = [
                    cv2.cvtColor(
                        loader.unscale_hr_imgs(img).astype(np.uint8),
                        cv2.COLOR_YCrCb2BGR) for img in imgs_sr
                ]

            else:
                imgs_lr = [
                    loader.unscale_lr_imgs(img).astype(np.uint8)
                    for img in imgs_lr
                ]
                imgs_hr = [
                    loader.unscale_hr_imgs(img).astype(np.uint8)
                    for img in imgs_hr
                ]
                imgs_sr = [
                    loader.unscale_hr_imgs(img).astype(np.uint8)
                    for img in imgs_sr
                ]

        # Loop through images
        for img_hr, img_lr, img_sr, img_path in zip(imgs_hr, imgs_lr, imgs_sr,
                                                    test_images):
            # Get the filename
            filename = os.path.basename(img_path).split(".")[0]

            # Bicubic upscale
            hr_shape = (int(img_hr.shape[1]), int(img_hr.shape[0]))
            img_bi = cv2.resize(img_lr,
                                hr_shape,
                                interpolation=cv2.INTER_CUBIC)

            # Images and titles
            images = {
                'Low Resoluiton': [img_lr, img_hr],
                'Bicubic': [img_bi, img_hr],
                name: [img_sr, img_hr],
                'Original': [img_hr, img_hr]
            }
            srgan_psnr.append(psnr(img_sr, img_hr, 255.))
            bi_psnr.append(psnr(img_bi, img_hr, 255.))
            # Plot the images. Note: rescaling and using squeeze since we are getting batches of size 1
            fig, axes = plt.subplots(1, 4, figsize=(40, 10))
            for i, (title, img) in enumerate(images.items()):
                axes[i].imshow(img[0])
                axes[i].set_title("{} - {} {}".format(
                    title, img[0].shape,
                    ("- psnr: " + str(round(psnr(img[0], img[1], 255.), 2)) if
                     (title == name or title == 'Bicubic') else " ")))
                #axes[i].set_title("{} - {}".format(title, img.shape))
                axes[i].axis('off')
            plt.suptitle('{} - Epoch: {}'.format(filename, epoch))

            # Save directory
            savefile = os.path.join(test_output,
                                    "{}-Epoch{}.png".format(filename, epoch))
            fig.savefig(savefile)
            plt.close()
            gc.collect()
        print('test {} psnr: {} - test bi psnr: {}'.format(
            name, np.mean(srgan_psnr), np.mean(bi_psnr)))
    except Exception as e:
        print(e)
Ejemplo n.º 6
0
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            prediction = parallel_model(data)
            loss = loss_fn(prediction, target)
            if args.amp:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()
            torch.nn.utils.clip_grad_value_(parallel_model.parameters(), args.clip)
            torch.nn.utils.clip_grad_norm_(parallel_model.parameters(), args.clip)
            optimizer.step()
            with torch.no_grad():
                train_loss += loss.item() * (1/len(train_loader))
                train_ssim += losses.ssim(prediction, target).item() * (1/len(train_loader))
                train_psnr += losses.psnr(prediction, target).item() * (1/len(train_loader))
            tq.update(args.batch_size)
            tq.set_postfix(loss=f'{train_loss*len(train_loader)/(batch_idx+1):4.6f}',
                    ssim=f'{train_ssim*len(train_loader)/(batch_idx+1):.4f}',
                    psnr=f'{train_psnr*len(train_loader)/(batch_idx+1):4.4f}')
        tq.close()
        writer.add_scalar('Loss/train', train_loss, epoch)
        writer.add_scalar('SSIM/train', train_ssim, epoch)
        writer.add_scalar('PSNR/train', train_psnr, epoch)
        scheduler.step(train_loss)

        # -----------------------------------------
        # save checkpoint for best loss

        if train_loss < best_loss:
            best_loss = train_loss
Ejemplo n.º 7
0
                train_losses.append(sum(train_loss) / len(train_loss))
                iterations.append(i)
                train_loss = []
                saveNet(filename, net, optimizer, iterations, train_losses,
                        val_losses)

                with torch.no_grad():
                    net.eval()
                    criterion_loss = 0.0
                    psnr_score = 0
                    for inputs, labels in validation_data:
                        inputs = inputs.to(device)
                        real_val = labels.to(device)
                        fakes_val = net(inputs)
                        criterion_loss += criterion(real_val, fakes_val).item()
                        psnr_score += psnr(real_val, fakes_val).item()

                    criterion_loss /= validation_size
                    psnr_score /= validation_size
                    validation_loss = criterion_loss
                    val_losses.append(validation_loss)
                    writer.add_scalar("loss/valid", validation_loss, i)
                    writer.add_scalar("psnr/valid", psnr_score, i)

                    speed_mini = read_image(
                        "speed-mini.png",
                        mode=ImageReadMode.RGB).to(device).float() / 255.0
                    writer.add_image("validation image",
                                     net(speed_mini.unsqueeze(0)).squeeze(), i)

                    print("Validation loss:", validation_loss, "Mean PSNR:",
Ejemplo n.º 8
0
    def loop(dataloader, epoch, loss_meter, back=True):
        for i, batch in enumerate(dataloader):
            step = epoch * len(dataloader) + i

            if back:
                optimizer.zero_grad()

            lr, hr = batch
            lr, hr = lr.to(device), hr.to(device)

            if using_mask:
                with torch.no_grad():
                    if config.over_upscale:
                        factor = 4
                    else:
                        factor = 1
                    upscaled, mask_in = image_mask(lr,
                                                   config.up_factor * factor)
                pred = model(upscaled.to(device), mask_in.to(device))
            elif config.unsupervised:
                pred = model(lr)
            elif config.pre_upscale:
                with torch.no_grad():
                    upscaled = transforms.functional.resize(
                        lr, (hr_size, hr_size))
                pred = model(upscaled)
            else:
                pred = model(lr)

            if config.loss == "VGG16Partial":
                loss, _, _ = loss_func(pred, hr)  # VGG style loss
            elif config.loss == "DISTS":
                loss = loss_func(pred,
                                 hr,
                                 require_grad=True,
                                 batch_average=True)
            else:
                loss = loss_func(pred, hr)

            if back:
                loss.backward()
                optimizer.step()

            loss_meter.update(loss.item(), writer, step, name=config.loss)

            if config.metrics:
                with torch.no_grad():
                    for metric in config.metrics:
                        tag = loss_meter.name + "/" + metric
                        if metric == "PSNR":
                            writer.add_scalar(tag, losses.psnr(pred, hr), step)
                        elif metric == "SSIM":
                            writer.add_scalar(tag, losses.ssim(pred, hr), step)
                        elif metric == "consistency":
                            downscaled_pred = transforms.functional.resize(
                                pred, (config.lr_size, config.lr_size))
                            writer.add_scalar(
                                tag,
                                torch.nn.functional.mse_loss(
                                    downscaled_pred, lr).item(),
                                step,
                            )
                        elif metric == "lr":
                            writer.add_scalar(tag,
                                              lr_scheduler.get_last_lr()[0],
                                              step)
                        elif metric == "sample":
                            model.eval()
                            if step % config.sample_step == 0:
                                writer.add_image("sample/hr",
                                                 hr[0],
                                                 global_step=step)
                                writer.add_image("sample/lr",
                                                 lr[0],
                                                 global_step=step)
                                writer.add_image("sample/bicubic",
                                                 upscaled[0],
                                                 global_step=step)
                                writer.add_image("sample/pred",
                                                 pred[0],
                                                 global_step=step)
                            model.train()
                        elif metric == "VGG16Partial":
                            val, _, _ = vgg(pred, hr)
                            writer.add_scalar(tag, val.item(), step)