示例#1
0
def test_store_restore():
    source = NeuronGroup(10, '''dv/dt = rates : 1
                                rates : Hz''', threshold='v>1', reset='v=0')
    source.rates = 'i*100*Hz'
    target = NeuronGroup(10, 'v:1')
    synapses = Synapses(source, target, model='w:1', pre='v+=w', connect='i==j')
    synapses.w = 'i*1.0'
    synapses.delay = 'i*ms'
    state_mon = StateMonitor(target, 'v', record=True)
    spike_mon = SpikeMonitor(source)
    net = Network(source, target, synapses, state_mon, spike_mon)
    net.store()  # default time slot
    net.run(10*ms)
    net.store('second')
    net.run(10*ms)
    v_values = state_mon.v[:, :]
    spike_indices, spike_times = spike_mon.it_

    net.restore() # Go back to beginning
    assert defaultclock.t == 0*ms
    assert net.t == 0*ms
    net.run(20*ms)
    assert_equal(v_values, state_mon.v[:, :])
    assert_equal(spike_indices, spike_mon.i[:])
    assert_equal(spike_times, spike_mon.t_[:])

    # Go back to middle
    net.restore('second')
    assert defaultclock.t == 10*ms
    assert net.t == 10*ms
    net.run(10*ms)
    assert_equal(v_values, state_mon.v[:, :])
    assert_equal(spike_indices, spike_mon.i[:])
    assert_equal(spike_times, spike_mon.t_[:])
示例#2
0
def test_dt_restore():
    defaultclock.dt = 0.5*ms
    G = NeuronGroup(1, 'dv/dt = -v/(10*ms) : 1')
    mon = StateMonitor(G, 'v', record=True)
    net = Network(G, mon)
    net.store()

    net.run(1*ms)
    assert_equal(mon.t[:], [0, 0.5]*ms)
    defaultclock.dt = 1*ms
    net.run(2*ms)
    assert_equal(mon.t[:], [0, 0.5, 1, 2]*ms)
    net.restore()
    assert_equal(mon.t[:], [])
    net.run(1*ms)
    assert defaultclock.dt == 0.5*ms
    assert_equal(mon.t[:], [0, 0.5]*ms)
示例#3
0
def test_store_restore_to_file_differing_nets():
    # Check that the store/restore mechanism is not used with differing
    # networks
    filename = tempfile.mktemp(suffix='state', prefix='brian_test')

    source = SpikeGeneratorGroup(5, [0, 1, 2, 3, 4], [0, 1, 2, 3, 4]*ms,
                                 name='source_1')
    mon = SpikeMonitor(source, name='monitor')
    net = Network(source, mon)
    net.store(filename=filename)

    source_2 = SpikeGeneratorGroup(5, [0, 1, 2, 3, 4], [0, 1, 2, 3, 4]*ms,
                                   name='source_2')
    mon = SpikeMonitor(source_2, name='monitor')
    net = Network(source_2, mon)
    assert_raises(KeyError, lambda: net.restore(filename=filename))

    net = Network(source)  # Without the monitor
    assert_raises(KeyError, lambda: net.restore(filename=filename))
示例#4
0
def test_store_restore_to_file_differing_nets():
    # Check that the store/restore mechanism is not used with differing
    # networks
    filename = tempfile.mktemp(suffix='state', prefix='brian_test')

    source = SpikeGeneratorGroup(5, [0, 1, 2, 3, 4], [0, 1, 2, 3, 4] * ms,
                                 name='source_1')
    mon = SpikeMonitor(source, name='monitor')
    net = Network(source, mon)
    net.store(filename=filename)

    source_2 = SpikeGeneratorGroup(5, [0, 1, 2, 3, 4], [0, 1, 2, 3, 4] * ms,
                                   name='source_2')
    mon = SpikeMonitor(source_2, name='monitor')
    net = Network(source_2, mon)
    assert_raises(KeyError, lambda: net.restore(filename=filename))

    net = Network(source)  # Without the monitor
    assert_raises(KeyError, lambda: net.restore(filename=filename))
示例#5
0
def test_store_restore_to_file():
    filename = tempfile.mktemp(suffix='state', prefix='brian_test')
    source = NeuronGroup(10,
                         '''dv/dt = rates : 1
                                rates : Hz''',
                         threshold='v>1',
                         reset='v=0')
    source.rates = 'i*100*Hz'
    target = NeuronGroup(10, 'v:1')
    synapses = Synapses(source, target, model='w:1', on_pre='v+=w')
    synapses.connect(j='i')
    synapses.w = 'i*1.0'
    synapses.delay = 'i*ms'
    state_mon = StateMonitor(target, 'v', record=True)
    spike_mon = SpikeMonitor(source)
    net = Network(source, target, synapses, state_mon, spike_mon)
    net.store(filename=filename)  # default time slot
    net.run(10 * ms)
    net.store('second', filename=filename)
    net.run(10 * ms)
    v_values = state_mon.v[:, :]
    spike_indices, spike_times = spike_mon.it_

    net.restore(filename=filename)  # Go back to beginning
    assert defaultclock.t == 0 * ms
    assert net.t == 0 * ms
    net.run(20 * ms)
    assert_equal(v_values, state_mon.v[:, :])
    assert_equal(spike_indices, spike_mon.i[:])
    assert_equal(spike_times, spike_mon.t_[:])

    # Go back to middle
    net.restore('second', filename=filename)
    assert defaultclock.t == 10 * ms
    assert net.t == 10 * ms
    net.run(10 * ms)
    assert_equal(v_values, state_mon.v[:, :])
    assert_equal(spike_indices, spike_mon.i[:])
    assert_equal(spike_times, spike_mon.t_[:])
    try:
        os.remove(filename)
    except OSError:
        pass
示例#6
0
def test_store_restore_to_file():
    filename = tempfile.mktemp(suffix='state', prefix='brian_test')
    source = NeuronGroup(10, '''dv/dt = rates : 1
                                rates : Hz''', threshold='v>1', reset='v=0')
    source.rates = 'i*100*Hz'
    target = NeuronGroup(10, 'v:1')
    synapses = Synapses(source, target, model='w:1', on_pre='v+=w')
    synapses.connect(j='i')
    synapses.w = 'i*1.0'
    synapses.delay = 'i*ms'
    state_mon = StateMonitor(target, 'v', record=True)
    spike_mon = SpikeMonitor(source)
    net = Network(source, target, synapses, state_mon, spike_mon)
    net.store(filename=filename)  # default time slot
    net.run(10*ms)
    net.store('second', filename=filename)
    net.run(10*ms)
    v_values = state_mon.v[:, :]
    spike_indices, spike_times = spike_mon.it_

    net.restore(filename=filename) # Go back to beginning
    assert defaultclock.t == 0*ms
    assert net.t == 0*ms
    net.run(20*ms)
    assert_equal(v_values, state_mon.v[:, :])
    assert_equal(spike_indices, spike_mon.i[:])
    assert_equal(spike_times, spike_mon.t_[:])

    # Go back to middle
    net.restore('second', filename=filename)
    assert defaultclock.t == 10*ms
    assert net.t == 10*ms
    net.run(10*ms)
    assert_equal(v_values, state_mon.v[:, :])
    assert_equal(spike_indices, spike_mon.i[:])
    assert_equal(spike_times, spike_mon.t_[:])
    try:
        os.remove(filename)
    except OSError:
        pass