Ejemplo n.º 1
0
    def postprocessing(self, ray_start, ray_dir, all_results, hits, sizes):
        # we will trace the background field here
        S, V, P = sizes
        fullsize = S * V * P

        vox_colors = fill_in((fullsize, 3), hits, all_results['colors'], 0.0)
        vox_missed = fill_in((fullsize, ), hits, all_results['missed'], 1.0)
        vox_depths = fill_in((fullsize, ), hits, all_results['depths'], 0.0)

        mid_dis = (self.args.near + self.args.far) / 2
        n_depth = fill_in((fullsize, ), hits, all_results['min_depths'],
                          mid_dis)[:, None]
        f_depth = fill_in((fullsize, ), hits, all_results['max_depths'],
                          mid_dis)[:, None]

        # front field
        nerf_step = getattr(self.args, "nerf_steps", 64)
        max_depth = n_depth
        min_depth = torch.ones_like(max_depth) * self.args.near
        intersection_outputs = {
            "min_depth": min_depth,
            "max_depth": max_depth,
            "probs": torch.ones_like(max_depth),
            "steps": torch.ones_like(max_depth).squeeze(-1) * nerf_step,
            "intersected_voxel_idx": torch.zeros_like(min_depth).int()
        }
        with with_torch_seed(self.unique_seed):
            fg_samples = self.bg_encoder.ray_sample(intersection_outputs)
        fg_results = self.raymarcher(self.bg_encoder, self.bg_field, ray_start,
                                     ray_dir, fg_samples, {})

        # back field
        min_depth = f_depth
        max_depth = torch.ones_like(min_depth) * self.args.far
        intersection_outputs = {
            "min_depth": min_depth,
            "max_depth": max_depth,
            "probs": torch.ones_like(max_depth),
            "steps": torch.ones_like(max_depth).squeeze(-1) * nerf_step,
            "intersected_voxel_idx": torch.zeros_like(min_depth).int()
        }
        with with_torch_seed(self.unique_seed):
            bg_samples = self.bg_encoder.ray_sample(intersection_outputs)
        bg_results = self.raymarcher(self.bg_encoder, self.bg_field, ray_start,
                                     ray_dir, bg_samples, {})

        # merge background to foreground
        all_results['voxcolors'] = vox_colors.view(S, V, P, 3)
        all_results['colors'] = fg_results[
            'colors'] + fg_results['missed'][:, None] * (
                vox_colors + vox_missed[:, None] * bg_results['colors'])
        all_results['depths'] = fg_results['depths'] + fg_results['missed'] * (
            vox_depths + vox_missed * bg_results['depths'])
        all_results['missed'] = fg_results['missed'] * vox_missed * bg_results[
            'missed']

        # apply the NSVF post-processing
        return super().postprocessing(ray_start, ray_dir, all_results, hits,
                                      sizes)
Ejemplo n.º 2
0
    def forward(self, ray_split=1, **kwargs):
        with with_torch_seed(self.unique_seed):   # make sure different GPU sample different rays
            ray_start, ray_dir, uv = self.reader(**kwargs)
        
        kwargs.update({
            'field_fn': self.field.forward,
            'input_fn': self.encoder.forward})

        if ray_split == 1:
            results = self._forward(ray_start, ray_dir, **kwargs)
        else:
            total_rays = ray_dir.shape[2]
            chunk_size = total_rays // ray_split
            results = [
                self._forward(
                    ray_start, ray_dir[:, :, i: i+chunk_size], **kwargs)
                for i in range(0, total_rays, chunk_size)
            ]
            results = self.merge_outputs(results)

        results['samples'] = {
            'sampled_uv': results.get('sampled_uv', uv),
            'ray_start': ray_start,
            'ray_dir': ray_dir
        }

        # caching the prediction
        self.cache = {
            w: results[w].detach() 
                if isinstance(w, torch.Tensor) 
                else results[w] 
            for w in results
        }
        return results
Ejemplo n.º 3
0
 def raymarching(self,
                 ray_start,
                 ray_dir,
                 intersection_outputs,
                 encoder_states,
                 fine=False):
     # sample points and use middle point approximation
     with with_torch_seed(
             self.unique_seed):  # make sure each GPU sample differently.
         samples = self.encoder.ray_sample(intersection_outputs)
     field = self.field_fine if fine and (self.field_fine
                                          is not None) else self.field
     all_results = self.raymarcher(self.encoder, field, ray_start, ray_dir,
                                   samples, encoder_states)
     return samples, all_results