Beispiel #1
0
def test_op_constant(dtype, diff, sess):
    ops = (SimNeurons(LIF(tau_rc=1), Signal(np.zeros(10)), None),
           SimNeurons(LIF(tau_rc=2 if diff else 1), Signal(np.zeros(10)),
                      None))

    signals = SignalDict(tf.float32, 1)
    const = signals.op_constant(
        [op.neurons for op in ops], [op.J.shape[0] for op in ops],
        "tau_rc", dtype)
    const1 = signals.op_constant(
        [op.neurons for op in ops], [op.J.shape[0] for op in ops],
        "tau_rc", dtype, ndims=1)
    const3 = signals.op_constant(
        [op.neurons for op in ops], [op.J.shape[0] for op in ops],
        "tau_rc", dtype, ndims=3)

    assert const.dtype.base_dtype == dtype

    sess.run(tf.variables_initializer(tf.get_collection("constants")),
             feed_dict=signals.constant_phs)
    x, x1, x3 = sess.run([const, const1, const3])

    if diff:
        assert np.array_equal(x, [[1]] * 10 + [[2]] * 10)
        assert np.array_equal(x[:, 0], x1)
        assert np.array_equal(x, x3[..., 0])
    else:
        assert np.array_equal(x, 1.0)
        assert np.array_equal(x, x1)
        assert np.array_equal(x, x3)
Beispiel #2
0
def test_op_constant(dtype, diff):
    ops = (
        SimNeurons(LIF(tau_rc=1), Signal(np.zeros(10)), None),
        SimNeurons(LIF(tau_rc=2 if diff else 1), Signal(np.zeros(10)), None),
    )

    signals = SignalDict(tf.float32, 1)
    const = signals.op_constant(
        [op.neurons for op in ops], [op.J.shape[0] for op in ops], "tau_rc", dtype
    )
    const1 = signals.op_constant(
        [op.neurons for op in ops],
        [op.J.shape[0] for op in ops],
        "tau_rc",
        dtype,
        shape=(-1,),
    )
    const3 = signals.op_constant(
        [op.neurons for op in ops],
        [op.J.shape[0] for op in ops],
        "tau_rc",
        dtype,
        shape=(1, -1, 1),
    )

    assert const.dtype.base_dtype == dtype

    if diff:
        assert np.array_equal(const, [[1.0] * 10 + [2.0] * 10])
        assert np.array_equal(const[0], const1)
        assert np.array_equal(const, const3[..., 0])
    else:
        assert np.array_equal(const, 1.0)
        assert np.array_equal(const, const1)
        assert np.array_equal(const, const3)