예제 #1
0
파일: ITE.py 프로젝트: neeraj2296/ANNarchy
    def process_ITE(condition):
        if_statement = condition[0]
        then_statement = condition[1]
        else_statement = condition[2]

        if_code = Equation(name,
                           if_statement,
                           description,
                           untouched=untouched.keys(),
                           type='cond').parse()
        if isinstance(then_statement, list):  # nested conditional
            then_code = process_ITE(then_statement)
        else:
            then_code = Equation(name,
                                 then_statement,
                                 description,
                                 untouched=untouched.keys(),
                                 type='return').parse().split(';')[0]
        if isinstance(else_statement, list):  # nested conditional
            else_code = process_ITE(else_statement)
        else:
            else_code = Equation(name,
                                 else_statement,
                                 description,
                                 untouched=untouched.keys(),
                                 type='return').parse().split(';')[0]

        code = '(' + if_code + ' ? ' + then_code + ' : ' + else_code + ')'
        return code
예제 #2
0
파일: ITE.py 프로젝트: neeraj2296/ANNarchy
def translate_ITE(name, eq, condition, description, untouched, split=True):
    " Recursively processes the different parts of an ITE statement"

    def process_ITE(condition):
        if_statement = condition[0]
        then_statement = condition[1]
        else_statement = condition[2]

        if_code = Equation(name,
                           if_statement,
                           description,
                           untouched=untouched.keys(),
                           type='cond').parse()
        if isinstance(then_statement, list):  # nested conditional
            then_code = process_ITE(then_statement)
        else:
            then_code = Equation(name,
                                 then_statement,
                                 description,
                                 untouched=untouched.keys(),
                                 type='return').parse().split(';')[0]
        if isinstance(else_statement, list):  # nested conditional
            else_code = process_ITE(else_statement)
        else:
            else_code = Equation(name,
                                 else_statement,
                                 description,
                                 untouched=untouched.keys(),
                                 type='return').parse().split(';')[0]

        code = '(' + if_code + ' ? ' + then_code + ' : ' + else_code + ')'
        return code

    if split:
        # Main equation, where the right part is __conditional__
        translator = Equation(name,
                              eq,
                              description,
                              untouched=untouched.keys())
        code = translator.parse()
    else:
        code = '__conditional__'
    # Process the ITE
    itecode = process_ITE(condition)
    # Replace
    if isinstance(code, str):
        code = code.replace('__conditional__', itecode)
    else:
        code[0] = code[0].replace('__conditional__', itecode)

    return code
예제 #3
0
def extract_spike_variable(description):

    cond = prepare_string(description['raw_spike'])
    if len(cond) > 1:
        Global.Global._print(description['raw_spike'])
        Global._error('The spike condition must be a single expression')

    translator = Equation('raw_spike_cond', cond[0].strip(), description)
    raw_spike_code = translator.parse()
    # Also store the variables used in the condition, as it may be needed for CUDA generation
    spike_code_dependencies = translator.dependencies()

    reset_desc = []
    if 'raw_reset' in description.keys() and description['raw_reset']:
        reset_desc = process_equations(description['raw_reset'])
        for var in reset_desc:
            translator = Equation(var['name'], var['eq'], description)
            var['cpp'] = translator.parse()
            var['dependencies'] = translator.dependencies()

    return {
        'spike_cond': raw_spike_code,
        'spike_cond_dependencies': spike_code_dependencies,
        'spike_reset': reset_desc
    }
예제 #4
0
def extract_stop_condition(pop):
    eq = pop['stop_condition']['eq']
    pop['stop_condition']['type'] = 'any'
    # Check the flags
    split = eq.split(':')
    if len(split) > 1:  # flag given
        eq = split[0]
        flags = split[1].strip()
        split = flags.split(' ')
        for el in split:
            if el.strip() == 'all':
                pop['stop_condition']['type'] = 'all'
    # Convert the expression
    translator = Equation('stop_cond', eq, pop, type='cond')
    code = translator.parse()
    pop['stop_condition']['cpp'] = '(' + code + ')'
예제 #5
0
def extract_stop_condition(pop):
    eq = pop['stop_condition']['eq']
    pop['stop_condition']['type'] = 'any'
    # Check the flags
    split = eq.split(':')
    if len(split) > 1: # flag given
        eq = split[0]
        flags = split[1].strip()
        split = flags.split(' ')
        for el in split:
            if el.strip() == 'all':
                pop['stop_condition']['type'] = 'all'
    # Convert the expression
    translator = Equation('stop_cond', eq, 
                          pop, 
                          type = 'cond')
    code = translator.parse()
    pop['stop_condition']['cpp'] = '(' + code + ')'
예제 #6
0
파일: ITE.py 프로젝트: vitay/ANNarchy
def translate_ITE(name, eq, condition, description, untouched, split=True):
    " Recursively processes the different parts of an ITE statement"
    def process_condition(condition):
        if_statement = condition[0]
        then_statement = condition[1]
        else_statement = condition[2]

        if_code = Equation(name, if_statement, description,
                          untouched = untouched.keys(),
                          type='cond').parse()
        if isinstance(then_statement, list): # nested conditional
            then_code =  process_condition(then_statement)
        else:
            then_code = Equation(name, then_statement, description,
                          untouched = untouched.keys(),
                          type='return').parse().split(';')[0]
        if isinstance(else_statement, list): # nested conditional
            else_code =  process_condition(else_statement)
        else:
            else_code = Equation(name, else_statement, description,
                          untouched = untouched.keys(),
                          type='return').parse().split(';')[0]

        code = '(' + if_code + ' ? ' + then_code + ' : ' + else_code + ')'
        return code

    if split:
        # Main equation, where the right part is __conditional__
        translator = Equation(name, eq, description,
                              untouched = untouched.keys())
        code = translator.parse()
    else:
        code = eq

    # Process the (possibly multiple) ITE
    for i in range(len(condition)):
        itecode =  process_condition(condition[i])
        # Replace
        if isinstance(code, str):
            code = code.replace('__conditional__'+str(i), itecode)
        else:
            code[0] = code[0].replace('__conditional__'+str(i), itecode)

    return code
예제 #7
0
    def __init__(self, description, variables):
        self.description = description
        self.variables = variables

        self.untouched = variables[0]['untouched']

        self.expression_list = {}
        for var in self.variables:
            self.expression_list[var['name']] = var['transformed_eq']

        self.names = self.expression_list.keys()

        self.local_variables = self.description['local']
        self.global_variables = self.description['global']

        self.local_dict = Equation('tmp',
                                   '',
                                   self.description,
                                   method='implicit',
                                   untouched=self.untouched).local_dict
예제 #8
0
def extract_spike_variable(description):

    cond = prepare_string(description['raw_spike'])
    if len(cond) > 1:
        Global.Global._print(description['raw_spike'])
        Global._error('The spike condition must be a single expression')

    translator = Equation('raw_spike_cond',
                            cond[0].strip(),
                            description)
    raw_spike_code = translator.parse()
    # Also store the variables used in the condition, as it may be needed for CUDA generation
    spike_code_dependencies = translator.dependencies()

    reset_desc = []
    if 'raw_reset' in description.keys() and description['raw_reset']:
        reset_desc = process_equations(description['raw_reset'])
        for var in reset_desc:
            translator = Equation(var['name'], var['eq'],
                                  description)
            var['cpp'] = translator.parse()
            var['dependencies'] = translator.dependencies()

    return { 'spike_cond': raw_spike_code,
             'spike_cond_dependencies': spike_code_dependencies, 
             'spike_reset': reset_desc}
예제 #9
0
def extract_spike_variable(description):

    cond = prepare_string(description['raw_spike'])
    if len(cond) > 1:
        _error('The spike condition must be a single expression')
        _print(description['raw_spike'])
        exit(0)
        
    translator = Equation('raw_spike_cond', 
                            cond[0].strip(), 
                            description)
    raw_spike_code = translator.parse()
    
    reset_desc = []
    if 'raw_reset' in description.keys() and description['raw_reset']:
        reset_desc = process_equations(description['raw_reset'])
        for var in reset_desc:
            translator = Equation(var['name'], var['eq'], 
                                  description)
            var['cpp'] = translator.parse() 
    
    return { 'spike_cond': raw_spike_code, 'spike_reset': reset_desc}
예제 #10
0
    def __init__(self, description, variables):
        self.description = description
        self.variables = variables

        self.untouched = variables[0]['untouched']

        self.expression_list = {}
        for var in self.variables:
            self.expression_list[var['name']] = var['transformed_eq']


        self.names = self.expression_list.keys()

        self.local_variables = self.description['local']
        self.global_variables = self.description['global']

        self.local_dict = Equation('tmp', '',
                                      self.description, 
                                      method = 'implicit',
                                      untouched = self.untouched
                                      ).local_dict
예제 #11
0
def extract_spike_variable(description):

    cond = prepare_string(description['raw_spike'])
    if len(cond) > 1:
        _error('The spike condition must be a single expression')
        _print(description['raw_spike'])
        exit(0)

    translator = Equation('raw_spike_cond', cond[0].strip(), description)
    raw_spike_code = translator.parse()

    reset_desc = []
    if 'raw_reset' in description.keys() and description['raw_reset']:
        reset_desc = process_equations(description['raw_reset'])
        for var in reset_desc:
            translator = Equation(var['name'], var['eq'], description)
            var['cpp'] = translator.parse()

    return {'spike_cond': raw_spike_code, 'spike_reset': reset_desc}
예제 #12
0
def extract_structural_plasticity(statement, description):
    # Extract flags
    try:
        eq, constraint = statement.rsplit(':', 1)
        bounds, flags = extract_flags(constraint)
    except:
        eq = statement.strip()
        bounds = {}
        flags = []

    # Extract RD
    rd = None
    for dist in available_distributions:
        matches = re.findall('(?P<pre>[^\w.])'+dist+'\(([^()]+)\)', eq)
        for l, v in matches:
            # Check the arguments
            arguments = v.split(',')
            # Check the number of provided arguments
            if len(arguments) < distributions_arguments[dist]:
                _error(eq)
                _error('The distribution ' + dist + ' requires ' + str(distributions_arguments[dist]) + 'parameters')
            elif len(arguments) > distributions_arguments[dist]:
                _error(eq)
                _error('Too many parameters provided to the distribution ' + dist)
            # Process the arguments
            processed_arguments = ""
            for idx in range(len(arguments)):
                try:
                    arg = float(arguments[idx])
                except: # A global parameter
                    _error(eq)
                    _error('Random distributions for creating/pruning synapses must use foxed values.')
                    exit(0)
                processed_arguments += str(arg)
                if idx != len(arguments)-1: # not the last one
                    processed_arguments += ', '
            definition = distributions_equivalents[dist] + '(' + processed_arguments + ')'
            
            # Store its definition
            if rd:
                _error(eq)
                _error('Only one random distribution per equation is allowed.')
                exit(0)

            rd = {'name': 'rand_' + str(0) ,
                    'origin': dist+'('+v+')',
                    'dist': dist,
                    'definition': definition,
                    'args' : processed_arguments,
                    'template': distributions_equivalents[dist]}

    if rd:
        eq = eq.replace(rd['origin'], 'rd(rng)')

    # Extract pre/post dependencies
    eq, untouched, dependencies = extract_prepost('test', eq, description)

    # Parse code
    translator = Equation('test', eq, 
                          description, 
                          method = 'cond', 
                          untouched = {})

    code = translator.parse()

    # Replace untouched variables with their original name
    for prev, new in untouched.items():
        code = code.replace(prev, new) 

    # Add new dependencies
    for dep in dependencies['pre']:
        description['dependencies']['pre'].append(dep)
    for dep in dependencies['post']:
        description['dependencies']['post'].append(dep)

    return {'eq': eq, 'cpp': code, 'bounds': bounds, 'flags': flags, 'rd': rd}
예제 #13
0
def analyse_neuron(neuron):
    """
    Parses the structure and generates code snippets for the neuron type.

    It returns a ``description`` dictionary with the following fields:

    * 'object': 'neuron' by default, to distinguish it from 'synapse'
    * 'type': either 'rate' or 'spiking'
    * 'raw_parameters': provided field
    * 'raw_equations': provided field
    * 'raw_functions': provided field
    * 'raw_reset': provided field
    * 'raw_spike': provided field
    * 'refractory': provided field
    * 'parameters': list of parameters defined for the neuron type
    * 'variables': list of variables defined for the neuron type
    * 'functions': list of functions defined for the neuron type
    * 'attributes': list of names of all parameters and variables
    * 'local': list of names of parameters and variables which are local to each neuron
    * 'global': list of names of parameters and variables which are global to the population
    * 'targets': list of targets used in the equations
    * 'random_distributions': list of random number generators used in the neuron equations
    * 'global_operations': list of global operations (min/max/mean...) used in the equations (unused)
    * 'spike': when defined, contains the equations of the spike conditions and reset.

    Each parameter is a dictionary with the following elements:

    * 'bounds': unused
    * 'ctype': 'type of the parameter: 'float', 'double', 'int' or 'bool'
    * 'eq': original equation in text format
    * 'flags': list of flags provided after the :
    * 'init': initial value
    * 'locality': 'local' or 'global'
    * 'name': name of the parameter

    Each variable is a dictionary with the following elements:

    * 'bounds': dictionary of bounds ('init', 'min', 'max') provided after the :
    * 'cpp': C++ code snippet updating the variable
    * 'ctype': type of the variable: 'float', 'double', 'int' or 'bool'
    * 'dependencies': list of variable and parameter names on which the equation depends
    * 'eq': original equation in text format
    * 'flags': list of flags provided after the :
    * 'init': initial value
    * 'locality': 'local' or 'global'
    * 'method': numericalmethod for ODEs
    * 'name': name of the variable
    * 'pre_loop': ODEs have a pre_loop term for precomputing dt/tau. dict with 'name' and 'value'. type must be inferred.
    * 'switch': ODEs have a switch term
    * 'transformed_eq': same as eq, except special terms (sums, rds) are replaced with a temporary name
    * 'untouched': dictionary of special terms, with their new name as keys and replacement values as values.

    The 'spike' element (when present) is a dictionary containing:

    * 'spike_cond': the C++ code snippet containing the spike condition ("v%(local_index)s > v_T")
    * 'spike_cond_dependencies': list of variables/parameters on which the spike condition depends
    * 'spike_reset': a list of reset statements, each of them composed of :
        * 'constraint': either '' or 'unless_refractory'
        * 'cpp': C++ code snippet
        * 'dependencies': list of variables on which the reset statement depends
        * 'eq': original equation in text format
        * 'name': name of the reset variable

    """

    # Store basic information
    description = {
        'object': 'neuron',
        'type': neuron.type,
        'raw_parameters': neuron.parameters,
        'raw_equations': neuron.equations,
        'raw_functions': neuron.functions,
    }

    # Spiking neurons additionally store the spike condition, the reset statements and a refractory period
    if neuron.type == 'spike':
        description['raw_reset'] = neuron.reset
        description['raw_spike'] = neuron.spike
        description['raw_axon_spike'] = neuron.axon_spike
        description['raw_axon_reset'] = neuron.axon_reset
        description['refractory'] = neuron.refractory

    # Extract parameters and variables names
    parameters = extract_parameters(neuron.parameters, neuron.extra_values)
    variables = extract_variables(neuron.equations)
    description['parameters'] = parameters
    description['variables'] = variables

    # Make sure r is defined for rate-coded networks
    from ANNarchy.extensions.bold.BoldModel import BoldModel
    if isinstance(neuron, BoldModel):
        found = False
        for var in description['parameters'] + description['variables']:
            if var['name'] == 'r':
                found = True
        if not found:
            description['variables'].append({
                'name': 'r',
                'locality': 'local',
                'bounds': {},
                'ctype': config['precision'],
                'init': 0.0,
                'flags': [],
                'eq': '',
                'cpp': ""
            })
    elif neuron.type == 'rate':
        for var in description['parameters'] + description['variables']:
            if var['name'] == 'r':
                break
        else:
            _error('Rate-coded neurons must define the variable "r".')

    else:  # spiking neurons define r by default, it contains the average FR if enabled
        for var in description['parameters'] + description['variables']:
            if var['name'] == 'r':
                _error(
                    'Spiking neurons use the variable "r" for the average FR, use another name.'
                )

        description['variables'].append({
            'name': 'r',
            'locality': 'local',
            'bounds': {},
            'ctype': config['precision'],
            'init': 0.0,
            'flags': [],
            'eq': '',
            'cpp': ""
        })

    # Extract functions
    functions = extract_functions(neuron.functions, False)
    description['functions'] = functions

    # Build lists of all attributes (param + var), which are local or global
    attributes, local_var, global_var, _ = get_attributes(parameters,
                                                          variables,
                                                          neuron=True)

    # Test if attributes are declared only once
    if len(attributes) != len(list(set(attributes))):
        _error('Attributes must be declared only once.', attributes)

    # Store the attributes
    description['attributes'] = attributes
    description['local'] = local_var
    description['semiglobal'] = []  # only for projections
    description['global'] = global_var

    # Extract all targets
    targets = sorted(list(set(extract_targets(variables))))
    description['targets'] = targets
    if neuron.type == 'spike':  # Add a default reset behaviour for conductances
        for target in targets:
            found = False
            for var in description['variables']:
                if var['name'] == 'g_' + target:
                    found = True
                    break
            if not found:
                description['variables'].append({
                    'name': 'g_' + target,
                    'locality': 'local',
                    'bounds': {},
                    'ctype': config['precision'],
                    'init': 0.0,
                    'flags': [],
                    'eq': 'g_' + target + ' = 0.0'
                })
                description['attributes'].append('g_' + target)
                description['local'].append('g_' + target)

    # Extract RandomDistribution objects
    random_distributions = extract_randomdist(description)
    description['random_distributions'] = random_distributions

    # Extract the spike condition if any
    if neuron.type == 'spike':
        description['spike'] = extract_spike_variable(description)
        description['axon_spike'] = extract_axon_spike_condition(description)

    # Global operation TODO
    description['global_operations'] = []

    # The ODEs may be interdependent (implicit, midpoint), so they need to be passed explicitely to CoupledEquations
    concurrent_odes = []

    # Translate the equations to C++
    for variable in description['variables']:
        # Get the equation
        eq = variable['transformed_eq']
        if eq.strip() == "":
            continue

        # Special variables (sums, global operations, rd) are placed in untouched, so that Sympy ignores them
        untouched = {}

        # Dependencies must be gathered
        dependencies = []

        # Replace sum(target) with _sum_exc__[i]
        for target in description['targets']:
            # sum() is valid for all targets
            eq = re.sub(r'(?P<pre>[^\w.])sum\(\)', r'\1sum(__all__)', eq)
            # Replace sum(target) with __sum_target__
            eq = re.sub('sum\(\s*' + target + '\s*\)',
                        '__sum_' + target + '__', eq)
            untouched['__sum_' + target +
                      '__'] = '_sum_' + target + '%(local_index)s'

        # Extract global operations
        eq, untouched_globs, global_ops = extract_globalops_neuron(
            variable['name'], eq, description)

        # Add the untouched variables to the global list
        for name, val in untouched_globs.items():
            if not name in untouched.keys():
                untouched[name] = val
        description['global_operations'] += global_ops

        # Extract if-then-else statements
        eq, condition = extract_ite(variable['name'], eq, description)

        # Find the numerical method if any
        method = find_method(variable)

        # Process the bounds
        if 'min' in variable['bounds'].keys():
            if isinstance(variable['bounds']['min'], str):
                translator = Equation(variable['name'],
                                      variable['bounds']['min'],
                                      description,
                                      type='return',
                                      untouched=untouched)
                variable['bounds']['min'] = translator.parse().replace(';', '')
                dependencies += translator.dependencies()

        if 'max' in variable['bounds'].keys():
            if isinstance(variable['bounds']['max'], str):
                translator = Equation(variable['name'],
                                      variable['bounds']['max'],
                                      description,
                                      type='return',
                                      untouched=untouched)
                variable['bounds']['max'] = translator.parse().replace(';', '')
                dependencies += translator.dependencies()

        # Analyse the equation
        if condition == []:  # No if-then-else
            translator = Equation(variable['name'],
                                  eq,
                                  description,
                                  method=method,
                                  untouched=untouched)
            code = translator.parse()
            dependencies += translator.dependencies()
        else:  # An if-then-else statement
            code, deps = translate_ITE(variable['name'], eq, condition,
                                       description, untouched)
            dependencies += deps

        # ODEs have a switch statement:
        #   double _r = (1.0 - r)/tau;
        #   r[i] += dt* _r;
        # while direct assignments are one-liners:
        #   r[i] = 1.0
        if isinstance(code, str):
            pre_loop = {}
            cpp_eq = code
            switch = None
        else:  # ODE
            pre_loop = code[0]
            cpp_eq = code[1]
            switch = code[2]

        # Replace untouched variables with their original name
        for prev, new in untouched.items():
            if prev.startswith('g_'):
                cpp_eq = re.sub(r'([^_]+)' + prev, r'\1' + new,
                                ' ' + cpp_eq).strip()
                if len(pre_loop) > 0:
                    pre_loop['value'] = re.sub(r'([^_]+)' + prev, new, ' ' +
                                               pre_loop['value']).strip()
                if switch:
                    switch = re.sub(r'([^_]+)' + prev, new,
                                    ' ' + switch).strip()
            else:
                cpp_eq = re.sub(prev, new, cpp_eq)
                if len(pre_loop) > 0:
                    pre_loop['value'] = re.sub(prev, new, pre_loop['value'])
                if switch:
                    switch = re.sub(prev, new, switch)

        # Replace local functions
        for f in description['functions']:
            cpp_eq = re.sub(r'([^\w]*)' + f['name'] + '\(',
                            r'\1' + f['name'] + '(', ' ' + cpp_eq).strip()

        # Store the result
        variable[
            'pre_loop'] = pre_loop  # Things to be declared before the for loop (eg. dt)
        variable['cpp'] = cpp_eq  # the C++ equation
        variable['switch'] = switch  # switch value of ODE
        variable['untouched'] = untouched  # may be needed later
        variable['method'] = method  # may be needed later
        variable['dependencies'] = list(
            set(dependencies))  # may be needed later

        # If the method is implicit or midpoint, the equations must be solved concurrently (depend on v[t+1])
        if method in ['implicit', 'midpoint', 'runge-kutta4'
                      ] and switch is not None:
            concurrent_odes.append(variable)

    # After all variables are processed, do it again if they are concurrent
    if len(concurrent_odes) > 1:
        solver = CoupledEquations(description, concurrent_odes)
        new_eqs = solver.parse()
        for idx, variable in enumerate(description['variables']):
            for new_eq in new_eqs:
                if variable['name'] == new_eq['name']:
                    description['variables'][idx] = new_eq

    return description
예제 #14
0
def analyse_synapse(synapse):
    """
    Parses the structure and generates code snippets for the synapse type.

    It returns a ``description`` dictionary with the following fields:

    * 'object': 'synapse' by default, to distinguish it from 'neuron'
    * 'type': either 'rate' or 'spiking'
    * 'raw_parameters': provided field
    * 'raw_equations': provided field
    * 'raw_functions': provided field
    * 'raw_psp': provided field
    * 'raw_pre_spike': provided field
    * 'raw_post_spike': provided field
    * 'parameters': list of parameters defined for the synapse type
    * 'variables': list of variables defined for the synapse type
    * 'functions': list of functions defined for the synapse type
    * 'attributes': list of names of all parameters and variables
    * 'local': list of names of parameters and variables which are local to each synapse
    * 'semiglobal': list of names of parameters and variables which are local to each postsynaptic neuron
    * 'global': list of names of parameters and variables which are global to the projection
    * 'random_distributions': list of random number generators used in the neuron equations
    * 'global_operations': list of global operations (min/max/mean...) used in the equations
    * 'pre_global_operations': list of global operations (min/max/mean...) on the pre-synaptic population
    * 'post_global_operations': list of global operations (min/max/mean...) on the post-synaptic population
    * 'pre_spike': list of variables updated after a pre-spike event
    * 'post_spike': list of variables updated after a post-spike event
    * 'dependencies': dictionary ('pre', 'post') of lists of pre (resp. post) variables accessed by the synapse (used for delaying variables)
    * 'psp': dictionary ('eq' and 'psp') for the psp code to be summed
    * 'pruning' and 'creating': statements for structural plasticity


    Each parameter is a dictionary with the following elements:

    * 'bounds': unused
    * 'ctype': 'type of the parameter: 'float', 'double', 'int' or 'bool'
    * 'eq': original equation in text format
    * 'flags': list of flags provided after the :
    * 'init': initial value
    * 'locality': 'local', 'semiglobal' or 'global'
    * 'name': name of the parameter

    Each variable is a dictionary with the following elements:

    * 'bounds': dictionary of bounds ('init', 'min', 'max') provided after the :
    * 'cpp': C++ code snippet updating the variable
    * 'ctype': type of the variable: 'float', 'double', 'int' or 'bool'
    * 'dependencies': list of variable and parameter names on which the equation depends
    * 'eq': original equation in text format
    * 'flags': list of flags provided after the :
    * 'init': initial value
    * 'locality': 'local', 'semiglobal' or 'global'
    * 'method': numericalmethod for ODEs
    * 'name': name of the variable
    * 'pre_loop': ODEs have a pre_loop term for precomputing dt/tau
    * 'switch': ODEs have a switch term
    * 'transformed_eq': same as eq, except special terms (sums, rds) are replaced with a temporary name
    * 'untouched': dictionary of special terms, with their new name as keys and replacement values as values.

    """

    # Store basic information
    description = {
        'object': 'synapse',
        'type': synapse.type,
        'raw_parameters': synapse.parameters,
        'raw_equations': synapse.equations,
        'raw_functions': synapse.functions
    }

    # Psps is what is actually summed over the incoming weights
    if synapse.psp:
        description['raw_psp'] = synapse.psp
    elif synapse.type == 'rate':
        description['raw_psp'] = "w*pre.r"

    # Spiking synapses additionally store pre_spike and post_spike
    if synapse.type == 'spike':
        description['raw_pre_spike'] = synapse.pre_spike
        description['raw_post_spike'] = synapse.post_spike

    # Extract parameters and variables names
    parameters = extract_parameters(synapse.parameters, synapse.extra_values)
    variables = extract_variables(synapse.equations)

    # Extract functions
    functions = extract_functions(synapse.functions, False)

    # Check the presence of w
    description['plasticity'] = False
    for var in parameters + variables:
        if var['name'] == 'w':
            break
    else:
        parameters.append({
            'name': 'w',
            'bounds': {},
            'ctype': config['precision'],
            'init': 0.0,
            'flags': [],
            'eq': 'w=0.0',
            'locality': 'local'
        })

    # Find out a plasticity rule
    for var in variables:
        if var['name'] == 'w':
            description['plasticity'] = True
            break

    # Build lists of all attributes (param+var), which are local or global
    attributes, local_var, global_var, semiglobal_var = get_attributes(
        parameters, variables, neuron=False)

    # Test if attributes are declared only once
    if len(attributes) != len(list(set(attributes))):
        _error('Attributes must be declared only once.', attributes)

    # Add this info to the description
    description['parameters'] = parameters
    description['variables'] = variables
    description['functions'] = functions
    description['attributes'] = attributes
    description['local'] = local_var
    description['semiglobal'] = semiglobal_var
    description['global'] = global_var
    description['global_operations'] = []

    # Lists of global operations needed at the pre and post populations
    description['pre_global_operations'] = []
    description['post_global_operations'] = []

    # Extract RandomDistribution objects
    description['random_distributions'] = extract_randomdist(description)

    # Extract event-driven info
    if description['type'] == 'spike':
        # pre_spike event
        description['pre_spike'] = extract_pre_spike_variable(description)
        for var in description['pre_spike']:
            if var['name'] in ['g_target']:  # Already dealt with
                continue
            for avar in description['variables']:
                if var['name'] == avar['name']:
                    break
            else:  # not defined already
                description['variables'].append({
                    'name': var['name'],
                    'bounds': var['bounds'],
                    'ctype': var['ctype'],
                    'init': var['init'],
                    'locality': var['locality'],
                    'flags': [],
                    'transformed_eq': '',
                    'eq': '',
                    'cpp': '',
                    'switch': '',
                    're_loop': '',
                    'untouched': '',
                    'method': 'explicit'
                })
                description['local'].append(var['name'])
                description['attributes'].append(var['name'])

        # post_spike event
        description['post_spike'] = extract_post_spike_variable(description)
        for var in description['post_spike']:
            if var['name'] in ['g_target', 'w']:  # Already dealt with
                continue
            for avar in description['variables']:
                if var['name'] == avar['name']:
                    break
            else:  # not defined already
                description['variables'].append({
                    'name': var['name'],
                    'bounds': var['bounds'],
                    'ctype': var['ctype'],
                    'init': var['init'],
                    'locality': var['locality'],
                    'flags': [],
                    'transformed_eq': '',
                    'eq': '',
                    'cpp': '',
                    'switch': '',
                    'untouched': '',
                    'method': 'explicit'
                })
                description['local'].append(var['name'])
                description['attributes'].append(var['name'])

    # Variables names for the parser which should be left untouched
    untouched = {}
    description['dependencies'] = {'pre': [], 'post': []}

    # The ODEs may be interdependent (implicit, midpoint), so they need to be passed explicitely to CoupledEquations
    concurrent_odes = []

    # Iterate over all variables
    for variable in description['variables']:
        # Equation
        eq = variable['transformed_eq']
        if eq.strip() == '':
            continue

        # Dependencies must be gathered
        dependencies = []

        # Extract global operations
        eq, untouched_globs, global_ops = extract_globalops_synapse(
            variable['name'], eq, description)
        description['pre_global_operations'] += global_ops['pre']
        description['post_global_operations'] += global_ops['post']
        # Remove doubled entries
        description['pre_global_operations'] = [
            i for n, i in enumerate(description['pre_global_operations'])
            if i not in description['pre_global_operations'][n + 1:]
        ]
        description['post_global_operations'] = [
            i for n, i in enumerate(description['post_global_operations'])
            if i not in description['post_global_operations'][n + 1:]
        ]

        # Extract pre- and post_synaptic variables
        eq, untouched_var, prepost_dependencies = extract_prepost(
            variable['name'], eq, description)

        # Store the pre-post dependencies at the synapse level
        description['dependencies']['pre'] += prepost_dependencies['pre']
        description['dependencies']['post'] += prepost_dependencies['post']
        # and also on the variable for checking
        variable['prepost_dependencies'] = prepost_dependencies

        # Extract if-then-else statements
        eq, condition = extract_ite(variable['name'], eq, description)

        # Add the untouched variables to the global list
        for name, val in untouched_globs.items():
            if not name in untouched.keys():
                untouched[name] = val
        for name, val in untouched_var.items():
            if not name in untouched.keys():
                untouched[name] = val

        # Save the tranformed equation
        variable['transformed_eq'] = eq

        # Find the numerical method if any
        method = find_method(variable)

        # Process the bounds
        if 'min' in variable['bounds'].keys():
            if isinstance(variable['bounds']['min'], str):
                translator = Equation(variable['name'],
                                      variable['bounds']['min'],
                                      description,
                                      type='return',
                                      untouched=untouched.keys())
                variable['bounds']['min'] = translator.parse().replace(';', '')
                dependencies += translator.dependencies()

        if 'max' in variable['bounds'].keys():
            if isinstance(variable['bounds']['max'], str):
                translator = Equation(variable['name'],
                                      variable['bounds']['max'],
                                      description,
                                      type='return',
                                      untouched=untouched.keys())
                variable['bounds']['max'] = translator.parse().replace(';', '')
                dependencies += translator.dependencies()

        # Analyse the equation
        if condition == []:  # Call Equation
            translator = Equation(variable['name'],
                                  eq,
                                  description,
                                  method=method,
                                  untouched=untouched.keys())
            code = translator.parse()
            dependencies += translator.dependencies()

        else:  # An if-then-else statement
            code, deps = translate_ITE(variable['name'], eq, condition,
                                       description, untouched)
            dependencies += deps

        if isinstance(code, str):
            pre_loop = {}
            cpp_eq = code
            switch = None
        else:  # ODE
            pre_loop = code[0]
            cpp_eq = code[1]
            switch = code[2]

        # Replace untouched variables with their original name
        for prev, new in untouched.items():
            cpp_eq = cpp_eq.replace(prev, new)

        # Replace local functions
        for f in description['functions']:
            cpp_eq = re.sub(r'([^\w]*)' + f['name'] + '\(',
                            r'\1' + f['name'] + '(', ' ' + cpp_eq).strip()

        # Store the result
        variable[
            'pre_loop'] = pre_loop  # Things to be declared before the for loop (eg. dt)
        variable['cpp'] = cpp_eq  # the C++ equation
        variable['switch'] = switch  # switch value id ODE
        variable['untouched'] = untouched  # may be needed later
        variable['method'] = method  # may be needed later
        variable['dependencies'] = dependencies

        # If the method is implicit or midpoint, the equations must be solved concurrently (depend on v[t+1])
        if method in ['implicit', 'midpoint'] and switch is not None:
            concurrent_odes.append(variable)

    # After all variables are processed, do it again if they are concurrent
    if len(concurrent_odes) > 1:
        solver = CoupledEquations(description, concurrent_odes)
        new_eqs = solver.parse()
        for idx, variable in enumerate(description['variables']):
            for new_eq in new_eqs:
                if variable['name'] == new_eq['name']:
                    description['variables'][idx] = new_eq

    # Translate the psp code if any
    if 'raw_psp' in description.keys():
        psp = {'eq': description['raw_psp'].strip()}

        # Extract global operations
        eq, untouched_globs, global_ops = extract_globalops_synapse(
            'psp', " " + psp['eq'] + " ", description)
        description['pre_global_operations'] += global_ops['pre']
        description['post_global_operations'] += global_ops['post']

        # Replace pre- and post_synaptic variables
        eq, untouched, prepost_dependencies = extract_prepost(
            'psp', eq, description)
        description['dependencies']['pre'] += prepost_dependencies['pre']
        description['dependencies']['post'] += prepost_dependencies['post']
        for name, val in untouched_globs.items():
            if not name in untouched.keys():
                untouched[name] = val

        # Extract if-then-else statements
        eq, condition = extract_ite('psp', eq, description, split=False)

        # Analyse the equation
        if condition == []:
            translator = Equation('psp',
                                  eq,
                                  description,
                                  method='explicit',
                                  untouched=untouched.keys(),
                                  type='return')
            code = translator.parse()
            deps = translator.dependencies()
        else:
            code, deps = translate_ITE('psp', eq, condition, description,
                                       untouched)

        # Replace untouched variables with their original name
        for prev, new in untouched.items():
            code = code.replace(prev, new)

        # Store the result
        psp['cpp'] = code
        psp['dependencies'] = deps
        description['psp'] = psp

    # Process event-driven info
    if description['type'] == 'spike':
        for variable in description['pre_spike'] + description['post_spike']:
            # Find plasticity
            if variable['name'] == 'w':
                description['plasticity'] = True

            # Retrieve the equation
            eq = variable['eq']

            # Extract if-then-else statements
            eq, condition = extract_ite(variable['name'], eq, description)

            # Extract pre- and post_synaptic variables
            eq, untouched, prepost_dependencies = extract_prepost(
                variable['name'], eq, description)

            # Update dependencies
            description['dependencies']['pre'] += prepost_dependencies['pre']
            description['dependencies']['post'] += prepost_dependencies['post']
            # and also on the variable for checking
            variable['prepost_dependencies'] = prepost_dependencies

            # Analyse the equation
            dependencies = []
            if condition == []:
                translator = Equation(variable['name'],
                                      eq,
                                      description,
                                      method='explicit',
                                      untouched=untouched)
                code = translator.parse()
                dependencies += translator.dependencies()
            else:
                code, deps = translate_ITE(variable['name'], eq, condition,
                                           description, untouched)
                dependencies += deps

            if isinstance(code, list):  # an ode in a pre/post statement
                Global._print(eq)
                if variable in description['pre_spike']:
                    Global._error(
                        'It is forbidden to use ODEs in a pre_spike term.')
                elif variable in description['posz_spike']:
                    Global._error(
                        'It is forbidden to use ODEs in a post_spike term.')
                else:
                    Global._error('It is forbidden to use ODEs here.')

            # Replace untouched variables with their original name
            for prev, new in untouched.items():
                code = code.replace(prev, new)

            # Process the bounds
            if 'min' in variable['bounds'].keys():
                if isinstance(variable['bounds']['min'], str):
                    translator = Equation(variable['name'],
                                          variable['bounds']['min'],
                                          description,
                                          type='return',
                                          untouched=untouched)
                    variable['bounds']['min'] = translator.parse().replace(
                        ';', '')
                    dependencies += translator.dependencies()

            if 'max' in variable['bounds'].keys():
                if isinstance(variable['bounds']['max'], str):
                    translator = Equation(variable['name'],
                                          variable['bounds']['max'],
                                          description,
                                          type='return',
                                          untouched=untouched)
                    variable['bounds']['max'] = translator.parse().replace(
                        ';', '')
                    dependencies += translator.dependencies()

            # Store the result
            variable['cpp'] = code  # the C++ equation
            variable['dependencies'] = dependencies

    # Structural plasticity
    if synapse.pruning:
        description['pruning'] = extract_structural_plasticity(
            synapse.pruning, description)
    if synapse.creating:
        description['creating'] = extract_structural_plasticity(
            synapse.creating, description)

    return description
예제 #15
0
class CoupledEquations(object):
    def __init__(self, description, variables):
        self.description = description
        self.variables = variables

        self.untouched = variables[0]['untouched']

        self.expression_list = {}
        for var in self.variables:
            self.expression_list[var['name']] = var['transformed_eq']

        self.names = self.expression_list.keys()

        self.local_variables = self.description['local']
        self.global_variables = self.description['global']

        self.local_dict = Equation('tmp',
                                   '',
                                   self.description,
                                   method='implicit',
                                   untouched=self.untouched).local_dict

    def process_variables(self):

        # Check if the numerical method is the same for all ODEs
        methods = []
        for var in self.variables:
            methods.append(var['method'])
        if len(list(set(methods))) > 1:  # mixture of methods
            _error(
                'Can not mix different numerical methods when solving a coupled system of equations.'
            )
            exit(0)
        else:
            method = methods[0]

        if method == 'implicit' or method == 'semiimplicit':
            return self.solve_implicit(self.expression_list)
        elif method == 'midpoint':
            return self.solve_midpoint(self.expression_list)

    def solve_implicit(self, expression_list):

        equations = {}
        new_vars = {}

        # Pre-processing to replace the gradient
        for name, expression in self.expression_list.items():
            # transform the expression to suppress =
            if '=' in expression:
                expression = expression.replace('=', '- (')
                expression += ')'
            # Suppress spaces to extract dvar/dt
            expression = expression.replace(' ', '')
            # Transform the gradient into a difference TODO: more robust...
            expression = expression.replace('d' + name, '_t_gradient_')
            expression_list[name] = expression

        # replace the variables by their future value
        for name, expression in expression_list.items():
            for n in self.names:
                expression = re.sub(r'([^\w]+)' + n + r'([^\w]+)',
                                    r'\1_' + n + r'\2', expression)
            expression = expression.replace('_t_gradient_',
                                            '(_' + name + ' - ' + name + ')')
            expression_list[name] = expression + '-' + name

            new_var = Symbol('_' + name)
            self.local_dict['_' + name] = new_var
            new_vars[new_var] = name

        for name, expression in expression_list.items():
            analysed = parse_expr(expression,
                                  local_dict=self.local_dict,
                                  transformations=(standard_transformations +
                                                   (convert_xor, )))
            equations[name] = analysed

        try:
            solution = solve(equations.values(), new_vars.keys())
        except:
            _error(
                'The multiple ODEs can not be solved together using the implicit Euler method.'
            )
            exit(0)

        for var, sol in solution.items():
            # simplify the solution
            sol = collect(sol, self.local_dict['dt'])

            # Generate the code
            cpp_eq = 'double _' + new_vars[var] + ' = ' + ccode(sol) + ';'
            switch = ccode(
                self.local_dict[new_vars[var]]) + ' += _' + new_vars[var] + ';'

            # Replace untouched variables with their original name
            for prev, new in self.untouched.items():
                cpp_eq = re.sub(prev, new, cpp_eq)
                switch = re.sub(prev, new, switch)

            # Store the result
            for variable in self.variables:
                if variable['name'] == new_vars[var]:
                    variable['cpp'] = cpp_eq
                    variable['switch'] = switch

        return self.variables

    def solve_midpoint(self, expression_list):

        expression_list = {}
        equations = {}
        evaluations = {}

        # Pre-processing to replace the gradient
        for name, expression in self.expression_list.items():
            # transform the expression to suppress =
            if '=' in expression:
                expression = expression.replace('=', '- (')
                expression += ')'
            # Suppress spaces to extract dvar/dt
            expression = expression.replace(' ', '')
            # Transform the gradient into a difference TODO: more robust...
            expression = expression.replace('d' + name + '/dt',
                                            '_gradient_' + name)
            self.local_dict['_gradient_' + name] = Symbol('_gradient_' + name)
            expression_list[name] = expression

        for name, expression in expression_list.items():
            analysed = parse_expr(expression,
                                  local_dict=self.local_dict,
                                  transformations=(standard_transformations +
                                                   (convert_xor, )))
            equations[name] = analysed
            evaluations[name] = solve(analysed,
                                      self.local_dict['_gradient_' + name])

        # Compute the k = f(x, t)
        ks = {}
        for name, evaluation in evaluations.items():
            ks[name] = 'double _k_' + name + ' = ' + ccode(evaluation[0]) + ';'

        # New dictionary replacing x by x+dt/2*k)
        tmp_dict = {}
        for name, val in self.local_dict.items():
            tmp_dict[name] = val
        for name, evaluation in evaluations.items():
            tmp_dict[name] = Symbol('(' + ccode(self.local_dict[name]) +
                                    ' + 0.5*dt*_k_' + name + ' )')

        # Compute the new values _x_new = f(x + dt/2*_k)
        news = {}
        for name, expression in expression_list.items():
            tmp_analysed = parse_expr(
                expression,
                local_dict=tmp_dict,
                transformations=(standard_transformations + (convert_xor, )))
            solved = solve(tmp_analysed, self.local_dict['_gradient_' + name])
            news[name] = 'double _' + name + ' = ' + ccode(solved[0]) + ';'

        # Compute the switches
        switches = {}
        for name, expression in expression_list.items():
            switches[name] = ccode(
                self.local_dict[name]) + ' += dt * _' + name + ' ;'

        # Store the generated code in the variables
        for name in self.names:
            k = ks[name]
            n = news[name]
            switch = switches[name]

            # Replace untouched variables with their original name
            for prev, new in self.untouched.items():
                k = re.sub(prev, new, k)
                n = re.sub(prev, new, n)
                switch = re.sub(prev, new, switch)

            # Store the result
            for variable in self.variables:
                if variable['name'] == name:
                    variable['cpp'] = [k, n]
                    variable['switch'] = switch

        return self.variables

        expression = expression.replace('d' + self.name + '/dt', '_grad_var_')
        new_var = Symbol('_grad_var_')
        self.local_dict['_grad_var_'] = new_var

        analysed = self.parse_expression(expression,
                                         local_dict=self.local_dict)

        variable_name = self.local_dict[self.name]

        equation = simplify(
            collect(solve(analysed, new_var)[0], self.local_dict['dt']))

        explicit_code = 'double _k_' + self.name + ' = dt*(' + self.c_code(
            equation) + ');'

        # Midpoint method:
        # Replace the variable x by x+_x/2
        tmp_dict = self.local_dict
        tmp_dict[self.name] = Symbol('(' + self.c_code(variable_name) +
                                     ' + 0.5*_k_' + self.name + ' )')
        tmp_analysed = self.parse_expression(expression,
                                             local_dict=self.local_dict)
        tmp_equation = solve(tmp_analysed, new_var)[0]

        explicit_code += '\n    double _' + self.name + ' = ' + self.c_code(
            tmp_equation) + ';'

        switch = self.c_code(variable_name) + ' += dt*_' + self.name + ' ;'

        # Return result
        return [explicit_code, switch]
예제 #16
0
def extract_structural_plasticity(statement, description):
    # Extract flags
    try:
        eq, constraint = statement.rsplit(':', 1)
        bounds, flags = extract_flags(constraint)
    except:
        eq = statement.strip()
        bounds = {}
        flags = []

    # Extract RD
    rd = None
    for dist in available_distributions:
        matches = re.findall('(?P<pre>[^\w.])' + dist + '\(([^()]+)\)', eq)
        for l, v in matches:
            # Check the arguments
            arguments = v.split(',')
            # Check the number of provided arguments
            if len(arguments) < distributions_arguments[dist]:
                Global._print(eq)
                Global._error('The distribution ' + dist + ' requires ' +
                              str(distributions_arguments[dist]) +
                              'parameters')
            elif len(arguments) > distributions_arguments[dist]:
                Global._print(eq)
                Global._error(
                    'Too many parameters provided to the distribution ' + dist)
            # Process the arguments
            processed_arguments = ""
            for idx in range(len(arguments)):
                try:
                    arg = float(arguments[idx])
                except:  # A global parameter
                    Global._print(eq)
                    Global._error(
                        'Random distributions for creating/pruning synapses must use foxed values.'
                    )

                processed_arguments += str(arg)
                if idx != len(arguments) - 1:  # not the last one
                    processed_arguments += ', '
            definition = distributions_equivalents[
                dist] + '(' + processed_arguments + ')'

            # Store its definition
            if rd:
                Global._print(eq)
                Global._error(
                    'Only one random distribution per equation is allowed.')

            rd = {
                'name': 'rand_' + str(0),
                'origin': dist + '(' + v + ')',
                'dist': dist,
                'definition': definition,
                'args': processed_arguments,
                'template': distributions_equivalents[dist]
            }

    if rd:
        eq = eq.replace(rd['origin'], 'rd(rng)')

    # Extract pre/post dependencies
    eq, untouched, dependencies = extract_prepost('test', eq, description)

    # Parse code
    translator = Equation('test', eq, description, method='cond', untouched={})

    code = translator.parse()
    deps = translator.dependencies()

    # Replace untouched variables with their original name
    for prev, new in untouched.items():
        code = code.replace(prev, new)

    # Add new dependencies
    for dep in dependencies['pre']:
        description['dependencies']['pre'].append(dep)
    for dep in dependencies['post']:
        description['dependencies']['post'].append(dep)

    return {
        'eq': eq,
        'cpp': code,
        'bounds': bounds,
        'flags': flags,
        'rd': rd,
        'dependencies': deps
    }
예제 #17
0
def analyse_synapse(synapse):
    """
    Parses the structure and generates code snippets for the synapse type.

    It returns a ``description`` dictionary with the following fields:

    * 'object': 'synapse' by default, to distinguish it from 'neuron'
    * 'type': either 'rate' or 'spiking'
    * 'raw_parameters': provided field
    * 'raw_equations': provided field
    * 'raw_functions': provided field
    * 'raw_psp': provided field
    * 'raw_pre_spike': provided field
    * 'raw_post_spike': provided field
    * 'parameters': list of parameters defined for the synapse type
    * 'variables': list of variables defined for the synapse type
    * 'functions': list of functions defined for the synapse type
    * 'attributes': list of names of all parameters and variables
    * 'local': list of names of parameters and variables which are local to each synapse
    * 'semiglobal': list of names of parameters and variables which are local to each postsynaptic neuron
    * 'global': list of names of parameters and variables which are global to the projection
    * 'random_distributions': list of random number generators used in the neuron equations
    * 'global_operations': list of global operations (min/max/mean...) used in the equations
    * 'pre_global_operations': list of global operations (min/max/mean...) on the pre-synaptic population
    * 'post_global_operations': list of global operations (min/max/mean...) on the post-synaptic population
    * 'pre_spike': list of variables updated after a pre-spike event
    * 'post_spike': list of variables updated after a post-spike event
    * 'dependencies': dictionary ('pre', 'post') of lists of pre (resp. post) variables accessed by the synapse (used for delaying variables)
    * 'psp': dictionary ('eq' and 'psp') for the psp code to be summed
    * 'pruning' and 'creating': statements for structural plasticity


    Each parameter is a dictionary with the following elements:

    * 'bounds': unused
    * 'ctype': 'type of the parameter: 'float', 'double', 'int' or 'bool'
    * 'eq': original equation in text format
    * 'flags': list of flags provided after the :
    * 'init': initial value
    * 'locality': 'local', 'semiglobal' or 'global'
    * 'name': name of the parameter

    Each variable is a dictionary with the following elements:

    * 'bounds': dictionary of bounds ('init', 'min', 'max') provided after the :
    * 'cpp': C++ code snippet updating the variable
    * 'ctype': type of the variable: 'float', 'double', 'int' or 'bool'
    * 'dependencies': list of variable and parameter names on which the equation depends
    * 'eq': original equation in text format
    * 'flags': list of flags provided after the :
    * 'init': initial value
    * 'locality': 'local', 'semiglobal' or 'global'
    * 'method': numericalmethod for ODEs
    * 'name': name of the variable
    * 'pre_loop': ODEs have a pre_loop term for precomputing dt/tau
    * 'switch': ODEs have a switch term
    * 'transformed_eq': same as eq, except special terms (sums, rds) are replaced with a temporary name
    * 'untouched': dictionary of special terms, with their new name as keys and replacement values as values.

    """

    # Store basic information
    description = {
        'object': 'synapse',
        'type': synapse.type,
        'raw_parameters': synapse.parameters,
        'raw_equations': synapse.equations,
        'raw_functions': synapse.functions
    }

    # Psps is what is actually summed over the incoming weights
    if synapse.psp:
        description['raw_psp'] = synapse.psp
    elif synapse.type == 'rate':
        description['raw_psp'] = "w*pre.r"

    # Spiking synapses additionally store pre_spike and post_spike
    if synapse.type == 'spike':
        description['raw_pre_spike'] = synapse.pre_spike
        description['raw_post_spike'] = synapse.post_spike

    # Extract parameters and variables names
    parameters = extract_parameters(synapse.parameters, synapse.extra_values)
    variables = extract_variables(synapse.equations)

    # Extract functions
    functions = extract_functions(synapse.functions, False)

    # Check the presence of w
    description['plasticity'] = False
    for var in parameters + variables:
        if var['name'] == 'w':
            break
    else:
        parameters.append(
            {
                'name': 'w', 'bounds': {}, 'ctype': config['precision'],
                'init': 0.0, 'flags': [], 'eq': 'w=0.0', 'locality': 'local'
            }
        )

    # Find out a plasticity rule
    for var in variables:
        if var['name'] == 'w':
            description['plasticity'] = True
            break

    # Build lists of all attributes (param+var), which are local or global
    attributes, local_var, global_var, semiglobal_var = get_attributes(parameters, variables, neuron=False)

    # Test if attributes are declared only once
    if len(attributes) != len(list(set(attributes))):
        _error('Attributes must be declared only once.', attributes)


    # Add this info to the description
    description['parameters'] = parameters
    description['variables'] = variables
    description['functions'] = functions
    description['attributes'] = attributes
    description['local'] = local_var
    description['semiglobal'] = semiglobal_var
    description['global'] = global_var
    description['global_operations'] = []

    # Lists of global operations needed at the pre and post populations
    description['pre_global_operations'] = []
    description['post_global_operations'] = []

    # Extract RandomDistribution objects
    description['random_distributions'] = extract_randomdist(description)

    # Extract event-driven info
    if description['type'] == 'spike':
        # pre_spike event
        description['pre_spike'] = extract_pre_spike_variable(description)
        for var in description['pre_spike']:
            if var['name'] in ['g_target']: # Already dealt with
                continue
            for avar in description['variables']:
                if var['name'] == avar['name']:
                    break
            else: # not defined already
                description['variables'].append(
                {'name': var['name'], 'bounds': var['bounds'], 'ctype': var['ctype'], 'init': var['init'],
                 'locality': var['locality'],
                 'flags': [], 'transformed_eq': '', 'eq': '',
                 'cpp': '', 'switch': '', 're_loop': '',
                 'untouched': '', 'method':'explicit'}
                )
                description['local'].append(var['name'])
                description['attributes'].append(var['name'])

        # post_spike event
        description['post_spike'] = extract_post_spike_variable(description)
        for var in description['post_spike']:
            if var['name'] in ['g_target', 'w']: # Already dealt with
                continue
            for avar in description['variables']:
                if var['name'] == avar['name']:
                    break
            else: # not defined already
                description['variables'].append(
                {'name': var['name'], 'bounds': var['bounds'], 'ctype': var['ctype'], 'init': var['init'],
                 'locality': var['locality'],
                 'flags': [], 'transformed_eq': '', 'eq': '',
                 'cpp': '', 'switch': '', 'untouched': '', 'method':'explicit'}
                )
                description['local'].append(var['name'])
                description['attributes'].append(var['name'])

    # Variables names for the parser which should be left untouched
    untouched = {}
    description['dependencies'] = {'pre': [], 'post': []}

    # The ODEs may be interdependent (implicit, midpoint), so they need to be passed explicitely to CoupledEquations
    concurrent_odes = []

    # Iterate over all variables
    for variable in description['variables']:
        # Equation
        eq = variable['transformed_eq']
        if eq.strip() == '':
            continue

        # Dependencies must be gathered
        dependencies = []

        # Extract global operations
        eq, untouched_globs, global_ops = extract_globalops_synapse(variable['name'], eq, description)
        description['pre_global_operations'] += global_ops['pre']
        description['post_global_operations'] += global_ops['post']
        # Remove doubled entries
        description['pre_global_operations'] = [i for n, i in enumerate(description['pre_global_operations']) if i not in description['pre_global_operations'][n + 1:]]
        description['post_global_operations'] = [i for n, i in enumerate(description['post_global_operations']) if i not in description['post_global_operations'][n + 1:]]

        # Extract pre- and post_synaptic variables
        eq, untouched_var, prepost_dependencies = extract_prepost(variable['name'], eq, description)

        # Store the pre-post dependencies at the synapse level
        description['dependencies']['pre'] += prepost_dependencies['pre']
        description['dependencies']['post'] += prepost_dependencies['post']
        # and also on the variable for checking
        variable['prepost_dependencies'] = prepost_dependencies

        # Extract if-then-else statements
        eq, condition = extract_ite(variable['name'], eq, description)

        # Add the untouched variables to the global list
        for name, val in untouched_globs.items():
            if not name in untouched.keys():
                untouched[name] = val
        for name, val in untouched_var.items():
            if not name in untouched.keys():
                untouched[name] = val

        # Save the tranformed equation
        variable['transformed_eq'] = eq

        # Find the numerical method if any
        method = find_method(variable)

        # Process the bounds
        if 'min' in variable['bounds'].keys():
            if isinstance(variable['bounds']['min'], str):
                translator = Equation(variable['name'], variable['bounds']['min'],
                                      description,
                                      type = 'return',
                                      untouched = untouched.keys())
                variable['bounds']['min'] = translator.parse().replace(';', '')
                dependencies += translator.dependencies()

        if 'max' in variable['bounds'].keys():
            if isinstance(variable['bounds']['max'], str):
                translator = Equation(variable['name'], variable['bounds']['max'],
                                      description,
                                      type = 'return',
                                      untouched = untouched.keys())
                variable['bounds']['max'] = translator.parse().replace(';', '')
                dependencies += translator.dependencies()

        # Analyse the equation
        if condition == []: # Call Equation
            translator = Equation(variable['name'], eq,
                                  description,
                                  method = method,
                                  untouched = untouched.keys())
            code = translator.parse()
            dependencies += translator.dependencies()

        else: # An if-then-else statement
            code, deps = translate_ITE(variable['name'], eq, condition, description,
                    untouched)
            dependencies += deps

        if isinstance(code, str):
            pre_loop = {}
            cpp_eq = code
            switch = None
        else: # ODE
            pre_loop = code[0]
            cpp_eq = code[1]
            switch = code[2]

        # Replace untouched variables with their original name
        for prev, new in untouched.items():
            cpp_eq = cpp_eq.replace(prev, new)

        # Replace local functions
        for f in description['functions']:
            cpp_eq = re.sub(r'([^\w]*)'+f['name']+'\(', r'\1'+ f['name'] + '(', ' ' + cpp_eq).strip()

        # Store the result
        variable['pre_loop'] = pre_loop # Things to be declared before the for loop (eg. dt)
        variable['cpp'] = cpp_eq # the C++ equation
        variable['switch'] = switch # switch value id ODE
        variable['untouched'] = untouched # may be needed later
        variable['method'] = method # may be needed later
        variable['dependencies'] = dependencies 

        # If the method is implicit or midpoint, the equations must be solved concurrently (depend on v[t+1])
        if method in ['implicit', 'midpoint'] and switch is not None:
            concurrent_odes.append(variable)

    # After all variables are processed, do it again if they are concurrent
    if len(concurrent_odes) > 1 :
        solver = CoupledEquations(description, concurrent_odes)
        new_eqs = solver.parse()
        for idx, variable in enumerate(description['variables']):
            for new_eq in new_eqs:
                if variable['name'] == new_eq['name']:
                    description['variables'][idx] = new_eq

    # Translate the psp code if any
    if 'raw_psp' in description.keys():
        psp = {'eq' : description['raw_psp'].strip() }

        # Extract global operations
        eq, untouched_globs, global_ops = extract_globalops_synapse('psp', " " + psp['eq'] + " ", description)
        description['pre_global_operations'] += global_ops['pre']
        description['post_global_operations'] += global_ops['post']

        # Replace pre- and post_synaptic variables
        eq, untouched, prepost_dependencies = extract_prepost('psp', eq, description)
        description['dependencies']['pre'] += prepost_dependencies['pre']
        description['dependencies']['post'] += prepost_dependencies['post']
        for name, val in untouched_globs.items():
            if not name in untouched.keys():
                untouched[name] = val

        # Extract if-then-else statements
        eq, condition = extract_ite('psp', eq, description, split=False)

        # Analyse the equation
        if condition == []:
            translator = Equation('psp', eq,
                                  description,
                                  method = 'explicit',
                                  untouched = untouched.keys(),
                                  type='return')
            code = translator.parse()
            deps = translator.dependencies()
        else:
            code, deps = translate_ITE('psp', eq, condition, description, untouched)

        # Replace untouched variables with their original name
        for prev, new in untouched.items():
            code = code.replace(prev, new)

        # Store the result
        psp['cpp'] = code
        psp['dependencies'] = deps
        description['psp'] = psp

    # Process event-driven info
    if description['type'] == 'spike':
        for variable in description['pre_spike'] + description['post_spike']:
            # Find plasticity
            if variable['name'] == 'w':
                description['plasticity'] = True

            # Retrieve the equation
            eq = variable['eq']
            
            # Extract if-then-else statements
            eq, condition = extract_ite(variable['name'], eq, description)

            # Extract pre- and post_synaptic variables
            eq, untouched, prepost_dependencies = extract_prepost(variable['name'], eq, description)

            # Update dependencies
            description['dependencies']['pre'] += prepost_dependencies['pre']
            description['dependencies']['post'] += prepost_dependencies['post']
            # and also on the variable for checking
            variable['prepost_dependencies'] = prepost_dependencies

            # Analyse the equation
            dependencies = []
            if condition == []:
                translator = Equation(variable['name'], 
                                      eq,
                                      description,
                                      method = 'explicit',
                                      untouched = untouched)
                code = translator.parse()
                dependencies += translator.dependencies()
            else:
                code, deps = translate_ITE(variable['name'], eq, condition, description, untouched)
                dependencies += deps

            if isinstance(code, list): # an ode in a pre/post statement
                Global._print(eq)
                if variable in description['pre_spike']:
                    Global._error('It is forbidden to use ODEs in a pre_spike term.')
                elif variable in description['posz_spike']:
                    Global._error('It is forbidden to use ODEs in a post_spike term.')
                else:
                    Global._error('It is forbidden to use ODEs here.')

            # Replace untouched variables with their original name
            for prev, new in untouched.items():
                code = code.replace(prev, new)

            # Process the bounds
            if 'min' in variable['bounds'].keys():
                if isinstance(variable['bounds']['min'], str):
                    translator = Equation(
                                    variable['name'],
                                    variable['bounds']['min'],
                                    description,
                                    type = 'return',
                                    untouched = untouched )
                    variable['bounds']['min'] = translator.parse().replace(';', '')
                    dependencies += translator.dependencies()

            if 'max' in variable['bounds'].keys():
                if isinstance(variable['bounds']['max'], str):
                    translator = Equation(
                                    variable['name'],
                                    variable['bounds']['max'],
                                    description,
                                    type = 'return',
                                    untouched = untouched)
                    variable['bounds']['max'] = translator.parse().replace(';', '')
                    dependencies += translator.dependencies()


            # Store the result
            variable['cpp'] = code # the C++ equation
            variable['dependencies'] = dependencies

    # Structural plasticity
    if synapse.pruning:
        description['pruning'] = extract_structural_plasticity(synapse.pruning, description)
    if synapse.creating:
        description['creating'] = extract_structural_plasticity(synapse.creating, description)

    return description
예제 #18
0
class CoupledEquations(object):

    def __init__(self, description, variables):
        self.description = description
        self.variables = variables

        self.untouched = variables[0]['untouched']

        self.expression_list = {}
        for var in self.variables:
            self.expression_list[var['name']] = var['transformed_eq']


        self.names = self.expression_list.keys()

        self.local_variables = self.description['local']
        self.global_variables = self.description['global']

        self.local_dict = Equation('tmp', '',
                                      self.description, 
                                      method = 'implicit',
                                      untouched = self.untouched
                                      ).local_dict


    def process_variables(self):

        # Check if the numerical method is the same for all ODEs
        methods = []
        for var in self.variables:
            methods.append(var['method'])
        if len(list(set(methods))) > 1: # mixture of methods
            _print(methods)
            _error('Can not mix different numerical methods when solving a coupled system of equations.')
            
        else:
            method = methods[0]

        if method == 'implicit' or method == 'semiimplicit':
            return self.solve_implicit(self.expression_list)
        elif method == 'midpoint': 
            return self.solve_midpoint(self.expression_list)


    def solve_implicit(self, expression_list):

        equations = {}
        new_vars = {}

        # Pre-processing to replace the gradient
        for name, expression in self.expression_list.items():
            # transform the expression to suppress =
            if '=' in expression:
                expression = expression.replace('=', '- (')
                expression += ')'
            # Suppress spaces to extract dvar/dt
            expression = expression.replace(' ', '')
            # Transform the gradient into a difference TODO: more robust...
            expression = expression.replace('d'+name, '_t_gradient_')
            expression_list[name] = expression

        # replace the variables by their future value
        for name, expression in expression_list.items():
            for n in self.names:
                expression = re.sub(r'([^\w]+)'+n+r'([^\w]+)', r'\1_'+n+r'\2', expression)
            expression = expression.replace('_t_gradient_', '(_'+name+' - '+name+')')
            expression_list[name] = expression + '-' + name

            new_var = Symbol('_'+name)
            self.local_dict['_'+name] = new_var
            new_vars[new_var] = name


        for name, expression in expression_list.items():
            analysed = parse_expr(expression,
                local_dict = self.local_dict,
                transformations = (standard_transformations + (convert_xor,))
            )
            equations[name] = analysed
          
        try:
            solution = solve(equations.values(), new_vars.keys())
        except:
            _print(expression_list)
            _error('The multiple ODEs can not be solved together using the implicit Euler method.')
            

        for var, sol in solution.items():
            # simplify the solution
            sol  = collect( sol, self.local_dict['dt'])

            # Generate the code
            cpp_eq = 'double _' + new_vars[var] + ' = ' + ccode(sol) + ';'
            switch =  ccode(self.local_dict[new_vars[var]] ) + ' += _' + new_vars[var] + ';'

            # Replace untouched variables with their original name
            for prev, new in self.untouched.items():
                cpp_eq = re.sub(prev, new, cpp_eq)
                switch = re.sub(prev, new, switch)

            # Store the result
            for variable in self.variables:
                if variable['name'] == new_vars[var]:
                    variable['cpp'] = cpp_eq
                    variable['switch'] = switch
            
        return self.variables



    def solve_midpoint(self, expression_list):

        expression_list = {}
        equations = {}
        evaluations = {}

        # Pre-processing to replace the gradient
        for name, expression in self.expression_list.items():
            # transform the expression to suppress =
            if '=' in expression:
                expression = expression.replace('=', '- (')
                expression += ')'
            # Suppress spaces to extract dvar/dt
            expression = expression.replace(' ', '')
            # Transform the gradient into a difference TODO: more robust...
            expression = expression.replace('d'+name+'/dt', '_gradient_'+name)
            self.local_dict['_gradient_'+name] = Symbol('_gradient_'+name)
            expression_list[name] = expression

        for name, expression in expression_list.items():
            analysed = parse_expr(expression,
                local_dict = self.local_dict,
                transformations = (standard_transformations + (convert_xor,))
            )
            equations[name] = analysed
            evaluations[name] = solve(analysed, self.local_dict['_gradient_'+name])


        # Compute the k = f(x, t)
        ks = {}
        for name, evaluation in evaluations.items():
            ks[name] = 'double _k_' + name + ' = ' + ccode(evaluation[0]) + ';'

        # New dictionary replacing x by x+dt/2*k)
        tmp_dict = {}
        for name, val in self.local_dict.items():
            tmp_dict[name] = val
        for name, evaluation in evaluations.items():
            tmp_dict[name] = Symbol('(' + ccode(self.local_dict[name]) + ' + 0.5*dt*_k_' + name + ' )')

        # Compute the new values _x_new = f(x + dt/2*_k)
        news = {}
        for name, expression in expression_list.items():
            tmp_analysed = parse_expr(expression,
                local_dict = tmp_dict,
                transformations = (standard_transformations + (convert_xor,))
            )
            solved = solve(tmp_analysed, self.local_dict['_gradient_'+name])
            news[name] = 'double _' + name + ' = ' + ccode(solved[0]) + ';'

        # Compute the switches
        switches = {}
        for name, expression in expression_list.items():
            switches[name] = ccode(self.local_dict[name]) + ' += dt * _' + name + ' ;'

        # Store the generated code in the variables
        for name in self.names:
            k = ks[name]
            n = news[name]
            switch = switches[name]

            # Replace untouched variables with their original name
            for prev, new in self.untouched.items():
                k = re.sub(prev, new, k)
                n = re.sub(prev, new, n)
                switch = re.sub(prev, new, switch)

            # Store the result
            for variable in self.variables:
                if variable['name'] == name:
                    variable['cpp'] = [k, n]
                    variable['switch'] = switch
            
        return self.variables


        expression = expression.replace('d'+self.name+'/dt', '_grad_var_')
        new_var = Symbol('_grad_var_')
        self.local_dict['_grad_var_'] = new_var

        analysed = self.parse_expression(expression,
            local_dict = self.local_dict
        )


        variable_name = self.local_dict[self.name]

        equation = simplify(collect( solve(analysed, new_var)[0], self.local_dict['dt']))

        explicit_code =  'double _k_' + self.name + ' = dt*(' + self.c_code(equation) + ');'

        # Midpoint method:
        # Replace the variable x by x+_x/2
        tmp_dict = self.local_dict
        tmp_dict[self.name] = Symbol('(' + self.c_code(variable_name) + ' + 0.5*_k_' + self.name + ' )')
        tmp_analysed = self.parse_expression(expression,
            local_dict = self.local_dict
        )
        tmp_equation = solve(tmp_analysed, new_var)[0]

        explicit_code += '\n    double _' + self.name + ' = ' + self.c_code(tmp_equation) + ';'

        switch = self.c_code(variable_name) + ' += dt*_' + self.name + ' ;'

        # Return result
        return [explicit_code, switch]