def visit_Attribute(self, node: ast3.Attribute) -> VisitorOutput: """Transforms accessing to an attribute to be handled as PythonValues do For example, it converts:: expr.val into:: expr.attr['val'] or:: expr.attr[('val', pos)]""" self.generic_visit(node) pos = pos_as_tuple(node) if pos is not None: varname = ast3.Tuple(elts=[ast3.Str(s=node.attr), pos], ctx=ast3.Load()) # type: ast3.expr else: varname = ast3.Str(s=node.attr) return ast3.Subscript( value=ast3.Attribute( value=node.value, attr='attr', ctx=ast3.Load(), ), slice=ast3.Index(value=varname, ), ctx=node.ctx, )
def _Class(self, node: ET.Element): # pylint: disable=invalid-name context = node.attrib['context'] assert context in self.namespaces, context cls_name = node.attrib['name'] if '<' in cls_name: _LOG.warning('processing template class %s', cls_name) assert '>' in cls_name cls_name, _, rest = cls_name.partition('<') rest = rest[:-1] generic_args = [_.strip() for _ in rest.split(',')] _LOG.warning('found generic args: %s', generic_args) full_name = '{}::{}'.format(self.namespaces[context], cls_name) is_stl_class = full_name in CPP_STL_CLASSES and generic_args value_type = None body = [] for member_id in node.attrib['members'].split(): if not is_stl_class: # TODO: handle non-STL classes too break if member_id not in self.all_types: continue member_type = self.all_types[member_id] if member_type.tag == 'Typedef' and member_type.attrib[ 'name'] == 'value_type': referenced_id = member_type.attrib['type'] assert referenced_id in self.all_types if referenced_id not in self.relevant_types \ and referenced_id not in self._new_relevant_types: self._new_relevant_types[referenced_id] = self.all_types[ referenced_id] _LOG.debug( 'type marked as relevant due to being container value type %s', ET.tostring( self.all_types[referenced_id]).decode().rstrip()) body.append(typed_ast3.Expr(typed_ast3.Str(referenced_id, ''))) value_type = referenced_id ''' if member_id not in self.relevant_types and member_id not in self._new_relevant_types: self._new_relevant_types[member_id] = member_type _LOG.warning('marked %s as relevant type', ET.tostring(member_type).decode().rstrip()) body.append(typed_ast3.Expr(typed_ast3.Str(member_id, ''))) ''' base_class = typed_ast3.parse(CPP_PYTHON_CLASS_PAIRS[full_name], mode='eval').body if is_stl_class: assert value_type is not None base_class = typed_ast3.Subscript( value=base_class, slice=typed_ast3.Index(typed_ast3.Str(value_type, '')), ctx=typed_ast3.Load()) return base_class
def visit_Assign(self, node): #astpretty.pprint(node) #Assign(expr* targets, expr value) #AnnAssign(expr target, expr annotation, expr? value, int simple) res = node targets, value = node.targets, node.value lineno = node.lineno if lineno in self.line2shape: assert len(targets) == 1 target = targets[0] shape = self.line2shape[lineno] ann = ast.Str(s=f'{shape}', kind='', lineno=lineno, col_offset=node.col_offset) #print("\n===>", astpretty.pprint(node, indent=' ')) res = ast.AnnAssign(target=target, annotation=ann, value=value, simple=1, lineno=lineno, col_offset=node.col_offset) return res
def _emulate_yield_from(self, targets: Optional[List[ast.Name]], node: ast.YieldFrom) -> Iterable[ast.AST]: generator = ast.Name( id='_py_backwards_generator_{}'.format(self._name_suffix)) exception = ast.Name( id='_py_backwards_generator_exception_{}'.format(self._name_suffix)) yield ast.Assign(targets=[generator], value=ast.Call(func=ast.Name(id='iter'), args=[node.value], keywords=[])) assign_to_targets = [ ast.If(test=ast.Call(func=ast.Name(id='hasattr'), args=[ exception, ast.Str(s='value'), ], keywords=[]), body=[ ast.Assign(targets=targets, value=ast.Attribute( value=exception, attr='value')), ], orelse=[]), ast.Break()] if targets else [ast.Break()] yield ast.While(test=ast.NameConstant(value=True), body=[ ast.Try(body=[ ast.Expr(value=ast.Yield(value=ast.Call( func=ast.Name(id='next'), args=[generator], keywords=[]))), ], handlers=[ ast.ExceptHandler( type=ast.Name(id='StopIteration'), name=exception.id, body=assign_to_targets), ], orelse=[], finalbody=[]), ], orelse=[]) self._name_suffix += 1
def visit_JoinedStr(self, node: ast.JoinedStr) -> ast.Call: self._tree_changed = True join_call = ast.Call(func=ast.Attribute(value=ast.Str(s=''), attr='join'), args=[ast.List(elts=node.values)], keywords=[]) return self.generic_visit(join_call) # type: ignore
def delete_declaration(declaration): if isinstance(declaration, (typed_ast3.Import, typed_ast3.ImportFrom)): # TODO: it's a hack return horast_nodes.Comment( typed_ast3.Str(' skipping a "use" statement when inlining', ''), eol=False) if not isinstance(declaration, (typed_ast3.Assign, typed_ast3.AnnAssign)): return declaration intent = getattr(declaration, 'fortran_metadata', {}).get('intent', None) if intent in {'in', 'out', 'inout'}: return horast_nodes.Comment( typed_ast3.Str(' skipping intent({}) declaration when inlining'.format(intent), ''), eol=False) if getattr(declaration, 'fortran_metadata', {}).get('is_declaration', False): # TODO: it's a hack return horast_nodes.Comment( typed_ast3.Str(' skipping a declaration when inlining', ''), eol=False) return declaration
def visit_FormattedValue(self, node: ast.FormattedValue) -> ast.Call: if node.format_spec: template = ''.join(['{:', node.format_spec.s, '}']) # type: ignore else: template = '{}' format_call = ast.Call(func=ast.Attribute(value=ast.Str(s=template), attr='format'), args=[node.value], keywords=[]) return self.generic_visit(format_call) # type: ignore
def test_non_str_type_comment(self): examples = { typed_ast3.Assign(targets=[ typed_ast3.Name('x', typed_ast3.Store()) ], value=typed_ast3.Str('universe, life, and everything'), type_comment=typed_ast3.Str('42')): logging.DEBUG, typed_ast3.Assign(targets=[ typed_ast3.Name('x', typed_ast3.Store()) ], value=typed_ast3.Str('universe, life, and everything'), type_comment=42): logging.WARNING } for example, expected_level in examples.items(): resolver = TypeHintResolver[typed_ast3, typed_ast3](eval_=False) with self.subTest(example=example, expected_level=expected_level): with self.assertLogs(level=expected_level): resolver.visit(example)
def visit_Constant(self, node): # pylint: disable=invalid-name """Transform Constant into Num or Str.""" type_ = node.type value = node.value _ = self.visit(node.coord) if type_ in ('int', ): return typed_ast3.Num(int(value)) if type_ in ('string', ): assert value[0] == '"' and value[-1] == '"', value return typed_ast3.Str(value[1:-1], '') return self.generic_visit(node)
def visit_Print(self, n): keywords = [] if n.dest is not None: keywords.append(ast3.keyword("file", self.visit(n.dest))) if not n.nl: keywords.append(ast3.keyword("end", ast3.Str(" ", lineno=n.lineno, col_offset=-1))) return ast3.Expr(ast3.Call(ast3.Name("print", ast3.Load(), lineno=n.lineno, col_offset=-1), self.visit(n.values), keywords, lineno=n.lineno, col_offset=-1))
def visit_Name(self, node: ast3.Name) -> VisitorOutput: """Transforms a name lookup into a dictionary lookup. For example, it converts:: var into:: st[('var', ...)] """ pos = pos_as_tuple(node) if pos is not None: varname = ast3.Tuple(elts=[ast3.Str(s=node.id), pos], ctx=ast3.Load()) # type: ast3.expr else: varname = ast3.Str(s=node.id) return ast3.Subscript(value=ast3.Name(id='st', ctx=ast3.Load()), slice=ast3.Index(value=varname), ctx=node.ctx)
def _PointerType(self, node: ET.Element): # pylint: disable=invalid-name type_id = node.attrib['type'] is_const = type_id.endswith('c') if is_const: type_id = type_id[:-1] assert type_id in self.all_types if type_id not in self.relevant_types and type_id not in self._new_relevant_types: self._new_relevant_types[type_id] = self.all_types[type_id] _LOG.debug('type makred as relevant through a pointer: %s', ET.tostring(self.all_types[type_id]).decode().rstrip()) type_info = make_pointer(typed_ast3.Str(type_id, '')) if is_const: type_info = make_const(type_info) return type_info
def _inline_call(self, call, replacers): # template_code = '''for dummy_variable in (0,):\n pass''' # inlined_call = typed_ast3.parse(template_code).body[0] call_code = typed_astunparse.unparse(call).strip() inlined_statements = [] if self._verbose: inlined_statements.append(horast_nodes.Comment( value=typed_ast3.Str(' inlined {}'.format(call_code), ''), eol=False)) for stmt in self._inlined_function.body: stmt = st.augment(copy.deepcopy(stmt), eval_=False) for replacer in replacers: stmt = replacer.visit(stmt) if stmt is not None: inlined_statements.append(stmt) if self._verbose: inlined_statements.append(horast_nodes.Comment( value=typed_ast3.Str(' end of inlined {}'.format(call_code), ''), eol=False)) _LOG.warning('inlining a call %s using replacers %s', call_code, replacers) # inlined_call.body = scope # return st.augment(inlined_call), eval_=False) assert inlined_statements if len(inlined_statements) == 1: return inlined_statements[0] return inlined_statements
def test_ast_validator_synthetic(self): examples = ((typed_ast3, typed_ast3.FormattedValue(typed_ast3.Str('value'), None, None)), (typed_ast3, typed_ast3.keyword( None, typed_ast3.Name('value', typed_ast3.Load()))), (typed_ast3, typed_ast3.ImportFrom('pkg', [typed_ast3.alias('module', None)], None))) for fields_first, (ast_module, example) in itertools.product( (False, True), examples): with self.subTest(example=example): # tree = ast_module.Expression(example) validator = AstValidator[ast_module](fields_first=fields_first, mode=None) validator.visit(example)
def test_partial_inline_burn(self): _ = self.app_source_folder.joinpath( 'physics', 'sourceTerms', 'Burn', 'BurnMain', 'nuclearBurn') inlined_path = _.joinpath('Aprox13', 'bn_mapNetworkToSpecies.F90') target_path = _.joinpath('Burn.F90') reader = CodeReader() inlined_code = reader.read_file(inlined_path) target_code = reader.read_file(target_path) parser = Parser.find(Language.find('Fortran'))() inlined_fortran_ast = parser.parse(inlined_code, inlined_path) # inlined_fortran_ast = inlined_fortran_ast.find('.//subroutine') target_fortran_ast = parser.parse(target_code, target_path) ast_generalizer = AstGeneralizer.find(Language.find('Fortran'))() inlined_syntax = ast_generalizer.generalize(inlined_fortran_ast) inlined_function = inlined_syntax.body[-1] # TODO: implement object finding to find function target_syntax = ast_generalizer.generalize(target_fortran_ast) target_function = target_syntax.body[-1] # TODO: implement object finding to find function # import horast # print(horast.unparse(inlined_function)) # print(horast.unparse(target_function)) # import ipdb; ipdb.set_trace() # import static_typing inlined_syntax = inline_syntax( target_function, inlined_function, # globals_={'NSPECIES': 13, 'st': static_typing, **globals()}, verbose=True) annotation = horast_nodes.Directive(typed_ast3.Str('$acc parallel loop', '')) annotate_loop_syntax(inlined_syntax, annotation) unparser = Unparser.find(Language.find('Fortran'))() transformed_code = unparser.unparse(inlined_syntax) results_path = pathlib.Path(APPS_RESULTS_ROOT, 'flash5-inlined') results_path.mkdir(exist_ok=True) CodeWriter().write_file(transformed_code, results_path.joinpath('Burn.inlined_some.F90'))
def _PointerType(self, node: ET.Element): # pylint: disable=invalid-name id_ = node.attrib['id'] type_ = node.attrib['type'] is_const = type_.endswith('c') if is_const: type_ = type_[:-1] try: base_type = self.fundamental_types[type_] except KeyError: # _LOG.debug() base_type = typed_ast3.Str(type_, '') type_info = typed_ast3.Subscript(value=typed_ast3.Name( id='Pointer', ctx=typed_ast3.Load()), slice=typed_ast3.Index(base_type), ctx=typed_ast3.Load()) if is_const: type_info = typed_ast3.Subscript(value=typed_ast3.Name( id='Const', ctx=typed_ast3.Load()), slice=typed_ast3.Index(type_info), ctx=typed_ast3.Load()) return (id_, type_info)
def visit_ImportFrom(self, node: ast3.ImportFrom) -> VisitorOutput: """Defines how to import (from) modules (supported and nonsupported) For example, it converts:: from numpy import array from numpy import * from somelib import var, othervar as var2 from otherlib import * into:: from pytropos.libs_checking import numpy_module st['array'] = numpy_module.attr['array', pos...] from pytropos.libs_checking import numpy_module st.importStar(numpy_module) st['var'] = pt.Top st['var2'] = pt.Top st.importStar() """ libs: 'List[ast3.AST]' = [] if node.module in self._supported_modules: module_name = self._supported_modules[node.module] # from pytropos.libs_checking import module_name libs.append( ast3.ImportFrom( module='pytropos.libs_checking', names=[ast3.alias(name=module_name, asname=None)], level=0, )) if node.names[0].name == '*': # st.importStar(module_name) libs.append( ast3.Expr(value=ast3.Call( func=ast3.Attribute( value=ast3.Name(id='st', ctx=ast3.Load()), attr='importStar', ctx=ast3.Load(), ), args=[ast3.Name(id=module_name, ctx=ast3.Load())], keywords=[], ), )) else: for alias in node.names: # st['asname'] = modname.attr['name'] pos = pos_as_tuple(node) if pos is not None: attrname = ast3.Tuple( elts=[ast3.Str(s=alias.name), pos], ctx=ast3.Load()) # type: ast3.expr else: attrname = ast3.Str(s=alias.name) libs.append( ast3.Assign( targets=[ ast3.Subscript( value=ast3.Name(id='st', ctx=ast3.Load()), slice=ast3.Index(value=ast3.Str( s=alias.asname if alias. asname else alias.name), ), ctx=ast3.Store(), ), ], value=ast3.Subscript( value=ast3.Attribute( value=ast3.Name(id=module_name, ctx=ast3.Load()), attr='attr', ctx=ast3.Load(), ), slice=ast3.Index(value=attrname, ), ctx=ast3.Load(), ), )) else: if node.names[0].name == '*': # st.importStar() libs.append( ast3.Expr(value=ast3.Call( func=ast3.Attribute( value=ast3.Name(id='st', ctx=ast3.Load()), attr='importStar', ctx=ast3.Load(), ), args=[], keywords=[], ), )) else: libs.extend( ast3.parse( # type: ignore '\n'.join([ "st['{asname}'] = pt.Top".format( asname=alias.asname if alias.asname else alias. name) for alias in node.names ])).body) return libs
def visit_Call(self, node: ast3.Call) -> VisitorOutput: """Transforms a call to be handled by Pytropos For example, it converts:: func(3, b, *c, d=2) into:: func.call(st, Args((pt.int(3), st['b']), st['c'], {'d': pt.int(2)}), pos=...)""" self.generic_visit(node) args = [] # type: List[ast3.expr] starred = None # type: Optional[ast3.expr] kwargs_keys = [] # type: List[ast3.expr] kwargs_values = [] # type: List[ast3.expr] for i, v in enumerate(node.args): if isinstance(v, ast3.Starred): starred = v.value break args.append(v) # In case a starred expresion was found else: # If there is something after the starred expr if len(node.args) > 0 and i < len(node.args) - 1: raise AstTransformerError( f"{self.filename}:{v.lineno}:{v.col_offset}: Fatal Error: " "Only one expression starred is allowed when calling a function" ) for val in node.keywords: if val.arg is None: raise AstTransformerError( f"{self.filename}:{v.lineno}:{v.col_offset}: Fatal Error: " "No kargs parameters is allowed when calling a function") kwargs_keys.append(ast3.Str(s=val.arg)) kwargs_values.append(val.value) new_call_args = [ ast3.Tuple( elts=args, ctx=ast3.Load(), ), ] # type: List[ast3.expr] if kwargs_keys: new_call_args.append( ast3.NameConstant(value=None) if starred is None else starred) new_call_args.append( ast3.Dict(keys=kwargs_keys, values=kwargs_values)) elif starred is not None: new_call_args.append(starred) return ast3.Call( func=ast3.Attribute( value=node.func, attr='call', ctx=ast3.Load(), ), args=[ ast3.Name(id='st', ctx=ast3.Load()), ast3.Call( func=ast3.Attribute( value=ast3.Name(id='pt', ctx=ast3.Load()), attr='Args', ctx=ast3.Load(), ), args=new_call_args, keywords=[], ) ], keywords=[ ast3.keyword(arg='pos', value=pos_as_tuple(node), ctx=ast3.Load()) ], )
def visit_Str(self, s): if isinstance(s.s, bytes): return ast3.Bytes(s.s, s.kind) else: return ast3.Str(s.s, s.kind)