Esempio n. 1
0
def verify_trace_roundtrip(
    sch: Schedule,
    mod: Union[PrimFunc, IRModule],
    *,
    debug_mask: Union[str, int] = "all",
    text_format: Union[str, Sequence[str]] = ["python", "json"],
) -> Schedule:
    """Serialize a traced schedule to JSON, then replay the JSON trace by applying to
    a fresh new schedule, verifying the reproducibility of scheduling.

    Parameters
    ----------
    sch : tir.Schedule
        The traced TensorIR schedule to be verified
    mod : Union[PrimFunc, IRModule]
        The IRModule or PrimFunc to construct the fresh new schedule
    debug_mask : Union[str, int]
        Do extra correctness checking after the class creation and each time
        after calling the Replace method.
        Possible choices of `debug_mask`:
        1) "all" - Turn on all the checks
        2) "none" - Turn off all the checks
        3) An integer - Turn on checks according to the bitmasks provided in ScheduleDebugMask
    text_format: Union[str, Sequence[str]]
        The text format or formats whose round-trip behavior should be
        validated.  If a single string, validate round-trips through
    """
    if not isinstance(text_format, str):
        for opt in text_format:
            new_sch = verify_trace_roundtrip(sch,
                                             mod,
                                             debug_mask=debug_mask,
                                             text_format=opt)
        return new_sch

    trace = sch.trace
    assert trace is not None

    # Step 1. Perform a round-trip through the text-format
    new_sch = Schedule(mod=mod, debug_mask=debug_mask)
    if text_format == "json":
        json_obj = trace.as_json()
        Trace.apply_json_to_schedule(json_obj=json_obj, sch=new_sch)
    elif text_format == "python":
        py_trace = "\n".join(trace.as_python())
        exec(py_trace, tvm.tir.__dict__, {"sch": new_sch})  # pylint: disable=exec-used
    else:
        assert text_format in ("json",
                               "python"), f"Unknown text format: {text_format}"

    # Step 2. Verify that the round-trip produced the same scheduling
    assert structural_equal(new_sch.mod, sch.mod)

    # Step 3. Check the consistency of the text format between the old and new traces
    py_repr = "\n".join(trace.as_python())
    new_py_repr = "\n".join(new_sch.trace.as_python())
    assert py_repr == new_py_repr

    # Step 4. Return the new schedule in case it could be useful
    return new_sch
Esempio n. 2
0
def _is_trace_equal(sch_1: Schedule, sch_2: Schedule, remove_decisions=True) -> bool:
    if remove_decisions:
        trace_1 = Trace(sch_1.trace.insts, {})
        trace_2 = Trace(sch_2.trace.insts, {})
    else:
        trace_1 = sch_1.trace
        trace_2 = sch_2.trace
    return str(trace_1) == str(trace_2)
Esempio n. 3
0
def check_trace(spaces: List[Schedule], expected: List[List[str]]):
    expected_traces = {"\n".join(t) for t in expected}
    actual_traces = set()
    for space in spaces:
        trace = Trace(space.trace.insts, {})
        trace = trace.simplified(remove_postproc=True)
        str_trace = "\n".join(str(trace).strip().splitlines())
        actual_traces.add(str_trace)
        assert str_trace in expected_traces, "\n" + str_trace
    assert len(expected_traces) == len(actual_traces)
Esempio n. 4
0
def _find_match_sketch_id(
    mod: IRModule,
    sketches: List[Schedule],
    expected_mod: IRModule,
    expected_decision: List[Tuple[str, List[int]]],
    *,
    debug_mask="all",
) -> Optional[int]:
    for sketch_id, sketch in enumerate(sketches):
        i = 0
        new_decisions = {}
        for inst in sketch.trace.insts:
            if not inst.kind.name.startswith("Sample"):
                continue
            assert i < len(expected_decision)
            if inst.kind.name == expected_decision[i][0]:
                new_decisions[inst] = expected_decision[i][1]
                i += 1
        if len(new_decisions) != len(expected_decision):
            continue
        sch = Schedule(mod, debug_mask=debug_mask)
        Trace(
            insts=sketch.trace.insts,
            decisions=new_decisions,
        ).apply_to_schedule(sch, remove_postproc=True)
        if structural_equal(sch.mod, expected_mod):
            verify_trace_roundtrip(sch=sch, mod=mod, debug_mask=debug_mask)
            return sketch_id
    return None
def test_conv2d_winograd_cuda():
    mod = conv2d_winograd_cuda
    mod = IRModule({"main": mod})
    context = TuneContext(
        mod=mod,
        target=Target("nvidia/geforce-rtx-3090", host="llvm"),
        task_name="Custom Search Space Task",
        sch_rules=DefaultCUDA._sch_rules(),  # pylint: disable=protected-access
    )
    for sch_rule in context.sch_rules:
        sch_rule.initialize_with_tune_context(context)
    post_order_apply = PostOrderApply()
    post_order_apply.initialize_with_tune_context(context)
    (sch,) = post_order_apply.generate_design_space(mod)
    decisions = dict(
        zip(
            [i for i in sch.trace.insts if i.kind.name.startswith("Sample")],
            [
                # data_pack
                [3, 3],
                [64, 2],
                2,
                # inverse
                [3, 3],
                [2, 64],
                2,
                # bgemm
                [1, 1, 1, 1, 6],
                [1, 1, 1, 3, 2],
                [3, 1, 1, 1, 3],
                [4, 2, 1, 4, 4],
                [32, 1, 4],
                1,
                1,
                # root anno
                2,
                # conv2d
                2,
            ],
        )
    )
    trace = Trace(sch.trace.insts, decisions=decisions)
    sch = Schedule(mod=mod)
    trace.apply_to_schedule(sch, remove_postproc=False)
    answer = sch.mod
    expected = _get_mod()
    tvm.ir.assert_structural_equal(answer, expected)
def _make_trace_1(b0, l1, l2):  # pylint: disable=invalid-name
    return Trace(
        insts=[
            _make_get_block(name="block", output=b0),
            _make_get_loops(input=b0, outputs=[l1, l2]),
        ],
        decisions={},
    )
def _make_trace_2(b0):  # pylint: disable=invalid-name
    return Trace(
        insts=[
            _make_get_block(name="B", output=b0),
            _make_compute_inline(input=b0),
        ],
        decisions={},
    )
Esempio n. 8
0
def _make_trace_4(b0, l1, l2, l3):  # pylint: disable=invalid-name
    return Trace(
        insts=[
            _make_get_block(name="B", output=b0),
            _make_get_loops(input=b0, outputs=[l1]),
            _make_split([l1, None, 32], [l2, l3]),
        ],
        decisions={},
    )
Esempio n. 9
0
def test_conv2d_winograd_cpu():
    mod = conv2d_winograd_cpu
    mod = IRModule({"main": mod})
    context = TuneContext(
        mod=mod,
        target=Target("llvm"),
        task_name="Custom Search Space Task",
        sch_rules=DefaultLLVM._sch_rules(),  # pylint: disable=protected-access
    )
    post_order_apply = PostOrderApply()
    post_order_apply.initialize_with_tune_context(context)
    (sch, ) = post_order_apply.generate_design_space(mod)

    decisions = dict(
        zip(
            [
                i for i in sch.trace.insts[:-4]
                if i.kind.name.startswith("Sample")
            ],
            [
                # data_pack
                [9, 1],
                [32, 4],
                # input_tile
                4,
                # data_pad
                -2,
                # inverse
                [1, 9],
                [2, 64],
                # bgemm
                [1, 2, 3, 1],
                [1, 1, 1, 6],
                [1, 1, 1, 9],
                [2, 1, 16, 4],
                [16, 8],
            ],
        ))
    trace = Trace(sch.trace.insts[:-4], decisions=decisions)
    sch = Schedule(mod=mod)
    trace.apply_to_schedule(sch, remove_postproc=False)
    answer = sch.mod
    expected = _get_mod()
    tvm.ir.assert_structural_equal(answer, expected)
Esempio n. 10
0
def verify_trace_roundtrip(
    sch: Schedule,
    mod: Union[PrimFunc, IRModule],
    *,
    debug_mask: Union[str, int] = "all",
) -> Schedule:
    """Serialize a traced schedule to JSON, then replay the JSON trace by applying to
    a fresh new schedule, verifying the reproducibility of scheduling.

    Parameters
    ----------
    sch : tir.Schedule
        The traced TensorIR schedule to be verified
    mod : Union[PrimFunc, IRModule]
        The IRModule or PrimFunc to construct the fresh new schedule
    debug_mask : Union[str, int]
        Do extra correctness checking after the class creation and each time
        after calling the Replace method.
        Possible choices of `debug_mask`:
        1) "all" - Turn on all the checks
        2) "none" - Turn off all the checks
        3) An integer - Turn on checks according to the bitmasks provided in ScheduleDebugMask
    """
    # Step 1. Serialize the trace to JSON
    trace = sch.trace
    assert trace is not None
    json_obj = trace.as_json()
    # Step 2. Apply the JSON trace to a new schedule, then check if it reproduces the scheduling
    new_sch = Schedule(mod=mod, debug_mask=debug_mask)
    Trace.apply_json_to_schedule(json_obj=json_obj, sch=new_sch)
    assert structural_equal(new_sch.mod, sch.mod)
    # Step 3. Check the consistency of the text format between the old and new traces
    py_repr = "\n".join(trace.as_python())
    new_py_repr = "\n".join(new_sch.trace.as_python())
    assert py_repr == new_py_repr
    # Step 4. Return the new schedule in case it could be useful
    return new_sch
def _make_trace_3(b0, b1, add_postproc):  # pylint: disable=invalid-name
    if add_postproc:
        insts = [
            _make_get_block(name="B", output=b0),
            _make_compute_inline(input=b0),
            _make_get_block(name="C", output=b1),
            _make_enter_postproc(),
            _make_compute_inline(input=b1),
        ]
    else:
        insts = [
            _make_get_block(name="B", output=b0),
            _make_compute_inline(input=b0),
            _make_get_block(name="C", output=b1),
        ]
    return Trace(insts=insts, decisions={})
def test_apply_json_to_schedule_1():
    trace = _make_trace_2(BlockRV())
    json_obj = trace.as_json()
    sch = tir.Schedule(elementwise, debug_mask="all")
    Trace.apply_json_to_schedule(json_obj, sch)
    tvm.ir.assert_structural_equal(elementwise_inlined, sch.mod["main"])
def test_trace_construct_pop_2():
    trace = Trace([], {})
    assert str(trace) == ""
    assert trace.pop() is None
    assert str(trace) == ""
Esempio n. 14
0
 def apply(self, trace: Trace, _) -> Optional[Trace]:
     return Trace(trace.insts, {})
Esempio n. 15
0
def _is_trace_equal(sch_1: Schedule, sch_2: Schedule) -> bool:
    trace_1 = Trace(sch_1.trace.insts, {})
    trace_2 = Trace(sch_2.trace.insts, {})
    return str(trace_1) == str(trace_2)