def vcat(cls, cameras, batch_size=-1): """ Concatenate the view dimension of the camera parameters. Args: cameras: cameras to concatenate batch_size: the batch size of the data Returns: Cameras concatenated in the view dimension then flattened """ z_span = cameras[0].z_span height = cameras[0].height width = cameras[0].width intrinsic = torch.cat( [b2bv(o.intrinsic, batch_size=batch_size) for o in cameras], dim=1) viewport = torch.cat( [b2bv(o.viewport, batch_size=batch_size) for o in cameras], dim=1) log_quaternion = torch.cat( [b2bv(o.log_quaternion, batch_size=batch_size) for o in cameras], dim=1) translation = torch.cat( [b2bv(o.translation, batch_size=batch_size) for o in cameras], dim=1) return cls(bv2b(intrinsic), None, z_span, bv2b(viewport), log_quaternion=bv2b(log_quaternion), translation=bv2b(translation), width=width, height=height)
def warp_blend_logits(logits, image_reproj, flow_size): device = image_reproj.device num_input_views = image_reproj.shape[1] height, width = image_reproj.shape[-2:] blend_logits, flow_x_logits, flow_y_logits = torch.split(logits, num_input_views, dim=1) blend_weights = torch.softmax(blend_logits, dim=1).unsqueeze(2) flow_dx = flow_size / width * torch.tanh(flow_x_logits) flow_dy = flow_size / height * torch.tanh(flow_y_logits) flow_y, flow_x = torch.meshgrid([ torch.linspace(-1, 1, height, device=device), torch.linspace(-1, 1, width, device=device) ]) flow_x = flow_x[None, None, :, :].expand_as(flow_dx) + flow_dx flow_y = flow_y[None, None, :, :].expand_as(flow_dy) + flow_dy flow_grid = torch.stack((flow_x, flow_y), dim=-1).clamp(-1, 1) image_fake = F.grid_sample(bv2b(image_reproj), bv2b(flow_grid), mode='bilinear') image_fake = b2bv(image_fake, num_input_views) image_fake = (blend_weights * image_fake).sum(dim=1) return image_fake, blend_weights, flow_dx, flow_dy
def reproject_views(image_in, depth_in, depth_out, camera_in, camera_out): """ Reprojects pixels from the input view to the output view. Args: image_in: pixels to copy from (V_i, C, H, W) depth_out: target depth (V_o, C, H, W) camera_in: input view cameras camera_out: output view cameras Returns: Each input view image reprojected to output views (V_o, V_i, C, H, W) Each input view depth transformed and reprojected to output views (V_o, V_i, C, H, W) """ grid = depth_to_warp_field(camera_in, camera_out, depth_out) # Expand and reshape to do batch operations: # batch_dim = output_views # view_dim = input_views image_in = bv2b( image_in.unsqueeze(0).expand(camera_out.length, -1, -1, -1, -1)) obj_coords_in = torch.stack(camera_in.depth_object_coords(depth_in), dim=-1) obj_coords_in = bv2b( obj_coords_in.unsqueeze(0).expand(camera_out.length, -1, -1, -1, -1)) # Repeat interleaved to expand to input view dimensions. camera_out = camera_out.repeat_interleave(camera_in.length) # Transform reprojected input coordinates to output view. cam_coords_in_tf = three.transform_coord_grid(obj_coords_in, camera_out.obj_to_cam) depth_in_tf = cam_coords_in_tf[..., 2].unsqueeze(1) depth_in_tf = camera_out.normalize_depth(depth_in_tf) grid = bv2b(grid) image_reproj = F.grid_sample(image_in, grid, mode='bilinear') depth_reproj = F.grid_sample(depth_in_tf, grid, mode='bilinear') return b2bv(image_reproj, camera_in.length), b2bv(depth_reproj, camera_in.length)
def compute_blend_weights(self, z_cam, camera): num_views = z_cam.shape[1] z_cam = bv2b(z_cam) # Concatenate camera-space coordinates to input. coords = utils.get_normalized_voxel_depth(z_cam) w = torch.cat((z_cam, coords), dim=1) w = self.unet(w) w = self.transform_block(w, camera) w = b2bv(w, num_views) # Softmax along view dimension. w = torch.softmax(w, dim=1) return w
def run_iteration(self, batch, train, is_step): if 'hard_' in self.g_depth_recon_loss_type: self._g_color_recon_criterion.k = int( self._g_color_recon_k_scheduler.get(self.epoch)) batch = process_batch(batch, self.cube_size, self.camera_dist, self._sculptor.in_size, self.device, random_orientation=False) if not self.color_random_background or self.crop_random_background: batch['in']['image'] = batch['in']['image'] * batch['in']['mask'] if not self.depth_random_background or self.crop_random_background: batch['in']['depth'] = mask_normalized_depth( batch['in']['depth'], batch['in']['mask']) data_process_time = self.mark_time() generator = self._generator if self.data_parallel: generator = nn.DataParallel(generator) image_reproj, depth_reproj, mask_ibr_out, depth_ibr_out, cam_dists_r, cam_dists_t = \ self._render_reprojections(batch) ibr_time = self.mark_time() # Add dists as another channel to image x = torch.cat(( image_reproj, depth_reproj, cam_dists_r[:, :, :, None, None, None].expand( -1, -1, -1, -1, *image_reproj.shape[-2:]), cam_dists_t[:, :, :, None, None, None].expand( -1, -1, -1, -1, *image_reproj.shape[-2:]), ), dim=3) # Factor input views into channels and batch/output_views into batch dim. x = x.view(x.shape[0] * x.shape[1], x.shape[2] * x.shape[3], x.shape[4], x.shape[5]) # Add output predicted depth to input. x = torch.cat((bv2b(depth_ibr_out), x), dim=1) blend_weights = None flow_dx = None flow_dy = None logits = generator(x, z_inject=None) if self.ibr_type == 'regress': image_ibr_out = torch.tanh(logits) elif self.ibr_type == 'blend': image_ibr_out, blend_weights = ibr.blend_logits( logits, bv2b(image_reproj)) else: image_ibr_out, blend_weights, flow_dx, flow_dy = ibr.warp_blend_logits( logits, bv2b(image_reproj), self.flow_size) image_ibr_out = b2bv(image_ibr_out, self.num_output_views) if not self.no_apply_mask: image_ibr_out = image_ibr_out * mask_ibr_out depth_ibr_out = mask_normalized_depth(depth_ibr_out, mask_ibr_out) if self._discriminator: image_real_noise = ( self.input_noise_weight * self._input_noise_dist.sample(batch['out_gt']['image'].size())) image_fake_noise = ( self.input_noise_weight * self._input_noise_dist.sample(image_ibr_out.size())) discriminator = self._discriminator if self.data_parallel: discriminator = nn.DataParallel(discriminator) # Train discriminator. d_real = discriminator(bv2b(batch['out_gt']['image'] + image_real_noise.to(self.device)), mask=bv2b(batch['out_gt']['mask'])) d_fake_d = discriminator(image_ibr_out.detach() + image_fake_noise.to(self.device), mask=mask_ibr_out) loss_d_real = multiscale_lsgan_loss(d_real, 1) loss_d_fake = multiscale_lsgan_loss(d_fake_d, 0) loss_d = loss_d_real + loss_d_fake d_fake_g = discriminator(image_ibr_out + image_fake_noise, mask=mask_ibr_out) loss_g_gan = self.g_gan_loss_weight * multiscale_lsgan_loss( d_fake_g, 1) if train: loss_d.backward() if is_step: self._optimizers['discriminator'].step() self.plotter.put_scalar('loss/discriminator/real', loss_d_real) self.plotter.put_scalar('loss/discriminator/fake', loss_d_fake) self.plotter.put_scalar('loss/discriminator/total', loss_d) else: d_real, d_fake_d, d_fake_g = None, None, None loss_g_gan = torch.tensor(0.0, device=self.device) # Train generator. Must re-evaluate discriminator to propagate gradients down the # generator. loss_g_color_recon = self.g_color_recon_loss_weight * reduce_loss( self._g_color_recon_criterion(image_ibr_out, batch['out_gt']['image']), reduction='mean') loss_g_depth_recon = self.g_depth_recon_loss_weight * reduce_loss( self._g_depth_recon_criterion(depth_ibr_out, batch['out_gt']['depth']), reduction='mean') loss_g_mask_recon = self.g_depth_recon_loss_weight * reduce_loss( self._g_depth_recon_criterion(mask_ibr_out, batch['out_gt']['mask']), reduction='mean') loss_g_mask_beta = beta_prior_loss(mask_ibr_out, alpha=self.g_mask_beta_loss_param, beta=self.g_mask_beta_loss_param) loss_g = (loss_g_gan + loss_g_color_recon + loss_g_depth_recon + loss_g_mask_recon + loss_g_mask_beta) if train: loss_g.backward() if is_step: self._optimizers['generator'].step() if self.train_recon: self._optimizers['sculptor'].step() self._optimizers['photographer'].step() if 'fuser' in self._optimizers: self._optimizers['fuser'].step() compute_time = self.mark_time() with torch.no_grad(): self.plotter.put_scalar( 'error/color/l1', F.l1_loss(image_ibr_out, batch['out_gt']['image'])) self.plotter.put_scalar( 'error/depth/l1', F.l1_loss(depth_ibr_out, batch['out_gt']['depth'])) self.plotter.put_scalar( 'error/mask/l1', F.l1_loss(mask_ibr_out, batch['out_gt']['mask'])) self.plotter.put_scalar( 'error/mask/cross_entropy', F.binary_cross_entropy_with_logits(mask_ibr_out, batch['out_gt']['mask'])) self.plotter.put_scalar('loss/generator/gan', loss_g_gan) self.plotter.put_scalar('loss/generator/recon/color', loss_g_color_recon) self.plotter.put_scalar('loss/generator/recon/depth', loss_g_depth_recon) self.plotter.put_scalar('loss/generator/recon/mask', loss_g_mask_recon) self.plotter.put_scalar('loss/generator/total', loss_g) self.plotter.put_scalar('params/input_noise_weight', self.input_noise_weight) self.plotter.put_scalar('time/data_process', data_process_time) self.plotter.put_scalar('time/ibr', ibr_time) self.plotter.put_scalar('time/compute', compute_time) plot_scalar_time = self.mark_time() self.plotter.put_scalar('time/plot/scalars', plot_scalar_time) if self.plotter.is_it_time_yet('show'): self.plotter.put_image( 'results', viz.make_grid([ gan_denormalize(batch['out_gt']['image']), gan_denormalize(image_ibr_out), viz.colorize_tensor(batch['out_gt']['depth'] / 2.0 + 0.5), viz.colorize_tensor(depth_ibr_out / 2.0 + 0.5), viz.colorize_tensor(batch['out_gt']['mask']), viz.colorize_tensor(mask_ibr_out), ], row_size=1, output_size=128, d_real=d_real, d_fake=d_fake_d)) n = self.num_output_views images = [ gan_denormalize(image_reproj[0].view( -1, *image_reproj.shape[-3:])), viz.colorize_tensor( depth_reproj[0].view(-1, *depth_reproj.shape[-3:]) / 2.0 + 0.5) ] if blend_weights is not None: images.append( viz.colorize_tensor(blend_weights[:n].view( -1, *blend_weights.shape[-2:]))) if flow_dx is not None: flow_range = self.flow_size / self.input_size images.append( viz.colorize_tensor(flow_dx[:n].reshape( -1, *flow_dx.shape[-2:]), cmap='coolwarm', cmin=-flow_range, cmax=flow_range)) images.append( viz.colorize_tensor(flow_dy[:n].reshape( -1, *flow_dy.shape[-2:]), cmap='coolwarm', cmin=-flow_range, cmax=flow_range)) self.plotter.put_image( 'ibr_components', viz.make_grid(images, row_size=4, stride=4, output_size=64)) if self.plotter.is_it_time_yet('histogram'): if flow_dx is not None: self.plotter.put_histogram('flow/dx', flow_dx) self.plotter.put_histogram('flow/dy', flow_dy) plot_images_time = self.mark_time() self.plotter.put_scalar('time/plot/images', plot_images_time)