Exemplo n.º 1
0
  def _create_model(self,
                    interleaved=False,
                    apply_hard_mining=True,
                    normalize_loc_loss_by_codesize=False,
                    add_background_class=True,
                    random_example_sampling=False,
                    use_expected_classification_loss_under_sampling=False,
                    min_num_negative_samples=1,
                    desired_negative_sampling_ratio=3,
                    unroll_length=1):
    num_classes = NUM_CLASSES
    is_training = False
    mock_anchor_generator = MockAnchorGenerator2x2()
    mock_box_predictor = test_utils.MockBoxPredictor(is_training, num_classes)
    mock_box_coder = test_utils.MockBoxCoder()
    if interleaved:
      fake_feature_extractor = FakeLSTMInterleavedFeatureExtractor()
    else:
      fake_feature_extractor = FakeLSTMFeatureExtractor()
    mock_matcher = test_utils.MockMatcher()
    region_similarity_calculator = sim_calc.IouSimilarity()
    encode_background_as_zeros = False
    def image_resizer_fn(image):
      return [tf.identity(image), tf.shape(image)]

    classification_loss = losses.WeightedSigmoidClassificationLoss()
    localization_loss = losses.WeightedSmoothL1LocalizationLoss()
    non_max_suppression_fn = functools.partial(
        post_processing.batch_multiclass_non_max_suppression,
        score_thresh=-20.0,
        iou_thresh=1.0,
        max_size_per_class=5,
        max_total_size=MAX_TOTAL_NUM_BOXES)
    classification_loss_weight = 1.0
    localization_loss_weight = 1.0
    negative_class_weight = 1.0
    normalize_loss_by_num_matches = False

    hard_example_miner = None
    if apply_hard_mining:
      # This hard example miner is expected to be a no-op.
      hard_example_miner = losses.HardExampleMiner(
          num_hard_examples=None,
          iou_threshold=1.0)

    target_assigner_instance = target_assigner.TargetAssigner(
        region_similarity_calculator,
        mock_matcher,
        mock_box_coder,
        negative_class_weight=negative_class_weight)

    code_size = 4
    model = lstm_ssd_meta_arch.LSTMSSDMetaArch(
        is_training=is_training,
        anchor_generator=mock_anchor_generator,
        box_predictor=mock_box_predictor,
        box_coder=mock_box_coder,
        feature_extractor=fake_feature_extractor,
        encode_background_as_zeros=encode_background_as_zeros,
        image_resizer_fn=image_resizer_fn,
        non_max_suppression_fn=non_max_suppression_fn,
        score_conversion_fn=tf.identity,
        classification_loss=classification_loss,
        localization_loss=localization_loss,
        classification_loss_weight=classification_loss_weight,
        localization_loss_weight=localization_loss_weight,
        normalize_loss_by_num_matches=normalize_loss_by_num_matches,
        hard_example_miner=hard_example_miner,
        unroll_length=unroll_length,
        target_assigner_instance=target_assigner_instance,
        add_summaries=False)
    return model, num_classes, mock_anchor_generator.num_anchors(), code_size
Exemplo n.º 2
0
def _build_lstm_model(ssd_config, lstm_config, is_training):
  """Builds an LSTM detection model based on the model config.

  Args:
    ssd_config: A ssd.proto object containing the config for the desired
      LSTMSSDMetaArch.
    lstm_config: LstmModel config proto that specifies LSTM train/eval configs.
    is_training: True if this model is being built for training purposes.

  Returns:
    LSTMSSDMetaArch based on the config.
  Raises:
    ValueError: If ssd_config.type is not recognized (i.e. not registered in
      model_class_map), or if lstm_config.interleave_strategy is not recognized.
    ValueError: If unroll_length is not specified in the config file.
  """
  feature_extractor = _build_lstm_feature_extractor(
      ssd_config.feature_extractor, is_training, lstm_config)

  box_coder = box_coder_builder.build(ssd_config.box_coder)
  matcher = matcher_builder.build(ssd_config.matcher)
  region_similarity_calculator = sim_calc.build(
      ssd_config.similarity_calculator)

  num_classes = ssd_config.num_classes
  ssd_box_predictor = box_predictor_builder.build(hyperparams_builder.build,
                                                  ssd_config.box_predictor,
                                                  is_training, num_classes)
  anchor_generator = anchor_generator_builder.build(ssd_config.anchor_generator)
  image_resizer_fn = image_resizer_builder.build(ssd_config.image_resizer)
  non_max_suppression_fn, score_conversion_fn = post_processing_builder.build(
      ssd_config.post_processing)
  (classification_loss, localization_loss, classification_weight,
   localization_weight, miner, _, _) = losses_builder.build(ssd_config.loss)

  normalize_loss_by_num_matches = ssd_config.normalize_loss_by_num_matches
  encode_background_as_zeros = ssd_config.encode_background_as_zeros
  negative_class_weight = ssd_config.negative_class_weight

  # Extra configs for lstm unroll length.
  unroll_length = None
  if 'lstm' in ssd_config.feature_extractor.type:
    if is_training:
      unroll_length = lstm_config.train_unroll_length
    else:
      unroll_length = lstm_config.eval_unroll_length
  if unroll_length is None:
    raise ValueError('No unroll length found in the config file')

  target_assigner_instance = target_assigner.TargetAssigner(
      region_similarity_calculator,
      matcher,
      box_coder,
      negative_class_weight=negative_class_weight)

  lstm_model = lstm_ssd_meta_arch.LSTMSSDMetaArch(
      is_training=is_training,
      anchor_generator=anchor_generator,
      box_predictor=ssd_box_predictor,
      box_coder=box_coder,
      feature_extractor=feature_extractor,
      encode_background_as_zeros=encode_background_as_zeros,
      image_resizer_fn=image_resizer_fn,
      non_max_suppression_fn=non_max_suppression_fn,
      score_conversion_fn=score_conversion_fn,
      classification_loss=classification_loss,
      localization_loss=localization_loss,
      classification_loss_weight=classification_weight,
      localization_loss_weight=localization_weight,
      normalize_loss_by_num_matches=normalize_loss_by_num_matches,
      hard_example_miner=miner,
      unroll_length=unroll_length,
      target_assigner_instance=target_assigner_instance)

  return lstm_model