def _build_conv_hyperparams(self, add_batch_norm=True): conv_hyperparams = hyperparams_pb2.Hyperparams() conv_hyperparams_text_proto = """ force_use_bias: true activation: SWISH regularizer { l2_regularizer { weight: 0.0004 } } initializer { truncated_normal_initializer { stddev: 0.03 mean: 0.0 } } """ if add_batch_norm: batch_norm_proto = """ batch_norm { scale: true, decay: 0.99, epsilon: 0.001, } """ conv_hyperparams_text_proto += batch_norm_proto text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams) return hyperparams_builder.KerasLayerHyperparams(conv_hyperparams)
def test_force_use_bias_if_batch_norm_center_keras(self): conv_hyperparams_text_proto = """ regularizer { l2_regularizer { } } initializer { truncated_normal_initializer { } } batch_norm { decay: 0.7 center: true scale: true epsilon: 0.03 train: true } force_use_bias: true """ conv_hyperparams_proto = hyperparams_pb2.Hyperparams() text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto) keras_config = hyperparams_builder.KerasLayerHyperparams( conv_hyperparams_proto) self.assertTrue(keras_config.use_batch_norm()) batch_norm_params = keras_config.batch_norm_params() self.assertTrue(batch_norm_params['center']) self.assertTrue(batch_norm_params['scale']) hyperparams = keras_config.params() self.assertTrue(hyperparams['use_bias'])
def test_return_non_default_batch_norm_params_keras_override(self): conv_hyperparams_text_proto = """ regularizer { l2_regularizer { } } initializer { truncated_normal_initializer { } } batch_norm { decay: 0.7 center: false scale: true epsilon: 0.03 } """ conv_hyperparams_proto = hyperparams_pb2.Hyperparams() text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto) keras_config = hyperparams_builder.KerasLayerHyperparams( conv_hyperparams_proto) self.assertTrue(keras_config.use_batch_norm()) batch_norm_params = keras_config.batch_norm_params(momentum=0.4) self.assertAlmostEqual(batch_norm_params['momentum'], 0.4) self.assertAlmostEqual(batch_norm_params['epsilon'], 0.03) self.assertFalse(batch_norm_params['center']) self.assertTrue(batch_norm_params['scale'])
def _build_faster_rcnn_keras_feature_extractor(feature_extractor_config, is_training, inplace_batchnorm_update=False): """Builds a faster_rcnn_meta_arch.FasterRCNNKerasFeatureExtractor from config. Args: feature_extractor_config: A FasterRcnnFeatureExtractor proto config from faster_rcnn.proto. is_training: True if this feature extractor is being built for training. inplace_batchnorm_update: Whether to update batch_norm inplace during training. This is required for batch norm to work correctly on TPUs. When this is false, user must add a control dependency on tf.GraphKeys.UPDATE_OPS for train/loss op in order to update the batch norm moving average parameters. Returns: faster_rcnn_meta_arch.FasterRCNNKerasFeatureExtractor based on config. Raises: ValueError: On invalid feature extractor type. """ if inplace_batchnorm_update: raise ValueError('inplace batchnorm updates not supported.') feature_type = feature_extractor_config.type first_stage_features_stride = ( feature_extractor_config.first_stage_features_stride) batch_norm_trainable = feature_extractor_config.batch_norm_trainable if feature_type not in FASTER_RCNN_KERAS_FEATURE_EXTRACTOR_CLASS_MAP: raise ValueError( 'Unknown Faster R-CNN feature_extractor: {}'.format(feature_type)) feature_extractor_class = FASTER_RCNN_KERAS_FEATURE_EXTRACTOR_CLASS_MAP[ feature_type] kwargs = {} if feature_extractor_config.HasField('conv_hyperparams'): kwargs.update({ 'conv_hyperparams': hyperparams_builder.KerasLayerHyperparams( feature_extractor_config.conv_hyperparams), 'override_base_feature_extractor_hyperparams': feature_extractor_config. override_base_feature_extractor_hyperparams }) if feature_extractor_config.HasField('fpn'): kwargs.update({ 'fpn_min_level': feature_extractor_config.fpn.min_level, 'fpn_max_level': feature_extractor_config.fpn.max_level, 'additional_layer_depth': feature_extractor_config.fpn.additional_layer_depth, }) return feature_extractor_class(is_training, first_stage_features_stride, batch_norm_trainable, **kwargs)
def test_return_undefined_regularizer_weight_keras(self): conv_hyperparams_text_proto = """ initializer { truncated_normal_initializer { } } """ conv_hyperparams_proto = hyperparams_pb2.Hyperparams() text_format.Parse(conv_hyperparams_text_proto, conv_hyperparams_proto) keras_config = hyperparams_builder.KerasLayerHyperparams( conv_hyperparams_proto) regularizer_weight = keras_config.get_regularizer_weight() self.assertIsNone(regularizer_weight)
def _build_conv_hyperparams(self): conv_hyperparams = hyperparams_pb2.Hyperparams() conv_hyperparams_text_proto = """ regularizer { l2_regularizer { } } initializer { truncated_normal_initializer { } } """ text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams) return hyperparams_builder.KerasLayerHyperparams(conv_hyperparams)
def _build_fc_hyperparams(self, op_type=hyperparams_pb2.Hyperparams.FC): hyperparams = hyperparams_pb2.Hyperparams() hyperparams_text_proto = """ activation: NONE regularizer { l2_regularizer { } } initializer { truncated_normal_initializer { } } """ text_format.Merge(hyperparams_text_proto, hyperparams) hyperparams.op = op_type return hyperparams_builder.KerasLayerHyperparams(hyperparams)
def test_use_relu_6_activation_keras(self): conv_hyperparams_text_proto = """ regularizer { l2_regularizer { } } initializer { truncated_normal_initializer { } } activation: RELU_6 """ conv_hyperparams_proto = hyperparams_pb2.Hyperparams() text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto) keras_config = hyperparams_builder.KerasLayerHyperparams( conv_hyperparams_proto) self.assertEqual(keras_config.params()['activation'], tf.nn.relu6)
def test_do_not_use_batch_norm_if_default_keras(self): conv_hyperparams_text_proto = """ regularizer { l2_regularizer { } } initializer { truncated_normal_initializer { } } """ conv_hyperparams_proto = hyperparams_pb2.Hyperparams() text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto) keras_config = hyperparams_builder.KerasLayerHyperparams( conv_hyperparams_proto) self.assertFalse(keras_config.use_batch_norm()) self.assertEqual(keras_config.batch_norm_params(), {})
def _build_conv_hyperparams(self): conv_hyperparams = hyperparams_pb2.Hyperparams() conv_hyperparams_text_proto = """ activation: RELU_6 regularizer { l2_regularizer { } } initializer { truncated_normal_initializer { } } batch_norm { scale: false } """ text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams) return hyperparams_builder.KerasLayerHyperparams(conv_hyperparams)
def test_return_l2_regularizer_weight_keras(self): conv_hyperparams_text_proto = """ regularizer { l2_regularizer { weight: 0.5 } } initializer { truncated_normal_initializer { } } """ conv_hyperparams_proto = hyperparams_pb2.Hyperparams() text_format.Parse(conv_hyperparams_text_proto, conv_hyperparams_proto) keras_config = hyperparams_builder.KerasLayerHyperparams( conv_hyperparams_proto) regularizer_weight = keras_config.get_regularizer_weight() self.assertAlmostEqual(regularizer_weight, 0.25)
def test_variance_in_range_with_random_normal_initializer_keras(self): conv_hyperparams_text_proto = """ regularizer { l2_regularizer { } } initializer { random_normal_initializer { mean: 0.0 stddev: 0.8 } } """ conv_hyperparams_proto = hyperparams_pb2.Hyperparams() text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto) keras_config = hyperparams_builder.KerasLayerHyperparams( conv_hyperparams_proto) initializer = keras_config.params()['kernel_initializer'] self._assert_variance_in_range(initializer, shape=[100, 40], variance=0.64, tol=1e-1)
def test_keras_initializer_by_name(self): conv_hyperparams_text_proto = """ regularizer { l2_regularizer { } } initializer { keras_initializer_by_name: "glorot_uniform" } """ conv_hyperparams_proto = hyperparams_pb2.Hyperparams() text_format.Parse(conv_hyperparams_text_proto, conv_hyperparams_proto) keras_config = hyperparams_builder.KerasLayerHyperparams( conv_hyperparams_proto) initializer_arg = keras_config.params()['kernel_initializer'] conv_layer = tf.keras.layers.Conv2D( filters=16, kernel_size=3, **keras_config.params()) self.assertEqual(initializer_arg, 'glorot_uniform') self.assertIsInstance(conv_layer.kernel_initializer, type(tf.keras.initializers.get('glorot_uniform')))
def test_return_l2_regularizer_weights_keras(self): conv_hyperparams_text_proto = """ regularizer { l2_regularizer { weight: 0.42 } } initializer { truncated_normal_initializer { } } """ conv_hyperparams_proto = hyperparams_pb2.Hyperparams() text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto) keras_config = hyperparams_builder.KerasLayerHyperparams( conv_hyperparams_proto) regularizer = keras_config.params()['kernel_regularizer'] weights = np.array([1., -1, 4., 2.]) result = regularizer(tf.constant(weights)).numpy() self.assertAllClose(np.power(weights, 2).sum() / 2.0 * 0.42, result)
def test_do_not_use_batch_norm_if_default_keras(self): conv_hyperparams_text_proto = """ regularizer { l2_regularizer { } } initializer { truncated_normal_initializer { } } """ conv_hyperparams_proto = hyperparams_pb2.Hyperparams() text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto) keras_config = hyperparams_builder.KerasLayerHyperparams( conv_hyperparams_proto) self.assertFalse(keras_config.use_batch_norm()) self.assertEqual(keras_config.batch_norm_params(), {}) # The batch norm builder should build an identity Lambda layer identity_layer = keras_config.build_batch_norm() self.assertIsInstance(identity_layer, tf.keras.layers.Lambda)
def test_return_l1_regularized_weights_keras(self): conv_hyperparams_text_proto = """ regularizer { l1_regularizer { weight: 0.5 } } initializer { truncated_normal_initializer { } } """ conv_hyperparams_proto = hyperparams_pb2.Hyperparams() text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto) keras_config = hyperparams_builder.KerasLayerHyperparams( conv_hyperparams_proto) regularizer = keras_config.params()['kernel_regularizer'] weights = np.array([1., -1, 4., 2.]) with self.test_session() as sess: result = sess.run(regularizer(tf.constant(weights))) self.assertAllClose(np.abs(weights).sum() * 0.5, result)
def test_use_relu_6_activation_keras(self): conv_hyperparams_text_proto = """ regularizer { l2_regularizer { } } initializer { truncated_normal_initializer { } } activation: RELU_6 """ conv_hyperparams_proto = hyperparams_pb2.Hyperparams() text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto) keras_config = hyperparams_builder.KerasLayerHyperparams( conv_hyperparams_proto) self.assertEqual(keras_config.params()['activation'], None) self.assertEqual( keras_config.params(include_activation=True)['activation'], tf.nn.relu6) activation_layer = keras_config.build_activation_layer() self.assertTrue(isinstance(activation_layer, tf.keras.layers.Lambda)) self.assertEqual(activation_layer.function, tf.nn.relu6)
def test_variance_in_range_with_variance_scaling_initializer_uniform_keras( self): conv_hyperparams_text_proto = """ regularizer { l2_regularizer { } } initializer { variance_scaling_initializer { factor: 2.0 mode: FAN_IN uniform: true } } """ conv_hyperparams_proto = hyperparams_pb2.Hyperparams() text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto) keras_config = hyperparams_builder.KerasLayerHyperparams( conv_hyperparams_proto) initializer = keras_config.params()['kernel_initializer'] self._assert_variance_in_range(initializer, shape=[100, 40], variance=2. / 100.)
def test_use_none_activation_keras(self): conv_hyperparams_text_proto = """ regularizer { l2_regularizer { } } initializer { truncated_normal_initializer { } } activation: NONE """ conv_hyperparams_proto = hyperparams_pb2.Hyperparams() text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto) keras_config = hyperparams_builder.KerasLayerHyperparams( conv_hyperparams_proto) self.assertIsNone(keras_config.params()['activation']) self.assertIsNone( keras_config.params(include_activation=True)['activation']) activation_layer = keras_config.build_activation_layer() self.assertIsInstance(activation_layer, tf.keras.layers.Lambda) self.assertEqual(activation_layer.function, tf.identity)
def _build_conv_hyperparams(self): conv_hyperparams = hyperparams_pb2.Hyperparams() conv_hyperparams_text_proto = """ activation: RELU_6, regularizer { l2_regularizer { weight: 0.0004 } } initializer { truncated_normal_initializer { stddev: 0.03 mean: 0.0 } } batch_norm { scale: true, decay: 0.997, epsilon: 0.001, } """ text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams) return hyperparams_builder.KerasLayerHyperparams(conv_hyperparams)
def _build_conv_hyperparams(self, add_batch_norm=True): conv_hyperparams = hyperparams_pb2.Hyperparams() conv_hyperparams_text_proto = """ activation: RELU_6 regularizer { l2_regularizer { } } initializer { truncated_normal_initializer { stddev: 0.01 mean: 0.0 } } """ if add_batch_norm: batch_norm_proto = """ batch_norm { train: true, } """ conv_hyperparams_text_proto += batch_norm_proto text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams) return hyperparams_builder.KerasLayerHyperparams(conv_hyperparams)
def _build_faster_rcnn_model(frcnn_config, is_training, add_summaries): """Builds a Faster R-CNN or R-FCN detection model based on the model config. Builds R-FCN model if the second_stage_box_predictor in the config is of type `rfcn_box_predictor` else builds a Faster R-CNN model. Args: frcnn_config: A faster_rcnn.proto object containing the config for the desired FasterRCNNMetaArch or RFCNMetaArch. is_training: True if this model is being built for training purposes. add_summaries: Whether to add tf summaries in the model. Returns: FasterRCNNMetaArch based on the config. Raises: ValueError: If frcnn_config.type is not recognized (i.e. not registered in model_class_map). """ num_classes = frcnn_config.num_classes image_resizer_fn = image_resizer_builder.build(frcnn_config.image_resizer) is_keras = (frcnn_config.feature_extractor.type in FASTER_RCNN_KERAS_FEATURE_EXTRACTOR_CLASS_MAP) if is_keras: feature_extractor = _build_faster_rcnn_keras_feature_extractor( frcnn_config.feature_extractor, is_training, inplace_batchnorm_update=frcnn_config.inplace_batchnorm_update) else: feature_extractor = _build_faster_rcnn_feature_extractor( frcnn_config.feature_extractor, is_training, inplace_batchnorm_update=frcnn_config.inplace_batchnorm_update) number_of_stages = frcnn_config.number_of_stages first_stage_anchor_generator = anchor_generator_builder.build( frcnn_config.first_stage_anchor_generator) first_stage_target_assigner = target_assigner.create_target_assigner( 'FasterRCNN', 'proposal', use_matmul_gather=frcnn_config.use_matmul_gather_in_matcher) first_stage_atrous_rate = frcnn_config.first_stage_atrous_rate if is_keras: first_stage_box_predictor_arg_scope_fn = ( hyperparams_builder.KerasLayerHyperparams( frcnn_config.first_stage_box_predictor_conv_hyperparams)) else: first_stage_box_predictor_arg_scope_fn = hyperparams_builder.build( frcnn_config.first_stage_box_predictor_conv_hyperparams, is_training) first_stage_box_predictor_kernel_size = ( frcnn_config.first_stage_box_predictor_kernel_size) first_stage_box_predictor_depth = frcnn_config.first_stage_box_predictor_depth first_stage_minibatch_size = frcnn_config.first_stage_minibatch_size use_static_shapes = frcnn_config.use_static_shapes and ( frcnn_config.use_static_shapes_for_eval or is_training) first_stage_sampler = sampler.BalancedPositiveNegativeSampler( positive_fraction=frcnn_config.first_stage_positive_balance_fraction, is_static=(frcnn_config.use_static_balanced_label_sampler and use_static_shapes)) first_stage_max_proposals = frcnn_config.first_stage_max_proposals if (frcnn_config.first_stage_nms_iou_threshold < 0 or frcnn_config.first_stage_nms_iou_threshold > 1.0): raise ValueError('iou_threshold not in [0, 1.0].') if (is_training and frcnn_config.second_stage_batch_size > first_stage_max_proposals): raise ValueError('second_stage_batch_size should be no greater than ' 'first_stage_max_proposals.') first_stage_non_max_suppression_fn = functools.partial( post_processing.batch_multiclass_non_max_suppression, score_thresh=frcnn_config.first_stage_nms_score_threshold, iou_thresh=frcnn_config.first_stage_nms_iou_threshold, max_size_per_class=frcnn_config.first_stage_max_proposals, max_total_size=frcnn_config.first_stage_max_proposals, use_static_shapes=use_static_shapes, use_partitioned_nms=frcnn_config.use_partitioned_nms_in_first_stage, use_combined_nms=frcnn_config.use_combined_nms_in_first_stage) first_stage_loc_loss_weight = ( frcnn_config.first_stage_localization_loss_weight) first_stage_obj_loss_weight = frcnn_config.first_stage_objectness_loss_weight initial_crop_size = frcnn_config.initial_crop_size maxpool_kernel_size = frcnn_config.maxpool_kernel_size maxpool_stride = frcnn_config.maxpool_stride second_stage_target_assigner = target_assigner.create_target_assigner( 'FasterRCNN', 'detection', use_matmul_gather=frcnn_config.use_matmul_gather_in_matcher) if is_keras: second_stage_box_predictor = box_predictor_builder.build_keras( hyperparams_builder.KerasLayerHyperparams, freeze_batchnorm=False, inplace_batchnorm_update=False, num_predictions_per_location_list=[1], box_predictor_config=frcnn_config.second_stage_box_predictor, is_training=is_training, num_classes=num_classes) else: second_stage_box_predictor = box_predictor_builder.build( hyperparams_builder.build, frcnn_config.second_stage_box_predictor, is_training=is_training, num_classes=num_classes) second_stage_batch_size = frcnn_config.second_stage_batch_size second_stage_sampler = sampler.BalancedPositiveNegativeSampler( positive_fraction=frcnn_config.second_stage_balance_fraction, is_static=(frcnn_config.use_static_balanced_label_sampler and use_static_shapes)) (second_stage_non_max_suppression_fn, second_stage_score_conversion_fn) = post_processing_builder.build( frcnn_config.second_stage_post_processing) second_stage_localization_loss_weight = ( frcnn_config.second_stage_localization_loss_weight) second_stage_classification_loss = ( losses_builder.build_faster_rcnn_classification_loss( frcnn_config.second_stage_classification_loss)) second_stage_classification_loss_weight = ( frcnn_config.second_stage_classification_loss_weight) second_stage_mask_prediction_loss_weight = ( frcnn_config.second_stage_mask_prediction_loss_weight) hard_example_miner = None if frcnn_config.HasField('hard_example_miner'): hard_example_miner = losses_builder.build_hard_example_miner( frcnn_config.hard_example_miner, second_stage_classification_loss_weight, second_stage_localization_loss_weight) crop_and_resize_fn = (ops.matmul_crop_and_resize if frcnn_config.use_matmul_crop_and_resize else ops.native_crop_and_resize) clip_anchors_to_image = (frcnn_config.clip_anchors_to_image) common_kwargs = { 'is_training': is_training, 'num_classes': num_classes, 'image_resizer_fn': image_resizer_fn, 'feature_extractor': feature_extractor, 'number_of_stages': number_of_stages, 'first_stage_anchor_generator': first_stage_anchor_generator, 'first_stage_target_assigner': first_stage_target_assigner, 'first_stage_atrous_rate': first_stage_atrous_rate, 'first_stage_box_predictor_arg_scope_fn': first_stage_box_predictor_arg_scope_fn, 'first_stage_box_predictor_kernel_size': first_stage_box_predictor_kernel_size, 'first_stage_box_predictor_depth': first_stage_box_predictor_depth, 'first_stage_minibatch_size': first_stage_minibatch_size, 'first_stage_sampler': first_stage_sampler, 'first_stage_non_max_suppression_fn': first_stage_non_max_suppression_fn, 'first_stage_max_proposals': first_stage_max_proposals, 'first_stage_localization_loss_weight': first_stage_loc_loss_weight, 'first_stage_objectness_loss_weight': first_stage_obj_loss_weight, 'second_stage_target_assigner': second_stage_target_assigner, 'second_stage_batch_size': second_stage_batch_size, 'second_stage_sampler': second_stage_sampler, 'second_stage_non_max_suppression_fn': second_stage_non_max_suppression_fn, 'second_stage_score_conversion_fn': second_stage_score_conversion_fn, 'second_stage_localization_loss_weight': second_stage_localization_loss_weight, 'second_stage_classification_loss': second_stage_classification_loss, 'second_stage_classification_loss_weight': second_stage_classification_loss_weight, 'hard_example_miner': hard_example_miner, 'add_summaries': add_summaries, 'crop_and_resize_fn': crop_and_resize_fn, 'clip_anchors_to_image': clip_anchors_to_image, 'use_static_shapes': use_static_shapes, 'resize_masks': frcnn_config.resize_masks, 'return_raw_detections_during_predict': (frcnn_config.return_raw_detections_during_predict) } if (isinstance(second_stage_box_predictor, rfcn_box_predictor.RfcnBoxPredictor) or isinstance(second_stage_box_predictor, rfcn_keras_box_predictor.RfcnKerasBoxPredictor)): return rfcn_meta_arch.RFCNMetaArch( second_stage_rfcn_box_predictor=second_stage_box_predictor, **common_kwargs) else: return faster_rcnn_meta_arch.FasterRCNNMetaArch( initial_crop_size=initial_crop_size, maxpool_kernel_size=maxpool_kernel_size, maxpool_stride=maxpool_stride, second_stage_mask_rcnn_box_predictor=second_stage_box_predictor, second_stage_mask_prediction_loss_weight=( second_stage_mask_prediction_loss_weight), **common_kwargs)
def _build_keras_layer_hyperparams(self, hyperparams_text_proto): hyperparams = hyperparams_pb2.Hyperparams() text_format.Merge(hyperparams_text_proto, hyperparams) return hyperparams_builder.KerasLayerHyperparams(hyperparams)
def _build_ssd_feature_extractor(feature_extractor_config, is_training, freeze_batchnorm, reuse_weights=None): """Builds a ssd_meta_arch.SSDFeatureExtractor based on config. Args: feature_extractor_config: A SSDFeatureExtractor proto config from ssd.proto. is_training: True if this feature extractor is being built for training. freeze_batchnorm: Whether to freeze batch norm parameters during training or not. When training with a small batch size (e.g. 1), it is desirable to freeze batch norm update and use pretrained batch norm params. reuse_weights: if the feature extractor should reuse weights. Returns: ssd_meta_arch.SSDFeatureExtractor based on config. Raises: ValueError: On invalid feature extractor type. """ feature_type = feature_extractor_config.type is_keras_extractor = feature_type in SSD_KERAS_FEATURE_EXTRACTOR_CLASS_MAP depth_multiplier = feature_extractor_config.depth_multiplier min_depth = feature_extractor_config.min_depth pad_to_multiple = feature_extractor_config.pad_to_multiple use_explicit_padding = feature_extractor_config.use_explicit_padding use_depthwise = feature_extractor_config.use_depthwise if is_keras_extractor: conv_hyperparams = hyperparams_builder.KerasLayerHyperparams( feature_extractor_config.conv_hyperparams) else: conv_hyperparams = hyperparams_builder.build( feature_extractor_config.conv_hyperparams, is_training) override_base_feature_extractor_hyperparams = ( feature_extractor_config.override_base_feature_extractor_hyperparams) if (feature_type not in SSD_FEATURE_EXTRACTOR_CLASS_MAP) and ( not is_keras_extractor): raise ValueError( 'Unknown ssd feature_extractor: {}'.format(feature_type)) if is_keras_extractor: feature_extractor_class = SSD_KERAS_FEATURE_EXTRACTOR_CLASS_MAP[ feature_type] else: feature_extractor_class = SSD_FEATURE_EXTRACTOR_CLASS_MAP[feature_type] kwargs = { 'is_training': is_training, 'depth_multiplier': depth_multiplier, 'min_depth': min_depth, 'pad_to_multiple': pad_to_multiple, 'use_explicit_padding': use_explicit_padding, 'use_depthwise': use_depthwise, 'override_base_feature_extractor_hyperparams': override_base_feature_extractor_hyperparams } if feature_extractor_config.HasField( 'replace_preprocessor_with_placeholder'): kwargs.update({ 'replace_preprocessor_with_placeholder': feature_extractor_config.replace_preprocessor_with_placeholder }) if feature_extractor_config.HasField('num_layers'): kwargs.update({'num_layers': feature_extractor_config.num_layers}) if is_keras_extractor: kwargs.update({ 'conv_hyperparams': conv_hyperparams, 'inplace_batchnorm_update': False, 'freeze_batchnorm': freeze_batchnorm }) else: kwargs.update({ 'conv_hyperparams_fn': conv_hyperparams, 'reuse_weights': reuse_weights, }) if feature_extractor_config.HasField('fpn'): kwargs.update({ 'fpn_min_level': feature_extractor_config.fpn.min_level, 'fpn_max_level': feature_extractor_config.fpn.max_level, 'additional_layer_depth': feature_extractor_config.fpn.additional_layer_depth, }) return feature_extractor_class(**kwargs)
def _build_ssd_feature_extractor(feature_extractor_config, is_training, freeze_batchnorm, reuse_weights=None): feature_type = feature_extractor_config.type is_keras_extractor = feature_type in SSD_KERAS_FEATURE_EXTRACTOR_CLASS_MAP depth_multiplier = feature_extractor_config.depth_multiplier min_depth = feature_extractor_config.min_depth pad_to_multiple = feature_extractor_config.pad_to_multiple use_explicit_padding = feature_extractor_config.use_explicit_padding use_depthwise = feature_extractor_config.use_depthwise if is_keras_extractor: conv_hyperparams = hyperparams_builder.KerasLayerHyperparams( feature_extractor_config.conv_hyperparams) else: conv_hyperparams = hyperparams_builder.build( feature_extractor_config.conv_hyperparams, is_training) override_base_feature_extractor_hyperparams = ( feature_extractor_config.override_base_feature_extractor_hyperparams) if (feature_type not in SSD_FEATURE_EXTRACTOR_CLASS_MAP) and ( not is_keras_extractor): raise ValueError('Unknown ssd feature_extractor: {}'.format(feature_type)) if is_keras_extractor: feature_extractor_class = SSD_KERAS_FEATURE_EXTRACTOR_CLASS_MAP[ feature_type] else: feature_extractor_class = SSD_FEATURE_EXTRACTOR_CLASS_MAP[feature_type] kwargs = { 'is_training': is_training, 'depth_multiplier': depth_multiplier, 'min_depth': min_depth, 'pad_to_multiple': pad_to_multiple, 'use_explicit_padding': use_explicit_padding, 'use_depthwise': use_depthwise, 'override_base_feature_extractor_hyperparams': override_base_feature_extractor_hyperparams } if feature_extractor_config.HasField('replace_preprocessor_with_placeholder'): kwargs.update({ 'replace_preprocessor_with_placeholder': feature_extractor_config.replace_preprocessor_with_placeholder }) if feature_extractor_config.HasField('num_layers'): kwargs.update({'num_layers': feature_extractor_config.num_layers}) if is_keras_extractor: kwargs.update({ 'conv_hyperparams': conv_hyperparams, 'inplace_batchnorm_update': False, 'freeze_batchnorm': freeze_batchnorm }) else: kwargs.update({ 'conv_hyperparams_fn': conv_hyperparams, 'reuse_weights': reuse_weights, }) if feature_extractor_config.HasField('fpn'): kwargs.update({ 'fpn_min_level': feature_extractor_config.fpn.min_level, 'fpn_max_level': feature_extractor_config.fpn.max_level, 'additional_layer_depth': feature_extractor_config.fpn.additional_layer_depth, }) return feature_extractor_class(**kwargs)