예제 #1
0
def test_sketch_search_policy_basic():
    # wrap the search in a new thread to avoid the conflict
    # between python's multiprocessing and tvm's thread pool
    t = PropagatingThread(target=search_common,
                          kwargs={'seed': 944563397, 'search_policy': 'sketch'})
    t.start()
    t.join()
예제 #2
0
def test_sketch_search_policy_xgbmodel():
    # wrap the search in a new thread to avoid the conflict
    # between python's multiprocessing and tvm's thread pool
    t = PropagatingThread(target=search_common,
                          kwargs={'seed': 944563397, 'search_policy': 'sketch',
                                  'cost_model': auto_scheduler.XGBModel()})
    t.start()
    t.join()
예제 #3
0
def test_sketch_search_policy_cuda_rpc_runner():
    measure_ctx = auto_scheduler.LocalRPCMeasureContext()
    # wrap the search in a new thread to avoid the conflict
    # between python's multiprocessing and tvm's thread pool
    t = PropagatingThread(target=search_common,
                          kwargs={'seed': 944563397, 'search_policy': 'sketch', 'target': 'cuda',
                                  'runner': measure_ctx.runner})
    t.start()
    t.join()
예제 #4
0
def test_sketch_search_policy_cuda_xgbmodel_rpc_runner():
    if not tvm.runtime.enabled("cuda"):
        return
    measure_ctx = auto_scheduler.LocalRPCMeasureContext()
    # wrap the search in a new thread to avoid the conflict
    # between python's multiprocessing and tvm's thread pool
    t = PropagatingThread(target=search_common,
                          kwargs={'seed': 944563397, 'search_policy': 'sketch', 'target': 'cuda',
                                  'runner': measure_ctx.runner, 'cost_model': auto_scheduler.XGBModel()})
    t.start()
    t.join()
예제 #5
0
def test_sketch_search_policy_cuda_rpc_runner():
    measure_ctx = auto_scheduler.LocalRPCMeasureContext()
    # wrap the search in a new thread to avoid the conflict
    # between python's multiprocessing and tvm's thread pool
    t = PropagatingThread(
        target=search_common,
        kwargs={
            "target": "cuda",
            "runner": measure_ctx.runner,
        },
    )
    t.start()
    t.join()
def test_sketch_search_policy_cuda_xgbmodel_rpc_runner():
    if not tvm.runtime.enabled("cuda"):
        return
    measure_ctx = auto_scheduler.LocalRPCMeasureContext()
    # wrap the search in a new thread to avoid the conflict
    # between python's multiprocessing and tvm's thread pool
    t = PropagatingThread(
        target=search_common,
        kwargs={
            "seed": 944563397,
            "search_policy": "sketch",
            "target": "cuda",
            "runner": measure_ctx.runner,
            "cost_model": auto_scheduler.XGBModel(),
        },
    )
    t.start()
    t.join()
예제 #7
0
def test_workload_registry_search_basic():
    # wrap the search in a new thread to avoid the conflict
    # between python's multiprocessing and tvm's thread pool
    t = PropagatingThread(target=search_common, kwargs={"seed": 944563397})
    t.start()
    t.join()
    t = PropagatingThread(target=search_common,
                          kwargs={
                              "seed": 944563397,
                              "workload": "matmul_auto_scheduler_test"
                          })
    t.start()
    t.join()
    t = PropagatingThread(
        target=search_common,
        kwargs={
            "seed": 944563397,
            "workload": "matmul_auto_scheduler_test_rename_1"
        },
    )
    t.start()
    t.join()
예제 #8
0
def test_workload_registry_search_basic():
    # wrap the search in a new thread to avoid the conflict
    # between python's multiprocessing and tvm's thread pool
    t = PropagatingThread(target=search_common,
                          kwargs={
                              "search_policy": "empty",
                              "num_measure_trials": 2
                          })
    t.start()
    t.join()

    t = PropagatingThread(
        target=search_common,
        kwargs={
            "workload": "matmul_auto_scheduler_test",
            "num_measure_trials": 2,
            "search_policy": "empty",
        },
    )
    t.start()
    t.join()

    t = PropagatingThread(
        target=search_common,
        kwargs={
            "workload": "matmul_auto_scheduler_test_rename_1",
            "num_measure_trials": 2,
            "search_policy": "empty",
        },
    )
    t.start()
    t.join()
예제 #9
0
def test_workload_registry_search_basic():
    if not tvm.runtime.enabled("llvm"):
        return
    # wrap the search in a new thread to avoid the conflict
    # between python's multiprocessing and tvm's thread pool
    t = PropagatingThread(target=search_common, kwargs={'seed': 944563397})
    t.start()
    t.join()
    t = PropagatingThread(target=search_common,
                          kwargs={
                              'seed': 944563397,
                              'workload': "matmul_auto_scheduler_test"
                          })
    t.start()
    t.join()
    t = PropagatingThread(target=search_common,
                          kwargs={
                              'seed': 944563397,
                              'workload': "matmul_auto_scheduler_test_rename_1"
                          })
    t.start()
    t.join()