コード例 #1
0
    def test_squad_tensorizer(self):
        source = SquadDataSource.from_config(
            SquadDataSource.Config(
                eval_filename=tests_module.test_file("squad_tiny.json")
            )
        )
        row = next(iter(source.eval))
        tensorizer = SquadForBERTTensorizer.from_config(
            SquadForBERTTensorizer.Config(
                tokenizer=WordPieceTokenizer.Config(
                    wordpiece_vocab_path="pytext/data/test/data/wordpiece_1k.txt"
                ),
                max_seq_len=250,
            )
        )
        tokens, segments, seq_len, start, end = tensorizer.numberize(row)
        # check against manually verified answer positions in tokenized output
        # there are 4 identical answers
        self.assertEqual(start, [83, 83, 83, 83])
        self.assertEqual(end, [87, 87, 87, 87])
        self.assertEqual(len(tokens), seq_len)
        self.assertEqual(len(segments), seq_len)

        tensorizer.max_seq_len = 50
        # answer should be truncated out
        _, _, _, start, end = tensorizer.numberize(row)
        self.assertEqual(start, [-100, -100, -100, -100])
        self.assertEqual(end, [-100, -100, -100, -100])
        self.assertEqual(len(tokens), seq_len)
        self.assertEqual(len(segments), seq_len)
コード例 #2
0
 class ModelInput(BaseModel.Config.ModelInput):
     squad_input: Union[
         SquadForBERTTensorizer.Config, SquadForRoBERTaTensorizer.Config
     ] = SquadForBERTTensorizer.Config(max_seq_len=256)
     # is_impossible label
     has_answer: LabelTensorizer.Config = LabelTensorizer.Config(
         column="has_answer"
     )