def stack_batch(batch): """ Stack multi-camera batches (B,N,C,H,W becomes BN,C,H,W) Parameters ---------- batch : dict Batch Returns ------- batch : dict Stacked batch """ # If there is multi-camera information if len(batch['rgb'].shape) == 5: assert batch['rgb'].shape[ 0] == 1, 'Only batch size 1 is supported for multi-cameras' # Loop over all keys for key in batch.keys(): # If list, stack every item if is_list(batch[key]): if is_tensor(batch[key][0]) or is_numpy(batch[key][0]): batch[key] = [sample[0] for sample in batch[key]] # Else, stack single item else: batch[key] = batch[key][0] return batch
def stack_sample(sample): """Stack a sample from multiple sensors""" # If there is only one sensor don't do anything if len(sample) == 1: return sample[0] # Otherwise, stack sample stacked_sample = {} for key in sample[0]: # Global keys (do not stack) if key in ['idx', 'dataset_idx', 'sensor_name', 'filename']: stacked_sample[key] = sample[0][key] else: # Stack torch tensors if is_tensor(sample[0][key]): stacked_sample[key] = torch.stack([s[key] for s in sample], 0) # Stack numpy arrays elif is_numpy(sample[0][key]): stacked_sample[key] = np.stack([s[key] for s in sample], 0) # Stack list elif is_list(sample[0][key]): stacked_sample[key] = [] # Stack list of torch tensors if is_tensor(sample[0][key][0]): for i in range(len(sample[0][key])): stacked_sample[key].append( torch.stack([s[key][i] for s in sample], 0)) # Stack list of numpy arrays if is_numpy(sample[0][key][0]): for i in range(len(sample[0][key])): stacked_sample[key].append( np.stack([s[key][i] for s in sample], 0)) # Return stacked sample return stacked_sample
def prep_dataset(config): """ Expand dataset configuration to match split length Parameters ---------- config : CfgNode Dataset configuration Returns ------- config : CfgNode Updated dataset configuration """ # If there is no dataset, do nothing if len(config.path) == 0: return config # If cameras is not a double list, make it so if not config.cameras or not is_list(config.cameras[0]): config.cameras = [config.cameras] # Get maximum length and expand other arguments to the same length n = max(len(config.split), len(config.cameras), len(config.depth_type)) config.dataset = make_list(config.dataset, n) config.path = make_list(config.path, n) config.split = make_list(config.split, n) config.depth_type = make_list(config.depth_type, n) config.cameras = make_list(config.cameras, n) if 'repeat' in config: config.repeat = make_list(config.repeat, n) # Return updated configuration return config
def flip(tensor, flip_fn): """ Flip tensors or list of tensors based on a function Parameters ---------- tensor : torch.Tensor or list[torch.Tensor] or list[list[torch.Tensor]] Tensor to be flipped flip_fn : Function Flip function Returns ------- tensor : torch.Tensor or list[torch.Tensor] or list[list[torch.Tensor]] Flipped tensor or list of tensors """ if not is_list(tensor): return flip_fn(tensor) else: if not is_list(tensor[0]): return [flip_fn(val) for val in tensor] else: return [[flip_fn(v) for v in val] for val in tensor]
def __call__(self, progress): """ Call for an update in the number of scales Parameters ---------- progress : float Training progress percentage Returns ------- num_scales : int New number of scales """ if is_list(self.progressive_scaling): return int(self.num_scales - np.searchsorted(self.progressive_scaling, progress)) else: return self.num_scales
def make_list(var, n=None): """ Wraps the input into a list, and optionally repeats it to be size n Parameters ---------- var : Any Variable to be wrapped in a list n : int How much the wrapped variable will be repeated Returns ------- var_list : list List generated from var """ var = var if is_list(var) else [var] if n is None: return var else: assert len(var) == 1 or len(var) == n, 'Wrong list length for make_list' return var * n if len(var) == 1 else var
def log_inv_depth(key, prefix, batch, i=0): """ Converts an inverse depth map from a batch for logging Parameters ---------- key : str Key from data containing the inverse depth map prefix : str Prefix added to the key for logging batch : dict Dictionary containing the key i : int Batch index from which to get the inverse depth map Returns ------- image : wandb.Image Wandb image ready for logging """ inv_depth = batch[key] if is_dict(batch) else batch inv_depth = inv_depth[0] if is_list(inv_depth) else inv_depth return prep_image(prefix, key, viz_inv_depth(inv_depth[i]))
def forward(self, image, context, inv_depths, poses, path_to_ego_mask, path_to_ego_mask_context, K, ref_K, extrinsics, ref_extrinsics, context_type, return_logs=False, progress=0.0): """ Calculates training photometric loss. Parameters ---------- image : torch.Tensor [B,3,H,W] Original image context : list of torch.Tensor [B,3,H,W] Context containing a list of reference images inv_depths : list of torch.Tensor [B,1,H,W] Predicted depth maps for the original image, in all scales K : torch.Tensor [B,3,3] Original camera intrinsics ref_K : torch.Tensor [B,3,3] Reference camera intrinsics poses : list of Pose Camera transformation between original and context return_logs : bool True if logs are saved for visualization progress : float Training percentage Returns ------- losses_and_metrics : dict Output dictionary """ # If using progressive scaling self.n = self.progressive_scaling(progress) # Loop over all reference images photometric_losses = [[] for _ in range(self.n)] images = match_scales(image, inv_depths, self.n) inv_depths2 = [] for k in range(len(inv_depths)): Btmp, C, H, W = inv_depths[k].shape depths2 = torch.zeros_like(inv_depths[k]) for i in range(H): for j in range(W): depths2[ 0, 0, i, j] = 2.0 #(2/H)**4 * (2/W)**4 * i**2 * (H - i)**2 * j**2 * (W - j)**2 * 20 inv_depths2.append(depth2inv(depths2)) if is_list(context_type[0][0]): n_context = len(context_type[0]) else: n_context = len(context_type) #n_context = len(context) device = image.get_device() B = len(path_to_ego_mask) if is_list(path_to_ego_mask[0]): H_full, W_full = np.load(path_to_ego_mask[0][0]).shape else: H_full, W_full = np.load(path_to_ego_mask[0]).shape # getting ego masks for target and source cameras # fullsize mask ego_mask_tensor = torch.ones(B, 1, H_full, W_full).to(device) ref_ego_mask_tensor = [] for i_context in range(n_context): ref_ego_mask_tensor.append( torch.ones(B, 1, H_full, W_full).to(device)) for b in range(B): if self.mask_ego: if is_list(path_to_ego_mask[b]): ego_mask_tensor[b, 0] = torch.from_numpy( np.load(path_to_ego_mask[b][0])).float() else: ego_mask_tensor[b, 0] = torch.from_numpy( np.load(path_to_ego_mask[b])).float() for i_context in range(n_context): if is_list(path_to_ego_mask_context[0][0]): paths_context_ego = [ p[i_context][0] for p in path_to_ego_mask_context ] else: paths_context_ego = path_to_ego_mask_context[i_context] ref_ego_mask_tensor[i_context][b, 0] = torch.from_numpy( np.load(paths_context_ego[b])).float() # resized masks ego_mask_tensors = [] ref_ego_mask_tensors = [] for i_context in range(n_context): ref_ego_mask_tensors.append([]) for i in range(self.n): Btmp, C, H, W = images[i].shape if W < W_full: #inv_scale_factor = int(W_full / W) #print(W_full / W) #ego_mask_tensors.append(-nn.MaxPool2d(inv_scale_factor, inv_scale_factor)(-ego_mask_tensor)) ego_mask_tensors.append( interpolate_image(ego_mask_tensor, shape=(Btmp, 1, H, W), mode='nearest', align_corners=None)) for i_context in range(n_context): #ref_ego_mask_tensors[i_context].append(-nn.MaxPool2d(inv_scale_factor, inv_scale_factor)(-ref_ego_mask_tensor[i_context])) ref_ego_mask_tensors[i_context].append( interpolate_image(ref_ego_mask_tensor[i_context], shape=(Btmp, 1, H, W), mode='nearest', align_corners=None)) else: ego_mask_tensors.append(ego_mask_tensor) for i_context in range(n_context): ref_ego_mask_tensors[i_context].append( ref_ego_mask_tensor[i_context]) for i_context in range(n_context): _, C, H, W = context[i_context].shape if W < W_full: inv_scale_factor = int(W_full / W) #ref_ego_mask_tensor[i_context] = -nn.MaxPool2d(inv_scale_factor, inv_scale_factor)(-ref_ego_mask_tensor[i_context]) ref_ego_mask_tensor[i_context] = interpolate_image( ref_ego_mask_tensor[i_context], shape=(Btmp, 1, H, W), mode='nearest', align_corners=None) print(ref_extrinsics) for j, (ref_image, pose) in enumerate(zip(context, poses)): print(ref_extrinsics[j]) ref_context_type = [c[j][0] for c in context_type] if is_list( context_type[0][0]) else context_type[j] print(ref_context_type) print(pose.mat) # Calculate warped images ref_warped, ref_ego_mask_tensors_warped = self.warp_ref_image( inv_depths2, ref_image, ref_ego_mask_tensor[j], K, ref_K[:, j, :, :], pose, ref_extrinsics[j], ref_context_type) print(pose.mat) # Calculate and store image loss photometric_loss = self.calc_photometric_loss(ref_warped, images) tt = str(int(time.time() % 10000)) for i in range(self.n): B, C, H, W = images[i].shape for b in range(B): orig_PIL_0 = torch.transpose( (ref_image[b, :, :, :]).unsqueeze(0).unsqueeze(4), 1, 4).squeeze().detach().cpu().numpy() orig_PIL = torch.transpose( (ref_image[b, :, :, :] * ref_ego_mask_tensor[j][b, :, :, :] ).unsqueeze(0).unsqueeze(4), 1, 4).squeeze().detach().cpu().numpy() warped_PIL_0 = torch.transpose( (ref_warped[i][b, :, :, :]).unsqueeze(0).unsqueeze(4), 1, 4).squeeze().detach().cpu().numpy() warped_PIL = torch.transpose( (ref_warped[i][b, :, :, :] * ego_mask_tensors[i][b, :, :, :] ).unsqueeze(0).unsqueeze(4), 1, 4).squeeze().detach().cpu().numpy() target_PIL_0 = torch.transpose( (images[i][b, :, :, :]).unsqueeze(0).unsqueeze(4), 1, 4).squeeze().detach().cpu().numpy() target_PIL = torch.transpose( (images[i][b, :, :, :] * ego_mask_tensors[i][b, :, :, :] ).unsqueeze(0).unsqueeze(4), 1, 4).squeeze().detach().cpu().numpy() cv2.imwrite( '/home/users/vbelissen/test' + '_' + str(j) + '_' + tt + '_' + str(b) + '_' + str(i) + '_orig_PIL_0.png', orig_PIL_0 * 255) cv2.imwrite( '/home/users/vbelissen/test' + '_' + str(j) + '_' + tt + '_' + str(b) + '_' + str(i) + '_orig_PIL.png', orig_PIL * 255) cv2.imwrite( '/home/users/vbelissen/test' + '_' + str(j) + '_' + tt + '_' + str(b) + '_' + str(i) + '_warped_PIL_0.png', warped_PIL_0 * 255) cv2.imwrite( '/home/users/vbelissen/test' + '_' + str(j) + '_' + tt + '_' + str(b) + '_' + str(i) + '_warped_PIL.png', warped_PIL * 255) cv2.imwrite( '/home/users/vbelissen/test' + '_' + str(j) + '_' + tt + '_' + str(b) + '_' + str(i) + '_target_PIL_0.png', target_PIL_0 * 255) cv2.imwrite( '/home/users/vbelissen/test' + '_' + str(j) + '_' + tt + '_' + str(b) + '_' + str(i) + '_target_PIL.png', target_PIL * 255) for i in range(self.n): photometric_losses[i].append(photometric_loss[i] * ego_mask_tensors[i] * ref_ego_mask_tensors_warped[i]) # If using automask if self.automask_loss: # Calculate and store unwarped image loss ref_images = match_scales(ref_image, inv_depths, self.n) unwarped_image_loss = self.calc_photometric_loss( ref_images, images) for i in range(self.n): photometric_losses[i].append(unwarped_image_loss[i] * ego_mask_tensors[i] * ref_ego_mask_tensors[j][i]) # Calculate reduced photometric loss loss = self.nonzero_reduce_photometric_loss(photometric_losses) # Include smoothness loss if requested if self.smooth_loss_weight > 0.0: loss += self.calc_smoothness_loss( [a * b for a, b in zip(inv_depths, ego_mask_tensors)], [a * b for a, b in zip(images, ego_mask_tensors)]) # Return losses and metrics return { 'loss': loss.unsqueeze(0), 'metrics': self.metrics, }