示例#1
0
    def test_spinenet_creation(self, model_id):
        """Test creation of SpineNet models."""
        input_size = 128
        min_level = 3
        max_level = 7

        input_specs = tf.keras.layers.InputSpec(
            shape=[None, input_size, input_size, 3])
        network = backbones.SpineNet(input_specs=input_specs,
                                     min_level=min_level,
                                     max_level=max_level,
                                     norm_momentum=0.99,
                                     norm_epsilon=1e-5)

        backbone_config = backbones_cfg.Backbone(
            type='spinenet',
            spinenet=backbones_cfg.SpineNet(model_id=model_id))
        norm_activation_config = common_cfg.NormActivation(norm_momentum=0.99,
                                                           norm_epsilon=1e-5,
                                                           use_sync_bn=False)

        factory_network = factory.build_backbone(
            input_specs=tf.keras.layers.InputSpec(
                shape=[None, input_size, input_size, 3]),
            backbone_config=backbone_config,
            norm_activation_config=norm_activation_config)

        network_config = network.get_config()
        factory_network_config = factory_network.get_config()

        self.assertEqual(network_config, factory_network_config)
示例#2
0
def build_backbone(input_specs: tf.keras.layers.InputSpec,
                   model_config,
                   l2_regularizer: tf.keras.regularizers.Regularizer = None):
    """Builds backbone from a config.

  Args:
    input_specs: tf.keras.layers.InputSpec.
    model_config: a OneOfConfig. Model config.
    l2_regularizer: tf.keras.regularizers.Regularizer instance. Default to None.

  Returns:
    tf.keras.Model instance of the backbone.
  """
    backbone_type = model_config.backbone.type
    backbone_cfg = model_config.backbone.get()
    norm_activation_config = model_config.norm_activation

    if backbone_type == 'resnet':
        backbone = backbones.ResNet(
            model_id=backbone_cfg.model_id,
            input_specs=input_specs,
            activation=norm_activation_config.activation,
            use_sync_bn=norm_activation_config.use_sync_bn,
            norm_momentum=norm_activation_config.norm_momentum,
            norm_epsilon=norm_activation_config.norm_epsilon,
            kernel_regularizer=l2_regularizer)
    elif backbone_type == 'efficientnet':
        backbone = backbones.EfficientNet(
            model_id=backbone_cfg.model_id,
            input_specs=input_specs,
            stochastic_depth_drop_rate=backbone_cfg.stochastic_depth_drop_rate,
            se_ratio=backbone_cfg.se_ratio,
            activation=norm_activation_config.activation,
            use_sync_bn=norm_activation_config.use_sync_bn,
            norm_momentum=norm_activation_config.norm_momentum,
            norm_epsilon=norm_activation_config.norm_epsilon,
            kernel_regularizer=l2_regularizer)
    elif backbone_type == 'spinenet':
        model_id = backbone_cfg.model_id
        if model_id not in spinenet.SCALING_MAP:
            raise ValueError(
                'SpineNet-{} is not a valid architecture.'.format(model_id))
        scaling_params = spinenet.SCALING_MAP[model_id]

        backbone = backbones.SpineNet(
            input_specs=input_specs,
            min_level=model_config.min_level,
            max_level=model_config.max_level,
            endpoints_num_filters=scaling_params['endpoints_num_filters'],
            resample_alpha=scaling_params['resample_alpha'],
            block_repeats=scaling_params['block_repeats'],
            filter_size_scale=scaling_params['filter_size_scale'],
            kernel_regularizer=l2_regularizer,
            activation=norm_activation_config.activation,
            use_sync_bn=norm_activation_config.use_sync_bn,
            norm_momentum=norm_activation_config.norm_momentum,
            norm_epsilon=norm_activation_config.norm_epsilon)
    elif backbone_type == 'revnet':
        backbone = backbones.RevNet(
            model_id=backbone_cfg.model_id,
            input_specs=input_specs,
            activation=norm_activation_config.activation,
            use_sync_bn=norm_activation_config.use_sync_bn,
            norm_momentum=norm_activation_config.norm_momentum,
            norm_epsilon=norm_activation_config.norm_epsilon,
            kernel_regularizer=l2_regularizer)
    else:
        raise ValueError('Backbone {!r} not implement'.format(backbone_type))

    return backbone