예제 #1
0
def test_good_less(spark, sample_cache):
    model = SARModel(sample_cache)
    y = model.predict([0, 2], [10, 3], top_k=5, remove_seen=False)

    assert_compare(0, 1, y[0])
    assert_compare(1, 11.6, y[1])
    assert_compare(2, 12.3, y[2])
예제 #2
0
def test_good(spark, sample_cache):
    model = SARModel(sample_cache)
    y = model.predict([0, 1], [10, 20], top_k=10, remove_seen=False)

    assert_compare(0, 5, y[0])
    assert_compare(1, 44, y[1])
    assert_compare(2, 64, y[2])
예제 #3
0
def test_good_require_sort(spark, sample_cache):
    model = SARModel(sample_cache)
    y = model.predict([1, 0], [20, 10], top_k=10, remove_seen=False)

    assert_compare(0, 5, y[0])
    assert_compare(1, 44, y[1])
    assert_compare(2, 64, y[2])

    assert 3 == len(y)
예제 #4
0
def test_pandas(spark, sample_cache):
    item_scores = pd.DataFrame([(0, 2.3), (1, 3.1)],
                               columns=["itemID", "score"])

    model = SARModel(sample_cache)
    y = model.predict(item_scores["itemID"].values,
                      item_scores["score"].values,
                      top_k=10,
                      remove_seen=False)

    assert_compare(0, 0.85, y[0])
    assert_compare(1, 6.9699, y[1])
    assert_compare(2, 9.92, y[2])
예제 #5
0
        def sar_predict_udf(df):
            # Magic happening here
            # The cache_path points to file write to by com.microsoft.sarplus
            # This has exactly the memory layout we need and since the file is
            # memory mapped, the memory consumption only happens once per worker
            # for all python processes
            model = SARModel(cache_path_input)
            preds = model.predict(df["idx"].values, df["rating"].values, top_k,
                                  remove_seen)

            user = df[local_header["col_user"]].iloc[0]

            preds_ret = pd.DataFrame([(user, x.id, x.score) for x in preds],
                                     columns=range(3))

            return preds_ret
예제 #6
0
def test_good_require_sort_remove_seen(spark, sample_cache):
    model = SARModel(sample_cache)
    y = model.predict([1, 0], [20, 10], top_k=10, remove_seen=True)

    assert_compare(2, 64, y[0])
    assert 1 == len(y)