def testSelectClassIDEmpty(self): labels = np.array(np.array([])) preds = np.array(np.array([])) got_labels, got_preds = metric_util.select_class_id(1, labels, preds) self.assertAllClose(got_labels, np.array([])) self.assertAllClose(got_preds, np.array([]))
def testSelectClassIDMultiDim(self): labels = np.array([[[0, 0, 1]]]) preds = np.array([[[0.2, 0.7, 0.1]]]) got_labels, got_preds = metric_util.select_class_id(1, labels, preds) self.assertAllClose(got_labels, np.array([[[0]]])) self.assertAllClose(got_preds, np.array([[[0.7]]]))
def testSelectClassIDWithMultipleValues(self): labels = np.array([[0, 0, 1], [0, 0, 1], [0, 1, 0]]) preds = np.array([[0.2, 0.7, 0.1], [0.3, 0.6, 0.1], [0.1, 0.2, 0.7]]) got_labels, got_preds = metric_util.select_class_id(1, labels, preds) self.assertAllClose(got_labels, np.array([[0], [0], [1]])) self.assertAllClose(got_preds, np.array([[0.7], [0.6], [0.2]]))
def testSelectClassIDSparseBatched(self): labels = np.array([[0], [2], [1]]) preds = np.array([[0.2, 0.7, 0.1], [0.3, 0.6, 0.1], [0.1, 0.2, 0.7]]) got_labels, got_preds = metric_util.select_class_id(1, labels, preds) self.assertAllClose(got_labels, np.array([[0], [0], [1]])) self.assertAllClose(got_preds, np.array([[0.7], [0.6], [0.2]]))
def testSelectClassIDSparseNoShape(self): labels = np.array(2) preds = np.array([0.2, 0.7, 0.1]) got_labels, got_preds = metric_util.select_class_id(1, labels, preds) self.assertAllClose(got_labels, np.array([0])) self.assertAllClose(got_preds, np.array([0.7]))
def testRaisesErrorForInvalidNonSparseSettings(self): with self.assertRaises(ValueError): labels = np.array([5]) preds = np.array([0.2, 0.7, 0.1]) metric_util.select_class_id(1, labels, preds, sparse_labels=False)