def dispatch_var_type(self, tree): code = horast.unparse(tree) stripped_code = code.strip() if stripped_code in PYTHON_FORTRAN_TYPE_PAIRS: type_name, precision = PYTHON_FORTRAN_TYPE_PAIRS[stripped_code] self.write(type_name) if precision is not None: self.write('*') self.write(str(precision)) elif _match_array(tree): sli = tree.slice assert isinstance(sli, typed_ast3.Index), typed_astunparse.dump(tree) assert isinstance(sli.value, typed_ast3.Tuple) assert len(sli.value.elts) in (2, 3), sli.value.elts elts = sli.value.elts self.dispatch_var_type(elts[1]) self.write(', ') self.write('dimension(') if len(sli.value.elts) == 2: self.write(':') else: if not self._context_input_args: self.dispatch(elts[2]) else: assert isinstance(elts[2], typed_ast3.Tuple) _LOG.warning('coercing indices of %i dims to 0-based', len(elts[2].elts)) # _LOG.warning('coercing indices of %s in %s to 0-based', arg.arg, t.name) tup_elts = [] for elt in elts[2].elts: if isinstance(elt, typed_ast3.Num): assert isinstance(elt.n, int) upper = typed_ast3.Num(n=elt.n - 1) else: assert isinstance(elt, typed_ast3.Name) upper = typed_ast3.BinOp(left=elt, op=typed_ast3.Sub(), right=typed_ast3.Num(n=1)) tup_elts.append( typed_ast3.Slice(lower=typed_ast3.Num(n=0), upper=upper, step=None)) tup = typed_ast3.Tuple(elts=tup_elts, ctx=typed_ast3.Load()) self.dispatch(tup) self.write(')') elif _match_io(tree): self.write('integer') elif isinstance(tree, typed_ast3.Call) and isinstance(tree.func, typed_ast3.Name) \ and tree.func.id == 'type': self.dispatch(tree) else: raise NotImplementedError('not yet implemented: {}'.format( typed_astunparse.dump(tree)))
def pos_as_tuple(node: Union[ast3.expr, ast3.stmt]) -> Optional[ast3.Tuple]: if not hasattr(node, 'lineno'): return None return ast3.Tuple(elts=[ ast3.Tuple(elts=[ast3.Num(node.lineno), ast3.Num(node.col_offset)], ctx=ast3.Load()), ast3.Name(id='fn', ctx=ast3.Load()) ], ctx=ast3.Load())
def inferred_range_args(node: typed_ast3.Call) -> t.Tuple[ typed_ast3.AST, typed_ast3.AST, typed_ast3.AST]: """Return a tuple (begin, end, step) for a given range() call. Return 3-tuple even if in the call some of the arguments are skipped. """ assert match_range_call(node), type(node) if len(node.args) == 1: return typed_ast3.Num(n=0), node.args[0], typed_ast3.Num(n=1) if len(node.args) == 2: return node.args[0], node.args[1], typed_ast3.Num(n=1) return tuple(node.args)
def visit_Constant(self, node): # pylint: disable=invalid-name """Transform Constant into Num or Str.""" type_ = node.type value = node.value _ = self.visit(node.coord) if type_ in ('int', ): return typed_ast3.Num(int(value)) if type_ in ('string', ): assert value[0] == '"' and value[-1] == '"', value return typed_ast3.Str(value[1:-1], '') return self.generic_visit(node)
def visit_UnaryOp(self, node): # pylint: disable=invalid-name """Transform UnaryOp.""" op_type, op_ = C_UNARY_OPERATORS_TO_PYTHON[node.op] expr = self.visit(node.expr) _ = self.visit(node.coord) if op_type is typed_ast3.Call: return op_type(func=typed_ast3.Name(id=op_, ctx=typed_ast3.Load()), args=[expr], keywords=[]) if op_type is typed_ast3.AugAssign: return op_type(target=expr, op=op_(), value=typed_ast3.Num(n=1)) # raise NotImplementedError() return op_type(op=op_(), operand=expr)
def _transform_print_call(call): if not hasattr(call, 'fortran_metadata'): call.fortran_metadata = {} call.fortran_metadata['is_transformed'] = True if len(call.args) == 1: arg = call.args[0] if isinstance(arg, typed_ast3.Call) and isinstance( arg.func, typed_ast3.Attribute): label = int(arg.func.value.id.replace('format_label_', '')) call.args = [typed_ast3.Num(n=label)] + arg.args return call call.args.insert(0, typed_ast3.Ellipsis()) return call
def dispatch_for_iter(self, tree): if not isinstance(tree, typed_ast3.Call) \ or not isinstance(tree.func, typed_ast3.Name) or tree.func.id != 'range' \ or len(tree.args) not in (1, 2, 3): self._unsupported_syntax(tree) if len(tree.args) == 1: lower = typed_ast3.Num(n=0) upper = tree.args[0] step = None else: lower, upper, step, *_ = tree.args + [None, None] self.dispatch(lower) self.write(', ') if isinstance(upper, typed_ast3.BinOp) and isinstance(upper.op, typed_ast3.Add) \ and isinstance(upper.right, typed_ast3.Num) and upper.right.n == 1: self.dispatch(upper.left) else: self.dispatch(typed_ast3.BinOp(left=upper, op=typed_ast3.Sub(), right=typed_ast3.Num(n=1))) if step is not None: self.write(', ') self.dispatch(step)
def visit_ListComp(self, node): from parser.functions import FunctionImplementation # calculate result type if len(node.generators) > 1: raise InvalidOperation( "Only one for statement permitted in comprehensions") comp = node.generators[0] if len(comp.ifs) > 1: raise InvalidOperation( "Only one if statement allowed in List Comprehension") assign_node = ast.Assign(targets=[comp.target], value=ast.Subscript(value=comp.iter, slice=ast.Index( ast.Num(0)))) return_node = ast.Return(value=node.elt) function_node = ast.FunctionDef(name="temp", args=ast.arguments(args=[], vararg=None, kwonlyargs=[], kw_defaults=[], kwarg=None, defaults=[]), body=[assign_node, return_node]) function_interpreter = FunctionImplementation(function_node, (), self.context) result_type = TypeDB.get_list([function_interpreter.retval.tp]) # create temp list to hold values result = self.context.get_temp_var(result_type) self.prepends.append( f"{result.code} = {result_type.as_literal([])};\n") # create for expression append_node = ast.Expr( ast.Call(func=ast.Attribute(value=ast.Name(id=result.code, ctx=ast.Load()), attr="append", ctx=ast.Load()), args=[node.elt], keywords=[])) if comp.ifs: body = ast.If(test=comp.ifs[0], body=[append_node], orelse=[]) else: body = append_node for_node = ast.For(target=comp.target, iter=comp.iter, body=[body], orelse=[]) self.prepends.append(for_node) return result
def visit_Num(self, node: ast3.Num) -> VisitorOutput: """Wraps a number into a Pytropos type. Example: given the number `3` returns `pt.int(3)` """ if isinstance(node.n, (int, float)): attr = 'int' if isinstance(node.n, int) else 'float' new_v = ast3.Call(func=ast3.Attribute(value=ast3.Name( id='pt', ctx=ast3.Load()), attr=attr, ctx=ast3.Load()), args=[ast3.Num(n=node.n)], keywords=[]) return new_v else: raise AstTransformerError( f"Number of type {type(node.n)} isn't supported by pytropos. Sorry :S" )
def make_st_ndarray(data_type: typed_ast3.AST, dimensions_or_sizes: t.Union[int, list]) -> typed_ast3.Subscript: """Create a typed_ast node equivalent to: st.ndarray[dimensions, data_type, sizes].""" if isinstance(dimensions_or_sizes, int): dimensions = dimensions_or_sizes sizes = None else: dimensions = len(dimensions_or_sizes) sizes = [make_expression_from_slice(size) for size in dimensions_or_sizes] return typed_ast3.Subscript( value=typed_ast3.Attribute( value=typed_ast3.Name(id='st', ctx=typed_ast3.Load()), attr='ndarray', ctx=typed_ast3.Load()), slice=typed_ast3.Index(value=typed_ast3.Tuple( elts=[typed_ast3.Num(n=dimensions), data_type] + [ typed_ast3.Tuple(elts=sizes, ctx=typed_ast3.Load())] if sizes else [], ctx=typed_ast3.Load())), ctx=typed_ast3.Load())
'np.argmin': 'minloc', 'np.argmax': 'maxloc', 'np.array': lambda _: _.args[0], 'np.conj': 'conjg', 'np.cos': 'cos', 'np.dot': 'dot_product', 'np.finfo.eps': 'epsilon', 'np.finfo.max': 'huge', 'np.finfo.tiny': 'tiny', 'np.maximum': 'max', 'np.minimum': 'min', 'np.sign': 'sign', 'np.sin': 'sin', 'np.sinh': 'sinh', 'np.sqrt': 'sqrt', 'np.zeros': lambda _: typed_ast3.Num(n=0), 'print': _transform_print_call, 'os.environ': 'getenv', 'is_not_none': 'present', 'MPI.Init': 'MPI_Init', 'MPI.COMM_WORLD.Comm_size': 'MPI_Comm_size', 'MPI.COMM_WORLD.Comm_rank': 'MPI_Comm_rank', 'MPI.COMM_WORLD.Barrier': 'MPI_Barrier', 'MPI.Bcast': 'MPI_Bcast', 'MPI.Allreduce': 'MPI_Allreduce', 'MPI.Finalize': 'MPI_Finalize', '{expression}.sum': None, 'Fortran.file_handles[{name}].read': None, 'Fortran.file_handles[{name}].close': None, '{name}.rstrip': None, 'slice': make_slice_from_call
def _get_code(): tree = ast.parse(initial_code) tree.body[0].body.extend( my_snippet.get_body(class_name='MyClass', x_value=ast.Num(10))) return unparse(tree)