Beispiel #1
0
def process_article(filename):
    new_name = os.path.basename(filename).replace(".gz", "") + ".elmo"
    new_name = os.path.join(OUTPUT_DIR, new_name)

    _, doc = process_xml_text(filename, correct_idx=False, stem=False)
    outfile = open(new_name, "w")
    outfile.write("\n".join([" ".join(s) for s in doc]))
def extractEmbeddings(args, articles):
    # Initializes a dictionary that lets us go from (article, sent_id) -> sentence
    sentences = {}
    if args.search_entities:
        training_embeddings = EmbeddingManager()
    else:
        training_embeddings = SummedEmbeddings()
    list_invalid_articles = []
    example_number = 0

    nrc_vocab = load_nrc_vocab(nrc_file)

    for article in tqdm(articles):
        try:
            h5_path = os.path.join(
                data_path, articleToName(article, append_str=".txt.xml.hdf5"))
            xml_path = os.path.join(
                data_path, articleToName(article, append_str=".txt.xml"))
        except OSError:
            print("Unable to read file {}".format(articleToName(article)))
            list_invalid_articles.append(article)
            continue

        root, document = process_xml_text(xml_path,
                                          correct_idx=False,
                                          stem=False,
                                          lower=True)
        sentence_nodes = [
            sent_node for sent_node in extractItem(
                root, ['document', 'sentences'], 'sentence')
        ]
        try:
            with h5py.File(h5_path, 'r') as h5py_file:
                sent_to_idx = ast.literal_eval(
                    h5py_file.get("sentence_to_index")[0])
                idx_to_sent = {
                    int(idx): sent
                    for sent, idx in sent_to_idx.items()
                }
                # TODO: Replace with logging
                if len(sentence_nodes) != len(idx_to_sent):
                    print(
                        "Mismatch in number of sentences, {} vs {} for article {}. Skipping article"
                        .format(len(document), len(idx_to_sent), article))
                    list_invalid_articles.append(article)
                    continue

                # Get coreference for the doc and parse it in a format that we can work with.
                doc_coreference_dict = None
                if args.use_coref:
                    doc_coreference = [
                        coreference for coreference in extractItem(
                            root, ['document', 'coreference', 'coreference'],
                            'coreference')
                    ]
                    doc_coreference_dict = build_coreference_dict(
                        doc_coreference)

                for sent_idx in idx_to_sent:
                    sent_tokens = [
                        tok for tok in extractItem(sentence_nodes[sent_idx],
                                                   ['tokens'], 'token')
                    ]
                    sent_words = [
                        extractItem(tok, ['word']).text.lower()
                        for tok in sent_tokens
                    ]
                    sent_POS = [
                        extractItem(tok, ['POS']).text for tok in sent_tokens
                    ]
                    sent_lemmas = [
                        extractItem(tok, ['lemma']).text.lower()
                        for tok in sent_tokens
                    ]
                    sent_embeddings = h5py_file.get(str(sent_idx))
                    sent_dependencies = [
                        dep for dep in extractItem(sentence_nodes[sent_idx],
                                                   ['dependencies'], 'dep')
                    ]

                    if sent_embeddings.shape[1] != len(sent_lemmas):
                        print(
                            "Mismatch in number of token in sentence {} : {} vs {}. Skipping sentence"
                            .format(sent_idx, sent_embeddings.shape[1],
                                    len(sent_lemmas)))
                        continue

                    # Filter sentences based on word content
                    # These are to be used in the evaluation portion
                    examples = filterSentence(args, article, sent_idx,
                                              sent_dependencies, sent_lemmas,
                                              sent_POS, doc_coreference_dict,
                                              nrc_vocab)
                    example_number += len(examples)

                    # NOTE: Weights refer to the accumulated layers of the 0) Inputs 1) Left context 2) Right context
                    def retrieveWordEmbedding(sent_embedding,
                                              verb_idx,
                                              weights=[0, 1, 0]):
                        return sent_embedding[0][verb_idx] * weights[
                            0] + sent_embedding[1][verb_idx] * weights[
                                1] + sent_embedding[2][verb_idx] * weights[2]

                    for example in examples:
                        # TODO: Keep track of the other fields of example in particular: - ENTITYGROUP (whether the obj/subj was prot/gov)
                        #                                                                - DEPTYPE --> needed to establish which model we use
                        #example['embedding'] =retrieveWordEmbedding(sent_embeddings, example['verb_idx'])
                        example['embedding'] = retrieveWordEmbedding(
                            sent_embeddings, example['tok_idx'])
                        training_embeddings.addItem(example)
        except OSError:
            list_invalid_articles.append(article)
            print("Invalid HDF5 file {}".format(articleToName(article)))
        except Exception as e:
            # Catch all for other errors
            list_invalid_articles.append(article)
            print("{} occured. Skipping article {}".format(
                e, articleToName(article)))
    print("Total number of examples processed:", example_number)

    # TODO: Store the trained models and embeddings for future reference
    with open(args.emb_file, 'wb+') as embedding_fh:
        pickle.dump(training_embeddings, embedding_fh)
def extract_entities(filename):
    root, full_doc = process_xml_text(filename)

    name_to_verbs = defaultdict(list)
    for coref in root.find('document').find('coreference').iter('coreference'):
        verbs_to_cache = []
        name = "Unknown"
        for mention in coref.findall('mention'):
            if 'representative' in mention.attrib:
                name = mention.find('text').text

            sent_id = int(mention.find('sentence').text) - 1

            sentence = root.find('document').find('sentences')[sent_id]
            for dep in sentence.find('dependencies').iter('dep'):
                if int(dep.find('dependent').get("idx")) != int(
                        mention.find('end').text) - 1:
                    continue

                parent_id = int(dep.find('governor').get("idx")) - 1
                parent = dep.find('governor').text

                parent_lemma = sentence.find('tokens')[int(parent_id)].find(
                    'lemma').text

                # We save the sentence id, the parent id, the entity name, the relationship, the article number
                # With sentence id and parent id we can find embedding
                if dep.get("type") in ["nsubj", "nsubjpass", "dobj"]:
                    verbs_to_cache.append(
                        VerbInstance(sent_id, parent_id, parent, parent_lemma,
                                     dep.get("type"),
                                     mention.find('text').text, "", filename))

        # end coreff chain
        # We do it this way so that if we set the name in the middle of the chain we keep it for all things in the chain
        if verbs_to_cache:
            name_to_verbs[name] += verbs_to_cache

    final_verb_dict = {}
    for name, tupls in name_to_verbs.items():
        for t in tupls:
            key = (t.sent_id, t.verb_id)
            final_verb_dict[key] = t._replace(entity_name=name)

    id_to_sent = {}
    # Also keep all verbs that are in lex
    for s in root.find('document').find('sentences').iter('sentence'):
        sent = []
        for tok in s.find('tokens').iter('token'):
            sent.append(tok.find("word").text.lower())
            sent_id = int(s.get("id")) - 1
            verb_id = int(tok.get("id")) - 1
            key = (sent_id, verb_id)
            if key in final_verb_dict:
                continue

            if tok.find('POS').text.startswith("VB"):
                final_verb_dict[key] = VerbInstance(
                    sent_id, verb_id,
                    tok.find("word").text,
                    tok.find('lemma').text.lower(), "", "", "", filename)
        id_to_sent[sent_id] = " ".join(sent)

    return final_verb_dict, id_to_sent