def test_efficientnet_creation(self, model_id, se_ratio): """Test creation of EfficientNet models.""" network = backbones.EfficientNet(model_id=model_id, se_ratio=se_ratio, norm_momentum=0.99, norm_epsilon=1e-5) backbone_config = backbones_cfg.Backbone( type='efficientnet', efficientnet=backbones_cfg.EfficientNet(model_id=model_id, se_ratio=se_ratio)) norm_activation_config = common_cfg.NormActivation(norm_momentum=0.99, norm_epsilon=1e-5, use_sync_bn=False) factory_network = factory.build_backbone( input_specs=tf.keras.layers.InputSpec(shape=[None, None, None, 3]), backbone_config=backbone_config, norm_activation_config=norm_activation_config) network_config = network.get_config() factory_network_config = factory_network.get_config() self.assertEqual(network_config, factory_network_config)
def build_backbone(input_specs: tf.keras.layers.InputSpec, model_config, l2_regularizer: tf.keras.regularizers.Regularizer = None): """Builds backbone from a config. Args: input_specs: tf.keras.layers.InputSpec. model_config: a OneOfConfig. Model config. l2_regularizer: tf.keras.regularizers.Regularizer instance. Default to None. Returns: tf.keras.Model instance of the backbone. """ backbone_type = model_config.backbone.type backbone_cfg = model_config.backbone.get() norm_activation_config = model_config.norm_activation if backbone_type == 'resnet': backbone = backbones.ResNet( model_id=backbone_cfg.model_id, input_specs=input_specs, activation=norm_activation_config.activation, use_sync_bn=norm_activation_config.use_sync_bn, norm_momentum=norm_activation_config.norm_momentum, norm_epsilon=norm_activation_config.norm_epsilon, kernel_regularizer=l2_regularizer) elif backbone_type == 'efficientnet': backbone = backbones.EfficientNet( model_id=backbone_cfg.model_id, input_specs=input_specs, stochastic_depth_drop_rate=backbone_cfg.stochastic_depth_drop_rate, se_ratio=backbone_cfg.se_ratio, activation=norm_activation_config.activation, use_sync_bn=norm_activation_config.use_sync_bn, norm_momentum=norm_activation_config.norm_momentum, norm_epsilon=norm_activation_config.norm_epsilon, kernel_regularizer=l2_regularizer) elif backbone_type == 'spinenet': model_id = backbone_cfg.model_id if model_id not in spinenet.SCALING_MAP: raise ValueError( 'SpineNet-{} is not a valid architecture.'.format(model_id)) scaling_params = spinenet.SCALING_MAP[model_id] backbone = backbones.SpineNet( input_specs=input_specs, min_level=model_config.min_level, max_level=model_config.max_level, endpoints_num_filters=scaling_params['endpoints_num_filters'], resample_alpha=scaling_params['resample_alpha'], block_repeats=scaling_params['block_repeats'], filter_size_scale=scaling_params['filter_size_scale'], kernel_regularizer=l2_regularizer, activation=norm_activation_config.activation, use_sync_bn=norm_activation_config.use_sync_bn, norm_momentum=norm_activation_config.norm_momentum, norm_epsilon=norm_activation_config.norm_epsilon) elif backbone_type == 'revnet': backbone = backbones.RevNet( model_id=backbone_cfg.model_id, input_specs=input_specs, activation=norm_activation_config.activation, use_sync_bn=norm_activation_config.use_sync_bn, norm_momentum=norm_activation_config.norm_momentum, norm_epsilon=norm_activation_config.norm_epsilon, kernel_regularizer=l2_regularizer) else: raise ValueError('Backbone {!r} not implement'.format(backbone_type)) return backbone