Esempio n. 1
0
 def test_variables_receive_gradients(self):
     x = tf.random.uniform((1, 5, 2), dtype=tf.float32)
     layer = signal_conv.SignalConv1D(2, 3, use_bias=True)
     with tf.GradientTape() as g:
         y = layer(x)
     grads = g.gradient(y, layer.trainable_weights)
     self.assertLen(grads, 2)
     self.assertNotIn(None, grads)
     grad_shapes = [tuple(g.shape) for g in grads]
     weight_shapes = [tuple(w.shape) for w in layer.trainable_weights]
     self.assertSameElements(grad_shapes, weight_shapes)
Esempio n. 2
0
    def test_can_be_saved_within_functional_model(self, build):
        inputs = tf.keras.Input(shape=(None, 2))
        outputs = signal_conv.SignalConv1D(1,
                                           3,
                                           use_bias=True,
                                           activation=tf.nn.relu)(inputs)
        model = tf.keras.Model(inputs=inputs, outputs=outputs)
        layer = model.get_layer("signal_conv1d")

        with self.subTest(name="layer_created_as_expected"):
            self.assertIsInstance(layer, signal_conv.SignalConv1D)
            self.assertIsInstance(layer.kernel_parameter,
                                  parameters.RDFTParameter)
            self.assertIsInstance(layer.bias_parameter, tf.Variable)

        if build:
            x = tf.random.uniform((1, 5, 2), dtype=tf.float32)
            y = model(x)
            weight_names = [w.name for w in model.weights]

        tempdir = self.create_tempdir()
        model_path = os.path.join(tempdir, "model")
        # This should force the model to be reconstructed via configs.
        model.save(model_path, save_traces=False)

        model = tf.keras.models.load_model(model_path)

        layer = model.get_layer("signal_conv1d")
        with self.subTest(name="layer_recreated_as_expected"):
            self.assertIsInstance(layer, signal_conv.SignalConv1D)
            self.assertIsInstance(layer.kernel_parameter,
                                  parameters.RDFTParameter)
            self.assertIsInstance(layer.bias_parameter, tf.Variable)

        if build:
            with self.subTest(name="model_outputs_identical"):
                self.assertAllEqual(model(x), y)

            with self.subTest(name="model_weights_identical"):
                self.assertSameElements(weight_names,
                                        [w.name for w in model.weights])
Esempio n. 3
0
 def test_attributes_cannot_be_set_after_build(self):
     layer = signal_conv.SignalConv1D(2, 1)
     layer.build((None, None, 2))
     with self.assertRaises(RuntimeError):
         layer.filters = 3
     with self.assertRaises(RuntimeError):
         layer.kernel_support = 3
     with self.assertRaises(RuntimeError):
         layer.corr = True
     with self.assertRaises(RuntimeError):
         layer.strides_down = 2
     with self.assertRaises(RuntimeError):
         layer.strides_up = 2
     with self.assertRaises(RuntimeError):
         layer.padding = "valid"
     with self.assertRaises(RuntimeError):
         layer.extra_pad_end = True
     with self.assertRaises(RuntimeError):
         layer.channel_separable = True
     with self.assertRaises(RuntimeError):
         layer.data_format = "channels_first"
     with self.assertRaises(RuntimeError):
         layer.activation = tf.nn.relu
     with self.assertRaises(RuntimeError):
         layer.use_bias = False
     with self.assertRaises(RuntimeError):
         layer.use_explicit = False
     with self.assertRaises(RuntimeError):
         layer.kernel_parameter = tf.ones((1, 2, 3))
     with self.assertRaises(RuntimeError):
         layer.bias_parameter = tf.ones((3, ))
     with self.assertRaises(RuntimeError):
         layer.kernel_initializer = tf.keras.initializers.Ones()
     with self.assertRaises(RuntimeError):
         layer.bias_initializer = tf.keras.initializers.Ones()
     with self.assertRaises(RuntimeError):
         layer.kernel_regularizer = tf.keras.regularizers.L2()
     with self.assertRaises(RuntimeError):
         layer.bias_regularizer = tf.keras.regularizers.L2()
Esempio n. 4
0
 def test_invalid_data_format_raises_error(self):
     with self.assertRaises(ValueError):
         signal_conv.SignalConv1D(2, 1, data_format="NHWC")