Ejemplo n.º 1
0
def test_forward_pass(base_layer, input_shape):
    sample_data = np.ones([1] + input_shape, dtype=np.float32)
    base_layer = base_layer()
    base_output = base_layer(sample_data)
    wn_layer = wrappers.WeightNormalization(base_layer, False)
    wn_output = wn_layer(sample_data)
    np.testing.assert_allclose(base_output, wn_output, rtol=1e-6, atol=1e-6)
Ejemplo n.º 2
0
def test_save_file_h5(base_layer, input_shape):
    base_layer = base_layer()
    wn_conv = wrappers.WeightNormalization(base_layer)
    model = tf.keras.Sequential(layers=[wn_conv])
    model.build([None] + input_shape)
    with tempfile.TemporaryDirectory() as tmp_dir:
        model.save_weights(os.path.join(tmp_dir, "wrapper_test_model.h5"))
Ejemplo n.º 3
0
 def test_weightnorm_applylayer(self):
     images = tf.random.uniform((2, 4, 4, 3))
     wn_wrapper = wrappers.WeightNormalization(tf.keras.layers.Conv2D(
         32, [2, 2]),
                                               input_shape=(4, 4, 3))
     wn_wrapper.apply(images)
     self.assertTrue(hasattr(wn_wrapper, 'g'))
Ejemplo n.º 4
0
 def test_model_build(self, base_layer_fn, input_shape):
     inputs = tf.keras.layers.Input(shape=input_shape)
     for data_init in [True, False]:
         base_layer = base_layer_fn()
         wt_layer = wrappers.WeightNormalization(base_layer, data_init)
         model = tf.keras.models.Sequential(layers=[inputs, wt_layer])
         model.build()
Ejemplo n.º 5
0
 def test_save_file_h5(self, base_layer, input_shape):
     self.create_tempfile("wrapper_test_model.h5")
     base_layer = base_layer()
     wn_conv = wrappers.WeightNormalization(base_layer)
     model = tf.keras.Sequential(layers=[wn_conv])
     model.build([None] + input_shape)
     model.save_weights("wrapper_test_model.h5")
Ejemplo n.º 6
0
 def test_with_time_dist(self):
     batch_shape = (32, 16, 64, 64, 3)
     inputs = tf.keras.layers.Input(batch_shape=batch_shape)
     a = tf.keras.layers.Conv2D(3, 5)
     b = wrappers.WeightNormalization(a)
     out = tf.keras.layers.TimeDistributed(b)(inputs)
     tf.keras.Model(inputs, out)
Ejemplo n.º 7
0
 def test_save_file_h5(self):
     self.create_tempfile('wrapper_test_model.h5')
     conv = tf.keras.layers.Conv1D(1, 1)
     wn_conv = wrappers.WeightNormalization(conv)
     model = tf.keras.Sequential(layers=[wn_conv])
     model.build([1, 2, 3])
     model.save_weights('wrapper_test_model.h5')
Ejemplo n.º 8
0
 def test_forward_pass(self, base_layer, input_shape):
     sample_data = np.ones([1] + input_shape, dtype=np.float32)
     base_layer = base_layer()
     base_output = base_layer(sample_data)
     wn_layer = wrappers.WeightNormalization(base_layer, False)
     wn_output = wn_layer(sample_data)
     self.evaluate(tf.compat.v1.global_variables_initializer())
     self.assertAllClose(self.evaluate(base_output), self.evaluate(wn_output))
Ejemplo n.º 9
0
def test_forward_pass(base_layer, input_shape):
    sample_data = np.ones([1] + input_shape, dtype=np.float32)
    base_layer = base_layer()
    base_output = base_layer(sample_data)
    wn_layer = wrappers.WeightNormalization(base_layer, False)
    wn_output = wn_layer(sample_data)
    tf.compat.v1.global_variables_initializer()
    np.testing.assert_allclose(base_output, wn_output)
Ejemplo n.º 10
0
def test_removal(base_layer_fn, input_shape, data_init):
    sample_data = np.ones([1] + input_shape, dtype=np.float32)

    base_layer = base_layer_fn()
    wn_layer = wrappers.WeightNormalization(base_layer, data_init)
    wn_output = wn_layer(sample_data)
    wn_removed_layer = wn_layer.remove()
    wn_removed_output = wn_removed_layer(sample_data)
    np.testing.assert_allclose(wn_removed_output.numpy(), wn_output.numpy())
    assert isinstance(wn_removed_layer, base_layer.__class__)
Ejemplo n.º 11
0
 def test_serialization(self, base_layer, rnn):
     base_layer = base_layer()
     wn_layer = wrappers.WeightNormalization(base_layer, not rnn)
     new_wn_layer = tf.keras.layers.deserialize(tf.keras.layers.serialize(wn_layer))
     self.assertEqual(wn_layer.data_init, new_wn_layer.data_init)
     self.assertEqual(wn_layer.is_rnn, new_wn_layer.is_rnn)
     self.assertEqual(wn_layer.is_rnn, rnn)
     if not isinstance(base_layer, tf.keras.layers.LSTM):
         # Issue with LSTM serialization, check with TF-core
         # Before serialization: tensorflow.python.keras.layers.recurrent_v2.LSTM
         # After serialization: tensorflow.python.keras.layers.recurrent.LSTM
         self.assertTrue(isinstance(new_wn_layer.layer, base_layer.__class__))
Ejemplo n.º 12
0
 def test_weightnorm_dense_train(self):
     model = tf.keras.models.Sequential()
     model.add(
         wrappers.WeightNormalization(tf.keras.layers.Dense(2),
                                      input_shape=(3, 4)))
     model.compile(
         optimizer=tf.keras.optimizers.RMSprop(learning_rate=0.001),
         loss='mse')
     model.fit(np.random.random((10, 3, 4)),
               np.random.random((10, 3, 2)),
               epochs=3,
               batch_size=10)
     self.assertTrue(hasattr(model.layers[0], 'g'))
Ejemplo n.º 13
0
    def test_removal(self, base_layer_fn, input_shape):
        sample_data = np.ones([1] + input_shape, dtype=np.float32)

        for data_init in [True, False]:
            base_layer = base_layer_fn()
            wn_layer = wrappers.WeightNormalization(base_layer, data_init)
            wn_output = wn_layer(sample_data)
            self.evaluate(tf.compat.v1.global_variables_initializer())
            with tf.control_dependencies([wn_output]):
                wn_removed_layer = wn_layer.remove()
                wn_removed_output = wn_removed_layer(sample_data)

            self.evaluate(tf.compat.v1.global_variables_initializer())
            self.assertAllClose(self.evaluate(wn_removed_output),
                                self.evaluate(wn_output))
            self.assertTrue(isinstance(wn_removed_layer, base_layer.__class__))
Ejemplo n.º 14
0
    def test_weightnorm_conv2d(self):
        model = tf.keras.models.Sequential()
        model.add(
            wrappers.WeightNormalization(tf.keras.layers.Conv2D(
                5, (2, 2), padding='same'),
                                         input_shape=(4, 4, 3)))

        model.add(tf.keras.layers.Activation('relu'))
        model.compile(optimizer=tf.optimizers.RMSprop(learning_rate=0.001),
                      loss='mse')
        model.fit(np.random.random((2, 4, 4, 3)),
                  np.random.random((2, 4, 4, 5)),
                  epochs=3,
                  batch_size=10)

        self.assertTrue(hasattr(model.layers[0], 'g'))
Ejemplo n.º 15
0
 def test_non_kernel_layer(self):
     images = tf.random.uniform((2, 2, 2))
     with self.assertRaisesRegexp(ValueError, "contains a `kernel`"):
         non_kernel_layer = tf.keras.layers.MaxPooling2D(2, 2)
         wn_wrapper = wrappers.WeightNormalization(non_kernel_layer)
         wn_wrapper(images)
Ejemplo n.º 16
0
 def test_weightnorm_nonlayer(self):
     images = tf.random.uniform((2, 4, 43))
     with self.assertRaises(AssertionError):
         wrappers.WeightNormalization(images)
Ejemplo n.º 17
0
def test_non_kernel_layer():
    images = tf.random.uniform((2, 2, 2))
    with pytest.raises(ValueError, match="contains a `kernel`"):
        non_kernel_layer = tf.keras.layers.MaxPooling2D(2, 2)
        wn_wrapper = wrappers.WeightNormalization(non_kernel_layer)
        wn_wrapper(images)
Ejemplo n.º 18
0
def test_non_layer():
    images = tf.random.uniform((2, 4, 3))
    with pytest.raises(AssertionError):
        wrappers.WeightNormalization(images)
Ejemplo n.º 19
0
 def test_weightnorm_nokernel(self):
     with self.assertRaises(ValueError):
         wrappers.WeightNormalization(tf.keras.layers.MaxPooling2D(
             2, 2)).build((2, 2))
Ejemplo n.º 20
0
 def test_weightnorm_with_rnn(self):
     inputs = tf.keras.layers.Input(shape=(None, 3))
     rnn_layer = tf.keras.layers.SimpleRNN(4)
     wt_rnn = wrappers.WeightNormalization(rnn_layer)
     dense = tf.keras.layers.Dense(1)
     model = tf.keras.models.Sequential(layers=[inputs, wt_rnn, dense])