Ejemplo n.º 1
0
from simple_elmo import ElmoModel
from blimp_utils import get_ppl
from spacy.lang.en import English

model = ElmoModel()
model.load(sys.argv[1], max_batch_size=1, full=True)

tokenizer = English().tokenizer

pairs = [("Could that window ever shut?", "That window could ever shut."),
         ("could that window ever shut?", "that window could ever shut."),
         ("Could that window ever shut.", "That window could ever shut?"),
         ("Can it mean something?", "Piece does mean something."),
         ("Has the river ever frozen?", "The river has ever frozen."),
         ("Is it really good?", "It is really good?"),
         ("Is it really good.", "It is really good.")]

print("=========================")
for pair in pairs:
    good, bad = pair
    print(good)
    print(get_ppl(model.get_elmo_substitutes([" ".join([token.text for token in tokenizer(good)])],
                                             topn=model.vocab.size)[0], "bidirectional"))
    print(bad)
    print(get_ppl(model.get_elmo_substitutes([" ".join([token.text for token in tokenizer(bad)])],
                                             topn=model.vocab.size)[0], "bidirectional"))
    print("=========================")



    target_substitutes = {w: [] for w in targets}

    start = time.time()
    CACHE = 1000
    lines_processed = 0
    lines_cache = []

    with open(data_path, "r") as dataset:
        for line in dataset:
            res = line.strip().split()[:WORD_LIMIT]
            if targets & set(res):
                lines_cache.append(" ".join(res))
            lines_processed += 1
            if len(lines_cache) == CACHE:
                lex_substitutes = model.get_elmo_substitutes(lines_cache)
                for sent in lex_substitutes:
                    for word in sent:
                        if word["word"] in targets:
                            data2add = {
                                el: word[el]
                                for el in word if el != "word"
                            }
                            target_substitutes[word["word"]].append(data2add)
                lines_cache = []
            if lines_processed % 5120 == 0:
                logger.info(f"{data_path}; Lines processed: {lines_processed}")
        if lines_cache:
            logger.debug(f"We fed {len(lines_cache)} sentences")
            lex_substitutes = model.get_elmo_substitutes(lines_cache)
            logger.debug(f"We have {len(lex_substitutes)} sentences")