Пример #1
0
def extract_features(args):

    tok_seqs = read_toks(args.tok_file)
    lemma_seqs = read_toks(args.lemma_file)
    pos_seqs = read_toks(args.pos_file)

    print 'A total of %d sentences' % len(tok_seqs)

    assert len(tok_seqs) == len(pos_seqs)
    assert len(tok_seqs) == len(lemma_seqs)

    stop_words = set([line.strip() for line in open(args.stop, 'r')])

    non_map_words = set(['am', 'is', 'are', 'be', 'a', 'an', 'the'])

    der_lemma_file = os.path.join(args.lemma_dir, 'der.lemma')
    der_lemma_map = initialize_lemma(der_lemma_file)

    train_entity_set = load_entities(args.stats_dir)
    all_entities = identify_entities(args.tok_file, args.ner_file, train_entity_set)

    (pred_set, pred_lemma_set, pred_labels, non_pred_mapping, non_pred_lemma_mapping, entity_labels) = load_mappings(args.stats_dir)

    feature_f = open(args.feature_file, 'w')
    #Each entity take the form (start, end, role)
    for (i, (toks, lemmas, pos_seq, entities_in_sent)) in enumerate(zip(tok_seqs, lemma_seqs, pos_seqs, all_entities)):
        n_toks = len(toks)
        aligned_set = set()

        all_spans = []
        assert len(toks) == len(lemmas)
        assert len(toks) == len(pos_seq)

        for (start, end, entity_typ) in entities_in_sent:
            new_aligned = set(xrange(start, end))
            aligned_set |= new_aligned
            all_spans.append((start, end, False, False, True))
            assert end <= len(toks)

        for index in xrange(n_toks):
            if index in aligned_set:
                continue

            curr_tok = toks[index]
            curr_lem = lemmas[index]
            curr_pos = pos_seq[index]

            aligned_set.add(index)

            if curr_tok in pred_set or curr_lem in pred_set or curr_lem in pred_lemma_set:
                all_spans.append((index, index+1, True, False, False))

            elif curr_tok in non_map_words:
                all_spans.append((index, index+1, False, False, False))
            elif curr_tok in non_pred_mapping or curr_lem in non_pred_lemma_mapping:
                all_spans.append((index, index+1, False, False, False))
            else: #not found in any mapping
                retrieved = False
                if curr_tok in der_lemma_map:
                    for tok in der_lemma_map[curr_tok]:
                        if tok in pred_set or tok in pred_lemma_set:
                            if curr_tok.endswith('ion') or curr_tok.endswith('er'):
                                all_spans.append((index, index+1, True, False, False))
                                retrieved = True
                                break
                if not retrieved:
                    all_spans.append((index, index+1, False, False, False))
            assert index < len(toks)

        all_spans = sorted(all_spans, key=lambda span: (span[0], span[1]))
        for (start, end, is_pred, is_op, is_ent) in all_spans:
            fs = []
            end -= 1
            #print start, end, len(toks)
            fs += extract_span(toks, start, end, 3, 'word')
            fs += extract_bigram(toks, start, end, 3, 'word')
            fs += extract_curr(toks, start, end, 'word')
            if is_ent:
                fs += extract_seq_feat(toks, start, end, 'word')

            #Lemma feature
            fs += extract_span(lemmas, start, end, 3, 'lemma')
            fs += extract_bigram(lemmas, start, end, 3, 'lemma')
            fs += extract_curr(lemmas, start, end, 'lemma')

            #Pos tag feature
            fs += extract_span(pos_seq, start, end, 3, 'POS')
            fs += extract_bigram(pos_seq, start, end, 3, 'POS')
            fs += extract_curr(pos_seq, start, end, 'POS')

            #Length of span feature
            fs.append('Length=%d' % (end - start))

            #Suffix feature
            if not is_ent and start == end:
                fs += suffix(toks[start])

            print >>feature_f, '##### %d-%d %s %s %s' % (start, end+1, '1' if is_pred else '0', '1' if is_ent else '0', ' '.join(fs))
        print >>feature_f, ''
    feature_f.close()
Пример #2
0
def concept_id(args):

    tok_seqs = read_toks(args.tok_file)
    lemma_seqs = read_toks(args.lemma_file)
    pos_seqs = read_toks(args.pos_file)

    #The process to extract entities
    train_entity_set = load_entities(args.stats_dir)
    all_entities = identify_entities(args.tok_file, args.ner_file, train_entity_set)
    all_dates = extract_all_dates(args.tok_file)

    non_map_words = set(['am', 'is', 'are', 'be', 'a', 'an', 'the', ',', '.', '..', '...', ':', '(', ')', '@-@', 'there', 'they', 'do', 'and', '\"', '-@' ])

    special_words = set(['a', 'the', 'an', 'it', 'its', 'are', 'been', 'have', 'has', 'had'])
    begin_words = set(['if', 'it'])

    tok_map = {}
    lemma_map = {}
    tok_f = open(args.tok_map, 'rb')
    tok_map = cPickle.load(tok_f)
    tok_f.close()

    lemma_f = open(args.lemma_map, 'rb')
    lemma_map = cPickle.load(lemma_f)
    lemma_f.close()

    new_tok_f = open('%s.temp' % args.tok_file, 'w')
    new_lemma_f = open('%s.temp' % args.lemma_file, 'w')

    result_f = open(args.output, 'w')

    der_lemma_map = initialize_lemma('./lemmas/der.lemma')

    for (i, (toks, lemmas, poss)) in enumerate(zip(tok_seqs, lemma_seqs, pos_seqs)):
        #print 'sentence %d' % i
        n_toks = len(toks)
        visited = set()

        orig_toks = copy(toks)
        toks = [t.lower() for t in toks]
        lemmas = [t.lower() for t in lemmas]

        ent_in_sent = all_entities[i]
        dates_in_line = all_dates[i]
        assert len(toks) == len(lemmas)
        assert len(toks) == len(poss)

        aligned_toks = set()  #See which position is covered
        aligned_rules = []

        for (start, end, lhs, frag_str) in dates_in_line:
            new_aligned = set(xrange(start, end))
            aligned_toks |= new_aligned

            rule_str = build_rule_str(lhs, toks, start, end, frag_str)
            aligned_rules.append(rule_str)

        if 0 not in aligned_toks:
            (is_li, lhs, frag_str) = extract_li(toks[0])
            if is_li:
                aligned_toks.add(0)
                rule_str = build_rule_str(lhs, toks, 0, 1, frag_str)
                aligned_rules.append(rule_str)

        for start in xrange(n_toks):
            (has_slash, lhs, frag_str) = dealwith_slash(toks[start:start+1], tok_map)
            if has_slash:
                aligned_toks.add(start)
                rule_str = build_rule_str(lhs, toks, start, start+1, frag_str)
                aligned_rules.append(rule_str)
                print 'Retrieved slash'
                print '%s : %s' % (toks[start], rule_str)
                sys.stdout.flush()

        #Dealing with entities
        for (start, end, entity_typ) in ent_in_sent:

            new_aligned = set(xrange(start, end))
            if len(new_aligned & aligned_toks) != 0:
                continue

            aligned_toks |= new_aligned

            curr_ent = '_'.join(toks[start:end])
            curr_lex = ' '.join(toks[start:end])
            if curr_lex in tok_map:
                items = tok_map[curr_lex][''].items()
                assert len(items) > 0
                items = sorted(items, key=lambda it: it[1])

                (lhs, frag_part) = items[-1][0]
                if lhs.strip() == 'Nothing':
                    continue

                rule_str = '%d-%d####%s ## %s ## %s' % (start, end, lhs, ' '.join(toks[start:end]), frag_part)
                aligned_rules.append(rule_str)
            else:
                if not entity_typ:
                    continue
                assert entity_typ
                #if 'PER' in entity_typ: #Identified as a person
                frag_label = 'entity+person'
                if 'PER' in entity_typ:
                    frag_label = 'entity+person'
                elif 'ORG' in entity_typ:
                    frag_label = 'entity+organization'
                elif 'LOC' in entity_typ:
                    frag_label = 'entity+city'

                frag_str = build_one_entity(orig_toks[start:end], frag_label)
                rule_str = '%d-%d####[A1-1] ## %s ## %s' % (start, end, ' '.join(toks[start:end]), frag_str)
                aligned_rules.append(rule_str)

        #continue
        preprocess_tok(toks, aligned_toks)
        preprocess_tok(lemmas, aligned_toks)

        for temp_t in toks:
            assert len(temp_t.strip()) > 0

        print >>new_tok_f, ' '.join(toks)
        print >>new_lemma_f, ' '.join(lemmas)

        possible_aligned = set()
        for start in xrange(n_toks):
            if start in visited:
                continue

            for length in xrange(n_toks+1, 0, -1):
                end = start + length

                if end > n_toks:
                    continue

                span_set = set(xrange(start, end))
                if len(span_set & aligned_toks) != 0:
                    continue

                if toks[start] in non_map_words and end-start > 1:
                    continue

                #if length >= 2:
                #    if poss[end-1] == 'IN' or poss[end-1] == 'DT':

                seq_str = ' '.join(toks[start:end])
                lem_seq_str = ' '.join(lemmas[start:end])
                if length == 1 and seq_str in non_map_words:  #Should verify if these two lines should be commented
                    continue

                contexts = get_context(toks, lemmas, poss, start, end)
                aligned_value = find_maps(seq_str, lem_seq_str, tok_map, lemma_map, contexts)
                if aligned_value:
                    (lhs, frag_part) = aligned_value
                    lhs = lhs.strip()
                    frag_part = frag_part.strip()
                    if lhs == 'Nothing':
                        continue

                    rule_str = build_rule_str(lhs, toks, start, end, frag_part)
                    if length >= 2 and length <= 4:
                        if toks[end-1] in special_words:
                            print 'discarded: %s' % rule_str
                            sys.stdout.flush()
                            continue

                        if toks[start] in begin_words:
                            print 'begin discard: %s' % rule_str
                            sys.stdout.flush()
                            continue

                    aligned_rules.append(rule_str)
                    #if poss[end-1] == 'the' or poss[end-1] == 'DT' or poss[start] == 'IN' or poss[start] == 'DT':
                    #    print rule_str
                    #    sys.stdout.flush()

                    if end - start == 1:
                        aligned_toks.add(start)

                    if num_nonmap(non_map_words, toks[start:end]) > 1:
                        if len(possible_aligned & span_set) == 0:
                            #new_aligned = set(xrange(start, end))
                            aligned_toks |= span_set
                            possible_aligned |= span_set
                            break

                    possible_aligned |= span_set

        unaligned_toks = set(xrange(n_toks)) - aligned_toks
        retrieve_unaligned(unaligned_toks, toks, lemmas, poss, der_lemma_map, aligned_rules, non_map_words, tok_map, lemma_map)

        print >>result_f, '%s ||| %s ||| %s' % (' '.join(toks), ' '.join([str(k) for k in unaligned_toks]), '++'.join(aligned_rules))
        #print ' '.join(['%s/%s/%s' % (toks[k], lemmas[k], poss[k]) for k in unaligned_toks])
    new_tok_f.close()
    new_lemma_f.close()
    result_f.close()
Пример #3
0
def linearize_amr(args):
    logger.file = open(os.path.join(args.run_dir, 'logger'), 'w')

    amr_file = os.path.join(args.data_dir, 'amr')
    alignment_file = os.path.join(args.data_dir, 'alignment')
    if args.use_lemma:
        tok_file = os.path.join(args.data_dir, 'lemmatized_token')
    else:
        tok_file = os.path.join(args.data_dir, 'token')
    pos_file = os.path.join(args.data_dir, 'pos')

    amr_graphs = load_amr_graphs(amr_file)
    alignments = [line.strip().split() for line in open(alignment_file, 'r')]
    toks = [line.strip().split() for line in open(tok_file, 'r')]
    poss = [line.strip().split() for line in open(pos_file, 'r')]

    assert len(amr_graphs) == len(alignments) and len(amr_graphs) == len(toks) and len(amr_graphs) == len(poss), '%d %d %d %d %d' % (len(amr_graphs), len(alignments), len(toks), len(poss))

    num_self_cycle = 0
    used_sents = 0

    amr_statistics = AMR_stats()

    if args.use_stats:
        amr_statistics.loadFromDir(args.stats_dir)
        #print amr_statistics
    else:
        os.system('mkdir -p %s' % args.stats_dir)
        amr_statistics.collect_stats(amr_graphs)
        amr_statistics.dump2dir(args.stats_dir)

    if args.parallel:
        singleton_num = 0.0
        multiple_num = 0.0
        total_num = 0.0
        empty_num = 0.0

        amr_seq_file = os.path.join(args.run_dir, 'amrseq')
        tok_seq_file = os.path.join(args.run_dir, 'tokseq')
        map_seq_file = os.path.join(args.run_dir, 'train_map')

        amrseq_wf = open(amr_seq_file, 'w')
        tokseq_wf = open(tok_seq_file, 'w')
        mapseq_wf = open(map_seq_file, 'w')

        for (sent_index, (tok_seq, pos_seq, alignment_seq, amr)) in enumerate(zip(toks, poss, alignments, amr_graphs)):

            logger.writeln('Sentence #%d' % (sent_index+1))
            logger.writeln(' '.join(tok_seq))

            amr.setStats(amr_statistics)

            edge_alignment = bitarray(len(amr.edges))
            if edge_alignment.count() != 0:
                edge_alignment ^= edge_alignment
            assert edge_alignment.count() == 0

            has_cycle = False
            if amr.check_self_cycle():
                num_self_cycle += 1
                has_cycle = True

            amr.set_sentence(tok_seq)
            amr.set_poss(pos_seq)

            aligned_fragments = []
            reentrancies = {}  #Map multiple spans as reentrancies, keeping only one as original, others as connections

            has_multiple = False
            no_alignment = False

            aligned_set = set()

            (opt_toks, role_toks, node_to_span, edge_to_span, temp_aligned) = extractNodeMapping(alignment_seq, amr)

            temp_unaligned = set(xrange(len(pos_seq))) - temp_aligned

            all_frags = []
            all_alignments = defaultdict(list)

            ####Extract named entities#####
            for (frag, wiki_label) in amr.extract_entities():
                if len(opt_toks) == 0:
                    logger.writeln("No alignment for the entity found")

                (aligned_indexes, entity_spans) = all_aligned_spans(frag, opt_toks, role_toks, temp_unaligned)
                root_node = amr.nodes[frag.root]

                entity_mention_toks = root_node.namedEntityMention()

                total_num += 1.0
                if entity_spans:
                    entity_spans = removeRedundant(tok_seq, entity_spans, entity_mention_toks)
                    if len(entity_spans) == 1:
                        singleton_num += 1.0
                        logger.writeln('Single fragment')
                        for (frag_start, frag_end) in entity_spans:
                            logger.writeln(' '.join(tok_seq[frag_start:frag_end]))
                            all_alignments[frag.root].append((frag_start, frag_end, wiki_label))
                            temp_aligned |= set(xrange(frag_start, frag_end))
                    else:
                        multiple_num += 1.0
                        logger.writeln('Multiple fragment')
                        logger.writeln(aligned_indexes)
                        logger.writeln(' '.join([tok_seq[index] for index in aligned_indexes]))

                        for (frag_start, frag_end) in entity_spans:
                            logger.writeln(' '.join(tok_seq[frag_start:frag_end]))
                            all_alignments[frag.root].append((frag_start, frag_end, wiki_label))
                            temp_aligned |= set(xrange(frag_start, frag_end))
                else:
                    empty_num += 1.0

            ####Process date entities
            date_entity_frags = amr.extract_all_dates()
            for frag in date_entity_frags:
                all_date_indices, index_to_attr = getDateAttr(frag)
                covered_toks, non_covered, index_to_toks = getSpanSide(tok_seq, alignment_seq, frag, temp_unaligned)

                covered_set = set(covered_toks)

                all_spans = getContinuousSpans(covered_toks, temp_unaligned, covered_set)
                if all_spans:
                    temp_spans = []
                    for span_start, span_end in all_spans:
                        if span_start > 0 and (span_start-1) in temp_unaligned:
                            if tok_seq[span_start-1] in str(frag) and tok_seq[0] in '0123456789':
                                temp_spans.append((span_start-1, span_end))
                            else:
                                temp_spans.append((span_start, span_end))
                        else:
                            temp_spans.append((span_start, span_end))
                    all_spans = temp_spans
                    all_spans = removeDateRedundant(all_spans)
                    for span_start, span_end in all_spans:
                        all_alignments[frag.root].append((span_start, span_end, None))
                        temp_aligned |= set(xrange(span_start, span_end))
                        if len(non_covered) == 0:
                            print 'Dates: %s' % ' '.join(tok_seq[span_start:span_end])
                else:
                    for index in temp_unaligned:
                        curr_tok = tok_seq[index]
                        found = False
                        for un_tok in non_covered:
                            if curr_tok[0] in '0123456789' and curr_tok in un_tok:
                                print 'recovered: %s' % curr_tok
                                found = True
                                break
                        if found:
                            all_alignments[frag.root].append((index, index+1, None))
                            temp_aligned.add(index)
                            print 'Date: %s' % tok_seq[index]

            #Verbalization list
            verb_map = {}
            for (index, curr_tok) in enumerate(tok_seq):
                if curr_tok in VERB_LIST:

                    for subgraph in VERB_LIST[curr_tok]:

                        matched_frags = amr.matchSubgraph(subgraph)
                        if matched_frags:
                            temp_aligned.add(index)

                        for (node_index, ex_rels) in matched_frags:
                            all_alignments[node_index].append((index, index+1, None))
                            verb_map[node_index] = subgraph

            #####Load verbalization list #####
            for node_index in node_to_span:
                if node_index in all_alignments:
                    continue

                all_alignments[node_index] = node_to_span[node_index]

            ##Based on the alignment from node index to spans in the string
            temp_unaligned = set(xrange(len(pos_seq))) - temp_aligned

            assert len(tok_seq) == len(pos_seq)

            amr_seq, cate_tok_seq, map_seq = categorizeParallelSequences(amr, tok_seq, all_alignments, temp_unaligned, verb_map, args.min_prd_freq, args.min_var_freq)
            print >> amrseq_wf, ' '.join(amr_seq)
            print >> tokseq_wf, ' '.join(cate_tok_seq)
            print >> mapseq_wf, '##'.join(map_seq)  #To separate single space

        amrseq_wf.close()
        tokseq_wf.close()
        mapseq_wf.close()

        #print "one to one alignment: %lf" % (singleton_num/total_num)
        #print "one to multiple alignment: %lf" % (multiple_num/total_num)
        #print "one to empty alignment: %lf" % (empty_num/total_num)
    else: #Only build the linearized token sequence

        mle_map = loadMap(args.map_file)
        if args.use_lemma:
            tok_file = os.path.join(args.data_dir, 'lemmatized_token')
        else:
            tok_file = os.path.join(args.data_dir, 'token')

        ner_file = os.path.join(args.data_dir, 'ner')
        date_file = os.path.join(args.data_dir, 'date')

        all_entities = identify_entities(tok_file, ner_file, mle_map)
        all_dates = dateMap(date_file)

        tokseq_result = os.path.join(args.data_dir, 'linearized_tokseq')
        dev_map_file = os.path.join(args.data_dir, 'cate_map')
        tokseq_wf = open(tokseq_result, 'w')
        dev_map_wf = open(dev_map_file, 'w')

        for (sent_index, (tok_seq, pos_seq, entities_in_sent)) in enumerate(zip(toks, poss, all_entities)):
            print 'snt: %d' % sent_index
            n_toks = len(tok_seq)
            aligned_set = set()

            all_spans = []
            date_spans = all_dates[sent_index]
            date_set = set()

            #Align dates
            for (start, end) in date_spans:
                if end - start > 1:
                    new_aligned = set(xrange(start, end))
                    aligned_set |= new_aligned
                    entity_name = ' '.join(tok_seq[start:end])
                    if entity_name in mle_map:
                        entity_typ = mle_map[entity_name]
                    else:
                        entity_typ = ('DATE', "date-entity", "NONE")
                    all_spans.append((start, end, entity_typ))
                    print 'Date:', start, end
                else:
                    date_set.add(start)

            #First align multi tokens
            for (start, end, entity_typ) in entities_in_sent:
                if end - start > 1:
                    new_aligned = set(xrange(start, end))
                    if len(aligned_set & new_aligned) != 0:
                        continue
                    aligned_set |= new_aligned
                    entity_name = ' '.join(tok_seq[start:end])
                    if entity_name in mle_map:
                        entity_typ = mle_map[entity_name]
                    else:
                        entity_typ = ('NE_person', "person", '-')
                    all_spans.append((start, end, entity_typ))

            #Single token
            for (index, curr_tok) in enumerate(tok_seq):
                if index in aligned_set:
                    continue

                curr_pos = pos_seq[index]
                aligned_set.add(index)

                if curr_tok in mle_map:
                    (category, node_repr, wiki_label) = mle_map[curr_tok]
                    if category.lower() == 'none':
                        all_spans.append((index, index+1, (curr_tok, "NONE", "NONE")))
                    else:
                        all_spans.append((index, index+1, mle_map[curr_tok]))
                else:

                    if curr_tok[0] in '\"\'.':
                        print 'weird token: %s, %s' % (curr_tok, curr_pos)
                        continue
                    if index in date_set:
                        entity_typ = ('DATE', "date-entity", "NONE")
                        all_spans.append((index, index+1, entity_typ))
                    elif curr_tok in VERB_LIST:
                        node_repr = VERB_LIST[curr_tok][0].keys()[0]
                        entity_typ = ('VERBAL', node_repr, "NONE")
                        all_spans.append((index, index+1, entity_typ))

                    elif curr_pos[0] == 'V':
                        node_repr = '%s-01' % curr_tok
                        all_spans.append((index, index+1, ('-VERB-', node_repr, "NONE")))
                    else:
                        node_repr = curr_tok
                        all_spans.append((index, index+1, ('-SURF-', curr_tok, "NONE")))

            all_spans = sorted(all_spans, key=lambda span: (span[0], span[1]))
            print all_spans
            linearized_tokseq, map_repr_seq = getIndexedForm(all_spans)

            print >> tokseq_wf, ' '.join(linearized_tokseq)
            print >> dev_map_wf, '##'.join(map_repr_seq)

        tokseq_wf.close()
        dev_map_wf.close()
Пример #4
0
def linearize_amr(args):
    logger.file = open(os.path.join(args.run_dir, 'logger'), 'w')

    amr_file = os.path.join(args.data_dir, 'amr')
    alignment_file = os.path.join(args.data_dir, 'alignment')
    if args.use_lemma:
        tok_file = os.path.join(args.data_dir, 'lemmatized_token')
    else:
        tok_file = os.path.join(args.data_dir, 'token')
    pos_file = os.path.join(args.data_dir, 'pos')

    amr_graphs = load_amr_graphs(amr_file)
    alignments = [line.strip().split() for line in open(alignment_file, 'r')]
    toks = [line.strip().split() for line in open(tok_file, 'r')]
    poss = [line.strip().split() for line in open(pos_file, 'r')]

    assert len(amr_graphs) == len(alignments) and len(amr_graphs) == len(
        toks) and len(amr_graphs) == len(poss), '%d %d %d %d %d' % (
            len(amr_graphs), len(alignments), len(toks), len(poss))

    num_self_cycle = 0
    used_sents = 0

    amr_statistics = AMR_stats()

    if args.use_stats:
        amr_statistics.loadFromDir(args.stats_dir)
        #print amr_statistics
    else:
        os.system('mkdir -p %s' % args.stats_dir)
        amr_statistics.collect_stats(amr_graphs)
        amr_statistics.dump2dir(args.stats_dir)

    if args.parallel:
        singleton_num = 0.0
        multiple_num = 0.0
        total_num = 0.0
        empty_num = 0.0

        amr_seq_file = os.path.join(args.run_dir, 'amrseq')
        tok_seq_file = os.path.join(args.run_dir, 'tokseq')
        map_seq_file = os.path.join(args.run_dir, 'train_map')

        amrseq_wf = open(amr_seq_file, 'w')
        tokseq_wf = open(tok_seq_file, 'w')
        mapseq_wf = open(map_seq_file, 'w')

        for (sent_index,
             (tok_seq, pos_seq, alignment_seq,
              amr)) in enumerate(zip(toks, poss, alignments, amr_graphs)):

            logger.writeln('Sentence #%d' % (sent_index + 1))
            logger.writeln(' '.join(tok_seq))

            amr.setStats(amr_statistics)

            edge_alignment = bitarray(len(amr.edges))
            if edge_alignment.count() != 0:
                edge_alignment ^= edge_alignment
            assert edge_alignment.count() == 0

            has_cycle = False
            if amr.check_self_cycle():
                num_self_cycle += 1
                has_cycle = True

            amr.set_sentence(tok_seq)
            amr.set_poss(pos_seq)

            aligned_fragments = []
            reentrancies = {
            }  #Map multiple spans as reentrancies, keeping only one as original, others as connections

            has_multiple = False
            no_alignment = False

            aligned_set = set()

            (opt_toks, role_toks, node_to_span, edge_to_span,
             temp_aligned) = extractNodeMapping(alignment_seq, amr)

            temp_unaligned = set(xrange(len(pos_seq))) - temp_aligned

            all_frags = []
            all_alignments = defaultdict(list)

            ####Extract named entities#####
            for (frag, wiki_label) in amr.extract_entities():
                if len(opt_toks) == 0:
                    logger.writeln("No alignment for the entity found")

                (aligned_indexes,
                 entity_spans) = all_aligned_spans(frag, opt_toks, role_toks,
                                                   temp_unaligned)
                root_node = amr.nodes[frag.root]

                entity_mention_toks = root_node.namedEntityMention()

                total_num += 1.0
                if entity_spans:
                    entity_spans = removeRedundant(tok_seq, entity_spans,
                                                   entity_mention_toks)
                    if len(entity_spans) == 1:
                        singleton_num += 1.0
                        logger.writeln('Single fragment')
                        for (frag_start, frag_end) in entity_spans:
                            logger.writeln(' '.join(
                                tok_seq[frag_start:frag_end]))
                            all_alignments[frag.root].append(
                                (frag_start, frag_end, wiki_label))
                            temp_aligned |= set(xrange(frag_start, frag_end))
                    else:
                        multiple_num += 1.0
                        logger.writeln('Multiple fragment')
                        logger.writeln(aligned_indexes)
                        logger.writeln(' '.join(
                            [tok_seq[index] for index in aligned_indexes]))

                        for (frag_start, frag_end) in entity_spans:
                            logger.writeln(' '.join(
                                tok_seq[frag_start:frag_end]))
                            all_alignments[frag.root].append(
                                (frag_start, frag_end, wiki_label))
                            temp_aligned |= set(xrange(frag_start, frag_end))
                else:
                    empty_num += 1.0

            ####Process date entities
            date_entity_frags = amr.extract_all_dates()
            for frag in date_entity_frags:
                all_date_indices, index_to_attr = getDateAttr(frag)
                covered_toks, non_covered, index_to_toks = getSpanSide(
                    tok_seq, alignment_seq, frag, temp_unaligned)

                covered_set = set(covered_toks)

                all_spans = getContinuousSpans(covered_toks, temp_unaligned,
                                               covered_set)
                if all_spans:
                    temp_spans = []
                    for span_start, span_end in all_spans:
                        if span_start > 0 and (span_start -
                                               1) in temp_unaligned:
                            if tok_seq[span_start - 1] in str(
                                    frag) and tok_seq[0] in '0123456789':
                                temp_spans.append((span_start - 1, span_end))
                            else:
                                temp_spans.append((span_start, span_end))
                        else:
                            temp_spans.append((span_start, span_end))
                    all_spans = temp_spans
                    all_spans = removeDateRedundant(all_spans)
                    for span_start, span_end in all_spans:
                        all_alignments[frag.root].append(
                            (span_start, span_end, None))
                        temp_aligned |= set(xrange(span_start, span_end))
                        if len(non_covered) == 0:
                            print 'Dates: %s' % ' '.join(
                                tok_seq[span_start:span_end])
                else:
                    for index in temp_unaligned:
                        curr_tok = tok_seq[index]
                        found = False
                        for un_tok in non_covered:
                            if curr_tok[
                                    0] in '0123456789' and curr_tok in un_tok:
                                print 'recovered: %s' % curr_tok
                                found = True
                                break
                        if found:
                            all_alignments[frag.root].append(
                                (index, index + 1, None))
                            temp_aligned.add(index)
                            print 'Date: %s' % tok_seq[index]

            #Verbalization list
            verb_map = {}
            for (index, curr_tok) in enumerate(tok_seq):
                if curr_tok in VERB_LIST:

                    for subgraph in VERB_LIST[curr_tok]:

                        matched_frags = amr.matchSubgraph(subgraph)
                        if matched_frags:
                            temp_aligned.add(index)

                        for (node_index, ex_rels) in matched_frags:
                            all_alignments[node_index].append(
                                (index, index + 1, None))
                            verb_map[node_index] = subgraph

            #####Load verbalization list #####
            for node_index in node_to_span:
                if node_index in all_alignments:
                    continue

                all_alignments[node_index] = node_to_span[node_index]

            ##Based on the alignment from node index to spans in the string
            temp_unaligned = set(xrange(len(pos_seq))) - temp_aligned

            assert len(tok_seq) == len(pos_seq)

            amr_seq, cate_tok_seq, map_seq = categorizeParallelSequences(
                amr, tok_seq, all_alignments, temp_unaligned, verb_map,
                args.min_prd_freq, args.min_var_freq)
            print >> amrseq_wf, ' '.join(amr_seq)
            print >> tokseq_wf, ' '.join(cate_tok_seq)
            print >> mapseq_wf, '##'.join(map_seq)  #To separate single space

        amrseq_wf.close()
        tokseq_wf.close()
        mapseq_wf.close()

        #print "one to one alignment: %lf" % (singleton_num/total_num)
        #print "one to multiple alignment: %lf" % (multiple_num/total_num)
        #print "one to empty alignment: %lf" % (empty_num/total_num)
    else:  #Only build the linearized token sequence

        mle_map = loadMap(args.map_file)
        if args.use_lemma:
            tok_file = os.path.join(args.data_dir, 'lemmatized_token')
        else:
            tok_file = os.path.join(args.data_dir, 'token')

        ner_file = os.path.join(args.data_dir, 'ner')
        date_file = os.path.join(args.data_dir, 'date')

        all_entities = identify_entities(tok_file, ner_file, mle_map)
        all_dates = dateMap(date_file)

        tokseq_result = os.path.join(args.data_dir, 'linearized_tokseq')
        dev_map_file = os.path.join(args.data_dir, 'cate_map')
        tokseq_wf = open(tokseq_result, 'w')
        dev_map_wf = open(dev_map_file, 'w')

        for (sent_index,
             (tok_seq, pos_seq,
              entities_in_sent)) in enumerate(zip(toks, poss, all_entities)):
            print 'snt: %d' % sent_index
            n_toks = len(tok_seq)
            aligned_set = set()

            all_spans = []
            date_spans = all_dates[sent_index]
            date_set = set()

            #Align dates
            for (start, end) in date_spans:
                if end - start > 1:
                    new_aligned = set(xrange(start, end))
                    aligned_set |= new_aligned
                    entity_name = ' '.join(tok_seq[start:end])
                    if entity_name in mle_map:
                        entity_typ = mle_map[entity_name]
                    else:
                        entity_typ = ('DATE', "date-entity", "NONE")
                    all_spans.append((start, end, entity_typ))
                    print 'Date:', start, end
                else:
                    date_set.add(start)

            #First align multi tokens
            for (start, end, entity_typ) in entities_in_sent:
                if end - start > 1:
                    new_aligned = set(xrange(start, end))
                    if len(aligned_set & new_aligned) != 0:
                        continue
                    aligned_set |= new_aligned
                    entity_name = ' '.join(tok_seq[start:end])
                    if entity_name in mle_map:
                        entity_typ = mle_map[entity_name]
                    else:
                        entity_typ = ('NE_person', "person", '-')
                    all_spans.append((start, end, entity_typ))

            #Single token
            for (index, curr_tok) in enumerate(tok_seq):
                if index in aligned_set:
                    continue

                curr_pos = pos_seq[index]
                aligned_set.add(index)

                if curr_tok in mle_map:
                    (category, node_repr, wiki_label) = mle_map[curr_tok]
                    if category.lower() == 'none':
                        all_spans.append(
                            (index, index + 1, (curr_tok, "NONE", "NONE")))
                    else:
                        all_spans.append((index, index + 1, mle_map[curr_tok]))
                else:

                    if curr_tok[0] in '\"\'.':
                        print 'weird token: %s, %s' % (curr_tok, curr_pos)
                        continue
                    if index in date_set:
                        entity_typ = ('DATE', "date-entity", "NONE")
                        all_spans.append((index, index + 1, entity_typ))
                    elif curr_tok in VERB_LIST:
                        node_repr = VERB_LIST[curr_tok][0].keys()[0]
                        entity_typ = ('VERBAL', node_repr, "NONE")
                        all_spans.append((index, index + 1, entity_typ))

                    elif curr_pos[0] == 'V':
                        node_repr = '%s-01' % curr_tok
                        all_spans.append(
                            (index, index + 1, ('-VERB-', node_repr, "NONE")))
                    else:
                        node_repr = curr_tok
                        all_spans.append(
                            (index, index + 1, ('-SURF-', curr_tok, "NONE")))

            all_spans = sorted(all_spans, key=lambda span: (span[0], span[1]))
            print all_spans
            linearized_tokseq, map_repr_seq = getIndexedForm(all_spans)

            print >> tokseq_wf, ' '.join(linearized_tokseq)
            print >> dev_map_wf, '##'.join(map_repr_seq)

        tokseq_wf.close()
        dev_map_wf.close()
Пример #5
0
def extract_features(args):

    tok_seqs = read_toks(args.tok_file)
    lemma_seqs = read_toks(args.lemma_file)
    pos_seqs = read_toks(args.pos_file)

    print 'A total of %d sentences' % len(tok_seqs)

    assert len(tok_seqs) == len(pos_seqs)
    assert len(tok_seqs) == len(lemma_seqs)

    stop_words = set([line.strip() for line in open(args.stop, 'r')])

    non_map_words = set(['am', 'is', 'are', 'be', 'a', 'an', 'the'])

    der_lemma_file = os.path.join(args.lemma_dir, 'der.lemma')
    der_lemma_map = initialize_lemma(der_lemma_file)

    train_entity_set = load_entities(args.stats_dir)
    all_entities = identify_entities(args.tok_file, args.ner_file,
                                     train_entity_set)

    (pred_set, pred_lemma_set, pred_labels, non_pred_mapping,
     non_pred_lemma_mapping, entity_labels) = load_mappings(args.stats_dir)

    feature_f = open(args.feature_file, 'w')
    #Each entity take the form (start, end, role)
    for (i, (toks, lemmas, pos_seq, entities_in_sent)) in enumerate(
            zip(tok_seqs, lemma_seqs, pos_seqs, all_entities)):
        n_toks = len(toks)
        aligned_set = set()

        all_spans = []
        assert len(toks) == len(lemmas)
        assert len(toks) == len(pos_seq)

        for (start, end, entity_typ) in entities_in_sent:
            new_aligned = set(xrange(start, end))
            aligned_set |= new_aligned
            all_spans.append((start, end, False, False, True))
            assert end <= len(toks)

        for index in xrange(n_toks):
            if index in aligned_set:
                continue

            curr_tok = toks[index]
            curr_lem = lemmas[index]
            curr_pos = pos_seq[index]

            aligned_set.add(index)

            if curr_tok in pred_set or curr_lem in pred_set or curr_lem in pred_lemma_set:
                all_spans.append((index, index + 1, True, False, False))

            elif curr_tok in non_map_words:
                all_spans.append((index, index + 1, False, False, False))
            elif curr_tok in non_pred_mapping or curr_lem in non_pred_lemma_mapping:
                all_spans.append((index, index + 1, False, False, False))
            else:  #not found in any mapping
                retrieved = False
                if curr_tok in der_lemma_map:
                    for tok in der_lemma_map[curr_tok]:
                        if tok in pred_set or tok in pred_lemma_set:
                            if curr_tok.endswith('ion') or curr_tok.endswith(
                                    'er'):
                                all_spans.append(
                                    (index, index + 1, True, False, False))
                                retrieved = True
                                break
                if not retrieved:
                    all_spans.append((index, index + 1, False, False, False))
            assert index < len(toks)

        all_spans = sorted(all_spans, key=lambda span: (span[0], span[1]))
        for (start, end, is_pred, is_op, is_ent) in all_spans:
            fs = []
            end -= 1
            #print start, end, len(toks)
            fs += extract_span(toks, start, end, 3, 'word')
            fs += extract_bigram(toks, start, end, 3, 'word')
            fs += extract_curr(toks, start, end, 'word')
            if is_ent:
                fs += extract_seq_feat(toks, start, end, 'word')

            #Lemma feature
            fs += extract_span(lemmas, start, end, 3, 'lemma')
            fs += extract_bigram(lemmas, start, end, 3, 'lemma')
            fs += extract_curr(lemmas, start, end, 'lemma')

            #Pos tag feature
            fs += extract_span(pos_seq, start, end, 3, 'POS')
            fs += extract_bigram(pos_seq, start, end, 3, 'POS')
            fs += extract_curr(pos_seq, start, end, 'POS')

            #Length of span feature
            fs.append('Length=%d' % (end - start))

            #Suffix feature
            if not is_ent and start == end:
                fs += suffix(toks[start])

            print >> feature_f, '##### %d-%d %s %s %s' % (
                start, end + 1, '1' if is_pred else '0',
                '1' if is_ent else '0', ' '.join(fs))
        print >> feature_f, ''
    feature_f.close()