Ejemplo n.º 1
0
def test_esn_echo_state_property_norm2():
    use_norm2 = True
    units = 3
    cell = ESNCell(
        units=units, use_norm2=use_norm2, recurrent_initializer="ones", connectivity=1.0
    )
    cell.build((3, 3))
    recurrent_weights = tf.constant(cell.get_weights()[0])
    max_eig = tf.reduce_max(tf.abs(tf.linalg.eig(recurrent_weights)[0]))
    assert max_eig < 1, "max(eig(W)) < 1"
Ejemplo n.º 2
0
def test_esn_connectivity():
    units = 1000
    connectivity = 0.5
    cell = ESNCell(
        units=units,
        connectivity=connectivity,
        use_norm2=True,
        recurrent_initializer="ones",
    )
    cell.build((3, 3))
    recurrent_weights = tf.constant(cell.get_weights()[0])
    num_non_zero = tf.math.count_nonzero(recurrent_weights)
    actual_connectivity = tf.divide(num_non_zero, units**2)
    np.testing.assert_allclose(np.asarray([actual_connectivity]),
                               np.asanyarray([connectivity]), 1e-2)