def test_unweighted_all_correct(self): s_obj = metrics.PrecisionAtRecall(0.7) inputs = np.random.randint(0, 2, size=(100, 1)) y_pred = K.constant(inputs, dtype='float32') y_true = K.constant(inputs) result = s_obj(y_true, y_pred) assert np.isclose(1, K.eval(result))
def test_PrecisionAtRecall(self, distribution): label_prediction = ([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8]) with distribution.scope(): metric = metrics.PrecisionAtRecall(0.5) self.evaluate([v.initializer for v in metric.variables]) updates = distribution.run(metric, args=label_prediction) self.evaluate(updates) self.assertAllClose(metric.result(), 0.5)
def test_unweighted_low_recall(self): s_obj = metrics.PrecisionAtRecall(0.4) pred_values = [0.0, 0.1, 0.2, 0.3, 0.4, 0.1, 0.15, 0.25, 0.26, 0.26] label_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1] y_pred = K.constant(pred_values, dtype='float32') y_true = K.constant(label_values) result = s_obj(y_true, y_pred) assert np.isclose(0.5, K.eval(result))
def test_unweighted_high_recall(self): s_obj = metrics.PrecisionAtRecall(0.8) pred_values = [0.0, 0.1, 0.2, 0.3, 0.5, 0.4, 0.5, 0.6, 0.8, 0.9] label_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1] # For a score between 0.4 and 0.5, we expect 0.8 precision, 0.8 recall. y_pred = K.constant(pred_values, dtype='float32') y_true = K.constant(label_values) result = s_obj(y_true, y_pred) assert np.isclose(0.8, K.eval(result))
def test_weighted(self): s_obj = metrics.PrecisionAtRecall(0.4) pred_values = [0.0, 0.1, 0.2, 0.3, 0.4, 0.01, 0.02, 0.25, 0.26, 0.26] label_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1] weight_values = [2, 2, 1, 1, 1, 1, 1, 2, 2, 2] y_pred = K.constant(pred_values, dtype='float32') y_true = K.constant(label_values, dtype='float32') weights = K.constant(weight_values) result = s_obj(y_true, y_pred, sample_weight=weights) assert np.isclose(2. / 3., K.eval(result))
def test_config(self): s_obj = metrics.PrecisionAtRecall(0.4, num_thresholds=100, name='precision_at_recall_1') assert s_obj.name == 'precision_at_recall_1' assert s_obj.recall == 0.4 assert s_obj.num_thresholds == 100 # Check save and restore config s_obj2 = metrics.PrecisionAtRecall.from_config(s_obj.get_config()) assert s_obj2.name == 'precision_at_recall_1' assert s_obj2.recall == 0.4 assert s_obj2.num_thresholds == 100
def test_invalid_num_thresholds(self): with pytest.raises(Exception): metrics.PrecisionAtRecall(0.4, num_thresholds=-1)
def test_invalid_sensitivity(self): with pytest.raises(Exception): metrics.PrecisionAtRecall(-1)