예제 #1
0
def concept_mismatch(all_amr_list, comparison_amr_list, concept):
    """Find cases where dev amr should have 'possibilty' but doesn't"""
    missing = list()
    spurious = list()
    correct = list()
    for comparison_amr in comparison_amr_list:
        matches = [
            base_tuple for base_tuple in all_amr_list
            if base_tuple[0] == comparison_amr[0]
        ]
        if len(matches) > 0:
            match = matches[0]
            match_amr = AMR.parse_AMR_line(match[2])
            match_nodes = match_amr.node_values
            comparison_nodes = AMR.parse_AMR_line(
                comparison_amr[2]).node_values
            id_number = int(comparison_amr[0].split('::')[0].split('.')[1])
            if concept in match_nodes and concept not in comparison_nodes:
                #missing.append(comparison_amr[0])
                missing.append((id_number, comparison_amr[1]))
            elif concept in comparison_nodes and concept not in match_nodes:
                #spurious.append(comparison_amr[0])
                spurious.append((id_number, comparison_amr[1]))
            else:
                #correct.append(comparison_amr[0])
                correct.append((id_number, comparison_amr[1]))
    return sorted(correct), sorted(missing), sorted(spurious)
예제 #2
0
 def html(text, delete_x_ids=True):
     amr = AMR(text)
     elems = [e for e in amr.text_elements]
     nodes = [id for id in amr.node_ids()]
     edges = [id for id in amr.edge_ids()]
     node_indices = [i for i,e in enumerate(amr.text_elements) if amr.NODE_RE.match(e)]
     edge_indices = [i for i,e in enumerate(amr.text_elements) if amr.EDGE_RE.match(e)]
     Named_Entity_RE = re.compile('x[0-9]+/".*?"')
     for i,e in enumerate(elems):
         if i in node_indices:
             id = nodes.pop(0)
             frame = e.split('/')[-1] if '/' in e else '_'
             node = e
             if delete_x_ids:
                 node = re.sub('^x[0-9]+/', '', e, 1)
             if frame in propbank_frames_dictionary:
                 description = propbank_frames_dictionary[frame].replace('\t','\n')
                 elems[i] = f'<span class="amr-frame" tok-id="{id}" title="{description}">{node}</span>'
             elif Named_Entity_RE.match(e):
                 elems[i] = f'<span class="amr-entity" tok-id="{id}">{node}</span>'
             else:
                 elems[i] = f'<span class="amr-node" tok-id="{id}">{node}</span>'
         elif i in edge_indices:
             id = edges.pop(0)
             elems[i] = f'<span class="amr-edge" tok-id="{id}">{e}</span>'
     text = ''.join(elems)
     return '\n<div class="amr-container">\n<pre>\n'+text+'\n</pre>\n</div>\n'
예제 #3
0
def main():

    amr_file = r'test-data/amrs.txt'
    sentence_file = r'test-data/sentences.txt'
    if len(sys.argv) > 2:
        amr_file = sys.argv[1]
        sentence_file = sys.argv[2]

    failed_amrs = Counter()
    failed_words = Counter()
    with open(sentence_file, 'r', encoding='utf8') as f1:
        sentences = [s for s in re.split('\n\s*\n', f1.read()) if s]
        with open(amr_file, 'r', encoding='utf8') as f2:
            for i, amr in enumerate(AMR.amr_iter(f2.read())):
                print('#' + str(i + 1))
                words = sentences[i].strip().split()
                amr = AMR(amr)
                # test_rules(amr, words)
                alignments, amr_unal, words_unal = align_amr(amr, words)
                print('# AMR:')
                print('\n'.join('# ' + l for l in str(amr).split('\n')))
                print('# Sentence:')
                print('# ' + ' '.join(words))
                print('# Alignments:')
                for a in alignments:
                    print('#', a.readible())
                for a in alignments:
                    print(a)
                print()
예제 #4
0
def read_amr(file_path):
    # read all lines
    ref_amr_lines = []
    with open(file_path, encoding='utf8') as fid:
        line = AMR.get_amr_line(fid)
        while line:
            ref_amr_lines.append(line)
            line = AMR.get_amr_line(fid)
    return ref_amr_lines
예제 #5
0
    def latex(text):
        amr = AMR(text)
        text = str(amr)
        for x in re.findall('x[0-9]+ ?/ ?[^()\s]+', text):
            text = text.replace(x, '(' + x + ')')
        edges = [(e, id) for e, id in zip(amr.edges(), amr.edge_ids())]
        elems = []
        max_depth = paren_utils.max_depth(text)
        prev_depth = 0
        depth = 0

        i = 0
        node_depth = {}
        for t in paren_utils.paren_iter(text):
            node = amr.NODE_RE.match(t).group()
            id = node.split('/')[0].strip()
            # clean node
            if re.match('x[0-9]+/', node):
                node = node.split('/')[1]
            node = node.replace('"', '``', 1).replace('"', "''", 1)
            prev_depth = depth
            depth = paren_utils.depth_at(text, text.index(t))
            if depth > prev_depth:
                i = 0
            node_depth[id] = depth
            num_nodes = paren_utils.mark_depth(text).count(f'<{depth}>')
            x = AMR_Latex.get_x(i, num_nodes)
            y = AMR_Latex.get_y(depth, max_depth)
            color = AMR_Latex.get_color(i)
            elems.append(f'\t\\node[{color}]({id}) at ({x},{y}) {{{node}}};')
            i += 1
        for edge, id in edges:
            source = id.split('_')[0]
            target = id.split('_')[2]
            dir1 = 'south'
            dir2 = 'north'
            if node_depth[source] > node_depth[target]:
                dir1 = 'north'
                dir2 = 'south'
            if node_depth[source] == node_depth[target]:
                dir1 = 'north'
                dir2 = 'north'
            elems.append(
                f'\t\draw[->, thick] ({source}.{dir1}) -- ({target}.{dir2}) node[midway, above, sloped] {{{edge}}};'
            )
        latex = '\n\\begin{tikzpicture}[\n'
        latex += 'red/.style={rectangle, draw=red!60, fill=red!5, very thick, minimum size=7mm},\n'
        latex += 'blue/.style={rectangle, draw=blue!60, fill=blue!5, very thick, minimum size=7mm},\n'
        latex += 'green/.style={rectangle, draw=green!60, fill=green!5, very thick, minimum size=7mm},\n'
        latex += 'purple/.style={rectangle, draw=purple!60, fill=purple!5, very thick, minimum size=7mm},\n'
        latex += 'orange/.style={rectangle, draw=orange!60, fill=orange!5, very thick, minimum size=7mm},\n'
        latex += ']\n'
        latex += '\n'.join(elems)
        latex += '\n\end{tikzpicture}\n'

        return latex
예제 #6
0
def normalize_entity(root, nodes, edges):
    normalize_ids = {
        id: i
        for i, id in enumerate(sorted(nodes, key=lambda x: nodes[x]))
    }
    normalized_entity = AMR()
    for n in nodes:
        normalized_entity.nodes[normalize_ids[n]] = nodes[n]
    for s, r, t in edges:
        normalized_entity.edges.append((normalize_ids[s], r, normalize_ids[t]))
    normalized_entity.edges = sorted(normalized_entity.edges)
    normalized_entity.root = normalize_ids[root]
    return normalized_entity
예제 #7
0
def main(data):
    logging.basicConfig(level=logging.ERROR)
    logger = logging.getLogger(__name__)
    test = codecs.open(data.test, 'r', 'utf-8')
    gold = codecs.open(data.gold, 'r', 'utf-8')
    flag = False
    sema = Sema()
    while True:
        cur_amr1 = AMR.get_amr_line(test)
        cur_amr2 = AMR.get_amr_line(gold)

        if cur_amr1 == '' and cur_amr2 == '':
            break
        if cur_amr1 == '':
            logger.error('Error: File 1 has less AMRs than file 2')
            logger.error('Ignoring remaining AMRs')
            flag = True
            break
        if cur_amr2 == '':
            logger.error('Error: File 2 has less AMRs than file 1')
            logger.error('Ignoring remaining AMRs')
            flag = True
            break
        try:
            amr1 = AMR.parse_AMR_line(cur_amr1)
        except Exception as e:
            logger.error('Error in parsing amr 1: %s' % cur_amr1)
            logger.error(
                "Please check if the AMR is ill-formatted. Ignoring remaining AMRs"
            )
            logger.error("Error message: %s" % str(e))
            flag = True
            break
        try:
            amr2 = AMR.parse_AMR_line(cur_amr2)
        except Exception as e:
            logger.error("Error in parsing amr 2: %s" % cur_amr2)
            logger.error(
                "Please check if the AMR is ill-formatted. Ignoring remaining AMRs"
            )
            logger.error("Error message: %s" % str(e))
            flag = True
            break
        prefix_test = 'a'
        prefix_gold = 'b'
        amr1.rename_node(prefix_test)
        amr2.rename_node(prefix_gold)
        sema.compute_sema(amr1, amr2)
    if not flag:
        precision, recall, f1 = sema.get_sema_value()
        print(f'SEMA: P {precision:.2f} R {recall:.2f} F1 {f1:.2f}')
예제 #8
0
def score_amr_pair(ref_amr_line, rec_amr_line, restart_num, justinstance=False, 
                   justattribute=False, justrelation=False):

    # parse lines
    amr1 = AMR.parse_AMR_line(ref_amr_line)
    amr2 = AMR.parse_AMR_line(rec_amr_line)

    # Fix prefix
    prefix1 = "a"
    prefix2 = "b"
    # Rename node to "a1", "a2", .etc
    amr1.rename_node(prefix1)
    # Renaming node to "b1", "b2", .etc
    amr2.rename_node(prefix2)

    # get triples
    (instance1, attributes1, relation1) = amr1.get_triples()
    (instance2, attributes2, relation2) = amr2.get_triples()

    # optionally turn off some of the node comparison
    doinstance = doattribute = dorelation = True
    if justinstance:
        doattribute = dorelation = False
    if justattribute:
        doinstance = dorelation = False
    if justrelation:
        doinstance = doattribute = False

    (best_mapping, best_match_num) = smatch.get_best_match(
        instance1, attributes1, relation1,
        instance2, attributes2, relation2,
        prefix1, prefix2, 
        restart_num,
        doinstance=doinstance,
        doattribute=doattribute,
        dorelation=dorelation
    )

    if justinstance:
        test_triple_num = len(instance1)
        gold_triple_num = len(instance2)
    elif justattribute:
        test_triple_num = len(attributes1)
        gold_triple_num = len(attributes2)
    elif justrelation:
        test_triple_num = len(relation1)
        gold_triple_num = len(relation2)
    else:
        test_triple_num = len(instance1) + len(attributes1) + len(relation1)
        gold_triple_num = len(instance2) + len(attributes2) + len(relation2)
    return best_match_num, test_triple_num, gold_triple_num
예제 #9
0
    def __init__(self, tokens, verbose=False, add_unaligned=0):
        tokens = tokens.copy()

        # add unaligned
        if add_unaligned and '<unaligned>' not in tokens:
            for i in range(add_unaligned):
                tokens.append('<unaligned>')
        # add root
        if '<ROOT>' not in tokens:
            tokens.append("<ROOT>")
        # init stack, buffer
        self.stack = []
        self.buffer = list(
            reversed([
                i + 1 for i, tok in enumerate(tokens) if tok != '<unaligned>'
            ]))
        self.latent = list(
            reversed([
                i + 1 for i, tok in enumerate(tokens) if tok == '<unaligned>'
            ]))

        # init amr
        self.amr = AMR(tokens=tokens)
        for i, tok in enumerate(tokens):
            if tok != "<ROOT>":
                self.amr.nodes[i + 1] = tok
        # add root
        self.buffer[0] = -1
        self.amr.nodes[-1] = "<ROOT>"

        self.new_id = len(tokens) + 1
        self.verbose = verbose
        # parser target output
        self.actions = []
        self.labels = []
        self.labelsA = []
        self.predicates = []

        # information for oracle
        self.merged_tokens = {}
        self.entities = []
        self.is_confirmed = set()
        self.is_confirmed.add(-1)
        self.swapped_words = {}

        if self.verbose:
            print('INIT')
            print(self.printStackBuffer())
예제 #10
0
def get_named_entities(all_amr_file=None):
    """Get all the named entities
    Inputs:
        amr_file: file with all the AMRs
    Returns:
        list of (id, snt, amr) tuples
    """
    if all_amr_file is None:
        all_amr_file = GOLD_AMRS
    match_amrs = list()  #(id,snt)
    comments_and_amrs = read_amrz(all_amr_file)  #(comment_list, amr_list)
    comments = comments_and_amrs[0]  #{'snt','id'}
    amrs = comments_and_amrs[1]
    for i in range(len(amrs)):
        amr_graph = AMR.parse_AMR_line(amrs[i])
        # amr_evaluation var2concept
        v2c = {}
        for n, v in zip(amr_graph.nodes, amr_graph.node_values):
            v2c[n] = v
        # print(v2c)
        # I don't know why we need these indices but we do
        triples = [t for t in amr_graph.get_triples()[1]]
        triples.extend([t for t in amr_graph.get_triples()[2]])
        #print(triples)
        # named_ent(v2c, triples)
        named_entities = [
            str(v2c[v1]) for (l, v1, v2) in triples if l == "name"
        ]
        print(named_entities)
예제 #11
0
def main():
    input_file = r'test-data/amrs.txt'
    if len(sys.argv) > 1:
        input_file = sys.argv[1]

    with open(input_file, 'r', encoding='utf8') as f:
        for amr in AMR.amr_iter(f.read()):
            amr = AMR_Latex.latex(amr)
            print(amr)
            print()
예제 #12
0
def main():
    input_file = r'test-data/amrs.txt'
    ids = True if '-x' in sys.argv else False
    if len(sys.argv) > 1:
        input_file = sys.argv[-1]

    with open(input_file, 'r', encoding='utf8') as f:
        for amr in AMR.amr_iter(f.read()):
            amr = AMR_HTML.html(amr, ids)
            print(amr)
            print()
예제 #13
0
파일: api.py 프로젝트: Oneplus/camr
def _smatch(cur_amr1, cur_amr2, n_iter):
    clear_match_triple_dict()

    amr1 = AMR.parse_AMR_line(cur_amr1)
    amr2 = AMR.parse_AMR_line(cur_amr2)
    prefix1 = "a"
    prefix2 = "b"

    amr1.rename_node(prefix1)
    amr2.rename_node(prefix2)
    instance1, attributes1, relation1 = amr1.get_triples()
    instance2, attributes2, relation2 = amr2.get_triples()

    best_mapping, best_match_num = get_best_match(instance1, attributes1,
                                                  relation1, instance2,
                                                  attributes2, relation2,
                                                  prefix1, prefix2)

    test_triple_num = len(instance1) + len(attributes1) + len(relation1)
    gold_triple_num = len(instance2) + len(attributes2) + len(relation2)
    return best_match_num, test_triple_num, gold_triple_num
예제 #14
0
def get_amr(tokens, actions, entity_rules):

    # play state machine to get AMR
    state_machine = AMRStateMachine(tokens, entity_rules=entity_rules)
    for action in actions:
        # FIXME: It is unclear that this will be allways the right option
        # manual exploration of dev yielded 4 cases and it works for the 4
        if action == "<unk>":
            action = f'PRED({state_machine.get_top_of_stack()[0].lower()})'
        state_machine.applyAction(action)

    # sanity check: foce close
    if not state_machine.is_closed:
        alert_str = yellow_font('Machine not closed!')
        print(alert_str)
        state_machine.CLOSE()

    # TODO: Probably waisting ressources here
    amr_str = state_machine.amr.toJAMRString()
    return AMR.get_amr_line(amr_str.split('\n'))
예제 #15
0
def get_amrs_with_concept(concept, all_amr_file=None):
    """Get the IDs of all AMRs with 'possible' concept
    Inputs:
        amr_file: file with all the AMRs
    Returns:
        list of (id, snt, amr) tuples
    """
    if all_amr_file is None:
        all_amr_file = GOLD_AMRS
    match_amrs = list()  #(id,snt)
    comments_and_amrs = read_amrz(all_amr_file)  #(comment_list, amr_list)
    comments = comments_and_amrs[0]  #{'snt','id'}
    amrs = comments_and_amrs[1]
    for i in range(len(amrs)):
        amr_graph = AMR.parse_AMR_line(amrs[i])
        node_values = amr_graph.node_values
        if concept in node_values:
            match_amrs.append((comments[i]['id'], comments[i]['snt'], amrs[i]))
            #possible_ids.append((comments[i]['id'].encode('utf8'),comments[i]['snt'].encode('utf8'),amrs[i].encode('utf8')))
    print("Total number of AMRs with '{}': {}".format(concept,
                                                      len(match_amrs)))
    return sorted(match_amrs,
                  key=lambda x: int(x[0].split(' ')[0].split('.')[1])
                  )  #sort by id number
예제 #16
0
class Transitions:
    def __init__(self, tokens, verbose=False, add_unaligned=0):
        tokens = tokens.copy()

        # add unaligned
        if add_unaligned and '<unaligned>' not in tokens:
            for i in range(add_unaligned):
                tokens.append('<unaligned>')
        # add root
        if '<ROOT>' not in tokens:
            tokens.append("<ROOT>")
        # init stack, buffer
        self.stack = []
        self.buffer = list(
            reversed([
                i + 1 for i, tok in enumerate(tokens) if tok != '<unaligned>'
            ]))
        self.latent = list(
            reversed([
                i + 1 for i, tok in enumerate(tokens) if tok == '<unaligned>'
            ]))

        # init amr
        self.amr = AMR(tokens=tokens)
        for i, tok in enumerate(tokens):
            if tok != "<ROOT>":
                self.amr.nodes[i + 1] = tok
        # add root
        self.buffer[0] = -1
        self.amr.nodes[-1] = "<ROOT>"

        self.new_id = len(tokens) + 1
        self.verbose = verbose
        # parser target output
        self.actions = []
        self.labels = []
        self.labelsA = []
        self.predicates = []

        # information for oracle
        self.merged_tokens = {}
        self.entities = []
        self.is_confirmed = set()
        self.is_confirmed.add(-1)
        self.swapped_words = {}

        if self.verbose:
            print('INIT')
            print(self.printStackBuffer())

    def __str__(self):
        s = '\t'.join(self.amr.tokens) + '\n'
        s += '\t'.join([a for a in self.actions]) + '\n'
        return s + '\n'

    @classmethod
    def readAction(cls, action):
        s = [action]
        if action.startswith('DEPENDENT') or action in [
                'LA(root)', 'RA(root)', 'LA1(root)', 'RA1(root)'
        ]:
            return s
        if '(' in action:
            paren_idx = action.index('(')
            s[0] = action[:paren_idx]
            properties = action[paren_idx + 1:-1]
            if ',' in properties:
                s.extend(properties.split(','))
            else:
                s.append(properties)
        return s

    def applyAction(self, act):
        action = self.readAction(act)
        action_label = action[0]
        if action_label in ['SHIFT']:
            if self.buffer:
                self.SHIFT()
            else:
                self.CLOSE()
                return True
        elif action_label in ['REDUCE', 'REDUCE1']:
            self.REDUCE()
        elif action_label in ['LA', 'LA1']:
            self.LA(action[1] if action[1].startswith(':') else ':' +
                    action[1])
        elif action_label in ['RA', 'RA1']:
            self.RA(action[1] if action[1].startswith(':') else ':' +
                    action[1])
        elif action_label in ['LA(root)', 'LA1(root)']:
            self.LA('root')
        elif action_label in ['RA(root)', 'RA1(root)']:
            self.RA('root')
        elif action_label in ['PRED', 'CONFIRM']:
            self.CONFIRM(action[-1])
        elif action_label in ['SWAP', 'UNSHIFT', 'UNSHIFT1']:
            self.SWAP()
        elif action_label in ['DUPLICATE']:
            self.DUPLICATE()
        elif action_label in ['INTRODUCE']:
            self.INTRODUCE()
        elif action_label.startswith('DEPENDENT'):
            paren_idx = action_label.index('(')
            properties = action_label[paren_idx + 1:-1].split(',')
            self.DEPENDENT(properties[1], properties[0])
        elif action_label in ['ADDNODE', 'ENTITY']:
            self.ENTITY(','.join(action[1:]))
        elif action_label in ['MERGE']:
            self.MERGE()
        elif action_label in ['CLOSE']:
            self.CLOSE()
            return True
        else:
            raise Exception(f'Unrecognized action: {act}')

    def applyActions(self, actions):
        for action in actions:
            is_closed = self.applyAction(action)
            if is_closed:
                return
        self.CLOSE()

    def SHIFT(self):
        """SHIFT : move buffer[-1] to stack[-1]"""

        if not self.buffer:
            self.CLOSE()
        tok = self.buffer.pop()
        self.stack.append(tok)
        self.actions.append('SHIFT')
        self.labels.append('_')
        self.labelsA.append('_')
        self.predicates.append('_')
        if self.verbose:
            print('SHIFT')
            print(self.printStackBuffer())

    def REDUCE(self):
        """REDUCE : delete token"""

        stack0 = self.stack.pop()
        # if stack0 has no edges, delete it from the amr
        if stack0 != -1 and stack0 not in self.entities:
            if len([e for e in self.amr.edges if stack0 in e]) == 0:
                if stack0 in self.amr.nodes:
                    del self.amr.nodes[stack0]
        self.actions.append('REDUCE')
        self.labels.append('_')
        self.labelsA.append('_')
        self.predicates.append('_')
        if self.verbose:
            print('REDUCE')
            print(self.printStackBuffer())

    def CONFIRM(self, node_label):
        """CONFIRM : assign a propbank label"""

        stack0 = self.stack[-1]
        # old_label = self.amr.nodes[stack0].split(',')[-1]
        # old_label = old_label.replace(',','-COMMA-').replace(')','-PAREN-')
        self.amr.nodes[stack0] = node_label
        self.actions.append(f'PRED({node_label})')
        self.labels.append('_')
        self.labelsA.append('_')
        self.predicates.append(node_label)
        self.is_confirmed.add(stack0)
        if self.verbose:
            print(f'PRED({node_label})')
            print(self.printStackBuffer())

    def LA(self, edge_label):
        """LA : add an edge from stack[-1] to stack[-2]"""

        head = self.stack[-1]
        dependent = self.stack[-2]
        self.amr.edges.append((head, edge_label, dependent))
        self.actions.append(f'LA({edge_label.replace(":","")})')
        if edge_label != 'root':
            self.labels.append(edge_label)
        else:
            self.labels.append('_')
        self.labelsA.append('_')
        self.predicates.append('_')
        if self.verbose:
            print(f'LA({edge_label})')
            print(self.printStackBuffer())

    def RA(self, edge_label):
        """RA : add an edge from stack[-2] to stack[-1]"""

        head = self.stack[-2]
        dependent = self.stack[-1]
        self.amr.edges.append((head, edge_label, dependent))
        self.actions.append(f'RA({edge_label.replace(":","")})')
        if edge_label != 'root':
            self.labels.append(edge_label)
        else:
            self.labels.append('_')
        self.labelsA.append('_')
        self.predicates.append('_')
        if self.verbose:
            print(f'RA({edge_label})')
            print(self.printStackBuffer())

    def MERGE(self):
        """MERGE : merge two tokens to be the same node"""

        lead = self.stack.pop()
        sec = self.stack.pop()
        self.stack.append(lead)

        # maintain merged tokens dict
        if lead not in self.merged_tokens:
            self.merged_tokens[lead] = [lead]
        if sec in self.merged_tokens:
            self.merged_tokens[
                lead] = self.merged_tokens[sec] + self.merged_tokens[lead]
        else:
            self.merged_tokens[lead].insert(0, sec)
        merged = ','.join(self.amr.tokens[x - 1].replace(',', '-COMMA-')
                          for x in self.merged_tokens[lead])

        for i, e in enumerate(self.amr.edges):
            if e[1] == 'entity':
                continue
            if sec == e[0]:
                self.amr.edges[i] = (lead, e[1], e[2])
            if sec == e[2]:
                self.amr.edges[i] = (e[0], e[1], lead)

        # Just in case you merge entities. This shouldn't happen but might.
        if lead in self.entities:
            entity_edges = [
                e for e in self.amr.edges if e[0] == lead and e[1] == 'entity'
            ]
            lead = [t for s, r, t in entity_edges][0]
        if sec in self.entities:
            entity_edges = [
                e for e in self.amr.edges if e[0] == sec and e[1] == 'entity'
            ]
            child = [t for s, r, t in entity_edges][0]
            del self.amr.nodes[sec]
            for e in entity_edges:
                self.amr.edges.remove(e)
            self.entities.remove(sec)
            sec = child

        # make tokens into a single node
        del self.amr.nodes[sec]
        self.amr.nodes[lead] = merged

        self.actions.append(f'MERGE')
        self.labels.append('_')
        self.labelsA.append('_')
        self.predicates.append('_')
        if self.verbose:
            print(f'MERGE({self.amr.nodes[lead]})')
            print(self.printStackBuffer())

    def ENTITY(self, entity_type):
        """ENTITY : create a named entity"""

        head = self.stack[-1]
        child_id = self.new_id
        self.new_id += 1

        self.amr.nodes[child_id] = self.amr.nodes[head]
        self.amr.nodes[head] = f'({entity_type})'
        self.amr.edges.append((head, 'entity', child_id))
        self.entities.append(head)

        self.actions.append(f'ADDNODE({entity_type})')
        self.labels.append('_')
        self.labelsA.append(f'{entity_type}')
        self.predicates.append('_')
        if self.verbose:
            print(f'ADDNODE({entity_type})')
            print(self.printStackBuffer())

    def DEPENDENT(self, edge_label, node_label, node_id=None):
        """DEPENDENT : add a single edge and node"""

        head = self.stack[-1]
        new_id = self.new_id

        edge_label = edge_label if edge_label.startswith(
            ':') else ':' + edge_label

        if node_id:
            new_id = node_id
        else:
            self.amr.nodes[new_id] = node_label
        self.amr.edges.append((head, edge_label, new_id))
        self.new_id += 1
        self.actions.append(
            f'DEPENDENT({node_label},{edge_label.replace(":","")})')
        self.labels.append('_')
        self.labelsA.append('_')
        self.predicates.append('_')
        if self.verbose:
            print(f'DEPENDENT({edge_label},{node_label})')
            print(self.printStackBuffer())

    def SWAP(self):
        """SWAP : move stack[1] to buffer"""

        stack0 = self.stack.pop()
        stack1 = self.stack.pop()
        self.buffer.append(stack1)
        self.stack.append(stack0)
        self.actions.append('UNSHIFT')
        self.labels.append('_')
        self.labelsA.append('_')
        self.predicates.append('_')
        if stack1 not in self.swapped_words:
            self.swapped_words[stack1] = []
        self.swapped_words[stack1].append(stack0)
        if self.verbose:
            print('UNSHIFT')
            print(self.printStackBuffer())

    def INTRODUCE(self):
        """INTRODUCE : move latent[-1] to stack"""

        latent0 = self.latent.pop()
        self.stack.append(latent0)
        self.actions.append('INTRODUCE')
        self.labels.append('_')
        self.labelsA.append('_')
        self.predicates.append('_')
        if self.verbose:
            print('INTRODUCE')
            print(self.printStackBuffer())

    def CLOSE(self, training=False, gold_amr=None, use_addnonde_rules=False):
        """CLOSE : finish parsing"""

        self.buffer = []
        self.stack = []
        if training and not use_addnonde_rules:
            self.postprocessing_training(gold_amr)
        else:
            self.postprocessing(gold_amr)

        for item in self.latent:
            if item in self.amr.nodes and not any(
                    item == s or item == t for s, r, t in self.amr.edges):
                del self.amr.nodes[item]
        self.latent = []
        # clean concepts
        for n in self.amr.nodes:
            if self.amr.nodes[n] in ['.', '?', '!', ',', ';', '"', "'"]:
                self.amr.nodes[n] = 'PUNCT'
            if self.amr.nodes[n].startswith(
                    '"') and self.amr.nodes[n].endswith('"'):
                self.amr.nodes[n] = '"' + self.amr.nodes[n].replace('"',
                                                                    '') + '"'
            if not (self.amr.nodes[n].startswith('"')
                    and self.amr.nodes[n].endswith('"')):
                for ch in ['/', ':', '(', ')', '\\']:
                    if ch in self.amr.nodes[n]:
                        self.amr.nodes[n] = self.amr.nodes[n].replace(ch, '-')
            if not self.amr.nodes[n]:
                self.amr.nodes[n] = 'None'
            if ',' in self.amr.nodes[n]:
                self.amr.nodes[n] = '"' + self.amr.nodes[n].replace('"',
                                                                    '') + '"'
            if not self.amr.nodes[n][0].isalpha() and not self.amr.nodes[n][
                    0].isdigit() and not self.amr.nodes[n][0] in ['-', '+']:
                self.amr.nodes[n] = '"' + self.amr.nodes[n].replace('"',
                                                                    '') + '"'
        # clean edges
        for j, e in enumerate(self.amr.edges):
            s, r, t = e
            if not r.startswith(':'):
                r = ':' + r
            e = (s, r, t)
            self.amr.edges[j] = e
        # handle missing nodes (this shouldn't happen but a bad sequence of actions can produce it)
        for s, r, t in self.amr.edges:
            if s not in self.amr.nodes:
                self.amr.nodes[s] = 'NA'
            if t not in self.amr.nodes:
                self.amr.nodes[t] = 'NA'
        self.connectGraph()

        self.actions.append('SHIFT')
        self.labels.append('_')
        self.labelsA.append('_')
        self.predicates.append('_')
        if self.verbose:
            print('CLOSE')
            print(self.printStackBuffer())
            print(self.amr.toJAMRString())

    def printStackBuffer(self):
        s = 'STACK [' + ' '.join(
            self.amr.nodes[x] if x in self.amr.nodes else 'None'
            for x in self.stack) + '] '
        s += 'BUFFER [' + ' '.join(
            self.amr.nodes[x] if x in self.amr.nodes else 'None'
            for x in reversed(self.buffer)) + ']\n'
        if self.latent:
            s += 'LATENT [' + ' '.join(
                self.amr.nodes[x] if x in self.amr.nodes else 'None'
                for x in reversed(self.latent)) + ']\n'
        return s

    def connectGraph(self):
        assigned_root = None
        root_edges = []
        if -1 in self.amr.nodes:
            del self.amr.nodes[-1]
        for s, r, t in self.amr.edges:
            if s == -1 and r == "root":
                assigned_root = t
            if s == -1 or t == -1:
                root_edges.append((s, r, t))
        for e in root_edges:
            self.amr.edges.remove(e)

        if not self.amr.nodes:
            return

        descendents = {n: {n} for n in self.amr.nodes}
        potential_roots = [n for n in self.amr.nodes]
        for x, r, y in self.amr.edges:
            if y in potential_roots and x not in descendents[y]:
                potential_roots.remove(y)
            descendents[x].update(descendents[y])
            for n in descendents:
                if x in descendents[n]:
                    descendents[n].update(descendents[x])

        disconnected = potential_roots.copy()
        for n in potential_roots.copy():
            if len([e for e in self.amr.edges if e[0] == n]) == 0:
                potential_roots.remove(n)

        # assign root
        if potential_roots:
            self.amr.root = potential_roots[0]
            for n in potential_roots:
                if self.amr.nodes[n] == 'multi-sentence' or n == assigned_root:
                    self.amr.root = n
            disconnected.remove(self.amr.root)
        else:
            self.amr.root = max(
                self.amr.nodes.keys(),
                key=lambda x: len([e for e in self.amr.edges if e[0] == x]) -
                len([e for e in self.amr.edges if e[2] == x]))
        # connect graph
        if len(disconnected) > 0:
            for n in disconnected:
                self.amr.edges.append((self.amr.root, default_rel, n))

    def postprocessing_training(self, gold_amr):

        for entity_id in self.entities:

            entity_edges = [
                e for e in self.amr.edges
                if e[0] == entity_id and e[1] == 'entity'
            ]

            for e in entity_edges:
                self.amr.edges.remove(e)

            child_id = [t for s, r, t in entity_edges][0]
            del self.amr.nodes[child_id]

            new_node_ids = []

            entity_alignment = gold_amr.alignmentsToken2Node(entity_id)
            gold_entity_subgraph = gold_amr.findSubGraph(entity_alignment)

            for i, n in enumerate(entity_alignment):
                if i == 0:
                    self.amr.nodes[entity_id] = gold_amr.nodes[n]
                    new_node_ids.append(entity_id)
                else:
                    self.amr.nodes[self.new_id] = gold_amr.nodes[n]
                    new_node_ids.append(self.new_id)
                    self.new_id += 1

            for s, r, t in gold_entity_subgraph.edges:
                new_s = new_node_ids[entity_alignment.index(s)]
                new_t = new_node_ids[entity_alignment.index(t)]
                self.amr.edges.append((new_s, r, new_t))

    def postprocessing(self, gold_amr):
        global entity_rules_json, entity_rule_stats, entity_rule_totals, entity_rule_fails

        if not entity_rules_json:
            with open('entity_rules.json', 'r', encoding='utf8') as f:
                entity_rules_json = json.load(f)

        for entity_id in self.entities:

            if entity_id not in self.amr.nodes:
                continue
            # Test postprocessing ----------------------------
            gold_concepts = []
            if gold_amr:
                entity_alignment = gold_amr.alignmentsToken2Node(entity_id)
                gold_entity_subgraph = gold_amr.findSubGraph(entity_alignment)
                for n in gold_entity_subgraph.nodes:
                    node = gold_entity_subgraph.nodes[n]
                    if n == gold_entity_subgraph.root:
                        gold_concepts.append(node)
                    for s, r, t in gold_entity_subgraph.edges:
                        if t == n:
                            edge = r
                            gold_concepts.append(edge + ' ' + node)
            # -------------------------------------------

            new_concepts = []

            entity_type = self.amr.nodes[entity_id]
            if entity_type.startswith('('):
                entity_type = entity_type[1:-1]
            entity_edges = [
                e for e in self.amr.edges
                if e[0] == entity_id and e[1] == 'entity'
            ]
            if not entity_edges:
                continue

            child_id = [t for s, r, t in entity_edges][0]
            entity_tokens = self.amr.nodes[child_id].split(',')

            for e in entity_edges:
                self.amr.edges.remove(e)
            del self.amr.nodes[child_id]

            # date-entity special rules
            if entity_type == 'date-entity':
                date_entity_rules = entity_rules_json['date-entity']
                assigned_edges = ['' for _ in entity_tokens]
                if len(entity_tokens) == 1:
                    date = entity_tokens[0]
                    if date.isdigit() and len(date) == 8:
                        # format yyyymmdd
                        entity_tokens = [date[:4], date[4:6], date[6:]]
                        assigned_edges = [':year', ':month', ':day']
                    elif date.isdigit() and len(date) == 6:
                        # format yymmdd
                        entity_tokens = [date[:2], date[2:4], date[4:]]
                        assigned_edges = [':year', ':month', ':day']
                    elif '/' in date and date.replace('/', '').isdigit():
                        # format mm-dd-yyyy
                        entity_tokens = date.split('/')
                        assigned_edges = ['' for _ in entity_tokens]
                    elif '-' in date and date.replace('-', '').isdigit():
                        # format mm-dd-yyyy
                        entity_tokens = date.split('-')
                        assigned_edges = ['' for _ in entity_tokens]
                    elif date.lower() == 'tonight':
                        entity_tokens = ['night', 'today']
                        assigned_edges = [':dayperiod', ':mod']
                    elif date[0].isdigit() and (date.endswith('BC')
                                                or date.endswith('AD')
                                                or date.endswith('BCE')
                                                or date.endswith('CE')):
                        # 10,000BC
                        idx = 0
                        for i in range(len(date)):
                            if date[i].isalpha():
                                idx = i
                        entity_tokens = [date[:idx], date[idx:]]
                        assigned_edges = [':year', ':era']
                for j, tok in enumerate(entity_tokens):
                    if assigned_edges[j]:
                        continue
                    if tok.lower() in date_entity_rules[':weekday']:
                        assigned_edges[j] = ':weekday'
                        continue
                    if tok in date_entity_rules[':timezone']:
                        assigned_edges[j] = ':timezone'
                        continue
                    if tok.lower() in date_entity_rules[':calendar']:
                        assigned_edges[j] = ':calendar'
                        if tok.lower() == 'lunar':
                            entity_tokens[j] = 'moon'
                        continue
                    if tok.lower() in date_entity_rules[':dayperiod']:
                        assigned_edges[j] = ':dayperiod'
                        for idx, tok in enumerate(entity_tokens):
                            if tok.lower() == 'this':
                                entity_tokens[idx] = 'today'
                            elif tok.lower() == 'last':
                                entity_tokens[idx] = 'yesterday'
                        idx = j - 1
                        if idx >= 0 and entity_tokens[idx].lower() == 'one':
                            assigned_edges[idx] = ':quant'
                        continue
                    if tok in date_entity_rules[':era'] or tok.lower() in date_entity_rules[':era'] \
                            or ('"' in tok and tok.replace('"', '') in date_entity_rules[':era']):
                        assigned_edges[j] = ':era'
                        continue
                    if tok.lower() in date_entity_rules[':season']:
                        assigned_edges[j] = ':season'
                        continue

                    months = entity_rules_json['normalize']['months']
                    if tok.lower() in months or len(
                            tok.lower()) == 4 and tok.lower().endswith(
                                '.') and tok.lower()[:3] in months:
                        if ':month' in assigned_edges:
                            idx = assigned_edges.index(':month')
                            if entity_tokens[idx].isdigit():
                                assigned_edges[idx] = ':day'
                        assigned_edges[j] = ':month'
                        continue
                    ntok = self.normalize_token(tok)
                    if ntok.isdigit():
                        if j + 1 < len(entity_tokens) and entity_tokens[
                                j + 1].lower() == 'century':
                            assigned_edges[j] = ':century'
                            continue
                        if 1 <= int(
                                ntok) <= 12 and ':month' not in assigned_edges:
                            if not (tok.endswith('th') or tok.endswith('st') or
                                    tok.endswith('nd') or tok.endswith('nd')):
                                assigned_edges[j] = ':month'
                                continue
                        if 1 <= int(
                                ntok) <= 31 and ':day' not in assigned_edges:
                            assigned_edges[j] = ':day'
                            continue
                        if 1 <= int(
                                ntok
                        ) <= 10001 and ':year' not in assigned_edges:
                            assigned_edges[j] = ':year'
                            continue
                    if tok.startswith("'") and len(
                            tok) == 3 and tok[1:].isdigit():
                        # 'yy
                        assigned_edges[j] = ':year'
                        entity_tokens[j] = tok[1:]
                        continue
                    decades = entity_rules_json['normalize']['decades']
                    if tok.lower() in decades:
                        assigned_edges[j] = ':decade'
                        entity_tokens[j] = str(decades[tok.lower()])
                        continue
                    if tok.endswith(
                            's') and len(tok) > 2 and tok[:2].isdigit():
                        assigned_edges[j] = ':decade'
                        entity_tokens[j] = tok[:-1]
                        continue
                    assigned_edges[j] = ':mod'

                self.amr.nodes[entity_id] = 'date-entity'
                new_concepts.append('date-entity')
                for tok, rel in zip(entity_tokens, assigned_edges):
                    if tok.lower() in [
                            '-comma-', 'of', 'the', 'in', 'at', 'on',
                            'century', '-', '/', '', '(', ')', '"'
                    ]:
                        continue
                    tok = tok.replace('"', '')
                    if rel in [':year', ':decade']:
                        year = tok
                        if len(year) == 2:
                            tok = '20' + year if (
                                0 <= int(year) <= 30) else '19' + year
                    if rel in [':month', ':day'
                               ] and tok.isdigit() and int(tok) == 0:
                        continue
                    if tok.isdigit():
                        while tok.startswith('0') and len(tok) > 1:
                            tok = tok[1:]
                    if rel in [
                            ':day', ':month', ':year', ':era', ':calendar',
                            ':century', ':quant', ':timezone'
                    ]:
                        self.amr.nodes[self.new_id] = self.normalize_token(tok)
                    else:
                        self.amr.nodes[self.new_id] = tok.lower()
                    self.amr.edges.append((entity_id, rel, self.new_id))
                    new_concepts.append(rel + ' ' +
                                        self.amr.nodes[self.new_id])
                    self.new_id += 1
                if gold_amr and set(gold_concepts) == set(new_concepts):
                    entity_rule_stats['date-entity'] += 1
                entity_rule_totals['date-entity'] += 1
                continue

            rule = entity_type + '\t' + ','.join(entity_tokens).lower()
            # check if singular is in fixed rules
            if rule not in entity_rules_json['fixed'] and len(
                    entity_tokens) == 1 and entity_tokens[0].endswith('s'):
                rule = entity_type + '\t' + entity_tokens[0][:-1]

            # fixed rules
            if rule in entity_rules_json['fixed']:
                edges = entity_rules_json['fixed'][rule]['edges']
                nodes = entity_rules_json['fixed'][rule]['nodes']
                root = entity_rules_json['fixed'][rule]['root']
                id_map = {}
                for j, n in enumerate(nodes):
                    node_label = nodes[n]
                    n = int(n)

                    id_map[n] = entity_id if n == root else self.new_id
                    self.new_id += 1
                    self.amr.nodes[id_map[n]] = node_label
                    new_concepts.append(node_label)
                for s, r, t in edges:
                    self.amr.edges.append((id_map[s], r, id_map[t]))
                    concept = self.amr.nodes[id_map[t]]
                    if concept in new_concepts:
                        idx = new_concepts.index(concept)
                        new_concepts[idx] = r + ' ' + new_concepts[idx]
                if gold_amr and set(gold_concepts) == set(new_concepts):
                    entity_rule_stats['fixed'] += 1
                else:
                    entity_rule_fails[entity_type] += 1
                entity_rule_totals['fixed'] += 1
                continue

            rule = entity_type + '\t' + str(len(entity_tokens))

            # variable rules
            if rule in entity_rules_json['var']:
                edges = entity_rules_json['var'][rule]['edges']
                nodes = entity_rules_json['var'][rule]['nodes']
                root = entity_rules_json['var'][rule]['root']
                node_map = {}
                ntok = None
                for i, tok in enumerate(entity_tokens):
                    ntok = self.normalize_token(tok)
                    node_map[f'X{i}'] = ntok if not ntok.startswith(
                        '"') else tok.lower()
                id_map = {}
                for j, n in enumerate(nodes):
                    node_label = nodes[n]
                    n = int(n)

                    id_map[n] = entity_id if n == root else self.new_id
                    self.new_id += 1
                    self.amr.nodes[id_map[n]] = node_map[
                        node_label] if node_label in node_map else node_label
                    new_concepts.append(self.amr.nodes[id_map[n]])
                for s, r, t in edges:
                    node_label = self.amr.nodes[id_map[t]]
                    if 'date-entity' not in entity_type and (
                            node_label.isdigit() or node_label
                            in ['many', 'few', 'some', 'multiple', 'none']):
                        r = ':quant'
                    self.amr.edges.append((id_map[s], r, id_map[t]))
                    concept = self.amr.nodes[id_map[t]]
                    if concept in new_concepts:
                        idx = new_concepts.index(concept)
                        new_concepts[idx] = r + ' ' + new_concepts[idx]
                if gold_amr and set(gold_concepts) == set(new_concepts):
                    entity_rule_stats['var'] += 1
                else:
                    entity_rule_fails[entity_type] += 1
                entity_rule_totals['var'] += 1
                continue

            rule = entity_type

            # named entities rules
            if entity_type.endswith(',name') or entity_type == 'name':
                name_id = None
                if rule in entity_rules_json['names']:
                    edges = entity_rules_json['names'][rule]['edges']
                    nodes = entity_rules_json['names'][rule]['nodes']
                    root = entity_rules_json['names'][rule]['root']
                    id_map = {}
                    for j, n in enumerate(nodes):
                        node_label = nodes[n]
                        n = int(n)

                        id_map[n] = entity_id if n == root else self.new_id
                        if node_label == 'name':
                            name_id = id_map[n]
                        self.new_id += 1
                        self.amr.nodes[id_map[n]] = node_label
                        new_concepts.append(node_label)
                    for s, r, t in edges:
                        self.amr.edges.append((id_map[s], r, id_map[t]))
                        concept = self.amr.nodes[id_map[t]]
                        if concept in new_concepts:
                            idx = new_concepts.index(concept)
                            new_concepts[idx] = r + ' ' + new_concepts[idx]
                else:
                    nodes = entity_type.split(',')
                    nodes.remove('name')
                    name_id = entity_id if len(nodes) == 0 else self.new_id
                    self.amr.nodes[name_id] = 'name'
                    self.new_id += 1
                    if len(nodes) == 0:
                        new_concepts.append('name')
                    for j, node in enumerate(nodes):
                        new_id = entity_id if j == 0 else self.new_id
                        self.amr.nodes[new_id] = node
                        if j == 0:
                            new_concepts.append(node)
                        self.new_id += 1
                        if j == len(nodes) - 1:
                            rel = ':name'
                            self.amr.edges.append((new_id, rel, name_id))
                            new_concepts.append(':name ' + 'name')
                        else:
                            rel = default_rel
                            self.amr.edges.append((new_id, rel, self.new_id))
                            new_concepts.append(default_rel + ' ' +
                                                self.amr.nodes[new_id])

                op_idx = 1
                for tok in entity_tokens:
                    tok = tok.replace('"', '')
                    if tok in ['(', ')', '']:
                        continue
                    new_tok = '"' + tok[0].upper() + tok[1:] + '"'
                    self.amr.nodes[self.new_id] = new_tok
                    rel = f':op{op_idx}'
                    self.amr.edges.append((name_id, rel, self.new_id))
                    new_concepts.append(rel + ' ' + new_tok)
                    self.new_id += 1
                    op_idx += 1
                if gold_amr and set(gold_concepts) == set(new_concepts):
                    entity_rule_stats['names'] += 1
                entity_rule_totals['names'] += 1
                continue

            # unknown entity types
            nodes = entity_type.split(',')
            idx = 0
            prev_id = None
            for node in nodes:
                if node in ['(', ')', '"', '']:
                    continue
                new_id = entity_id if idx == 0 else self.new_id
                self.amr.nodes[new_id] = node
                self.new_id += 1
                if idx > 0:
                    self.amr.edges.append((prev_id, default_rel, new_id))
                    new_concepts.append(default_rel + ' ' + node)
                else:
                    new_concepts.append(node)
                prev_id = new_id
            for tok in entity_tokens:
                tok = tok.replace('"', '')
                if tok in ['(', ')', '']:
                    continue
                self.amr.nodes[self.new_id] = tok.lower()
                self.amr.edges.append((prev_id, default_rel, self.new_id))
                new_concepts.append(default_rel + ' ' + tok.lower())
                self.new_id += 1
            if gold_amr and set(gold_concepts) == set(new_concepts):
                entity_rule_stats['unknown'] += 1
            else:
                entity_rule_fails[entity_type] += 1
            entity_rule_totals['unknown'] += 1

    def normalize_token(self, string):
        global entity_rules_json

        if not entity_rules_json:
            with open('entity_rules.json', 'r', encoding='utf8') as f:
                entity_rules_json = json.load(f)

        lstring = string.lower()
        months = entity_rules_json['normalize']['months']
        units = entity_rules_json['normalize']['units']
        cardinals = entity_rules_json['normalize']['cardinals']
        ordinals = entity_rules_json['normalize']['ordinals']

        # number or ordinal
        if NUM_RE.match(lstring):
            return lstring.replace(',', '').replace('st', '').replace(
                'nd', '').replace('rd', '').replace('th', '')

        # months
        if lstring in months:
            return str(months[lstring])
        if len(lstring) == 4 and lstring.endswith(
                '.') and lstring[:3] in months:
            return str(months[lstring[:3]])

        # cardinal numbers
        if lstring in cardinals:
            return str(cardinals[lstring])

        # ordinal numbers
        if lstring in ordinals:
            return str(ordinals[lstring])

        # unit abbreviations
        if lstring in units:
            return str(units[lstring])
        if lstring.endswith('s') and lstring[:-1] in units:
            return str(units[lstring[:-1]])
        if lstring in units.values():
            return lstring
        if string.endswith('s') and lstring[:-1] in units.values():
            return lstring[:-1]

        return '"' + string + '"'
예제 #17
0
def main(arguments):
    parser = argparse.ArgumentParser(
        description=__doc__,
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--input_file', help="Path of the file containing AMRs of each sentence", type=str,
       default='/home/prerna/Documents/thesis_work/LDC2015E86_DEFT_Phase_2_AMR_Annotation_R1/' + \
       'data/amrs/split/test/deft-p2-amr-r1-amrs-test-alignments-proxy.txt')
    parser.add_argument('--dataset',
                        help="Name of dataset",
                        type=str,
                        default='')
    parser.add_argument(
        '--display',
        help="Path of the file containing AMRs of each sentence",
        type=bool,
        default=False)

    args = parser.parse_args(arguments)

    input_file = args.input_file
    dataset = args.dataset

    # '''
    # 'docs' is a list of 'documents', each 'document' is list a dictionary. Each dictionary contains
    # information about a sentence. Each dicitonary has 'alignments', 'amr' etc. keys. Corresponding
    # to each key we have the relevant information like the amr, text, alignment etc.
    # '''

    # Remove alignments from the new file
    os.system('cp ' + input_file + ' auxiliary/temp')
    with codecs.open('auxiliary/temp', 'r') as data_file:
        original_data = data_file.readlines()

    os.system('sed -i \'s/~e.[	0-9]*//g\' auxiliary/temp')
    os.system('sed -i \'s/,[	0-9]*//g\' auxiliary/temp')

    with codecs.open('auxiliary/temp', 'r') as data_file:
        data = data_file.readlines()
    for index_line, line in enumerate(data):
        if line.startswith('#'):
            data[index_line] = original_data[index_line]

    with codecs.open('auxiliary/temp', 'w') as data_file:
        for line in data:
            data_file.write(line)

    input_file = 'auxiliary/temp'

    docs, target_summaries, stories = read_data(input_file)

    os.system('rm auxiliary/temp')
    save_stories(stories, 'auxiliary/stories.txt')

    with open('auxiliary/target_summaries.txt', 'w') as f:
        for summary in target_summaries:
            f.write(tok_to_std_format_convertor(summary) + '\n')

    f = open('auxiliary/predicted_summaries.txt', 'w')
    summary_sentences_per_story = []
    # currently all the information of a node is stored as a list, changing it to a dictionary
    debug = False
    # 'document_amrs' is the list of document amrs formed after joining nodes and collapsing same entities etc.
    target_summaries_amrs = []
    predicted_summaries_amrs = []
    document_amrs = []
    selected_sents = []
    for index_doc, doc in enumerate(docs):
        current_doc_sent_amr_list = []
        current_target_summary_sent_amr_list = []
        for index_dict, dict_sentence in enumerate(doc):
            if dict_sentence['amr'] != []:
                if dict_sentence['tok'].strip()[-1] != '.':
                    dict_sentence['tok'] = dict_sentence['tok'] + ' .'
                # Get the AMR class for each sentence using just the text
                if dict_sentence['snt-type'] == 'summary':
                    current_target_summary_sent_amr_list.append(
                        AMR(dict_sentence['amr'],
                            amr_with_attributes=False,
                            text=dict_sentence['tok'],
                            alignments=dict_sentence['alignments']))
                if dict_sentence['snt-type'] == 'body':
                    docs[index_doc][index_dict]['amr'] = AMR(
                        dict_sentence['amr'],
                        amr_with_attributes=False,
                        text=dict_sentence['tok'],
                        alignments=dict_sentence['alignments'])
                    current_doc_sent_amr_list.append(
                        docs[index_doc][index_dict]['amr'])
        # merging the sentence AMRs to form a single AMR
        amr_as_list, document_text, document_alignments,var_to_sent = \
                  merge_sentence_amrs(current_doc_sent_amr_list,debug=False)
        new_document_amr = AMR(text_list=amr_as_list,
                               text=document_text,
                               alignments=document_alignments,
                               amr_with_attributes=True,
                               var_to_sent=var_to_sent)
        document_amrs.append(new_document_amr)
        target_summaries_amrs.append(current_target_summary_sent_amr_list)
        imp_doc = index_doc
        if imp_doc == 1000:
            # just the first sentence of the story is the summary
            predicted_summaries_amrs.append([current_doc_sent_amr_list[0]])

        print index_doc
        if index_doc == imp_doc:
            document_amrs[index_doc] = resolve_coref_doc_AMR(
                amr=document_amrs[index_doc],
                resolved=True,
                story=' '.join(document_amrs[index_doc].text),
                # location_of_resolved_story='auxiliary/human_corefs.txt',
                location_of_resolved_story='auxiliary/' + dataset +
                '_predicted_resolutions.txt',
                location_of_story_in_file=index_doc,
                location_of_resolver='.',
                debug=False)

            pr = document_amrs[index_doc].directed_graph.rank_sent_in_degree()
            ranks, weights = zip(*pr)
            print ranks
            print weights

            # get pairs in order of importance
            ranked_pairs = document_amrs[index_doc].directed_graph.rank_pairs(
                ranks=ranks, weights=weights, pairs_to_rank=3)
            # print 'ranked_pairs', ranked_pairs
            paths_and_sub_graphs = document_amrs[
                index_doc].directed_graph.max_imp_path(
                    ordered_pairs=ranked_pairs)

            # add method to check no repeated sub_graph
            summary_paths = []
            summary_amrs = []
            summary_amrs_text = []
            for path_and_sub_graph in paths_and_sub_graphs:
                path, sub_graph, sent = path_and_sub_graph

                path_sent_dict = {}
                if sent == -1:
                    path_sent_dict = document_amrs[
                        index_doc].break_path_by_sentences(path=path)
                else:
                    path_sent_dict[sent] = path

                for key in path_sent_dict.keys():
                    temp_path = path_sent_dict[key]

                    # path = document_amrs[index_doc].concept_relation_list.get_concepts_given_path(sent_index=key,path=temp_path)
                    path = -1
                    # key = 0
                    if path == -1:
                        path = document_amrs[index_doc].get_sent_amr(
                            sent_index=key)

                    nodes, sub_graph = document_amrs[
                        index_doc].directed_graph.get_name_path(nodes=path)

                    new_amr_graph = document_amrs[
                        index_doc].get_AMR_from_directed_graph(
                            sub_graph=sub_graph)

                    repeated_path = False
                    # removing repreating sents/amrs
                    for var_set in summary_paths:
                        if set(var_set) == set(nodes): repeated_path = True

                    if repeated_path: continue

                    summary_paths.append(list(nodes))
                    summary_amrs_text.append(
                        new_amr_graph.print_amr(file=f,
                                                print_indices=False,
                                                write_in_file=True,
                                                one_line_output=True,
                                                return_str=True,
                                                to_print=False))
                    print ''
                    summary_amrs.append(new_amr_graph)

            final_summary_amrs_text = []
            final_summary_amrs = []
            for index, path in enumerate(summary_paths):
                indices_to_search_at = range(len(summary_paths))
                indices_to_search_at.remove(index)
                to_print = True
                for index_2 in indices_to_search_at:
                    if set(path) < set(summary_paths[index_2]):
                        to_print = False
                if to_print:
                    final_summary_amrs_text.append(summary_amrs_text[index])
                    final_summary_amrs.append(summary_amrs[index])

            for summary_amr in final_summary_amrs_text:
                try:
                    summary_sentences_per_story[index_doc] += 1
                except:
                    summary_sentences_per_story.append(1)

                print summary_amr

            predicted_summaries_amrs.append(final_summary_amrs)

    with open('auxiliary/' + dataset + '_eos_stories.txt', 'w') as f:
        for document_amr in document_amrs:
            f.write(' <eos> '.join(document_amr.text) + '\n')

    f.close()
    with open('auxiliary/num_sent_per_story.txt', 'w') as f3:
        pickle.dump(summary_sentences_per_story, f3)
    # save document AMR in file
    with open('auxiliary/text_amr.txt', 'w') as f2:
        f2.write(
            '# :id PROXY_AFP_ENG_20050317_010.10 ::amr-annotator SDL-AMR-09  ::preferred ::snt-type body\n'
        )
        f2.write('# ::snt On 21 March 2005\n')
        f2.write('# ::tok On 21 March 2005\n')
        if imp_doc >= 0 and imp_doc < len(document_amrs):
            for index_node, node in enumerate(document_amrs[imp_doc].amr):
                f2.write('\t' * node['depth'] + node['text'] + '\n')

        # an option to generate the graphical representations
        # return document_amrs
    target_summaries_nodes = []
    for target_summary_amrs in target_summaries_amrs:
        current_summary_nodes = []
        for target_summary_amr in target_summary_amrs:
            current_summary_nodes.extend(target_summary_amr.get_nodes())
        target_summaries_nodes.append(current_summary_nodes)

    with open('auxiliary/target_summary_nodes.txt', 'w') as f6:
        for node_list in target_summaries_nodes:
            f6.write(' '.join([node for node in node_list]) + '\n')

    predicted_summaries_nodes = []
    for predicted_summary_amrs in predicted_summaries_amrs:
        current_summary_nodes = []
        for predicted_summary_amr in predicted_summary_amrs:
            current_summary_nodes.extend(predicted_summary_amr.get_nodes())
        predicted_summaries_nodes.append(current_summary_nodes)

    with open('auxiliary/predicted_summary_nodes.txt', 'w') as f7:
        for node_list in predicted_summaries_nodes:
            f7.write(' '.join([node for node in node_list]) + '\n')
#!/usr/bin/env python2.7
#coding=utf-8
'''

@author: Nathan Schneider ([email protected])
@since: 2015-05-06
'''
from __future__ import print_function
import sys, re, fileinput, codecs
from collections import Counter, defaultdict

from amr import AMR, AMRSyntaxError, AMRError, Concept, AMRConstant

c = Counter()
for ln in fileinput.input():
    try:
        a = AMR(ln)
        c.update(map(repr, a.nodes.keys()))    # vars, concepts, constants: count once per AMR
        c.update('.'+repr(x) for _,r,x in a.triples(rel=':instance-of'))  # concepts count once per variable
        c.update(map((lambda x: x[1]), a.triples()))    # relations
        c.update('.'+repr(x) for _,_,x in a.triples() if isinstance(x,AMRConstant))  # constants count once per relation
    except AMRSyntaxError as ex:
        print(ex, file=sys.stderr)
    except AMRError as ex:
        print(ex, file=sys.stderr)
    
for k,n in c.most_common():
    print(k,n, sep='\t')
예제 #19
0
            args.p_ctx), str(args.p_proj))
    logging.basicConfig(filename=logfilename,
                        level=logging.INFO,
                        format='%(asctime)s :: %(levelname)s :: %(message)s')
    logging.info('log info to ' + logfilename)

logging.info(args)
if args.dataset == 'amazon':
    ds = ds_amazon(logging, args)
else:
    raise Exception('no dataset' + args.dataset)

if args.model == 'bpr':
    model = BPR(ds, args, logging)
elif args.model == 'cbpr':
    model = CBPR(ds, args, logging)
elif args.model == 'vbpr':
    model = VBPR(ds, args, logging)
elif args.model == 'amr':
    model = AMR(ds, args, logging)
elif args.model == 'mtpr':
    model = MTPR(ds, args, logging)
else:
    raise Exception('unknown model type', args.model)

model.train()

weight_filename = 'weights/%s_%s_%s_%s_%s.npy' % (
    args.dataset, args.model, str(args.p_emb), str(args.p_ctx), str(
        args.p_proj))
model.save(weight_filename)
예제 #20
0
def main(arguments):
	parser = argparse.ArgumentParser(
		description=__doc__,
		formatter_class=argparse.ArgumentDefaultsHelpFormatter)
	parser.add_argument('--input_file', help="Path of the file containing AMRs of each sentence", type=str, 
				default='/home/shibhansh/UGP-2/data/LDC2015E86_DEFT_Phase_2_AMR_Annotation_R1/' + \
				'data/amrs/split/test/deft-p2-amr-r1-amrs-test-alignments-proxy.txt')
	parser.add_argument('--dataset', help="Name of dataset",
				type=str, default='')
	parser.add_argument('--display', help="Path of the file containing AMRs of each sentence",
				type=bool, default=False)

	args = parser.parse_args(arguments)

	input_file = args.input_file
	dataset = args.dataset

	'''
	'docs' is a list of 'documents', each 'document' is list a dictionary. Each dictionary contains
	information about a sentence. Each dicitonary has 'alignments', 'amr' etc. keys. Corresponding
	to each key we have the relevant information like the amr, text, alignment etc.
	'''

	# Remove alignments from the new file
	os.system('cp '+ input_file +' auxiliary/temp')
	with codecs.open('auxiliary/temp', 'r') as data_file:
		original_data = data_file.readlines()

	os.system('sed -i \'s/~e.[	0-9]*//g\' auxiliary/temp')
	os.system('sed -i \'s/,[	0-9]*//g\' auxiliary/temp')

	with codecs.open('auxiliary/temp', 'r') as data_file:
		data = data_file.readlines()
	for index_line,line in enumerate(data):
		if line.startswith('#'):
			data[index_line] = original_data[index_line]

	with codecs.open('auxiliary/temp', 'w') as data_file:
		for line in data:
			data_file.write(line)

	input_file = 'auxiliary/temp'

	docs, target_summaries, stories = read_data(input_file)

	os.system('rm auxiliary/temp')
	save_stories(stories,'auxiliary/stories.txt')

	with open('auxiliary/target_summaries.txt','w') as f:
		for summary in target_summaries:
			f.write(tok_to_std_format_convertor(summary)+'\n')
	idf = {}
	with open('auxiliary/'+dataset+'_idf.txt','r') as f:
		idf = pickle.load(f) 

	f = open('auxiliary/predicted_summaries.txt','w')
	summary_sentences_per_story = []
	# currently all the information of a node is stored as a list, changing it to a dictionary
	debug = False
	# 'document_amrs' is the list of document amrs formed after joining nodes and collapsing same entities etc.
	target_summaries_amrs = []
	predicted_summaries_amrs = []
	document_amrs = []
	selected_sents = []
	for index_doc, doc in enumerate(docs):
		current_doc_sent_amr_list = []
		current_target_summary_sent_amr_list = []
		for index_dict, dict_sentence in enumerate(doc):
			if dict_sentence['amr'] != []:
				if dict_sentence['tok'].strip()[-1] != '.': dict_sentence['tok'] = dict_sentence['tok'] + ' .' 
				# Get the AMR class for each sentence using just the text
				if dict_sentence['snt-type'] == 'summary':
					current_target_summary_sent_amr_list.append(AMR(dict_sentence['amr'],
													amr_with_attributes=False,
													text=dict_sentence['tok'],
													alignments=dict_sentence['alignments']))
				if dict_sentence['snt-type'] == 'body':
					docs[index_doc][index_dict]['amr'] = AMR(dict_sentence['amr'],
														amr_with_attributes=False,
														text=dict_sentence['tok'],
														alignments=dict_sentence['alignments'])
					current_doc_sent_amr_list.append(docs[index_doc][index_dict]['amr'])
		# merging the sentence AMRs to form a single AMR
		amr_as_list, document_text, document_alignments,var_to_sent = \
												merge_sentence_amrs(current_doc_sent_amr_list,debug=False)
		new_document_amr = AMR(text_list=amr_as_list,
							text=document_text,
							alignments=document_alignments,
							amr_with_attributes=True,
							var_to_sent=var_to_sent)
		document_amrs.append(new_document_amr)
		target_summaries_amrs.append(current_target_summary_sent_amr_list)

		# number of nodes required in summary

		imp_doc = index_doc
		# imp_doc = 1000
		if imp_doc == 1000:
			# just the first sentence of the story is the summary
			predicted_summaries_amrs.append([current_doc_sent_amr_list[0]])
		if imp_doc == 2000:
			# just the first two sentences of the story is the summary
			predicted_summaries_amrs.append([current_doc_sent_amr_list[0],current_doc_sent_amr_list[1]])
		if imp_doc == 3000:
			# just the first two sentences of the story is the summary
			predicted_summaries_amrs.append([current_doc_sent_amr_list[0],current_doc_sent_amr_list[1]\
												,current_doc_sent_amr_list[2]])
		if imp_doc == -1:
			# all sentences of the story is the summary
			predicted_summaries_amrs.append(current_doc_sent_amr_list)
		if index_doc == imp_doc:
			document_amrs[index_doc], phrases,idf_vars = resolve_coref_doc_AMR(amr=document_amrs[index_doc], 
									resolved=True,story=' '.join(document_amrs[index_doc].text),
									location_of_resolved_story='auxiliary/'+dataset+'_predicted_resolutions.txt',
									location_of_story_in_file=index_doc,
									location_of_resolver='.',
									idf=idf,
									debug=False)

			cn_freq_dict,cn_sent_lists,cn_var_lists=document_amrs[index_doc].get_common_nouns(phrases=phrases)
			idf_vars = document_amrs[index_doc].get_idf_vars(idf_vars=idf_vars,idf=idf)
		
			# range equal to the std_deviation of the summary size in the dataset
			if dataset == '':
				current_summary_nodes = []
				for target_summary_amr in current_target_summary_sent_amr_list:
					current_summary_nodes.extend(target_summary_amr.get_nodes() )

				num_summary_nodes = len(current_summary_nodes)
				range_num_nodes = 0
				range_num_nodes = int((len(document_amrs[index_doc].get_nodes())*4)/100)

			document_amrs[index_doc].get_concept_relation_list(story_index=index_doc,debug=False)

			pr = document_amrs[index_doc].directed_graph.rank_sent_in_degree()

			# rank the nodes with the 'meta_nodes'
			pr = document_amrs[index_doc].directed_graph.rank_with_meta_nodes(var_freq_list=pr,
																			cn_freq_dict=cn_freq_dict,
																			cn_sent_lists=cn_sent_lists,
																			cn_var_dict=cn_var_lists)
			ranks, weights, _ = zip(*pr)
			print ranks
			print weights

			pr = document_amrs[index_doc].directed_graph.add_idf_ranking(var_freq_list=pr,
																		default_idf=5.477,
																		idf_vars=idf_vars,
																		num_vars_to_add=5)

			ranks, weights, _ = zip(*pr)
			print ranks
			print weights

			new_graph = document_amrs[index_doc].directed_graph.construct_greedily_first(ranks=ranks,weights=weights,
							concept_relation_list=document_amrs[index_doc].concept_relation_list,
							use_true_sent_rank=False,num_nodes=num_summary_nodes,range_num_nodes=range_num_nodes)

			# generate AMR from the graphical representation
			new_amr_graph = document_amrs[index_doc].get_AMR_from_directed_graph(sub_graph=new_graph)
			new_amr_graph.print_amr()
			predicted_summaries_amrs.append([new_amr_graph])

	with open('auxiliary/'+dataset+'_eos_stories.txt','w') as f:
		for document_amr in document_amrs:
			f.write(' <eos> '.join(document_amr.text)+'\n')

	f.close()
	with open('auxiliary/num_sent_per_story.txt','w') as f3:
		pickle.dump(summary_sentences_per_story,f3)
	# save document AMR in file
	with open('auxiliary/text_amr.txt','w') as f2:
		f2.write('# :id PROXY_AFP_ENG_20050317_010.10 ::amr-annotator SDL-AMR-09  ::preferred ::snt-type body\n')
		f2.write('# ::snt On 21 March 2005\n')
		f2.write('# ::tok On 21 March 2005\n')
		if imp_doc >= 0 and imp_doc < len(document_amrs):
			for index_node, node in enumerate(document_amrs[imp_doc].amr):
				f2.write('\t'*node['depth']+node['text']+'\n')

	target_summaries_nodes = []
	for target_summary_amrs in target_summaries_amrs:
		current_summary_nodes = []
		for target_summary_amr in target_summary_amrs:
			# current_summary_nodes.extend(target_summary_amr.get_edge_tuples() )
			current_summary_nodes.extend(target_summary_amr.get_nodes() )
		target_summaries_nodes.append(current_summary_nodes)

	target_summary_lengths = [len(i) for i in target_summaries_nodes]
	document_lengths = [len(i.get_nodes()) for i in document_amrs]

	ratios = []
	for i in range(len(document_lengths)):
		ratios.append(float(target_summary_lengths[i]/document_lengths[i])*100)

	average_ratio = (float(sum(ratios)) / len(ratios))
	deviations = [abs(ratio - average_ratio) for ratio in ratios]

	mean_deviation = (float(sum(deviations)) / len(deviations))

	# average ratio in 'gold' dataset is 9%, and deviation is 4%
	print 'average_ratio', average_ratio, 'mean_deviation', mean_deviation

	with open('auxiliary/target_summary_nodes.txt','w') as f6:
		for node_list in target_summaries_nodes:
			f6.write(' '.join([node for node in node_list]) + '\n')

	predicted_summaries_nodes = []
	for predicted_summary_amrs in predicted_summaries_amrs:
		current_summary_nodes = []
		for predicted_summary_amr in predicted_summary_amrs:
			# current_summary_nodes.extend(predicted_summary_amr.get_edge_tuples() )
			current_summary_nodes.extend(predicted_summary_amr.get_nodes() )
		predicted_summaries_nodes.append(current_summary_nodes)

	with open('auxiliary/predicted_summary_nodes.txt','w') as f7:
		for node_list in predicted_summaries_nodes:
			f7.write(' '.join([node for node in node_list]) + '\n')
#!/usr/bin/env python2.7
#coding=utf-8
'''

@author: Nathan Schneider ([email protected])
@since: 2015-05-06
'''
from __future__ import print_function
import sys, re, fileinput, codecs
from collections import Counter, defaultdict

from amr import AMR, AMRSyntaxError, AMRError, Concept, AMRConstant

c = defaultdict(Counter)
for ln in fileinput.input():
    try:
        a = AMR(ln)
        for h, r, d in a.role_triples(normalize_inverses=True,
                                      normalize_mod=False):
            if a._v2c[h].is_frame():
                c[str(a._v2c[h])][r] += 1
    except AMRSyntaxError as ex:
        print(ex, file=sys.stderr)
    except AMRError as ex:
        print(ex, file=sys.stderr)

for f, roles in sorted(c.items()):
    print(f,
          '\t'.join(' '.join([r, str(n)]) for r, n in sorted(roles.items())),
          sep='\t')
예제 #22
0
def main(args):

    # First, let's read the graphs and surface forms
    with open(args.input_amr) as f:
        amrs = f.readlines()
    with open(args.input_surface) as f:
        surfs = f.readlines()

    if args.triples_output is not None:
        triples_out = open(args.triples_output, 'w')

    # Iterate
    anon_surfs = []
    anon_maps = []
    anon_surfs_scope = []
    i = 0
    with open(args.output, 'w') as out, open(args.output_surface,
                                             'w') as surf_out:
        for amr, surf in zip(amrs, surfs):
            graph = AMR(amr, surf.split())

            # Get variable: concept map for reentrancies
            #v2c = graph.var2concept()

            if args.mode == 'LIN':
                # Linearisation mode for seq2seq

                tokens = amr.split()
                new_tokens = simplify(tokens, v2c)
                out.write(' '.join(new_tokens) + '\n')

            elif args.mode == 'GRAPH':
                # Triples mode for graph2seq
                #import ipdb; ipdb.set_trace()
                # Get concepts and generate IDs
                v_ids, rev_v_ids = get_nodes2(graph)

                # Triples
                triples = get_triples(graph, v_ids, rev_v_ids)

                # Print concepts/constants and triples
                #cs = [get_name(c) for c in rev_c_ids]
                cs = [get_name(v, v2c) for v in rev_v_ids]
                out.write(' '.join(cs) + '\n')
                triples_out.write(
                    ' '.join(['(' + ','.join(adj) + ')'
                              for adj in triples]) + '\n')

            elif args.mode == 'LINE_GRAPH':
                # Similar to GRAPH, but with edges as extra nodes
                #import ipdb; ipdb.set_trace()
                print(i)
                i += 1
                #if i == 98:
                #    import ipdb; ipdb.set_trace()
                nodes, triples, anon_surf, anon_map, anon_surf_scope = get_line_graph(
                    graph, surf, anon=args.anon)
                out.write(' '.join(nodes) + '\n')
                triples_out.write(
                    ' '.join(['(%d,%d,%s)' % adj for adj in triples]) + '\n')
                #surf = ' '.join(new_surf)
                anon_surfs.append(anon_surf)
                anon_maps.append(json.dumps(anon_map))
                anon_surfs_scope.append(anon_surf_scope)

            # Process the surface form
            surf_out.write(surf.lower())
    if args.anon:
        with open(args.anon_surface, 'w') as f:
            for anon_surf in anon_surfs:
                f.write(anon_surf + '\n')
        with open(args.map_output, 'w') as f:
            for anon_map in anon_maps:
                f.write(anon_map + '\n')
        with open(args.anon_surface_scope, 'w') as f:
            for anon_surf_scope in anon_surfs_scope:
                f.write(anon_surf_scope + '\n')
예제 #23
0
    parser.add_argument('-s', '--significant', type=int, default=2, help='significant digits to output (default: 2)')
    parser.add_argument('--ms', action='store_true', default=False,
                        help='Output multiple scores (one AMR pair a score)'
                             'instead of a single document-level smatch score (Default: false)')
    args = vars(parser.parse_args())
    if args['v']:
        log.getLogger().setLevel(level=log.INFO)

    if args['vv']:
        log.getLogger().setLevel(level=log.DEBUG)

    file1, file2 = args['amrfile']
    float_fmt = '%%.%df' % args['significant']
    # Note: instead of computing overage, we are summing all AMRs
    total_match, file1_count, file2_count = 0, 0, 0
    for amr1, amr2 in AMR.read_amrs(file1, file2):
        smatch = SmatchILP(amr1, amr2)
        score, match_count = smatch.solve()
        total_match += match_count
        file1_count += smatch.arg1size
        file2_count += smatch.arg2size
        if args['ms']:
            out = float_fmt % score
            print('F-score: %s' % out)
    if total_match > 0:
        prec = total_match / file1_count
        recall = total_match / file2_count
        smatch_score = SmatchILP.f_mneasure(prec, recall)
        out = float_fmt % smatch_score
        print('\nAggregated F-score: %s' % out)
    else:
예제 #24
0
def compute_subscores(pred, gold):
    inters = defaultdict(int)
    golds = defaultdict(int)
    preds = defaultdict(int)
    # Loop through all entries
    for amr_pred, amr_gold in zip(pred, gold):
        # Create the predicted data
        amr_pred = AMR.parse_AMR_line(amr_pred.replace("\n", ""))
        if amr_pred is None:
            logger.error('Empty amr_pred entry')
            continue
        dict_pred = var2concept(amr_pred)
        triples_pred = [t for t in amr_pred.get_triples()[1]]
        triples_pred.extend([t for t in amr_pred.get_triples()[2]])
        # Create the gold data
        amr_gold = AMR.parse_AMR_line(amr_gold.replace("\n", ""))
        if amr_gold is None:
            logger.error('Empty amr_gold entry')
            continue
        dict_gold = var2concept(amr_gold)
        triples_gold = [t for t in amr_gold.get_triples()[1]]
        triples_gold.extend([t for t in amr_gold.get_triples()[2]])
        # Non_sense_frames scores
        list_pred = non_sense_frames(dict_pred)
        list_gold = non_sense_frames(dict_gold)
        inters["Non_sense_frames"] += len(
            list(set(list_pred) & set(list_gold)))
        preds["Non_sense_frames"] += len(set(list_pred))
        golds["Non_sense_frames"] += len(set(list_gold))
        # Wikification scores
        list_pred = wikification(triples_pred)
        list_gold = wikification(triples_gold)
        inters["Wikification"] += len(list(set(list_pred) & set(list_gold)))
        preds["Wikification"] += len(set(list_pred))
        golds["Wikification"] += len(set(list_gold))
        # Named entity scores
        list_pred = namedent(dict_pred, triples_pred)
        list_gold = namedent(dict_gold, triples_gold)
        inters["Named Ent."] += len(list(set(list_pred) & set(list_gold)))
        preds["Named Ent."] += len(set(list_pred))
        golds["Named Ent."] += len(set(list_gold))
        # Negation scores
        list_pred = negations(dict_pred, triples_pred)
        list_gold = negations(dict_gold, triples_gold)
        inters["Negations"] += len(list(set(list_pred) & set(list_gold)))
        preds["Negations"] += len(set(list_pred))
        golds["Negations"] += len(set(list_gold))
        # Ignore Vars scores
        list_pred = everything(dict_pred, triples_pred)
        list_gold = everything(dict_gold, triples_gold)
        inters["IgnoreVars"] += len(list(set(list_pred) & set(list_gold)))
        preds["IgnoreVars"] += len(set(list_pred))
        golds["IgnoreVars"] += len(set(list_gold))
        # Concepts scores
        list_pred = concepts(dict_pred)
        list_gold = concepts(dict_gold)
        inters["Concepts"] += len(list(set(list_pred) & set(list_gold)))
        preds["Concepts"] += len(set(list_pred))
        golds["Concepts"] += len(set(list_gold))
        # Frames scores
        list_pred = frames(dict_pred)
        list_gold = frames(dict_gold)
        inters["Frames"] += len(list(set(list_pred) & set(list_gold)))
        preds["Frames"] += len(set(list_pred))
        golds["Frames"] += len(set(list_gold))
    # Create the return dictionary
    rdict = OrderedDict()
    for score in preds:
        pr = 0 if preds[score] <= 0 else inters[score] / float(preds[score])
        rc = 0 if golds[score] <= 0 else inters[score] / float(golds[score])
        f = 0 if pr + rc <= 0 else 2 * (pr * rc) / (pr + rc)
        rdict[score] = (pr, rc, f)
    return rdict
예제 #25
0
#!/usr/bin/env python2.7
#coding=utf-8
'''

@author: Nathan Schneider ([email protected])
@since: 2015-05-06
'''
from __future__ import print_function
import sys, re, fileinput, codecs
from collections import Counter, defaultdict

from amr import AMR, AMRSyntaxError, AMRError, Concept, AMRConstant

c = defaultdict(Counter)
for ln in fileinput.input():
    try:
        a = AMR(ln)
        for h,r,d in a.role_triples(normalize_inverses=True, normalize_mod=False):
            if a._v2c[h].is_frame():
                c[str(a._v2c[h])][r] += 1
    except AMRSyntaxError as ex:
        print(ex, file=sys.stderr)
    except AMRError as ex:
        print(ex, file=sys.stderr)
    
for f,roles in sorted(c.items()):
    print(f,'\t'.join(' '.join([r,str(n)]) for r,n in sorted(roles.items())), sep='\t')