Beispiel #1
0
    def get_prediction(self, voc_src, voc_trg, inputs, gold, output):
        tokenizer = get_tokenizer(self._tokenizer_name)

        input_string = tokenizer.detokenize(
            [voc_src[token.item()] for token in inputs]).split("<EOS>")[0]
        gold_string = tokenizer.detokenize(
            [voc_trg[token.item()] for token in gold]).split("<EOS>")[0]
        output_string = tokenizer.detokenize(
            [voc_trg[token.item()] for token in output]).split("<EOS>")[0]

        return input_string, gold_string, output_string
Beispiel #2
0
def align_bytebpe(text: Text,
                  tokenizer_name: str) -> Tuple[TokenAligner, List[Text]]:
    """Alignment fn for Byte-level BPE tokenizer, used in GPT-2 and RoBERTa
    """
    bow_tokens = space_tokenize_with_bow(text)
    bytebpe_tokenizer = get_tokenizer(tokenizer_name)
    bytebpe_tokens = bytebpe_tokenizer.tokenize(text)

    modified_bytebpe_tokens = list(
        map(process_bytebpe_for_alignment, bytebpe_tokens))
    ta = TokenAligner(bow_tokens, modified_bytebpe_tokens)
    return ta, bytebpe_tokens
Beispiel #3
0
def align_sentencepiece(
        text: Text, tokenizer_name: str) -> Tuple[TokenAligner, List[Text]]:
    """Alignment fn for SentencePiece Tokenizer, used in XLNET
    """
    bow_tokens = space_tokenize_with_bow(text)
    sentencepiece_tokenizer = get_tokenizer(tokenizer_name)
    sentencepiece_tokens = sentencepiece_tokenizer.tokenize(text)

    modified_sentencepiece_tokens = list(
        map(process_sentencepiece_for_alignment, sentencepiece_tokens))
    ta = TokenAligner(bow_tokens, modified_sentencepiece_tokens)
    return ta, sentencepiece_tokens
Beispiel #4
0
def get_aligner_fn(tokenizer_name: Text):
    """Given the tokenzier_name, return the corresponding alignment function.
    An alignment function modified the tokenized input to make it close to source token,
    and choose a space tokenizer with its word-boundary at the same side as tokenizer_name,
    hence the source (from space tokenizer) and target (from tokenizer_name) is sufficiently close.
    Use TokenAligner to project token index from source to target.
    """
    if tokenizer_name == "MosesTokenizer" or tokenizer_name.startswith(
            "transfo-xl-"):
        return align_moses
    elif tokenizer_name.startswith("albert"):
        wpm_tokenizer = AlbertTokenizer(
            vocab_file='/work/dcml0714/albert/albert_base/30k-clean.model')
        return functools.partial(align_wpm,
                                 wpm_tokenizer=wpm_tokenizer,
                                 do_lower_case=True)
    elif tokenizer_name.startswith("bert-"):
        do_lower_case = tokenizer_name.endswith("uncased")
        wpm_tokenizer = get_tokenizer(tokenizer_name)
        return functools.partial(align_wpm,
                                 wpm_tokenizer=wpm_tokenizer,
                                 do_lower_case=do_lower_case)
    elif tokenizer_name.startswith("openai-gpt") or tokenizer_name.startswith(
            "xlm-mlm-en-"):
        bpe_tokenizer = get_tokenizer(tokenizer_name)
        return functools.partial(align_bpe, bpe_tokenizer=bpe_tokenizer)
    elif tokenizer_name.startswith("xlnet-") or tokenizer_name.startswith(
            "albert-"):
        sentencepiece_tokenizer = get_tokenizer(tokenizer_name)
        return functools.partial(
            align_sentencepiece,
            sentencepiece_tokenizer=sentencepiece_tokenizer)
    elif tokenizer_name.startswith("roberta-") or tokenizer_name.startswith(
            "gpt2"):
        bytebpe_tokenizer = get_tokenizer(tokenizer_name)
        return functools.partial(align_bytebpe,
                                 bytebpe_tokenizer=bytebpe_tokenizer)
    else:
        raise ValueError(f"Unsupported tokenizer '{tokenizer_name}'")
Beispiel #5
0
def process_sentence(tokenizer_name, sent, max_seq_len):
    """process a sentence """
    max_seq_len -= 2
    assert max_seq_len > 0, "Max sequence length should be at least 2!"
    tokenizer = get_tokenizer(tokenizer_name)
    if tokenizer_name.startswith("bert-"):
        sos_tok, eos_tok = BERT_CLS_TOK, BERT_SEP_TOK
    else:
        sos_tok, eos_tok = SOS_TOK, EOS_TOK
    if isinstance(sent, str):
        return [sos_tok] + tokenizer.tokenize(sent)[:max_seq_len] + [eos_tok]
    elif isinstance(sent, list):
        assert isinstance(sent[0], str), "Invalid sentence found!"
        return [sos_tok] + sent[:max_seq_len] + [eos_tok]
Beispiel #6
0
def align_wpm(text: Text,
              tokenizer_name: str) -> Tuple[TokenAligner, List[Text]]:
    """Alignment fn for WPM tokenizer, used in BERT
    """
    # If using lowercase, do this for the source tokens for better matching.
    do_lower_case = tokenizer_name.endswith("uncased")
    bow_tokens = space_tokenize_with_bow(
        text.lower() if do_lower_case else text)
    wpm_tokenizer = get_tokenizer(tokenizer_name)
    wpm_tokens = wpm_tokenizer.tokenize(text)

    # Align using <w> markers for stability w.r.t. word boundaries.
    modified_wpm_tokens = list(map(process_wordpiece_for_alignment,
                                   wpm_tokens))
    ta = TokenAligner(bow_tokens, modified_wpm_tokens)
    return ta, wpm_tokens
def test_bert_get_tokenized_string_span_map1():
    text = "What does أنۢبياء anbiyā' mean in English?"
    b_tokenizer = get_tokenizer("bert-large-cased")
    result = bert_get_tokenized_string_span_map(text,
                                                b_tokenizer.tokenize(text))
    assert tuple(result) == (
        ("What", 0, 5),
        ("does", 5, 9),
        ("[UNK]", 9, 18),
        ("an", 18, 20),
        ("##bi", 20, 22),
        ("##y", 22, 23),
        ("##ā", 23, 24),
        ("'", 24, 26),
        ("mean", 26, 31),
        ("in", 31, 34),
        ("English", 34, 41),
        ("?", 41, 42),
    )
def test_bert_get_tokenized_string_span_map2():
    text = "What does أنۢبياء أنۢبياء anbiyā' mean in English?"
    b_tokenizer = get_tokenizer("bert-large-cased")
    result = bert_get_tokenized_string_span_map(text,
                                                b_tokenizer.tokenize(text))
    assert tuple(result) == (
        ("What", 0, 5),
        ("does", 5, 9),
        ("[UNK]", 9, 26),
        ("[UNK]", 26, 26),
        ("an", 26, 28),
        ("##bi", 28, 30),
        ("##y", 30, 31),
        ("##ā", 31, 32),
        ("'", 32, 34),
        ("mean", 34, 39),
        ("in", 39, 42),
        ("English", 42, 49),
        ("?", 49, 50),
    )
from pytorch_pretrained_bert import BertTokenizer
from jiant.utils import retokenize, tokenizers, utils

log.basicConfig(format="%(asctime)s: %(message)s", datefmt="%m/%d %I:%M:%S %p", level=log.INFO)

PARSER = argparse.ArgumentParser()
PARSER.add_argument("-t", dest="tokenizer_name", type=str, required=True, help="Tokenizer name.")
PARSER.add_argument(
    "--num_parallel", type=int, default=4, help="Number of parallel processes to use."
)
PARSER.add_argument("inputs", type=str, nargs="+", help="Input JSON files.")

# For now, this module expects MosesTokenizer as the default.
# TODO: change this once we have better support in core utils.
MosesTokenizer = tokenizers.get_tokenizer("MosesTokenizer")
assert MosesTokenizer is not None


def retokenize_record(record, tokenizer_name):
    """Retokenize an edge probing example. Modifies in-place."""
    text = record["text"]
    aligner_fn = retokenize.get_aligner_fn(tokenizer_name)
    ta, new_tokens = aligner_fn(text)
    record["text"] = " ".join(new_tokens)
    for target in record["targets"]:
        if "span1" in target:
            target["span1"] = list(map(int, ta.project_span(*target["span1"])))
        if "span2" in target:
            target["span2"] = list(map(int, ta.project_span(*target["span2"])))
    return record