def compute_pose_net(self, image, contexts): """Compute poses from image and a sequence of context images""" pose_vec = self.pose_net(image, contexts) return [ Pose.from_vec(pose_vec[:, i], self.rotation_mode) for i in range(pose_vec.shape[1]) ]
def forward(self, batch, return_logs=False, progress=0.0): """ Processes a batch. Parameters ---------- batch : dict Input batch return_logs : bool True if logs are stored progress : Training progress percentage Returns ------- output : dict Dictionary containing a "loss" scalar and different metrics and predictions for logging and downstream usage. """ # Calculate predicted depth and pose output output = super().forward(batch, return_logs=return_logs) if not self.training: # If not training, no need for self-supervised loss return output else: # Otherwise, calculate self-supervised loss self_sup_output = self.self_supervised_loss( batch['rgb_original'], batch['rgb_context_original'], output['inv_depths'], output['poses'], batch['intrinsics'], return_logs=return_logs, progress=progress) if len(batch['rgb_context'] ) == 2 and self.pose_consistency_loss_weight != 0.: pose_contexts = self.compute_poses(batch['rgb_context'][0], [batch['rgb_context'][1]]) pose_consistency_output = self.pose_consistency_loss( invert_pose(output['poses'][0].item()), output['poses'][1].item(), pose_contexts[0].item()) pose_consistency_output[ 'loss'] *= self.pose_consistency_loss_weight output = merge_outputs(output, pose_consistency_output) if self.identity_pose_loss_weight != 0: # Identity pose loss. pose(I_t, I_t+1) = pose(I_t+1, It)^⁻1 pose_vec_minus1_0 = self.pose_net(batch['rgb_context'][0], batch['rgb']) pose_vec_minus1_0 = Pose.from_vec(pose_vec_minus1_0, self.rotation_mode) identity_pose_output_minus1_0 = self.identity_pose_loss( output['poses'][0].item(), invert_pose(pose_vec_minus1_0)) pose_vec_01 = self.pose_net(batch['rgb_context'][1], batch['rgb']) pose_vec_01 = Pose.from_vec(pose_vec_01, self.rotation_mode) identity_pose_output_01 = self.identity_pose_loss( output['poses'][1].item(), invert_pose(pose_vec_01)) identity_pose_output_minus1_0[ 'loss'] *= self.identity_pose_loss_weight identity_pose_output_01[ 'loss'] *= self.identity_pose_loss_weight output = merge_outputs(output, identity_pose_output_minus1_0, identity_pose_output_01) if self.temporal_consistency_loss_weight != 0: # Temporal consistency: D_t = proj_on_t(D_t-1) + pose_t-1_t[3,3] temporal_consistency_loss = [] # for each ref image for j, (ref_image, pose) in enumerate( zip(batch['rgb_context'], output['poses'])): ref_warped_depths = warp_inv_depth(output['inv_depths'], batch['intrinsics'], batch['intrinsics'], pose) # add z from pose to warped depth ref_warped_depths = [ ref_warped_depth + pose[:, 3, 3] for ref_warped_depth in ref_warped_depths ] ref_inv_depths = self.compute_inv_depths(ref_image) ref_depths = [ inv2depth(ref_inv_depths[i]) for i in range(len(ref_inv_depths)) ] l1_loss = self.temporal_consistency_loss( ref_warped_depths, ref_depths) temporal_consistency_loss.append([i for i in l1_loss]) temporal_consistency_loss = temporal_consistency_loss.mean( ) * self.temporal_consistency_loss_weight output = merge_outputs(output, temporal_consistency_loss) # Return loss and metrics return { 'loss': self_sup_output['loss'], **merge_outputs(output, self_sup_output), }