def verify_json_dict(json_dict): sentences: List[List[Span]] = json_dict["sentences"] sections: List[Span] = json_dict["sections"] entities: Dict[Span, str] = json_dict["ner"] corefs: Dict[str, List[Span]] = json_dict["coref"] assert all(sum(is_x_in_y(e, s) for s in sections) == 1 for e in entities), breakpoint() assert all( sum(is_x_in_y(e, ss) for s in sentences for ss in s) == 1 for e in entities), breakpoint() assert all((sections[i][0] == sentences[i][0][0] and sections[i][-1] == sentences[i][-1][-1]) for i in range(len(sections))), breakpoint() assert all(x in entities for k, v in corefs.items() for x in v), breakpoint()
def group_sentences_to_sections(sentences: List[Span], sections: List[Span]) -> List[List[Span]]: grouped_sentences = [[] for _ in range(len(sections))] for s in sentences: done = 0 for i, sec in enumerate(sections): if is_x_in_y(s, sec): grouped_sentences[i].append(s) done += 1 if done != 1: breakpoint() return grouped_sentences
def extract_sentence_features(sentences, words, entities): entities_to_features_map = {} sentence_features = [get_features_for_sections(sents, words) for sents in sentences] for e in entities: index = [ (i, j) for i, sents in enumerate(sentences) for j, sspan in enumerate(sents) if is_x_in_y(e, sspan) ] assert len(index) == 1, breakpoint() i, j = index[0] entities_to_features_map[(e[0], e[1])] = sentence_features[i][j] return entities_to_features_map
def move_boundaries(plist, elist): ends = np.cumsum(plist) starts = ends - np.array(plist) starts, ends = list(starts), list(ends) elist = sorted(elist, key=lambda x: (x[0], x[1])) para_stack = list(zip(starts, ends)) new_paragraphs = [] eix = 0 while len(para_stack) > 0: p = para_stack.pop(0) while True: if eix >= len(elist): new_paragraphs.append(p) break elif elist[eix][0] >= p[0] and elist[eix][1] <= p[1]: eix += 1 elif elist[eix][0] >= p[1]: new_paragraphs.append(p) break elif elist[eix][0] >= p[0]: p1 = para_stack.pop(0) new_paragraphs.append((p[0], elist[eix][1])) para_stack.insert(0, (elist[eix][1], p1[1])) eix += 1 break assert new_paragraphs[0][0] == starts[0] assert new_paragraphs[-1][1] == ends[-1] for p, q in zip(new_paragraphs[:-1], new_paragraphs[1:]): assert p[1] == q[0] for e in elist: done = False for p in new_paragraphs: if is_x_in_y((e[0], e[1]), p): done = True assert done return new_paragraphs
def clean_json_dict(json_dict): # Get fields from JSON dict entities: List[Tuple[int, int, BaseEntityType]] = json_dict["ner"] # Convert Entities to dictionary {(s, e) -> type} entities = sorted(entities, key=lambda x: (x[0], x[1])) entities: Dict[Span, BaseEntityType] = OrderedDict([((s, e), t) for s, e, t in entities]) clusters_dict: Dict[ClusterName, List[Span]] = { cluster_name: sorted(list(set([tuple(x) for x in spans]))) for cluster_name, spans in json_dict.get('coref', dict()).items() } n_ary_relations: List[Dict[BaseEntityType, ClusterName]] = [ x for x in json_dict.get("n_ary_relations", list()) ] existing_entities = set( [v for relation in n_ary_relations for k, v in relation.items()]) cluster_to_type: Dict[ClusterName, BaseEntityType] = {} for rel in n_ary_relations: for k, v in rel.items(): cluster_to_type[v] = k # Under current model, we do not use method subrelations as separate component # Therefore, we add each submethod as a separate entity if "method_subrelations" in json_dict: # Map each method to set containing (all submethod names and the method name itself) . method_subrelations: Dict[ClusterName, Set[ClusterName]] = { k: set([k] + [x[1] for x in v]) for k, v in json_dict["method_subrelations"].items() } # Add each submethod to cluster_to_type as Method for method_name, method_sub in method_subrelations.items(): for m in method_sub: if m in clusters_dict and m != method_name and m not in existing_entities: clusters_dict[method_name] += clusters_dict[m] clusters_dict[method_name] = sorted( list(set(clusters_dict[method_name]))) del clusters_dict[m] for cluster, spans in clusters_dict.items(): for span in spans: assert span in entities, breakpoint() if cluster not in cluster_to_type: continue entities[span] = cluster_to_type[cluster] for e in entities: entities[e]: EntityType = (entities[e], str( any(e in v for v in clusters_dict.values()))) json_dict["ner"]: Dict[Span, BaseEntityType] = entities json_dict["coref"]: Dict[ClusterName, List[Span]] = clusters_dict for e in entities: in_sentences = [ i for i, s in enumerate(json_dict["sentences"]) if is_x_in_y(e, s) ] # Check entity lie in one sentence if len(in_sentences) > 1: breakpoint() if len(in_sentences) == 0: in_sentences = [ i for i, s in enumerate(json_dict["sentences"]) if does_overlap(e, s) ] assert sorted(in_sentences) == list( range(min(in_sentences), max(in_sentences) + 1)), breakpoint() # breakpoint() in_sentences = sorted(in_sentences) json_dict["sentences"][in_sentences[0]][1] = json_dict[ "sentences"][in_sentences[-1]][1] json_dict["sentences"] = [ s for i, s in enumerate(json_dict["sentences"]) if i not in in_sentences[1:] ] json_dict["sentences"]: List[List[Span]] = group_sentences_to_sections( json_dict["sentences"], json_dict["sections"]) return json_dict
def resize_sections_and_group(self, sections: List[Span], sentences: List[List[Span]], entities: Dict[Span, EntityType]): broken_sections = move_boundaries( break_paragraphs( collapse_paragraphs(sections, min_len=20, max_len=self._max_paragraph_length), max_len=self._max_paragraph_length, ), list(entities.keys()), ) for p, q in zip(broken_sections[:-1], broken_sections[1:]): if p[1] != q[0] or p[1] < p[0] or q[1] < q[0]: breakpoint() sections = broken_sections entities_grouped = [{} for _ in range(len(sections))] sentences_grouped = [[] for _ in range(len(sections))] # Bert is PITA. Group entities into sections they belong to. for e in entities: is_in_n_para = 0 for para_id, p in enumerate(sections): if is_x_in_y(e, p): entities_grouped[para_id][(e[0], e[1])] = entities[e] is_in_n_para += 1 assert is_in_n_para == 1, breakpoint() ## Bert is serious PITA. Need to align sentences with sections also. sentences = [sent for section in sentences for sent in section] assert all([ sentences[i + 1][0] == sentences[i][1] for i in range(len(sentences) - 1) ]), breakpoint() assert sentences[-1][1] == sections[-1][1], breakpoint() sentence_indices = sorted( list( set([0] + [s[1] for s in sentences] + [s[1] for s in sections]))) sentences = list(zip(sentence_indices[:-1], sentence_indices[1:])) for e in sentences: is_in_n_para = 0 for para_id, p in enumerate(sections): if is_x_in_y(e, p): sentences_grouped[para_id].append(e) is_in_n_para += 1 assert is_in_n_para == 1, breakpoint() zipped = zip(sections, sentences_grouped, entities_grouped) # Remove Empty sections sections, sentences_grouped, entities_grouped = [], [], [] for p, s, e in zipped: if p[1] - p[0] == 0: assert len(e) == 0, breakpoint() assert len(s) == 0, breakpoint() continue sections.append(p) entities_grouped.append(e) sentences_grouped.append(s) return sections, sentences_grouped, entities_grouped
def convert_scirex_instance_to_scierc_format(instance): words = instance['paragraph'] sentence_indices = instance['sentence_indices'] mentions = instance['ner_dict'] start_ix, end_ix = instance['start_ix'], instance['end_ix'] metadata = instance['document_metadata'] instance_id = metadata['doc_id'] + ':' + str(instance['paragraph_num']) mentions = {(span[0] - start_ix, span[1] - start_ix): label for span, label in mentions.items()} sentence_indices = [(sent[0] - start_ix, sent[1] - start_ix) for sent in sentence_indices] ner = [[] for _ in range(len(sentence_indices))] for mention in mentions: in_sent = set([ i for i, sent in enumerate(sentence_indices) if is_x_in_y(mention, sent) ]) assert len(in_sent) == 1, breakpoint() ner[list(in_sent)[0]].append( [mention[0], mention[1], mentions[mention][0]]) sentences = [words[sent[0]:sent[1]] for sent in sentence_indices] span_to_cluster_ids = metadata['span_to_cluster_ids'] num_clusters = len(metadata['cluster_name_to_id']) clusters = [[] for _ in range(num_clusters)] for span, cluster_ids in span_to_cluster_ids.items(): span = (span[0] - start_ix, span[1] - start_ix) if span in mentions and len(cluster_ids) > 0: clusters[cluster_ids[0]].append(span) relations = [[] for _ in range(len(sentence_indices))] for idx, sentence_mentions in enumerate(ner): for span_1, span_2 in combinations(sentence_mentions, 2): span_1_orig = (span_1[0] + start_ix, span_1[1] + start_ix) span_2_orig = (span_2[0] + start_ix, span_2[1] + start_ix) if span_1_orig in span_to_cluster_ids and span_2_orig in span_to_cluster_ids: ids_1 = span_to_cluster_ids[span_1_orig] ids_2 = span_to_cluster_ids[span_2_orig] if len(set(ids_1) & set(ids_2)) > 0: relations[idx].append((span_1[0], span_1[1], span_2[0], span_2[1], 'USED_FOR')) ner = [[(int(s), int(e - 1), v) for (s, e, v) in sentence] for sentence in ner] clusters = [[(int(s), int(e - 1)) for (s, e) in cluster] for cluster in clusters if len(cluster) > 0] relations = [[(int(s1), int(e1 - 1), int(s2), int(e2 - 1), l) for (s1, e1, s2, e2, l) in sentence] for sentence in relations] return { 'doc_key': instance_id, 'ner': ner, 'sentences': sentences, 'clusters': clusters, 'relations': relations }