def _parse_amr_from_metadata(tokens, metadata): ''' Metadata format is ... # ::id sentence id # ::tok tokens... # ::node node_id node alignments # ::root root_id root # ::edge src label trg src_id trg_id alignments amr graph ''' amr = AMR(tokens=tokens) alignments = [] nodes = metadata['node'] edges = metadata['edge'] if 'edge' in metadata else [] root = metadata['root'][0] amr.root = root[0] for data in nodes: n, label = data[:2] if len(data) > 2: toks = data[2] alignments.append( AMR_Alignment(type='jamr', nodes=[n], tokens=toks)) amr.nodes[n] = label for data in edges: _, r, _, s, t = data[:5] if len(data) > 5: toks = data[5] alignments.append( AMR_Alignment(type='jamr', edges=[(s, r, t)], tokens=toks)) if not r.startswith(':'): r = ':' + r amr.edges.append((s, r, t)) return amr, alignments
def __init__(self, amrs, subgraph_alignments, relation_alignments, alpha=1): super().__init__(amrs, alpha) self.distance_model_parent = Skellam_Distance_Model() self.distance_model_child = Skellam_Distance_Model() self.subgraph_alignments = subgraph_alignments self.relation_alignments = relation_alignments self.edges_count = Counter() self.edges_total = 0 self.allowed_types_memo_ = None edge_labels = set() for amr in amrs: parents = {s: [] for s in amr.nodes} children = {t: [] for t in amr.nodes} taken_tokens = set() for s, r, t in amr.edges: parents[s].append((s, r, t)) children[t].append((s, r, t)) for align in subgraph_alignments[amr.id]: token_label = ' '.join(amr.lemmas[t] for t in align.tokens) for n in align.nodes: for e in parents[n]: s, r, t = e if t in align.nodes: continue partial_align = AMR_Alignment(type='relation', tokens=align.tokens, edges=[e]) edge_label = self.get_alignment_label( amr, partial_align) edge_labels.add(edge_label) for e in children[n]: s, r, t = e if s in align.nodes: continue partial_align = AMR_Alignment(type='relation', tokens=align.tokens, edges=[e]) edge_label = self.get_alignment_label( amr, partial_align) edge_labels.add(edge_label) if not align.nodes and token_label not in taken_tokens: edges = set() for e in amr.edges: partial_align = AMR_Alignment(type='relation', tokens=align.tokens, edges=[e]) edge_label = self.get_alignment_label( amr, partial_align) edges.add(edge_label) taken_tokens.add(token_label)
def main(): file = '../data/szubert/szubert_amrs.isi_alignments.txt' ids_file = '../data/szubert/szubert_ids.isi.txt' output = '../data/szubert/szubert_amrs.isi.txt' amr_file1 = '../data/ldc_train.txt' amr_file2 = '../data/szubert/szubert_amrs.txt' reader = AMR_Reader() amrs = reader.load(amr_file1, remove_wiki=True) szubert_amrs = reader.load(amr_file2, remove_wiki=True) szubert_amr_ids = [amr.id for amr in szubert_amrs] amrs += szubert_amrs amrs = {amr.id: amr for amr in amrs} amr_ids = [] with open(ids_file, encoding='utf8') as f: for line in f: if line: amr_ids.append(line.strip()) isi_amrs, isi_alignments = reader.load(file, output_alignments=True) subgraph_alignments = {} relation_alignments = {} for isi_amr in isi_amrs: if isi_amr.id not in szubert_amr_ids: continue amr = amrs[isi_amr.id] if len(amr.tokens) != len(isi_amr.tokens): raise Exception('Inconsistent Tokenization:', amr.id) node_labels = node_map(isi_amr, amr) edge_labels = edge_map(isi_amr, amr) isi_aligns = isi_alignments[amr.id] subgraph_alignments[amr.id] = [] relation_alignments[amr.id] = [] for i, tok in enumerate(amr.tokens): aligns = [align for align in isi_aligns if i in align.tokens] nodes = [node_labels[n] for align in aligns for n in align.nodes] edges = [edge_labels[e] for align in aligns for e in align.edges] subgraph_alignments[amr.id].append( AMR_Alignment(type='subgraph', tokens=[i], nodes=nodes)) relation_alignments[amr.id].append( AMR_Alignment(type='relation', tokens=[i], edges=edges)) reader.save_alignments_to_json( output.replace('.txt', '.subgraph_alignments.json'), subgraph_alignments) reader.save_alignments_to_json( output.replace('.txt', '.relation_alignments.json'), relation_alignments) for amr in szubert_amrs: if amr.id not in subgraph_alignments: raise Exception('Missing AMR:', amr.id)
def align_all(self, amrs, alignments=None, preprocess=True, debug=False): alignments = super().align_all(amrs, alignments, preprocess, debug) for amr in amrs: # hack to handle degenerate sentences if amr.nodes and amr.tokens and not any( align for align in alignments[amr.id]): new_align = AMR_Alignment(type='subgraph', tokens=amr.spans[0], nodes=[n for n in amr.nodes], edges=[e for e in amr.edges], amr=amr) alignments[amr.id].append(new_align) for n in amr.nodes: if not amr.get_alignment(alignments, node_id=n): parent = [e for e in amr.edges if e[-1] == n] if parent: align = amr.get_alignment(alignments, node_id=parent[0][0]) align.nodes.append(n) # add subgraph edges for align in alignments[amr.id]: if len(align.nodes) > 1: for e in amr.edges: s, r, t = e if s in align.nodes and t in align.nodes and e not in align.edges: align.edges.append(e) return alignments
def get_alignment(self, alignments, token_id=None, node_id=None, edge=None): if not isinstance(alignments, dict): raise Exception('Alignments object must be a dict.') if self.id not in alignments: return AMR_Alignment() for align in alignments[self.id]: if token_id is not None and token_id in align.tokens: return align if node_id is not None and node_id in align.nodes: return align if edge is not None and edge in align.edges: return align return AMR_Alignment()
def clean_alignments(amr, alignments, spans): aligns = [] for span in spans: align = amr.get_alignment(alignments, token_id=span[0]) if align: aligns.append(align) else: aligns.append(AMR_Alignment(type='subgraph', tokens=span, amr=amr)) alignments[amr.id] = aligns
def _parse_isi_alignments(amr, amr_file, aligns, isi_labels, isi_edge_labels): aligns = [(int(a.split('-')[0]), a.split('-')[-1]) for a in aligns if '-' in a] alignments = [] xml_offset = 1 if amr.tokens[0].startswith( '<') and amr.tokens[0].endswith('>') else 0 if any(t + xml_offset >= len(amr.tokens) for t, n in aligns): xml_offset = 0 for tok, component in aligns: tok += xml_offset nodes = [] edges = [] if component.replace('.r', '') in isi_labels: # node or attribute n = isi_labels[component.replace('.r', '')] if n == 'ignore': continue nodes.append(n) if n not in amr.nodes: raise Exception('Could not parse alignment:', amr_file, amr.id, tok, component) elif not component.endswith( '.r' ) and component not in isi_labels and component + '.r' in isi_edge_labels: # reentrancy e = isi_edge_labels[component + '.r'] edges.append(e) if e not in amr.edges: raise Exception('Could not parse alignment:', amr_file, amr.id, tok, component) elif component.endswith('.r'): # edge e = isi_edge_labels[component] if e == 'ignore': continue edges.append(e) if e not in amr.edges: raise Exception('Could not parse alignment:', amr_file, amr.id, tok, component) elif component == '0.r': nodes.append(amr.root) else: raise Exception('Could not parse alignment:', amr_file, amr.id, tok, component) if tok >= len(amr.tokens): raise Exception('Could not parse alignment:', amr_file, amr.id, tok, component) new_align = AMR_Alignment(type='isi', tokens=[tok], nodes=nodes, edges=edges) alignments.append(new_align) return alignments
def evaluate_duplicates(amrs, pred_alignments, gold_alignments): print('duplicates') duplicate_alignments = {} gold_duplicate_alignments = {} for amr in amrs: duplicate_alignments[amr.id] = [] dupicates = {} for align in pred_alignments[amr.id]: if align.type.startswith('dupl'): span = tuple(align.tokens) nodes = set(align.nodes) if span not in dupicates: dupicates[span] = set() dupicates[span].update(nodes) duplicate_alignments[amr.id] = [ AMR_Alignment(type='subgraph:dupl', tokens=list(span), nodes=list(dupicates[span])) for span in dupicates ] gold_duplicate_alignments[amr.id] = [] dupicates = {} for align in gold_alignments[amr.id]: if align.type.startswith('dupl'): span = tuple(align.tokens) nodes = set(align.nodes) if span not in dupicates: dupicates[span] = set() dupicates[span].update(nodes) gold_duplicate_alignments[amr.id] = [ AMR_Alignment(type='subgraph:dupl', tokens=list(span), nodes=list(dupicates[span])) for span in dupicates ] evaluate(amrs, duplicate_alignments, gold_duplicate_alignments, mode='nodes')
def get_initial_alignments(self, amrs, preprocess=True): relation_alignments = {} for j, amr in enumerate(amrs): print(f'\r{j} / {len(amrs)} preprocessed', end='') relation_alignments[amr.id] = [] for span in amr.spans: relation_alignments[amr.id].append(AMR_Alignment(type='relation', tokens=span, amr=amr)) rule_based_align_relations(amr, self.subgraph_alignments, relation_alignments) exact_match_relations(amr, self.subgraph_alignments, relation_alignments) print('\r', end='') print('Preprocessing coverage:', self.coverage(amrs, relation_alignments)) return relation_alignments
def add_relation_alignment(amr, relation_alignments, edge, span): if not span: raise Exception('Tried to align to empty span.') for align in relation_alignments[amr.id]: if align.tokens == span: new_align = align new_align.edges.append(edge) return new_align = AMR_Alignment(type='relation', tokens=span, edges=[edge]) relation_alignments[amr.id].append(new_align) relation_alignments[amr.id] = [ align for align in sorted(relation_alignments[amr.id], key=lambda x: x.tokens[0]) ]
def separate_components(amr, align): node_labels = [amr.nodes[n] for n in align.nodes] if len(node_labels) > 1 and all(node == node_labels[0] for node in node_labels): return [ AMR_Alignment(type='subgraph', tokens=align.tokens, nodes=[n], amr=amr) for n in align.nodes ] if not align.nodes: return [align] if is_subgraph(amr, align.nodes): return [align] components = get_connected_components(amr, align.nodes) components = [list(sub.nodes.keys()) for sub in components] components = [ AMR_Alignment(type='subgraph', tokens=align.tokens, nodes=nodes, amr=amr) for nodes in components ] return components
def _parse_jamr_alignments(amr, amr_file, aligns, jamr_labels, metadata_parser): aligns = [(metadata_parser.get_token_range(a.split('|')[0]), a.split('|')[-1].split('+')) for a in aligns if '|' in a] alignments = [] for toks, components in aligns: if not all(n in jamr_labels for n in components) or any(t >= len(amr.tokens) for t in toks): raise Exception('Could not parse alignment:', amr_file, amr.id, toks, components) nodes = [jamr_labels[n] for n in components] new_align = AMR_Alignment(type='jamr', tokens=toks, nodes=nodes) alignments.append(new_align) return alignments
def get_initial_alignments(self, amrs, preprocess=True): print(f'Apply Rules = {preprocess}') alignments = {} for j, amr in enumerate(amrs): print(f'\rPreprocessing: {j} / {len(amrs)}', end='') alignments[amr.id] = [] for span in amr.spans: alignments[amr.id].append( AMR_Alignment(type='subgraph', tokens=span, amr=amr)) if preprocess: fuzzy_align_subgraphs(amr, alignments, english=ENGLISH) for align in alignments[amr.id]: postprocess_subgraph(amr, alignments, align, english=ENGLISH) test = clean_subgraph(amr, alignments, align) if test is None: align.nodes.clear() print('\r', end='') print('Preprocessing coverage:', coverage(amrs, alignments)) return alignments
def align_primary_edges(self, amr, alignments): if not hasattr(amr, 'reentrancies'): amr.reentrancies = [ e for e in amr.edges if len([e2 for e2 in amr.edges if e2[-1] == e[-1]]) > 1 ] ts = {t for s, r, t in amr.reentrancies} for t in ts: candidates = [e for e in amr.reentrancies if e[-1] == t] talign = amr.get_alignment(self.subgraph_alignments, node_id=t) rel_align = amr.get_alignment(self.relation_alignments, token_id=talign.tokens[0]) if rel_align and any(e in rel_align.edges for e in candidates): span = talign.tokens e = [e for e in candidates if e in rel_align.edges][0] else: dists = {} for s, r, t in candidates: if not amr.get_alignment(self.relation_alignments, edge=(s,r,t)) \ and any(amr.get_alignment(self.relation_alignments, edge=e2)for e2 in candidates): continue salign = amr.get_alignment(self.subgraph_alignments, node_id=s) talign = amr.get_alignment(self.subgraph_alignments, node_id=t) dist = self.distance_model_parent.distance( amr, salign.tokens, talign.tokens) dists[(s, r, t)] = (abs(dist), salign.tokens[0]) e = min(dists, key=lambda x: dists[x]) ealign = amr.get_alignment(self.relation_alignments, edge=e) span = ealign.tokens if not span: continue alignments[amr.id].append( AMR_Alignment(type='reentrancy:primary', tokens=span, edges=[e]))
def parse_amr(self, tokens, amr_string): amr = AMR(tokens=tokens) g = penman.decode(amr_string, model=TreePenmanModel()) triples = g.triples() if callable(g.triples) else g.triples letter_labels = {} isi_labels = {g.top: '1'} isi_edge_labels = {} jamr_labels = {g.top: '0'} new_idx = 0 isi_edge_idx = {g.top: 1} jamr_edge_idx = {g.top: 0} nodes = [] attributes = [] edges = [] reentrancies = [] for i, tr in enumerate(triples): s, r, t = tr # an amr node if r == ':instance': if reentrancies and edges[-1] == reentrancies[-1]: s2, r2, t2 = edges[-1] jamr_labels[t2] = jamr_labels[s2] + '.' + str( jamr_edge_idx[s2]) isi_labels[t2] = isi_labels[s2] + '.' + str( isi_edge_idx[s2]) new_s = s while new_s in letter_labels: new_idx += 1 new_s = f'x{new_idx}' letter_labels[s] = new_s nodes.append(tr) # an amr edge elif t not in letter_labels: if len(t) > 5 or not t[0].isalpha(): if tr in letter_labels: isi_labels['ignore'] = isi_labels[s] + '.' + str( isi_edge_idx[s]) isi_edge_labels['ignore'] = isi_labels[s] + '.' + str( isi_edge_idx[s]) + '.r' isi_edge_idx[s] += 1 jamr_edge_idx[s] += 1 continue # attribute new_s = s while new_s in letter_labels: new_idx += 1 new_s = f'x{new_idx}' letter_labels[tr] = new_s jamr_labels[tr] = jamr_labels[s] + '.' + str( jamr_edge_idx[s]) isi_labels[tr] = isi_labels[s] + '.' + str(isi_edge_idx[s]) isi_edge_labels[tr] = isi_labels[s] + '.' + str( isi_edge_idx[s]) + '.r' isi_edge_idx[s] += 1 jamr_edge_idx[s] += 1 attributes.append(tr) else: # edge jamr_edge_idx[t] = 0 isi_edge_idx[t] = 1 jamr_labels[t] = jamr_labels[s] + '.' + str( jamr_edge_idx[s]) if i + 1 < len(triples) and triples[i + 1][1] == ':instance': jamr_edge_idx[s] += 1 isi_labels[t] = isi_labels[s] + '.' + str(isi_edge_idx[s]) isi_edge_labels[tr] = isi_labels[s] + '.' + str( isi_edge_idx[s]) + '.r' isi_edge_idx[s] += 1 edges.append(tr) else: # reentrancy isi_edge_labels[tr] = isi_labels[s] + '.' + str( isi_edge_idx[s]) + '.r' isi_edge_idx[s] += 1 edges.append(tr) reentrancies.append(tr) default_labels = letter_labels if self.style == 'isi': default_labels = isi_labels elif self.style == 'jamr': default_labels = jamr_labels amr.root = default_labels[g.top] edge_map = {} for tr in nodes: s, r, t = tr amr.nodes[default_labels[s]] = t for tr in attributes: s, r, t = tr if not r.startswith(':'): r = ':' + r amr.nodes[default_labels[tr]] = t amr.edges.append((default_labels[s], r, default_labels[tr])) edge_map[tr] = (default_labels[s], r, default_labels[tr]) for tr in edges: s, r, t = tr if not r.startswith(':'): r = ':' + r amr.edges.append((default_labels[s], r, default_labels[t])) edge_map[tr] = (default_labels[s], r, default_labels[t]) aligns = [] for tr, epidata in g.epidata.items(): for align in epidata: if 'Alignment' in type(align).__name__: indices = align.indices s, r, t = tr if tr[1] == ':instance': align = AMR_Alignment(type='isi', tokens=list(indices), nodes=[default_labels[s]]) elif len(t) > 5 or not t[0].isalpha(): align = AMR_Alignment(type='isi', tokens=list(indices), nodes=[default_labels[tr]]) else: align = AMR_Alignment(type='isi', tokens=list(indices), edges=[edge_map[tr]]) aligns.append(align) letter_labels = { v: default_labels[k] for k, v in letter_labels.items() } jamr_labels = {v: default_labels[k] for k, v in jamr_labels.items()} isi_labels = { v: default_labels[k] if k != 'ignore' else k for k, v in isi_labels.items() } isi_edge_labels = { v: edge_map[k] if k in edge_map else k for k, v in isi_edge_labels.items() } return amr, (letter_labels, jamr_labels, isi_labels, isi_edge_labels, aligns)
def main(): amr_file = sys.argv[1] hand_alignments_file = sys.argv[2] reader = AMR_Reader() amrs = reader.load(amr_file, remove_wiki=True) amrs = {amr.id: amr for amr in amrs} subgraph_alignments = {} relation_alignments = {} reentrancy_alignments = {} all_spans = {amr_id: set() for amr_id in amrs} amr = None node_labels = {} with open(hand_alignments_file) as f: hand_alignments = csv.reader(f, delimiter="\t") for row in hand_alignments: if row[0] == 'amr': amr_id = row[1] subgraph_alignments[amr_id] = [] relation_alignments[amr_id] = [] reentrancy_alignments[amr_id] = [] amr = amrs[amr_id] taken = set() node_labels = get_node_labels(amr) node_labels = {v: k for k, v in node_labels.items()} edge_labels = get_edge_labels(amr) edge_labels = {v: k for k, v in edge_labels.items()} elif row[0] == 'node': type = 'subgraph' if row[3].startswith('*'): type = 'dupl-subgraph' row[3] = row[3].replace('*', '') if not row[3]: raise Exception('Missing Annotation:', amr_id) node_id = row[1] if node_id not in node_labels: raise Exception('Failed to parse node labels:', amr.id, node_id) n = node_labels[node_id] token_ids = [int(t) for t in row[3].split(',')] if any(t >= len(amr.tokens) for t in token_ids): raise Exception('Bad Annotation:', amr_id) if tuple(token_ids) not in all_spans[amr_id] and any( t in taken for t in token_ids): raise Exception('Bad Span Annotation', amr_id) all_spans[amr_id].add(tuple(token_ids)) taken.update(token_ids) align = amr.get_alignment(subgraph_alignments, token_id=token_ids[0]) if align and align.type == type: align.nodes.append(n) else: new_align = AMR_Alignment(type=type, tokens=token_ids, nodes=[n], amr=amr) subgraph_alignments[amr.id].append(new_align) elif row[0] == 'edge': type = 'relation' if row[3].startswith('*'): row[3] = row[3].replace('*', '') if not row[3]: raise Exception('Missing Annotation:', amr_id) edge_id = row[1] if edge_id not in edge_labels: raise Exception('Failed to parse edge labels:', amr.id, node_id) e = edge_labels[edge_id] token_ids = [int(t) for t in row[3].split(',')] if any(t >= len(amr.tokens) for t in token_ids): raise Exception('Bad Annotation:', amr_id) if tuple(token_ids) not in all_spans[amr_id] and any( t in taken for t in token_ids): raise Exception('Bad Span Annotation', amr_id, token_ids) all_spans[amr_id].add(tuple(token_ids)) taken.update(token_ids) align = amr.get_alignment(relation_alignments, token_id=token_ids[0]) if align and align.type == type: align.edges.append(e) else: new_align = AMR_Alignment(type=type, tokens=token_ids, edges=[e], amr=amr) relation_alignments[amr.id].append(new_align) elif row[0] == 'reentrancy': if not row[3]: raise Exception('Missing Annotation:', amr_id) edge_id = row[1] e = edge_labels[edge_id] if row[3].startswith('*'): row[3] = row[3].replace('*', '') if row[3] == '_': token_ids = amr.get_alignment(relation_alignments, edge=e).tokens else: token_ids = [int(t) for t in row[3].split(',')] tag = row[4] if row[3] == '_': tag = 'primary' if not tag: raise Exception('Missing reentrancy tag:', amr.id) type = f'reentrancy:{tag}' if any(t >= len(amr.tokens) for t in token_ids): raise Exception('Bad Annotation:', amr_id) if tuple(token_ids) not in all_spans[amr_id] and any( t in taken for t in token_ids): raise Exception('Bad Span Annotation', amr_id, token_ids) all_spans[amr_id].add(tuple(token_ids)) taken.update(token_ids) new_align = AMR_Alignment(type=type, tokens=token_ids, edges=[e], amr=amr) reentrancy_alignments[amr.id].append(new_align) for amr_id in subgraph_alignments: amr = amrs[amr_id] for t in range(len(amr.tokens)): if not any(t in span for span in all_spans[amr_id]): all_spans[amr_id].add((t, )) spans = [ list(span) for span in sorted(all_spans[amr_id], key=lambda x: x[0]) ] for align in subgraph_alignments[amr_id]: if align.nodes and not is_subgraph(amr, align.nodes): print('Possible Bad align:', amr.id, align.tokens, ' '.join(amr.tokens[t] for t in align.tokens), file=sys.stderr) for align in relation_alignments[amr_id]: subgraph_aligns = [ a for a in subgraph_alignments[amr.id] if a.tokens == align.tokens ] for s, r, t in align.edges: if subgraph_aligns and not any( s in a.nodes or t in a.nodes or not a.nodes for a in subgraph_aligns): if r == ':manner' and amr.tokens[ align.tokens[0]] == 'without': continue raise Exception('Bad Relation align:', amr.id, align.tokens, s, r, t) dupl_sub_aligns = [ align for align in subgraph_alignments[amr_id] if align.type.startswith('dupl') ] subgraph_alignments[amr_id] = [ align for align in subgraph_alignments[amr_id] if not align.type.startswith('dupl') ] # dupl_rel_aligns = [align for align in relation_alignments[amr_id] if align.type.startswith('dupl')] # relation_alignments[amr_id] = [align for align in relation_alignments[amr_id] if not align.type.startswith('dupl')] clean_alignments(amr, subgraph_alignments, dupl_sub_aligns, spans) clean_alignments(amr, relation_alignments, [], spans, mode='relations') for t, _ in enumerate(amr.tokens): count = [span for span in spans if t in span] if len(count) != 1: raise Exception('Bad Span:', amr.id, count) # amr_file = amr_file.replace('.txt', '.jakob') align_file = amr_file.replace('.txt', '') + f'.subgraph_alignments.gold.json' print(f'Writing subgraph alignments to: {align_file}') reader.save_alignments_to_json(align_file, subgraph_alignments) align_file = amr_file.replace('.txt', '') + f'.relation_alignments.gold.json' print(f'Writing relation alignments to: {align_file}') reader.save_alignments_to_json(align_file, relation_alignments) align_file = amr_file.replace('.txt', '') + f'.reentrancy_alignments.gold.json' print(f'Writing reentrancy alignments to: {align_file}') reader.save_alignments_to_json(align_file, reentrancy_alignments)
def align(self, amr, relation_alignments, e, unaligned=None, return_all=False): # get candidates candidate_spans = [align.tokens for align in self.subgraph_alignments[amr.id] if not align.nodes] candidate_spans = [span for span in candidate_spans if not amr.get_alignment(relation_alignments, token_id=span[0])] candidate_spans = [span for span in candidate_spans if not english_ignore_tokens(amr, span)] candidate_neighbors = rule_based_anchor_relation(e) # only align to prepositions between parent and child parent = amr.get_alignment(self.subgraph_alignments, node_id=e[0]) child = amr.get_alignment(self.subgraph_alignments, node_id=e[2]) candidate_spans = [span for span in candidate_spans if (parent.tokens[0]<span[0]<child.tokens[0]) or (child.tokens[0]<span[0]<parent.tokens[0]) or ' '.join(amr.lemmas[t] for t in span)=='ago'] # make sure rel alignment does not interfere with child and any of its descendents child_descendents = set() child_descendents.update(child.tokens) for s,r,t in amr.edges: if s == e[2]: talign = amr.get_alignment(self.subgraph_alignments, node_id=t) child_descendents.update(talign.tokens) if child_descendents: start, end = min(child_descendents), max(child_descendents) if not (start<=parent.tokens[0]<=end): candidate_spans = [span for span in candidate_spans if not (start<=span[0]<=end)] scores1 = {} aligns1 = {} for i, span in enumerate(candidate_spans): new_align = AMR_Alignment(type='relation', tokens=span, edges=[e], amr=amr) replaced_align = AMR_Alignment(type='relation', tokens=span, edges=[], amr=amr) scores1[i] = self.logp(amr, relation_alignments, new_align) - self.logp(amr, relation_alignments, replaced_align) aligns1[i] = new_align scores2 = {} aligns2 = {} for i, neighbor in enumerate(candidate_neighbors): sub_align = amr.get_alignment(self.subgraph_alignments, node_id=neighbor) span = sub_align.tokens if not span: continue if span not in amr.spans: raise Exception('Subgraph Alignment has Faulty Span:', span) replaced_align = amr.get_alignment(relation_alignments, token_id=span[0]) new_align = AMR_Alignment(type='relation', tokens=replaced_align.tokens, edges=replaced_align.edges+[e], amr=amr) scores2[i] = self.logp(amr, relation_alignments, new_align) - self.logp(amr, relation_alignments, replaced_align) aligns2[i] = new_align all_scores = {} all_aligns = {} for x in scores1: span = tuple(aligns1[x].tokens) all_scores[span] = scores1[x] all_aligns[span] = aligns1[x] for x in scores2: span = tuple(aligns2[x].tokens) all_scores[span] = scores2[x] all_aligns[span] = aligns2[x] if not all_scores: return None, None if return_all: return all_aligns, all_scores best_span = max(all_scores.keys(), key=lambda x:all_scores[x]) best_score = all_scores[best_span] best_align = all_aligns[best_span] # readable = [r for r in sorted(readable, key=lambda x:x['score'], reverse=True)] return best_align, best_score
def align(self, amr, alignments, n, unaligned=None, return_all=False): # get candidates if unaligned is None: unaligned = self.get_unaligned(amr, alignments) candidate_spans = [ align.tokens for align in alignments[amr.id] if not align.nodes ] tmp_align = AMR_Alignment(type='subgraph', tokens=[0], nodes=[n]) postprocess_subgraph(amr, alignments, tmp_align, english=ENGLISH) candidate_neighbors = [s for s, r, t in amr.edges if t in tmp_align.nodes and s not in unaligned] + \ [t for s, r, t in amr.edges if s in tmp_align.nodes and t not in unaligned] for n2 in candidate_neighbors[:]: nalign = amr.get_alignment(alignments, node_id=n2) if not nalign or nalign.type == 'dupl-subgraph': candidate_neighbors.remove(n2) # handle "never => ever, -" and other similar cases edge_map = {n: [] for n in amr.nodes} for s, r, t in amr.edges: edge_map[s].append(t) if not edge_map[n]: for n2 in amr.nodes: if edge_map[n2]: continue if n2 in unaligned: continue if amr.nodes[n] == amr.nodes[n2]: continue nalign = amr.get_alignment(alignments, node_id=n2) if len(nalign.nodes) != 1: continue if any(n in edge_map[p] and n2 in edge_map[p] for p in amr.nodes): candidate_neighbors.append(n2) # special rules for multi-sentence, and, or if ENGLISH: candidate_spans2 = [ span for span in candidate_spans if not english_is_alignment_forbidden(amr, span, n) ] if amr.nodes[n] == 'multi-sentence' and not candidate_spans: candidate_spans2 = candidate_spans elif amr.nodes[n] == 'and' and not candidate_spans: candidate_spans2 = candidate_spans candidate_spans = candidate_spans2 if amr.nodes[n] in ['multi-sentence', 'and', 'or'] and candidate_spans: candidate_neighbors = [] for n2 in candidate_neighbors[:]: if amr.nodes[n2] in ['multi-sentence', 'and', 'or' ] and candidate_spans: candidate_neighbors.remove(n2) if len([n2 for n2 in amr.nodes if amr.nodes[n] == amr.nodes[n2]]) > 1: for s, r, t in amr.edges: if t == n and amr.nodes[s] in [ 'include-91', 'same-01', 'instead-of-91', 'resemble-01', 'differ-02', 'and', 'or' ]: if len([lemma for lemma in amr.lemmas if amr.nodes[n].split('-')[0]==lemma]) >= \ len([n2 for n2 in amr.nodes if amr.nodes[n]==amr.nodes[n2]]): break for s2, r2, t2 in amr.edges: if s2 == s and t2 != t and amr.nodes[t2] == amr.nodes[ n] and r2.endswith('1'): candidate_spans = [] break elif t2 == s and amr.nodes[s2] == amr.nodes[ n] and r2.endswith('1-of'): candidate_spans = [] break candidate_duplicates = [] for n2 in amr.nodes: if amr.nodes[n].isdigit() or '"' in amr.nodes[n]: break if n2 != n and amr.nodes[n] == amr.nodes[n2]: align = amr.get_alignment(alignments, node_id=n2) if align: candidate_duplicates.append(align.tokens) scores1 = {} aligns1 = {} for i, span in enumerate(candidate_spans): new_align = AMR_Alignment(type='subgraph', tokens=span, nodes=[n], amr=amr) replaced_align = AMR_Alignment(type='subgraph', tokens=span, nodes=[], amr=amr) scores1[i] = self.logp(amr, alignments, new_align) - self.logp( amr, alignments, replaced_align) aligns1[i] = new_align scores2 = {} aligns2 = {} for i, neighbor in enumerate(candidate_neighbors): replaced_align = amr.get_alignment(alignments, node_id=neighbor) if replaced_align.type.startswith('dupl'): continue new_align = AMR_Alignment(type=replaced_align.type, tokens=replaced_align.tokens, nodes=replaced_align.nodes + [n], amr=amr) scores2[i] = self.logp(amr, alignments, new_align) - self.logp( amr, alignments, replaced_align, postprocess=False) aligns2[i] = new_align scores3 = {} aligns3 = {} if self.align_duplicates: for i, span in enumerate(candidate_duplicates): new_align = AMR_Alignment(type='dupl-subgraph', tokens=span, nodes=[n], amr=amr) replaced_align = amr.get_alignment(alignments, token_id=span[0]) scores3[i] = math.log(DUPLICATE_RATE) + self.logp( amr, alignments, new_align) - self.logp( amr, alignments, replaced_align, postprocess=False) aligns3[i] = new_align all_scores = {} all_aligns = {} for x in scores1: span = tuple(aligns1[x].tokens) all_scores[span] = scores1[x] all_aligns[span] = aligns1[x] for x in scores2: span = tuple(aligns2[x].tokens) all_scores[span] = scores2[x] all_aligns[span] = aligns2[x] for x in scores3: span = tuple(aligns3[x].tokens) all_scores[span] = scores3[x] all_aligns[span] = aligns3[x] if not all_scores: return None, None if return_all: return all_aligns, all_scores best_span = max(all_scores.keys(), key=lambda x: all_scores[x]) best_score = all_scores[best_span] best_align = all_aligns[best_span] # readable = [r for r in sorted(readable, key=lambda x:x['score'], reverse=True)] return best_align, best_score
def main(): dir = '../data/tamr' szubert_amrs = '../data/szubert/szubert_amrs.txt' output = '../data/szubert/szubert_amrs.tamr.subgraph_alignments.json' file2 = '../data/tamr/ldc_train_2017.txt' reader = AMR_Reader() amrs = reader.load(szubert_amrs, remove_wiki=True) amrs2 = reader.load(file2, remove_wiki=True) alignments = {} for filename in os.listdir(dir): if filename.endswith(".tamr_alignment"): file = os.path.join(dir, filename) amr_id = '' with open(file) as f: for line in f: if line.startswith('# ::alignments'): aligns = line[len('# ::alignments '):].split() aligns = [s.split('|') for s in aligns if '|' in s] aligns = [(a[0], a[1].split('+')) for a in aligns] for span, nodes in aligns: start = int(span.split('-')[0]) end = int(span.split('-')[1]) span = [t for t in range(start, end)] align = AMR_Alignment(type='subgraph', tokens=span, nodes=nodes) alignments[amr_id].append(align) elif line.strip(): amr_id = line.strip() alignments[amr_id] = [] amrs2 = {amr.id: amr for amr in amrs2} amrs = [amr for amr in amrs if amr.id in alignments and amr.id in amrs2] amrs3 = [] for amr in amrs[:]: amr2 = amrs2[amr.id] nodes = {amr.nodes[n] for n in amr.nodes} nodes2 = {amr2.nodes[n] for n in amr2.nodes} edges = {(amr.nodes[s], r, amr.nodes[t]) for s, r, t in amr.edges} edges2 = {(amr2.nodes[s], r, amr2.nodes[t]) for s, r, t in amr2.edges} if nodes == nodes2 and edges == edges2: amrs3.append(amr) amr_ids = [amr.id for amr in amrs] alignments = { amr_id: alignments[amr_id] for amr_id in alignments if amr_id in amr_ids } for amr in amrs: node_map = {} nodes = [n for align in alignments[amr.id] for n in align.nodes] nodes = [n for n in sorted(nodes, key=lambda x: (len(x), x))] for n in nodes: prefix = '.'.join(i for i in n.split('.')[:-1]) last = int(n.split('.')[-1]) if prefix: if prefix not in node_map: new_prefix = '.'.join( str(int(i) + 1) for i in n.split('.')[:-1]) if new_prefix not in amr.nodes: continue node_map[prefix] = new_prefix new_n = node_map[prefix] + '.' + str(last + 1) else: new_n = str(last + 1) if new_n in amr.nodes: node_map[n] = new_n nodes = [ n for align in alignments[amr.id] for n in align.nodes if n not in node_map ] nodes = [n for n in sorted(nodes, key=lambda x: (len(x), x))] for n in nodes: prefix = '.'.join(i for i in n.split('.')[:-1]) if prefix not in node_map: new_prefix = '.'.join( str(int(i) + 1) for i in n.split('.')[:-1]) if new_prefix in amr.nodes: node_map[prefix] = new_prefix else: del alignments[amr.id] break candidates = [t for s, r, t in amr.edges if s == node_map[prefix]] candidates = [t for t in candidates if t not in node_map.values()] candidates = [t for t in sorted(candidates)] if not candidates: del alignments[amr.id] break new_n = candidates[0] node_map[n] = new_n if amr.id in alignments: for align in alignments[amr.id]: align.nodes = [node_map[n] for n in align.nodes] align.amr = amr for t, tok in enumerate(amr.tokens): align = amr.get_alignment(alignments, token_id=t) if not align: align = AMR_Alignment(type='subgraph', tokens=[t], nodes=[], amr=amr) alignments[amr.id].append(align) alignments[amr.id] = [ align for align in sorted(alignments[amr.id], key=lambda a: a.tokens[0]) ] reader.save_alignments_to_json(output, alignments)
def align(self, amr, reentrancy_alignments, e, unaligned=None, return_all=False): # get candidates allowed_types = self.get_allowed_types(amr) candidate_spans = [ span for span in amr.spans if allowed_types[e][tuple(span)] ] # candidate_spans = [span for span in candidate_spans if not amr.get_alignment(reentrancy_alignments, token_id=span[0])] candidate_neighbors = [] #[e] neighbor_aligns = [ amr.get_alignment(reentrancy_alignments, edge=(s, r, t)) for s, r, t in amr.reentrancies if t == e[-1] and e != (s, r, t) ] # if all(a.type!='reentrancy:primary' for a in neighbor_aligns): # candidate_spans = [] readable = [] scores1 = {} aligns1 = {} for i, span in enumerate(candidate_spans): type = allowed_types[e][tuple(span)][0] new_align = AMR_Alignment(type=f'reentrancy:{type}', tokens=span, edges=[e], amr=amr) # replaced_align = AMR_Alignment(type='relation', tokens=span, edges=[], amr=amr) scores1[i] = self.logp( amr, reentrancy_alignments, new_align ) #- self.logp(amr, reentrancy_alignments, replaced_align) # scores1[i] = self.inductive_bias(amr, reentrancy_alignments, new_align) - self.inductive_bias(amr, reentrancy_alignments, replaced_align) if type == 'pragmatic': scores1[i] += math.log(PRAGMATIC_RATE) aligns1[i] = new_align # readable.append(self.readable_logp(amr,alignments, new_align)) scores2 = {} aligns2 = {} for i, neighbor in enumerate(candidate_neighbors): rel_align = amr.get_alignment(self.relation_alignments, edge=neighbor) span = rel_align.tokens if not span: continue if span not in amr.spans: raise Exception('Relation Alignment has Faulty Span:', span) new_align = AMR_Alignment(type='reentrancy:primary', tokens=rel_align.tokens, edges=[e], amr=amr) scores2[i] = self.logp( amr, reentrancy_alignments, new_align ) #- self.logp(amr, reentrancy_alignments, replaced_align) # scores2[i] = self.inductive_bias(amr, reentrancy_alignments, new_align) - self.inductive_bias(amr, reentrancy_alignments, replaced_align) aligns2[i] = new_align # readable.append(self.readable_logp(amr, alignments, new_align)) all_scores = {} all_aligns = {} for x in scores1: span = tuple(aligns1[x].tokens) all_scores[span] = scores1[x] all_aligns[span] = aligns1[x] for x in scores2: span = tuple(aligns2[x].tokens) all_scores[span] = scores2[x] all_aligns[span] = aligns2[x] if not all_scores: return None, None if return_all: return all_aligns, all_scores best_span = max(all_scores.keys(), key=lambda x: all_scores[x]) best_score = all_scores[best_span] best_align = all_aligns[best_span] # readable = [r for r in sorted(readable, key=lambda x:x['score'], reverse=True)] return best_align, best_score