예제 #1
0
 def test_variables_are_not_enumerated_when_overridden(self):
     layer = signal_conv.SignalConv2D(1, 1)
     layer.kernel_parameter = [[[[1]]]]
     layer.bias_parameter = [0]
     layer.build((None, None, None, 1))
     self.assertEmpty(layer.weights)
     self.assertEmpty(layer.trainable_weights)
예제 #2
0
 def test_variables_are_enumerated(self):
     layer = signal_conv.SignalConv2D(3, 1, use_bias=True)
     layer.build((None, None, None, 2))
     self.assertLen(layer.weights, 2)
     self.assertLen(layer.trainable_weights, 2)
     weight_names = [w.name for w in layer.weights]
     self.assertSameElements(weight_names, ["kernel_rdft:0", "bias:0"])
예제 #3
0
 def test_bias_variable_is_not_unnecessarily_created(self):
     layer = signal_conv.SignalConv2D(5, 3, use_bias=False)
     layer.build((None, None, None, 3))
     self.assertLen(layer.weights, 1)
     self.assertLen(layer.trainable_weights, 1)
     weight_names = [w.name for w in layer.weights]
     self.assertSameElements(weight_names, ["kernel_rdft:0"])
예제 #4
0
 def test_dtypes_are_correct_with_mixed_precision(self):
     tf.keras.mixed_precision.set_global_policy("mixed_float16")
     try:
         x = tf.random.uniform((1, 4, 4, 3), dtype=tf.float16)
         layer = signal_conv.SignalConv2D(2, 3, use_bias=True)
         y = layer(x)
         for variable in layer.variables:
             self.assertEqual(variable.dtype, tf.float32)
         self.assertEqual(y.dtype, tf.float16)
     finally:
         tf.keras.mixed_precision.set_global_policy(None)
예제 #5
0
 def test_variables_trainable_state_follows_layer(self):
     layer = signal_conv.SignalConv2D(1, 1, use_bias=True)
     layer.trainable = False
     layer.build((None, None, None, 1))
     self.assertLen(layer.weights, 2)
     self.assertEmpty(layer.trainable_weights)