def test_gist_case(self):
        # https://gist.github.com/bwhite/3726239#gistcomment-2852580

        res = recommender_metrics.calculate_metrics(
            k_list=[1, 2, 3],
            group_ids=np.asarray([1, 1, 1, 2, 2, 2, 3, 3, 3]),
            scores=np.asarray([1, 2, 3, 1, 2, 3, 1, 2, 3]),
            labels=np.asarray([1, 1, 1, 0, 1, 1, 0, 0, 1]),
            reduce=False,
        )

        target = {
            "mAP@1": np.asarray([1.0, 1.0, 1.0]),
            "precision@1": np.asarray([1.0, 1.0, 1.0]),
            "recall@1": np.asarray([0.33333333, 0.5, 1.0]),
            "mAP@2": np.asarray([1.0, 1.0, 1.0]),
            "precision@2": np.asarray([1.0, 1.0, 0.5]),
            "recall@2": np.asarray([0.66666667, 1.0, 1.0]),
            "mAP@3": np.asarray([1.0, 1.0, 1.0]),
            "precision@3": np.asarray([1.0, 0.66666667, 0.33333333]),
            "recall@3": np.asarray([1.0, 1.0, 1.0]),
        }

        for key, val in target.items():
            assert np.allclose(val, res[key])
def basic_usage_example_1():
    from recommender_metrics import calculate_metrics
    import numpy as np
    import json

    print("Running example1")

    rng = np.random.RandomState(1234)
    metrics = calculate_metrics(
        group_ids=rng.randint(0, 10, 100),
        scores=rng.normal(0, 1, 100),
        labels=rng.rand(100) > 0.8,
    )
    print(json.dumps(metrics, indent=2))
    print("\n\n\n")
def basic_usage_large_dataset():
    from recommender_metrics import calculate_metrics
    from recommender_metrics import generate_random_data
    import json

    groups, scores, labels = generate_random_data(n_users=50000)
    print("Larger data:")
    print("  #groups:", len(groups))
    print("  #scores:", len(scores))
    print("  #labels:", len(labels))
    print()

    metrics = calculate_metrics(group_ids=groups, scores=scores, labels=labels)
    print("Metrics:")
    print(json.dumps(metrics, indent=2))
    print("\n\n\n")
def basic_usage_group_filttering():
    from recommender_metrics import calculate_metrics
    from recommender_metrics import generate_random_data
    import json

    print("Running example3")

    groups, scores, labels = generate_random_data()
    print("Data:")
    print("  #groups:", len(groups))
    print("  #scores:", len(scores))
    print("  #labels:", len(labels))
    print()

    metrics = calculate_metrics(group_ids=groups,
                                scores=scores,
                                labels=labels,
                                remove_empty=True)
    print("Metrics:")
    print(json.dumps(metrics, indent=2))
    print("\n\n\n")
def basic_usage_ascending_scores():
    from recommender_metrics import calculate_metrics
    from recommender_metrics import search_data
    import json

    print("Running example2")

    groups, positions, labels = search_data()
    print("Data:")
    print("     groups:", groups)
    print("  positions:", positions)
    print("     labels:", labels)
    print()

    metrics = calculate_metrics(group_ids=groups,
                                scores=positions,
                                labels=labels,
                                ascending=True)
    print("Metrics:")
    print(json.dumps(metrics, indent=2))
    print("\n\n\n")
def basic_usage_custom_metrics_k():
    from recommender_metrics import calculate_metrics
    from recommender_metrics import generate_random_data
    import json

    print("Running example4")

    groups, scores, labels = generate_random_data()
    print("Data:")
    print("  #groups:", len(groups))
    print("  #scores:", len(scores))
    print("  #labels:", len(labels))
    print()

    metrics = calculate_metrics(
        group_ids=groups,
        scores=scores,
        labels=labels,
        k_list=[1, 2, 4, 8, 16],
        metrics=["mAP", "precision", "recall", "ndcg", "auroc"],
    )
    print("Metrics:")
    print(json.dumps(metrics, indent=2))
    print("\n\n\n")
    def test_empty(self):
        target = {
            "mAP@1": 0.2631578947368421,
            "precision@1": 0.2631578947368421,
            "recall@1": 0.07631578947368421,
            "mAP@5": 0.5046783625730993,
            "precision@5": 0.3140350877192983,
            "recall@5": 0.45743525480367586,
            "mAP@10": 0.4538314536340852,
            "precision@10": 0.2941520467836258,
            "recall@10": 0.6738512949039264,
            "mAP@20": 0.424754158736318,
            "precision@20": 0.33355881010883454,
            "recall@20": 1.0,
        }

        groups, scores, labels = recommender_metrics.generate_random_data()
        metrics = recommender_metrics.calculate_metrics(group_ids=groups,
                                                        scores=scores,
                                                        labels=labels,
                                                        remove_empty=True)
        self.dict_vals_all_close(target=target,
                                 pred=metrics,
                                 desc=f"Removal of empty group labels")