def testSelectClassIDEmpty(self):
        labels = np.array(np.array([]))
        preds = np.array(np.array([]))
        got_labels, got_preds = metric_util.select_class_id(1, labels, preds)

        self.assertAllClose(got_labels, np.array([]))
        self.assertAllClose(got_preds, np.array([]))
    def testSelectClassIDMultiDim(self):
        labels = np.array([[[0, 0, 1]]])
        preds = np.array([[[0.2, 0.7, 0.1]]])
        got_labels, got_preds = metric_util.select_class_id(1, labels, preds)

        self.assertAllClose(got_labels, np.array([[[0]]]))
        self.assertAllClose(got_preds, np.array([[[0.7]]]))
    def testSelectClassIDWithMultipleValues(self):
        labels = np.array([[0, 0, 1], [0, 0, 1], [0, 1, 0]])
        preds = np.array([[0.2, 0.7, 0.1], [0.3, 0.6, 0.1], [0.1, 0.2, 0.7]])
        got_labels, got_preds = metric_util.select_class_id(1, labels, preds)

        self.assertAllClose(got_labels, np.array([[0], [0], [1]]))
        self.assertAllClose(got_preds, np.array([[0.7], [0.6], [0.2]]))
    def testSelectClassIDSparseBatched(self):
        labels = np.array([[0], [2], [1]])
        preds = np.array([[0.2, 0.7, 0.1], [0.3, 0.6, 0.1], [0.1, 0.2, 0.7]])
        got_labels, got_preds = metric_util.select_class_id(1, labels, preds)

        self.assertAllClose(got_labels, np.array([[0], [0], [1]]))
        self.assertAllClose(got_preds, np.array([[0.7], [0.6], [0.2]]))
    def testSelectClassIDSparseNoShape(self):
        labels = np.array(2)
        preds = np.array([0.2, 0.7, 0.1])
        got_labels, got_preds = metric_util.select_class_id(1, labels, preds)

        self.assertAllClose(got_labels, np.array([0]))
        self.assertAllClose(got_preds, np.array([0.7]))
 def testRaisesErrorForInvalidNonSparseSettings(self):
     with self.assertRaises(ValueError):
         labels = np.array([5])
         preds = np.array([0.2, 0.7, 0.1])
         metric_util.select_class_id(1, labels, preds, sparse_labels=False)