예제 #1
0
def amr_subgraphs_optimized(g, n_min=1, n_max=None):  # g: AMRGraph object
    """ -> connected subgraphs whose number of nodes is >= n_min & <= n_max """
    output = defaultdict(list)
    # PROXY_AFP_ENG_20021112_0467.11 - a cyclic graph
    if not nx.is_directed_acyclic_graph(g):
        print('The input graph is not directed acyclic.')
        return output

    amr_root = list(g.successors('@'))[0]
    order = list(nx.dfs_preorder_nodes(g, amr_root))
    #    print(order)
    if not n_max:
        n_max = len(g.nodes())
    # assumption: n_min < n_max
    for i in range(n_min, n_max + 1):
        #        print(i)
        for n in order:
            #            pool = list(nx.dfs_preorder_nodes(g,'s',depth_limit=i-1))
            pool = set(y for v in nx.dfs_successors(g, n, depth_limit=i -
                                                    1).values() for y in v)
            #            print(n,pool)
            if len(pool) < i - 1:
                continue
            for ns in itertools.combinations(pool, i - 1):
                sg = g.subgraph((n, ) + ns).copy()
                if nx.is_connected(sg.to_undirected()):
                    amr_root = list(nx.topological_sort(sg))[0]
                    sg.add_edge('@', amr_root, label='')
                    sg = AMRGraph(sg)
                    sg.meta = '# connected subgraph of {} nodes'.format(i)
                    output[i].append(sg)
    return output
예제 #2
0
 def _generate_amr_graph(self):
     #print(self.amr_sentences[0].entry_id)
     start_graph = AMRGraph()
     s_graphs = []
     for sentence in self.amr_sentences:
         #print(sentence.entry_id)
         sentence_graph = start_graph.merge(sentence.amr_graph)
         s_graphs.append(sentence_graph)
     self.amr_graph = start_graph
     self.s_graphs = s_graphs
예제 #3
0
def construct_amr_subgraph(g, nodes, root_node=None):
    """ Add dummy root node '@', apply AMRGraph type """
    if root_node:
        output = g.subgraph(nodes.union({root_node})).copy()
    else:
        output = g.subgraph(nodes).copy()
        # assumption: nx.is_directed_acyclic_graph(output)==True
        root_node = list(nx.topological_sort(output))[0]
    output.add_edge('@', root_node, label='')
    output = AMRGraph(output)
    output.meta = '# root node: {}'.format(root_node)
    return output
예제 #4
0
def amr_subgraphs(g, num):  # g: AMRGraph object
    """ -> connected subgraphs with more than num nodes """
    output = defaultdict(list)
    # assumption: num < len(g.nodes())+1
    for i in range(num, len(g.nodes()) + 1):
        for nodes in itertools.combinations(g.nodes(), i):
            sg = g.subgraph(nodes).copy()
            if nx.is_connected(sg.to_undirected()) and '@' not in sg.nodes():
                amr_root = list(nx.topological_sort(sg))[0]
                sg.add_edge('@', amr_root, label='')
                sg = AMRGraph(sg)
                sg.meta = '# connected subgraph of {} nodes'.format(i)
                output[i].append(sg)
    return output
예제 #5
0
def amr_subgraphs_from_docs(docs, n_min=1, n_max=None):  # docs: amr_table
    """ -> list of string lines of subgraphs of graphs in all documents """
    output = dict()
    for k in docs:
        doc = [AMRGraph(sen=docs[k][kk]) for kk in sorted(docs[k].keys())]
        # output[k] = amr_subgraphs_from_doc(doc,n_min,n_max)
        output[k] = amr_subgraphs_from_doc(doc)
    return output
예제 #6
0
파일: eval.py 프로젝트: lnsongxf/sembleu
def get_amr_ngrams(path, stat_save_path=None):
    data = []
    if stat_save_path:
        f = open(stat_save_path, 'w')
    for line in read_amr(path):
        try:
            amr = AMRGraph(line.strip())
        except AssertionError:
            print line
            assert False
        amr.revert_of_edges()
        ngrams = amr.extract_ngrams(3, multi_roots=True)  # dict(list(tuple))
        data.append(NgramInst(ngram=ngrams, length=len(amr.edges)))
        if stat_save_path:
            print >> f, len(amr), len(ngrams[1]), len(ngrams[2]), len(
                ngrams[3])
    if stat_save_path:
        f.close()
    return data
예제 #7
0
def lines_from_docs(docs):  # docs: amr_table
    """ -> list of string lines of subgraphs of graphs in all documents """
    lines = []
    for k in docs:
        print(k)
        lines.append('# {}'.format(k))
        doc = [AMRGraph(sen=docs[k][kk]) for kk in sorted(docs[k].keys())]
        # lines.extend(lines_from_doc(doc,n_min,n_max))
        lines.extend(lines_from_doc(doc))
    # return [sgs ]
    return lines
예제 #8
0
 def __init__(self):
     self.graph = AMRGraph()
     self.nodes = deque()
     self.eval_later = {}
예제 #9
0
class AMRParser(object):

    def __init__(self):
        self.graph = AMRGraph()
        self.nodes = deque()
        self.eval_later = {}

    def parse(self, amr):
        """
        Builds a DAG from the AMR.

        Senses (watch, boy, etc) never have outgoing edges. Instances of these
        sense have an edge going to the thing that they are an instance of.

        Instances of senses that have arguments also have labeled edges going
        to the instances of each of their arguments.
        """
        tree = self.standardize_spacing(amr)
        while tree:
            if tree[0] == '(':
                node, tree = self.extract_node(tree)
                self.nodes.append(node)
            elif tree[0] == ':':
                tree = self.extract_attr(tree)
            elif tree[0] == ')':
                self.nodes.pop()
                tree = tree[1:].strip()

        return tree

    def extract_node(self, tree):
        m = AMR_NODE.match(tree)
        if not m:
            raise Exception(tree)

        instance, sense, rest = m.groups()
        instance_node, sense_node = [self.graph.add_node(n)
                                     for n in [instance, sense]]
        self.graph.add_edge(instance_node, sense_node, label='instance')

        if instance in self.eval_later:
            # pop the list from the dict, so we don't ever evaluate it again
            later_list = self.eval_later.pop(instance)
            for func, args in later_list:
                func(instance_node, *args)

        return instance_node, rest.strip()

    def extract_attr(self, tree):
        m = AMR_ATTR.match(tree)
        if not m:
            raise Exception(tree)

        attr, rest = m.groups()
        rest = rest.strip()
        if rest[0] == '(':
            n, remainder = self.extract_node(rest)
            self.graph.add_edge(self.nodes[-1], n, label=attr)
            self.nodes.append(n)
        elif rest[0].isalpha():
            # this is a node that has/will be instantiated elsewhere
            span = re.search(r'^\w+\b', rest).span()
            node_name = rest[span[0]:span[1]]
            remainder = rest[span[1]:].strip()
            if node_name not in self.graph.nodes:
                # the node has not actually been instantiated. store a lambda
                # and relevant args to insert the relevant edge whenever the
                # node actually gets instantiated
                later_list = self.eval_later.get(node_name, [])
                later_func = lambda node, other, label: self.graph.add_edge(
                    other,
                    node,
                    label=label,
                )
                later_args = [self.nodes[-1], attr]
                later_list.append((later_func, later_args))
                self.eval_later[node_name] = later_list
            else:
                self.graph.add_edge(
                    self.nodes[-1],
                    self.graph.nodes[node_name],
                    label=attr,
                )
        else:
            if rest[0] == '"':
                val_matcher = AMR_QUOTE_ATTR_VALUE
            else:
                val_matcher = AMR_REG_ATTR_VALUE
            val_match = val_matcher.match(rest)
            if not val_match:
                raise Exception(rest)

            val, remainder = val_match.groups()
            p = CLOSE_PARENS.match(val[::-1])  # strip trailing closed parens
            # put trailing close parens back on the remainder
            if p.groups()[0]:
                val = val[:-len(p.groups()[0])]
                remainder = p.groups()[0] + remainder
            remainder = remainder.strip()
            self.nodes[-1].add_attribute(attr, val)

        return remainder

    def standardize_spacing(self, amr):
        return WHITESPACE.sub(' ', amr)
예제 #10
0
    # see amr_table data structure in amr_reader package (.\scr\reader.py)
    #    amr_table = get_amr_table_path(AMR_PILOT_SELECTED)
    file_name = "amr-release-1.0-proxy_selected"
    #    save_data_pkl(amr_table,os.path.join(AMR_PILOT_SELECTED_PKL,
    #                                         ''.join([file_name,'.pkl'])))
    # TypeError: <amr_reader.models.Sentence.Sentence object at ...
    #            is not JSON serializable
    #    save_data_jsn(amr_table,os.path.join(AMR_PILOT_SELECTED_JSN,
    #                                         ''.join([file_name,'.json'])))
    # docid = 'PROXY_AFP_ENG_20020422_0296'
    # docid = 'PROXY_AFP_ENG_20021112_0467'
    docid = 'PROXY_AFP_ENG_20040329_0408'
    amr_table = load_data_pkl(
        os.path.join(AMR_PILOT_SELECTED_PKL, ''.join([file_name, '.pkl'])))
    doc = amr_table[docid]
    doc = [AMRGraph(sen=doc[k]) for k in sorted(doc.keys())]
    # g = doc[4]
    # print(g)
    # print(lines_from_sent(g,3))
    ti = time.time()
    print('Subgraph generation - START')
    # for line in lines_from_sent(g):
    # print(line)

    #    write_lines(lines_from_doc(doc,3,7),
    # write_lines(lines_from_doc(doc),
    # os.path.join(AMR_PILOT_SELECTED_SUBGRAPHS,
    #  ''.join([docid,'_20180728.txt'])))
    # save_data_pkl(amr_subgraphs_from_doc(doc,3),
    save_data_pkl(
        amr_subgraphs_from_doc(doc),
예제 #11
0
    print >> f, ''
    print ''


# given the sentence, amr and alignment,
# we enumerate all subgraphs of the amr, dump the (subamr, span) pair if it is consistent under the alignment
if __name__ == '__main__':
    print 'loading reference'
    ref = []
    for line in open('AMR-generation/train/token', 'rU'):
        ref.append(line.strip().split())

    alignment = []
    for line in open('AMR-generation/train/alignment', 'rU'):
        alignment.append(line.strip().split())

    amr_line = ''
    i = 0
    f = open('train.dump', 'w')
    for line in open('AMR-generation/train/amr', 'rU'):
        line = line.strip()
        if len(line) == 0:
            if len(amr_line) > 0:
                amr = AMRGraph(amr_line.strip())
                rst = dump(f, amr, ref[i], alignment[i])
                i += 1
                amr_line = ''
        else:
            assert line.startswith('#') == False
            amr_line = amr_line + line