Example #1
0
 def get_model():
     x0 = tf.keras.layers.Input(shape=(13, ))
     x1 = Cross(projection_dim=None)(x0, x0)
     x2 = Cross(projection_dim=None)(x0, x1)
     logits = tf.keras.layers.Dense(units=1)(x2)
     model = tf.keras.Model(x0, logits)
     return model
Example #2
0
 def test_invalid_diag_scale(self):
     with self.assertRaisesRegexp(ValueError,
                                  r"`diag_scale` should be non-negative"):
         x0 = np.asarray([[0.1, 0.2, 0.3]]).astype(np.float32)
         x = np.asarray([[0.4, 0.5, 0.6]]).astype(np.float32)
         layer = Cross(diag_scale=-1.)
         layer(x0, x)
Example #3
0
 def test_invalid_proj_dim(self):
     with self.assertRaisesRegexp(ValueError,
                                  r"should be smaller than last_dim / 2"):
         x0 = np.random.random((12, 5))
         x = np.random.random((12, 5))
         layer = Cross(projection_dim=6)
         layer(x0, x)
Example #4
0
 def test_low_rank_matrix(self):
     x0 = np.asarray([[0.1, 0.2, 0.3]]).astype(np.float32)
     x = np.asarray([[0.4, 0.5, 0.6]]).astype(np.float32)
     layer = Cross(projection_dim=1, kernel_initializer="ones")
     output = layer(x0, x)
     self.evaluate(tf.compat.v1.global_variables_initializer())
     self.assertAllClose(np.asarray([[0.55, 0.8, 1.05]]), output)
Example #5
0
 def test_diag_scale(self):
     x0 = np.asarray([[0.1, 0.2, 0.3]]).astype(np.float32)
     x = np.asarray([[0.4, 0.5, 0.6]]).astype(np.float32)
     layer = Cross(projection_dim=None,
                   diag_scale=1.,
                   kernel_initializer="ones")
     output = layer(x0, x)
     self.evaluate(tf.compat.v1.global_variables_initializer())
     self.assertAllClose(np.asarray([[0.59, 0.9, 1.23]]), output)
Example #6
0
 def test_serialization(self):
     layer = Cross(projection_dim=None)
     serialized_layer = tf.keras.layers.serialize(layer)
     new_layer = tf.keras.layers.deserialize(serialized_layer)
     self.assertEqual(layer.get_config(), new_layer.get_config())
Example #7
0
 def test_unsupported_input_dim(self):
     with self.assertRaisesRegexp(ValueError, r"dimension mismatch"):
         x0 = np.random.random((12, 5))
         x = np.random.random((12, 7))
         layer = Cross()
         layer(x0, x)
Example #8
0
 def test_one_input(self):
     x0 = np.asarray([[0.1, 0.2, 0.3]]).astype(np.float32)
     layer = Cross(projection_dim=None, kernel_initializer="ones")
     output = layer(x0)
     self.evaluate(tf.compat.v1.global_variables_initializer())
     self.assertAllClose(np.asarray([[0.16, 0.32, 0.48]]), output)