def read_deps(corpus, section='all', nary_enc='chain', rew_pseudo_rels=False, mrg_same_units=False): """Collect dependencies from the corpus. Parameters ---------- corpus : dict from str to dict from FileId to RSTTree Corpus of RST c-trees indexed by {'train', 'test'} then FileId. section : str, one of {'train', 'test', 'all'} Section of interest in the RST-DT. nary_enc : str, one of {'tree', 'chain'} Encoding of n-ary relations used in the c-to-d conversion. rew_pseudo_rels : boolean, defaults to False If True, rewrite pseudo relations ; see `educe.rst_dt.pseudo_relations`. mrg_same_units : boolean, defaults to False If True, merge fragmented EDUs ; see `educe.rst_dt.pseudo_relations`. Returns ------- edu_df : pandas.DataFrame Table of EDUs read from the corpus. dep_df : pandas.DataFrame Table of dependencies read from the corpus. """ # experimental: rewrite pseudo-relations if rew_pseudo_rels: for sec_name, sec_corpus in corpus.items(): corpus[sec_name] = { doc_id: rewrite_pseudo_rels(doc_id, rst_ctree) for doc_id, rst_ctree in sec_corpus.items() } if mrg_same_units: for sec_name, sec_corpus in corpus.items(): corpus[sec_name] = { doc_id: merge_same_units(doc_id, rst_ctree) for doc_id, rst_ctree in sec_corpus.items() } # convert to d-trees, collect dependencies edus = [] deps = [] for sec_name, sec_corpus in corpus.items(): for doc_id, rst_ctree in sorted(sec_corpus.items()): doc_name = doc_id.doc doc_text = rst_ctree.text() # DIRTY infer (approximate) sentence and paragraph indices # from newlines in the text (\n and \n\n) sent_idx = 0 para_idx = 0 # end DIRTY rst_dtree = RstDepTree.from_rst_tree(rst_ctree, nary_enc='chain') for dep_idx, (edu, hd_idx, lbl, nuc, hd_order) in enumerate( zip(rst_dtree.edus[1:], rst_dtree.heads[1:], rst_dtree.labels[1:], rst_dtree.nucs[1:], rst_dtree.ranks[1:]), start=1): char_beg = edu.span.char_start char_end = edu.span.char_end edus.append((sec_name, doc_name, dep_idx, char_beg, char_end, sent_idx, para_idx)) deps.append((doc_name, dep_idx, hd_idx, lbl, nuc, hd_order)) # DIRTY search for paragraph or sentence breaks in the # text of the EDU *plus the next three characters* (yerk) edu_txt_plus = doc_text[char_beg:char_end + 3] if '\n\n' in edu_txt_plus: para_idx += 1 sent_idx += 1 # sometimes wrong ; to be fixed elif '\n' in edu_txt_plus: sent_idx += 1 # end DIRTY # turn into DataFrame edu_df = pd.DataFrame(edus, columns=[ 'section', 'doc_name', 'dep_idx', 'char_beg', 'char_end', 'sent_idx', 'para_idx' ]) dep_df = pd.DataFrame( deps, columns=['doc_name', 'dep_idx', 'hd_idx', 'rel', 'nuc', 'hd_order']) # additional columns # * attachment length in EDUs dep_df['len_edu'] = dep_df['dep_idx'] - dep_df['hd_idx'] dep_df['len_edu_abs'] = abs(dep_df['len_edu']) # * attachment length, in sentences and paragraphs if False: # TODO rewrite in a pandas-ic manner ; my previous attempts have # failed but I think I got pretty close # NB: the current implementation is *extremely* slow: 155 seconds # on my laptop for the RST-DT, just for this (minor) computation len_sent = [] len_para = [] for _, row in dep_df[['doc_name', 'dep_idx', 'hd_idx']].iterrows(): edu_dep = edu_df[(edu_df['doc_name'] == row['doc_name']) & (edu_df['dep_idx'] == row['dep_idx'])] if row['hd_idx'] == 0: # {sent,para}_idx + 1 for dependents of the fake root lsent = edu_dep['sent_idx'].values[0] + 1 lpara = edu_dep['para_idx'].values[0] + 1 else: edu_hd = edu_df[(edu_df['doc_name'] == row['doc_name']) & (edu_df['dep_idx'] == row['hd_idx'])] lsent = (edu_dep['sent_idx'].values[0] - edu_hd['sent_idx'].values[0]) lpara = (edu_dep['para_idx'].values[0] - edu_hd['para_idx'].values[0]) len_sent.append(lsent) len_para.append(lpara) dep_df['len_sent'] = pd.Series(len_sent) dep_df['len_sent_abs'] = abs(dep_df['len_sent']) dep_df['len_para'] = pd.Series(len_para) dep_df['len_para_abs'] = abs(dep_df['len_para']) # * class of relation (FIXME we need to handle interaction with # rewrite_pseudo_rels) rel_conv = RstRelationConverter(RELMAP_112_18_FILE).convert_label dep_df['rel_class'] = dep_df['rel'].apply(rel_conv) # * boolean indicator for pseudo-relations ; NB: the 'Style-' prefix # can only apply if rew_pseudo_rels (otherwise no occurrence) dep_df['pseudo_rel'] = ( (dep_df['rel'].str.startswith('Style')) | (dep_df['rel'].str.endswith('Same-Unit')) | (dep_df['rel'].str.endswith('TextualOrganization'))) return edu_df, dep_df
def score_cspans(dpacks, dpredictions, coarse_rels=True, binary_trees=True, oracle_ctree_gold=False, verbose=1): """Count correctly predicted spans. Parameters ---------- dpacks : list of DataPack A DataPack per document dpredictions : list of ? Prediction for each document coarse_rels : boolean, optional If True, convert relation labels to their coarse-grained version. binary_trees : boolean, optional If True, convert (gold) constituency trees to their binary version. oracle_ctree_gold : boolean, optional If True, use oracle gold constituency trees, rebuilt from the gold dependency tree. This should emulate the evaluation in (Li 2014). verbose : int, defaults to 1 Verbosity level ; currently set to 1 because it's still considered WIP. Returns ------- cnt_s : Count Count S cnt_sn : Count Count S+N cnt_sr : Count Count S+R cnt_snr : Count Count S+N+R """ # trim down DataPacks att_packs = [attached_only(dpack, dpack.target)[0] for dpack in dpacks] # ctree_gold: oracle (from dependency version) vs true gold if oracle_ctree_gold: edges_golds = [[(edu1.id, edu2.id, att_pack.get_label(rel)) for (edu1, edu2), rel in zip(att_pack.pairings, att_pack.target) if att_pack.get_label(rel) != UNRELATED] for att_pack in att_packs] ctree_golds = [get_oracle_ctrees(edges_gold, att_pack.edus) for edges_gold, att_pack in zip(edges_golds, att_packs)] else: ctree_golds = [dpack.ctarget.values() for dpack in dpacks] # WIP coarse-grained rels and binary # these probably don't belong here because they leak educe stuff in rel_conv = RstRelationConverter(RELMAP_112_18_FILE).convert_tree binarize_tree = _binarize if coarse_rels: ctree_golds = [[rel_conv(ctg) for ctg in ctree_gold] for ctree_gold in ctree_golds] if binary_trees: ctree_golds = [[binarize_tree(ctg) for ctg in ctree_gold] for ctree_gold in ctree_golds] # end WIP # spans of the gold constituency trees ctree_spans_golds = [list(itertools.chain.from_iterable( get_spans(ctg) for ctg in ctree_gold)) for ctree_gold in ctree_golds] # spans of the predicted oracle constituency trees edges_preds = [[(edu1, edu2, rel) for edu1, edu2, rel in predictions if rel != UNRELATED] for predictions in dpredictions] ctree_spans_preds = [oracle_ctree_spans(edges_pred, att_pack.edus) for edges_pred, att_pack in zip(edges_preds, att_packs)] # FIXME replace loop with attelo.metrics.constituency.XXX cnts = [] for metric_type, lbl_fn in LBL_FNS: y_true = [[(span[0], lbl_fn(span)) for span in ctree_spans] for ctree_spans in ctree_spans_golds] y_pred = [[(span[0], lbl_fn(span)) for span in ctree_spans] for ctree_spans in ctree_spans_preds] # WIP if verbose: digits = 4 values = [metric_type] p, r, f1, s = precision_recall_fscore_support( y_true, y_pred, labels=None, average='micro') for v in (p, r, f1): values += ["{0:0.{1}f}".format(v, digits)] values += ["{0}".format(s)] print('\t'.join(values)) # end WIP # FIXME replace with calls to attelo.metrics.classification_structured. # precision_recall_fscore_support(y_true, y_pred, labels=None, # average='micro') y_tpos = sum(len(set(yt) & set(yp)) for yt, yp in zip(y_true, y_pred)) y_tpos_fpos = sum(len(yp) for yp in y_pred) y_tpos_fneg = sum(len(yt) for yt in y_true) cnts.append(CSpanCount(tpos=y_tpos, tpos_fpos=y_tpos_fpos, tpos_fneg=y_tpos_fneg)) return cnts[0], cnts[1], cnts[2], cnts[3]
align_edus_with_sentences) # RST corpus # TODO import CORPUS_DIR/CD_TRAIN e.g. from educe.rst_dt.rst_wsj_corpus CORPUS_DIR = os.path.join( os.path.dirname(__file__), '..', '..', 'data', # alt: '..', '..', 'corpora', 'rst_discourse_treebank', 'data', 'RSTtrees-WSJ-main-1.0') CD_TRAIN = os.path.join(CORPUS_DIR, 'TRAINING') CD_TEST = os.path.join(CORPUS_DIR, 'TEST') # relation converter (fine- to coarse-grained labels) REL_CONV = RstRelationConverter(RELMAP_112_18_FILE).convert_tree def is_internal_node(node): """Return True iff the node is an internal node of an RSTTree Maybe this function should be moved to `educe.rst_dt.annotation`. """ return isinstance(node, RSTTree) and len(node) > 1 def load_corpus_as_dataframe(selection='train'): """Load training section of the RST-WSJ corpus as a pandas.DataFrame. Parameters ----------
"""Dependency format for RST discourse trees. One line per EDU. """ from __future__ import absolute_import, print_function import codecs import csv import os from educe.rst_dt.corpus import RELMAP_112_18_FILE, RstRelationConverter RELCONV = RstRelationConverter(RELMAP_112_18_FILE).convert_label def _dump_disdep_file(rst_deptree, f): """Actually do dump""" writer = csv.writer(f, dialect=csv.excel_tab) # 0 is the fake root, there is no point in writing its info edus = rst_deptree.edus[1:] heads = rst_deptree.heads[1:] labels = rst_deptree.labels[1:] nucs = rst_deptree.nucs[1:] ranks = rst_deptree.ranks[1:] for i, (edu, head, label, nuc, rank) in enumerate(zip(edus, heads, labels, nucs, ranks), start=1): # text of EDU ; some EDUs have newlines in their text, so convert # those to simple spaces txt = edu.text().replace('\n', ' ')