Esempio n. 1
0
    def forward_once(
        self, input_fn, field_fn, ray_start, ray_dir, samples, encoder_states, 
        early_stop=None, output_types=['sigma', 'texture']
        ):
        """
        chunks: set > 1 if out-of-memory. it can save some memory by time.
        """
        sampled_depth = samples['sampled_point_depth']
        sampled_idx = samples['sampled_point_voxel_idx'].long()
        
        # only compute when the ray hits
        sample_mask = sampled_idx.ne(-1)
        if early_stop is not None:
            sample_mask = sample_mask & (~early_stop.unsqueeze(-1))
        if sample_mask.sum() == 0:  # miss everything skip
            return None, 0

        sampled_xyz = ray(ray_start.unsqueeze(1), ray_dir.unsqueeze(1), sampled_depth.unsqueeze(2))
        sampled_dir = ray_dir.unsqueeze(1).expand(*sampled_depth.size(), ray_dir.size()[-1])
        samples['sampled_point_xyz'] = sampled_xyz
        samples['sampled_point_ray_direction'] = sampled_dir
    
        # apply mask    
        samples = {name: s[sample_mask] for name, s in samples.items()}
       
        # get encoder features as inputs
        field_inputs = input_fn(samples, encoder_states)
        
        # forward implicit fields
        field_outputs = field_fn(field_inputs, outputs=output_types)
        outputs = {'sample_mask': sample_mask}
        
        def masked_scatter(mask, x):
            B, K = mask.size()
            if x.dim() == 1:
                return x.new_zeros(B, K).masked_scatter(mask, x)
            return x.new_zeros(B, K, x.size(-1)).masked_scatter(
                mask.unsqueeze(-1).expand(B, K, x.size(-1)), x)
        
        # post processing
        if 'sigma' in field_outputs:
            sigma, sampled_dists= field_outputs['sigma'], field_inputs['dists']
            noise = 0 if not self.discrete_reg and (not self.training) else torch.zeros_like(sigma).normal_()  
            free_energy = torch.relu(noise + sigma) * sampled_dists   
            free_energy = free_energy * 7.0  # ? [debug] 
            # (optional) free_energy = (F.elu(sigma - 3, alpha=1) + 1) * dists
            # (optional) free_energy = torch.abs(sigma) * sampled_dists  ## ??
            outputs['free_energy'] = masked_scatter(sample_mask, free_energy)
        if 'sdf' in field_outputs:
            outputs['sdf'] = masked_scatter(sample_mask, field_outputs['sdf'])
        if 'texture' in field_outputs:
            outputs['texture'] = masked_scatter(sample_mask, field_outputs['texture'])
        if 'normal' in field_outputs:
            outputs['normal'] = masked_scatter(sample_mask, field_outputs['normal'])
        if 'feat_n2' in field_outputs:
            outputs['feat_n2'] = masked_scatter(sample_mask, field_outputs['feat_n2'])

        #input_fn.set_colouredpoints(field_inputs['originalpoints'],field_outputs['texture'])
        return outputs, sample_mask.sum()
Esempio n. 2
0
    def _visualize(self, images, sample, output, state, **kwargs):
        img_id, shape, view, width, name = state
        if 'colors' in output and output['colors'] is not None:
            images['{}_color/{}:HWC'.format(name, img_id)] = {
                'img': output['colors'][shape, view],
                'min_val': float(self.args.min_color)
            }
        if 'depths' in output and output['depths'] is not None:
            min_depth, max_depth = output['depths'].min(
            ), output['depths'].max()
            if getattr(self.args, "near", None) is not None:
                min_depth = self.args.near
                max_depth = self.args.far
            images['{}_depth/{}:HWC'.format(name, img_id)] = {
                'img': output['depths'][shape, view],
                'min_val': min_depth,
                'max_val': max_depth
            }
            normals = compute_normal_map(
                sample['ray_start'][shape, view].float(),
                sample['ray_dir'][shape, view].float(),
                output['depths'][shape, view].float(),
                sample['extrinsics'][shape, view].float().inverse(), width)
            images['{}_normal/{}:HWC'.format(name, img_id)] = {
                'img': normals,
                'min_val': -1,
                'max_val': 1
            }

            # generate point clouds from depth
            images['{}_point/{}'.format(name, img_id)] = {
                'img':
                torch.cat([
                    ray(sample['ray_start'][shape, view].float(),
                        sample['ray_dir'][shape, view].float(),
                        output['depths'][shape, view].unsqueeze(-1).float()),
                    (output['colors'][shape, view] - self.args.min_color) /
                    (1 - self.args.min_color)
                ], 1),  # XYZRGB
                'raw':
                True
            }

        if 'z' in output and output['z'] is not None:
            images['{}_z/{}:HWC'.format(name, img_id)] = {
                'img': output['z'][shape, view],
                'min_val': 0,
                'max_val': 1
            }
        if 'normal' in output and output['normal'] is not None:
            images['{}_predn/{}:HWC'.format(name, img_id)] = {
                'img': output['normal'][shape, view],
                'min_val': -1,
                'max_val': 1
            }
        return images
Esempio n. 3
0
    def forward_once(
        self, input_fn, field_fn, ray_start, ray_dir, samples, encoder_states, 
        early_stop=None, output_types=['sigma', 'texture']
        ):
        """
        chunks: set > 1 if out-of-memory. it can save some memory by time.
        """
        sampled_depth = samples['sampled_point_depth']
        sampled_idx = samples['sampled_point_voxel_idx'].long()
        
        # only compute when the ray hits
        sample_mask = sampled_idx.ne(-1)
        if early_stop is not None:
            sample_mask = sample_mask & (~early_stop.unsqueeze(-1))
        if sample_mask.sum() == 0:  # miss everything skip
            return None, 0

        sampled_xyz = ray(ray_start.unsqueeze(1), ray_dir.unsqueeze(1), sampled_depth.unsqueeze(2))
        sampled_dir = ray_dir.unsqueeze(1).expand(*sampled_depth.size(), ray_dir.size()[-1])
        samples['sampled_point_xyz'] = sampled_xyz
        samples['sampled_point_ray_direction'] = sampled_dir
    
        # apply mask    
        masked_samples = {name: s[sample_mask] for name, s in samples.items()}
       
        # get encoder features as inputs
        field_inputs = input_fn(masked_samples, encoder_states)
        
        def masked_scatter(mask, x):
            B, K = mask.size()
            if x.dim() == 1:
                return x.new_zeros(B, K).masked_scatter(mask, x)
            return x.new_zeros(B, K, x.size(-1)).masked_scatter(
                mask.unsqueeze(-1).expand(B, K, x.size(-1)), x)

        def post_process_sigma(sigma, sampled_dists):
            noise = 0 if not self.discrete_reg and (not self.training) else torch.zeros_like(sigma).normal_()  
            free_energy = torch.relu(noise + sigma) * sampled_dists   
            free_energy = free_energy * 7.0  # ? [debug] 
            # (optional) free_energy = (F.elu(sigma - 3, alpha=1) + 1) * dists
            # (optional) free_energy = torch.abs(sigma) * sampled_dists  ## ??
            free_energy = masked_scatter(sample_mask, free_energy)
            return free_energy

        # forward implicit fields
        outputs = {'sample_mask': sample_mask}
        if self.acum_latent and 'sigma' in output_types and 'texture' in output_types:
            # Calculate per ray densities
            output_types = output_types.copy() # Create a copy so we don't modify input
            output_types.remove('texture')
            field_outputs = field_fn(field_inputs, outputs=output_types) # TODO: Investigate which other outputs should be in the first pass
            sigma, sampled_dists = field_outputs['sigma'], field_inputs['dists']
            free_energy = post_process_sigma(sigma, sampled_dists)
            probs = self.calc_hit_probs(sampled_depth, free_energy)

            # Calculate density weighted average embedding along each ray
            probs = probs.unsqueeze(-1)
            emb = masked_scatter(sample_mask, field_inputs['emb'])
            emb = (probs * emb).sum(1)

            # Decode per ray embedding to texture
            ray_dir = samples['sampled_point_ray_direction'][:,0]
            new_field_inputs = {'emb': emb, 'ray': ray_dir} # TODO: Include rest of inputs?
            new_field_outputs = field_fn(new_field_inputs, outputs='texture') # TODO: Investigate which outputs should be in the second pass
            field_outputs.update({
                'feat_n2': new_field_outputs['feat_n2'].unsqueeze(1),
                'texture': new_field_outputs['texture'].unsqueeze(1)
            })
            outputs['texture'] = new_field_outputs['texture'].unsqueeze(1)
            outputs['feat_n2'] = new_field_outputs['feat_n2'].unsqueeze(1)
            
        else:
            field_outputs = field_fn(field_inputs, outputs=output_types)
        
        # post processing
        if 'sigma' in field_outputs:
            sigma, sampled_dists = field_outputs['sigma'], field_inputs['dists']
            outputs['free_energy'] = post_process_sigma(sigma, sampled_dists)
        if 'sdf' in field_outputs:
            outputs['sdf'] = masked_scatter(sample_mask, field_outputs['sdf'])
        if 'texture' in field_outputs and not self.acum_latent:
            outputs['texture'] = masked_scatter(sample_mask, field_outputs['texture'])
        if 'normal' in field_outputs:
            outputs['normal'] = masked_scatter(sample_mask, field_outputs['normal'])
        if 'feat_n2' in field_outputs  and not self.acum_latent:
            outputs['feat_n2'] = masked_scatter(sample_mask, field_outputs['feat_n2'])
        return outputs, sample_mask.sum()