Beispiel #1
0
 def test_non_iterable_initial_kernel_size_raises(self):
     with self.assertRaisesRegex(
             ValueError,
             'initial_kernel_size must be an iterable of length 2 '
             'containing only positive integers'):
         resnet_models.create_resnet(input_shape=(32, 32, 3),
                                     initial_kernel_size=10)
Beispiel #2
0
 def test_repetitions_with_negative_values_raises(self):
     with self.assertRaisesRegex(
             ValueError,
             'repetitions must be None or an iterable containing '
             'positive integers'):
         resnet_models.create_resnet(input_shape=(32, 32, 3),
                                     repetitions=[2, -1])
Beispiel #3
0
 def test_initial_strides_with_length_not_2_raises(self):
     with self.assertRaisesRegex(
             ValueError,
             'initial_strides must be an iterable of length 2 containing'
             ' only positive integers'):
         resnet_models.create_resnet(input_shape=(32, 32, 3),
                                     initial_strides=(3, 3, 4))
Beispiel #4
0
 def test_repetitions_increases_number_parameters(self):
     input_shape = (32, 32, 3)
     num_classes = 10
     small_resnet = resnet_models.create_resnet(input_shape,
                                                num_classes,
                                                repetitions=[1, 1])
     big_resnet = resnet_models.create_resnet(input_shape,
                                              num_classes,
                                              repetitions=[2, 2])
     self.assertLess(small_resnet.count_params(), big_resnet.count_params())
Beispiel #5
0
    def test_basic_block_has_fewer_parameters_than_bottleneck(self):
        input_shape = (32, 32, 3)
        num_classes = 10
        basic_resnet = resnet_models.create_resnet(
            input_shape,
            num_classes,
            residual_block=resnet_models.ResidualBlock.BASIC)
        bottleneck_resnet = resnet_models.create_resnet(
            input_shape,
            num_classes,
            residual_block=resnet_models.ResidualBlock.BOTTLENECK)

        self.assertLess(basic_resnet.count_params(),
                        bottleneck_resnet.count_params())
Beispiel #6
0
 def test_negative_initial_filters_raises(self):
     with self.assertRaisesRegex(
             ValueError, 'initial_filters must be a positive integer'):
         resnet_models.create_resnet(input_shape=(32, 32, 3),
                                     initial_filters=-2)
Beispiel #7
0
 def test_unsupported_residual_block_raises(self):
     with self.assertRaisesRegex(
             ValueError, 'residual_block must be of type `ResidualBlock`'):
         resnet_models.create_resnet(input_shape=(32, 32, 3),
                                     residual_block='bad_block')
Beispiel #8
0
 def test_negative_num_classes_raises(self):
     with self.assertRaisesRegex(ValueError,
                                 'num_classes must be a positive integer'):
         resnet_models.create_resnet(input_shape=(32, 32, 3),
                                     num_classes=-5)
Beispiel #9
0
 def test_non_length_3_input_shape_raises(self):
     with self.assertRaisesRegex(
             ValueError,
             'input_shape must be an iterable of length 3 containing '
             'only positive integers'):
         resnet_models.create_resnet(input_shape=(10, 10))
Beispiel #10
0
 def test_resnet_constructs_with_group_norm(self):
     group_resnet = resnet_models.create_resnet(
         input_shape=(32, 32, 3),
         num_classes=10,
         norm_layer=resnet_models.NormLayer.GROUP_NORM)
     self.assertIsInstance(group_resnet, tf.keras.Model)
Beispiel #11
0
 def test_resnet_constructs_with_batch_norm(self):
     batch_resnet = resnet_models.create_resnet(
         input_shape=(32, 32, 3),
         num_classes=10,
         norm_layer=resnet_models.NormLayer.BATCH_NORM)
     self.assertIsInstance(batch_resnet, tf.keras.Model)
Beispiel #12
0
 def test_unsupported_norm_raises(self):
     with self.assertRaisesRegex(ValueError,
                                 'norm_layer must be of type `NormLayer`'):
         resnet_models.create_resnet(input_shape=(32, 32, 3),
                                     norm_layer='bad_norm')