def get_model(): x0 = tf.keras.layers.Input(shape=(13, )) x1 = Cross(projection_dim=None)(x0, x0) x2 = Cross(projection_dim=None)(x0, x1) logits = tf.keras.layers.Dense(units=1)(x2) model = tf.keras.Model(x0, logits) return model
def test_invalid_diag_scale(self): with self.assertRaisesRegexp(ValueError, r"`diag_scale` should be non-negative"): x0 = np.asarray([[0.1, 0.2, 0.3]]).astype(np.float32) x = np.asarray([[0.4, 0.5, 0.6]]).astype(np.float32) layer = Cross(diag_scale=-1.) layer(x0, x)
def test_invalid_proj_dim(self): with self.assertRaisesRegexp(ValueError, r"should be smaller than last_dim / 2"): x0 = np.random.random((12, 5)) x = np.random.random((12, 5)) layer = Cross(projection_dim=6) layer(x0, x)
def test_low_rank_matrix(self): x0 = np.asarray([[0.1, 0.2, 0.3]]).astype(np.float32) x = np.asarray([[0.4, 0.5, 0.6]]).astype(np.float32) layer = Cross(projection_dim=1, kernel_initializer="ones") output = layer(x0, x) self.evaluate(tf.compat.v1.global_variables_initializer()) self.assertAllClose(np.asarray([[0.55, 0.8, 1.05]]), output)
def test_diag_scale(self): x0 = np.asarray([[0.1, 0.2, 0.3]]).astype(np.float32) x = np.asarray([[0.4, 0.5, 0.6]]).astype(np.float32) layer = Cross(projection_dim=None, diag_scale=1., kernel_initializer="ones") output = layer(x0, x) self.evaluate(tf.compat.v1.global_variables_initializer()) self.assertAllClose(np.asarray([[0.59, 0.9, 1.23]]), output)
def test_serialization(self): layer = Cross(projection_dim=None) serialized_layer = tf.keras.layers.serialize(layer) new_layer = tf.keras.layers.deserialize(serialized_layer) self.assertEqual(layer.get_config(), new_layer.get_config())
def test_unsupported_input_dim(self): with self.assertRaisesRegexp(ValueError, r"dimension mismatch"): x0 = np.random.random((12, 5)) x = np.random.random((12, 7)) layer = Cross() layer(x0, x)
def test_one_input(self): x0 = np.asarray([[0.1, 0.2, 0.3]]).astype(np.float32) layer = Cross(projection_dim=None, kernel_initializer="ones") output = layer(x0) self.evaluate(tf.compat.v1.global_variables_initializer()) self.assertAllClose(np.asarray([[0.16, 0.32, 0.48]]), output)