Exemplo n.º 1
0
 def test_repetitions_increases_number_parameters(self):
   input_shape = (32, 32, 3)
   num_classes = 10
   small_resnet = resnet_models.create_resnet(
       input_shape, num_classes, repetitions=[1, 1])
   big_resnet = resnet_models.create_resnet(
       input_shape, num_classes, repetitions=[2, 2])
   self.assertLess(small_resnet.count_params(), big_resnet.count_params())
Exemplo n.º 2
0
  def test_basic_block_has_fewer_parameters_than_bottleneck(self):
    input_shape = (32, 32, 3)
    num_classes = 10
    basic_resnet = resnet_models.create_resnet(
        input_shape,
        num_classes,
        residual_block=resnet_models.ResidualBlock.basic)
    bottleneck_resnet = resnet_models.create_resnet(
        input_shape,
        num_classes,
        residual_block=resnet_models.ResidualBlock.bottleneck)

    self.assertLess(basic_resnet.count_params(),
                    bottleneck_resnet.count_params())
Exemplo n.º 3
0
 def test_initial_kernel_size_with_negative_values_raises(self):
   with self.assertRaisesRegex(
       ValueError, 'initial_kernel_size must be an iterable of length 2 '
       'containing only positive integers'):
     resnet_models.create_resnet(
         input_shape=(32, 32, 3), initial_kernel_size=(3, -1))
Exemplo n.º 4
0
 def test_initial_strides_with_length_not_2_raises(self):
   with self.assertRaisesRegex(
       ValueError, 'initial_strides must be an iterable of length 2 containing'
       ' only positive integers'):
     resnet_models.create_resnet(
         input_shape=(32, 32, 3), initial_strides=(3, 3, 4))
Exemplo n.º 5
0
 def test_negative_initial_filters_raises(self):
   with self.assertRaisesRegex(ValueError,
                               'initial_filters must be a positive integer'):
     resnet_models.create_resnet(input_shape=(32, 32, 3), initial_filters=-2)
Exemplo n.º 6
0
 def test_repetitions_with_negative_values_raises(self):
   with self.assertRaisesRegex(
       ValueError, 'repetitions must be None or an iterable containing '
       'positive integers'):
     resnet_models.create_resnet(input_shape=(32, 32, 3), repetitions=[2, -1])
Exemplo n.º 7
0
 def test_unsupported_residual_block_raises(self):
   with self.assertRaisesRegex(
       ValueError, 'residual_block must be of type `ResidualBlock`'):
     resnet_models.create_resnet(
         input_shape=(32, 32, 3), residual_block='bad_block')
Exemplo n.º 8
0
 def test_negative_num_classes_raises(self):
   with self.assertRaisesRegex(ValueError,
                               'num_classes must be a positive integer'):
     resnet_models.create_resnet(input_shape=(32, 32, 3), num_classes=-5)
Exemplo n.º 9
0
 def test_non_length_3_input_shape_raises(self):
   with self.assertRaisesRegex(
       ValueError, 'input_shape must be an iterable of length 3 containing '
       'only positive integers'):
     resnet_models.create_resnet(input_shape=(10, 10))
Exemplo n.º 10
0
 def test_resnet_constructs_with_group_norm(self):
   group_resnet = resnet_models.create_resnet(
       input_shape=(32, 32, 3),
       num_classes=10,
       norm_layer=resnet_models.NormLayer.group_norm)
   self.assertIsInstance(group_resnet, tf.keras.Model)
Exemplo n.º 11
0
 def test_unsupported_norm_raises(self):
   with self.assertRaisesRegex(
       ValueError, 'norm_layer must be of type `NormLayer`'):
     resnet_models.create_resnet(
         input_shape=(32, 32, 3), norm_layer='bad_norm')