def testSelectTopKBatched(self):
        labels = np.array([[2], [3]])
        preds = np.array([[0.4, 0.1, 0.2, 0.3], [0.1, 0.2, 0.1, 0.6]])
        got_labels, got_preds = metric_util.select_top_k(2, labels, preds)

        self.assertAllClose(got_labels, np.array([[0, 0], [1, 0]]))
        self.assertAllClose(got_preds, np.array([[0.4, 0.3], [0.6, 0.2]]))
  def testSelectTopKWithBinaryClassification(self):
    labels = np.array([0])
    preds = np.array([0.2, 0.8])
    got_labels, got_preds = metric_util.select_top_k(1, labels, preds)

    self.assertAllClose(got_labels, np.array([0]))
    self.assertAllClose(got_preds, np.array([0.8]))
    def testSelectTopK(self):
        labels = np.array([3])
        preds = np.array([0.4, 0.1, 0.2, 0.3])
        got_labels, got_preds = metric_util.select_top_k(2, labels, preds)

        self.assertAllClose(got_labels, np.array([0, 1]))
        self.assertAllClose(got_preds, np.array([0.4, 0.3]))
  def testSelectTopKUsingSeparateScores(self):
    labels = np.array(['', '', '', 'c'])
    preds = np.array(['b', 'c', 'a', 'd'])
    scores = np.array([0.4, 0.1, 0.2, 0.3])
    got_labels, got_preds = metric_util.select_top_k(2, labels, preds, scores)

    self.assertSequenceEqual(list(got_labels), ['', 'c'])
    self.assertSequenceEqual(list(got_preds), ['b', 'd'])