Beispiel #1
0
def cascade_part(part: Part, stripe_stage: te.Stage, stripe_axis: tir.IterVar,
                 sch: te.Schedule) -> None:
    """Schedule a Part into a cascade indicated by a stripe Stage."""
    te_subgraph = part.subgraph
    g = sch.create_group(outputs=te_subgraph.output_tensor,
                         inputs=te_subgraph.input_tensors,
                         include_inputs=False)
    g.compute_at(stripe_stage, stripe_axis)
Beispiel #2
0
def apply_proposal(proposal: Proposal, sch: te.Schedule) -> None:
    """Apply a Proposal to a Schedule, converting all the Plans into TE scheduling instructions.

    Note that the Schedule is mutated in-place.

    Parameters
    ----------
    proposal : Proposal
        The Proposal to apply to the Schedule.
    sch : te.Schedule
        The Schedule to apply to Proposal to.

    """
    for plan in proposal.plans:
        for part in plan.part_group:
            if isinstance(part, EthosuPart):
                tensor_config = plan.tensor_configs[part.output_tensor]
                stripe_config = tensor_config.stripe_configs[0]
                block_config = part.get_block_config(stripe_config)
                iv = part.subgraph.output_tensor.op.axis[0]
                block_shape = block_config.output_shape
                if len(block_shape) == 4:
                    height, width, depth = block_shape[1:]
                else:
                    height = block_shape[1]
                    width = block_shape[3]
                    depth = block_shape[2] * block_shape[4]
                sch[part.subgraph.output_tensor].pragma(
                    iv, "block_config_height", height)
                sch[part.subgraph.output_tensor].pragma(
                    iv, "block_config_width", width)
                sch[part.subgraph.output_tensor].pragma(
                    iv, "block_config_depth", depth)

        output_tensor_config = plan.output_config
        output_tensor = output_tensor_config.tensor
        output_part = output_tensor.producers[0]
        if output_part.in_line:
            continue
        stripe_config = output_tensor_config.stripe_configs[0]
        stripe_shape = [int(x) for x in stripe_config.shape]
        stripe_stage, stripe_axis = stripe_part(output_part, stripe_shape, sch)
        copy_te_tensors = []
        readers = defaultdict(list)
        for part in plan.part_group:
            if part != output_part:
                cascade_part(part, stripe_stage, stripe_axis, sch)

            update_readers(part, readers)
            for i, input_tensor in enumerate(part.input_tensors):
                tensor_config = plan.tensor_configs[input_tensor]
                if tensor_config.home_region != tensor_config.copy_region:
                    copy_te_tensors.append(part.subgraph.input_tensors[i])

        for te_tensor in copy_te_tensors:
            copy_stage = sch.cache_read(te_tensor, "global",
                                        readers[te_tensor])
            sch[copy_stage].compute_at(stripe_stage, stripe_axis)
Beispiel #3
0
def stripe_part(part: Part, stripe_shape: Tuple[int, ...],
                sch: te.Schedule) -> Tuple[te.Stage, tir.IterVar]:
    """Apply a striping schedule to the TE subgraph represented by a Part."""
    te_subgraph = part.subgraph
    te_output_tensor = te_subgraph.output_tensor
    outer_indices, _ = tile_nd(sch, te_output_tensor, stripe_shape)
    g = sch.create_group(
        outputs=te_output_tensor.op.input_tensors,
        inputs=te_subgraph.input_tensors,
        include_inputs=False,
    )
    g.compute_at(sch[te_output_tensor], outer_indices[-1])
    for axis in outer_indices:
        sch[te_output_tensor].unroll(axis)

    return sch[te_output_tensor], outer_indices[-1]
Beispiel #4
0
def apply_proposal(proposal: Proposal, sch: te.Schedule) -> None:
    """Apply a Proposal to a Schedule, converting all the Plans into TE scheduling instructions.

    Note that the Schedule is mutated in-place.

    Parameters
    ----------
    proposal : Proposal
        The Proposal to apply to the Schedule.
    sch : te.Schedule
        The Schedule to apply to Proposal to.

    """
    for plan in proposal.plans:
        output_tensor_config = plan.output_config
        output_tensor = output_tensor_config.tensor
        output_part = output_tensor.producers[0]
        if output_part.in_line:
            continue
        stripe_config = output_tensor_config.stripe_configs[0]
        stripe_shape = [int(x) for x in stripe_config.shape]
        stripe_stage, stripe_axis = stripe_part(output_part, stripe_shape, sch)
        copy_te_tensors = []
        readers = defaultdict(list)
        for part in plan.part_group:
            if part != output_part:
                cascade_part(part, stripe_stage, stripe_axis, sch)

            update_readers(part, readers)
            for i, input_tensor in enumerate(part.input_tensors):
                tensor_config = plan.tensor_configs[input_tensor]
                if tensor_config.home_region != tensor_config.copy_region:
                    copy_te_tensors.append(part.subgraph.input_tensors[i])

        for te_tensor in copy_te_tensors:
            copy_stage = sch.cache_read(te_tensor, "global", readers[te_tensor])
            sch[copy_stage].compute_at(stripe_stage, stripe_axis)
Beispiel #5
0
def apply_proposal(proposal: Proposal, sch: te.Schedule) -> None:
    """Apply a Proposal to a Schedule, converting all the Plans into TE scheduling instructions.

    Note that the Schedule is mutated in-place.

    Parameters
    ----------
    proposal : Proposal
        The Proposal to apply to the Schedule.
    sch : te.Schedule
        The Schedule to apply to Proposal to.

    """
    for plan in proposal.plans:
        for part in plan.part_group:
            if isinstance(part, EthosuPart):
                tensor_config = plan.tensor_configs[part.output_tensor]
                stripe_config = tensor_config.stripe_configs[0]
                buffer_mode = tensor_config.buffer_mode
                block_config = part.get_block_config(stripe_config)
                compute_cycles = part.get_performance_info(
                    stripe_config, buffer_mode
                ).compute_cycles
                iv = part.subgraph.output_tensor.op.axis[0]
                block_shape = block_config.output_shape
                if len(block_shape) == 4:
                    height, width, depth = block_shape[1:]
                else:
                    height = block_shape[1]
                    width = block_shape[3]
                    depth = block_shape[2] * block_shape[4]
                sch[part.subgraph.output_tensor].pragma(iv, "block_config_height", height)
                sch[part.subgraph.output_tensor].pragma(iv, "block_config_width", width)
                sch[part.subgraph.output_tensor].pragma(iv, "block_config_depth", depth)

                # Attach AttrStmt directly to npu op so it isn't removed by ReplaceOperators
                npu_op = part.subgraph.output_tensor.op.input_tensors[0].op.input_tensors[0]
                sch[npu_op].pragma(npu_op.op.axis[0], "compute_cycles_hint", compute_cycles)

        output_tensor_config = plan.output_config
        output_tensor = output_tensor_config.tensor
        output_part = output_tensor.producers[0]
        if output_part.in_line:
            continue
        stripe_config = output_tensor_config.stripe_configs[0]
        stripe_shape = [int(x) for x in stripe_config.shape]
        stripe_stage, stripe_axis = stripe_part(output_part, stripe_shape, sch)
        copy_te_tensors = []
        compute_cycles_hints = []
        readers = defaultdict(list)
        for part in plan.part_group:
            if part != output_part:
                cascade_part(part, stripe_stage, stripe_axis, sch)

            update_readers(part, readers)
            for i, input_tensor in enumerate(part.input_tensors):
                tensor_config = plan.tensor_configs[input_tensor]
                if tensor_config.home_region != tensor_config.copy_region:
                    copy_te_tensors.append(part.subgraph.input_tensors[i])

                    compute_cycles_hint, _ = get_copy_cycles_hint(tensor_config)
                    compute_cycles_hints.append(compute_cycles_hint)

        for te_tensor, compute_cycles_hint in zip(copy_te_tensors, compute_cycles_hints):
            copy_stage = sch.cache_read(te_tensor, "global", readers[te_tensor])
            sch[copy_stage].pragma(
                copy_stage.op.axis[0], "compute_cycles_hint", compute_cycles_hint
            )
            sch[copy_stage].compute_at(stripe_stage, stripe_axis)
Beispiel #6
0
    def __call__(self, s: te.Schedule, op):
        write_cache = s.cache_write(op.output(0), "local")
        read_caches = [s.cache_read(t, "local", [write_cache]) for t in
                       op.input_tensors]

        op_stg: Stage = s[op]
        wc_stg: Stage = s[write_cache]
        rc_stgs: List[Stage] = [s[rc] for rc in read_caches]

        sp_ivs = [x for x in op_stg.op.axis]
        assert len(sp_ivs) > 0, "empty spatial axes"

        def set_pragma():
            n = sp_ivs[0]
            outer_scope, n = self._split(op_stg, n, nparts=1)
            sp_ivs[0] = n

            unroll = self._config["unroll"]
            if unroll is not None:
                step, explicit = unroll
                op_stg.pragma(outer_scope, 'auto_unroll_max_step', step)
                op_stg.pragma(outer_scope, 'unroll_explicit', explicit)
        set_pragma()

        def tile_and_fuse():
            sp_parts = self._split_axes(
                op_stg, sp_ivs, self._config["spatial"])
            sp_levels = list(zip(*sp_parts))
            op_stg.reorder(*(iv for lv in sp_levels for iv in lv))

            sp_levels = self._fuse_axes(
                op_stg, sp_levels, self._config["fuse"])
            return sp_levels
        sp_levels = tile_and_fuse()

        def bind_and_check():
            blocks = [te.thread_axis(f"blockIdx.{x}") for x in "xyz"]
            threads = [te.thread_axis(f"threadIdx.{x}") for x in "xyz"]
            vthreads = [te.thread_axis("vthread") for _ in "xyz"]
            local_write_pos = self._bind_axes(
                op_stg, sp_levels, [blocks, vthreads, threads])
            n_threads_per_block = reduce(lambda a, b: a * b, (self._get_iv_extent(
                iv) for (t, iv) in self._thread_ivs.items() if t.startswith("threadIdx")), 1)
            if n_threads_per_block > MAX_THREADS_PER_BLOCK:
                raise RuntimeError(
                    "Work group excess limit size: {} (required) vs. {} (given)".format(
                        n_threads_per_block, MAX_THREADS_PER_BLOCK))
            return local_write_pos
        local_write_pos = bind_and_check()

        def unroll_and_vectorize():
            bound_axes = set(self._thread_ivs.values())
            [op_stg.unroll(iv) for iv in sp_levels[-1]
             [:-1] if iv not in bound_axes]
            last_iv = sp_levels[-1][-1]
            if last_iv not in bound_axes:
                last_ext = self._get_iv_extent(last_iv)
                def vec(x):
                    outer, inner = self._split(op_stg, last_iv, factor=x)
                    op_stg.unroll(outer)
                    op_stg.vectorize(inner)
                if last_ext % 16 == 0:
                    vec(16)
                elif last_ext % 8 == 0:
                    vec(8) 
                elif last_ext % 4 == 0:
                    vec(4)
                elif last_ext % 2 == 0:
                    vec(2)
        unroll_and_vectorize()

        def handle_write_cache():
            local_read_pos = None
            # compute at
            wc_stg.compute_at(op_stg, local_write_pos)
            # split reduce axis
            wc_sp_ivs = wc_stg.op.axis
            re_ivs = wc_stg.op.reduce_axis
            if len(re_ivs) > 0:
                re_parts = self._split_axes(
                    wc_stg, re_ivs, self._config["reduce"])
                re_levels = list(zip(*re_parts))
                last_lv = re_levels[-1]
                # interleave reorder
                reorder_lst = [iv for lv in re_levels[:-1] for iv in lv]
                pos = self._config["reorder"]
                if pos is None:
                    reorder_lst.extend(last_lv + wc_sp_ivs)
                else:
                    reorder_lst.extend(_interleave_shift(last_lv, wc_sp_ivs, pos))
                wc_stg.reorder(*reorder_lst)
                # unroll
                [wc_stg.unroll(iv) for iv in wc_sp_ivs]

                local_read_pos = last_lv[-1]

            return local_read_pos
        local_read_pos = handle_write_cache()


        def handle_read_caches():
            for rc_stg in rc_stgs:
                if local_read_pos is None:
                    rc_stg.compute_inline()
                else:
                    # compute at
                    rc_stg.compute_at(wc_stg, local_read_pos)
                    # unroll and vectorize
                    rc_sp_ivs = rc_stg.op.axis
                    # print([self._get_iv_extent(iv) for iv in rc_sp_ivs[:-1]], flush=True)
                    [rc_stg.unroll(iv) for iv in rc_sp_ivs[:-1]]
                    last_iv = rc_sp_ivs[-1]
                    last_ext = self._get_iv_extent(rc_sp_ivs[-1])
                    def vec(x):
                        outer, inner = self._split(rc_stg, last_iv, factor=x)
                        rc_stg.unroll(outer)
                        rc_stg.vectorize(inner)
                    if last_ext % 16 == 0:
                        vec(16)
                    elif last_ext % 8 == 0:
                        vec(8) 
                    elif last_ext % 4 == 0:
                        vec(4)
                    elif last_ext % 2 == 0:
                        vec(2)
        handle_read_caches()