def forward_loss(self, gt, hazy, args): results_forward = self.forward(gt, hazy) rec_gt, rec_hazy_free = results_forward["rec_gt"], results_forward["rec_hazy_free"] losses = dict() teacher_recons_loss = self.teacher_l1loss(rec_gt, gt) student_recons_loss = self.student_l1loss(rec_hazy_free, gt) gt_perceptual_features = self.vgg19(gt) reconstructed_perceptual_features = self.vgg19(rec_hazy_free) perceptual_loss = 0.0 # Sum up perceptual loss taken from different layers of VGG19 for idx, (gt_feat, rec_feat) in enumerate(zip(gt_perceptual_features, reconstructed_perceptual_features)): perceptual_loss += self.perceptual_loss(rec_feat, gt_feat) # TODO ADD MIMICKING LOSS dehazing_loss = student_recons_loss + args.lambda_p * perceptual_loss # Scale between 0 - 1 rec_hazy_free = rec_hazy_free + 1 gt = gt + 1 psnr_loss = psnr(rec_hazy_free, gt) ssim_loss = ssim(rec_hazy_free, gt) losses["teacher_rec_loss"] = teacher_recons_loss losses["student_rec_loss"] = student_recons_loss losses["perceptual_loss"] = perceptual_loss losses["dehazing_loss"] = dehazing_loss losses["loss_psnr"] = psnr_loss losses["loss_ssim"] = ssim_loss self.teacher_scheduler.step(teacher_recons_loss) self.student_scheduler.step(student_recons_loss) return losses
def train_epoch(device, model, data_loader, optimizer, loss_fn, epoch): model.train() tq = tqdm.tqdm(total=len(data_loader) * args.batch_size) tq.set_description( f'Train: Epoch {epoch:4}, LR: {optimizer.param_groups[0]["lr"]:0.6f}') train_loss, train_ssim, train_psnr = 0, 0, 0 for batch_idx, (data, target) in enumerate(data_loader): data, target = data.to(device), target.to(device) optimizer.zero_grad() prediction = model(data) loss = loss_fn(prediction, target) if args.amp: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() torch.nn.utils.clip_grad_value_(model.parameters(), args.clip) torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip) optimizer.step() with torch.no_grad(): train_loss += loss.item() * (1 / len(data_loader)) if 'temp' in args.type: prediction, target = prediction[:, prediction.size(1) // 2].squeeze( 1 ), target[:, target.size(1) // 2].squeeze(1) train_ssim += losses.ssim(prediction, target).item() * (1 / len(data_loader)) train_psnr += losses.psnr(prediction, target).item() * (1 / len(data_loader)) tq.update(args.batch_size) tq.set_postfix( loss=f'{train_loss*len(data_loader)/(batch_idx+1):4.6f}', ssim=f'{train_ssim*len(data_loader)/(batch_idx+1):.4f}', psnr=f'{train_psnr*len(data_loader)/(batch_idx+1):4.4f}') tq.close() writer.add_scalar('Loss/train', train_loss, epoch) writer.add_scalar('SSIM/train', train_ssim, epoch) writer.add_scalar('PSNR/train', train_psnr, epoch)
def eval_epoch(device, model, data_loader, loss_fn, epoch): model.eval() tq = tqdm.tqdm(total=len(data_loader)) tq.set_description(f'Test: Epoch {epoch:4}') eval_loss, eval_ssim, eval_psnr = 0, 0, 0 with torch.no_grad(): for batch_idx, (data, target) in enumerate(data_loader): data, target = data.to(device), target.to(device) prediction = model(data) eval_loss += loss_fn(prediction, target).item() * (1 / len(data_loader)) if 'temp' in args.type: prediction, target = prediction[:, prediction.size(1) // 2].squeeze( 1 ), target[:, target.size(1) // 2].squeeze(1) eval_ssim += losses.ssim(prediction, target).item() * (1 / len(data_loader)) eval_psnr += losses.psnr(prediction, target).item() * (1 / len(data_loader)) tq.update() tq.set_postfix( loss=f'{eval_loss*len(data_loader)/(batch_idx+1):4.6f}', ssim=f'{eval_ssim*len(data_loader)/(batch_idx+1):.4f}', psnr=f'{eval_psnr*len(data_loader)/(batch_idx+1):4.4f}') tq.close() writer.add_scalar('Loss/test', eval_loss, epoch) writer.add_scalar('SSIM/test', eval_ssim, epoch) writer.add_scalar('PSNR/test', eval_psnr, epoch) if epoch % 10 == 0: if 'temp' in args.type: data = data[:, data.size(1) // 2].squeeze(1) writer.add_image(f'Prediction/test', torch.clamp( torch.cat( (data[-1, 0:3], prediction[-1], target[-1]), dim=-1), 0, 1), epoch, dataformats='CHW') return eval_loss
def backward(self, gt, hazy, args): results_forward = self.forward(gt, hazy) rec_gt, rec_hazy_free = results_forward["rec_gt"], results_forward["rec_hazy_free"] losses = dict() teacher_recons_loss = self.teacher_l1loss(rec_gt, gt) self.teacher_optimizer.zero_grad() teacher_recons_loss.backward() self.teacher_optimizer.step() student_recons_loss = self.student_l1loss(rec_hazy_free, gt) gt_perceptual_features = self.vgg19(gt) reconstructed_perceptual_features = self.vgg19(rec_hazy_free) perceptual_loss = 0.0 # Sum up perceptual loss taken from different layers of VGG19 for idx, (gt_feat, rec_feat) in enumerate(zip(gt_perceptual_features, reconstructed_perceptual_features)): perceptual_loss += self.perceptual_loss(rec_feat, gt_feat) mimicking_loss = 0.0 for idx, (gt_mimicking, rec_mimicking) in enumerate(zip(self.teacher.forward_mimicking_features(gt), self.student.forward_mimicking_features(hazy))): mimicking_loss += self.mimicking_loss(gt_mimicking, rec_mimicking) self.student_optimizer.zero_grad() dehazing_loss = student_recons_loss + args.lambda_p * perceptual_loss + args.lambda_rm * mimicking_loss dehazing_loss.backward() self.student_optimizer.step() # Scale between 0 - 1 rec_hazy_free = rec_hazy_free + 1 gt = gt + 1 psnr_loss = psnr(rec_hazy_free, gt) ssim_loss = ssim(rec_hazy_free, gt) losses["teacher_rec_loss"] = teacher_recons_loss losses["student_rec_loss"] = student_recons_loss losses["perceptual_loss"] = perceptual_loss losses["dehazing_loss"] = dehazing_loss losses["loss_psnr"] = psnr_loss losses["loss_ssim"] = ssim_loss self.teacher_scheduler.step(teacher_recons_loss) self.student_scheduler.step(student_recons_loss) return losses
import numpy as np from keras import backend as K from keras.losses import mean_absolute_error from losses import ssim, photometric_consistency_loss x = np.ones((1, 10, 10, 1)) y = np.ones((1, 10, 10, 1)) x_img1 = K.variable(x) y_img1 = K.variable(y) ssim1 = ssim(x_img1, y_img1) assert K.eval(ssim1).all() == np.zeros((1, 10, 10)).all() x_img2 = K.variable(255 * x) y_img2 = K.variable(-255 * y) ssim2 = ssim(x_img2, y_img2) assert K.eval(ssim2).all() == np.ones((1, 10, 10)).all() pcl = photometric_consistency_loss(x_img1, y_img1) assert K.eval(pcl).all() == np.zeros((1, 10, 10)).all()
for batch_idx, (data, target) in enumerate(train_loader): data, target = data.to(device), target.to(device) optimizer.zero_grad() prediction = parallel_model(data) loss = loss_fn(prediction, target) if args.amp: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() torch.nn.utils.clip_grad_value_(parallel_model.parameters(), args.clip) torch.nn.utils.clip_grad_norm_(parallel_model.parameters(), args.clip) optimizer.step() with torch.no_grad(): train_loss += loss.item() * (1/len(train_loader)) train_ssim += losses.ssim(prediction, target).item() * (1/len(train_loader)) train_psnr += losses.psnr(prediction, target).item() * (1/len(train_loader)) tq.update(args.batch_size) tq.set_postfix(loss=f'{train_loss*len(train_loader)/(batch_idx+1):4.6f}', ssim=f'{train_ssim*len(train_loader)/(batch_idx+1):.4f}', psnr=f'{train_psnr*len(train_loader)/(batch_idx+1):4.4f}') tq.close() writer.add_scalar('Loss/train', train_loss, epoch) writer.add_scalar('SSIM/train', train_ssim, epoch) writer.add_scalar('PSNR/train', train_psnr, epoch) scheduler.step(train_loss) # ----------------------------------------- # save checkpoint for best loss if train_loss < best_loss:
def benchmark(pred, true, check_all=False, check=["mse", "magn", "imcon"]): ''' Benchmark parcours to evaluate quality of predictions. Checks include MSE, MAE and SSIM applied directly, using phase-correlation and cross-correllation. Sharpness and the magnitude error are tested and image constraints are checked. Method expects numpy arrays ''' pred_signal = np.real(pred) true_signal = np.real(true) checks = [e.lower() for e in check] pred_croco = register_croco(pred_signal, true_signal) pred_phaco = register_phaco(pred_signal, true_signal) markdown = "" print("Signal error:") if "mse" in checks or check_all: _mse = mse(pred_signal, true_signal) markdown = markdown + " {:.{}f} |".format(_mse[0], 4 + math.floor(-math.log10(_mse[0]))) print(" MSE: {}, std: {}".format(*_mse)) if "mae" in checks or check_all: _mae = mae(pred_signal, true_signal) markdown = markdown + " {:.{}f} |".format(_mae[0], 4 + math.floor(-math.log10(_mae[0]))) print(" MAE: {}, std: {}".format(*_mae)) if "ssim" in checks or check_all: _ssim = ssim(pred_signal, true_signal) markdown = markdown + " {:.{}f} |".format(_ssim[0], 4 + math.floor(-math.log10(_ssim[0]))) print(" SSIM: {}, std: {}".format(*_ssim)) if "sharpness" in checks or check_all: _sharpness = sharp_dist(pred_croco, true_signal) markdown = markdown + " {:.{}f} |".format(_sharpness[0], 4 + math.floor(-math.log10(_sharpness[0]))) print(" Sharpness: {}, std: {}".format(*_sharpness)) if "phaco" in checks or check_all: print("=============================PHACO=============================") _fasimse = mse(pred_phaco, true_signal) markdown = markdown + " {:.{}f} |".format(_fasimse[0], 4 + math.floor(-math.log10(_fasimse[0]))) print(" PhaCo-MSE: {}, std: {}".format(*_fasimse)) _fasimae = mae(pred_phaco, true_signal) markdown = markdown + " {:.{}f} |".format(_fasimae[0], 4 + math.floor(-math.log10(_fasimae[0]))) print(" PhaCo-MAE: {}, std: {}".format(*_fasimae)) _fasissim = ssim(pred_phaco, true_signal) markdown = markdown + " {:.{}f} |".format(_fasissim[0], 4 + math.floor(-math.log10(_fasissim[0]))) print(" PhaCo-SSIM: {}, std: {}".format(*_fasissim)) if "croco" in checks or check_all: print("=============================CROCO=============================") _crocomse = mse(pred_croco, true_signal) markdown = markdown + " {:.{}f} |".format(_crocomse[0], 4 + math.floor(-math.log10(_crocomse[0]))) print(" CroCo-MSE: {}, std: {}".format(*_crocomse)) _crocomae = mae(pred_croco, true_signal) markdown = markdown + " {:.{}f} |".format(_crocomae[0], 4 + math.floor(-math.log10(_crocomae[0]))) print(" CroCo-MAE: {}, std: {}".format(*_crocomae)) _crocossim = ssim(pred_croco, true_signal) markdown = markdown + " {:.{}f} |".format(_crocossim[0], 4 + math.floor(-math.log10(_crocossim[0]))) print(" CroCo-SSIM: {}, std: {}".format(*_crocossim)) if "magn" in checks or check_all: _magn = magn_mse(pred, true) markdown = markdown + " {:.{}f} |".format(_magn[0], 4 + math.floor(-math.log10(_magn[0]))) print() print("Magnitude error:") print(" MSE Magnitude: {}, std: {}".format(*_magn)) if "imcon" in checks or check_all: print() print("Image constraints:") print(" Imag part =", np.mean(np.imag(pred)), "- should be very close to 0") print(" Real part is in [{0:.2f}, {1:.2f}]".format(np.min(np.real(pred)), np.max(np.real(pred))), "- should be in [0, 1]") print() print("Markdown table values:") print(markdown)
def loop(dataloader, epoch, loss_meter, back=True): for i, batch in enumerate(dataloader): step = epoch * len(dataloader) + i if back: optimizer.zero_grad() lr, hr = batch lr, hr = lr.to(device), hr.to(device) if using_mask: with torch.no_grad(): if config.over_upscale: factor = 4 else: factor = 1 upscaled, mask_in = image_mask(lr, config.up_factor * factor) pred = model(upscaled.to(device), mask_in.to(device)) elif config.unsupervised: pred = model(lr) elif config.pre_upscale: with torch.no_grad(): upscaled = transforms.functional.resize( lr, (hr_size, hr_size)) pred = model(upscaled) else: pred = model(lr) if config.loss == "VGG16Partial": loss, _, _ = loss_func(pred, hr) # VGG style loss elif config.loss == "DISTS": loss = loss_func(pred, hr, require_grad=True, batch_average=True) else: loss = loss_func(pred, hr) if back: loss.backward() optimizer.step() loss_meter.update(loss.item(), writer, step, name=config.loss) if config.metrics: with torch.no_grad(): for metric in config.metrics: tag = loss_meter.name + "/" + metric if metric == "PSNR": writer.add_scalar(tag, losses.psnr(pred, hr), step) elif metric == "SSIM": writer.add_scalar(tag, losses.ssim(pred, hr), step) elif metric == "consistency": downscaled_pred = transforms.functional.resize( pred, (config.lr_size, config.lr_size)) writer.add_scalar( tag, torch.nn.functional.mse_loss( downscaled_pred, lr).item(), step, ) elif metric == "lr": writer.add_scalar(tag, lr_scheduler.get_last_lr()[0], step) elif metric == "sample": model.eval() if step % config.sample_step == 0: writer.add_image("sample/hr", hr[0], global_step=step) writer.add_image("sample/lr", lr[0], global_step=step) writer.add_image("sample/bicubic", upscaled[0], global_step=step) writer.add_image("sample/pred", pred[0], global_step=step) model.train() elif metric == "VGG16Partial": val, _, _ = vgg(pred, hr) writer.add_scalar(tag, val.item(), step)