def test_nvt_tf_rossmann_inference_triton(asv_db, bench_info, n_rows, err_tol):
    with test_utils.run_triton_server(
            os.path.expanduser(MODEL_DIR),
            "rossmann",
            TRITON_SERVER_PATH,
            TRITON_DEVICE_ID,
            "tensorflow",
    ) as client:
        diff, run_time = _run_rossmann_query(client, n_rows)

        assert (diff < err_tol).all()
        benchmark_results = []
        result = create_bench_result("test_nvt_tf_rossmann_inference_triton",
                                     [("n_rows", n_rows)], run_time,
                                     "datetime")
        benchmark_results.append(result)
def test_nvt_tf_movielens_inference_triton_mt(asv_db, bench_info, n_rows,
                                              err_tol):
    futures = []
    with test_utils.run_triton_server(
            os.path.expanduser(MODEL_DIR),
            "movielens",
            TRITON_SERVER_PATH,
            TRITON_DEVICE_ID,
            "tensorflow",
    ) as client:
        with concurrent.futures.ThreadPoolExecutor() as executor:
            for n_row in n_rows:
                futures.append(
                    executor.submit(_run_movielens_query, client, n_row))

    for future in concurrent.futures.as_completed(futures):
        diff, run_time = future.result()
        assert (diff < err_tol).all()
        benchmark_results = []
        result = create_bench_result(
            "test_nvt_tf_movielens_inference_triton_mt", [("n_rows", n_rows)],
            run_time, "datetime")
        benchmark_results.append(result)
Exemple #3
0
def test_inference(n_rows, err_tol):
    warnings.simplefilter("ignore")

    data_path = DATA_DIR + "test/data.csv"
    output_path = DATA_DIR + "test/output.csv"
    ps_file = TRAIN_DIR + "ps.json"

    workflow_path = MODEL_DIR + MODEL_NAME + "_nvt/1/workflow"

    _write_ps_hugectr(ps_file, MODEL_NAME, SPARSE_FILES, DENSE_FILE, NETWORK_FILE)

    with test_utils.run_triton_server(
        os.path.expanduser(MODEL_DIR),
        MODEL_NAME + "_ens",
        TRITON_SERVER_PATH,
        TRITON_DEVICE_ID,
        "hugectr",
        ps_file,
    ) as client:
        diff, run_time = _run_query(
            client,
            n_rows,
            MODEL_NAME + "_ens",
            workflow_path,
            data_path,
            output_path,
            "OUTPUT0",
            CATEGORICAL_COLUMNS,
            "hugectr",
        )
    assert (diff < err_tol).all()
    benchmark_results = []
    result = create_bench_result(
        "test_nvt_hugectr_inference", [("n_rows", n_rows)], run_time, "datetime"
    )
    benchmark_results.append(result)