def process_ITE(condition): if_statement = condition[0] then_statement = condition[1] else_statement = condition[2] if_code = Equation(name, if_statement, description, untouched=untouched.keys(), type='cond').parse() if isinstance(then_statement, list): # nested conditional then_code = process_ITE(then_statement) else: then_code = Equation(name, then_statement, description, untouched=untouched.keys(), type='return').parse().split(';')[0] if isinstance(else_statement, list): # nested conditional else_code = process_ITE(else_statement) else: else_code = Equation(name, else_statement, description, untouched=untouched.keys(), type='return').parse().split(';')[0] code = '(' + if_code + ' ? ' + then_code + ' : ' + else_code + ')' return code
def translate_ITE(name, eq, condition, description, untouched, split=True): " Recursively processes the different parts of an ITE statement" def process_ITE(condition): if_statement = condition[0] then_statement = condition[1] else_statement = condition[2] if_code = Equation(name, if_statement, description, untouched=untouched.keys(), type='cond').parse() if isinstance(then_statement, list): # nested conditional then_code = process_ITE(then_statement) else: then_code = Equation(name, then_statement, description, untouched=untouched.keys(), type='return').parse().split(';')[0] if isinstance(else_statement, list): # nested conditional else_code = process_ITE(else_statement) else: else_code = Equation(name, else_statement, description, untouched=untouched.keys(), type='return').parse().split(';')[0] code = '(' + if_code + ' ? ' + then_code + ' : ' + else_code + ')' return code if split: # Main equation, where the right part is __conditional__ translator = Equation(name, eq, description, untouched=untouched.keys()) code = translator.parse() else: code = '__conditional__' # Process the ITE itecode = process_ITE(condition) # Replace if isinstance(code, str): code = code.replace('__conditional__', itecode) else: code[0] = code[0].replace('__conditional__', itecode) return code
def extract_spike_variable(description): cond = prepare_string(description['raw_spike']) if len(cond) > 1: Global.Global._print(description['raw_spike']) Global._error('The spike condition must be a single expression') translator = Equation('raw_spike_cond', cond[0].strip(), description) raw_spike_code = translator.parse() # Also store the variables used in the condition, as it may be needed for CUDA generation spike_code_dependencies = translator.dependencies() reset_desc = [] if 'raw_reset' in description.keys() and description['raw_reset']: reset_desc = process_equations(description['raw_reset']) for var in reset_desc: translator = Equation(var['name'], var['eq'], description) var['cpp'] = translator.parse() var['dependencies'] = translator.dependencies() return { 'spike_cond': raw_spike_code, 'spike_cond_dependencies': spike_code_dependencies, 'spike_reset': reset_desc }
def extract_stop_condition(pop): eq = pop['stop_condition']['eq'] pop['stop_condition']['type'] = 'any' # Check the flags split = eq.split(':') if len(split) > 1: # flag given eq = split[0] flags = split[1].strip() split = flags.split(' ') for el in split: if el.strip() == 'all': pop['stop_condition']['type'] = 'all' # Convert the expression translator = Equation('stop_cond', eq, pop, type='cond') code = translator.parse() pop['stop_condition']['cpp'] = '(' + code + ')'
def extract_stop_condition(pop): eq = pop['stop_condition']['eq'] pop['stop_condition']['type'] = 'any' # Check the flags split = eq.split(':') if len(split) > 1: # flag given eq = split[0] flags = split[1].strip() split = flags.split(' ') for el in split: if el.strip() == 'all': pop['stop_condition']['type'] = 'all' # Convert the expression translator = Equation('stop_cond', eq, pop, type = 'cond') code = translator.parse() pop['stop_condition']['cpp'] = '(' + code + ')'
def translate_ITE(name, eq, condition, description, untouched, split=True): " Recursively processes the different parts of an ITE statement" def process_condition(condition): if_statement = condition[0] then_statement = condition[1] else_statement = condition[2] if_code = Equation(name, if_statement, description, untouched = untouched.keys(), type='cond').parse() if isinstance(then_statement, list): # nested conditional then_code = process_condition(then_statement) else: then_code = Equation(name, then_statement, description, untouched = untouched.keys(), type='return').parse().split(';')[0] if isinstance(else_statement, list): # nested conditional else_code = process_condition(else_statement) else: else_code = Equation(name, else_statement, description, untouched = untouched.keys(), type='return').parse().split(';')[0] code = '(' + if_code + ' ? ' + then_code + ' : ' + else_code + ')' return code if split: # Main equation, where the right part is __conditional__ translator = Equation(name, eq, description, untouched = untouched.keys()) code = translator.parse() else: code = eq # Process the (possibly multiple) ITE for i in range(len(condition)): itecode = process_condition(condition[i]) # Replace if isinstance(code, str): code = code.replace('__conditional__'+str(i), itecode) else: code[0] = code[0].replace('__conditional__'+str(i), itecode) return code
def __init__(self, description, variables): self.description = description self.variables = variables self.untouched = variables[0]['untouched'] self.expression_list = {} for var in self.variables: self.expression_list[var['name']] = var['transformed_eq'] self.names = self.expression_list.keys() self.local_variables = self.description['local'] self.global_variables = self.description['global'] self.local_dict = Equation('tmp', '', self.description, method='implicit', untouched=self.untouched).local_dict
def extract_spike_variable(description): cond = prepare_string(description['raw_spike']) if len(cond) > 1: Global.Global._print(description['raw_spike']) Global._error('The spike condition must be a single expression') translator = Equation('raw_spike_cond', cond[0].strip(), description) raw_spike_code = translator.parse() # Also store the variables used in the condition, as it may be needed for CUDA generation spike_code_dependencies = translator.dependencies() reset_desc = [] if 'raw_reset' in description.keys() and description['raw_reset']: reset_desc = process_equations(description['raw_reset']) for var in reset_desc: translator = Equation(var['name'], var['eq'], description) var['cpp'] = translator.parse() var['dependencies'] = translator.dependencies() return { 'spike_cond': raw_spike_code, 'spike_cond_dependencies': spike_code_dependencies, 'spike_reset': reset_desc}
def extract_spike_variable(description): cond = prepare_string(description['raw_spike']) if len(cond) > 1: _error('The spike condition must be a single expression') _print(description['raw_spike']) exit(0) translator = Equation('raw_spike_cond', cond[0].strip(), description) raw_spike_code = translator.parse() reset_desc = [] if 'raw_reset' in description.keys() and description['raw_reset']: reset_desc = process_equations(description['raw_reset']) for var in reset_desc: translator = Equation(var['name'], var['eq'], description) var['cpp'] = translator.parse() return { 'spike_cond': raw_spike_code, 'spike_reset': reset_desc}
def __init__(self, description, variables): self.description = description self.variables = variables self.untouched = variables[0]['untouched'] self.expression_list = {} for var in self.variables: self.expression_list[var['name']] = var['transformed_eq'] self.names = self.expression_list.keys() self.local_variables = self.description['local'] self.global_variables = self.description['global'] self.local_dict = Equation('tmp', '', self.description, method = 'implicit', untouched = self.untouched ).local_dict
def extract_spike_variable(description): cond = prepare_string(description['raw_spike']) if len(cond) > 1: _error('The spike condition must be a single expression') _print(description['raw_spike']) exit(0) translator = Equation('raw_spike_cond', cond[0].strip(), description) raw_spike_code = translator.parse() reset_desc = [] if 'raw_reset' in description.keys() and description['raw_reset']: reset_desc = process_equations(description['raw_reset']) for var in reset_desc: translator = Equation(var['name'], var['eq'], description) var['cpp'] = translator.parse() return {'spike_cond': raw_spike_code, 'spike_reset': reset_desc}
def extract_structural_plasticity(statement, description): # Extract flags try: eq, constraint = statement.rsplit(':', 1) bounds, flags = extract_flags(constraint) except: eq = statement.strip() bounds = {} flags = [] # Extract RD rd = None for dist in available_distributions: matches = re.findall('(?P<pre>[^\w.])'+dist+'\(([^()]+)\)', eq) for l, v in matches: # Check the arguments arguments = v.split(',') # Check the number of provided arguments if len(arguments) < distributions_arguments[dist]: _error(eq) _error('The distribution ' + dist + ' requires ' + str(distributions_arguments[dist]) + 'parameters') elif len(arguments) > distributions_arguments[dist]: _error(eq) _error('Too many parameters provided to the distribution ' + dist) # Process the arguments processed_arguments = "" for idx in range(len(arguments)): try: arg = float(arguments[idx]) except: # A global parameter _error(eq) _error('Random distributions for creating/pruning synapses must use foxed values.') exit(0) processed_arguments += str(arg) if idx != len(arguments)-1: # not the last one processed_arguments += ', ' definition = distributions_equivalents[dist] + '(' + processed_arguments + ')' # Store its definition if rd: _error(eq) _error('Only one random distribution per equation is allowed.') exit(0) rd = {'name': 'rand_' + str(0) , 'origin': dist+'('+v+')', 'dist': dist, 'definition': definition, 'args' : processed_arguments, 'template': distributions_equivalents[dist]} if rd: eq = eq.replace(rd['origin'], 'rd(rng)') # Extract pre/post dependencies eq, untouched, dependencies = extract_prepost('test', eq, description) # Parse code translator = Equation('test', eq, description, method = 'cond', untouched = {}) code = translator.parse() # Replace untouched variables with their original name for prev, new in untouched.items(): code = code.replace(prev, new) # Add new dependencies for dep in dependencies['pre']: description['dependencies']['pre'].append(dep) for dep in dependencies['post']: description['dependencies']['post'].append(dep) return {'eq': eq, 'cpp': code, 'bounds': bounds, 'flags': flags, 'rd': rd}
def analyse_neuron(neuron): """ Parses the structure and generates code snippets for the neuron type. It returns a ``description`` dictionary with the following fields: * 'object': 'neuron' by default, to distinguish it from 'synapse' * 'type': either 'rate' or 'spiking' * 'raw_parameters': provided field * 'raw_equations': provided field * 'raw_functions': provided field * 'raw_reset': provided field * 'raw_spike': provided field * 'refractory': provided field * 'parameters': list of parameters defined for the neuron type * 'variables': list of variables defined for the neuron type * 'functions': list of functions defined for the neuron type * 'attributes': list of names of all parameters and variables * 'local': list of names of parameters and variables which are local to each neuron * 'global': list of names of parameters and variables which are global to the population * 'targets': list of targets used in the equations * 'random_distributions': list of random number generators used in the neuron equations * 'global_operations': list of global operations (min/max/mean...) used in the equations (unused) * 'spike': when defined, contains the equations of the spike conditions and reset. Each parameter is a dictionary with the following elements: * 'bounds': unused * 'ctype': 'type of the parameter: 'float', 'double', 'int' or 'bool' * 'eq': original equation in text format * 'flags': list of flags provided after the : * 'init': initial value * 'locality': 'local' or 'global' * 'name': name of the parameter Each variable is a dictionary with the following elements: * 'bounds': dictionary of bounds ('init', 'min', 'max') provided after the : * 'cpp': C++ code snippet updating the variable * 'ctype': type of the variable: 'float', 'double', 'int' or 'bool' * 'dependencies': list of variable and parameter names on which the equation depends * 'eq': original equation in text format * 'flags': list of flags provided after the : * 'init': initial value * 'locality': 'local' or 'global' * 'method': numericalmethod for ODEs * 'name': name of the variable * 'pre_loop': ODEs have a pre_loop term for precomputing dt/tau. dict with 'name' and 'value'. type must be inferred. * 'switch': ODEs have a switch term * 'transformed_eq': same as eq, except special terms (sums, rds) are replaced with a temporary name * 'untouched': dictionary of special terms, with their new name as keys and replacement values as values. The 'spike' element (when present) is a dictionary containing: * 'spike_cond': the C++ code snippet containing the spike condition ("v%(local_index)s > v_T") * 'spike_cond_dependencies': list of variables/parameters on which the spike condition depends * 'spike_reset': a list of reset statements, each of them composed of : * 'constraint': either '' or 'unless_refractory' * 'cpp': C++ code snippet * 'dependencies': list of variables on which the reset statement depends * 'eq': original equation in text format * 'name': name of the reset variable """ # Store basic information description = { 'object': 'neuron', 'type': neuron.type, 'raw_parameters': neuron.parameters, 'raw_equations': neuron.equations, 'raw_functions': neuron.functions, } # Spiking neurons additionally store the spike condition, the reset statements and a refractory period if neuron.type == 'spike': description['raw_reset'] = neuron.reset description['raw_spike'] = neuron.spike description['raw_axon_spike'] = neuron.axon_spike description['raw_axon_reset'] = neuron.axon_reset description['refractory'] = neuron.refractory # Extract parameters and variables names parameters = extract_parameters(neuron.parameters, neuron.extra_values) variables = extract_variables(neuron.equations) description['parameters'] = parameters description['variables'] = variables # Make sure r is defined for rate-coded networks from ANNarchy.extensions.bold.BoldModel import BoldModel if isinstance(neuron, BoldModel): found = False for var in description['parameters'] + description['variables']: if var['name'] == 'r': found = True if not found: description['variables'].append({ 'name': 'r', 'locality': 'local', 'bounds': {}, 'ctype': config['precision'], 'init': 0.0, 'flags': [], 'eq': '', 'cpp': "" }) elif neuron.type == 'rate': for var in description['parameters'] + description['variables']: if var['name'] == 'r': break else: _error('Rate-coded neurons must define the variable "r".') else: # spiking neurons define r by default, it contains the average FR if enabled for var in description['parameters'] + description['variables']: if var['name'] == 'r': _error( 'Spiking neurons use the variable "r" for the average FR, use another name.' ) description['variables'].append({ 'name': 'r', 'locality': 'local', 'bounds': {}, 'ctype': config['precision'], 'init': 0.0, 'flags': [], 'eq': '', 'cpp': "" }) # Extract functions functions = extract_functions(neuron.functions, False) description['functions'] = functions # Build lists of all attributes (param + var), which are local or global attributes, local_var, global_var, _ = get_attributes(parameters, variables, neuron=True) # Test if attributes are declared only once if len(attributes) != len(list(set(attributes))): _error('Attributes must be declared only once.', attributes) # Store the attributes description['attributes'] = attributes description['local'] = local_var description['semiglobal'] = [] # only for projections description['global'] = global_var # Extract all targets targets = sorted(list(set(extract_targets(variables)))) description['targets'] = targets if neuron.type == 'spike': # Add a default reset behaviour for conductances for target in targets: found = False for var in description['variables']: if var['name'] == 'g_' + target: found = True break if not found: description['variables'].append({ 'name': 'g_' + target, 'locality': 'local', 'bounds': {}, 'ctype': config['precision'], 'init': 0.0, 'flags': [], 'eq': 'g_' + target + ' = 0.0' }) description['attributes'].append('g_' + target) description['local'].append('g_' + target) # Extract RandomDistribution objects random_distributions = extract_randomdist(description) description['random_distributions'] = random_distributions # Extract the spike condition if any if neuron.type == 'spike': description['spike'] = extract_spike_variable(description) description['axon_spike'] = extract_axon_spike_condition(description) # Global operation TODO description['global_operations'] = [] # The ODEs may be interdependent (implicit, midpoint), so they need to be passed explicitely to CoupledEquations concurrent_odes = [] # Translate the equations to C++ for variable in description['variables']: # Get the equation eq = variable['transformed_eq'] if eq.strip() == "": continue # Special variables (sums, global operations, rd) are placed in untouched, so that Sympy ignores them untouched = {} # Dependencies must be gathered dependencies = [] # Replace sum(target) with _sum_exc__[i] for target in description['targets']: # sum() is valid for all targets eq = re.sub(r'(?P<pre>[^\w.])sum\(\)', r'\1sum(__all__)', eq) # Replace sum(target) with __sum_target__ eq = re.sub('sum\(\s*' + target + '\s*\)', '__sum_' + target + '__', eq) untouched['__sum_' + target + '__'] = '_sum_' + target + '%(local_index)s' # Extract global operations eq, untouched_globs, global_ops = extract_globalops_neuron( variable['name'], eq, description) # Add the untouched variables to the global list for name, val in untouched_globs.items(): if not name in untouched.keys(): untouched[name] = val description['global_operations'] += global_ops # Extract if-then-else statements eq, condition = extract_ite(variable['name'], eq, description) # Find the numerical method if any method = find_method(variable) # Process the bounds if 'min' in variable['bounds'].keys(): if isinstance(variable['bounds']['min'], str): translator = Equation(variable['name'], variable['bounds']['min'], description, type='return', untouched=untouched) variable['bounds']['min'] = translator.parse().replace(';', '') dependencies += translator.dependencies() if 'max' in variable['bounds'].keys(): if isinstance(variable['bounds']['max'], str): translator = Equation(variable['name'], variable['bounds']['max'], description, type='return', untouched=untouched) variable['bounds']['max'] = translator.parse().replace(';', '') dependencies += translator.dependencies() # Analyse the equation if condition == []: # No if-then-else translator = Equation(variable['name'], eq, description, method=method, untouched=untouched) code = translator.parse() dependencies += translator.dependencies() else: # An if-then-else statement code, deps = translate_ITE(variable['name'], eq, condition, description, untouched) dependencies += deps # ODEs have a switch statement: # double _r = (1.0 - r)/tau; # r[i] += dt* _r; # while direct assignments are one-liners: # r[i] = 1.0 if isinstance(code, str): pre_loop = {} cpp_eq = code switch = None else: # ODE pre_loop = code[0] cpp_eq = code[1] switch = code[2] # Replace untouched variables with their original name for prev, new in untouched.items(): if prev.startswith('g_'): cpp_eq = re.sub(r'([^_]+)' + prev, r'\1' + new, ' ' + cpp_eq).strip() if len(pre_loop) > 0: pre_loop['value'] = re.sub(r'([^_]+)' + prev, new, ' ' + pre_loop['value']).strip() if switch: switch = re.sub(r'([^_]+)' + prev, new, ' ' + switch).strip() else: cpp_eq = re.sub(prev, new, cpp_eq) if len(pre_loop) > 0: pre_loop['value'] = re.sub(prev, new, pre_loop['value']) if switch: switch = re.sub(prev, new, switch) # Replace local functions for f in description['functions']: cpp_eq = re.sub(r'([^\w]*)' + f['name'] + '\(', r'\1' + f['name'] + '(', ' ' + cpp_eq).strip() # Store the result variable[ 'pre_loop'] = pre_loop # Things to be declared before the for loop (eg. dt) variable['cpp'] = cpp_eq # the C++ equation variable['switch'] = switch # switch value of ODE variable['untouched'] = untouched # may be needed later variable['method'] = method # may be needed later variable['dependencies'] = list( set(dependencies)) # may be needed later # If the method is implicit or midpoint, the equations must be solved concurrently (depend on v[t+1]) if method in ['implicit', 'midpoint', 'runge-kutta4' ] and switch is not None: concurrent_odes.append(variable) # After all variables are processed, do it again if they are concurrent if len(concurrent_odes) > 1: solver = CoupledEquations(description, concurrent_odes) new_eqs = solver.parse() for idx, variable in enumerate(description['variables']): for new_eq in new_eqs: if variable['name'] == new_eq['name']: description['variables'][idx] = new_eq return description
def analyse_synapse(synapse): """ Parses the structure and generates code snippets for the synapse type. It returns a ``description`` dictionary with the following fields: * 'object': 'synapse' by default, to distinguish it from 'neuron' * 'type': either 'rate' or 'spiking' * 'raw_parameters': provided field * 'raw_equations': provided field * 'raw_functions': provided field * 'raw_psp': provided field * 'raw_pre_spike': provided field * 'raw_post_spike': provided field * 'parameters': list of parameters defined for the synapse type * 'variables': list of variables defined for the synapse type * 'functions': list of functions defined for the synapse type * 'attributes': list of names of all parameters and variables * 'local': list of names of parameters and variables which are local to each synapse * 'semiglobal': list of names of parameters and variables which are local to each postsynaptic neuron * 'global': list of names of parameters and variables which are global to the projection * 'random_distributions': list of random number generators used in the neuron equations * 'global_operations': list of global operations (min/max/mean...) used in the equations * 'pre_global_operations': list of global operations (min/max/mean...) on the pre-synaptic population * 'post_global_operations': list of global operations (min/max/mean...) on the post-synaptic population * 'pre_spike': list of variables updated after a pre-spike event * 'post_spike': list of variables updated after a post-spike event * 'dependencies': dictionary ('pre', 'post') of lists of pre (resp. post) variables accessed by the synapse (used for delaying variables) * 'psp': dictionary ('eq' and 'psp') for the psp code to be summed * 'pruning' and 'creating': statements for structural plasticity Each parameter is a dictionary with the following elements: * 'bounds': unused * 'ctype': 'type of the parameter: 'float', 'double', 'int' or 'bool' * 'eq': original equation in text format * 'flags': list of flags provided after the : * 'init': initial value * 'locality': 'local', 'semiglobal' or 'global' * 'name': name of the parameter Each variable is a dictionary with the following elements: * 'bounds': dictionary of bounds ('init', 'min', 'max') provided after the : * 'cpp': C++ code snippet updating the variable * 'ctype': type of the variable: 'float', 'double', 'int' or 'bool' * 'dependencies': list of variable and parameter names on which the equation depends * 'eq': original equation in text format * 'flags': list of flags provided after the : * 'init': initial value * 'locality': 'local', 'semiglobal' or 'global' * 'method': numericalmethod for ODEs * 'name': name of the variable * 'pre_loop': ODEs have a pre_loop term for precomputing dt/tau * 'switch': ODEs have a switch term * 'transformed_eq': same as eq, except special terms (sums, rds) are replaced with a temporary name * 'untouched': dictionary of special terms, with their new name as keys and replacement values as values. """ # Store basic information description = { 'object': 'synapse', 'type': synapse.type, 'raw_parameters': synapse.parameters, 'raw_equations': synapse.equations, 'raw_functions': synapse.functions } # Psps is what is actually summed over the incoming weights if synapse.psp: description['raw_psp'] = synapse.psp elif synapse.type == 'rate': description['raw_psp'] = "w*pre.r" # Spiking synapses additionally store pre_spike and post_spike if synapse.type == 'spike': description['raw_pre_spike'] = synapse.pre_spike description['raw_post_spike'] = synapse.post_spike # Extract parameters and variables names parameters = extract_parameters(synapse.parameters, synapse.extra_values) variables = extract_variables(synapse.equations) # Extract functions functions = extract_functions(synapse.functions, False) # Check the presence of w description['plasticity'] = False for var in parameters + variables: if var['name'] == 'w': break else: parameters.append({ 'name': 'w', 'bounds': {}, 'ctype': config['precision'], 'init': 0.0, 'flags': [], 'eq': 'w=0.0', 'locality': 'local' }) # Find out a plasticity rule for var in variables: if var['name'] == 'w': description['plasticity'] = True break # Build lists of all attributes (param+var), which are local or global attributes, local_var, global_var, semiglobal_var = get_attributes( parameters, variables, neuron=False) # Test if attributes are declared only once if len(attributes) != len(list(set(attributes))): _error('Attributes must be declared only once.', attributes) # Add this info to the description description['parameters'] = parameters description['variables'] = variables description['functions'] = functions description['attributes'] = attributes description['local'] = local_var description['semiglobal'] = semiglobal_var description['global'] = global_var description['global_operations'] = [] # Lists of global operations needed at the pre and post populations description['pre_global_operations'] = [] description['post_global_operations'] = [] # Extract RandomDistribution objects description['random_distributions'] = extract_randomdist(description) # Extract event-driven info if description['type'] == 'spike': # pre_spike event description['pre_spike'] = extract_pre_spike_variable(description) for var in description['pre_spike']: if var['name'] in ['g_target']: # Already dealt with continue for avar in description['variables']: if var['name'] == avar['name']: break else: # not defined already description['variables'].append({ 'name': var['name'], 'bounds': var['bounds'], 'ctype': var['ctype'], 'init': var['init'], 'locality': var['locality'], 'flags': [], 'transformed_eq': '', 'eq': '', 'cpp': '', 'switch': '', 're_loop': '', 'untouched': '', 'method': 'explicit' }) description['local'].append(var['name']) description['attributes'].append(var['name']) # post_spike event description['post_spike'] = extract_post_spike_variable(description) for var in description['post_spike']: if var['name'] in ['g_target', 'w']: # Already dealt with continue for avar in description['variables']: if var['name'] == avar['name']: break else: # not defined already description['variables'].append({ 'name': var['name'], 'bounds': var['bounds'], 'ctype': var['ctype'], 'init': var['init'], 'locality': var['locality'], 'flags': [], 'transformed_eq': '', 'eq': '', 'cpp': '', 'switch': '', 'untouched': '', 'method': 'explicit' }) description['local'].append(var['name']) description['attributes'].append(var['name']) # Variables names for the parser which should be left untouched untouched = {} description['dependencies'] = {'pre': [], 'post': []} # The ODEs may be interdependent (implicit, midpoint), so they need to be passed explicitely to CoupledEquations concurrent_odes = [] # Iterate over all variables for variable in description['variables']: # Equation eq = variable['transformed_eq'] if eq.strip() == '': continue # Dependencies must be gathered dependencies = [] # Extract global operations eq, untouched_globs, global_ops = extract_globalops_synapse( variable['name'], eq, description) description['pre_global_operations'] += global_ops['pre'] description['post_global_operations'] += global_ops['post'] # Remove doubled entries description['pre_global_operations'] = [ i for n, i in enumerate(description['pre_global_operations']) if i not in description['pre_global_operations'][n + 1:] ] description['post_global_operations'] = [ i for n, i in enumerate(description['post_global_operations']) if i not in description['post_global_operations'][n + 1:] ] # Extract pre- and post_synaptic variables eq, untouched_var, prepost_dependencies = extract_prepost( variable['name'], eq, description) # Store the pre-post dependencies at the synapse level description['dependencies']['pre'] += prepost_dependencies['pre'] description['dependencies']['post'] += prepost_dependencies['post'] # and also on the variable for checking variable['prepost_dependencies'] = prepost_dependencies # Extract if-then-else statements eq, condition = extract_ite(variable['name'], eq, description) # Add the untouched variables to the global list for name, val in untouched_globs.items(): if not name in untouched.keys(): untouched[name] = val for name, val in untouched_var.items(): if not name in untouched.keys(): untouched[name] = val # Save the tranformed equation variable['transformed_eq'] = eq # Find the numerical method if any method = find_method(variable) # Process the bounds if 'min' in variable['bounds'].keys(): if isinstance(variable['bounds']['min'], str): translator = Equation(variable['name'], variable['bounds']['min'], description, type='return', untouched=untouched.keys()) variable['bounds']['min'] = translator.parse().replace(';', '') dependencies += translator.dependencies() if 'max' in variable['bounds'].keys(): if isinstance(variable['bounds']['max'], str): translator = Equation(variable['name'], variable['bounds']['max'], description, type='return', untouched=untouched.keys()) variable['bounds']['max'] = translator.parse().replace(';', '') dependencies += translator.dependencies() # Analyse the equation if condition == []: # Call Equation translator = Equation(variable['name'], eq, description, method=method, untouched=untouched.keys()) code = translator.parse() dependencies += translator.dependencies() else: # An if-then-else statement code, deps = translate_ITE(variable['name'], eq, condition, description, untouched) dependencies += deps if isinstance(code, str): pre_loop = {} cpp_eq = code switch = None else: # ODE pre_loop = code[0] cpp_eq = code[1] switch = code[2] # Replace untouched variables with their original name for prev, new in untouched.items(): cpp_eq = cpp_eq.replace(prev, new) # Replace local functions for f in description['functions']: cpp_eq = re.sub(r'([^\w]*)' + f['name'] + '\(', r'\1' + f['name'] + '(', ' ' + cpp_eq).strip() # Store the result variable[ 'pre_loop'] = pre_loop # Things to be declared before the for loop (eg. dt) variable['cpp'] = cpp_eq # the C++ equation variable['switch'] = switch # switch value id ODE variable['untouched'] = untouched # may be needed later variable['method'] = method # may be needed later variable['dependencies'] = dependencies # If the method is implicit or midpoint, the equations must be solved concurrently (depend on v[t+1]) if method in ['implicit', 'midpoint'] and switch is not None: concurrent_odes.append(variable) # After all variables are processed, do it again if they are concurrent if len(concurrent_odes) > 1: solver = CoupledEquations(description, concurrent_odes) new_eqs = solver.parse() for idx, variable in enumerate(description['variables']): for new_eq in new_eqs: if variable['name'] == new_eq['name']: description['variables'][idx] = new_eq # Translate the psp code if any if 'raw_psp' in description.keys(): psp = {'eq': description['raw_psp'].strip()} # Extract global operations eq, untouched_globs, global_ops = extract_globalops_synapse( 'psp', " " + psp['eq'] + " ", description) description['pre_global_operations'] += global_ops['pre'] description['post_global_operations'] += global_ops['post'] # Replace pre- and post_synaptic variables eq, untouched, prepost_dependencies = extract_prepost( 'psp', eq, description) description['dependencies']['pre'] += prepost_dependencies['pre'] description['dependencies']['post'] += prepost_dependencies['post'] for name, val in untouched_globs.items(): if not name in untouched.keys(): untouched[name] = val # Extract if-then-else statements eq, condition = extract_ite('psp', eq, description, split=False) # Analyse the equation if condition == []: translator = Equation('psp', eq, description, method='explicit', untouched=untouched.keys(), type='return') code = translator.parse() deps = translator.dependencies() else: code, deps = translate_ITE('psp', eq, condition, description, untouched) # Replace untouched variables with their original name for prev, new in untouched.items(): code = code.replace(prev, new) # Store the result psp['cpp'] = code psp['dependencies'] = deps description['psp'] = psp # Process event-driven info if description['type'] == 'spike': for variable in description['pre_spike'] + description['post_spike']: # Find plasticity if variable['name'] == 'w': description['plasticity'] = True # Retrieve the equation eq = variable['eq'] # Extract if-then-else statements eq, condition = extract_ite(variable['name'], eq, description) # Extract pre- and post_synaptic variables eq, untouched, prepost_dependencies = extract_prepost( variable['name'], eq, description) # Update dependencies description['dependencies']['pre'] += prepost_dependencies['pre'] description['dependencies']['post'] += prepost_dependencies['post'] # and also on the variable for checking variable['prepost_dependencies'] = prepost_dependencies # Analyse the equation dependencies = [] if condition == []: translator = Equation(variable['name'], eq, description, method='explicit', untouched=untouched) code = translator.parse() dependencies += translator.dependencies() else: code, deps = translate_ITE(variable['name'], eq, condition, description, untouched) dependencies += deps if isinstance(code, list): # an ode in a pre/post statement Global._print(eq) if variable in description['pre_spike']: Global._error( 'It is forbidden to use ODEs in a pre_spike term.') elif variable in description['posz_spike']: Global._error( 'It is forbidden to use ODEs in a post_spike term.') else: Global._error('It is forbidden to use ODEs here.') # Replace untouched variables with their original name for prev, new in untouched.items(): code = code.replace(prev, new) # Process the bounds if 'min' in variable['bounds'].keys(): if isinstance(variable['bounds']['min'], str): translator = Equation(variable['name'], variable['bounds']['min'], description, type='return', untouched=untouched) variable['bounds']['min'] = translator.parse().replace( ';', '') dependencies += translator.dependencies() if 'max' in variable['bounds'].keys(): if isinstance(variable['bounds']['max'], str): translator = Equation(variable['name'], variable['bounds']['max'], description, type='return', untouched=untouched) variable['bounds']['max'] = translator.parse().replace( ';', '') dependencies += translator.dependencies() # Store the result variable['cpp'] = code # the C++ equation variable['dependencies'] = dependencies # Structural plasticity if synapse.pruning: description['pruning'] = extract_structural_plasticity( synapse.pruning, description) if synapse.creating: description['creating'] = extract_structural_plasticity( synapse.creating, description) return description
class CoupledEquations(object): def __init__(self, description, variables): self.description = description self.variables = variables self.untouched = variables[0]['untouched'] self.expression_list = {} for var in self.variables: self.expression_list[var['name']] = var['transformed_eq'] self.names = self.expression_list.keys() self.local_variables = self.description['local'] self.global_variables = self.description['global'] self.local_dict = Equation('tmp', '', self.description, method='implicit', untouched=self.untouched).local_dict def process_variables(self): # Check if the numerical method is the same for all ODEs methods = [] for var in self.variables: methods.append(var['method']) if len(list(set(methods))) > 1: # mixture of methods _error( 'Can not mix different numerical methods when solving a coupled system of equations.' ) exit(0) else: method = methods[0] if method == 'implicit' or method == 'semiimplicit': return self.solve_implicit(self.expression_list) elif method == 'midpoint': return self.solve_midpoint(self.expression_list) def solve_implicit(self, expression_list): equations = {} new_vars = {} # Pre-processing to replace the gradient for name, expression in self.expression_list.items(): # transform the expression to suppress = if '=' in expression: expression = expression.replace('=', '- (') expression += ')' # Suppress spaces to extract dvar/dt expression = expression.replace(' ', '') # Transform the gradient into a difference TODO: more robust... expression = expression.replace('d' + name, '_t_gradient_') expression_list[name] = expression # replace the variables by their future value for name, expression in expression_list.items(): for n in self.names: expression = re.sub(r'([^\w]+)' + n + r'([^\w]+)', r'\1_' + n + r'\2', expression) expression = expression.replace('_t_gradient_', '(_' + name + ' - ' + name + ')') expression_list[name] = expression + '-' + name new_var = Symbol('_' + name) self.local_dict['_' + name] = new_var new_vars[new_var] = name for name, expression in expression_list.items(): analysed = parse_expr(expression, local_dict=self.local_dict, transformations=(standard_transformations + (convert_xor, ))) equations[name] = analysed try: solution = solve(equations.values(), new_vars.keys()) except: _error( 'The multiple ODEs can not be solved together using the implicit Euler method.' ) exit(0) for var, sol in solution.items(): # simplify the solution sol = collect(sol, self.local_dict['dt']) # Generate the code cpp_eq = 'double _' + new_vars[var] + ' = ' + ccode(sol) + ';' switch = ccode( self.local_dict[new_vars[var]]) + ' += _' + new_vars[var] + ';' # Replace untouched variables with their original name for prev, new in self.untouched.items(): cpp_eq = re.sub(prev, new, cpp_eq) switch = re.sub(prev, new, switch) # Store the result for variable in self.variables: if variable['name'] == new_vars[var]: variable['cpp'] = cpp_eq variable['switch'] = switch return self.variables def solve_midpoint(self, expression_list): expression_list = {} equations = {} evaluations = {} # Pre-processing to replace the gradient for name, expression in self.expression_list.items(): # transform the expression to suppress = if '=' in expression: expression = expression.replace('=', '- (') expression += ')' # Suppress spaces to extract dvar/dt expression = expression.replace(' ', '') # Transform the gradient into a difference TODO: more robust... expression = expression.replace('d' + name + '/dt', '_gradient_' + name) self.local_dict['_gradient_' + name] = Symbol('_gradient_' + name) expression_list[name] = expression for name, expression in expression_list.items(): analysed = parse_expr(expression, local_dict=self.local_dict, transformations=(standard_transformations + (convert_xor, ))) equations[name] = analysed evaluations[name] = solve(analysed, self.local_dict['_gradient_' + name]) # Compute the k = f(x, t) ks = {} for name, evaluation in evaluations.items(): ks[name] = 'double _k_' + name + ' = ' + ccode(evaluation[0]) + ';' # New dictionary replacing x by x+dt/2*k) tmp_dict = {} for name, val in self.local_dict.items(): tmp_dict[name] = val for name, evaluation in evaluations.items(): tmp_dict[name] = Symbol('(' + ccode(self.local_dict[name]) + ' + 0.5*dt*_k_' + name + ' )') # Compute the new values _x_new = f(x + dt/2*_k) news = {} for name, expression in expression_list.items(): tmp_analysed = parse_expr( expression, local_dict=tmp_dict, transformations=(standard_transformations + (convert_xor, ))) solved = solve(tmp_analysed, self.local_dict['_gradient_' + name]) news[name] = 'double _' + name + ' = ' + ccode(solved[0]) + ';' # Compute the switches switches = {} for name, expression in expression_list.items(): switches[name] = ccode( self.local_dict[name]) + ' += dt * _' + name + ' ;' # Store the generated code in the variables for name in self.names: k = ks[name] n = news[name] switch = switches[name] # Replace untouched variables with their original name for prev, new in self.untouched.items(): k = re.sub(prev, new, k) n = re.sub(prev, new, n) switch = re.sub(prev, new, switch) # Store the result for variable in self.variables: if variable['name'] == name: variable['cpp'] = [k, n] variable['switch'] = switch return self.variables expression = expression.replace('d' + self.name + '/dt', '_grad_var_') new_var = Symbol('_grad_var_') self.local_dict['_grad_var_'] = new_var analysed = self.parse_expression(expression, local_dict=self.local_dict) variable_name = self.local_dict[self.name] equation = simplify( collect(solve(analysed, new_var)[0], self.local_dict['dt'])) explicit_code = 'double _k_' + self.name + ' = dt*(' + self.c_code( equation) + ');' # Midpoint method: # Replace the variable x by x+_x/2 tmp_dict = self.local_dict tmp_dict[self.name] = Symbol('(' + self.c_code(variable_name) + ' + 0.5*_k_' + self.name + ' )') tmp_analysed = self.parse_expression(expression, local_dict=self.local_dict) tmp_equation = solve(tmp_analysed, new_var)[0] explicit_code += '\n double _' + self.name + ' = ' + self.c_code( tmp_equation) + ';' switch = self.c_code(variable_name) + ' += dt*_' + self.name + ' ;' # Return result return [explicit_code, switch]
def extract_structural_plasticity(statement, description): # Extract flags try: eq, constraint = statement.rsplit(':', 1) bounds, flags = extract_flags(constraint) except: eq = statement.strip() bounds = {} flags = [] # Extract RD rd = None for dist in available_distributions: matches = re.findall('(?P<pre>[^\w.])' + dist + '\(([^()]+)\)', eq) for l, v in matches: # Check the arguments arguments = v.split(',') # Check the number of provided arguments if len(arguments) < distributions_arguments[dist]: Global._print(eq) Global._error('The distribution ' + dist + ' requires ' + str(distributions_arguments[dist]) + 'parameters') elif len(arguments) > distributions_arguments[dist]: Global._print(eq) Global._error( 'Too many parameters provided to the distribution ' + dist) # Process the arguments processed_arguments = "" for idx in range(len(arguments)): try: arg = float(arguments[idx]) except: # A global parameter Global._print(eq) Global._error( 'Random distributions for creating/pruning synapses must use foxed values.' ) processed_arguments += str(arg) if idx != len(arguments) - 1: # not the last one processed_arguments += ', ' definition = distributions_equivalents[ dist] + '(' + processed_arguments + ')' # Store its definition if rd: Global._print(eq) Global._error( 'Only one random distribution per equation is allowed.') rd = { 'name': 'rand_' + str(0), 'origin': dist + '(' + v + ')', 'dist': dist, 'definition': definition, 'args': processed_arguments, 'template': distributions_equivalents[dist] } if rd: eq = eq.replace(rd['origin'], 'rd(rng)') # Extract pre/post dependencies eq, untouched, dependencies = extract_prepost('test', eq, description) # Parse code translator = Equation('test', eq, description, method='cond', untouched={}) code = translator.parse() deps = translator.dependencies() # Replace untouched variables with their original name for prev, new in untouched.items(): code = code.replace(prev, new) # Add new dependencies for dep in dependencies['pre']: description['dependencies']['pre'].append(dep) for dep in dependencies['post']: description['dependencies']['post'].append(dep) return { 'eq': eq, 'cpp': code, 'bounds': bounds, 'flags': flags, 'rd': rd, 'dependencies': deps }
def analyse_synapse(synapse): """ Parses the structure and generates code snippets for the synapse type. It returns a ``description`` dictionary with the following fields: * 'object': 'synapse' by default, to distinguish it from 'neuron' * 'type': either 'rate' or 'spiking' * 'raw_parameters': provided field * 'raw_equations': provided field * 'raw_functions': provided field * 'raw_psp': provided field * 'raw_pre_spike': provided field * 'raw_post_spike': provided field * 'parameters': list of parameters defined for the synapse type * 'variables': list of variables defined for the synapse type * 'functions': list of functions defined for the synapse type * 'attributes': list of names of all parameters and variables * 'local': list of names of parameters and variables which are local to each synapse * 'semiglobal': list of names of parameters and variables which are local to each postsynaptic neuron * 'global': list of names of parameters and variables which are global to the projection * 'random_distributions': list of random number generators used in the neuron equations * 'global_operations': list of global operations (min/max/mean...) used in the equations * 'pre_global_operations': list of global operations (min/max/mean...) on the pre-synaptic population * 'post_global_operations': list of global operations (min/max/mean...) on the post-synaptic population * 'pre_spike': list of variables updated after a pre-spike event * 'post_spike': list of variables updated after a post-spike event * 'dependencies': dictionary ('pre', 'post') of lists of pre (resp. post) variables accessed by the synapse (used for delaying variables) * 'psp': dictionary ('eq' and 'psp') for the psp code to be summed * 'pruning' and 'creating': statements for structural plasticity Each parameter is a dictionary with the following elements: * 'bounds': unused * 'ctype': 'type of the parameter: 'float', 'double', 'int' or 'bool' * 'eq': original equation in text format * 'flags': list of flags provided after the : * 'init': initial value * 'locality': 'local', 'semiglobal' or 'global' * 'name': name of the parameter Each variable is a dictionary with the following elements: * 'bounds': dictionary of bounds ('init', 'min', 'max') provided after the : * 'cpp': C++ code snippet updating the variable * 'ctype': type of the variable: 'float', 'double', 'int' or 'bool' * 'dependencies': list of variable and parameter names on which the equation depends * 'eq': original equation in text format * 'flags': list of flags provided after the : * 'init': initial value * 'locality': 'local', 'semiglobal' or 'global' * 'method': numericalmethod for ODEs * 'name': name of the variable * 'pre_loop': ODEs have a pre_loop term for precomputing dt/tau * 'switch': ODEs have a switch term * 'transformed_eq': same as eq, except special terms (sums, rds) are replaced with a temporary name * 'untouched': dictionary of special terms, with their new name as keys and replacement values as values. """ # Store basic information description = { 'object': 'synapse', 'type': synapse.type, 'raw_parameters': synapse.parameters, 'raw_equations': synapse.equations, 'raw_functions': synapse.functions } # Psps is what is actually summed over the incoming weights if synapse.psp: description['raw_psp'] = synapse.psp elif synapse.type == 'rate': description['raw_psp'] = "w*pre.r" # Spiking synapses additionally store pre_spike and post_spike if synapse.type == 'spike': description['raw_pre_spike'] = synapse.pre_spike description['raw_post_spike'] = synapse.post_spike # Extract parameters and variables names parameters = extract_parameters(synapse.parameters, synapse.extra_values) variables = extract_variables(synapse.equations) # Extract functions functions = extract_functions(synapse.functions, False) # Check the presence of w description['plasticity'] = False for var in parameters + variables: if var['name'] == 'w': break else: parameters.append( { 'name': 'w', 'bounds': {}, 'ctype': config['precision'], 'init': 0.0, 'flags': [], 'eq': 'w=0.0', 'locality': 'local' } ) # Find out a plasticity rule for var in variables: if var['name'] == 'w': description['plasticity'] = True break # Build lists of all attributes (param+var), which are local or global attributes, local_var, global_var, semiglobal_var = get_attributes(parameters, variables, neuron=False) # Test if attributes are declared only once if len(attributes) != len(list(set(attributes))): _error('Attributes must be declared only once.', attributes) # Add this info to the description description['parameters'] = parameters description['variables'] = variables description['functions'] = functions description['attributes'] = attributes description['local'] = local_var description['semiglobal'] = semiglobal_var description['global'] = global_var description['global_operations'] = [] # Lists of global operations needed at the pre and post populations description['pre_global_operations'] = [] description['post_global_operations'] = [] # Extract RandomDistribution objects description['random_distributions'] = extract_randomdist(description) # Extract event-driven info if description['type'] == 'spike': # pre_spike event description['pre_spike'] = extract_pre_spike_variable(description) for var in description['pre_spike']: if var['name'] in ['g_target']: # Already dealt with continue for avar in description['variables']: if var['name'] == avar['name']: break else: # not defined already description['variables'].append( {'name': var['name'], 'bounds': var['bounds'], 'ctype': var['ctype'], 'init': var['init'], 'locality': var['locality'], 'flags': [], 'transformed_eq': '', 'eq': '', 'cpp': '', 'switch': '', 're_loop': '', 'untouched': '', 'method':'explicit'} ) description['local'].append(var['name']) description['attributes'].append(var['name']) # post_spike event description['post_spike'] = extract_post_spike_variable(description) for var in description['post_spike']: if var['name'] in ['g_target', 'w']: # Already dealt with continue for avar in description['variables']: if var['name'] == avar['name']: break else: # not defined already description['variables'].append( {'name': var['name'], 'bounds': var['bounds'], 'ctype': var['ctype'], 'init': var['init'], 'locality': var['locality'], 'flags': [], 'transformed_eq': '', 'eq': '', 'cpp': '', 'switch': '', 'untouched': '', 'method':'explicit'} ) description['local'].append(var['name']) description['attributes'].append(var['name']) # Variables names for the parser which should be left untouched untouched = {} description['dependencies'] = {'pre': [], 'post': []} # The ODEs may be interdependent (implicit, midpoint), so they need to be passed explicitely to CoupledEquations concurrent_odes = [] # Iterate over all variables for variable in description['variables']: # Equation eq = variable['transformed_eq'] if eq.strip() == '': continue # Dependencies must be gathered dependencies = [] # Extract global operations eq, untouched_globs, global_ops = extract_globalops_synapse(variable['name'], eq, description) description['pre_global_operations'] += global_ops['pre'] description['post_global_operations'] += global_ops['post'] # Remove doubled entries description['pre_global_operations'] = [i for n, i in enumerate(description['pre_global_operations']) if i not in description['pre_global_operations'][n + 1:]] description['post_global_operations'] = [i for n, i in enumerate(description['post_global_operations']) if i not in description['post_global_operations'][n + 1:]] # Extract pre- and post_synaptic variables eq, untouched_var, prepost_dependencies = extract_prepost(variable['name'], eq, description) # Store the pre-post dependencies at the synapse level description['dependencies']['pre'] += prepost_dependencies['pre'] description['dependencies']['post'] += prepost_dependencies['post'] # and also on the variable for checking variable['prepost_dependencies'] = prepost_dependencies # Extract if-then-else statements eq, condition = extract_ite(variable['name'], eq, description) # Add the untouched variables to the global list for name, val in untouched_globs.items(): if not name in untouched.keys(): untouched[name] = val for name, val in untouched_var.items(): if not name in untouched.keys(): untouched[name] = val # Save the tranformed equation variable['transformed_eq'] = eq # Find the numerical method if any method = find_method(variable) # Process the bounds if 'min' in variable['bounds'].keys(): if isinstance(variable['bounds']['min'], str): translator = Equation(variable['name'], variable['bounds']['min'], description, type = 'return', untouched = untouched.keys()) variable['bounds']['min'] = translator.parse().replace(';', '') dependencies += translator.dependencies() if 'max' in variable['bounds'].keys(): if isinstance(variable['bounds']['max'], str): translator = Equation(variable['name'], variable['bounds']['max'], description, type = 'return', untouched = untouched.keys()) variable['bounds']['max'] = translator.parse().replace(';', '') dependencies += translator.dependencies() # Analyse the equation if condition == []: # Call Equation translator = Equation(variable['name'], eq, description, method = method, untouched = untouched.keys()) code = translator.parse() dependencies += translator.dependencies() else: # An if-then-else statement code, deps = translate_ITE(variable['name'], eq, condition, description, untouched) dependencies += deps if isinstance(code, str): pre_loop = {} cpp_eq = code switch = None else: # ODE pre_loop = code[0] cpp_eq = code[1] switch = code[2] # Replace untouched variables with their original name for prev, new in untouched.items(): cpp_eq = cpp_eq.replace(prev, new) # Replace local functions for f in description['functions']: cpp_eq = re.sub(r'([^\w]*)'+f['name']+'\(', r'\1'+ f['name'] + '(', ' ' + cpp_eq).strip() # Store the result variable['pre_loop'] = pre_loop # Things to be declared before the for loop (eg. dt) variable['cpp'] = cpp_eq # the C++ equation variable['switch'] = switch # switch value id ODE variable['untouched'] = untouched # may be needed later variable['method'] = method # may be needed later variable['dependencies'] = dependencies # If the method is implicit or midpoint, the equations must be solved concurrently (depend on v[t+1]) if method in ['implicit', 'midpoint'] and switch is not None: concurrent_odes.append(variable) # After all variables are processed, do it again if they are concurrent if len(concurrent_odes) > 1 : solver = CoupledEquations(description, concurrent_odes) new_eqs = solver.parse() for idx, variable in enumerate(description['variables']): for new_eq in new_eqs: if variable['name'] == new_eq['name']: description['variables'][idx] = new_eq # Translate the psp code if any if 'raw_psp' in description.keys(): psp = {'eq' : description['raw_psp'].strip() } # Extract global operations eq, untouched_globs, global_ops = extract_globalops_synapse('psp', " " + psp['eq'] + " ", description) description['pre_global_operations'] += global_ops['pre'] description['post_global_operations'] += global_ops['post'] # Replace pre- and post_synaptic variables eq, untouched, prepost_dependencies = extract_prepost('psp', eq, description) description['dependencies']['pre'] += prepost_dependencies['pre'] description['dependencies']['post'] += prepost_dependencies['post'] for name, val in untouched_globs.items(): if not name in untouched.keys(): untouched[name] = val # Extract if-then-else statements eq, condition = extract_ite('psp', eq, description, split=False) # Analyse the equation if condition == []: translator = Equation('psp', eq, description, method = 'explicit', untouched = untouched.keys(), type='return') code = translator.parse() deps = translator.dependencies() else: code, deps = translate_ITE('psp', eq, condition, description, untouched) # Replace untouched variables with their original name for prev, new in untouched.items(): code = code.replace(prev, new) # Store the result psp['cpp'] = code psp['dependencies'] = deps description['psp'] = psp # Process event-driven info if description['type'] == 'spike': for variable in description['pre_spike'] + description['post_spike']: # Find plasticity if variable['name'] == 'w': description['plasticity'] = True # Retrieve the equation eq = variable['eq'] # Extract if-then-else statements eq, condition = extract_ite(variable['name'], eq, description) # Extract pre- and post_synaptic variables eq, untouched, prepost_dependencies = extract_prepost(variable['name'], eq, description) # Update dependencies description['dependencies']['pre'] += prepost_dependencies['pre'] description['dependencies']['post'] += prepost_dependencies['post'] # and also on the variable for checking variable['prepost_dependencies'] = prepost_dependencies # Analyse the equation dependencies = [] if condition == []: translator = Equation(variable['name'], eq, description, method = 'explicit', untouched = untouched) code = translator.parse() dependencies += translator.dependencies() else: code, deps = translate_ITE(variable['name'], eq, condition, description, untouched) dependencies += deps if isinstance(code, list): # an ode in a pre/post statement Global._print(eq) if variable in description['pre_spike']: Global._error('It is forbidden to use ODEs in a pre_spike term.') elif variable in description['posz_spike']: Global._error('It is forbidden to use ODEs in a post_spike term.') else: Global._error('It is forbidden to use ODEs here.') # Replace untouched variables with their original name for prev, new in untouched.items(): code = code.replace(prev, new) # Process the bounds if 'min' in variable['bounds'].keys(): if isinstance(variable['bounds']['min'], str): translator = Equation( variable['name'], variable['bounds']['min'], description, type = 'return', untouched = untouched ) variable['bounds']['min'] = translator.parse().replace(';', '') dependencies += translator.dependencies() if 'max' in variable['bounds'].keys(): if isinstance(variable['bounds']['max'], str): translator = Equation( variable['name'], variable['bounds']['max'], description, type = 'return', untouched = untouched) variable['bounds']['max'] = translator.parse().replace(';', '') dependencies += translator.dependencies() # Store the result variable['cpp'] = code # the C++ equation variable['dependencies'] = dependencies # Structural plasticity if synapse.pruning: description['pruning'] = extract_structural_plasticity(synapse.pruning, description) if synapse.creating: description['creating'] = extract_structural_plasticity(synapse.creating, description) return description
class CoupledEquations(object): def __init__(self, description, variables): self.description = description self.variables = variables self.untouched = variables[0]['untouched'] self.expression_list = {} for var in self.variables: self.expression_list[var['name']] = var['transformed_eq'] self.names = self.expression_list.keys() self.local_variables = self.description['local'] self.global_variables = self.description['global'] self.local_dict = Equation('tmp', '', self.description, method = 'implicit', untouched = self.untouched ).local_dict def process_variables(self): # Check if the numerical method is the same for all ODEs methods = [] for var in self.variables: methods.append(var['method']) if len(list(set(methods))) > 1: # mixture of methods _print(methods) _error('Can not mix different numerical methods when solving a coupled system of equations.') else: method = methods[0] if method == 'implicit' or method == 'semiimplicit': return self.solve_implicit(self.expression_list) elif method == 'midpoint': return self.solve_midpoint(self.expression_list) def solve_implicit(self, expression_list): equations = {} new_vars = {} # Pre-processing to replace the gradient for name, expression in self.expression_list.items(): # transform the expression to suppress = if '=' in expression: expression = expression.replace('=', '- (') expression += ')' # Suppress spaces to extract dvar/dt expression = expression.replace(' ', '') # Transform the gradient into a difference TODO: more robust... expression = expression.replace('d'+name, '_t_gradient_') expression_list[name] = expression # replace the variables by their future value for name, expression in expression_list.items(): for n in self.names: expression = re.sub(r'([^\w]+)'+n+r'([^\w]+)', r'\1_'+n+r'\2', expression) expression = expression.replace('_t_gradient_', '(_'+name+' - '+name+')') expression_list[name] = expression + '-' + name new_var = Symbol('_'+name) self.local_dict['_'+name] = new_var new_vars[new_var] = name for name, expression in expression_list.items(): analysed = parse_expr(expression, local_dict = self.local_dict, transformations = (standard_transformations + (convert_xor,)) ) equations[name] = analysed try: solution = solve(equations.values(), new_vars.keys()) except: _print(expression_list) _error('The multiple ODEs can not be solved together using the implicit Euler method.') for var, sol in solution.items(): # simplify the solution sol = collect( sol, self.local_dict['dt']) # Generate the code cpp_eq = 'double _' + new_vars[var] + ' = ' + ccode(sol) + ';' switch = ccode(self.local_dict[new_vars[var]] ) + ' += _' + new_vars[var] + ';' # Replace untouched variables with their original name for prev, new in self.untouched.items(): cpp_eq = re.sub(prev, new, cpp_eq) switch = re.sub(prev, new, switch) # Store the result for variable in self.variables: if variable['name'] == new_vars[var]: variable['cpp'] = cpp_eq variable['switch'] = switch return self.variables def solve_midpoint(self, expression_list): expression_list = {} equations = {} evaluations = {} # Pre-processing to replace the gradient for name, expression in self.expression_list.items(): # transform the expression to suppress = if '=' in expression: expression = expression.replace('=', '- (') expression += ')' # Suppress spaces to extract dvar/dt expression = expression.replace(' ', '') # Transform the gradient into a difference TODO: more robust... expression = expression.replace('d'+name+'/dt', '_gradient_'+name) self.local_dict['_gradient_'+name] = Symbol('_gradient_'+name) expression_list[name] = expression for name, expression in expression_list.items(): analysed = parse_expr(expression, local_dict = self.local_dict, transformations = (standard_transformations + (convert_xor,)) ) equations[name] = analysed evaluations[name] = solve(analysed, self.local_dict['_gradient_'+name]) # Compute the k = f(x, t) ks = {} for name, evaluation in evaluations.items(): ks[name] = 'double _k_' + name + ' = ' + ccode(evaluation[0]) + ';' # New dictionary replacing x by x+dt/2*k) tmp_dict = {} for name, val in self.local_dict.items(): tmp_dict[name] = val for name, evaluation in evaluations.items(): tmp_dict[name] = Symbol('(' + ccode(self.local_dict[name]) + ' + 0.5*dt*_k_' + name + ' )') # Compute the new values _x_new = f(x + dt/2*_k) news = {} for name, expression in expression_list.items(): tmp_analysed = parse_expr(expression, local_dict = tmp_dict, transformations = (standard_transformations + (convert_xor,)) ) solved = solve(tmp_analysed, self.local_dict['_gradient_'+name]) news[name] = 'double _' + name + ' = ' + ccode(solved[0]) + ';' # Compute the switches switches = {} for name, expression in expression_list.items(): switches[name] = ccode(self.local_dict[name]) + ' += dt * _' + name + ' ;' # Store the generated code in the variables for name in self.names: k = ks[name] n = news[name] switch = switches[name] # Replace untouched variables with their original name for prev, new in self.untouched.items(): k = re.sub(prev, new, k) n = re.sub(prev, new, n) switch = re.sub(prev, new, switch) # Store the result for variable in self.variables: if variable['name'] == name: variable['cpp'] = [k, n] variable['switch'] = switch return self.variables expression = expression.replace('d'+self.name+'/dt', '_grad_var_') new_var = Symbol('_grad_var_') self.local_dict['_grad_var_'] = new_var analysed = self.parse_expression(expression, local_dict = self.local_dict ) variable_name = self.local_dict[self.name] equation = simplify(collect( solve(analysed, new_var)[0], self.local_dict['dt'])) explicit_code = 'double _k_' + self.name + ' = dt*(' + self.c_code(equation) + ');' # Midpoint method: # Replace the variable x by x+_x/2 tmp_dict = self.local_dict tmp_dict[self.name] = Symbol('(' + self.c_code(variable_name) + ' + 0.5*_k_' + self.name + ' )') tmp_analysed = self.parse_expression(expression, local_dict = self.local_dict ) tmp_equation = solve(tmp_analysed, new_var)[0] explicit_code += '\n double _' + self.name + ' = ' + self.c_code(tmp_equation) + ';' switch = self.c_code(variable_name) + ' += dt*_' + self.name + ' ;' # Return result return [explicit_code, switch]