예제 #1
0
 def testVars(self):
     classification.f1_score(predictions=array_ops.ones((10, 1)),
                             labels=array_ops.ones((10, 1)),
                             num_thresholds=3)
     expected = {
         'f1/true_positives:0', 'f1/false_positives:0',
         'f1/false_negatives:0'
     }
     self.assertEquals(expected,
                       set(v.name for v in variables.local_variables()))
     self.assertEquals(set(expected),
                       set(v.name for v in variables.local_variables()))
     self.assertEquals(
         set(expected),
         set(v.name
             for v in ops.get_collection(ops.GraphKeys.METRIC_VARIABLES)))
예제 #2
0
 def testUpdatesCollection(self):
     my_collection_name = '__updates__'
     _, f1_op = classification.f1_score(
         predictions=array_ops.ones((10, 1)),
         labels=array_ops.ones((10, 1)),
         num_thresholds=3,
         updates_collections=[my_collection_name])
     self.assertListEqual(ops.get_collection(my_collection_name), [f1_op])
예제 #3
0
 def testMetricsCollection(self):
     my_collection_name = '__metrics__'
     f1, _ = classification.f1_score(
         predictions=array_ops.ones((10, 1)),
         labels=array_ops.ones((10, 1)),
         num_thresholds=3,
         metrics_collections=[my_collection_name])
     self.assertListEqual(ops.get_collection(my_collection_name), [f1])
예제 #4
0
    def testWithMultipleUpdates(self):
        num_samples = 1000
        batch_size = 10
        num_batches = int(num_samples / batch_size)

        # Create the labels and data.
        labels = np.random.randint(0, 2, size=(num_samples, 1))
        noise = np.random.normal(0.0, scale=0.2, size=(num_samples, 1))
        predictions = 0.4 + 0.2 * labels + noise
        predictions[predictions > 1] = 1
        predictions[predictions < 0] = 0
        thresholds = [-0.01, 0.5, 1.01]

        expected_max_f1 = -1.0
        for threshold in thresholds:
            tp = 0
            fp = 0
            fn = 0
            tn = 0
            for i in range(num_samples):
                if predictions[i] >= threshold:
                    if labels[i] == 1:
                        tp += 1
                    else:
                        fp += 1
                else:
                    if labels[i] == 1:
                        fn += 1
                    else:
                        tn += 1
            epsilon = 1e-7
            expected_prec = tp / (epsilon + tp + fp)
            expected_rec = tp / (epsilon + tp + fn)
            expected_f1 = (2 * expected_prec * expected_rec /
                           (epsilon + expected_prec + expected_rec))
            if expected_f1 > expected_max_f1:
                expected_max_f1 = expected_f1

        labels = labels.astype(np.float32)
        predictions = predictions.astype(np.float32)
        tf_predictions, tf_labels = dataset_ops.make_one_shot_iterator(
            dataset_ops.Dataset.from_tensor_slices(
                (predictions, labels)).repeat().batch(batch_size)).get_next()
        f1, f1_op = classification.f1_score(tf_labels,
                                            tf_predictions,
                                            num_thresholds=3)

        with self.cached_session() as sess:
            sess.run(variables.local_variables_initializer())
            for _ in range(num_batches):
                sess.run([f1_op])
            # Since this is only approximate, we can't expect a 6 digits match.
            # Although with higher number of samples/thresholds we should see the
            # accuracy improving
            self.assertAlmostEqual(expected_max_f1, f1.eval(), 2)
예제 #5
0
    def testZeroLabelsPredictions(self):
        with self.cached_session() as sess:
            predictions = array_ops.zeros([4], dtype=dtypes.float32)
            labels = array_ops.zeros([4])
            f1, f1_op = classification.f1_score(predictions,
                                                labels,
                                                num_thresholds=3)
            sess.run(variables.local_variables_initializer())
            sess.run([f1_op])

            self.assertAlmostEqual(0.0, f1.eval(), places=5)
예제 #6
0
 def testSomeCorrect(self):
     predictions = constant_op.constant([1, 0, 1, 0],
                                        shape=(1, 4),
                                        dtype=dtypes.float32)
     labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4))
     f1, f1_op = classification.f1_score(predictions,
                                         labels,
                                         num_thresholds=1)
     with self.cached_session() as sess:
         sess.run(variables.local_variables_initializer())
         sess.run([f1_op])
         # Threshold 0 will have around 0.5 precision and 1 recall yielding an F1
         # score of 2 * 0.5 * 1 / (1 + 0.5).
         self.assertAlmostEqual(2 * 0.5 * 1 / (1 + 0.5), f1.eval())
예제 #7
0
    def testAllCorrect(self):
        inputs = np.random.randint(0, 2, size=(100, 1))

        with self.cached_session() as sess:
            predictions = constant_op.constant(inputs, dtype=dtypes.float32)
            labels = constant_op.constant(inputs)
            f1, f1_op = classification.f1_score(predictions,
                                                labels,
                                                num_thresholds=3)

            sess.run(variables.local_variables_initializer())
            sess.run([f1_op])

            self.assertEqual(1, f1.eval())
예제 #8
0
    def testWeights2d(self):
        with self.cached_session() as sess:
            predictions = constant_op.constant([[1, 0], [1, 0]],
                                               shape=(2, 2),
                                               dtype=dtypes.float32)
            labels = constant_op.constant([[0, 1], [1, 0]], shape=(2, 2))
            weights = constant_op.constant([[0, 0], [1, 1]],
                                           shape=(2, 2),
                                           dtype=dtypes.float32)
            f1, f1_op = classification.f1_score(predictions,
                                                labels,
                                                weights,
                                                num_thresholds=3)
            sess.run(variables.local_variables_initializer())
            sess.run([f1_op])

            self.assertAlmostEqual(1.0, f1.eval(), places=5)
예제 #9
0
    def testAllIncorrect(self):
        inputs = np.random.randint(0, 2, size=(10000, 1))

        with self.cached_session() as sess:
            predictions = constant_op.constant(inputs, dtype=dtypes.float32)
            labels = constant_op.constant(1 - inputs, dtype=dtypes.float32)
            f1, f1_op = classification.f1_score(predictions,
                                                labels,
                                                num_thresholds=3)

            sess.run(variables.local_variables_initializer())
            sess.run([f1_op])

            # Threshold 0 will have around 0.5 precision and 1 recall yielding an F1
            # score of 2 * 0.5 * 1 / (1 + 0.5).
            self.assertAlmostEqual(2 * 0.5 * 1 / (1 + 0.5),
                                   f1.eval(),
                                   places=2)
예제 #10
0
    def testValueTensorIsIdempotent(self):
        predictions = random_ops.random_uniform((10, 3),
                                                maxval=1,
                                                dtype=dtypes.float32,
                                                seed=1)
        labels = random_ops.random_uniform((10, 3),
                                           maxval=2,
                                           dtype=dtypes.int64,
                                           seed=2)
        f1, f1_op = classification.f1_score(predictions,
                                            labels,
                                            num_thresholds=3)

        with self.cached_session() as sess:
            sess.run(variables.local_variables_initializer())

            # Run several updates.
            for _ in range(10):
                sess.run([f1_op])

            # Then verify idempotency.
            initial_f1 = f1.eval()
            for _ in range(10):
                self.assertAllClose(initial_f1, f1.eval())