Пример #1
0
def test_meta_schedule_tuning_record_round_trip():
    mod: IRModule = Matmul
    with tempfile.TemporaryDirectory() as tmpdir:
        database = _create_tmp_database(tmpdir)
        workload = database.commit_workload(mod)
        record = TuningRecord(
            _create_schedule(mod, _schedule_matmul).trace,
            [1.5, 2.5, 1.8],
            workload,
            tvm.target.Target("llvm"),
            ArgInfo.from_prim_func(func=mod["main"]),  # pylint: disable=unsubscriptable-object
        )
        database.commit_tuning_record(record)
        new_record = TuningRecord.from_json(record.as_json(), workload)
        _equal_record(record, new_record)
Пример #2
0
def test_meta_schedule_database_reload():
    mod: IRModule = Matmul
    with tempfile.TemporaryDirectory() as tmpdir:
        database = _create_tmp_database(tmpdir)
        token = database.commit_workload(mod)
        trace = _create_schedule(mod, _schedule_matmul).trace
        records = [
            TuningRecord(
                trace,
                [7.0, 8.0, 9.0],
                token,
                tvm.target.Target("llvm"),
                ArgInfo.from_prim_func(func=mod["main"]),  # pylint: disable=unsubscriptable-object
            ),
            TuningRecord(
                trace,
                [1.0, 2.0, 3.0],
                token,
                tvm.target.Target("llvm"),
                ArgInfo.from_prim_func(func=mod["main"]),  # pylint: disable=unsubscriptable-object
            ),
            TuningRecord(
                trace,
                [4.0, 5.0, 6.0],
                token,
                tvm.target.Target("llvm"),
                ArgInfo.from_prim_func(func=mod["main"]),  # pylint: disable=unsubscriptable-object
            ),
        ]
        for record in records:
            database.commit_tuning_record(record)
        new_database = JSONDatabase(  # pylint: disable=unused-variable
            path_workload=database.path_workload,
            path_tuning_record=database.path_tuning_record,
        )
        token = new_database.commit_workload(mod)
        ret = new_database.get_top_k(token, 2)
        assert len(ret) == 2
        try:
            _equal_record(ret[0], records[2])
            _equal_record(ret[1], records[1])
        except AssertionError:
            _equal_record(ret[0], records[1])
            _equal_record(ret[1], records[2])
Пример #3
0
def test_meta_schedule_integration_apply_history_best():
    class DummyDatabase(PyDatabase):
        def __init__(self):
            super().__init__()
            self.records = []
            self.workload_reg = []

        def has_workload(self, mod: IRModule) -> Workload:
            for workload in self.workload_reg:
                if tvm.ir.structural_equal(workload.mod, mod):
                    return True
            return False

        def commit_tuning_record(self, record: TuningRecord) -> None:
            self.records.append(record)

        def commit_workload(self, mod: IRModule) -> Workload:
            for workload in self.workload_reg:
                if tvm.ir.structural_equal(workload.mod, mod):
                    return workload
            workload = Workload(mod)
            self.workload_reg.append(workload)
            return workload

        def get_top_k(self, workload: Workload,
                      top_k: int) -> List[TuningRecord]:
            return list(
                filter(
                    lambda x: x.workload == workload,
                    sorted(self.records,
                           key=lambda x: sum(x.run_secs) / len(x.run_secs)),
                ))[:int(top_k)]

        def __len__(self) -> int:
            return len(self.records)

        def print_results(self) -> None:
            print("\n".join([str(r) for r in self.records]))

    mod, _, _, _ = get_network(
        name="resnet-18",
        batch_size=1,
        layout="NHWC",
        dtype="float32",
    )
    database = DummyDatabase()
    env = ApplyHistoryBest(database)
    target = Target("llvm")
    workload = database.commit_workload(MockModule)
    database.commit_tuning_record(
        TuningRecord(Schedule(MockModule).trace, [1.0], workload, target, []))
    mod = env.query(task_name="mock-task",
                    mod=mod,
                    target=target,
                    dispatched=[MockModule])
    assert tvm.ir.structural_equal(mod, workload.mod)
Пример #4
0
def test_meta_schedule_integration_apply_history_best():
    mod, _, _ = get_network(name="resnet_18", input_shape=[1, 3, 224, 224])
    database = DummyDatabase()
    env = ApplyHistoryBest(database)
    target = Target("llvm")
    workload = database.commit_workload(MockModule)
    database.commit_tuning_record(
        TuningRecord(Schedule(MockModule).trace, [1.0], workload, target, [])
    )
    mod = env.query(task_name="mock-task", mod=mod, target=target, dispatched=[MockModule])
    assert tvm.ir.structural_equal(mod, workload.mod)
Пример #5
0
def test_meta_schedule_database_add_entry():
    mod: IRModule = Matmul
    with tempfile.TemporaryDirectory() as tmpdir:
        database = _create_tmp_database(tmpdir)
        workload = database.commit_workload(mod)
        record = TuningRecord(
            _create_schedule(mod, _schedule_matmul).trace,
            [1.5, 2.5, 1.8],
            workload,
            tvm.target.Target("llvm"),
            ArgInfo.from_prim_func(func=mod["main"]),  # pylint: disable=unsubscriptable-object
        )
        database.commit_tuning_record(record)
        assert len(database) == 1
        (ret, ) = database.get_top_k(workload, 3)
        _equal_record(ret, record)
Пример #6
0
def apply_fixed_schedules(
    relay_mod: Union[RelayFunc, IRModule],
    target: Union[str, Target],
    params: Optional[Dict[str, NDArray]],
    schedule_fn: Callable[[ExtractedTask, Schedule], bool],
):
    """Apply fixed schedules (manually written, without any tunable knobs) as specified by
    schedule_fn to extracted tasks, and return a database that can be passed to ApplyHistoryBest.

    Parameters
    ----------
    mod : Union[RelayFunc, IRModule]
        The Relay module to apply fixed schedules.
    target : Union[str, Target]
        The target used to extract tasks.
    params : Optional[Dict[str, tvm.runtime.NDArray]]
        The associated parameters of the module.
    schedule_fn : Callable[[ExtractedTask, Schedule], bool]
        A callable that is applied for each extracted task and the corresponding default schedule.
        Returns True if the given schedule should be committed to the database, False otherwise.

    Returns
    -------
    database : Database
        The database containing dummy tuning records for manually scheduled traces.
    """
    target = Target(target) if isinstance(target, str) else target
    extracted_tasks = extract_task_from_relay(relay_mod, target, params)

    database = DummyDatabase()

    for task in extracted_tasks:
        mod = Parse._mod(task.dispatched[0])
        sch = Schedule(mod)

        if schedule_fn(task, sch):
            workload = database.commit_workload(mod)
            tune_rec = TuningRecord(sch.trace, [0.0], workload, target, [])
            database.commit_tuning_record(tune_rec)

    return database
def test_meta_schedule_relay_lowering():
    data_shape = (1, 3, 16, 16)
    weight_shape = (8, 3, 5, 5)
    data = relay.var("data", relay.TensorType(data_shape, "float32"))
    weight = relay.var("weight", relay.TensorType(weight_shape, "float32"))
    y = relay.nn.conv2d(
        data,
        weight,
        padding=(2, 2),
        kernel_size=(5, 5),
        kernel_layout="OIHW",
        out_dtype="float32",
    )
    f = relay.Function([data, weight], y)
    mod = tvm.IRModule.from_expr(f)
    mod = relay.transform.InferType()(mod)

    data_sample = np.random.rand(*data_shape).astype("float32")
    weight_sample = np.random.rand(*weight_shape).astype("float32")
    params = {mod["main"].params[1].name_hint: weight_sample}

    input_name = "data"
    dev = tvm.cpu()
    target = Target("llvm --num-cores=16")
    data = tvm.nd.array(data_sample, dev)

    with tempfile.TemporaryDirectory() as work_dir:
        database = JSONDatabase(osp.join(work_dir, "workload.json"),
                                osp.join(work_dir, "records.json"))

        database.commit_tuning_record(
            TuningRecord(
                Trace([], {}),
                [0.0],
                database.commit_workload(
                    tvmgen_default_fused_nn_contrib_conv2d_NCHWc),
                target=target,
                args_info=[],
            ))

        with ApplyHistoryBest(database):
            with tvm.transform.PassContext(
                    opt_level=3,
                    config={"relay.backend.use_meta_schedule": True},
            ):
                rt_mod1 = relay.build(mod, target=target, params=params)

        # Compile without meta-scheduler for correctness check
        with tvm.transform.PassContext(opt_level=0):
            rt_mod2 = relay.build(mod, target=target, params=params)

        def get_output(data, lib):
            module = graph_executor.GraphModule(lib["default"](dev))
            module.set_input(input_name, data)
            module.run()
            return module.get_output(0).numpy()

        # Check correctness
        actual_output = get_output(data, rt_mod1)
        expected_output = get_output(data, rt_mod2)
        assert np.allclose(actual_output,
                           expected_output,
                           rtol=1e-4,
                           atol=2e-4)