def test_get_expected_feature_map_shapes(self, bifpn_num_iterations):
    with test_utils.GraphContextOrNone() as g:
      image_features = [
          ('block3', tf.random_uniform([4, 16, 16, 256], dtype=tf.float32)),
          ('block4', tf.random_uniform([4, 8, 8, 256], dtype=tf.float32)),
          ('block5', tf.random_uniform([4, 4, 4, 256], dtype=tf.float32))
      ]
      bifpn_generator = bifpn_generators.KerasBiFpnFeatureMaps(
          bifpn_num_iterations=bifpn_num_iterations,
          bifpn_num_filters=128,
          fpn_min_level=3,
          fpn_max_level=7,
          input_max_level=5,
          is_training=True,
          conv_hyperparams=self._build_conv_hyperparams(),
          freeze_batchnorm=False)
    def graph_fn():
      feature_maps = bifpn_generator(image_features)
      return feature_maps

    expected_feature_map_shapes = {
        '{}_dn_lvl_3'.format(bifpn_num_iterations): (4, 16, 16, 128),
        '{}_up_lvl_4'.format(bifpn_num_iterations): (4, 8, 8, 128),
        '{}_up_lvl_5'.format(bifpn_num_iterations): (4, 4, 4, 128),
        '{}_up_lvl_6'.format(bifpn_num_iterations): (4, 2, 2, 128),
        '{}_up_lvl_7'.format(bifpn_num_iterations): (4, 1, 1, 128)}
    out_feature_maps = self.execute(graph_fn, [], g)
    out_feature_map_shapes = dict(
        (key, value.shape) for key, value in out_feature_maps.items())
    self.assertDictEqual(expected_feature_map_shapes, out_feature_map_shapes)
 def build(self, input_shape):
     self._bifpn_stage = bifpn_generators.KerasBiFpnFeatureMaps(
         bifpn_num_iterations=self._bifpn_num_iterations,
         bifpn_num_filters=self._bifpn_num_filters,
         fpn_min_level=self._bifpn_min_level,
         fpn_max_level=self._bifpn_max_level,
         input_max_level=self._backbone_max_level,
         is_training=self._is_training,
         conv_hyperparams=self._conv_hyperparams,
         freeze_batchnorm=self._freeze_batchnorm,
         bifpn_node_params=self._bifpn_node_params,
         name='bifpn')
     self.built = True
    def __init__(self,
                 is_training,
                 first_stage_features_stride,
                 conv_hyperparams,
                 min_depth,
                 bifpn_min_level,
                 bifpn_max_level,
                 bifpn_num_iterations,
                 bifpn_num_filters,
                 bifpn_combine_method,
                 efficientnet_version,
                 freeze_batchnorm,
                 pad_to_multiple=32,
                 weight_decay=0.0,
                 name=None):
        """Constructor.

    Args:
      is_training: See base class.
      resnet_v1_base_model: base resnet v1 network to use. One of
        the resnet_v1.resnet_v1_{50,101,152} models.
      resnet_v1_base_model_name: model name under which to construct resnet v1.
      first_stage_features_stride: See base class.
      conv_hyperparams: a `hyperparams_builder.KerasLayerHyperparams` object
        containing convolution hyperparameters for the layers added on top of
        the base feature extractor.
      batch_norm_trainable: See base class.
      pad_to_multiple: An integer multiple to pad input image.
      weight_decay: See base class.
      fpn_min_level: the highest resolution feature map to use in FPN. The valid
        values are {2, 3, 4, 5} which map to Resnet v1 layers.
      fpn_max_level: the smallest resolution feature map to construct or use in
        FPN. FPN constructions uses features maps starting from fpn_min_level
        upto the fpn_max_level. In the case that there are not enough feature
        maps in the backbone network, additional feature maps are created by
        applying stride 2 convolutions until we get the desired number of fpn
        levels.
      additional_layer_depth: additional feature map layer channel depth.
      override_base_feature_extractor_hyperparams: Whether to override
        hyperparameters of the base feature extractor with the one from
        `conv_hyperparams`.

    Raises:
      ValueError: If `first_stage_features_stride` is not 8 or 16.
    """
        if first_stage_features_stride != 8 and first_stage_features_stride != 16:
            raise ValueError('`first_stage_features_stride` must be 8 or 16.')

        super(ProbabilisticTwoStageEfficientNetBiFPNKerasFeatureExtractor,
              self).__init__(
                  is_training=is_training,
                  first_stage_features_stride=first_stage_features_stride,
                  freeze_batchnorm=freeze_batchnorm,
                  weight_decay=weight_decay)

        self._bifpn_min_level = bifpn_min_level
        self._bifpn_max_level = bifpn_max_level
        self._bifpn_num_iterations = bifpn_num_iterations
        self._bifpn_num_filters = max(bifpn_num_filters, min_depth)
        self._bifpn_node_params = {'combine_method': bifpn_combine_method}
        self._efficientnet_version = efficientnet_version
        self._pad_to_multiple = pad_to_multiple
        self._conv_hyperparams = conv_hyperparams
        self._freeze_batchnorm = freeze_batchnorm

        self.classification_backbone = None

        logging.info('EfficientDet EfficientNet backbone version: %s',
                     self._efficientnet_version)
        logging.info('EfficientDet BiFPN num filters: %d',
                     self._bifpn_num_filters)
        logging.info('EfficientDet BiFPN num iterations: %d',
                     self._bifpn_num_iterations)

        self._backbone_max_level = min(
            max(_EFFICIENTNET_LEVEL_ENDPOINTS.keys()), self._bifpn_max_level)
        self._output_layer_names = [
            _EFFICIENTNET_LEVEL_ENDPOINTS[i]
            for i in range(self._bifpn_min_level, self._backbone_max_level + 1)
        ]
        self._output_layer_alias = [
            'level_{}'.format(i)
            for i in range(self._bifpn_min_level, self._backbone_max_level + 1)
        ]

        efficientnet_base = efficientnet_model.EfficientNet.from_name(
            model_name=self._efficientnet_version,
            overrides={'rescale_input': False})
        outputs = [
            efficientnet_base.get_layer(output_layer_name).output
            for output_layer_name in self._output_layer_names
        ]
        self.classification_backbone = tf.keras.Model(
            inputs=efficientnet_base.inputs, outputs=outputs)
        self._bifpn_stage = bifpn_generators.KerasBiFpnFeatureMaps(
            bifpn_num_iterations=self._bifpn_num_iterations,
            bifpn_num_filters=self._bifpn_num_filters,
            fpn_min_level=self._bifpn_min_level,
            fpn_max_level=self._bifpn_max_level,
            input_max_level=self._backbone_max_level,
            is_training=self._is_training,
            conv_hyperparams=self._conv_hyperparams,
            freeze_batchnorm=self._freeze_batchnorm,
            bifpn_node_params=self._bifpn_node_params,
            name='bifpn')

        self.proposal_feature_extractor_model = _EfficientNetBiFPN(
            efficientnet_backbone=self.classification_backbone,
            bifpn_generator=self._bifpn_stage,
            pad_to_multiple=self._pad_to_multiple,
            output_layer_alias=self._output_layer_alias)

        self.box_classifier_model_conv = tf.keras.models.Sequential([
            tf.keras.layers.SeparableConv2D(filters=128,
                                            kernel_size=[3, 3],
                                            strides=2,
                                            activation='relu'),
            tf.keras.layers.Flatten(),
            tf.keras.layers.Dense(units=1024, activation='relu'),
            tf.keras.layers.Dense(units=512, activation='relu'),
            tf.keras.layers.Dense(units=256, activation='relu'),
            tf.keras.layers.Reshape((1, 1, 256))
        ])
  def test_get_expected_variable_names(self, bifpn_num_iterations):
    with test_utils.GraphContextOrNone() as g:
      image_features = [
          ('block3', tf.random_uniform([4, 16, 16, 256], dtype=tf.float32)),
          ('block4', tf.random_uniform([4, 8, 8, 256], dtype=tf.float32)),
          ('block5', tf.random_uniform([4, 4, 4, 256], dtype=tf.float32))
      ]
      bifpn_generator = bifpn_generators.KerasBiFpnFeatureMaps(
          bifpn_num_iterations=bifpn_num_iterations,
          bifpn_num_filters=128,
          fpn_min_level=3,
          fpn_max_level=7,
          input_max_level=5,
          is_training=True,
          conv_hyperparams=self._build_conv_hyperparams(),
          freeze_batchnorm=False,
          name='bifpn')
    def graph_fn():
      return bifpn_generator(image_features)

    self.execute(graph_fn, [], g)
    expected_variables = [
        'bifpn/node_00/0_up_lvl_6/input_0_up_lvl_5/1x1_pre_sample/conv/bias',
        'bifpn/node_00/0_up_lvl_6/input_0_up_lvl_5/1x1_pre_sample/conv/kernel',
        'bifpn/node_03/1_dn_lvl_5/input_0_up_lvl_5/1x1_pre_sample/conv/bias',
        'bifpn/node_03/1_dn_lvl_5/input_0_up_lvl_5/1x1_pre_sample/conv/kernel',
        'bifpn/node_04/1_dn_lvl_4/input_0_up_lvl_4/1x1_pre_sample/conv/bias',
        'bifpn/node_04/1_dn_lvl_4/input_0_up_lvl_4/1x1_pre_sample/conv/kernel',
        'bifpn/node_05/1_dn_lvl_3/input_0_up_lvl_3/1x1_pre_sample/conv/bias',
        'bifpn/node_05/1_dn_lvl_3/input_0_up_lvl_3/1x1_pre_sample/conv/kernel',
        'bifpn/node_06/1_up_lvl_4/input_0_up_lvl_4/1x1_pre_sample/conv/bias',
        'bifpn/node_06/1_up_lvl_4/input_0_up_lvl_4/1x1_pre_sample/conv/kernel',
        'bifpn/node_07/1_up_lvl_5/input_0_up_lvl_5/1x1_pre_sample/conv/bias',
        'bifpn/node_07/1_up_lvl_5/input_0_up_lvl_5/1x1_pre_sample/conv/kernel']
    expected_node_variable_patterns = [
        ['bifpn/node_{:02}/{}_dn_lvl_6/combine/bifpn_combine_weights',
         'bifpn/node_{:02}/{}_dn_lvl_6/post_combine/separable_conv/bias',
         'bifpn/node_{:02}/{}_dn_lvl_6/post_combine/separable_conv/depthwise_kernel',
         'bifpn/node_{:02}/{}_dn_lvl_6/post_combine/separable_conv/pointwise_kernel'],
        ['bifpn/node_{:02}/{}_dn_lvl_5/combine/bifpn_combine_weights',
         'bifpn/node_{:02}/{}_dn_lvl_5/post_combine/separable_conv/bias',
         'bifpn/node_{:02}/{}_dn_lvl_5/post_combine/separable_conv/depthwise_kernel',
         'bifpn/node_{:02}/{}_dn_lvl_5/post_combine/separable_conv/pointwise_kernel'],
        ['bifpn/node_{:02}/{}_dn_lvl_4/combine/bifpn_combine_weights',
         'bifpn/node_{:02}/{}_dn_lvl_4/post_combine/separable_conv/bias',
         'bifpn/node_{:02}/{}_dn_lvl_4/post_combine/separable_conv/depthwise_kernel',
         'bifpn/node_{:02}/{}_dn_lvl_4/post_combine/separable_conv/pointwise_kernel'],
        ['bifpn/node_{:02}/{}_dn_lvl_3/combine/bifpn_combine_weights',
         'bifpn/node_{:02}/{}_dn_lvl_3/post_combine/separable_conv/bias',
         'bifpn/node_{:02}/{}_dn_lvl_3/post_combine/separable_conv/depthwise_kernel',
         'bifpn/node_{:02}/{}_dn_lvl_3/post_combine/separable_conv/pointwise_kernel'],
        ['bifpn/node_{:02}/{}_up_lvl_4/combine/bifpn_combine_weights',
         'bifpn/node_{:02}/{}_up_lvl_4/post_combine/separable_conv/bias',
         'bifpn/node_{:02}/{}_up_lvl_4/post_combine/separable_conv/depthwise_kernel',
         'bifpn/node_{:02}/{}_up_lvl_4/post_combine/separable_conv/pointwise_kernel'],
        ['bifpn/node_{:02}/{}_up_lvl_5/combine/bifpn_combine_weights',
         'bifpn/node_{:02}/{}_up_lvl_5/post_combine/separable_conv/bias',
         'bifpn/node_{:02}/{}_up_lvl_5/post_combine/separable_conv/depthwise_kernel',
         'bifpn/node_{:02}/{}_up_lvl_5/post_combine/separable_conv/pointwise_kernel'],
        ['bifpn/node_{:02}/{}_up_lvl_6/combine/bifpn_combine_weights',
         'bifpn/node_{:02}/{}_up_lvl_6/post_combine/separable_conv/bias',
         'bifpn/node_{:02}/{}_up_lvl_6/post_combine/separable_conv/depthwise_kernel',
         'bifpn/node_{:02}/{}_up_lvl_6/post_combine/separable_conv/pointwise_kernel'],
        ['bifpn/node_{:02}/{}_up_lvl_7/combine/bifpn_combine_weights',
         'bifpn/node_{:02}/{}_up_lvl_7/post_combine/separable_conv/bias',
         'bifpn/node_{:02}/{}_up_lvl_7/post_combine/separable_conv/depthwise_kernel',
         'bifpn/node_{:02}/{}_up_lvl_7/post_combine/separable_conv/pointwise_kernel']]

    node_i = 2
    for iter_i in range(1, bifpn_num_iterations+1):
      for node_variable_patterns in expected_node_variable_patterns:
        for pattern in node_variable_patterns:
          expected_variables.append(pattern.format(node_i, iter_i))
        node_i += 1

    expected_variables = set(expected_variables)
    actual_variable_set = set(
        [var.name.split(':')[0] for var in bifpn_generator.variables])
    self.assertSetEqual(expected_variables, actual_variable_set)