def test_can_modify(self):
        pretok = CharDelimiterSplit("@")
        assert pretok.delimiter == "@"

        # Modify these
        pretok.delimiter = "!"
        assert pretok.delimiter == "!"
Beispiel #2
0
 def test_instantiate(self):
     assert CharDelimiterSplit("-") is not None
     with pytest.raises(Exception, match="delimiter must be a single character"):
         CharDelimiterSplit("")
     assert isinstance(CharDelimiterSplit(" "), PreTokenizer)
     assert isinstance(CharDelimiterSplit(" "), CharDelimiterSplit)
     assert isinstance(pickle.loads(pickle.dumps(CharDelimiterSplit("-"))), CharDelimiterSplit)
 def test_instantiate(self):
     assert CharDelimiterSplit("-") is not None
     with pytest.raises(ValueError, match="expected a string of length 1"):
         CharDelimiterSplit("")
     assert isinstance(CharDelimiterSplit(" "), PreTokenizer)
     assert isinstance(CharDelimiterSplit(" "), CharDelimiterSplit)
     assert isinstance(pickle.loads(pickle.dumps(CharDelimiterSplit("-"))),
                       CharDelimiterSplit)
    def __init__(
        self,
        vocab_file,
        delimiter,
        lowercase,
        unk_token,
        eos_token,
        add_eos=False,
        add_double_eos=False,
        normalization: Optional[str] = None,
    ):

        try:
            tokenizer = WordLevel(vocab_file, unk_token=unk_token)
            tokenizer = Tokenizer(tokenizer)
        except Exception:
            raise ValueError(
                "Unable to parse file {}. Unknown format. "
                "If you tried to load a model saved through TransfoXLTokenizer,"
                "please note they are not compatible.".format(vocab_file))

        # Create the correct normalization path
        normalizer = []

        # Include unicode normalization
        if normalization:
            normalizer += [unicode_normalizer_from_str(normalization)]

        # Include case normalization
        if lowercase:
            normalizer += [Lowercase()]

        # Strip normalizer at the end
        normalizer += [Strip(left=True, right=True)]

        if len(normalizer) > 0:
            tokenizer.normalizer = Sequence(
                normalizer) if len(normalizer) > 1 else normalizer[0]

        # Setup the splitter
        tokenizer.pre_tokenizer = CharDelimiterSplit(
            delimiter) if delimiter else WhitespaceSplit()

        if add_double_eos:
            tokenizer.post_processor = BertProcessing(
                (eos_token, tokenizer.token_to_id(eos_token)),
                (eos_token, tokenizer.token_to_id(eos_token)))

        parameters = {
            "model": "TransfoXLModel",
            "add_eos": add_eos,
            "add_double_eos": add_double_eos,
            "unk_token": unk_token,
            "eos_token": eos_token,
            "delimiter": delimiter,
            "lowercase": lowercase,
        }

        super().__init__(tokenizer, parameters)
Beispiel #5
0
    def __init__(
        self,
        vocab_file,
        delimiter,
        lowercase,
        unk_token,
        eos_token,
        add_eos=False,
        add_double_eos=False,
        normalization: Optional[str] = None,
    ):

        tokenizer = WordLevel.from_files(vocab_file, unk_token=unk_token)
        tokenizer = Tokenizer(tokenizer)

        # Create the correct normalization path
        normalizer = []

        # Include unicode normalization
        if normalization:
            normalizer += [unicode_normalizer_from_str(normalization)]

        # Include case normalization
        if lowercase:
            normalizer += [Lowercase()]

        if len(normalizer) > 0:
            tokenizer.normalizer = Sequence(
                normalizer) if len(normalizer) > 1 else normalizer[0]

        # Setup the splitter
        tokenizer.pre_tokenizer = CharDelimiterSplit(
            delimiter) if delimiter else WhitespaceSplit()

        if add_double_eos:
            tokenizer.post_processor = BertProcessing(
                (eos_token, tokenizer.token_to_id(eos_token)),
                (eos_token, tokenizer.token_to_id(eos_token)))

        parameters = {
            "model": "TransfoXLModel",
            "add_eos": add_eos,
            "add_double_eos": add_double_eos,
            "unk_token": unk_token,
            "eos_token": eos_token,
            "delimiter": delimiter,
            "lowercase": lowercase,
        }

        super().__init__(tokenizer, parameters)
Beispiel #6
0
    def __init__(
        self,
        vocab_file,
        sep_token="<sep>",
        cls_token="<cls>",
        pad_token="<pad>",
        mask_token="<mask>",
        lowercase: bool = True,
    ):

        tokenizer = Tokenizer(WordLevel(vocab_file, unk_token=unk_token))
        tokenizer.normalizer = Strip()
        tokenizer.pre_tokenizer = CharDelimiterSplit(" ")

        tokenizer.post_processor = BertProcessing(
            ("</s>", tokenizer.token_to_id("</s>")),
            ("<s>", tokenizer.token_to_id("<s>")),
        )
        tokenizer.enable_truncation(max_length=512)

        # Let the tokenizer know about special tokens if they are part of the vocab
        if tokenizer.token_to_id(str(unk_token)) is not None:
            tokenizer.add_special_tokens([str(unk_token)])
        if tokenizer.token_to_id(str(sep_token)) is not None:
            tokenizer.add_special_tokens([str(sep_token)])
        if tokenizer.token_to_id(str(cls_token)) is not None:
            tokenizer.add_special_tokens([str(cls_token)])
        if tokenizer.token_to_id(str(pad_token)) is not None:
            tokenizer.add_special_tokens([str(pad_token)])
        if tokenizer.token_to_id(str(mask_token)) is not None:
            tokenizer.add_special_tokens([str(mask_token)])

        parameters = {
            "model": "WordLevel",
            "unk_token": unk_token,
            "sep_token": sep_token,
            "cls_token": cls_token,
            "pad_token": pad_token,
            "mask_token": mask_token,
            "lowercase": lowercase,
        }

        super().__init__(tokenizer, parameters)
Beispiel #7
0
data_path = Path('/workspace/poetry2021.gt/data/pan_tadeusz5')
dataset_path = data_path / 'dataset'
vocab_path = data_path / 'vocab.json'
tokenizer_tmp_path = data_path / 'tokenizer_tmp'
tokenizer_path = data_path / 'tokenizer'

text_tokenizer = TextTokenizer(dataset_path)
text_tokenizer.load_vocab(vocab_path)

vocab = text_tokenizer.vocab
vocab_count = len(vocab.keys())
vocab.update({'<|endoftext|>': vocab_count})

tokenizer_tmp = Tokenizer(WordLevel(text_tokenizer.vocab))
tokenizer_tmp.pre_tokenizer = CharDelimiterSplit(' ')

tokenizer_tmp.post_processor = BertProcessing(
    ("<|endoftext|>", tokenizer_tmp.token_to_id("<|endoftext|>")),
    ("<|endoftext|>", tokenizer_tmp.token_to_id("<|endoftext|>")),
)

tokenizer_tmp_path.mkdir(parents=True, exist_ok=True)
tokenizer_tmp.save(str(tokenizer_tmp_path / "tokenizer.json"))

# Re-create as GPT2 compatible tokenizer


class GPT2CompatibleTokenizer(PreTrainedTokenizerFast):
    def save_vocabulary(self,
                        save_directory: str,
    for line in tqdm(fin):
        dp = json.loads(line.strip())
        for d in enumerate(dp):
            if "value" in d:
                if "," in d["value"]:
                    print('Not cleaned up')

# Extract value/types from trees and store in comma separated raw file (all_raw.json)

with open("output/all_new_trees.json") as fin, open("output/all_raw.json",
                                                    "w") as fout:
    for i, line in enumerate(tqdm(fin)):
        dp = json.loads(line)
        token_list = []
        for d in dp:
            if "value" in d:
                token_list.append(d["value"])
            elif "type" in d:
                token_list.append(d["type"])
        raw = ",".join(token_list)
        print(json.dumps(raw), file=fout)

# Train tokenizer on raw file

tokenizer = Tokenizer(WordPiece(unk_token="[UNK]"))
tokenizer.pre_tokenizer = CharDelimiterSplit(delimiter=",")
trainer = WordPieceTrainer(special_tokens=["[UNK]", "[PAD]"])

tokenizer.train(["output/all_raw.json"], trainer)

tokenizer.save("output/tokenizer.json")