def get_comparisons(self, model_input, prediction, ground_truth=None):
        predictions, depth_maps = prediction

        batch_size = predictions.shape[0]

        # Parse model input.
        intrinsics = model_input["intrinsics"].cuda()
        uv = model_input["uv"].cuda().float()

        x_cam = uv[:, :, 0].view(batch_size, -1)
        y_cam = uv[:, :, 1].view(batch_size, -1)
        z_cam = depth_maps.view(batch_size, -1)

        normals = geometry.compute_normal_map(x_img=x_cam,
                                              y_img=y_cam,
                                              z=z_cam,
                                              intrinsics=intrinsics)
        normals = F.pad(normals, pad=(1, 1, 1, 1), mode="constant", value=1.)

        predictions = util.lin2img(predictions)

        if ground_truth is not None:
            trgt_imgs = ground_truth["rgb"]
            trgt_imgs = util.lin2img(trgt_imgs)

            return torch.cat(
                (normals.cpu(), predictions.cpu(), trgt_imgs.cpu()),
                dim=3).numpy()
        else:
            return torch.cat((normals.cpu(), predictions.cpu()), dim=3).numpy()
    def get_psnr(self, prediction, ground_truth):
        """Compute PSNR of model image predictions.

        :param prediction: Return value of forward pass.
        :param ground_truth: Ground truth.
        :return: (psnr, ssim): tuple of floats
        """
        pred_imgs, _ = prediction
        trgt_imgs = ground_truth['rgb']

        trgt_imgs = trgt_imgs.cuda()
        batch_size = pred_imgs.shape[0]

        if not isinstance(pred_imgs, np.ndarray):
            pred_imgs = util.lin2img(pred_imgs).detach().cpu().numpy()

        if not isinstance(trgt_imgs, np.ndarray):
            trgt_imgs = util.lin2img(trgt_imgs).detach().cpu().numpy()

        psnrs, ssims = list(), list()
        for i in range(batch_size):
            p = pred_imgs[i].squeeze().transpose(1, 2, 0)
            trgt = trgt_imgs[i].squeeze().transpose(1, 2, 0)

            p = (p / 2.) + 0.5
            p = np.clip(p, a_min=0., a_max=1.)

            trgt = (trgt / 2.) + 0.5

            ssim = skimage.measure.compare_ssim(p,
                                                trgt,
                                                multichannel=True,
                                                data_range=1)
            psnr = skimage.measure.compare_psnr(p, trgt, data_range=1)

            psnrs.append(psnr)
            ssims.append(ssim)

        return psnrs, ssims
Ejemplo n.º 3
0
def compute_normal_map(x_img, y_img, z, intrinsics):
    cam_coords = lift(x_img, y_img, z, intrinsics)
    cam_coords = util.lin2img(cam_coords)

    shift_left = cam_coords[:, :, 2:, :]
    shift_right = cam_coords[:, :, :-2, :]

    shift_up = cam_coords[:, :, :, 2:]
    shift_down = cam_coords[:, :, :, :-2]

    diff_hor = F.normalize(shift_right - shift_left, dim=1)[:, :, :, 1:-1]
    diff_ver = F.normalize(shift_up - shift_down, dim=1)[:, :, 1:-1, :]

    cross = torch.cross(diff_hor, diff_ver, dim=1)
    return cross
Ejemplo n.º 4
0
    def forward(self,
                cam2world,  # pose
                phi,
                uv,
                intrinsics):
        batch_size, num_samples, _ = uv.shape
        log = list()

        ray_dirs = geometry.get_ray_directions(uv,
                                               cam2world=cam2world,
                                               intrinsics=intrinsics)

        initial_depth = torch.zeros((batch_size, num_samples, 1)).normal_(mean=0.05, std=5e-4).cuda()
        init_world_coords = geometry.world_from_xy_depth(uv,
                                                         initial_depth,
                                                         intrinsics=intrinsics,
                                                         cam2world=cam2world)


        world_coords = [init_world_coords]
        depths = [initial_depth]
        states = [None]

        for step in range(self.steps):

            v = phi(world_coords[-1])

            state = self.lstm(v.view(-1, self.n_feature_channels), states[-1])

            if state[0].requires_grad:
                state[0].register_hook(lambda x: x.clamp(min=-20, max=20))

            signed_distance = self.out_layer(state[0]).view(batch_size, num_samples, 1)
            new_world_coords = world_coords[-1] + ray_dirs * signed_distance

            states.append(state)
            world_coords.append(new_world_coords)

            depth = geometry.depth_from_world(world_coords[-1], cam2world)

            # if self.training:
            #     print("Raymarch step %d/%d: Min depth %0.6f, max depth %0.6f" %
            #           (step, self.steps, depths[-1].min().detach().cpu().numpy(), depths[-1].max().detach().cpu().numpy()))
            depths.append(depth)

        if not self.counter % 100:
            # Write tensorboard summary for each step of ray-marcher.
            drawing_depths = torch.stack(depths, dim=0)[:, 0, :, :]
            drawing_depths = util.lin2img(drawing_depths).repeat(1, 3, 1, 1)
            log.append(('image', 'raycast_progress',
                        torch.clamp(torchvision.utils.make_grid(drawing_depths, scale_each=False, normalize=True), 0.0,
                                    5),
                        100))

            # Visualize residual step distance (i.e., the size of the final step)
            fig = util.show_images([util.lin2img(signed_distance)[i, :, :, :].detach().cpu().numpy().squeeze()
                                    for i in range(batch_size)])
            log.append(('figure', 'stopping_distances', fig, 100))
        self.counter += 1

        return world_coords[-1], depths[-1], log
    def write_updates(self,
                      writer,
                      predictions,
                      ground_truth,
                      iter,
                      prefix=""):
        """Writes tensorboard summaries using tensorboardx api.

        :param writer: tensorboardx writer object.
        :param predictions: Output of forward pass.
        :param ground_truth: Ground truth.
        :param iter: Iteration number.
        :param prefix: Every summary will be prefixed with this string.
        """
        predictions, depth_maps = predictions
        trgt_imgs = ground_truth['rgb']

        trgt_imgs = trgt_imgs.cuda()

        batch_size, num_samples, _ = predictions.shape

        # Module"s own log
        for type, name, content, every_n in self.logs:
            name = prefix + name

            if not iter % every_n:
                if type == "image":
                    writer.add_image(name,
                                     content.detach().cpu().numpy(), iter)
                    writer.add_scalar(name + "_min", content.min(), iter)
                    writer.add_scalar(name + "_max", content.max(), iter)
                elif type == "figure":
                    writer.add_figure(name, content, iter, close=True)
                elif type == "histogram":
                    writer.add_histogram(name,
                                         content.detach().cpu().numpy(), iter)
                elif type == "scalar":
                    writer.add_scalar(name,
                                      content.detach().cpu().numpy(), iter)
                elif type == "embedding":
                    writer.add_embedding(mat=content, global_step=iter)

        if not iter % 100:
            output_vs_gt = torch.cat((predictions, trgt_imgs), dim=0)
            output_vs_gt = util.lin2img(output_vs_gt)
            writer.add_image(
                prefix + "Output_vs_gt",
                torchvision.utils.make_grid(
                    output_vs_gt, scale_each=False,
                    normalize=True).cpu().detach().numpy(), iter)

            rgb_loss = ((predictions.float().cuda() -
                         trgt_imgs.float().cuda())**2).mean(dim=2,
                                                            keepdim=True)
            rgb_loss = util.lin2img(rgb_loss)

            fig = util.show_images([
                rgb_loss[i].detach().cpu().numpy().squeeze()
                for i in range(batch_size)
            ])
            writer.add_figure(prefix + "rgb_error_fig", fig, iter, close=True)

            depth_maps_plot = util.lin2img(depth_maps)
            writer.add_image(
                prefix + "pred_depth",
                torchvision.utils.make_grid(
                    depth_maps_plot.repeat(1, 3, 1, 1),
                    scale_each=True,
                    normalize=True).cpu().detach().numpy(), iter)

        writer.add_scalar(prefix + "out_min", predictions.min(), iter)
        writer.add_scalar(prefix + "out_max", predictions.max(), iter)

        writer.add_scalar(prefix + "trgt_min", trgt_imgs.min(), iter)
        writer.add_scalar(prefix + "trgt_max", trgt_imgs.max(), iter)

        if iter:
            writer.add_scalar(prefix + "latent_reg_loss", self.latent_reg_loss,
                              iter)
 def get_output_img(self, prediction):
     pred_imgs, _ = prediction
     return util.lin2img(pred_imgs)