コード例 #1
0
def amr_subgraphs_optimized(g, n_min=1, n_max=None):  # g: AMRGraph object
    """ -> connected subgraphs whose number of nodes is >= n_min & <= n_max """
    output = defaultdict(list)
    # PROXY_AFP_ENG_20021112_0467.11 - a cyclic graph
    if not nx.is_directed_acyclic_graph(g):
        print('The input graph is not directed acyclic.')
        return output

    amr_root = list(g.successors('@'))[0]
    order = list(nx.dfs_preorder_nodes(g, amr_root))
    #    print(order)
    if not n_max:
        n_max = len(g.nodes())
    # assumption: n_min < n_max
    for i in range(n_min, n_max + 1):
        #        print(i)
        for n in order:
            #            pool = list(nx.dfs_preorder_nodes(g,'s',depth_limit=i-1))
            pool = set(y for v in nx.dfs_successors(g, n, depth_limit=i -
                                                    1).values() for y in v)
            #            print(n,pool)
            if len(pool) < i - 1:
                continue
            for ns in itertools.combinations(pool, i - 1):
                sg = g.subgraph((n, ) + ns).copy()
                if nx.is_connected(sg.to_undirected()):
                    amr_root = list(nx.topological_sort(sg))[0]
                    sg.add_edge('@', amr_root, label='')
                    sg = AMRGraph(sg)
                    sg.meta = '# connected subgraph of {} nodes'.format(i)
                    output[i].append(sg)
    return output
コード例 #2
0
 def _generate_amr_graph(self):
     #print(self.amr_sentences[0].entry_id)
     start_graph = AMRGraph()
     s_graphs = []
     for sentence in self.amr_sentences:
         #print(sentence.entry_id)
         sentence_graph = start_graph.merge(sentence.amr_graph)
         s_graphs.append(sentence_graph)
     self.amr_graph = start_graph
     self.s_graphs = s_graphs
コード例 #3
0
def construct_amr_subgraph(g, nodes, root_node=None):
    """ Add dummy root node '@', apply AMRGraph type """
    if root_node:
        output = g.subgraph(nodes.union({root_node})).copy()
    else:
        output = g.subgraph(nodes).copy()
        # assumption: nx.is_directed_acyclic_graph(output)==True
        root_node = list(nx.topological_sort(output))[0]
    output.add_edge('@', root_node, label='')
    output = AMRGraph(output)
    output.meta = '# root node: {}'.format(root_node)
    return output
コード例 #4
0
def amr_subgraphs(g, num):  # g: AMRGraph object
    """ -> connected subgraphs with more than num nodes """
    output = defaultdict(list)
    # assumption: num < len(g.nodes())+1
    for i in range(num, len(g.nodes()) + 1):
        for nodes in itertools.combinations(g.nodes(), i):
            sg = g.subgraph(nodes).copy()
            if nx.is_connected(sg.to_undirected()) and '@' not in sg.nodes():
                amr_root = list(nx.topological_sort(sg))[0]
                sg.add_edge('@', amr_root, label='')
                sg = AMRGraph(sg)
                sg.meta = '# connected subgraph of {} nodes'.format(i)
                output[i].append(sg)
    return output
コード例 #5
0
def amr_subgraphs_from_docs(docs, n_min=1, n_max=None):  # docs: amr_table
    """ -> list of string lines of subgraphs of graphs in all documents """
    output = dict()
    for k in docs:
        doc = [AMRGraph(sen=docs[k][kk]) for kk in sorted(docs[k].keys())]
        # output[k] = amr_subgraphs_from_doc(doc,n_min,n_max)
        output[k] = amr_subgraphs_from_doc(doc)
    return output
コード例 #6
0
ファイル: eval.py プロジェクト: lnsongxf/sembleu
def get_amr_ngrams(path, stat_save_path=None):
    data = []
    if stat_save_path:
        f = open(stat_save_path, 'w')
    for line in read_amr(path):
        try:
            amr = AMRGraph(line.strip())
        except AssertionError:
            print line
            assert False
        amr.revert_of_edges()
        ngrams = amr.extract_ngrams(3, multi_roots=True)  # dict(list(tuple))
        data.append(NgramInst(ngram=ngrams, length=len(amr.edges)))
        if stat_save_path:
            print >> f, len(amr), len(ngrams[1]), len(ngrams[2]), len(
                ngrams[3])
    if stat_save_path:
        f.close()
    return data
コード例 #7
0
def lines_from_docs(docs):  # docs: amr_table
    """ -> list of string lines of subgraphs of graphs in all documents """
    lines = []
    for k in docs:
        print(k)
        lines.append('# {}'.format(k))
        doc = [AMRGraph(sen=docs[k][kk]) for kk in sorted(docs[k].keys())]
        # lines.extend(lines_from_doc(doc,n_min,n_max))
        lines.extend(lines_from_doc(doc))
    # return [sgs ]
    return lines
コード例 #8
0
ファイル: amr_parser.py プロジェクト: jonpiffle/amr_discourse
 def __init__(self):
     self.graph = AMRGraph()
     self.nodes = deque()
     self.eval_later = {}
コード例 #9
0
ファイル: amr_parser.py プロジェクト: jonpiffle/amr_discourse
class AMRParser(object):

    def __init__(self):
        self.graph = AMRGraph()
        self.nodes = deque()
        self.eval_later = {}

    def parse(self, amr):
        """
        Builds a DAG from the AMR.

        Senses (watch, boy, etc) never have outgoing edges. Instances of these
        sense have an edge going to the thing that they are an instance of.

        Instances of senses that have arguments also have labeled edges going
        to the instances of each of their arguments.
        """
        tree = self.standardize_spacing(amr)
        while tree:
            if tree[0] == '(':
                node, tree = self.extract_node(tree)
                self.nodes.append(node)
            elif tree[0] == ':':
                tree = self.extract_attr(tree)
            elif tree[0] == ')':
                self.nodes.pop()
                tree = tree[1:].strip()

        return tree

    def extract_node(self, tree):
        m = AMR_NODE.match(tree)
        if not m:
            raise Exception(tree)

        instance, sense, rest = m.groups()
        instance_node, sense_node = [self.graph.add_node(n)
                                     for n in [instance, sense]]
        self.graph.add_edge(instance_node, sense_node, label='instance')

        if instance in self.eval_later:
            # pop the list from the dict, so we don't ever evaluate it again
            later_list = self.eval_later.pop(instance)
            for func, args in later_list:
                func(instance_node, *args)

        return instance_node, rest.strip()

    def extract_attr(self, tree):
        m = AMR_ATTR.match(tree)
        if not m:
            raise Exception(tree)

        attr, rest = m.groups()
        rest = rest.strip()
        if rest[0] == '(':
            n, remainder = self.extract_node(rest)
            self.graph.add_edge(self.nodes[-1], n, label=attr)
            self.nodes.append(n)
        elif rest[0].isalpha():
            # this is a node that has/will be instantiated elsewhere
            span = re.search(r'^\w+\b', rest).span()
            node_name = rest[span[0]:span[1]]
            remainder = rest[span[1]:].strip()
            if node_name not in self.graph.nodes:
                # the node has not actually been instantiated. store a lambda
                # and relevant args to insert the relevant edge whenever the
                # node actually gets instantiated
                later_list = self.eval_later.get(node_name, [])
                later_func = lambda node, other, label: self.graph.add_edge(
                    other,
                    node,
                    label=label,
                )
                later_args = [self.nodes[-1], attr]
                later_list.append((later_func, later_args))
                self.eval_later[node_name] = later_list
            else:
                self.graph.add_edge(
                    self.nodes[-1],
                    self.graph.nodes[node_name],
                    label=attr,
                )
        else:
            if rest[0] == '"':
                val_matcher = AMR_QUOTE_ATTR_VALUE
            else:
                val_matcher = AMR_REG_ATTR_VALUE
            val_match = val_matcher.match(rest)
            if not val_match:
                raise Exception(rest)

            val, remainder = val_match.groups()
            p = CLOSE_PARENS.match(val[::-1])  # strip trailing closed parens
            # put trailing close parens back on the remainder
            if p.groups()[0]:
                val = val[:-len(p.groups()[0])]
                remainder = p.groups()[0] + remainder
            remainder = remainder.strip()
            self.nodes[-1].add_attribute(attr, val)

        return remainder

    def standardize_spacing(self, amr):
        return WHITESPACE.sub(' ', amr)
コード例 #10
0
    # see amr_table data structure in amr_reader package (.\scr\reader.py)
    #    amr_table = get_amr_table_path(AMR_PILOT_SELECTED)
    file_name = "amr-release-1.0-proxy_selected"
    #    save_data_pkl(amr_table,os.path.join(AMR_PILOT_SELECTED_PKL,
    #                                         ''.join([file_name,'.pkl'])))
    # TypeError: <amr_reader.models.Sentence.Sentence object at ...
    #            is not JSON serializable
    #    save_data_jsn(amr_table,os.path.join(AMR_PILOT_SELECTED_JSN,
    #                                         ''.join([file_name,'.json'])))
    # docid = 'PROXY_AFP_ENG_20020422_0296'
    # docid = 'PROXY_AFP_ENG_20021112_0467'
    docid = 'PROXY_AFP_ENG_20040329_0408'
    amr_table = load_data_pkl(
        os.path.join(AMR_PILOT_SELECTED_PKL, ''.join([file_name, '.pkl'])))
    doc = amr_table[docid]
    doc = [AMRGraph(sen=doc[k]) for k in sorted(doc.keys())]
    # g = doc[4]
    # print(g)
    # print(lines_from_sent(g,3))
    ti = time.time()
    print('Subgraph generation - START')
    # for line in lines_from_sent(g):
    # print(line)

    #    write_lines(lines_from_doc(doc,3,7),
    # write_lines(lines_from_doc(doc),
    # os.path.join(AMR_PILOT_SELECTED_SUBGRAPHS,
    #  ''.join([docid,'_20180728.txt'])))
    # save_data_pkl(amr_subgraphs_from_doc(doc,3),
    save_data_pkl(
        amr_subgraphs_from_doc(doc),
コード例 #11
0
    print >> f, ''
    print ''


# given the sentence, amr and alignment,
# we enumerate all subgraphs of the amr, dump the (subamr, span) pair if it is consistent under the alignment
if __name__ == '__main__':
    print 'loading reference'
    ref = []
    for line in open('AMR-generation/train/token', 'rU'):
        ref.append(line.strip().split())

    alignment = []
    for line in open('AMR-generation/train/alignment', 'rU'):
        alignment.append(line.strip().split())

    amr_line = ''
    i = 0
    f = open('train.dump', 'w')
    for line in open('AMR-generation/train/amr', 'rU'):
        line = line.strip()
        if len(line) == 0:
            if len(amr_line) > 0:
                amr = AMRGraph(amr_line.strip())
                rst = dump(f, amr, ref[i], alignment[i])
                i += 1
                amr_line = ''
        else:
            assert line.startswith('#') == False
            amr_line = amr_line + line