Ejemplo n.º 1
0
def test_invalid_inputs():
    with pytest.raises(ValueError, match="must be a tuple or list of size 2"):
        x0 = np.random.random((12, 5))
        x = np.random.random((12, 5))
        x1 = np.random.random((12, 5))
        layer = PolynomialCrossing(projection_dim=6)
        layer([x0, x, x1])
Ejemplo n.º 2
0
def test_invalid_proj_dim():
    with pytest.raises(ValueError) as exception_info:
        x0 = np.random.random((12, 5))
        x = np.random.random((12, 5))
        layer = PolynomialCrossing(projection_dim=6)
        layer([x0, x])
    assert "is not supported yet" in str(exception_info.value)
Ejemplo n.º 3
0
def test_invalid_proj_dim():
    with pytest.raises(ValueError,
                       match="should be smaller than last_dim / 2"):
        x0 = np.random.random((12, 5))
        x = np.random.random((12, 5))
        layer = PolynomialCrossing(projection_dim=6)
        layer([x0, x])
Ejemplo n.º 4
0
def test_invalid_inputs():
    with pytest.raises(ValueError) as exception_info:
        x0 = np.random.random((12, 5))
        x = np.random.random((12, 5))
        x1 = np.random.random((12, 5))
        layer = PolynomialCrossing(projection_dim=6)
        layer([x0, x, x1])
    assert "must be a tuple or list of size 2" in str(exception_info.value)
Ejemplo n.º 5
0
 def test_invalid_inputs(self):
     with self.assertRaisesRegexp(ValueError,
                                  r"must be a tuple or list of size 2"):
         x0 = np.random.random((12, 5))
         x = np.random.random((12, 5))
         x1 = np.random.random((12, 5))
         layer = PolynomialCrossing(projection_dim=6)
         layer([x0, x, x1])
Ejemplo n.º 6
0
 def test_full_matrix(self):
     x0 = np.asarray([[0.1, 0.2, 0.3]]).astype(np.float32)
     x = np.asarray([[0.4, 0.5, 0.6]]).astype(np.float32)
     layer = PolynomialCrossing(projection_dim=None,
                                kernel_initializer="ones")
     output = layer([x0, x])
     self.evaluate(tf.compat.v1.global_variables_initializer())
     self.assertAllClose(np.asarray([[0.55, 0.8, 1.05]]), output)
Ejemplo n.º 7
0
def test_diag_scale():
    x0 = np.asarray([[0.1, 0.2, 0.3]]).astype(np.float32)
    x = np.asarray([[0.4, 0.5, 0.6]]).astype(np.float32)
    layer = PolynomialCrossing(projection_dim=None,
                               diag_scale=1.0,
                               kernel_initializer="ones")
    output = layer([x0, x])
    np.testing.assert_allclose([[0.59, 0.9, 1.23]], output)
Ejemplo n.º 8
0
def test_serialization():
    layer = PolynomialCrossing(projection_dim=None)
    serialized_layer = tf.keras.layers.serialize(layer)
    new_layer = tf.keras.layers.deserialize(serialized_layer)
    assert layer.get_config() == new_layer.get_config()
Ejemplo n.º 9
0
def test_full_matrix():
    x0 = np.asarray([[0.1, 0.2, 0.3]]).astype(np.float32)
    x = np.asarray([[0.4, 0.5, 0.6]]).astype(np.float32)
    layer = PolynomialCrossing(projection_dim=None, kernel_initializer="ones")
    output = layer([x0, x])
    np.testing.assert_allclose([[0.55, 0.8, 1.05]], output)
Ejemplo n.º 10
0
 def test_serialization(self):
     layer = PolynomialCrossing(projection_dim=None)
     serialized_layer = tf.keras.layers.serialize(layer)
     new_layer = tf.keras.layers.deserialize(serialized_layer)
     self.assertEqual(layer.get_config(), new_layer.get_config())
Ejemplo n.º 11
0
 def test_invalid_proj_dim(self):
     with self.assertRaisesRegexp(ValueError, r"is not supported yet"):
         x0 = np.random.random((12, 5))
         x = np.random.random((12, 5))
         layer = PolynomialCrossing(projection_dim=6)
         layer([x0, x])