示例#1
0
def connect_graph_if_not_connected(graph):
    try:
        encoded = pm_encode(graph)
        return graph, ParsedStatus.OK
    except:
        pass

    nxgraph = nx.MultiGraph()
    variables = graph.variables()
    for v1, _, v2 in graph.triples:
        if v1 in variables and v2 in variables:
            nxgraph.add_edge(v1, v2)
        elif v1 in variables:
            nxgraph.add_edge(v1, v1)

    triples = graph.triples.copy()
    new_triples = []
    addition = f'a{len(variables) + 1}'
    triples.append(penman.Triple(addition, ':instance', 'and'))
    for i, conn_set in enumerate(nx.connected_components(nxgraph), start=1):
        edge = f':op{i}'
        conn_set = sorted(conn_set, key=lambda x: int(x[1:]))
        conn_set = [c for c in conn_set if c in variables]
        node = conn_set[0]
        new_triples.append(penman.Triple(addition, edge, node))
    triples = new_triples + triples
    metadata = graph.metadata
    graph = penman.Graph(triples)
    graph.metadata.update(metadata)
    pm_encode(graph)

    return graph, ParsedStatus.FIXED
示例#2
0
def _split_name_ops(graph):
    # identify name triples
    name_vars = {}
    for i, (v1, rel, v2) in enumerate(graph.triples):
        if rel == ':instance' and v2 == 'name':
            name_vars[v1] = 1

    # check if they have ops
    name_vars_to_ops = defaultdict(list)
    for i, (v1, rel, v2) in enumerate(graph.triples):
        if v1 in name_vars and rel.startswith(':op'):
            name_vars_to_ops[v1].append((i, rel, v2.strip('"')))

    triples = graph.triples.copy()
    for nv, ops in name_vars_to_ops.items():
        ops = sorted(ops, key=lambda x: int(x[1][3:]))
        idx, _, lits = zip(*ops)
        for i in idx:
            triples[i] = None

        lits = ['"' + l + '"' for lit in lits for l in lit.split('_')]

        tt = []
        for i, l in enumerate(lits, start=1):
            rel = ':op' + str(i)
            tt.append(penman.Triple(nv, rel, l))

        triples[min(idx)] = tt

    triples = [t if isinstance(t, list) else [t] for t in triples if t is not None]
    triples = [t for tt in triples for t in tt]

    graph_ = penman.Graph(triples)
    graph_.metadata = graph.metadata
    return graph_
示例#3
0
    def _collapse_name_ops(self, amr):
        # identify name triples
        name_vars = {}
        for i, (v1, rel, v2) in enumerate(amr.triples):
            if rel == ':instance' and v2 == 'name':
                name_vars[v1] = 1

        # check if they have ops
        name_vars_to_ops = defaultdict(list)
        for i, (v1, rel, v2) in enumerate(amr.triples):
            if v1 in name_vars and rel.startswith(':op'):
                name_vars_to_ops[v1].append((i, rel, v2.strip('"')))

        triples = amr.triples.copy()
        for nv, ops in name_vars_to_ops.items():
            ops = sorted(ops, key=lambda x: int(x[1][3:]))
            idx, _, lits = zip(*ops)
            for i in idx:
                triples[i] = None
            lit = '"' + '_'.join(lits) + '"'
            triples[min(idx)] = penman.Triple(nv, ':op1', lit)

        triples = [t for t in triples if t is not None]
        amr_ = penman.Graph(triples)
        amr_.metadata = amr.metadata
        return amr_
示例#4
0
 def add_edge(self, source, target, label):
     self._G.add_edge(source, target, label=label)
     t = penman.Triple(source=source.identifier,
                       relation=label,
                       target=target.identifier)
     triples = self._triples + [t]
     triples = penman.alphanum_order(triples)
     self._update_penman_graph(triples)
示例#5
0
文件: amr_rw.py 项目: bjascob/amrlib
def _replace_wiki(graph):
    metadata = graph.metadata
    triples = []
    for t in graph.triples:
        v1, rel, v2 = t
        if rel == ':wiki':
            t = penman.Triple(v1, rel, '+')
        triples.append(t)
    graph = penman.Graph(triples)
    graph.metadata = metadata
    return graph
示例#6
0
文件: amr.py 项目: wjiawei97/stog
 def replace_node_attribute(self, node, attr, old, new):
     node.replace_attribute(attr, old, new)
     triples = []
     found = False
     for t in self._triples:
         if t.source == node.identifier and t.relation == attr and t.target == old:
             found = True
             t = penman.Triple(source=node.identifier, relation=attr, target=new)
         triples.append(t)
     if not found:
         raise KeyError
     self._triples = penman.alphanum_order(triples)
示例#7
0
文件: norman.py 项目: goodmami/norman
def reify_attributes(g):
    variables = g.variables()
    # filter out triples with empty instances
    triples = [t for t in g.triples()
               if t.relation != 'instance' or t.target]
    # ensure every node has a type
    # types = {t.source: t for t in triples if t.relation == 'instance'}
    # for src in variables.difference(types):
    #     triples.append(penman.Triple(src, 'instance', 'amr-missing'))
    # ensure constants are nodes
    new_triples = []
    for triple in triples:
        if triple.relation != 'instance' and triple.target not in variables:
            var = _unique_var('', variables, '_')
            new_triples.extend([
                penman.Triple(var, 'instance', triple.target),
                penman.Triple(triple.source, triple.relation, var)
            ])
            variables.add(var)
        else:
            new_triples.append(triple)

    return penman.Graph(new_triples, g.top)
示例#8
0
文件: norman.py 项目: goodmami/norman
def reify(g, re_map, prefix=None):
    variables = g.variables()
    counts = defaultdict(int)
    triples = []
    for triple in g.triples():
        if triple.relation in re_map:
            concept, srcrole, tgtrole = re_map[triple.relation]
            var = _unique_var(concept, variables, prefix)
            variables.add(var)
            triples.extend([
                penman.Triple(var, 'instance', concept),
                # source triple is inverse direction of original relation
                penman.Triple(var, srcrole, triple.source,
                              inverted=not triple.inverted),
                # target triple is same direction as original relation
                penman.Triple(var, tgtrole, triple.target,
                              inverted=triple.inverted)
            ])
            counts[triple.relation] += 1
        else:
            triples.append(triple)
    g = penman.Graph(triples, top=g.top)
    g.reified_counts = counts
    return g
示例#9
0
def create_amr_graph_from_prediction(prediction):
    nodes = prediction['nodes']
    heads = prediction['heads']
    head_labels = prediction['head_labels']
    sentence_id = prediction['sentence_id']
    root = prediction['root']

    variable_map = dict()
    triples = []
    for idx, node in enumerate(nodes):
        variable_map['vv' + str(idx + 1)] = node

    # find top
    top = 'vv1'
    for var, value in variable_map.items():
        if value == root:
            top = var

    # rename nodes
    for key in variable_map:
        variable_map[key] = variable_map[key].split('_')[0]

    # create instances
    for variable in variable_map:
        triples.append((variable, 'instance', variable_map[variable]))

    # create connections
    for idx, head in enumerate(heads):
        if (head != 0):
            head_var = 'vv{}'.format(head)
            dep_var = 'vv{}'.format(idx + 1)
            label = head_labels[idx]
            triple = (head_var, label, dep_var)
            triples.append(triple)

    graph = penman.Graph()

    graph.heads = heads
    graph.nodes = nodes
    graph.head_labels = head_labels
    graph._top = top
    graph._triples = [penman.Triple(*t) for t in triples]
    graph.id = sentence_id

    return graph
示例#10
0
文件: norman.py 项目: goodmami/norman
def _dereification_agenda(g, co_map):
    """
    Find eligible dereifications and return the replacements.
    """
    agenda = {}
    variables = g.variables()
    fixed = {tgt for _, _, tgt in g.edges()}.union([g.top])
    for triple in g.triples(relation='instance'):
        if triple.source not in fixed and triple.target in co_map:
            rels = {t.relation: t
                    for t in g.triples(source=triple.source)
                    if t.relation != 'instance'}
            used = set()
            agendum = []
            incoming_triple = None
            for role, src_role, tgt_role in co_map[triple.target]:
                if not (src_role in rels and tgt_role in rels):
                    continue  # source and target must exist
                src = rels[src_role]
                tgt = rels[tgt_role]
                if (src_role in used and tgt_role in used):
                    continue  # don't duplicate info
                elif src.target not in variables:
                    continue  # don't create new nodes from attributes
                agendum.append(penman.Triple(src.target, role, tgt.target,
                                             inverted=tgt.inverted))
                used.add(src_role)
                used.add(tgt_role)
                if src.inverted:
                    incoming_triple = src
                elif tgt.inverted:
                    incoming_triple = tgt
            # only include for a full mapping
            if used == set(rels):
                assert incoming_triple is not None
                agenda[triple.source] = (incoming_triple, agendum)
    return agenda
示例#11
0
    def from_prediction(cls, prediction):
        def is_attribute_value(value):
            return re.search(r'(^".*"$|^[^a-zA-Z]+$)', value) is not None

        def is_attribute_edge(label):
            return label in ('instance', 'mode', 'li', 'value', 'month',
                             'year', 'day', 'decade', 'ARG6')

        def normalize_number(text):
            if re.search(r'^\d+,\d+$', text):
                text = text.replace(',', '')
            return text

        def abstract_node(value):
            return re.search(
                r'^([A-Z]+|DATE_ATTRS|SCORE_ENTITY|ORDINAL_ENTITY)_\d+$',
                value)

        def abstract_attribute(value):
            return re.search(r'^_QUANTITY_\d+$', value)

        def correct_multiroot(heads):
            for i in range(1, len(heads)):
                if heads[i] == 0:
                    heads[i] = 1
            return heads

        nodes = [normalize_number(n) for n in prediction['nodes']]
        heads = correct_multiroot(prediction['heads'])
        corefs = [int(x) for x in prediction['corefs']]
        head_labels = prediction['head_labels']

        triples = []
        top = None
        # Build the variable map from variable to instance.
        variable_map = {}
        for coref_index in corefs:
            node = nodes[coref_index - 1]
            head_label = head_labels[coref_index - 1]
            if (re.search(r'[/:\\()]', node) or is_attribute_value(node)
                    or is_attribute_edge(head_label)
                    or abstract_attribute(node)):
                continue
            variable_map['vv{}'.format(coref_index)] = node
        for head_index in heads:
            if head_index == 0:
                continue
            node = nodes[head_index - 1]
            coref_index = corefs[head_index - 1]
            variable_map['vv{}'.format(coref_index)] = node
        # Build edge triples and other attribute triples.
        for i, head_index in enumerate(heads):
            if head_index == 0:
                top_variable = 'vv{}'.format(corefs[i])
                if top_variable not in variable_map:
                    variable_map[top_variable] = nodes[i]
                top = top_variable
                continue
            head_variable = 'vv{}'.format(corefs[head_index - 1])
            modifier = nodes[i]
            modifier_variable = 'vv{}'.format(corefs[i])
            label = head_labels[i]
            assert head_variable in variable_map
            if modifier_variable in variable_map:
                triples.append((head_variable, label, modifier_variable))
            else:
                # Add quotes if there's a backslash.
                if re.search(r'[/:\\()]',
                             modifier) and not re.search(r'^".*"$', modifier):
                    modifier = '"{}"'.format(modifier)
                triples.append((head_variable, label, modifier))

        for var, node in variable_map.items():
            if re.search(r'^".*"$', node):
                node = node[1:-1]
            if re.search(r'[/:\\()]', node):
                parts = re.split(r'[/:\\()]', node)
                for part in parts[::-1]:
                    if len(part):
                        node = part
                        break
                else:
                    node = re.sub(r'[/:\\()]', '_', node)
            triples.append((var, 'instance', node))

        if len(triples) == 0:
            triples.append(('vv1', 'instance', 'string-entity'))
            top = 'vv1'
        triples.sort(key=lambda x: int(x[0].replace('vv', '')))
        graph = penman.Graph()
        graph._top = top
        graph._triples = [penman.Triple(*t) for t in triples]
        graph = cls(graph)
        try:
            GraphRepair.do(graph, nodes)
            amr_codec.encode(graph)
        except Exception as e:
            graph._top = top
            graph._triples = [penman.Triple(*t) for t in triples]
            graph = cls(graph)
        return graph
示例#12
0
 def add_node_attribute(self, node, attr, value):
     node.add_attribute(attr, value)
     t = penman.Triple(source=node.identifier, relation=attr, target=value)
     self._triples = penman.alphanum_order(self._triples + [t])
示例#13
0
    def _fix_and_make_graph(self, nodes):

        nodes_ = []
        for n in nodes:
            if isinstance(n, str):
                if n.startswith('<') and n.endswith('>') and (not n.startswith('<pointer:')):
                    pass
                else:
                    nodes_.append(n)
            else:
                nodes_.append(n)
        nodes = nodes_
        if not nodes:
            return penman.Graph()

        if self.use_pointer_tokens:

            i = 0
            nodes_ = []
            while i < len(nodes):
                nxt = nodes[i]
                pst = None
                if isinstance(nxt, str) and nxt.startswith('<pointer:'):
                    e = nxt.find('>')
                    if e != len(nxt) - 1:
                        pst = nxt[e + 1:]
                        nxt = nxt[:e + 1]
                    nodes_.append(nxt)
                    if pst is not None:
                        nodes_.append(pst)
                else:
                    nodes_.append(nxt)
                i += 1
            nodes = nodes_

            i = 1
            nodes_ = [nodes[0]]
            while i < len(nodes):
                nxt = nodes[i]
                if isinstance(nxt, str) and nxt.startswith('<pointer:') and i + 1 < len(nodes):
                    nxt = 'z' + nxt[9:-1]
                    fol = nodes[i + 1]
                    # is not expansion
                    if isinstance(fol, str) and (fol.startswith(':') or (fol == ')')):
                        nodes_.append(nxt)
                    else:
                        if self.remove_pars:
                            nodes_.append('(')
                        else:
                            if nodes_[-1] != '(':
                                nodes_.append('(')
                                # pass
                        nodes_.append(nxt)
                        nodes_.append('/')
                else:
                    nodes_.append(nxt)
                i += 1
            nodes = nodes_

        i = 0
        nodes_ = []
        while i < (len(nodes) - 1):
            if nodes[i] == ':':
                nodes_.append(nodes[i] + nodes[i + 1])
                i += 2
                last = False
            else:
                nodes_.append(nodes[i])
                i += 1
                last = True
        if last:
            nodes_.append(nodes[-1])
        nodes = nodes_

        i = 0
        nodes_ = []
        while i < (len(nodes)):
            if i < 2:
                nodes_.append(nodes[i])
                i += 1
            elif nodes_[-2] == '/' and nodes[i] == '/':
                i += 2
            else:
                nodes_.append(nodes[i])
                i += 1
        nodes = nodes_

        i = 0
        newvars = 0
        variables = set()
        remap = {}
        nodes_ = []
        while i < (len(nodes)):

            next = nodes[i]

            if next == '/':
                last = nodes_[-1]
                if last in variables:
                    last_remap = f"z{newvars + 1000}"
                    newvars += 1
                    nodes_[-1] = last_remap
                    remap[last] = last_remap
                variables.add(last)
                nodes_.append(next)

            elif self._classify(next) == 'VAR' and next in remap and (i < len(nodes) - 1) and nodes[i + 1] != '/':
                next = remap[next]
                nodes_.append(next)

            else:
                nodes_.append(next)

            i += 1

        nodes = nodes_
        pieces_ = []
        open_cnt = 0
        closed_cnt = 0
        if nodes[0] != '(':
            pieces_.append('(')
            open_cnt += 1
        for p in nodes:
            if p == '(':
                open_cnt += 1
            elif p == ')':
                closed_cnt += 1
            pieces_.append(p)
            if open_cnt == closed_cnt:
                break
        nodes = pieces_ + [')'] * (open_cnt - closed_cnt)

        pieces = []
        for piece in nodes:
            if not pieces:
                pieces.append('(')
            else:
                piece = str(piece)
                if piece.startswith('"') or piece.startswith('"') or '"' in piece.strip('"'):
                    piece = '"' + piece.replace('"', '') + '"'

                prev = self._classify(pieces[-1])
                next = self._classify(piece)

                if next == 'CONST':
                    quote = False
                    for char in (',', ':', '/', '(', ')', '.', '!', '?', '\\', '_', '='):
                        if char in piece:
                            quote = True
                            break
                    if quote:
                        piece = '"' + piece.strip('"') + '"'

                if prev == '(':
                    if next in ('VAR', 'I'):
                        pieces.append(piece)
                elif prev == ')':
                    if next in (')', 'EDGE', 'MODE'):
                        pieces.append(piece)
                elif prev == 'VAR':
                    if next in ('/', 'EDGE', 'MODE', ')'):
                        pieces.append(piece)
                elif prev == '/':
                    if next in ('INST', 'I'):
                        pieces.append(piece)
                elif prev == 'INST':
                    if next in (')', 'EDGE', 'MODE'):
                        pieces.append(piece)
                elif prev == 'I':
                    if next in ('/', ')', 'EDGE', 'MODE'):
                        pieces.append(piece)
                elif prev == 'EDGE':
                    if next in ('(', 'VAR', 'CONST', 'I'):
                        pieces.append(piece)
                    elif next == ')':
                        pieces[-1] = piece
                    elif next in ('EDGE', 'MODE'):
                        pieces[-1] = piece
                elif prev == 'MODE':
                    if next == 'INST':
                        pieces.append(piece)
                elif prev == 'CONST':
                    if next in (')', 'EDGE', 'MODE'):
                        pieces.append(piece)

        pieces_ = []
        open_cnt = 0
        closed_cnt = 0
        if pieces[0] != '(':
            pieces_.append('(')
            open_cnt += 1
        for p in pieces:
            if p == '(':
                open_cnt += 1
            elif p == ')':
                closed_cnt += 1
            pieces_.append(p)
            if open_cnt == closed_cnt:
                break
        pieces = pieces_ + [')'] * (open_cnt - closed_cnt)

        linearized = re.sub(r'\s+', ' ', ' '.join(pieces)).strip()

        """
        line = linearized
        # make sure parentheses match
        # copied from https://github.com/RikVN/AMR/blob/master/restoreAMR/restore_amr.py
        open_count = 0
        close_count = 0
        for i, c in enumerate(line):
            if c == '(':
                open_count += 1
            elif c == ')':
                close_count += 1
            if open_count == close_count and open_count > 0:
                line = line[:i].strip()
                break
        old_line = line
        while True:
            open_count = len(re.findall(r'\(', line))
            close_count = len(re.findall(r'\)', line))
            if open_count > close_count:
                line += ')' * (open_count - close_count)
            elif close_count > open_count:
                for i in range(close_count - open_count):
                    line = line.rstrip(')')
                    line = line.rstrip(' ')
            if old_line == line:
                break
            old_line = line
        """

        graph = penman.decode(linearized + ' ')
        triples = []
        newvars = 2000
        for triple in graph.triples:
            x, rel, y = triple
            if x is None:
                pass
            elif rel == ':instance' and y is None:
                triples.append(penman.Triple(x, rel, 'thing'))
            elif y is None:
                var = f'z{newvars}'
                newvars += 1
                triples.append(penman.Triple(x, rel, var))
                triples.append(penman.Triple(var, ':instance', 'thing'))
            else:
                triples.append(triple)
        graph = penman.Graph(triples)
        linearized = encode(graph)

        def fix_text(linearized=linearized):
            n = 0

            def _repl1(match):
                nonlocal n
                out = match.group(1) + match.group(2) + str(3000 + n) + ' / ' + match.group(2) + match.group(3)
                n += 1
                return out

            linearized = re.sub(r'(\(\s?)([a-z])([^\/:\)]+[:\)])', _repl1, linearized,
                                flags=re.IGNORECASE | re.MULTILINE)

            def _repl2(match):
                return match.group(1)

            linearized = re.sub(r'(\(\s*[a-z][\d+]\s*\/\s*[^\s\)\(:\/]+\s*)((?:/\s*[^\s\)\(:\/]+\s*)+)', _repl2,
                                linearized,
                                flags=re.IGNORECASE | re.MULTILINE)

            # adds a ':' to args w/o it
            linearized = re.sub(r'([^:])(ARG)', r'\1 :\2', linearized)

            # removes edges with no node
            # linearized = re.sub(r':[^\s\)\(:\/]+?\s*\)', ')', linearized, flags=re.MULTILINE)

            return linearized

        linearized = fix_text(linearized)

        g = penman.decode(linearized)
        return g
示例#14
0
def _reconstruct_graph_from_nodes(nodes, backreferences):
    triples = []
    triples_added = set()

    variable2index = {}
    index2variable = {}
    start_index = 0

    cnt = defaultdict(Counter)

    while start_index < len(nodes):
        stop_index = index_of('<stop>',
                              nodes,
                              default=len(nodes) + 1,
                              start=start_index)
        old_start_index = start_index
        start_index = stop_index + 1

        src_node, src_backr = nodes[old_start_index], backreferences[
            old_start_index]

        if src_node == '<stop>':
            continue

        trg_nodes_edges = nodes[old_start_index:stop_index]
        trg_nodes_edges_backr = backreferences[old_start_index:stop_index]
        trg_nodes_edges_indices = list(range(old_start_index, stop_index))

        if isinstance(src_node, str):
            if src_node in ('<s>', '</s>', '<stop>'):
                continue
            elif ('/' in src_node) or (':' in src_node) or (
                    '(' in src_node) or (')' in src_node):
                src_node = 'thing'

        if src_node is not None:
            src_node = str(src_node)
            src_var = src_node[0].lower()
            if not src_var not in 'abcdefghijklmnopqrstuvwxyz':
                src_var = 'x'
            # src_var = f'{src_var}_{len(variable2index)}'
            src_var = f'{src_var}{len(variable2index)}'
            src_var_i = old_start_index
            variable2index[src_var] = src_var_i
            index2variable[src_var_i] = src_var
            triple = penman.Triple(src_var, ':instance', src_node)
            if triple not in triples_added:
                triples.append(triple)
                triples_added.add(triple)
        else:
            if src_backr in index2variable:
                src_var = index2variable[src_backr]
        # more resilient logic here
        (trg_edges, trg_nodes), (_, trg_nodes_backr), (_, trg_nodes_indices) = \
            separate_edges_nodes(
                trg_nodes_edges,
                trg_nodes_edges,
                trg_nodes_edges_backr,
                trg_nodes_edges_indices)

        for n, e, nb, ni in zip(trg_nodes, trg_edges, trg_nodes_backr,
                                trg_nodes_indices):

            if isinstance(n, str) and n.startswith(':'):
                continue
            if isinstance(n, str) and n.startswith('<') and n.endswith('>'):
                continue
            if e == ':li':
                pass
            elif len(e) < 4 or (not e.startswith(':')):
                continue

            # same edge more than once
            num = cnt[src_var][e]
            # num = 0
            if num:

                if e.startswith(':op') or e.startswith(':snt'):
                    continue
                # elif e.startswith(':ARG'):
                #    continue
                elif num > 3:
                    continue

            if n is None:
                if nb not in index2variable:
                    continue
                trg_var = index2variable[nb]
                trg = trg_var
            elif e == ':mode':
                trg = n
            elif (not isinstance(n, str)) or re.match(
                    r"^[+-]?\d+\.?\d*$", n) or (n == '-') or (n == '+'):
                trg = str(n)
            elif (n.startswith('"') and n.endswith('"') and len(n) > 2):
                trg = '"' + n.replace('"', '') + '"'
            elif ('/' in n) or (':' in n) or ('(' in n) or (')'
                                                            in n) or ('='
                                                                      in n):
                trg = f'"{n}"'
            elif n == '"':
                continue
            elif (n.startswith('"') and
                  (not n.endswith('"'))) or (not n.startswith('"') and
                                             (n.endswith('"'))) or ('"' in n):
                trg = '"' + n.replace('"', '') + '"'
            else:
                trg_var = n[0].lower()
                if trg_var not in 'abcdefghijklmnopqrstuvwxyz':
                    trg_var = 'x'
                # trg_var = f'{trg_var}_{len(variable2index)}'
                trg_var = f'{trg_var}{len(variable2index)}'
                trg_var_i = ni
                variable2index[trg_var] = trg_var_i
                index2variable[trg_var_i] = trg_var
                triple = penman.Triple(trg_var, ':instance', n)
                if triple not in triples_added:
                    triples.append(triple)
                    triples_added.add(triple)
                trg = trg_var

            triple = penman.Triple(src_var, e, trg)
            if triple not in triples_added:
                triples.append(triple)
                triples_added.add(triple)

            cnt[src_var][e] += 1

    return penman.Graph(triples)
示例#15
0
from collections import defaultdict, Counter
import enum
import re
import networkx as nx
import penman

from hanlp.components.amr.seq2seq.dataset.penman import pm_encode

BACKOFF = penman.Graph([
    penman.Triple('d2', ':instance', 'dog'),
    penman.Triple('b1', ':instance', 'bark-01'),
    penman.Triple('b1', ':ARG0', 'd2'),
])


def token_processing(tok):
    if tok is None:
        return None
    elif tok.isdigit():
        try:
            return eval(tok)
        except:
            return tok
    elif tok.startswith('"') and (not tok.endswith('"')):
        return tok + '"'
    elif tok.endswith('"') and (not tok.startswith('"')):
        return '"' + tok
    else:
        return tok

示例#16
0
文件: utils.py 项目: soon2soon/tb-amr
    def print_graph(self, gold=False):

        #Fill the nodes attribute for the gold graph
        if gold:

            for a in self.A:
                h = a.head
                r = a.rel
                d = a.dependent
                self.nodes[h.id].add_pred_children((r, d.id))
                self.nodes[d.id].add_pred_relation((r, h.id))

        id_root_id = "-1"
        if gold:
            if "-1" in self.nodes:
                id_root_id = "-1"
            else:
                id_root_id = "0"

        #Post-processing steps: In the predicted graph some nodes might
        #have not been attached to any node.
        #We define an heuristic and connect them tho the as a "snt" branch
        #of the AMR graph
        nodes_without_root = True
        while nodes_without_root:

            nodes_to_connect = []
            for nid in self.nodes:
                if not self.path_to_root(self.nodes[nid], [], gold):
                    nodes_to_connect.append(self.nodes[nid])

            nodes_without_root = (nodes_to_connect != [])
            if nodes_without_root:

                aux = sorted(nodes_to_connect,
                             key=lambda x: len(x.pred_children),
                             reverse=True)[0]
                self.nodes[aux.id].add_pred_relation(("snt", id_root_id))
                self.nodes[id_root_id].add_pred_children(("snt", aux.id))
                nodes_to_connect = []

        is_multisentence = self.is_multisentence_graph(self.nodes[id_root_id])
        root_abr = ID_MULTISENTENCE_ABBR

        if is_multisentence:
            if self.nodes[id_root_id].concept == ID_ROOT_SYMBOL or self.nodes[
                    id_root_id].concept == "and":
                self.nodes[id_root_id].concept = "and"
                root_abr = "a"
            else:
                self.nodes[id_root_id].concept = ID_MULTISENTENCE
                root_abr = ID_MULTISENTENCE_ABBR
        d_abbr = {id_root_id: root_abr}

        added_nodes = set([])
        triplets = []

        nodes_id = [
            e[0] for e in sorted(self.nodes.items(),
                                 key=lambda kv: len(kv[1].pred_children),
                                 reverse=True)
        ]

        for n in nodes_id:

            node = self.nodes[n]
            n_concept = self.nodes[n].concept
            n_id = self.nodes[n].id

            n_abbr = self.abbreviation_for_concept(d_abbr, node, n_id)

            if n_id not in added_nodes and not node.is_constant:
                triplets.append(
                    penman.Triple(source=n_abbr.encode("utf-8"),
                                  relation=INSTANCE_TRIPLET,
                                  target=n_concept.encode("utf-8")))
                added_nodes.add(n_id)

            for r, childid in self.nodes[n].pred_children:

                child_node = self.nodes[childid]
                c_concept = self.nodes[childid].concept
                c_abbr = self.abbreviation_for_concept(d_abbr, child_node,
                                                       childid)

                if childid not in added_nodes and not child_node.is_constant:
                    triplets.append(
                        penman.Triple(source=c_abbr.encode("utf-8"),
                                      relation=INSTANCE_TRIPLET,
                                      target=c_concept.encode("utf-8")))
                    added_nodes.add(childid)

                triplets.append(
                    penman.Triple(source=n_abbr.encode("utf-8"),
                                  relation=r,
                                  target=c_abbr.encode("utf-8")))

        top, triplets = self.get_graph_root(triplets, is_multisentence)

        new_triplets, n_snt = [], 1
        #new_triplets = []
        for t in triplets:

            if t.relation == "*root*" or t.relation.startswith("snt"):

                if is_multisentence and self.nodes[id_root_id].concept == "and":
                    new_triplets.append(
                        penman.Triple(source=t.source,
                                      relation="op" + str(n_snt),
                                      target=t.target))
                    n_snt += 1
                else:
                    new_triplets.append(
                        penman.Triple(source=t.source,
                                      relation="snt" + str(n_snt),
                                      target=t.target))
                    n_snt += 1
            else:
                new_triplets.append(t)

        #Create a dummy AMR graph (for evaluation purposes) if an empty graph was predicted.
        if new_triplets == []:
            top = EMPTY_GRAPH_ABBR  #"e"
            new_triplets.append(
                penman.Triple(source=top,
                              relation=INSTANCE_TRIPLET,
                              target=EMPTY_GRAPH_CONCEPT))
            #new_triplets.append(penman.Triple(source="e",relation=INSTANCE_TRIPLET, target="empty-graph"))

        g = penman.Graph(data=new_triplets, top=top)
        return penman.encode(g)