Example #1
0
    def compile_functions_C(self, eqs, freeze=False):
        """
        Compile all functions defined as strings.
        If freeze is True, all external parameters and units are replaced by their value.
        ALL FUNCTIONS MUST HAVE STRINGS.
        """
        all_variables = eqs._eq_names + eqs._diffeq_names + eqs._alias.keys(
        ) + ['t']
        # Check if freezable
        freeze = freeze and all([optimiser.freeze(expr, all_variables, eqs._namespace[name])\
                               for name, expr in eqs._string.iteritems()])
        eqs._frozen = freeze

        vars_strings = []
        for var in eqs._diffeq_names:
            vars_strings.append("float " + var)
        vars_strings.append("float t")
        # Compile strings to functions
        for name, expr in eqs._string.iteritems():
            namespace = eqs._namespace[name]  # name space of the function
            expr = optimiser.freeze(expr, all_variables, namespace)
            s = "(" + ",".join(
                vars_strings
            ) + ") {float result = " + expr + "; return result; }"
            eqs._function_C_String[name] = s
Example #2
0
    def generate_forward_euler_code(self):
        eqs = self.eqs
        M = len(eqs._diffeq_names)
        all_variables = eqs._eq_names + eqs._diffeq_names + eqs._alias.keys() + ['t']
        clines = '__global__ void stateupdate(int N, SCALAR t, SCALAR *S)\n'
        clines += '{\n'
        clines += '    int i = blockIdx.x * blockDim.x + threadIdx.x;\n'
        clines += '    if(i>=N) return;\n'
        for j, name in enumerate(eqs._diffeq_names):
            clines += '    int _index_' + name + ' = i+' + str(j) + '*N;\n'
        for j, name in enumerate(eqs._diffeq_names):
#            clines += '    SCALAR &' + name + ' = S[i+'+str(j)+'*N];\n'
            clines += '    SCALAR ' + name + ' = S[_index_' + name + '];\n'
        for j, name in enumerate(eqs._diffeq_names):
            namespace = eqs._namespace[name]
            expr = optimiser.freeze(eqs._string[name], all_variables, namespace)
            expr = rewrite_to_c_expression(expr)
            print expr
            if name in eqs._diffeq_names_nonzero:
                clines += '    SCALAR ' + name + '__tmp = ' + expr + ';\n'
        for name in eqs._diffeq_names_nonzero:
#            clines += '    '+name+' += '+str(self.clock_dt)+'*'+name+'__tmp;\n'
            clines += '    S[_index_' + name + '] = ' + name + '+' + str(self.clock_dt) + '*' + name + '__tmp;\n'
        clines += '}\n'
        clines = clines.replace('SCALAR', self.precision)
        self.gpu_mod = compiler.SourceModule(clines)
        self.gpu_func = self.gpu_mod.get_function("stateupdate")
        return clines
Example #3
0
 def generate_threshold_code(self, src):
     eqs = self.eqs
     threshold = self.G._threshold
     if threshold.__class__ is Threshold:
         state = threshold.state
         if isinstance(state, int):
             state = eqs._diffeq_names[state]
         threshold = state + '>' + str(float(threshold.threshold))
     elif isinstance(threshold, VariableThreshold):
         state = threshold.state
         if isinstance(state, int):
             state = eqs._diffeq_names[state]
         threshold = state + '>' + threshold.threshold_state
     elif isinstance(threshold, StringThreshold):
         namespace = threshold._namespace
         expr = threshold._expr
         all_variables = eqs._eq_names + eqs._diffeq_names + eqs._alias.keys() + ['t']
         expr = optimiser.freeze(expr, all_variables, namespace)
         threshold = expr
     else:
         raise ValueError('Threshold must be constant, VariableThreshold or StringThreshold.')
     
     # Substitute threshold
     src = src.replace('%THRESHOLD%', threshold)
     
     return src
Example #4
0
    def generate_threshold_code(self, src):
        eqs = self.eqs
        threshold = self.G._threshold
        if threshold.__class__ is Threshold:
            state = threshold.state
            if isinstance(state, int):
                state = eqs._diffeq_names[state]
            threshold = state + '>' + str(float(threshold.threshold))
        elif isinstance(threshold, VariableThreshold):
            state = threshold.state
            if isinstance(state, int):
                state = eqs._diffeq_names[state]
            threshold = state + '>' + threshold.threshold_state
        elif isinstance(threshold, StringThreshold):
            namespace = threshold._namespace
            expr = threshold._expr
            all_variables = eqs._eq_names + eqs._diffeq_names + eqs._alias.keys(
            ) + ['t']
            expr = optimiser.freeze(expr, all_variables, namespace)
            threshold = expr
        else:
            raise ValueError(
                'Threshold must be constant, VariableThreshold or StringThreshold.'
            )

        # Substitute threshold
        src = src.replace('%THRESHOLD%', threshold)

        return src
Example #5
0
    def generate_reset_code(self, src):
        eqs = self.eqs
        reset = self.G._resetfun
        if reset.__class__ is Reset:
            state = reset.state
            if isinstance(state, int):
                state = eqs._diffeq_names[state]
            reset = state + ' = ' + str(float(reset.resetvalue))
        elif isinstance(reset, VariableReset):
            state = reset.state
            if isinstance(state, int):
                state = eqs._diffeq_names[state]
            reset = state + ' = ' + reset.resetvaluestate
        elif isinstance(reset, StringReset):
            namespace = reset._namespace
            expr = reset._expr
            all_variables = eqs._eq_names + eqs._diffeq_names + eqs._alias.keys(
            ) + ['t']
            expr = optimiser.freeze(expr, all_variables, namespace)
            reset = expr
#        self.reset = reset
# Substitute reset
        reset = '\n            '.join(line.strip() + ';'
                                      for line in reset.split('\n')
                                      if line.strip())
        src = src.replace('%RESET%', reset)

        return src
Example #6
0
    def generate_reset_code(self, src):
        eqs = self.eqs
        reset = self.G._resetfun
        if reset.__class__ is Reset:
            state = reset.state
            if isinstance(state, int):
                state = eqs._diffeq_names[state]
            reset = state + ' = ' + str(float(reset.resetvalue))
        elif isinstance(reset, VariableReset):
            state = reset.state
            if isinstance(state, int):
                state = eqs._diffeq_names[state]
            reset = state + ' = ' + reset.resetvaluestate
        elif isinstance(reset, StringReset):
            namespace = reset._namespace
            expr = reset._expr
            all_variables = eqs._eq_names + eqs._diffeq_names + eqs._alias.keys() + ['t']
            expr = optimiser.freeze(expr, all_variables, namespace)
            reset = expr
#        self.reset = reset
        # Substitute reset
        reset = '\n            '.join(line.strip() + ';' for line in reset.split('\n') if line.strip())
        src = src.replace('%RESET%', reset)
        
        return src
Example #7
0
def frozen_equations(eqs):
    frozen_eqs = {}
    eqs.prepare()
    all_variables = eqs._eq_names + eqs._diffeq_names + eqs._alias.keys() + ["t"]
    for var in eqs._diffeq_names:
        namespace = eqs._namespace[var]
        var_expr = freeze(eqs._string[var], all_variables, namespace)
        frozen_eqs[var] = var_expr
    return frozen_eqs
Example #8
0
 def compile_functions_C(self, eqs, freeze=False):
     """
     Compile all functions defined as strings.
     If freeze is True, all external parameters and units are replaced by their value.
     ALL FUNCTIONS MUST HAVE STRINGS.
     """
     all_variables = eqs._eq_names + eqs._diffeq_names + eqs._alias.keys() + ['t']
     # Check if freezable
     freeze = freeze and all([optimiser.freeze(expr, all_variables, eqs._namespace[name])\
                            for name, expr in eqs._string.iteritems()])
     eqs._frozen = freeze
     
     vars_strings = []
     for var in eqs._diffeq_names:
         vars_strings.append("float "+ var)
     vars_strings.append("float t")
     # Compile strings to functions
     for name, expr in eqs._string.iteritems():
         namespace = eqs._namespace[name] # name space of the function
         expr = optimiser.freeze(expr, all_variables, namespace)
         s = "(" + ",".join(vars_strings) + ") {float result = " + expr + "; return result; }"
         eqs._function_C_String[name] = s
Example #9
0
def freeze_with_equations(inputcode, eqs, ns):
    '''
    Returns a frozen version of ``inputcode`` with equations and namespace.
    
    Replaces each occurrence in ``inputcode`` of a variable name in the
    namespace ``ns`` with its value if it is of int or float type. Variables
    with names in :class:`brian.Equations` ``eqs`` are not replaced, and neither
    are ``dt`` or ``t``.
    '''
    inputcode = inputcode.strip()
    all_variables = eqs._eq_names+eqs._diffeq_names+eqs._alias.keys()+['dt', 't']
    inputcode = freeze(inputcode, all_variables, ns)
    return inputcode
Example #10
0
def frozen_equations(eqs):
    '''
    Returns a frozen set of equations.
    
    Each expression defining an equation is frozen as in
    :func:`freeze_with_equations`.
    '''
    frozen_eqs = {}
    eqs.prepare()
    all_variables = eqs._eq_names+eqs._diffeq_names+eqs._alias.keys()+['dt', 't']
    for var in eqs._diffeq_names:
        namespace = eqs._namespace[var]
        var_expr = freeze(eqs._string[var], all_variables, namespace)
        frozen_eqs[var] = var_expr
    return frozen_eqs
Example #11
0
def freeze_with_equations(inputcode, eqs, ns):
    '''
    Returns a frozen version of ``inputcode`` with equations and namespace.
    
    Replaces each occurrence in ``inputcode`` of a variable name in the
    namespace ``ns`` with its value if it is of int or float type. Variables
    with names in :class:`brian.Equations` ``eqs`` are not replaced, and neither
    are ``dt`` or ``t``.
    '''
    inputcode = inputcode.strip()
    all_variables = eqs._eq_names + eqs._diffeq_names + eqs._alias.keys() + [
        'dt', 't'
    ]
    inputcode = freeze(inputcode, all_variables, ns)
    return inputcode
Example #12
0
    def calc(self, var, res_gpu):
        n = len(self.S_in)
        var_len = len(self.S_in[0])

        if var not in self.already_calc:
            vars_strings = []
            for var_aux in self.eqs._diffeq_names:
                vars_strings.append("float " + var_aux)
            vars_strings.append("float t")

            expr = self.eqs._string[var]
            namespace = self.eqs._namespace[var]  # name space of the function
            all_variables = self.eqs._eq_names + self.eqs._diffeq_names + self.eqs._alias.keys(
            ) + ['t']
            expr = optimiser.freeze(expr, all_variables, namespace)
            s = "(" + ",".join(
                vars_strings
            ) + ") { float result = " + expr + "; return result; }"
            for var_aux in self.eqs._diffeq_names:  #this is ugly. really ugly.
                s = s.replace(var_aux + '**2', var_aux + '*' + var_aux)
                s = s.replace(var_aux + '**3',
                              var_aux + '*' + var_aux + '*' + var_aux)
                s = s.replace(
                    var_aux + '**4',
                    var_aux + '*' + var_aux + '*' + var_aux + '*' + var_aux)
                s = s.replace(
                    var_aux + '**5', var_aux + '*' + var_aux + '*' + var_aux +
                    '*' + var_aux + '*' + var_aux)
            args_fun = []
            for i in xrange(var_len):
                args_fun.append("S_out[" + str(i) + " + blockIdx.x * var_len]")
            mod = SourceModule("""
                    __device__ float f""" + s + """
                    
                    __global__ void calc(float *res,float *S_out, int var_len)
                    { 
                        int idx = blockIdx.x;
                        res[idx] = f(""" + ",".join(args_fun) + """);
                        
                    }
                    """)
            self.calc_dict[var] = mod.get_function("calc")
            self.calc_dict[var].prepare(['P', 'P', 'i'], block=(1, 1, 1))
            self.already_calc[var] = True
        self.calc_dict[var].prepared_call((n, 1), res_gpu, self.S_out_gpu,
                                          numpy.int32(var_len))
Example #13
0
def frozen_equations(eqs):
    '''
    Returns a frozen set of equations.
    
    Each expression defining an equation is frozen as in
    :func:`freeze_with_equations`.
    '''
    frozen_eqs = {}
    eqs.prepare()
    all_variables = eqs._eq_names + eqs._diffeq_names + eqs._alias.keys() + [
        'dt', 't'
    ]
    for var in eqs._diffeq_names:
        namespace = eqs._namespace[var]
        var_expr = freeze(eqs._string[var], all_variables, namespace)
        frozen_eqs[var] = var_expr
    return frozen_eqs
Example #14
0
 def calc(self,var,res_gpu):
     n = len(self.S_in)
     var_len = len(self.S_in[0])
     
     if var not in self.already_calc:
         vars_strings = []
         for var_aux in self.eqs._diffeq_names:
             vars_strings.append("float "+ var_aux)
         vars_strings.append("float t")
         
         expr = self.eqs._string[var]
         namespace = self.eqs._namespace[var] # name space of the function
         all_variables = self.eqs._eq_names + self.eqs._diffeq_names + self.eqs._alias.keys() + ['t']
         expr = optimiser.freeze(expr, all_variables, namespace)
         s = "(" + ",".join(vars_strings) + ") { float result = " + expr + "; return result; }"
         for var_aux in self.eqs._diffeq_names: #this is ugly. really ugly.
             s = s.replace(var_aux+'**2',var_aux+'*'+var_aux)
             s = s.replace(var_aux+'**3',var_aux+'*'+var_aux+'*'+var_aux)
             s = s.replace(var_aux+'**4',var_aux+'*'+var_aux+'*'+var_aux+'*'+var_aux)
             s = s.replace(var_aux+'**5',var_aux+'*'+var_aux+'*'+var_aux+'*'+var_aux+'*'+var_aux)
         args_fun =[]
         for i in xrange(var_len):
             args_fun.append("S_out["+str(i)+" + blockIdx.x * var_len]")
         mod = SourceModule("""
                 __device__ float f"""+ s +"""
                 
                 __global__ void calc(float *res,float *S_out, int var_len)
                 { 
                     int idx = blockIdx.x;
                     res[idx] = f("""+",".join(args_fun)+""");
                     
                 }
                 """)
         self.calc_dict[var] = mod.get_function("calc")
         self.calc_dict[var].prepare(['P','P','i'],block=(1,1,1))
         self.already_calc[var] = True
     self.calc_dict[var].prepared_call((n,1),res_gpu,self.S_out_gpu,numpy.int32(var_len))
Example #15
0
    def __init__(self,
                 C,
                 eqs,
                 pre,
                 post,
                 wmin=0,
                 wmax=Inf,
                 level=0,
                 clock=None,
                 delay_pre=None,
                 delay_post=None):
        NetworkOperation.__init__(self, lambda: None, clock=clock)
        C.compress()
        # Convert to equations object
        if isinstance(eqs, Equations):
            eqs_obj = eqs
        else:
            eqs_obj = Equations(eqs, level=level + 1)
        # handle multi-line pre, post equations and multi-statement equations separated by ;
        if '\n' in pre:
            pre = flattened_docstring(pre)
        elif ';' in pre:
            pre = '\n'.join([line.strip() for line in pre.split(';')])
        if '\n' in post:
            post = flattened_docstring(post)
        elif ';' in post:
            post = '\n'.join([line.strip() for line in post.split(';')])

        # Check units
        eqs_obj.compile_functions()
        eqs_obj.check_units()

        # Get variable names
        vars = eqs_obj._diffeq_names
        # Find which ones are directly modified (e.g. regular expression matching; careful with comments)
        vars_pre = [var for var in vars if var in modified_variables(pre)]
        vars_post = [var for var in vars if var in modified_variables(post)]

        # additional dependencies are used to ensure that if there are multiple
        # pre/post separated equations they are grouped together as one
        additional_deps = [
            '__pre_deps=' + '+'.join(vars_pre),
            '__post_deps=' + '+'.join(vars_post)
        ]
        separated_equations = separate_equations(eqs_obj, additional_deps)
        if not len(separated_equations) == 2:
            print separated_equations
            raise ValueError(
                'Equations should separate into pre and postsynaptic variables.'
            )
        sep_pre, sep_post = separated_equations
        for v in vars_pre:
            if v in sep_post._diffeq_names:
                sep_pre, sep_post = sep_post, sep_pre
                break

        index_pre = [
            i for i in range(len(vars))
            if vars[i] in vars_pre or vars[i] in sep_pre._diffeq_names
        ]
        index_post = [
            i for i in range(len(vars))
            if vars[i] in vars_post or vars[i] in sep_post._diffeq_names
        ]

        vars_pre = array(vars)[index_pre].tolist()
        vars_post = array(vars)[index_post].tolist()

        # Check pre/post consistency
        shared_vars = set(vars_pre).intersection(vars_post)
        if shared_vars != set([]):
            raise Exception, str(
                list(shared_vars)) + " are both presynaptic and postsynaptic!"

        # Substitute equations/aliases into pre/post code
        def substitute_eqs(code):
            for name in sep_pre._eq_names[-1::-1] + sep_post._eq_names[
                    -1::-1]:  # reverse order, as in equations.py
                if name in sep_pre._eq_names:
                    expr = sep_pre._string[name]
                else:
                    expr = sep_post._string[name]
                code = re.sub("\\b" + name + "\\b", '(' + expr + ')', code)
            return code

        pre = substitute_eqs(pre)
        post = substitute_eqs(post)

        # Create namespaces for pre and post codes
        pre_namespace = namespace(pre, level=level + 1)
        post_namespace = namespace(post, level=level + 1)

        def splitcode(incode):
            num_perneuron = num_persynapse = 0
            reordering_warning = False
            incode_lines = [
                line.strip() for line in incode.split('\n') if line.strip()
            ]
            per_neuron_lines = []
            per_synapse_lines = []
            for line in incode_lines:
                if not line.strip(): continue
                m = re.search(
                    r'\bw\b\s*[^><=]?=',
                    line)  # lines of the form w = ..., w *= ..., etc.
                if m:
                    num_persynapse += 1
                    per_synapse_lines.append(line)
                else:
                    num_perneuron += 1
                    if num_persynapse != 0 and not reordering_warning:
                        log_warn(
                            'brian.experimental.cstdp',
                            'STDP operations are being re-ordered, results may be wrong.'
                        )
                        reordering_warning = True
                    per_neuron_lines.append(line)
            return per_neuron_lines, per_synapse_lines

        per_neuron_pre, per_synapse_pre = splitcode(pre)
        per_neuron_post, per_synapse_post = splitcode(post)

        all_vars = vars_pre + vars_post + ['w']

        per_neuron_pre = [
            c_single_statement(freeze(line, all_vars, pre_namespace))
            for line in per_neuron_pre
        ]
        per_neuron_post = [
            c_single_statement(freeze(line, all_vars, post_namespace))
            for line in per_neuron_post
        ]
        per_synapse_pre = [
            c_single_statement(freeze(line, all_vars, pre_namespace))
            for line in per_synapse_pre
        ]
        per_synapse_post = [
            c_single_statement(freeze(line, all_vars, post_namespace))
            for line in per_synapse_post
        ]

        per_neuron_pre = '\n'.join(per_neuron_pre)
        per_neuron_post = '\n'.join(per_neuron_post)
        per_synapse_pre = '\n'.join(per_synapse_pre)
        per_synapse_post = '\n'.join(per_synapse_post)

        # Neuron groups
        G_pre = NeuronGroup(len(C.source), model=sep_pre, clock=self.clock)
        G_post = NeuronGroup(len(C.target), model=sep_post, clock=self.clock)
        G_pre._S[:] = 0
        G_post._S[:] = 0
        self.pre_group = G_pre
        self.post_group = G_post
        var_group = {}
        for i, v in enumerate(vars_pre):
            var_group[v] = G_pre
        for i, v in enumerate(vars_post):
            var_group[v] = G_post
        self.var_group = var_group

        self.contained_objects += [G_pre, G_post]

        vars_pre_ind = {}
        for i, var in enumerate(vars_pre):
            vars_pre_ind[var] = i
        vars_post_ind = {}
        for i, var in enumerate(vars_post):
            vars_post_ind[var] = i

        prevars_dict = dict((k, G_pre.state(k)) for k in vars_pre)
        postvars_dict = dict((k, G_post.state(k)) for k in vars_post)

        clipcode = ''
        if isfinite(wmin):
            clipcode += 'if(w<%wmin%) w = %wmin%;\n'.replace(
                '%wmin%', repr(float(wmin)))
        if isfinite(wmax):
            clipcode += 'if(w>%wmax%) w = %wmax%;\n'.replace(
                '%wmax%', repr(float(wmax)))

        if not isinstance(C, DelayConnection):
            precode = iterate_over_spikes(
                '_j', '_spikes',
                (load_required_variables(
                    '_j', prevars_dict), transform_code(per_neuron_pre),
                 iterate_over_row(
                     '_k', 'w', C.W, '_j', (load_required_variables(
                         '_k', postvars_dict), transform_code(per_synapse_pre),
                                            ConnectionCode(clipcode)))))
            postcode = iterate_over_spikes(
                '_j', '_spikes',
                (load_required_variables(
                    '_j', postvars_dict), transform_code(per_neuron_post),
                 iterate_over_col('_i', 'w', C.W, '_j',
                                  (load_required_variables('_i', prevars_dict),
                                   transform_code(per_synapse_post),
                                   ConnectionCode(clipcode)))))
            log_debug('brian.experimental.c_stdp',
                      'CSTDP Pre code:\n' + str(precode))
            log_debug('brian.experimental.c_stdp',
                      'CSTDP Post code:\n' + str(postcode))
            connection_delay = C.delay * C.source.clock.dt
            if (delay_pre is None) and (
                    delay_post is None):  # same delays as the Connnection C
                delay_pre = connection_delay
                delay_post = 0 * ms
            elif delay_pre is None:
                delay_pre = connection_delay - delay_post
                if delay_pre < 0 * ms:
                    raise AttributeError, "Postsynaptic delay is too large"
            elif delay_post is None:
                delay_post = connection_delay - delay_pre
                if delay_post < 0 * ms:
                    raise AttributeError, "Postsynaptic delay is too large"
            # create forward and backward Connection objects or SpikeMonitor objects
            pre_updater = SpikeMonitor(C.source,
                                       function=precode,
                                       delay=delay_pre)
            post_updater = SpikeMonitor(C.target,
                                        function=postcode,
                                        delay=delay_post)
            updaters = [pre_updater, post_updater]
            self.contained_objects += [pre_updater, post_updater]
        else:
            if delay_pre is not None or delay_post is not None:
                raise ValueError(
                    "Must use delay_pre=delay_post=None for the moment.")
            max_delay = C._max_delay * C.target.clock.dt
            # Ensure that the source and target neuron spikes are kept for at least the
            # DelayConnection's maximum delay
            C.source.set_max_delay(max_delay)
            C.target.set_max_delay(max_delay)

            self.G_pre_monitors = {}
            self.G_post_monitors = {}
            self.G_pre_monitors.update(
                ((var,
                  RecentStateMonitor(G_pre,
                                     vars_pre_ind[var],
                                     duration=(C._max_delay + 1) *
                                     C.target.clock.dt,
                                     clock=G_pre.clock)) for var in vars_pre))
            self.G_post_monitors.update(
                ((var,
                  RecentStateMonitor(
                      G_post,
                      vars_post_ind[var],
                      duration=(C._max_delay + 1) * C.target.clock.dt,
                      clock=G_post.clock)) for var in vars_post))
            self.contained_objects += self.G_pre_monitors.values()
            self.contained_objects += self.G_post_monitors.values()

            prevars_dict_delayed = dict(
                (k, self.G_pre_monitors[k]) for k in prevars_dict.keys())
            postvars_dict_delayed = dict(
                (k, self.G_post_monitors[k]) for k in postvars_dict.keys())

            precode_immediate = iterate_over_spikes(
                '_j', '_spikes', (load_required_variables(
                    '_j', prevars_dict), transform_code(per_neuron_pre)))
            precode_delayed = iterate_over_spikes(
                '_j', '_spikes',
                iterate_over_row('_k',
                                 'w',
                                 C.W,
                                 '_j',
                                 extravars={'_delay': C.delayvec},
                                 code=(ConnectionCode(
                                     'double _t_past = _max_delay-_delay;',
                                     vars={'_max_delay': float(max_delay)}),
                                       load_required_variables_pastvalue(
                                           '_k', '_t_past',
                                           postvars_dict_delayed),
                                       transform_code(per_synapse_pre),
                                       ConnectionCode(clipcode))))
            postcode = iterate_over_spikes(
                '_j', '_spikes',
                (load_required_variables(
                    '_j', postvars_dict), transform_code(per_neuron_post),
                 iterate_over_col('_i',
                                  'w',
                                  C.W,
                                  '_j',
                                  extravars={'_delay': C.delayvec},
                                  code=(load_required_variables_pastvalue(
                                      '_i', '_delay', prevars_dict_delayed),
                                        transform_code(per_synapse_post),
                                        ConnectionCode(clipcode)))))
            log_debug('brian.experimental.c_stdp',
                      'CSTDP Pre code (immediate):\n' + str(precode_immediate))
            log_debug('brian.experimental.c_stdp',
                      'CSTDP Pre code (delayed):\n' + str(precode_delayed))
            log_debug('brian.experimental.c_stdp',
                      'CSTDP Post code:\n' + str(postcode))
            pre_updater_immediate = SpikeMonitor(C.source,
                                                 function=precode_immediate)
            pre_updater_delayed = SpikeMonitor(C.source,
                                               function=precode_delayed,
                                               delay=max_delay)
            post_updater = SpikeMonitor(C.target, function=postcode)
            updaters = [
                pre_updater_immediate, pre_updater_delayed, post_updater
            ]
            self.contained_objects += updaters
Example #16
0
    def __init__(self, C, eqs, pre, post, wmin=0, wmax=Inf, level=0,
                 clock=None, delay_pre=None, delay_post=None):
        NetworkOperation.__init__(self, lambda:None, clock=clock)
        C.compress()
        # Convert to equations object
        if isinstance(eqs, Equations):
            eqs_obj = eqs
        else:
            eqs_obj = Equations(eqs, level=level + 1)
        # handle multi-line pre, post equations and multi-statement equations separated by ;
        if '\n' in pre:
            pre = flattened_docstring(pre)
        elif ';' in pre:
            pre = '\n'.join([line.strip() for line in pre.split(';')])
        if '\n' in post:
            post = flattened_docstring(post)
        elif ';' in post:
            post = '\n'.join([line.strip() for line in post.split(';')])

        # Check units
        eqs_obj.compile_functions()
        eqs_obj.check_units()

        # Get variable names
        vars = eqs_obj._diffeq_names
        # Find which ones are directly modified (e.g. regular expression matching; careful with comments)
        vars_pre = [var for var in vars if var in modified_variables(pre)]
        vars_post = [var for var in vars if var in modified_variables(post)]

        # additional dependencies are used to ensure that if there are multiple
        # pre/post separated equations they are grouped together as one
        additional_deps = ['__pre_deps='+'+'.join(vars_pre),
                           '__post_deps='+'+'.join(vars_post)]
        separated_equations = separate_equations(eqs_obj, additional_deps)
        if not len(separated_equations) == 2:
            print separated_equations
            raise ValueError('Equations should separate into pre and postsynaptic variables.')
        sep_pre, sep_post = separated_equations
        for v in vars_pre:
            if v in sep_post._diffeq_names:
                sep_pre, sep_post = sep_post, sep_pre
                break

        index_pre = [i for i in range(len(vars)) if vars[i] in vars_pre or vars[i] in sep_pre._diffeq_names]
        index_post = [i for i in range(len(vars)) if vars[i] in vars_post or vars[i] in sep_post._diffeq_names]

        vars_pre = array(vars)[index_pre].tolist()
        vars_post = array(vars)[index_post].tolist()

        # Check pre/post consistency
        shared_vars = set(vars_pre).intersection(vars_post)
        if shared_vars != set([]):
            raise Exception, str(list(shared_vars)) + " are both presynaptic and postsynaptic!"

        # Substitute equations/aliases into pre/post code
        def substitute_eqs(code):
            for name in sep_pre._eq_names[-1::-1]+sep_post._eq_names[-1::-1]: # reverse order, as in equations.py
                if name in sep_pre._eq_names:
                    expr = sep_pre._string[name]
                else:
                    expr = sep_post._string[name]
                code = re.sub("\\b" + name + "\\b", '(' + expr + ')', code)
            return code
        pre = substitute_eqs(pre)
        post = substitute_eqs(post)

        # Create namespaces for pre and post codes
        pre_namespace = namespace(pre, level=level + 1)
        post_namespace = namespace(post, level=level + 1)

        def splitcode(incode):
            num_perneuron = num_persynapse = 0
            reordering_warning = False
            incode_lines = [line.strip() for line in incode.split('\n') if line.strip()]
            per_neuron_lines = []
            per_synapse_lines = []
            for line in incode_lines:
                if not line.strip(): continue
                m = re.search(r'\bw\b\s*[^><=]?=', line) # lines of the form w = ..., w *= ..., etc.
                if m:
                    num_persynapse += 1
                    per_synapse_lines.append(line)
                else:
                    num_perneuron += 1
                    if num_persynapse!=0 and not reordering_warning:
                        log_warn('brian.experimental.cstdp', 'STDP operations are being re-ordered, results may be wrong.')
                        reordering_warning = True
                    per_neuron_lines.append(line)
            return per_neuron_lines, per_synapse_lines

        per_neuron_pre, per_synapse_pre = splitcode(pre)
        per_neuron_post, per_synapse_post = splitcode(post)

        all_vars = vars_pre + vars_post + ['w']        

        per_neuron_pre = [c_single_statement(freeze(line, all_vars, pre_namespace)) for line in per_neuron_pre]
        per_neuron_post = [c_single_statement(freeze(line, all_vars, post_namespace)) for line in per_neuron_post]
        per_synapse_pre = [c_single_statement(freeze(line, all_vars, pre_namespace)) for line in per_synapse_pre]
        per_synapse_post = [c_single_statement(freeze(line, all_vars, post_namespace)) for line in per_synapse_post]

        per_neuron_pre = '\n'.join(per_neuron_pre)
        per_neuron_post = '\n'.join(per_neuron_post)
        per_synapse_pre = '\n'.join(per_synapse_pre)
        per_synapse_post = '\n'.join(per_synapse_post)

        # Neuron groups
        G_pre = NeuronGroup(len(C.source), model=sep_pre, clock=self.clock)
        G_post = NeuronGroup(len(C.target), model=sep_post, clock=self.clock)
        G_pre._S[:] = 0
        G_post._S[:] = 0
        self.pre_group = G_pre
        self.post_group = G_post
        var_group = {}
        for i, v in enumerate(vars_pre):
            var_group[v] = G_pre
        for i, v in enumerate(vars_post):
            var_group[v] = G_post
        self.var_group = var_group

        self.contained_objects += [G_pre, G_post]

        vars_pre_ind = {}
        for i, var in enumerate(vars_pre):
            vars_pre_ind[var] = i
        vars_post_ind = {}
        for i, var in enumerate(vars_post):
            vars_post_ind[var] = i

        prevars_dict = dict((k, G_pre.state(k)) for k in vars_pre)
        postvars_dict = dict((k, G_post.state(k)) for k in vars_post)

        clipcode = ''
        if isfinite(wmin):
            clipcode += 'if(w<%wmin%) w = %wmin%;\n'.replace('%wmin%', repr(float(wmin)))
        if isfinite(wmax):
            clipcode += 'if(w>%wmax%) w = %wmax%;\n'.replace('%wmax%', repr(float(wmax)))

        if not isinstance(C, DelayConnection):
            precode = iterate_over_spikes('_j', '_spikes',
                        (load_required_variables('_j', prevars_dict),
                         transform_code(per_neuron_pre),
                         iterate_over_row('_k', 'w', C.W, '_j',
                            (load_required_variables('_k', postvars_dict),
                             transform_code(per_synapse_pre),
                             ConnectionCode(clipcode)))))
            postcode = iterate_over_spikes('_j', '_spikes',
                        (load_required_variables('_j', postvars_dict),
                         transform_code(per_neuron_post),
                         iterate_over_col('_i', 'w', C.W, '_j',
                            (load_required_variables('_i', prevars_dict),
                             transform_code(per_synapse_post),
                             ConnectionCode(clipcode)))))
            log_debug('brian.experimental.c_stdp', 'CSTDP Pre code:\n' + str(precode))
            log_debug('brian.experimental.c_stdp', 'CSTDP Post code:\n' + str(postcode))
            connection_delay = C.delay * C.source.clock.dt
            if (delay_pre is None) and (delay_post is None): # same delays as the Connnection C
                delay_pre = connection_delay
                delay_post = 0 * ms
            elif delay_pre is None:
                delay_pre = connection_delay - delay_post
                if delay_pre < 0 * ms: raise AttributeError, "Postsynaptic delay is too large"
            elif delay_post is None:
                delay_post = connection_delay - delay_pre
                if delay_post < 0 * ms: raise AttributeError, "Postsynaptic delay is too large"
            # create forward and backward Connection objects or SpikeMonitor objects
            pre_updater = SpikeMonitor(C.source, function=precode, delay=delay_pre)
            post_updater = SpikeMonitor(C.target, function=postcode, delay=delay_post)
            updaters = [pre_updater, post_updater]
            self.contained_objects += [pre_updater, post_updater]
        else:
            if delay_pre is not None or delay_post is not None:
                raise ValueError("Must use delay_pre=delay_post=None for the moment.")
            max_delay = C._max_delay * C.target.clock.dt
            # Ensure that the source and target neuron spikes are kept for at least the
            # DelayConnection's maximum delay
            C.source.set_max_delay(max_delay)
            C.target.set_max_delay(max_delay)

            self.G_pre_monitors = {}
            self.G_post_monitors = {}
            self.G_pre_monitors.update(((var, RecentStateMonitor(G_pre, vars_pre_ind[var], duration=(C._max_delay + 1) * C.target.clock.dt, clock=G_pre.clock)) for var in vars_pre))
            self.G_post_monitors.update(((var, RecentStateMonitor(G_post, vars_post_ind[var], duration=(C._max_delay + 1) * C.target.clock.dt, clock=G_post.clock)) for var in vars_post))
            self.contained_objects += self.G_pre_monitors.values()
            self.contained_objects += self.G_post_monitors.values()

            prevars_dict_delayed = dict((k, self.G_pre_monitors[k]) for k in prevars_dict.keys())
            postvars_dict_delayed = dict((k, self.G_post_monitors[k]) for k in postvars_dict.keys())

            precode_immediate = iterate_over_spikes('_j', '_spikes',
                                    (load_required_variables('_j', prevars_dict),
                                     transform_code(per_neuron_pre)))
            precode_delayed = iterate_over_spikes('_j', '_spikes',
                                     iterate_over_row('_k', 'w', C.W, '_j', extravars={'_delay':C.delayvec},
                                        code=(
                                         ConnectionCode('double _t_past = _max_delay-_delay;', vars={'_max_delay':float(max_delay)}),
                                         load_required_variables_pastvalue('_k', '_t_past', postvars_dict_delayed),
                                         transform_code(per_synapse_pre),
                                         ConnectionCode(clipcode))))
            postcode = iterate_over_spikes('_j', '_spikes',
                            (load_required_variables('_j', postvars_dict),
                             transform_code(per_neuron_post),
                             iterate_over_col('_i', 'w', C.W, '_j', extravars={'_delay':C.delayvec},
                                code=(
                                 load_required_variables_pastvalue('_i', '_delay', prevars_dict_delayed),
                                 transform_code(per_synapse_post),
                                 ConnectionCode(clipcode)))))
            log_debug('brian.experimental.c_stdp', 'CSTDP Pre code (immediate):\n' + str(precode_immediate))
            log_debug('brian.experimental.c_stdp', 'CSTDP Pre code (delayed):\n' + str(precode_delayed))
            log_debug('brian.experimental.c_stdp', 'CSTDP Post code:\n' + str(postcode))
            pre_updater_immediate = SpikeMonitor(C.source, function=precode_immediate)
            pre_updater_delayed = SpikeMonitor(C.source, function=precode_delayed, delay=max_delay)
            post_updater = SpikeMonitor(C.target, function=postcode)
            updaters = [pre_updater_immediate, pre_updater_delayed, post_updater]
            self.contained_objects += updaters
Example #17
0
 def __init__(self,
              G,
              eqs,
              I,
              I_offset,
              spiketimes,
              spiketimes_offset,
              spikedelays,
              refractory,
              delta,
              onset=0 * ms,
              coincidence_count_algorithm='exclusive',
              precision=default_precision,
              scheme=euler_scheme):
     eqs.prepare()
     self.precision = precision
     if precision == 'double':
         self.mydtype = float64
     else:
         self.mydtype = float32
     self.N = N = len(G)
     self.dt = dt = G.clock.dt
     self.delta = delta
     self.onset = onset
     self.eqs = eqs
     self.G = G
     self.coincidence_count_algorithm = coincidence_count_algorithm
     threshold = G._threshold
     if threshold.__class__ is Threshold:
         state = threshold.state
         if isinstance(state, int):
             state = eqs._diffeq_names[state]
         threshold = state + '>' + str(float(threshold.threshold))
     elif isinstance(threshold, VariableThreshold):
         state = threshold.state
         if isinstance(state, int):
             state = eqs._diffeq_names[state]
         threshold = state + '>' + threshold.threshold_state
     elif isinstance(threshold, StringThreshold):
         namespace = threshold._namespace
         expr = threshold._expr
         all_variables = eqs._eq_names + eqs._diffeq_names + eqs._alias.keys(
         ) + ['t']
         expr = optimiser.freeze(expr, all_variables, namespace)
         threshold = expr
     else:
         raise ValueError(
             'Threshold must be constant, VariableThreshold or StringThreshold.'
         )
     self.threshold = threshold
     reset = G._resetfun
     if reset.__class__ is Reset:
         state = reset.state
         if isinstance(state, int):
             state = eqs._diffeq_names[state]
         reset = state + ' = ' + str(float(reset.resetvalue))
     elif isinstance(reset, VariableReset):
         state = reset.state
         if isinstance(state, int):
             state = eqs._diffeq_names[state]
         reset = state + ' = ' + reset.resetvaluestate
     elif isinstance(reset, StringReset):
         namespace = reset._namespace
         expr = reset._expr
         all_variables = eqs._eq_names + eqs._diffeq_names + eqs._alias.keys(
         ) + ['t']
         expr = optimiser.freeze(expr, all_variables, namespace)
         reset = expr
     self.reset = reset
     self.kernel_src = generate_modelfitting_kernel_src(
         self.G,
         eqs,
         threshold,
         reset,
         dt,
         N,
         delta,
         coincidence_count_algorithm=coincidence_count_algorithm,
         precision=precision,
         scheme=scheme)
Example #18
0
def freeze_with_equations(inputcode, eqs, ns):
    inputcode = inputcode.strip()
    all_variables = eqs._eq_names + eqs._diffeq_names + eqs._alias.keys() + ["t"]
    inputcode = freeze(inputcode, all_variables, ns)
    return inputcode
Example #19
0
 def __init__(self, G, eqs, I, I_offset, spiketimes, spiketimes_offset,
                    spikedelays, refractory,
                    delta, onset=0 * ms,
                    coincidence_count_algorithm='exclusive',
                    precision=default_precision,
                    scheme=euler_scheme
                    ):
     eqs.prepare()
     self.precision = precision
     if precision == 'double':
         self.mydtype = float64
     else:
         self.mydtype = float32
     self.N = N = len(G)
     self.dt = dt = G.clock.dt
     self.delta = delta
     self.onset = onset
     self.eqs = eqs
     self.G = G
     self.coincidence_count_algorithm = coincidence_count_algorithm
     threshold = G._threshold
     if threshold.__class__ is Threshold:
         state = threshold.state
         if isinstance(state, int):
             state = eqs._diffeq_names[state]
         threshold = state + '>' + str(float(threshold.threshold))
     elif isinstance(threshold, VariableThreshold):
         state = threshold.state
         if isinstance(state, int):
             state = eqs._diffeq_names[state]
         threshold = state + '>' + threshold.threshold_state
     elif isinstance(threshold, StringThreshold):
         namespace = threshold._namespace
         expr = threshold._expr
         all_variables = eqs._eq_names + eqs._diffeq_names + eqs._alias.keys() + ['t']
         expr = optimiser.freeze(expr, all_variables, namespace)
         threshold = expr
     else:
         raise ValueError('Threshold must be constant, VariableThreshold or StringThreshold.')
     self.threshold = threshold
     reset = G._resetfun
     if reset.__class__ is Reset:
         state = reset.state
         if isinstance(state, int):
             state = eqs._diffeq_names[state]
         reset = state + ' = ' + str(float(reset.resetvalue))
     elif isinstance(reset, VariableReset):
         state = reset.state
         if isinstance(state, int):
             state = eqs._diffeq_names[state]
         reset = state + ' = ' + reset.resetvaluestate
     elif isinstance(reset, StringReset):
         namespace = reset._namespace
         expr = reset._expr
         all_variables = eqs._eq_names + eqs._diffeq_names + eqs._alias.keys() + ['t']
         expr = optimiser.freeze(expr, all_variables, namespace)
         reset = expr
     self.reset = reset
     self.kernel_src = generate_modelfitting_kernel_src(
               self.G, eqs, threshold, reset, dt, N, delta,
               coincidence_count_algorithm=coincidence_count_algorithm,
               precision=precision, scheme=scheme)