コード例 #1
0
ファイル: test_network.py プロジェクト: ttxtea/brian2
def test_store_restore():
    source = NeuronGroup(10, '''dv/dt = rates : 1
                                rates : Hz''', threshold='v>1', reset='v=0')
    source.rates = 'i*100*Hz'
    target = NeuronGroup(10, 'v:1')
    synapses = Synapses(source, target, model='w:1', pre='v+=w', connect='i==j')
    synapses.w = 'i*1.0'
    synapses.delay = 'i*ms'
    state_mon = StateMonitor(target, 'v', record=True)
    spike_mon = SpikeMonitor(source)
    net = Network(source, target, synapses, state_mon, spike_mon)
    net.store()  # default time slot
    net.run(10*ms)
    net.store('second')
    net.run(10*ms)
    v_values = state_mon.v[:, :]
    spike_indices, spike_times = spike_mon.it_

    net.restore() # Go back to beginning
    assert defaultclock.t == 0*ms
    assert net.t == 0*ms
    net.run(20*ms)
    assert_equal(v_values, state_mon.v[:, :])
    assert_equal(spike_indices, spike_mon.i[:])
    assert_equal(spike_times, spike_mon.t_[:])

    # Go back to middle
    net.restore('second')
    assert defaultclock.t == 10*ms
    assert net.t == 10*ms
    net.run(10*ms)
    assert_equal(v_values, state_mon.v[:, :])
    assert_equal(spike_indices, spike_mon.i[:])
    assert_equal(spike_times, spike_mon.t_[:])
コード例 #2
0
ファイル: cache.py プロジェクト: kburel/SNN_HFO_iEEG
def create_cache(configuration):
    model_paths = load_model_paths()
    neuron_counts = _read_neuron_counts(configuration)
    hidden_layer = create_non_input_layer(model_paths, neuron_counts.hidden,
                                          'hidden')
    hidden_layer.Itau = get_current(15.3e-3) * amp
    number_of_output_neurons = 1
    layers_connected_to_output = _get_layers_connected_to_output_count(
        configuration)
    output_layer = create_non_input_layer(
        model_paths,
        number_of_output_neurons,
        'output',
        num_inputs=layers_connected_to_output)
    hidden_to_output_synapses = create_hidden_to_output_synapses(
        'main', hidden_layer, output_layer, model_paths, neuron_counts)

    input_layer = create_input_layer('main', neuron_counts.input)

    input_to_hidden_synapses = create_input_to_hidden_synapses(
        name='main',
        input_layer=input_layer,
        hidden_layer=hidden_layer,
        model_paths=model_paths,
        neuron_counts=neuron_counts)

    spike_monitors = SpikeMonitors(hidden=SpikeMonitor(hidden_layer),
                                   output=SpikeMonitor(output_layer))
    network = Network(input_layer, input_to_hidden_synapses, hidden_layer,
                      spike_monitors.hidden, spike_monitors.output,
                      output_layer, hidden_to_output_synapses)

    if should_add_artifact_filter(configuration):
        add_artifact_filter_to_network(model_paths, input_layer, output_layer,
                                       network)

    advanced_artifact_filter_input_layer = add_advanced_artifact_filter_to_network(
        network, output_layer, model_paths, neuron_counts
    ) if should_add_advanced_artifact_filter(configuration) else None

    network.store()

    return Cache(
        model_paths=model_paths,
        neuron_counts=neuron_counts,
        spike_monitors=spike_monitors,
        network=network,
        input_layer=input_layer,
        advanced_artifact_filter_input_layer=
        advanced_artifact_filter_input_layer,
    )
コード例 #3
0
ファイル: test_network.py プロジェクト: ttxtea/brian2
def test_dt_restore():
    defaultclock.dt = 0.5*ms
    G = NeuronGroup(1, 'dv/dt = -v/(10*ms) : 1')
    mon = StateMonitor(G, 'v', record=True)
    net = Network(G, mon)
    net.store()

    net.run(1*ms)
    assert_equal(mon.t[:], [0, 0.5]*ms)
    defaultclock.dt = 1*ms
    net.run(2*ms)
    assert_equal(mon.t[:], [0, 0.5, 1, 2]*ms)
    net.restore()
    assert_equal(mon.t[:], [])
    net.run(1*ms)
    assert defaultclock.dt == 0.5*ms
    assert_equal(mon.t[:], [0, 0.5]*ms)
コード例 #4
0
def test_store_restore_to_file_differing_nets():
    # Check that the store/restore mechanism is not used with differing
    # networks
    filename = tempfile.mktemp(suffix='state', prefix='brian_test')

    source = SpikeGeneratorGroup(5, [0, 1, 2, 3, 4], [0, 1, 2, 3, 4]*ms,
                                 name='source_1')
    mon = SpikeMonitor(source, name='monitor')
    net = Network(source, mon)
    net.store(filename=filename)

    source_2 = SpikeGeneratorGroup(5, [0, 1, 2, 3, 4], [0, 1, 2, 3, 4]*ms,
                                   name='source_2')
    mon = SpikeMonitor(source_2, name='monitor')
    net = Network(source_2, mon)
    assert_raises(KeyError, lambda: net.restore(filename=filename))

    net = Network(source)  # Without the monitor
    assert_raises(KeyError, lambda: net.restore(filename=filename))
コード例 #5
0
def test_store_restore_to_file_differing_nets():
    # Check that the store/restore mechanism is not used with differing
    # networks
    filename = tempfile.mktemp(suffix='state', prefix='brian_test')

    source = SpikeGeneratorGroup(5, [0, 1, 2, 3, 4], [0, 1, 2, 3, 4] * ms,
                                 name='source_1')
    mon = SpikeMonitor(source, name='monitor')
    net = Network(source, mon)
    net.store(filename=filename)

    source_2 = SpikeGeneratorGroup(5, [0, 1, 2, 3, 4], [0, 1, 2, 3, 4] * ms,
                                   name='source_2')
    mon = SpikeMonitor(source_2, name='monitor')
    net = Network(source_2, mon)
    assert_raises(KeyError, lambda: net.restore(filename=filename))

    net = Network(source)  # Without the monitor
    assert_raises(KeyError, lambda: net.restore(filename=filename))
コード例 #6
0
def test_store_restore_to_file():
    filename = tempfile.mktemp(suffix='state', prefix='brian_test')
    source = NeuronGroup(10,
                         '''dv/dt = rates : 1
                                rates : Hz''',
                         threshold='v>1',
                         reset='v=0')
    source.rates = 'i*100*Hz'
    target = NeuronGroup(10, 'v:1')
    synapses = Synapses(source, target, model='w:1', on_pre='v+=w')
    synapses.connect(j='i')
    synapses.w = 'i*1.0'
    synapses.delay = 'i*ms'
    state_mon = StateMonitor(target, 'v', record=True)
    spike_mon = SpikeMonitor(source)
    net = Network(source, target, synapses, state_mon, spike_mon)
    net.store(filename=filename)  # default time slot
    net.run(10 * ms)
    net.store('second', filename=filename)
    net.run(10 * ms)
    v_values = state_mon.v[:, :]
    spike_indices, spike_times = spike_mon.it_

    net.restore(filename=filename)  # Go back to beginning
    assert defaultclock.t == 0 * ms
    assert net.t == 0 * ms
    net.run(20 * ms)
    assert_equal(v_values, state_mon.v[:, :])
    assert_equal(spike_indices, spike_mon.i[:])
    assert_equal(spike_times, spike_mon.t_[:])

    # Go back to middle
    net.restore('second', filename=filename)
    assert defaultclock.t == 10 * ms
    assert net.t == 10 * ms
    net.run(10 * ms)
    assert_equal(v_values, state_mon.v[:, :])
    assert_equal(spike_indices, spike_mon.i[:])
    assert_equal(spike_times, spike_mon.t_[:])
    try:
        os.remove(filename)
    except OSError:
        pass
コード例 #7
0
def test_store_restore_to_file():
    filename = tempfile.mktemp(suffix='state', prefix='brian_test')
    source = NeuronGroup(10, '''dv/dt = rates : 1
                                rates : Hz''', threshold='v>1', reset='v=0')
    source.rates = 'i*100*Hz'
    target = NeuronGroup(10, 'v:1')
    synapses = Synapses(source, target, model='w:1', on_pre='v+=w')
    synapses.connect(j='i')
    synapses.w = 'i*1.0'
    synapses.delay = 'i*ms'
    state_mon = StateMonitor(target, 'v', record=True)
    spike_mon = SpikeMonitor(source)
    net = Network(source, target, synapses, state_mon, spike_mon)
    net.store(filename=filename)  # default time slot
    net.run(10*ms)
    net.store('second', filename=filename)
    net.run(10*ms)
    v_values = state_mon.v[:, :]
    spike_indices, spike_times = spike_mon.it_

    net.restore(filename=filename) # Go back to beginning
    assert defaultclock.t == 0*ms
    assert net.t == 0*ms
    net.run(20*ms)
    assert_equal(v_values, state_mon.v[:, :])
    assert_equal(spike_indices, spike_mon.i[:])
    assert_equal(spike_times, spike_mon.t_[:])

    # Go back to middle
    net.restore('second', filename=filename)
    assert defaultclock.t == 10*ms
    assert net.t == 10*ms
    net.run(10*ms)
    assert_equal(v_values, state_mon.v[:, :])
    assert_equal(spike_indices, spike_mon.i[:])
    assert_equal(spike_times, spike_mon.t_[:])
    try:
        os.remove(filename)
    except OSError:
        pass
コード例 #8
0
    defaultclock.dt = DT

    net = Network([
        neurons,
        ee_synapses,
        ei_synapses,
        ie_synapses,
        ii_synapses,
        static_synapses_exc,
        static_synapses_inh,
        stimulus,
        spike_monitor_exc,
        spike_monitor_inh,
    ])
    net.store()

    collected_pairs = collect_stimulus_pairs()

    # add only jittered pairs
    collected_pairs[0] = [
        [generate_poisson(DURATION / ms, STIMULUS_POISSON_RATE / Hz / 1e3)] * 2
        for _ in range(N_PAIRS)
    ]

    def map_sim(spike_times):
        """Wrapper to sim for multiprocessing
        """
        return sim(net, spike_times)

    result = defaultdict(list)