示例#1
0
    def test_collate_batch_idxs(self) -> None:
        """
        Tests that (q|c)[len_sorted][orig_idxs] == q|c
        """
        samples: List[EncodedSample] = [
            self.make_sample("c1 c2 c3", [], "q0", "c1 c2"),
            self.make_sample("c1", [], "q1", "c1 c2 c3"),
            self.make_sample("c1 c2", [], "q2", "c1"),
        ]

        batch: QABatch = collate_batch(samples)

        self.assertTrue(np.allclose(batch.context_len_idxs, [0, 2, 1]))
        self.assertTrue(np.allclose(batch.question_len_idxs, [1, 0, 2]))
        self.assertTrue(
            np.allclose(
                batch.question_words,
                (batch.question_words[batch.question_len_idxs]
                 )[batch.question_orig_idxs],
            ))
        self.assertTrue(
            np.allclose(
                batch.context_words,
                (batch.context_words[batch.context_len_idxs]
                 )[batch.context_orig_idxs],
            ))
示例#2
0
    def test_collate_batch_q_len_sorting(self) -> None:
        """
        Tests that question lengths in batch are sorted, and
        len_idxs indices map to correct indices
        """
        samples: List[EncodedSample] = [
            self.make_sample("c1", [], "q0", "c1"),
            self.make_sample("c1", [], "q1", "c1 c2 c3"),
            self.make_sample("c1", [], "q2", "c1 c2"),
        ]

        batch: QABatch = collate_batch(samples)
        self.assertTrue(
            np.all(batch.question_words.numpy() == np.stack(
                [[1, 0, 0], [1, 2, 3], [1, 2, 0]])),
            "Batch questions: {0} Expected: {1}".format(
                batch.question_words, [[1, 0, 0], [1, 2, 3], [1, 2, 0]]),
        )
        self.assertTrue(
            np.allclose(batch.question_lens, [3, 2, 1]),
            "Question lens: {0} expected: {1}".format(batch.question_lens,
                                                      [3, 2, 1]),
        )
        self.assertTrue(
            np.allclose(batch.question_len_idxs, [1, 2, 0]),
            "Question len idxs: {0} expected: {1}".format(
                batch.question_len_idxs, [1, 2, 0]),
        )
示例#3
0
 def test_collate_batch_different_context_word_lens(self) -> None:
     """
     Tests that collate batch deals with questions with different
     context word lengths
     """
     samples: List[EncodedSample] = [
         self.make_sample("c00", [], "q0", "c1"),
         self.make_sample("c1", [], "q1", "c1"),
     ]
     batch: QABatch = collate_batch(samples)
     self.assertEqual(batch.context_chars.shape, t.Size([2, 1, 3]))
     self.check_collated_chars(batch.context_chars, batch.context_words)
示例#4
0
 def test_collate_batch_context_chars(self) -> None:
     """
     Tests that collate batch includes all context word characters that are parsed correctly
     """
     samples: List[EncodedSample] = [
         self.make_sample("c1", [], "q0", "c1"),
         self.make_sample("c2", [], "q1", "c1"),
         self.make_sample("c3", [], "q2", "c1"),
     ]
     batch: QABatch = collate_batch(samples)
     self.assertEqual(batch.question_chars.shape, t.Size([3, 1, 2]))
     self.check_collated_chars(batch.context_chars, batch.context_words)
示例#5
0
 def test_collate_batch_different_question_word_numbers_and_lens(
         self) -> None:
     """
     Tests that collate batch deals with questions with different
     question word lengths and numbers
     """
     samples: List[EncodedSample] = [
         self.make_sample("c1", [], "q0", "c1"),
         self.make_sample("c1", [], "q1", "c00"),
         self.make_sample("c1", [], "q2", "c1 c00"),
     ]
     batch: QABatch = collate_batch(samples)
     self.assertEqual(batch.question_chars.shape, t.Size([3, 2, 3]))
     self.check_collated_chars(batch.question_chars, batch.question_words)
示例#6
0
    def test_collate_batch_ctx_len_sorting(self) -> None:
        """
        Tests that question lengths in batch are sorted, and
        len_idxs indices map to correct indices
        """
        samples: List[EncodedSample] = [
            self.make_sample("c1 c2 c3", [], "q0", "c1"),
            self.make_sample("c1", [], "q1", "c1"),
            self.make_sample("c1 c2", [], "q2", "c1"),
        ]

        batch: QABatch = collate_batch(samples)
        self.assertTrue(np.allclose(batch.context_lens, [3, 2, 1]))
        self.assertTrue(np.allclose(batch.context_len_idxs, [0, 2, 1]))
示例#7
0
 def test_collate_batch_simple_words(self) -> None:
     """
     Tests that collate batch includes all question ids in original order
     """
     samples: List[EncodedSample] = [
         self.make_sample("c1", [], "q0", "c1"),
         self.make_sample("c1", [], "q1", "c2"),
         self.make_sample("c1", [], "q2", "c3"),
     ]
     batch: QABatch = collate_batch(samples)
     self.assertEqual(
         batch.question_ids,
         [QuestionId("q0"),
          QuestionId("q1"),
          QuestionId("q2")])
     self.assertTrue(
         np.all(batch.question_words.numpy() == np.stack([[1], [2], [3]])),
         "Batch questions: {0} Expected: {1}".format(
             batch.question_words, [[1], [2], [3]]),
     )