def cascade_part(part: Part, stripe_stage: te.Stage, stripe_axis: tir.IterVar, sch: te.Schedule) -> None: """Schedule a Part into a cascade indicated by a stripe Stage.""" te_subgraph = part.subgraph g = sch.create_group(outputs=te_subgraph.output_tensor, inputs=te_subgraph.input_tensors, include_inputs=False) g.compute_at(stripe_stage, stripe_axis)
def apply_proposal(proposal: Proposal, sch: te.Schedule) -> None: """Apply a Proposal to a Schedule, converting all the Plans into TE scheduling instructions. Note that the Schedule is mutated in-place. Parameters ---------- proposal : Proposal The Proposal to apply to the Schedule. sch : te.Schedule The Schedule to apply to Proposal to. """ for plan in proposal.plans: for part in plan.part_group: if isinstance(part, EthosuPart): tensor_config = plan.tensor_configs[part.output_tensor] stripe_config = tensor_config.stripe_configs[0] block_config = part.get_block_config(stripe_config) iv = part.subgraph.output_tensor.op.axis[0] block_shape = block_config.output_shape if len(block_shape) == 4: height, width, depth = block_shape[1:] else: height = block_shape[1] width = block_shape[3] depth = block_shape[2] * block_shape[4] sch[part.subgraph.output_tensor].pragma( iv, "block_config_height", height) sch[part.subgraph.output_tensor].pragma( iv, "block_config_width", width) sch[part.subgraph.output_tensor].pragma( iv, "block_config_depth", depth) output_tensor_config = plan.output_config output_tensor = output_tensor_config.tensor output_part = output_tensor.producers[0] if output_part.in_line: continue stripe_config = output_tensor_config.stripe_configs[0] stripe_shape = [int(x) for x in stripe_config.shape] stripe_stage, stripe_axis = stripe_part(output_part, stripe_shape, sch) copy_te_tensors = [] readers = defaultdict(list) for part in plan.part_group: if part != output_part: cascade_part(part, stripe_stage, stripe_axis, sch) update_readers(part, readers) for i, input_tensor in enumerate(part.input_tensors): tensor_config = plan.tensor_configs[input_tensor] if tensor_config.home_region != tensor_config.copy_region: copy_te_tensors.append(part.subgraph.input_tensors[i]) for te_tensor in copy_te_tensors: copy_stage = sch.cache_read(te_tensor, "global", readers[te_tensor]) sch[copy_stage].compute_at(stripe_stage, stripe_axis)
def stripe_part(part: Part, stripe_shape: Tuple[int, ...], sch: te.Schedule) -> Tuple[te.Stage, tir.IterVar]: """Apply a striping schedule to the TE subgraph represented by a Part.""" te_subgraph = part.subgraph te_output_tensor = te_subgraph.output_tensor outer_indices, _ = tile_nd(sch, te_output_tensor, stripe_shape) g = sch.create_group( outputs=te_output_tensor.op.input_tensors, inputs=te_subgraph.input_tensors, include_inputs=False, ) g.compute_at(sch[te_output_tensor], outer_indices[-1]) for axis in outer_indices: sch[te_output_tensor].unroll(axis) return sch[te_output_tensor], outer_indices[-1]
def apply_proposal(proposal: Proposal, sch: te.Schedule) -> None: """Apply a Proposal to a Schedule, converting all the Plans into TE scheduling instructions. Note that the Schedule is mutated in-place. Parameters ---------- proposal : Proposal The Proposal to apply to the Schedule. sch : te.Schedule The Schedule to apply to Proposal to. """ for plan in proposal.plans: output_tensor_config = plan.output_config output_tensor = output_tensor_config.tensor output_part = output_tensor.producers[0] if output_part.in_line: continue stripe_config = output_tensor_config.stripe_configs[0] stripe_shape = [int(x) for x in stripe_config.shape] stripe_stage, stripe_axis = stripe_part(output_part, stripe_shape, sch) copy_te_tensors = [] readers = defaultdict(list) for part in plan.part_group: if part != output_part: cascade_part(part, stripe_stage, stripe_axis, sch) update_readers(part, readers) for i, input_tensor in enumerate(part.input_tensors): tensor_config = plan.tensor_configs[input_tensor] if tensor_config.home_region != tensor_config.copy_region: copy_te_tensors.append(part.subgraph.input_tensors[i]) for te_tensor in copy_te_tensors: copy_stage = sch.cache_read(te_tensor, "global", readers[te_tensor]) sch[copy_stage].compute_at(stripe_stage, stripe_axis)
def apply_proposal(proposal: Proposal, sch: te.Schedule) -> None: """Apply a Proposal to a Schedule, converting all the Plans into TE scheduling instructions. Note that the Schedule is mutated in-place. Parameters ---------- proposal : Proposal The Proposal to apply to the Schedule. sch : te.Schedule The Schedule to apply to Proposal to. """ for plan in proposal.plans: for part in plan.part_group: if isinstance(part, EthosuPart): tensor_config = plan.tensor_configs[part.output_tensor] stripe_config = tensor_config.stripe_configs[0] buffer_mode = tensor_config.buffer_mode block_config = part.get_block_config(stripe_config) compute_cycles = part.get_performance_info( stripe_config, buffer_mode ).compute_cycles iv = part.subgraph.output_tensor.op.axis[0] block_shape = block_config.output_shape if len(block_shape) == 4: height, width, depth = block_shape[1:] else: height = block_shape[1] width = block_shape[3] depth = block_shape[2] * block_shape[4] sch[part.subgraph.output_tensor].pragma(iv, "block_config_height", height) sch[part.subgraph.output_tensor].pragma(iv, "block_config_width", width) sch[part.subgraph.output_tensor].pragma(iv, "block_config_depth", depth) # Attach AttrStmt directly to npu op so it isn't removed by ReplaceOperators npu_op = part.subgraph.output_tensor.op.input_tensors[0].op.input_tensors[0] sch[npu_op].pragma(npu_op.op.axis[0], "compute_cycles_hint", compute_cycles) output_tensor_config = plan.output_config output_tensor = output_tensor_config.tensor output_part = output_tensor.producers[0] if output_part.in_line: continue stripe_config = output_tensor_config.stripe_configs[0] stripe_shape = [int(x) for x in stripe_config.shape] stripe_stage, stripe_axis = stripe_part(output_part, stripe_shape, sch) copy_te_tensors = [] compute_cycles_hints = [] readers = defaultdict(list) for part in plan.part_group: if part != output_part: cascade_part(part, stripe_stage, stripe_axis, sch) update_readers(part, readers) for i, input_tensor in enumerate(part.input_tensors): tensor_config = plan.tensor_configs[input_tensor] if tensor_config.home_region != tensor_config.copy_region: copy_te_tensors.append(part.subgraph.input_tensors[i]) compute_cycles_hint, _ = get_copy_cycles_hint(tensor_config) compute_cycles_hints.append(compute_cycles_hint) for te_tensor, compute_cycles_hint in zip(copy_te_tensors, compute_cycles_hints): copy_stage = sch.cache_read(te_tensor, "global", readers[te_tensor]) sch[copy_stage].pragma( copy_stage.op.axis[0], "compute_cycles_hint", compute_cycles_hint ) sch[copy_stage].compute_at(stripe_stage, stripe_axis)
def __call__(self, s: te.Schedule, op): write_cache = s.cache_write(op.output(0), "local") read_caches = [s.cache_read(t, "local", [write_cache]) for t in op.input_tensors] op_stg: Stage = s[op] wc_stg: Stage = s[write_cache] rc_stgs: List[Stage] = [s[rc] for rc in read_caches] sp_ivs = [x for x in op_stg.op.axis] assert len(sp_ivs) > 0, "empty spatial axes" def set_pragma(): n = sp_ivs[0] outer_scope, n = self._split(op_stg, n, nparts=1) sp_ivs[0] = n unroll = self._config["unroll"] if unroll is not None: step, explicit = unroll op_stg.pragma(outer_scope, 'auto_unroll_max_step', step) op_stg.pragma(outer_scope, 'unroll_explicit', explicit) set_pragma() def tile_and_fuse(): sp_parts = self._split_axes( op_stg, sp_ivs, self._config["spatial"]) sp_levels = list(zip(*sp_parts)) op_stg.reorder(*(iv for lv in sp_levels for iv in lv)) sp_levels = self._fuse_axes( op_stg, sp_levels, self._config["fuse"]) return sp_levels sp_levels = tile_and_fuse() def bind_and_check(): blocks = [te.thread_axis(f"blockIdx.{x}") for x in "xyz"] threads = [te.thread_axis(f"threadIdx.{x}") for x in "xyz"] vthreads = [te.thread_axis("vthread") for _ in "xyz"] local_write_pos = self._bind_axes( op_stg, sp_levels, [blocks, vthreads, threads]) n_threads_per_block = reduce(lambda a, b: a * b, (self._get_iv_extent( iv) for (t, iv) in self._thread_ivs.items() if t.startswith("threadIdx")), 1) if n_threads_per_block > MAX_THREADS_PER_BLOCK: raise RuntimeError( "Work group excess limit size: {} (required) vs. {} (given)".format( n_threads_per_block, MAX_THREADS_PER_BLOCK)) return local_write_pos local_write_pos = bind_and_check() def unroll_and_vectorize(): bound_axes = set(self._thread_ivs.values()) [op_stg.unroll(iv) for iv in sp_levels[-1] [:-1] if iv not in bound_axes] last_iv = sp_levels[-1][-1] if last_iv not in bound_axes: last_ext = self._get_iv_extent(last_iv) def vec(x): outer, inner = self._split(op_stg, last_iv, factor=x) op_stg.unroll(outer) op_stg.vectorize(inner) if last_ext % 16 == 0: vec(16) elif last_ext % 8 == 0: vec(8) elif last_ext % 4 == 0: vec(4) elif last_ext % 2 == 0: vec(2) unroll_and_vectorize() def handle_write_cache(): local_read_pos = None # compute at wc_stg.compute_at(op_stg, local_write_pos) # split reduce axis wc_sp_ivs = wc_stg.op.axis re_ivs = wc_stg.op.reduce_axis if len(re_ivs) > 0: re_parts = self._split_axes( wc_stg, re_ivs, self._config["reduce"]) re_levels = list(zip(*re_parts)) last_lv = re_levels[-1] # interleave reorder reorder_lst = [iv for lv in re_levels[:-1] for iv in lv] pos = self._config["reorder"] if pos is None: reorder_lst.extend(last_lv + wc_sp_ivs) else: reorder_lst.extend(_interleave_shift(last_lv, wc_sp_ivs, pos)) wc_stg.reorder(*reorder_lst) # unroll [wc_stg.unroll(iv) for iv in wc_sp_ivs] local_read_pos = last_lv[-1] return local_read_pos local_read_pos = handle_write_cache() def handle_read_caches(): for rc_stg in rc_stgs: if local_read_pos is None: rc_stg.compute_inline() else: # compute at rc_stg.compute_at(wc_stg, local_read_pos) # unroll and vectorize rc_sp_ivs = rc_stg.op.axis # print([self._get_iv_extent(iv) for iv in rc_sp_ivs[:-1]], flush=True) [rc_stg.unroll(iv) for iv in rc_sp_ivs[:-1]] last_iv = rc_sp_ivs[-1] last_ext = self._get_iv_extent(rc_sp_ivs[-1]) def vec(x): outer, inner = self._split(rc_stg, last_iv, factor=x) rc_stg.unroll(outer) rc_stg.vectorize(inner) if last_ext % 16 == 0: vec(16) elif last_ext % 8 == 0: vec(8) elif last_ext % 4 == 0: vec(4) elif last_ext % 2 == 0: vec(2) handle_read_caches()