def test_hidden_types(hidden_cell, fft, rng, seed): x = rng.uniform(-1, 1, size=(2, 5, 32)) lmu_params = dict( memory_d=1, order=3, theta=4, kernel_initializer=tf.keras.initializers.glorot_uniform(seed=seed), ) base_lmu = tf.keras.layers.RNN( layers.LMUCell(hidden_cell=None, **lmu_params), return_sequences=True, ) base_output = base_lmu(x) if isinstance(hidden_cell, tf.keras.layers.SimpleRNNCell): base_output = tf.keras.layers.RNN(hidden_cell, return_sequences=True)( base_output ) elif isinstance(hidden_cell, tf.keras.layers.Dense): base_output = hidden_cell(base_output) lmu = ( layers.LMUFFT(hidden_cell=hidden_cell, return_sequences=True, **lmu_params) if fft else tf.keras.layers.RNN( layers.LMUCell(hidden_cell=hidden_cell, **lmu_params), return_sequences=True, ) ) lmu_output = lmu(x) assert np.allclose(lmu_output, base_output, atol=2e-6 if fft else 1e-8)
def test_connection_params(fft, hidden_cell): input_shape = (32, 7 if fft else None, 6) x = tf.keras.Input(batch_shape=input_shape) lmu_args = dict( memory_d=1, order=3, theta=4, hidden_cell=hidden_cell if hidden_cell is None else hidden_cell(units=5), input_to_hidden=False, ) if not fft: lmu_args["hidden_to_memory"] = False lmu_args["memory_to_memory"] = False lmu = layers.LMUCell(**lmu_args) if not fft else layers.LMUFFT(**lmu_args) y = lmu(x) if fft else tf.keras.layers.RNN(lmu)(x) assert lmu.kernel.shape == (input_shape[-1], lmu.memory_d) if not fft: assert lmu.recurrent_kernel is None if hidden_cell is not None: assert lmu.hidden_cell.kernel.shape == ( lmu.memory_d * lmu.order, lmu.hidden_cell.units, ) assert y.shape == ( input_shape[0], lmu.memory_d * lmu.order if hidden_cell is None else lmu.hidden_cell.units, ) lmu_args["input_to_hidden"] = hidden_cell is not None if not fft: lmu_args["hidden_to_memory"] = hidden_cell is not None lmu_args["memory_to_memory"] = True lmu = layers.LMUCell(**lmu_args) if not fft else layers.LMUFFT(**lmu_args) if hidden_cell is not None: lmu.hidden_cell.built = False # so that the kernel will be rebuilt y = lmu(x) if fft else tf.keras.layers.RNN(lmu)(x) assert lmu.kernel.shape == ( input_shape[-1] + (0 if fft or hidden_cell is None else lmu.hidden_cell.units), lmu.memory_d, ) if not fft: assert lmu.recurrent_kernel.shape == ( lmu.order * lmu.memory_d, lmu.memory_d, ) if hidden_cell is not None: assert lmu.hidden_cell.kernel.shape == ( lmu.memory_d * lmu.order + input_shape[-1], lmu.hidden_cell.units, ) assert y.shape == ( input_shape[0], lmu.memory_d * lmu.order if hidden_cell is None else lmu.hidden_cell.units, )
def test_multivariate_lmu(rng): memory_d = 4 order = 16 n_steps = 10 input_d = 32 input_enc = rng.uniform(0, 1, size=(input_d, memory_d)) # check that one multivariate LMU is the same as n one-dimensional LMUs (omitting # the hidden part) inp = tf.keras.Input(shape=(n_steps, input_d)) multi_lmu = tf.keras.layers.RNN( layers.LMUCell( memory_d=memory_d, order=order, theta=n_steps, kernel_initializer=tf.initializers.constant(input_enc), hidden_cell=tf.keras.layers.SimpleRNNCell( units=memory_d * order, activation=None, kernel_initializer=tf.initializers.constant(np.eye(memory_d * order)), recurrent_initializer=tf.initializers.zeros(), ), ), return_sequences=True, )(inp) lmus = [ tf.keras.layers.RNN( layers.LMUCell( memory_d=1, order=order, theta=n_steps, kernel_initializer=tf.initializers.constant(input_enc[:, [i]]), hidden_cell=tf.keras.layers.SimpleRNNCell( units=order, activation=None, kernel_initializer=tf.initializers.constant(np.eye(order)), recurrent_initializer=tf.initializers.zeros(), ), ), return_sequences=True, )(inp) for i in range(memory_d) ] model = tf.keras.Model(inp, [multi_lmu] + lmus) results = model.predict(rng.uniform(0, 1, size=(1, n_steps, input_d))) for i in range(memory_d): assert np.allclose( results[0][..., i * order : (i + 1) * order], results[i + 1], atol=2e-6 )
def test_validation_errors(): fft_layer = layers.LMUFFT(1, 2, 3, None) with pytest.raises(ValueError, match="temporal axis be fully specified"): fft_layer(tf.keras.Input((None, 32))) with pytest.raises(ValueError, match="hidden_to_memory must be False"): layers.LMUCell(1, 2, 3, None, hidden_to_memory=True) with pytest.raises(ValueError, match="input_to_hidden must be False"): layers.LMUCell(1, 2, 3, None, input_to_hidden=True) with pytest.raises(ValueError, match="input_to_hidden must be False"): layers.LMUFFT(1, 2, 3, None, input_to_hidden=True)
def test_fft(return_sequences, hidden_cell, memory_d, rng): kwargs = dict(memory_d=memory_d, order=2, theta=3, hidden_cell=hidden_cell()) x = rng.uniform(-1, 1, size=(2, 10, 32)) rnn_layer = tf.keras.layers.RNN( layers.LMUCell(**kwargs), return_sequences=return_sequences, ) rnn_out = rnn_layer(x) fft_layer = layers.LMUFFT(return_sequences=return_sequences, **kwargs) fft_layer.build(x.shape) fft_layer.kernel.assign(rnn_layer.cell.kernel) fft_out = fft_layer(x) assert np.allclose(rnn_out, fft_out, atol=2e-6)
def test_save_load_serialization(mode, tmp_path): inp = tf.keras.Input((10 if mode == "fft" else None, 32)) if mode == "cell": out = tf.keras.layers.RNN( layers.LMUCell(1, 2, 3, tf.keras.layers.SimpleRNNCell(4)), return_sequences=True, )(inp) elif mode == "lmu": out = layers.LMU( 1, 2, 3, tf.keras.layers.SimpleRNNCell(4), return_sequences=True, memory_to_memory=True, )(inp) elif mode == "fft": out = layers.LMUFFT( 1, 2, 3, tf.keras.layers.SimpleRNNCell(4), return_sequences=True, )(inp) model = tf.keras.Model(inp, out) model.save(str(tmp_path)) model_load = tf.keras.models.load_model( str(tmp_path), custom_objects={ "LMUCell": layers.LMUCell, "LMU": layers.LMU, "LMUFFT": layers.LMUFFT, }, ) assert np.allclose( model.predict(np.ones((32, 10, 32))), model_load.predict(np.ones((32, 10, 32))) )
def test_layer_vs_cell(has_input_kernel, fft, rng): n_steps = 10 input_d = 32 kwargs = dict( memory_d=4 if has_input_kernel else input_d, order=12, theta=n_steps, kernel_initializer="glorot_uniform" if has_input_kernel else None, memory_to_memory=not fft, ) hidden_cell = lambda: tf.keras.layers.SimpleRNNCell(units=64) inp = rng.uniform(-1, 1, size=(2, n_steps, input_d)) lmu_cell = tf.keras.layers.RNN( layers.LMUCell(hidden_cell=hidden_cell(), **kwargs), return_sequences=True, ) cell_out = lmu_cell(inp) lmu_layer = layers.LMU(return_sequences=True, hidden_cell=hidden_cell(), **kwargs) lmu_layer.build(inp.shape) lmu_layer.layer.set_weights(lmu_cell.get_weights()) layer_out = lmu_layer(inp) assert isinstance(lmu_layer.layer, layers.LMUFFT if fft else tf.keras.layers.RNN) for w0, w1 in zip( sorted(lmu_cell.weights, key=lambda w: w.shape.as_list()), sorted(lmu_layer.weights, key=lambda w: w.shape.as_list()), ): assert np.allclose(w0.numpy(), w1.numpy()) atol = 2e-6 if fft else 1e-8 assert np.allclose(cell_out, lmu_cell(inp), atol=atol) assert np.allclose(cell_out, layer_out, atol=atol)