Пример #1
0
class ParseState(object):
    """Support object for read()."""
    def __init__(self, config=defaults, name=None):
        self.config = config
        self.dataset = Dataset(name=name)
        self.document = Document()
        self.texts = []
        self.tags = []

    def sentence_break(self):
        if len(self.texts) == 0:
            return
        if self.config.iobes:
            self.tags = iob_to_iobes(self.tags)
        tokens = [Token(t, g) for t, g in zip(self.texts, self.tags)]
        self.document.add_child(Sentence(tokens=tokens))
        self.texts = []
        self.tags = []

    def document_break(self):
        self.sentence_break()
        if len(self.document) == 0:
            return
        self.dataset.add_child(self.document)
        self.document = Document()

    def finish(self):
        self.document_break()
class ParseState(object):
    """Support object for read()."""

    def __init__(self, config=defaults, name=None):
        self.config = config
        self.dataset = Dataset(name=name)
        self.document = Document()
        self.texts = []
        self.tags = []

    def sentence_break(self):
        if len(self.texts) == 0:
            return
        if self.config.iobes:
            self.tags = iob_to_iobes(self.tags)
        tokens = [Token(t, g) for t, g in zip(self.texts, self.tags)]
        self.document.add_child(Sentence(tokens=tokens))
        self.texts = []
        self.tags = []

    def document_break(self):
        self.sentence_break()
        if len(self.document) == 0:
            return
        self.dataset.add_child(self.document)
        self.document = Document()

    def finish(self):
        self.document_break()
Пример #3
0
 def stream_metadata(self):
     # get the IDs of all documents we need
     documents = self.docs_from_ids()
     # read the metadata file and extract all categories for the documents we want
     logger.info("reading metadata from " + self.file_metadata)
     metadata = util.json_read_lines(
         self.file_metadata)  # type: List[Dict[str,Any]]
     category_count = Counter()
     for meta_dict in metadata:
         doc_id = meta_dict['header']['identifier'].split(':')[-1]
         # match doc ids
         if doc_id in documents:
             doc = documents[doc_id]
             categories = meta_dict['header']['setSpecs']
             categories_clean = sorted(
                 set(c.split(':')[0] for c in categories))
             doc.categories = categories_clean
             for cat in categories_clean:
                 category_count[cat] += 1
     # integrity check
     for doc in documents.values():
         if doc.categories is None:
             logger.warning(
                 "there was no metadata entry for document '{}'".format(
                     doc.doc_id))
     # reading finished. print stats and write to file
     logger.info("categories for {} documents have been read: {}".format(
         len(documents), category_count.items()))
     util.json_write(Document.store_documents(documents.values()),
                     self.file_docs,
                     pretty=False)
Пример #4
0
def read_jsonl(path, _log, _run, name='test', encoding='utf-8', lower=True):
    _log.info('Reading %s JSONL file from %s', name, path)
    with open(path, encoding=encoding) as f:
        for line in f:
            yield Document.from_mapping(json.loads(line.strip()), lower=lower)
    if SAVE_FILES:
        _run.add_resource(path)
Пример #5
0
    def annotate(self, text, properties):
        with open('input.txt', 'w') as t:
            for x in text:
                t.write("{}\n".format(x))

        with open('props.properties', 'w') as props:
            if 'annotators' not in properties:
                props.write("annotators = {}\n".format(
                    self.default_annotators))
            props.write("file = input.txt\n")
            for key, value in properties.items():
                props.write("{} = {}\n".format(key, value))

        path = "{}/*".format(self.path)
        args = []
        try:
            sp = subprocess.Popen([
                'java', "-cp", path, "-Xmx2g",
                "edu.stanford.nlp.pipeline.StanfordCoreNLP", "-props",
                "props.properties"
            ],
                                  stdout=subprocess.PIPE,
                                  stderr=subprocess.PIPE)
            sp.wait()
        except java.lang.OutOfMemoryError:
            raise ("Out of Memory")

        return Document()
Пример #6
0
def main(args):
    os.makedirs(args.output_dir, exist_ok=True)
    with open(args.path, encoding=args.encoding) as f:
        for line in f:
            doc = Document.from_mapping(json.loads(line.strip()),
                                        lower=args.lower)
            write_neuralsum_oracle(doc,
                                   args.output_dir,
                                   encoding=args.encoding)
Пример #7
0
def make_document(token_texts, label):
    """Return Document object initialized with given token texts."""
    tokens = [Token(t) for t in token_texts]
    # We don't have sentence splitting, but the data structure expects
    # Documents to contain Sentences which in turn contain Tokens.
    # Create a dummy sentence containing all document tokens to work
    # around this constraint.
    sentences = [Sentence(tokens=tokens)]
    return Document(target_str=label, sentences=sentences)
Пример #8
0
def read_jsonl(path, _log, _run, name='test', encoding='utf-8', lower=True, remove_puncts=True,
               replace_digits=True, stopwords_path=None):
    _log.info('Reading %s JSONL file from %s', name, path)
    if SAVE_FILES:
        _run.add_resource(path)
    stopwords = None if stopwords_path is None else read_stopwords(stopwords_path)

    with open(path, encoding=encoding) as f:
        for line in f:
            yield Document.from_mapping(
                json.loads(line.strip()), lower=lower, remove_puncts=remove_puncts,
                replace_digits=replace_digits, stopwords=stopwords)
Пример #9
0
 def docs_from_metadata(self, topics: List[Topic]) -> Dict[str, Document]:
     # restore documents
     topic_dict = {t.topic_id: t for t in topics}
     documents = Document.restore_documents(util.json_read(self.file_docs),
                                            topic_dict)
     # add topics to documents (one for each category)
     if self.category_layer:
         for doc in documents.values():
             if doc.categories:
                 for category in doc.categories:
                     doc.add_topic(topic_dict[category], 1.0)
             else:
                 logger.warning("Document {} has no categories!".format(
                     doc.doc_id))
     return documents
Пример #10
0
def main(args):
    docs = []
    with open(args.path, encoding=args.encoding) as f:
        for linum, line in enumerate(f):
            try:
                obj = json.loads(line.strip())
                docs.append(Document.from_mapping(obj))
            except Exception as e:
                message = f'line {linum+1}: {e}'
                raise RuntimeError(message)

    with Executor(max_workers=args.max_workers) as ex:
        results = ex.map(label_sentences, docs)
        for best_rouge, doc in results:
            print(json.dumps(doc.to_dict(), sort_keys=True))
            if args.verbose:
                print(f'ROUGE-1-F: {best_rouge:.2f}', file=sys.stderr)
Пример #11
0
def main(args):
    objs = []
    with open(args.path, encoding=args.encoding) as f:
        for linum, line in enumerate(f):
            try:
                objs.append(json.loads(line.strip()))
            except Exception as e:
                message = f'line {linum+1}: {e}'
                raise RuntimeError(message)

    nlp = spacy.blank('id')
    with ProcessPoolExecutor(max_workers=args.max_workers) as exc:
        tok_objs = exc.map(partial(tokenize_obj, nlp), objs, chunksize=args.chunk_size)
        docs = [Document.from_mapping(obj) for obj in tok_objs]
        if args.discard_long_summary:
            docs = [doc for doc in docs if not has_long_summary(doc)]
        print('\n'.join(json.dumps(doc.to_dict(), sort_keys=True) for doc in docs))
Пример #12
0
def parse_documents():
    arxiv_parser = ArxivParser(TextProcessor())
    springer_parser = SpringerParser(TextProcessor())

    for document in Document.select(lambda doc: not doc.is_processed):
        if "arxiv.org" in urlparse(document.url)[1]:
            cur_parser = arxiv_parser
        elif "springer.com" in urlparse(document.url)[1]:
            cur_parser = springer_parser
        else:
            continue

        page = WebPage.from_disk(document.url, document.file_path)

        if document.document_hash != page.page_hash:
            Document[document.id].delete()
            continue

        parsed = cur_parser.parse(page)
        document.is_processed = True
        commit()

        logging.debug(("Article: {}" if parsed else "{}").format(document.url))
Пример #13
0
 def __init__(self, config=defaults, name=None):
     self.config = config
     self.dataset = Dataset(name=name)
     self.document = Document()
     self.texts = []
     self.tags = []
Пример #14
0
    def process_data(self, input_folder, summary_path, qap_path, document_path, pickle_folder, small_number=-1, summary_only=False, interval=50):
        reload(sys)
        sys.setdefaultencoding('utf8')

        # # Takes time to load so only do this inside function rather than in constructor
        # self.nlp =spacy.load('en_core_web_md', disable= ["tagger", "parser"])

        # Here we load files that contain the summaries, questions, answers and information about the documents
        # Not the documents themselves
        # assuming every unique id has one summary only

        to_anonymize = ["GPE", "PERSON", "ORG", "LOC"]
        def _getNER(string_data,entity_dict,other_dict):
            doc = self.nlp(string_data)
            data = string_data.split()
            NE_data = ""
            start_pos = 0
            for ents in doc.ents:
                start = ents.start_char
                end = ents.end_char
                label = ents.label_
                tokens = ents.text
                key = tokens.lower()
                if label in to_anonymize:
                    if key not in data:
                        if key not in entity_dict:
                            entity_dict[key] = "@ent" + str(len(entity_dict)) + "~ner:" + label
                        NE_data += string_data[start_pos:start] + entity_dict[key] + " "
                        start_pos = end + 1
                else:
                    other_dict[key] = tokens + "~ner:" + label
                    NE_data += string_data[start_pos:start] + tokens + "~ner:" + label + " "
                    start_pos = end + 1

            NE_data += string_data[start_pos:]
            return NE_data.split()


        summaries = {}
        with codecs.open(summary_path, "r", encoding='utf-8', errors='replace') as fin:
            first = True
            for line in reader(fin):
                if first:
                    first=False
                    continue
                id = line[0]
                summary_tokens = line[2]
                ner_summary, pos_summary, tokens = self.getNER(line[2])
                summaries[id] = (tokens, ner_summary, pos_summary)
        print("Loaded summaries")
        qaps = {}

        candidates_per_doc = defaultdict(list)
        ner_candidates_per_doc = defaultdict(list)
        pos_candidates_per_doc = defaultdict(list)
        count = 0
        with codecs.open(qap_path, "r") as fin:
            first= True
            for line in reader(fin):

                if first:
                    first= False
                    continue
                id = line[0]

                if id in qaps:



                    ner_answer, pos_answer,tokens= self.getNER(line[3])
                    ner_candidates_per_doc[id].append(ner_answer)
                    pos_candidates_per_doc[id].append(pos_answer)
                    candidates_per_doc[id].append(tokens)

                    ner_answer, pos_answer,tokens = self.getNER(line[4])
                    ner_candidates_per_doc[id].append(ner_answer)
                    pos_candidates_per_doc[id].append(pos_answer)
                    candidates_per_doc[id].append(tokens)

                    indices = [candidate_index, candidate_index + 1]
                    candidate_index += 2

                    ner_question, pos_question,tokens = self.getNER(line[2])
                    qaps[id].append(
                        Query(tokens,ner_question, pos_question, indices))
                else:
                    #print(id)
                    qaps[id] = []
                    candidates_per_doc[id] = []
                    candidate_index = 0


                    ner_answer, pos_answer,tokens = self.getNER(line[3])
                    ner_candidates_per_doc[id].append(ner_answer)
                    pos_candidates_per_doc[id].append(pos_answer)
                    candidates_per_doc[id].append(tokens)


                    ner_answer, pos_answer,tokens = self.getNER(line[4])
                    ner_candidates_per_doc[id].append(ner_answer)
                    pos_candidates_per_doc[id].append(pos_answer)
                    candidates_per_doc[id].append(tokens)

                    indices= [candidate_index, candidate_index + 1]
                    candidate_index += 2

                    ner_question, pos_question,tokens = self.getNER(line[2])
                    qaps[id].append(
                        Query(tokens,ner_question, pos_question,indices))

        print("Loaded question answer pairs")
        documents = {}
        with codecs.open(document_path, "r") as fin:
            index = 0
            for line in reader(fin):

                tokens = line
                assert len(tokens) == 10

                if index > 0:
                    doc_id = tokens[0]
                    set = tokens[1]
                    kind = tokens[2]
                    start_tag = tokens[8]
                    end_tag = tokens[9]
                    documents[doc_id] = (set, kind, start_tag, end_tag)

                index = index + 1


        # Create lists of document objects for the summaries
        train_summaries = []
        valid_summaries= []
        test_summaries= []

        if small_number > 0:
            small_summaries = []

        for doc_id in documents:
            set, kind, _, _ = documents[doc_id]
            tokens, ner_summary, pos_summary  = summaries[doc_id]
            summary = Document(doc_id, set, kind, tokens, qaps[doc_id],{},{}, candidates_per_doc[doc_id],ner_candidates_per_doc[doc_id], pos_candidates_per_doc[doc_id], ner_summary, pos_summary)

            # When constructing small data set, just add to one pile and save when we have a sufficient number
            if small_number > 0:
                small_summaries.append(summary)
                if len(small_summaries)==small_number:
                    with open(pickle_folder + "small_summaries.pickle", "wb") as fout:
                        pickle.dump(small_summaries, fout)
                    break
            else:
                if set == 'train':
                    train_summaries.append(summary)
                elif set == 'valid':
                    valid_summaries.append(summary)
                elif set == 'test':
                    test_summaries.append(summary)

        print("Pickling summaries")
        with open(pickle_folder + "train_summaries.pickle", "wb") as fout:
            pickle.dump(train_summaries, fout)
        with open(pickle_folder + "valid_summaries.pickle", "wb") as fout:
            pickle.dump(valid_summaries, fout)
        with open(pickle_folder + "test_summaries.pickle", "wb") as fout:
            pickle.dump(test_summaries, fout)

        # If only interested in summaries, return here so we don't process the documents
        if summary_only:
            return

        train_docs = []
        valid_docs = []
        test_docs = []

        # In case of creation of small test dataset
        if small_number > 0:
            small_docs = []
            small_train_docs = []
            small_valid_docs = []
            small_test_docs = []

        # Here we load documents, tokenize them, and create Document class instances
        print("Processing documents")
        filenames=glob.glob(os.path.join(input_folder, '*.content'))
        for file_number in range(len(filenames)):
            filename=filenames[file_number]
            doc_id = os.path.basename(filename).replace(".content", "")
            print("Processing:{0}".format(doc_id))
            try:
                (set, kind, start_tag, end_tag) = documents[doc_id]
            except KeyError:
                print("Document id not found: {0}".format(doc_id))
                exit(0)                

            if kind == "gutenberg":
                try:
                    with codecs.open(input_folder + doc_id + ".content", "r", encoding='utf-8', errors='replace') as fin:
                        data = fin.read()
                        data = data.replace('"', '')
                        tokenized_data = " ".join(word_tokenize(data))
                        start_index = tokenized_data.find(start_tag)
                        end_index = tokenized_data.rfind(end_tag, start_index)
                        filtered_data = tokenized_data[start_index:end_index]
                        if len(filtered_data) == 0:
                            print("Error in book extraction: ",
                                    filename, start_tag, end_tag)
                        else:
                            filtered_data = filtered_data.replace(
                                " 's ", " s ")
                            document_tokens = word_tokenize(filtered_data)

                except Exception as error:
                    print(error)
                    print("Books for which 'utf-8' doesnt work: ", doc_id)
            else:
                try:
                    # Here we remove some annotation that is unique to movie scripts
                    with codecs.open(input_folder + doc_id + ".content", "r", encoding="utf-8",
                                        errors="replace") as fin:
                        text = fin.read()
                        text = text.replace('"', '')
                        script_regex = r"<script.*>.*?</script>|<SCRIPT.*>.*?</SCRIPT>"
                        text = re.sub(script_regex, '', text)
                        for tag in start_tags_with_attributes:
                            my_regex = r'{0}.*=.*?>'.format(tag)
                            text = re.sub(my_regex, '', text)
                        for tag in end_tags:
                            text = text.replace(tag, "")
                        for tag in start_tags:
                            text = text.replace(tag, "")
                        start_tag = start_tag.replace(
                            " S ", " 'S ").replace(" s ", " 's ")
                        tokenized_data = " ".join(word_tokenize(text))
                        start_index = tokenized_data.find(start_tag)
                        if start_index == -1:
                            pass
                        end_index = tokenized_data.rfind(end_tag, start_index)
                        filtered_data = tokenized_data[start_index:end_index]
                        if len(filtered_data) == 0:
                            print("Error in movie extraction: ",
                                    filename, start_tag)
                        else:
                            filtered_data == filtered_data.replace(
                                " 's ", " s ")
                            document_tokens = word_tokenize(filtered_data)

                except Exception as error:
                    print(error)
                    print(
                        "Movie for which html extraction doesnt work doesnt work: ", doc_id)

            #Get NER
            entity_dictionary = {}
            other_dictionary = {}
            title_document_tokens = [token.lower() if token.isupper() else token for token in document_tokens]
            string_doc = " ".join(title_document_tokens)
            if len(string_doc) > 1000000:
                q1 = len(string_doc) / 4

                first_quarter = string_doc[0:q1]
                second_quarter = string_doc[q1:q1*2]
                third_quarter = string_doc[q1 * 2:q1*3]
                fourth_quarter = string_doc[q1*3:]
                first_q_tokens = _getNER(first_quarter,entity_dictionary,other_dictionary)
                second_q_tokens = _getNER(second_quarter, entity_dictionary,other_dictionary)
                third_q_tokens = _getNER(third_quarter, entity_dictionary,other_dictionary)
                fourth_q_tokens = _getNER(fourth_quarter, entity_dictionary,other_dictionary)

                NER_document_tokens = first_q_tokens + second_q_tokens + third_q_tokens + fourth_q_tokens
            else:
                NER_document_tokens = _getNER(string_doc,entity_dictionary,other_dictionary)

            doc = Document(
                doc_id, set, kind, NER_document_tokens, qaps[doc_id], entity_dictionary,other_dictionary,candidates_per_doc[doc_id],ner_candidates_per_doc[doc_id], pos_candidates_per_doc[doc_id])

            
            if (file_number+1) % interval == 0:
                print("Processed {} documents".format(file_number+1))

            # If testing, add to test list, pickle and return when sufficient documents retrieved
            if small_number > 0:
                small_docs.append(doc)
                if set == "train":
                    small_train_docs.append(doc)
                elif set == "valid":
                    small_valid_docs.append(doc)
                else:
                    small_test_docs.append(doc)
                if len(small_docs) == small_number:
                    with open(pickle_folder + "small_train_docs.pickle", "wb") as fout:
                        pickle.dump(small_train_docs, fout)
                    with open(pickle_folder + "small_valid_docs.pickle", "wb") as fout:
                        pickle.dump(small_valid_docs, fout)
                    with open(pickle_folder + "small_test_docs.pickle", "wb") as fout:
                        pickle.dump(small_test_docs, fout)
                    return


            else:
                if set == "train":
                    train_docs.append(doc)
                elif set == "valid":
                    valid_docs.append(doc)
                else:
                    test_docs.append(doc)

        # Save documents to pickle
        print("Pickling documents")
        with open(pickle_folder + "train_docs.pickle", "wb") as fout:
            pickle.dump(train_docs, fout)
        with open(pickle_folder + "validate_docs.pickle", "wb") as fout:
            pickle.dump(valid_docs, fout)
        with open(pickle_folder + "test_docs.pickle", "wb") as fout:
            pickle.dump(test_docs, fout)
Пример #15
0
 def __init__(self, config=defaults, name=None):
     self.config = config
     self.dataset = Dataset(name=name)
     self.document = Document()
     self.texts = []
     self.tags = []
Пример #16
0
 def document_break(self):
     self.sentence_break()
     if len(self.document) == 0:
         return
     self.dataset.add_child(self.document)
     self.document = Document()
Пример #17
0
 def document_break(self):
     self.sentence_break()
     if len(self.document) == 0:
         return
     self.dataset.add_child(self.document)
     self.document = Document()
Пример #18
0
    def _prepare_doc(self, curr_doc: Document) -> Dict:
        """ Returns a cache dictionary with preprocessed data. This should only be called once per document, since
        data inside same document does not get shuffled. """
        ret = {}

        preprocessed_sents, max_len = [], 0
        for curr_sent in curr_doc.raw_sentences():
            # TODO: uncased/cased option
            curr_processed_sent = list(map(lambda s: s.lower().strip(), curr_sent)) + ["<PAD>"]
            preprocessed_sents.append(curr_processed_sent)
            if len(curr_processed_sent) > max_len:
                max_len = len(curr_processed_sent)

        for i in range(len(preprocessed_sents)):
            preprocessed_sents[i].extend(["<PAD>"] * (max_len - len(preprocessed_sents[i])))

        cluster_sets = []
        mention_to_cluster_id = {}
        for i, curr_cluster in enumerate(curr_doc.clusters):
            cluster_sets.append(set(curr_cluster))
            for mid in curr_cluster:
                mention_to_cluster_id[mid] = i

        all_candidate_data = []
        for idx_head, (head_id, head_mention) in enumerate(curr_doc.mentions.items(), start=1):
            gt_antecedent_ids = cluster_sets[mention_to_cluster_id[head_id]]

            # Note: no data for dummy antecedent (len(`features`) is one less than `candidates`)
            candidates, candidate_data = [None], []
            starts, ends = [], []
            candidate_attention = []
            correct_antecedents = []

            curr_head_data = [[], []]
            for curr_token in head_mention.tokens:
                curr_head_data[0].append(curr_token.sentence_index)
                curr_head_data[1].append(curr_token.position_in_sentence)

            num_tokens = len(head_mention.tokens)
            if num_tokens > self.max_span_size:
                curr_head_data[0] = curr_head_data[0][:self.max_span_size]
                curr_head_data[1] = curr_head_data[1][:self.max_span_size]
            else:
                curr_head_data[0] += [head_mention.tokens[0].sentence_index] * (self.max_span_size - num_tokens)
                curr_head_data[1] += [-1] * (self.max_span_size - num_tokens)

            head_start = 0
            head_end = num_tokens
            head_attention = torch.ones((1, self.max_span_size), dtype=torch.bool)
            head_attention[0, num_tokens:] = False

            for idx_candidate, (cand_id, cand_mention) in enumerate(curr_doc.mentions.items(), start=1):
                if idx_candidate >= idx_head:
                    break

                candidates.append(cand_id)

                # Maps tokens to positions inside document (idx_sent, idx_inside_sent) for efficient indexing later
                curr_candidate_data = [[], []]
                for curr_token in cand_mention.tokens:
                    curr_candidate_data[0].append(curr_token.sentence_index)
                    curr_candidate_data[1].append(curr_token.position_in_sentence)

                num_tokens = len(cand_mention.tokens)
                if num_tokens > self.max_span_size:
                    curr_candidate_data[0] = curr_candidate_data[0][:self.max_span_size]
                    curr_candidate_data[1] = curr_candidate_data[1][:self.max_span_size]
                else:
                    curr_candidate_data[0] += [cand_mention.tokens[0].sentence_index] * (self.max_span_size - num_tokens)
                    curr_candidate_data[1] += [-1] * (self.max_span_size - num_tokens)

                candidate_data.append(curr_candidate_data)
                starts.append(0)
                ends.append(num_tokens)

                curr_attention = torch.ones((1, self.max_span_size), dtype=torch.bool)
                curr_attention[0, num_tokens:] = False
                candidate_attention.append(curr_attention)

                is_coreferent = cand_id in gt_antecedent_ids
                if is_coreferent:
                    correct_antecedents.append(idx_candidate)

            if len(correct_antecedents) == 0:
                correct_antecedents.append(0)

            candidate_attention = torch.cat(candidate_attention) if len(candidate_attention) > 0 else []

            all_candidate_data.append({
                "head_id": head_id,
                "head_data": torch.tensor([curr_head_data]),
                "head_attention": head_attention,
                "head_start": head_start,
                "head_end": head_end,
                "candidates": candidates,
                "candidate_data": torch.tensor(candidate_data),
                "candidate_attention": candidate_attention,
                "correct_antecedents": correct_antecedents
            })

        ret["preprocessed_sents"] = preprocessed_sents
        ret["steps"] = all_candidate_data

        return ret
Пример #19
0
 def docs_from_ids(self) -> Dict[str, Document]:
     return {
         doc_id: Document(doc_id)
         for doc_id in util.json_read(self.file_ids)
     }
Пример #20
0
    def _train_doc(self, curr_doc: Document, eval_mode=False):
        """ Trains/evaluates (if `eval_mode` is True) model on specific document.
            Returns predictions, loss and number of examples evaluated. """

        if len(curr_doc.mentions) == 0:
            return {}, (0.0, 0)

        if not hasattr(curr_doc, "_cache_nc"):
            curr_doc._cache_nc = self._prepare_doc(curr_doc)
        cache = curr_doc._cache_nc  # type: dict

        embedded_doc = []
        for curr_sent in cache["preprocessed_sents"]:
            embedded_doc.append(self.embed_sequence(curr_sent))
        embedded_doc = torch.stack(embedded_doc)  # [num_sents, max_tokens_in_any_sent + 1, embedding_size]

        doc_loss, n_examples = 0.0, len(cache["steps"])
        preds = {}

        for curr_step in cache["steps"]:
            head_id = curr_step["head_id"]
            head_data = curr_step["head_data"]

            candidates = curr_step["candidates"]
            candidate_data = curr_step["candidate_data"]
            correct_antecedents = curr_step["correct_antecedents"]

            # Note: num_candidates includes dummy antecedent + actual candidates
            num_candidates = len(candidates)
            if num_candidates == 1:
                curr_pred = 0
            else:
                idx_sent = candidate_data[:, 0, :]
                idx_in_sent = candidate_data[:, 1, :]

                # [num_candidates, max_span_size, embedding_size]
                candidate_data = embedded_doc[idx_sent, idx_in_sent]
                # [1, head_size, embedding_size]
                head_data = embedded_doc[head_data[:, 0, :], head_data[:, 1, :]]
                head_data = head_data.repeat((num_candidates - 1, 1, 1))

                candidate_scores = self.scorer(candidate_data, head_data,
                                               curr_step["candidate_attention"],
                                               curr_step["head_attention"].repeat((num_candidates - 1, 1)))
                # [1, num_candidates]
                candidate_scores = torch.cat((torch.tensor([0.0], device=DEVICE),
                                              candidate_scores.flatten())).unsqueeze(0)

                curr_pred = torch.argmax(candidate_scores)
                doc_loss += self.loss(candidate_scores.repeat((len(correct_antecedents), 1)),
                                      torch.tensor(correct_antecedents, device=DEVICE))

            # { antecedent: [mention(s)] } pair
            existing_refs = preds.get(candidates[int(curr_pred)], [])
            existing_refs.append(head_id)
            preds[candidates[int(curr_pred)]] = existing_refs

        if not eval_mode:
            doc_loss.backward()
            self.optimizer.step()
            self.optimizer.zero_grad()

        return preds, (float(doc_loss), n_examples)
Пример #21
0
    def _prepare_doc(self, curr_doc: Document) -> Dict:
        """ Returns a cache dictionary with preprocessed data. This should only be called once per document, since
        data inside same document does not get shuffled. """
        ret = {}

        # By default, each sentence is its own segment, meaning sentences are processed independently
        if self.max_segment_size is None:

            def get_position(t):
                return t.sentence_index, t.position_in_sentence

            _encoded_segments = batch_to_ids(curr_doc.raw_sentences())
        # Optionally, one can specify max_segment_size, in which case segments of tokens are processed independently
        else:

            def get_position(t):
                doc_position = t.position_in_document
                return doc_position // self.max_segment_size, doc_position % self.max_segment_size

            flattened_doc = list(chain(*curr_doc.raw_sentences()))
            num_segments = (len(flattened_doc) + self.max_segment_size -
                            1) // self.max_segment_size
            _encoded_segments = \
                batch_to_ids([flattened_doc[idx_seg * self.max_segment_size: (idx_seg + 1) * self.max_segment_size]
                              for idx_seg in range(num_segments)])

        encoded_segments = []
        # Convention: Add a PAD word ([0] * max_chars vector) at the end of each segment, for padding mentions
        for curr_sent in _encoded_segments:
            encoded_segments.append(
                torch.cat((curr_sent,
                           torch.zeros(
                               (1, ELMoCharacterMapper.max_word_length),
                               dtype=torch.long))))
        encoded_segments = torch.stack(encoded_segments)

        cluster_sets = []
        mention_to_cluster_id = {}
        for i, curr_cluster in enumerate(curr_doc.clusters):
            cluster_sets.append(set(curr_cluster))
            for mid in curr_cluster:
                mention_to_cluster_id[mid] = i

        all_candidate_data = []
        for idx_head, (head_id,
                       head_mention) in enumerate(curr_doc.mentions.items(),
                                                  1):
            gt_antecedent_ids = cluster_sets[mention_to_cluster_id[head_id]]

            # Note: no data for dummy antecedent (len(`features`) is one less than `candidates`)
            candidates, candidate_data = [None], []
            candidate_attention = []
            correct_antecedents = []

            curr_head_data = [[], []]
            num_head_words = 0
            for curr_token in head_mention.tokens:
                idx_segment, idx_inside_segment = get_position(curr_token)
                curr_head_data[0].append(idx_segment)
                curr_head_data[1].append(idx_inside_segment)
                num_head_words += 1

            if num_head_words > self.max_span_size:
                curr_head_data[0] = curr_head_data[0][:self.max_span_size]
                curr_head_data[1] = curr_head_data[1][:self.max_span_size]
            else:
                curr_head_data[0] += [curr_head_data[0][-1]
                                      ] * (self.max_span_size - num_head_words)
                curr_head_data[1] += [-1
                                      ] * (self.max_span_size - num_head_words)

            head_attention = torch.ones((1, self.max_span_size),
                                        dtype=torch.bool)
            head_attention[0, num_head_words:] = False

            for idx_candidate, (cand_id, cand_mention) in enumerate(
                    curr_doc.mentions.items(), start=1):
                if idx_candidate >= idx_head:
                    break

                candidates.append(cand_id)

                # Maps tokens to positions inside segments (idx_seg, idx_inside_seg) for efficient indexing later
                curr_candidate_data = [[], []]
                num_candidate_words = 0
                for curr_token in cand_mention.tokens:
                    idx_segment, idx_inside_segment = get_position(curr_token)
                    curr_candidate_data[0].append(idx_segment)
                    curr_candidate_data[1].append(idx_inside_segment)
                    num_candidate_words += 1

                if num_candidate_words > self.max_span_size:
                    curr_candidate_data[0] = curr_candidate_data[
                        0][:self.max_span_size]
                    curr_candidate_data[1] = curr_candidate_data[
                        1][:self.max_span_size]
                else:
                    # padding tokens index into the PAD token of the last segment
                    curr_candidate_data[0] += [curr_candidate_data[0][-1]] * (
                        self.max_span_size - num_candidate_words)
                    curr_candidate_data[1] += [-1] * (self.max_span_size -
                                                      num_candidate_words)

                candidate_data.append(curr_candidate_data)
                curr_attention = torch.ones((1, self.max_span_size),
                                            dtype=torch.bool)
                curr_attention[0, num_candidate_words:] = False
                candidate_attention.append(curr_attention)

                is_coreferent = cand_id in gt_antecedent_ids
                if is_coreferent:
                    correct_antecedents.append(idx_candidate)

            if len(correct_antecedents) == 0:
                correct_antecedents.append(0)

            candidate_attention = torch.cat(
                candidate_attention) if len(candidate_attention) > 0 else []
            all_candidate_data.append({
                "head_id":
                head_id,
                "head_data":
                torch.tensor([curr_head_data]),
                "head_attention":
                head_attention,
                "candidates":
                candidates,
                "candidate_data":
                torch.tensor(candidate_data),
                "candidate_attention":
                candidate_attention,
                "correct_antecedents":
                correct_antecedents
            })

        ret["preprocessed_segments"] = encoded_segments
        ret["steps"] = all_candidate_data

        return ret