def forward_once( self, input_fn, field_fn, ray_start, ray_dir, samples, encoder_states, early_stop=None, output_types=['sigma', 'texture'] ): """ chunks: set > 1 if out-of-memory. it can save some memory by time. """ sampled_depth = samples['sampled_point_depth'] sampled_idx = samples['sampled_point_voxel_idx'].long() # only compute when the ray hits sample_mask = sampled_idx.ne(-1) if early_stop is not None: sample_mask = sample_mask & (~early_stop.unsqueeze(-1)) if sample_mask.sum() == 0: # miss everything skip return None, 0 sampled_xyz = ray(ray_start.unsqueeze(1), ray_dir.unsqueeze(1), sampled_depth.unsqueeze(2)) sampled_dir = ray_dir.unsqueeze(1).expand(*sampled_depth.size(), ray_dir.size()[-1]) samples['sampled_point_xyz'] = sampled_xyz samples['sampled_point_ray_direction'] = sampled_dir # apply mask samples = {name: s[sample_mask] for name, s in samples.items()} # get encoder features as inputs field_inputs = input_fn(samples, encoder_states) # forward implicit fields field_outputs = field_fn(field_inputs, outputs=output_types) outputs = {'sample_mask': sample_mask} def masked_scatter(mask, x): B, K = mask.size() if x.dim() == 1: return x.new_zeros(B, K).masked_scatter(mask, x) return x.new_zeros(B, K, x.size(-1)).masked_scatter( mask.unsqueeze(-1).expand(B, K, x.size(-1)), x) # post processing if 'sigma' in field_outputs: sigma, sampled_dists= field_outputs['sigma'], field_inputs['dists'] noise = 0 if not self.discrete_reg and (not self.training) else torch.zeros_like(sigma).normal_() free_energy = torch.relu(noise + sigma) * sampled_dists free_energy = free_energy * 7.0 # ? [debug] # (optional) free_energy = (F.elu(sigma - 3, alpha=1) + 1) * dists # (optional) free_energy = torch.abs(sigma) * sampled_dists ## ?? outputs['free_energy'] = masked_scatter(sample_mask, free_energy) if 'sdf' in field_outputs: outputs['sdf'] = masked_scatter(sample_mask, field_outputs['sdf']) if 'texture' in field_outputs: outputs['texture'] = masked_scatter(sample_mask, field_outputs['texture']) if 'normal' in field_outputs: outputs['normal'] = masked_scatter(sample_mask, field_outputs['normal']) if 'feat_n2' in field_outputs: outputs['feat_n2'] = masked_scatter(sample_mask, field_outputs['feat_n2']) #input_fn.set_colouredpoints(field_inputs['originalpoints'],field_outputs['texture']) return outputs, sample_mask.sum()
def _visualize(self, images, sample, output, state, **kwargs): img_id, shape, view, width, name = state if 'colors' in output and output['colors'] is not None: images['{}_color/{}:HWC'.format(name, img_id)] = { 'img': output['colors'][shape, view], 'min_val': float(self.args.min_color) } if 'depths' in output and output['depths'] is not None: min_depth, max_depth = output['depths'].min( ), output['depths'].max() if getattr(self.args, "near", None) is not None: min_depth = self.args.near max_depth = self.args.far images['{}_depth/{}:HWC'.format(name, img_id)] = { 'img': output['depths'][shape, view], 'min_val': min_depth, 'max_val': max_depth } normals = compute_normal_map( sample['ray_start'][shape, view].float(), sample['ray_dir'][shape, view].float(), output['depths'][shape, view].float(), sample['extrinsics'][shape, view].float().inverse(), width) images['{}_normal/{}:HWC'.format(name, img_id)] = { 'img': normals, 'min_val': -1, 'max_val': 1 } # generate point clouds from depth images['{}_point/{}'.format(name, img_id)] = { 'img': torch.cat([ ray(sample['ray_start'][shape, view].float(), sample['ray_dir'][shape, view].float(), output['depths'][shape, view].unsqueeze(-1).float()), (output['colors'][shape, view] - self.args.min_color) / (1 - self.args.min_color) ], 1), # XYZRGB 'raw': True } if 'z' in output and output['z'] is not None: images['{}_z/{}:HWC'.format(name, img_id)] = { 'img': output['z'][shape, view], 'min_val': 0, 'max_val': 1 } if 'normal' in output and output['normal'] is not None: images['{}_predn/{}:HWC'.format(name, img_id)] = { 'img': output['normal'][shape, view], 'min_val': -1, 'max_val': 1 } return images
def forward_once( self, input_fn, field_fn, ray_start, ray_dir, samples, encoder_states, early_stop=None, output_types=['sigma', 'texture'] ): """ chunks: set > 1 if out-of-memory. it can save some memory by time. """ sampled_depth = samples['sampled_point_depth'] sampled_idx = samples['sampled_point_voxel_idx'].long() # only compute when the ray hits sample_mask = sampled_idx.ne(-1) if early_stop is not None: sample_mask = sample_mask & (~early_stop.unsqueeze(-1)) if sample_mask.sum() == 0: # miss everything skip return None, 0 sampled_xyz = ray(ray_start.unsqueeze(1), ray_dir.unsqueeze(1), sampled_depth.unsqueeze(2)) sampled_dir = ray_dir.unsqueeze(1).expand(*sampled_depth.size(), ray_dir.size()[-1]) samples['sampled_point_xyz'] = sampled_xyz samples['sampled_point_ray_direction'] = sampled_dir # apply mask masked_samples = {name: s[sample_mask] for name, s in samples.items()} # get encoder features as inputs field_inputs = input_fn(masked_samples, encoder_states) def masked_scatter(mask, x): B, K = mask.size() if x.dim() == 1: return x.new_zeros(B, K).masked_scatter(mask, x) return x.new_zeros(B, K, x.size(-1)).masked_scatter( mask.unsqueeze(-1).expand(B, K, x.size(-1)), x) def post_process_sigma(sigma, sampled_dists): noise = 0 if not self.discrete_reg and (not self.training) else torch.zeros_like(sigma).normal_() free_energy = torch.relu(noise + sigma) * sampled_dists free_energy = free_energy * 7.0 # ? [debug] # (optional) free_energy = (F.elu(sigma - 3, alpha=1) + 1) * dists # (optional) free_energy = torch.abs(sigma) * sampled_dists ## ?? free_energy = masked_scatter(sample_mask, free_energy) return free_energy # forward implicit fields outputs = {'sample_mask': sample_mask} if self.acum_latent and 'sigma' in output_types and 'texture' in output_types: # Calculate per ray densities output_types = output_types.copy() # Create a copy so we don't modify input output_types.remove('texture') field_outputs = field_fn(field_inputs, outputs=output_types) # TODO: Investigate which other outputs should be in the first pass sigma, sampled_dists = field_outputs['sigma'], field_inputs['dists'] free_energy = post_process_sigma(sigma, sampled_dists) probs = self.calc_hit_probs(sampled_depth, free_energy) # Calculate density weighted average embedding along each ray probs = probs.unsqueeze(-1) emb = masked_scatter(sample_mask, field_inputs['emb']) emb = (probs * emb).sum(1) # Decode per ray embedding to texture ray_dir = samples['sampled_point_ray_direction'][:,0] new_field_inputs = {'emb': emb, 'ray': ray_dir} # TODO: Include rest of inputs? new_field_outputs = field_fn(new_field_inputs, outputs='texture') # TODO: Investigate which outputs should be in the second pass field_outputs.update({ 'feat_n2': new_field_outputs['feat_n2'].unsqueeze(1), 'texture': new_field_outputs['texture'].unsqueeze(1) }) outputs['texture'] = new_field_outputs['texture'].unsqueeze(1) outputs['feat_n2'] = new_field_outputs['feat_n2'].unsqueeze(1) else: field_outputs = field_fn(field_inputs, outputs=output_types) # post processing if 'sigma' in field_outputs: sigma, sampled_dists = field_outputs['sigma'], field_inputs['dists'] outputs['free_energy'] = post_process_sigma(sigma, sampled_dists) if 'sdf' in field_outputs: outputs['sdf'] = masked_scatter(sample_mask, field_outputs['sdf']) if 'texture' in field_outputs and not self.acum_latent: outputs['texture'] = masked_scatter(sample_mask, field_outputs['texture']) if 'normal' in field_outputs: outputs['normal'] = masked_scatter(sample_mask, field_outputs['normal']) if 'feat_n2' in field_outputs and not self.acum_latent: outputs['feat_n2'] = masked_scatter(sample_mask, field_outputs['feat_n2']) return outputs, sample_mask.sum()