def test_variables_are_not_enumerated_when_overridden(self): layer = signal_conv.SignalConv2D(1, 1) layer.kernel_parameter = [[[[1]]]] layer.bias_parameter = [0] layer.build((None, None, None, 1)) self.assertEmpty(layer.weights) self.assertEmpty(layer.trainable_weights)
def test_variables_are_enumerated(self): layer = signal_conv.SignalConv2D(3, 1, use_bias=True) layer.build((None, None, None, 2)) self.assertLen(layer.weights, 2) self.assertLen(layer.trainable_weights, 2) weight_names = [w.name for w in layer.weights] self.assertSameElements(weight_names, ["kernel_rdft:0", "bias:0"])
def test_bias_variable_is_not_unnecessarily_created(self): layer = signal_conv.SignalConv2D(5, 3, use_bias=False) layer.build((None, None, None, 3)) self.assertLen(layer.weights, 1) self.assertLen(layer.trainable_weights, 1) weight_names = [w.name for w in layer.weights] self.assertSameElements(weight_names, ["kernel_rdft:0"])
def test_dtypes_are_correct_with_mixed_precision(self): tf.keras.mixed_precision.set_global_policy("mixed_float16") try: x = tf.random.uniform((1, 4, 4, 3), dtype=tf.float16) layer = signal_conv.SignalConv2D(2, 3, use_bias=True) y = layer(x) for variable in layer.variables: self.assertEqual(variable.dtype, tf.float32) self.assertEqual(y.dtype, tf.float16) finally: tf.keras.mixed_precision.set_global_policy(None)
def test_variables_trainable_state_follows_layer(self): layer = signal_conv.SignalConv2D(1, 1, use_bias=True) layer.trainable = False layer.build((None, None, None, 1)) self.assertLen(layer.weights, 2) self.assertEmpty(layer.trainable_weights)