예제 #1
0
    def _parse_amr_from_metadata(tokens, metadata):
        '''
           Metadata format is ...
           # ::id sentence id
           # ::tok tokens...
           # ::node node_id node alignments
           # ::root root_id root
           # ::edge src label trg src_id trg_id alignments
           amr graph
           '''
        amr = AMR(tokens=tokens)
        alignments = []

        nodes = metadata['node']
        edges = metadata['edge'] if 'edge' in metadata else []
        root = metadata['root'][0]
        amr.root = root[0]
        for data in nodes:
            n, label = data[:2]
            if len(data) > 2:
                toks = data[2]
                alignments.append(
                    AMR_Alignment(type='jamr', nodes=[n], tokens=toks))
            amr.nodes[n] = label
        for data in edges:
            _, r, _, s, t = data[:5]
            if len(data) > 5:
                toks = data[5]
                alignments.append(
                    AMR_Alignment(type='jamr', edges=[(s, r, t)], tokens=toks))
            if not r.startswith(':'): r = ':' + r
            amr.edges.append((s, r, t))
        return amr, alignments
예제 #2
0
    def __init__(self,
                 amrs,
                 subgraph_alignments,
                 relation_alignments,
                 alpha=1):
        super().__init__(amrs, alpha)

        self.distance_model_parent = Skellam_Distance_Model()
        self.distance_model_child = Skellam_Distance_Model()

        self.subgraph_alignments = subgraph_alignments
        self.relation_alignments = relation_alignments

        self.edges_count = Counter()
        self.edges_total = 0

        self.allowed_types_memo_ = None

        edge_labels = set()
        for amr in amrs:
            parents = {s: [] for s in amr.nodes}
            children = {t: [] for t in amr.nodes}
            taken_tokens = set()
            for s, r, t in amr.edges:
                parents[s].append((s, r, t))
                children[t].append((s, r, t))
            for align in subgraph_alignments[amr.id]:
                token_label = ' '.join(amr.lemmas[t] for t in align.tokens)
                for n in align.nodes:
                    for e in parents[n]:
                        s, r, t = e
                        if t in align.nodes: continue
                        partial_align = AMR_Alignment(type='relation',
                                                      tokens=align.tokens,
                                                      edges=[e])
                        edge_label = self.get_alignment_label(
                            amr, partial_align)
                        edge_labels.add(edge_label)
                    for e in children[n]:
                        s, r, t = e
                        if s in align.nodes: continue
                        partial_align = AMR_Alignment(type='relation',
                                                      tokens=align.tokens,
                                                      edges=[e])
                        edge_label = self.get_alignment_label(
                            amr, partial_align)
                        edge_labels.add(edge_label)
                if not align.nodes and token_label not in taken_tokens:
                    edges = set()
                    for e in amr.edges:
                        partial_align = AMR_Alignment(type='relation',
                                                      tokens=align.tokens,
                                                      edges=[e])
                        edge_label = self.get_alignment_label(
                            amr, partial_align)
                        edges.add(edge_label)
                    taken_tokens.add(token_label)
예제 #3
0
def main():
    file = '../data/szubert/szubert_amrs.isi_alignments.txt'
    ids_file = '../data/szubert/szubert_ids.isi.txt'
    output = '../data/szubert/szubert_amrs.isi.txt'

    amr_file1 = '../data/ldc_train.txt'
    amr_file2 = '../data/szubert/szubert_amrs.txt'
    reader = AMR_Reader()
    amrs = reader.load(amr_file1, remove_wiki=True)
    szubert_amrs = reader.load(amr_file2, remove_wiki=True)
    szubert_amr_ids = [amr.id for amr in szubert_amrs]
    amrs += szubert_amrs
    amrs = {amr.id: amr for amr in amrs}

    amr_ids = []
    with open(ids_file, encoding='utf8') as f:
        for line in f:
            if line:
                amr_ids.append(line.strip())

    isi_amrs, isi_alignments = reader.load(file, output_alignments=True)

    subgraph_alignments = {}
    relation_alignments = {}
    for isi_amr in isi_amrs:
        if isi_amr.id not in szubert_amr_ids: continue
        amr = amrs[isi_amr.id]
        if len(amr.tokens) != len(isi_amr.tokens):
            raise Exception('Inconsistent Tokenization:', amr.id)
        node_labels = node_map(isi_amr, amr)
        edge_labels = edge_map(isi_amr, amr)
        isi_aligns = isi_alignments[amr.id]
        subgraph_alignments[amr.id] = []
        relation_alignments[amr.id] = []
        for i, tok in enumerate(amr.tokens):
            aligns = [align for align in isi_aligns if i in align.tokens]
            nodes = [node_labels[n] for align in aligns for n in align.nodes]
            edges = [edge_labels[e] for align in aligns for e in align.edges]
            subgraph_alignments[amr.id].append(
                AMR_Alignment(type='subgraph', tokens=[i], nodes=nodes))
            relation_alignments[amr.id].append(
                AMR_Alignment(type='relation', tokens=[i], edges=edges))
    reader.save_alignments_to_json(
        output.replace('.txt', '.subgraph_alignments.json'),
        subgraph_alignments)
    reader.save_alignments_to_json(
        output.replace('.txt', '.relation_alignments.json'),
        relation_alignments)

    for amr in szubert_amrs:
        if amr.id not in subgraph_alignments:
            raise Exception('Missing AMR:', amr.id)
예제 #4
0
    def align_all(self, amrs, alignments=None, preprocess=True, debug=False):
        alignments = super().align_all(amrs, alignments, preprocess, debug)

        for amr in amrs:
            # hack to handle degenerate sentences
            if amr.nodes and amr.tokens and not any(
                    align for align in alignments[amr.id]):
                new_align = AMR_Alignment(type='subgraph',
                                          tokens=amr.spans[0],
                                          nodes=[n for n in amr.nodes],
                                          edges=[e for e in amr.edges],
                                          amr=amr)
                alignments[amr.id].append(new_align)
            for n in amr.nodes:
                if not amr.get_alignment(alignments, node_id=n):
                    parent = [e for e in amr.edges if e[-1] == n]
                    if parent:
                        align = amr.get_alignment(alignments,
                                                  node_id=parent[0][0])
                        align.nodes.append(n)
            # add subgraph edges
            for align in alignments[amr.id]:
                if len(align.nodes) > 1:
                    for e in amr.edges:
                        s, r, t = e
                        if s in align.nodes and t in align.nodes and e not in align.edges:
                            align.edges.append(e)

        return alignments
예제 #5
0
 def get_alignment(self,
                   alignments,
                   token_id=None,
                   node_id=None,
                   edge=None):
     if not isinstance(alignments, dict):
         raise Exception('Alignments object must be a dict.')
     if self.id not in alignments:
         return AMR_Alignment()
     for align in alignments[self.id]:
         if token_id is not None and token_id in align.tokens:
             return align
         if node_id is not None and node_id in align.nodes:
             return align
         if edge is not None and edge in align.edges:
             return align
     return AMR_Alignment()
예제 #6
0
def clean_alignments(amr, alignments, spans):
    aligns = []
    for span in spans:
        align = amr.get_alignment(alignments, token_id=span[0])
        if align:
            aligns.append(align)
        else:
            aligns.append(AMR_Alignment(type='subgraph', tokens=span, amr=amr))
    alignments[amr.id] = aligns
예제 #7
0
    def _parse_isi_alignments(amr, amr_file, aligns, isi_labels,
                              isi_edge_labels):
        aligns = [(int(a.split('-')[0]), a.split('-')[-1]) for a in aligns
                  if '-' in a]

        alignments = []
        xml_offset = 1 if amr.tokens[0].startswith(
            '<') and amr.tokens[0].endswith('>') else 0
        if any(t + xml_offset >= len(amr.tokens) for t, n in aligns):
            xml_offset = 0

        for tok, component in aligns:
            tok += xml_offset
            nodes = []
            edges = []
            if component.replace('.r', '') in isi_labels:
                # node or attribute
                n = isi_labels[component.replace('.r', '')]
                if n == 'ignore': continue
                nodes.append(n)
                if n not in amr.nodes:
                    raise Exception('Could not parse alignment:', amr_file,
                                    amr.id, tok, component)
            elif not component.endswith(
                    '.r'
            ) and component not in isi_labels and component + '.r' in isi_edge_labels:
                # reentrancy
                e = isi_edge_labels[component + '.r']
                edges.append(e)
                if e not in amr.edges:
                    raise Exception('Could not parse alignment:', amr_file,
                                    amr.id, tok, component)
            elif component.endswith('.r'):
                # edge
                e = isi_edge_labels[component]
                if e == 'ignore': continue
                edges.append(e)
                if e not in amr.edges:
                    raise Exception('Could not parse alignment:', amr_file,
                                    amr.id, tok, component)
            elif component == '0.r':
                nodes.append(amr.root)
            else:
                raise Exception('Could not parse alignment:', amr_file, amr.id,
                                tok, component)
            if tok >= len(amr.tokens):
                raise Exception('Could not parse alignment:', amr_file, amr.id,
                                tok, component)
            new_align = AMR_Alignment(type='isi',
                                      tokens=[tok],
                                      nodes=nodes,
                                      edges=edges)
            alignments.append(new_align)
        return alignments
예제 #8
0
def evaluate_duplicates(amrs, pred_alignments, gold_alignments):
    print('duplicates')
    duplicate_alignments = {}
    gold_duplicate_alignments = {}

    for amr in amrs:
        duplicate_alignments[amr.id] = []
        dupicates = {}
        for align in pred_alignments[amr.id]:
            if align.type.startswith('dupl'):
                span = tuple(align.tokens)
                nodes = set(align.nodes)
                if span not in dupicates:
                    dupicates[span] = set()
                dupicates[span].update(nodes)
        duplicate_alignments[amr.id] = [
            AMR_Alignment(type='subgraph:dupl',
                          tokens=list(span),
                          nodes=list(dupicates[span])) for span in dupicates
        ]

        gold_duplicate_alignments[amr.id] = []
        dupicates = {}
        for align in gold_alignments[amr.id]:
            if align.type.startswith('dupl'):
                span = tuple(align.tokens)
                nodes = set(align.nodes)
                if span not in dupicates:
                    dupicates[span] = set()
                dupicates[span].update(nodes)
        gold_duplicate_alignments[amr.id] = [
            AMR_Alignment(type='subgraph:dupl',
                          tokens=list(span),
                          nodes=list(dupicates[span])) for span in dupicates
        ]

    evaluate(amrs,
             duplicate_alignments,
             gold_duplicate_alignments,
             mode='nodes')
예제 #9
0
    def get_initial_alignments(self, amrs, preprocess=True):

        relation_alignments = {}
        for j, amr in enumerate(amrs):
            print(f'\r{j} / {len(amrs)} preprocessed', end='')
            relation_alignments[amr.id] = []
            for span in amr.spans:
                relation_alignments[amr.id].append(AMR_Alignment(type='relation', tokens=span, amr=amr))
            rule_based_align_relations(amr, self.subgraph_alignments, relation_alignments)
            exact_match_relations(amr, self.subgraph_alignments, relation_alignments)
        print('\r', end='')
        print('Preprocessing coverage:', self.coverage(amrs, relation_alignments))
        return relation_alignments
예제 #10
0
def add_relation_alignment(amr, relation_alignments, edge, span):
    if not span:
        raise Exception('Tried to align to empty span.')
    for align in relation_alignments[amr.id]:
        if align.tokens == span:
            new_align = align
            new_align.edges.append(edge)
            return
    new_align = AMR_Alignment(type='relation', tokens=span, edges=[edge])
    relation_alignments[amr.id].append(new_align)
    relation_alignments[amr.id] = [
        align for align in sorted(relation_alignments[amr.id],
                                  key=lambda x: x.tokens[0])
    ]
예제 #11
0
def separate_components(amr, align):
    node_labels = [amr.nodes[n] for n in align.nodes]
    if len(node_labels) > 1 and all(node == node_labels[0]
                                    for node in node_labels):
        return [
            AMR_Alignment(type='subgraph',
                          tokens=align.tokens,
                          nodes=[n],
                          amr=amr) for n in align.nodes
        ]
    if not align.nodes:
        return [align]
    if is_subgraph(amr, align.nodes):
        return [align]
    components = get_connected_components(amr, align.nodes)
    components = [list(sub.nodes.keys()) for sub in components]
    components = [
        AMR_Alignment(type='subgraph',
                      tokens=align.tokens,
                      nodes=nodes,
                      amr=amr) for nodes in components
    ]
    return components
예제 #12
0
    def _parse_jamr_alignments(amr, amr_file, aligns, jamr_labels,
                               metadata_parser):
        aligns = [(metadata_parser.get_token_range(a.split('|')[0]),
                   a.split('|')[-1].split('+')) for a in aligns if '|' in a]

        alignments = []
        for toks, components in aligns:
            if not all(n in jamr_labels
                       for n in components) or any(t >= len(amr.tokens)
                                                   for t in toks):
                raise Exception('Could not parse alignment:', amr_file, amr.id,
                                toks, components)
            nodes = [jamr_labels[n] for n in components]
            new_align = AMR_Alignment(type='jamr', tokens=toks, nodes=nodes)
            alignments.append(new_align)
        return alignments
예제 #13
0
 def get_initial_alignments(self, amrs, preprocess=True):
     print(f'Apply Rules = {preprocess}')
     alignments = {}
     for j, amr in enumerate(amrs):
         print(f'\rPreprocessing: {j} / {len(amrs)}', end='')
         alignments[amr.id] = []
         for span in amr.spans:
             alignments[amr.id].append(
                 AMR_Alignment(type='subgraph', tokens=span, amr=amr))
         if preprocess:
             fuzzy_align_subgraphs(amr, alignments, english=ENGLISH)
             for align in alignments[amr.id]:
                 postprocess_subgraph(amr,
                                      alignments,
                                      align,
                                      english=ENGLISH)
                 test = clean_subgraph(amr, alignments, align)
                 if test is None:
                     align.nodes.clear()
     print('\r', end='')
     print('Preprocessing coverage:', coverage(amrs, alignments))
     return alignments
예제 #14
0
 def align_primary_edges(self, amr, alignments):
     if not hasattr(amr, 'reentrancies'):
         amr.reentrancies = [
             e for e in amr.edges
             if len([e2 for e2 in amr.edges if e2[-1] == e[-1]]) > 1
         ]
     ts = {t for s, r, t in amr.reentrancies}
     for t in ts:
         candidates = [e for e in amr.reentrancies if e[-1] == t]
         talign = amr.get_alignment(self.subgraph_alignments, node_id=t)
         rel_align = amr.get_alignment(self.relation_alignments,
                                       token_id=talign.tokens[0])
         if rel_align and any(e in rel_align.edges for e in candidates):
             span = talign.tokens
             e = [e for e in candidates if e in rel_align.edges][0]
         else:
             dists = {}
             for s, r, t in candidates:
                 if not amr.get_alignment(self.relation_alignments, edge=(s,r,t)) \
                         and any(amr.get_alignment(self.relation_alignments, edge=e2)for e2 in candidates):
                     continue
                 salign = amr.get_alignment(self.subgraph_alignments,
                                            node_id=s)
                 talign = amr.get_alignment(self.subgraph_alignments,
                                            node_id=t)
                 dist = self.distance_model_parent.distance(
                     amr, salign.tokens, talign.tokens)
                 dists[(s, r, t)] = (abs(dist), salign.tokens[0])
             e = min(dists, key=lambda x: dists[x])
             ealign = amr.get_alignment(self.relation_alignments, edge=e)
             span = ealign.tokens
         if not span:
             continue
         alignments[amr.id].append(
             AMR_Alignment(type='reentrancy:primary',
                           tokens=span,
                           edges=[e]))
예제 #15
0
    def parse_amr(self, tokens, amr_string):
        amr = AMR(tokens=tokens)
        g = penman.decode(amr_string, model=TreePenmanModel())
        triples = g.triples() if callable(g.triples) else g.triples

        letter_labels = {}
        isi_labels = {g.top: '1'}
        isi_edge_labels = {}
        jamr_labels = {g.top: '0'}

        new_idx = 0

        isi_edge_idx = {g.top: 1}
        jamr_edge_idx = {g.top: 0}

        nodes = []
        attributes = []
        edges = []
        reentrancies = []

        for i, tr in enumerate(triples):
            s, r, t = tr
            # an amr node
            if r == ':instance':
                if reentrancies and edges[-1] == reentrancies[-1]:
                    s2, r2, t2 = edges[-1]
                    jamr_labels[t2] = jamr_labels[s2] + '.' + str(
                        jamr_edge_idx[s2])
                    isi_labels[t2] = isi_labels[s2] + '.' + str(
                        isi_edge_idx[s2])
                new_s = s
                while new_s in letter_labels:
                    new_idx += 1
                    new_s = f'x{new_idx}'
                letter_labels[s] = new_s
                nodes.append(tr)
            # an amr edge
            elif t not in letter_labels:
                if len(t) > 5 or not t[0].isalpha():
                    if tr in letter_labels:
                        isi_labels['ignore'] = isi_labels[s] + '.' + str(
                            isi_edge_idx[s])
                        isi_edge_labels['ignore'] = isi_labels[s] + '.' + str(
                            isi_edge_idx[s]) + '.r'
                        isi_edge_idx[s] += 1
                        jamr_edge_idx[s] += 1
                        continue
                    # attribute
                    new_s = s
                    while new_s in letter_labels:
                        new_idx += 1
                        new_s = f'x{new_idx}'
                    letter_labels[tr] = new_s
                    jamr_labels[tr] = jamr_labels[s] + '.' + str(
                        jamr_edge_idx[s])
                    isi_labels[tr] = isi_labels[s] + '.' + str(isi_edge_idx[s])
                    isi_edge_labels[tr] = isi_labels[s] + '.' + str(
                        isi_edge_idx[s]) + '.r'
                    isi_edge_idx[s] += 1
                    jamr_edge_idx[s] += 1
                    attributes.append(tr)
                else:
                    # edge
                    jamr_edge_idx[t] = 0
                    isi_edge_idx[t] = 1
                    jamr_labels[t] = jamr_labels[s] + '.' + str(
                        jamr_edge_idx[s])
                    if i + 1 < len(triples) and triples[i +
                                                        1][1] == ':instance':
                        jamr_edge_idx[s] += 1
                    isi_labels[t] = isi_labels[s] + '.' + str(isi_edge_idx[s])
                    isi_edge_labels[tr] = isi_labels[s] + '.' + str(
                        isi_edge_idx[s]) + '.r'
                    isi_edge_idx[s] += 1
                    edges.append(tr)
            else:
                # reentrancy
                isi_edge_labels[tr] = isi_labels[s] + '.' + str(
                    isi_edge_idx[s]) + '.r'
                isi_edge_idx[s] += 1
                edges.append(tr)
                reentrancies.append(tr)

        default_labels = letter_labels
        if self.style == 'isi':
            default_labels = isi_labels
        elif self.style == 'jamr':
            default_labels = jamr_labels

        amr.root = default_labels[g.top]
        edge_map = {}
        for tr in nodes:
            s, r, t = tr
            amr.nodes[default_labels[s]] = t
        for tr in attributes:
            s, r, t = tr
            if not r.startswith(':'): r = ':' + r
            amr.nodes[default_labels[tr]] = t
            amr.edges.append((default_labels[s], r, default_labels[tr]))
            edge_map[tr] = (default_labels[s], r, default_labels[tr])
        for tr in edges:
            s, r, t = tr
            if not r.startswith(':'): r = ':' + r
            amr.edges.append((default_labels[s], r, default_labels[t]))
            edge_map[tr] = (default_labels[s], r, default_labels[t])

        aligns = []
        for tr, epidata in g.epidata.items():
            for align in epidata:
                if 'Alignment' in type(align).__name__:
                    indices = align.indices
                    s, r, t = tr
                    if tr[1] == ':instance':
                        align = AMR_Alignment(type='isi',
                                              tokens=list(indices),
                                              nodes=[default_labels[s]])
                    elif len(t) > 5 or not t[0].isalpha():
                        align = AMR_Alignment(type='isi',
                                              tokens=list(indices),
                                              nodes=[default_labels[tr]])
                    else:
                        align = AMR_Alignment(type='isi',
                                              tokens=list(indices),
                                              edges=[edge_map[tr]])
                    aligns.append(align)

        letter_labels = {
            v: default_labels[k]
            for k, v in letter_labels.items()
        }
        jamr_labels = {v: default_labels[k] for k, v in jamr_labels.items()}
        isi_labels = {
            v: default_labels[k] if k != 'ignore' else k
            for k, v in isi_labels.items()
        }
        isi_edge_labels = {
            v: edge_map[k] if k in edge_map else k
            for k, v in isi_edge_labels.items()
        }

        return amr, (letter_labels, jamr_labels, isi_labels, isi_edge_labels,
                     aligns)
예제 #16
0
def main():
    amr_file = sys.argv[1]
    hand_alignments_file = sys.argv[2]

    reader = AMR_Reader()
    amrs = reader.load(amr_file, remove_wiki=True)
    amrs = {amr.id: amr for amr in amrs}

    subgraph_alignments = {}
    relation_alignments = {}
    reentrancy_alignments = {}
    all_spans = {amr_id: set() for amr_id in amrs}

    amr = None
    node_labels = {}
    with open(hand_alignments_file) as f:
        hand_alignments = csv.reader(f, delimiter="\t")
        for row in hand_alignments:
            if row[0] == 'amr':
                amr_id = row[1]
                subgraph_alignments[amr_id] = []
                relation_alignments[amr_id] = []
                reentrancy_alignments[amr_id] = []
                amr = amrs[amr_id]
                taken = set()
                node_labels = get_node_labels(amr)
                node_labels = {v: k for k, v in node_labels.items()}
                edge_labels = get_edge_labels(amr)
                edge_labels = {v: k for k, v in edge_labels.items()}
            elif row[0] == 'node':
                type = 'subgraph'
                if row[3].startswith('*'):
                    type = 'dupl-subgraph'
                    row[3] = row[3].replace('*', '')
                if not row[3]:
                    raise Exception('Missing Annotation:', amr_id)
                node_id = row[1]
                if node_id not in node_labels:
                    raise Exception('Failed to parse node labels:', amr.id,
                                    node_id)
                n = node_labels[node_id]
                token_ids = [int(t) for t in row[3].split(',')]
                if any(t >= len(amr.tokens) for t in token_ids):
                    raise Exception('Bad Annotation:', amr_id)
                if tuple(token_ids) not in all_spans[amr_id] and any(
                        t in taken for t in token_ids):
                    raise Exception('Bad Span Annotation', amr_id)
                all_spans[amr_id].add(tuple(token_ids))
                taken.update(token_ids)
                align = amr.get_alignment(subgraph_alignments,
                                          token_id=token_ids[0])
                if align and align.type == type:
                    align.nodes.append(n)
                else:
                    new_align = AMR_Alignment(type=type,
                                              tokens=token_ids,
                                              nodes=[n],
                                              amr=amr)
                    subgraph_alignments[amr.id].append(new_align)
            elif row[0] == 'edge':
                type = 'relation'
                if row[3].startswith('*'):
                    row[3] = row[3].replace('*', '')
                if not row[3]:
                    raise Exception('Missing Annotation:', amr_id)
                edge_id = row[1]
                if edge_id not in edge_labels:
                    raise Exception('Failed to parse edge labels:', amr.id,
                                    node_id)
                e = edge_labels[edge_id]
                token_ids = [int(t) for t in row[3].split(',')]
                if any(t >= len(amr.tokens) for t in token_ids):
                    raise Exception('Bad Annotation:', amr_id)
                if tuple(token_ids) not in all_spans[amr_id] and any(
                        t in taken for t in token_ids):
                    raise Exception('Bad Span Annotation', amr_id, token_ids)
                all_spans[amr_id].add(tuple(token_ids))
                taken.update(token_ids)
                align = amr.get_alignment(relation_alignments,
                                          token_id=token_ids[0])
                if align and align.type == type:
                    align.edges.append(e)
                else:
                    new_align = AMR_Alignment(type=type,
                                              tokens=token_ids,
                                              edges=[e],
                                              amr=amr)
                    relation_alignments[amr.id].append(new_align)
            elif row[0] == 'reentrancy':
                if not row[3]:
                    raise Exception('Missing Annotation:', amr_id)
                edge_id = row[1]
                e = edge_labels[edge_id]
                if row[3].startswith('*'):
                    row[3] = row[3].replace('*', '')
                if row[3] == '_':
                    token_ids = amr.get_alignment(relation_alignments,
                                                  edge=e).tokens
                else:
                    token_ids = [int(t) for t in row[3].split(',')]
                tag = row[4]
                if row[3] == '_':
                    tag = 'primary'
                if not tag:
                    raise Exception('Missing reentrancy tag:', amr.id)
                type = f'reentrancy:{tag}'
                if any(t >= len(amr.tokens) for t in token_ids):
                    raise Exception('Bad Annotation:', amr_id)
                if tuple(token_ids) not in all_spans[amr_id] and any(
                        t in taken for t in token_ids):
                    raise Exception('Bad Span Annotation', amr_id, token_ids)
                all_spans[amr_id].add(tuple(token_ids))
                taken.update(token_ids)
                new_align = AMR_Alignment(type=type,
                                          tokens=token_ids,
                                          edges=[e],
                                          amr=amr)
                reentrancy_alignments[amr.id].append(new_align)
    for amr_id in subgraph_alignments:
        amr = amrs[amr_id]
        for t in range(len(amr.tokens)):
            if not any(t in span for span in all_spans[amr_id]):
                all_spans[amr_id].add((t, ))
        spans = [
            list(span)
            for span in sorted(all_spans[amr_id], key=lambda x: x[0])
        ]

        for align in subgraph_alignments[amr_id]:
            if align.nodes and not is_subgraph(amr, align.nodes):
                print('Possible Bad align:',
                      amr.id,
                      align.tokens,
                      ' '.join(amr.tokens[t] for t in align.tokens),
                      file=sys.stderr)
        for align in relation_alignments[amr_id]:
            subgraph_aligns = [
                a for a in subgraph_alignments[amr.id]
                if a.tokens == align.tokens
            ]
            for s, r, t in align.edges:
                if subgraph_aligns and not any(
                        s in a.nodes or t in a.nodes or not a.nodes
                        for a in subgraph_aligns):
                    if r == ':manner' and amr.tokens[
                            align.tokens[0]] == 'without':
                        continue
                    raise Exception('Bad Relation align:', amr.id,
                                    align.tokens, s, r, t)
        dupl_sub_aligns = [
            align for align in subgraph_alignments[amr_id]
            if align.type.startswith('dupl')
        ]
        subgraph_alignments[amr_id] = [
            align for align in subgraph_alignments[amr_id]
            if not align.type.startswith('dupl')
        ]
        # dupl_rel_aligns = [align for align in relation_alignments[amr_id] if align.type.startswith('dupl')]
        # relation_alignments[amr_id] = [align for align in relation_alignments[amr_id] if not align.type.startswith('dupl')]
        clean_alignments(amr, subgraph_alignments, dupl_sub_aligns, spans)
        clean_alignments(amr, relation_alignments, [], spans, mode='relations')
        for t, _ in enumerate(amr.tokens):
            count = [span for span in spans if t in span]
            if len(count) != 1:
                raise Exception('Bad Span:', amr.id, count)

    # amr_file = amr_file.replace('.txt', '.jakob')
    align_file = amr_file.replace('.txt',
                                  '') + f'.subgraph_alignments.gold.json'
    print(f'Writing subgraph alignments to: {align_file}')
    reader.save_alignments_to_json(align_file, subgraph_alignments)

    align_file = amr_file.replace('.txt',
                                  '') + f'.relation_alignments.gold.json'
    print(f'Writing relation alignments to: {align_file}')
    reader.save_alignments_to_json(align_file, relation_alignments)

    align_file = amr_file.replace('.txt',
                                  '') + f'.reentrancy_alignments.gold.json'
    print(f'Writing reentrancy alignments to: {align_file}')
    reader.save_alignments_to_json(align_file, reentrancy_alignments)
예제 #17
0
    def align(self, amr, relation_alignments, e, unaligned=None, return_all=False):
        # get candidates
        candidate_spans = [align.tokens for align in self.subgraph_alignments[amr.id] if not align.nodes]
        candidate_spans = [span for span in candidate_spans if not amr.get_alignment(relation_alignments, token_id=span[0])]
        candidate_spans = [span for span in candidate_spans if not english_ignore_tokens(amr, span)]
        candidate_neighbors = rule_based_anchor_relation(e)

        # only align to prepositions between parent and child
        parent = amr.get_alignment(self.subgraph_alignments, node_id=e[0])
        child = amr.get_alignment(self.subgraph_alignments, node_id=e[2])
        candidate_spans = [span for span in candidate_spans if (parent.tokens[0]<span[0]<child.tokens[0])
                                                            or (child.tokens[0]<span[0]<parent.tokens[0])
                                                            or ' '.join(amr.lemmas[t] for t in span)=='ago']
        # make sure rel alignment does not interfere with child and any of its descendents
        child_descendents = set()
        child_descendents.update(child.tokens)
        for s,r,t in amr.edges:
            if s == e[2]:
                talign = amr.get_alignment(self.subgraph_alignments, node_id=t)
                child_descendents.update(talign.tokens)
        if child_descendents:
            start, end = min(child_descendents), max(child_descendents)
            if not (start<=parent.tokens[0]<=end):
                candidate_spans = [span for span in candidate_spans if not (start<=span[0]<=end)]

        scores1 = {}
        aligns1 = {}
        for i, span in enumerate(candidate_spans):
            new_align = AMR_Alignment(type='relation', tokens=span, edges=[e], amr=amr)
            replaced_align = AMR_Alignment(type='relation', tokens=span, edges=[], amr=amr)
            scores1[i] = self.logp(amr, relation_alignments, new_align) - self.logp(amr, relation_alignments, replaced_align)
            aligns1[i] = new_align
        scores2 = {}
        aligns2 = {}
        for i, neighbor in enumerate(candidate_neighbors):
            sub_align = amr.get_alignment(self.subgraph_alignments, node_id=neighbor)
            span = sub_align.tokens
            if not span: continue
            if span not in amr.spans:
                raise Exception('Subgraph Alignment has Faulty Span:', span)
            replaced_align = amr.get_alignment(relation_alignments, token_id=span[0])
            new_align = AMR_Alignment(type='relation', tokens=replaced_align.tokens, edges=replaced_align.edges+[e], amr=amr)
            scores2[i] = self.logp(amr, relation_alignments, new_align) - self.logp(amr, relation_alignments, replaced_align)
            aligns2[i] = new_align

        all_scores = {}
        all_aligns = {}
        for x in scores1:
            span = tuple(aligns1[x].tokens)
            all_scores[span] = scores1[x]
            all_aligns[span] = aligns1[x]
        for x in scores2:
            span = tuple(aligns2[x].tokens)
            all_scores[span] = scores2[x]
            all_aligns[span] = aligns2[x]

        if not all_scores:
            return None, None

        if return_all:
            return all_aligns, all_scores

        best_span = max(all_scores.keys(), key=lambda x:all_scores[x])
        best_score = all_scores[best_span]
        best_align = all_aligns[best_span]

        # readable = [r for r in sorted(readable, key=lambda x:x['score'], reverse=True)]
        return best_align, best_score
예제 #18
0
    def align(self, amr, alignments, n, unaligned=None, return_all=False):

        # get candidates
        if unaligned is None:
            unaligned = self.get_unaligned(amr, alignments)
        candidate_spans = [
            align.tokens for align in alignments[amr.id] if not align.nodes
        ]
        tmp_align = AMR_Alignment(type='subgraph', tokens=[0], nodes=[n])
        postprocess_subgraph(amr, alignments, tmp_align, english=ENGLISH)
        candidate_neighbors = [s for s, r, t in amr.edges if t in tmp_align.nodes and s not in unaligned] + \
                              [t for s, r, t in amr.edges if s in tmp_align.nodes and t not in unaligned]
        for n2 in candidate_neighbors[:]:
            nalign = amr.get_alignment(alignments, node_id=n2)
            if not nalign or nalign.type == 'dupl-subgraph':
                candidate_neighbors.remove(n2)

        # handle "never => ever, -" and other similar cases
        edge_map = {n: [] for n in amr.nodes}
        for s, r, t in amr.edges:
            edge_map[s].append(t)
        if not edge_map[n]:
            for n2 in amr.nodes:
                if edge_map[n2]: continue
                if n2 in unaligned: continue
                if amr.nodes[n] == amr.nodes[n2]: continue
                nalign = amr.get_alignment(alignments, node_id=n2)
                if len(nalign.nodes) != 1: continue
                if any(n in edge_map[p] and n2 in edge_map[p]
                       for p in amr.nodes):
                    candidate_neighbors.append(n2)

        # special rules for multi-sentence, and, or
        if ENGLISH:
            candidate_spans2 = [
                span for span in candidate_spans
                if not english_is_alignment_forbidden(amr, span, n)
            ]
            if amr.nodes[n] == 'multi-sentence' and not candidate_spans:
                candidate_spans2 = candidate_spans
            elif amr.nodes[n] == 'and' and not candidate_spans:
                candidate_spans2 = candidate_spans
            candidate_spans = candidate_spans2

        if amr.nodes[n] in ['multi-sentence', 'and', 'or'] and candidate_spans:
            candidate_neighbors = []
        for n2 in candidate_neighbors[:]:
            if amr.nodes[n2] in ['multi-sentence', 'and', 'or'
                                 ] and candidate_spans:
                candidate_neighbors.remove(n2)
        if len([n2 for n2 in amr.nodes if amr.nodes[n] == amr.nodes[n2]]) > 1:
            for s, r, t in amr.edges:
                if t == n and amr.nodes[s] in [
                        'include-91', 'same-01', 'instead-of-91',
                        'resemble-01', 'differ-02', 'and', 'or'
                ]:
                    if len([lemma for lemma in amr.lemmas if amr.nodes[n].split('-')[0]==lemma]) >= \
                        len([n2 for n2 in amr.nodes if amr.nodes[n]==amr.nodes[n2]]):
                        break
                    for s2, r2, t2 in amr.edges:
                        if s2 == s and t2 != t and amr.nodes[t2] == amr.nodes[
                                n] and r2.endswith('1'):
                            candidate_spans = []
                            break
                        elif t2 == s and amr.nodes[s2] == amr.nodes[
                                n] and r2.endswith('1-of'):
                            candidate_spans = []
                            break

        candidate_duplicates = []
        for n2 in amr.nodes:
            if amr.nodes[n].isdigit() or '"' in amr.nodes[n]: break
            if n2 != n and amr.nodes[n] == amr.nodes[n2]:
                align = amr.get_alignment(alignments, node_id=n2)
                if align:
                    candidate_duplicates.append(align.tokens)

        scores1 = {}
        aligns1 = {}
        for i, span in enumerate(candidate_spans):
            new_align = AMR_Alignment(type='subgraph',
                                      tokens=span,
                                      nodes=[n],
                                      amr=amr)
            replaced_align = AMR_Alignment(type='subgraph',
                                           tokens=span,
                                           nodes=[],
                                           amr=amr)
            scores1[i] = self.logp(amr, alignments, new_align) - self.logp(
                amr, alignments, replaced_align)
            aligns1[i] = new_align
        scores2 = {}
        aligns2 = {}
        for i, neighbor in enumerate(candidate_neighbors):
            replaced_align = amr.get_alignment(alignments, node_id=neighbor)
            if replaced_align.type.startswith('dupl'): continue
            new_align = AMR_Alignment(type=replaced_align.type,
                                      tokens=replaced_align.tokens,
                                      nodes=replaced_align.nodes + [n],
                                      amr=amr)
            scores2[i] = self.logp(amr, alignments, new_align) - self.logp(
                amr, alignments, replaced_align, postprocess=False)
            aligns2[i] = new_align
        scores3 = {}
        aligns3 = {}
        if self.align_duplicates:
            for i, span in enumerate(candidate_duplicates):
                new_align = AMR_Alignment(type='dupl-subgraph',
                                          tokens=span,
                                          nodes=[n],
                                          amr=amr)
                replaced_align = amr.get_alignment(alignments,
                                                   token_id=span[0])
                scores3[i] = math.log(DUPLICATE_RATE) + self.logp(
                    amr, alignments, new_align) - self.logp(
                        amr, alignments, replaced_align, postprocess=False)
                aligns3[i] = new_align

        all_scores = {}
        all_aligns = {}
        for x in scores1:
            span = tuple(aligns1[x].tokens)
            all_scores[span] = scores1[x]
            all_aligns[span] = aligns1[x]
        for x in scores2:
            span = tuple(aligns2[x].tokens)
            all_scores[span] = scores2[x]
            all_aligns[span] = aligns2[x]
        for x in scores3:
            span = tuple(aligns3[x].tokens)
            all_scores[span] = scores3[x]
            all_aligns[span] = aligns3[x]

        if not all_scores:
            return None, None

        if return_all:
            return all_aligns, all_scores

        best_span = max(all_scores.keys(), key=lambda x: all_scores[x])
        best_score = all_scores[best_span]
        best_align = all_aligns[best_span]

        # readable = [r for r in sorted(readable, key=lambda x:x['score'], reverse=True)]
        return best_align, best_score
예제 #19
0
def main():
    dir = '../data/tamr'
    szubert_amrs = '../data/szubert/szubert_amrs.txt'
    output = '../data/szubert/szubert_amrs.tamr.subgraph_alignments.json'

    file2 = '../data/tamr/ldc_train_2017.txt'

    reader = AMR_Reader()
    amrs = reader.load(szubert_amrs, remove_wiki=True)
    amrs2 = reader.load(file2, remove_wiki=True)

    alignments = {}
    for filename in os.listdir(dir):
        if filename.endswith(".tamr_alignment"):
            file = os.path.join(dir, filename)
            amr_id = ''
            with open(file) as f:
                for line in f:
                    if line.startswith('# ::alignments'):
                        aligns = line[len('# ::alignments '):].split()
                        aligns = [s.split('|') for s in aligns if '|' in s]
                        aligns = [(a[0], a[1].split('+')) for a in aligns]
                        for span, nodes in aligns:
                            start = int(span.split('-')[0])
                            end = int(span.split('-')[1])
                            span = [t for t in range(start, end)]
                            align = AMR_Alignment(type='subgraph',
                                                  tokens=span,
                                                  nodes=nodes)
                            alignments[amr_id].append(align)

                    elif line.strip():
                        amr_id = line.strip()
                        alignments[amr_id] = []

    amrs2 = {amr.id: amr for amr in amrs2}
    amrs = [amr for amr in amrs if amr.id in alignments and amr.id in amrs2]
    amrs3 = []
    for amr in amrs[:]:
        amr2 = amrs2[amr.id]
        nodes = {amr.nodes[n] for n in amr.nodes}
        nodes2 = {amr2.nodes[n] for n in amr2.nodes}
        edges = {(amr.nodes[s], r, amr.nodes[t]) for s, r, t in amr.edges}
        edges2 = {(amr2.nodes[s], r, amr2.nodes[t]) for s, r, t in amr2.edges}
        if nodes == nodes2 and edges == edges2:
            amrs3.append(amr)

    amr_ids = [amr.id for amr in amrs]
    alignments = {
        amr_id: alignments[amr_id]
        for amr_id in alignments if amr_id in amr_ids
    }
    for amr in amrs:
        node_map = {}
        nodes = [n for align in alignments[amr.id] for n in align.nodes]
        nodes = [n for n in sorted(nodes, key=lambda x: (len(x), x))]
        for n in nodes:
            prefix = '.'.join(i for i in n.split('.')[:-1])
            last = int(n.split('.')[-1])
            if prefix:
                if prefix not in node_map:
                    new_prefix = '.'.join(
                        str(int(i) + 1) for i in n.split('.')[:-1])
                    if new_prefix not in amr.nodes:
                        continue
                    node_map[prefix] = new_prefix
                new_n = node_map[prefix] + '.' + str(last + 1)
            else:
                new_n = str(last + 1)
            if new_n in amr.nodes:
                node_map[n] = new_n
        nodes = [
            n for align in alignments[amr.id] for n in align.nodes
            if n not in node_map
        ]
        nodes = [n for n in sorted(nodes, key=lambda x: (len(x), x))]
        for n in nodes:
            prefix = '.'.join(i for i in n.split('.')[:-1])
            if prefix not in node_map:
                new_prefix = '.'.join(
                    str(int(i) + 1) for i in n.split('.')[:-1])
                if new_prefix in amr.nodes:
                    node_map[prefix] = new_prefix
                else:
                    del alignments[amr.id]
                    break
            candidates = [t for s, r, t in amr.edges if s == node_map[prefix]]
            candidates = [t for t in candidates if t not in node_map.values()]
            candidates = [t for t in sorted(candidates)]
            if not candidates:
                del alignments[amr.id]
                break
            new_n = candidates[0]
            node_map[n] = new_n
        if amr.id in alignments:
            for align in alignments[amr.id]:
                align.nodes = [node_map[n] for n in align.nodes]
                align.amr = amr
            for t, tok in enumerate(amr.tokens):
                align = amr.get_alignment(alignments, token_id=t)
                if not align:
                    align = AMR_Alignment(type='subgraph',
                                          tokens=[t],
                                          nodes=[],
                                          amr=amr)
                    alignments[amr.id].append(align)
            alignments[amr.id] = [
                align for align in sorted(alignments[amr.id],
                                          key=lambda a: a.tokens[0])
            ]

    reader.save_alignments_to_json(output, alignments)
예제 #20
0
    def align(self,
              amr,
              reentrancy_alignments,
              e,
              unaligned=None,
              return_all=False):
        # get candidates
        allowed_types = self.get_allowed_types(amr)
        candidate_spans = [
            span for span in amr.spans if allowed_types[e][tuple(span)]
        ]
        # candidate_spans = [span for span in candidate_spans if not amr.get_alignment(reentrancy_alignments, token_id=span[0])]
        candidate_neighbors = []  #[e]
        neighbor_aligns = [
            amr.get_alignment(reentrancy_alignments, edge=(s, r, t))
            for s, r, t in amr.reentrancies if t == e[-1] and e != (s, r, t)
        ]
        # if all(a.type!='reentrancy:primary' for a in neighbor_aligns):
        #     candidate_spans = []

        readable = []
        scores1 = {}
        aligns1 = {}
        for i, span in enumerate(candidate_spans):
            type = allowed_types[e][tuple(span)][0]
            new_align = AMR_Alignment(type=f'reentrancy:{type}',
                                      tokens=span,
                                      edges=[e],
                                      amr=amr)
            # replaced_align = AMR_Alignment(type='relation', tokens=span, edges=[], amr=amr)
            scores1[i] = self.logp(
                amr, reentrancy_alignments, new_align
            )  #- self.logp(amr, reentrancy_alignments, replaced_align)
            # scores1[i] = self.inductive_bias(amr, reentrancy_alignments, new_align) - self.inductive_bias(amr, reentrancy_alignments, replaced_align)
            if type == 'pragmatic':
                scores1[i] += math.log(PRAGMATIC_RATE)
            aligns1[i] = new_align
            # readable.append(self.readable_logp(amr,alignments, new_align))
        scores2 = {}
        aligns2 = {}
        for i, neighbor in enumerate(candidate_neighbors):
            rel_align = amr.get_alignment(self.relation_alignments,
                                          edge=neighbor)
            span = rel_align.tokens
            if not span: continue
            if span not in amr.spans:
                raise Exception('Relation Alignment has Faulty Span:', span)
            new_align = AMR_Alignment(type='reentrancy:primary',
                                      tokens=rel_align.tokens,
                                      edges=[e],
                                      amr=amr)
            scores2[i] = self.logp(
                amr, reentrancy_alignments, new_align
            )  #- self.logp(amr, reentrancy_alignments, replaced_align)
            # scores2[i] = self.inductive_bias(amr, reentrancy_alignments, new_align) - self.inductive_bias(amr, reentrancy_alignments, replaced_align)
            aligns2[i] = new_align
            # readable.append(self.readable_logp(amr, alignments, new_align))

        all_scores = {}
        all_aligns = {}
        for x in scores1:
            span = tuple(aligns1[x].tokens)
            all_scores[span] = scores1[x]
            all_aligns[span] = aligns1[x]
        for x in scores2:
            span = tuple(aligns2[x].tokens)
            all_scores[span] = scores2[x]
            all_aligns[span] = aligns2[x]

        if not all_scores:
            return None, None

        if return_all:
            return all_aligns, all_scores

        best_span = max(all_scores.keys(), key=lambda x: all_scores[x])
        best_score = all_scores[best_span]
        best_align = all_aligns[best_span]

        # readable = [r for r in sorted(readable, key=lambda x:x['score'], reverse=True)]
        return best_align, best_score