def test_forward_pass(base_layer, input_shape): sample_data = np.ones([1] + input_shape, dtype=np.float32) base_layer = base_layer() base_output = base_layer(sample_data) wn_layer = wrappers.WeightNormalization(base_layer, False) wn_output = wn_layer(sample_data) np.testing.assert_allclose(base_output, wn_output, rtol=1e-6, atol=1e-6)
def test_save_file_h5(base_layer, input_shape): base_layer = base_layer() wn_conv = wrappers.WeightNormalization(base_layer) model = tf.keras.Sequential(layers=[wn_conv]) model.build([None] + input_shape) with tempfile.TemporaryDirectory() as tmp_dir: model.save_weights(os.path.join(tmp_dir, "wrapper_test_model.h5"))
def test_weightnorm_applylayer(self): images = tf.random.uniform((2, 4, 4, 3)) wn_wrapper = wrappers.WeightNormalization(tf.keras.layers.Conv2D( 32, [2, 2]), input_shape=(4, 4, 3)) wn_wrapper.apply(images) self.assertTrue(hasattr(wn_wrapper, 'g'))
def test_model_build(self, base_layer_fn, input_shape): inputs = tf.keras.layers.Input(shape=input_shape) for data_init in [True, False]: base_layer = base_layer_fn() wt_layer = wrappers.WeightNormalization(base_layer, data_init) model = tf.keras.models.Sequential(layers=[inputs, wt_layer]) model.build()
def test_save_file_h5(self, base_layer, input_shape): self.create_tempfile("wrapper_test_model.h5") base_layer = base_layer() wn_conv = wrappers.WeightNormalization(base_layer) model = tf.keras.Sequential(layers=[wn_conv]) model.build([None] + input_shape) model.save_weights("wrapper_test_model.h5")
def test_with_time_dist(self): batch_shape = (32, 16, 64, 64, 3) inputs = tf.keras.layers.Input(batch_shape=batch_shape) a = tf.keras.layers.Conv2D(3, 5) b = wrappers.WeightNormalization(a) out = tf.keras.layers.TimeDistributed(b)(inputs) tf.keras.Model(inputs, out)
def test_save_file_h5(self): self.create_tempfile('wrapper_test_model.h5') conv = tf.keras.layers.Conv1D(1, 1) wn_conv = wrappers.WeightNormalization(conv) model = tf.keras.Sequential(layers=[wn_conv]) model.build([1, 2, 3]) model.save_weights('wrapper_test_model.h5')
def test_forward_pass(self, base_layer, input_shape): sample_data = np.ones([1] + input_shape, dtype=np.float32) base_layer = base_layer() base_output = base_layer(sample_data) wn_layer = wrappers.WeightNormalization(base_layer, False) wn_output = wn_layer(sample_data) self.evaluate(tf.compat.v1.global_variables_initializer()) self.assertAllClose(self.evaluate(base_output), self.evaluate(wn_output))
def test_forward_pass(base_layer, input_shape): sample_data = np.ones([1] + input_shape, dtype=np.float32) base_layer = base_layer() base_output = base_layer(sample_data) wn_layer = wrappers.WeightNormalization(base_layer, False) wn_output = wn_layer(sample_data) tf.compat.v1.global_variables_initializer() np.testing.assert_allclose(base_output, wn_output)
def test_removal(base_layer_fn, input_shape, data_init): sample_data = np.ones([1] + input_shape, dtype=np.float32) base_layer = base_layer_fn() wn_layer = wrappers.WeightNormalization(base_layer, data_init) wn_output = wn_layer(sample_data) wn_removed_layer = wn_layer.remove() wn_removed_output = wn_removed_layer(sample_data) np.testing.assert_allclose(wn_removed_output.numpy(), wn_output.numpy()) assert isinstance(wn_removed_layer, base_layer.__class__)
def test_serialization(self, base_layer, rnn): base_layer = base_layer() wn_layer = wrappers.WeightNormalization(base_layer, not rnn) new_wn_layer = tf.keras.layers.deserialize(tf.keras.layers.serialize(wn_layer)) self.assertEqual(wn_layer.data_init, new_wn_layer.data_init) self.assertEqual(wn_layer.is_rnn, new_wn_layer.is_rnn) self.assertEqual(wn_layer.is_rnn, rnn) if not isinstance(base_layer, tf.keras.layers.LSTM): # Issue with LSTM serialization, check with TF-core # Before serialization: tensorflow.python.keras.layers.recurrent_v2.LSTM # After serialization: tensorflow.python.keras.layers.recurrent.LSTM self.assertTrue(isinstance(new_wn_layer.layer, base_layer.__class__))
def test_weightnorm_dense_train(self): model = tf.keras.models.Sequential() model.add( wrappers.WeightNormalization(tf.keras.layers.Dense(2), input_shape=(3, 4))) model.compile( optimizer=tf.keras.optimizers.RMSprop(learning_rate=0.001), loss='mse') model.fit(np.random.random((10, 3, 4)), np.random.random((10, 3, 2)), epochs=3, batch_size=10) self.assertTrue(hasattr(model.layers[0], 'g'))
def test_removal(self, base_layer_fn, input_shape): sample_data = np.ones([1] + input_shape, dtype=np.float32) for data_init in [True, False]: base_layer = base_layer_fn() wn_layer = wrappers.WeightNormalization(base_layer, data_init) wn_output = wn_layer(sample_data) self.evaluate(tf.compat.v1.global_variables_initializer()) with tf.control_dependencies([wn_output]): wn_removed_layer = wn_layer.remove() wn_removed_output = wn_removed_layer(sample_data) self.evaluate(tf.compat.v1.global_variables_initializer()) self.assertAllClose(self.evaluate(wn_removed_output), self.evaluate(wn_output)) self.assertTrue(isinstance(wn_removed_layer, base_layer.__class__))
def test_weightnorm_conv2d(self): model = tf.keras.models.Sequential() model.add( wrappers.WeightNormalization(tf.keras.layers.Conv2D( 5, (2, 2), padding='same'), input_shape=(4, 4, 3))) model.add(tf.keras.layers.Activation('relu')) model.compile(optimizer=tf.optimizers.RMSprop(learning_rate=0.001), loss='mse') model.fit(np.random.random((2, 4, 4, 3)), np.random.random((2, 4, 4, 5)), epochs=3, batch_size=10) self.assertTrue(hasattr(model.layers[0], 'g'))
def test_non_kernel_layer(self): images = tf.random.uniform((2, 2, 2)) with self.assertRaisesRegexp(ValueError, "contains a `kernel`"): non_kernel_layer = tf.keras.layers.MaxPooling2D(2, 2) wn_wrapper = wrappers.WeightNormalization(non_kernel_layer) wn_wrapper(images)
def test_weightnorm_nonlayer(self): images = tf.random.uniform((2, 4, 43)) with self.assertRaises(AssertionError): wrappers.WeightNormalization(images)
def test_non_kernel_layer(): images = tf.random.uniform((2, 2, 2)) with pytest.raises(ValueError, match="contains a `kernel`"): non_kernel_layer = tf.keras.layers.MaxPooling2D(2, 2) wn_wrapper = wrappers.WeightNormalization(non_kernel_layer) wn_wrapper(images)
def test_non_layer(): images = tf.random.uniform((2, 4, 3)) with pytest.raises(AssertionError): wrappers.WeightNormalization(images)
def test_weightnorm_nokernel(self): with self.assertRaises(ValueError): wrappers.WeightNormalization(tf.keras.layers.MaxPooling2D( 2, 2)).build((2, 2))
def test_weightnorm_with_rnn(self): inputs = tf.keras.layers.Input(shape=(None, 3)) rnn_layer = tf.keras.layers.SimpleRNN(4) wt_rnn = wrappers.WeightNormalization(rnn_layer) dense = tf.keras.layers.Dense(1) model = tf.keras.models.Sequential(layers=[inputs, wt_rnn, dense])