def test_weights(): # Check if weights get initialized correctly layer = GroupNormalization(groups=1, scale=False, center=False) layer.build((None, 3, 4)) assert len(layer.trainable_weights) == 0 assert len(layer.weights) == 0 layer = InstanceNormalization() layer.build((None, 3, 4)) assert len(layer.trainable_weights) == 2 assert len(layer.weights) == 2
def test_weights(self): # Check if weights get initialized correctly layer = GroupNormalization(groups=1, scale=False, center=False) layer.build((None, 3, 4)) self.assertEqual(len(layer.trainable_weights), 0) self.assertEqual(len(layer.weights), 0) layer = InstanceNormalization() layer.build((None, 3, 4)) self.assertEqual(len(layer.trainable_weights), 2) self.assertEqual(len(layer.weights), 2)
def test_regularizations(self): layer = GroupNormalization( gamma_regularizer='l1', beta_regularizer='l1', groups=4, axis=2) layer.build((None, 4, 4)) self.assertEqual(len(layer.losses), 2) max_norm = tf.keras.constraints.max_norm layer = GroupNormalization( gamma_constraint=max_norm, beta_constraint=max_norm) layer.build((None, 3, 4)) self.assertEqual(layer.gamma.constraint, max_norm) self.assertEqual(layer.beta.constraint, max_norm)
def test_regularizations(): layer = GroupNormalization(gamma_regularizer="l1", beta_regularizer="l1", groups=4, axis=2) layer.build((None, 4, 4)) assert len(layer.losses) == 2 max_norm = tf.keras.constraints.max_norm layer = GroupNormalization(gamma_constraint=max_norm, beta_constraint=max_norm) layer.build((None, 3, 4)) assert layer.gamma.constraint == max_norm assert layer.beta.constraint == max_norm
def test_group_norm_compute_output_shape(center, scale): target_variables_len = [center, scale].count(True) target_trainable_variables_len = [center, scale].count(True) layer1 = GroupNormalization(groups=2, center=center, scale=scale) layer1.build(input_shape=[8, 28, 28, 16]) # build() assert len(layer1.variables) == target_variables_len assert len(layer1.trainable_variables) == target_trainable_variables_len layer2 = GroupNormalization(groups=2, center=center, scale=scale) layer2.compute_output_shape(input_shape=[8, 28, 28, 16]) # compute_output_shape() assert len(layer2.variables) == target_variables_len assert len(layer2.trainable_variables) == target_trainable_variables_len layer3 = GroupNormalization(groups=2, center=center, scale=scale) layer3(tf.random.normal(shape=[8, 28, 28, 16])) # call() assert len(layer3.variables) == target_variables_len assert len(layer3.trainable_variables) == target_trainable_variables_len