Exemple #1
0
def _map_senses(svsm, tokens, postags=[], lemmas=[], use_postag=False, use_lemma=False):
    """Given loaded LMMS and a list of tokens, returns a list of scored sensekeys."""

    matches = []

    if len(tokens) != len(postags):  # mismatched
        use_postag = False

    if len(tokens) != len(lemmas):  # mismatched
        use_lemma = False

    from lmms_api.bert_as_service import bert_embed
    sent_bert = bert_embed([" ".join(tokens)], merge_strategy="mean")[0]

    for idx in range(len(tokens)):
        idx_vec = sent_bert[idx][1]
        idx_vec = idx_vec / np.linalg.norm(idx_vec)

        if svsm.ndims == 1024:
            # idx_vec = idx_vec
            pass

        elif svsm.ndims == 1024 + 1024:
            idx_vec = np.hstack((idx_vec, idx_vec))
            idx_vec = idx_vec / np.linalg.norm(idx_vec)

        idx_matches = []
        if use_lemma and use_postag:
            idx_matches = svsm.match_senses(
                idx_vec, lemmas[idx], postags[idx], topn=None
            )

        elif use_lemma:
            idx_matches = svsm.match_senses(idx_vec, lemmas[idx], None, topn=None)

        elif use_postag:
            idx_matches = svsm.match_senses(idx_vec, None, postags[idx], topn=None)

        else:
            idx_matches = svsm.match_senses(idx_vec, None, None, topn=None)

        matches.append(idx_matches)

    return matches
Exemple #2
0
    eval_instances = load_wsd_fw_set(wsd_fw_set_path)
    """
    Iterate over evaluation instances and write predictions in WSD_Evaluation_Framework's format.
    File with predictions is processed by the official scorer after iterating over all instances.
    """
    results_path = 'data/results/%d.%s.%s.key' % (int(
        time()), args.test_set, args.merge_strategy)
    with open(results_path, 'w') as results_f:
        for batch_idx, batch in enumerate(
                chunks(eval_instances, args.batch_size)):
            batch_sents = [
                sent_info['tokenized_sentence'] for sent_info in batch
            ]

            # process contextual embeddings in sentences batches of size args.batch_size
            batch_bert = bert_embed(batch_sents,
                                    merge_strategy=args.merge_strategy)

            for sent_info, sent_bert in zip(batch, batch_bert):
                idx_map_abs = sent_info['idx_map_abs']

                for mw_idx, tok_idxs in idx_map_abs:
                    curr_sense = sent_info['senses'][mw_idx]

                    if curr_sense is None:
                        continue

                    curr_lemma = sent_info['lemmas'][mw_idx]

                    if args.use_lemma and curr_lemma not in senses_vsm.known_lemmas:
                        continue  # skips hurt performance in official scorer
                        choices=['train', 'dev', 'test'])
    args = parser.parse_args()

    results_path = 'data/results/wic.compare.%s.txt' % args.eval_set

    logging.info('Loading SensesVSM ...')
    senses_vsm = SensesVSM(args.lmms_path, normalize=True)

    logging.info('Processing sentences ...')
    n_instances, n_correct = 0, 0
    with open(results_path, 'w') as results_f:  # store results in WiC's format
        for wic_idx, wic_entry in enumerate(
                load_wic(args.eval_set, wic_path='external/wic')):
            word, postag, idx1, idx2, ex1, ex2, gold = wic_entry

            bert_ex1, bert_ex2 = bert_embed([ex1, ex2], merge_strategy='mean')

            # example1
            ex1_curr_word, ex1_curr_vector = bert_ex1[idx1]
            ex1_curr_lemma = wn_lemmatize(word, postag)
            ex1_curr_vector = ex1_curr_vector / np.linalg.norm(ex1_curr_vector)

            if senses_vsm.ndims == 1024:
                ex1_curr_vector = ex1_curr_vector

            elif senses_vsm.ndims == 1024 + 1024:
                ex1_curr_vector = np.hstack((ex1_curr_vector, ex1_curr_vector))

            ex1_curr_vector = ex1_curr_vector / np.linalg.norm(ex1_curr_vector)
            ex1_matches = senses_vsm.match_senses(ex1_curr_vector,
                                                  lemma=ex1_curr_lemma,
Exemple #4
0
def train(train_path,
          eval_path,
          vecs_path,
          merge_strategy='mean',
          max_seq_len=512,
          max_instances=float('inf')):

    sense_vecs = {}
    sense_mapping = get_sense_mapping(eval_path)

    batch, batch_idx, batch_t0 = [], 0, time()
    for sent_idx, sent_et in enumerate(read_xml_sents(train_path)):
        entry = {
            f: []
            for f in ['token', 'token_mw', 'lemma', 'senses', 'pos', 'id']
        }
        for ch in sent_et.getchildren():
            for k, v in ch.items():
                entry[k].append(v)
            entry['token_mw'].append(ch.text)

            if 'id' in ch.attrib.keys():
                entry['senses'].append(sense_mapping[ch.attrib['id']])
            else:
                entry['senses'].append(None)

        entry['token'] = sum([t.split() for t in entry['token_mw']], [])
        entry['sentence'] = ' '.join([t for t in entry['token_mw']])

        bert_tokens = bert_tokenizer.tokenize(entry['sentence'])
        if len(bert_tokens) < max_seq_len:
            batch.append(entry)

        if len(batch) == args.batch_size:

            batch_sents = [e['sentence'] for e in batch]
            batch_bert = bert_embed(batch_sents, merge_strategy=merge_strategy)

            for sent_info, sent_bert in zip(batch, batch_bert):
                # handling multi-word expressions, mapping allows matching tokens with mw features
                idx_map_abs = []
                idx_map_rel = [(i, list(range(len(t.split()))))
                               for i, t in enumerate(sent_info['token_mw'])]
                token_counter = 0
                for idx_group, idx_tokens in idx_map_rel:  # converting relative token positions to absolute
                    idx_tokens = [i + token_counter for i in idx_tokens]
                    token_counter += len(idx_tokens)
                    idx_map_abs.append([idx_group, idx_tokens])

                for mw_idx, tok_idxs in idx_map_abs:
                    if sent_info['senses'][mw_idx] is None:
                        continue

                    vec = np.array([sent_bert[i][1] for i in tok_idxs],
                                   dtype=np.float32).mean(axis=0)

                    for sense in sent_info['senses'][mw_idx]:
                        try:
                            if sense_vecs[sense]['vecs_num'] < max_instances:
                                sense_vecs[sense]['vecs_sum'] += vec
                                sense_vecs[sense]['vecs_num'] += 1
                        except KeyError:
                            sense_vecs[sense] = {
                                'vecs_sum': vec,
                                'vecs_num': 1
                            }

            batch_tspan = time() - batch_t0
            logging.info(
                '%.3f sents/sec - %d sents, %d senses' %
                (args.batch_size / batch_tspan, sent_idx, len(sense_vecs)))

            batch, batch_t0 = [], time()
            batch_idx += 1

    logging.info('#sents: %d' % sent_idx)

    logging.info('Writing Sense Vectors ...')
    with open(vecs_path, 'w') as vecs_f:
        for sense, vecs_info in sense_vecs.items():
            vec = vecs_info['vecs_sum'] / vecs_info['vecs_num']
            vec_str = ' '.join([str(round(v, 6)) for v in vec.tolist()])
            vecs_f.write('%s %s\n' % (sense, vec_str))
    logging.info('Written %s' % vecs_path)