Example #1
0
def test_unit_checking():
    # dummy Variable class
    class S(object):
        def __init__(self, unit):
            self.unit = unit

    # inconsistent unit for a differential equation
    eqs = Equations('dv/dt = -v : volt')
    group = SimpleGroup({'v': S(volt)})
    assert_raises(DimensionMismatchError,
                  lambda: eqs.check_units(group))

    eqs = Equations('dv/dt = -v / tau: volt')
    group = SimpleGroup(namespace={'tau': 5*mV}, variables={'v': S(volt)})
    assert_raises(DimensionMismatchError,
                  lambda: eqs.check_units(group))
    group = SimpleGroup(namespace={'I': 3*second}, variables={'v': S(volt)})
    eqs = Equations('dv/dt = -(v + I) / (5 * ms): volt')
    assert_raises(DimensionMismatchError,
                  lambda: eqs.check_units(group))

    eqs = Equations('''dv/dt = -(v + I) / (5 * ms): volt
                       I : second''')
    group = SimpleGroup(variables={'v': S(volt),
                                   'I': S(second)}, namespace={})
    assert_raises(DimensionMismatchError,
                  lambda: eqs.check_units(group))
    
    # inconsistent unit for a subexpression
    eqs = Equations('''dv/dt = -v / (5 * ms) : volt
                       I = 2 * v : amp''')
    group = SimpleGroup(variables={'v': S(volt),
                                   'I': S(second)}, namespace={})
    assert_raises(DimensionMismatchError,
                  lambda: eqs.check_units(group))
Example #2
0
def test_unit_checking():
    # dummy Variable class
    class S(object):
        def __init__(self, unit):
            self.unit = unit
    
    # Let's create a namespace with a user-defined namespace that we can
    # updater later on 
    namespace = create_namespace({})
    # inconsistent unit for a differential equation
    eqs = Equations('dv/dt = -v : volt')
    assert_raises(DimensionMismatchError,
                  lambda: eqs.check_units(namespace, {'v': S(volt)}))

    eqs = Equations('dv/dt = -v / tau: volt')
    namespace['tau'] = 5*mV
    assert_raises(DimensionMismatchError,
                  lambda: eqs.check_units(namespace, {'v': S(volt)}))
    namespace['I'] = 3*second
    eqs = Equations('dv/dt = -(v + I) / (5 * ms): volt')
    assert_raises(DimensionMismatchError,
                  lambda: eqs.check_units(namespace, {'v': S(volt)}))

    eqs = Equations('''dv/dt = -(v + I) / (5 * ms): volt
                       I : second''')
    assert_raises(DimensionMismatchError,
                  lambda: eqs.check_units(namespace, {'v': S(volt),
                                                      'I': S(second)}))
    
    # inconsistent unit for a static equation
    eqs = Equations('''dv/dt = -v / (5 * ms) : volt
                       I = 2 * v : amp''')
    assert_raises(DimensionMismatchError,
                  lambda: eqs.check_units(namespace, {'v': S(volt),
                                                      'I': S(amp)}))
Example #3
0
def test_repeated_construction():
    eqs1 = Equations('dx/dt = x : 1')
    eqs2 = Equations('dx/dt = x : 1', x='y')
    assert len(eqs1) == 1
    assert 'x' in eqs1
    assert eqs1['x'].expr == Expression('x')
    assert len(eqs2) == 1
    assert 'y' in eqs2
    assert eqs2['y'].expr == Expression('y')
Example #4
0
def test_substitute():
    # Check that Equations.substitute returns an independent copy
    eqs = Equations('dx/dt = x : 1')
    eqs2 = eqs.substitute(x='y')

    # First equation should be unaffected
    assert len(eqs) == 1 and 'x' in eqs
    assert eqs['x'].expr == Expression('x')

    # Second equation should have x substituted by y
    assert len(eqs2) == 1 and 'y' in eqs2
    assert eqs2['y'].expr == Expression('y')
Example #5
0
def test_substitute():
    # Check that Equations.substitute returns an independent copy
    eqs = Equations('dx/dt = x : 1')
    eqs2 = eqs.substitute(x='y')

    # First equation should be unaffected
    assert len(eqs) == 1 and 'x' in eqs
    assert eqs['x'].expr == Expression('x')

    # Second equation should have x substituted by y
    assert len(eqs2) == 1 and 'y' in eqs2
    assert eqs2['y'].expr == Expression('y')
Example #6
0
def test_correct_replacements():
    ''' Test replacing variables via keyword arguments '''
    # replace a variable name with a new name
    eqs = Equations('dv/dt = -v / tau : 1', v='V')
    # Correct left hand side
    assert ('V' in eqs) and not ('v' in eqs)
    # Correct right hand side
    assert ('V' in eqs['V'].identifiers) and not ('v' in eqs['V'].identifiers)

    # replace a variable name with a value
    eqs = Equations('dv/dt = -v / tau : 1', tau=10 * ms)
    assert not 'tau' in eqs['v'].identifiers
Example #7
0
def test_identifier_checks():
    legal_identifiers = ['v', 'Vm', 'V', 'x', 'ge', 'g_i', 'a2', 'gaba_123']
    illegal_identifiers = ['_v', '1v', 'ü', 'ge!', 'v.x', 'for', 'else', 'if']

    for identifier in legal_identifiers:
        try:
            check_identifier_basic(identifier)
            check_identifier_reserved(identifier)
        except ValueError as ex:
            raise AssertionError(
                f'check complained about identifier "{identifier}": {ex}')

    for identifier in illegal_identifiers:
        with pytest.raises(SyntaxError):
            check_identifier_basic(identifier)

    for identifier in ('t', 'dt', 'xi', 'i', 'N'):
        with pytest.raises(SyntaxError):
            check_identifier_reserved(identifier)

    for identifier in ('not_refractory', 'refractory', 'refractory_until'):
        with pytest.raises(SyntaxError):
            check_identifier_refractory(identifier)

    for identifier in ('exp', 'sin', 'sqrt'):
        with pytest.raises(SyntaxError):
            check_identifier_functions(identifier)

    for identifier in ('e', 'pi', 'inf'):
        with pytest.raises(SyntaxError):
            check_identifier_constants(identifier)

    for identifier in ('volt', 'second', 'mV', 'nA'):
        with pytest.raises(SyntaxError):
            check_identifier_units(identifier)

    # Check identifier registry
    assert check_identifier_basic in Equations.identifier_checks
    assert check_identifier_reserved in Equations.identifier_checks
    assert check_identifier_refractory in Equations.identifier_checks
    assert check_identifier_functions in Equations.identifier_checks
    assert check_identifier_constants in Equations.identifier_checks
    assert check_identifier_units in Equations.identifier_checks

    # Set up a dummy identifier check that disallows the variable name
    # gaba_123 (that is otherwise valid)
    def disallow_gaba_123(identifier):
        if identifier == 'gaba_123':
            raise SyntaxError("I do not like this name")

    Equations.check_identifier('gaba_123')
    old_checks = set(Equations.identifier_checks)
    Equations.register_identifier_check(disallow_gaba_123)
    with pytest.raises(SyntaxError):
        Equations.check_identifier('gaba_123')
    Equations.identifier_checks = old_checks

    # registering a non-function should not work
    with pytest.raises(ValueError):
        Equations.register_identifier_check('no function')
Example #8
0
 def brian2(self) -> BrianObject:
     model = Equations('''
     w1 : 1
     w2 : 1
     w3 : 1
     w4 : 1
     ''')
     on_pre = f'''
     {self.post_variable_name[0]} += w1
     {self.post_variable_name[1]} += w2
     {self.post_variable_name[2]} += w3
     {self.post_variable_name[3]} += w4
     '''
     syn = Synapses(
         source=self.origin.brian2,
         target=self.target.brian2,
         method='euler',
         model=model,
         on_pre=on_pre,
         name=self.ref,
     )
     syn.connect(j='i')
     syn.w1[:] = 1
     syn.w2[:] = 1
     syn.w3[:] = 1
     syn.w4[:] = 1
     return syn
Example #9
0
def plot_inf(ax, parameters):
    """Plot gating variable steady-state values as function of membrane potential.

    ax -- matplotlib axes to be plotted on
    parameters -- dictionary of parameters for gating variable steady-state equations
    """

    inf_group = NeuronGroup(100,
                            Equations('v : volt') +
                            reduce(operator.add, [
                                construct_gating_variable_inf_equation(gv)
                                for gv in ['m', 'n', 'h']
                            ]),
                            method='euler',
                            namespace=parameters)
    inf_group.v = np.linspace(-100, 100, len(inf_group)) * mV

    ax.plot(inf_group.v / mV, inf_group.m_inf, label=r'$m_\infty$')
    ax.plot(inf_group.v / mV, inf_group.n_inf, label=r'$n_\infty$')
    ax.plot(inf_group.v / mV, inf_group.h_inf, label=r'$h_\infty$')
    ax.set_xlabel('$v$ (mV)')
    ax.set_ylabel('steady-state activation')
    ax.yaxis.set_label_position("right")
    ax.yaxis.tick_right()
    ax.legend()
Example #10
0
def test_dependency_calculation():
    eqs = Equations('''dv/dt = I_m / C_m : volt
                       I_m = I_ext + I_pas : amp
                       I_ext = 1*nA + sin(2*pi*100*Hz*t)*nA : amp
                       I_pas = g_L*(E_L - v) : amp''')
    deps = eqs.dependencies
    assert set(deps.keys()) == {'v', 'I_m', 'I_ext', 'I_pas'}

    # v depends directly on I_m, on I_ext and I_pas via I_m, and on v via I_m -> I_pas
    assert len(deps['v']) == 4
    assert set(d.equation.varname
               for d in deps['v']) == {'I_m', 'I_ext', 'I_pas', 'v'}
    expected_via = {
        'I_m': (),
        'I_pas': ('I_m', ),
        'I_ext': ('I_m', ),
        'v': ('I_m', 'I_pas')
    }
    assert all([d.via == expected_via[d.equation.varname] for d in deps['v']])

    # I_m depends directly on I_ext and I_pas, and on v via I_pas
    assert len(deps['I_m']) == 3
    assert set(d.equation.varname
               for d in deps['I_m']) == {'I_ext', 'I_pas', 'v'}
    expected_via = {'I_ext': (), 'I_pas': (), 'v': ('I_pas', )}
    assert all(
        [d.via == expected_via[d.equation.varname] for d in deps['I_m']])

    # I_ext does not depend on anything
    assert len(deps['I_ext']) == 0

    # I_pas depends on v directly
    assert len(deps['I_pas']) == 1
    assert deps['I_pas'][0].equation.varname == 'v'
    assert deps['I_pas'][0].via == ()
Example #11
0
 def brian2(self) -> BrianObject:
     model = Equations(
         '''
         w : 1
         s_post = w * s : 1 (summed)
         ds / dt = - s / tau + alpha * x * (1 - s) : 1 (clock-driven)
         dx / dt = - x / tau_rise : 1 (clock-driven)
         ''',
         s_post=f'{self.post_variable_name_tot}_post',
         s=self.post_variable_name,
         x=self.post_nonlinear_name,
         tau=self[IP.TAU],
         tau_rise=self[IP.TAU_NMDA_RISE],
         alpha=self[IP.ALPHA],
     )
     eqs_pre = f'''
     {self.post_nonlinear_name} += 1
     '''
     C = Synapses(
         self.origin.brian2,
         self.target.brian2,
         method='euler',
         model=model,
         on_pre=eqs_pre,
         name=self.ref,
     )
     C.connect()
     C.w[:] = 1
     return C
Example #12
0
def test_str_repr():
    '''
    Test the string representation (only that it does not throw errors).
    '''
    tau = 10 * ms
    eqs = Equations('''dv/dt = -(v + I)/ tau : volt (unless refractory)
                       I = sin(2 * 22/7. * f * t)* volt : volt
                       f : Hz''')
    assert len(str(eqs)) > 0
    assert len(repr(eqs)) > 0

    # Test str and repr of SingleEquations explicitly (might already have been
    # called by Equations
    for eq in eqs.itervalues():
        assert (len(str(eq))) > 0
        assert (len(repr(eq))) > 0
Example #13
0
def test_str_repr():
    '''
    Test the string representation (only that it does not throw errors).
    '''
    tau = 10 * ms
    eqs = Equations('''dv/dt = -(v + I)/ tau : volt (unless refractory)
                       I = sin(2 * 22/7. * f * t)* volt : volt
                       f : Hz''')
    assert len(str(eqs)) > 0
    assert len(repr(eqs)) > 0

    # Test str and repr of SingleEquations explicitly (might already have been
    # called by Equations
    for eq in eqs.itervalues():
        assert(len(str(eq))) > 0
        assert(len(repr(eq))) > 0
def setup_spikes(request):
    def fin():
        reinit_devices()

    request.addfinalizer(fin)
    EL = -70 * mV
    VT = -50 * mV
    DeltaT = 2 * mV
    C = 1 * nF
    gL = 30 * nS
    I = TimedArray(input_current, dt=0.01 * ms)
    model = Equations('''
                      dv/dt = (gL*(EL-v)+gL*DeltaT*exp((v-VT)/DeltaT) + I(t))/C : volt
                      ''')
    group = NeuronGroup(1,
                        model,
                        threshold='v > -50*mV',
                        reset='v = -70*mV',
                        method='exponential_euler')
    group.v = -70 * mV
    spike_mon = SpikeMonitor(group)
    run(60 * ms)
    spikes = getattr(spike_mon, 't_')

    return spike_mon, spikes
def test_fitter_fit_methods(method):
    dt = 0.01 * ms
    model = Equations('''
        I = g*(v-E) : amp
        g : siemens (constant)
        E : volt (constant)
        ''')
    tf = TraceFitter(dt=dt,
                     model=model,
                     input_var='v',
                     output_var='I',
                     input=input_traces,
                     output=output_traces,
                     n_samples=30)
    # Skip all BO methods for now (TODO: check what is going on)
    if 'BO' in method:
        pytest.skip(f'Skipping method {method}')
    optimizer = NevergradOptimizer(method)
    # Just make sure that it can run at all
    tf.fit(n_rounds=2,
           optimizer=optimizer,
           metric=metric,
           g=[1 * nS, 30 * nS],
           E=[-60 * mV, -20 * mV],
           callback=None)
Example #16
0
    def brian2(self) -> BrianObject:
        model = Equations('''
        x : 1
        u : 1
        w : 1
        ''')

        U = self[IP.U]
        tauf = self[IP.TAU_F]
        taud = self[IP.TAU_D]

        on_pre = f'''
        u = {U} + (u - {U}) * exp(- (t - lastupdate) / ({tauf / units.ms} * ms))
        x = 1 + (x - 1) * exp(- (t - lastupdate) / ({taud / units.ms} * ms))
        
        {self.post_variable_name} += w * u * x
        
        x *= (1 - u)
        u += {U} * (1 - u)
        '''
        syn = Synapses(source=self.origin.brian2,
                       target=self.target.brian2,
                       method='euler',
                       model=model,
                       on_pre=on_pre,
                       name=self.ref)
        self.connection.simulation(syn)
        syn.w[:] = self[IP.W]
        syn.x[:] = 1
        syn.u[:] = 1
        return syn
Example #17
0
def plot_tau(ax, parameters):
    """Plot gating variable time constants as function of membrane potential.

    ax -- matplotlib axes to be plotted on
    parameters -- dictionary of parameters for gating variable time constant equations
    """

    tau_group = NeuronGroup(100,
                            Equations('v : volt') +
                            reduce(operator.add, [
                                construct_gating_variable_tau_equation(gv)
                                for gv in ['m', 'n', 'h']
                            ]),
                            method='euler',
                            namespace=parameters)

    min_v = -100
    max_v = 100
    tau_group.v = np.linspace(min_v, max_v, len(tau_group)) * mV

    ax.plot(tau_group.v / mV, tau_group.tau_m / ms, label=r'$\tau_m$')
    ax.plot(tau_group.v / mV, tau_group.tau_n / ms, label=r'$\tau_n$')
    ax.plot(tau_group.v / mV, tau_group.tau_h / ms, label=r'$\tau_h$')

    ax.set_xlabel('$v$ (mV)')
    ax.set_ylabel(r'$\tau$ (ms)')
    ax.yaxis.set_label_position("right")
    ax.yaxis.tick_right()
    ax.legend()
Example #18
0
    def introspect_population(pop):
        if False:
            print(pop.thresholder)
            print(pop.resetter)
            print(pop.events['spike'])
            print(pop.event_codes.get('spike'))

        print('{} [{}] n: {}'.format(pop.__class__.__name__, pop.name, pop.N))
        print(Equations(str(pop.user_equations), **globals))
Example #19
0
def test_resolve():
    '''
    Test resolving identifiers in equations.
    '''
    def foo(x):
        return 5 * x 
    tau = 10 * ms
    tau2 = 20 
    eqs = '''dv/dt = -v / tau : volt
             du/dt = -foo(u) * param / (tau_long * ms): volt
             param : 1
          '''
    eq = Equations(eqs, namespace={'tau_long': tau2}, exhaustive=False)
    namespace = eq.resolve()
    assert namespace['tau'] is tau
    assert namespace['tau_long'] is tau2
    assert namespace['foo'] is foo
    assert namespace['ms'] is ms
    assert set(namespace.keys()) == set(('tau', 'tau_long', 'foo', 'ms'))
Example #20
0
def construct_gating_variable_ode(gating_variable):
    """Construct the ordinary differential equation of the gating variable.

    gating_variable -- gating variable, typically one of "m", "n" and "h"
    """

    return Equations('dx/dt = (xinf - x)/tau : 1',
                     x=gating_variable,
                     xinf=f'{gating_variable}_inf',
                     tau=f'tau_{gating_variable}')
Example #21
0
def test_extract_subexpressions():
    eqs = Equations("""dv/dt = -v / (10*ms) : 1
                       s1 = 2*v : 1
                       s2 = -v : 1 (constant over dt)
                    """)
    variable, constant = extract_constant_subexpressions(eqs)
    assert [var in variable for var in ['v', 's1', 's2']]
    assert variable['s1'].type == SUBEXPRESSION
    assert variable['s2'].type == PARAMETER
    assert constant['s2'].type == SUBEXPRESSION
Example #22
0
def test_construction_errors():
    '''
    Test that the Equations constructor raises errors correctly
    '''
    # parse error
    assert_raises(EquationError, lambda: Equations('dv/dt = -v / tau volt'))

    # Only a single string or a list of SingleEquation objects is allowed
    assert_raises(TypeError, lambda: Equations(None))
    assert_raises(TypeError, lambda: Equations(42))
    assert_raises(TypeError, lambda: Equations(['dv/dt = -v / tau : volt']))

    # duplicate variable names
    assert_raises(EquationError, lambda: Equations('''dv/dt = -v / tau : volt
                                                    v = 2 * t/second * volt : volt'''))

    eqs = [SingleEquation(DIFFERENTIAL_EQUATION, 'v', volt,
                          expr=Expression('-v / tau')),
           SingleEquation(SUBEXPRESSION, 'v', volt,
                          expr=Expression('2 * t/second * volt'))
           ]
    assert_raises(EquationError, lambda: Equations(eqs))

    # illegal variable names
    assert_raises(ValueError, lambda: Equations('ddt/dt = -dt / tau : volt'))
    assert_raises(ValueError, lambda: Equations('dt/dt = -t / tau : volt'))
    assert_raises(ValueError, lambda: Equations('dxi/dt = -xi / tau : volt'))
    assert_raises(ValueError, lambda: Equations('for : volt'))
    assert_raises((EquationError, ValueError),
                  lambda: Equations('d1a/dt = -1a / tau : volt'))
    assert_raises(ValueError, lambda: Equations('d_x/dt = -_x / tau : volt'))

    # xi in a subexpression
    assert_raises(EquationError,
                  lambda: Equations('''dv/dt = -(v + I) / (5 * ms) : volt
                                       I = second**-1*xi**-2*volt : volt'''))

    # more than one xi
    assert_raises(EquationError,
                  lambda: Equations('''dv/dt = -v / tau + xi/tau**.5 : volt
                                       dx/dt = -x / tau + 2*xi/tau : volt
                                       tau : second'''))
    # using not-allowed flags
    eqs = Equations('dv/dt = -v / (5 * ms) : volt (flag)')
    eqs.check_flags({DIFFERENTIAL_EQUATION: ['flag']})  # allow this flag
    assert_raises(ValueError, lambda: eqs.check_flags({DIFFERENTIAL_EQUATION: []}))
    assert_raises(ValueError, lambda: eqs.check_flags({}))
    assert_raises(ValueError, lambda: eqs.check_flags({SUBEXPRESSION: ['flag']}))
    assert_raises(ValueError, lambda: eqs.check_flags({DIFFERENTIAL_EQUATION: ['otherflag']}))

    # Circular subexpression
    assert_raises(ValueError, lambda: Equations('''dv/dt = -(v + w) / (10 * ms) : 1
                                                   w = 2 * x : 1
                                                   x = 3 * w : 1'''))

    # Boolean/integer differential equations
    assert_raises(TypeError, lambda: Equations('dv/dt = -v / (10*ms) : boolean'))
    assert_raises(TypeError, lambda: Equations('dv/dt = -v / (10*ms) : integer'))
Example #23
0
 def brian2(self) -> BrianObject:
     model = Equations('w : 1')
     on_pre = f'{self.post_variable_name} += w'
     syn = Synapses(source=self.origin.brian2,
                    target=self.target.brian2,
                    method='euler',
                    model=model,
                    on_pre=on_pre,
                    name=self.ref)
     self.connection.simulation(syn)
     syn.w[:] = self[IP.W]
     return syn
Example #24
0
def test_ipython_pprint():
    from io import StringIO
    eqs = Equations("""dv/dt = -(v + I)/ tau : volt (unless refractory)
                       I = sin(2 * 22/7. * f * t)* volt : volt
                       f : Hz""")
    # Test ipython's pretty printing
    old_stdout = sys.stdout
    string_output = StringIO()
    sys.stdout = string_output
    pprint(eqs)
    assert len(string_output.getvalue()) > 0
    sys.stdout = old_stdout
Example #25
0
def construct_gating_variable_inf_equation(gating_variable):
    """Construct the voltage-dependent steady-state gating variable equation.

    Approximated by Boltzmann function.

    gating_variable -- gating variable, typically one of "m", "n" and "h"
    """

    return Equations('xinf = 1/(1+exp((v_half-v)/k)) : 1',
                     xinf=f'{gating_variable}_inf',
                     v_half=f'v_{gating_variable}_half',
                     k=f'k_{gating_variable}')
Example #26
0
def test_identifier_checks():    
    legal_identifiers = ['v', 'Vm', 'V', 'x', 'ge', 'g_i', 'a2', 'gaba_123']
    illegal_identifiers = ['_v', '1v', u'ü', 'ge!', 'v.x', 'for', 'else', 'if']
    
    for identifier in legal_identifiers:
        try:
            check_identifier_basic(identifier)
            check_identifier_reserved(identifier)
        except ValueError as ex:
            raise AssertionError('check complained about '
                                 'identifier "%s": %s' % (identifier, ex))
    
    for identifier in illegal_identifiers:
        assert_raises(ValueError, lambda: check_identifier_basic(identifier))

    for identifier in ('t', 'dt', 'xi'):
        assert_raises(ValueError, lambda: check_identifier_reserved(identifier))
    
    for identifier in ('is_active', 'refractory', 'refractory_until'):
        assert_raises(ValueError, lambda: check_identifier_refractory(identifier))
    
    # Check identifier registry
    assert check_identifier_basic in Equations.identifier_checks
    assert check_identifier_reserved in Equations.identifier_checks
    assert check_identifier_refractory in Equations.identifier_checks
    
    # Set up a dummy identifier check that disallows the variable name
    # gaba_123 (that is otherwise valid)
    def disallow_gaba_123(identifier):
        if identifier == 'gaba_123':
            raise ValueError('I do not like this name')
    
    check_identifier('gaba_123')
    old_checks = Equations.identifier_checks
    Equations.register_identifier_check(disallow_gaba_123)
    assert_raises(ValueError, lambda: check_identifier('gaba_123'))
    Equations.identifier_checks = old_checks
    
    # registering a non-function should now work
    assert_raises(ValueError, lambda: Equations.register_identifier_check('no function'))
Example #27
0
def test_wrong_replacements():
    '''Tests for replacements that should not work'''
    # Replacing a variable name with an illegal new name
    assert_raises(SyntaxError,
                  lambda: Equations('dv/dt = -v / tau : 1', v='illegal name'))
    assert_raises(SyntaxError,
                  lambda: Equations('dv/dt = -v / tau : 1', v='_reserved'))
    assert_raises(SyntaxError,
                  lambda: Equations('dv/dt = -v / tau : 1', v='t'))

    # Replacing a variable name with a value that already exists
    assert_raises(
        EquationError, lambda: Equations('''
                                                    dv/dt = -v / tau : 1
                                                    dx/dt = -x / tau : 1
                                                    ''',
                                         v='x'))

    # Replacing a model variable name with a value
    assert_raises(ValueError,
                  lambda: Equations('dv/dt = -v / tau : 1', v=3 * mV))

    # Replacing with an illegal value
    assert_raises(SyntaxError,
                  lambda: Equations('dv/dt = -v/tau : 1', tau=np.arange(5)))
Example #28
0
def test_concatenation():
    eqs1 = Equations('''dv/dt = -(v + I) / tau : volt
                        I = sin(2*pi*freq*t) : volt
                        freq : Hz''')

    # Concatenate two equation objects
    eqs2 = (Equations('dv/dt = -(v + I) / tau : volt') +
            Equations('''I = sin(2*pi*freq*t) : volt
                         freq : Hz'''))

    # Concatenate using "in-place" addition (which is not actually in-place)
    eqs3 = Equations('dv/dt = -(v + I) / tau : volt')
    eqs3 += Equations('''I = sin(2*pi*freq*t) : volt
                         freq : Hz''')

    # Concatenate with a string (will be parsed first)
    eqs4 = Equations('dv/dt = -(v + I) / tau : volt')
    eqs4 += '''I = sin(2*pi*freq*t) : volt
               freq : Hz'''

    # Concatenating with something that is not a string should not work
    assert_raises(TypeError, lambda: eqs4 + 5)

    # The string representation is canonical, therefore it should be identical
    # in all cases
    assert str(eqs1) == str(eqs2)
    assert str(eqs2) == str(eqs3)
    assert str(eqs3) == str(eqs4)
Example #29
0
def test_properties():
    '''
    Test accessing the various properties of equation objects
    '''
    tau = 10 * ms
    eqs = Equations('''dv/dt = -(v + I)/ tau : volt
                       I = sin(2 * 22/7. * f * t)* volt : volt
                       f = freq * Hz: Hz
                       freq : 1''')
    assert (len(eqs.diff_eq_expressions) == 1 and
            eqs.diff_eq_expressions[0][0] == 'v' and
            isinstance(eqs.diff_eq_expressions[0][1], Expression))
    assert eqs.diff_eq_names == {'v'}
    assert (len(eqs.eq_expressions) == 3 and
            set([name for name, _ in eqs.eq_expressions]) == {'v', 'I', 'f'} and
            all((isinstance(expr, Expression) for _, expr in eqs.eq_expressions)))
    assert len(eqs.eq_names) == 3 and eqs.eq_names == {'v', 'I', 'f'}
    assert set(eqs.keys()) == {'v', 'I', 'f', 'freq'}
    # test that the equations object is iterable itself
    assert all((isinstance(eq, SingleEquation) for eq in eqs.itervalues()))
    assert all((isinstance(eq, basestring) for eq in eqs))
    assert (len(eqs.ordered) == 4 and
            all((isinstance(eq, SingleEquation) for eq in eqs.ordered)) and
            [eq.varname for eq in eqs.ordered] == ['f', 'I', 'v', 'freq'])
    assert [eq.unit for eq in eqs.ordered] == [Hz, volt, volt, 1]
    assert eqs.names == {'v', 'I', 'f', 'freq'}
    assert eqs.parameter_names == {'freq'}
    assert eqs.subexpr_names == {'I', 'f'}
    dimensions = eqs.dimensions
    assert set(dimensions.keys()) == {'v', 'I', 'f', 'freq'}
    assert dimensions['v'] is volt.dim
    assert dimensions['I'] is volt.dim
    assert dimensions['f'] is Hz.dim
    assert dimensions['freq'] is DIMENSIONLESS
    assert eqs.names == set(eqs.dimensions.keys())
    assert eqs.identifiers == {'tau', 'volt', 'Hz', 'sin', 't'}

    # stochastic equations
    assert len(eqs.stochastic_variables) == 0
    assert eqs.stochastic_type is None
    
    eqs = Equations('''dv/dt = -v / tau + 0.1*second**-.5*xi : 1''')
    assert eqs.stochastic_variables == {'xi'}
    assert eqs.stochastic_type == 'additive'
    
    eqs = Equations('''dv/dt = -v / tau + 0.1*second**-.5*xi_1 +  0.1*second**-.5*xi_2: 1''')
    assert eqs.stochastic_variables == {'xi_1', 'xi_2'}
    assert eqs.stochastic_type == 'additive'
    
    eqs = Equations('''dv/dt = -v / tau + 0.1*second**-1.5*xi*t : 1''')
    assert eqs.stochastic_type == 'multiplicative'

    eqs = Equations('''dv/dt = -v / tau + 0.1*second**-1.5*xi*v : 1''')
    assert eqs.stochastic_type == 'multiplicative'
Example #30
0
 def brian2_model(self) -> Equations:
     """Returns Brian2 dynamic (Equations) affecting specified populations."""
     return Equations(
         '''
         I = g * (v - vrev) * s : amp
         ds / dt = - s / tau : 1
         ''',
         I=self.current_name,
         g=self[IP.GM],
         s=self.post_variable_name,
         vrev=self[IP.VREV],
         tau=self[IP.TAU],
     )
Example #31
0
def test_ipython_pprint():
    if pprint is None:
        raise SkipTest('ipython is not available')
    eqs = Equations('''dv/dt = -(v + I)/ tau : volt (unless refractory)
                       I = sin(2 * 22/7. * f * t)* volt : volt
                       f : Hz''')
    # Test ipython's pretty printing
    old_stdout = sys.stdout
    string_output = StringIO()
    sys.stdout = string_output
    pprint(eqs)
    assert len(string_output.getvalue()) > 0
    sys.stdout = old_stdout
Example #32
0
 def brian2_model(self) -> Equations:
     return Equations(
         '''
         I = g * (v - vrev) / (1 + mg * gamma * exp(- beta * v / mV) ) * s : amp
         s: 1
         ''',
         s=self.post_variable_name_tot,
         I=self.current_name,
         g=self[IP.GM],
         vrev=self[IP.VREV],
         beta=self[IP.BETA],
         gamma=self[IP.GAMMA],
         mg=self[IP.MG],
     )
Example #33
0
 def brian2_model(self) -> Equations:
     return Equations(
         '''
         I = g * (v - vrev ) / (1 + gamma * exp(- beta * v)) * s : amp
         ds / dt = - s / tau_decay : 1
         ''',
         I=self.current_name,
         g=self[IP.GM],
         s=self.post_variable_name,
         s_post=self.post_variable_name + '_post',
         vrev=self[IP.VREV],
         tau_decay=self[IP.TAU],
         gamma=self[IP.GAMMA],
         beta=self[IP.BETA] / units.mV,
     )
Example #34
0
def test_str_repr():
    '''
    Test the string representation (only that it does not throw errors).
    '''
    tau = 10 * ms
    eqs = Equations('''dv/dt = -(v + I)/ tau : volt (unless-refractory)
                       I = sin(2 * 22/7. * f * t)* volt : volt
                       f : Hz''')
    assert len(str(eqs)) > 0
    assert len(repr(eqs)) > 0

    # Test str and repr of SingleEquations explicitly (might already have been
    # called by Equations
    for eq in eqs.itervalues():
        assert (len(str(eq))) > 0
        assert (len(repr(eq))) > 0

    # Test ipython's pretty printing
    old_stdout = sys.stdout
    string_output = StringIO()
    sys.stdout = string_output
    pprint(eqs)
    assert len(string_output.getvalue()) > 0
    sys.stdout = old_stdout
Example #35
0
def test_str_repr():
    '''
    Test the string representation (only that it does not throw errors).
    '''
    tau = 10 * ms
    eqs = Equations('''dv/dt = -(v + I)/ tau : volt (unless-refractory)
                       I = sin(2 * 22/7. * f * t)* volt : volt
                       f : Hz''')
    assert len(str(eqs)) > 0
    assert len(repr(eqs)) > 0

    # Test str and repr of SingleEquations explicitly (might already have been
    # called by Equations
    for eq in eqs.itervalues():
        assert(len(str(eq))) > 0
        assert(len(repr(eq))) > 0

    # Test ipython's pretty printing
    old_stdout = sys.stdout
    string_output = StringIO()
    sys.stdout = string_output
    pprint(eqs)
    assert len(string_output.getvalue()) > 0
    sys.stdout = old_stdout
Example #36
0
def construct_gating_variable_tau_equation(gating_variable):
    """Construct the voltage-dependent gating variable time constant equation.

    Approximated by Gaussian function.

    gating_variable -- gating variable, typically one of "m", "n" and "h"
    """

    return Equations(
        'tau = c_base + c_amp*exp(-(v_max - v)**2/sigma**2) : second',
        tau=f'tau_{gating_variable}',
        c_base=f'c_{gating_variable}_base',
        c_amp=f'c_{gating_variable}_amp',
        v_max=f'v_{gating_variable}_max',
        sigma=f'sigma_{gating_variable}')
Example #37
0
def test_ipython_pprint():
    try:
        from cStringIO import StringIO  # Python 2
    except ImportError:
        from io import StringIO  # Python 3
    eqs = Equations('''dv/dt = -(v + I)/ tau : volt (unless refractory)
                       I = sin(2 * 22/7. * f * t)* volt : volt
                       f : Hz''')
    # Test ipython's pretty printing
    old_stdout = sys.stdout
    string_output = StringIO()
    sys.stdout = string_output
    pprint(eqs)
    assert len(string_output.getvalue()) > 0
    sys.stdout = old_stdout
Example #38
0
def test_identifier_checks():
    legal_identifiers = ['v', 'Vm', 'V', 'x', 'ge', 'g_i', 'a2', 'gaba_123']
    illegal_identifiers = ['_v', '1v', u'ü', 'ge!', 'v.x', 'for', 'else', 'if']

    for identifier in legal_identifiers:
        try:
            check_identifier_basic(identifier)
            check_identifier_reserved(identifier)
        except ValueError as ex:
            raise AssertionError('check complained about '
                                 'identifier "%s": %s' % (identifier, ex))

    for identifier in illegal_identifiers:
        assert_raises(ValueError, lambda: check_identifier_basic(identifier))

    for identifier in ('t', 'dt', 'xi'):
        assert_raises(ValueError,
                      lambda: check_identifier_reserved(identifier))

    for identifier in ('not_refractory', 'refractory', 'refractory_until'):
        assert_raises(ValueError,
                      lambda: check_identifier_refractory(identifier))

    for identifier in ('exp', 'sin', 'sqrt'):
        assert_raises(ValueError,
                      lambda: check_identifier_functions(identifier))

    for identifier in ('volt', 'second', 'mV', 'nA'):
        assert_raises(ValueError, lambda: check_identifier_units(identifier))

    # Check identifier registry
    assert check_identifier_basic in Equations.identifier_checks
    assert check_identifier_reserved in Equations.identifier_checks
    assert check_identifier_refractory in Equations.identifier_checks
    assert check_identifier_functions in Equations.identifier_checks
    assert check_identifier_units in Equations.identifier_checks

    # Set up a dummy identifier check that disallows the variable name
    # gaba_123 (that is otherwise valid)
    def disallow_gaba_123(identifier):
        if identifier == 'gaba_123':
            raise ValueError('I do not like this name')

    Equations.check_identifier('gaba_123')
    old_checks = set(Equations.identifier_checks)
    Equations.register_identifier_check(disallow_gaba_123)
    assert_raises(ValueError, lambda: Equations.check_identifier('gaba_123'))
    Equations.identifier_checks = old_checks

    # registering a non-function should now work
    assert_raises(ValueError,
                  lambda: Equations.register_identifier_check('no function'))
Example #39
0
    def brian2_model(self) -> Optional[Equations]:
        eqs = Equations(
            'dv / dt = (- g * (v - vl) - I) / cm : volt',
            g=self[PP.GM],
            vl=self[PP.VL],
            cm=self[PP.CM]
        )

        all_currents = []
        for s in self.inputs + self.noises:
            eqs += s.brian2_model
            all_currents.append(s.current_name)

        if len(all_currents):
            sum_currents = ' + '.join(all_currents)
            eqs += f'I = {sum_currents} : amp'
        else:
            eqs += 'I = 0 : amp'

        return eqs
Example #40
0
def test_construction_errors():
    '''
    Test that the Equations constructor raises errors correctly
    '''
    # parse error
    assert_raises(SyntaxError, lambda: Equations('dv/dt = -v / tau volt'))
    
    # Only a single string or a list of SingleEquation objects is allowed
    assert_raises(TypeError, lambda: Equations(None))
    assert_raises(TypeError, lambda: Equations(42))
    assert_raises(TypeError, lambda: Equations(['dv/dt = -v / tau : volt']))
    
    # duplicate variable names
    assert_raises(SyntaxError, lambda: Equations('''dv/dt = -v / tau : volt
                                                    v = 2 * t/second * volt : volt'''))
        
    
    # illegal variable names
    assert_raises(ValueError, lambda: Equations('ddt/dt = -dt / tau : volt'))
    assert_raises(ValueError, lambda: Equations('dt/dt = -t / tau : volt'))
    assert_raises(ValueError, lambda: Equations('dxi/dt = -xi / tau : volt'))
    assert_raises(ValueError, lambda: Equations('for : volt'))
    assert_raises((SyntaxError, ValueError),
                  lambda: Equations('d1a/dt = -1a / tau : volt'))
    assert_raises(ValueError, lambda: Equations('d_x/dt = -_x / tau : volt'))
    
    # inconsistent unit for a differential equation
    assert_raises(DimensionMismatchError,
                  lambda: Equations('dv/dt = -v : volt'))
    assert_raises(DimensionMismatchError,
                  lambda: Equations('dv/dt = -v / tau: volt',
                                    namespace={'tau': 5 * mV}))
    assert_raises(DimensionMismatchError,
                  lambda: Equations('dv/dt = -(v + I) / (5 * ms): volt',
                                    namespace={'I': 3 * second}))    
    
    # inconsistent unit for a static equation
    assert_raises(DimensionMismatchError,
                  lambda: Equations('''dv/dt = -v / (5 * ms) : volt
                                       I = 2 * v : amp'''))
    
    # xi in a static equation
    assert_raises(SyntaxError,
                  lambda: Equations('''dv/dt = -(v + I) / (5 * ms) : volt
                                       I = second**-1*xi**-2*volt : volt''' ))
    
    # more than one xi    
    assert_raises(SyntaxError,                  
                  lambda: Equations('''dv/dt = -v / tau + xi/tau**.5 : volt
                                       dx/dt = -x / tau + 2*xi/tau : volt
                                       tau : second'''))
    # using not-allowed flags
    eqs = Equations('dv/dt = -v / (5 * ms) : volt (flag)')    
    eqs.check_flags({DIFFERENTIAL_EQUATION: ['flag']}) # allow this flag
    assert_raises(ValueError, lambda: eqs.check_flags({DIFFERENTIAL_EQUATION: []}))
    assert_raises(ValueError, lambda: eqs.check_flags({}))
    assert_raises(ValueError, lambda: eqs.check_flags({STATIC_EQUATION: ['flag']}))
    assert_raises(ValueError, lambda: eqs.check_flags({DIFFERENTIAL_EQUATION: ['otherflag']}))
    
    # Circular static equations
    assert_raises(ValueError, lambda: Equations('''dv/dt = -(v + w) / (10 * ms) : 1
                                                   w = 2 * x : 1
                                                   x = 3 * w : 1'''))