def _visit_arguments(self, stmt): if stmt.vararg or stmt.kwarg: errors.report(VARARGS, symbol = stmt, severity='fatal') arguments = [] if stmt.args: n_expl = len(stmt.args)-len(stmt.defaults) positional_args = [Argument(a.arg, annotation=self._visit(a.annotation)) for a in stmt.args[:n_expl]] valued_arguments = [ValuedArgument(Argument(a.arg, annotation=self._visit(a.annotation)),\ self._visit(d)) for a,d in zip(stmt.args[n_expl:],stmt.defaults)] arguments = positional_args + valued_arguments if stmt.kwonlyargs: for a,d in zip(stmt.kwonlyargs,stmt.kw_defaults): annotation = self._visit(a.annotation) if d is not None: arg = Argument(a.arg, annotation=annotation) arg = ValuedArgument(arg, self._visit(d), kwonly=True) else: arg = Argument(a.arg, kwonly=True, annotation=annotation) arguments.append(arg) return arguments
def _visit_arguments(self, stmt): arguments = [] if stmt.vararg or stmt.kwarg: errors.report(VARARGS, symbol=stmt, severity='fatal') if stmt.args: n_expl = len(stmt.args) - len(stmt.defaults) arguments += [Argument(a.arg) for a in stmt.args[:n_expl]] arguments += [ ValuedArgument(Argument(a.arg), self._visit(d)) for a, d in zip(stmt.args[n_expl:], stmt.defaults) ] if stmt.kwonlyargs: arguments += [ ValuedArgument(Argument(a.arg), self._visit(d), kwonly=True) if d is not None else Argument(a.arg, kwonly=True) for a, d in zip(stmt.kwonlyargs, stmt.kw_defaults) ] return arguments
def expr(self): arg_ = self.arg if isinstance(arg_, MacroList): return Tuple(*arg_.expr) arg = Symbol(str(arg_)) value = self.value if not(value is None): if isinstance(value, (MacroStmt,StringStmt)): value = value.expr else: value = sympify(str(value),locals={'N':Symbol('N'),'S':Symbol('S')}) return ValuedArgument(arg, value) return arg
def _visit_keyword(self, stmt): target = stmt.arg val = self._visit(stmt.value) return ValuedArgument(target, val)
def _visit_FunctionDef(self, stmt): # TODO check all inputs and which ones should be treated in stage 1 or 2 name = self._visit(stmt.name) name = name.replace("'", '') arguments = self._visit(stmt.args) local_vars = [] global_vars = [] headers = [] templates = {} is_pure = False is_elemental = False is_private = False imports = [] doc_string = None def fill_types(ls): container = [] for arg in ls: if isinstance(arg, Symbol): arg = arg.name container.append(arg) elif isinstance(arg, LiteralString): arg = str(arg) arg = arg.strip("'").strip('"') container.append(arg) else: msg = 'Invalid argument of type {} passed to types decorator'.format(type(arg)) errors.report(msg, bounding_box = (stmt.lineno, stmt.col_offset), severity='error') return container decorators = {} # add the decorator @types if the arguments are annotated annotated_args = [] for a in arguments: if isinstance(a, Argument): annotated_args.append(a.annotation) elif isinstance(a, ValuedArgument): annotated_args.append(a.argument.annotation) if all(not isinstance(a, Nil) for a in annotated_args): if stmt.returns: returns = ValuedArgument(Symbol('results'),self._visit(stmt.returns)) annotated_args.append(returns) decorators['types'] = [Function('types')(*annotated_args)] for d in self._visit(stmt.decorator_list): tmp_var = str(d) if isinstance(d, Symbol) else str(type(d)) if tmp_var in decorators: decorators[tmp_var] += [d] else: decorators[tmp_var] = [d] if 'bypass' in decorators: return EmptyNode() if 'stack_array' in decorators: decorators['stack_array'] = tuple(str(b) for a in decorators['stack_array'] for b in a.args) if 'allow_negative_index' in decorators: decorators['allow_negative_index'] = tuple(str(b) for a in decorators['allow_negative_index'] for b in a.args) # extract the templates if 'template' in decorators: for comb_types in decorators['template']: cache.clear_cache() types = [] if len(comb_types.args) != 2: msg = 'Number of Arguments provided to the template decorator is not valid' errors.report(msg, symbol = comb_types, bounding_box = (stmt.lineno, stmt.col_offset), severity='error') for i in comb_types.args: if isinstance(i, ValuedArgument) and not i.name in ('name', 'types'): msg = 'Argument provided to the template decorator is not valid' errors.report(msg, symbol = comb_types, bounding_box = (stmt.lineno, stmt.col_offset), severity='error') if all(isinstance(i, ValuedArgument) for i in comb_types.args): tp_name, ls = (comb_types.args[0].value, comb_types.args[1].value) if\ comb_types.args[0].name == 'name' else\ (comb_types.args[1].value, comb_types.args[0].value) else: tp_name = comb_types.args[0] ls = comb_types.args[1] ls = ls.value if isinstance(ls, ValuedArgument) else ls try: tp_name = str(tp_name) ls = ls if isinstance(ls, PythonTuple) else list(ls) except TypeError: msg = 'Argument provided to the template decorator is not valid' errors.report(msg, symbol = comb_types, bounding_box = (stmt.lineno, stmt.col_offset), severity='fatal') types = fill_types(ls) txt = '#$ header template ' + str(tp_name) txt += '(' + '|'.join(types) + ')' if tp_name in templates: msg = 'The template "{}" is duplicated'.format(tp_name) errors.report(msg, bounding_box = (stmt.lineno, stmt.col_offset), severity='warning') templates[tp_name] = hdr_parse(stmts=txt) # extract the types to construct a header if 'types' in decorators: for comb_types in decorators['types']: cache.clear_cache() results = [] ls = comb_types.args if len(ls) > 0 and isinstance(ls[-1], ValuedArgument): arg_name = ls[-1].name if not arg_name == 'results': msg = 'Argument "{}" provided to the types decorator is not valid'.format(arg_name) errors.report(msg, symbol = comb_types, bounding_box = (stmt.lineno, stmt.col_offset), severity='error') else: container = ls[-1].value container = container if isinstance(container, PythonTuple) else [container] results = fill_types(container) types = fill_types(ls[:-1]) else: types = fill_types(ls) txt = '#$ header ' + name txt += '(' + ','.join(types) + ')' if results: txt += ' results(' + ','.join(results) + ')' header = hdr_parse(stmts=txt) if name in self.namespace.static_functions: header = header.to_static() headers += [header] body = stmt.body if 'sympy' in decorators.keys(): # TODO maybe we should run pylint here stmt.decorators.pop() func = SympyFunction(name, arguments, [], [stmt.__str__()]) func.set_fst(stmt) self.insert_function(func) return EmptyNode() elif 'python' in decorators.keys(): # TODO maybe we should run pylint here stmt.decorators.pop() func = PythonFunction(name, arguments, [], [stmt.__str__()]) func.set_fst(stmt) self.insert_function(func) return EmptyNode() else: body = self._visit(body) if len(body) > 0 and isinstance(body[0], CommentBlock): doc_string = body[0] doc_string.header = '' body = body[1:] if 'pure' in decorators.keys(): is_pure = True if 'elemental' in decorators.keys(): is_elemental = True if len(arguments) > 1: errors.report(FORTRAN_ELEMENTAL_SINGLE_ARGUMENT, symbol=decorators['elemental'], bounding_box=(stmt.lineno, stmt.col_offset), severity='error') if 'private' in decorators.keys(): is_private = True returns = [i.expr for i in _atomic(body, cls=Return)] assert all(len(i) == len(returns[0]) for i in returns) results = [] result_counter = 1 for i in zip(*returns): if not all(i[0]==j for j in i) or not isinstance(i[0], Symbol): result_name, result_counter = create_variable(self._used_names, prefix = 'Out', counter = result_counter) results.append(result_name) elif isinstance(i[0], Symbol) and any(i[0].name==x.name for x in arguments): result_name, result_counter = create_variable(self._used_names, prefix = 'Out', counter = result_counter) results.append(result_name) else: results.append(i[0]) func = FunctionDef( name, arguments, results, body, local_vars=local_vars, global_vars=global_vars, is_pure=is_pure, is_elemental=is_elemental, is_private=is_private, imports=imports, decorators=decorators, headers=headers, templates=templates, doc_string=doc_string) func.set_fst(stmt) return func