def _make_apply_block(self, interval_block): # Body stmts = [] for stmt_info in interval_block.stmts: if not isinstance(stmt_info.stmt, gt_ir.Decl): stmts.append(stmt_info.stmt) body = gt_ir.BlockStmt(stmts=stmts) result = gt_ir.ApplyBlock( interval=self._make_axis_interval(interval_block.interval), body=body ) return result
def visit_If(self, node: ast.If) -> gt_ir.If: main_stmts = [] for stmt in node.body: main_stmts.extend(gt_utils.listify(self.visit(stmt))) assert all(isinstance(item, gt_ir.Statement) for item in main_stmts) else_stmts = [] if node.orelse: for stmt in node.orelse: else_stmts.extend(gt_utils.listify(self.visit(stmt))) assert all( isinstance(item, gt_ir.Statement) for item in else_stmts) result = gt_ir.If( condition=gt_ir.utils.make_expr(self.visit(node.test)), main_body=gt_ir.BlockStmt(stmts=main_stmts), else_body=gt_ir.BlockStmt( stmts=else_stmts) if else_stmts else None, ) return result
def _visit_interval_node(self, node: ast.With) -> gt_ir.ComputationBlock: loc = gt_ir.Location.from_ast_node(node) interval_error = GTScriptSyntaxError( f"Invalid 'interval' specification at line {loc.line} (column {loc.column})", loc=loc) interval_node = node.items[0].context_expr if ((len(interval_node.args) + len(interval_node.keywords) < 1) or (len(interval_node.args) + len(interval_node.keywords) > 2) or any(keyword.arg not in ["start", "end"] for keyword in interval_node.keywords)): raise interval_error loc = gt_ir.Location.from_ast_node(node) range_error = GTScriptSyntaxError( f"Invalid interval range specification at line {loc.line} (column {loc.column})", loc=loc, ) if interval_node.args: range_node = interval_node.args else: range_node = [ interval_node.keywords[0].value, interval_node.keywords[1].value ] if len(range_node) == 1 and isinstance(range_node[0], ast.Ellipsis): interval = gt_ir.AxisInterval.full_interval() elif len(range_node) == 2 and all( isinstance(elem, (ast.Num, ast.UnaryOp, ast.NameConstant)) for elem in range_node): range_value = tuple(self.visit(elem) for elem in range_node) try: interval = gt_ir.utils.make_axis_interval(range_value) except AssertionError as e: raise range_error from e else: raise range_error self.parsing_context = ParsingContext.INTERVAL stmts = [] for stmt in node.body: stmts.extend(gt_utils.listify(self.visit(stmt))) self.parsing_context = ParsingContext.COMPUTATION result = gt_ir.ComputationBlock( interval=interval, iteration_order=gt_ir.IterationOrder.PARALLEL, body=gt_ir.BlockStmt(stmts=stmts), ) return result
def visit_If(self, node: ast.If) -> gt_ir.If: condition_value = gt_utils.meta.ast_eval(node.test, self.externals, default=NOTHING) if condition_value is not NOTHING: # Compile-time evaluation stmts = [] if condition_value: for stmt in node.body: stmts.extend(gt_utils.listify(self.visit(stmt))) elif node.orelse: for stmt in node.orelse: stmts.extend(gt_utils.listify(self.visit(stmt))) result = stmts else: # run-time evaluation main_stmts = [] for stmt in node.body: main_stmts.extend(gt_utils.listify(self.visit(stmt))) assert all( isinstance(item, gt_ir.Statement) for item in main_stmts) else_stmts = [] if node.orelse: for stmt in node.orelse: else_stmts.extend(gt_utils.listify(self.visit(stmt))) assert all( isinstance(item, gt_ir.Statement) for item in else_stmts) result = gt_ir.If( condition=gt_ir.utils.make_expr(self.visit(node.test)), main_body=gt_ir.BlockStmt(stmts=main_stmts), else_body=gt_ir.BlockStmt( stmts=else_stmts) if else_stmts else None, ) return result
def _make_stage(self, ij_block): # Apply blocks and decls apply_blocks = [] decls = [] for int_block in ij_block.interval_blocks: # Make apply block stmts = [] local_symbols = {} for stmt_info in int_block.stmts: if isinstance(stmt_info.stmt, gt_ir.Decl): decl = stmt_info.stmt if decl.name in self.data.symbols: decls.append(stmt_info.stmt) else: assert isinstance(decl, gt_ir.VarDecl) local_symbols[decl.name] = decl else: stmts.append(stmt_info.stmt) apply_block = gt_ir.ApplyBlock( interval=self._make_axis_interval(int_block.interval), local_symbols=local_symbols, body=gt_ir.BlockStmt(stmts=stmts), ) apply_blocks.append(apply_block) # Accessors accessors = [] remaining_outputs = set(ij_block.outputs) for name, extent in ij_block.inputs.items(): if name in remaining_outputs: read_write = True remaining_outputs.remove(name) extent |= Extent.zeros() else: read_write = False accessors.append(self._make_accessor(name, extent, read_write)) zero_extent = Extent.zeros(self.data.ndims) for name in remaining_outputs: accessors.append(self._make_accessor(name, zero_extent, True)) stage = gt_ir.Stage( name="stage__{}".format(ij_block.id), accessors=accessors, apply_blocks=apply_blocks, compute_extent=ij_block.compute_extent, ) return stage