def walk(subtree, posinfo=PosInfo(text=0, edu=0)): """ walk down first-cut tree, counting span info and returning a fancier tree along the way """ if isinstance(subtree, Tree): start = copy.copy(posinfo) children = [] nuc_kids = [] for kid in subtree: tree, posinfo, nuc_kid = walk(kid, posinfo) children.append(tree) nuc_kids.append(nuc_kid) nuclearity = ''.join(x for x in nuc_kids) match = _lw_type_re.match(treenode(subtree)) if not match: raise RSTTreeException("Missing nuclearity annotation in " + str(subtree)) nuc = match.group("nuc") rel = match.group("rel") or "leaf" edu_span = (start.edu, posinfo.edu - 1) span = Span(start.text, posinfo.text) node = Node(nuclearity, edu_span, span, rel) return SimpleRSTTree(node, children), posinfo, nuc else: text = subtree start = posinfo.text end = start + len(text) posinfo2 = PosInfo(text=end, edu=posinfo.edu + 1) return EDU(posinfo.edu, Span(start, end), text), posinfo2, "leaf"
def _postprocess(tree, start=0, edu_start=1): """ Helper function: Convert the NLTK-parsed representation of an RST tree to one using educe-style Standoff objects """ if isinstance(tree, Tree): children = [] position = start - 1 # compensate for virtual whitespace added below node = _parse_node(treenode(tree), Span(-1, -1)) edu_start2 = node.edu_span[0] for child_ in tree: # (NB: +1 to add virtual whitespace between EDUs) child = _postprocess(child_, position + 1, edu_start2) children.append(child) # pylint: disable=E1101 child_sp = _tree_span(child) # pylint: enable=E1101 position = child_sp.char_end node.span = Span(start, position) return RSTTree(node, children) else: if tree.startswith("["): return _parse_edu(tree[1:-1], edu_start, start) else: raise RSTTreeException("ERROR in rst tree format for leaf : ", child)
def assertOverlap(self, expected, pair1, pair2, **kwargs): "true if `pair1.overlaps(pair2) == expected` (modulo boxing)" (x1, y1) = pair1 (x2, y2) = pair2 (rx, ry) = expected o = Span(x1, y1).overlaps(Span(x2, y2), **kwargs) self.assertTrue(o) self.assertEqual(Span(rx, ry), o)
def split_doc(doc, middle): """ Given a split point, break a document into two pieces. If the split point is None, we take the whole document (this is slightly different from having -1 as a split point) Raise an exception if there are any annotations that span the point. Parameters ---------- doc : Document The document we want to split. middle : int Split point. Returns ------- doc_prefix : Document Deep copy of `doc` restricted to span [:middle] doc_suffix : Document Deep copy of `doc` restricted to span [middle:] ; the span of each annotation is shifted to match the new text. """ doc_len = doc.text_span().char_end if middle < 0: middle = doc_len + 1 + middle def straddles(point, span): """ True if the point is somewhere in the middle of the span (sitting at right edge doesn't count). Note that this is not the same as checking for enclosure because we do not include the rightward edge """ if span is None: return False return span.char_start < point and span.char_end > point leftovers = [ x for x in doc.annotations() if straddles(middle, x.text_span()) ] if leftovers: oops = ("Can't split document [{origin}] at {middle} because it is " "straddled by the following annotations:\n" "{annotations}\n" "Either split at a different place or remove the annotations") leftovers = [' * %s %s' % (x.text_span(), x) for x in leftovers] raise StacDocException( oops.format(origin=doc.origin, middle=middle, annotations='\n'.join(leftovers))) prefix = Span(0, middle) suffix = Span(middle, doc_len) return narrow_to_span(doc, prefix), narrow_to_span(doc, suffix)
def test_simple_align(self): "trivial token realignment" tokens = ["a", "bb", "ccc"] text = "a bb ccc" spans = list(generic_token_spans(text, tokens)) expected = [Span(0, 1), Span(2, 4), Span(8, 11)] self.assertEquals(expected, spans)
def test_messy_align(self): "ignore whitespace in token" tokens = ["a", "b b", "c c c"] text = "a bb ccc" spans = list(generic_token_spans(text, tokens)) expected = [Span(0, 1), Span(2, 4), Span(8, 11)] self.assertEquals(expected, spans)
def _actually_split(tcache, doc, dialogue, turn): """Split the dialogue before the given turn. """ dspan = dialogue.text_span() tspan = turn.text_span() span1 = Span(dspan.char_start, tspan.char_start - 1) span2 = Span(tspan.char_start - 1, dspan.char_end) dialogue1 = dialogue dialogue2 = copy.deepcopy(dialogue) _set(tcache, span1, dialogue1) _set(tcache, span2, dialogue2) doc.units.append(dialogue2) dialogue2.features = {}
def generic_token_spans(text, tokens, offset=0, txtfn=None): """ Given a string and a sequence of substrings within than string, infer a span for each of the substrings. We do this spans by walking the text and the tokens we consume substrings and skipping over any whitespace (including that which is within the tokens). For this to work, the substring sequence must be identical to the text modulo whitespace. Spans are relative to the start of the string itself, but can be shifted by passing an offset (the start of the original string's span). Empty tokens are accepted but have a zero-length span. Note: this function is lazy so you can use it incrementally provided you can generate the tokens lazily too You probably want `token_spans` instead; this function is meant to be used for similar tasks outside of pos tagging :param txtfn: function to extract text from a token (default None, treated as identity function) """ txt_iter = ifilterfalse(lambda x: x[1].isspace(), enumerate(text)) txtfn = txtfn or (lambda x: x) last = offset # for corner case of empty tokens for token in tokens: tok_chars = list(ifilterfalse(lambda x: x.isspace(), txtfn(token))) if not tok_chars: yield Span(last, last) continue prefix = list(islice(txt_iter, len(tok_chars))) if not prefix: msg = "Too many tokens (current: %s)" % txtfn(token) raise EducePosTagException(msg) last = prefix[-1][0] + 1 + offset span = Span(prefix[0][0] + offset, last) pretty_prefix = text[span.char_start:span.char_end] # check the text prefix to make sure we have the same # non-whitespace characters for txt_pair, tok_char in zip(prefix, tok_chars): idx, txt_char = txt_pair if txt_char != tok_char: msg = "token mismatch at char %d (%s vs %s)\n"\ % (idx, txt_char, tok_char)\ + " token: [%s]\n" % token\ + " text: [%s]" % pretty_prefix raise EducePosTagException(msg) yield span
def _load_rst_wsj_corpus_text_file_wsj(f): """Actually do load""" text = f.read() start = 0 sent_id = 0 output_sents = [] output_paras = [] for para_id, paragraph in enumerate(text.split(WSJ_SEP_PARA)): para_sents = [] for sentence in paragraph.split(WSJ_SEP_SENT): end = start + len(sentence) # NEW: remove trailing white space rws = len(sentence) - len(sentence.rstrip()) if rws: end -= rws # end NEW if end > start: para_sents.append(Sentence(sent_id, Span(start, end))) sent_id += 1 start = end + rws + 1 # + 1 for + len(WSJ_SEP_SENT) output_paras.append(Paragraph(para_id, para_sents)) output_sents.extend(para_sents) start += 2 # whitespace and second newline return text, output_sents, output_paras
class Token(RawToken, Standoff): """ A token with a part of speech tag and some character offsets associated with it. """ def __init__(self, tok, span): RawToken.__init__(self, tok.word, tok.tag) Standoff.__init__(self) self.span = span def __str__(self): return '%s\t%s' % (RawToken.__str__(self), self.span) def __unicode__(self): return '%s\t%s' % (RawToken.__unicode__(self), self.span) # left padding Token _lpad_word = '__START__' _lpad_tag = '__START__' _lpad_span = Span(0, 0) @classmethod def left_padding(cls): "Return a special Token for left padding" return Token(RawToken(cls._lpad_word, cls._lpad_tag), cls._lpad_span)
def token_spans(text, tokens, offset=0): """ Given a string and a sequence of RawToken representing tokens in that string, infer the span for each token. Return the results as a sequence of Token objects. We infer these spans by walking the text as we consume tokens, and skipping over any whitespace in between. For this to work, the raw token text must be identical to the text modulo whitespace. Spans are relative to the start of the string itself, but can be shifted by passing an offset (the start of the original string's span) """ token_words = [tok.word for tok in tokens] spans = generic_token_spans(text, token_words, offset) res = [Token(tok, span) for tok, span in zip(tokens, spans)] # sanity checks that should be moved to tests for orig_tok, new_tok in zip(tokens, res): span = Span(new_tok.span.char_start - offset, new_tok.span.char_end - offset) snippet = text[span.char_start:span.char_end] assert snippet == new_tok.word assert orig_tok.word == new_tok.word assert orig_tok.tag == new_tok.tag return res
def shift_span(span, updates, stretch_right=False): """ Given a span and an updates tuple, return a Span that is shifted over to reflect the updates Parameters ---------- span: Span updates: Updates stretch_right : boolean, optional If True, stretch the right boundary of an annotation that buts up against the left of a new annotation. This is recommended for annotations that should fully cover a given span, like dialogues for documents. Returns ------- span: Span See also -------- shift_char: for details on how this works """ start = shift_char(span.char_start, updates) if stretch_right: end = shift_char(span.char_end, updates) else: # this is to avoid spurious overstretching of the right # boundary of an annotation that buts up against the # left of a new annotation end = 1 + shift_char(span.char_end - 1, updates) return Span(start, end)
def approximate_cover(elts, tgt): """Returns True if elts covers tgt's span. This is approximate because we only check that: * the first and last elements respectively begin and end at the extremities of tgt.span, * consecutive elements don't overlap. Because of the second item, we assume that elts has been sorted by span. Parameters ---------- elts : sorted list of Annotation Sequence of elements tgt : Annotation Target annotation Returns ------- res : boolean True if elts approximately cover tgt.span """ span_seq = Span(elts[0].span.char_start, elts[-1].span.char_end) res = (span_eq(span_seq, tgt.text_span(), eps=1) and all( elt_cur.overlaps(elt_nxt) is None for elt_cur, elt_nxt in zip(elts[:-1], elts[1:]))) return res
def main(args): """ Subcommand main. You shouldn't need to call this yourself if you're using `config_argparser` """ corpus = read_corpus_with_unannotated(args) tcache = TimestampCache() output_dir = get_output_dir(args, default_overwrite=True) commit_info = None for k in corpus: old_doc = corpus[k] new_doc = copy.deepcopy(old_doc) span = Span.merge_all(args.spans) _split_edu(tcache, k, new_doc, args.spans) diffs = _mini_diff(k, old_doc, new_doc, span) print("\n".join(diffs).encode('utf-8'), file=sys.stderr) save_document(output_dir, k, new_doc) # for commit message generation commit_info = CommitInfo(key=k, annotator=args.annotator, before=old_doc, after=new_doc, span=span) if commit_info and not args.no_commit_msg: print("-----8<------") print(commit_msg(commit_info)) announce_output_dir(output_dir)
def _dialogues_in_turns(corpus, turn1, turn2): """ Given a pair of turns """ # grab a document from the set (assumption here is that # they are all morally the same doc) if not corpus.values(): sys.exit("No documents selected") doc = corpus.values()[0] starting_turn = get_turn(turn1, doc) ending_turn = get_turn(turn2, doc) # there's a bit of fuzz for whitespace before/after the # turns span = Span(starting_turn.text_span().char_start - 1, ending_turn.text_span().char_end + 1) def is_in_range(anno): """ If the annotation is a dialogue that is covered by the turns in question """ return is_dialogue(anno) and span.encloses(anno.span) return [ anno_id_to_tuple(x.local_id()) for x in doc.annotations() if is_in_range(x) ]
def merge_turn_stars(doc): """Return a copy of the document in which consecutive turns by the same speaker have been merged. Merging is done by taking the first turn in grouping of consecutive speaker turns, and stretching its span over all the subsequent turns. Additionally turn prefix text (containing turn numbers and speakers) from the removed turns are stripped out. """ def prefix_span(turn): "given a turn annotation, return the span of its prefix" prefix, _ = split_turn_text(doc.text(turn.text_span())) start = turn.text_span().char_start return start, start + len(prefix) doc = copy.deepcopy(doc) dialogues = sorted([x for x in doc.units if is_dialogue(x)], key=lambda x: x.text_span()) rejects = [] # spans for the "deleted" turns' prefixes for dia in dialogues: dia_turns = sorted(turns_in_span(doc, dia.text_span()), key=lambda x: x.text_span()) for _, turns in itr.groupby(dia_turns, anno_speaker): turns = list(turns) tstar = turns[0] tstar.span = Span.merge_all(x.text_span() for x in turns) rejects.extend(turns[1:]) for anno in turns[1:]: doc.units.remove(anno) # pylint: disable=protected-access doc._text = _blank_out(doc._text, [prefix_span(x) for x in rejects]) # pylint: enable=protected-access return doc
def _mk_token(ttoken, span): """ Convert a tweaked token and the span it's been aligned with into a proper Token object. """ if ttoken.offset != 0: span = Span(span.char_start + ttoken.offset, span.char_end) return Token(ttoken, span)
def anno(doc, prefix, tspan): "pad text segment as needed" prefix_t = "..."\ if tspan.char_start + len(prefix) < info.span.char_start\ else "" myspan = Span(info.span.char_start, tspan.char_end) return "".join([prefix, prefix_t, annotate_doc(doc, span=myspan)])
def compute_updates(src_doc, tgt_doc, matches): """Return updates that would need to be made on the target document. Given matches between the source and target document, return span updates along with any source annotations that do not have an equivalent in the target document (the latter may indicate that resegmentation has taken place, or that there is some kind of problem) Parameters ---------- src_doc : Document tgt_doc : Document matches : [Match] Returns ------- updates: Updates """ res = Updates() # case 2 and 5 (to be pruned below) res.expected_src_only.extend(src_doc.units) res.abnormal_tgt_only.extend(tgt_doc.units) # case 1, 2 and 4 for src, tgt, size in matches: tgt_to_src = src - tgt res.shift_if_ge[tgt] = tgt_to_src # case 1 and 2 src_annos = enclosed(Span(src, src + size), src_doc.units) tgt_annos = enclosed(Span(tgt, tgt + size), tgt_doc.units) for src_anno in src_annos: res.expected_src_only.remove(src_anno) # prune from case 5 src_span = src_anno.text_span() tgt_equiv = [ x for x in tgt_annos if x.text_span().shift(tgt_to_src) == src_span ] if not tgt_equiv: # case 4 res.abnormal_src_only.append(src_anno) for tgt_anno in tgt_equiv: # prun from case 2 if tgt_anno in res.abnormal_tgt_only: res.abnormal_tgt_only.remove(tgt_anno) return res
def __init__(self, node, children, origin=None): SearchableTree.__init__(self, node, children) Standoff.__init__(self, origin) if not children: raise Exception("Can't create a tree with no children") self.children = children start = min(x.span.char_start for x in children) end = max(x.span.char_end for x in children) self.span = Span(start, end)
def __init__(self, t, offset, origin=None): extent = t['extent'] word = t['word'] tag = t['POS'] span = Span(extent[0], extent[1] + 1).shift(offset) postag.Token.__init__(self, postag.RawToken(word, tag), span) self.features = copy.copy(t) for k in ['s_id', 'word', 'extent', 'POS']: del self.features[k]
def _enclosing_turn_span(doc, span): """ Return the span for any turn annotations that enclose this span. If none are found, return the span itself """ def is_match(anno): "enclosing turn" return educe.stac.is_turn(anno) and anno.text_span().encloses(span) spans = [span] + [u.text_span() for u in doc.units if is_match(u)] return Span.merge_all(spans)
def __init__(self, node, children, link, origin=None): SearchableTree.__init__(self, node, children) Standoff.__init__(self, origin) nodes = children if not self.is_root(): nodes.append(self.label()) start = min(x.span.char_start for x in nodes) end = max(x.span.char_end for x in nodes) self.link = link self.span = Span(start, end) self.origin = origin
def _recompute_spans(tree, context): """ Recalculate tree node spans from the bottom up (helper for _align_with_context) """ if isinstance(tree, Tree): spans = [] for child in tree: _recompute_spans(child, context) spans.append(_tree_span(child)) treenode(tree).span = Span.merge_all(spans) treenode(tree).context = context
def _split_edu(tcache, k, doc, spans): """ Find the edu covered by these spans and do the split """ # seek edu big_span = Span.merge_all(spans) matches = [x for x in doc.units if x.text_span() == big_span and educe.stac.is_edu(x)] if not matches and k.stage != 'discourse': print("No matches found in %s" % k, file=sys.stderr) elif not matches: _tweak_presplit(tcache, doc, spans) else: _actually_split(tcache, doc, spans, matches[0])
def _split_edu(tcache, k, doc, spans): """ Find the edu covered by these spans and do the split """ # seek edu big_span = Span.merge_all(spans) matches = [ x for x in doc.units if x.text_span() == big_span and educe.stac.is_edu(x) ] if not matches and k.stage != 'discourse': print("No matches found in %s" % k, file=sys.stderr) elif not matches: _tweak_presplit(tcache, doc, spans) else: _actually_split(tcache, doc, spans, matches[0])
def main(args): """Subcommand main. You shouldn't need to call this yourself if you're using `config_argparser`. """ output_dir = get_output_dir(args, default_overwrite=True) # locate insertion site: target document reader = educe.stac.Reader(args.corpus) tgt_files = reader.filter(reader.files(), is_requested(args)) tgt_corpus = reader.slurp(tgt_files) # TODO mark units with FIXME, optionally delete in/out relations span = args.span sub_text = args.sub_text minor = args.minor # store before/after annos_before = [] annos_after = [] for tgt_k, tgt_doc in tgt_corpus.items(): annos_before.append(annotate_doc(tgt_doc, span=span)) # process new_tgt_doc = replace_text_at_span(tgt_doc, span, sub_text, minor=minor) # WIP new_span, depends on the offset offset = len(sub_text) - (span.char_end - span.char_start) new_span = Span(span.char_start, span.char_end + offset) # end WIP annos_after.append(annotate_doc(new_tgt_doc, span=new_span)) # show diff and save doc diffs = [ "======= REPLACE TEXT IN %s ========" % tgt_k, show_diff(tgt_doc, new_tgt_doc) ] print("\n".join(diffs).encode('utf-8'), file=sys.stderr) save_document(output_dir, tgt_k, new_tgt_doc) announce_output_dir(output_dir) # commit message tgt_k, tgt_doc = list(tgt_corpus.items())[0] anno_str_before = annos_before[0] anno_str_after = annos_after[0] if tgt_k and not args.no_commit_msg: print("-----8<------") print(commit_msg(tgt_k, anno_str_before, anno_str_after))
def _actually_merge(tcache, edus, doc): """ Given a timestamp cache, a document and a collection of edus, replace the edus with a single merged edu in the document Anything that points to one of the EDUs should point instead to the new edu. Anything which points exclusively to EDUs in the span should be deleted (or signaled?) Annotations and features should be merged """ def one_or_join(strs): "Return element if singleton, otherwise moosh together" strs = [x for x in strs if x is not None] return list(strs)[0] if len(strs) == 1\ else _MERGE_PREFIX + "/".join(strs) if not edus: return new_edu = copy.deepcopy(edus[0]) new_edu.span = Span.merge_all(x.text_span() for x in edus) stamp = tcache.get(new_edu.span) set_anno_date(new_edu, stamp) set_anno_author(new_edu, _AUTHOR) if doc.origin.stage == 'units': new_edu.type = one_or_join(frozenset(x.type for x in edus)) # feature keys for all edus all_keys = frozenset(x for edu in edus for x in edu.features.keys()) for key in all_keys: old_values = frozenset(x.features.get(key) for x in edus) new_edu.features[key] = one_or_join(old_values) # in-place replacement for i, _ in enumerate(doc.units): if doc.units[i] in edus: doc.units[i] = new_edu break for edu in edus: if edu in doc.units: doc.units.remove(edu) retarget(doc, edu.local_id(), new_edu)
def shift_anno(anno, offset, point): """Get a shifted copy of an annotation""" anno2 = copy.deepcopy(anno) if not isinstance(anno, Unit): return anno2 anno_span = anno2.text_span() if anno_span.char_start >= point: # if the annotation is entirely after the deletion site, # shift the whole span anno2.span = anno_span.shift(offset) elif anno_span.char_end >= point: # if the annotation straddles the substitution site, # stretch (shift its end) anno2.span = Span(anno_span.char_start, anno_span.char_end + offset) return anno2
def _merge_edus(tcache, span, doc): """ Find any EDUs within the given span in the document and merge them into a single one. The EDUs should stretch from the beginning to the end of the span (gaps OK). The output EDU should have the same ID in all documents """ edus = edus_in_span(doc, span) if not edus: sys.exit("No EDUs in span %s" % span) espan = Span.merge_all(x.text_span() for x in edus) if espan != span: sys.exit("EDUs in do not cover full span %s [only %s]" % (span, espan)) _actually_merge(tcache, edus, doc)
def __init__(self, t, offset, origin=None): """ Parameters ---------- t : dict Token from corenlp's XML output. offset : int Offset from the span of the corenlp token to the document. origin : FileId, optional Identifier for the document. """ extent = t['extent'] word = t['word'] tag = t['POS'] span = Span(extent[0], extent[1] + 1).shift(offset) postag.Token.__init__(self, postag.RawToken(word, tag), span) self.features = copy.copy(t) for k in ['s_id', 'word', 'extent', 'POS']: del self.features[k]
def _actually_split(tcache, doc, spans, edu): """ Split the EDU, trying to generate the same new ID for the same new EDU across all sections Discourse stage: If the EDU is in any relations or CDUs, replace any references to it with a new CDU encompassing the newly created EDUs """ new_edus = {} for span in sorted(spans): stamp = tcache.get(span) edu2 = copy.deepcopy(edu) new_id = anno_id_from_tuple((_AUTHOR, stamp)) set_anno_date(edu2, stamp) set_anno_author(edu2, _AUTHOR) if doc.origin.stage == 'units': edu2.type = _SPLIT_PREFIX + edu2.type for key in edu2.features: edu2.features[key] = _SPLIT_PREFIX + edu2.features[key] new_edus[new_id] = edu2 edu2.span = span doc.units.append(edu2) cdu_stamp = tcache.get(Span.merge_all(spans)) cdu = educe.annotation.Schema(anno_id_from_tuple((_AUTHOR, cdu_stamp)), frozenset(new_edus), frozenset(), frozenset(), 'Complex_discourse_unit', {}, metadata={'author': _AUTHOR, 'creation-date': str(cdu_stamp)}) cdu.fleshout(new_edus) want_cdu = retarget(doc, edu.local_id(), cdu) doc.units.remove(edu) if want_cdu: doc.schemas.append(cdu)
def _nudge_down(turn, dialogue, prev_turn, next_dialogue): """ Move last turn to next dialogue. (ie. shorten the right boundary of this dialogue and extend the left boundary of this dialogue) Return encompassing span to show what we've changed """ if not prev_turn: sys.exit("Can't move very first turn. " "Try `stac-util merge-dialogue` instead") elif not next_dialogue: sys.exit("Can't move from last dialogue." "Try `stac-util move` instead") elif turn.span.char_end != dialogue.span.char_end: sys.exit("Turn %d %s is not at the end of its dialogue %s" % (st.turn_id(turn), turn.span, dialogue.span)) offset = prev_turn.span.char_end - turn.span.char_end # take both dialogue boundaries down a bit (to next turn end) next_dialogue.span.char_start += offset dialogue.span.char_end += offset return Span.merge_all([dialogue.span, next_dialogue.span])
def _nudge_up(turn, dialogue, next_turn, prev_dialogue): """ Move first turn to previous dialogue (ie. extend the previous dialogue to incorporate this turn, and push this dialogue to exclude it) Return encompassing span to show what we've changed """ if not next_turn: sys.exit("Can't move very last turn. " "Try `stac-util merge-dialogue` instead") elif not prev_dialogue: sys.exit("Can't move from first dialogue." "Try `stac-util move` instead") elif turn.span.char_start - 1 != dialogue.span.char_start: sys.exit("Turn %d %s is not at the start of its dialogue %s" % (st.turn_id(turn), turn.span, dialogue.span)) offset = next_turn.span.char_start - turn.span.char_start # take both dialogue boundaries up a bit (to prev turn end) prev_dialogue.span.char_end += offset dialogue.span.char_start += offset return Span.merge_all([prev_dialogue.span, dialogue.span])