Esempio n. 1
0
def get_snacs_refined_ucca(passage):
    p_snacs = convert.join_passages([passage])
    p_refined = convert.join_passages([passage])
    edges_snacs = (c.edge for _c in extract_candidates(p_snacs).values() for c in _c)
    edges_refined = (c.edge for _c in extract_candidates(p_refined).values() for c in _c)
    for e_snacs, e_refined in zip(edges_snacs, edges_refined):
        assert e_snacs.parent.ID == e_refined.parent.ID and e_snacs.child.ID == e_refined.child.ID
        old_tags, e_snacs.categories, e_refined.categories = e_snacs.tags, [], []
        all_old_tags = []
        for tag in old_tags:
            all_old_tags.extend(tag.split(':'))
#        new_tags = []
        if any(t.startswith('p.') for t in all_old_tags):
            for t in all_old_tags:
                if t.startswith('p.'):
                    e_snacs.add(t)
                elif t[0] not in '?`':
                    e_refined.add(t)
                else:
                    assert False, (t, str(e_snacs.parent), str(e_snacs.child), all_old_tags)
#                    if edge not in edges: edges.add((edge, tuple(sorted(all_old_tags))))
#                    edge.add(t)
#        if new_tags:
#            edge.add(':'.join(sorted(new_tags)))
    return p_snacs, p_refined
Esempio n. 2
0
def remove_preterminals(passage):
    _p = convert.join_passages([passage])
    for edge in (c.edge for _c in extract_candidates(_p, constructions=('preterminals',)).values() for c in _c):
#        old_term_edge = terminal.incoming[0]
        non_preterminal_cats, pss = [], []
        for c in edge.categories:
            if c.tag.startswith('Preterminal'):
                tags = c.tag.split(':')
                for t in tags:
                    if t.startswith('p.'):
                        pss.append(t)
            else:
                non_preterminal_cats.append(c.tag)
        assert len(pss) <= 1, (str(edge.parent), pss)
            #if len(tags) >= 2:
            #    refinements += ':' if refinements else '' + ':'.join([t for t in tags[1:] if t.startswith('p.')])
        prepreterminal = edge.parent
        outgoing = [(e.categories, e.child) for e in edge.child.outgoing if isinstance(e.child, layer0.Terminal)]
        assert len(outgoing) <= 1, (prepreterminal, [([c.tag for c in _cats], str(n)) for _cats, n in outgoing])
        if non_preterminal_cats:
            edge.categories = [c for c in edge.categories if not c.tag.startswith('Preterminal')]
            print('WARNING: preterminals and non-preterminals', prepreterminal, outgoing)
        else:
            edge.child.destroy()
            for _cats, n in outgoing:
                new_edge = prepreterminal.add_multiple([(c.tag, '', c.layer, '') for c in _cats] + [(t,) for t in pss] , n)
                if pss:
                    assert n.text
                    new_edge.refinement = pss[0]

    return _p
Esempio n. 3
0
def convert_concat(passage):
    for edge in (c.edge for _c in extract_candidates(passage).values() for c in _c):
        old_tags, edge.categories = edge.tags, []
        for tag in old_tags:
            ucca_snacs = tag.split(':')
            for t in ucca_snacs:
                if t[0] not in '?`':
                    edge.add(t)
Esempio n. 4
0
def extract_and_check(p, constructions=None, expected=None):
    d = OrderedDict(
        (construction, [candidate.edge for candidate in candidates])
        for construction, candidates in extract_candidates(
            p, constructions=constructions).items() if candidates)
    if expected is not None:
        hist = {c.name: len(e) for c, e in d.items()}
        assert hist == expected, " != ".join(",".join(sorted(h))
                                             for h in (hist, expected))
Esempio n. 5
0
def get_snacs_ucca(passage):
    _p = convert.join_passages([passage])
    for edge in (c.edge for _c in extract_candidates(_p).values() for c in _c):
        old_tags, edge.categories = edge.tags, []
        for tag in old_tags:
            ucca_snacs = tag.split(':')
            if len(ucca_snacs) > 1:
                for t in ucca_snacs[1:]:
                    edge.add(t)
    return _p
Esempio n. 6
0
def get_refined_ucca(passage):
    _p = convert.join_passages([passage])
    for edge in (c.edge for _c in extract_candidates(_p).values() for c in _c):
        old_tags, edge.categories = edge.tags, []
        for tag in old_tags:
            ucca_snacs = tag.split(':')
            if len(ucca_snacs) >= 2:
                if any(t.startswith('p.') for t in ucca_snacs[1:]):
                    edge.add(ucca_snacs[0])
    return _p
Esempio n. 7
0
def get_vanilla_ucca(passage):
    _p = convert.join_passages([passage])
    for edge in (c.edge for _c in extract_candidates(_p).values() for c in _c):
        old_tags, edge.categories = edge.tags, []
        for tag in old_tags:
            ucca_snacs = tag.split(':')
            edge.add(ucca_snacs[0])
            #if len(ucca_snacs) >= 2:
            #    edge.refinement = ucca_snacs[1]
    return _p
Esempio n. 8
0
def convert_refinement_to_concat(passage):
    for edge in (c.edge for _c in extract_candidates(passage).values()
                 for c in _c):
        ss = edge.refinement
        if ss is None: continue
        old_tag, old_tags, edge.categories = edge.tag, edge.tags, []
        edge.add(f'{old_tag}:{ss}')
        for tag in old_tags:
            if tag != old_tag:
                edge.add(f'{tag}')
Esempio n. 9
0
def get_snacs_ucca(passage):
    _p = convert.join_passages([passage])
    edges = set()
    for edge in (c.edge for _c in extract_candidates(_p).values() for c in _c):
        old_tags, edge.categories = edge.tags, []
        all_old_tags = []
        for tag in old_tags:
            all_old_tags.extend(tag.split(':'))
        if any(t.startswith('p.') for t in all_old_tags):
            for t in all_old_tags:
                if t.startswith('p.'):
                    if edge not in edges: edges.add((edge, tuple(sorted(all_old_tags))))
                    edge.add(t)
    return _p, edges
Esempio n. 10
0
def main(args):
    for passage in get_passages_with_progress_bar(args.passages):
        c2es = OrderedDict((c, [candidate.edge for candidate in candidates]) for c, candidates in
                           extract_candidates(passage, constructions=args.constructions, verbose=args.verbose).items()
                           if candidates)
        if any(c2es.values()):
            with external_write_mode():
                if not args.verbose:
                    print("%s:" % passage.ID)
                for construction, edges in c2es.items():
                    if edges:
                        print("  %s:" % construction.description)
                        for edge in edges:
                            print("    %s [%s %s]" % (edge, edge.tag, edge.child))
                print()
Esempio n. 11
0
def get_full_ucca(passage):
    _p = convert.join_passages([passage])
    for edge in (c.edge for _c in extract_candidates(_p).values() for c in _c):
        old_tags, edge.categories = edge.tags, []
        all_old_tags, _ucca, _snacs = [], [], []
        for tag in old_tags:
            for t in tag.split(':'):
                all_old_tags.append(t)
                if t.startswith('p.'):
                    _snacs.append(t)
                else:
                    _ucca.append(t)
#        for t in sorted(_ucca):
#            edge.add(f'{t}:{":".join(sorted(_snacs))}')
        edge.add(f'{":".join(sorted(set(all_old_tags)))}')
#        for tag in old_tags:
#            ucca_snacs = tag.split(':')
#            _tag = ucca_snacs[0]
#            if len(ucca_snacs) >= 2:
#                for t in sorted(ucca_snacs[1:]):
#                    if t.startswith('p.'):
#                        _tag += ':' + t
#            edge.add(_tag)
    return _p
Esempio n. 12
0
def main(args):

    streusle_file = args[0]
    ucca_path = args[1]
    outpath = args[2]

    for doc, passage, term2tok in get_passages(streusle_file,
                                               ucca_path,
                                               annotate=True,
                                               target='prep'):

        sent_ids = map(lambda x: ''.join(x['sent_id'].split('-')[-2:]),
                       doc['sents'])

        sent_passage = zip(sent_ids,
                           uconv.split_passage(passage, doc['ends'], sent_ids))

        for sent, psg in sent_passage:

            p = uconv.join_passages([psg])
            l0 = p.layer(ul0.LAYER_ID)
            l1 = p.layer(ul1.LAYER_ID)

            for pos, terminal in l0.pairs:

                # print(terminal.extra)
                if 'ss' not in terminal.extra or not isinstance(
                        terminal.extra['ss'],
                        str) or terminal.extra['ss'][0] != 'p':
                    # print(terminal.extra)
                    continue

                unit = doc["exprs"][tuple(
                    map(int, terminal.extra["toknums"].split()))]

                # pt = terminal.incoming[0].parent
                # node = pt.fparent
                # if node.fparent:
                #     node = node.fparent
                # nodes = set(get_all_descendants(node, remotes=True))

                # print(refined)

                # for n in nodes:
                ID = f'{doc["id"]}_{unit["sent_offs"]}_{unit["local_toknums"][0]}-{unit["local_toknums"][-1]}'

                # p = ucore.Passage(ID)
                # other_l0 = ul0.Layer0(p)
                # other_l1 = ul1.Layer1(p)
                #
                # root = other_l1.add_fnode(other_l1._head_fnode, ul1.EdgeTags.ParallelScene)
                #
                # # prep
                # term = create_terminal(pt, unit, other_l0, True)
                # if not term: continue
                # preterminal = other_l1.add_fnode(root, str(pt._fedge() in refined))
                # preterminal.add(ul1.EdgeTags.Terminal, term)
                #
                # # other node
                # term = create_terminal(n, unit, other_l0, False)
                # if not term: continue
                # preterminal = other_l1.add_fnode(root, str(n._fedge() in refined))
                # preterminal.add(ul1.EdgeTags.Terminal, term)

                refined, error = find_refined(terminal,
                                              dict(l0.pairs),
                                              local=True)

                for _, term in p.layer(ul0.LAYER_ID).pairs:
                    _pt = term.incoming[0].parent
                    toks = [t.text for t in _pt.get_terminals()]
                    term.extra['lexlemma'] = ' '.join(toks)
                    term.extra['lexcat'] = _pt.ftag
                    # term.extra.update(unit.get('heuristic_relation', {}))
                    term.extra['is_part_of_mwe'] = len(toks) > 1
                    term.extra['identified_for_pss'] = str(
                        term.ID == terminal.ID)

                edges = [
                    c.edge for cs in uconst.extract_candidates(p).values()
                    for c in cs
                ]
                for edge in edges:
                    edge.categories = []
                    edge.add(str(edge in refined))

                uconv.passage2file(p, f'{outpath}/{ID}.xml')
Esempio n. 13
0
def convert_refinement(passage):
    for edge in (c.edge for _c in extract_candidates(passage).values() for c in _c):
        ss = edge.refinement
        if ss is None: continue
        if ss.startswith('p.'):
            edge.add(ss)
Esempio n. 14
0
def extract_and_check(p, constructions=None, expected=None):
    d = OrderedDict((construction, [candidate.edge for candidate in candidates]) for construction, candidates in
                    extract_candidates(p, constructions=constructions).items() if candidates)
    if expected is not None:
        hist = {c.name: len(e) for c, e in d.items()}
        assert hist == expected, " != ".join(",".join(sorted(h)) for h in (hist, expected))
Esempio n. 15
0
def main(args):
    try:
        integrate_full = True
        integrate_term = False
        concatenate = False
        pss_feature = False
        annotate = True
        object = False
        v2_only = True
        draw = False
        output = True
        inp_ucca = False
        if '-I' in args:
            args.remove('-I')
            args.append('--no-integrate')
        if '--no-integrate' in args:
            integrate_full = False
            args.remove('--no-integrate')

        if '-c' in args:
            args.remove('-c')
            args.append('--concatenate')
        if '--concatenate' in args:
            concatenate = True
            args.remove('--concatenate')

        if '-A' in args:
            args.remove('-A')
            args.append('--no-annotate')
        if '--no-annotate' in args:
            integrate_full = False
            annotate = False
            args.remove('--no-annotate')

        if '-s' in args:
            args.remove('-s')
            args.append('--pss-feature')
        if '--pss-feature' in args:
            pss_feature = True
            args.remove('--pss-feature')

        if '--term' in args:
            integrate_term = True
            integrate_full = False
            args.remove('--term')

        if '--inp_ucca' in args:
            inp_ucca = True
            args.remove('--inp_ucca')

        if '-o' in args:
            args.remove('-o')
            args.append('--object')
        if '--object' in args:
            object = True
            args.remove('--object')

        if '-n' in args:
            args.remove('-n')
            args.append('--no-output')
        if '--no-output' in args:
            output = False
            args.remove('--no-output')

        if '--all' in args:
            v2_only = False
            args.remove('--all')

        if '--draw' in args:
            draw = True
            args.remove('--draw')
            import visualization as uviz
            import matplotlib.pyplot as plt

        streusle_file = args[
            0]  #'../../streusle/streusle.govobj.json' #args[0] #'streusle.govobj.json'  # sys.argv[1]
        ucca_path = args[
            1]  #'../../UCCA_English-EWT' #args[1] # '/home/jakob/nert/corpora/UCCA_English-EWT/xml'  # sys.argv[2]
        out_dir = args[2]

    except:
        print(f'usage: python3 {sys.argv[0]} STREUSLE_JSON UCCA_PATH OUT_DIR',
              file=sys.stderr)
        exit(1)

    with open(streusle_file) as f:
        streusle = json.load(f)

    print()

    global_error = Counter()

    unit_counter = 0
    successful_units = 0
    unsuccessful_units = 0
    deductible_multiple_successes = 0
    deductible_multiple_fails = 0
    deductible_fail_and_success = 0
    units_with_remote = 0

    doc_error = 0

    primary_edges = 0
    remote_edges = 0

    _doc_id = None

    v2_docids = set()
    if v2_only:
        with open(ucca_path + '/v2.txt') as f:
            for line in f:
                v2_docids.add(line.strip())

    ignore = []
    #"""020851
    #            020992
    #            059005
    #            059416
    #            200957
    #            210066
    #            211797
    #            216456
    #            217359
    #            360937
    #            399348""".split()

    unit_times = []

    # print('usnacs.get_passages(streusle_file, ucca_path, annotate=(integrate or annotate), ignore=ignore, docids=v2_docids)')

    tag_refinements = Counter()

    for doc, passage, term2tok in get_passages(
            streusle_file,
            ucca_path,
            annotate=(integrate_term or integrate_full or annotate),
            target='obj' if object else 'prep',
            ignore=ignore,
            docids=v2_docids):

        if output and (not integrate_full and not integrate_term):
            for p in uconv.split_passage(
                    passage, doc['ends'],
                    map(lambda x: ''.join(x['sent_id'].split('-')[-2:]),
                        doc['sents'])):
                uconv.passage2file(p, out_dir + '/' + p.ID + '.xml')
            continue

        l1 = passage.layer('1')

        if not output:
            primary_edges += len(
                uconstr.extract_candidates(
                    passage, constructions=(uconstr.PRIMARY, ))['primary'])
            remote_edges += len(
                uconstr.extract_candidates(passage,
                                           constructions=uconstr.get_by_names(
                                               ['remote']))['remote'])

        for terminal in passage.layer('0').words:

            if integrate_term and concatenate:  # and not terminal.incoming[0].parent.tag.startswith('Preterminal'):
                old_term_edge = terminal.incoming[0]
                preterminal = old_term_edge.parent
                preterminal._outgoing.remove(old_term_edge)
                terminal._incoming.remove(old_term_edge)
                passage._remove_edge(old_term_edge)
                #                old_preterm_edge = preterminal._fedge()
                #                preterminal.fparent._outgoing.remove(old_preterm_edge)
                new_preterminal = l1.add_fnode(
                    preterminal, 'Preterminal'
                )  #[[c.tag, '', c.layer, ''] for c in old_preterm_edge.categories])
                #                passage._add_node(new_preterminal)
                #for outg in preterminal.outgoing:
                #if inc.parent != preterminal.fparent and ul1.EdgeTags.Terminal not in inc.tags:
                #                new_preterminal.add(ul1.EdgeTags.Terminal, terminal)
                #                passage._add_node(new_preterminal)
                #preterminal._incoming = []
                #                new_preterminal.add('Preterminal', preterminal)
                #                passage._remove_edge(old_term_edge)
                new_preterminal.add_multiple(
                    [[c.tag, '', c.layer, '']
                     for c in old_term_edge.categories], terminal)
#                assert preterminal.outgoing
#                assert new_preterminal.outgoing
#                print(preterminal)
#                print(new_preterminal)
#                print(terminal)

            pss_label = ''
            if 'ss' in terminal.extra:
                pss_label = terminal.extra['ss']
            if not pss_label.startswith('p'):
                # print(terminal.extra)
                continue

            # print('ok')

            start_time = time.time()
            unit_counter += 1

            if integrate_term:
                if concatenate:
                    #                    old_term_edge = terminal.incoming[0]
                    #                    preterminal = old_term_edge.parent
                    #                    new_preterminal = l1.add_fnode(preterminal, 'Preterminal')
                    #                    passage._add_node(new_preterminal)
                    #                    old_term_edge.parent._outgoing.remove(old_term_edge)
                    #                    old_term_edge.child._incoming.remove(old_term_edge)
                    #                    passage._remove_edge(old_term_edge)
                    #                    new_term_edge = new_preterminal.add(ul1.EdgeTags.Terminal, terminal)
                    #                    passage._add_edge(new_term_edge)
                    #                    refined = new_preterminal.incoming
                    refined = terminal.incoming[0].parent.incoming
                else:
                    refined = terminal.incoming
            else:
                refined, error = find_refined(
                    terminal, dict(passage.layer(ul0.LAYER_ID).pairs))

                global_error += Counter(
                    {k: v
                     for k, v in error.items() if isinstance(v, int)})

                if error['successes_for_unit'] >= 1:
                    successful_units += 1
                    deductible_multiple_successes += error[
                        'successes_for_unit'] - 1
                    if error['fails_for_unit'] >= 1:
                        deductible_fail_and_success += 1
                else:
                    unsuccessful_units += 1

                if error['fails_for_unit'] >= 1:
                    deductible_multiple_fails += error['fails_for_unit'] - 1

                if error['remotes'] >= 1:
                    units_with_remote += 1

                if not output:
                    if 'larger_UNA_warn' in error['failed_heuristics']:
                        print(terminal, terminal.incoming[0].parent)

                    if 'PP_idiom_not_UNA' in error['failed_heuristics']:
                        print('PP_idiom:', terminal.extra['lexlemma'],
                              terminal, terminal.incoming[0].parent)

                    if 'MWP_not_UNA' in error['failed_heuristics']:
                        print('MWP:', terminal.extra['lexlemma'], terminal,
                              terminal.incoming[0].parent)

            for r in refined:
                # TODO: deal with doubly refined edges
                if (not concatenate and r.refinement) or (concatenate
                                                          and ':' in r.tag):
                    pass
                else:
                    if concatenate:
                        cats, r.categories = r.categories, []
                        for c in cats:
                            composit_tag = f'{c.tag}:{pss_label}'
                            r.add(composit_tag)
                            tag_refinements[composit_tag] += 1
                    else:
                        r.refinement = pss_label
#                print('FAIL', doc['id'], terminal.extra['toknums'], terminal.extra['lexlemma'])

            unit_times.append(time.time() - start_time)

            if not pss_feature:
                terminal.extra.pop('ss')  # ensuring pss is not also a feature

#            if integrate_term:
#                terminal.extra['identified_for_pss'] = str(True)

        if draw:
            for sent, psg in zip(doc['sents'],
                                 uconv.split_passage(passage, doc['ends'])):
                uviz.draw(psg)
                plt.savefig(f'../graphs/{sent["sent_id"]}.svg')
                plt.clf()


#        print(passage)
        if output:
            for p in uconv.split_passage(
                    passage, doc['ends'],
                    map(lambda x: ''.join(x['sent_id'].split('-')[-2:]),
                        doc['sents'])):
                #                print(p)
                #            augmented = uconv.join_passages([p, ucore.Passage('0')])
                #            for root_edge in augmented.layer(ul1.LAYER_ID)._head_fnode.outgoing:
                #                if len(root_edge.tag.split('-')) > 1:
                #                    assert False, augmented
                #                root_edge.tag = root_edge.tag.split('-')[0]
                uconv.passage2file(p, out_dir + '/' + p.ID + '.xml')

    for x, y in tag_refinements.most_common(len(tag_refinements)):
        print(x, y, sep='\t')

    #print(f'successful units\t{successful_units}\t{100*successful_units/(unit_counter-doc_error)}%')
    #print(f'unsuccessful units\t{unsuccessful_units}\t{100-(100*successful_units/(unit_counter-doc_error))}%') #={unit_counter - doc_error - successful_units}={mwe_una_fail+abgh_fail+c_fail+d_fail+e_fail+f_fail+g_fail+no_match}

    if integrate_full and not output:

        print('\n\n')
        print(f'total units\t{unit_counter}')
        #   print(f'gov and obj present\t{gov_and_obj_counter}')
        print(f'document error\t{doc_error}\t{100*doc_error/unit_counter}%')
        print(
            f'document success\t{unit_counter - doc_error}\t{100-(100 * doc_error / unit_counter)}%'
        )
        print(f'total primary edges\t{primary_edges}')
        print(f'total remote edges\t{remote_edges}')
        print('----------------------------------------------------')
        print(
            f'successful units\t{successful_units}\t{100*successful_units/(unit_counter-doc_error)}%'
        )
        print(
            f'unsuccessful units\t{unsuccessful_units}\t{100-(100*successful_units/(unit_counter-doc_error))}%'
        )  #={unit_counter - doc_error - successful_units}={mwe_una_fail+abgh_fail+c_fail+d_fail+e_fail+f_fail+g_fail+no_match}
        print(f'warnings\t{global_error["warnings"]}')
        print('---------------------------------')
        #    for ftype, count in fail_counts.most_common():
        #        print(f'{ftype}\t{count}')
        print(
            f'syntactic and semantic obj match\t{global_error["synt_sem_obj_match"]}'
        )
        print('---------------------------------')
        print(f'\tMWE but not UNA\t{global_error["mwe_una_fail"]}')
        print(f'\tPP idiom\t{global_error["idiom"]}')
        print(
            f'\tR, N, F ({global_error["abgh"]}) but A and B miss\t{global_error["abgh_fail"]}'
        )
        print(f'\tA (scene mod)\t{global_error["a"]}')
        print(f'\tB (non-scene mod) \t{global_error["b"]}')

        print(f'\tG (inh purpose) \t{global_error["g"]}')
        print(f'\t  scn \t{global_error["g_scn_mod"]}')
        print(f'\t  non scn \t{global_error["g"] - global_error["g_scn_mod"]}')

        print(f'\tH (approximator) \t{global_error["h"]}')
        print(f'\t  scn \t{global_error["h_scn_mod"]}')
        print(f'\t  non scn \t{global_error["h"] - global_error["h_scn_mod"]}')

        print(
            f'\tP, S ({global_error["c"]}) but C miss\t{global_error["c_fail"]}'
        )
        print(
            f'\tL ({global_error["d"]}) but D miss\t{global_error["d_fail"]}')

        print(
            f'\tA, D, E, T ({global_error["ef"]}) but E miss\t{global_error["ef_fail"]}'
        )

        print(f'\tE (intr adp) \t{global_error["e"]}')
        print(f'\t  scn \t{global_error["e_scn_mod"]}')
        print(f'\t  non scn \t{global_error["e"] - global_error["e_scn_mod"]}')

        print(f'\tF (poss pron) \t{global_error["f"]}')
        print(f'\t  scn \t{global_error["f_scn_mod"]}')
        print(f'\t  non scn \t{global_error["f"] - global_error["f_scn_mod"]}')

        #print(f'\tA ({f}) but F miss\t{f_fail}')
        #print(f'\tF ({g}) but G miss\t{g_fail}')
        print(
            f'\tno match\t{global_error["no_match"]}')  #\t{ucca_categories}')
        print(f'\tnon-semantic role\t{global_error["non_semrole"]}')
        print(
            f'\tmultiple preterminals\t{global_error["multiple_preterminals"]}'
        )
        print(
            f'\tunits with remote\t{units_with_remote} (total {global_error["remotes"]})'
        )
        #
        #
        print('---------------------------------')
        print(
            f'\tdeductible (multiple successes for single unit)\t{deductible_multiple_successes}'
        )
        print(
            f'\tdeductible (multiple fails for single unit)\t{deductible_multiple_fails}'
        )
        print(
            f'\tdeductible (fail and success for single unit)\t{deductible_fail_and_success}'
        )