def make_expr(expr: ExprType): """ Wrap a concrete expression (e.g VarAccessExpr) into an Expr object :param expr: Expression to wrap :return: Expression wrapped into Expr """ if isinstance(expr, Expr): return expr wrapped_expr = Expr() if isinstance(expr, UnaryOperator): wrapped_expr.unary_operator.CopyFrom(expr) elif isinstance(expr, BinaryOperator): wrapped_expr.binary_operator.CopyFrom(expr) elif isinstance(expr, AssignmentExpr): wrapped_expr.assignment_expr.CopyFrom(expr) elif isinstance(expr, TernaryOperator): wrapped_expr.ternary_operator.CopyFrom(expr) elif isinstance(expr, FunCallExpr): wrapped_expr.fun_call_expr.CopyFrom(expr) elif isinstance(expr, StencilFunCallExpr): wrapped_expr.stencil_fun_call_expr.CopyFrom(expr) elif isinstance(expr, StencilFunArgExpr): wrapped_expr.stencil_fun_arg_expr.CopyFrom(expr) elif isinstance(expr, VarAccessExpr): wrapped_expr.var_access_expr.CopyFrom(expr) elif isinstance(expr, FieldAccessExpr): wrapped_expr.field_access_expr.CopyFrom(expr) elif isinstance(expr, LiteralAccessExpr): wrapped_expr.literal_access_expr.CopyFrom(expr) else: raise SIRError("cannot create Expr from type {}".format(type(expr))) return wrapped_expr
def make_stmt(stmt: StmtType): """ Wrap a concrete statement (e.g ExprStmt) into an Stmt object :param stmt: Statement to wrap :return: Statement wrapped into Stmt """ if isinstance(stmt, Stmt): return stmt wrapped_stmt = Stmt() if isinstance(stmt, BlockStmt): wrapped_stmt.block_stmt.CopyFrom(stmt) elif isinstance(stmt, ExprStmt): wrapped_stmt.expr_stmt.CopyFrom(stmt) elif isinstance(stmt, ReturnStmt): wrapped_stmt.return_stmt.CopyFrom(stmt) elif isinstance(stmt, VarDeclStmt): wrapped_stmt.var_decl_stmt.CopyFrom(stmt) elif isinstance(stmt, VerticalRegionDeclStmt): wrapped_stmt.vertical_region_decl_stmt.CopyFrom(stmt) elif isinstance(stmt, StencilCallDeclStmt): wrapped_stmt.var_decl_stmt.CopyFrom(stmt) elif isinstance(stmt, BoundaryConditionDeclStmt): wrapped_stmt.var_decl_stmt.CopyFrom(stmt) elif isinstance(stmt, IfStmt): wrapped_stmt.if_stmt.CopyFrom(stmt) else: raise SIRError("cannot create Stmt from type {}".format(type(stmt))) return wrapped_stmt
def makeAST(root: StmtType) -> AST: """ Create an AST :param root: Root node of the AST (needs to be of type BlockStmt) """ ast = AST() if isinstance(root, BlockStmt) or (isinstance(root, Stmt) and root.WhichOneof("stmt") == "block_stmt"): ast.root.CopyFrom(makeStmt(root)) else: raise SIRError("root statement of an AST needs to be a BlockStmt") return ast