def main(args): model = models.LearnableDemosaick(num_filters=args.nfilters, fsize=args.fsize) reference_model = models.NaiveDemosaick() dset = datasets.DemosaickingDataset(args.dataset, transform=datasets.ToTensor()) log.info("Validating on {} with {} images".format(args.dataset, len(dset))) loader = DataLoader(dset, batch_size=16, num_workers=4, shuffle=False) if args.cuda: model = model.cuda() l1_fn = metrics.CroppedL1Loss(crop=args.fsize // 2) msssim_fn = metrics.MSSSIM() alpha = 0.84 crop = args.fsize // 2 psnr_fn = metrics.PSNR(crop=args.fsize // 2) env = os.path.basename(args.chkpt) + "_eval" checkpointer = utils.Checkpointer(args.chkpt, model, None, verbose=False) chkpt_name, _ = checkpointer.load_latest() log.info("Loading checkpoint {}.".format(chkpt_name)) callback = demosaick.DemosaickCallback(model, reference_model, len(loader), loader, env=env) idx = 0 with tqdm(total=len(loader), unit=' batches') as pbar: pbar.set_description("Validation") avg = utils.Averager(["loss", "psnr", "ssim", "l1"]) for batch_id, batch in enumerate(loader): mosaick, reference = batch mosaick = Variable(mosaick, requires_grad=False) reference = Variable(reference, requires_grad=False) if args.cuda: mosaick = mosaick.cuda() reference = reference.cuda() output = model(mosaick) if args.save is not None: if not os.path.exists(args.save): os.makedirs(args.save) for i in range(output.shape[0]): im = output[i].cpu().data.numpy() im = np.transpose(im, [1, 2, 0]) im = np.clip(im, 0, 1) fname = os.path.join(args.save, "{:04d}.png".format(idx)) idx += 1 skimage.io.imsave(fname, im) if crop > 0: output = output[:, :, crop:-crop, crop:-crop] reference = reference[:, :, crop:-crop, crop:-crop] ssim_ = 1 - msssim_fn(output, reference) l1_ = l1_fn(output, reference) loss = ssim_ * alpha + (1 - alpha) * l1_ psnr = psnr_fn(output, reference) avg.update("loss", loss.data[0], count=mosaick.shape[0]) avg.update("psnr", psnr.data[0], count=mosaick.shape[0]) avg.update("ssim", ssim_.data[0], count=mosaick.shape[0]) avg.update("l1", l1_.data[0], count=mosaick.shape[0]) pbar.update(1) logs = { "loss": avg["loss"], "psnr": avg["psnr"], "ssim": avg["ssim"], "l1": avg["l1"] } pbar.set_postfix(logs) callback.on_epoch_end(0, logs)
def main(args): model = models.LearnableDemosaick(num_filters=args.nfilters, fsize=args.fsize) # model.softmax_scale[...] = 0.01 reference_model = models.NaiveDemosaick() if not os.path.exists(args.output): os.makedirs(args.output) dset = datasets.DemosaickingDataset(args.dataset, transform=datasets.ToTensor()) val_dset = datasets.DemosaickingDataset(args.val_dataset, transform=datasets.ToTensor()) log.info("Training on {} with {} images".format(args.dataset, len(dset))) log.info("Validating on {} with {} images".format(args.val_dataset, len(val_dset))) # log.info("Computing PCA filters") # vects = demosaick.get_pca_filters(dset, args.fsize) # model.sel_filts.data = th.from_numpy(vects) loader = DataLoader(dset, batch_size=args.batch_size, num_workers=4, shuffle=True) val_loader = DataLoader(val_dset, batch_size=args.batch_size) if args.cuda: model = model.cuda() # model.softmax_scale.cuda() # params = [p for n, p in model.named_parameters() if n != "green_filts"] # optimizer = th.optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, nesterov=True) optimizer = th.optim.Adam(model.parameters(), lr=args.lr) # mse_fn = metrics.CroppedMSELoss(crop=args.fsize//2) l1_fn = th.nn.L1Loss() msssim_fn = metrics.MSSSIM() # grad_l1_fn = metrics.CroppedGradientLoss(crop=args.fsize//2) # loss_fn = lambda a, b: 0.84*msssim_fn(a, b) + (1-0.84)*l1_fn(a, b) alpha = args.alpha crop = args.fsize // 2 psnr_fn = metrics.PSNR(crop=args.fsize // 2) env = os.path.basename(args.output) checkpointer = utils.Checkpointer(args.output, model, optimizer, verbose=False, interval=600) callback = demosaick.DemosaickCallback(model, reference_model, len(loader), val_loader, env=env) if args.regularize: log.info("Using L1 weight regularization") if args.chkpt is not None: log.info("Loading checkpoint {}".format(args.chkpt)) checkpointer.load_checkpoint(args.chkpt, ignore_optim=True) else: chkpt_name, _ = checkpointer.load_latest() log.info("Resuming from latest checkpoint {}.".format(chkpt_name)) ema = utils.ExponentialMovingAverage( ["loss", "psnr", "ssim", "l1", "psnr_g"]) for epoch in range(args.num_epochs): # callback.on_epoch_end(epoch, {}) # Training model.train(True) with tqdm(total=len(loader), unit=' batches') as pbar: pbar.set_description("Epoch {}/{}".format(epoch + 1, args.num_epochs)) # callback.on_epoch_begin(epoch) for batch_id, batch in enumerate(loader): mosaick, reference = batch mosaick = Variable(mosaick, requires_grad=False) reference = Variable(reference, requires_grad=False) if args.cuda: mosaick = mosaick.cuda() reference = reference.cuda() output = model(mosaick) optimizer.zero_grad() if crop > 0: output = output[:, :, crop:-crop, crop:-crop] reference = reference[:, :, crop:-crop, crop:-crop] ssim_ = 1 - msssim_fn(output, reference) l1_ = l1_fn(output, reference) loss = ssim_ * alpha + (1 - alpha) * l1_ if args.regularize: l1_reg = None reg_w = 1e-6 for n, p in model.named_parameters(): if l1_reg is None: l1_reg = p.norm(1) else: l1_reg = l1_reg + p.norm(1) loss += l1_reg * reg_w loss.backward() optimizer.step() psnr = psnr_fn(output, reference) psnr_green = psnr_fn(output[:, 1, ...], reference[:, 1, ...]) ema.update("loss", loss.data[0]) ema.update("psnr", psnr.data[0]) ema.update("psnr_g", psnr_green.data[0]) ema.update("ssim", ssim_.data[0]) ema.update("l1", l1_.data[0]) logs = { "loss": ema["loss"], "psnr": ema["psnr"], "psnr_g": ema["psnr_g"], "ssim": ema["ssim"], "l1": ema["l1"] } pbar.set_postfix(logs) pbar.update(1) if pbar.n % args.viz_step == 0: callback.on_batch_end(batch_id, logs) callback.show_val_batch() checkpointer.periodic_checkpoint(epoch) # Validation model.train(False) with tqdm(total=len(val_loader), unit=' batches') as pbar: pbar.set_description("Epoch {}/{} (val)".format( epoch + 1, args.num_epochs)) avg = utils.Averager(["loss", "psnr", "ssim", "l1"]) for batch_id, batch in enumerate(val_loader): mosaick, reference = batch mosaick = Variable(mosaick, requires_grad=False) reference = Variable(reference, requires_grad=False) if args.cuda: mosaick = mosaick.cuda() reference = reference.cuda() output = model(mosaick) if crop > 0: output = output[:, :, crop:-crop, crop:-crop] reference = reference[:, :, crop:-crop, crop:-crop] ssim_ = 1 - msssim_fn(output, reference) l1_ = l1_fn(output, reference) loss = ssim_ * alpha + (1 - alpha) * l1_ psnr = psnr_fn(output, reference) avg.update("loss", loss.data[0], count=mosaick.shape[0]) avg.update("psnr", psnr.data[0], count=mosaick.shape[0]) avg.update("ssim", ssim_.data[0], count=mosaick.shape[0]) avg.update("l1", l1_.data[0], count=mosaick.shape[0]) pbar.update(1) logs = { "loss": avg["loss"], "psnr": avg["psnr"], "ssim": avg["ssim"], "l1": avg["l1"] } pbar.set_postfix(logs) callback.on_epoch_end(epoch, logs) # save checkpointer.on_epoch_end(epoch)
def main(args): model = models.DeconvNonlinearCG(num_stages=3) ref_model = models.DeconvNonlinearCG(ref=True) if not os.path.exists(args.output): os.makedirs(args.output) dset = datasets.DeconvDataset(args.dataset) val_dset = datasets.DeconvDataset(args.val_dataset, is_validate=True) log.info("Training on {} with {} images".format(args.dataset, len(dset))) log.info("Validating on {} with {} images".format(args.val_dataset, len(val_dset))) if args.cuda: model = model.cuda() ref_model = ref_model.cuda() print("Training parameters:") params_to_train = [] for n, p in model.named_parameters(): print(" -", n) params_to_train.append(p) optimizer = th.optim.Adam(params_to_train, lr=args.lr) # optimizer = th.optim.SGD(model.parameters(), lr=args.lr) loss_fn = metrics.CroppedL1Loss(crop=16) psnr_fn = metrics.PSNR(crop=16) loader = DataLoader(dset, batch_size=args.batch_size, num_workers=1, shuffle=True) val_loader = DataLoader(val_dset, batch_size=8) checkpointer = utils.Checkpointer(args.output, model, optimizer, verbose=True) callback = DeconvCallback(model, ref_model, val_loader, args.cuda, env="gapps_deconv") smooth_loss = 0 smooth_psnr = 0 ema = 0.9 chkpt_name, iteration = checkpointer.load_latest() log.info("Resuming from latest checkpoint {}.".format(chkpt_name)) train_iterator = iter(loader) best_psnr = 0.0 first = True while True: # Training # Get a batch from the dataset try: batch = train_iterator.next() except StopIteration: train_iterator = iter(loader) batch = train_iterator.next() model.train(True) # Setup input & reference blurred, reference, kernel = batch blurred = Variable(blurred, requires_grad=False) reference = Variable(reference, requires_grad=False) kernel = Variable(kernel, requires_grad=False) # Transfer data to gpu if necessary if args.cuda: blurred = blurred.cuda() reference = reference.cuda() kernel = kernel.cuda() # Run the model output = model(blurred, kernel, cg_iter) # Compute loss & optimize optimizer.zero_grad() loss = loss_fn(output, reference) loss.backward() optimizer.step() model.reg_powers.data.fill_(2.0) # Compute PSNR psnr = psnr_fn(output, reference) # Exponential smooth of error curve if first: smooth_loss = loss.data[0] smooth_psnr = psnr.data[0] first = False else: smooth_loss = ema * smooth_loss + (1 - ema) * loss.data[0] smooth_psnr = ema * smooth_psnr + (1 - ema) * psnr.data[0] print('loss: {}, psnr: {}'.format(smooth_loss, smooth_psnr)) model.train(False) ref_model.train(False) logs = {"loss": smooth_loss, "psnr": smooth_psnr} callback.on_iteration_end(iteration, logs) if iteration % 20 == 0: # Validation # Go through the whole validation dataset total_loss = 0 total_psnr = 0 total_ref_loss = 0 total_ref_psnr = 0 n_seen = 0 for batch_id, batch in enumerate(val_loader): blurred, reference, kernel = batch blurred = Variable(blurred, requires_grad=False) reference = Variable(reference, requires_grad=False) kernel = Variable(kernel, requires_grad=False) if args.cuda: blurred = blurred.cuda() reference = reference.cuda() kernel = kernel.cuda() output = model(blurred, kernel, ref_cg_iter) loss = loss_fn(output, reference) psnr = psnr_fn(output, reference) ref_output = ref_model(blurred, kernel, ref_cg_iter) ref_loss = loss_fn(ref_output, reference) ref_psnr = psnr_fn(ref_output, reference) total_loss += loss.data[0] * args.batch_size total_psnr += psnr.data[0] * args.batch_size total_ref_loss += ref_loss.data[0] * args.batch_size total_ref_psnr += ref_psnr.data[0] * args.batch_size n_seen += args.batch_size val_loss = total_loss / n_seen val_psnr = total_psnr / n_seen ref_loss = total_ref_loss / n_seen ref_psnr = total_ref_psnr / n_seen logs = { "val_loss": val_loss, "val_psnr": val_psnr, "ref_loss": ref_loss, "ref_psnr": ref_psnr } callback.on_validation_end(iteration, logs) # save checkpointer.on_epoch_end(iteration) # save best if val_psnr > best_psnr: filename = 'epoch_{:03d}_best.pth.tar'.format(iteration + 1) checkpointer.save_checkpoint(iteration, filename) best_psnr = val_psnr iteration += 1