def test_ranker_indicator_function_logistic(): ranker = factories.ranker() ranker.set_indicator_type('logistic') ranker.sigma = 0.1 ranker.cutoff = 5 ranker.predict = np.array( [-1., -0.75, -0.5, -0.25, 0., 0.25, 0.5, 0.75, 1.]) ranker.kappa = np.array([0., 0., 0., 0., 0., 0., 0., 0., 0.]) ranker.indicator_func() # indicator score expected = [ 4.5397869e-05, 5.5277864e-04, 6.6928509e-03, 7.5858180e-02, 5.0000000e-01, 9.2414182e-01, 9.9330715e-01, 9.9944722e-01, 9.9995460e-01 ] np.testing.assert_almost_equal(expected, ranker.indicator_score) # indicator derivative expected = [ 4.5395808e-04, 5.5247307e-03, 6.6480567e-02, 7.0103717e-01, 2.5000000e+00, 7.0103717e-01, 6.6480567e-02, 5.5247307e-03, 4.5395808e-04 ] np.testing.assert_almost_equal(expected, ranker.indicator_derivative)
def test_ranker_indicator_function_relu(): ranker = factories.ranker() ranker.set_indicator_type('relu') ranker.delta = 0.1 ranker.cutoff = 3 ranker.predict = np.array( [-1., -0.75, -0.5, -0.25, 0., 0.25, 0.5, 0.75, 1.]) ranker.kappa = np.array([0., 0., 0., 0., 0., 0., 0., 0., 0.]) ranker.indicator_func() # indicator score expected = [0., 0., 0., 0., 0.5, 1., 1., 1., 1.] np.testing.assert_almost_equal(expected, ranker.indicator_score) # indicator derivative expected = [0., 0., 0., 0., 5., 0., 0., 0., 0.] np.testing.assert_almost_equal(expected, ranker.indicator_derivative)
def test_cascade_computed_kappa_when_training(): qid = np.array([1, 1, 1, 1, 1]) offsets = group_offsets(qid) a, b = next(offsets) cascade = factories.dummy_cascade() ranker = factories.ranker() ranker.cutoff = 2 prev_mask = [1, 1, 0, 1, 1] scores = np.array([0.1, 1.0, -0.03, 0.5, 0.25]) ranker.predict = np.copy(scores) # according to previous mask ranker.predict[2] = Cascade.SCORE_MASK scores = cascade.ranker_apply_cutoff(ranker, scores, prev_mask, qid, is_train=True) expected = [0.5] * 5 np.testing.assert_almost_equal(ranker.kappa[a:b], expected) assert scores is not ranker.predict
def test_ranker_set_unkown_indicator_type(): ranker = factories.ranker() with pytest.raises(KeyError): ranker.set_indicator_type('unkown-foo')