コード例 #1
0
    def train(self, docbin_file, cutoff=None):
        """Train the HMM annotator based on the docbin file"""

        spacy_docs = annotations.docbin_reader(docbin_file, cutoff=cutoff)
        X_stream = (self.extract_sequence(doc) for doc in spacy_docs)
        streams = itertools.tee(X_stream, 3)
        self._initialise_startprob(streams[0])
        self._initialise_transmat(streams[1])
        self._initialise_emissions(streams[2])
        self._check()

        self.monitor_._reset()
        for iter in range(self.n_iter):
            print("Starting iteration", (iter + 1))
            stats = self._initialize_sufficient_statistics()
            curr_logprob = 0

            nb_docs = 0
            import random
            for doc in annotations.docbin_reader(docbin_file, cutoff=cutoff):
                if doc.user_data["annotations"]["score"] != 5:
                    if random.randint(0, 9) > 3:
                        continue

                X = self.extract_sequence(doc)
                score = 1.0
                import math
                framelogprob = self._compute_log_likelihood(X) + math.log(
                    score)
                if framelogprob.max(axis=1).min() < -100000:
                    print("problem found!")
                    return framelogprob

                logprob, fwdlattice = self._do_forward_pass(framelogprob)
                curr_logprob += logprob
                bwdlattice = self._do_backward_pass(framelogprob)
                posteriors = self._compute_posteriors(fwdlattice, bwdlattice)
                self._accumulate_sufficient_statistics(stats, X, framelogprob,
                                                       posteriors, fwdlattice,
                                                       bwdlattice)
                nb_docs += 1

                if nb_docs % 1000 == 0:
                    print("Number of processed documents:", nb_docs)
            print("Finished E-step with %i documents" % nb_docs)

            # XXX must be before convergence check, because otherwise
            #     there won't be any updates for the case ``n_iter=1``.
            self._do_mstep(stats)

            self.monitor_.report(curr_logprob)
            if self.monitor_.converged:
                break

        return self
コード例 #2
0
ファイル: labelling3.py プロジェクト: justkk/CIS620Project1
 def train(self, docbin_file):
     """Trains the Snorkel model on the provided corpus"""
     
     import snorkel.labeling
     all_obs = []
     for doc in annotations.docbin_reader(docbin_file):
         doc = self.specialise_annotations(doc)
         spans, obs = self._get_inputs(doc)
         all_obs.append(obs)
         if len(all_obs) > 5:
             break
     all_obs = np.vstack(all_obs)
     self.label_model = snorkel.labeling.LabelModel(len(LABELS) + 1)
     self.label_model.fit(all_obs)
コード例 #3
0
ファイル: ner2.py プロジェクト: justkk/CIS620Project1
def generate_from_docbin(docbin_file,
                         target_source=None,
                         cutoff=None,
                         nb_to_skip=0,
                         labels_to_retain=None,
                         labels_to_map=None,
                         loop=False):
    """Generates spacy documents from a DocBin object."""

    import annotations

    nb_generated = 0
    vocab = spacy.load("en_core_web_md").vocab

    while True:
        reader = annotations.docbin_reader(docbin_file,
                                           vocab=vocab,
                                           cutoff=cutoff,
                                           nb_to_skip=nb_to_skip)
        for spacy_doc in reader:

            spans = []
            if target_source is None:
                spans = [(ent.start, ent.end, ent.label_)
                         for ent in spacy_doc.ents]
            else:
                spans = [(start, end, label)
                         for (start,
                              end), vals in spacy_doc.user_data["annotations"]
                         [target_source].items() for label, conf in vals
                         if conf > 0.5]

            new_spans = []
            for start, end, label in spans:
                if labels_to_map is not None:
                    label = labels_to_map.get(label, label)
                if labels_to_retain is None or label in labels_to_retain:
                    ent = spacy.tokens.Span(spacy_doc, start, end,
                                            spacy_doc.vocab.strings[label])
                    new_spans.append(ent)
            spacy_doc.ents = tuple(new_spans)

            yield spacy_doc
            nb_generated += 1
            if cutoff is not None and nb_generated >= cutoff:
                return
        if not loop:
            break
コード例 #4
0
def get_crowd_data():
    crowd_docs = []
    import itertools
    import json

    import spacy

    import annotations
    import spacy_wrapper
    nlp = spacy.load("en_core_web_md", disable=["tagger", "parser", "ner"])

    pipe1 = annotations.docbin_reader("./data/reuters.docbin")
    reuters_docs = []
    pipe_stream0, pipe_stream1 = itertools.tee(pipe1, 2)
    pipe2 = nlp.pipe((x.text for x in pipe_stream0))
    nb_written = 0
    for i, (doc, doc2) in enumerate(zip(pipe_stream1, pipe2)):
        if "&" in doc.text or "<" in doc.text or ">" in doc.text:
            continue
        corrected = spacy_wrapper._correct_tokenisation(doc2)
        if [tok.text for tok in corrected] != [tok.text for tok in doc2]:
            continue
        reuters_docs.append(doc)
        nb_written += 1
        if nb_written >= 1000:
            break

    pipe1 = annotations.docbin_reader("./data/bloomberg1.docbin")
    bloomberg_docs = []
    pipe_stream0, pipe_stream1 = itertools.tee(pipe1, 2)
    pipe2 = nlp.pipe((x.text for x in pipe_stream0))
    nb_written = 0
    for i, (doc, doc2) in enumerate(zip(pipe_stream1, pipe2)):
        if "&" in doc.text or "<" in doc.text or ">" in doc.text:
            continue
        corrected = spacy_wrapper._correct_tokenisation(doc2)
        if [tok.text for tok in corrected] != [tok.text for tok in doc2]:
            continue
        bloomberg_docs.append(doc)
        nb_written += 1
        if nb_written >= 1000:
            break

    print("Number of read documents:", len(reuters_docs), len(bloomberg_docs))

    crowd_docs = []
    dic = json.load(open("data/second_launch_annotations.json", "r"))
    for k, v in dic.items():
        if v["source"] == "Bloomberg":
            doc = bloomberg_docs[int(v["source_doc"])]
        else:
            doc = reuters_docs[int(v["source_doc"])]
        for sent in doc.sents:
            if sent.text.strip() == v["original_text"].strip():
                for span in v["annotated_text"].split():
                    if "/" in span:
                        entity = span.split("/")[1].upper()
                        start = int(span.split("-")[0])
                        end = int(span.split("-")[1].split("/")[0]) + 1
                        ent_span = doc.char_span(sent.start_char + start,
                                                 sent.start_char + end)
                        if ent_span is None:
                            print("strange span", sent, span)
                            continue
                        if "crowd" not in doc.user_data["annotations"]:
                            doc.user_data["annotations"]["crowd"] = {}
                        doc.user_data["annotations"]["crowd"][(
                            ent_span.start, ent_span.end)] = ((entity, 1.0), )
                sent2 = sent.as_doc()
                sent2.user_data["annotations"] = {}
                for source in doc.user_data["annotations"]:
                    if source == "crowd_sents":
                        continue
                    sent2.user_data["annotations"][source] = {}
                    for (
                            start, end
                    ), vals in doc.user_data["annotations"][source].items():
                        if start >= sent.start and start < sent.end:
                            sent2.user_data["annotations"][source][(
                                start - sent.start, end - sent.start)] = vals
                crowd_docs.append(sent2)

    for doc in crowd_docs:
        if "crowd" not in doc.user_data["annotations"]:
            continue
        spans = []
        for (start, end) in sorted(doc.user_data["annotations"]["crowd"]):
            for val, conf in doc.user_data["annotations"]["crowd"][(start,
                                                                    end)]:
                if spans:
                    other_start, other_end = spans[-1].start, spans[-1].end
                else:
                    other_start, other_end = 0, 0
                if other_end > start:
                    print("overlap between", start, end, other_start,
                          other_end)
                    spans = spans[:-1]
                    start = other_start
                spans.append(
                    spacy.tokens.Span(
                        doc, start, end, nlp.vocab.
                        strings[val if val != "DATETIME" else "DATE"]))
        doc.ents = tuple(spans)

    return crowd_docs