def test_collate_batch_idxs(self) -> None: """ Tests that (q|c)[len_sorted][orig_idxs] == q|c """ samples: List[EncodedSample] = [ self.make_sample("c1 c2 c3", [], "q0", "c1 c2"), self.make_sample("c1", [], "q1", "c1 c2 c3"), self.make_sample("c1 c2", [], "q2", "c1"), ] batch: QABatch = collate_batch(samples) self.assertTrue(np.allclose(batch.context_len_idxs, [0, 2, 1])) self.assertTrue(np.allclose(batch.question_len_idxs, [1, 0, 2])) self.assertTrue( np.allclose( batch.question_words, (batch.question_words[batch.question_len_idxs] )[batch.question_orig_idxs], )) self.assertTrue( np.allclose( batch.context_words, (batch.context_words[batch.context_len_idxs] )[batch.context_orig_idxs], ))
def test_collate_batch_q_len_sorting(self) -> None: """ Tests that question lengths in batch are sorted, and len_idxs indices map to correct indices """ samples: List[EncodedSample] = [ self.make_sample("c1", [], "q0", "c1"), self.make_sample("c1", [], "q1", "c1 c2 c3"), self.make_sample("c1", [], "q2", "c1 c2"), ] batch: QABatch = collate_batch(samples) self.assertTrue( np.all(batch.question_words.numpy() == np.stack( [[1, 0, 0], [1, 2, 3], [1, 2, 0]])), "Batch questions: {0} Expected: {1}".format( batch.question_words, [[1, 0, 0], [1, 2, 3], [1, 2, 0]]), ) self.assertTrue( np.allclose(batch.question_lens, [3, 2, 1]), "Question lens: {0} expected: {1}".format(batch.question_lens, [3, 2, 1]), ) self.assertTrue( np.allclose(batch.question_len_idxs, [1, 2, 0]), "Question len idxs: {0} expected: {1}".format( batch.question_len_idxs, [1, 2, 0]), )
def test_collate_batch_different_context_word_lens(self) -> None: """ Tests that collate batch deals with questions with different context word lengths """ samples: List[EncodedSample] = [ self.make_sample("c00", [], "q0", "c1"), self.make_sample("c1", [], "q1", "c1"), ] batch: QABatch = collate_batch(samples) self.assertEqual(batch.context_chars.shape, t.Size([2, 1, 3])) self.check_collated_chars(batch.context_chars, batch.context_words)
def test_collate_batch_context_chars(self) -> None: """ Tests that collate batch includes all context word characters that are parsed correctly """ samples: List[EncodedSample] = [ self.make_sample("c1", [], "q0", "c1"), self.make_sample("c2", [], "q1", "c1"), self.make_sample("c3", [], "q2", "c1"), ] batch: QABatch = collate_batch(samples) self.assertEqual(batch.question_chars.shape, t.Size([3, 1, 2])) self.check_collated_chars(batch.context_chars, batch.context_words)
def test_collate_batch_different_question_word_numbers_and_lens( self) -> None: """ Tests that collate batch deals with questions with different question word lengths and numbers """ samples: List[EncodedSample] = [ self.make_sample("c1", [], "q0", "c1"), self.make_sample("c1", [], "q1", "c00"), self.make_sample("c1", [], "q2", "c1 c00"), ] batch: QABatch = collate_batch(samples) self.assertEqual(batch.question_chars.shape, t.Size([3, 2, 3])) self.check_collated_chars(batch.question_chars, batch.question_words)
def test_collate_batch_ctx_len_sorting(self) -> None: """ Tests that question lengths in batch are sorted, and len_idxs indices map to correct indices """ samples: List[EncodedSample] = [ self.make_sample("c1 c2 c3", [], "q0", "c1"), self.make_sample("c1", [], "q1", "c1"), self.make_sample("c1 c2", [], "q2", "c1"), ] batch: QABatch = collate_batch(samples) self.assertTrue(np.allclose(batch.context_lens, [3, 2, 1])) self.assertTrue(np.allclose(batch.context_len_idxs, [0, 2, 1]))
def test_collate_batch_simple_words(self) -> None: """ Tests that collate batch includes all question ids in original order """ samples: List[EncodedSample] = [ self.make_sample("c1", [], "q0", "c1"), self.make_sample("c1", [], "q1", "c2"), self.make_sample("c1", [], "q2", "c3"), ] batch: QABatch = collate_batch(samples) self.assertEqual( batch.question_ids, [QuestionId("q0"), QuestionId("q1"), QuestionId("q2")]) self.assertTrue( np.all(batch.question_words.numpy() == np.stack([[1], [2], [3]])), "Batch questions: {0} Expected: {1}".format( batch.question_words, [[1], [2], [3]]), )