def lambdify(expr, args): if isinstance(args, Lambda): new_expr = args.expr new_expr = Return(new_expr) new_expr.set_fst(expr) f_arguments = args.variables func = FunctionDef('lambda', f_arguments, [], [new_expr]) return func code = compile(args.body[0], '', 'single') g = {} eval(code, g) f_name = str(args.name) code = g[f_name] new_args = args.arguments new_expr = code(*new_args) f_arguments = list(new_expr.free_symbols) stmts = cse(new_expr) if isinstance(stmts[-1], (Assign, GC)): var = stmts[-1].lhs else: var = create_variable(expr) stmts[-1] = Assign(var, stmts[-1]) stmts += [Return([var])] set_fst(stmts, args.fst) func = FunctionDef(f_name, new_args, [], stmts, decorators=args.decorators) return func
def as_static_function_call(func, mod_name, name=None): assert isinstance(func, FunctionDef) assert isinstance(mod_name, str) # create function alias by prepending 'mod_' to its name func_alias = func.clone('mod_' + str(func.name)) # from module import func as func_alias imports = [Import(target=AsName(func.name, func_alias.name), source=mod_name)] # function arguments args = sanitize_arguments(func.arguments) # function body call = FunctionCall(func_alias, args) results = func.results results = results[0] if len(results) == 1 else results stmt = call if len(func.results) == 0 else Assign(results, call) body = [stmt] # new function declaration new_func = FunctionDef(func.name, list(args), func.results, body, arguments_inout = func.arguments_inout, functions = func.functions, interfaces = func.interfaces, imports = imports, doc_string = func.doc_string, ) # make it compatible with c static_func = as_static_function(new_func, name) return static_func
def _create_wrapper_check(self, check_var, parse_args, types_dict, used_names, func_name): check_func_body = [] flags = (len(types_dict) - 1) * 4 for arg in types_dict: var_name = "" body = [] types = [] arg_type_check_list = list(types_dict[arg]) arg_type_check_list.sort(key= lambda x : x[0].precision) for elem in arg_type_check_list: var_name = elem[0].name value = elem[2] << flags body.append((elem[1], [AugAssign(check_var, '+' ,value)])) types.append(elem[0]) flags -= 4 error = ' or '.join(['{} bit {}'.format(v.precision * 8 , str_dtype(v.dtype)) if not isinstance(v.dtype, NativeBool) else str_dtype(v.dtype) for v in types]) body.append((LiteralTrue(), [PyErr_SetString('PyExc_TypeError', '"{} must be {}"'.format(var_name, error)), Return([LiteralInteger(0)])])) check_func_body += [If(*body)] check_func_body = [Assign(check_var, LiteralInteger(0))] + check_func_body check_func_body.append(Return([check_var])) # Creating check function definition check_func_name = self.get_new_name(used_names.union(self._global_names), 'type_check') self._global_names.add(check_func_name) check_func_def = FunctionDef(name = check_func_name, arguments = parse_args, results = [check_var], body = check_func_body, local_vars = []) return check_func_def
def __new__(cls, name, arguments, results, body, **kwargs): generators = kwargs.pop('generators', {}) m_results = kwargs.pop('m_results', []) obj = FunctionDef.__new__(cls, name, arguments, results, body, **kwargs) obj._generators = generators obj._m_results = m_results return obj
def as_static_function_call(func): assert (isinstance(func, FunctionDef)) args = func.arguments args = sanitize_arguments(args) functions = func.functions body = [FunctionCall(func, args)] func = FunctionDef(func.name, list(args), [], body, arguments_inout=func.arguments_inout, functions=functions) static_func = as_static_function(func) return static_func
def _visit_FunctionDef(self, stmt): # TODO check all inputs and which ones should be treated in stage 1 or 2 name = self._visit(stmt.name) name = name.replace("'", '') arguments = self._visit(stmt.args) local_vars = [] global_vars = [] headers = [] templates = {} hide = False kind = 'function' is_pure = False is_elemental = False is_private = False imports = [] def fill_types(ls): container = [] for arg in ls: if isinstance(arg, Symbol): arg = arg.name container.append(arg) elif isinstance(arg, LiteralString): arg = str(arg) arg = arg.strip("'").strip('"') container.append(arg) else: msg = 'Invalid argument of type {} passed to types decorator'.format( type(arg)) errors.report(msg, bounding_box=(stmt.lineno, stmt.col_offset), severity='error') return container decorators = {} for d in self._visit(stmt.decorator_list): tmp_var = str(d) if isinstance(d, Symbol) else str(type(d)) if tmp_var in decorators: decorators[tmp_var] += [d] else: decorators[tmp_var] = [d] if 'bypass' in decorators: return EmptyNode() if 'stack_array' in decorators: decorators['stack_array'] = tuple( str(b) for a in decorators['stack_array'] for b in a.args) if 'allow_negative_index' in decorators: decorators['allow_negative_index'] = tuple( str(b) for a in decorators['allow_negative_index'] for b in a.args) # extract the templates if 'template' in decorators: for comb_types in decorators['template']: cache.clear_cache() types = [] if len(comb_types.args) != 2: msg = 'Number of Arguments provided to the template decorator is not valid' errors.report(msg, symbol=comb_types, bounding_box=(stmt.lineno, stmt.col_offset), severity='error') for i in comb_types.args: if isinstance(i, ValuedArgument) and not i.name in ('name', 'types'): msg = 'Argument provided to the template decorator is not valid' errors.report(msg, symbol=comb_types, bounding_box=(stmt.lineno, stmt.col_offset), severity='error') if all(isinstance(i, ValuedArgument) for i in comb_types.args): tp_name, ls = (comb_types.args[0].value, comb_types.args[1].value) if\ comb_types.args[0].name == 'name' else\ (comb_types.args[1].value, comb_types.args[0].value) else: tp_name = comb_types.args[0] ls = comb_types.args[1] ls = ls.value if isinstance(ls, ValuedArgument) else ls try: tp_name = str(tp_name) ls = ls if isinstance(ls, PythonTuple) else list(ls) except TypeError: msg = 'Argument provided to the template decorator is not valid' errors.report(msg, symbol=comb_types, bounding_box=(stmt.lineno, stmt.col_offset), severity='fatal') types = fill_types(ls) txt = '#$ header template ' + str(tp_name) txt += '(' + '|'.join(types) + ')' if tp_name in templates: msg = 'The template "{}" is duplicated'.format(tp_name) errors.report(msg, bounding_box=(stmt.lineno, stmt.col_offset), severity='warning') templates[tp_name] = hdr_parse(stmts=txt) # extract the types to construct a header if 'types' in decorators: for comb_types in decorators['types']: cache.clear_cache() results = [] ls = comb_types.args if len(ls) > 0 and isinstance(ls[-1], ValuedArgument): arg_name = ls[-1].name if not arg_name == 'results': msg = 'Argument "{}" provided to the types decorator is not valid'.format( arg_name) errors.report(msg, symbol=comb_types, bounding_box=(stmt.lineno, stmt.col_offset), severity='error') else: container = ls[-1].value container = container if isinstance( container, PythonTuple) else [container] results = fill_types(container) types = fill_types(ls[:-1]) else: types = fill_types(ls) txt = '#$ header ' + name txt += '(' + ','.join(types) + ')' if results: txt += ' results(' + ','.join(results) + ')' header = hdr_parse(stmts=txt) if name in self.namespace.static_functions: header = header.to_static() headers += [header] body = stmt.body if 'sympy' in decorators.keys(): # TODO maybe we should run pylint here stmt.decorators.pop() func = SympyFunction(name, arguments, [], [stmt.__str__()]) func.set_fst(stmt) self.insert_function(func) return EmptyNode() elif 'python' in decorators.keys(): # TODO maybe we should run pylint here stmt.decorators.pop() func = PythonFunction(name, arguments, [], [stmt.__str__()]) func.set_fst(stmt) self.insert_function(func) return EmptyNode() else: body = self._visit(body) if 'pure' in decorators.keys(): is_pure = True if 'elemental' in decorators.keys(): is_elemental = True if len(arguments) > 1: errors.report(FORTRAN_ELEMENTAL_SINGLE_ARGUMENT, symbol=decorators['elemental'], bounding_box=(stmt.lineno, stmt.col_offset), severity='error') if 'private' in decorators.keys(): is_private = True returns = [i.expr for i in _atomic(body, cls=Return)] assert all(len(i) == len(returns[0]) for i in returns) results = [] result_counter = 1 for i in zip(*returns): if not all(i[0] == j for j in i) or not isinstance(i[0], Symbol): result_name, result_counter = create_variable( self._used_names, prefix='Out', counter=result_counter) results.append(result_name) elif isinstance(i[0], Symbol) and any(i[0].name == x.name for x in arguments): result_name, result_counter = create_variable( self._used_names, prefix='Out', counter=result_counter) results.append(result_name) else: results.append(i[0]) func = FunctionDef(name, arguments, results, body, local_vars=local_vars, global_vars=global_vars, hide=hide, kind=kind, is_pure=is_pure, is_elemental=is_elemental, is_private=is_private, imports=imports, decorators=decorators, headers=headers, templates=templates) func.set_fst(stmt) return func
def ompfy(stmt, **options): """ Converts some statements to OpenMP statments. stmt: stmt, list statement or a list of statements """ if isinstance(stmt, (list, tuple, Tuple)): return [ompfy(i, **options) for i in stmt] if isinstance(stmt, Tensor): # TODO to implement return stmt if isinstance(stmt, ForIterator): iterable = ompfy(stmt.iterable, **options) target = stmt.target body = ompfy(stmt.body, **options) info, clauses = get_for_clauses(iterable) if (clauses is None): return ForIterator(target, iterable, body, strict=False) else: loop = ForIterator(target, iterable, body, strict=False) nowait = info['nowait'] return OMP_For(loop, clauses, nowait) if isinstance(stmt, For): iterable = ompfy(stmt.iterable, **options) target = stmt.target body = ompfy(stmt.body, **options) return For(target, iterable, body, strict=False) if isinstance(stmt, list): return [ompfy(a, **options) for a in stmt] if isinstance(stmt, While): test = ompfy(stmt.test, **options) body = ompfy(stmt.body, **options) return While(test, body) if isinstance(stmt, With): test = ompfy(stmt.test, **options) body = ompfy(stmt.body, **options) settings = ompfy(stmt.settings, **options) clauses = get_with_clauses(test) if (clauses is None): return With(test, body, settings) else: # TODO to be defined variables = [] return OMP_Parallel(clauses, variables, body) if isinstance(stmt, If): args = [] for block in stmt.args: test = block[0] stmts = block[1] t = ompfy(test, **options) s = ompfy(stmts, **options) args.append((t,s)) return If(*args) if isinstance(stmt, FunctionDef): name = ompfy(stmt.name, **options) arguments = ompfy(stmt.arguments, **options) results = ompfy(stmt.results, **options) body = ompfy(stmt.body, **options) local_vars = ompfy(stmt.local_vars, **options) global_vars = ompfy(stmt.global_vars, **options) return FunctionDef(name, arguments, results, body, local_vars, global_vars) if isinstance(stmt, ClassDef): name = ompfy(stmt.name, **options) attributs = ompfy(stmt.attributs, **options) methods = ompfy(stmt.methods, **options) options = ompfy(stmt.options, **options) return ClassDef(name, attributs, methods, options) if isinstance(stmt, Module): name = ompfy(stmt.name, **options) variables = ompfy(stmt.variables, **options) funcs = ompfy(stmt.funcs , **options) classes = ompfy(stmt.classes , **options) imports = ompfy(stmt.imports , **options) imports += [Import('omp_lib')] return Module(name, variables, funcs, classes, imports=imports) if isinstance(stmt, Program): name = ompfy(stmt.name, **options) variables = ompfy(stmt.variables, **options) funcs = ompfy(stmt.funcs , **options) classes = ompfy(stmt.classes , **options) imports = ompfy(stmt.imports , **options) body = ompfy(stmt.body , **options) modules = ompfy(stmt.modules , **options) imports += [Import('omp_lib')] return Program(name, variables, funcs, classes, body, imports=imports, modules=modules) if isinstance(stmt, ParallelBlock): variables = stmt.variables body = stmt.body clauses = stmt.clauses return OMP_Parallel(clauses, variables, body) return stmt
def _print_FunctionDef(self, expr): # Save all used names used_names = set([a.name for a in expr.arguments] + [r.name for r in expr.results] + [expr.name.name]) # Find a name for the wrapper function wrapper_name = self._get_wrapper_name(used_names, expr) used_names.add(wrapper_name) # Collect local variables wrapper_vars = {a.name: a for a in expr.arguments} wrapper_vars.update({r.name: r for r in expr.results}) python_func_args = self.get_new_PyObject("args", used_names) python_func_kwargs = self.get_new_PyObject("kwargs", used_names) python_func_selfarg = self.get_new_PyObject("self", used_names) # Collect arguments and results wrapper_args = [ python_func_selfarg, python_func_args, python_func_kwargs ] wrapper_results = [self.get_new_PyObject("result", used_names)] if expr.is_private: wrapper_func = FunctionDef( name=wrapper_name, arguments=wrapper_args, results=wrapper_results, body=[ PyErr_SetString( 'PyExc_NotImplementedError', '"Private functions are not accessible from python"'), AliasAssign(wrapper_results[0], Nil()), Return(wrapper_results) ]) return CCodePrinter._print_FunctionDef(self, wrapper_func) if any(isinstance(arg, FunctionAddress) for arg in expr.arguments): wrapper_func = FunctionDef( name=wrapper_name, arguments=wrapper_args, results=wrapper_results, body=[ PyErr_SetString('PyExc_NotImplementedError', '"Cannot pass a function as an argument"'), AliasAssign(wrapper_results[0], Nil()), Return(wrapper_results) ]) return CCodePrinter._print_FunctionDef(self, wrapper_func) # Collect argument names for PyArgParse arg_names = [a.name for a in expr.arguments] keyword_list_name = self.get_new_name(used_names, 'kwlist') keyword_list = PyArgKeywords(keyword_list_name, arg_names) wrapper_body = [keyword_list] wrapper_body_translations = [] parse_args = [] collect_vars = {} for arg in expr.arguments: collect_var, cast_func = self.get_PyArgParseType(used_names, arg) collect_vars[arg] = collect_var body, tmp_variable = self._body_management(used_names, arg, collect_var, cast_func, True) if tmp_variable: wrapper_vars[tmp_variable.name] = tmp_variable # If the variable cannot be collected from PyArgParse directly wrapper_vars[collect_var.name] = collect_var # Save cast to argument variable wrapper_body_translations.extend(body) parse_args.append(collect_var) # Write default values if isinstance(arg, ValuedVariable): wrapper_body.append( self.get_default_assign(parse_args[-1], arg)) # Parse arguments parse_node = PyArg_ParseTupleNode(python_func_args, python_func_kwargs, expr.arguments, parse_args, keyword_list) wrapper_body.append(If((PyccelNot(parse_node), [Return([Nil()])]))) wrapper_body.extend(wrapper_body_translations) # Call function static_function, static_args, additional_body = self._get_static_function( used_names, expr, collect_vars) wrapper_body.extend(additional_body) for var in static_args: wrapper_vars[var.name] = var if len(expr.results) == 0: func_call = FunctionCall(static_function, static_args) else: results = expr.results if len( expr.results) > 1 else expr.results[0] func_call = Assign(results, FunctionCall(static_function, static_args)) wrapper_body.append(func_call) # Loop over results to carry out necessary casts and collect Py_BuildValue type string res_args = [] for a in expr.results: collect_var, cast_func = self.get_PyBuildValue(used_names, a) if cast_func is not None: wrapper_vars[collect_var.name] = collect_var wrapper_body.append(AliasAssign(collect_var, cast_func)) res_args.append( VariableAddress(collect_var) if collect_var. is_pointer else collect_var) # Call PyBuildNode wrapper_body.append( AliasAssign(wrapper_results[0], PyBuildValueNode(res_args))) # Call free function for python type wrapper_body += [ FunctionCall(Py_DECREF, [i]) for i in self._to_free_PyObject_list ] self._to_free_PyObject_list.clear() #Return wrapper_body.append(Return(wrapper_results)) # Create FunctionDef and write using classic method wrapper_func = FunctionDef(name=wrapper_name, arguments=wrapper_args, results=wrapper_results, body=wrapper_body, local_vars=wrapper_vars.values()) return CCodePrinter._print_FunctionDef(self, wrapper_func)
def _print_Interface(self, expr): # Collecting all functions funcs = expr.functions # Save all used names used_names = set(n.name for n in funcs) # Find a name for the wrapper function wrapper_name = self._get_wrapper_name(used_names, expr) self._global_names.add(wrapper_name) # Collect local variables python_func_args = self.get_new_PyObject("args", used_names) python_func_kwargs = self.get_new_PyObject("kwargs", used_names) python_func_selfarg = self.get_new_PyObject("self", used_names) # Collect wrapper arguments and results wrapper_args = [ python_func_selfarg, python_func_args, python_func_kwargs ] wrapper_results = [self.get_new_PyObject("result", used_names)] # Collect parser arguments wrapper_vars = {} # Collect argument names for PyArgParse arg_names = [a.name for a in funcs[0].arguments] keyword_list_name = self.get_new_name(used_names, 'kwlist') keyword_list = PyArgKeywords(keyword_list_name, arg_names) wrapper_body = [keyword_list] wrapper_body_translations = [] body_tmp = [] # To store the mini function responsible of collecting value and calling interfaces functions and return the builded value funcs_def = [] default_value = { } # dict to collect all initialisation needed in the wrapper check_var = Variable(dtype=NativeInteger(), name=self.get_new_name(used_names, "check")) wrapper_vars[check_var.name] = check_var types_dict = OrderedDict( (a, set()) for a in funcs[0].arguments ) #dict to collect each variable possible type and the corresponding flags # collect parse arg parse_args = [ Variable(dtype=PyccelPyArrayObject(), is_pointer=True, rank=a.rank, order=a.order, name=self.get_new_name(used_names, a.name + "_tmp")) if a.rank > 0 else Variable(dtype=PyccelPyObject(), name=self.get_new_name(used_names, a.name + "_tmp"), is_pointer=True) for a in funcs[0].arguments ] # Managing the body of wrapper for func in funcs: mini_wrapper_func_body = [] res_args = [] mini_wrapper_func_vars = {a.name: a for a in func.arguments} flags = 0 collect_vars = {} # Loop for all args in every functions and create the corresponding condition and body for p_arg, f_arg in zip(parse_args, func.arguments): collect_vars[f_arg] = p_arg body, tmp_variable = self._body_management( used_names, f_arg, p_arg, None) if tmp_variable: mini_wrapper_func_vars[tmp_variable.name] = tmp_variable # get check type function check = self._get_check_type_statement(f_arg, p_arg) # If the variable cannot be collected from PyArgParse directly wrapper_vars[p_arg.name] = p_arg # Save the body wrapper_body_translations.extend(body) # Write default values if isinstance(f_arg, ValuedVariable): wrapper_body.append( self.get_default_assign(parse_args[-1], f_arg)) flag_value = flags_registry[(f_arg.dtype, f_arg.precision)] flags = (flags << 4) + flag_value # shift by 4 to the left types_dict[f_arg].add( (f_arg, check, flag_value)) # collect variable type for each arguments mini_wrapper_func_body += body # create the corresponding function call static_function, static_args, additional_body = self._get_static_function( used_names, func, collect_vars) mini_wrapper_func_body.extend(additional_body) for var in static_args: mini_wrapper_func_vars[var.name] = var if len(func.results) == 0: func_call = FunctionCall(static_function, static_args) else: results = func.results if len( func.results) > 1 else func.results[0] func_call = Assign(results, FunctionCall(static_function, static_args)) mini_wrapper_func_body.append(func_call) # Loop for all res in every functions and create the corresponding body and cast for r in func.results: collect_var, cast_func = self.get_PyBuildValue(used_names, r) mini_wrapper_func_vars[collect_var.name] = collect_var if cast_func is not None: mini_wrapper_func_vars[r.name] = r mini_wrapper_func_body.append( AliasAssign(collect_var, cast_func)) res_args.append( VariableAddress(collect_var) if collect_var. is_pointer else collect_var) # Building PybuildValue and freeing the allocated variable after. mini_wrapper_func_body.append( AliasAssign(wrapper_results[0], PyBuildValueNode(res_args))) mini_wrapper_func_body += [ FunctionCall(Py_DECREF, [i]) for i in self._to_free_PyObject_list ] mini_wrapper_func_body.append(Return(wrapper_results)) self._to_free_PyObject_list.clear() # Building Mini wrapper function mini_wrapper_func_name = self.get_new_name( used_names.union(self._global_names), func.name.name + '_mini_wrapper') self._global_names.add(mini_wrapper_func_name) mini_wrapper_func_def = FunctionDef( name=mini_wrapper_func_name, arguments=parse_args, results=wrapper_results, body=mini_wrapper_func_body, local_vars=mini_wrapper_func_vars.values()) funcs_def.append(mini_wrapper_func_def) # append check condition to the functioncall body_tmp.append((PyccelEq(check_var, LiteralInteger(flags)), [ AliasAssign(wrapper_results[0], FunctionCall(mini_wrapper_func_def, parse_args)) ])) # Errors / Types management # Creating check_type function check_func_def = self._create_wrapper_check(check_var, parse_args, types_dict, used_names, funcs[0].name.name) funcs_def.append(check_func_def) # Create the wrapper body with collected informations body_tmp = [((PyccelNot(check_var), [Return([Nil()])]))] + body_tmp body_tmp.append((LiteralTrue(), [ PyErr_SetString('PyExc_TypeError', '"Arguments combinations don\'t exist"'), Return([Nil()]) ])) wrapper_body_translations = [If(*body_tmp)] # Parsing Arguments parse_node = PyArg_ParseTupleNode(python_func_args, python_func_kwargs, funcs[0].arguments, parse_args, keyword_list, True) wrapper_body += list(default_value.values()) wrapper_body.append(If((PyccelNot(parse_node), [Return([Nil()])]))) #finishing the wrapper body wrapper_body.append( Assign(check_var, FunctionCall(check_func_def, parse_args))) wrapper_body.extend(wrapper_body_translations) wrapper_body.append(Return(wrapper_results)) # Return # Create FunctionDef funcs_def.append( FunctionDef(name=wrapper_name, arguments=wrapper_args, results=wrapper_results, body=wrapper_body, local_vars=wrapper_vars.values())) sep = self._print(SeparatorComment(40)) return sep + '\n'.join( CCodePrinter._print_FunctionDef(self, f) for f in funcs_def)
def as_static_function(func): assert (isinstance(func, FunctionDef)) args = func.arguments results = func.results body = func.body arguments_inout = func.arguments_inout functions = func.functions _results = [] if results: if len(results) == 1: result = results[0] if result.rank > 0: # updates args args = list(args) + [result] arguments_inout += [False] else: _results = results else: raise NotImplementedError('when len(results) > 1') name = 'f2py_{}'.format(func.name).lower() # ... results_names = [i.name for i in results] _args = [] _arguments_inout = [] for i_a, a in enumerate(args): if not isinstance(a, Variable): raise TypeError('Expecting a Variable type for {}'.format(a)) rank = a.rank if rank > 0: # ... additional_args = [] for i in range(0, rank): n_name = 'n{i}_{name}'.format(name=str(a.name), i=i) n_arg = Variable('int', n_name) additional_args += [n_arg] shape_new = Tuple(*additional_args, sympify=False) # ... _args += additional_args for j in additional_args: _arguments_inout += [False] a_new = Variable(a.dtype, a.name, allocatable=a.allocatable, is_pointer=a.is_pointer, is_target=a.is_target, is_optional=a.is_optional, shape=shape_new, rank=a.rank, order=a.order, precision=a.precision) if not (a.name in results_names): _args += [a_new] else: _results += [a_new] else: _args += [a] intent = arguments_inout[i_a] _arguments_inout += [intent] args = _args results = _results arguments_inout = _arguments_inout # ... return FunctionDef(name, list(args), results, body, local_vars=func.local_vars, is_static=True, arguments_inout=arguments_inout, functions=functions)
def __new__(cls, func, import_lambda): # ... m_results = func.m_results name = 'interface_{}'.format(func.name) args = [i for i in func.arguments if not i in m_results] s_results = func.results results = list(s_results) + list(m_results) # ... # ... imports = [import_lambda] stmts = [] # ... # ... out argument if len(results) == 1: outs = [Symbol('out')] else: outs = [Symbol('out_{}'.format(i)) for i in range(0, len(results))] # ... # ... generators = func.generators d_shapes = {} for i in m_results: d_shapes[i] = compute_shape(i, generators) # ... # ... TODO build statements if_cond = Is(Symbol('out'), Nil()) if_body = [] # TODO add imports from numpy if_body += [Import('zeros', 'numpy')] if_body += [Import('float64', 'numpy')] for i, var in enumerate(results): if var in m_results: shaping = d_shapes[var] if_body += shaping.stmts if_body += [Assign(outs[i], Zeros(shaping.var, var.dtype))] # update statements stmts = [If((if_cond, if_body))] # ... # ... add call to the python or pyccelized function stmts += [FunctionCall(func, args + outs)] # ... # ... add return out if len(outs) == 1: stmts += [Return(outs[0])] else: stmts += [Return(outs)] # ... # ... body = imports + stmts # ... # update arguments with optional args += [Assign(Symbol('out'), Nil())] return FunctionDef(name, args, results, body)