def test_mobilenet_network_creation(self, mobilenet_model_id, filter_size_scale): """Test for creation of a MobileNet classifier.""" mobilenet_params = { ('MobileNetV1', 1.0): 4254889, ('MobileNetV1', 0.75): 2602745, ('MobileNetV2', 1.0): 3540265, ('MobileNetV2', 0.75): 2664345, ('MobileNetV3Large', 1.0): 5508713, ('MobileNetV3Large', 0.75): 4013897, ('MobileNetV3Small', 1.0): 2555993, ('MobileNetV3Small', 0.75): 2052577, ('MobileNetV3EdgeTPU', 1.0): 4131593, ('MobileNetV3EdgeTPU', 0.75): 3019569, } inputs = np.random.rand(2, 224, 224, 3) tf.keras.backend.set_image_data_format('channels_last') backbone = backbones.MobileNet( model_id=mobilenet_model_id, filter_size_scale=filter_size_scale) num_classes = 1001 model = classification_model.ClassificationModel( backbone=backbone, num_classes=num_classes, dropout_rate=0.2, ) self.assertEqual(model.count_params(), mobilenet_params[(mobilenet_model_id, filter_size_scale)]) logits = model(inputs) self.assertAllEqual([2, num_classes], logits.numpy().shape)
def test_mobilenet_creation(self, model_id, filter_size_scale): """Test creation of Mobilenet models.""" network = backbones.MobileNet(model_id=model_id, filter_size_scale=filter_size_scale, norm_momentum=0.99, norm_epsilon=1e-5) backbone_config = backbones_cfg.Backbone( type='mobilenet', mobilenet=backbones_cfg.MobileNet( model_id=model_id, filter_size_scale=filter_size_scale)) norm_activation_config = common_cfg.NormActivation(norm_momentum=0.99, norm_epsilon=1e-5, use_sync_bn=False) factory_network = factory.build_backbone( input_specs=tf.keras.layers.InputSpec(shape=[None, None, None, 3]), backbone_config=backbone_config, norm_activation_config=norm_activation_config) network_config = network.get_config() factory_network_config = factory_network.get_config() self.assertEqual(network_config, factory_network_config)
def test_mobilenet_network_creation(self, mobilenet_model_id, filter_size_scale): """Test for creation of a MobileNet classifier.""" inputs = np.random.rand(2, 224, 224, 3) tf.keras.backend.set_image_data_format('channels_last') backbone = backbones.MobileNet(model_id=mobilenet_model_id, filter_size_scale=filter_size_scale) num_classes = 1001 model = classification_model.ClassificationModel( backbone=backbone, num_classes=num_classes, dropout_rate=0.2, ) logits = model(inputs) self.assertAllEqual([2, num_classes], logits.numpy().shape)