Exemple #1
0
    def _replace_expressions(self, expressions, name, y_sub, t_sub=None):
        """Replace expressions of df part.

        Parameters
        ----------
        expressions : list, tuple
            The list/tuple of expressions.
        name : str
            The name of the new expression.
        y_sub : str
            The new name of the variable "y".
        t_sub : str, optional
            The new name of the variable "t".

        Returns
        -------
        list_of_expr : list
            A list of expressions.
        """
        return_expressions = []

        # replacements
        replacement = {self.var_name: y_sub}
        if t_sub is not None:
            replacement[self.t_name] = t_sub

        # replace variables in expressions
        for expr in expressions:
            replace = False
            identifiers = expr.identifiers
            for repl_var in replacement.keys():
                if repl_var in identifiers:
                    replace = True
                    break
            if replace:
                code = tools.word_replace(expr.code, replacement)
                new_expr = Expression(f"{expr.var_name}_{name}", code)
                return_expressions.append(new_expr)
                replacement[expr.var_name] = new_expr.var_name
        return return_expressions
Exemple #2
0
    def visit_For(self, node):
        iter_ = tools.ast2code(ast.fix_missing_locations(node.iter))

        if iter_.strip() == self.iter_name:
            data_to_replace = Collector()
            final_node = ast.Module(body=[])
            self.success = True

            # target
            if not isinstance(node.target, ast.Name):
                raise errors.BrainPyError(
                    f'Only support scalar iter, like "for x in xxxx:", not "for '
                    f'{tools.ast2code(ast.fix_missing_locations(node.target))} '
                    f'in {iter_}:')
            target = node.target.id

            # for loop values
            for i, value in enumerate(self.loop_values):
                # module and code
                module = ast.Module(body=deepcopy(node).body)
                code = tools.ast2code(module)

                if isinstance(value, Base):  # transform Base objects
                    r = _analyze_cls_func_body(host=value,
                                               self_name=target,
                                               code=code,
                                               tree=module,
                                               show_code=self.show_code,
                                               **self.jit_setting)

                    new_code, arguments, arg2call, nodes, code_scope = r
                    self.arguments.update(arguments)
                    self.arg2call.update(arg2call)
                    self.arg2call.update(arg2call)
                    self.nodes.update(nodes)
                    self.code_scope.update(code_scope)

                    final_node.body.extend(ast.parse(new_code).body)

                elif callable(value):  # transform functions
                    r = _jit_func(obj_or_fun=value,
                                  show_code=self.show_code,
                                  **self.jit_setting)
                    tree = _replace_func_call_by_tree(
                        deepcopy(module),
                        func_call=target,
                        arg_to_append=r['arguments'],
                        new_func_name=f'{target}_{i}')

                    # update import parameters
                    self.arguments.update(r['arguments'])
                    self.arg2call.update(r['arg2call'])
                    self.nodes.update(r['nodes'])

                    # replace the data
                    if isinstance(value, Base):
                        host = value
                        replace_name = f'{host.name}_{target}'
                    elif hasattr(value, '__self__') and isinstance(
                            value.__self__, Base):
                        host = value.__self__
                        replace_name = f'{host.name}_{target}'
                    else:
                        replace_name = f'{target}_{i}'
                    self.code_scope[replace_name] = r['func']
                    data_to_replace[f'{target}_{i}'] = replace_name

                    final_node.body.extend(tree.body)

                else:
                    raise errors.BrainPyError(
                        f'Only support JIT an iterable objects of function '
                        f'or Base object, but we got:\n\n {value}')

            # replace words
            final_code = tools.ast2code(final_node)
            final_code = tools.word_replace(final_code,
                                            data_to_replace,
                                            exclude_dot=True)
            final_node = ast.parse(final_code)

        else:
            final_node = node

        self.generic_visit(final_node)
        return final_node
Exemple #3
0
def _analyze_cls_func_body(host,
                           self_name,
                           code,
                           tree,
                           show_code=False,
                           has_func_def=False,
                           **jit_setting):
    arguments, arg2call, nodes, code_scope = set(), dict(), Collector(), dict()

    # all self data
    self_data = re.findall('\\b' + self_name + '\\.[A-Za-z_][A-Za-z0-9_.]*\\b',
                           code)
    self_data = list(set(self_data))

    # analyze variables and functions accessed by the self.xx
    data_to_replace = {}
    for key in self_data:
        split_keys = key.split('.')
        if len(split_keys) < 2:
            raise errors.BrainPyError

        # get target and data
        target = host
        for i in range(1, len(split_keys)):
            next_target = getattr(target, split_keys[i])
            if isinstance(next_target, Integrator):
                break
            if not isinstance(next_target, Base):
                break
            target = next_target
        else:
            raise errors.BrainPyError
        data = getattr(target, split_keys[i])

        # analyze data
        if isinstance(data, math.numpy.Variable):  # data is a variable
            arguments.add(f'{target.name}_{split_keys[i]}')
            arg2call[
                f'{target.name}_{split_keys[i]}'] = f'{target.name}.{split_keys[-1]}.value'
            nodes[target.name] = target
            # replace the data
            if len(split_keys) == i + 1:
                data_to_replace[key] = f'{target.name}_{split_keys[i]}'
            else:
                data_to_replace[
                    key] = f'{target.name}_{split_keys[i]}.{".".join(split_keys[i + 1:])}'

        elif isinstance(data, np.random.RandomState):  # data is a RandomState
            # replace RandomState
            code_scope[f'{target.name}_{split_keys[i]}'] = np.random
            # replace the data
            if len(split_keys) == i + 1:
                data_to_replace[key] = f'{target.name}_{split_keys[i]}'
            else:
                data_to_replace[
                    key] = f'{target.name}_{split_keys[i]}.{".".join(split_keys[i + 1:])}'

        elif callable(data):  # data is a function
            assert len(split_keys) == i + 1
            r = _jit_func(obj_or_fun=data, show_code=show_code, **jit_setting)
            # if len(r['arguments']):
            tree = _replace_func_call_by_tree(tree,
                                              func_call=key,
                                              arg_to_append=r['arguments'])
            arguments.update(r['arguments'])
            arg2call.update(r['arg2call'])
            nodes.update(r['nodes'])
            code_scope[f'{target.name}_{split_keys[i]}'] = r['func']
            data_to_replace[
                key] = f'{target.name}_{split_keys[i]}'  # replace the data

        elif isinstance(
                data, (dict, list,
                       tuple)):  # data is a list/tuple/dict of function/object
            # get all values
            if isinstance(data, dict):  # check dict
                if len(split_keys) != i + 2 and split_keys[-1] != 'values':
                    raise errors.BrainPyError(
                        f'Only support iter dict.values(). while we got '
                        f'dict.{split_keys[-1]}  for data: \n\n{data}')
                values = list(data.values())
                iter_name = key + '()'
            else:  # check list / tuple
                assert len(split_keys) == i + 1
                values = list(data)
                iter_name = key
                if len(values) > 0:
                    if not (callable(values[0])
                            or isinstance(values[0], Base)):
                        code_scope[f'{target.name}_{split_keys[i]}'] = data
                        if len(split_keys) == i + 1:
                            data_to_replace[
                                key] = f'{target.name}_{split_keys[i]}'
                        else:
                            data_to_replace[
                                key] = f'{target.name}_{split_keys[i]}.{".".join(split_keys[i + 1:])}'
                        continue
                        # raise errors.BrainPyError(f'Only support JIT an iterable objects of function '
                        #                           f'or Base object, but we got:\n\n {values[0]}')
            # replace this for-loop
            r = _replace_this_forloop(tree=tree,
                                      iter_name=iter_name,
                                      loop_values=values,
                                      show_code=show_code,
                                      **jit_setting)
            tree, _arguments, _arg2call, _nodes, _code_scope = r
            arguments.update(_arguments)
            arg2call.update(_arg2call)
            nodes.update(_nodes)
            code_scope.update(_code_scope)

        else:  # constants
            code_scope[f'{target.name}_{split_keys[i]}'] = data
            # replace the data
            if len(split_keys) == i + 1:
                data_to_replace[key] = f'{target.name}_{split_keys[i]}'
            else:
                data_to_replace[
                    key] = f'{target.name}_{split_keys[i]}.{".".join(split_keys[i + 1:])}'

    if has_func_def:
        tree.body[0].decorator_list.clear()
        tree.body[0].args.args.extend(
            [ast.Name(id=a) for a in sorted(arguments)])
        tree.body[0].args.defaults.extend(
            [ast.Constant(None) for _ in sorted(arguments)])
        tree.body[0].args.kwarg = None

    # replace words
    code = tools.ast2code(tree)
    code = tools.word_replace(code, data_to_replace, exclude_dot=True)

    return code, arguments, arg2call, nodes, code_scope
Exemple #4
0
    def get_integral_step(diff_eq, *args):
        dt = backend.get_dt()
        f_expressions = diff_eq.get_f_expressions(
            substitute_vars=diff_eq.var_name)

        # code lines
        code_lines = [str(expr) for expr in f_expressions[:-1]]

        # get the linear system using sympy
        f_res = f_expressions[-1]
        df_expr = ast_analysis.str2sympy(f_res.code).expr.expand()
        s_df = sympy.Symbol(f"{f_res.var_name}")
        code_lines.append(f'{s_df.name} = {ast_analysis.sympy2str(df_expr)}')
        var = sympy.Symbol(diff_eq.var_name, real=True)

        # get df part
        s_linear = sympy.Symbol(f'_{diff_eq.var_name}_linear')
        s_linear_exp = sympy.Symbol(f'_{diff_eq.var_name}_linear_exp')
        s_df_part = sympy.Symbol(f'_{diff_eq.var_name}_df_part')
        if df_expr.has(var):
            # linear
            linear = sympy.collect(df_expr, var, evaluate=False)[var]
            code_lines.append(
                f'{s_linear.name} = {ast_analysis.sympy2str(linear)}')
            # linear exponential
            linear_exp = sympy.exp(linear * dt)
            code_lines.append(
                f'{s_linear_exp.name} = {ast_analysis.sympy2str(linear_exp)}')
            # df part
            df_part = (s_linear_exp - 1) / s_linear * s_df
            code_lines.append(
                f'{s_df_part.name} = {ast_analysis.sympy2str(df_part)}')

        else:
            # linear exponential
            code_lines.append(f'{s_linear_exp.name} = sqrt({dt})')
            # df part
            code_lines.append(
                f'{s_df_part.name} = {ast_analysis.sympy2str(dt * s_df)}')

        # get dg part
        if diff_eq.is_stochastic:
            # dW
            noise = f'_normal_like_({diff_eq.var_name})'
            code_lines.append(f'_{diff_eq.var_name}_dW = {noise}')
            # expressions of the stochastic part
            g_expressions = diff_eq.get_g_expressions()
            code_lines.extend([str(expr) for expr in g_expressions[:-1]])
            g_expr = g_expressions[-1].code
            # get the dg_part
            s_dg_part = sympy.Symbol(f'_{diff_eq.var_name}_dg_part')
            code_lines.append(
                f'_{diff_eq.var_name}_dg_part = {g_expr} * _{diff_eq.var_name}_dW'
            )
        else:
            s_dg_part = 0

        # update expression
        update = var + s_df_part + s_dg_part * s_linear_exp

        # The actual update step
        code_lines.append(
            f'{diff_eq.var_name} = {ast_analysis.sympy2str(update)}')
        return_expr = ', '.join([diff_eq.var_name] +
                                diff_eq.return_intermediates)
        code_lines.append(f'_res = {return_expr}')

        # final
        code = '\n'.join(code_lines)
        subs_dict = {
            arg: f'_{arg}'
            for arg in diff_eq.func_args + diff_eq.expr_names
        }
        code = tools.word_replace(code, subs_dict)
        return code
Exemple #5
0
def class2func(cls_func, host, func_name=None, show_code=False):
    """Transform the function in a class into the ordinary function which is
    compatible with the Numba JIT compilation.

    Parameters
    ----------
    cls_func : function
        The function of the instantiated class.
    func_name : str
        The function name. If not given, it will get the function by `cls_func.__name__`.
    show_code : bool
        Whether show the code.

    Returns
    -------
    new_func : function
        The transformed function.
    """
    class_arg, arguments = utils.get_args(cls_func)
    func_name = cls_func.__name__ if func_name is None else func_name
    host_name = host.name

    # arguments 1
    calls = []
    for arg in arguments:
        if hasattr(host, arg):
            calls.append(f'{host_name}.{arg}')
        elif arg in backend.SYSTEM_KEYWORDS:
            calls.append(arg)
        else:
            raise errors.ModelDefError(
                f'Step function "{func_name}" of {host} '
                f'define an unknown argument "{arg}" which is not '
                f'an attribute of {host} nor the system keywords '
                f'{backend.SYSTEM_KEYWORDS}.')

    # analysis
    analyzed_results = analyze_step_func(host=host, f=cls_func)
    delay_call = analyzed_results['delay_call']
    # code_string = analyzed_results['code_string']
    main_code = analyzed_results['code_string']
    code_scope = analyzed_results['code_scope']
    self_data_in_right = analyzed_results['self_data_in_right']
    self_data_without_index_in_left = analyzed_results[
        'self_data_without_index_in_left']
    self_data_with_index_in_left = analyzed_results[
        'self_data_with_index_in_left']
    # main_code = get_func_body_code(code_string)
    num_indent = get_num_indent(main_code)
    data_need_pass = sorted(
        list(set(self_data_in_right + self_data_with_index_in_left)))
    data_need_return = self_data_without_index_in_left

    # check delay
    replaces_early = {}
    replaces_later = {}
    if len(delay_call) > 0:
        for delay_ in delay_call.values():
            # delay_ = dict(type=calls[-1],
            #               args=args,
            #               keywords=keywords,
            #               kws_append=kws_append,
            #               func=func,
            #               org_call=org_call,
            #               rep_call=rep_call,
            #               data_need_pass=data_need_pass)
            if delay_['type'] == 'push':
                if len(delay_['args'] + delay_['keywords']) == 2:
                    func = numba.njit(delay.push_type2)
                elif len(delay_['args'] + delay_['keywords']) == 1:
                    func = numba.njit(delay.push_type1)
                else:
                    raise ValueError(f'Unknown delay push. {delay_}')
            else:
                if len(delay_['args'] + delay_['keywords']) == 1:
                    func = numba.njit(delay.pull_type1)
                elif len(delay_['args'] + delay_['keywords']) == 0:
                    func = numba.njit(delay.pull_type0)
                else:
                    raise ValueError(f'Unknown delay pull. {delay_}')
            delay_call_name = delay_['func']
            data_need_pass.remove(delay_call_name)
            data_need_pass.extend(delay_['data_need_pass'])
            replaces_early[delay_['org_call']] = delay_['rep_call']
            replaces_later[delay_call_name] = delay_call_name.replace('.', '_')
            code_scope[delay_call_name.replace('.', '_')] = func
    for target, dest in replaces_early.items():
        main_code = main_code.replace(target, dest)
    # main_code = tools.word_replace(main_code, replaces_early)

    # arguments 2: data need pass
    new_args = arguments + []
    for data in sorted(set(data_need_pass)):
        splits = data.split('.')
        replaces_later[data] = data.replace('.', '_')
        obj = host
        for attr in splits[1:]:
            obj = getattr(obj, attr)
        if callable(obj):
            code_scope[data.replace('.', '_')] = obj
            continue
        new_args.append(data.replace('.', '_'))
        calls.append('.'.join([host_name] + splits[1:]))

    # data need return
    assigns = []
    returns = []
    for data in data_need_return:
        splits = data.split('.')
        assigns.append('.'.join([host_name] + splits[1:]))
        returns.append(data.replace('.', '_'))
        replaces_later[data] = data.replace('.', '_')

    # code scope
    code_scope[host_name] = host

    # codes
    header = f'def new_{func_name}({", ".join(new_args)}):\n'
    main_code = header + tools.indent(main_code, spaces_per_tab=2)
    if len(returns):
        main_code += f'\n{" " * num_indent + "  "}return {", ".join(returns)}'
    main_code = tools.word_replace(main_code, replaces_later)
    if show_code:
        print(main_code)
        print(code_scope)
        print()

    # recompile
    exec(compile(main_code, '', 'exec'), code_scope)
    func = code_scope[f'new_{func_name}']
    func = numba.jit(**NUMBA_PROFILE)(func)
    return func, calls, assigns