def test_subsample_selection_no_batch_size_static(self):
   labels = tf.constant([[True, False, False]])
   indicator = tf.constant([True, False, True])
   sampler = (
       balanced_positive_negative_sampler.BalancedPositiveNegativeSampler())
   with self.assertRaises(ValueError):
     sampler.subsample(indicator, None, labels)
    def _loss_weights(self, labels, is_fg):
        loss_positive_balance_fraction = self._loss_positive_balance_fraction

        ## If true sample uniformly from all the scores
        if self._loss_positive_balance_fraction <= 0:
            loss_positive_balance_fraction = 1.0

        loss_sampler = sampler.BalancedPositiveNegativeSampler(
            positive_fraction=loss_positive_balance_fraction)

        def _minibatch_subsample_fn(labels):
            indicators = tf.ones_like(labels)
            ## If true samples uniformly from scores
            if self._loss_positive_balance_fraction <= 0:
                labels = indicators

            return loss_sampler.subsample(indicators,
                                          self._loss_minibatch_size, labels)

        ffg = tf.to_float(is_fg)
        weights = (1 - ffg) * self._negative_example_weight + ffg

        ## if _loss_minibatch_size == 0 do not do subsampling
        if self._loss_minibatch_size > 0:
            batch_sampled_indices = tf.to_float(
                tf.map_fn(_minibatch_subsample_fn,
                          is_fg,
                          dtype=tf.bool,
                          parallel_iterations=self._parallel_iterations,
                          back_prop=True))
            weights *= batch_sampled_indices
        return weights * self._loss_weight
Exemplo n.º 3
0
 def build(self):
     super(FasterRCNNFirstStageLoss, self).build()
     self._proposal_target_assigner = (
         target_assigner.create_target_assigner('FasterRCNN', 'proposal'))
     self._sampler = (
         balanced_positive_negative_sampler.BalancedPositiveNegativeSampler(
             positive_fraction=self.positive_balance_fraction,
             is_static=False))
     self._localization_loss = od_losses.WeightedSmoothL1LocalizationLoss()
     self._objectness_loss = od_losses.WeightedSoftmaxClassificationLoss()
     return self
Exemplo n.º 4
0
def build(loss_config):
  
  classification_loss = _build_classification_loss(
      loss_config.classification_loss)
  localization_loss = _build_localization_loss(
      loss_config.localization_loss)
  classification_weight = loss_config.classification_weight
  localization_weight = loss_config.localization_weight
  hard_example_miner = None
  if loss_config.HasField('hard_example_miner'):
    if (loss_config.classification_loss.WhichOneof('classification_loss') ==
        'weighted_sigmoid_focal'):
      raise ValueError('HardExampleMiner should not be used with sigmoid focal '
                       'loss')
    hard_example_miner = build_hard_example_miner(
        loss_config.hard_example_miner,
        classification_weight,
        localization_weight)
  random_example_sampler = None
  if loss_config.HasField('random_example_sampler'):
    if loss_config.random_example_sampler.positive_sample_fraction <= 0:
      raise ValueError('RandomExampleSampler should not use non-positive'
                       'value as positive sample fraction.')
    random_example_sampler = sampler.BalancedPositiveNegativeSampler(
        positive_fraction=loss_config.random_example_sampler.
        positive_sample_fraction)

  if loss_config.expected_loss_weights == loss_config.NONE:
    expected_loss_weights_fn = None
  elif loss_config.expected_loss_weights == loss_config.EXPECTED_SAMPLING:
    expected_loss_weights_fn = functools.partial(
        ops.expected_classification_loss_by_expected_sampling,
        min_num_negative_samples=loss_config.min_num_negative_samples,
        desired_negative_sampling_ratio=loss_config
        .desired_negative_sampling_ratio)
  elif (loss_config.expected_loss_weights == loss_config
        .REWEIGHTING_UNMATCHED_ANCHORS):
    expected_loss_weights_fn = functools.partial(
        ops.expected_classification_loss_by_reweighting_unmatched_anchors,
        min_num_negative_samples=loss_config.min_num_negative_samples,
        desired_negative_sampling_ratio=loss_config
        .desired_negative_sampling_ratio)
  else:
    raise ValueError('Not a valid value for expected_classification_loss.')

  return (classification_loss, localization_loss, classification_weight,
          localization_weight, hard_example_miner, random_example_sampler,
          expected_loss_weights_fn)
Exemplo n.º 5
0
def build(loss_config):
    """Build losses based on the config.

  Builds classification, localization losses and optionally a hard example miner
  based on the config.

  Args:
    loss_config: A losses_pb2.Loss object.

  Returns:
    classification_loss: Classification loss object.
    localization_loss: Localization loss object.
    classification_weight: Classification loss weight.
    localization_weight: Localization loss weight.
    hard_example_miner: Hard example miner object.
    random_example_sampler: BalancedPositiveNegativeSampler object.

  Raises:
    ValueError: If hard_example_miner is used with sigmoid_focal_loss.
    ValueError: If random_example_sampler is getting non-positive value as
      desired positive example fraction.
  """
    classification_loss = _build_classification_loss(
        loss_config.classification_loss)
    localization_loss = _build_localization_loss(loss_config.localization_loss)
    classification_weight = loss_config.classification_weight
    localization_weight = loss_config.localization_weight
    hard_example_miner = None
    if loss_config.HasField('hard_example_miner'):
        if (loss_config.classification_loss.WhichOneof('classification_loss')
                == 'weighted_sigmoid_focal'):
            raise ValueError(
                'HardExampleMiner should not be used with sigmoid focal '
                'loss')
        hard_example_miner = build_hard_example_miner(
            loss_config.hard_example_miner, classification_weight,
            localization_weight)
    random_example_sampler = None
    if loss_config.HasField('random_example_sampler'):
        if loss_config.random_example_sampler.positive_sample_fraction <= 0:
            raise ValueError('RandomExampleSampler should not use non-positive'
                             'value as positive sample fraction.')
        random_example_sampler = sampler.BalancedPositiveNegativeSampler(
            positive_fraction=loss_config.random_example_sampler.
            positive_sample_fraction)
    return (classification_loss, localization_loss, classification_weight,
            localization_weight, hard_example_miner, random_example_sampler)
  def _test_subsample_all_examples(self, is_static=False):
    numpy_labels = np.random.permutation(300)
    indicator = tf.constant(np.ones(300) == 1)
    numpy_labels = (numpy_labels - 200) > 0

    labels = tf.constant(numpy_labels)

    sampler = (
        balanced_positive_negative_sampler.BalancedPositiveNegativeSampler(
            is_static=is_static))
    is_sampled = sampler.subsample(indicator, 64, labels)
    with self.test_session() as sess:
      is_sampled = sess.run(is_sampled)
      self.assertTrue(sum(is_sampled) == 64)
      self.assertTrue(sum(np.logical_and(numpy_labels, is_sampled)) == 32)
      self.assertTrue(sum(np.logical_and(
          np.logical_not(numpy_labels), is_sampled)) == 32)
Exemplo n.º 7
0
    def _subsample():
        balance_sampler = sampler.BalancedPositiveNegativeSampler(
            positive_fraction=balance_fraction)
        is_sampled = balance_sampler.subsample(indicators, k, labels)
        n_sampled = tf.count_nonzero(is_sampled, dtype=tf.int32)

        def get_inds():
            inds = tf.where(is_sampled)[:, 0]
            inds.set_shape(k)
            return inds

        def resample():
            inds = pad_inds_with_resampling(is_sampled, k, n_sampled)
            return tf.Print(
                inds,
                ['Warning: balance_sampler result padded', k - n_sampled])

        return tf.cond(
            tf.equal(n_sampled, k),
            get_inds,  # Return indices
            resample)  # Add k - n_sampled indices by resampling
  def _test_subsample_selection(self, is_static=False):
    # Test random sampling when only some examples can be sampled:
    # 100 samples, 20 positives, 10 positives cannot be sampled
    numpy_labels = np.arange(100)
    numpy_indicator = numpy_labels < 90
    indicator = tf.constant(numpy_indicator)
    numpy_labels = (numpy_labels - 80) >= 0

    labels = tf.constant(numpy_labels)

    sampler = (
        balanced_positive_negative_sampler.BalancedPositiveNegativeSampler(
            is_static=is_static))
    is_sampled = sampler.subsample(indicator, 64, labels)
    with self.test_session() as sess:
      is_sampled = sess.run(is_sampled)
      self.assertTrue(sum(is_sampled) == 64)
      self.assertTrue(sum(np.logical_and(numpy_labels, is_sampled)) == 10)
      self.assertTrue(sum(np.logical_and(
          np.logical_not(numpy_labels), is_sampled)) == 54)
      self.assertAllEqual(is_sampled, np.logical_and(is_sampled,
                                                     numpy_indicator))
  def _test_subsample_selection_larger_batch_size(self, is_static=False):
    # Test random sampling when total number of examples that can be sampled are
    # less than batch size:
    # 100 samples, 50 positives, 40 positives cannot be sampled, batch size 64.
    numpy_labels = np.arange(100)
    numpy_indicator = numpy_labels < 60
    indicator = tf.constant(numpy_indicator)
    numpy_labels = (numpy_labels - 50) >= 0

    labels = tf.constant(numpy_labels)

    sampler = (
        balanced_positive_negative_sampler.BalancedPositiveNegativeSampler(
            is_static=is_static))
    is_sampled = sampler.subsample(indicator, 64, labels)
    with self.test_session() as sess:
      is_sampled = sess.run(is_sampled)
      self.assertTrue(sum(is_sampled) == 60)
      self.assertTrue(sum(np.logical_and(numpy_labels, is_sampled)) == 10)
      self.assertTrue(
          sum(np.logical_and(np.logical_not(numpy_labels), is_sampled)) == 50)
      self.assertAllEqual(is_sampled, np.logical_and(is_sampled,
                                                     numpy_indicator))
Exemplo n.º 10
0
def build(loss_config):
    """Build losses based on the config.

  Builds classification, localization losses and optionally a hard example miner
  based on the config.

  Args:
    loss_config: A losses_pb2.Loss object.

  Returns:
    classification_loss: Classification loss object.
    localization_loss: Localization loss object.
    classification_weight: Classification loss weight.
    localization_weight: Localization loss weight.
    hard_example_miner: Hard example miner object.
    random_example_sampler: BalancedPositiveNegativeSampler object.

  Raises:
    ValueError: If hard_example_miner is used with sigmoid_focal_loss.
    ValueError: If random_example_sampler is getting non-positive value as
      desired positive example fraction.
  """
    classification_loss = _build_classification_loss(
        loss_config.classification_loss)
    localization_loss = _build_localization_loss(loss_config.localization_loss)
    classification_weight = loss_config.classification_weight
    localization_weight = loss_config.localization_weight
    hard_example_miner = None
    if loss_config.HasField('hard_example_miner'):
        if (loss_config.classification_loss.WhichOneof('classification_loss')
                == 'weighted_sigmoid_focal'):
            raise ValueError(
                'HardExampleMiner should not be used with sigmoid focal '
                'loss')
        hard_example_miner = build_hard_example_miner(
            loss_config.hard_example_miner, classification_weight,
            localization_weight)
    random_example_sampler = None
    if loss_config.HasField('random_example_sampler'):
        if loss_config.random_example_sampler.positive_sample_fraction <= 0:
            raise ValueError('RandomExampleSampler should not use non-positive'
                             'value as positive sample fraction.')
        random_example_sampler = sampler.BalancedPositiveNegativeSampler(
            positive_fraction=loss_config.random_example_sampler.
            positive_sample_fraction)

    if loss_config.expected_loss_weights == loss_config.NONE:
        expected_loss_weights_fn = None
    elif loss_config.expected_loss_weights == loss_config.EXPECTED_SAMPLING:
        expected_loss_weights_fn = functools.partial(
            ops.expected_classification_loss_by_expected_sampling,
            min_num_negative_samples=loss_config.min_num_negative_samples,
            desired_negative_sampling_ratio=loss_config.
            desired_negative_sampling_ratio)
    elif (loss_config.expected_loss_weights ==
          loss_config.REWEIGHTING_UNMATCHED_ANCHORS):
        expected_loss_weights_fn = functools.partial(
            ops.expected_classification_loss_by_reweighting_unmatched_anchors,
            min_num_negative_samples=loss_config.min_num_negative_samples,
            desired_negative_sampling_ratio=loss_config.
            desired_negative_sampling_ratio)
    else:
        raise ValueError('Not a valid value for expected_classification_loss.')

    return (classification_loss, localization_loss, classification_weight,
            localization_weight, hard_example_miner, random_example_sampler,
            expected_loss_weights_fn)
Exemplo n.º 11
0
    def __init__(self, desc):
        """Init faster rcnn.

        :param desc: config dict
        """
        super(FasterRCNN, self).__init__()

        self.num_classes = int(desc.num_classes)
        self.number_of_stages = int(desc.number_of_stages)

        # Backbone for feature extractor
        self.feature_extractor = NetworkDesc(desc.backbone).to_model()

        # First stage anchor generator
        self.first_stage_anchor_generator = NetworkDesc(
            desc["first_stage_anchor_generator"]).to_model()

        # First stage target assigner
        self.use_matmul_gather_in_matcher = False  # Default
        self.first_stage_target_assigner = target_assigner.create_target_assigner(
            'FasterRCNN',
            'proposal',
            use_matmul_gather=self.use_matmul_gather_in_matcher)

        # First stage box predictor
        self.first_stage_box_predictor_arg_scope_fn = scope_generator.get_hyper_params_scope(
            desc.first_stage_box_predictor_conv_hyperparams)
        self.first_stage_atrous_rate = 1  # Default: 1
        self.first_stage_box_predictor_kernel_size = 3  # Default
        self.first_stage_box_predictor_depth = 512  # Default
        self.first_stage_minibatch_size = 256  # Default

        # First stage sampler
        self.first_stage_positive_balance_fraction = 0.5  # Default
        self.use_static_balanced_label_sampler = False  # Default
        self.use_static_shapes = False  # Default
        self.first_stage_sampler = sampler.BalancedPositiveNegativeSampler(
            positive_fraction=self.first_stage_positive_balance_fraction,
            is_static=(self.use_static_balanced_label_sampler
                       and self.use_static_shapes))

        # First stage NMS
        self.first_stage_nms_score_threshold = 0.0
        self.first_stage_nms_iou_threshold = 0.7
        self.first_stage_max_proposals = 300
        self.use_partitioned_nms_in_first_stage = True  # Default
        self.use_combined_nms_in_first_stage = False  # Default
        self.first_stage_non_max_suppression_fn = functools.partial(
            post_processing.batch_multiclass_non_max_suppression,
            score_thresh=self.first_stage_nms_score_threshold,
            iou_thresh=self.first_stage_nms_iou_threshold,
            max_size_per_class=self.first_stage_max_proposals,
            max_total_size=self.first_stage_max_proposals,
            use_static_shapes=self.use_static_shapes,
            use_partitioned_nms=self.use_partitioned_nms_in_first_stage,
            use_combined_nms=self.use_combined_nms_in_first_stage)

        # First stage localization loss weight
        self.first_stage_localization_loss_weight = 2.0

        # First stage objectness loss weight
        self.first_stage_objectness_loss_weight = 1.0

        # Second stage target assigner
        self.second_stage_target_assigner = target_assigner.create_target_assigner(
            'FasterRCNN',
            'detection',
            use_matmul_gather=self.use_matmul_gather_in_matcher)

        # Second stage sampler
        self.second_stage_batch_size = 64  # Default
        self.second_stage_balance_fraction = 0.25  # Default
        self.second_stage_sampler = sampler.BalancedPositiveNegativeSampler(
            positive_fraction=self.second_stage_balance_fraction,
            is_static=(self.use_static_balanced_label_sampler
                       and self.use_static_shapes))

        # Second stage box predictor
        self.second_stage_box_predictor = NetworkDesc(
            desc.mask_rcnn_box).to_model()

        # Second stage NMS function
        self.second_stage_non_max_suppression_fn, self.second_stage_score_conversion_fn = \
            post_processing_util.get_post_processing_fn(desc.second_stage_post_processing)

        # Second stage mask prediction loss weight
        self.second_stage_mask_prediction_loss_weight = 1.0  # default

        # Second stage localization loss weight
        self.second_stage_localization_loss_weight = 2.0

        # Second stage classification loss weight
        self.second_stage_classification_loss_weight = 1.0

        # Second stage classification loss
        self.logit_scale = 1.0  # Default
        self.second_stage_classification_loss = losses.WeightedSoftmaxClassificationLoss(
            logit_scale=self.logit_scale)

        self.hard_example_miner = None
        self.add_summaries = True

        # Crop and resize function
        self.use_matmul_crop_and_resize = False  # Default
        self.crop_and_resize_fn = (
            spatial_ops.multilevel_matmul_crop_and_resize
            if self.use_matmul_crop_and_resize else
            spatial_ops.native_crop_and_resize)

        self.clip_anchors_to_image = False  # Default
        self.resize_masks = True  # Default
        self.return_raw_detections_during_predict = False  # Default
        self.output_final_box_features = False  # Default

        # Image resizer function
        self.image_resizer_fn = image_resizer_util.get_image_resizer(
            desc.image_resizer)

        self.initial_crop_size = 14
        self.maxpool_kernel_size = 2
        self.maxpool_stride = 2

        # Real model to be called
        self.model = None
Exemplo n.º 12
0
def _build_faster_rcnn_model(frcnn_config, is_training, add_summaries):
    """Builds a Faster R-CNN or R-FCN detection model based on the model config.

  Builds R-FCN model if the second_stage_box_predictor in the config is of type
  `rfcn_box_predictor` else builds a Faster R-CNN model.

  Args:
    frcnn_config: A faster_rcnn.proto object containing the config for the
      desired FasterRCNNMetaArch or RFCNMetaArch.
    is_training: True if this model is being built for training purposes.
    add_summaries: Whether to add tf summaries in the model.

  Returns:
    FasterRCNNMetaArch based on the config.

  Raises:
    ValueError: If frcnn_config.type is not recognized (i.e. not registered in
      model_class_map).
  """
    num_classes = frcnn_config.num_classes
    image_resizer_fn = image_resizer_builder.build(frcnn_config.image_resizer)

    feature_extractor = _build_faster_rcnn_feature_extractor(
        frcnn_config.feature_extractor, is_training,
        frcnn_config.inplace_batchnorm_update)

    number_of_stages = frcnn_config.number_of_stages
    first_stage_anchor_generator = anchor_generator_builder.build(
        frcnn_config.first_stage_anchor_generator)

    first_stage_target_assigner = target_assigner.create_target_assigner(
        'FasterRCNN',
        'proposal',
        use_matmul_gather=frcnn_config.use_matmul_gather_in_matcher)
    first_stage_atrous_rate = frcnn_config.first_stage_atrous_rate
    first_stage_box_predictor_arg_scope_fn = hyperparams_builder.build(
        frcnn_config.first_stage_box_predictor_conv_hyperparams, is_training)
    first_stage_box_predictor_kernel_size = (
        frcnn_config.first_stage_box_predictor_kernel_size)
    first_stage_box_predictor_depth = frcnn_config.first_stage_box_predictor_depth
    first_stage_minibatch_size = frcnn_config.first_stage_minibatch_size
    first_stage_sampler = sampler.BalancedPositiveNegativeSampler(
        positive_fraction=frcnn_config.first_stage_positive_balance_fraction,
        is_static=frcnn_config.use_static_balanced_label_sampler)
    first_stage_nms_score_threshold = frcnn_config.first_stage_nms_score_threshold
    first_stage_nms_iou_threshold = frcnn_config.first_stage_nms_iou_threshold
    first_stage_max_proposals = frcnn_config.first_stage_max_proposals
    first_stage_loc_loss_weight = (
        frcnn_config.first_stage_localization_loss_weight)
    first_stage_obj_loss_weight = frcnn_config.first_stage_objectness_loss_weight

    initial_crop_size = frcnn_config.initial_crop_size
    maxpool_kernel_size = frcnn_config.maxpool_kernel_size
    maxpool_stride = frcnn_config.maxpool_stride

    second_stage_target_assigner = target_assigner.create_target_assigner(
        'FasterRCNN',
        'detection',
        use_matmul_gather=frcnn_config.use_matmul_gather_in_matcher)
    second_stage_box_predictor = box_predictor_builder.build(
        hyperparams_builder.build,
        frcnn_config.second_stage_box_predictor,
        is_training=is_training,
        num_classes=num_classes)
    second_stage_batch_size = frcnn_config.second_stage_batch_size
    second_stage_sampler = sampler.BalancedPositiveNegativeSampler(
        positive_fraction=frcnn_config.second_stage_balance_fraction,
        is_static=frcnn_config.use_static_balanced_label_sampler)
    (second_stage_non_max_suppression_fn,
     second_stage_score_conversion_fn) = post_processing_builder.build(
         frcnn_config.second_stage_post_processing)
    second_stage_localization_loss_weight = (
        frcnn_config.second_stage_localization_loss_weight)
    second_stage_classification_loss = (
        losses_builder.build_faster_rcnn_classification_loss(
            frcnn_config.second_stage_classification_loss))
    second_stage_classification_loss_weight = (
        frcnn_config.second_stage_classification_loss_weight)
    second_stage_mask_prediction_loss_weight = (
        frcnn_config.second_stage_mask_prediction_loss_weight)

    hard_example_miner = None
    if frcnn_config.HasField('hard_example_miner'):
        hard_example_miner = losses_builder.build_hard_example_miner(
            frcnn_config.hard_example_miner,
            second_stage_classification_loss_weight,
            second_stage_localization_loss_weight)

    use_matmul_crop_and_resize = (frcnn_config.use_matmul_crop_and_resize)
    clip_anchors_to_image = (frcnn_config.clip_anchors_to_image)

    common_kwargs = {
        'is_training': is_training,
        'num_classes': num_classes,
        'image_resizer_fn': image_resizer_fn,
        'feature_extractor': feature_extractor,
        'number_of_stages': number_of_stages,
        'first_stage_anchor_generator': first_stage_anchor_generator,
        'first_stage_target_assigner': first_stage_target_assigner,
        'first_stage_atrous_rate': first_stage_atrous_rate,
        'first_stage_box_predictor_arg_scope_fn':
        first_stage_box_predictor_arg_scope_fn,
        'first_stage_box_predictor_kernel_size':
        first_stage_box_predictor_kernel_size,
        'first_stage_box_predictor_depth': first_stage_box_predictor_depth,
        'first_stage_minibatch_size': first_stage_minibatch_size,
        'first_stage_sampler': first_stage_sampler,
        'first_stage_nms_score_threshold': first_stage_nms_score_threshold,
        'first_stage_nms_iou_threshold': first_stage_nms_iou_threshold,
        'first_stage_max_proposals': first_stage_max_proposals,
        'first_stage_localization_loss_weight': first_stage_loc_loss_weight,
        'first_stage_objectness_loss_weight': first_stage_obj_loss_weight,
        'second_stage_target_assigner': second_stage_target_assigner,
        'second_stage_batch_size': second_stage_batch_size,
        'second_stage_sampler': second_stage_sampler,
        'second_stage_non_max_suppression_fn':
        second_stage_non_max_suppression_fn,
        'second_stage_score_conversion_fn': second_stage_score_conversion_fn,
        'second_stage_localization_loss_weight':
        second_stage_localization_loss_weight,
        'second_stage_classification_loss': second_stage_classification_loss,
        'second_stage_classification_loss_weight':
        second_stage_classification_loss_weight,
        'hard_example_miner': hard_example_miner,
        'add_summaries': add_summaries,
        'use_matmul_crop_and_resize': use_matmul_crop_and_resize,
        'clip_anchors_to_image': clip_anchors_to_image
    }

    if isinstance(second_stage_box_predictor,
                  rfcn_box_predictor.RfcnBoxPredictor):
        return rfcn_meta_arch.RFCNMetaArch(
            second_stage_rfcn_box_predictor=second_stage_box_predictor,
            **common_kwargs)
    else:
        return faster_rcnn_meta_arch.FasterRCNNMetaArch(
            initial_crop_size=initial_crop_size,
            maxpool_kernel_size=maxpool_kernel_size,
            maxpool_stride=maxpool_stride,
            second_stage_mask_rcnn_box_predictor=second_stage_box_predictor,
            second_stage_mask_prediction_loss_weight=(
                second_stage_mask_prediction_loss_weight),
            **common_kwargs)
Exemplo n.º 13
0
def build(loss_config):
    """Build losses based on the config.

  Builds classification, localization losses and optionally a hard example miner
  based on the config.

  Args:
    loss_config: A losses_pb2.Loss object.

  Returns:
    classification_loss: Classification loss object.
    localization_loss: Localization loss object.
    classification_weight: Classification loss weight.
    localization_weight: Localization loss weight.
    hard_example_miner: Hard example miner object.
    random_example_sampler: BalancedPositiveNegativeSampler object.

  Raises:
    ValueError: If hard_example_miner is used with sigmoid_focal_loss.
    ValueError: If random_example_sampler is getting non-positive value as
      desired positive example fraction.
  """
    classification_loss = _build_classification_loss(
        loss_config.classification_loss)
    localization_loss = _build_localization_loss(loss_config.localization_loss)
    classification_weight = loss_config.classification_weight
    localization_weight = loss_config.localization_weight
    hard_example_miner = None
    if loss_config.HasField('hard_example_miner'):
        if (loss_config.classification_loss.WhichOneof('classification_loss')
                == 'weighted_sigmoid_focal'):
            raise ValueError(
                'HardExampleMiner should not be used with sigmoid focal '
                'loss')
        hard_example_miner = build_hard_example_miner(
            loss_config.hard_example_miner, classification_weight,
            localization_weight)
    random_example_sampler = None
    if loss_config.HasField('random_example_sampler'):
        if loss_config.random_example_sampler.positive_sample_fraction <= 0:
            raise ValueError('RandomExampleSampler should not use non-positive'
                             'value as positive sample fraction.')
        random_example_sampler = sampler.BalancedPositiveNegativeSampler(
            positive_fraction=loss_config.random_example_sampler.
            positive_sample_fraction)
    bbox_ignore_background_mask = None
    if loss_config.HasField('bbox_ignore_background_mask'):

        class BboxIgnoreBackgroundMask():
            def __init__(self, overlap_threshold):
                self.overlap_threshold = overlap_threshold

        overlap_threshold = loss_config.bbox_ignore_background_mask.overlap_threshold
        if overlap_threshold < 0 or overlap_threshold > 1:
            raise ValueError(
                'BboxIgnoreBackgroundMask should have threshold between'
                '0 and 1')
        bbox_ignore_background_mask = BboxIgnoreBackgroundMask(
            overlap_threshold)

    return (classification_loss, localization_loss, classification_weight,
            localization_weight, hard_example_miner, random_example_sampler,
            bbox_ignore_background_mask)
Exemplo n.º 14
0
def _build_faster_rcnn_model(frcnn_config, is_training, add_summaries, **kwargs):
    """Builds a Faster R-CNN or R-FCN detection model based on the model config.

    Builds R-FCN model if the second_stage_box_predictor in the config is of type
    `rfcn_box_predictor` else builds a Faster R-CNN model.

    Args:
      frcnn_config: A faster_rcnn.proto object containing the config for the
        desired FasterRCNNMetaArch or RFCNMetaArch.
      is_training: True if this model is being built for training purposes.
      add_summaries: Whether to add tf summaries in the model.
      kwargs: key-value
              'rpn_type' is the type of rpn which is 'cascade_rpn','orign_rpn'
                  and 'without_rpn' which need some boxes replacing the proposal
                  generated by rpn
              'filter_fn_arg' is the args of filter fn which need the boxes to filter
                  the proposals.
              'replace_rpn_arg' is a dictionary.
                  only if the rpn_type=='without_rpn' and not None, it's useful in order to
                  replace the proposals generated by rpn with the gt which maybe adjusted.
                   'type': a string which is 'gt' or 'others'.
                   'scale': a float which is used to scale the boxes(maybe gt).

    Returns:
      FasterRCNNMetaArch based on the config.

    Raises:
      ValueError: If frcnn_config.type is not recognized (i.e. not registered in
        model_class_map).
    """
    num_classes = frcnn_config.num_classes
    image_resizer_fn = image_resizer_builder.build(frcnn_config.image_resizer)

    feature_extractor = _build_faster_rcnn_feature_extractor(
        frcnn_config.feature_extractor, is_training,
        inplace_batchnorm_update=frcnn_config.inplace_batchnorm_update)

    number_of_stages = frcnn_config.number_of_stages
    first_stage_anchor_generator = anchor_generator_builder.build(
        frcnn_config.first_stage_anchor_generator)

    first_stage_target_assigner = target_assigner.create_target_assigner(
        'FasterRCNN',
        'proposal',
        use_matmul_gather=frcnn_config.use_matmul_gather_in_matcher)
    first_stage_atrous_rate = frcnn_config.first_stage_atrous_rate
    first_stage_box_predictor_arg_scope_fn = hyperparams_builder.build(
        frcnn_config.first_stage_box_predictor_conv_hyperparams, is_training)
    first_stage_box_predictor_kernel_size = (
        frcnn_config.first_stage_box_predictor_kernel_size)
    first_stage_box_predictor_depth = frcnn_config.first_stage_box_predictor_depth
    first_stage_minibatch_size = frcnn_config.first_stage_minibatch_size
    use_static_shapes = frcnn_config.use_static_shapes and (
            frcnn_config.use_static_shapes_for_eval or is_training)
    first_stage_sampler = sampler.BalancedPositiveNegativeSampler(
        positive_fraction=frcnn_config.first_stage_positive_balance_fraction,
        is_static=(frcnn_config.use_static_balanced_label_sampler and
                   use_static_shapes))
    first_stage_max_proposals = frcnn_config.first_stage_max_proposals
    if (frcnn_config.first_stage_nms_iou_threshold < 0 or
            frcnn_config.first_stage_nms_iou_threshold > 1.0):
        raise ValueError('iou_threshold not in [0, 1.0].')
    if (is_training and frcnn_config.second_stage_batch_size >
            first_stage_max_proposals):
        raise ValueError('second_stage_batch_size should be no greater than '
                         'first_stage_max_proposals.')
    first_stage_non_max_suppression_fn = functools.partial(
        post_processing.batch_multiclass_non_max_suppression,
        score_thresh=frcnn_config.first_stage_nms_score_threshold,
        iou_thresh=frcnn_config.first_stage_nms_iou_threshold,
        max_size_per_class=frcnn_config.first_stage_max_proposals,
        max_total_size=frcnn_config.first_stage_max_proposals,
        use_static_shapes=use_static_shapes)
    first_stage_loc_loss_weight = (
        frcnn_config.first_stage_localization_loss_weight)
    first_stage_obj_loss_weight = frcnn_config.first_stage_objectness_loss_weight

    initial_crop_size = frcnn_config.initial_crop_size
    maxpool_kernel_size = frcnn_config.maxpool_kernel_size
    maxpool_stride = frcnn_config.maxpool_stride

    second_stage_target_assigner = target_assigner.create_target_assigner(
        'FasterRCNN',
        'detection',
        use_matmul_gather=frcnn_config.use_matmul_gather_in_matcher)
    second_stage_box_predictor = box_predictor_builder.build(
        hyperparams_builder.build,
        frcnn_config.second_stage_box_predictor,
        is_training=is_training,
        num_classes=num_classes)
    second_stage_batch_size = frcnn_config.second_stage_batch_size
    second_stage_sampler = sampler.BalancedPositiveNegativeSampler(
        positive_fraction=frcnn_config.second_stage_balance_fraction,
        is_static=(frcnn_config.use_static_balanced_label_sampler and
                   use_static_shapes))
    (second_stage_non_max_suppression_fn, second_stage_score_conversion_fn
     ) = post_processing_builder.build(frcnn_config.second_stage_post_processing)
    second_stage_localization_loss_weight = (
        frcnn_config.second_stage_localization_loss_weight)
    second_stage_classification_loss = (
        losses_builder.build_faster_rcnn_classification_loss(
            frcnn_config.second_stage_classification_loss))
    second_stage_classification_loss_weight = (
        frcnn_config.second_stage_classification_loss_weight)
    second_stage_mask_prediction_loss_weight = (
        frcnn_config.second_stage_mask_prediction_loss_weight)

    hard_example_miner = None
    if frcnn_config.HasField('hard_example_miner'):
        hard_example_miner = losses_builder.build_hard_example_miner(
            frcnn_config.hard_example_miner,
            second_stage_classification_loss_weight,
            second_stage_localization_loss_weight)

    crop_and_resize_fn = (
        ops.matmul_crop_and_resize if frcnn_config.use_matmul_crop_and_resize
        else ops.native_crop_and_resize)
    clip_anchors_to_image = (
        frcnn_config.clip_anchors_to_image)

    common_kwargs = {
        'is_training': is_training,
        'num_classes': num_classes,
        'image_resizer_fn': image_resizer_fn,
        'feature_extractor': feature_extractor,
        'number_of_stages': number_of_stages,
        'first_stage_anchor_generator': first_stage_anchor_generator,
        'first_stage_target_assigner': first_stage_target_assigner,
        'first_stage_atrous_rate': first_stage_atrous_rate,
        'first_stage_box_predictor_arg_scope_fn':
            first_stage_box_predictor_arg_scope_fn,
        'first_stage_box_predictor_kernel_size':
            first_stage_box_predictor_kernel_size,
        'first_stage_box_predictor_depth': first_stage_box_predictor_depth,
        'first_stage_minibatch_size': first_stage_minibatch_size,
        'first_stage_sampler': first_stage_sampler,
        'first_stage_non_max_suppression_fn': first_stage_non_max_suppression_fn,
        'first_stage_max_proposals': first_stage_max_proposals,
        'first_stage_localization_loss_weight': first_stage_loc_loss_weight,
        'first_stage_objectness_loss_weight': first_stage_obj_loss_weight,
        'second_stage_target_assigner': second_stage_target_assigner,
        'second_stage_batch_size': second_stage_batch_size,
        'second_stage_sampler': second_stage_sampler,
        'second_stage_non_max_suppression_fn':
            second_stage_non_max_suppression_fn,
        'second_stage_score_conversion_fn': second_stage_score_conversion_fn,
        'second_stage_localization_loss_weight':
            second_stage_localization_loss_weight,
        'second_stage_classification_loss':
            second_stage_classification_loss,
        'second_stage_classification_loss_weight':
            second_stage_classification_loss_weight,
        'hard_example_miner': hard_example_miner,
        'add_summaries': add_summaries,
        'crop_and_resize_fn': crop_and_resize_fn,
        'clip_anchors_to_image': clip_anchors_to_image,
        'use_static_shapes': use_static_shapes,
        'resize_masks': frcnn_config.resize_masks
    }

    filter_fn_arg = kwargs.get('filter_fn_arg')
    if filter_fn_arg:
        filter_fn = functools.partial(filter_bbox, **filter_fn_arg)
        common_kwargs['filter_fn'] = filter_fn
    rpn_type = kwargs.get('rpn_type')
    if rpn_type:
        common_kwargs['rpn_type'] = rpn_type
    replace_rpn_arg = kwargs.get('replace_rpn_arg')
    if replace_rpn_arg:
        common_kwargs['replace_rpn_arg'] = replace_rpn_arg

    if isinstance(second_stage_box_predictor,
                  rfcn_box_predictor.RfcnBoxPredictor):
        return rfcn_meta_arch.RFCNMetaArch(
            second_stage_rfcn_box_predictor=second_stage_box_predictor,
            **common_kwargs)
    else:
        return faster_rcnn_meta_arch.FasterRCNNMetaArch(
            initial_crop_size=initial_crop_size,
            maxpool_kernel_size=maxpool_kernel_size,
            maxpool_stride=maxpool_stride,
            second_stage_mask_rcnn_box_predictor=second_stage_box_predictor,
            second_stage_mask_prediction_loss_weight=(
                second_stage_mask_prediction_loss_weight),
            **common_kwargs)
Exemplo n.º 15
0
  def _create_model(
      self,
      model_fn=ssd_meta_arch.SSDMetaArch,
      apply_hard_mining=True,
      normalize_loc_loss_by_codesize=False,
      add_background_class=True,
      random_example_sampling=False,
      expected_loss_weights=model_pb2.DetectionModel().ssd.loss.NONE,
      min_num_negative_samples=1,
      desired_negative_sampling_ratio=3,
      use_keras=False,
      predict_mask=False,
      use_static_shapes=False,
      nms_max_size_per_class=5,
      calibration_mapping_value=None,
      return_raw_detections_during_predict=False):
    is_training = False
    num_classes = 1
    mock_anchor_generator = MockAnchorGenerator2x2()
    if use_keras:
      mock_box_predictor = test_utils.MockKerasBoxPredictor(
          is_training, num_classes, add_background_class=add_background_class)
    else:
      mock_box_predictor = test_utils.MockBoxPredictor(
          is_training, num_classes, add_background_class=add_background_class)
    mock_box_coder = test_utils.MockBoxCoder()
    if use_keras:
      fake_feature_extractor = FakeSSDKerasFeatureExtractor()
    else:
      fake_feature_extractor = FakeSSDFeatureExtractor()
    mock_matcher = test_utils.MockMatcher()
    region_similarity_calculator = sim_calc.IouSimilarity()
    encode_background_as_zeros = False

    def image_resizer_fn(image):
      return [tf.identity(image), tf.shape(image)]

    classification_loss = losses.WeightedSigmoidClassificationLoss()
    localization_loss = losses.WeightedSmoothL1LocalizationLoss()
    non_max_suppression_fn = functools.partial(
        post_processing.batch_multiclass_non_max_suppression,
        score_thresh=-20.0,
        iou_thresh=1.0,
        max_size_per_class=nms_max_size_per_class,
        max_total_size=nms_max_size_per_class,
        use_static_shapes=use_static_shapes)
    score_conversion_fn = tf.identity
    calibration_config = calibration_pb2.CalibrationConfig()
    if calibration_mapping_value:
      calibration_text_proto = """
      function_approximation {
        x_y_pairs {
            x_y_pair {
              x: 0.0
              y: %f
            }
            x_y_pair {
              x: 1.0
              y: %f
            }}}""" % (calibration_mapping_value, calibration_mapping_value)
      text_format.Merge(calibration_text_proto, calibration_config)
      score_conversion_fn = (
          post_processing_builder._build_calibrated_score_converter(  # pylint: disable=protected-access
              tf.identity, calibration_config))
    classification_loss_weight = 1.0
    localization_loss_weight = 1.0
    negative_class_weight = 1.0
    normalize_loss_by_num_matches = False

    hard_example_miner = None
    if apply_hard_mining:
      # This hard example miner is expected to be a no-op.
      hard_example_miner = losses.HardExampleMiner(
          num_hard_examples=None, iou_threshold=1.0)

    random_example_sampler = None
    if random_example_sampling:
      random_example_sampler = sampler.BalancedPositiveNegativeSampler(
          positive_fraction=0.5)

    target_assigner_instance = target_assigner.TargetAssigner(
        region_similarity_calculator,
        mock_matcher,
        mock_box_coder,
        negative_class_weight=negative_class_weight)

    model_config = model_pb2.DetectionModel()
    if expected_loss_weights == model_config.ssd.loss.NONE:
      expected_loss_weights_fn = None
    else:
      raise ValueError('Not a valid value for expected_loss_weights.')

    code_size = 4

    kwargs = {}
    if predict_mask:
      kwargs.update({
          'mask_prediction_fn': test_utils.MockMaskHead(num_classes=1).predict,
      })

    model = model_fn(
        is_training=is_training,
        anchor_generator=mock_anchor_generator,
        box_predictor=mock_box_predictor,
        box_coder=mock_box_coder,
        feature_extractor=fake_feature_extractor,
        encode_background_as_zeros=encode_background_as_zeros,
        image_resizer_fn=image_resizer_fn,
        non_max_suppression_fn=non_max_suppression_fn,
        score_conversion_fn=score_conversion_fn,
        classification_loss=classification_loss,
        localization_loss=localization_loss,
        classification_loss_weight=classification_loss_weight,
        localization_loss_weight=localization_loss_weight,
        normalize_loss_by_num_matches=normalize_loss_by_num_matches,
        hard_example_miner=hard_example_miner,
        target_assigner_instance=target_assigner_instance,
        add_summaries=False,
        normalize_loc_loss_by_codesize=normalize_loc_loss_by_codesize,
        freeze_batchnorm=False,
        inplace_batchnorm_update=False,
        add_background_class=add_background_class,
        random_example_sampler=random_example_sampler,
        expected_loss_weights_fn=expected_loss_weights_fn,
        return_raw_detections_during_predict=(
            return_raw_detections_during_predict),
        **kwargs)
    return model, num_classes, mock_anchor_generator.num_anchors(), code_size
Exemplo n.º 16
0
    def _create_model(self,
                      apply_hard_mining=True,
                      normalize_loc_loss_by_codesize=False,
                      add_background_class=True,
                      random_example_sampling=False):
        is_training = False
        num_classes = 1
        mock_anchor_generator = MockAnchorGenerator2x2()
        mock_box_predictor = test_utils.MockBoxPredictor(
            is_training, num_classes)
        mock_box_coder = test_utils.MockBoxCoder()
        fake_feature_extractor = FakeSSDFeatureExtractor()
        mock_matcher = test_utils.MockMatcher()
        region_similarity_calculator = sim_calc.IouSimilarity()
        encode_background_as_zeros = False

        def image_resizer_fn(image):
            return [tf.identity(image), tf.shape(image)]

        classification_loss = losses.WeightedSigmoidClassificationLoss()
        localization_loss = losses.WeightedSmoothL1LocalizationLoss()
        non_max_suppression_fn = functools.partial(
            post_processing.batch_multiclass_non_max_suppression,
            score_thresh=-20.0,
            iou_thresh=1.0,
            max_size_per_class=5,
            max_total_size=5)
        classification_loss_weight = 1.0
        localization_loss_weight = 1.0
        negative_class_weight = 1.0
        normalize_loss_by_num_matches = False

        hard_example_miner = None
        if apply_hard_mining:
            # This hard example miner is expected to be a no-op.
            hard_example_miner = losses.HardExampleMiner(
                num_hard_examples=None, iou_threshold=1.0)

        random_example_sampler = None
        if random_example_sampling:
            random_example_sampler = sampler.BalancedPositiveNegativeSampler(
                positive_fraction=0.5)

        code_size = 4
        model = ssd_meta_arch.SSDMetaArch(
            is_training,
            mock_anchor_generator,
            mock_box_predictor,
            mock_box_coder,
            fake_feature_extractor,
            mock_matcher,
            region_similarity_calculator,
            encode_background_as_zeros,
            negative_class_weight,
            image_resizer_fn,
            non_max_suppression_fn,
            tf.identity,
            classification_loss,
            localization_loss,
            classification_loss_weight,
            localization_loss_weight,
            normalize_loss_by_num_matches,
            hard_example_miner,
            add_summaries=False,
            normalize_loc_loss_by_codesize=normalize_loc_loss_by_codesize,
            freeze_batchnorm=False,
            inplace_batchnorm_update=False,
            add_background_class=add_background_class,
            random_example_sampler=random_example_sampler)
        return model, num_classes, mock_anchor_generator.num_anchors(
        ), code_size
Exemplo n.º 17
0
def _build_faster_rcnn_model(frcnn_config,
                             is_training,
                             add_summaries,
                             meta_architecture='faster_rcnn'):
    """Builds a Faster R-CNN or R-FCN detection model based on the model config.

  Builds R-FCN model if the second_stage_box_predictor in the config is of type
  `rfcn_box_predictor` else builds a Faster R-CNN model.

  Args:
    frcnn_config: A faster_rcnn.proto object containing the config for the
      desired FasterRCNNMetaArch or RFCNMetaArch.
    is_training: True if this model is being built for training purposes.
    add_summaries: Whether to add tf summaries in the model.

  Returns:
    FasterRCNNMetaArch based on the config.

  Raises:
    ValueError: If frcnn_config.type is not recognized (i.e. not registered in
      model_class_map).
  """
    num_classes = frcnn_config.num_classes
    image_resizer_fn = image_resizer_builder.build(frcnn_config.image_resizer)

    feature_extractor = _build_faster_rcnn_feature_extractor(
        frcnn_config.feature_extractor, is_training,
        frcnn_config.inplace_batchnorm_update)

    number_of_stages = frcnn_config.number_of_stages
    first_stage_anchor_generator = anchor_generator_builder.build(
        frcnn_config.first_stage_anchor_generator)

    first_stage_target_assigner = target_assigner.create_target_assigner(
        'FasterRCNN',
        'proposal',
        use_matmul_gather=frcnn_config.use_matmul_gather_in_matcher)
    first_stage_atrous_rate = frcnn_config.first_stage_atrous_rate
    first_stage_box_predictor_arg_scope_fn = hyperparams_builder.build(
        frcnn_config.first_stage_box_predictor_conv_hyperparams, is_training)
    first_stage_box_predictor_kernel_size = (
        frcnn_config.first_stage_box_predictor_kernel_size)
    first_stage_box_predictor_depth = frcnn_config.first_stage_box_predictor_depth
    first_stage_minibatch_size = frcnn_config.first_stage_minibatch_size
    # TODO(bhattad): When eval is supported using static shapes, add separate
    # use_static_shapes_for_trainig and use_static_shapes_for_evaluation.
    use_static_shapes = frcnn_config.use_static_shapes and is_training
    first_stage_sampler = sampler.BalancedPositiveNegativeSampler(
        positive_fraction=frcnn_config.first_stage_positive_balance_fraction,
        is_static=frcnn_config.use_static_balanced_label_sampler
        and is_training)
    first_stage_max_proposals = frcnn_config.first_stage_max_proposals
    first_stage_proposals_path = frcnn_config.first_stage_proposals_path
    if (frcnn_config.first_stage_nms_iou_threshold < 0
            or frcnn_config.first_stage_nms_iou_threshold > 1.0):
        raise ValueError('iou_threshold not in [0, 1.0].')
    if (is_training and
            frcnn_config.second_stage_batch_size > first_stage_max_proposals):
        raise ValueError('second_stage_batch_size should be no greater than '
                         'first_stage_max_proposals.')
    first_stage_non_max_suppression_fn = functools.partial(
        post_processing.batch_multiclass_non_max_suppression,
        score_thresh=frcnn_config.first_stage_nms_score_threshold,
        iou_thresh=frcnn_config.first_stage_nms_iou_threshold,
        max_size_per_class=frcnn_config.first_stage_max_proposals,
        max_total_size=frcnn_config.first_stage_max_proposals,
        use_static_shapes=use_static_shapes and is_training)
    first_stage_loc_loss_weight = (
        frcnn_config.first_stage_localization_loss_weight)
    first_stage_obj_loss_weight = frcnn_config.first_stage_objectness_loss_weight

    initial_crop_size = frcnn_config.initial_crop_size
    maxpool_kernel_size = frcnn_config.maxpool_kernel_size
    maxpool_stride = frcnn_config.maxpool_stride

    second_stage_target_assigner = target_assigner.create_target_assigner(
        'FasterRCNN',
        'detection',
        use_matmul_gather=frcnn_config.use_matmul_gather_in_matcher,
        iou_threshold=frcnn_config.second_stage_target_iou_threshold)
    second_stage_box_predictor = box_predictor_builder.build(
        hyperparams_builder.build,
        frcnn_config.second_stage_box_predictor,
        is_training=is_training,
        num_classes=num_classes)
    second_stage_batch_size = frcnn_config.second_stage_batch_size
    second_stage_sampler = sampler.BalancedPositiveNegativeSampler(
        positive_fraction=frcnn_config.second_stage_balance_fraction,
        is_static=frcnn_config.use_static_balanced_label_sampler
        and is_training)
    (second_stage_non_max_suppression_fn,
     second_stage_score_conversion_fn) = post_processing_builder.build(
         frcnn_config.second_stage_post_processing)
    second_stage_localization_loss_weight = (
        frcnn_config.second_stage_localization_loss_weight)
    second_stage_classification_loss = (
        losses_builder.build_faster_rcnn_classification_loss(
            frcnn_config.second_stage_classification_loss))
    second_stage_classification_loss_weight = (
        frcnn_config.second_stage_classification_loss_weight)
    second_stage_mask_prediction_loss_weight = (
        frcnn_config.second_stage_mask_prediction_loss_weight)

    hard_example_miner = None
    if frcnn_config.HasField('hard_example_miner'):
        hard_example_miner = losses_builder.build_hard_example_miner(
            frcnn_config.hard_example_miner,
            second_stage_classification_loss_weight,
            second_stage_localization_loss_weight)

    crop_and_resize_fn = (ops.matmul_crop_and_resize
                          if frcnn_config.use_matmul_crop_and_resize else
                          ops.native_crop_and_resize)
    clip_anchors_to_image = (frcnn_config.clip_anchors_to_image)

    common_kwargs = {
        'is_training': is_training,
        'num_classes': num_classes,
        'image_resizer_fn': image_resizer_fn,
        'feature_extractor': feature_extractor,
        'number_of_stages': number_of_stages,
        'first_stage_anchor_generator': first_stage_anchor_generator,
        'first_stage_target_assigner': first_stage_target_assigner,
        'first_stage_atrous_rate': first_stage_atrous_rate,
        'first_stage_box_predictor_arg_scope_fn':
        first_stage_box_predictor_arg_scope_fn,
        'first_stage_box_predictor_kernel_size':
        first_stage_box_predictor_kernel_size,
        'first_stage_box_predictor_depth': first_stage_box_predictor_depth,
        'first_stage_minibatch_size': first_stage_minibatch_size,
        'first_stage_sampler': first_stage_sampler,
        'first_stage_non_max_suppression_fn':
        first_stage_non_max_suppression_fn,
        'first_stage_max_proposals': first_stage_max_proposals,
        'first_stage_localization_loss_weight': first_stage_loc_loss_weight,
        'first_stage_objectness_loss_weight': first_stage_obj_loss_weight,
        'second_stage_target_assigner': second_stage_target_assigner,
        'second_stage_batch_size': second_stage_batch_size,
        'second_stage_sampler': second_stage_sampler,
        'second_stage_non_max_suppression_fn':
        second_stage_non_max_suppression_fn,
        'second_stage_score_conversion_fn': second_stage_score_conversion_fn,
        'second_stage_localization_loss_weight':
        second_stage_localization_loss_weight,
        'second_stage_classification_loss': second_stage_classification_loss,
        'second_stage_classification_loss_weight':
        second_stage_classification_loss_weight,
        'hard_example_miner': hard_example_miner,
        'add_summaries': add_summaries,
        'crop_and_resize_fn': crop_and_resize_fn,
        'clip_anchors_to_image': clip_anchors_to_image,
        'use_static_shapes': use_static_shapes,
        'resize_masks': frcnn_config.resize_masks
    }

    if isinstance(second_stage_box_predictor,
                  rfcn_box_predictor.RfcnBoxPredictor):
        return rfcn_meta_arch.RFCNMetaArch(
            second_stage_rfcn_box_predictor=second_stage_box_predictor,
            **common_kwargs)
    elif meta_architecture == 'faster_rcnn':
        return faster_rcnn_meta_arch.FasterRCNNMetaArch(
            initial_crop_size=initial_crop_size,
            maxpool_kernel_size=maxpool_kernel_size,
            maxpool_stride=maxpool_stride,
            second_stage_mask_rcnn_box_predictor=second_stage_box_predictor,
            second_stage_mask_prediction_loss_weight=(
                second_stage_mask_prediction_loss_weight),
            **common_kwargs)
    elif meta_architecture == 'faster_rcnn_override_RPN':
        return faster_rcnn_meta_arch_override_RPN.FasterRCNNMetaArchOverrideRPN(
            initial_crop_size=initial_crop_size,
            maxpool_kernel_size=maxpool_kernel_size,
            maxpool_stride=maxpool_stride,
            first_stage_proposals_path=first_stage_proposals_path,
            second_stage_mask_rcnn_box_predictor=second_stage_box_predictor,
            second_stage_mask_prediction_loss_weight=(
                second_stage_mask_prediction_loss_weight),
            **common_kwargs)
    elif meta_architecture == 'faster_rcnn_rpn_blend':
        common_kwargs['use_matmul_crop_and_resize'] = False
        common_kwargs[
            'first_stage_nms_iou_threshold'] = frcnn_config.first_stage_nms_iou_threshold
        common_kwargs[
            'first_stage_nms_score_threshold'] = frcnn_config.first_stage_nms_score_threshold
        common_kwargs.pop('crop_and_resize_fn')
        common_kwargs.pop('first_stage_non_max_suppression_fn')
        common_kwargs.pop('resize_masks')
        common_kwargs.pop('use_static_shapes')
        return faster_rcnn_meta_arch_rpn_blend.FasterRCNNMetaArchRPNBlend(
            initial_crop_size=initial_crop_size,
            maxpool_kernel_size=maxpool_kernel_size,
            maxpool_stride=maxpool_stride,
            first_stage_proposals_path=first_stage_proposals_path,
            second_stage_mask_rcnn_box_predictor=second_stage_box_predictor,
            second_stage_mask_prediction_loss_weight=(
                second_stage_mask_prediction_loss_weight),
            **common_kwargs)
Exemplo n.º 18
0
    def _build_model(self,
                     is_training,
                     number_of_stages,
                     second_stage_batch_size,
                     first_stage_max_proposals=8,
                     num_classes=2,
                     hard_mining=False,
                     softmax_second_stage_classification_loss=True,
                     predict_masks=False,
                     pad_to_max_dimension=None,
                     masks_are_class_agnostic=False,
                     use_matmul_crop_and_resize=False,
                     clip_anchors_to_image=False,
                     use_matmul_gather_in_matcher=False,
                     use_static_shapes=False,
                     calibration_mapping_value=None,
                     share_box_across_classes=False,
                     return_raw_detections_during_predict=False):
        use_keras = tf_version.is_tf2()

        def image_resizer_fn(image, masks=None):
            """Fake image resizer function."""
            resized_inputs = []
            resized_image = tf.identity(image)
            if pad_to_max_dimension is not None:
                resized_image = tf.image.pad_to_bounding_box(
                    image, 0, 0, pad_to_max_dimension, pad_to_max_dimension)
            resized_inputs.append(resized_image)
            if masks is not None:
                resized_masks = tf.identity(masks)
                if pad_to_max_dimension is not None:
                    resized_masks = tf.image.pad_to_bounding_box(
                        tf.transpose(masks, [1, 2, 0]), 0, 0,
                        pad_to_max_dimension, pad_to_max_dimension)
                    resized_masks = tf.transpose(resized_masks, [2, 0, 1])
                resized_inputs.append(resized_masks)
            resized_inputs.append(tf.shape(image))
            return resized_inputs

        # anchors in this test are designed so that a subset of anchors are inside
        # the image and a subset of anchors are outside.
        first_stage_anchor_scales = (0.001, 0.005, 0.1)
        first_stage_anchor_aspect_ratios = (0.5, 1.0, 2.0)
        first_stage_anchor_strides = (1, 1)
        first_stage_anchor_generator = grid_anchor_generator.GridAnchorGenerator(
            first_stage_anchor_scales,
            first_stage_anchor_aspect_ratios,
            anchor_stride=first_stage_anchor_strides)
        first_stage_target_assigner = target_assigner.create_target_assigner(
            'FasterRCNN',
            'proposal',
            use_matmul_gather=use_matmul_gather_in_matcher)

        if use_keras:
            fake_feature_extractor = FakeFasterRCNNKerasFeatureExtractor()
        else:
            fake_feature_extractor = FakeFasterRCNNFeatureExtractor()

        first_stage_box_predictor_hyperparams_text_proto = """
      op: CONV
      activation: RELU
      regularizer {
        l2_regularizer {
          weight: 0.00004
        }
      }
      initializer {
        truncated_normal_initializer {
          stddev: 0.03
        }
      }
    """
        if use_keras:
            first_stage_box_predictor_arg_scope_fn = (
                self._build_keras_layer_hyperparams(
                    first_stage_box_predictor_hyperparams_text_proto))
        else:
            first_stage_box_predictor_arg_scope_fn = (
                self._build_arg_scope_with_hyperparams(
                    first_stage_box_predictor_hyperparams_text_proto,
                    is_training))

        first_stage_box_predictor_kernel_size = 3
        first_stage_atrous_rate = 1
        first_stage_box_predictor_depth = 512
        first_stage_minibatch_size = 3
        first_stage_sampler = sampler.BalancedPositiveNegativeSampler(
            positive_fraction=0.5, is_static=use_static_shapes)

        first_stage_nms_score_threshold = -1.0
        first_stage_nms_iou_threshold = 1.0
        first_stage_max_proposals = first_stage_max_proposals
        first_stage_non_max_suppression_fn = functools.partial(
            post_processing.batch_multiclass_non_max_suppression,
            score_thresh=first_stage_nms_score_threshold,
            iou_thresh=first_stage_nms_iou_threshold,
            max_size_per_class=first_stage_max_proposals,
            max_total_size=first_stage_max_proposals,
            use_static_shapes=use_static_shapes)

        first_stage_localization_loss_weight = 1.0
        first_stage_objectness_loss_weight = 1.0

        post_processing_config = post_processing_pb2.PostProcessing()
        post_processing_text_proto = """
      score_converter: IDENTITY
      batch_non_max_suppression {
        score_threshold: -20.0
        iou_threshold: 1.0
        max_detections_per_class: 5
        max_total_detections: 5
        use_static_shapes: """ + '{}'.format(use_static_shapes) + """
      }
    """
        if calibration_mapping_value:
            calibration_text_proto = """
      calibration_config {
        function_approximation {
          x_y_pairs {
            x_y_pair {
              x: 0.0
              y: %f
            }
            x_y_pair {
              x: 1.0
              y: %f
              }}}}""" % (calibration_mapping_value, calibration_mapping_value)
            post_processing_text_proto = (post_processing_text_proto + ' ' +
                                          calibration_text_proto)
        text_format.Merge(post_processing_text_proto, post_processing_config)
        second_stage_non_max_suppression_fn, second_stage_score_conversion_fn = (
            post_processing_builder.build(post_processing_config))

        second_stage_target_assigner = target_assigner.create_target_assigner(
            'FasterRCNN',
            'detection',
            use_matmul_gather=use_matmul_gather_in_matcher)
        second_stage_sampler = sampler.BalancedPositiveNegativeSampler(
            positive_fraction=1.0, is_static=use_static_shapes)

        second_stage_localization_loss_weight = 1.0
        second_stage_classification_loss_weight = 1.0
        if softmax_second_stage_classification_loss:
            second_stage_classification_loss = (
                losses.WeightedSoftmaxClassificationLoss())
        else:
            second_stage_classification_loss = (
                losses.WeightedSigmoidClassificationLoss())

        hard_example_miner = None
        if hard_mining:
            hard_example_miner = losses.HardExampleMiner(
                num_hard_examples=1,
                iou_threshold=0.99,
                loss_type='both',
                cls_loss_weight=second_stage_classification_loss_weight,
                loc_loss_weight=second_stage_localization_loss_weight,
                max_negatives_per_positive=None)

        crop_and_resize_fn = (ops.matmul_crop_and_resize
                              if use_matmul_crop_and_resize else
                              ops.native_crop_and_resize)
        common_kwargs = {
            'is_training':
            is_training,
            'num_classes':
            num_classes,
            'image_resizer_fn':
            image_resizer_fn,
            'feature_extractor':
            fake_feature_extractor,
            'number_of_stages':
            number_of_stages,
            'first_stage_anchor_generator':
            first_stage_anchor_generator,
            'first_stage_target_assigner':
            first_stage_target_assigner,
            'first_stage_atrous_rate':
            first_stage_atrous_rate,
            'first_stage_box_predictor_arg_scope_fn':
            first_stage_box_predictor_arg_scope_fn,
            'first_stage_box_predictor_kernel_size':
            first_stage_box_predictor_kernel_size,
            'first_stage_box_predictor_depth':
            first_stage_box_predictor_depth,
            'first_stage_minibatch_size':
            first_stage_minibatch_size,
            'first_stage_sampler':
            first_stage_sampler,
            'first_stage_non_max_suppression_fn':
            first_stage_non_max_suppression_fn,
            'first_stage_max_proposals':
            first_stage_max_proposals,
            'first_stage_localization_loss_weight':
            first_stage_localization_loss_weight,
            'first_stage_objectness_loss_weight':
            first_stage_objectness_loss_weight,
            'second_stage_target_assigner':
            second_stage_target_assigner,
            'second_stage_batch_size':
            second_stage_batch_size,
            'second_stage_sampler':
            second_stage_sampler,
            'second_stage_non_max_suppression_fn':
            second_stage_non_max_suppression_fn,
            'second_stage_score_conversion_fn':
            second_stage_score_conversion_fn,
            'second_stage_localization_loss_weight':
            second_stage_localization_loss_weight,
            'second_stage_classification_loss_weight':
            second_stage_classification_loss_weight,
            'second_stage_classification_loss':
            second_stage_classification_loss,
            'hard_example_miner':
            hard_example_miner,
            'crop_and_resize_fn':
            crop_and_resize_fn,
            'clip_anchors_to_image':
            clip_anchors_to_image,
            'use_static_shapes':
            use_static_shapes,
            'resize_masks':
            True,
            'return_raw_detections_during_predict':
            return_raw_detections_during_predict
        }

        return self._get_model(
            self._get_second_stage_box_predictor(
                num_classes=num_classes,
                is_training=is_training,
                use_keras=use_keras,
                predict_masks=predict_masks,
                masks_are_class_agnostic=masks_are_class_agnostic,
                share_box_across_classes=share_box_across_classes),
            **common_kwargs)
Exemplo n.º 19
0
def _build_faster_rcnn_model(frcnn_config, is_training, add_summaries):
    """Builds a Faster R-CNN or R-FCN detection model based on the model config.

  Builds R-FCN model if the second_stage_box_predictor in the config is of type
  `rfcn_box_predictor` else builds a Faster R-CNN model.

  Args:
    frcnn_config: A faster_rcnn.proto object containing the config for the
      desired FasterRCNNMetaArch or RFCNMetaArch.
    is_training: True if this model is being built for training purposes.
    add_summaries: Whether to add tf summaries in the model.

  Returns:
    FasterRCNNMetaArch based on the config.

  Raises:
    ValueError: If frcnn_config.type is not recognized (i.e. not registered in
      model_class_map).
  """
    num_classes = frcnn_config.num_classes
    image_resizer_fn = image_resizer_builder.build(frcnn_config.image_resizer)

    is_keras = (frcnn_config.feature_extractor.type
                in FASTER_RCNN_KERAS_FEATURE_EXTRACTOR_CLASS_MAP)

    if is_keras:
        feature_extractor = _build_faster_rcnn_keras_feature_extractor(
            frcnn_config.feature_extractor,
            is_training,
            inplace_batchnorm_update=frcnn_config.inplace_batchnorm_update)
    else:
        feature_extractor = _build_faster_rcnn_feature_extractor(
            frcnn_config.feature_extractor,
            is_training,
            inplace_batchnorm_update=frcnn_config.inplace_batchnorm_update)

    number_of_stages = frcnn_config.number_of_stages
    first_stage_anchor_generator = anchor_generator_builder.build(
        frcnn_config.first_stage_anchor_generator)

    first_stage_target_assigner = target_assigner.create_target_assigner(
        'FasterRCNN',
        'proposal',
        use_matmul_gather=frcnn_config.use_matmul_gather_in_matcher)
    first_stage_atrous_rate = frcnn_config.first_stage_atrous_rate
    if is_keras:
        first_stage_box_predictor_arg_scope_fn = (
            hyperparams_builder.KerasLayerHyperparams(
                frcnn_config.first_stage_box_predictor_conv_hyperparams))
    else:
        first_stage_box_predictor_arg_scope_fn = hyperparams_builder.build(
            frcnn_config.first_stage_box_predictor_conv_hyperparams,
            is_training)
    first_stage_box_predictor_kernel_size = (
        frcnn_config.first_stage_box_predictor_kernel_size)
    first_stage_box_predictor_depth = frcnn_config.first_stage_box_predictor_depth
    first_stage_minibatch_size = frcnn_config.first_stage_minibatch_size
    use_static_shapes = frcnn_config.use_static_shapes and (
        frcnn_config.use_static_shapes_for_eval or is_training)
    first_stage_sampler = sampler.BalancedPositiveNegativeSampler(
        positive_fraction=frcnn_config.first_stage_positive_balance_fraction,
        is_static=(frcnn_config.use_static_balanced_label_sampler
                   and use_static_shapes))
    first_stage_max_proposals = frcnn_config.first_stage_max_proposals
    if (frcnn_config.first_stage_nms_iou_threshold < 0
            or frcnn_config.first_stage_nms_iou_threshold > 1.0):
        raise ValueError('iou_threshold not in [0, 1.0].')
    if (is_training and
            frcnn_config.second_stage_batch_size > first_stage_max_proposals):
        raise ValueError('second_stage_batch_size should be no greater than '
                         'first_stage_max_proposals.')
    first_stage_non_max_suppression_fn = functools.partial(
        post_processing.batch_multiclass_non_max_suppression,
        score_thresh=frcnn_config.first_stage_nms_score_threshold,
        iou_thresh=frcnn_config.first_stage_nms_iou_threshold,
        max_size_per_class=frcnn_config.first_stage_max_proposals,
        max_total_size=frcnn_config.first_stage_max_proposals,
        use_static_shapes=use_static_shapes,
        use_partitioned_nms=frcnn_config.use_partitioned_nms_in_first_stage,
        use_combined_nms=frcnn_config.use_combined_nms_in_first_stage)
    first_stage_loc_loss_weight = (
        frcnn_config.first_stage_localization_loss_weight)
    first_stage_obj_loss_weight = frcnn_config.first_stage_objectness_loss_weight

    initial_crop_size = frcnn_config.initial_crop_size
    maxpool_kernel_size = frcnn_config.maxpool_kernel_size
    maxpool_stride = frcnn_config.maxpool_stride

    second_stage_target_assigner = target_assigner.create_target_assigner(
        'FasterRCNN',
        'detection',
        use_matmul_gather=frcnn_config.use_matmul_gather_in_matcher)
    if is_keras:
        second_stage_box_predictor = box_predictor_builder.build_keras(
            hyperparams_builder.KerasLayerHyperparams,
            freeze_batchnorm=False,
            inplace_batchnorm_update=False,
            num_predictions_per_location_list=[1],
            box_predictor_config=frcnn_config.second_stage_box_predictor,
            is_training=is_training,
            num_classes=num_classes)
    else:
        second_stage_box_predictor = box_predictor_builder.build(
            hyperparams_builder.build,
            frcnn_config.second_stage_box_predictor,
            is_training=is_training,
            num_classes=num_classes)
    second_stage_batch_size = frcnn_config.second_stage_batch_size
    second_stage_sampler = sampler.BalancedPositiveNegativeSampler(
        positive_fraction=frcnn_config.second_stage_balance_fraction,
        is_static=(frcnn_config.use_static_balanced_label_sampler
                   and use_static_shapes))
    (second_stage_non_max_suppression_fn,
     second_stage_score_conversion_fn) = post_processing_builder.build(
         frcnn_config.second_stage_post_processing)
    second_stage_localization_loss_weight = (
        frcnn_config.second_stage_localization_loss_weight)
    second_stage_classification_loss = (
        losses_builder.build_faster_rcnn_classification_loss(
            frcnn_config.second_stage_classification_loss))
    second_stage_classification_loss_weight = (
        frcnn_config.second_stage_classification_loss_weight)
    second_stage_mask_prediction_loss_weight = (
        frcnn_config.second_stage_mask_prediction_loss_weight)

    hard_example_miner = None
    if frcnn_config.HasField('hard_example_miner'):
        hard_example_miner = losses_builder.build_hard_example_miner(
            frcnn_config.hard_example_miner,
            second_stage_classification_loss_weight,
            second_stage_localization_loss_weight)

    crop_and_resize_fn = (ops.matmul_crop_and_resize
                          if frcnn_config.use_matmul_crop_and_resize else
                          ops.native_crop_and_resize)
    clip_anchors_to_image = (frcnn_config.clip_anchors_to_image)

    common_kwargs = {
        'is_training':
        is_training,
        'num_classes':
        num_classes,
        'image_resizer_fn':
        image_resizer_fn,
        'feature_extractor':
        feature_extractor,
        'number_of_stages':
        number_of_stages,
        'first_stage_anchor_generator':
        first_stage_anchor_generator,
        'first_stage_target_assigner':
        first_stage_target_assigner,
        'first_stage_atrous_rate':
        first_stage_atrous_rate,
        'first_stage_box_predictor_arg_scope_fn':
        first_stage_box_predictor_arg_scope_fn,
        'first_stage_box_predictor_kernel_size':
        first_stage_box_predictor_kernel_size,
        'first_stage_box_predictor_depth':
        first_stage_box_predictor_depth,
        'first_stage_minibatch_size':
        first_stage_minibatch_size,
        'first_stage_sampler':
        first_stage_sampler,
        'first_stage_non_max_suppression_fn':
        first_stage_non_max_suppression_fn,
        'first_stage_max_proposals':
        first_stage_max_proposals,
        'first_stage_localization_loss_weight':
        first_stage_loc_loss_weight,
        'first_stage_objectness_loss_weight':
        first_stage_obj_loss_weight,
        'second_stage_target_assigner':
        second_stage_target_assigner,
        'second_stage_batch_size':
        second_stage_batch_size,
        'second_stage_sampler':
        second_stage_sampler,
        'second_stage_non_max_suppression_fn':
        second_stage_non_max_suppression_fn,
        'second_stage_score_conversion_fn':
        second_stage_score_conversion_fn,
        'second_stage_localization_loss_weight':
        second_stage_localization_loss_weight,
        'second_stage_classification_loss':
        second_stage_classification_loss,
        'second_stage_classification_loss_weight':
        second_stage_classification_loss_weight,
        'hard_example_miner':
        hard_example_miner,
        'add_summaries':
        add_summaries,
        'crop_and_resize_fn':
        crop_and_resize_fn,
        'clip_anchors_to_image':
        clip_anchors_to_image,
        'use_static_shapes':
        use_static_shapes,
        'resize_masks':
        frcnn_config.resize_masks,
        'return_raw_detections_during_predict':
        (frcnn_config.return_raw_detections_during_predict)
    }

    if (isinstance(second_stage_box_predictor,
                   rfcn_box_predictor.RfcnBoxPredictor)
            or isinstance(second_stage_box_predictor,
                          rfcn_keras_box_predictor.RfcnKerasBoxPredictor)):
        return rfcn_meta_arch.RFCNMetaArch(
            second_stage_rfcn_box_predictor=second_stage_box_predictor,
            **common_kwargs)
    else:
        return faster_rcnn_meta_arch.FasterRCNNMetaArch(
            initial_crop_size=initial_crop_size,
            maxpool_kernel_size=maxpool_kernel_size,
            maxpool_stride=maxpool_stride,
            second_stage_mask_rcnn_box_predictor=second_stage_box_predictor,
            second_stage_mask_prediction_loss_weight=(
                second_stage_mask_prediction_loss_weight),
            **common_kwargs)
    def _create_model(self,
                      model_fn=ssd_meta_arch.SSDMetaArch,
                      apply_hard_mining=True,
                      normalize_loc_loss_by_codesize=False,
                      add_background_class=True,
                      random_example_sampling=False,
                      weight_regression_loss_by_score=False,
                      use_expected_classification_loss_under_sampling=False,
                      min_num_negative_samples=1,
                      desired_negative_sampling_ratio=3,
                      use_keras=False,
                      predict_mask=False,
                      use_static_shapes=False,
                      nms_max_size_per_class=5):
        is_training = False
        num_classes = 1
        mock_anchor_generator = MockAnchorGenerator2x2()
        if use_keras:
            mock_box_predictor = test_utils.MockKerasBoxPredictor(
                is_training,
                num_classes,
                add_background_class=add_background_class,
                predict_mask=predict_mask)
        else:
            mock_box_predictor = test_utils.MockBoxPredictor(
                is_training,
                num_classes,
                add_background_class=add_background_class,
                predict_mask=predict_mask)
        mock_box_coder = test_utils.MockBoxCoder()
        if use_keras:
            fake_feature_extractor = FakeSSDKerasFeatureExtractor()
        else:
            fake_feature_extractor = FakeSSDFeatureExtractor()
        mock_matcher = test_utils.MockMatcher()
        region_similarity_calculator = sim_calc.IouSimilarity()
        encode_background_as_zeros = False

        def image_resizer_fn(image):
            return [tf.identity(image), tf.shape(image)]

        classification_loss = losses.WeightedSigmoidClassificationLoss()
        localization_loss = losses.WeightedSmoothL1LocalizationLoss()
        non_max_suppression_fn = functools.partial(
            post_processing.batch_multiclass_non_max_suppression,
            score_thresh=-20.0,
            iou_thresh=1.0,
            max_size_per_class=nms_max_size_per_class,
            max_total_size=nms_max_size_per_class,
            use_static_shapes=use_static_shapes)
        classification_loss_weight = 1.0
        localization_loss_weight = 1.0
        negative_class_weight = 1.0
        normalize_loss_by_num_matches = False

        hard_example_miner = None
        if apply_hard_mining:
            # This hard example miner is expected to be a no-op.
            hard_example_miner = losses.HardExampleMiner(
                num_hard_examples=None, iou_threshold=1.0)

        random_example_sampler = None
        if random_example_sampling:
            random_example_sampler = sampler.BalancedPositiveNegativeSampler(
                positive_fraction=0.5)

        target_assigner_instance = target_assigner.TargetAssigner(
            region_similarity_calculator,
            mock_matcher,
            mock_box_coder,
            negative_class_weight=negative_class_weight,
            weight_regression_loss_by_score=weight_regression_loss_by_score)

        expected_classification_loss_under_sampling = None
        if use_expected_classification_loss_under_sampling:
            expected_classification_loss_under_sampling = functools.partial(
                ops.expected_classification_loss_under_sampling,
                min_num_negative_samples=min_num_negative_samples,
                desired_negative_sampling_ratio=desired_negative_sampling_ratio
            )

        code_size = 4
        model = model_fn(
            is_training=is_training,
            anchor_generator=mock_anchor_generator,
            box_predictor=mock_box_predictor,
            box_coder=mock_box_coder,
            feature_extractor=fake_feature_extractor,
            encode_background_as_zeros=encode_background_as_zeros,
            image_resizer_fn=image_resizer_fn,
            non_max_suppression_fn=non_max_suppression_fn,
            score_conversion_fn=tf.identity,
            classification_loss=classification_loss,
            localization_loss=localization_loss,
            classification_loss_weight=classification_loss_weight,
            localization_loss_weight=localization_loss_weight,
            normalize_loss_by_num_matches=normalize_loss_by_num_matches,
            hard_example_miner=hard_example_miner,
            target_assigner_instance=target_assigner_instance,
            add_summaries=False,
            normalize_loc_loss_by_codesize=normalize_loc_loss_by_codesize,
            freeze_batchnorm=False,
            inplace_batchnorm_update=False,
            add_background_class=add_background_class,
            random_example_sampler=random_example_sampler,
            expected_classification_loss_under_sampling=
            expected_classification_loss_under_sampling)
        return model, num_classes, mock_anchor_generator.num_anchors(
        ), code_size