Exemplo n.º 1
0
def test_spike_neurongroup():
    """
    Test dictionary representation of spiking neuron
    """
    eqn = ''' dv/dt = (v_th - v) / tau : volt
              v_th = 900 * mV :volt
              v_rest = -70 * mV :volt
              tau :second (constant)'''

    tau = 10 * ms
    size = 10

    grp = NeuronGroup(size, eqn, threshold='v > v_th',
                      reset='v = v_rest',
                      refractory=2 * ms)

    neuron_dict = collect_NeuronGroup(grp, get_local_namespace(0))

    assert neuron_dict['N'] == size
    assert neuron_dict['user_method'] is None

    eqns = Equations(eqn)
    assert neuron_dict['equations']['v']['type'] == DIFFERENTIAL_EQUATION
    assert neuron_dict['equations']['v']['unit'] == volt
    assert neuron_dict['equations']['v']['var_type'] == FLOAT
    assert neuron_dict['equations']['v']['expr'] == eqns['v'].expr.code

    assert neuron_dict['equations']['v_th']['type'] == SUBEXPRESSION
    assert neuron_dict['equations']['v_th']['unit'] == volt
    assert neuron_dict['equations']['v_th']['var_type'] == FLOAT
    assert neuron_dict['equations']['v_th']['expr'] == eqns['v_th'].expr.code

    assert neuron_dict['equations']['v_rest']['type'] == SUBEXPRESSION
    assert neuron_dict['equations']['v_rest']['unit'] == volt
    assert neuron_dict['equations']['v_rest']['var_type'] == FLOAT

    assert neuron_dict['equations']['tau']['type'] == PARAMETER
    assert neuron_dict['equations']['tau']['unit'] == second
    assert neuron_dict['equations']['tau']['var_type'] == FLOAT
    assert neuron_dict['equations']['tau']['flags'][0] == 'constant'

    thresholder = grp.thresholder['spike']
    neuron_events = neuron_dict['events']['spike']
    assert neuron_events['threshold']['code'] == 'v > v_th'
    assert neuron_events['threshold']['when'] == thresholder.when
    assert neuron_events['threshold']['order'] == thresholder.order
    assert neuron_events['threshold']['dt'] == grp.clock.dt

    resetter = grp.resetter['spike']
    assert neuron_events['reset']['code'] == 'v = v_rest'
    assert neuron_events['reset']['when'] == resetter.when
    assert neuron_events['reset']['order'] == resetter.order
    assert neuron_events['reset']['dt'] == resetter.clock.dt

    assert neuron_dict['events']['spike']['refractory'] == Quantity(2 * ms)

    # example 2 with threshold but no reset

    start_scope()
    grp2 = NeuronGroup(size, '''dv/dt = (100 * mV - v) / tau_n : volt''',
                             threshold='v > 800 * mV',
                             method='euler')
    tau_n = 10 * ms

    neuron_dict2 = collect_NeuronGroup(grp2, get_local_namespace(0))
    thresholder = grp2.thresholder['spike']
    neuron_events = neuron_dict2['events']['spike']
    assert neuron_events['threshold']['code'] == 'v > 800 * mV'
    assert neuron_events['threshold']['when'] == thresholder.when
    assert neuron_events['threshold']['order'] == thresholder.order
    assert neuron_events['threshold']['dt'] == grp2.clock.dt

    with pytest.raises(KeyError):
        neuron_dict2['events']['spike']['reset']
        neuron_dict2['events']['spike']['refractory']
Test the simulation class - runtime
'''
import pytest
import numpy as np
from numpy.testing.utils import assert_equal, assert_raises
from brian2 import (Equations, NeuronGroup, StateMonitor, Network, ms,
                    start_scope, mV)
from brian2.devices.device import Dummy
from brian2modelfitting.simulator import (initialize_neurons,
                                          initialize_parameter, Simulator,
                                          RuntimeSimulator)
from brian2.devices.device import reinit_devices

model = Equations('''
    I = g*(v-E) : amp
    v = 10*mvolt :volt
    g : siemens (constant)
    E : volt (constant)
    ''')

model2 = Equations(
    '''
                  dv/dt = (gL*(EL-v)+gL*DeltaT*exp((v-VT)/DeltaT) + I)/C : volt
                  I = 20* nA :amp
                  gL: siemens (constant)
                  C: farad (constant)
                  ''',
    EL=-70 * mV,
    VT=-50 * mV,
    DeltaT=2 * mV,
)
Exemplo n.º 3
0
def test_construction_errors():
    '''
    Test that the Equations constructor raises errors correctly
    '''
    # parse error
    assert_raises(EquationError, lambda: Equations('dv/dt = -v / tau volt'))

    # Only a single string or a list of SingleEquation objects is allowed
    assert_raises(TypeError, lambda: Equations(None))
    assert_raises(TypeError, lambda: Equations(42))
    assert_raises(TypeError, lambda: Equations(['dv/dt = -v / tau : volt']))

    # duplicate variable names
    assert_raises(
        EquationError, lambda: Equations('''dv/dt = -v / tau : volt
                                                    v = 2 * t/second * volt : volt'''
                                         ))

    eqs = [
        SingleEquation(DIFFERENTIAL_EQUATION,
                       'v',
                       volt.dim,
                       expr=Expression('-v / tau')),
        SingleEquation(SUBEXPRESSION,
                       'v',
                       volt.dim,
                       expr=Expression('2 * t/second * volt'))
    ]
    assert_raises(EquationError, lambda: Equations(eqs))

    # illegal variable names
    assert_raises(ValueError, lambda: Equations('ddt/dt = -dt / tau : volt'))
    assert_raises(ValueError, lambda: Equations('dt/dt = -t / tau : volt'))
    assert_raises(ValueError, lambda: Equations('dxi/dt = -xi / tau : volt'))
    assert_raises(ValueError, lambda: Equations('for : volt'))
    assert_raises((EquationError, ValueError),
                  lambda: Equations('d1a/dt = -1a / tau : volt'))
    assert_raises(ValueError, lambda: Equations('d_x/dt = -_x / tau : volt'))

    # xi in a subexpression
    assert_raises(
        EquationError, lambda: Equations('''dv/dt = -(v + I) / (5 * ms) : volt
                                       I = second**-1*xi**-2*volt : volt'''))

    # more than one xi
    assert_raises(
        EquationError,
        lambda: Equations('''dv/dt = -v / tau + xi/tau**.5 : volt
                                       dx/dt = -x / tau + 2*xi/tau : volt
                                       tau : second'''))
    # using not-allowed flags
    eqs = Equations('dv/dt = -v / (5 * ms) : volt (flag)')
    eqs.check_flags({DIFFERENTIAL_EQUATION: ['flag']})  # allow this flag
    assert_raises(ValueError,
                  lambda: eqs.check_flags({DIFFERENTIAL_EQUATION: []}))
    assert_raises(ValueError, lambda: eqs.check_flags({}))
    assert_raises(ValueError,
                  lambda: eqs.check_flags({SUBEXPRESSION: ['flag']}))
    assert_raises(
        ValueError,
        lambda: eqs.check_flags({DIFFERENTIAL_EQUATION: ['otherflag']}))
    eqs = Equations('dv/dt = -v / (5 * ms) : volt (flag1, flag2)')
    eqs.check_flags({DIFFERENTIAL_EQUATION: ['flag1',
                                             'flag2']})  # allow both flags
    # Don't allow the two flags in combination
    assert_raises(
        ValueError,
        lambda: eqs.check_flags({DIFFERENTIAL_EQUATION: ['flag1', 'flag2']},
                                incompatible_flags=[('flag1', 'flag2')]))
    eqs = Equations('''dv/dt = -v / (5 * ms) : volt (flag1)
                       dw/dt = -w / (5 * ms) : volt (flag2)''')
    # They should be allowed when used independently
    eqs.check_flags({DIFFERENTIAL_EQUATION: ['flag1', 'flag2']},
                    incompatible_flags=[('flag1', 'flag2')])

    # Circular subexpression
    assert_raises(
        ValueError, lambda: Equations('''dv/dt = -(v + w) / (10 * ms) : 1
                                                   w = 2 * x : 1
                                                   x = 3 * w : 1'''))

    # Boolean/integer differential equations
    assert_raises(TypeError,
                  lambda: Equations('dv/dt = -v / (10*ms) : boolean'))
    assert_raises(TypeError,
                  lambda: Equations('dv/dt = -v / (10*ms) : integer'))
from brian2modelfitting import (NevergradOptimizer, TraceFitter, MSEMetric,
                                OnlineTraceFitter, Simulator, Metric,
                                Optimizer, GammaFactor)
from brian2.devices.device import reinit_devices, reset_device
from brian2modelfitting.fitter import get_param_dic


E = 40*mV
input_traces = zeros((10, 5))*volt
for i in range(5):
    input_traces[5:, i] = i*10*mV

output_traces = 10*nS*input_traces

model = Equations('''
    I = g*(v-E) : amp
    g : siemens (constant)
    ''')

strmodel = '''
    I = g*(v-E) : amp
    g : siemens (constant)
    '''

constant_model = Equations('''
    v = c + x: volt
    c : volt (constant)''')

n_opt = NevergradOptimizer()
metric = MSEMetric()