Exemplo n.º 1
0
    def test_revnet_creation(self, model_id):
        """Test creation of RevNet models."""
        network = backbones.RevNet(model_id=model_id,
                                   norm_momentum=0.99,
                                   norm_epsilon=1e-5)

        backbone_config = backbones_cfg.Backbone(
            type='revnet', revnet=backbones_cfg.RevNet(model_id=model_id))
        norm_activation_config = common_cfg.NormActivation(norm_momentum=0.99,
                                                           norm_epsilon=1e-5,
                                                           use_sync_bn=False)

        factory_network = factory.build_backbone(
            input_specs=tf.keras.layers.InputSpec(shape=[None, None, None, 3]),
            backbone_config=backbone_config,
            norm_activation_config=norm_activation_config)

        network_config = network.get_config()
        factory_network_config = factory_network.get_config()

        self.assertEqual(network_config, factory_network_config)
    def test_revnet_network_creation(self):
        """Test for creation of a RevNet-56 classifier."""
        revnet_model_id = 56
        inputs = np.random.rand(2, 224, 224, 3)

        tf.keras.backend.set_image_data_format('channels_last')

        backbone = backbones.RevNet(model_id=revnet_model_id)
        self.assertEqual(backbone.count_params(), 19473792)

        num_classes = 1000
        model = classification_model.ClassificationModel(
            backbone=backbone,
            num_classes=num_classes,
            dropout_rate=0.2,
            add_head_batch_norm=True,
        )
        self.assertEqual(model.count_params(), 22816104)

        logits = model(inputs)
        self.assertAllEqual([2, num_classes], logits.numpy().shape)
Exemplo n.º 3
0
def build_backbone(input_specs: tf.keras.layers.InputSpec,
                   model_config,
                   l2_regularizer: tf.keras.regularizers.Regularizer = None):
    """Builds backbone from a config.

  Args:
    input_specs: tf.keras.layers.InputSpec.
    model_config: a OneOfConfig. Model config.
    l2_regularizer: tf.keras.regularizers.Regularizer instance. Default to None.

  Returns:
    tf.keras.Model instance of the backbone.
  """
    backbone_type = model_config.backbone.type
    backbone_cfg = model_config.backbone.get()
    norm_activation_config = model_config.norm_activation

    if backbone_type == 'resnet':
        backbone = backbones.ResNet(
            model_id=backbone_cfg.model_id,
            input_specs=input_specs,
            activation=norm_activation_config.activation,
            use_sync_bn=norm_activation_config.use_sync_bn,
            norm_momentum=norm_activation_config.norm_momentum,
            norm_epsilon=norm_activation_config.norm_epsilon,
            kernel_regularizer=l2_regularizer)
    elif backbone_type == 'efficientnet':
        backbone = backbones.EfficientNet(
            model_id=backbone_cfg.model_id,
            input_specs=input_specs,
            stochastic_depth_drop_rate=backbone_cfg.stochastic_depth_drop_rate,
            se_ratio=backbone_cfg.se_ratio,
            activation=norm_activation_config.activation,
            use_sync_bn=norm_activation_config.use_sync_bn,
            norm_momentum=norm_activation_config.norm_momentum,
            norm_epsilon=norm_activation_config.norm_epsilon,
            kernel_regularizer=l2_regularizer)
    elif backbone_type == 'spinenet':
        model_id = backbone_cfg.model_id
        if model_id not in spinenet.SCALING_MAP:
            raise ValueError(
                'SpineNet-{} is not a valid architecture.'.format(model_id))
        scaling_params = spinenet.SCALING_MAP[model_id]

        backbone = backbones.SpineNet(
            input_specs=input_specs,
            min_level=model_config.min_level,
            max_level=model_config.max_level,
            endpoints_num_filters=scaling_params['endpoints_num_filters'],
            resample_alpha=scaling_params['resample_alpha'],
            block_repeats=scaling_params['block_repeats'],
            filter_size_scale=scaling_params['filter_size_scale'],
            kernel_regularizer=l2_regularizer,
            activation=norm_activation_config.activation,
            use_sync_bn=norm_activation_config.use_sync_bn,
            norm_momentum=norm_activation_config.norm_momentum,
            norm_epsilon=norm_activation_config.norm_epsilon)
    elif backbone_type == 'revnet':
        backbone = backbones.RevNet(
            model_id=backbone_cfg.model_id,
            input_specs=input_specs,
            activation=norm_activation_config.activation,
            use_sync_bn=norm_activation_config.use_sync_bn,
            norm_momentum=norm_activation_config.norm_momentum,
            norm_epsilon=norm_activation_config.norm_epsilon,
            kernel_regularizer=l2_regularizer)
    else:
        raise ValueError('Backbone {!r} not implement'.format(backbone_type))

    return backbone