Ejemplo n.º 1
0
    def setUp(self):
        self.json_data_source = SquadDataSource.from_config(
            SquadDataSource.Config(
                train_filename=tests_module.test_file("squad_tiny.json"),
                eval_filename=None,
                test_filename=None,
            )
        )
        self.tsv_data_source = SquadDataSource.from_config(
            SquadDataSource.Config(
                train_filename=tests_module.test_file("squad_tiny.tsv"),
                eval_filename=None,
                test_filename=None,
            )
        )

        self.tensorizer_with_wordpiece = SquadTensorizer.from_config(
            SquadTensorizer.Config(
                tokenizer=WordPieceTokenizer.Config(
                    wordpiece_vocab_path="pytext/data/test/data/wordpiece_1k.txt"
                ),
                max_seq_len=250,
            )
        )
        self.tensorizer_with_alphanumeric = SquadTensorizer.from_config(
            SquadTensorizer.Config(
                tokenizer=Tokenizer.Config(split_regex=r"\W+"), max_seq_len=250
            )
        )
Ejemplo n.º 2
0
    def test_squad_tensorizer(self):
        source = SquadDataSource.from_config(
            SquadDataSource.Config(
                eval_filename=tests_module.test_file("squad_tiny.json")
            )
        )
        row = next(iter(source.eval))
        tensorizer = SquadForBERTTensorizer.from_config(
            SquadForBERTTensorizer.Config(
                tokenizer=WordPieceTokenizer.Config(
                    wordpiece_vocab_path="pytext/data/test/data/wordpiece_1k.txt"
                ),
                max_seq_len=250,
            )
        )
        tokens, segments, seq_len, start, end = tensorizer.numberize(row)
        # check against manually verified answer positions in tokenized output
        # there are 4 identical answers
        self.assertEqual(start, [83, 83, 83, 83])
        self.assertEqual(end, [87, 87, 87, 87])
        self.assertEqual(len(tokens), seq_len)
        self.assertEqual(len(segments), seq_len)

        tensorizer.max_seq_len = 50
        # answer should be truncated out
        _, _, _, start, end = tensorizer.numberize(row)
        self.assertEqual(start, [-100, -100, -100, -100])
        self.assertEqual(end, [-100, -100, -100, -100])
        self.assertEqual(len(tokens), seq_len)
        self.assertEqual(len(segments), seq_len)