Exemple #1
0
    def visit_Call(self, node, level=0):
        if getattr(node, 'starargs', None) is not None:
            raise ValueError(
                "Variable number of arguments (*args) are not supported")
        if getattr(node, 'kwargs', None) is not None:
            raise ValueError("Keyword arguments (**kwargs) are not supported")

        # get function name
        call = tools.ast2code(node.func)
        if call == self.func_name:
            # args
            args = [self.generic_visit(arg) for arg in node.args]
            # remove self arg
            if self.remove_self:
                if args[0].id == self.remove_self:
                    args.pop(0)
            # kwargs
            kwargs = [self.generic_visit(keyword) for keyword in node.keywords]
            # new kwargs
            arg_to_append = deepcopy(self.arg_to_append)
            for arg in kwargs:
                if arg.arg in arg_to_append:
                    arg_to_append.remove(arg.arg)
            if len(arg_to_append):
                code = f'f({", ".join([f"{k}={k}" for k in arg_to_append])})'
                tree = ast.parse(code)
                new_keywords = tree.body[0].value.keywords
                kwargs.extend(new_keywords)
            # final function
            if self.new_func_name:
                func_call = ast.parse(
                    f'{self.new_func_name}()').body[0].value.func
            else:
                func_call = node.func
            return ast.Call(func=func_call, args=args, keywords=kwargs)
        return node
Exemple #2
0
    def visit_For(self, node):
        iter_ = tools.ast2code(ast.fix_missing_locations(node.iter))

        if iter_.strip() == self.iter_name:
            data_to_replace = Collector()
            final_node = ast.Module(body=[])
            self.success = True

            # target
            if not isinstance(node.target, ast.Name):
                raise errors.BrainPyError(
                    f'Only support scalar iter, like "for x in xxxx:", not "for '
                    f'{tools.ast2code(ast.fix_missing_locations(node.target))} '
                    f'in {iter_}:')
            target = node.target.id

            # for loop values
            for i, value in enumerate(self.loop_values):
                # module and code
                module = ast.Module(body=deepcopy(node).body)
                code = tools.ast2code(module)

                if isinstance(value, Base):  # transform Base objects
                    r = _analyze_cls_func_body(host=value,
                                               self_name=target,
                                               code=code,
                                               tree=module,
                                               show_code=self.show_code,
                                               **self.jit_setting)

                    new_code, arguments, arg2call, nodes, code_scope = r
                    self.arguments.update(arguments)
                    self.arg2call.update(arg2call)
                    self.arg2call.update(arg2call)
                    self.nodes.update(nodes)
                    self.code_scope.update(code_scope)

                    final_node.body.extend(ast.parse(new_code).body)

                elif callable(value):  # transform functions
                    r = _jit_func(obj_or_fun=value,
                                  show_code=self.show_code,
                                  **self.jit_setting)
                    tree = _replace_func_call_by_tree(
                        deepcopy(module),
                        func_call=target,
                        arg_to_append=r['arguments'],
                        new_func_name=f'{target}_{i}')

                    # update import parameters
                    self.arguments.update(r['arguments'])
                    self.arg2call.update(r['arg2call'])
                    self.nodes.update(r['nodes'])

                    # replace the data
                    if isinstance(value, Base):
                        host = value
                        replace_name = f'{host.name}_{target}'
                    elif hasattr(value, '__self__') and isinstance(
                            value.__self__, Base):
                        host = value.__self__
                        replace_name = f'{host.name}_{target}'
                    else:
                        replace_name = f'{target}_{i}'
                    self.code_scope[replace_name] = r['func']
                    data_to_replace[f'{target}_{i}'] = replace_name

                    final_node.body.extend(tree.body)

                else:
                    raise errors.BrainPyError(
                        f'Only support JIT an iterable objects of function '
                        f'or Base object, but we got:\n\n {value}')

            # replace words
            final_code = tools.ast2code(final_node)
            final_code = tools.word_replace(final_code,
                                            data_to_replace,
                                            exclude_dot=True)
            final_node = ast.parse(final_code)

        else:
            final_node = node

        self.generic_visit(final_node)
        return final_node
Exemple #3
0
def _analyze_cls_func_body(host,
                           self_name,
                           code,
                           tree,
                           show_code=False,
                           has_func_def=False,
                           **jit_setting):
    arguments, arg2call, nodes, code_scope = set(), dict(), Collector(), dict()

    # all self data
    self_data = re.findall('\\b' + self_name + '\\.[A-Za-z_][A-Za-z0-9_.]*\\b',
                           code)
    self_data = list(set(self_data))

    # analyze variables and functions accessed by the self.xx
    data_to_replace = {}
    for key in self_data:
        split_keys = key.split('.')
        if len(split_keys) < 2:
            raise errors.BrainPyError

        # get target and data
        target = host
        for i in range(1, len(split_keys)):
            next_target = getattr(target, split_keys[i])
            if isinstance(next_target, Integrator):
                break
            if not isinstance(next_target, Base):
                break
            target = next_target
        else:
            raise errors.BrainPyError
        data = getattr(target, split_keys[i])

        # analyze data
        if isinstance(data, math.numpy.Variable):  # data is a variable
            arguments.add(f'{target.name}_{split_keys[i]}')
            arg2call[
                f'{target.name}_{split_keys[i]}'] = f'{target.name}.{split_keys[-1]}.value'
            nodes[target.name] = target
            # replace the data
            if len(split_keys) == i + 1:
                data_to_replace[key] = f'{target.name}_{split_keys[i]}'
            else:
                data_to_replace[
                    key] = f'{target.name}_{split_keys[i]}.{".".join(split_keys[i + 1:])}'

        elif isinstance(data, np.random.RandomState):  # data is a RandomState
            # replace RandomState
            code_scope[f'{target.name}_{split_keys[i]}'] = np.random
            # replace the data
            if len(split_keys) == i + 1:
                data_to_replace[key] = f'{target.name}_{split_keys[i]}'
            else:
                data_to_replace[
                    key] = f'{target.name}_{split_keys[i]}.{".".join(split_keys[i + 1:])}'

        elif callable(data):  # data is a function
            assert len(split_keys) == i + 1
            r = _jit_func(obj_or_fun=data, show_code=show_code, **jit_setting)
            # if len(r['arguments']):
            tree = _replace_func_call_by_tree(tree,
                                              func_call=key,
                                              arg_to_append=r['arguments'])
            arguments.update(r['arguments'])
            arg2call.update(r['arg2call'])
            nodes.update(r['nodes'])
            code_scope[f'{target.name}_{split_keys[i]}'] = r['func']
            data_to_replace[
                key] = f'{target.name}_{split_keys[i]}'  # replace the data

        elif isinstance(
                data, (dict, list,
                       tuple)):  # data is a list/tuple/dict of function/object
            # get all values
            if isinstance(data, dict):  # check dict
                if len(split_keys) != i + 2 and split_keys[-1] != 'values':
                    raise errors.BrainPyError(
                        f'Only support iter dict.values(). while we got '
                        f'dict.{split_keys[-1]}  for data: \n\n{data}')
                values = list(data.values())
                iter_name = key + '()'
            else:  # check list / tuple
                assert len(split_keys) == i + 1
                values = list(data)
                iter_name = key
                if len(values) > 0:
                    if not (callable(values[0])
                            or isinstance(values[0], Base)):
                        code_scope[f'{target.name}_{split_keys[i]}'] = data
                        if len(split_keys) == i + 1:
                            data_to_replace[
                                key] = f'{target.name}_{split_keys[i]}'
                        else:
                            data_to_replace[
                                key] = f'{target.name}_{split_keys[i]}.{".".join(split_keys[i + 1:])}'
                        continue
                        # raise errors.BrainPyError(f'Only support JIT an iterable objects of function '
                        #                           f'or Base object, but we got:\n\n {values[0]}')
            # replace this for-loop
            r = _replace_this_forloop(tree=tree,
                                      iter_name=iter_name,
                                      loop_values=values,
                                      show_code=show_code,
                                      **jit_setting)
            tree, _arguments, _arg2call, _nodes, _code_scope = r
            arguments.update(_arguments)
            arg2call.update(_arg2call)
            nodes.update(_nodes)
            code_scope.update(_code_scope)

        else:  # constants
            code_scope[f'{target.name}_{split_keys[i]}'] = data
            # replace the data
            if len(split_keys) == i + 1:
                data_to_replace[key] = f'{target.name}_{split_keys[i]}'
            else:
                data_to_replace[
                    key] = f'{target.name}_{split_keys[i]}.{".".join(split_keys[i + 1:])}'

    if has_func_def:
        tree.body[0].decorator_list.clear()
        tree.body[0].args.args.extend(
            [ast.Name(id=a) for a in sorted(arguments)])
        tree.body[0].args.defaults.extend(
            [ast.Constant(None) for _ in sorted(arguments)])
        tree.body[0].args.kwarg = None

    # replace words
    code = tools.ast2code(tree)
    code = tools.word_replace(code, data_to_replace, exclude_dot=True)

    return code, arguments, arg2call, nodes, code_scope
Exemple #4
0
def _jit_intg(f, show_code=False, **jit_setting):
    # TODO: integrator has "integral", "code_lines", "code_scope", "func_name", "derivative",
    assert isinstance(f, Integrator)

    # exponential euler methods
    if hasattr(f.integral, '__self__'):
        return _jit_cls_func(f=f.integral,
                             code="\n".join(f.code_lines),
                             show_code=show_code,
                             **jit_setting)

    # information in the integrator
    func_name = f.func_name
    raw_func = f.derivative
    tree = ast.parse('\n'.join(f.code_lines))
    code_scope = {key: val for key, val in f.code_scope.items()}

    # essential information
    arguments = set()
    arg2call = dict()
    nodes = Collector()

    # jit raw functions
    f_node = None
    remove_self = None
    if hasattr(f, '__self__') and isinstance(f.__self__, DynamicalSystem):
        f_node = f.__self__
        _arg = tree.body[0].args.args.pop(0)  # remove "self" arg
        # remove "self" in functional call
        remove_self = _arg.arg

    need_recompile = False
    for key, func in raw_func.items():
        # get node of host
        func_node = None
        if f_node:
            func_node = f_node
        elif hasattr(func, '__self__') and isinstance(func.__self__,
                                                      DynamicalSystem):
            func_node = func.__self__

        # get new compiled function
        if isinstance(func, Dispatcher):
            continue
        elif func_node is not None:
            need_recompile = True
            r = _jit_cls_func(f=func,
                              host=func_node,
                              show_code=show_code,
                              **jit_setting)
            if len(r['arguments']) or remove_self:
                tree = _replace_func_call_by_tree(tree,
                                                  func_call=key,
                                                  arg_to_append=r['arguments'],
                                                  remove_self=remove_self)
            code_scope[key] = r['func']
            arguments.update(r['arguments'])  # update arguments
            arg2call.update(r['arg2call'])  # update arg2call
            nodes.update(r['nodes'])  # update nodes
            nodes[func_node.name] = func_node  # update nodes
        else:
            need_recompile = True
            code_scope[key] = numba.jit(func, **jit_setting)

    if need_recompile:
        tree.body[0].decorator_list.clear()
        tree.body[0].args.args.extend(
            [ast.Name(id=a) for a in sorted(arguments)])
        tree.body[0].args.defaults.extend(
            [ast.Constant(None) for _ in sorted(arguments)])
        code = tools.ast2code(tree)
        # code, _scope = _add_try_except(code)
        # code_scope.update(_scope)
        # code_scope_backup = {k: v for k, v in code_scope.items()}
        # compile functions
        if show_code:
            _show_compiled_codes(code, code_scope)
        exec(compile(code, '', 'exec'), code_scope)
        new_f = code_scope[func_name]
        # new_f.brainpy_data = {key: val for key, val in f.brainpy_data.items()}
        # new_f.brainpy_data['code_lines'] = code.strip().split('\n')
        # new_f.brainpy_data['code_scope'] = code_scope_backup
        jit_f = numba.jit(new_f, **jit_setting)
        return dict(func=jit_f,
                    arguments=arguments,
                    arg2call=arg2call,
                    nodes=nodes)
    else:
        return dict(func=f,
                    arguments=arguments,
                    arg2call=arg2call,
                    nodes=nodes)
Exemple #5
0
def analyze_step_func(host, f):
    """Analyze the step functions in a population.

    Parameters
    ----------
    f : callable
        The step function.
    host : Population
        The data and the function host.

    Returns
    -------
    results : dict
        The code string of the function, the code scope,
        the data need pass into the arguments,
        the data need return.
    """
    code_string = tools.deindent(inspect.getsource(f)).strip()
    tree = ast.parse(code_string)

    # arguments
    # ---
    args = tools.ast2code(ast.fix_missing_locations(
        tree.body[0].args)).split(',')

    # code AST analysis
    # ---
    formatter = StepFuncReader(host=host)
    formatter.visit(tree)

    # data assigned by self.xx in line right
    # ---
    self_data_in_right = []
    if args[0] in backend.CLASS_KEYWORDS:
        code = ', \n'.join(formatter.rights)
        self_data_in_right = re.findall(
            '\\b' + args[0] + '\\.[A-Za-z_][A-Za-z0-9_.]*\\b', code)
        self_data_in_right = list(set(self_data_in_right))

    # data assigned by self.xxx in line left
    # ---
    code = ', \n'.join(formatter.lefts)
    self_data_without_index_in_left = []
    self_data_with_index_in_left = []
    if args[0] in backend.CLASS_KEYWORDS:
        class_p1 = '\\b' + args[0] + '\\.[A-Za-z_][A-Za-z0-9_.]*\\b'
        self_data_without_index_in_left = set(re.findall(class_p1, code))
        class_p2 = '(\\b' + args[0] + '\\.[A-Za-z_][A-Za-z0-9_.]*)\\[.*\\]'
        self_data_with_index_in_left = set(re.findall(
            class_p2, code))  #- self_data_without_index_in_left
        # self_data_with_index_in_left = set(re.findall(class_p2, code)) - self_data_without_index_in_left
        self_data_with_index_in_left = list(self_data_with_index_in_left)
        self_data_without_index_in_left = list(self_data_without_index_in_left)

    # code scope
    # ---
    closure_vars = inspect.getclosurevars(f)
    code_scope = dict(closure_vars.nonlocals)
    code_scope.update(closure_vars.globals)

    # final
    # ---
    self_data_in_right = sorted(self_data_in_right)
    self_data_without_index_in_left = sorted(self_data_without_index_in_left)
    self_data_with_index_in_left = sorted(self_data_with_index_in_left)

    analyzed_results = {
        'delay_call': formatter.delay_call,
        'code_string': '\n'.join(formatter.lines),
        'code_scope': code_scope,
        'self_data_in_right': self_data_in_right,
        'self_data_without_index_in_left': self_data_without_index_in_left,
        'self_data_with_index_in_left': self_data_with_index_in_left,
    }

    return analyzed_results
Exemple #6
0
    def visit_Call(self, node, level=0):
        if getattr(node, 'starargs', None) is not None:
            raise ValueError("Variable number of arguments not supported")
        if getattr(node, 'kwargs', None) is not None:
            raise ValueError("Keyword arguments not supported")

        if node in self.visited_calls:
            return node

        calls = self.visit_attr(node.func)
        calls = calls[::-1]

        # get the object and the function
        if calls[0] not in backend.CLASS_KEYWORDS:
            return node
        obj = self.host
        for data in calls[1:-1]:
            obj = getattr(obj, data)
        obj_func = getattr(obj, calls[-1])

        # get function arguments
        args = []
        for arg in node.args:
            args.append(tools.ast2code(ast.fix_missing_locations(arg)))
        kw_args = OrderedDict()
        for keyword in node.keywords:
            kw_args[keyword.arg] = tools.ast2code(
                ast.fix_missing_locations(keyword.value))

        # TASK 1 : extract delay push and delay pull
        # ------
        # Replace the delay function call to the delay_data
        # index. In such a way, delay function will be removed.
        # ------

        if calls[-1] in ['push', 'pull'] and isinstance(
                obj, delays.ConstantDelay) and callable(obj_func):
            dvar4call = '.'.join(calls[0:-1])
            uniform_delay = getattr(obj, 'uniform_delay')
            if calls[-1] == 'push':
                data_need_pass = [
                    f'{dvar4call}.delay_data', f'{dvar4call}.delay_in_idx'
                ]
                idx_or_val = kw_args['idx_or_val'] if len(
                    args) == 0 else args[0]
                if len(args) + len(kw_args) == 1:
                    rep_expression = f'{dvar4call}.delay_data[{dvar4call}.delay_in_idx] = {idx_or_val}'
                elif len(args) + len(kw_args) == 2:
                    value = kw_args['value'] if len(args) <= 1 else args[1]
                    if uniform_delay:
                        rep_expression = f'{dvar4call}.delay_data[{dvar4call}.delay_in_idx, {idx_or_val}] = {value}'
                    else:
                        rep_expression = f'{dvar4call}.delay_data[{dvar4call}.delay_in_idx[{idx_or_val}], {idx_or_val}] = {value}'
                else:
                    raise errors.CodeError(
                        f'Cannot analyze the code: \n\n'
                        f'{tools.ast2code(ast.fix_missing_locations(node))}')
            else:
                data_need_pass = [
                    f'{dvar4call}.delay_data', f'{dvar4call}.delay_out_idx'
                ]
                if len(args) + len(kw_args) == 0:
                    rep_expression = f'{dvar4call}.delay_data[{dvar4call}.delay_out_idx]'
                elif len(args) + len(kw_args) == 1:
                    idx = kw_args['idx'] if len(args) == 0 else args[0]
                    if uniform_delay:
                        rep_expression = f'{dvar4call}.delay_data[{dvar4call}.delay_out_idx, {idx}]'
                    else:
                        rep_expression = f'{dvar4call}.delay_data[{dvar4call}.delay_out_idx[{idx}], {idx}]'
                else:
                    raise errors.CodeError(
                        f'Cannot analyze the code: \n\n'
                        f'{tools.ast2code(ast.fix_missing_locations(node))}')

            org_call = tools.ast2code(ast.fix_missing_locations(node))
            self.visited_calls[node] = dict(type=calls[-1],
                                            org_call=org_call,
                                            rep_call=rep_expression,
                                            data_need_pass=data_need_pass)

        self.generic_visit(node)