def _calculate(self): # On tpu we strive to stack tensors together and perform ops once on the # entire stack, to save time HBM memory. We thus stack the batch-of-first- # frames and the batch-of-second frames, for both depth and RGB. The batch # dimension of rgb_stack and gt_depth_stack are thus twice the original # batch size. # Create stacks for features that need to be scaled into pyramids for # multi-scale training. rgb_stack_ = tf.concat(self._endpoints['rgb'], axis=0) flipped_rgb_stack_ = tf.concat(self._endpoints['rgb'][::-1], axis=0) predicted_depth_stack_ = tf.concat(self._endpoints['predicted_depth'], axis=0) flipped_predicted_depth_stack_ = tf.concat( self._endpoints['predicted_depth'][::-1], axis=0) residual_translation_ = tf.concat( self._endpoints['residual_translation'], axis=0) flipped_residual_translation_ = tf.concat( self._endpoints['residual_translation'][::-1], axis=0) intrinsics_mat_ = tf.concat(self._endpoints['intrinsics_mat'], axis=0) # Create pyramids from each stack to support multi-scale training. num_scales = self._params.num_scales rgb_pyramid = _get_pyramid(rgb_stack_, num_scales=num_scales) flipped_rgb_pyramid = _get_pyramid(flipped_rgb_stack_, num_scales=num_scales) predicted_depth_pyramid = _get_pyramid(predicted_depth_stack_, num_scales=num_scales) flipped_predicted_depth_pyramid = _get_pyramid( flipped_predicted_depth_stack_, num_scales=num_scales) residual_translation_pyramid = _get_pyramid(residual_translation_, num_scales=num_scales) flipped_residual_translation_pyramid = _get_pyramid( flipped_residual_translation_, num_scales=num_scales) intrinsics_mat_pyramid = _get_intrinsics_mat_pyramid( intrinsics_mat_, num_scales=num_scales) validity_mask_ = self._endpoints.get('validity_mask') if validity_mask_ is not None: validity_mask_ = tf.concat(validity_mask_, axis=0) validity_mask_pyramid = _get_pyramid(validity_mask_, num_scales, _min_pool2d) else: validity_mask_pyramid = [None] * num_scales if 'groundtruth_depth' in self._endpoints: gt_depth_stack_ = tf.concat(self._endpoints['groundtruth_depth'], axis=0) gt_depth_pyramid = _get_pyramid(gt_depth_stack_, num_scales=num_scales) if 'groundtruth_depth_weight' in self._endpoints: gt_depth_weight_stack_ = tf.concat( self._endpoints['groundtruth_depth_weight'], axis=0) else: gt_depth_weight_stack_ = tf.cast( tf.greater(gt_depth_stack_, 0.2), tf.float32) gt_depth_weight_pyramid = _get_pyramid(gt_depth_weight_stack_, num_scales=num_scales) if 'groundtruth_depth_filter' in self._endpoints: depth_filter_ = tf.concat( self._endpoints['groundtruth_depth_filter'], axis=0) depth_filter_ = tf.cast(depth_filter_, tf.float32) depth_filter_pyramid = _get_pyramid(gt_depth_stack_, num_scales=num_scales) # Calculate losses at each scale. Iterate in reverse so that the final # output values are set at scale 0. for s in reversed(range(self._params.num_scales)): # Weight applied to all losses at this scale. scale_w = 1.0 / 2**s rgb_stack = rgb_pyramid[s] predicted_depth_stack = predicted_depth_pyramid[s] flipped_predicted_depth_stack = flipped_predicted_depth_pyramid[s] if 'groundtruth_depth' in self._endpoints: gt_depth_stack = gt_depth_pyramid[s] depth_error = tf.abs(gt_depth_stack - predicted_depth_stack) # Weigh the spatial loss if a weight map is provided. Otherwise, revert # to original behavior. gt_depth_weight_stack = gt_depth_weight_pyramid[s] depth_error = depth_error * gt_depth_weight_stack # Optionally filter the depth map if a boolean depth filter is provided. # We use a TPU-friendly equivalent of tf.boolean_mask. depth_filter = tf.ones_like(depth_error, tf.float32) if 'groundtruth_depth_filter' in self._endpoints: depth_filter = depth_filter_pyramid[s] self._losses['depth_supervision'] += scale_w * tf.reduce_mean( depth_error * depth_filter) / tf.reduce_mean(depth_filter) # In theory, the training losses should be agnostic to the global scale of # the predicted depth. However in reality second order effects can lead to # (https://en.wikipedia.org/wiki/Von_Neumann_stability_analysis) diverging # modes. For some reason this happens when training on TPU. Since the # scale is immaterial anyway, we normalize it out, and the training # stabilizes. # # Note that the depth supervision term, which is sensitive to the scale, # was applied before this normalization. Therefore the scale of the depth # is learned. mean_depth = tf.reduce_mean(predicted_depth_stack) # When training starts, the depth sometimes tends to collapse to a # constant value, which seems to be a fixed point where the trainig can # stuck. To discourage this collapse, we penalize the reciprocal of the # variance with a tiny weight. Note that the mean of predicted_depth is # one, hence we subtract 1.0. depth_var = tf.reduce_mean( tf.square(predicted_depth_stack / mean_depth - 1.0)) self._losses['depth_variance'] = scale_w * 1.0 / depth_var if self._params.scale_normalization: predicted_depth_stack /= mean_depth flipped_predicted_depth_stack /= mean_depth disp = 1.0 / predicted_depth_stack mean_disp = tf.reduce_mean(disp, axis=[1, 2, 3], keep_dims=True) self._losses['depth_smoothing'] += ( scale_w * regularizers.joint_bilateral_smoothing( disp / mean_disp, rgb_stack)) self._output_endpoints['disparity'] = disp flipped_rgb_stack = flipped_rgb_pyramid[s] background_translation = tf.concat( self._endpoints['background_translation'], axis=0) flipped_background_translation = tf.concat( self._endpoints['background_translation'][::-1], axis=0) residual_translation = residual_translation_pyramid[s] flipped_residual_translation = flipped_residual_translation_pyramid[ s] if self._params.scale_normalization: background_translation /= mean_depth flipped_background_translation /= mean_depth residual_translation /= mean_depth flipped_residual_translation /= mean_depth translation = residual_translation + background_translation flipped_translation = (flipped_residual_translation + flipped_background_translation) rotation = tf.concat(self._endpoints['rotation'], axis=0) flipped_rotation = tf.concat(self._endpoints['rotation'][::-1], axis=0) intrinsics_mat = intrinsics_mat_pyramid[s] intrinsics_mat_inv = intrinsics_utils.invert_intrinsics_matrix( intrinsics_mat) validity_mask = validity_mask_pyramid[s] transformed_depth = transform_depth_map.using_motion_vector( tf.squeeze(predicted_depth_stack, axis=-1), translation, rotation, intrinsics_mat, intrinsics_mat_inv) flipped_predicted_depth_stack = tf.squeeze( flipped_predicted_depth_stack, axis=-1) if self._params.target_depth_stop_gradient: flipped_predicted_depth_stack = tf.stop_gradient( flipped_predicted_depth_stack) # The first and second halves of the batch not contain Frame1's and # Frame2's depths transformed onto Frame2 and Frame1 respectively. Te # demand consistency, we need to `flip` `predicted_depth` as well. loss_endpoints = ( consistency_losses.rgbd_and_motion_consistency_loss( transformed_depth, rgb_stack, flipped_predicted_depth_stack, flipped_rgb_stack, rotation, translation, flipped_rotation, flipped_translation, validity_mask=validity_mask)) normalized_trans = regularizers.normalize_motion_map( residual_translation, translation) self._losses[ 'motion_smoothing'] += scale_w * regularizers.l1smoothness( normalized_trans, self._weights.motion_drift == 0) self._losses[ 'motion_drift'] += scale_w * regularizers.sqrt_sparsity( normalized_trans) self._losses['depth_consistency'] += ( scale_w * loss_endpoints['depth_error']) self._losses[ 'rgb_consistency'] += scale_w * loss_endpoints['rgb_error'] self._losses[ 'ssim'] += scale_w * 0.5 * loss_endpoints['ssim_error'] self._losses['rotation_cycle_consistency'] += ( scale_w * loss_endpoints['rotation_error']) self._losses['translation_cycle_consistency'] += ( scale_w * loss_endpoints['translation_error']) self._output_endpoints['depth_proximity_weight'] = loss_endpoints[ 'depth_proximity_weight'] self._output_endpoints['trans'] = translation self._output_endpoints['inv_trans'] = flipped_translation for k, w in self._weights.as_dict().items(): # multiply by 2 to match the scale of the old code. self._losses[k] *= w * 2 if tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES): self._losses[tf.GraphKeys.REGULARIZATION_LOSSES] = tf.add_n( tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
def loss_fn(features, mode, params): """Computes the training loss for depth and egomotion training. This function is written with TPU-friendlines in mind. Args: features: A dictionary mapping strings to tuples of (tf.Tensor, tf.Tensor), representing pairs of frames. The loss will be calculated from these tensors. The expected endpoints are 'rgb', 'depth', 'intrinsics_mat' and 'intrinsics_mat_inv'. mode: One of tf.estimator.ModeKeys: TRAIN, PREDICT or EVAL. params: A dictionary with hyperparameters that optionally override DEFAULT_PARAMS above. Returns: A dictionary mapping each loss name (see DEFAULT_PARAMS['loss_weights']'s keys) to a scalar tf.Tensor representing the respective loss. The total training loss. Raises: ValueError: `features` endpoints that don't conform with their expected structure. """ params = parameter_container.ParameterContainer.from_defaults_and_overrides( DEFAULT_PARAMS, params, is_strict=True, strictness_depth=2) if len(features['rgb']) != 2 or 'depth' in features and len( features['depth']) != 2: raise ValueError( 'RGB and depth endpoints are expected to be a tuple of two' ' tensors. Rather, they are %s.' % str(features)) # On tpu we strive to stack tensors together and perform ops once on the # entire stack, to save time HBM memory. We thus stack the batch-of-first- # frames and the batch-of-second frames, for both depth and RGB. The batch # dimension of rgb_stack and gt_depth_stack are thus twice the original batch # size. rgb_stack = tf.concat(features['rgb'], axis=0) depth_predictor = depth_prediction_nets.ResNet18DepthPredictor( mode, params.depth_predictor_params.as_dict()) predicted_depth = depth_predictor.predict_depth(rgb_stack) maybe_summary.histogram('PredictedDepth', predicted_depth) endpoints = {} endpoints['predicted_depth'] = tf.split(predicted_depth, 2, axis=0) endpoints['rgb'] = features['rgb'] # We make the heuristic that depths that are less than 0.2 meters are not # accurate. This is a rough placeholder for a confidence map that we're going # to have in future. if 'depth' in features: endpoints['groundtruth_depth'] = features['depth'] if params.cascade: motion_features = [ tf.concat([features['rgb'][0], endpoints['predicted_depth'][0]], axis=-1), tf.concat([features['rgb'][1], endpoints['predicted_depth'][1]], axis=-1) ] else: motion_features = features['rgb'] motion_features_stack = tf.concat(motion_features, axis=0) flipped_motion_features_stack = tf.concat(motion_features[::-1], axis=0) # Unlike `rgb_stack`, here we stacked the frames in reverse order along the # Batch dimension. By concatenating the two stacks below along the channel # axis, we create the following tensor: # # Channel dimension (3) # _ _ # | Frame1-s batch | Frame2-s batch |____Batch # |_ Frame2-s batch | Frame1-s batch _| dimension (0) # # When we send this tensor to the motion prediction network, the first and # second halves of the result represent the camera motion from Frame1 to # Frame2 and from Frame2 to Frame1 respectively. Further below we impose a # loss that drives these two to be the inverses of one another # (cycle-consistency). pairs = tf.concat([motion_features_stack, flipped_motion_features_stack], axis=-1) rot, trans, residual_translation, intrinsics_mat = ( object_motion_nets.motion_field_net( images=pairs, weight_reg=params.motion_prediction_params.weight_reg, align_corners=params.motion_prediction_params.align_corners, auto_mask=params.motion_prediction_params.auto_mask)) if params.motion_field_burnin_steps > 0.0: step = tf.to_float(tf.train.get_or_create_global_step()) burnin_steps = tf.to_float(params.motion_field_burnin_steps) residual_translation *= tf.clip_by_value(2 * step / burnin_steps - 1, 0.0, 1.0) # If using grouth truth egomotion if not params.learn_egomotion: egomotion_mat = tf.concat(features['egomotion_mat'], axis=0) rot = transform_utils.angles_from_matrix(egomotion_mat[:, :3, :3]) trans = egomotion_mat[:, :3, 3] trans = tf.expand_dims(trans, 1) trans = tf.expand_dims(trans, 1) if params.use_mask: mask = tf.to_float(tf.concat(features['mask'], axis=0) > 0) if params.foreground_dilation > 0: pool_size = params.foreground_dilation * 2 + 1 mask = tf.nn.max_pool(mask, [1, pool_size, pool_size, 1], [1] * 4, 'SAME') residual_translation *= mask maybe_summary.histogram('ResidualTranslation', residual_translation) maybe_summary.histogram('BackgroundTranslation', trans) maybe_summary.histogram('Rotation', rot) endpoints['residual_translation'] = tf.split(residual_translation, 2, axis=0) endpoints['background_translation'] = tf.split(trans, 2, axis=0) endpoints['rotation'] = tf.split(rot, 2, axis=0) if not params.learn_intrinsics.enabled: endpoints['intrinsics_mat'] = features['intrinsics_mat'] endpoints['intrinsics_mat_inv'] = features['intrinsics_mat_inv'] elif params.learn_intrinsics.per_video: int_mat = intrinsics_utils.create_and_fetch_intrinsics_per_video_index( features['video_index'][0], params.image_preprocessing.image_height, params.image_preprocessing.image_width, max_video_index=params.learn_intrinsics.max_number_of_videos) endpoints['intrinsics_mat'] = tf.concat([int_mat] * 2, axis=0) endpoints[ 'intrinsics_mat_inv'] = intrinsics_utils.invert_intrinsics_matrix( int_mat) else: # The intrinsic matrix should be the same, no matter the order of # images (mat = inv_mat). It's probably a good idea to enforce this # by a loss, but for now we just take their average as a prediction for the # intrinsic matrix. intrinsics_mat = 0.5 * sum(tf.split(intrinsics_mat, 2, axis=0)) endpoints['intrinsics_mat'] = [intrinsics_mat] * 2 endpoints['intrinsics_mat_inv'] = [ intrinsics_utils.invert_intrinsics_matrix(intrinsics_mat) ] * 2 aggregator = loss_aggregator.DepthMotionFieldLossAggregator( endpoints, params.loss_weights.as_dict(), params.loss_params.as_dict()) # Add some more summaries. maybe_summary.image('rgb0', features['rgb'][0]) maybe_summary.image('rgb1', features['rgb'][1]) disp0, disp1 = tf.split(aggregator.output_endpoints['disparity'], 2, axis=0) maybe_summary.image('disparity0/grayscale', disp0) maybe_summary.image_with_colormap('disparity0/plasma', tf.squeeze(disp0, axis=3), 'plasma', 0.0) maybe_summary.image('disparity1/grayscale', disp1) maybe_summary.image_with_colormap('disparity1/plasma', tf.squeeze(disp1, axis=3), 'plasma', 0.0) if maybe_summary.summaries_enabled(): if 'depth' in features: gt_disp0 = 1.0 / tf.maximum(features['depth'][0], 0.5) gt_disp1 = 1.0 / tf.maximum(features['depth'][1], 0.5) maybe_summary.image('disparity_gt0', gt_disp0) maybe_summary.image('disparity_gt1', gt_disp1) depth_proximity_weight0, depth_proximity_weight1 = tf.split( aggregator.output_endpoints['depth_proximity_weight'], 2, axis=0) maybe_summary.image('consistency_weight0', tf.expand_dims(depth_proximity_weight0, -1)) maybe_summary.image('consistency_weight1', tf.expand_dims(depth_proximity_weight1, -1)) maybe_summary.image('trans', aggregator.output_endpoints['trans']) maybe_summary.image('trans_inv', aggregator.output_endpoints['inv_trans']) maybe_summary.image('trans_res', endpoints['residual_translation'][0]) maybe_summary.image('trans_res_inv', endpoints['residual_translation'][1]) return aggregator.losses