Beispiel #1
0
    def _calculate(self):
        # On tpu we strive to stack tensors together and perform ops once on the
        # entire stack, to save time HBM memory. We thus stack the batch-of-first-
        # frames and the batch-of-second frames, for both depth and RGB. The batch
        # dimension of rgb_stack and gt_depth_stack are thus twice the original
        # batch size.

        # Create stacks for features that need to be scaled into pyramids for
        # multi-scale training.
        rgb_stack_ = tf.concat(self._endpoints['rgb'], axis=0)
        flipped_rgb_stack_ = tf.concat(self._endpoints['rgb'][::-1], axis=0)
        predicted_depth_stack_ = tf.concat(self._endpoints['predicted_depth'],
                                           axis=0)
        flipped_predicted_depth_stack_ = tf.concat(
            self._endpoints['predicted_depth'][::-1], axis=0)
        residual_translation_ = tf.concat(
            self._endpoints['residual_translation'], axis=0)
        flipped_residual_translation_ = tf.concat(
            self._endpoints['residual_translation'][::-1], axis=0)
        intrinsics_mat_ = tf.concat(self._endpoints['intrinsics_mat'], axis=0)

        # Create pyramids from each stack to support multi-scale training.
        num_scales = self._params.num_scales
        rgb_pyramid = _get_pyramid(rgb_stack_, num_scales=num_scales)
        flipped_rgb_pyramid = _get_pyramid(flipped_rgb_stack_,
                                           num_scales=num_scales)
        predicted_depth_pyramid = _get_pyramid(predicted_depth_stack_,
                                               num_scales=num_scales)
        flipped_predicted_depth_pyramid = _get_pyramid(
            flipped_predicted_depth_stack_, num_scales=num_scales)
        residual_translation_pyramid = _get_pyramid(residual_translation_,
                                                    num_scales=num_scales)
        flipped_residual_translation_pyramid = _get_pyramid(
            flipped_residual_translation_, num_scales=num_scales)
        intrinsics_mat_pyramid = _get_intrinsics_mat_pyramid(
            intrinsics_mat_, num_scales=num_scales)
        validity_mask_ = self._endpoints.get('validity_mask')
        if validity_mask_ is not None:
            validity_mask_ = tf.concat(validity_mask_, axis=0)
            validity_mask_pyramid = _get_pyramid(validity_mask_, num_scales,
                                                 _min_pool2d)
        else:
            validity_mask_pyramid = [None] * num_scales

        if 'groundtruth_depth' in self._endpoints:
            gt_depth_stack_ = tf.concat(self._endpoints['groundtruth_depth'],
                                        axis=0)
            gt_depth_pyramid = _get_pyramid(gt_depth_stack_,
                                            num_scales=num_scales)
            if 'groundtruth_depth_weight' in self._endpoints:
                gt_depth_weight_stack_ = tf.concat(
                    self._endpoints['groundtruth_depth_weight'], axis=0)
            else:
                gt_depth_weight_stack_ = tf.cast(
                    tf.greater(gt_depth_stack_, 0.2), tf.float32)
            gt_depth_weight_pyramid = _get_pyramid(gt_depth_weight_stack_,
                                                   num_scales=num_scales)

            if 'groundtruth_depth_filter' in self._endpoints:
                depth_filter_ = tf.concat(
                    self._endpoints['groundtruth_depth_filter'], axis=0)
                depth_filter_ = tf.cast(depth_filter_, tf.float32)
                depth_filter_pyramid = _get_pyramid(gt_depth_stack_,
                                                    num_scales=num_scales)

        # Calculate losses at each scale.  Iterate in reverse so that the final
        # output values are set at scale 0.
        for s in reversed(range(self._params.num_scales)):
            # Weight applied to all losses at this scale.
            scale_w = 1.0 / 2**s

            rgb_stack = rgb_pyramid[s]
            predicted_depth_stack = predicted_depth_pyramid[s]
            flipped_predicted_depth_stack = flipped_predicted_depth_pyramid[s]

            if 'groundtruth_depth' in self._endpoints:
                gt_depth_stack = gt_depth_pyramid[s]
                depth_error = tf.abs(gt_depth_stack - predicted_depth_stack)

                # Weigh the spatial loss if a weight map is provided. Otherwise, revert
                # to original behavior.
                gt_depth_weight_stack = gt_depth_weight_pyramid[s]
                depth_error = depth_error * gt_depth_weight_stack

                # Optionally filter the depth map if a boolean depth filter is provided.
                # We use a TPU-friendly equivalent of tf.boolean_mask.
                depth_filter = tf.ones_like(depth_error, tf.float32)
                if 'groundtruth_depth_filter' in self._endpoints:
                    depth_filter = depth_filter_pyramid[s]

                self._losses['depth_supervision'] += scale_w * tf.reduce_mean(
                    depth_error * depth_filter) / tf.reduce_mean(depth_filter)

            # In theory, the training losses should be agnostic to the global scale of
            # the predicted depth. However in reality second order effects can lead to
            # (https://en.wikipedia.org/wiki/Von_Neumann_stability_analysis) diverging
            # modes. For some reason this happens when training on TPU. Since the
            # scale is immaterial anyway, we normalize it out, and the training
            # stabilizes.
            #
            # Note that the depth supervision term, which is sensitive to the scale,
            # was applied before this normalization. Therefore the scale of the depth
            # is learned.
            mean_depth = tf.reduce_mean(predicted_depth_stack)

            # When training starts, the depth sometimes tends to collapse to a
            # constant value, which seems to be a fixed point where the trainig can
            # stuck. To discourage this collapse, we penalize the reciprocal of the
            # variance with a tiny weight. Note that the mean of predicted_depth is
            # one, hence we subtract 1.0.
            depth_var = tf.reduce_mean(
                tf.square(predicted_depth_stack / mean_depth - 1.0))
            self._losses['depth_variance'] = scale_w * 1.0 / depth_var

            if self._params.scale_normalization:
                predicted_depth_stack /= mean_depth
                flipped_predicted_depth_stack /= mean_depth

            disp = 1.0 / predicted_depth_stack

            mean_disp = tf.reduce_mean(disp, axis=[1, 2, 3], keep_dims=True)
            self._losses['depth_smoothing'] += (
                scale_w * regularizers.joint_bilateral_smoothing(
                    disp / mean_disp, rgb_stack))
            self._output_endpoints['disparity'] = disp

            flipped_rgb_stack = flipped_rgb_pyramid[s]

            background_translation = tf.concat(
                self._endpoints['background_translation'], axis=0)
            flipped_background_translation = tf.concat(
                self._endpoints['background_translation'][::-1], axis=0)
            residual_translation = residual_translation_pyramid[s]
            flipped_residual_translation = flipped_residual_translation_pyramid[
                s]
            if self._params.scale_normalization:
                background_translation /= mean_depth
                flipped_background_translation /= mean_depth
                residual_translation /= mean_depth
                flipped_residual_translation /= mean_depth
            translation = residual_translation + background_translation
            flipped_translation = (flipped_residual_translation +
                                   flipped_background_translation)

            rotation = tf.concat(self._endpoints['rotation'], axis=0)
            flipped_rotation = tf.concat(self._endpoints['rotation'][::-1],
                                         axis=0)
            intrinsics_mat = intrinsics_mat_pyramid[s]
            intrinsics_mat_inv = intrinsics_utils.invert_intrinsics_matrix(
                intrinsics_mat)
            validity_mask = validity_mask_pyramid[s]

            transformed_depth = transform_depth_map.using_motion_vector(
                tf.squeeze(predicted_depth_stack, axis=-1), translation,
                rotation, intrinsics_mat, intrinsics_mat_inv)
            flipped_predicted_depth_stack = tf.squeeze(
                flipped_predicted_depth_stack, axis=-1)
            if self._params.target_depth_stop_gradient:
                flipped_predicted_depth_stack = tf.stop_gradient(
                    flipped_predicted_depth_stack)
            # The first and second halves of the batch not contain Frame1's and
            # Frame2's depths transformed onto Frame2 and Frame1 respectively. Te
            # demand consistency, we need to `flip` `predicted_depth` as well.
            loss_endpoints = (
                consistency_losses.rgbd_and_motion_consistency_loss(
                    transformed_depth,
                    rgb_stack,
                    flipped_predicted_depth_stack,
                    flipped_rgb_stack,
                    rotation,
                    translation,
                    flipped_rotation,
                    flipped_translation,
                    validity_mask=validity_mask))

            normalized_trans = regularizers.normalize_motion_map(
                residual_translation, translation)
            self._losses[
                'motion_smoothing'] += scale_w * regularizers.l1smoothness(
                    normalized_trans, self._weights.motion_drift == 0)
            self._losses[
                'motion_drift'] += scale_w * regularizers.sqrt_sparsity(
                    normalized_trans)
            self._losses['depth_consistency'] += (
                scale_w * loss_endpoints['depth_error'])
            self._losses[
                'rgb_consistency'] += scale_w * loss_endpoints['rgb_error']
            self._losses[
                'ssim'] += scale_w * 0.5 * loss_endpoints['ssim_error']

            self._losses['rotation_cycle_consistency'] += (
                scale_w * loss_endpoints['rotation_error'])
            self._losses['translation_cycle_consistency'] += (
                scale_w * loss_endpoints['translation_error'])

            self._output_endpoints['depth_proximity_weight'] = loss_endpoints[
                'depth_proximity_weight']
            self._output_endpoints['trans'] = translation
            self._output_endpoints['inv_trans'] = flipped_translation

        for k, w in self._weights.as_dict().items():
            # multiply by 2 to match the scale of the old code.
            self._losses[k] *= w * 2

        if tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES):
            self._losses[tf.GraphKeys.REGULARIZATION_LOSSES] = tf.add_n(
                tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
Beispiel #2
0
def loss_fn(features, mode, params):
    """Computes the training loss for depth and egomotion training.

  This function is written with TPU-friendlines in mind.

  Args:
    features: A dictionary mapping strings to tuples of (tf.Tensor, tf.Tensor),
      representing pairs of frames. The loss will be calculated from these
      tensors. The expected endpoints are 'rgb', 'depth', 'intrinsics_mat'
      and 'intrinsics_mat_inv'.
    mode: One of tf.estimator.ModeKeys: TRAIN, PREDICT or EVAL.
    params: A dictionary with hyperparameters that optionally override
      DEFAULT_PARAMS above.

  Returns:
    A dictionary mapping each loss name (see DEFAULT_PARAMS['loss_weights']'s
    keys) to a scalar tf.Tensor representing the respective loss. The total
    training loss.

  Raises:
    ValueError: `features` endpoints that don't conform with their expected
       structure.
  """
    params = parameter_container.ParameterContainer.from_defaults_and_overrides(
        DEFAULT_PARAMS, params, is_strict=True, strictness_depth=2)

    if len(features['rgb']) != 2 or 'depth' in features and len(
            features['depth']) != 2:
        raise ValueError(
            'RGB and depth endpoints are expected to be a tuple of two'
            ' tensors. Rather, they are %s.' % str(features))

    # On tpu we strive to stack tensors together and perform ops once on the
    # entire stack, to save time HBM memory. We thus stack the batch-of-first-
    # frames and the batch-of-second frames, for both depth and RGB. The batch
    # dimension of rgb_stack and gt_depth_stack are thus twice the original batch
    # size.
    rgb_stack = tf.concat(features['rgb'], axis=0)

    depth_predictor = depth_prediction_nets.ResNet18DepthPredictor(
        mode, params.depth_predictor_params.as_dict())
    predicted_depth = depth_predictor.predict_depth(rgb_stack)
    maybe_summary.histogram('PredictedDepth', predicted_depth)

    endpoints = {}
    endpoints['predicted_depth'] = tf.split(predicted_depth, 2, axis=0)
    endpoints['rgb'] = features['rgb']

    # We make the heuristic that depths that are less than 0.2 meters are not
    # accurate. This is a rough placeholder for a confidence map that we're going
    # to have in future.
    if 'depth' in features:
        endpoints['groundtruth_depth'] = features['depth']

    if params.cascade:
        motion_features = [
            tf.concat([features['rgb'][0], endpoints['predicted_depth'][0]],
                      axis=-1),
            tf.concat([features['rgb'][1], endpoints['predicted_depth'][1]],
                      axis=-1)
        ]
    else:
        motion_features = features['rgb']

    motion_features_stack = tf.concat(motion_features, axis=0)
    flipped_motion_features_stack = tf.concat(motion_features[::-1], axis=0)
    # Unlike `rgb_stack`, here we stacked the frames in reverse order along the
    # Batch dimension. By concatenating the two stacks below along the channel
    # axis, we create the following tensor:
    #
    #         Channel dimension (3)
    #   _                                 _
    #  |  Frame1-s batch | Frame2-s batch  |____Batch
    #  |_ Frame2-s batch | Frame1-s batch _|    dimension (0)
    #
    # When we send this tensor to the motion prediction network, the first and
    # second halves of the result represent the camera motion from Frame1 to
    # Frame2 and from Frame2 to Frame1 respectively. Further below we impose a
    # loss that drives these two to be the inverses of one another
    # (cycle-consistency).
    pairs = tf.concat([motion_features_stack, flipped_motion_features_stack],
                      axis=-1)

    rot, trans, residual_translation, intrinsics_mat = (
        object_motion_nets.motion_field_net(
            images=pairs,
            weight_reg=params.motion_prediction_params.weight_reg,
            align_corners=params.motion_prediction_params.align_corners,
            auto_mask=params.motion_prediction_params.auto_mask))

    if params.motion_field_burnin_steps > 0.0:
        step = tf.to_float(tf.train.get_or_create_global_step())
        burnin_steps = tf.to_float(params.motion_field_burnin_steps)
        residual_translation *= tf.clip_by_value(2 * step / burnin_steps - 1,
                                                 0.0, 1.0)

    # If using grouth truth egomotion
    if not params.learn_egomotion:
        egomotion_mat = tf.concat(features['egomotion_mat'], axis=0)
        rot = transform_utils.angles_from_matrix(egomotion_mat[:, :3, :3])
        trans = egomotion_mat[:, :3, 3]
        trans = tf.expand_dims(trans, 1)
        trans = tf.expand_dims(trans, 1)

    if params.use_mask:
        mask = tf.to_float(tf.concat(features['mask'], axis=0) > 0)
        if params.foreground_dilation > 0:
            pool_size = params.foreground_dilation * 2 + 1
            mask = tf.nn.max_pool(mask, [1, pool_size, pool_size, 1], [1] * 4,
                                  'SAME')
        residual_translation *= mask

    maybe_summary.histogram('ResidualTranslation', residual_translation)
    maybe_summary.histogram('BackgroundTranslation', trans)
    maybe_summary.histogram('Rotation', rot)
    endpoints['residual_translation'] = tf.split(residual_translation,
                                                 2,
                                                 axis=0)
    endpoints['background_translation'] = tf.split(trans, 2, axis=0)
    endpoints['rotation'] = tf.split(rot, 2, axis=0)

    if not params.learn_intrinsics.enabled:
        endpoints['intrinsics_mat'] = features['intrinsics_mat']
        endpoints['intrinsics_mat_inv'] = features['intrinsics_mat_inv']
    elif params.learn_intrinsics.per_video:
        int_mat = intrinsics_utils.create_and_fetch_intrinsics_per_video_index(
            features['video_index'][0],
            params.image_preprocessing.image_height,
            params.image_preprocessing.image_width,
            max_video_index=params.learn_intrinsics.max_number_of_videos)
        endpoints['intrinsics_mat'] = tf.concat([int_mat] * 2, axis=0)
        endpoints[
            'intrinsics_mat_inv'] = intrinsics_utils.invert_intrinsics_matrix(
                int_mat)
    else:
        # The intrinsic matrix should be the same, no matter the order of
        # images (mat = inv_mat). It's probably a good idea to enforce this
        # by a loss, but for now we just take their average as a prediction for the
        # intrinsic matrix.
        intrinsics_mat = 0.5 * sum(tf.split(intrinsics_mat, 2, axis=0))
        endpoints['intrinsics_mat'] = [intrinsics_mat] * 2
        endpoints['intrinsics_mat_inv'] = [
            intrinsics_utils.invert_intrinsics_matrix(intrinsics_mat)
        ] * 2

    aggregator = loss_aggregator.DepthMotionFieldLossAggregator(
        endpoints, params.loss_weights.as_dict(), params.loss_params.as_dict())

    # Add some more summaries.
    maybe_summary.image('rgb0', features['rgb'][0])
    maybe_summary.image('rgb1', features['rgb'][1])
    disp0, disp1 = tf.split(aggregator.output_endpoints['disparity'],
                            2,
                            axis=0)
    maybe_summary.image('disparity0/grayscale', disp0)
    maybe_summary.image_with_colormap('disparity0/plasma',
                                      tf.squeeze(disp0, axis=3), 'plasma', 0.0)
    maybe_summary.image('disparity1/grayscale', disp1)
    maybe_summary.image_with_colormap('disparity1/plasma',
                                      tf.squeeze(disp1, axis=3), 'plasma', 0.0)
    if maybe_summary.summaries_enabled():
        if 'depth' in features:
            gt_disp0 = 1.0 / tf.maximum(features['depth'][0], 0.5)
            gt_disp1 = 1.0 / tf.maximum(features['depth'][1], 0.5)
            maybe_summary.image('disparity_gt0', gt_disp0)
            maybe_summary.image('disparity_gt1', gt_disp1)

        depth_proximity_weight0, depth_proximity_weight1 = tf.split(
            aggregator.output_endpoints['depth_proximity_weight'], 2, axis=0)
        maybe_summary.image('consistency_weight0',
                            tf.expand_dims(depth_proximity_weight0, -1))
        maybe_summary.image('consistency_weight1',
                            tf.expand_dims(depth_proximity_weight1, -1))
        maybe_summary.image('trans', aggregator.output_endpoints['trans'])
        maybe_summary.image('trans_inv',
                            aggregator.output_endpoints['inv_trans'])
        maybe_summary.image('trans_res', endpoints['residual_translation'][0])
        maybe_summary.image('trans_res_inv',
                            endpoints['residual_translation'][1])

    return aggregator.losses