def _map_senses(svsm, tokens, postags=[], lemmas=[], use_postag=False, use_lemma=False): """Given loaded LMMS and a list of tokens, returns a list of scored sensekeys.""" matches = [] if len(tokens) != len(postags): # mismatched use_postag = False if len(tokens) != len(lemmas): # mismatched use_lemma = False from lmms_api.bert_as_service import bert_embed sent_bert = bert_embed([" ".join(tokens)], merge_strategy="mean")[0] for idx in range(len(tokens)): idx_vec = sent_bert[idx][1] idx_vec = idx_vec / np.linalg.norm(idx_vec) if svsm.ndims == 1024: # idx_vec = idx_vec pass elif svsm.ndims == 1024 + 1024: idx_vec = np.hstack((idx_vec, idx_vec)) idx_vec = idx_vec / np.linalg.norm(idx_vec) idx_matches = [] if use_lemma and use_postag: idx_matches = svsm.match_senses( idx_vec, lemmas[idx], postags[idx], topn=None ) elif use_lemma: idx_matches = svsm.match_senses(idx_vec, lemmas[idx], None, topn=None) elif use_postag: idx_matches = svsm.match_senses(idx_vec, None, postags[idx], topn=None) else: idx_matches = svsm.match_senses(idx_vec, None, None, topn=None) matches.append(idx_matches) return matches
eval_instances = load_wsd_fw_set(wsd_fw_set_path) """ Iterate over evaluation instances and write predictions in WSD_Evaluation_Framework's format. File with predictions is processed by the official scorer after iterating over all instances. """ results_path = 'data/results/%d.%s.%s.key' % (int( time()), args.test_set, args.merge_strategy) with open(results_path, 'w') as results_f: for batch_idx, batch in enumerate( chunks(eval_instances, args.batch_size)): batch_sents = [ sent_info['tokenized_sentence'] for sent_info in batch ] # process contextual embeddings in sentences batches of size args.batch_size batch_bert = bert_embed(batch_sents, merge_strategy=args.merge_strategy) for sent_info, sent_bert in zip(batch, batch_bert): idx_map_abs = sent_info['idx_map_abs'] for mw_idx, tok_idxs in idx_map_abs: curr_sense = sent_info['senses'][mw_idx] if curr_sense is None: continue curr_lemma = sent_info['lemmas'][mw_idx] if args.use_lemma and curr_lemma not in senses_vsm.known_lemmas: continue # skips hurt performance in official scorer
choices=['train', 'dev', 'test']) args = parser.parse_args() results_path = 'data/results/wic.compare.%s.txt' % args.eval_set logging.info('Loading SensesVSM ...') senses_vsm = SensesVSM(args.lmms_path, normalize=True) logging.info('Processing sentences ...') n_instances, n_correct = 0, 0 with open(results_path, 'w') as results_f: # store results in WiC's format for wic_idx, wic_entry in enumerate( load_wic(args.eval_set, wic_path='external/wic')): word, postag, idx1, idx2, ex1, ex2, gold = wic_entry bert_ex1, bert_ex2 = bert_embed([ex1, ex2], merge_strategy='mean') # example1 ex1_curr_word, ex1_curr_vector = bert_ex1[idx1] ex1_curr_lemma = wn_lemmatize(word, postag) ex1_curr_vector = ex1_curr_vector / np.linalg.norm(ex1_curr_vector) if senses_vsm.ndims == 1024: ex1_curr_vector = ex1_curr_vector elif senses_vsm.ndims == 1024 + 1024: ex1_curr_vector = np.hstack((ex1_curr_vector, ex1_curr_vector)) ex1_curr_vector = ex1_curr_vector / np.linalg.norm(ex1_curr_vector) ex1_matches = senses_vsm.match_senses(ex1_curr_vector, lemma=ex1_curr_lemma,
def train(train_path, eval_path, vecs_path, merge_strategy='mean', max_seq_len=512, max_instances=float('inf')): sense_vecs = {} sense_mapping = get_sense_mapping(eval_path) batch, batch_idx, batch_t0 = [], 0, time() for sent_idx, sent_et in enumerate(read_xml_sents(train_path)): entry = { f: [] for f in ['token', 'token_mw', 'lemma', 'senses', 'pos', 'id'] } for ch in sent_et.getchildren(): for k, v in ch.items(): entry[k].append(v) entry['token_mw'].append(ch.text) if 'id' in ch.attrib.keys(): entry['senses'].append(sense_mapping[ch.attrib['id']]) else: entry['senses'].append(None) entry['token'] = sum([t.split() for t in entry['token_mw']], []) entry['sentence'] = ' '.join([t for t in entry['token_mw']]) bert_tokens = bert_tokenizer.tokenize(entry['sentence']) if len(bert_tokens) < max_seq_len: batch.append(entry) if len(batch) == args.batch_size: batch_sents = [e['sentence'] for e in batch] batch_bert = bert_embed(batch_sents, merge_strategy=merge_strategy) for sent_info, sent_bert in zip(batch, batch_bert): # handling multi-word expressions, mapping allows matching tokens with mw features idx_map_abs = [] idx_map_rel = [(i, list(range(len(t.split())))) for i, t in enumerate(sent_info['token_mw'])] token_counter = 0 for idx_group, idx_tokens in idx_map_rel: # converting relative token positions to absolute idx_tokens = [i + token_counter for i in idx_tokens] token_counter += len(idx_tokens) idx_map_abs.append([idx_group, idx_tokens]) for mw_idx, tok_idxs in idx_map_abs: if sent_info['senses'][mw_idx] is None: continue vec = np.array([sent_bert[i][1] for i in tok_idxs], dtype=np.float32).mean(axis=0) for sense in sent_info['senses'][mw_idx]: try: if sense_vecs[sense]['vecs_num'] < max_instances: sense_vecs[sense]['vecs_sum'] += vec sense_vecs[sense]['vecs_num'] += 1 except KeyError: sense_vecs[sense] = { 'vecs_sum': vec, 'vecs_num': 1 } batch_tspan = time() - batch_t0 logging.info( '%.3f sents/sec - %d sents, %d senses' % (args.batch_size / batch_tspan, sent_idx, len(sense_vecs))) batch, batch_t0 = [], time() batch_idx += 1 logging.info('#sents: %d' % sent_idx) logging.info('Writing Sense Vectors ...') with open(vecs_path, 'w') as vecs_f: for sense, vecs_info in sense_vecs.items(): vec = vecs_info['vecs_sum'] / vecs_info['vecs_num'] vec_str = ' '.join([str(round(v, 6)) for v in vec.tolist()]) vecs_f.write('%s %s\n' % (sense, vec_str)) logging.info('Written %s' % vecs_path)