def cse(expr): """ symplify a complicated sympy expression into a list of expression using the cse sympy function """ ls = list(expr.atoms(Sum)) if not ls: return [expr] ls += [expr] (ls, _) = sympy_cse(ls) (vars_old, stmts) = map(list, zip(*ls)) vars_new = [] free_gl = expr.free_symbols free_gl.update(expr.atoms(IndexedBase)) free_gl.update(vars_old) stmts.append(expr) for i in range(len(stmts) - 1): free = stmts[i].free_symbols free = free.difference(free_gl) free = list(free) var = create_variable(stmts[i]) if len(free) > 0: var = IndexedBase(var)[free] vars_new.append(var) for i in range(len(stmts) - 1): stmts[i + 1] = stmts[i + 1].replace(vars_old[i], vars_new[i]) stmts[-1] = stmts[-1].replace(stmts[i], vars_new[i]) allocate = [] for i in range(len(stmts) - 1): stmts[i] = Assign(vars_new[i], stmts[i]) stmts[i] = pyccel_sum(stmts[i]) if isinstance(vars_new[i], Indexed): ind = vars_new[i].indices tp = list(stmts[i + 1].atoms(Tuple)) size = None size = [None] * len(ind) for (j, k) in enumerate(ind): for t in tp: if k == t[0]: size[j] = t[2] - t[1] + 1 break if not all(size): raise ValueError('Unable to find range of index') name = str(vars_new[i].base) var = Symbol(name) stmt = Assign(var, Function('empty')(size[0])) allocate.append(stmt) stmts[i] = For(ind[0], Function('range')(size[0]), [stmts[i]], strict=False) lhs = create_variable(expr) stmts[-1] = Assign(lhs, stmts[-1]) imports = [Import('empty', 'numpy')] return imports + allocate + stmts
def lambdify(expr, args): if isinstance(args, Lambda): new_expr = args.expr new_expr = Return(new_expr) new_expr.set_fst(expr) f_arguments = args.variables func = FunctionDef('lambda', f_arguments, [], [new_expr]) return func code = compile(args.body[0], '', 'single') g = {} eval(code, g) f_name = str(args.name) code = g[f_name] new_args = args.arguments new_expr = code(*new_args) f_arguments = list(new_expr.free_symbols) stmts = cse(new_expr) if isinstance(stmts[-1], (Assign, GC)): var = stmts[-1].lhs else: var = create_variable(expr) stmts[-1] = Assign(var, stmts[-1]) stmts += [Return([var])] set_fst(stmts, args.fst) func = FunctionDef(f_name, new_args, [], stmts, decorators=args.decorators) return func
def get_new_variable(self, prefix=None): """ Creates a new sympy Symbol using the prefix provided. If this prefix is None, then the standard prefix is used, and the dummy counter is used and updated to facilitate finding the next value of this common case Parameters ---------- prefix : str Returns ------- variable : sympy.Symbol """ if prefix is not None: var, _ = create_variable(self._used_names, prefix) else: var, self._dummy_counter = create_variable( self._used_names, prefix, counter=self._dummy_counter) return var
def _visit_FunctionDef(self, stmt): # TODO check all inputs and which ones should be treated in stage 1 or 2 name = self._visit(stmt.name) name = name.replace("'", '') arguments = self._visit(stmt.args) local_vars = [] global_vars = [] headers = [] templates = {} hide = False kind = 'function' is_pure = False is_elemental = False is_private = False imports = [] def fill_types(ls): container = [] for arg in ls: if isinstance(arg, Symbol): arg = arg.name container.append(arg) elif isinstance(arg, LiteralString): arg = str(arg) arg = arg.strip("'").strip('"') container.append(arg) else: msg = 'Invalid argument of type {} passed to types decorator'.format( type(arg)) errors.report(msg, bounding_box=(stmt.lineno, stmt.col_offset), severity='error') return container decorators = {} for d in self._visit(stmt.decorator_list): tmp_var = str(d) if isinstance(d, Symbol) else str(type(d)) if tmp_var in decorators: decorators[tmp_var] += [d] else: decorators[tmp_var] = [d] if 'bypass' in decorators: return EmptyNode() if 'stack_array' in decorators: decorators['stack_array'] = tuple( str(b) for a in decorators['stack_array'] for b in a.args) if 'allow_negative_index' in decorators: decorators['allow_negative_index'] = tuple( str(b) for a in decorators['allow_negative_index'] for b in a.args) # extract the templates if 'template' in decorators: for comb_types in decorators['template']: cache.clear_cache() types = [] if len(comb_types.args) != 2: msg = 'Number of Arguments provided to the template decorator is not valid' errors.report(msg, symbol=comb_types, bounding_box=(stmt.lineno, stmt.col_offset), severity='error') for i in comb_types.args: if isinstance(i, ValuedArgument) and not i.name in ('name', 'types'): msg = 'Argument provided to the template decorator is not valid' errors.report(msg, symbol=comb_types, bounding_box=(stmt.lineno, stmt.col_offset), severity='error') if all(isinstance(i, ValuedArgument) for i in comb_types.args): tp_name, ls = (comb_types.args[0].value, comb_types.args[1].value) if\ comb_types.args[0].name == 'name' else\ (comb_types.args[1].value, comb_types.args[0].value) else: tp_name = comb_types.args[0] ls = comb_types.args[1] ls = ls.value if isinstance(ls, ValuedArgument) else ls try: tp_name = str(tp_name) ls = ls if isinstance(ls, PythonTuple) else list(ls) except TypeError: msg = 'Argument provided to the template decorator is not valid' errors.report(msg, symbol=comb_types, bounding_box=(stmt.lineno, stmt.col_offset), severity='fatal') types = fill_types(ls) txt = '#$ header template ' + str(tp_name) txt += '(' + '|'.join(types) + ')' if tp_name in templates: msg = 'The template "{}" is duplicated'.format(tp_name) errors.report(msg, bounding_box=(stmt.lineno, stmt.col_offset), severity='warning') templates[tp_name] = hdr_parse(stmts=txt) # extract the types to construct a header if 'types' in decorators: for comb_types in decorators['types']: cache.clear_cache() results = [] ls = comb_types.args if len(ls) > 0 and isinstance(ls[-1], ValuedArgument): arg_name = ls[-1].name if not arg_name == 'results': msg = 'Argument "{}" provided to the types decorator is not valid'.format( arg_name) errors.report(msg, symbol=comb_types, bounding_box=(stmt.lineno, stmt.col_offset), severity='error') else: container = ls[-1].value container = container if isinstance( container, PythonTuple) else [container] results = fill_types(container) types = fill_types(ls[:-1]) else: types = fill_types(ls) txt = '#$ header ' + name txt += '(' + ','.join(types) + ')' if results: txt += ' results(' + ','.join(results) + ')' header = hdr_parse(stmts=txt) if name in self.namespace.static_functions: header = header.to_static() headers += [header] body = stmt.body if 'sympy' in decorators.keys(): # TODO maybe we should run pylint here stmt.decorators.pop() func = SympyFunction(name, arguments, [], [stmt.__str__()]) func.set_fst(stmt) self.insert_function(func) return EmptyNode() elif 'python' in decorators.keys(): # TODO maybe we should run pylint here stmt.decorators.pop() func = PythonFunction(name, arguments, [], [stmt.__str__()]) func.set_fst(stmt) self.insert_function(func) return EmptyNode() else: body = self._visit(body) if 'pure' in decorators.keys(): is_pure = True if 'elemental' in decorators.keys(): is_elemental = True if len(arguments) > 1: errors.report(FORTRAN_ELEMENTAL_SINGLE_ARGUMENT, symbol=decorators['elemental'], bounding_box=(stmt.lineno, stmt.col_offset), severity='error') if 'private' in decorators.keys(): is_private = True returns = [i.expr for i in _atomic(body, cls=Return)] assert all(len(i) == len(returns[0]) for i in returns) results = [] result_counter = 1 for i in zip(*returns): if not all(i[0] == j for j in i) or not isinstance(i[0], Symbol): result_name, result_counter = create_variable( self._used_names, prefix='Out', counter=result_counter) results.append(result_name) elif isinstance(i[0], Symbol) and any(i[0].name == x.name for x in arguments): result_name, result_counter = create_variable( self._used_names, prefix='Out', counter=result_counter) results.append(result_name) else: results.append(i[0]) func = FunctionDef(name, arguments, results, body, local_vars=local_vars, global_vars=global_vars, hide=hide, kind=kind, is_pure=is_pure, is_elemental=is_elemental, is_private=is_private, imports=imports, decorators=decorators, headers=headers, templates=templates) func.set_fst(stmt) return func