def get_coref_infos( key_lines, sys_lines, NP_only=False, remove_nested=False, keep_singletons=True, min_span=False, doc="dummy_doc" ): key_doc_lines = {doc: key_lines} sys_doc_lines = {doc: sys_lines} doc_coref_infos = {} key_nested_coref_num = 0 sys_nested_coref_num = 0 key_removed_nested_clusters = 0 sys_removed_nested_clusters = 0 key_singletons_num = 0 sys_singletons_num = 0 key_clusters, singletons_num = reader.get_doc_mentions(doc, key_doc_lines[doc], keep_singletons) key_singletons_num += singletons_num if NP_only or min_span: key_clusters = reader.set_annotated_parse_trees(key_clusters, key_doc_lines[doc], NP_only, min_span) sys_clusters, singletons_num = reader.get_doc_mentions(doc, sys_doc_lines[doc], keep_singletons) sys_singletons_num += singletons_num if NP_only or min_span: sys_clusters = reader.set_annotated_parse_trees(sys_clusters, key_doc_lines[doc], NP_only, min_span) if remove_nested: nested_mentions, removed_clusters = reader.remove_nested_coref_mentions(key_clusters, keep_singletons) key_nested_coref_num += nested_mentions key_removed_nested_clusters += removed_clusters nested_mentions, removed_clusters = reader.remove_nested_coref_mentions(sys_clusters, keep_singletons) sys_nested_coref_num += nested_mentions sys_removed_nested_clusters += removed_clusters sys_mention_key_cluster = reader.get_mention_assignments(sys_clusters, key_clusters) key_mention_sys_cluster = reader.get_mention_assignments(key_clusters, sys_clusters) doc_coref_infos[doc] = (key_clusters, sys_clusters, key_mention_sys_cluster, sys_mention_key_cluster) if remove_nested: logger.info( "Number of removed nested coreferring mentions in the key " f"annotation: {key_nested_coref_num}; and system annotation: {sys_nested_coref_num}" ) logger.info( "Number of resulting singleton clusters in the key " f"annotation: {key_removed_nested_clusters}; and system annotation: {sys_removed_nested_clusters}" ) if not keep_singletons: logger.info( f"{key_singletons_num:d} and {sys_singletons_num:d} singletons are removed from the key and system " "files, respectively" ) return doc_coref_infos
def get_coref_infos(key_lines, sys_lines, NP_only=False, remove_nested=False, keep_singletons=True, min_span=False, doc="dummy_doc"): key_doc_lines = {doc: key_lines} sys_doc_lines = {doc: sys_lines} doc_coref_infos = {} key_nested_coref_num = 0 sys_nested_coref_num = 0 key_removed_nested_clusters = 0 sys_removed_nested_clusters = 0 key_singletons_num = 0 sys_singletons_num = 0 key_clusters, singletons_num = reader.get_doc_mentions( doc, key_doc_lines[doc], keep_singletons) key_singletons_num += singletons_num if NP_only or min_span: key_clusters = reader.set_annotated_parse_trees( key_clusters, key_doc_lines[doc], NP_only, min_span) sys_clusters, singletons_num = reader.get_doc_mentions( doc, sys_doc_lines[doc], keep_singletons) sys_singletons_num += singletons_num if NP_only or min_span: sys_clusters = reader.set_annotated_parse_trees( sys_clusters, key_doc_lines[doc], NP_only, min_span) if remove_nested: nested_mentions, removed_clusters = reader.remove_nested_coref_mentions( key_clusters, keep_singletons) key_nested_coref_num += nested_mentions key_removed_nested_clusters += removed_clusters nested_mentions, removed_clusters = reader.remove_nested_coref_mentions( sys_clusters, keep_singletons) sys_nested_coref_num += nested_mentions sys_removed_nested_clusters += removed_clusters sys_mention_key_cluster = reader.get_mention_assignments( sys_clusters, key_clusters) key_mention_sys_cluster = reader.get_mention_assignments( key_clusters, sys_clusters) doc_coref_infos[doc] = (key_clusters, sys_clusters, key_mention_sys_cluster, sys_mention_key_cluster) if remove_nested: print('Number of removed nested coreferring mentions in the key ' 'annotation: %s; and system annotation: %s' % (key_nested_coref_num, sys_nested_coref_num)) print('Number of resulting singleton clusters in the key ' 'annotation: %s; and system annotation: %s' % (key_removed_nested_clusters, sys_removed_nested_clusters)) if not keep_singletons: print('%d and %d singletons are removed from the key and system ' 'files, respectively' % (key_singletons_num, sys_singletons_num)) return doc_coref_infos
def get_coref_from_orig_hyp_gts_dcts( self, hyp_orig_dct, gts_orig_dct, met_inp=None, conv_dct=None, ): self.reset_coval_scorer_dict() def get_coref_dct_for_gt1(gt1): coref_dct = {} for evix, ev_i in enumerate(ev_lst, 1): gt_args = gt1[ev_i]["Args"] for gt_ag in gt_args: gt_ag_name = arg_mapper(gt_ag) if gt_ag_name in self.args_used: gtv1 = gt_args[gt_ag] if gtv1 not in coref_dct: coref_dct[gtv1] = [] coref_dct[gtv1].append(f"{ev_i}_{gt_ag_name}") return coref_dct def get_coref_dct_for_pred(pred, gt1): coref_dct = {} for evix, ev_i in enumerate(ev_lst, 1): # gt_args = gt1[ev_i]["Args"] gt_args = list(gt1[ev_i]["Args"].keys()) # pred_set1 = set() for gt_ag in gt_args: gt_ag_name = arg_mapper(gt_ag) if gt_ag_name in self.args_used: if gt_ag_name in pred[ev_i]: pred_v1 = pred[ev_i][gt_ag_name] if pred_v1 not in coref_dct: coref_dct[pred_v1] = [] coref_dct[pred_v1].append(f"{ev_i}_{gt_ag_name}") return coref_dct def preproc_dct(dct1): out_lst = list(dct1.values()) return out_lst ev_lst = [f"Ev{ix}" for ix in range(1, 6)] ann_idx_keys = sorted(list(hyp_orig_dct.keys())) coval_mets = ["mentions", "muc", "bcub", "ceafe", "lea", "lea_soft"] out_f1_scores = {cmet: [] for cmet in coval_mets} is_lea_soft = False if conv_dct is not None: is_lea_soft = True if is_lea_soft: conv_dct2 = {} for ck, c in conv_dct.items(): if c["ann_idx"] not in conv_dct2: conv_dct2[c["ann_idx"]] = [] conv_dct2[c["ann_idx"]].append(c) gt_max = len(gts_orig_dct[list(gts_orig_dct.keys())[0]]) for gtix in range(gt_max): self.reset_coval_scorer_dict() for ann_idx in ann_idx_keys: gts1 = gts_orig_dct[ann_idx][gtix] hypo_1 = hyp_orig_dct[ann_idx] if is_lea_soft: conv1 = conv_dct2[ann_idx] conv11 = {v["ev_agname"]: v for v in conv1} if "Ev1" in hypo_1: if "Args" in hypo_1["Ev1"]: sys_dct = preproc_dct(get_coref_dct_for_gt1(hypo_1)) else: sys_dct = preproc_dct(get_coref_dct_for_pred(hypo_1, gts1)) if is_lea_soft: cid_sc_lst = [] for cls1 in sys_dct: cid_sc_lst1 = [] for cls11 in cls1: cid_sc_idx = conv11[cls11] cid_sc = met_inp["cider_sent"][cid_sc_idx["aix"]] cid_sc_lst1.append(cid_sc) cid_sc_lst.append(cid_sc_lst1) key_dct = preproc_dct(get_coref_dct_for_gt1(gts1)) key_to_sys_dct = get_mention_assignments(key_dct, sys_dct) sys_to_key_dct = get_mention_assignments(sys_dct, key_dct) tup = (key_dct, sys_dct, key_to_sys_dct, sys_to_key_dct) for cmet in coval_mets: if cmet != "lea_soft": self.coval_scorer_dict[cmet].update(tup) else: self.coval_scorer_dict[cmet].update( tup, cider_for_sys=cid_sc_lst ) for cmt in coval_mets: out_f1_scores[cmt].append(self.coval_scorer_dict[cmt].get_f1()) return {cmt: sum(v) / len(v) for cmt, v in out_f1_scores.items()}