def _make_apply_block(self, interval_block): # Body stmts = [] for stmt_info in interval_block.stmts: if not isinstance(stmt_info.stmt, gt_ir.Decl): stmts.append(stmt_info.stmt) body = gt_ir.BlockStmt(stmts=stmts) result = gt_ir.ApplyBlock( interval=self._make_axis_interval(interval_block.interval), body=body ) return result
def _make_stage(self, ij_block): # Apply blocks and decls apply_blocks = [] decls = [] for int_block in ij_block.interval_blocks: # Make apply block stmts = [] local_symbols = {} for stmt_info in int_block.stmts: if isinstance(stmt_info.stmt, gt_ir.Decl): decl = stmt_info.stmt if decl.name in self.data.symbols: decls.append(stmt_info.stmt) else: assert isinstance(decl, gt_ir.VarDecl) local_symbols[decl.name] = decl else: stmts.append(stmt_info.stmt) apply_block = gt_ir.ApplyBlock( interval=self._make_axis_interval(int_block.interval), local_symbols=local_symbols, body=gt_ir.BlockStmt(stmts=stmts), ) apply_blocks.append(apply_block) # Accessors accessors = [] remaining_outputs = set(ij_block.outputs) for name, extent in ij_block.inputs.items(): if name in remaining_outputs: read_write = True remaining_outputs.remove(name) extent |= Extent.zeros() else: read_write = False accessors.append(self._make_accessor(name, extent, read_write)) zero_extent = Extent.zeros(self.data.ndims) for name in remaining_outputs: accessors.append(self._make_accessor(name, zero_extent, True)) stage = gt_ir.Stage( name="stage__{}".format(ij_block.id), accessors=accessors, apply_blocks=apply_blocks, compute_extent=ij_block.compute_extent, ) return stage