def test_esn_echo_state_property_norm2(): use_norm2 = True units = 3 cell = ESNCell( units=units, use_norm2=use_norm2, recurrent_initializer="ones", connectivity=1.0 ) cell.build((3, 3)) recurrent_weights = tf.constant(cell.get_weights()[0]) max_eig = tf.reduce_max(tf.abs(tf.linalg.eig(recurrent_weights)[0])) assert max_eig < 1, "max(eig(W)) < 1"
def test_esn_connectivity(): units = 1000 connectivity = 0.5 cell = ESNCell( units=units, connectivity=connectivity, use_norm2=True, recurrent_initializer="ones", ) cell.build((3, 3)) recurrent_weights = tf.constant(cell.get_weights()[0]) num_non_zero = tf.math.count_nonzero(recurrent_weights) actual_connectivity = tf.divide(num_non_zero, units**2) np.testing.assert_allclose(np.asarray([actual_connectivity]), np.asanyarray([connectivity]), 1e-2)