Ejemplo n.º 1
0
 def predict_and_output(cls, parser, options, sentences, output_file):
     total_result = ScorerResult.zero()
     with open(output_file + ".txt", "w") as f, \
             open(output_file + ".graph", "w") as f_graph, \
             open(output_file + ".gold_graph", "w") as f_gold_graph:
         for sent_id, graph in parser.predict(sentences):
             with gzip.open(options.deepbank_dir + "/" + sent_id + ".gz",
                            "rb") as f_gz:
                 contents = f_gz.read().decode("utf-8")
             fields = contents.strip().split("\n\n")
             if options.graph_type == "eds":
                 eds_literal = fields[-2]
                 eds_literal = re.sub(r"\{.*\}", "", eds_literal)
                 e = eds.loads_one(eds_literal)
                 gold_graph = EDSScorer.from_eds(e, sent_id)
                 f_gold_graph.write(eds_for_smatch(sent_id, e))
             else:
                 assert options.graph_type == "dmrs"
                 m = simplemrs.loads_one(fields[-3])
                 gold_graph = EDSScorer.from_mrs(m)
                 f_gold_graph.write(mrs_for_smatch(sent_id, m))
             result = EDSScorer.from_hypergraph(graph).compare_with(
                 gold_graph,
                 True, log_func=lambda x: print(x, file=f))
             f.write(str(result))
             f_graph.write(output_hg(sent_id, graph))
             total_result += result
         print("Total:")
         print(total_result)
         f.write(str(total_result))
     current_path = os.path.dirname(__file__)
     os.system('{}/../utils/smatch_1 {} {} > {}.smatch'.format(
         current_path, output_file + ".graph", output_file + ".gold_graph", output_file))
     os.system('cat {}.smatch'.format(output_file))
Ejemplo n.º 2
0
 def eds(self):
     """
     Deserialize and return an Eds object for native- or
     JSON-formatted EDS data; otherwise return the original string.
     """
     _eds = self.get('eds')
     if _eds is not None:
         if isinstance(_eds, dict):
             _eds = eds.Eds.from_dict(_eds)
         elif isinstance(_eds, stringtypes):
             _eds = eds.loads_one(_eds)
     return _eds
Ejemplo n.º 3
0
 def eds(self):
     """
     Deserialize and return an Eds object for native- or
     JSON-formatted EDS data; otherwise return the original string.
     """
     _eds = self.get('eds')
     if _eds is not None:
         if isinstance(_eds, dict):
             _eds = eds.Eds.from_dict(_eds)
         elif isinstance(_eds, stringtypes):
             _eds = eds.loads_one(_eds)
     return _eds
Ejemplo n.º 4
0
def test_deserialize():
    e = eds.loads_one('{}')
    assert e.top is None
    assert len(e.nodes()) == 0

    e = eds.loads_one('{:}')
    assert e.top is None
    assert len(e.nodes()) == 0

    e = eds.loads_one('{e2: e2:_rain_v_1<3:9>[]}')
    assert e.top == 'e2'
    assert len(e.nodes()) == 1
    assert e.nodes()[0].pred == '_rain_v_1_rel'

    e = eds.loads_one('{: e2:_rain_v_1<3:9>[]}')
    assert e.top is None
    assert len(e.nodes()) == 1
    assert e.nodes()[0].pred == '_rain_v_1_rel'

    e = eds.loads_one('{e2:\n'
                      ' e2:_rain_v_1<3:9>{e SF prop, TENSE pres}[]\n'
                      '}')
    assert e.top == 'e2'
    assert len(e.nodes()) == 1
    assert e.nodes()[0].properties == {'SF': 'prop', 'TENSE': 'pres'}

    e = eds.loads_one('{e2: (fragmented)\n'
                      '|e5:_nearly_x_deg<0:6>[]\n'
                      ' _1:_every_q<7:12>[BV x3]\n'
                      ' x3:_dog_n_1<13:16>[]\n'
                      ' e2:_bark_v_1<17:24>[ARG1 x3]\n'
                      '}')
    assert e.top == 'e2'
    assert len(e.nodes()) == 4
Ejemplo n.º 5
0
def test_deserialize():
    e = eds.loads_one('{}')
    assert e.top is None
    assert len(e.nodes()) == 0

    e = eds.loads_one('{:}')
    assert e.top is None
    assert len(e.nodes()) == 0

    e = eds.loads_one('{e2: e2:_rain_v_1<3:9>[]}')
    assert e.top == 'e2'
    assert len(e.nodes()) == 1
    assert e.nodes()[0].pred == '_rain_v_1_rel'

    e = eds.loads_one('{: e2:_rain_v_1<3:9>[]}')
    assert e.top is None
    assert len(e.nodes()) == 1
    assert e.nodes()[0].pred == '_rain_v_1_rel'

    e = eds.loads_one(
        '{e2:\n'
        ' e2:_rain_v_1<3:9>{e SF prop, TENSE pres}[]\n'
        '}'
    )
    assert e.top == 'e2'
    assert len(e.nodes()) == 1
    assert e.nodes()[0].properties == {'SF': 'prop', 'TENSE': 'pres'}

    e = eds.loads_one(
        '{e2: (fragmented)\n'
        '|e5:_nearly_x_deg<0:6>[]\n'
        ' _1:_every_q<7:12>[BV x3]\n'
        ' x3:_dog_n_1<13:16>[]\n'
        ' e2:_bark_v_1<17:24>[ARG1 x3]\n'
        '}'
    )
    assert e.top == 'e2'
    assert len(e.nodes()) == 4
Ejemplo n.º 6
0
def test_deserialize():
    e = eds.loads_one('{}')
    assert e.top is None
    assert len(e.nodes()) == 0

    e = eds.loads_one('{:}')
    assert e.top is None
    assert len(e.nodes()) == 0

    e = eds.loads_one('{e2: e2:_rain_v_1<3:9>[]}')
    assert e.top == 'e2'
    assert len(e.nodes()) == 1
    assert e.nodes()[0].pred == '_rain_v_1_rel'

    e = eds.loads_one('{: e2:_rain_v_1<3:9>[]}')
    assert e.top is None
    assert len(e.nodes()) == 1
    assert e.nodes()[0].pred == '_rain_v_1_rel'

    e = eds.loads_one(
        '{e2:\n'
        ' e2:_rain_v_1<3:9>{e SF prop, TENSE pres}[]\n'
        '}'
    )
    assert e.top == 'e2'
    assert len(e.nodes()) == 1
    assert e.nodes()[0].properties == {'SF': 'prop', 'TENSE': 'pres'}

    e = eds.loads_one(
        '{e2: (fragmented)\n'
        '|e5:_nearly_x_deg<0:6>[]\n'
        ' _1:_every_q<7:12>[BV x3]\n'
        ' x3:_dog_n_1<13:16>[]\n'
        ' e2:_bark_v_1<17:24>[ARG1 x3]\n'
        '}'
    )
    assert e.top == 'e2'
    assert len(e.nodes()) == 4

    # GitHub issue #203
    # _thing_n_of-about was tripping up the parser due to the hyphen,
    # and the empty property list of _business_n_1 does not have a space
    # before } (without the space is better, I think)
    e = eds.loads_one(
        '{e3:\n'
        ' _1:udef_q<0:35>[BV x6]\n'
        ' e9:_successful_a_1<0:10>{e SF prop, TENSE untensed, MOOD indicative, PROG -, PERF -}[ARG1 x6]\n'
        ' e10:_american_a_1<11:19>{e SF prop, TENSE untensed, MOOD indicative, PROG -, PERF -}[ARG1 x6]\n'
        ' e12:compound<20:35>{e SF prop, TENSE untensed, MOOD indicative, PROG -, PERF -}[ARG1 x6, ARG2 x11]\n'
        ' _2:udef_q<20:28>[BV x11]\n'
        ' x11:_business_n_1<20:28>{x}[]\n'
        ' x6:_owner_n_of<29:35>{x PERS 3, NUM pl, IND +}[]\n'
        ' e3:_do_v_1<36:38>{e SF prop, TENSE pres, MOOD indicative, PROG -, PERF -}[ARG1 x6, ARG2 x18]\n'
        ' _3:_the_q<39:42>[BV x18]\n'
        ' e23:_same_a_as<43:47>{e SF prop, TENSE untensed, MOOD indicative, PROG -, PERF -}[ARG1 x18]\n'
        ' e25:comp_equal<43:47>{e SF prop, TENSE untensed, MOOD indicative, PROG -, PERF -}[ARG1 e23]\n'
        ' x18:_thing_n_of-about<48:54>{x PERS 3, NUM sg, IND +}[]\n'
        '}'
    )
    assert e.top == 'e3'
    assert len(e.nodes()) == 12
    assert e.nodes()[5].properties == {}
    assert e.nodes()[11].pred == '_thing_n_of-about'
Ejemplo n.º 7
0
def mapper(options):
    main_dir, bank, strip_tree, is_train, graph_type, detect_func_name = options
    detect_func = {
        "small": HRGDerivation.detect_small,
        "large": HRGDerivation.detect_large,
        "lexicalized": HRGDerivation.detect_lexicalized
    }[detect_func_name]
    result = []
    with open(main_dir + bank, encoding="utf-8") as f:
        if bank.startswith("."):
            return
        while True:
            sent_id = f.readline().strip()
            if not sent_id:
                break
            assert sent_id.startswith("#")
            sent_id = sent_id[1:]
            tree_literal = f.readline().strip()
            try:
                with gzip.open(
                        deepbank_export_path + bank + "/" + sent_id + ".gz",
                        "rb") as f_gz:
                    contents = f_gz.read().decode("utf-8")
                cfg = ConstTree.from_java_code_deepbank_1_1(
                    tree_literal, contents)

                # strip labels
                if strip_tree == STRIP_ALL_LABELS or strip_tree == STRIP_INTERNAL_LABELS:
                    if strip_tree == STRIP_ALL_LABELS:
                        strip_label(cfg)
                    elif strip_tree == STRIP_INTERNAL_LABELS:
                        strip_label_internal(cfg)
                    strip_unary(cfg)
                elif strip_tree == STRIP_TO_UNLABEL or strip_tree == FUZZY_TREE:
                    strip_to_unlabel(cfg)

                cfg = cfg.condensed_unary_chain()
                cfg.populate_spans_internal()
                fix_punct_hyphen(cfg)
                fields = contents.strip().split("\n\n")
                if graph_type == "eds":
                    eds_literal = fields[-2]
                    eds_literal = re.sub("\{.*\}", "", eds_literal)
                    e = eds.loads_one(eds_literal)
                    hg = HyperGraph.from_eds(e)
                elif graph_type == "dmrs":
                    mrs_literal = fields[-3]
                    mrs_obj = simplemrs.loads_one(mrs_literal)
                    hg = HyperGraph.from_mrs(mrs_obj)
                else:
                    raise Exception("Invalid graph type!")
                names, args = extract_features(hg, cfg)
                if strip_tree == 3:
                    cfg = fuzzy_cfg(cfg, names)
                derivations = CFGRule.extract(
                    hg,
                    cfg,
                    # draw=True,
                    sent_id=sent_id,
                    detect_func=detect_func)
                sent_id_info = "# ID: " + sent_id + "\n"
                span_info = "# DelphinSpans: " + repr(
                    [i.span for i in cfg.generate_words()]) + "\n"
                args_info = "# Args: " + repr(list(args)) + "\n"
                names_info = "# Names: " + repr(list(names)) + "\n"
                header = cfg.get_words()
                original_cfg = cfg.to_string(with_comma=False).replace(
                    "+++", "+!+")
                rules = list(cfg.generate_rules())
                assert len(derivations) == len(rules)
                for syn_rule, cfg_rule in zip(derivations, rules):
                    assert cfg_rule.tag == syn_rule.lhs
                    new_name = "{}#{}".format(cfg_rule.tag,
                                              len(syn_rule.hrg.lhs.nodes) \
                                                  if syn_rule.hrg is not None else 0)
                    cfg_rule.tag = new_name
                additional_cfg = cfg.to_string(with_comma=False).replace(
                    "+++", "+!+")
                if any(rule for rule in cfg.generate_rules()
                       if len(rule.child) > 2):
                    if is_train:
                        print("{} Not binary tree!".format(sent_id))
                    else:
                        raise Exception("Not binary tree!")
                result.append((sent_id, derivations, header,
                               header + original_cfg, header + additional_cfg))
            except Exception as e:
                print(sent_id)
                print(e.__class__.__name__)
                traceback.print_exc()
    return bank, result
Ejemplo n.º 8
0
def relate_eds_coref(coref_discourse, eds_discourse, db_discourse,
                     index_discourse, syntax_discourse):
    def wash_name(name):
        # 将"-"两端连起来作为一个词
        name_list = name.split(' ')
        new_list = []
        i = 0
        while i < len(name_list):
            word = name_list[i]
            word = word.replace("'", "’")
            word = word.replace('``', '“')
            word = word.replace('’’', '”')
            word = word.replace('-LCB-', '{')
            word = word.replace('-RCB-', '}')
            if word == '-':
                new_list[-1] = new_list[-1] + '-' + name_list[i + 1]
                i += 1
            elif word == '--':
                new_list.append('–')
            elif word == '-LRB-':
                new_list.append('(')
            elif word == '-RRB-':
                new_list.append(')')
            else:
                new_list.append(word)
            i += 1
        return " ".join(new_list)

    def match_name_sent(initial_name, sent, index_sent, span_index=None):
        sent_str = " ".join(sent)
        name = wash_name(initial_name)
        if not name:
            print("Initial name for void:", initial_name)
        #print("Washed: ", name)
        name_list = name.split(' ')
        accumulated_pos = 0
        result = []
        for i in range(len(sent) - len(name_list) + 1):
            flag = 1
            for j in range(i, i + len(name_list)):
                if sent[j] != name_list[j - i]:
                    flag = 0
            if flag:
                start = int(index_sent[i].split(':')[0])
                end = int(index_sent[i + len(name_list) - 1].split(':')[-1])
                #return start, end
                result.append([start, end, i])

        # 选择最近的span
        if len(result):
            span_start = int(span_index.split('-')[0])
            distance = 10000
            for k, span in enumerate(result):
                if abs(span[2] - span_start) < distance:
                    distance = abs(span_start - span[2])
                    record = k
            return result[record][:2]

        # e.g. Packwood & Packwood-Roth
        if len(name_list) == 1:
            for i in range(len(sent)):
                if name_list[0] in sent[i]:
                    start = int(index_sent[i].split(':')[0])
                    end = int(index_sent[i].split(':')[-1])
                    return start, end
        if 'U.S .' in sent_str and 'U.S.' in name:
            name = name.replace('U.S.', 'U.S .')
            name_list = name.split(' ')
            print(name_list)
            print(name in sent_str)
            for i in range(len(sent) - len(name_list) + 1):
                flag = 1
                for j in range(i, i + len(name_list)):
                    print(sent[j])
                    print(name_list[j - i])
                    print('')
                    if sent[j] != name_list[j - i]:
                        flag = 0
                if flag:
                    start = int(index_sent[i].split(':')[0])
                    end = int(index_sent[i + len(name_list) -
                                         1].split(':')[-1])
                    return start, end

        print("Washed:", name)
        print(sent_str)
        print("")
        return -1, -1

    def get_related_node(start, end, nodes, sent=None):
        res = []
        for item in nodes.items():
            item_start = item[1].lnk.data[0]
            item_end = item[1].lnk.data[1]
            if start - 2 <= item_start and end + 2 >= item_end and (
                (item_start <= end and item_end >= start) or
                (item_end >= start and item_start <= end)):
                res.append(item[0])
            #elif ((start <= item_start <= end) or (start <= item_end <= end)) and not item[0].startswith('e'):
            elif ((start <= item_start <= end) or
                  (start <= item_end <= end)) and item[0].startswith('e'):
                res.append(item[0])
        return res

    basic_information = []
    for i in range(len(db_discourse)):
        basic_information.append({
            'sentence': " ".join(db_discourse[i]),
            'eds': eds_discourse[i],
            'syntax': syntax_discourse[i]
        })
    global_dict['discourse'] = basic_information

    ambiguous = 0
    count_all = 0
    global not_matched_chunk
    coref_information_discourse = []
    for value in coref_discourse.values():
        coref_information = []
        for item in value:
            name = item.split('|')[0]
            sent_index = int(item.split('|')[1])
            span_index = item.split('|')[2]
            sent = db_discourse[sent_index]
            if not sent:
                continue
            start, end = match_name_sent(name, sent,
                                         index_discourse[sent_index],
                                         span_index)
            if start == -1 and end == -1:
                not_matched_chunk += 1
            eds_literal = " ".join(eds_discourse[sent_index])
            eds_literal = re.sub("\{.*\}", "", eds_literal)
            try:
                eds = loads_one(eds_literal)
            except:
                print(eds_literal)
                raise AssertionError('Error')
            related_nodes = get_related_node(start, end, eds._nodes,
                                             " ".join(sent))
            subgraph = []
            for node in related_nodes:
                for line in eds_discourse[sent_index]:
                    if line.startswith(' ' + node):
                        subgraph.append(line)
                        break
            try:
                index_to_node = parse_eds(subgraph)
                context = parse_eds(eds_discourse[sent_index][1:-1])
            except:
                for line in eds_discourse[sent_index][1:-1]:
                    print(line, end='')
                traceback.print_exc()
                exit(0)
            try:
                """
                if len([node for node in index_to_node.values() if node.index.startswith('i') and node.name == 'compound']):
                    print(name, '***', sent_index, ' ', " ".join(db_discourse[sent_index]), '\nSubgraph:')
                    for line in subgraph:
                        print(line, end='')
                    input('Check This.')
                """
                cores, reference_count = match_pattern(index_to_node, context)
            except:
                print("Error occurred when extracting cores!")
                print(name, '***', sent_index, ' ',
                      " ".join(db_discourse[sent_index]), '\nSubgraph:')
                for line in subgraph:
                    print(line, end='')
                traceback.print_exc()
                exit(0)
            count_all += 1
            if len(cores) > 1 or len(cores) == 0:
                check = [1 for cnt in reference_count if cnt > 0]
                if len(check) == 1:
                    idx = [int(cnt > 0) * 1
                           for cnt in reference_count].index(1)
                    coref_information.append(item + '|' + cores[idx].index)
                    continue
                if len(cores) > 1:
                    ambiguous += 1
                else:
                    not_matched_chunk += 1
                if debug:
                    print(start, end)
                    if db_discourse:
                        print(name, '***', sent_index, ' ',
                              " ".join(db_discourse[sent_index]),
                              '\nSubgraph:')
                    else:
                        print(name, '***', sent_index, '\nSubgraph:')
                    for line in subgraph:
                        print(line, end='')
                    print('*******')
                    for line in eds_discourse[sent_index][1:-1]:
                        print(line, end='')
                    for j, node in enumerate(cores):
                        print(node.index, reference_count[j])
                    #for line in eds_discourse[sent_index]:
                    #    print(line, end='')
                    #print(eds_discourse[sent_index])
                    if str(input("Error occured. Check this. Continue? y/n\n")
                           ) != 'y':
                        exit(0)
                    print('=============')
            else:
                coref_information.append(item + '|' + cores[0].index)
        #if len(coref_information) > 1:
        #    for item in coref_information:
        #file.write(item + '\n')
        #file.write('\n')
        coref_information_discourse.append(coref_information)
    global_dict['coreference'] = coref_information_discourse
    json.dump(global_dict, file)
    return ambiguous, count_all