def visualized_image(self, batch, fwd_output): fwd_output = fwd_output.cpu().detach() mosaic, target = batch mosaic = crop_like(mosaic.cpu().detach(), fwd_output) target = crop_like(target.cpu().detach(), fwd_output) diff = 4 * (fwd_output - target).abs() vizdata = [mosaic, target, fwd_output, diff] viz = th.clamp(th.cat(vizdata, 2), 0, 1) return viz
def backward(self, batch, fwd): self.optimizer.zero_grad() out = fwd["radiance"] tgt = crop_like(batch["target_image"], out) # make sure sizes match loss = self.loss_fn(out, tgt) loss.backward() # Couple checks to pick up on outliers in the data. if not np.isfinite(loss.data.item()): LOG.error("Loss is infinite, there might be outliers in the data.") raise RuntimeError("Infinite loss at train time.") if np.isnan(loss.data.item()): LOG.error("NaN in the loss, there might be outliers in the data.") raise RuntimeError("NaN loss at train time.") clip = 1000 actual = th.nn.utils.clip_grad_norm_(self.model.parameters(), clip) if actual > clip: LOG.info("Clipped gradients {} -> {}".format(clip, actual)) self.optimizer.step() with th.no_grad(): rmse = self.rmse_fn(out, tgt) return {"loss": loss.item(), "rmse": rmse.item()}
def backward(self, tgt, fwd): if self.cuda: tgt = tgt.cuda() self.optimizer.zero_grad() loss = self.loss_fn(fwd, crop_like(tgt, fwd)) loss.backward() self.optimizer.step() return loss.item()
def forward(self, data): """Forward pass of the model. Args: data(dict) with keys: "kpcn_diffuse_in": "kpcn_specular_in": "kpcn_diffuse_buffer": "kpcn_specular_buffer": "kpcn_albedo": Returns: (dict) with keys: "radiance": "diffuse": "specular": """ # Process the diffuse and specular channels independently k_diffuse = self.diffuse(data["kpcn_diffuse_in"]) k_specular = self.specular(data["kpcn_specular_in"]) # Match dimensions b_diffuse = crop_like(data["kpcn_diffuse_buffer"], k_diffuse).contiguous() b_specular = crop_like(data["kpcn_specular_buffer"], k_specular).contiguous() # Kernel reconstruction r_diffuse, _ = self.kernel_apply(b_diffuse, k_diffuse) r_specular, _ = self.kernel_apply(b_specular, k_specular) # Combine diffuse/specular/albedo albedo = crop_like(data["kpcn_albedo"], r_diffuse) final_specular = th.exp(r_specular) - 1 final_diffuse = albedo * r_diffuse final_radiance = final_diffuse + final_specular output = dict(radiance=final_radiance, diffuse=r_diffuse, specular=r_specular) return output
def visualized_image(self, batch, fwd_result): lowspp = batch["low_spp"].detach() target = batch["target_image"].detach() output = fwd_result["radiance"].detach() # Make sure images have the same size lowspp = crop_like(lowspp, output) target = crop_like(target, output) # Assemble a display gallery diff = (output - target).abs() data = th.cat([lowspp, output, target, diff], -2) # Clip and tonemap data = th.clamp(data, 0) data /= 1 + data data = th.pow(data, 1.0 / 2.2) data = th.clamp(data, 0, 1) return data
def update_validation(self, batch, fwd_output, running_data): target = batch[1].to(self.device) # remove boundaries to match output size target = crop_like(target, fwd_output) with th.no_grad(): psnr = self.psnr(th.clamp(fwd_output, 0, 1), target) n = target.shape[0] return { "psnr": running_data["psnr"] + psnr.item() * n, "count": running_data["count"] + n }
def backward(self, batch, fwd_output): target = batch[1].to(self.device) # remove boundaries to match output size target = crop_like(target, fwd_output) loss = self.loss(fwd_output, target) self.opt.zero_grad() loss.backward() self.opt.step() with th.no_grad(): psnr = self.psnr(th.clamp(fwd_output, 0, 1), target) return {"loss": loss.item(), "psnr": psnr.item()}
def update_validation(self, batch, fwd, running): """Updates running statistics for the validation.""" with th.no_grad(): out = fwd["radiance"] tgt = crop_like(batch["target_image"], out) loss = self.loss_fn(out, tgt).item() rmse = self.rmse_fn(out, tgt).item() # Make sure our statistics accound for potentially varying batch size b = out.shape[0] # Update the running means n = running["n"] + b new_loss = running["loss"] - (1.0 / n) * (running["loss"] - b * loss) new_rmse = running["rmse"] - (1.0 / n) * (running["rmse"] - b * rmse) return {"loss": new_loss, "rmse": new_rmse, "n": n}
def forward(self, data, coords): # in_ = coords.contiguous() in_ = th.cat([th.log10(1.0 + data / 255.0), coords], 2).contiguous() assert in_.shape[ 0] == 1, "current implementation assumes batch_size = 1" kernels = self.net(in_.squeeze(0)) cdata = crop_like(data.squeeze(0), kernels).contiguous() output, _ = self.kernel_update(cdata, kernels) # Average over samples output = th.unsqueeze(output, 0).mean(1) # crop output k = (self.ksize - 1) // 2 output = output[..., k:-k, k:-k] kviz = kernels.detach().clone() min_ = kviz.min() max_ = kviz.max() kviz = (kviz - min_) / (max_ - min_ - 1e-8) bs, k2, h, w = kviz.shape return output, kviz.view(bs, self.ksize, self.ksize, h, w)
def main(args): dataset = AADataset(args.input, ds=args.ds, spp=args.spp, sigma=args.sigma, size=args.size, outliers=args.outliers, outliers_p=args.outliers_p) loader = DataLoader(dataset, batch_size=1, num_workers=4) # Kernel optimization ------------- # initialize everything to a box filter th.manual_seed(0) width = 32 depth = 3 gather_mdl = ForwardModel(dataset.h_lr, dataset.w_lr, ksize=args.ksize, depth=depth, width=width, scatter=False) th.manual_seed(0) # beefier gather model b_gather_mdl = ForwardModel(dataset.h_lr, dataset.w_lr, ksize=args.ksize, depth=depth * args.depth_factor, width=width * args.width_factor, scatter=False) th.manual_seed(0) scatter_mdl = ForwardModel(dataset.h_lr, dataset.w_lr, ksize=args.ksize, depth=depth, width=width, scatter=True) gather_interface = OptimizerInterface(gather_mdl) b_gather_interface = OptimizerInterface(b_gather_mdl) scatter_interface = OptimizerInterface(scatter_mdl) # optimize the kernels all_losses = np.zeros((3, args.nsteps)) step = 0 for step, batch in enumerate(loader): data, subpixel_coords, target, mask = batch gather_result, gather_k = gather_interface.forward( data, subpixel_coords) loss_gather = gather_interface.backward(target, gather_result) b_gather_result, b_gather_k = b_gather_interface.forward( data, subpixel_coords) loss_b_gather = b_gather_interface.backward(target, b_gather_result) scatter_result, scatter_k = scatter_interface.forward( data, subpixel_coords) loss_scatter = scatter_interface.backward(target, scatter_result) all_losses[0, step] = loss_gather all_losses[1, step] = loss_b_gather all_losses[2, step] = loss_scatter step += 1 if step == args.nsteps: break if step % 10 == 0: print( "{:05d} | Gather = {:.4f}, Gather(big) = {:.4f}, Scatter = {:.4f}, gather/scatter = {:.2f} gather(big)/scatter = {:.2f}" .format(step, loss_gather, loss_b_gather, loss_scatter, loss_gather / loss_scatter, loss_b_gather / loss_scatter)) # --------------------------------- gather_k_viz = np.clip(kernels2im(gather_k).detach().cpu().numpy(), 0, 1) b_gather_k_viz = np.clip( kernels2im(b_gather_k).detach().cpu().numpy(), 0, 1) scatter_k_viz = np.clip(kernels2im(scatter_k).detach().cpu().numpy(), 0, 1) spp, h, w = gather_k_viz.shape gather_k_viz = gather_k_viz.reshape([spp * h, w]) spp, b_h, b_w = b_gather_k_viz.shape b_gather_k_viz = b_gather_k_viz.reshape([spp * b_h, b_w]) scatter_k_viz = scatter_k_viz.reshape([spp * h, w]) # Conversion to numpy ----------------- gres_c = gather_result.clone() mask = crop_like(mask.permute(0, 1, 4, 2, 3).max(1)[0], gather_result) mask = mask[0, 0] im = crop_like(data.mean(1), gather_result) im = th.clamp(im.detach()[0], 0, 255).permute(1, 2, 0).cpu().numpy().astype(np.uint8) target = crop_like(target, gather_result) target = th.clamp(target.detach()[0], 0, 255).permute(1, 2, 0).cpu().numpy().astype(np.uint8) gather_result = th.clamp(gather_result.detach()[0], 0, 255).permute(1, 2, 0).cpu().numpy().astype(np.uint8) b_gather_result = th.clamp(b_gather_result.detach()[0], 0, 255).permute(1, 2, 0).cpu().numpy().astype(np.uint8) scatter_result = th.clamp(scatter_result.detach()[0], 0, 255).permute(1, 2, 0).cpu().numpy().astype(np.uint8) # Output ------------- os.makedirs(args.output, exist_ok=True) path = os.path.join(args.output, "0_input.png") skio.imsave(path, im) path = os.path.join(args.output, "0b_mask.png") skio.imsave(path, mask) # path = os.path.join(args.output, "1_lowpass.png") # skio.imsave(path, lp) # # path = os.path.join(args.output, "2_subsampled_jitter.png") # skio.imsave(path, jitter_sampled) # path = os.path.join(args.output, "3_gather.png") skio.imsave(path, gather_result) path = os.path.join(args.output, "3_gather_big.png") skio.imsave(path, b_gather_result) path = os.path.join(args.output, "4_target.png") skio.imsave(path, target) path = os.path.join(args.output, "5_scatter.png") skio.imsave(path, scatter_result) kdict = { "gather_kernels": gather_k, "b_gather_kernels": b_gather_k, "scatter_kernels": scatter_k, } h, w = mask.shape for k in kdict: kernels = crop_like(kdict[k], gres_c) kernels -= kernels.max() kernels = th.exp(kernels) for s in range(args.spp): d = os.path.join(args.output, k, "spp%02d" % s) os.makedirs(d, exist_ok=True) for y in range(h): for x in range(w): kernel = kernels[s, :, :, y, x].detach().cpu().numpy() if mask[y, x] == 1: suff = "outlier" else: suff = "" skio.imsave( os.path.join(d, "y%02d_x%02d%s.png" % (y, x, suff)), kernel) path = os.path.join(args.output, "6_gather_kernels.png") skio.imsave(path, gather_k_viz) path = os.path.join(args.output, "6_b_gather_kernels.png") skio.imsave(path, b_gather_k_viz) path = os.path.join(args.output, "7_scatter_kernels.png") skio.imsave(path, scatter_k_viz) ax = plt.subplot(111) plt.plot(all_losses.T) plt.legend(["gather", "gather(big)", "scatter"]) plt.title("loss vs. step") plt.xlabel("optim step") plt.ylabel("MSE") ax.set_yscale("log", nonposy='clip') path = os.path.join(args.output, "8_loss.png") plt.savefig(path) path = os.path.join(args.output, "8_loss.pdf") plt.savefig(path)
def main(args): """Entrypoint to the training.""" # Load model parameters from checkpoint, if any meta = ttools.Checkpointer.load_meta(args.checkpoint_dir) if meta is None: LOG.warning("No checkpoint found at %s, aborting.", args.checkpoint_dir) return data = demosaicnet.Dataset(args.data, download=False, mode=meta["mode"], subset=demosaicnet.TEST_SUBSET) dataloader = DataLoader(data, batch_size=1, num_workers=4, pin_memory=True, shuffle=True) if meta["mode"] == demosaicnet.BAYER_MODE: model = demosaicnet.BayerDemosaick(depth=meta["depth"], width=meta["width"], pretrained=True, pad=False) elif meta["mode"] == demosaicnet.XTRANS_MODE: model = demosaicnet.XTransDemosaick(depth=meta["depth"], width=meta["width"], pretrained=True, pad=False) checkpointer = ttools.Checkpointer(args.checkpoint_dir, model, meta=meta) checkpointer.load_latest() # Resume from checkpoint, if any. # No need for gradients for p in model.parameters(): p.requires_grad = False mse_fn = th.nn.MSELoss() psnr_fn = PSNR() device = "cpu" if th.cuda.is_available(): device = "cuda" LOG.info("Using CUDA") count = 0 mse = 0.0 psnr = 0.0 for idx, batch in enumerate(dataloader): mosaic = batch[0].to(device) target = batch[1].to(device) output = model(mosaic) target = crop_like(target, output) output = th.clamp(output, 0, 1) psnr_ = psnr_fn(output, target).item() mse_ = mse_fn(output, target).item() psnr += psnr_ mse += mse_ count += 1 LOG.info("Image %04d, PSNR = %.1f dB, MSE = %.5f", idx, psnr_, mse_) mse /= count psnr /= count LOG.info("-----------------------------------") LOG.info("Average, PSNR = %.1f dB, MSE = %.5f", psnr, mse)
def forward(self, samples): """Forward pass of the model. Args: data(dict) with keys: "radiance": (th.Tensor[bs, spp, 3, h, w]) sample radiance. "features": (th.Tensor[bs, spp, nf, h, w]) sample features. "global_features": (th.Tensor[bs, ngf, h, w]) global features. Returns: (dict) with keys: "radiance": (th.Tensor[bs, 3, h, w]) denoised radiance """ radiance = samples["radiance"] features = samples["features"] gfeatures = samples["global_features"].cuda() if self.pixel: # Make the pixel-average look like one sample radiance = radiance.mean(1, keepdim=True) features = features.mean(1, keepdim=True) bs, spp, nf, h, w = features.shape modules = {n: m for (n, m) in self.named_modules()} limit_memory_usage = not self.training # -- Embed the samples then collapse to pixel-wise summaries ---------- if limit_memory_usage: gf = gfeatures.repeat([1, 1, h, w]) new_features = th.zeros(bs, spp, self.embedding_width, h, w) else: gf = gfeatures.repeat([spp, 1, h, w]) for step in range(self.nsteps): if limit_memory_usage: # Go through the samples one by one to preserve memory for # large images for sp in range(spp): f = features[:, sp].cuda() if step == 0: # Global features at first iteration only f = th.cat([f, gf], 1) else: f = th.cat([f, propagated], 1) f = modules["embedding_{:02d}".format(step)](f) new_features[:, sp].copy_(f, non_blocking=True) if sp == 0: reduced = f else: reduced.add_(f) del f th.cuda.empty_cache() features = new_features reduced.div_(spp) th.cuda.empty_cache() else: flat = features.view([bs * spp, nf, h, w]) if step == 0: # Global features at first iteration only flat = th.cat([flat, gf], 1) else: flat = th.cat([ flat, propagated.unsqueeze(1).repeat([1, spp, 1, 1, 1]).view( spp * bs, self.width, h, w) ], 1) flat = modules["embedding_{:02d}".format(step)](flat) flat = flat.view(bs, spp, self.embedding_width, h, w) reduced = flat.mean(1) features = flat nf = self.embedding_width # Propagate spatially the pixel context propagated = modules["propagation_{:02d}".format(step)](reduced) if limit_memory_usage: del reduced th.cuda.empty_cache() # Predict kernels based on the context information and # the current sample's features sum_r, sum_w, max_w = None, None, None for sp in range(spp): f = features[:, sp].cuda() f = th.cat([f, propagated], 1) r = radiance[:, sp].cuda() kernels = self.kernel_regressor(f) if limit_memory_usage: th.cuda.empty_cache() # Update radiance estimate sum_r, sum_w, max_w = self.kernel_update(crop_like(r, kernels), kernels, sum_r, sum_w, max_w) if limit_memory_usage: th.cuda.empty_cache() # Normalize output with the running sum output = sum_r / (sum_w + self.eps) # Remove the invalid boundary data crop = (self.ksize - 1) // 2 output = output[..., crop:-crop, crop:-crop] return {"radiance": output}
def main(args): log.info("Loading model {}".format(args.checkpoint)) meta_params = ttools.Checkpointer.load_meta(args.checkpoint) spp = meta_params["spp"] use_p = meta_params["use_p"] use_ld = meta_params["use_ld"] use_bt = meta_params["use_bt"] # use_coc = meta_params["use_coc"] mode = "sample" if "DisneyPreprocessor" == meta_params["preprocessor"]: mode = "disney_pixel" elif "SampleDisneyPreprocessor" == meta_params["preprocessor"]: mode = "disney_sample" log.info("Rendering at {} spp".format(spp)) log.info("Setting up dataloader, p:{} bt:{} ld:{}".format(use_p, use_bt, use_ld)) data = dset.FullImageDataset(args.data, dset.RenderDataset, spp=spp, use_p=use_p, use_ld=use_ld, use_bt=use_bt) preprocessor = pre.get(meta_params["preprocessor"])(data) xforms = transforms.Compose([dset.ToTensor(), preprocessor]) data.transform = xforms dataloader = DataLoader(data, batch_size=1, shuffle=False, num_workers=0, pin_memory=True) model = models.get(preprocessor, meta_params["model_params"]) model.cuda() model.train(False) checkpointer = ttools.Checkpointer(args.checkpoint, model, None) extras, meta = checkpointer.load_latest() log.info("Loading latest checkpoint {}".format("failed" if meta is None else "success")) for scene_id, batch in enumerate(dataloader): batch_v = make_variable(batch, cuda=True) with th.no_grad(): klist = [] out_ = model(batch_v, kernel_list=klist) lowspp = batch["radiance"] target = batch["target_image"] out = out_["radiance"] cx = 70 cy = 20 c = 128 target = crop_like(target, out) lowspp = crop_like(lowspp.squeeze(), out) lowspp = lowspp[..., cy:cy+c, cx:cx+c] lowspp = lowspp.permute(1, 2, 0, 3) chan, h, w, s = lowspp.shape lowspp = lowspp.contiguous().view(chan, h, w*s) sum_r = [] sum_w = [] max_w = [] maxi = crop_like(klist[-1]["max_w"].unsqueeze(1), out) kernels = [] updated_kernels = [] for k in klist: kernels.append(th.exp(crop_like(k["kernels"], out)-maxi)) updated_kernels.append(th.exp(crop_like(k["updated_kernels"], out)-maxi)) out = out[..., cy:cy+c, cx:cx+c] target = target[..., cy:cy+c, cx:cx+c] updated_kernels = [k[..., cy:cy+c, cx:cx+c] for k in updated_kernels] kernels = [k[..., cy:cy+c, cx:cx+c] for k in kernels] u_kernels_im = viz.kernels2im(kernels) kmean = u_kernels_im.mean(0) kvar = u_kernels_im.std(0) n, h, w = u_kernels_im.shape u_kernels_im = u_kernels_im.permute(1, 0, 2).contiguous().view(h, w*n) fname = os.path.join(args.output, "lowspp.png") save(fname, lowspp) fname = os.path.join(args.output, "target.png") save(fname, target) fname = os.path.join(args.output, "output.png") save(fname, out) fname = os.path.join(args.output, "kernels_gather.png") save(fname, u_kernels_im) fname = os.path.join(args.output, "kernels_variance.png") print(kvar.max()) save(fname, kvar) import ipdb; ipdb.set_trace() break