def test_esn_echo_state_property_norm2(): use_norm2 = True units = 3 cell = ESNCell( units=units, use_norm2=use_norm2, recurrent_initializer="ones", connectivity=1.0 ) cell.build((3, 3)) recurrent_weights = tf.constant(cell.get_weights()[0]) max_eig = tf.reduce_max(tf.abs(tf.linalg.eig(recurrent_weights)[0])) assert max_eig < 1, "max(eig(W)) < 1"
def test_esn_config(): cell = ESNCell( units=3, connectivity=1, leaky=1, spectral_radius=0.9, use_norm2=False, use_bias=True, activation="tanh", kernel_initializer="glorot_uniform", recurrent_initializer="glorot_uniform", bias_initializer="glorot_uniform", name="esn_cell_3", ) expected_config = { "name": "esn_cell_3", "trainable": True, "dtype": "float32", "units": 3, "connectivity": 1, "leaky": 1, "spectral_radius": 0.9, "use_norm2": False, "use_bias": True, "activation": tf.keras.activations.serialize(tf.keras.activations.get("tanh")), "kernel_initializer": tf.keras.initializers.serialize( tf.keras.initializers.get("glorot_uniform")), "recurrent_initializer": tf.keras.initializers.serialize( tf.keras.initializers.get("glorot_uniform")), "bias_initializer": tf.keras.initializers.serialize( tf.keras.initializers.get("glorot_uniform")), } config = cell.get_config() assert config == expected_config restored_cell = ESNCell.from_config(config) restored_config = restored_cell.get_config() assert config == restored_config
def test_esn_connectivity(): units = 1000 connectivity = 0.5 cell = ESNCell( units=units, connectivity=connectivity, use_norm2=True, recurrent_initializer="ones", ) cell.build((3, 3)) recurrent_weights = tf.constant(cell.get_weights()[0]) num_non_zero = tf.math.count_nonzero(recurrent_weights) actual_connectivity = tf.divide(num_non_zero, units**2) np.testing.assert_allclose(np.asarray([actual_connectivity]), np.asanyarray([connectivity]), 1e-2)
def __init__(self, units: TensorLike, connectivity: FloatTensorLike = 0.1, leaky: FloatTensorLike = 1, spectral_radius: FloatTensorLike = 0.9, use_norm2: bool = False, use_bias: bool = True, activation: Activation = "tanh", kernel_initializer: Initializer = "glorot_uniform", recurrent_initializer: Initializer = "glorot_uniform", bias_initializer: Initializer = "zeros", return_sequences=False, go_backwards=False, unroll=False, **kwargs): cell = ESNCell( units, connectivity=connectivity, leaky=leaky, spectral_radius=spectral_radius, use_norm2=use_norm2, use_bias=use_bias, activation=activation, kernel_initializer=kernel_initializer, recurrent_initializer=recurrent_initializer, bias_initializer=bias_initializer, dtype=kwargs.get("dtype"), ) super().__init__( cell, return_sequences=return_sequences, go_backwards=go_backwards, unroll=unroll, **kwargs, )
def test_base_esn(): units = 3 expected_output = np.array( [[2.77, 2.77, 2.77], [4.77, 4.77, 4.77], [6.77, 6.77, 6.77]], dtype=np.float32) const_initializer = tf.constant_initializer(0.5) cell = ESNCell( units=units, connectivity=1, leaky=1, spectral_radius=0.9, use_norm2=True, use_bias=True, activation=None, kernel_initializer=const_initializer, recurrent_initializer=const_initializer, bias_initializer=const_initializer, ) inputs = tf.constant( np.array( [[1.0, 1.0, 1.0, 1.0], [2.0, 2.0, 2.0, 2.0], [3.0, 3.0, 3.0, 3.0]], dtype=np.float32, ), dtype=tf.float32, ) state_value = tf.constant(0.3 * np.ones((units, units), dtype=np.float32), dtype=tf.float32) init_state = [state_value, state_value] output, state = cell(inputs, init_state) np.testing.assert_allclose(output, expected_output, 1e-5) np.testing.assert_allclose(state, expected_output, 1e-5)
def test_esn_keras_rnn(): cell = ESNCell(10) seq_input = tf.convert_to_tensor(np.random.rand(2, 3, 5), name="seq_input", dtype=tf.float32) rnn_layer = keras.layers.RNN(cell=cell) rnn_outputs = rnn_layer(seq_input) assert rnn_outputs.shape == (2, 10)
def test_esn_keras_rnn_e2e(): inputs = np.random.random((2, 3, 4)) targets = np.abs(np.random.random((2, 5))) targets /= targets.sum(axis=-1, keepdims=True) cell = ESNCell(5) model = keras.models.Sequential() model.add(keras.layers.Masking(input_shape=(3, 4))) model.add(keras.layers.RNN(cell)) model.compile(loss="categorical_crossentropy", optimizer="rmsprop") model.fit(inputs, targets, epochs=1, batch_size=2, verbose=1)