def test_base_cascade_model_obs_irrelevant(): np.random.seed(42) y = [1, 2, 0, 4, 3] assert click(BaseCascadeModel(0.3, 0), [], y) == [1, 1, 1, 1, 1] assert click(BaseCascadeModel(0.3, 1), [], y) == [1, 1, 0, 1, 0] assert click(BaseCascadeModel(0.3, 3), [], y) == [0, 0, 0, 1, 1] assert click(BaseCascadeModel(0.3, 4), [], y) == [0, 0, 0, 1, 0]
def test_only_relevant_above_threshold_click_model(): click_model = OnlyRelevantClickModel(2) assert click(click_model, np.array([[0, 1]]), [2]) == [1] assert click(click_model, np.array([[0, 1]]), [1]) == [0] assert click(click_model, np.array([[0, 1]]), [0]) == [0] assert click(click_model, np.array([[1, 1], [1, 0], [0, 0]]), [0, 1, 2]) == [0, 0, 1]
def test_inner_click_model_should_just_get_unmasked_docs(): np.random.seed(42) click_model = MaskedRemainMasked(RandomClickModel(n_clicks=1)) clicks = click( click_model, np.ones((5, 1)), np.array([ 0, PADDED_Y_VALUE, PADDED_Y_VALUE, PADDED_Y_VALUE, PADDED_Y_VALUE ])) assert clicks == [ 1, PADDED_Y_VALUE, PADDED_Y_VALUE, PADDED_Y_VALUE, PADDED_Y_VALUE ]
def test_base_cascade_model_eta(): np.random.seed(42) click_model_1 = BaseCascadeModel(0.3, 1) click_model_2 = BaseCascadeModel(0.5, 1) assert click(click_model_1, [], [1, 2]) == [1, 0] assert click(click_model_1, [], [1, 2, 3]) == [1, 1, 1] assert click(click_model_1, [], [1, 2, 3, 4]) == [1, 1, 0, 1] assert click(click_model_2, [], [1, 2]) == [1, 1] assert click(click_model_2, [], [1, 2, 3]) == [1, 0, 1] assert click(click_model_2, [], [1, 2, 3, 4]) == [1, 1, 1, 0]
def test_masked_should_remain_masked(): click_model = MaskedRemainMasked(FixedClickModel(click_positions=[1])) assert click(click_model, np.ones((3, 1)), np.array([0, 0, PADDED_Y_VALUE])) == [0, 1, PADDED_Y_VALUE]
def test_fixed_click_model_single(): click_model = FixedClickModel([0]) assert click(click_model, [], [1]) == [1] assert click(click_model, [], [1, 2]) == [1, 0] assert click(click_model, [], [1, 2, 3]) == [1, 0, 0]
def test_fixed_click_model_multiple(): assert click(FixedClickModel([0, 1]), [], [1, 2, 3, 4]) == [1, 1, 0, 0] assert click(FixedClickModel([0, 1, 2]), [], [1, 2, 3, 4]) == [1, 1, 1, 0] assert click(FixedClickModel([0, 2, 3]), [], [1, 2, 3, 4]) == [1, 0, 1, 1]
def test_feature_click_model_everything(): click_model = EverythingButDuplicatesClickModel() assert click(click_model, np.array([[0, 1]]), []) == [1] assert click(click_model, np.array([[1, 1], [1, 0]]), []) == [1, 1] assert click(click_model, np.array([[1, 1], [1, 0], [0, 0]]), []) == [1, 1, 1]
def test_feature_click_model_except_near_duplicates(): click_model = EverythingButDuplicatesClickModel(0.1) assert click(click_model, np.array([[0, 1]]), []) == [1] assert click(click_model, np.array([[1, 1], [1, 1]]), []) == [1, 0] assert click(click_model, np.array([[1, 1], [1, 0.99], [1, 0.8]]), []) == [1, 0, 1]
def test_only_relevant_click_model(): click_model = OnlyRelevantClickModel(1) assert click(click_model, np.array([[0, 1]]), [1]) == [1] assert click(click_model, np.array([[0, 1]]), [0]) == [0] assert click(click_model, np.array([[1, 1], [1, 0], [0, 0]]), [1, 0, 0]) == [1, 0, 0]
def test_random_click_model_single(): click_model = RandomClickModel(1) np.random.seed(42) assert click(click_model, [], [1]) == [1] assert click(click_model, [], [1, 2]) == [0, 1] assert click(click_model, [], [1, 2, 3]) == [0, 1, 0]
def test_random_click_model_multiple(): np.random.seed(42) assert click(RandomClickModel(2), [], [1, 2, 3, 4]) == [0, 1, 0, 1] assert click(RandomClickModel(3), [], [1, 2, 3, 4]) == [1, 1, 0, 1] assert click(RandomClickModel(4), [], [1, 2, 3, 4]) == [1, 1, 1, 1]
def test_base_cascade_model_no_eta(): click_model = BaseCascadeModel(0.0, 1) assert click(click_model, [], [1]) == [1] assert click(click_model, [], [1, 2]) == [1, 1] assert click(click_model, [], [1, 2, 3]) == [1, 1, 1]
def test_base_cascade_model_below_threshold(): y = [1, 2, 0, 4, 3] assert click(BaseCascadeModel(0.0, 1), [], y) == [1, 1, 0, 1, 1] assert click(BaseCascadeModel(0.0, 2), [], y) == [0, 1, 0, 1, 1] assert click(BaseCascadeModel(0.0, 4), [], y) == [0, 0, 0, 1, 0]