コード例 #1
0
def test_meta_schedule_database_reload():
    mod: IRModule = Matmul
    with tempfile.TemporaryDirectory() as tmpdir:
        database = _create_tmp_database(tmpdir)
        token = database.commit_workload(mod)
        trace = _create_schedule(mod, _schedule_matmul).trace
        records = [
            TuningRecord(
                trace,
                [7.0, 8.0, 9.0],
                token,
                tvm.target.Target("llvm"),
                ArgInfo.from_prim_func(func=mod["main"]),  # pylint: disable=unsubscriptable-object
            ),
            TuningRecord(
                trace,
                [1.0, 2.0, 3.0],
                token,
                tvm.target.Target("llvm"),
                ArgInfo.from_prim_func(func=mod["main"]),  # pylint: disable=unsubscriptable-object
            ),
            TuningRecord(
                trace,
                [4.0, 5.0, 6.0],
                token,
                tvm.target.Target("llvm"),
                ArgInfo.from_prim_func(func=mod["main"]),  # pylint: disable=unsubscriptable-object
            ),
        ]
        for record in records:
            database.commit_tuning_record(record)
        new_database = JSONDatabase(  # pylint: disable=unused-variable
            path_workload=database.path_workload,
            path_tuning_record=database.path_tuning_record,
        )
        token = new_database.commit_workload(mod)
        ret = new_database.get_top_k(token, 2)
        assert len(ret) == 2
        try:
            _equal_record(ret[0], records[2])
            _equal_record(ret[1], records[1])
        except AssertionError:
            _equal_record(ret[0], records[1])
            _equal_record(ret[1], records[2])
コード例 #2
0
def test_meta_schedule_tuning_record_round_trip():
    mod: IRModule = Matmul
    with tempfile.TemporaryDirectory() as tmpdir:
        database = _create_tmp_database(tmpdir)
        workload = database.commit_workload(mod)
        record = TuningRecord(
            _create_schedule(mod, _schedule_matmul).trace,
            [1.5, 2.5, 1.8],
            workload,
            tvm.target.Target("llvm"),
            ArgInfo.from_prim_func(func=mod["main"]),  # pylint: disable=unsubscriptable-object
        )
        database.commit_tuning_record(record)
        new_record = TuningRecord.from_json(record.as_json(), workload)
        _equal_record(record, new_record)
コード例 #3
0
def test_meta_schedule_database_add_entry():
    mod: IRModule = Matmul
    with tempfile.TemporaryDirectory() as tmpdir:
        database = _create_tmp_database(tmpdir)
        workload = database.commit_workload(mod)
        record = TuningRecord(
            _create_schedule(mod, _schedule_matmul).trace,
            [1.5, 2.5, 1.8],
            workload,
            tvm.target.Target("llvm"),
            ArgInfo.from_prim_func(func=mod["main"]),  # pylint: disable=unsubscriptable-object
        )
        database.commit_tuning_record(record)
        assert len(database) == 1
        (ret, ) = database.get_top_k(workload, 3)
        _equal_record(ret, record)
コード例 #4
0
def test_meta_schedule_arg_info_from_prim_func():
    a_info, b_info, c_info = ArgInfo.from_prim_func(Matmul)
    assert str(a_info) == 'TensorInfo("float32", [128, 256])'
    assert str(b_info) == 'TensorInfo("float32", [256, 512])'
    assert str(c_info) == 'TensorInfo("float32", [128, 512])'