Exemplo n.º 1
0
def cse(expr):
    """ symplify a complicated sympy expression
        into a list of expression using the cse
        sympy function
    """
    ls = list(expr.atoms(Sum))
    if not ls:
        return [expr]
    ls += [expr]
    (ls, _) = sympy_cse(ls)

    (vars_old, stmts) = map(list, zip(*ls))
    vars_new = []
    free_gl = expr.free_symbols
    free_gl.update(expr.atoms(IndexedBase))
    free_gl.update(vars_old)
    stmts.append(expr)

    for i in range(len(stmts) - 1):
        free = stmts[i].free_symbols
        free = free.difference(free_gl)
        free = list(free)
        var = create_variable(stmts[i])
        if len(free) > 0:
            var = IndexedBase(var)[free]
        vars_new.append(var)
    for i in range(len(stmts) - 1):
        stmts[i + 1] = stmts[i + 1].replace(vars_old[i], vars_new[i])
        stmts[-1] = stmts[-1].replace(stmts[i], vars_new[i])

    allocate = []
    for i in range(len(stmts) - 1):
        stmts[i] = Assign(vars_new[i], stmts[i])
        stmts[i] = pyccel_sum(stmts[i])
        if isinstance(vars_new[i], Indexed):
            ind = vars_new[i].indices
            tp = list(stmts[i + 1].atoms(Tuple))
            size = None
            size = [None] * len(ind)
            for (j, k) in enumerate(ind):
                for t in tp:
                    if k == t[0]:
                        size[j] = t[2] - t[1] + 1
                        break
            if not all(size):
                raise ValueError('Unable to find range of index')
            name = str(vars_new[i].base)
            var = Symbol(name)
            stmt = Assign(var, Function('empty')(size[0]))
            allocate.append(stmt)
            stmts[i] = For(ind[0],
                           Function('range')(size[0]), [stmts[i]],
                           strict=False)
    lhs = create_variable(expr)
    stmts[-1] = Assign(lhs, stmts[-1])
    imports = [Import('empty', 'numpy')]
    return imports + allocate + stmts
Exemplo n.º 2
0
def lambdify(expr, args):
    if isinstance(args, Lambda):
        new_expr = args.expr
        new_expr = Return(new_expr)
        new_expr.set_fst(expr)
        f_arguments = args.variables
        func = FunctionDef('lambda', f_arguments, [], [new_expr])
        return func

    code = compile(args.body[0], '', 'single')
    g = {}
    eval(code, g)
    f_name = str(args.name)
    code = g[f_name]
    new_args = args.arguments
    new_expr = code(*new_args)
    f_arguments = list(new_expr.free_symbols)
    stmts = cse(new_expr)
    if isinstance(stmts[-1], (Assign, GC)):
        var = stmts[-1].lhs
    else:
        var = create_variable(expr)
        stmts[-1] = Assign(var, stmts[-1])
    stmts += [Return([var])]
    set_fst(stmts, args.fst)
    func = FunctionDef(f_name, new_args, [], stmts, decorators=args.decorators)
    return func
Exemplo n.º 3
0
    def get_new_variable(self, prefix=None):
        """
        Creates a new sympy Symbol using the prefix provided. If this prefix is None,
        then the standard prefix is used, and the dummy counter is used and updated
        to facilitate finding the next value of this common case

          Parameters
          ----------
          prefix   : str

          Returns
          -------
          variable : sympy.Symbol
        """
        if prefix is not None:
            var, _ = create_variable(self._used_names, prefix)
        else:
            var, self._dummy_counter = create_variable(
                self._used_names, prefix, counter=self._dummy_counter)
        return var
Exemplo n.º 4
0
    def _visit_FunctionDef(self, stmt):

        #  TODO check all inputs and which ones should be treated in stage 1 or 2

        name = self._visit(stmt.name)
        name = name.replace("'", '')

        arguments = self._visit(stmt.args)

        local_vars = []
        global_vars = []
        headers = []
        templates = {}
        hide = False
        kind = 'function'
        is_pure = False
        is_elemental = False
        is_private = False
        imports = []

        def fill_types(ls):
            container = []
            for arg in ls:
                if isinstance(arg, Symbol):
                    arg = arg.name
                    container.append(arg)
                elif isinstance(arg, LiteralString):
                    arg = str(arg)
                    arg = arg.strip("'").strip('"')
                    container.append(arg)
                else:
                    msg = 'Invalid argument of type {} passed to types decorator'.format(
                        type(arg))
                    errors.report(msg,
                                  bounding_box=(stmt.lineno, stmt.col_offset),
                                  severity='error')
            return container

        decorators = {}
        for d in self._visit(stmt.decorator_list):
            tmp_var = str(d) if isinstance(d, Symbol) else str(type(d))
            if tmp_var in decorators:
                decorators[tmp_var] += [d]
            else:
                decorators[tmp_var] = [d]

        if 'bypass' in decorators:
            return EmptyNode()

        if 'stack_array' in decorators:
            decorators['stack_array'] = tuple(
                str(b) for a in decorators['stack_array'] for b in a.args)

        if 'allow_negative_index' in decorators:
            decorators['allow_negative_index'] = tuple(
                str(b) for a in decorators['allow_negative_index']
                for b in a.args)

        # extract the templates
        if 'template' in decorators:
            for comb_types in decorators['template']:
                cache.clear_cache()
                types = []
                if len(comb_types.args) != 2:
                    msg = 'Number of Arguments provided to the template decorator is not valid'
                    errors.report(msg,
                                  symbol=comb_types,
                                  bounding_box=(stmt.lineno, stmt.col_offset),
                                  severity='error')

                for i in comb_types.args:
                    if isinstance(i,
                                  ValuedArgument) and not i.name in ('name',
                                                                     'types'):
                        msg = 'Argument provided to the template decorator is not valid'
                        errors.report(msg,
                                      symbol=comb_types,
                                      bounding_box=(stmt.lineno,
                                                    stmt.col_offset),
                                      severity='error')
                if all(isinstance(i, ValuedArgument) for i in comb_types.args):
                    tp_name, ls = (comb_types.args[0].value, comb_types.args[1].value) if\
                            comb_types.args[0].name == 'name' else\
                            (comb_types.args[1].value, comb_types.args[0].value)
                else:
                    tp_name = comb_types.args[0]
                    ls = comb_types.args[1]
                    ls = ls.value if isinstance(ls, ValuedArgument) else ls
                try:
                    tp_name = str(tp_name)
                    ls = ls if isinstance(ls, PythonTuple) else list(ls)
                except TypeError:
                    msg = 'Argument provided to the template decorator is not valid'
                    errors.report(msg,
                                  symbol=comb_types,
                                  bounding_box=(stmt.lineno, stmt.col_offset),
                                  severity='fatal')

                types = fill_types(ls)

                txt = '#$ header template ' + str(tp_name)
                txt += '(' + '|'.join(types) + ')'
                if tp_name in templates:
                    msg = 'The template "{}" is duplicated'.format(tp_name)
                    errors.report(msg,
                                  bounding_box=(stmt.lineno, stmt.col_offset),
                                  severity='warning')

                templates[tp_name] = hdr_parse(stmts=txt)

        # extract the types to construct a header
        if 'types' in decorators:
            for comb_types in decorators['types']:

                cache.clear_cache()
                results = []
                ls = comb_types.args

                if len(ls) > 0 and isinstance(ls[-1], ValuedArgument):
                    arg_name = ls[-1].name
                    if not arg_name == 'results':
                        msg = 'Argument "{}" provided to the types decorator is not valid'.format(
                            arg_name)
                        errors.report(msg,
                                      symbol=comb_types,
                                      bounding_box=(stmt.lineno,
                                                    stmt.col_offset),
                                      severity='error')
                    else:
                        container = ls[-1].value
                        container = container if isinstance(
                            container, PythonTuple) else [container]
                        results = fill_types(container)
                    types = fill_types(ls[:-1])
                else:
                    types = fill_types(ls)

                txt = '#$ header ' + name
                txt += '(' + ','.join(types) + ')'

                if results:
                    txt += ' results(' + ','.join(results) + ')'

                header = hdr_parse(stmts=txt)
                if name in self.namespace.static_functions:
                    header = header.to_static()
                headers += [header]

        body = stmt.body

        if 'sympy' in decorators.keys():
            # TODO maybe we should run pylint here
            stmt.decorators.pop()
            func = SympyFunction(name, arguments, [], [stmt.__str__()])
            func.set_fst(stmt)
            self.insert_function(func)
            return EmptyNode()

        elif 'python' in decorators.keys():

            # TODO maybe we should run pylint here

            stmt.decorators.pop()
            func = PythonFunction(name, arguments, [], [stmt.__str__()])
            func.set_fst(stmt)
            self.insert_function(func)
            return EmptyNode()

        else:
            body = self._visit(body)

        if 'pure' in decorators.keys():
            is_pure = True

        if 'elemental' in decorators.keys():
            is_elemental = True
            if len(arguments) > 1:
                errors.report(FORTRAN_ELEMENTAL_SINGLE_ARGUMENT,
                              symbol=decorators['elemental'],
                              bounding_box=(stmt.lineno, stmt.col_offset),
                              severity='error')

        if 'private' in decorators.keys():
            is_private = True

        returns = [i.expr for i in _atomic(body, cls=Return)]
        assert all(len(i) == len(returns[0]) for i in returns)
        results = []
        result_counter = 1
        for i in zip(*returns):
            if not all(i[0] == j for j in i) or not isinstance(i[0], Symbol):
                result_name, result_counter = create_variable(
                    self._used_names, prefix='Out', counter=result_counter)
                results.append(result_name)
            elif isinstance(i[0], Symbol) and any(i[0].name == x.name
                                                  for x in arguments):
                result_name, result_counter = create_variable(
                    self._used_names, prefix='Out', counter=result_counter)
                results.append(result_name)
            else:
                results.append(i[0])

        func = FunctionDef(name,
                           arguments,
                           results,
                           body,
                           local_vars=local_vars,
                           global_vars=global_vars,
                           hide=hide,
                           kind=kind,
                           is_pure=is_pure,
                           is_elemental=is_elemental,
                           is_private=is_private,
                           imports=imports,
                           decorators=decorators,
                           headers=headers,
                           templates=templates)

        func.set_fst(stmt)
        return func