Esempio n. 1
0
def test_hidden_types(hidden_cell, fft, rng, seed):
    x = rng.uniform(-1, 1, size=(2, 5, 32))

    lmu_params = dict(
        memory_d=1,
        order=3,
        theta=4,
        kernel_initializer=tf.keras.initializers.glorot_uniform(seed=seed),
    )

    base_lmu = tf.keras.layers.RNN(
        layers.LMUCell(hidden_cell=None, **lmu_params),
        return_sequences=True,
    )
    base_output = base_lmu(x)
    if isinstance(hidden_cell, tf.keras.layers.SimpleRNNCell):
        base_output = tf.keras.layers.RNN(hidden_cell, return_sequences=True)(
            base_output
        )
    elif isinstance(hidden_cell, tf.keras.layers.Dense):
        base_output = hidden_cell(base_output)

    lmu = (
        layers.LMUFFT(hidden_cell=hidden_cell, return_sequences=True, **lmu_params)
        if fft
        else tf.keras.layers.RNN(
            layers.LMUCell(hidden_cell=hidden_cell, **lmu_params),
            return_sequences=True,
        )
    )
    lmu_output = lmu(x)

    assert np.allclose(lmu_output, base_output, atol=2e-6 if fft else 1e-8)
Esempio n. 2
0
def test_connection_params(fft, hidden_cell):
    input_shape = (32, 7 if fft else None, 6)

    x = tf.keras.Input(batch_shape=input_shape)

    lmu_args = dict(
        memory_d=1,
        order=3,
        theta=4,
        hidden_cell=hidden_cell if hidden_cell is None else hidden_cell(units=5),
        input_to_hidden=False,
    )
    if not fft:
        lmu_args["hidden_to_memory"] = False
        lmu_args["memory_to_memory"] = False

    lmu = layers.LMUCell(**lmu_args) if not fft else layers.LMUFFT(**lmu_args)
    y = lmu(x) if fft else tf.keras.layers.RNN(lmu)(x)
    assert lmu.kernel.shape == (input_shape[-1], lmu.memory_d)
    if not fft:
        assert lmu.recurrent_kernel is None
    if hidden_cell is not None:
        assert lmu.hidden_cell.kernel.shape == (
            lmu.memory_d * lmu.order,
            lmu.hidden_cell.units,
        )
    assert y.shape == (
        input_shape[0],
        lmu.memory_d * lmu.order if hidden_cell is None else lmu.hidden_cell.units,
    )

    lmu_args["input_to_hidden"] = hidden_cell is not None
    if not fft:
        lmu_args["hidden_to_memory"] = hidden_cell is not None
        lmu_args["memory_to_memory"] = True

    lmu = layers.LMUCell(**lmu_args) if not fft else layers.LMUFFT(**lmu_args)
    if hidden_cell is not None:
        lmu.hidden_cell.built = False  # so that the kernel will be rebuilt
    y = lmu(x) if fft else tf.keras.layers.RNN(lmu)(x)
    assert lmu.kernel.shape == (
        input_shape[-1] + (0 if fft or hidden_cell is None else lmu.hidden_cell.units),
        lmu.memory_d,
    )
    if not fft:
        assert lmu.recurrent_kernel.shape == (
            lmu.order * lmu.memory_d,
            lmu.memory_d,
        )
    if hidden_cell is not None:
        assert lmu.hidden_cell.kernel.shape == (
            lmu.memory_d * lmu.order + input_shape[-1],
            lmu.hidden_cell.units,
        )
    assert y.shape == (
        input_shape[0],
        lmu.memory_d * lmu.order if hidden_cell is None else lmu.hidden_cell.units,
    )
Esempio n. 3
0
def test_multivariate_lmu(rng):
    memory_d = 4
    order = 16
    n_steps = 10
    input_d = 32

    input_enc = rng.uniform(0, 1, size=(input_d, memory_d))

    # check that one multivariate LMU is the same as n one-dimensional LMUs (omitting
    # the hidden part)
    inp = tf.keras.Input(shape=(n_steps, input_d))
    multi_lmu = tf.keras.layers.RNN(
        layers.LMUCell(
            memory_d=memory_d,
            order=order,
            theta=n_steps,
            kernel_initializer=tf.initializers.constant(input_enc),
            hidden_cell=tf.keras.layers.SimpleRNNCell(
                units=memory_d * order,
                activation=None,
                kernel_initializer=tf.initializers.constant(np.eye(memory_d * order)),
                recurrent_initializer=tf.initializers.zeros(),
            ),
        ),
        return_sequences=True,
    )(inp)
    lmus = [
        tf.keras.layers.RNN(
            layers.LMUCell(
                memory_d=1,
                order=order,
                theta=n_steps,
                kernel_initializer=tf.initializers.constant(input_enc[:, [i]]),
                hidden_cell=tf.keras.layers.SimpleRNNCell(
                    units=order,
                    activation=None,
                    kernel_initializer=tf.initializers.constant(np.eye(order)),
                    recurrent_initializer=tf.initializers.zeros(),
                ),
            ),
            return_sequences=True,
        )(inp)
        for i in range(memory_d)
    ]

    model = tf.keras.Model(inp, [multi_lmu] + lmus)

    results = model.predict(rng.uniform(0, 1, size=(1, n_steps, input_d)))

    for i in range(memory_d):
        assert np.allclose(
            results[0][..., i * order : (i + 1) * order], results[i + 1], atol=2e-6
        )
Esempio n. 4
0
def test_validation_errors():
    fft_layer = layers.LMUFFT(1, 2, 3, None)
    with pytest.raises(ValueError, match="temporal axis be fully specified"):
        fft_layer(tf.keras.Input((None, 32)))

    with pytest.raises(ValueError, match="hidden_to_memory must be False"):
        layers.LMUCell(1, 2, 3, None, hidden_to_memory=True)

    with pytest.raises(ValueError, match="input_to_hidden must be False"):
        layers.LMUCell(1, 2, 3, None, input_to_hidden=True)

    with pytest.raises(ValueError, match="input_to_hidden must be False"):
        layers.LMUFFT(1, 2, 3, None, input_to_hidden=True)
Esempio n. 5
0
def test_fft(return_sequences, hidden_cell, memory_d, rng):
    kwargs = dict(memory_d=memory_d, order=2, theta=3, hidden_cell=hidden_cell())

    x = rng.uniform(-1, 1, size=(2, 10, 32))

    rnn_layer = tf.keras.layers.RNN(
        layers.LMUCell(**kwargs),
        return_sequences=return_sequences,
    )
    rnn_out = rnn_layer(x)

    fft_layer = layers.LMUFFT(return_sequences=return_sequences, **kwargs)
    fft_layer.build(x.shape)
    fft_layer.kernel.assign(rnn_layer.cell.kernel)
    fft_out = fft_layer(x)

    assert np.allclose(rnn_out, fft_out, atol=2e-6)
Esempio n. 6
0
def test_save_load_serialization(mode, tmp_path):
    inp = tf.keras.Input((10 if mode == "fft" else None, 32))
    if mode == "cell":
        out = tf.keras.layers.RNN(
            layers.LMUCell(1, 2, 3, tf.keras.layers.SimpleRNNCell(4)),
            return_sequences=True,
        )(inp)
    elif mode == "lmu":
        out = layers.LMU(
            1,
            2,
            3,
            tf.keras.layers.SimpleRNNCell(4),
            return_sequences=True,
            memory_to_memory=True,
        )(inp)
    elif mode == "fft":
        out = layers.LMUFFT(
            1,
            2,
            3,
            tf.keras.layers.SimpleRNNCell(4),
            return_sequences=True,
        )(inp)

    model = tf.keras.Model(inp, out)

    model.save(str(tmp_path))

    model_load = tf.keras.models.load_model(
        str(tmp_path),
        custom_objects={
            "LMUCell": layers.LMUCell,
            "LMU": layers.LMU,
            "LMUFFT": layers.LMUFFT,
        },
    )

    assert np.allclose(
        model.predict(np.ones((32, 10, 32))), model_load.predict(np.ones((32, 10, 32)))
    )
Esempio n. 7
0
def test_layer_vs_cell(has_input_kernel, fft, rng):
    n_steps = 10
    input_d = 32
    kwargs = dict(
        memory_d=4 if has_input_kernel else input_d,
        order=12,
        theta=n_steps,
        kernel_initializer="glorot_uniform" if has_input_kernel else None,
        memory_to_memory=not fft,
    )
    hidden_cell = lambda: tf.keras.layers.SimpleRNNCell(units=64)

    inp = rng.uniform(-1, 1, size=(2, n_steps, input_d))

    lmu_cell = tf.keras.layers.RNN(
        layers.LMUCell(hidden_cell=hidden_cell(), **kwargs),
        return_sequences=True,
    )
    cell_out = lmu_cell(inp)

    lmu_layer = layers.LMU(return_sequences=True, hidden_cell=hidden_cell(), **kwargs)
    lmu_layer.build(inp.shape)
    lmu_layer.layer.set_weights(lmu_cell.get_weights())
    layer_out = lmu_layer(inp)

    assert isinstance(lmu_layer.layer, layers.LMUFFT if fft else tf.keras.layers.RNN)

    for w0, w1 in zip(
        sorted(lmu_cell.weights, key=lambda w: w.shape.as_list()),
        sorted(lmu_layer.weights, key=lambda w: w.shape.as_list()),
    ):
        assert np.allclose(w0.numpy(), w1.numpy())

    atol = 2e-6 if fft else 1e-8
    assert np.allclose(cell_out, lmu_cell(inp), atol=atol)
    assert np.allclose(cell_out, layer_out, atol=atol)