Ejemplo n.º 1
0
    def postprocessing(self, ray_start, ray_dir, all_results, hits, sizes):
        # we will trace the background field here
        S, V, P = sizes
        fullsize = S * V * P

        vox_colors = fill_in((fullsize, 3), hits, all_results['colors'], 0.0)
        vox_missed = fill_in((fullsize, ), hits, all_results['missed'], 1.0)
        vox_depths = fill_in((fullsize, ), hits, all_results['depths'], 0.0)

        mid_dis = (self.args.near + self.args.far) / 2
        n_depth = fill_in((fullsize, ), hits, all_results['min_depths'],
                          mid_dis)[:, None]
        f_depth = fill_in((fullsize, ), hits, all_results['max_depths'],
                          mid_dis)[:, None]

        # front field
        nerf_step = getattr(self.args, "nerf_steps", 64)
        max_depth = n_depth
        min_depth = torch.ones_like(max_depth) * self.args.near
        intersection_outputs = {
            "min_depth": min_depth,
            "max_depth": max_depth,
            "probs": torch.ones_like(max_depth),
            "steps": torch.ones_like(max_depth).squeeze(-1) * nerf_step,
            "intersected_voxel_idx": torch.zeros_like(min_depth).int()
        }
        with with_torch_seed(self.unique_seed):
            fg_samples = self.bg_encoder.ray_sample(intersection_outputs)
        fg_results = self.raymarcher(self.bg_encoder, self.bg_field, ray_start,
                                     ray_dir, fg_samples, {})

        # back field
        min_depth = f_depth
        max_depth = torch.ones_like(min_depth) * self.args.far
        intersection_outputs = {
            "min_depth": min_depth,
            "max_depth": max_depth,
            "probs": torch.ones_like(max_depth),
            "steps": torch.ones_like(max_depth).squeeze(-1) * nerf_step,
            "intersected_voxel_idx": torch.zeros_like(min_depth).int()
        }
        with with_torch_seed(self.unique_seed):
            bg_samples = self.bg_encoder.ray_sample(intersection_outputs)
        bg_results = self.raymarcher(self.bg_encoder, self.bg_field, ray_start,
                                     ray_dir, bg_samples, {})

        # merge background to foreground
        all_results['voxcolors'] = vox_colors.view(S, V, P, 3)
        all_results['colors'] = fg_results[
            'colors'] + fg_results['missed'][:, None] * (
                vox_colors + vox_missed[:, None] * bg_results['colors'])
        all_results['depths'] = fg_results['depths'] + fg_results['missed'] * (
            vox_depths + vox_missed * bg_results['depths'])
        all_results['missed'] = fg_results['missed'] * vox_missed * bg_results[
            'missed']

        # apply the NSVF post-processing
        return super().postprocessing(ray_start, ray_dir, all_results, hits,
                                      sizes)
Ejemplo n.º 2
0
    def postprocessing(self, ray_start, ray_dir, all_results, hits, sizes):
        # we need fill_in for NSVF for background
        S, V, P = sizes
        fullsize = S * V * P

        all_results['missed'] = fill_in(
            (fullsize, ), hits, all_results['missed'], 1.0).view(S, V, P)
        all_results['colors'] = fill_in(
            (fullsize, 3), hits, all_results['colors'], 0.0).view(S, V, P, 3)
        all_results['depths'] = fill_in(
            (fullsize, ), hits, all_results['depths'], 0.0).view(S, V, P)

        BG_DEPTH = self.field.bg_color.depth
        bg_color = self.field.bg_color(all_results['colors'])
        all_results['colors'] += all_results['missed'].unsqueeze(
            -1) * bg_color.reshape(fullsize, 3).view(S, V, P, 3)
        all_results['depths'] += all_results['missed'] * BG_DEPTH
        if 'normal' in all_results:
            all_results['normal'] = fill_in((fullsize, 3), hits,
                                            all_results['normal'],
                                            0.0).view(S, V, P, 3)
        if 'voxel_depth' in all_results:
            all_results['voxel_depth'] = fill_in((fullsize, ), hits,
                                                 all_results['voxel_depth'],
                                                 BG_DEPTH).view(S, V, P)
        if 'voxel_edges' in all_results:
            all_results['voxel_edges'] = fill_in((fullsize, 3), hits,
                                                 all_results['voxel_edges'],
                                                 1.0).view(S, V, P, 3)
        if 'feat_n2' in all_results:
            all_results['feat_n2'] = fill_in(
                (fullsize, ), hits, all_results['feat_n2'], 0.0).view(S, V, P)
        return all_results
Ejemplo n.º 3
0
Archivo: nsvf.py Proyecto: yyeboah/NSVF
    def _forward(self, ray_start, ray_dir, **kwargs):
        S, V, P, _ = ray_dir.size()
        assert S == 1, "naive NeRF only supports single object."

        # voxel encoder (precompute for each voxel if needed)
        encoder_states = self.encoder.precompute(**kwargs)

        # ray-voxel intersection
        with GPUTimer() as timer0:
            ray_start, ray_dir, intersection_outputs, hits = \
                self.encoder.ray_intersect(ray_start, ray_dir, encoder_states)

            if self.reader.no_sampling and self.training:  # sample points after ray-voxel intersection
                uv, size = kwargs['uv'], kwargs['size']
                mask = hits.reshape(*uv.size()[:2], uv.size(-1))

                # sample rays based on voxel intersections
                sampled_uv, sampled_masks = self.reader.sample_pixels(
                    uv, size, mask=mask, return_mask=True)
                sampled_masks = sampled_masks.reshape(uv.size(0), -1).bool()
                hits, sampled_masks = hits[sampled_masks].reshape(
                    S, -1), sampled_masks.unsqueeze(-1)
                intersection_outputs = {
                    name: outs[sampled_masks.expand_as(outs)].reshape(
                        S, -1, outs.size(-1))
                    for name, outs in intersection_outputs.items()
                }
                ray_start = ray_start[sampled_masks.expand_as(
                    ray_start)].reshape(S, -1, 3)
                ray_dir = ray_dir[sampled_masks.expand_as(ray_dir)].reshape(
                    S, -1, 3)
                P = hits.size(-1) // V  # the number of pixels per image
            else:
                sampled_uv = None

        # neural ray-marching
        fullsize = S * V * P

        BG_DEPTH = self.field.bg_color.depth
        bg_color = self.field.bg_color(ray_dir)

        all_results = defaultdict(lambda: None)
        if hits.sum() > 0:  # check if ray missed everything
            intersection_outputs = {
                name: outs[hits]
                for name, outs in intersection_outputs.items()
            }
            ray_start, ray_dir = ray_start[hits], ray_dir[hits]

            # sample evalution points along the ray
            samples = self.encoder.ray_sample(intersection_outputs)
            encoder_states = {
                name: s.reshape(-1, s.size(-1)) if s is not None else None
                for name, s in encoder_states.items()
            }

            # rendering
            all_results = self.raymarcher(self.encoder, self.field, ray_start,
                                          ray_dir, samples, encoder_states)
            all_results['depths'] = all_results[
                'depths'] + BG_DEPTH * all_results['missed']
            all_results['voxel_edges'] = self.encoder.get_edge(
                ray_start, ray_dir, samples, encoder_states)
            all_results['voxel_depth'] = samples['sampled_point_depth'][:, 0]

        # fill out the full size
        hits = hits.reshape(fullsize)
        all_results['missed'] = fill_in(
            (fullsize, ), hits, all_results['missed'], 1.0).view(S, V, P)
        all_results['depths'] = fill_in(
            (fullsize, ), hits, all_results['depths'], BG_DEPTH).view(S, V, P)
        all_results['voxel_depth'] = fill_in((fullsize, ), hits,
                                             all_results['voxel_depth'],
                                             BG_DEPTH).view(S, V, P)
        all_results['voxel_edges'] = fill_in((fullsize, 3), hits,
                                             all_results['voxel_edges'],
                                             1.0).view(S, V, P, 3)
        all_results['colors'] = fill_in(
            (fullsize, 3), hits, all_results['colors'], 0.0).view(S, V, P, 3)
        all_results['bg_color'] = bg_color.reshape(fullsize,
                                                   3).view(S, V, P, 3)
        all_results['colors'] += all_results['missed'].unsqueeze(
            -1) * all_results['bg_color']
        if 'normal' in all_results:
            all_results['normal'] = fill_in((fullsize, 3), hits,
                                            all_results['normal'],
                                            0.0).view(S, V, P, 3)

        # other logs
        all_results['other_logs'] = {
            'voxs_log': self.encoder.voxel_size.item(),
            'stps_log': self.encoder.step_size.item(),
            'tvox_log': timer0.sum,
            'asf_log': (all_results['ae'].float() / fullsize).item(),
            'ash_log': (all_results['ae'].float() / hits.sum()).item(),
            'nvox_log': self.encoder.num_voxels,
        }
        all_results['sampled_uv'] = sampled_uv
        return all_results