def test_non_iterable_initial_kernel_size_raises(self): with self.assertRaisesRegex( ValueError, 'initial_kernel_size must be an iterable of length 2 ' 'containing only positive integers'): resnet_models.create_resnet(input_shape=(32, 32, 3), initial_kernel_size=10)
def test_repetitions_with_negative_values_raises(self): with self.assertRaisesRegex( ValueError, 'repetitions must be None or an iterable containing ' 'positive integers'): resnet_models.create_resnet(input_shape=(32, 32, 3), repetitions=[2, -1])
def test_initial_strides_with_length_not_2_raises(self): with self.assertRaisesRegex( ValueError, 'initial_strides must be an iterable of length 2 containing' ' only positive integers'): resnet_models.create_resnet(input_shape=(32, 32, 3), initial_strides=(3, 3, 4))
def test_repetitions_increases_number_parameters(self): input_shape = (32, 32, 3) num_classes = 10 small_resnet = resnet_models.create_resnet(input_shape, num_classes, repetitions=[1, 1]) big_resnet = resnet_models.create_resnet(input_shape, num_classes, repetitions=[2, 2]) self.assertLess(small_resnet.count_params(), big_resnet.count_params())
def test_basic_block_has_fewer_parameters_than_bottleneck(self): input_shape = (32, 32, 3) num_classes = 10 basic_resnet = resnet_models.create_resnet( input_shape, num_classes, residual_block=resnet_models.ResidualBlock.BASIC) bottleneck_resnet = resnet_models.create_resnet( input_shape, num_classes, residual_block=resnet_models.ResidualBlock.BOTTLENECK) self.assertLess(basic_resnet.count_params(), bottleneck_resnet.count_params())
def test_negative_initial_filters_raises(self): with self.assertRaisesRegex( ValueError, 'initial_filters must be a positive integer'): resnet_models.create_resnet(input_shape=(32, 32, 3), initial_filters=-2)
def test_unsupported_residual_block_raises(self): with self.assertRaisesRegex( ValueError, 'residual_block must be of type `ResidualBlock`'): resnet_models.create_resnet(input_shape=(32, 32, 3), residual_block='bad_block')
def test_negative_num_classes_raises(self): with self.assertRaisesRegex(ValueError, 'num_classes must be a positive integer'): resnet_models.create_resnet(input_shape=(32, 32, 3), num_classes=-5)
def test_non_length_3_input_shape_raises(self): with self.assertRaisesRegex( ValueError, 'input_shape must be an iterable of length 3 containing ' 'only positive integers'): resnet_models.create_resnet(input_shape=(10, 10))
def test_resnet_constructs_with_group_norm(self): group_resnet = resnet_models.create_resnet( input_shape=(32, 32, 3), num_classes=10, norm_layer=resnet_models.NormLayer.GROUP_NORM) self.assertIsInstance(group_resnet, tf.keras.Model)
def test_resnet_constructs_with_batch_norm(self): batch_resnet = resnet_models.create_resnet( input_shape=(32, 32, 3), num_classes=10, norm_layer=resnet_models.NormLayer.BATCH_NORM) self.assertIsInstance(batch_resnet, tf.keras.Model)
def test_unsupported_norm_raises(self): with self.assertRaisesRegex(ValueError, 'norm_layer must be of type `NormLayer`'): resnet_models.create_resnet(input_shape=(32, 32, 3), norm_layer='bad_norm')