Exemple #1
0
 def visualized_image(self, batch, fwd_output):
     fwd_output = fwd_output.cpu().detach()
     mosaic, target = batch
     mosaic = crop_like(mosaic.cpu().detach(), fwd_output)
     target = crop_like(target.cpu().detach(), fwd_output)
     diff = 4 * (fwd_output - target).abs()
     vizdata = [mosaic, target, fwd_output, diff]
     viz = th.clamp(th.cat(vizdata, 2), 0, 1)
     return viz
Exemple #2
0
    def backward(self, batch, fwd):
        self.optimizer.zero_grad()

        out = fwd["radiance"]
        tgt = crop_like(batch["target_image"], out)  # make sure sizes match

        loss = self.loss_fn(out, tgt)
        loss.backward()

        # Couple checks to pick up on outliers in the data.
        if not np.isfinite(loss.data.item()):
            LOG.error("Loss is infinite, there might be outliers in the data.")
            raise RuntimeError("Infinite loss at train time.")

        if np.isnan(loss.data.item()):
            LOG.error("NaN in the loss, there might be outliers in the data.")
            raise RuntimeError("NaN loss at train time.")

        clip = 1000
        actual = th.nn.utils.clip_grad_norm_(self.model.parameters(), clip)
        if actual > clip:
            LOG.info("Clipped gradients {} -> {}".format(clip, actual))

        self.optimizer.step()

        with th.no_grad():
            rmse = self.rmse_fn(out, tgt)

        return {"loss": loss.item(), "rmse": rmse.item()}
Exemple #3
0
    def backward(self, tgt, fwd):
        if self.cuda:
            tgt = tgt.cuda()

        self.optimizer.zero_grad()
        loss = self.loss_fn(fwd, crop_like(tgt, fwd))
        loss.backward()
        self.optimizer.step()

        return loss.item()
Exemple #4
0
    def forward(self, data):
        """Forward pass of the model.

        Args:
            data(dict) with keys:
                "kpcn_diffuse_in":
                "kpcn_specular_in":
                "kpcn_diffuse_buffer":
                "kpcn_specular_buffer":
                "kpcn_albedo":

        Returns:
            (dict) with keys:
                "radiance":
                "diffuse":
                "specular":
        """
        # Process the diffuse and specular channels independently
        k_diffuse = self.diffuse(data["kpcn_diffuse_in"])
        k_specular = self.specular(data["kpcn_specular_in"])

        # Match dimensions
        b_diffuse = crop_like(data["kpcn_diffuse_buffer"],
                              k_diffuse).contiguous()
        b_specular = crop_like(data["kpcn_specular_buffer"],
                               k_specular).contiguous()

        # Kernel reconstruction
        r_diffuse, _ = self.kernel_apply(b_diffuse, k_diffuse)
        r_specular, _ = self.kernel_apply(b_specular, k_specular)

        # Combine diffuse/specular/albedo
        albedo = crop_like(data["kpcn_albedo"], r_diffuse)
        final_specular = th.exp(r_specular) - 1
        final_diffuse = albedo * r_diffuse
        final_radiance = final_diffuse + final_specular

        output = dict(radiance=final_radiance,
                      diffuse=r_diffuse,
                      specular=r_specular)

        return output
Exemple #5
0
    def visualized_image(self, batch, fwd_result):
        lowspp = batch["low_spp"].detach()
        target = batch["target_image"].detach()
        output = fwd_result["radiance"].detach()

        # Make sure images have the same size
        lowspp = crop_like(lowspp, output)
        target = crop_like(target, output)

        # Assemble a display gallery
        diff = (output - target).abs()
        data = th.cat([lowspp, output, target, diff], -2)

        # Clip and tonemap
        data = th.clamp(data, 0)
        data /= 1 + data
        data = th.pow(data, 1.0 / 2.2)
        data = th.clamp(data, 0, 1)

        return data
Exemple #6
0
    def update_validation(self, batch, fwd_output, running_data):
        target = batch[1].to(self.device)

        # remove boundaries to match output size
        target = crop_like(target, fwd_output)

        with th.no_grad():
            psnr = self.psnr(th.clamp(fwd_output, 0, 1), target)
            n = target.shape[0]

        return {
            "psnr": running_data["psnr"] + psnr.item() * n,
            "count": running_data["count"] + n
        }
Exemple #7
0
    def backward(self, batch, fwd_output):
        target = batch[1].to(self.device)

        # remove boundaries to match output size
        target = crop_like(target, fwd_output)

        loss = self.loss(fwd_output, target)

        self.opt.zero_grad()
        loss.backward()
        self.opt.step()

        with th.no_grad():
            psnr = self.psnr(th.clamp(fwd_output, 0, 1), target)

        return {"loss": loss.item(), "psnr": psnr.item()}
Exemple #8
0
    def update_validation(self, batch, fwd, running):
        """Updates running statistics for the validation."""
        with th.no_grad():
            out = fwd["radiance"]
            tgt = crop_like(batch["target_image"], out)
            loss = self.loss_fn(out, tgt).item()
            rmse = self.rmse_fn(out, tgt).item()

        # Make sure our statistics accound for potentially varying batch size
        b = out.shape[0]

        # Update the running means
        n = running["n"] + b
        new_loss = running["loss"] - (1.0 / n) * (running["loss"] - b * loss)
        new_rmse = running["rmse"] - (1.0 / n) * (running["rmse"] - b * rmse)

        return {"loss": new_loss, "rmse": new_rmse, "n": n}
Exemple #9
0
    def forward(self, data, coords):
        # in_ = coords.contiguous()
        in_ = th.cat([th.log10(1.0 + data / 255.0), coords], 2).contiguous()
        assert in_.shape[
            0] == 1, "current implementation assumes batch_size = 1"
        kernels = self.net(in_.squeeze(0))
        cdata = crop_like(data.squeeze(0), kernels).contiguous()
        output, _ = self.kernel_update(cdata, kernels)

        # Average over samples
        output = th.unsqueeze(output, 0).mean(1)

        # crop output
        k = (self.ksize - 1) // 2
        output = output[..., k:-k, k:-k]

        kviz = kernels.detach().clone()
        min_ = kviz.min()
        max_ = kviz.max()
        kviz = (kviz - min_) / (max_ - min_ - 1e-8)
        bs, k2, h, w = kviz.shape
        return output, kviz.view(bs, self.ksize, self.ksize, h, w)
Exemple #10
0
def main(args):
    dataset = AADataset(args.input,
                        ds=args.ds,
                        spp=args.spp,
                        sigma=args.sigma,
                        size=args.size,
                        outliers=args.outliers,
                        outliers_p=args.outliers_p)
    loader = DataLoader(dataset, batch_size=1, num_workers=4)

    # Kernel optimization -------------
    # initialize everything to a box filter
    th.manual_seed(0)
    width = 32
    depth = 3
    gather_mdl = ForwardModel(dataset.h_lr,
                              dataset.w_lr,
                              ksize=args.ksize,
                              depth=depth,
                              width=width,
                              scatter=False)
    th.manual_seed(0)
    # beefier gather model
    b_gather_mdl = ForwardModel(dataset.h_lr,
                                dataset.w_lr,
                                ksize=args.ksize,
                                depth=depth * args.depth_factor,
                                width=width * args.width_factor,
                                scatter=False)
    th.manual_seed(0)
    scatter_mdl = ForwardModel(dataset.h_lr,
                               dataset.w_lr,
                               ksize=args.ksize,
                               depth=depth,
                               width=width,
                               scatter=True)

    gather_interface = OptimizerInterface(gather_mdl)
    b_gather_interface = OptimizerInterface(b_gather_mdl)
    scatter_interface = OptimizerInterface(scatter_mdl)

    # optimize the kernels
    all_losses = np.zeros((3, args.nsteps))
    step = 0
    for step, batch in enumerate(loader):
        data, subpixel_coords, target, mask = batch
        gather_result, gather_k = gather_interface.forward(
            data, subpixel_coords)
        loss_gather = gather_interface.backward(target, gather_result)

        b_gather_result, b_gather_k = b_gather_interface.forward(
            data, subpixel_coords)
        loss_b_gather = b_gather_interface.backward(target, b_gather_result)

        scatter_result, scatter_k = scatter_interface.forward(
            data, subpixel_coords)
        loss_scatter = scatter_interface.backward(target, scatter_result)

        all_losses[0, step] = loss_gather
        all_losses[1, step] = loss_b_gather
        all_losses[2, step] = loss_scatter

        step += 1
        if step == args.nsteps:
            break

        if step % 10 == 0:
            print(
                "{:05d} | Gather = {:.4f}, Gather(big) = {:.4f}, Scatter = {:.4f}, gather/scatter = {:.2f} gather(big)/scatter = {:.2f}"
                .format(step, loss_gather, loss_b_gather, loss_scatter,
                        loss_gather / loss_scatter,
                        loss_b_gather / loss_scatter))

    # ---------------------------------

    gather_k_viz = np.clip(kernels2im(gather_k).detach().cpu().numpy(), 0, 1)
    b_gather_k_viz = np.clip(
        kernels2im(b_gather_k).detach().cpu().numpy(), 0, 1)
    scatter_k_viz = np.clip(kernels2im(scatter_k).detach().cpu().numpy(), 0, 1)

    spp, h, w = gather_k_viz.shape
    gather_k_viz = gather_k_viz.reshape([spp * h, w])
    spp, b_h, b_w = b_gather_k_viz.shape
    b_gather_k_viz = b_gather_k_viz.reshape([spp * b_h, b_w])
    scatter_k_viz = scatter_k_viz.reshape([spp * h, w])

    # Conversion to numpy -----------------
    gres_c = gather_result.clone()
    mask = crop_like(mask.permute(0, 1, 4, 2, 3).max(1)[0], gather_result)
    mask = mask[0, 0]
    im = crop_like(data.mean(1), gather_result)
    im = th.clamp(im.detach()[0], 0,
                  255).permute(1, 2, 0).cpu().numpy().astype(np.uint8)
    target = crop_like(target, gather_result)
    target = th.clamp(target.detach()[0], 0,
                      255).permute(1, 2, 0).cpu().numpy().astype(np.uint8)
    gather_result = th.clamp(gather_result.detach()[0], 0,
                             255).permute(1, 2,
                                          0).cpu().numpy().astype(np.uint8)
    b_gather_result = th.clamp(b_gather_result.detach()[0], 0,
                               255).permute(1, 2,
                                            0).cpu().numpy().astype(np.uint8)
    scatter_result = th.clamp(scatter_result.detach()[0], 0,
                              255).permute(1, 2,
                                           0).cpu().numpy().astype(np.uint8)

    # Output -------------
    os.makedirs(args.output, exist_ok=True)

    path = os.path.join(args.output, "0_input.png")
    skio.imsave(path, im)

    path = os.path.join(args.output, "0b_mask.png")
    skio.imsave(path, mask)

    # path = os.path.join(args.output, "1_lowpass.png")
    # skio.imsave(path, lp)
    #
    # path = os.path.join(args.output, "2_subsampled_jitter.png")
    # skio.imsave(path, jitter_sampled)
    #
    path = os.path.join(args.output, "3_gather.png")
    skio.imsave(path, gather_result)

    path = os.path.join(args.output, "3_gather_big.png")
    skio.imsave(path, b_gather_result)

    path = os.path.join(args.output, "4_target.png")
    skio.imsave(path, target)

    path = os.path.join(args.output, "5_scatter.png")
    skio.imsave(path, scatter_result)

    kdict = {
        "gather_kernels": gather_k,
        "b_gather_kernels": b_gather_k,
        "scatter_kernels": scatter_k,
    }

    h, w = mask.shape
    for k in kdict:
        kernels = crop_like(kdict[k], gres_c)
        kernels -= kernels.max()
        kernels = th.exp(kernels)

        for s in range(args.spp):
            d = os.path.join(args.output, k, "spp%02d" % s)
            os.makedirs(d, exist_ok=True)

            for y in range(h):
                for x in range(w):
                    kernel = kernels[s, :, :, y, x].detach().cpu().numpy()
                    if mask[y, x] == 1:
                        suff = "outlier"
                    else:
                        suff = ""
                    skio.imsave(
                        os.path.join(d, "y%02d_x%02d%s.png" % (y, x, suff)),
                        kernel)

    path = os.path.join(args.output, "6_gather_kernels.png")
    skio.imsave(path, gather_k_viz)

    path = os.path.join(args.output, "6_b_gather_kernels.png")
    skio.imsave(path, b_gather_k_viz)

    path = os.path.join(args.output, "7_scatter_kernels.png")
    skio.imsave(path, scatter_k_viz)

    ax = plt.subplot(111)
    plt.plot(all_losses.T)
    plt.legend(["gather", "gather(big)", "scatter"])
    plt.title("loss vs. step")
    plt.xlabel("optim step")
    plt.ylabel("MSE")
    ax.set_yscale("log", nonposy='clip')
    path = os.path.join(args.output, "8_loss.png")
    plt.savefig(path)
    path = os.path.join(args.output, "8_loss.pdf")
    plt.savefig(path)
Exemple #11
0
def main(args):
    """Entrypoint to the training."""

    # Load model parameters from checkpoint, if any
    meta = ttools.Checkpointer.load_meta(args.checkpoint_dir)
    if meta is None:
        LOG.warning("No checkpoint found at %s, aborting.",
                    args.checkpoint_dir)
        return

    data = demosaicnet.Dataset(args.data,
                               download=False,
                               mode=meta["mode"],
                               subset=demosaicnet.TEST_SUBSET)
    dataloader = DataLoader(data,
                            batch_size=1,
                            num_workers=4,
                            pin_memory=True,
                            shuffle=True)

    if meta["mode"] == demosaicnet.BAYER_MODE:
        model = demosaicnet.BayerDemosaick(depth=meta["depth"],
                                           width=meta["width"],
                                           pretrained=True,
                                           pad=False)
    elif meta["mode"] == demosaicnet.XTRANS_MODE:
        model = demosaicnet.XTransDemosaick(depth=meta["depth"],
                                            width=meta["width"],
                                            pretrained=True,
                                            pad=False)

    checkpointer = ttools.Checkpointer(args.checkpoint_dir, model, meta=meta)
    checkpointer.load_latest()  # Resume from checkpoint, if any.

    # No need for gradients
    for p in model.parameters():
        p.requires_grad = False

    mse_fn = th.nn.MSELoss()
    psnr_fn = PSNR()

    device = "cpu"
    if th.cuda.is_available():
        device = "cuda"
        LOG.info("Using CUDA")

    count = 0
    mse = 0.0
    psnr = 0.0
    for idx, batch in enumerate(dataloader):
        mosaic = batch[0].to(device)
        target = batch[1].to(device)
        output = model(mosaic)

        target = crop_like(target, output)

        output = th.clamp(output, 0, 1)

        psnr_ = psnr_fn(output, target).item()
        mse_ = mse_fn(output, target).item()

        psnr += psnr_
        mse += mse_
        count += 1

        LOG.info("Image %04d, PSNR = %.1f dB, MSE = %.5f", idx, psnr_, mse_)

    mse /= count
    psnr /= count

    LOG.info("-----------------------------------")
    LOG.info("Average, PSNR = %.1f dB, MSE = %.5f", psnr, mse)
Exemple #12
0
    def forward(self, samples):
        """Forward pass of the model.

        Args:
            data(dict) with keys:
                "radiance": (th.Tensor[bs, spp, 3, h, w]) sample radiance.
                "features": (th.Tensor[bs, spp, nf, h, w]) sample features.
                "global_features": (th.Tensor[bs, ngf, h, w]) global features.

        Returns:
            (dict) with keys:
                "radiance": (th.Tensor[bs, 3, h, w]) denoised radiance
        """
        radiance = samples["radiance"]
        features = samples["features"]
        gfeatures = samples["global_features"].cuda()

        if self.pixel:
            # Make the pixel-average look like one sample
            radiance = radiance.mean(1, keepdim=True)
            features = features.mean(1, keepdim=True)

        bs, spp, nf, h, w = features.shape

        modules = {n: m for (n, m) in self.named_modules()}

        limit_memory_usage = not self.training

        # -- Embed the samples then collapse to pixel-wise summaries ----------
        if limit_memory_usage:
            gf = gfeatures.repeat([1, 1, h, w])
            new_features = th.zeros(bs, spp, self.embedding_width, h, w)
        else:
            gf = gfeatures.repeat([spp, 1, h, w])

        for step in range(self.nsteps):
            if limit_memory_usage:
                # Go through the samples one by one to preserve memory for
                # large images
                for sp in range(spp):
                    f = features[:, sp].cuda()
                    if step == 0:  # Global features at first iteration only
                        f = th.cat([f, gf], 1)
                    else:
                        f = th.cat([f, propagated], 1)

                    f = modules["embedding_{:02d}".format(step)](f)

                    new_features[:, sp].copy_(f, non_blocking=True)

                    if sp == 0:
                        reduced = f
                    else:
                        reduced.add_(f)

                    del f
                    th.cuda.empty_cache()

                features = new_features
                reduced.div_(spp)
                th.cuda.empty_cache()
            else:
                flat = features.view([bs * spp, nf, h, w])
                if step == 0:  # Global features at first iteration only
                    flat = th.cat([flat, gf], 1)
                else:
                    flat = th.cat([
                        flat,
                        propagated.unsqueeze(1).repeat([1, spp, 1, 1, 1]).view(
                            spp * bs, self.width, h, w)
                    ], 1)
                flat = modules["embedding_{:02d}".format(step)](flat)
                flat = flat.view(bs, spp, self.embedding_width, h, w)
                reduced = flat.mean(1)
                features = flat
                nf = self.embedding_width

            # Propagate spatially the pixel context
            propagated = modules["propagation_{:02d}".format(step)](reduced)

            if limit_memory_usage:
                del reduced
                th.cuda.empty_cache()

        # Predict kernels based on the context information and
        # the current sample's features
        sum_r, sum_w, max_w = None, None, None

        for sp in range(spp):
            f = features[:, sp].cuda()
            f = th.cat([f, propagated], 1)
            r = radiance[:, sp].cuda()
            kernels = self.kernel_regressor(f)
            if limit_memory_usage:
                th.cuda.empty_cache()

            # Update radiance estimate
            sum_r, sum_w, max_w = self.kernel_update(crop_like(r, kernels),
                                                     kernels, sum_r, sum_w,
                                                     max_w)
            if limit_memory_usage:
                th.cuda.empty_cache()

        # Normalize output with the running sum
        output = sum_r / (sum_w + self.eps)

        # Remove the invalid boundary data
        crop = (self.ksize - 1) // 2
        output = output[..., crop:-crop, crop:-crop]

        return {"radiance": output}
Exemple #13
0
def main(args):
  log.info("Loading model {}".format(args.checkpoint))
  meta_params = ttools.Checkpointer.load_meta(args.checkpoint)

  spp = meta_params["spp"]
  use_p = meta_params["use_p"]
  use_ld = meta_params["use_ld"]
  use_bt = meta_params["use_bt"]
  # use_coc = meta_params["use_coc"]

  mode = "sample"
  if "DisneyPreprocessor" == meta_params["preprocessor"]:
    mode = "disney_pixel"
  elif "SampleDisneyPreprocessor" == meta_params["preprocessor"]:
    mode = "disney_sample"

  log.info("Rendering at {} spp".format(spp))

  log.info("Setting up dataloader, p:{} bt:{} ld:{}".format(use_p, use_bt, use_ld))
  data = dset.FullImageDataset(args.data, dset.RenderDataset, spp=spp, use_p=use_p, use_ld=use_ld, use_bt=use_bt)
  preprocessor = pre.get(meta_params["preprocessor"])(data)
  xforms = transforms.Compose([dset.ToTensor(), preprocessor])
  data.transform = xforms
  dataloader = DataLoader(data, batch_size=1,
                          shuffle=False, num_workers=0,
                          pin_memory=True)

  model = models.get(preprocessor, meta_params["model_params"])
  model.cuda()
  model.train(False)

  checkpointer = ttools.Checkpointer(args.checkpoint, model, None)
  extras, meta = checkpointer.load_latest()
  log.info("Loading latest checkpoint {}".format("failed" if meta is None else "success"))

  for scene_id, batch in enumerate(dataloader):
    batch_v = make_variable(batch, cuda=True)
    with th.no_grad():
      klist = []
      out_ = model(batch_v, kernel_list=klist)
    lowspp = batch["radiance"]
    target = batch["target_image"]
    out = out_["radiance"]

    cx = 70
    cy = 20
    c = 128

    target = crop_like(target, out)
    lowspp = crop_like(lowspp.squeeze(), out)
    lowspp = lowspp[..., cy:cy+c, cx:cx+c]

    lowspp = lowspp.permute(1, 2, 0, 3)
    chan, h, w, s = lowspp.shape
    lowspp = lowspp.contiguous().view(chan, h, w*s)

    sum_r = []
    sum_w = []
    max_w = []
    maxi = crop_like(klist[-1]["max_w"].unsqueeze(1), out)
    kernels = []
    updated_kernels = []
    for k in klist:
      kernels.append(th.exp(crop_like(k["kernels"], out)-maxi)) 
      updated_kernels.append(th.exp(crop_like(k["updated_kernels"], out)-maxi)) 

    out = out[..., cy:cy+c, cx:cx+c]
    target = target[..., cy:cy+c, cx:cx+c]
    updated_kernels = [k[..., cy:cy+c, cx:cx+c] for k in updated_kernels]
    kernels = [k[..., cy:cy+c, cx:cx+c] for k in kernels]

    u_kernels_im = viz.kernels2im(kernels)
    kmean = u_kernels_im.mean(0)
    kvar = u_kernels_im.std(0)

    n, h, w = u_kernels_im.shape
    u_kernels_im = u_kernels_im.permute(1, 0, 2).contiguous().view(h, w*n)

    fname = os.path.join(args.output, "lowspp.png")
    save(fname, lowspp)
    fname = os.path.join(args.output, "target.png")
    save(fname, target)
    fname = os.path.join(args.output, "output.png")
    save(fname, out)
    fname = os.path.join(args.output, "kernels_gather.png")
    save(fname, u_kernels_im)
    fname = os.path.join(args.output, "kernels_variance.png")
    print(kvar.max())
    save(fname, kvar)
    import ipdb; ipdb.set_trace()
    break