示例#1
0
    def vcat(cls, cameras, batch_size=-1):
        """
        Concatenate the view dimension of the camera parameters.
        Args:
            cameras: cameras to concatenate
            batch_size: the batch size of the data

        Returns:
            Cameras concatenated in the view dimension then flattened
        """
        z_span = cameras[0].z_span
        height = cameras[0].height
        width = cameras[0].width
        intrinsic = torch.cat(
            [b2bv(o.intrinsic, batch_size=batch_size) for o in cameras], dim=1)
        viewport = torch.cat(
            [b2bv(o.viewport, batch_size=batch_size) for o in cameras], dim=1)
        log_quaternion = torch.cat(
            [b2bv(o.log_quaternion, batch_size=batch_size) for o in cameras],
            dim=1)
        translation = torch.cat(
            [b2bv(o.translation, batch_size=batch_size) for o in cameras],
            dim=1)

        return cls(bv2b(intrinsic),
                   None,
                   z_span,
                   bv2b(viewport),
                   log_quaternion=bv2b(log_quaternion),
                   translation=bv2b(translation),
                   width=width,
                   height=height)
示例#2
0
def warp_blend_logits(logits, image_reproj, flow_size):
    device = image_reproj.device
    num_input_views = image_reproj.shape[1]
    height, width = image_reproj.shape[-2:]
    blend_logits, flow_x_logits, flow_y_logits = torch.split(logits,
                                                             num_input_views,
                                                             dim=1)
    blend_weights = torch.softmax(blend_logits, dim=1).unsqueeze(2)
    flow_dx = flow_size / width * torch.tanh(flow_x_logits)
    flow_dy = flow_size / height * torch.tanh(flow_y_logits)
    flow_y, flow_x = torch.meshgrid([
        torch.linspace(-1, 1, height, device=device),
        torch.linspace(-1, 1, width, device=device)
    ])
    flow_x = flow_x[None, None, :, :].expand_as(flow_dx) + flow_dx
    flow_y = flow_y[None, None, :, :].expand_as(flow_dy) + flow_dy
    flow_grid = torch.stack((flow_x, flow_y), dim=-1).clamp(-1, 1)

    image_fake = F.grid_sample(bv2b(image_reproj),
                               bv2b(flow_grid),
                               mode='bilinear')
    image_fake = b2bv(image_fake, num_input_views)
    image_fake = (blend_weights * image_fake).sum(dim=1)

    return image_fake, blend_weights, flow_dx, flow_dy
示例#3
0
def reproject_views(image_in, depth_in, depth_out, camera_in, camera_out):
    """
    Reprojects pixels from the input view to the output view.

    Args:
        image_in: pixels to copy from (V_i, C, H, W)
        depth_out: target depth (V_o, C, H, W)
        camera_in: input view cameras
        camera_out: output view cameras

    Returns:
        Each input view image reprojected to output views (V_o, V_i, C, H, W)
        Each input view depth transformed and reprojected to output views (V_o, V_i, C, H, W)

    """
    grid = depth_to_warp_field(camera_in, camera_out, depth_out)

    # Expand and reshape to do batch operations:
    #   batch_dim = output_views
    #   view_dim = input_views
    image_in = bv2b(
        image_in.unsqueeze(0).expand(camera_out.length, -1, -1, -1, -1))

    obj_coords_in = torch.stack(camera_in.depth_object_coords(depth_in),
                                dim=-1)
    obj_coords_in = bv2b(
        obj_coords_in.unsqueeze(0).expand(camera_out.length, -1, -1, -1, -1))

    # Repeat interleaved to expand to input view dimensions.
    camera_out = camera_out.repeat_interleave(camera_in.length)

    # Transform reprojected input coordinates to output view.
    cam_coords_in_tf = three.transform_coord_grid(obj_coords_in,
                                                  camera_out.obj_to_cam)
    depth_in_tf = cam_coords_in_tf[..., 2].unsqueeze(1)
    depth_in_tf = camera_out.normalize_depth(depth_in_tf)

    grid = bv2b(grid)

    image_reproj = F.grid_sample(image_in, grid, mode='bilinear')
    depth_reproj = F.grid_sample(depth_in_tf, grid, mode='bilinear')
    return b2bv(image_reproj, camera_in.length), b2bv(depth_reproj,
                                                      camera_in.length)
示例#4
0
    def compute_blend_weights(self, z_cam, camera):
        num_views = z_cam.shape[1]

        z_cam = bv2b(z_cam)

        # Concatenate camera-space coordinates to input.
        coords = utils.get_normalized_voxel_depth(z_cam)

        w = torch.cat((z_cam, coords), dim=1)
        w = self.unet(w)
        w = self.transform_block(w, camera)
        w = b2bv(w, num_views)

        # Softmax along view dimension.
        w = torch.softmax(w, dim=1)

        return w
示例#5
0
    def run_iteration(self, batch, train, is_step):
        if 'hard_' in self.g_depth_recon_loss_type:
            self._g_color_recon_criterion.k = int(
                self._g_color_recon_k_scheduler.get(self.epoch))

        batch = process_batch(batch,
                              self.cube_size,
                              self.camera_dist,
                              self._sculptor.in_size,
                              self.device,
                              random_orientation=False)

        if not self.color_random_background or self.crop_random_background:
            batch['in']['image'] = batch['in']['image'] * batch['in']['mask']

        if not self.depth_random_background or self.crop_random_background:
            batch['in']['depth'] = mask_normalized_depth(
                batch['in']['depth'], batch['in']['mask'])

        data_process_time = self.mark_time()

        generator = self._generator
        if self.data_parallel:
            generator = nn.DataParallel(generator)

        image_reproj, depth_reproj, mask_ibr_out, depth_ibr_out, cam_dists_r, cam_dists_t = \
            self._render_reprojections(batch)
        ibr_time = self.mark_time()

        # Add dists as another channel to image
        x = torch.cat((
            image_reproj,
            depth_reproj,
            cam_dists_r[:, :, :, None, None, None].expand(
                -1, -1, -1, -1, *image_reproj.shape[-2:]),
            cam_dists_t[:, :, :, None, None, None].expand(
                -1, -1, -1, -1, *image_reproj.shape[-2:]),
        ),
                      dim=3)
        # Factor input views into channels and batch/output_views into batch dim.
        x = x.view(x.shape[0] * x.shape[1], x.shape[2] * x.shape[3],
                   x.shape[4], x.shape[5])

        # Add output predicted depth to input.
        x = torch.cat((bv2b(depth_ibr_out), x), dim=1)
        blend_weights = None
        flow_dx = None
        flow_dy = None
        logits = generator(x, z_inject=None)
        if self.ibr_type == 'regress':
            image_ibr_out = torch.tanh(logits)
        elif self.ibr_type == 'blend':
            image_ibr_out, blend_weights = ibr.blend_logits(
                logits, bv2b(image_reproj))
        else:
            image_ibr_out, blend_weights, flow_dx, flow_dy = ibr.warp_blend_logits(
                logits, bv2b(image_reproj), self.flow_size)

        image_ibr_out = b2bv(image_ibr_out, self.num_output_views)

        if not self.no_apply_mask:
            image_ibr_out = image_ibr_out * mask_ibr_out
            depth_ibr_out = mask_normalized_depth(depth_ibr_out, mask_ibr_out)

        if self._discriminator:
            image_real_noise = (
                self.input_noise_weight *
                self._input_noise_dist.sample(batch['out_gt']['image'].size()))
            image_fake_noise = (
                self.input_noise_weight *
                self._input_noise_dist.sample(image_ibr_out.size()))

            discriminator = self._discriminator
            if self.data_parallel:
                discriminator = nn.DataParallel(discriminator)
            # Train discriminator.
            d_real = discriminator(bv2b(batch['out_gt']['image'] +
                                        image_real_noise.to(self.device)),
                                   mask=bv2b(batch['out_gt']['mask']))
            d_fake_d = discriminator(image_ibr_out.detach() +
                                     image_fake_noise.to(self.device),
                                     mask=mask_ibr_out)

            loss_d_real = multiscale_lsgan_loss(d_real, 1)
            loss_d_fake = multiscale_lsgan_loss(d_fake_d, 0)
            loss_d = loss_d_real + loss_d_fake

            d_fake_g = discriminator(image_ibr_out + image_fake_noise,
                                     mask=mask_ibr_out)
            loss_g_gan = self.g_gan_loss_weight * multiscale_lsgan_loss(
                d_fake_g, 1)

            if train:
                loss_d.backward()
                if is_step:
                    self._optimizers['discriminator'].step()

            self.plotter.put_scalar('loss/discriminator/real', loss_d_real)
            self.plotter.put_scalar('loss/discriminator/fake', loss_d_fake)
            self.plotter.put_scalar('loss/discriminator/total', loss_d)
        else:
            d_real, d_fake_d, d_fake_g = None, None, None
            loss_g_gan = torch.tensor(0.0, device=self.device)

        # Train generator. Must re-evaluate discriminator to propagate gradients down the
        # generator.
        loss_g_color_recon = self.g_color_recon_loss_weight * reduce_loss(
            self._g_color_recon_criterion(image_ibr_out,
                                          batch['out_gt']['image']),
            reduction='mean')
        loss_g_depth_recon = self.g_depth_recon_loss_weight * reduce_loss(
            self._g_depth_recon_criterion(depth_ibr_out,
                                          batch['out_gt']['depth']),
            reduction='mean')
        loss_g_mask_recon = self.g_depth_recon_loss_weight * reduce_loss(
            self._g_depth_recon_criterion(mask_ibr_out,
                                          batch['out_gt']['mask']),
            reduction='mean')
        loss_g_mask_beta = beta_prior_loss(mask_ibr_out,
                                           alpha=self.g_mask_beta_loss_param,
                                           beta=self.g_mask_beta_loss_param)
        loss_g = (loss_g_gan + loss_g_color_recon + loss_g_depth_recon +
                  loss_g_mask_recon + loss_g_mask_beta)
        if train:
            loss_g.backward()
            if is_step:
                self._optimizers['generator'].step()
                if self.train_recon:
                    self._optimizers['sculptor'].step()
                    self._optimizers['photographer'].step()
                    if 'fuser' in self._optimizers:
                        self._optimizers['fuser'].step()

        compute_time = self.mark_time()

        with torch.no_grad():
            self.plotter.put_scalar(
                'error/color/l1',
                F.l1_loss(image_ibr_out, batch['out_gt']['image']))
            self.plotter.put_scalar(
                'error/depth/l1',
                F.l1_loss(depth_ibr_out, batch['out_gt']['depth']))
            self.plotter.put_scalar(
                'error/mask/l1',
                F.l1_loss(mask_ibr_out, batch['out_gt']['mask']))
            self.plotter.put_scalar(
                'error/mask/cross_entropy',
                F.binary_cross_entropy_with_logits(mask_ibr_out,
                                                   batch['out_gt']['mask']))

        self.plotter.put_scalar('loss/generator/gan', loss_g_gan)
        self.plotter.put_scalar('loss/generator/recon/color',
                                loss_g_color_recon)
        self.plotter.put_scalar('loss/generator/recon/depth',
                                loss_g_depth_recon)
        self.plotter.put_scalar('loss/generator/recon/mask', loss_g_mask_recon)
        self.plotter.put_scalar('loss/generator/total', loss_g)

        self.plotter.put_scalar('params/input_noise_weight',
                                self.input_noise_weight)

        self.plotter.put_scalar('time/data_process', data_process_time)
        self.plotter.put_scalar('time/ibr', ibr_time)
        self.plotter.put_scalar('time/compute', compute_time)
        plot_scalar_time = self.mark_time()
        self.plotter.put_scalar('time/plot/scalars', plot_scalar_time)

        if self.plotter.is_it_time_yet('show'):
            self.plotter.put_image(
                'results',
                viz.make_grid([
                    gan_denormalize(batch['out_gt']['image']),
                    gan_denormalize(image_ibr_out),
                    viz.colorize_tensor(batch['out_gt']['depth'] / 2.0 + 0.5),
                    viz.colorize_tensor(depth_ibr_out / 2.0 + 0.5),
                    viz.colorize_tensor(batch['out_gt']['mask']),
                    viz.colorize_tensor(mask_ibr_out),
                ],
                              row_size=1,
                              output_size=128,
                              d_real=d_real,
                              d_fake=d_fake_d))

            n = self.num_output_views
            images = [
                gan_denormalize(image_reproj[0].view(
                    -1, *image_reproj.shape[-3:])),
                viz.colorize_tensor(
                    depth_reproj[0].view(-1, *depth_reproj.shape[-3:]) / 2.0 +
                    0.5)
            ]
            if blend_weights is not None:
                images.append(
                    viz.colorize_tensor(blend_weights[:n].view(
                        -1, *blend_weights.shape[-2:])))
            if flow_dx is not None:
                flow_range = self.flow_size / self.input_size
                images.append(
                    viz.colorize_tensor(flow_dx[:n].reshape(
                        -1, *flow_dx.shape[-2:]),
                                        cmap='coolwarm',
                                        cmin=-flow_range,
                                        cmax=flow_range))
                images.append(
                    viz.colorize_tensor(flow_dy[:n].reshape(
                        -1, *flow_dy.shape[-2:]),
                                        cmap='coolwarm',
                                        cmin=-flow_range,
                                        cmax=flow_range))
            self.plotter.put_image(
                'ibr_components',
                viz.make_grid(images, row_size=4, stride=4, output_size=64))

        if self.plotter.is_it_time_yet('histogram'):
            if flow_dx is not None:
                self.plotter.put_histogram('flow/dx', flow_dx)
                self.plotter.put_histogram('flow/dy', flow_dy)

        plot_images_time = self.mark_time()
        self.plotter.put_scalar('time/plot/images', plot_images_time)