Exemple #1
0
def build(matcher_config):
    """Builds a matcher object based on the matcher config.

  Args:
    matcher_config: A matcher.proto object containing the config for the desired
      Matcher.

  Returns:
    Matcher based on the config.

  Raises:
    ValueError: On empty matcher proto.
  """
    if not isinstance(matcher_config, matcher_pb2.Matcher):
        raise ValueError('matcher_config not of type matcher_pb2.Matcher.')
    if matcher_config.WhichOneof('matcher_oneof') == 'argmax_matcher':
        matcher = matcher_config.argmax_matcher
        matched_threshold = unmatched_threshold = None
        if not matcher.ignore_thresholds:
            matched_threshold = matcher.matched_threshold
            unmatched_threshold = matcher.unmatched_threshold
        return argmax_matcher.ArgMaxMatcher(
            matched_threshold=matched_threshold,
            unmatched_threshold=unmatched_threshold,
            negatives_lower_than_unmatched=matcher.
            negatives_lower_than_unmatched,
            force_match_for_each_row=matcher.force_match_for_each_row,
            use_matmul_gather=matcher.use_matmul_gather)
    if matcher_config.WhichOneof('matcher_oneof') == 'bipartite_matcher':
        if tf_version.is_tf2():
            raise ValueError('bipartite_matcher is not supported in TF 2.X')
        matcher = matcher_config.bipartite_matcher
        return bipartite_matcher.GreedyBipartiteMatcher(
            matcher.use_matmul_gather)
    raise ValueError('Empty matcher.')
    def _build_feature_map_generator(self,
                                     feature_map_layout,
                                     pool_residual=False):
        if tf_version.is_tf2():
            return feature_map_generators.KerasMultiResolutionFeatureMaps(
                feature_map_layout=feature_map_layout,
                depth_multiplier=1,
                min_depth=32,
                insert_1x1_conv=True,
                freeze_batchnorm=False,
                is_training=True,
                conv_hyperparams=self._build_conv_hyperparams(),
                name='FeatureMaps')
        else:

            def feature_map_generator(image_features):
                return feature_map_generators.multi_resolution_feature_maps(
                    feature_map_layout=feature_map_layout,
                    depth_multiplier=1,
                    min_depth=32,
                    insert_1x1_conv=True,
                    image_features=image_features,
                    pool_residual=pool_residual)

            return feature_map_generator
    def _build_feature_map_generator(self,
                                     image_features,
                                     depth,
                                     use_bounded_activations=False,
                                     use_native_resize_op=False,
                                     use_explicit_padding=False,
                                     use_depthwise=False):
        if tf_version.is_tf2():
            return feature_map_generators.KerasFpnTopDownFeatureMaps(
                num_levels=len(image_features),
                depth=depth,
                is_training=True,
                conv_hyperparams=self._build_conv_hyperparams(),
                freeze_batchnorm=False,
                use_depthwise=use_depthwise,
                use_explicit_padding=use_explicit_padding,
                use_bounded_activations=use_bounded_activations,
                use_native_resize_op=use_native_resize_op,
                scope=None,
                name='FeatureMaps',
            )
        else:

            def feature_map_generator(image_features):
                return feature_map_generators.fpn_top_down_feature_maps(
                    image_features=image_features,
                    depth=depth,
                    use_depthwise=use_depthwise,
                    use_explicit_padding=use_explicit_padding,
                    use_bounded_activations=use_bounded_activations,
                    use_native_resize_op=use_native_resize_op)

            return feature_map_generator
    def test_get_expected_variable_names_with_inception_v2(self):
        with test_utils.GraphContextOrNone() as g:
            image_features = {
                'Mixed_3c': tf.random_uniform([4, 28, 28, 256],
                                              dtype=tf.float32),
                'Mixed_4c': tf.random_uniform([4, 14, 14, 576],
                                              dtype=tf.float32),
                'Mixed_5c': tf.random_uniform([4, 7, 7, 1024],
                                              dtype=tf.float32)
            }
            feature_map_generator = self._build_feature_map_generator(
                feature_map_layout=INCEPTION_V2_LAYOUT, )

        def graph_fn():
            return feature_map_generator(image_features)

        self.execute(graph_fn, [], g)
        expected_slim_variables = set([
            'Mixed_5c_1_Conv2d_3_1x1_256/weights',
            'Mixed_5c_1_Conv2d_3_1x1_256/biases',
            'Mixed_5c_2_Conv2d_3_3x3_s2_512/weights',
            'Mixed_5c_2_Conv2d_3_3x3_s2_512/biases',
            'Mixed_5c_1_Conv2d_4_1x1_128/weights',
            'Mixed_5c_1_Conv2d_4_1x1_128/biases',
            'Mixed_5c_2_Conv2d_4_3x3_s2_256/weights',
            'Mixed_5c_2_Conv2d_4_3x3_s2_256/biases',
            'Mixed_5c_1_Conv2d_5_1x1_128/weights',
            'Mixed_5c_1_Conv2d_5_1x1_128/biases',
            'Mixed_5c_2_Conv2d_5_3x3_s2_256/weights',
            'Mixed_5c_2_Conv2d_5_3x3_s2_256/biases',
        ])

        expected_keras_variables = set([
            'FeatureMaps/Mixed_5c_1_Conv2d_3_1x1_256_conv/kernel',
            'FeatureMaps/Mixed_5c_1_Conv2d_3_1x1_256_conv/bias',
            'FeatureMaps/Mixed_5c_2_Conv2d_3_3x3_s2_512_conv/kernel',
            'FeatureMaps/Mixed_5c_2_Conv2d_3_3x3_s2_512_conv/bias',
            'FeatureMaps/Mixed_5c_1_Conv2d_4_1x1_128_conv/kernel',
            'FeatureMaps/Mixed_5c_1_Conv2d_4_1x1_128_conv/bias',
            'FeatureMaps/Mixed_5c_2_Conv2d_4_3x3_s2_256_conv/kernel',
            'FeatureMaps/Mixed_5c_2_Conv2d_4_3x3_s2_256_conv/bias',
            'FeatureMaps/Mixed_5c_1_Conv2d_5_1x1_128_conv/kernel',
            'FeatureMaps/Mixed_5c_1_Conv2d_5_1x1_128_conv/bias',
            'FeatureMaps/Mixed_5c_2_Conv2d_5_3x3_s2_256_conv/kernel',
            'FeatureMaps/Mixed_5c_2_Conv2d_5_3x3_s2_256_conv/bias',
        ])

        if tf_version.is_tf2():
            actual_variable_set = set([
                var.name.split(':')[0]
                for var in feature_map_generator.variables
            ])
            self.assertSetEqual(expected_keras_variables, actual_variable_set)
        else:
            with g.as_default():
                actual_variable_set = set(
                    [var.op.name for var in tf.trainable_variables()])
            self.assertSetEqual(expected_slim_variables, actual_variable_set)
 def test_build_bipartite_matcher(self):
     if tf_version.is_tf2():
         self.skipTest('BipartiteMatcher unsupported in TF 2.X. Skipping.')
     matcher_text_proto = """
   bipartite_matcher {
   }
 """
     matcher_proto = matcher_pb2.Matcher()
     text_format.Merge(matcher_text_proto, matcher_proto)
     matcher_object = matcher_builder.build(matcher_proto)
     self.assertIsInstance(matcher_object,
                           bipartite_matcher.GreedyBipartiteMatcher)
Exemple #6
0
 def __exit__(self, ttype, value, traceback):
   if tf_version.is_tf2():
     return False
   else:
     return self.graph.__exit__(ttype, value, traceback)
Exemple #7
0
 def __enter__(self):
   if tf_version.is_tf2():
     return None
   else:
     return self.graph.__enter__()
Exemple #8
0
 def __init__(self):
   if tf_version.is_tf2():
     self.graph = None
   else:
     self.graph = tf.Graph().as_default()
 def materialize_tensors(self, list_of_tensors):
   if tf_version.is_tf2():
     return [tensor.numpy() for tensor in list_of_tensors]
   else:
     with self.cached_session() as sess:
       return sess.run(list_of_tensors)
Exemple #10
0
def _build_ssd_feature_extractor(feature_extractor_config,
                                 is_training,
                                 freeze_batchnorm,
                                 reuse_weights=None):
    """Builds a ssd_meta_arch.SSDFeatureExtractor based on config.

  Args:
    feature_extractor_config: A SSDFeatureExtractor proto config from ssd.proto.
    is_training: True if this feature extractor is being built for training.
    freeze_batchnorm: Whether to freeze batch norm parameters during
      training or not. When training with a small batch size (e.g. 1), it is
      desirable to freeze batch norm update and use pretrained batch norm
      params.
    reuse_weights: if the feature extractor should reuse weights.

  Returns:
    ssd_meta_arch.SSDFeatureExtractor based on config.

  Raises:
    ValueError: On invalid feature extractor type.
  """
    feature_type = feature_extractor_config.type
    depth_multiplier = feature_extractor_config.depth_multiplier
    min_depth = feature_extractor_config.min_depth
    pad_to_multiple = feature_extractor_config.pad_to_multiple
    use_explicit_padding = feature_extractor_config.use_explicit_padding
    use_depthwise = feature_extractor_config.use_depthwise

    is_keras = tf_version.is_tf2()
    if is_keras:
        conv_hyperparams = hyperparams_builder.KerasLayerHyperparams(
            feature_extractor_config.conv_hyperparams)
    else:
        conv_hyperparams = hyperparams_builder.build(
            feature_extractor_config.conv_hyperparams, is_training)
    override_base_feature_extractor_hyperparams = (
        feature_extractor_config.override_base_feature_extractor_hyperparams)

    if not is_keras and feature_type not in SSD_FEATURE_EXTRACTOR_CLASS_MAP:
        raise ValueError(
            'Unknown ssd feature_extractor: {}'.format(feature_type))

    if is_keras:
        feature_extractor_class = SSD_KERAS_FEATURE_EXTRACTOR_CLASS_MAP[
            feature_type]
    else:
        feature_extractor_class = SSD_FEATURE_EXTRACTOR_CLASS_MAP[feature_type]
    kwargs = {
        'is_training':
        is_training,
        'depth_multiplier':
        depth_multiplier,
        'min_depth':
        min_depth,
        'pad_to_multiple':
        pad_to_multiple,
        'use_explicit_padding':
        use_explicit_padding,
        'use_depthwise':
        use_depthwise,
        'override_base_feature_extractor_hyperparams':
        override_base_feature_extractor_hyperparams
    }

    if feature_extractor_config.HasField(
            'replace_preprocessor_with_placeholder'):
        kwargs.update({
            'replace_preprocessor_with_placeholder':
            feature_extractor_config.replace_preprocessor_with_placeholder
        })

    if feature_extractor_config.HasField('num_layers'):
        kwargs.update({'num_layers': feature_extractor_config.num_layers})

    if is_keras:
        kwargs.update({
            'conv_hyperparams': conv_hyperparams,
            'inplace_batchnorm_update': False,
            'freeze_batchnorm': freeze_batchnorm
        })
    else:
        kwargs.update({
            'conv_hyperparams_fn': conv_hyperparams,
            'reuse_weights': reuse_weights,
        })

    if feature_extractor_config.HasField('fpn'):
        kwargs.update({
            'fpn_min_level':
            feature_extractor_config.fpn.min_level,
            'fpn_max_level':
            feature_extractor_config.fpn.max_level,
            'additional_layer_depth':
            feature_extractor_config.fpn.additional_layer_depth,
        })

    if feature_extractor_config.HasField('bifpn'):
        kwargs.update({
            'bifpn_min_level':
            feature_extractor_config.bifpn.min_level,
            'bifpn_max_level':
            feature_extractor_config.bifpn.max_level,
            'bifpn_num_iterations':
            feature_extractor_config.bifpn.num_iterations,
            'bifpn_num_filters':
            feature_extractor_config.bifpn.num_filters,
            'bifpn_combine_method':
            feature_extractor_config.bifpn.combine_method,
        })

    return feature_extractor_class(**kwargs)
Exemple #11
0
import numpy as np
import tensorflow.compat.v1 as tf
import tf_slim as slim
from google.protobuf import text_format

from object_detection.builders import hyperparams_builder
from object_detection.core import freezable_batch_norm
from object_detection.protos import hyperparams_pb2
from object_detection.utils import tf_version


def _get_scope_key(op):
    return getattr(op, '_key_op', str(op))


@unittest.skipIf(tf_version.is_tf2(), 'Skipping TF1.X only tests.')
class HyperparamsBuilderTest(tf.test.TestCase):
    def test_default_arg_scope_has_conv2d_op(self):
        conv_hyperparams_text_proto = """
      regularizer {
        l1_regularizer {
        }
      }
      initializer {
        truncated_normal_initializer {
        }
      }
    """
        conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
        text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
        scope_fn = hyperparams_builder.build(conv_hyperparams_proto,
Exemple #12
0
    def _create_model(
            self,
            model_fn=ssd_meta_arch.SSDMetaArch,
            apply_hard_mining=True,
            normalize_loc_loss_by_codesize=False,
            add_background_class=True,
            random_example_sampling=False,
            expected_loss_weights=model_pb2.DetectionModel().ssd.loss.NONE,
            min_num_negative_samples=1,
            desired_negative_sampling_ratio=3,
            predict_mask=False,
            use_static_shapes=False,
            nms_max_size_per_class=5,
            calibration_mapping_value=None,
            return_raw_detections_during_predict=False):
        is_training = False
        num_classes = 1
        mock_anchor_generator = MockAnchorGenerator2x2()
        use_keras = tf_version.is_tf2()
        if use_keras:
            mock_box_predictor = test_utils.MockKerasBoxPredictor(
                is_training,
                num_classes,
                add_background_class=add_background_class)
        else:
            mock_box_predictor = test_utils.MockBoxPredictor(
                is_training,
                num_classes,
                add_background_class=add_background_class)
        mock_box_coder = test_utils.MockBoxCoder()
        if use_keras:
            fake_feature_extractor = FakeSSDKerasFeatureExtractor()
        else:
            fake_feature_extractor = FakeSSDFeatureExtractor()
        mock_matcher = test_utils.MockMatcher()
        region_similarity_calculator = sim_calc.IouSimilarity()
        encode_background_as_zeros = False

        def image_resizer_fn(image):
            return [tf.identity(image), tf.shape(image)]

        classification_loss = losses.WeightedSigmoidClassificationLoss()
        localization_loss = losses.WeightedSmoothL1LocalizationLoss()
        non_max_suppression_fn = functools.partial(
            post_processing.batch_multiclass_non_max_suppression,
            score_thresh=-20.0,
            iou_thresh=1.0,
            max_size_per_class=nms_max_size_per_class,
            max_total_size=nms_max_size_per_class,
            use_static_shapes=use_static_shapes)
        score_conversion_fn = tf.identity
        calibration_config = calibration_pb2.CalibrationConfig()
        if calibration_mapping_value:
            calibration_text_proto = """
      function_approximation {
        x_y_pairs {
            x_y_pair {
              x: 0.0
              y: %f
            }
            x_y_pair {
              x: 1.0
              y: %f
            }}}""" % (calibration_mapping_value, calibration_mapping_value)
            text_format.Merge(calibration_text_proto, calibration_config)
            score_conversion_fn = (
                post_processing_builder._build_calibrated_score_converter(  # pylint: disable=protected-access
                    tf.identity, calibration_config))
        classification_loss_weight = 1.0
        localization_loss_weight = 1.0
        negative_class_weight = 1.0
        normalize_loss_by_num_matches = False

        hard_example_miner = None
        if apply_hard_mining:
            # This hard example miner is expected to be a no-op.
            hard_example_miner = losses.HardExampleMiner(
                num_hard_examples=None, iou_threshold=1.0)

        random_example_sampler = None
        if random_example_sampling:
            random_example_sampler = sampler.BalancedPositiveNegativeSampler(
                positive_fraction=0.5)

        target_assigner_instance = target_assigner.TargetAssigner(
            region_similarity_calculator,
            mock_matcher,
            mock_box_coder,
            negative_class_weight=negative_class_weight)

        model_config = model_pb2.DetectionModel()
        if expected_loss_weights == model_config.ssd.loss.NONE:
            expected_loss_weights_fn = None
        else:
            raise ValueError('Not a valid value for expected_loss_weights.')

        code_size = 4

        kwargs = {}
        if predict_mask:
            kwargs.update({
                'mask_prediction_fn':
                test_utils.MockMaskHead(num_classes=1).predict,
            })

        model = model_fn(
            is_training=is_training,
            anchor_generator=mock_anchor_generator,
            box_predictor=mock_box_predictor,
            box_coder=mock_box_coder,
            feature_extractor=fake_feature_extractor,
            encode_background_as_zeros=encode_background_as_zeros,
            image_resizer_fn=image_resizer_fn,
            non_max_suppression_fn=non_max_suppression_fn,
            score_conversion_fn=score_conversion_fn,
            classification_loss=classification_loss,
            localization_loss=localization_loss,
            classification_loss_weight=classification_loss_weight,
            localization_loss_weight=localization_loss_weight,
            normalize_loss_by_num_matches=normalize_loss_by_num_matches,
            hard_example_miner=hard_example_miner,
            target_assigner_instance=target_assigner_instance,
            add_summaries=False,
            normalize_loc_loss_by_codesize=normalize_loc_loss_by_codesize,
            freeze_batchnorm=False,
            inplace_batchnorm_update=False,
            add_background_class=add_background_class,
            random_example_sampler=random_example_sampler,
            expected_loss_weights_fn=expected_loss_weights_fn,
            return_raw_detections_during_predict=(
                return_raw_detections_during_predict),
            **kwargs)
        return model, num_classes, mock_anchor_generator.num_anchors(
        ), code_size
Exemple #13
0
class ContextRCNNMetaArchTest(test_case.TestCase, parameterized.TestCase):
    def _get_model(self, box_predictor, **common_kwargs):
        return context_rcnn_meta_arch.ContextRCNNMetaArch(
            initial_crop_size=3,
            maxpool_kernel_size=1,
            maxpool_stride=1,
            second_stage_mask_rcnn_box_predictor=box_predictor,
            attention_bottleneck_dimension=10,
            attention_temperature=0.2,
            **common_kwargs)

    def _build_arg_scope_with_hyperparams(self, hyperparams_text_proto,
                                          is_training):
        hyperparams = hyperparams_pb2.Hyperparams()
        text_format.Merge(hyperparams_text_proto, hyperparams)
        return hyperparams_builder.build(hyperparams, is_training=is_training)

    def _build_keras_layer_hyperparams(self, hyperparams_text_proto):
        hyperparams = hyperparams_pb2.Hyperparams()
        text_format.Merge(hyperparams_text_proto, hyperparams)
        return hyperparams_builder.KerasLayerHyperparams(hyperparams)

    def _get_second_stage_box_predictor_text_proto(
            self, share_box_across_classes=False):
        share_box_field = 'true' if share_box_across_classes else 'false'
        box_predictor_text_proto = """
      mask_rcnn_box_predictor {{
        fc_hyperparams {{
          op: FC
          activation: NONE
          regularizer {{
            l2_regularizer {{
              weight: 0.0005
            }}
          }}
          initializer {{
            variance_scaling_initializer {{
              factor: 1.0
              uniform: true
              mode: FAN_AVG
            }}
          }}
        }}
        share_box_across_classes: {share_box_across_classes}
      }}
    """.format(share_box_across_classes=share_box_field)
        return box_predictor_text_proto

    def _get_box_classifier_features_shape(self, image_size, batch_size,
                                           max_num_proposals,
                                           initial_crop_size, maxpool_stride,
                                           num_features):
        return (batch_size * max_num_proposals,
                initial_crop_size / maxpool_stride,
                initial_crop_size / maxpool_stride, num_features)

    def _get_second_stage_box_predictor(self,
                                        num_classes,
                                        is_training,
                                        predict_masks,
                                        masks_are_class_agnostic,
                                        share_box_across_classes=False,
                                        use_keras=False):
        box_predictor_proto = box_predictor_pb2.BoxPredictor()
        text_format.Merge(
            self._get_second_stage_box_predictor_text_proto(
                share_box_across_classes), box_predictor_proto)
        if predict_masks:
            text_format.Merge(
                self._add_mask_to_second_stage_box_predictor_text_proto(
                    masks_are_class_agnostic), box_predictor_proto)

        if use_keras:
            return box_predictor_builder.build_keras(
                hyperparams_builder.KerasLayerHyperparams,
                inplace_batchnorm_update=False,
                freeze_batchnorm=False,
                box_predictor_config=box_predictor_proto,
                num_classes=num_classes,
                num_predictions_per_location_list=None,
                is_training=is_training)
        else:
            return box_predictor_builder.build(hyperparams_builder.build,
                                               box_predictor_proto,
                                               num_classes=num_classes,
                                               is_training=is_training)

    def _build_model(self,
                     is_training,
                     number_of_stages,
                     second_stage_batch_size,
                     first_stage_max_proposals=8,
                     num_classes=2,
                     hard_mining=False,
                     softmax_second_stage_classification_loss=True,
                     predict_masks=False,
                     pad_to_max_dimension=None,
                     masks_are_class_agnostic=False,
                     use_matmul_crop_and_resize=False,
                     clip_anchors_to_image=False,
                     use_matmul_gather_in_matcher=False,
                     use_static_shapes=False,
                     calibration_mapping_value=None,
                     share_box_across_classes=False,
                     return_raw_detections_during_predict=False):
        use_keras = tf_version.is_tf2()

        def image_resizer_fn(image, masks=None):
            """Fake image resizer function."""
            resized_inputs = []
            resized_image = tf.identity(image)
            if pad_to_max_dimension is not None:
                resized_image = tf.image.pad_to_bounding_box(
                    image, 0, 0, pad_to_max_dimension, pad_to_max_dimension)
            resized_inputs.append(resized_image)
            if masks is not None:
                resized_masks = tf.identity(masks)
                if pad_to_max_dimension is not None:
                    resized_masks = tf.image.pad_to_bounding_box(
                        tf.transpose(masks, [1, 2, 0]), 0, 0,
                        pad_to_max_dimension, pad_to_max_dimension)
                    resized_masks = tf.transpose(resized_masks, [2, 0, 1])
                resized_inputs.append(resized_masks)
            resized_inputs.append(tf.shape(image))
            return resized_inputs

        # anchors in this test are designed so that a subset of anchors are inside
        # the image and a subset of anchors are outside.
        first_stage_anchor_scales = (0.001, 0.005, 0.1)
        first_stage_anchor_aspect_ratios = (0.5, 1.0, 2.0)
        first_stage_anchor_strides = (1, 1)
        first_stage_anchor_generator = grid_anchor_generator.GridAnchorGenerator(
            first_stage_anchor_scales,
            first_stage_anchor_aspect_ratios,
            anchor_stride=first_stage_anchor_strides)
        first_stage_target_assigner = target_assigner.create_target_assigner(
            'FasterRCNN',
            'proposal',
            use_matmul_gather=use_matmul_gather_in_matcher)

        if use_keras:
            fake_feature_extractor = FakeFasterRCNNKerasFeatureExtractor()
        else:
            fake_feature_extractor = FakeFasterRCNNFeatureExtractor()

        first_stage_box_predictor_hyperparams_text_proto = """
      op: CONV
      activation: RELU
      regularizer {
        l2_regularizer {
          weight: 0.00004
        }
      }
      initializer {
        truncated_normal_initializer {
          stddev: 0.03
        }
      }
    """
        if use_keras:
            first_stage_box_predictor_arg_scope_fn = (
                self._build_keras_layer_hyperparams(
                    first_stage_box_predictor_hyperparams_text_proto))
        else:
            first_stage_box_predictor_arg_scope_fn = (
                self._build_arg_scope_with_hyperparams(
                    first_stage_box_predictor_hyperparams_text_proto,
                    is_training))

        first_stage_box_predictor_kernel_size = 3
        first_stage_atrous_rate = 1
        first_stage_box_predictor_depth = 512
        first_stage_minibatch_size = 3
        first_stage_sampler = sampler.BalancedPositiveNegativeSampler(
            positive_fraction=0.5, is_static=use_static_shapes)

        first_stage_nms_score_threshold = -1.0
        first_stage_nms_iou_threshold = 1.0
        first_stage_max_proposals = first_stage_max_proposals
        first_stage_non_max_suppression_fn = functools.partial(
            post_processing.batch_multiclass_non_max_suppression,
            score_thresh=first_stage_nms_score_threshold,
            iou_thresh=first_stage_nms_iou_threshold,
            max_size_per_class=first_stage_max_proposals,
            max_total_size=first_stage_max_proposals,
            use_static_shapes=use_static_shapes)

        first_stage_localization_loss_weight = 1.0
        first_stage_objectness_loss_weight = 1.0

        post_processing_config = post_processing_pb2.PostProcessing()
        post_processing_text_proto = """
      score_converter: IDENTITY
      batch_non_max_suppression {
        score_threshold: -20.0
        iou_threshold: 1.0
        max_detections_per_class: 5
        max_total_detections: 5
        use_static_shapes: """ + '{}'.format(use_static_shapes) + """
      }
    """
        if calibration_mapping_value:
            calibration_text_proto = """
      calibration_config {
        function_approximation {
          x_y_pairs {
            x_y_pair {
              x: 0.0
              y: %f
            }
            x_y_pair {
              x: 1.0
              y: %f
              }}}}""" % (calibration_mapping_value, calibration_mapping_value)
            post_processing_text_proto = (post_processing_text_proto + ' ' +
                                          calibration_text_proto)
        text_format.Merge(post_processing_text_proto, post_processing_config)
        second_stage_non_max_suppression_fn, second_stage_score_conversion_fn = (
            post_processing_builder.build(post_processing_config))

        second_stage_target_assigner = target_assigner.create_target_assigner(
            'FasterRCNN',
            'detection',
            use_matmul_gather=use_matmul_gather_in_matcher)
        second_stage_sampler = sampler.BalancedPositiveNegativeSampler(
            positive_fraction=1.0, is_static=use_static_shapes)

        second_stage_localization_loss_weight = 1.0
        second_stage_classification_loss_weight = 1.0
        if softmax_second_stage_classification_loss:
            second_stage_classification_loss = (
                losses.WeightedSoftmaxClassificationLoss())
        else:
            second_stage_classification_loss = (
                losses.WeightedSigmoidClassificationLoss())

        hard_example_miner = None
        if hard_mining:
            hard_example_miner = losses.HardExampleMiner(
                num_hard_examples=1,
                iou_threshold=0.99,
                loss_type='both',
                cls_loss_weight=second_stage_classification_loss_weight,
                loc_loss_weight=second_stage_localization_loss_weight,
                max_negatives_per_positive=None)

        crop_and_resize_fn = (spatial_ops.multilevel_matmul_crop_and_resize
                              if use_matmul_crop_and_resize else
                              spatial_ops.multilevel_native_crop_and_resize)
        common_kwargs = {
            'is_training':
            is_training,
            'num_classes':
            num_classes,
            'image_resizer_fn':
            image_resizer_fn,
            'feature_extractor':
            fake_feature_extractor,
            'number_of_stages':
            number_of_stages,
            'first_stage_anchor_generator':
            first_stage_anchor_generator,
            'first_stage_target_assigner':
            first_stage_target_assigner,
            'first_stage_atrous_rate':
            first_stage_atrous_rate,
            'first_stage_box_predictor_arg_scope_fn':
            first_stage_box_predictor_arg_scope_fn,
            'first_stage_box_predictor_kernel_size':
            first_stage_box_predictor_kernel_size,
            'first_stage_box_predictor_depth':
            first_stage_box_predictor_depth,
            'first_stage_minibatch_size':
            first_stage_minibatch_size,
            'first_stage_sampler':
            first_stage_sampler,
            'first_stage_non_max_suppression_fn':
            first_stage_non_max_suppression_fn,
            'first_stage_max_proposals':
            first_stage_max_proposals,
            'first_stage_localization_loss_weight':
            first_stage_localization_loss_weight,
            'first_stage_objectness_loss_weight':
            first_stage_objectness_loss_weight,
            'second_stage_target_assigner':
            second_stage_target_assigner,
            'second_stage_batch_size':
            second_stage_batch_size,
            'second_stage_sampler':
            second_stage_sampler,
            'second_stage_non_max_suppression_fn':
            second_stage_non_max_suppression_fn,
            'second_stage_score_conversion_fn':
            second_stage_score_conversion_fn,
            'second_stage_localization_loss_weight':
            second_stage_localization_loss_weight,
            'second_stage_classification_loss_weight':
            second_stage_classification_loss_weight,
            'second_stage_classification_loss':
            second_stage_classification_loss,
            'hard_example_miner':
            hard_example_miner,
            'crop_and_resize_fn':
            crop_and_resize_fn,
            'clip_anchors_to_image':
            clip_anchors_to_image,
            'use_static_shapes':
            use_static_shapes,
            'resize_masks':
            True,
            'return_raw_detections_during_predict':
            return_raw_detections_during_predict
        }

        return self._get_model(
            self._get_second_stage_box_predictor(
                num_classes=num_classes,
                is_training=is_training,
                use_keras=use_keras,
                predict_masks=predict_masks,
                masks_are_class_agnostic=masks_are_class_agnostic,
                share_box_across_classes=share_box_across_classes),
            **common_kwargs)

    @unittest.skipIf(tf_version.is_tf2(), 'Skipping TF1.X only test.')
    @mock.patch.object(context_rcnn_meta_arch, 'context_rcnn_lib')
    def test_prediction_mock_tf1(self, mock_context_rcnn_lib_v1):
        """Mocks the context_rcnn_lib_v1 module to test the prediction.

    Using mock object so that we can ensure compute_box_context_attention is
    called in side the prediction function.

    Args:
      mock_context_rcnn_lib_v1: mock module for the context_rcnn_lib_v1.
    """
        model = self._build_model(is_training=False,
                                  number_of_stages=2,
                                  second_stage_batch_size=6,
                                  num_classes=42)
        mock_tensor = tf.ones([2, 8, 3, 3, 3], tf.float32)

        mock_context_rcnn_lib_v1.compute_box_context_attention.return_value = mock_tensor
        inputs_shape = (2, 20, 20, 3)
        inputs = tf.cast(tf.random_uniform(inputs_shape,
                                           minval=0,
                                           maxval=255,
                                           dtype=tf.int32),
                         dtype=tf.float32)
        preprocessed_inputs, true_image_shapes = model.preprocess(inputs)
        context_features = tf.random_uniform((2, 20, 10),
                                             minval=0,
                                             maxval=255,
                                             dtype=tf.float32)
        valid_context_size = tf.random_uniform((2, ),
                                               minval=0,
                                               maxval=10,
                                               dtype=tf.int32)
        features = {
            fields.InputDataFields.context_features: context_features,
            fields.InputDataFields.valid_context_size: valid_context_size
        }

        side_inputs = model.get_side_inputs(features)

        _ = model.predict(preprocessed_inputs, true_image_shapes,
                          **side_inputs)
        mock_context_rcnn_lib_v1.compute_box_context_attention.assert_called_once(
        )

    @parameterized.named_parameters(
        {
            'testcase_name': 'static_shapes',
            'static_shapes': True
        },
        {
            'testcase_name': 'nostatic_shapes',
            'static_shapes': False
        },
    )
    def test_prediction_end_to_end(self, static_shapes):
        """Runs prediction end to end and test the shape of the results."""
        with test_utils.GraphContextOrNone() as g:
            model = self._build_model(
                is_training=False,
                number_of_stages=2,
                second_stage_batch_size=6,
                use_matmul_crop_and_resize=static_shapes,
                clip_anchors_to_image=static_shapes,
                use_matmul_gather_in_matcher=static_shapes,
                use_static_shapes=static_shapes,
                num_classes=42)

        def graph_fn():
            inputs_shape = (2, 20, 20, 3)
            inputs = tf.cast(tf.random_uniform(inputs_shape,
                                               minval=0,
                                               maxval=255,
                                               dtype=tf.int32),
                             dtype=tf.float32)
            preprocessed_inputs, true_image_shapes = model.preprocess(inputs)
            context_features = tf.random_uniform((2, 20, 10),
                                                 minval=0,
                                                 maxval=255,
                                                 dtype=tf.float32)
            valid_context_size = tf.random_uniform((2, ),
                                                   minval=0,
                                                   maxval=10,
                                                   dtype=tf.int32)
            features = {
                fields.InputDataFields.context_features: context_features,
                fields.InputDataFields.valid_context_size: valid_context_size
            }

            side_inputs = model.get_side_inputs(features)
            prediction_dict = model.predict(preprocessed_inputs,
                                            true_image_shapes, **side_inputs)
            return (prediction_dict['rpn_box_predictor_features'],
                    prediction_dict['rpn_box_encodings'],
                    prediction_dict['refined_box_encodings'],
                    prediction_dict['proposal_boxes_normalized'],
                    prediction_dict['proposal_boxes'])

        execute_fn = self.execute if static_shapes else self.execute_cpu
        (rpn_box_predictor_features, rpn_box_encodings, refined_box_encodings,
         proposal_boxes_normalized, proposal_boxes) = execute_fn(graph_fn, [],
                                                                 graph=g)
        self.assertAllEqual(rpn_box_predictor_features.shape, [2, 20, 20, 512])
        self.assertAllEqual(rpn_box_encodings.shape, [2, 3600, 4])
        self.assertAllEqual(refined_box_encodings.shape, [16, 42, 4])
        self.assertAllEqual(proposal_boxes_normalized.shape, [2, 8, 4])
        self.assertAllEqual(proposal_boxes.shape, [2, 8, 4])
    def test_get_expected_variable_names_with_depthwise(
            self, use_native_resize_op):
        with test_utils.GraphContextOrNone() as g:
            image_features = [
                ('block2', tf.random_uniform([4, 8, 8, 256],
                                             dtype=tf.float32)),
                ('block3', tf.random_uniform([4, 4, 4, 256],
                                             dtype=tf.float32)),
                ('block4', tf.random_uniform([4, 2, 2, 256],
                                             dtype=tf.float32)),
                ('block5', tf.random_uniform([4, 1, 1, 256], dtype=tf.float32))
            ]
            feature_map_generator = self._build_feature_map_generator(
                image_features=image_features,
                depth=128,
                use_depthwise=True,
                use_native_resize_op=use_native_resize_op)

        def graph_fn():
            return feature_map_generator(image_features)

        self.execute(graph_fn, [], g)
        expected_slim_variables = set([
            'projection_1/weights',
            'projection_1/biases',
            'projection_2/weights',
            'projection_2/biases',
            'projection_3/weights',
            'projection_3/biases',
            'projection_4/weights',
            'projection_4/biases',
            'smoothing_1/depthwise_weights',
            'smoothing_1/pointwise_weights',
            'smoothing_1/biases',
            'smoothing_2/depthwise_weights',
            'smoothing_2/pointwise_weights',
            'smoothing_2/biases',
            'smoothing_3/depthwise_weights',
            'smoothing_3/pointwise_weights',
            'smoothing_3/biases',
        ])

        expected_keras_variables = set([
            'FeatureMaps/top_down/projection_1/kernel',
            'FeatureMaps/top_down/projection_1/bias',
            'FeatureMaps/top_down/projection_2/kernel',
            'FeatureMaps/top_down/projection_2/bias',
            'FeatureMaps/top_down/projection_3/kernel',
            'FeatureMaps/top_down/projection_3/bias',
            'FeatureMaps/top_down/projection_4/kernel',
            'FeatureMaps/top_down/projection_4/bias',
            'FeatureMaps/top_down/smoothing_1_depthwise_conv/depthwise_kernel',
            'FeatureMaps/top_down/smoothing_1_depthwise_conv/pointwise_kernel',
            'FeatureMaps/top_down/smoothing_1_depthwise_conv/bias',
            'FeatureMaps/top_down/smoothing_2_depthwise_conv/depthwise_kernel',
            'FeatureMaps/top_down/smoothing_2_depthwise_conv/pointwise_kernel',
            'FeatureMaps/top_down/smoothing_2_depthwise_conv/bias',
            'FeatureMaps/top_down/smoothing_3_depthwise_conv/depthwise_kernel',
            'FeatureMaps/top_down/smoothing_3_depthwise_conv/pointwise_kernel',
            'FeatureMaps/top_down/smoothing_3_depthwise_conv/bias'
        ])

        if tf_version.is_tf2():
            actual_variable_set = set([
                var.name.split(':')[0]
                for var in feature_map_generator.variables
            ])
            self.assertSetEqual(expected_keras_variables, actual_variable_set)
        else:
            with g.as_default():
                actual_variable_set = set(
                    [var.op.name for var in tf.trainable_variables()])
            self.assertSetEqual(expected_slim_variables, actual_variable_set)
class FPNFeatureMapGeneratorTest(test_case.TestCase, parameterized.TestCase):
    def _build_conv_hyperparams(self):
        conv_hyperparams = hyperparams_pb2.Hyperparams()
        conv_hyperparams_text_proto = """
      regularizer {
        l2_regularizer {
        }
      }
      initializer {
        truncated_normal_initializer {
        }
      }
    """
        text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams)
        return hyperparams_builder.KerasLayerHyperparams(conv_hyperparams)

    def _build_feature_map_generator(self,
                                     image_features,
                                     depth,
                                     use_bounded_activations=False,
                                     use_native_resize_op=False,
                                     use_explicit_padding=False,
                                     use_depthwise=False):
        if tf_version.is_tf2():
            return feature_map_generators.KerasFpnTopDownFeatureMaps(
                num_levels=len(image_features),
                depth=depth,
                is_training=True,
                conv_hyperparams=self._build_conv_hyperparams(),
                freeze_batchnorm=False,
                use_depthwise=use_depthwise,
                use_explicit_padding=use_explicit_padding,
                use_bounded_activations=use_bounded_activations,
                use_native_resize_op=use_native_resize_op,
                scope=None,
                name='FeatureMaps',
            )
        else:

            def feature_map_generator(image_features):
                return feature_map_generators.fpn_top_down_feature_maps(
                    image_features=image_features,
                    depth=depth,
                    use_depthwise=use_depthwise,
                    use_explicit_padding=use_explicit_padding,
                    use_bounded_activations=use_bounded_activations,
                    use_native_resize_op=use_native_resize_op)

            return feature_map_generator

    def test_get_expected_feature_map_shapes(self, use_native_resize_op):
        with test_utils.GraphContextOrNone() as g:
            image_features = [
                ('block2', tf.random_uniform([4, 8, 8, 256],
                                             dtype=tf.float32)),
                ('block3', tf.random_uniform([4, 4, 4, 256],
                                             dtype=tf.float32)),
                ('block4', tf.random_uniform([4, 2, 2, 256],
                                             dtype=tf.float32)),
                ('block5', tf.random_uniform([4, 1, 1, 256], dtype=tf.float32))
            ]
            feature_map_generator = self._build_feature_map_generator(
                image_features=image_features,
                depth=128,
                use_native_resize_op=use_native_resize_op)

        def graph_fn():
            return feature_map_generator(image_features)

        expected_feature_map_shapes = {
            'top_down_block2': (4, 8, 8, 128),
            'top_down_block3': (4, 4, 4, 128),
            'top_down_block4': (4, 2, 2, 128),
            'top_down_block5': (4, 1, 1, 128)
        }
        out_feature_maps = self.execute(graph_fn, [], g)
        out_feature_map_shapes = dict(
            (key, value.shape) for key, value in out_feature_maps.items())
        self.assertDictEqual(expected_feature_map_shapes,
                             out_feature_map_shapes)

    def test_get_expected_feature_map_shapes_with_explicit_padding(
            self, use_native_resize_op):
        with test_utils.GraphContextOrNone() as g:
            image_features = [
                ('block2', tf.random_uniform([4, 8, 8, 256],
                                             dtype=tf.float32)),
                ('block3', tf.random_uniform([4, 4, 4, 256],
                                             dtype=tf.float32)),
                ('block4', tf.random_uniform([4, 2, 2, 256],
                                             dtype=tf.float32)),
                ('block5', tf.random_uniform([4, 1, 1, 256], dtype=tf.float32))
            ]
            feature_map_generator = self._build_feature_map_generator(
                image_features=image_features,
                depth=128,
                use_explicit_padding=True,
                use_native_resize_op=use_native_resize_op)

        def graph_fn():
            return feature_map_generator(image_features)

        expected_feature_map_shapes = {
            'top_down_block2': (4, 8, 8, 128),
            'top_down_block3': (4, 4, 4, 128),
            'top_down_block4': (4, 2, 2, 128),
            'top_down_block5': (4, 1, 1, 128)
        }
        out_feature_maps = self.execute(graph_fn, [], g)
        out_feature_map_shapes = dict(
            (key, value.shape) for key, value in out_feature_maps.items())
        self.assertDictEqual(expected_feature_map_shapes,
                             out_feature_map_shapes)

    @unittest.skipIf(tf_version.is_tf2(), 'Skipping TF1.X only test.')
    def test_use_bounded_activations_add_operations(self,
                                                    use_native_resize_op):
        with test_utils.GraphContextOrNone() as g:
            image_features = [
                ('block2', tf.random_uniform([4, 8, 8, 256],
                                             dtype=tf.float32)),
                ('block3', tf.random_uniform([4, 4, 4, 256],
                                             dtype=tf.float32)),
                ('block4', tf.random_uniform([4, 2, 2, 256],
                                             dtype=tf.float32)),
                ('block5', tf.random_uniform([4, 1, 1, 256], dtype=tf.float32))
            ]
            feature_map_generator = self._build_feature_map_generator(
                image_features=image_features,
                depth=128,
                use_bounded_activations=True,
                use_native_resize_op=use_native_resize_op)

        def graph_fn():
            return feature_map_generator(image_features)

        self.execute(graph_fn, [], g)
        expected_added_operations = dict.fromkeys([
            'top_down/clip_by_value', 'top_down/clip_by_value_1',
            'top_down/clip_by_value_2', 'top_down/clip_by_value_3',
            'top_down/clip_by_value_4', 'top_down/clip_by_value_5',
            'top_down/clip_by_value_6'
        ])
        op_names = {op.name: None for op in g.get_operations()}
        self.assertDictContainsSubset(expected_added_operations, op_names)

    @unittest.skipIf(tf_version.is_tf2(), 'Skipping TF1.X only test.')
    def test_use_bounded_activations_clip_value(self, use_native_resize_op):
        tf_graph = tf.Graph()
        with tf_graph.as_default():
            image_features = [
                ('block2', 255 * tf.ones([4, 8, 8, 256], dtype=tf.float32)),
                ('block3', 255 * tf.ones([4, 4, 4, 256], dtype=tf.float32)),
                ('block4', 255 * tf.ones([4, 2, 2, 256], dtype=tf.float32)),
                ('block5', 255 * tf.ones([4, 1, 1, 256], dtype=tf.float32))
            ]
            feature_map_generator = self._build_feature_map_generator(
                image_features=image_features,
                depth=128,
                use_bounded_activations=True,
                use_native_resize_op=use_native_resize_op)
            feature_map_generator(image_features)

            expected_clip_by_value_ops = [
                'top_down/clip_by_value', 'top_down/clip_by_value_1',
                'top_down/clip_by_value_2', 'top_down/clip_by_value_3',
                'top_down/clip_by_value_4', 'top_down/clip_by_value_5',
                'top_down/clip_by_value_6'
            ]

            # Gathers activation tensors before and after clip_by_value operations.
            activations = {}
            for clip_by_value_op in expected_clip_by_value_ops:
                clip_input_tensor = tf_graph.get_operation_by_name(
                    '{}/Minimum'.format(clip_by_value_op)).inputs[0]
                clip_output_tensor = tf_graph.get_tensor_by_name(
                    '{}:0'.format(clip_by_value_op))
                activations.update({
                    'before_{}'.format(clip_by_value_op):
                    clip_input_tensor,
                    'after_{}'.format(clip_by_value_op):
                    clip_output_tensor,
                })

            expected_lower_bound = -feature_map_generators.ACTIVATION_BOUND
            expected_upper_bound = feature_map_generators.ACTIVATION_BOUND
            init_op = tf.global_variables_initializer()
            with self.test_session() as session:
                session.run(init_op)
                activations_output = session.run(activations)
                for clip_by_value_op in expected_clip_by_value_ops:
                    # Before clipping, activations are beyound the expected bound because
                    # of large input image_features values.
                    activations_before_clipping = (activations_output[
                        'before_{}'.format(clip_by_value_op)])
                    before_clipping_lower_bound = np.amin(
                        activations_before_clipping)
                    before_clipping_upper_bound = np.amax(
                        activations_before_clipping)
                    self.assertLessEqual(before_clipping_lower_bound,
                                         expected_lower_bound)
                    self.assertGreaterEqual(before_clipping_upper_bound,
                                            expected_upper_bound)

                    # After clipping, activations are bounded as expectation.
                    activations_after_clipping = (activations_output[
                        'after_{}'.format(clip_by_value_op)])
                    after_clipping_lower_bound = np.amin(
                        activations_after_clipping)
                    after_clipping_upper_bound = np.amax(
                        activations_after_clipping)
                    self.assertGreaterEqual(after_clipping_lower_bound,
                                            expected_lower_bound)
                    self.assertLessEqual(after_clipping_upper_bound,
                                         expected_upper_bound)

    def test_get_expected_feature_map_shapes_with_depthwise(
            self, use_native_resize_op):
        with test_utils.GraphContextOrNone() as g:
            image_features = [
                ('block2', tf.random_uniform([4, 8, 8, 256],
                                             dtype=tf.float32)),
                ('block3', tf.random_uniform([4, 4, 4, 256],
                                             dtype=tf.float32)),
                ('block4', tf.random_uniform([4, 2, 2, 256],
                                             dtype=tf.float32)),
                ('block5', tf.random_uniform([4, 1, 1, 256], dtype=tf.float32))
            ]
            feature_map_generator = self._build_feature_map_generator(
                image_features=image_features,
                depth=128,
                use_depthwise=True,
                use_native_resize_op=use_native_resize_op)

        def graph_fn():
            return feature_map_generator(image_features)

        expected_feature_map_shapes = {
            'top_down_block2': (4, 8, 8, 128),
            'top_down_block3': (4, 4, 4, 128),
            'top_down_block4': (4, 2, 2, 128),
            'top_down_block5': (4, 1, 1, 128)
        }
        out_feature_maps = self.execute(graph_fn, [], g)
        out_feature_map_shapes = dict(
            (key, value.shape) for key, value in out_feature_maps.items())
        self.assertDictEqual(expected_feature_map_shapes,
                             out_feature_map_shapes)

    def test_get_expected_variable_names(self, use_native_resize_op):
        with test_utils.GraphContextOrNone() as g:
            image_features = [
                ('block2', tf.random_uniform([4, 8, 8, 256],
                                             dtype=tf.float32)),
                ('block3', tf.random_uniform([4, 4, 4, 256],
                                             dtype=tf.float32)),
                ('block4', tf.random_uniform([4, 2, 2, 256],
                                             dtype=tf.float32)),
                ('block5', tf.random_uniform([4, 1, 1, 256], dtype=tf.float32))
            ]
            feature_map_generator = self._build_feature_map_generator(
                image_features=image_features,
                depth=128,
                use_native_resize_op=use_native_resize_op)

        def graph_fn():
            return feature_map_generator(image_features)

        self.execute(graph_fn, [], g)
        expected_slim_variables = set([
            'projection_1/weights',
            'projection_1/biases',
            'projection_2/weights',
            'projection_2/biases',
            'projection_3/weights',
            'projection_3/biases',
            'projection_4/weights',
            'projection_4/biases',
            'smoothing_1/weights',
            'smoothing_1/biases',
            'smoothing_2/weights',
            'smoothing_2/biases',
            'smoothing_3/weights',
            'smoothing_3/biases',
        ])

        expected_keras_variables = set([
            'FeatureMaps/top_down/projection_1/kernel',
            'FeatureMaps/top_down/projection_1/bias',
            'FeatureMaps/top_down/projection_2/kernel',
            'FeatureMaps/top_down/projection_2/bias',
            'FeatureMaps/top_down/projection_3/kernel',
            'FeatureMaps/top_down/projection_3/bias',
            'FeatureMaps/top_down/projection_4/kernel',
            'FeatureMaps/top_down/projection_4/bias',
            'FeatureMaps/top_down/smoothing_1_conv/kernel',
            'FeatureMaps/top_down/smoothing_1_conv/bias',
            'FeatureMaps/top_down/smoothing_2_conv/kernel',
            'FeatureMaps/top_down/smoothing_2_conv/bias',
            'FeatureMaps/top_down/smoothing_3_conv/kernel',
            'FeatureMaps/top_down/smoothing_3_conv/bias'
        ])

        if tf_version.is_tf2():
            actual_variable_set = set([
                var.name.split(':')[0]
                for var in feature_map_generator.variables
            ])
            self.assertSetEqual(expected_keras_variables, actual_variable_set)
        else:
            with g.as_default():
                actual_variable_set = set(
                    [var.op.name for var in tf.trainable_variables()])
            self.assertSetEqual(expected_slim_variables, actual_variable_set)

    def test_get_expected_variable_names_with_depthwise(
            self, use_native_resize_op):
        with test_utils.GraphContextOrNone() as g:
            image_features = [
                ('block2', tf.random_uniform([4, 8, 8, 256],
                                             dtype=tf.float32)),
                ('block3', tf.random_uniform([4, 4, 4, 256],
                                             dtype=tf.float32)),
                ('block4', tf.random_uniform([4, 2, 2, 256],
                                             dtype=tf.float32)),
                ('block5', tf.random_uniform([4, 1, 1, 256], dtype=tf.float32))
            ]
            feature_map_generator = self._build_feature_map_generator(
                image_features=image_features,
                depth=128,
                use_depthwise=True,
                use_native_resize_op=use_native_resize_op)

        def graph_fn():
            return feature_map_generator(image_features)

        self.execute(graph_fn, [], g)
        expected_slim_variables = set([
            'projection_1/weights',
            'projection_1/biases',
            'projection_2/weights',
            'projection_2/biases',
            'projection_3/weights',
            'projection_3/biases',
            'projection_4/weights',
            'projection_4/biases',
            'smoothing_1/depthwise_weights',
            'smoothing_1/pointwise_weights',
            'smoothing_1/biases',
            'smoothing_2/depthwise_weights',
            'smoothing_2/pointwise_weights',
            'smoothing_2/biases',
            'smoothing_3/depthwise_weights',
            'smoothing_3/pointwise_weights',
            'smoothing_3/biases',
        ])

        expected_keras_variables = set([
            'FeatureMaps/top_down/projection_1/kernel',
            'FeatureMaps/top_down/projection_1/bias',
            'FeatureMaps/top_down/projection_2/kernel',
            'FeatureMaps/top_down/projection_2/bias',
            'FeatureMaps/top_down/projection_3/kernel',
            'FeatureMaps/top_down/projection_3/bias',
            'FeatureMaps/top_down/projection_4/kernel',
            'FeatureMaps/top_down/projection_4/bias',
            'FeatureMaps/top_down/smoothing_1_depthwise_conv/depthwise_kernel',
            'FeatureMaps/top_down/smoothing_1_depthwise_conv/pointwise_kernel',
            'FeatureMaps/top_down/smoothing_1_depthwise_conv/bias',
            'FeatureMaps/top_down/smoothing_2_depthwise_conv/depthwise_kernel',
            'FeatureMaps/top_down/smoothing_2_depthwise_conv/pointwise_kernel',
            'FeatureMaps/top_down/smoothing_2_depthwise_conv/bias',
            'FeatureMaps/top_down/smoothing_3_depthwise_conv/depthwise_kernel',
            'FeatureMaps/top_down/smoothing_3_depthwise_conv/pointwise_kernel',
            'FeatureMaps/top_down/smoothing_3_depthwise_conv/bias'
        ])

        if tf_version.is_tf2():
            actual_variable_set = set([
                var.name.split(':')[0]
                for var in feature_map_generator.variables
            ])
            self.assertSetEqual(expected_keras_variables, actual_variable_set)
        else:
            with g.as_default():
                actual_variable_set = set(
                    [var.op.name for var in tf.trainable_variables()])
            self.assertSetEqual(expected_slim_variables, actual_variable_set)
        for i in range(self.od_eval.num_class):
            self.assertTrue(
                np.allclose(copy_precisions_per_class[i],
                            precisions_per_class[i]))
            self.assertTrue(
                np.allclose(copy_recalls_per_class[i], recalls_per_class[i]))
        self.assertTrue(
            np.allclose(copy_average_precision_per_class,
                        average_precision_per_class))
        self.assertTrue(np.allclose(copy_corloc_per_class, corloc_per_class))
        self.assertAlmostEqual(copy_mean_ap, mean_ap)
        self.assertAlmostEqual(copy_mean_corloc, mean_corloc)


@unittest.skipIf(tf_version.is_tf2(),
                 'Eval Metrics ops are supported in TF1.X '
                 'only.')
class ObjectDetectionEvaluatorTest(tf.test.TestCase, parameterized.TestCase):
    def setUp(self):
        self.categories = [{
            'id': 1,
            'name': 'person'
        }, {
            'id': 2,
            'name': 'dog'
        }, {
            'id': 3,
            'name': 'cat'
        }]
        self.od_eval = object_detection_evaluation.ObjectDetectionEvaluator(
Exemple #17
0
 def is_tf2(self):
     """Returns whether TF2 is enabled."""
     return tf_version.is_tf2()
Exemple #18
0
    def _build_model(self,
                     is_training,
                     number_of_stages,
                     second_stage_batch_size,
                     first_stage_max_proposals=8,
                     num_classes=2,
                     hard_mining=False,
                     softmax_second_stage_classification_loss=True,
                     predict_masks=False,
                     pad_to_max_dimension=None,
                     masks_are_class_agnostic=False,
                     use_matmul_crop_and_resize=False,
                     clip_anchors_to_image=False,
                     use_matmul_gather_in_matcher=False,
                     use_static_shapes=False,
                     calibration_mapping_value=None,
                     share_box_across_classes=False,
                     return_raw_detections_during_predict=False):
        use_keras = tf_version.is_tf2()

        def image_resizer_fn(image, masks=None):
            """Fake image resizer function."""
            resized_inputs = []
            resized_image = tf.identity(image)
            if pad_to_max_dimension is not None:
                resized_image = tf.image.pad_to_bounding_box(
                    image, 0, 0, pad_to_max_dimension, pad_to_max_dimension)
            resized_inputs.append(resized_image)
            if masks is not None:
                resized_masks = tf.identity(masks)
                if pad_to_max_dimension is not None:
                    resized_masks = tf.image.pad_to_bounding_box(
                        tf.transpose(masks, [1, 2, 0]), 0, 0,
                        pad_to_max_dimension, pad_to_max_dimension)
                    resized_masks = tf.transpose(resized_masks, [2, 0, 1])
                resized_inputs.append(resized_masks)
            resized_inputs.append(tf.shape(image))
            return resized_inputs

        # anchors in this test are designed so that a subset of anchors are inside
        # the image and a subset of anchors are outside.
        first_stage_anchor_scales = (0.001, 0.005, 0.1)
        first_stage_anchor_aspect_ratios = (0.5, 1.0, 2.0)
        first_stage_anchor_strides = (1, 1)
        first_stage_anchor_generator = grid_anchor_generator.GridAnchorGenerator(
            first_stage_anchor_scales,
            first_stage_anchor_aspect_ratios,
            anchor_stride=first_stage_anchor_strides)
        first_stage_target_assigner = target_assigner.create_target_assigner(
            'FasterRCNN',
            'proposal',
            use_matmul_gather=use_matmul_gather_in_matcher)

        if use_keras:
            fake_feature_extractor = FakeFasterRCNNKerasFeatureExtractor()
        else:
            fake_feature_extractor = FakeFasterRCNNFeatureExtractor()

        first_stage_box_predictor_hyperparams_text_proto = """
      op: CONV
      activation: RELU
      regularizer {
        l2_regularizer {
          weight: 0.00004
        }
      }
      initializer {
        truncated_normal_initializer {
          stddev: 0.03
        }
      }
    """
        if use_keras:
            first_stage_box_predictor_arg_scope_fn = (
                self._build_keras_layer_hyperparams(
                    first_stage_box_predictor_hyperparams_text_proto))
        else:
            first_stage_box_predictor_arg_scope_fn = (
                self._build_arg_scope_with_hyperparams(
                    first_stage_box_predictor_hyperparams_text_proto,
                    is_training))

        first_stage_box_predictor_kernel_size = 3
        first_stage_atrous_rate = 1
        first_stage_box_predictor_depth = 512
        first_stage_minibatch_size = 3
        first_stage_sampler = sampler.BalancedPositiveNegativeSampler(
            positive_fraction=0.5, is_static=use_static_shapes)

        first_stage_nms_score_threshold = -1.0
        first_stage_nms_iou_threshold = 1.0
        first_stage_max_proposals = first_stage_max_proposals
        first_stage_non_max_suppression_fn = functools.partial(
            post_processing.batch_multiclass_non_max_suppression,
            score_thresh=first_stage_nms_score_threshold,
            iou_thresh=first_stage_nms_iou_threshold,
            max_size_per_class=first_stage_max_proposals,
            max_total_size=first_stage_max_proposals,
            use_static_shapes=use_static_shapes)

        first_stage_localization_loss_weight = 1.0
        first_stage_objectness_loss_weight = 1.0

        post_processing_config = post_processing_pb2.PostProcessing()
        post_processing_text_proto = """
      score_converter: IDENTITY
      batch_non_max_suppression {
        score_threshold: -20.0
        iou_threshold: 1.0
        max_detections_per_class: 5
        max_total_detections: 5
        use_static_shapes: """ + '{}'.format(use_static_shapes) + """
      }
    """
        if calibration_mapping_value:
            calibration_text_proto = """
      calibration_config {
        function_approximation {
          x_y_pairs {
            x_y_pair {
              x: 0.0
              y: %f
            }
            x_y_pair {
              x: 1.0
              y: %f
              }}}}""" % (calibration_mapping_value, calibration_mapping_value)
            post_processing_text_proto = (post_processing_text_proto + ' ' +
                                          calibration_text_proto)
        text_format.Merge(post_processing_text_proto, post_processing_config)
        second_stage_non_max_suppression_fn, second_stage_score_conversion_fn = (
            post_processing_builder.build(post_processing_config))

        second_stage_target_assigner = target_assigner.create_target_assigner(
            'FasterRCNN',
            'detection',
            use_matmul_gather=use_matmul_gather_in_matcher)
        second_stage_sampler = sampler.BalancedPositiveNegativeSampler(
            positive_fraction=1.0, is_static=use_static_shapes)

        second_stage_localization_loss_weight = 1.0
        second_stage_classification_loss_weight = 1.0
        if softmax_second_stage_classification_loss:
            second_stage_classification_loss = (
                losses.WeightedSoftmaxClassificationLoss())
        else:
            second_stage_classification_loss = (
                losses.WeightedSigmoidClassificationLoss())

        hard_example_miner = None
        if hard_mining:
            hard_example_miner = losses.HardExampleMiner(
                num_hard_examples=1,
                iou_threshold=0.99,
                loss_type='both',
                cls_loss_weight=second_stage_classification_loss_weight,
                loc_loss_weight=second_stage_localization_loss_weight,
                max_negatives_per_positive=None)

        crop_and_resize_fn = (ops.matmul_crop_and_resize
                              if use_matmul_crop_and_resize else
                              ops.native_crop_and_resize)
        common_kwargs = {
            'is_training':
            is_training,
            'num_classes':
            num_classes,
            'image_resizer_fn':
            image_resizer_fn,
            'feature_extractor':
            fake_feature_extractor,
            'number_of_stages':
            number_of_stages,
            'first_stage_anchor_generator':
            first_stage_anchor_generator,
            'first_stage_target_assigner':
            first_stage_target_assigner,
            'first_stage_atrous_rate':
            first_stage_atrous_rate,
            'first_stage_box_predictor_arg_scope_fn':
            first_stage_box_predictor_arg_scope_fn,
            'first_stage_box_predictor_kernel_size':
            first_stage_box_predictor_kernel_size,
            'first_stage_box_predictor_depth':
            first_stage_box_predictor_depth,
            'first_stage_minibatch_size':
            first_stage_minibatch_size,
            'first_stage_sampler':
            first_stage_sampler,
            'first_stage_non_max_suppression_fn':
            first_stage_non_max_suppression_fn,
            'first_stage_max_proposals':
            first_stage_max_proposals,
            'first_stage_localization_loss_weight':
            first_stage_localization_loss_weight,
            'first_stage_objectness_loss_weight':
            first_stage_objectness_loss_weight,
            'second_stage_target_assigner':
            second_stage_target_assigner,
            'second_stage_batch_size':
            second_stage_batch_size,
            'second_stage_sampler':
            second_stage_sampler,
            'second_stage_non_max_suppression_fn':
            second_stage_non_max_suppression_fn,
            'second_stage_score_conversion_fn':
            second_stage_score_conversion_fn,
            'second_stage_localization_loss_weight':
            second_stage_localization_loss_weight,
            'second_stage_classification_loss_weight':
            second_stage_classification_loss_weight,
            'second_stage_classification_loss':
            second_stage_classification_loss,
            'hard_example_miner':
            hard_example_miner,
            'crop_and_resize_fn':
            crop_and_resize_fn,
            'clip_anchors_to_image':
            clip_anchors_to_image,
            'use_static_shapes':
            use_static_shapes,
            'resize_masks':
            True,
            'return_raw_detections_during_predict':
            return_raw_detections_during_predict
        }

        return self._get_model(
            self._get_second_stage_box_predictor(
                num_classes=num_classes,
                is_training=is_training,
                use_keras=use_keras,
                predict_masks=predict_masks,
                masks_are_class_agnostic=masks_are_class_agnostic,
                share_box_across_classes=share_box_across_classes),
            **common_kwargs)
Exemple #19
0
from object_detection.meta_architectures import faster_rcnn_meta_arch
from object_detection.meta_architectures import rfcn_meta_arch
from object_detection.meta_architectures import ssd_meta_arch
from object_detection.predictors.heads import mask_head
from object_detection.protos import losses_pb2
from object_detection.protos import model_pb2
from object_detection.utils import label_map_util
from object_detection.utils import ops
from object_detection.utils import tf_version

## Feature Extractors for TF
## This section conditionally imports different feature extractors based on the
## Tensorflow version.
##
# pylint: disable=g-import-not-at-top
if tf_version.is_tf2():
    from object_detection.models import center_net_hourglass_feature_extractor
    from object_detection.models import center_net_mobilenet_v2_feature_extractor
    from object_detection.models import center_net_resnet_feature_extractor
    from object_detection.models import center_net_resnet_v1_fpn_feature_extractor
    from object_detection.models import faster_rcnn_inception_resnet_v2_keras_feature_extractor as frcnn_inc_res_keras
    from object_detection.models import faster_rcnn_resnet_keras_feature_extractor as frcnn_resnet_keras
    from object_detection.models import ssd_resnet_v1_fpn_keras_feature_extractor as ssd_resnet_v1_fpn_keras
    from object_detection.models import faster_rcnn_resnet_v1_fpn_keras_feature_extractor as frcnn_resnet_fpn_keras
    from object_detection.models.ssd_mobilenet_v1_fpn_keras_feature_extractor import SSDMobileNetV1FpnKerasFeatureExtractor
    from object_detection.models.ssd_mobilenet_v1_keras_feature_extractor import SSDMobileNetV1KerasFeatureExtractor
    from object_detection.models.ssd_mobilenet_v2_fpn_keras_feature_extractor import SSDMobileNetV2FpnKerasFeatureExtractor
    from object_detection.models.ssd_mobilenet_v2_keras_feature_extractor import SSDMobileNetV2KerasFeatureExtractor
    from object_detection.predictors import rfcn_keras_box_predictor
    if sys.version_info[0] >= 3:
        from object_detection.models import ssd_efficientnet_bifpn_feature_extractor as ssd_efficientnet_bifpn
Exemple #20
0
class EvalUtilTest(test_case.TestCase, parameterized.TestCase):
    def _get_categories_list(self):
        return [{
            'id': 1,
            'name': 'person'
        }, {
            'id': 2,
            'name': 'dog'
        }, {
            'id': 3,
            'name': 'cat'
        }]

    def _get_categories_list_with_keypoints(self):
        return [{
            'id': 1,
            'name': 'person',
            'keypoints': {
                'left_eye': 0,
                'right_eye': 3
            }
        }, {
            'id': 2,
            'name': 'dog',
            'keypoints': {
                'tail_start': 1,
                'mouth': 2
            }
        }, {
            'id': 3,
            'name': 'cat'
        }]

    def _make_evaluation_dict(self,
                              resized_groundtruth_masks=False,
                              batch_size=1,
                              max_gt_boxes=None,
                              scale_to_absolute=False):
        input_data_fields = fields.InputDataFields
        detection_fields = fields.DetectionResultFields

        image = tf.zeros(shape=[batch_size, 20, 20, 3], dtype=tf.uint8)
        if batch_size == 1:
            key = tf.constant('image1')
        else:
            key = tf.constant([str(i) for i in range(batch_size)])
        detection_boxes = tf.tile(tf.constant([[[0., 0., 1., 1.]]]),
                                  multiples=[batch_size, 1, 1])
        detection_scores = tf.tile(tf.constant([[0.8]]),
                                   multiples=[batch_size, 1])
        detection_classes = tf.tile(tf.constant([[0]]),
                                    multiples=[batch_size, 1])
        detection_masks = tf.tile(tf.ones(shape=[1, 1, 20, 20],
                                          dtype=tf.float32),
                                  multiples=[batch_size, 1, 1, 1])
        num_detections = tf.ones([batch_size])
        groundtruth_boxes = tf.constant([[0., 0., 1., 1.]])
        groundtruth_classes = tf.constant([1])
        groundtruth_instance_masks = tf.ones(shape=[1, 20, 20], dtype=tf.uint8)
        groundtruth_keypoints = tf.constant([[0.0, 0.0], [0.5, 0.5],
                                             [1.0, 1.0]])
        if resized_groundtruth_masks:
            groundtruth_instance_masks = tf.ones(shape=[1, 10, 10],
                                                 dtype=tf.uint8)

        if batch_size > 1:
            groundtruth_boxes = tf.tile(tf.expand_dims(groundtruth_boxes, 0),
                                        multiples=[batch_size, 1, 1])
            groundtruth_classes = tf.tile(tf.expand_dims(
                groundtruth_classes, 0),
                                          multiples=[batch_size, 1])
            groundtruth_instance_masks = tf.tile(
                tf.expand_dims(groundtruth_instance_masks, 0),
                multiples=[batch_size, 1, 1, 1])
            groundtruth_keypoints = tf.tile(tf.expand_dims(
                groundtruth_keypoints, 0),
                                            multiples=[batch_size, 1, 1])

        detections = {
            detection_fields.detection_boxes: detection_boxes,
            detection_fields.detection_scores: detection_scores,
            detection_fields.detection_classes: detection_classes,
            detection_fields.detection_masks: detection_masks,
            detection_fields.num_detections: num_detections
        }
        groundtruth = {
            input_data_fields.groundtruth_boxes:
            groundtruth_boxes,
            input_data_fields.groundtruth_classes:
            groundtruth_classes,
            input_data_fields.groundtruth_keypoints:
            groundtruth_keypoints,
            input_data_fields.groundtruth_instance_masks:
            groundtruth_instance_masks
        }
        if batch_size > 1:
            return eval_util.result_dict_for_batched_example(
                image,
                key,
                detections,
                groundtruth,
                scale_to_absolute=scale_to_absolute,
                max_gt_boxes=max_gt_boxes)
        else:
            return eval_util.result_dict_for_single_example(
                image,
                key,
                detections,
                groundtruth,
                scale_to_absolute=scale_to_absolute)

    @parameterized.parameters(
        {
            'batch_size': 1,
            'max_gt_boxes': None,
            'scale_to_absolute': True
        }, {
            'batch_size': 8,
            'max_gt_boxes': [1],
            'scale_to_absolute': True
        }, {
            'batch_size': 1,
            'max_gt_boxes': None,
            'scale_to_absolute': False
        }, {
            'batch_size': 8,
            'max_gt_boxes': [1],
            'scale_to_absolute': False
        })
    @unittest.skipIf(tf_version.is_tf2(), 'Only compatible with TF1.X')
    def test_get_eval_metric_ops_for_coco_detections(self,
                                                     batch_size=1,
                                                     max_gt_boxes=None,
                                                     scale_to_absolute=False):
        eval_config = eval_pb2.EvalConfig()
        eval_config.metrics_set.extend(['coco_detection_metrics'])
        categories = self._get_categories_list()
        eval_dict = self._make_evaluation_dict(
            batch_size=batch_size,
            max_gt_boxes=max_gt_boxes,
            scale_to_absolute=scale_to_absolute)
        metric_ops = eval_util.get_eval_metric_ops_for_evaluators(
            eval_config, categories, eval_dict)
        _, update_op = metric_ops['DetectionBoxes_Precision/mAP']

        with self.test_session() as sess:
            metrics = {}
            for key, (value_op, _) in six.iteritems(metric_ops):
                metrics[key] = value_op
            sess.run(update_op)
            metrics = sess.run(metrics)
            self.assertAlmostEqual(1.0,
                                   metrics['DetectionBoxes_Precision/mAP'])
            self.assertNotIn('DetectionMasks_Precision/mAP', metrics)

    @parameterized.parameters(
        {
            'batch_size': 1,
            'max_gt_boxes': None,
            'scale_to_absolute': True
        }, {
            'batch_size': 8,
            'max_gt_boxes': [1],
            'scale_to_absolute': True
        }, {
            'batch_size': 1,
            'max_gt_boxes': None,
            'scale_to_absolute': False
        }, {
            'batch_size': 8,
            'max_gt_boxes': [1],
            'scale_to_absolute': False
        })
    @unittest.skipIf(tf_version.is_tf2(), 'Only compatible with TF1.X')
    def test_get_eval_metric_ops_for_coco_detections_and_masks(
            self, batch_size=1, max_gt_boxes=None, scale_to_absolute=False):
        eval_config = eval_pb2.EvalConfig()
        eval_config.metrics_set.extend(
            ['coco_detection_metrics', 'coco_mask_metrics'])
        categories = self._get_categories_list()
        eval_dict = self._make_evaluation_dict(
            batch_size=batch_size,
            max_gt_boxes=max_gt_boxes,
            scale_to_absolute=scale_to_absolute)
        metric_ops = eval_util.get_eval_metric_ops_for_evaluators(
            eval_config, categories, eval_dict)
        _, update_op_boxes = metric_ops['DetectionBoxes_Precision/mAP']
        _, update_op_masks = metric_ops['DetectionMasks_Precision/mAP']

        with self.test_session() as sess:
            metrics = {}
            for key, (value_op, _) in six.iteritems(metric_ops):
                metrics[key] = value_op
            sess.run(update_op_boxes)
            sess.run(update_op_masks)
            metrics = sess.run(metrics)
            self.assertAlmostEqual(1.0,
                                   metrics['DetectionBoxes_Precision/mAP'])
            self.assertAlmostEqual(1.0,
                                   metrics['DetectionMasks_Precision/mAP'])

    @parameterized.parameters(
        {
            'batch_size': 1,
            'max_gt_boxes': None,
            'scale_to_absolute': True
        }, {
            'batch_size': 8,
            'max_gt_boxes': [1],
            'scale_to_absolute': True
        }, {
            'batch_size': 1,
            'max_gt_boxes': None,
            'scale_to_absolute': False
        }, {
            'batch_size': 8,
            'max_gt_boxes': [1],
            'scale_to_absolute': False
        })
    @unittest.skipIf(tf_version.is_tf2(), 'Only compatible with TF1.X')
    def test_get_eval_metric_ops_for_coco_detections_and_resized_masks(
            self, batch_size=1, max_gt_boxes=None, scale_to_absolute=False):
        eval_config = eval_pb2.EvalConfig()
        eval_config.metrics_set.extend(
            ['coco_detection_metrics', 'coco_mask_metrics'])
        categories = self._get_categories_list()
        eval_dict = self._make_evaluation_dict(
            batch_size=batch_size,
            max_gt_boxes=max_gt_boxes,
            scale_to_absolute=scale_to_absolute,
            resized_groundtruth_masks=True)
        metric_ops = eval_util.get_eval_metric_ops_for_evaluators(
            eval_config, categories, eval_dict)
        _, update_op_boxes = metric_ops['DetectionBoxes_Precision/mAP']
        _, update_op_masks = metric_ops['DetectionMasks_Precision/mAP']

        with self.test_session() as sess:
            metrics = {}
            for key, (value_op, _) in six.iteritems(metric_ops):
                metrics[key] = value_op
            sess.run(update_op_boxes)
            sess.run(update_op_masks)
            metrics = sess.run(metrics)
            self.assertAlmostEqual(1.0,
                                   metrics['DetectionBoxes_Precision/mAP'])
            self.assertAlmostEqual(1.0,
                                   metrics['DetectionMasks_Precision/mAP'])

    @unittest.skipIf(tf_version.is_tf2(), 'Only compatible with TF1.X')
    def test_get_eval_metric_ops_raises_error_with_unsupported_metric(self):
        eval_config = eval_pb2.EvalConfig()
        eval_config.metrics_set.extend(['unsupported_metric'])
        categories = self._get_categories_list()
        eval_dict = self._make_evaluation_dict()
        with self.assertRaises(ValueError):
            eval_util.get_eval_metric_ops_for_evaluators(
                eval_config, categories, eval_dict)

    def test_get_eval_metric_ops_for_evaluators(self):
        eval_config = eval_pb2.EvalConfig()
        eval_config.metrics_set.extend([
            'coco_detection_metrics', 'coco_mask_metrics',
            'precision_at_recall_detection_metrics'
        ])
        eval_config.include_metrics_per_category = True
        eval_config.recall_lower_bound = 0.2
        eval_config.recall_upper_bound = 0.6

        evaluator_options = eval_util.evaluator_options_from_eval_config(
            eval_config)
        self.assertTrue(evaluator_options['coco_detection_metrics']
                        ['include_metrics_per_category'])
        self.assertTrue(evaluator_options['coco_mask_metrics']
                        ['include_metrics_per_category'])
        self.assertAlmostEqual(
            evaluator_options['precision_at_recall_detection_metrics']
            ['recall_lower_bound'], eval_config.recall_lower_bound)
        self.assertAlmostEqual(
            evaluator_options['precision_at_recall_detection_metrics']
            ['recall_upper_bound'], eval_config.recall_upper_bound)

    def test_get_evaluator_with_evaluator_options(self):
        eval_config = eval_pb2.EvalConfig()
        eval_config.metrics_set.extend([
            'coco_detection_metrics', 'precision_at_recall_detection_metrics'
        ])
        eval_config.include_metrics_per_category = True
        eval_config.recall_lower_bound = 0.2
        eval_config.recall_upper_bound = 0.6
        categories = self._get_categories_list()

        evaluator_options = eval_util.evaluator_options_from_eval_config(
            eval_config)
        evaluator = eval_util.get_evaluators(eval_config, categories,
                                             evaluator_options)

        self.assertTrue(evaluator[0]._include_metrics_per_category)
        self.assertAlmostEqual(evaluator[1]._recall_lower_bound,
                               eval_config.recall_lower_bound)
        self.assertAlmostEqual(evaluator[1]._recall_upper_bound,
                               eval_config.recall_upper_bound)

    def test_get_evaluator_with_no_evaluator_options(self):
        eval_config = eval_pb2.EvalConfig()
        eval_config.metrics_set.extend([
            'coco_detection_metrics', 'precision_at_recall_detection_metrics'
        ])
        eval_config.include_metrics_per_category = True
        eval_config.recall_lower_bound = 0.2
        eval_config.recall_upper_bound = 0.6
        categories = self._get_categories_list()

        evaluator = eval_util.get_evaluators(eval_config,
                                             categories,
                                             evaluator_options=None)

        # Even though we are setting eval_config.include_metrics_per_category = True
        # and bounds on recall, these options are never passed into the
        # DetectionEvaluator constructor (via `evaluator_options`).
        self.assertFalse(evaluator[0]._include_metrics_per_category)
        self.assertAlmostEqual(evaluator[1]._recall_lower_bound, 0.0)
        self.assertAlmostEqual(evaluator[1]._recall_upper_bound, 1.0)

    def test_get_evaluator_with_keypoint_metrics(self):
        eval_config = eval_pb2.EvalConfig()
        person_keypoints_metric = eval_config.parameterized_metric.add()
        person_keypoints_metric.coco_keypoint_metrics.class_label = 'person'
        person_keypoints_metric.coco_keypoint_metrics.keypoint_label_to_sigmas[
            'left_eye'] = 0.1
        person_keypoints_metric.coco_keypoint_metrics.keypoint_label_to_sigmas[
            'right_eye'] = 0.2
        dog_keypoints_metric = eval_config.parameterized_metric.add()
        dog_keypoints_metric.coco_keypoint_metrics.class_label = 'dog'
        dog_keypoints_metric.coco_keypoint_metrics.keypoint_label_to_sigmas[
            'tail_start'] = 0.3
        dog_keypoints_metric.coco_keypoint_metrics.keypoint_label_to_sigmas[
            'mouth'] = 0.4
        categories = self._get_categories_list_with_keypoints()

        evaluator = eval_util.get_evaluators(eval_config,
                                             categories,
                                             evaluator_options=None)

        # Verify keypoint evaluator class variables.
        self.assertLen(evaluator, 3)
        self.assertFalse(evaluator[0]._include_metrics_per_category)
        self.assertEqual(evaluator[1]._category_name, 'person')
        self.assertEqual(evaluator[2]._category_name, 'dog')
        self.assertAllEqual(evaluator[1]._keypoint_ids, [0, 3])
        self.assertAllEqual(evaluator[2]._keypoint_ids, [1, 2])
        self.assertAllClose([0.1, 0.2], evaluator[1]._oks_sigmas)
        self.assertAllClose([0.3, 0.4], evaluator[2]._oks_sigmas)

    def test_get_evaluator_with_unmatched_label(self):
        eval_config = eval_pb2.EvalConfig()
        person_keypoints_metric = eval_config.parameterized_metric.add()
        person_keypoints_metric.coco_keypoint_metrics.class_label = 'unmatched'
        person_keypoints_metric.coco_keypoint_metrics.keypoint_label_to_sigmas[
            'kpt'] = 0.1
        categories = self._get_categories_list_with_keypoints()

        evaluator = eval_util.get_evaluators(eval_config,
                                             categories,
                                             evaluator_options=None)
        self.assertLen(evaluator, 1)
        self.assertNotIsInstance(evaluator[0],
                                 coco_evaluation.CocoKeypointEvaluator)

    def test_padded_image_result_dict(self):

        input_data_fields = fields.InputDataFields
        detection_fields = fields.DetectionResultFields
        key = tf.constant([str(i) for i in range(2)])

        detection_boxes = np.array(
            [[[0., 0., 1., 1.]], [[0.0, 0.0, 0.5, 0.5]]], dtype=np.float32)
        detection_keypoints = np.array([[0.0, 0.0], [0.5, 0.5], [1.0, 1.0]],
                                       dtype=np.float32)

        def graph_fn():
            detections = {
                detection_fields.detection_boxes:
                tf.constant(detection_boxes),
                detection_fields.detection_scores:
                tf.constant([[1.], [1.]]),
                detection_fields.detection_classes:
                tf.constant([[1], [2]]),
                detection_fields.num_detections:
                tf.constant([1, 1]),
                detection_fields.detection_keypoints:
                tf.tile(tf.reshape(tf.constant(detection_keypoints),
                                   shape=[1, 1, 3, 2]),
                        multiples=[2, 1, 1, 1])
            }

            gt_boxes = detection_boxes
            groundtruth = {
                input_data_fields.groundtruth_boxes:
                tf.constant(gt_boxes),
                input_data_fields.groundtruth_classes:
                tf.constant([[1.], [1.]]),
                input_data_fields.groundtruth_keypoints:
                tf.tile(tf.reshape(tf.constant(detection_keypoints),
                                   shape=[1, 1, 3, 2]),
                        multiples=[2, 1, 1, 1])
            }

            image = tf.zeros((2, 100, 100, 3), dtype=tf.float32)

            true_image_shapes = tf.constant([[100, 100, 3], [50, 100, 3]])
            original_image_spatial_shapes = tf.constant([[200, 200],
                                                         [150, 300]])

            result = eval_util.result_dict_for_batched_example(
                image,
                key,
                detections,
                groundtruth,
                scale_to_absolute=True,
                true_image_shapes=true_image_shapes,
                original_image_spatial_shapes=original_image_spatial_shapes,
                max_gt_boxes=tf.constant(1))
            return (result[input_data_fields.groundtruth_boxes],
                    result[input_data_fields.groundtruth_keypoints],
                    result[detection_fields.detection_boxes],
                    result[detection_fields.detection_keypoints])

        (gt_boxes, gt_keypoints, detection_boxes,
         detection_keypoints) = self.execute_cpu(graph_fn, [])
        self.assertAllEqual([[[0., 0., 200., 200.]], [[0.0, 0.0, 150., 150.]]],
                            gt_boxes)
        self.assertAllClose([[[[0., 0.], [100., 100.], [200., 200.]]],
                             [[[0., 0.], [150., 150.], [300., 300.]]]],
                            gt_keypoints)

        # Predictions from the model are not scaled.
        self.assertAllEqual([[[0., 0., 200., 200.]], [[0.0, 0.0, 75., 150.]]],
                            detection_boxes)
        self.assertAllClose([[[[0., 0.], [100., 100.], [200., 200.]]],
                             [[[0., 0.], [75., 150.], [150., 300.]]]],
                            detection_keypoints)
Exemple #21
0
def _build_faster_rcnn_model(frcnn_config, is_training, add_summaries):
    """Builds a Faster R-CNN or R-FCN detection model based on the model config.

  Builds R-FCN model if the second_stage_box_predictor in the config is of type
  `rfcn_box_predictor` else builds a Faster R-CNN model.

  Args:
    frcnn_config: A faster_rcnn.proto object containing the config for the
      desired FasterRCNNMetaArch or RFCNMetaArch.
    is_training: True if this model is being built for training purposes.
    add_summaries: Whether to add tf summaries in the model.

  Returns:
    FasterRCNNMetaArch based on the config.

  Raises:
    ValueError: If frcnn_config.type is not recognized (i.e. not registered in
      model_class_map).
  """
    num_classes = frcnn_config.num_classes
    image_resizer_fn = image_resizer_builder.build(frcnn_config.image_resizer)
    _check_feature_extractor_exists(frcnn_config.feature_extractor.type)
    is_keras = tf_version.is_tf2()

    if is_keras:
        feature_extractor = _build_faster_rcnn_keras_feature_extractor(
            frcnn_config.feature_extractor,
            is_training,
            inplace_batchnorm_update=frcnn_config.inplace_batchnorm_update)
    else:
        feature_extractor = _build_faster_rcnn_feature_extractor(
            frcnn_config.feature_extractor,
            is_training,
            inplace_batchnorm_update=frcnn_config.inplace_batchnorm_update)

    number_of_stages = frcnn_config.number_of_stages
    first_stage_anchor_generator = anchor_generator_builder.build(
        frcnn_config.first_stage_anchor_generator)

    first_stage_target_assigner = target_assigner.create_target_assigner(
        'FasterRCNN',
        'proposal',
        use_matmul_gather=frcnn_config.use_matmul_gather_in_matcher)
    first_stage_atrous_rate = frcnn_config.first_stage_atrous_rate
    if is_keras:
        first_stage_box_predictor_arg_scope_fn = (
            hyperparams_builder.KerasLayerHyperparams(
                frcnn_config.first_stage_box_predictor_conv_hyperparams))
    else:
        first_stage_box_predictor_arg_scope_fn = hyperparams_builder.build(
            frcnn_config.first_stage_box_predictor_conv_hyperparams,
            is_training)
    first_stage_box_predictor_kernel_size = (
        frcnn_config.first_stage_box_predictor_kernel_size)
    first_stage_box_predictor_depth = frcnn_config.first_stage_box_predictor_depth
    first_stage_minibatch_size = frcnn_config.first_stage_minibatch_size
    use_static_shapes = frcnn_config.use_static_shapes and (
        frcnn_config.use_static_shapes_for_eval or is_training)
    first_stage_sampler = sampler.BalancedPositiveNegativeSampler(
        positive_fraction=frcnn_config.first_stage_positive_balance_fraction,
        is_static=(frcnn_config.use_static_balanced_label_sampler
                   and use_static_shapes))
    first_stage_max_proposals = frcnn_config.first_stage_max_proposals
    if (frcnn_config.first_stage_nms_iou_threshold < 0
            or frcnn_config.first_stage_nms_iou_threshold > 1.0):
        raise ValueError('iou_threshold not in [0, 1.0].')
    if (is_training and
            frcnn_config.second_stage_batch_size > first_stage_max_proposals):
        raise ValueError('second_stage_batch_size should be no greater than '
                         'first_stage_max_proposals.')
    first_stage_non_max_suppression_fn = functools.partial(
        post_processing.batch_multiclass_non_max_suppression,
        score_thresh=frcnn_config.first_stage_nms_score_threshold,
        iou_thresh=frcnn_config.first_stage_nms_iou_threshold,
        max_size_per_class=frcnn_config.first_stage_max_proposals,
        max_total_size=frcnn_config.first_stage_max_proposals,
        use_static_shapes=use_static_shapes,
        use_partitioned_nms=frcnn_config.use_partitioned_nms_in_first_stage,
        use_combined_nms=frcnn_config.use_combined_nms_in_first_stage)
    first_stage_loc_loss_weight = (
        frcnn_config.first_stage_localization_loss_weight)
    first_stage_obj_loss_weight = frcnn_config.first_stage_objectness_loss_weight

    initial_crop_size = frcnn_config.initial_crop_size
    maxpool_kernel_size = frcnn_config.maxpool_kernel_size
    maxpool_stride = frcnn_config.maxpool_stride

    second_stage_target_assigner = target_assigner.create_target_assigner(
        'FasterRCNN',
        'detection',
        use_matmul_gather=frcnn_config.use_matmul_gather_in_matcher)
    if is_keras:
        second_stage_box_predictor = box_predictor_builder.build_keras(
            hyperparams_builder.KerasLayerHyperparams,
            freeze_batchnorm=False,
            inplace_batchnorm_update=False,
            num_predictions_per_location_list=[1],
            box_predictor_config=frcnn_config.second_stage_box_predictor,
            is_training=is_training,
            num_classes=num_classes)
    else:
        second_stage_box_predictor = box_predictor_builder.build(
            hyperparams_builder.build,
            frcnn_config.second_stage_box_predictor,
            is_training=is_training,
            num_classes=num_classes)
    second_stage_batch_size = frcnn_config.second_stage_batch_size
    second_stage_sampler = sampler.BalancedPositiveNegativeSampler(
        positive_fraction=frcnn_config.second_stage_balance_fraction,
        is_static=(frcnn_config.use_static_balanced_label_sampler
                   and use_static_shapes))
    (second_stage_non_max_suppression_fn,
     second_stage_score_conversion_fn) = post_processing_builder.build(
         frcnn_config.second_stage_post_processing)
    second_stage_localization_loss_weight = (
        frcnn_config.second_stage_localization_loss_weight)
    second_stage_classification_loss = (
        losses_builder.build_faster_rcnn_classification_loss(
            frcnn_config.second_stage_classification_loss))
    second_stage_classification_loss_weight = (
        frcnn_config.second_stage_classification_loss_weight)
    second_stage_mask_prediction_loss_weight = (
        frcnn_config.second_stage_mask_prediction_loss_weight)

    hard_example_miner = None
    if frcnn_config.HasField('hard_example_miner'):
        hard_example_miner = losses_builder.build_hard_example_miner(
            frcnn_config.hard_example_miner,
            second_stage_classification_loss_weight,
            second_stage_localization_loss_weight)

    crop_and_resize_fn = (ops.matmul_crop_and_resize
                          if frcnn_config.use_matmul_crop_and_resize else
                          ops.native_crop_and_resize)
    clip_anchors_to_image = (frcnn_config.clip_anchors_to_image)

    common_kwargs = {
        'is_training': is_training,
        'num_classes': num_classes,
        'image_resizer_fn': image_resizer_fn,
        'feature_extractor': feature_extractor,
        'number_of_stages': number_of_stages,
        'first_stage_anchor_generator': first_stage_anchor_generator,
        'first_stage_target_assigner': first_stage_target_assigner,
        'first_stage_atrous_rate': first_stage_atrous_rate,
        'first_stage_box_predictor_arg_scope_fn':
        first_stage_box_predictor_arg_scope_fn,
        'first_stage_box_predictor_kernel_size':
        first_stage_box_predictor_kernel_size,
        'first_stage_box_predictor_depth': first_stage_box_predictor_depth,
        'first_stage_minibatch_size': first_stage_minibatch_size,
        'first_stage_sampler': first_stage_sampler,
        'first_stage_non_max_suppression_fn':
        first_stage_non_max_suppression_fn,
        'first_stage_max_proposals': first_stage_max_proposals,
        'first_stage_localization_loss_weight': first_stage_loc_loss_weight,
        'first_stage_objectness_loss_weight': first_stage_obj_loss_weight,
        'second_stage_target_assigner': second_stage_target_assigner,
        'second_stage_batch_size': second_stage_batch_size,
        'second_stage_sampler': second_stage_sampler,
        'second_stage_non_max_suppression_fn':
        second_stage_non_max_suppression_fn,
        'second_stage_score_conversion_fn': second_stage_score_conversion_fn,
        'second_stage_localization_loss_weight':
        second_stage_localization_loss_weight,
        'second_stage_classification_loss': second_stage_classification_loss,
        'second_stage_classification_loss_weight':
        second_stage_classification_loss_weight,
        'hard_example_miner': hard_example_miner,
        'add_summaries': add_summaries,
        'crop_and_resize_fn': crop_and_resize_fn,
        'clip_anchors_to_image': clip_anchors_to_image,
        'use_static_shapes': use_static_shapes,
        'resize_masks': frcnn_config.resize_masks,
        'return_raw_detections_during_predict':
        frcnn_config.return_raw_detections_during_predict,
        'output_final_box_features': frcnn_config.output_final_box_features
    }

    if ((not is_keras and isinstance(second_stage_box_predictor,
                                     rfcn_box_predictor.RfcnBoxPredictor)) or
        (is_keras
         and isinstance(second_stage_box_predictor,
                        rfcn_keras_box_predictor.RfcnKerasBoxPredictor))):
        return rfcn_meta_arch.RFCNMetaArch(
            second_stage_rfcn_box_predictor=second_stage_box_predictor,
            **common_kwargs)
    elif frcnn_config.HasField('context_config'):
        context_config = frcnn_config.context_config
        common_kwargs.update({
            'attention_bottleneck_dimension':
            context_config.attention_bottleneck_dimension,
            'attention_temperature':
            context_config.attention_temperature
        })
        return context_rcnn_meta_arch.ContextRCNNMetaArch(
            initial_crop_size=initial_crop_size,
            maxpool_kernel_size=maxpool_kernel_size,
            maxpool_stride=maxpool_stride,
            second_stage_mask_rcnn_box_predictor=second_stage_box_predictor,
            second_stage_mask_prediction_loss_weight=(
                second_stage_mask_prediction_loss_weight),
            **common_kwargs)
    else:
        return faster_rcnn_meta_arch.FasterRCNNMetaArch(
            initial_crop_size=initial_crop_size,
            maxpool_kernel_size=maxpool_kernel_size,
            maxpool_stride=maxpool_stride,
            second_stage_mask_rcnn_box_predictor=second_stage_box_predictor,
            second_stage_mask_prediction_loss_weight=(
                second_stage_mask_prediction_loss_weight),
            **common_kwargs)
Exemple #22
0
class ConfigUtilTest(tf.test.TestCase):

  def _create_and_load_test_configs(self, pipeline_config):
    pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")
    _write_config(pipeline_config, pipeline_config_path)
    return config_util.get_configs_from_pipeline_file(pipeline_config_path)

  def test_get_configs_from_pipeline_file(self):
    """Test that proto configs can be read from pipeline config file."""
    pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")

    pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
    pipeline_config.model.faster_rcnn.num_classes = 10
    pipeline_config.train_config.batch_size = 32
    pipeline_config.train_input_reader.label_map_path = "path/to/label_map"
    pipeline_config.eval_config.num_examples = 20
    pipeline_config.eval_input_reader.add().queue_capacity = 100

    _write_config(pipeline_config, pipeline_config_path)

    configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
    self.assertProtoEquals(pipeline_config.model, configs["model"])
    self.assertProtoEquals(pipeline_config.train_config,
                           configs["train_config"])
    self.assertProtoEquals(pipeline_config.train_input_reader,
                           configs["train_input_config"])
    self.assertProtoEquals(pipeline_config.eval_config,
                           configs["eval_config"])
    self.assertProtoEquals(pipeline_config.eval_input_reader,
                           configs["eval_input_configs"])

  def test_create_configs_from_pipeline_proto(self):
    """Tests creating configs dictionary from pipeline proto."""

    pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
    pipeline_config.model.faster_rcnn.num_classes = 10
    pipeline_config.train_config.batch_size = 32
    pipeline_config.train_input_reader.label_map_path = "path/to/label_map"
    pipeline_config.eval_config.num_examples = 20
    pipeline_config.eval_input_reader.add().queue_capacity = 100

    configs = config_util.create_configs_from_pipeline_proto(pipeline_config)
    self.assertProtoEquals(pipeline_config.model, configs["model"])
    self.assertProtoEquals(pipeline_config.train_config,
                           configs["train_config"])
    self.assertProtoEquals(pipeline_config.train_input_reader,
                           configs["train_input_config"])
    self.assertProtoEquals(pipeline_config.eval_config, configs["eval_config"])
    self.assertProtoEquals(pipeline_config.eval_input_reader,
                           configs["eval_input_configs"])

  def test_create_pipeline_proto_from_configs(self):
    """Tests that proto can be reconstructed from configs dictionary."""
    pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")

    pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
    pipeline_config.model.faster_rcnn.num_classes = 10
    pipeline_config.train_config.batch_size = 32
    pipeline_config.train_input_reader.label_map_path = "path/to/label_map"
    pipeline_config.eval_config.num_examples = 20
    pipeline_config.eval_input_reader.add().queue_capacity = 100
    _write_config(pipeline_config, pipeline_config_path)

    configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
    pipeline_config_reconstructed = (
        config_util.create_pipeline_proto_from_configs(configs))
    self.assertEqual(pipeline_config, pipeline_config_reconstructed)

  def test_save_pipeline_config(self):
    """Tests that the pipeline config is properly saved to disk."""
    pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
    pipeline_config.model.faster_rcnn.num_classes = 10
    pipeline_config.train_config.batch_size = 32
    pipeline_config.train_input_reader.label_map_path = "path/to/label_map"
    pipeline_config.eval_config.num_examples = 20
    pipeline_config.eval_input_reader.add().queue_capacity = 100

    config_util.save_pipeline_config(pipeline_config, self.get_temp_dir())
    configs = config_util.get_configs_from_pipeline_file(
        os.path.join(self.get_temp_dir(), "pipeline.config"))
    pipeline_config_reconstructed = (
        config_util.create_pipeline_proto_from_configs(configs))

    self.assertEqual(pipeline_config, pipeline_config_reconstructed)

  def test_get_configs_from_multiple_files(self):
    """Tests that proto configs can be read from multiple files."""
    temp_dir = self.get_temp_dir()

    # Write model config file.
    model_config_path = os.path.join(temp_dir, "model.config")
    model = model_pb2.DetectionModel()
    model.faster_rcnn.num_classes = 10
    _write_config(model, model_config_path)

    # Write train config file.
    train_config_path = os.path.join(temp_dir, "train.config")
    train_config = train_config = train_pb2.TrainConfig()
    train_config.batch_size = 32
    _write_config(train_config, train_config_path)

    # Write train input config file.
    train_input_config_path = os.path.join(temp_dir, "train_input.config")
    train_input_config = input_reader_pb2.InputReader()
    train_input_config.label_map_path = "path/to/label_map"
    _write_config(train_input_config, train_input_config_path)

    # Write eval config file.
    eval_config_path = os.path.join(temp_dir, "eval.config")
    eval_config = eval_pb2.EvalConfig()
    eval_config.num_examples = 20
    _write_config(eval_config, eval_config_path)

    # Write eval input config file.
    eval_input_config_path = os.path.join(temp_dir, "eval_input.config")
    eval_input_config = input_reader_pb2.InputReader()
    eval_input_config.label_map_path = "path/to/another/label_map"
    _write_config(eval_input_config, eval_input_config_path)

    configs = config_util.get_configs_from_multiple_files(
        model_config_path=model_config_path,
        train_config_path=train_config_path,
        train_input_config_path=train_input_config_path,
        eval_config_path=eval_config_path,
        eval_input_config_path=eval_input_config_path)
    self.assertProtoEquals(model, configs["model"])
    self.assertProtoEquals(train_config, configs["train_config"])
    self.assertProtoEquals(train_input_config,
                           configs["train_input_config"])
    self.assertProtoEquals(eval_config, configs["eval_config"])
    self.assertProtoEquals(eval_input_config, configs["eval_input_configs"][0])

  def _assertOptimizerWithNewLearningRate(self, optimizer_name):
    """Asserts successful updating of all learning rate schemes."""
    original_learning_rate = 0.7
    learning_rate_scaling = 0.1
    warmup_learning_rate = 0.07
    hparams = contrib_training.HParams(learning_rate=0.15)
    pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")

    # Constant learning rate.
    pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
    optimizer = getattr(pipeline_config.train_config.optimizer, optimizer_name)
    _update_optimizer_with_constant_learning_rate(optimizer,
                                                  original_learning_rate)
    _write_config(pipeline_config, pipeline_config_path)

    configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
    configs = config_util.merge_external_params_with_configs(configs, hparams)
    optimizer = getattr(configs["train_config"].optimizer, optimizer_name)
    constant_lr = optimizer.learning_rate.constant_learning_rate
    self.assertAlmostEqual(hparams.learning_rate, constant_lr.learning_rate)

    # Exponential decay learning rate.
    pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
    optimizer = getattr(pipeline_config.train_config.optimizer, optimizer_name)
    _update_optimizer_with_exponential_decay_learning_rate(
        optimizer, original_learning_rate)
    _write_config(pipeline_config, pipeline_config_path)

    configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
    configs = config_util.merge_external_params_with_configs(configs, hparams)
    optimizer = getattr(configs["train_config"].optimizer, optimizer_name)
    exponential_lr = optimizer.learning_rate.exponential_decay_learning_rate
    self.assertAlmostEqual(hparams.learning_rate,
                           exponential_lr.initial_learning_rate)

    # Manual step learning rate.
    pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
    optimizer = getattr(pipeline_config.train_config.optimizer, optimizer_name)
    _update_optimizer_with_manual_step_learning_rate(
        optimizer, original_learning_rate, learning_rate_scaling)
    _write_config(pipeline_config, pipeline_config_path)

    configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
    configs = config_util.merge_external_params_with_configs(configs, hparams)
    optimizer = getattr(configs["train_config"].optimizer, optimizer_name)
    manual_lr = optimizer.learning_rate.manual_step_learning_rate
    self.assertAlmostEqual(hparams.learning_rate,
                           manual_lr.initial_learning_rate)
    for i, schedule in enumerate(manual_lr.schedule):
      self.assertAlmostEqual(hparams.learning_rate * learning_rate_scaling**i,
                             schedule.learning_rate)

    # Cosine decay learning rate.
    pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
    optimizer = getattr(pipeline_config.train_config.optimizer, optimizer_name)
    _update_optimizer_with_cosine_decay_learning_rate(optimizer,
                                                      original_learning_rate,
                                                      warmup_learning_rate)
    _write_config(pipeline_config, pipeline_config_path)

    configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
    configs = config_util.merge_external_params_with_configs(configs, hparams)
    optimizer = getattr(configs["train_config"].optimizer, optimizer_name)
    cosine_lr = optimizer.learning_rate.cosine_decay_learning_rate

    self.assertAlmostEqual(hparams.learning_rate, cosine_lr.learning_rate_base)
    warmup_scale_factor = warmup_learning_rate / original_learning_rate
    self.assertAlmostEqual(hparams.learning_rate * warmup_scale_factor,
                           cosine_lr.warmup_learning_rate)

  @unittest.skipIf(tf_version.is_tf2(), "Skipping TF1.X only test.")
  def testRMSPropWithNewLearingRate(self):
    """Tests new learning rates for RMSProp Optimizer."""
    self._assertOptimizerWithNewLearningRate("rms_prop_optimizer")

  @unittest.skipIf(tf_version.is_tf2(), "Skipping TF1.X only test.")
  def testMomentumOptimizerWithNewLearningRate(self):
    """Tests new learning rates for Momentum Optimizer."""
    self._assertOptimizerWithNewLearningRate("momentum_optimizer")

  @unittest.skipIf(tf_version.is_tf2(), "Skipping TF1.X only test.")
  def testAdamOptimizerWithNewLearningRate(self):
    """Tests new learning rates for Adam Optimizer."""
    self._assertOptimizerWithNewLearningRate("adam_optimizer")

  @unittest.skipIf(tf_version.is_tf2(), "Skipping TF1.X only test.")
  def testGenericConfigOverride(self):
    """Tests generic config overrides for all top-level configs."""
    # Set one parameter for each of the top-level pipeline configs:
    pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
    pipeline_config.model.ssd.num_classes = 1
    pipeline_config.train_config.batch_size = 1
    pipeline_config.eval_config.num_visualizations = 1
    pipeline_config.train_input_reader.label_map_path = "/some/path"
    pipeline_config.eval_input_reader.add().label_map_path = "/some/path"
    pipeline_config.graph_rewriter.quantization.weight_bits = 1

    pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")
    _write_config(pipeline_config, pipeline_config_path)

    # Override each of the parameters:
    configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
    hparams = contrib_training.HParams(
        **{
            "model.ssd.num_classes": 2,
            "train_config.batch_size": 2,
            "train_input_config.label_map_path": "/some/other/path",
            "eval_config.num_visualizations": 2,
            "graph_rewriter_config.quantization.weight_bits": 2
        })
    configs = config_util.merge_external_params_with_configs(configs, hparams)

    # Ensure that the parameters have the overridden values:
    self.assertEqual(2, configs["model"].ssd.num_classes)
    self.assertEqual(2, configs["train_config"].batch_size)
    self.assertEqual("/some/other/path",
                     configs["train_input_config"].label_map_path)
    self.assertEqual(2, configs["eval_config"].num_visualizations)
    self.assertEqual(2,
                     configs["graph_rewriter_config"].quantization.weight_bits)

  @unittest.skipIf(tf_version.is_tf2(), "Skipping TF1.X only test.")
  def testNewBatchSize(self):
    """Tests that batch size is updated appropriately."""
    original_batch_size = 2
    hparams = contrib_training.HParams(batch_size=16)
    pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")

    pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
    pipeline_config.train_config.batch_size = original_batch_size
    _write_config(pipeline_config, pipeline_config_path)

    configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
    configs = config_util.merge_external_params_with_configs(configs, hparams)
    new_batch_size = configs["train_config"].batch_size
    self.assertEqual(16, new_batch_size)

  @unittest.skipIf(tf_version.is_tf2(), "Skipping TF1.X only test.")
  def testNewBatchSizeWithClipping(self):
    """Tests that batch size is clipped to 1 from below."""
    original_batch_size = 2
    hparams = contrib_training.HParams(batch_size=0.5)
    pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")

    pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
    pipeline_config.train_config.batch_size = original_batch_size
    _write_config(pipeline_config, pipeline_config_path)

    configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
    configs = config_util.merge_external_params_with_configs(configs, hparams)
    new_batch_size = configs["train_config"].batch_size
    self.assertEqual(1, new_batch_size)  # Clipped to 1.0.

  @unittest.skipIf(tf_version.is_tf2(), "Skipping TF1.X only test.")
  def testOverwriteBatchSizeWithKeyValue(self):
    """Tests that batch size is overwritten based on key/value."""
    pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
    pipeline_config.train_config.batch_size = 2
    configs = self._create_and_load_test_configs(pipeline_config)
    hparams = contrib_training.HParams(**{"train_config.batch_size": 10})
    configs = config_util.merge_external_params_with_configs(configs, hparams)
    new_batch_size = configs["train_config"].batch_size
    self.assertEqual(10, new_batch_size)

  @unittest.skipIf(tf_version.is_tf2(), "Skipping TF1.X only test.")
  def testOverwriteSampleFromDatasetWeights(self):
    """Tests config override for sample_from_datasets_weights."""
    pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
    pipeline_config.train_input_reader.sample_from_datasets_weights.extend(
        [1, 2])
    pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")
    _write_config(pipeline_config, pipeline_config_path)

    # Override parameters:
    configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
    hparams = contrib_training.HParams(sample_from_datasets_weights=[0.5, 0.5])
    configs = config_util.merge_external_params_with_configs(configs, hparams)

    # Ensure that the parameters have the overridden values:
    self.assertListEqual(
        [0.5, 0.5],
        list(configs["train_input_config"].sample_from_datasets_weights))

  @unittest.skipIf(tf_version.is_tf2(), "Skipping TF1.X only test.")
  def testOverwriteSampleFromDatasetWeightsWrongLength(self):
    """Tests config override for sample_from_datasets_weights."""
    pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
    pipeline_config.train_input_reader.sample_from_datasets_weights.extend(
        [1, 2])
    pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")
    _write_config(pipeline_config, pipeline_config_path)

    # Try to override parameter with too many weights:
    configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
    hparams = contrib_training.HParams(
        sample_from_datasets_weights=[0.5, 0.5, 0.5])
    with self.assertRaises(
        ValueError,
        msg="sample_from_datasets_weights override has a different number of"
        " values (3) than the configured dataset weights (2)."
    ):
      config_util.merge_external_params_with_configs(configs, hparams)

  @unittest.skipIf(tf_version.is_tf2(), "Skipping TF1.X only test.")
  def testKeyValueOverrideBadKey(self):
    """Tests that overwriting with a bad key causes an exception."""
    pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
    configs = self._create_and_load_test_configs(pipeline_config)
    hparams = contrib_training.HParams(**{"train_config.no_such_field": 10})
    with self.assertRaises(ValueError):
      config_util.merge_external_params_with_configs(configs, hparams)

  @unittest.skipIf(tf_version.is_tf2(), "Skipping TF1.X only test.")
  def testOverwriteBatchSizeWithBadValueType(self):
    """Tests that overwriting with a bad valuye type causes an exception."""
    pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
    pipeline_config.train_config.batch_size = 2
    configs = self._create_and_load_test_configs(pipeline_config)
    # Type should be an integer, but we're passing a string "10".
    hparams = contrib_training.HParams(**{"train_config.batch_size": "10"})
    with self.assertRaises(TypeError):
      config_util.merge_external_params_with_configs(configs, hparams)

  @unittest.skipIf(tf_version.is_tf2(), "Skipping TF1.X only test.")
  def testNewMomentumOptimizerValue(self):
    """Tests that new momentum value is updated appropriately."""
    original_momentum_value = 0.4
    hparams = contrib_training.HParams(momentum_optimizer_value=1.1)
    pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")

    pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
    optimizer_config = pipeline_config.train_config.optimizer.rms_prop_optimizer
    optimizer_config.momentum_optimizer_value = original_momentum_value
    _write_config(pipeline_config, pipeline_config_path)

    configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
    configs = config_util.merge_external_params_with_configs(configs, hparams)
    optimizer_config = configs["train_config"].optimizer.rms_prop_optimizer
    new_momentum_value = optimizer_config.momentum_optimizer_value
    self.assertAlmostEqual(1.0, new_momentum_value)  # Clipped to 1.0.

  @unittest.skipIf(tf_version.is_tf2(), "Skipping TF1.X only test.")
  def testNewClassificationLocalizationWeightRatio(self):
    """Tests that the loss weight ratio is updated appropriately."""
    original_localization_weight = 0.1
    original_classification_weight = 0.2
    new_weight_ratio = 5.0
    hparams = contrib_training.HParams(
        classification_localization_weight_ratio=new_weight_ratio)
    pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")

    pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
    pipeline_config.model.ssd.loss.localization_weight = (
        original_localization_weight)
    pipeline_config.model.ssd.loss.classification_weight = (
        original_classification_weight)
    _write_config(pipeline_config, pipeline_config_path)

    configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
    configs = config_util.merge_external_params_with_configs(configs, hparams)
    loss = configs["model"].ssd.loss
    self.assertAlmostEqual(1.0, loss.localization_weight)
    self.assertAlmostEqual(new_weight_ratio, loss.classification_weight)

  @unittest.skipIf(tf_version.is_tf2(), "Skipping TF1.X only test.")
  def testNewFocalLossParameters(self):
    """Tests that the loss weight ratio is updated appropriately."""
    original_alpha = 1.0
    original_gamma = 1.0
    new_alpha = 0.3
    new_gamma = 2.0
    hparams = contrib_training.HParams(
        focal_loss_alpha=new_alpha, focal_loss_gamma=new_gamma)
    pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")

    pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
    classification_loss = pipeline_config.model.ssd.loss.classification_loss
    classification_loss.weighted_sigmoid_focal.alpha = original_alpha
    classification_loss.weighted_sigmoid_focal.gamma = original_gamma
    _write_config(pipeline_config, pipeline_config_path)

    configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
    configs = config_util.merge_external_params_with_configs(configs, hparams)
    classification_loss = configs["model"].ssd.loss.classification_loss
    self.assertAlmostEqual(new_alpha,
                           classification_loss.weighted_sigmoid_focal.alpha)
    self.assertAlmostEqual(new_gamma,
                           classification_loss.weighted_sigmoid_focal.gamma)

  def testMergingKeywordArguments(self):
    """Tests that keyword arguments get merged as do hyperparameters."""
    original_num_train_steps = 100
    desired_num_train_steps = 10
    pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")

    pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
    pipeline_config.train_config.num_steps = original_num_train_steps
    _write_config(pipeline_config, pipeline_config_path)

    configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
    override_dict = {"train_steps": desired_num_train_steps}
    configs = config_util.merge_external_params_with_configs(
        configs, kwargs_dict=override_dict)
    train_steps = configs["train_config"].num_steps
    self.assertEqual(desired_num_train_steps, train_steps)

  def testGetNumberOfClasses(self):
    """Tests that number of classes can be retrieved."""
    pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")
    pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
    pipeline_config.model.faster_rcnn.num_classes = 20
    _write_config(pipeline_config, pipeline_config_path)

    configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
    number_of_classes = config_util.get_number_of_classes(configs["model"])
    self.assertEqual(20, number_of_classes)

  def testNewTrainInputPath(self):
    """Tests that train input path can be overwritten with single file."""
    original_train_path = ["path/to/data"]
    new_train_path = "another/path/to/data"
    pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")

    pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
    reader_config = pipeline_config.train_input_reader.tf_record_input_reader
    reader_config.input_path.extend(original_train_path)
    _write_config(pipeline_config, pipeline_config_path)

    configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
    override_dict = {"train_input_path": new_train_path}
    configs = config_util.merge_external_params_with_configs(
        configs, kwargs_dict=override_dict)
    reader_config = configs["train_input_config"].tf_record_input_reader
    final_path = reader_config.input_path
    self.assertEqual([new_train_path], final_path)

  def testNewTrainInputPathList(self):
    """Tests that train input path can be overwritten with multiple files."""
    original_train_path = ["path/to/data"]
    new_train_path = ["another/path/to/data", "yet/another/path/to/data"]
    pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")

    pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
    reader_config = pipeline_config.train_input_reader.tf_record_input_reader
    reader_config.input_path.extend(original_train_path)
    _write_config(pipeline_config, pipeline_config_path)

    configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
    override_dict = {"train_input_path": new_train_path}
    configs = config_util.merge_external_params_with_configs(
        configs, kwargs_dict=override_dict)
    reader_config = configs["train_input_config"].tf_record_input_reader
    final_path = reader_config.input_path
    self.assertEqual(new_train_path, final_path)

  def testNewLabelMapPath(self):
    """Tests that label map path can be overwritten in input readers."""
    original_label_map_path = "path/to/original/label_map"
    new_label_map_path = "path//to/new/label_map"
    pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")

    pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
    train_input_reader = pipeline_config.train_input_reader
    train_input_reader.label_map_path = original_label_map_path
    eval_input_reader = pipeline_config.eval_input_reader.add()
    eval_input_reader.label_map_path = original_label_map_path
    _write_config(pipeline_config, pipeline_config_path)

    configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
    override_dict = {"label_map_path": new_label_map_path}
    configs = config_util.merge_external_params_with_configs(
        configs, kwargs_dict=override_dict)
    self.assertEqual(new_label_map_path,
                     configs["train_input_config"].label_map_path)
    for eval_input_config in configs["eval_input_configs"]:
      self.assertEqual(new_label_map_path, eval_input_config.label_map_path)

  def testDontOverwriteEmptyLabelMapPath(self):
    """Tests that label map path will not by overwritten with empty string."""
    original_label_map_path = "path/to/original/label_map"
    new_label_map_path = ""
    pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")

    pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
    train_input_reader = pipeline_config.train_input_reader
    train_input_reader.label_map_path = original_label_map_path
    eval_input_reader = pipeline_config.eval_input_reader.add()
    eval_input_reader.label_map_path = original_label_map_path
    _write_config(pipeline_config, pipeline_config_path)

    configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
    override_dict = {"label_map_path": new_label_map_path}
    configs = config_util.merge_external_params_with_configs(
        configs, kwargs_dict=override_dict)
    self.assertEqual(original_label_map_path,
                     configs["train_input_config"].label_map_path)
    self.assertEqual(original_label_map_path,
                     configs["eval_input_configs"][0].label_map_path)

  def testNewMaskType(self):
    """Tests that mask type can be overwritten in input readers."""
    original_mask_type = input_reader_pb2.NUMERICAL_MASKS
    new_mask_type = input_reader_pb2.PNG_MASKS
    pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")

    pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
    train_input_reader = pipeline_config.train_input_reader
    train_input_reader.mask_type = original_mask_type
    eval_input_reader = pipeline_config.eval_input_reader.add()
    eval_input_reader.mask_type = original_mask_type
    _write_config(pipeline_config, pipeline_config_path)

    configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
    override_dict = {"mask_type": new_mask_type}
    configs = config_util.merge_external_params_with_configs(
        configs, kwargs_dict=override_dict)
    self.assertEqual(new_mask_type, configs["train_input_config"].mask_type)
    self.assertEqual(new_mask_type, configs["eval_input_configs"][0].mask_type)

  def testUseMovingAverageForEval(self):
    use_moving_averages_orig = False
    pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")

    pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
    pipeline_config.eval_config.use_moving_averages = use_moving_averages_orig
    _write_config(pipeline_config, pipeline_config_path)

    configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
    override_dict = {"eval_with_moving_averages": True}
    configs = config_util.merge_external_params_with_configs(
        configs, kwargs_dict=override_dict)
    self.assertEqual(True, configs["eval_config"].use_moving_averages)

  def  testGetImageResizerConfig(self):
    """Tests that number of classes can be retrieved."""
    model_config = model_pb2.DetectionModel()
    model_config.faster_rcnn.image_resizer.fixed_shape_resizer.height = 100
    model_config.faster_rcnn.image_resizer.fixed_shape_resizer.width = 300
    image_resizer_config = config_util.get_image_resizer_config(model_config)
    self.assertEqual(image_resizer_config.fixed_shape_resizer.height, 100)
    self.assertEqual(image_resizer_config.fixed_shape_resizer.width, 300)

  def testGetSpatialImageSizeFromFixedShapeResizerConfig(self):
    image_resizer_config = image_resizer_pb2.ImageResizer()
    image_resizer_config.fixed_shape_resizer.height = 100
    image_resizer_config.fixed_shape_resizer.width = 200
    image_shape = config_util.get_spatial_image_size(image_resizer_config)
    self.assertAllEqual(image_shape, [100, 200])

  def testGetSpatialImageSizeFromAspectPreservingResizerConfig(self):
    image_resizer_config = image_resizer_pb2.ImageResizer()
    image_resizer_config.keep_aspect_ratio_resizer.min_dimension = 100
    image_resizer_config.keep_aspect_ratio_resizer.max_dimension = 600
    image_resizer_config.keep_aspect_ratio_resizer.pad_to_max_dimension = True
    image_shape = config_util.get_spatial_image_size(image_resizer_config)
    self.assertAllEqual(image_shape, [600, 600])

  def testGetSpatialImageSizeFromAspectPreservingResizerDynamic(self):
    image_resizer_config = image_resizer_pb2.ImageResizer()
    image_resizer_config.keep_aspect_ratio_resizer.min_dimension = 100
    image_resizer_config.keep_aspect_ratio_resizer.max_dimension = 600
    image_shape = config_util.get_spatial_image_size(image_resizer_config)
    self.assertAllEqual(image_shape, [-1, -1])

  def testGetSpatialImageSizeFromConditionalShapeResizer(self):
    image_resizer_config = image_resizer_pb2.ImageResizer()
    image_resizer_config.conditional_shape_resizer.size_threshold = 100
    image_shape = config_util.get_spatial_image_size(image_resizer_config)
    self.assertAllEqual(image_shape, [-1, -1])

  def testGetMaxNumContextFeaturesFromModelConfig(self):
    model_config = model_pb2.DetectionModel()
    model_config.faster_rcnn.context_config.max_num_context_features = 10
    max_num_context_features = config_util.get_max_num_context_features(
        model_config)
    self.assertAllEqual(max_num_context_features, 10)

  def testGetContextFeatureLengthFromModelConfig(self):
    model_config = model_pb2.DetectionModel()
    model_config.faster_rcnn.context_config.context_feature_length = 100
    context_feature_length = config_util.get_context_feature_length(
        model_config)
    self.assertAllEqual(context_feature_length, 100)

  def testEvalShuffle(self):
    """Tests that `eval_shuffle` keyword arguments are applied correctly."""
    original_shuffle = True
    desired_shuffle = False

    pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")
    pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
    pipeline_config.eval_input_reader.add().shuffle = original_shuffle
    _write_config(pipeline_config, pipeline_config_path)

    configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
    override_dict = {"eval_shuffle": desired_shuffle}
    configs = config_util.merge_external_params_with_configs(
        configs, kwargs_dict=override_dict)
    self.assertEqual(desired_shuffle, configs["eval_input_configs"][0].shuffle)

  def testTrainShuffle(self):
    """Tests that `train_shuffle` keyword arguments are applied correctly."""
    original_shuffle = True
    desired_shuffle = False

    pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")
    pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
    pipeline_config.train_input_reader.shuffle = original_shuffle
    _write_config(pipeline_config, pipeline_config_path)

    configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
    override_dict = {"train_shuffle": desired_shuffle}
    configs = config_util.merge_external_params_with_configs(
        configs, kwargs_dict=override_dict)
    train_shuffle = configs["train_input_config"].shuffle
    self.assertEqual(desired_shuffle, train_shuffle)

  def testOverWriteRetainOriginalImages(self):
    """Tests that `train_shuffle` keyword arguments are applied correctly."""
    original_retain_original_images = True
    desired_retain_original_images = False

    pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")
    pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
    pipeline_config.eval_config.retain_original_images = (
        original_retain_original_images)
    _write_config(pipeline_config, pipeline_config_path)

    configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
    override_dict = {
        "retain_original_images_in_eval": desired_retain_original_images
    }
    configs = config_util.merge_external_params_with_configs(
        configs, kwargs_dict=override_dict)
    retain_original_images = configs["eval_config"].retain_original_images
    self.assertEqual(desired_retain_original_images, retain_original_images)

  def testOverwriteAllEvalSampling(self):
    original_num_eval_examples = 1
    new_num_eval_examples = 10

    pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")
    pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
    pipeline_config.eval_input_reader.add().sample_1_of_n_examples = (
        original_num_eval_examples)
    pipeline_config.eval_input_reader.add().sample_1_of_n_examples = (
        original_num_eval_examples)
    _write_config(pipeline_config, pipeline_config_path)

    configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
    override_dict = {"sample_1_of_n_eval_examples": new_num_eval_examples}
    configs = config_util.merge_external_params_with_configs(
        configs, kwargs_dict=override_dict)
    for eval_input_config in configs["eval_input_configs"]:
      self.assertEqual(new_num_eval_examples,
                       eval_input_config.sample_1_of_n_examples)

  def testOverwriteAllEvalNumEpochs(self):
    original_num_epochs = 10
    new_num_epochs = 1

    pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")
    pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
    pipeline_config.eval_input_reader.add().num_epochs = original_num_epochs
    pipeline_config.eval_input_reader.add().num_epochs = original_num_epochs
    _write_config(pipeline_config, pipeline_config_path)

    configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
    override_dict = {"eval_num_epochs": new_num_epochs}
    configs = config_util.merge_external_params_with_configs(
        configs, kwargs_dict=override_dict)
    for eval_input_config in configs["eval_input_configs"]:
      self.assertEqual(new_num_epochs, eval_input_config.num_epochs)

  def testUpdateMaskTypeForAllInputConfigs(self):
    original_mask_type = input_reader_pb2.NUMERICAL_MASKS
    new_mask_type = input_reader_pb2.PNG_MASKS

    pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")
    pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
    train_config = pipeline_config.train_input_reader
    train_config.mask_type = original_mask_type
    eval_1 = pipeline_config.eval_input_reader.add()
    eval_1.mask_type = original_mask_type
    eval_1.name = "eval_1"
    eval_2 = pipeline_config.eval_input_reader.add()
    eval_2.mask_type = original_mask_type
    eval_2.name = "eval_2"
    _write_config(pipeline_config, pipeline_config_path)

    configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
    override_dict = {"mask_type": new_mask_type}
    configs = config_util.merge_external_params_with_configs(
        configs, kwargs_dict=override_dict)

    self.assertEqual(configs["train_input_config"].mask_type, new_mask_type)
    for eval_input_config in configs["eval_input_configs"]:
      self.assertEqual(eval_input_config.mask_type, new_mask_type)

  def testErrorOverwritingMultipleInputConfig(self):
    original_shuffle = False
    new_shuffle = True
    pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")
    pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
    eval_1 = pipeline_config.eval_input_reader.add()
    eval_1.shuffle = original_shuffle
    eval_1.name = "eval_1"
    eval_2 = pipeline_config.eval_input_reader.add()
    eval_2.shuffle = original_shuffle
    eval_2.name = "eval_2"
    _write_config(pipeline_config, pipeline_config_path)

    configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
    override_dict = {"eval_shuffle": new_shuffle}
    with self.assertRaises(ValueError):
      configs = config_util.merge_external_params_with_configs(
          configs, kwargs_dict=override_dict)

  def testCheckAndParseInputConfigKey(self):
    pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")
    pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
    pipeline_config.eval_input_reader.add().name = "eval_1"
    pipeline_config.eval_input_reader.add().name = "eval_2"
    _write_config(pipeline_config, pipeline_config_path)
    configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)

    specific_shuffle_update_key = "eval_input_configs:eval_2:shuffle"
    is_valid_input_config_key, key_name, input_name, field_name = (
        config_util.check_and_parse_input_config_key(
            configs, specific_shuffle_update_key))
    self.assertTrue(is_valid_input_config_key)
    self.assertEqual(key_name, "eval_input_configs")
    self.assertEqual(input_name, "eval_2")
    self.assertEqual(field_name, "shuffle")

    legacy_shuffle_update_key = "eval_shuffle"
    is_valid_input_config_key, key_name, input_name, field_name = (
        config_util.check_and_parse_input_config_key(configs,
                                                     legacy_shuffle_update_key))
    self.assertTrue(is_valid_input_config_key)
    self.assertEqual(key_name, "eval_input_configs")
    self.assertEqual(input_name, None)
    self.assertEqual(field_name, "shuffle")

    non_input_config_update_key = "label_map_path"
    is_valid_input_config_key, key_name, input_name, field_name = (
        config_util.check_and_parse_input_config_key(
            configs, non_input_config_update_key))
    self.assertFalse(is_valid_input_config_key)
    self.assertEqual(key_name, None)
    self.assertEqual(input_name, None)
    self.assertEqual(field_name, "label_map_path")

    with self.assertRaisesRegexp(ValueError,
                                 "Invalid key format when overriding configs."):
      config_util.check_and_parse_input_config_key(
          configs, "train_input_config:shuffle")

    with self.assertRaisesRegexp(
        ValueError, "Invalid key_name when overriding input config."):
      config_util.check_and_parse_input_config_key(
          configs, "invalid_key_name:train_name:shuffle")

    with self.assertRaisesRegexp(
        ValueError, "Invalid input_name when overriding input config."):
      config_util.check_and_parse_input_config_key(
          configs, "eval_input_configs:unknown_eval_name:shuffle")

    with self.assertRaisesRegexp(
        ValueError, "Invalid field_name when overriding input config."):
      config_util.check_and_parse_input_config_key(
          configs, "eval_input_configs:eval_2:unknown_field_name")

  def testUpdateInputReaderConfigSuccess(self):
    original_shuffle = False
    new_shuffle = True
    pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")
    pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
    pipeline_config.train_input_reader.shuffle = original_shuffle
    _write_config(pipeline_config, pipeline_config_path)
    configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)

    config_util.update_input_reader_config(
        configs,
        key_name="train_input_config",
        input_name=None,
        field_name="shuffle",
        value=new_shuffle)
    self.assertEqual(configs["train_input_config"].shuffle, new_shuffle)

    config_util.update_input_reader_config(
        configs,
        key_name="train_input_config",
        input_name=None,
        field_name="shuffle",
        value=new_shuffle)
    self.assertEqual(configs["train_input_config"].shuffle, new_shuffle)

  def testUpdateInputReaderConfigErrors(self):
    pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")
    pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
    pipeline_config.eval_input_reader.add().name = "same_eval_name"
    pipeline_config.eval_input_reader.add().name = "same_eval_name"
    _write_config(pipeline_config, pipeline_config_path)
    configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)

    with self.assertRaisesRegexp(ValueError,
                                 "Duplicate input name found when overriding."):
      config_util.update_input_reader_config(
          configs,
          key_name="eval_input_configs",
          input_name="same_eval_name",
          field_name="shuffle",
          value=False)

    with self.assertRaisesRegexp(
        ValueError, "Input name name_not_exist not found when overriding."):
      config_util.update_input_reader_config(
          configs,
          key_name="eval_input_configs",
          input_name="name_not_exist",
          field_name="shuffle",
          value=False)

    with self.assertRaisesRegexp(ValueError,
                                 "Unknown input config overriding."):
      config_util.update_input_reader_config(
          configs,
          key_name="eval_input_configs",
          input_name=None,
          field_name="shuffle",
          value=False)

  def testOverWriteRetainOriginalImageAdditionalChannels(self):
    """Tests that keyword arguments are applied correctly."""
    original_retain_original_image_additional_channels = True
    desired_retain_original_image_additional_channels = False

    pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")
    pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
    pipeline_config.eval_config.retain_original_image_additional_channels = (
        original_retain_original_image_additional_channels)
    _write_config(pipeline_config, pipeline_config_path)

    configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
    override_dict = {
        "retain_original_image_additional_channels_in_eval":
            desired_retain_original_image_additional_channels
    }
    configs = config_util.merge_external_params_with_configs(
        configs, kwargs_dict=override_dict)
    retain_original_image_additional_channels = configs[
        "eval_config"].retain_original_image_additional_channels
    self.assertEqual(desired_retain_original_image_additional_channels,
                     retain_original_image_additional_channels)

  def testUpdateNumClasses(self):
    pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")
    pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
    pipeline_config.model.faster_rcnn.num_classes = 10

    _write_config(pipeline_config, pipeline_config_path)

    configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)

    self.assertEqual(config_util.get_number_of_classes(configs["model"]), 10)

    config_util.merge_external_params_with_configs(
        configs, kwargs_dict={"num_classes": 2})

    self.assertEqual(config_util.get_number_of_classes(configs["model"]), 2)

  def testRemoveUnecessaryEma(self):
    input_dict = {
        "expanded_conv_10/project/act_quant/min":
            1,
        "FeatureExtractor/MobilenetV2_2/expanded_conv_5/expand/act_quant/min":
            2,
        "expanded_conv_10/expand/BatchNorm/gamma/min/ExponentialMovingAverage":
            3,
        "expanded_conv_3/depthwise/BatchNorm/beta/max/ExponentialMovingAverage":
            4,
        "BoxPredictor_1/ClassPredictor_depthwise/act_quant":
            5
    }

    no_ema_collection = ["/min", "/max"]

    output_dict = {
        "expanded_conv_10/project/act_quant/min":
            1,
        "FeatureExtractor/MobilenetV2_2/expanded_conv_5/expand/act_quant/min":
            2,
        "expanded_conv_10/expand/BatchNorm/gamma/min":
            3,
        "expanded_conv_3/depthwise/BatchNorm/beta/max":
            4,
        "BoxPredictor_1/ClassPredictor_depthwise/act_quant":
            5
    }

    self.assertEqual(
        output_dict,
        config_util.remove_unecessary_ema(input_dict, no_ema_collection))