def gen_tf_records_files( vocab_file, category_map_file, mentions_file, valid_idxs_file, tok_sents_file, max_seq_len, train_output_file, valid_output_file): tokenizer = tokenization.SpaceTokenizer(vocab_file) category_map_dict, category_id_dict = datautils.load_category_mapping(category_map_file) mentions = datautils.read_json_objs(mentions_file) valid_idxs = set(datautils.read_json_objs(valid_idxs_file)[0]) writer_train = tf.python_io.TFRecordWriter(train_output_file) writer_valid = tf.python_io.TFRecordWriter(valid_output_file) f = open(tok_sents_file, encoding='utf-8') for i, tokens_str in enumerate(f): mention = mentions[i] token_span = mention['token_span'] categories = mention['tcategory'] y_categories = onehot_encode(categories, category_id_dict) input_ids, input_mask, segment_ids, tokens = get_tfrec_example(tokens_str.lower(), max_seq_len, tokenizer) features = __get_feature_dict(input_ids, input_mask, segment_ids, len(tokens), token_span, y_categories) tf_example = tf.train.Example(features=tf.train.Features(feature=features)) if i in valid_idxs: writer_valid.write(tf_example.SerializeToString()) else: writer_train.write(tf_example.SerializeToString()) # if i > 500: # break f.close() writer_train.close() writer_valid.close()
def load_hyp_preds(hypext_file, verif_hypext_file, hyp_verif_logits_file, use_verif_logits=True): from utils import datautils hypext_results = datautils.read_json_objs(hypext_file) verif_hypext_results = datautils.read_json_objs(verif_hypext_file) with open(hyp_verif_logits_file, encoding='utf-8') as f: verif_logits = [float(line.strip()) for line in f] assert len(verif_hypext_results) == len(verif_logits) if len(hypext_results) != len(verif_hypext_results): print('len(hypext_results) != len(verif_hypext_results)') verif_hypext_result_dict = { r['mention_id']: (i, r) for i, r in enumerate(verif_hypext_results) } hypext_results_dict = dict() for r in hypext_results: mention_id = r['mention_id'] tmp = verif_hypext_result_dict.get(mention_id) if tmp is None: continue verif_result_idx, verif_result = tmp if not use_verif_logits or verif_logits[verif_result_idx] > 0: hypext_results_dict[mention_id] = (r, verif_logits[verif_result_idx]) return hypext_results_dict
def samples_from_man_labeled(token_id_dict, unknown_token_id, type_id_dict, mentions_file, sents_file, dep_tags_file, srl_results_file, man_label_file): labeled_samples = datautils.real_manual_label_file(man_label_file, None) mentions = datautils.read_json_objs(mentions_file) sents = datautils.read_json_objs(sents_file) sent_dict = {sent['sent_id']: (i, sent) for i, sent in enumerate(sents)} dep_tag_seq_list = None if dep_tags_file is not None: with open(dep_tags_file, encoding='utf-8') as f: dep_tag_seq_list = [ datautils.next_sent_dependency(f) for _ in range(len(sents)) ] srl_results_list = datautils.read_srl_results(srl_results_file) print(len(sents), len(srl_results_list)) mid_manual_label_dict = {x[0]: x[2] for x in labeled_samples} samples = list() for m in mentions: manual_labels = mid_manual_label_dict.get(m['mention_id'], None) if manual_labels is None: continue # print(m) # print(manual_labels) # exit() mspan = m['span'] sent_idx, sent = sent_dict[m['sent_id']] sent_tokens = sent['text'].split(' ') dep_tag_seq = dep_tag_seq_list[ sent_idx] if dep_tag_seq_list is not None else None srl_results = srl_results_list[sent_idx] matched_tag_list, matched_tag_spans_list = utils.match_srl_to_mentions_all( sent_tokens, srl_results, mspan, dep_tag_seq) if not matched_tag_list: continue # type_labels = m.get('labels', ['/PERSON']) type_ids = [type_id_dict[t] for t in manual_labels] for matched_tag, matched_tag_spans in zip(matched_tag_list, matched_tag_spans_list): matched_tag_pos = int(matched_tag[-1:]) srl_info = (utils.get_srl_tag_span(matched_tag_spans, 'V'), utils.get_srl_tag_span(matched_tag_spans, 'ARG0'), utils.get_srl_tag_span(matched_tag_spans, 'ARG1'), utils.get_srl_tag_span(matched_tag_spans, 'ARG2'), matched_tag_pos) sent_token_ids = [ token_id_dict.get(token, unknown_token_id) for token in sent_tokens ] sample = (m['mention_id'], m['str'], mspan[0], mspan[1], None, type_ids, sent_token_ids, srl_info) samples.append(sample) return samples
def __init__(self, base_preds_file, srl_preds_file, hypext_file, verif_hypext_file, hypext_logits_file): base_pred_objs = datautils.read_json_objs(base_preds_file) self.base_preds_dict = {x['mention_id']: x for x in base_pred_objs} srl_pred_objs = datautils.read_json_objs(srl_preds_file) self.srl_preds_dict = get_srl_pred_dict(srl_pred_objs) self.hyp_preds_dict = load_hyp_preds(hypext_file, verif_hypext_file, hypext_logits_file)
def gen_mrel_title_to_wid_file(cands_files, output_file): title_wid_dict = dict() for cands_file in cands_files: cands_objs = datautils.read_json_objs(cands_file) for cands_obj in cands_objs: for wid, title in zip(cands_obj['wids'], cands_obj['titles']): title_wid_dict[title] = wid datautils.save_csv(list(title_wid_dict.items()), ['title', 'wid'], output_file)
def samples_from_test(gres: ResData, child_type_vecs, test_file_tup): (mentions_file, sents_file, base_preds_file, srl_preds_file, hypext_file, verif_hypext_file, hypext_logits_file) = test_file_tup prc = PredResultCollect(base_preds_file, srl_preds_file, hypext_file, verif_hypext_file, hypext_logits_file) mentions = datautils.read_json_objs(mentions_file) true_labels_dict = {m['mention_id']: m['labels'] for m in mentions} sents = datautils.read_json_objs(sents_file) samples = list() for i, m in enumerate(mentions): mention_id = m['mention_id'] base_logits, srl_logits, hyp_logits, hyp_verif_logit = get_pred_results( prc, gres.n_types, gres.type_id_dict, child_type_vecs, mention_id) labels = m['labels'] label_ids = [gres.type_id_dict[t] for t in labels] sample = (mention_id, base_logits, srl_logits, hyp_logits, hyp_verif_logit, label_ids) samples.append(sample) return samples, true_labels_dict
def model_samples_from_json(token_id_dict, unknown_token_id, mention_token_id, type_id_dict, mentions_file, sents_file): sent_tokens_dict = datautils.read_sents_to_token_id_seq_dict( sents_file, token_id_dict, unknown_token_id) samples = list() mentions = datautils.read_json_objs(mentions_file) for m in mentions: sample = get_model_sample(m['mention_id'], mention_str=m['str'], mention_span=m['span'], sent_tokens=sent_tokens_dict[m['sent_id']], mention_token_id=mention_token_id) samples.append(sample) return samples
def model_samples_from_json(config, token_id_dict, unknown_token_id, type_id_dict, mentions_file, sents_file): if config.use_bert: tokenizer = BertTokenizer.from_pretrained('bert-base-cased', do_lower_case=False) print('bert tokenizer loaded') sent_tokens_id_dict = dict() sent_tokens_dict = dict() with open(sents_file, encoding='utf-8') as f: for line in f: sent = json.loads(line) tokens = sent['text'].split(' ') sent_tokens_id_dict[sent['sent_id']] = [token_id_dict.get(t, unknown_token_id) for t in tokens] sent_tokens_dict[sent['sent_id']] = [t for t in tokens] samples = list() mentions = datautils.read_json_objs(mentions_file) for m in mentions: if config.use_bert: org_tok_sents = sent_tokens_dict[m['sent_id']] bert_sent_tokens = org_tok_sents[:m['span'][0]] + ['[MASK]'] + org_tok_sents[m['span'][1]:] full_sent = ' '.join(bert_sent_tokens) tokens = ["[CLS]"] t = tokenizer.tokenize(full_sent) tokens.extend(t) mention_token_idx = 0 for i, x in enumerate(tokens): if x == '[MASK]': mention_token_idx = i break tokens.append("[SEP]") sentence_token = tokenizer.convert_tokens_to_ids(tokens) else: sentence_token = sent_tokens_id_dict[m['sent_id']] mention_token_idx = m['span'][0] labels = m['labels'] label_ids = [type_id_dict[t] for t in labels] sample = [m['mention_id'], sent_tokens_id_dict[m['sent_id']][m['span'][0]:m['span'][1]], sentence_token, mention_token_idx, label_ids ] samples.append(sample) return samples
def check_retrieved_sents(): import tensorflow as tf output_file = os.path.join(config.DATA_DIR, 'tmp/uf_wia_results_200_ins.txt') # results_file = os.path.join(config.DATA_DIR, 'realm_output/uf_wia_results_nm.txt') results_file = os.path.join(config.DATA_DIR, 'realm_output/uf_wia_results_200.txt') samples_file = os.path.join(config.DATA_DIR, 'ultrafine/uf_data/crowd/test.json') wia_block_records_path = os.path.join(config.DATA_DIR, 'ultrafine/zoutput/webisa_full_uffilter.tfr') wia_block_labels_file = os.path.join(config.DATA_DIR, 'ultrafine/zoutput/webisa_full_uffilter_labels.txt') samples = datautils.read_json_objs(samples_file) blocks_dataset = tf.data.TFRecordDataset( wia_block_records_path, buffer_size=512 * 1024 * 1024) sents = list() for i, sent in enumerate(blocks_dataset): sents.append(sent.numpy().decode('utf-8')) if i % 500000 == 0: print(i) with open(wia_block_labels_file, encoding='utf-8') as f: wia_labels = [line.strip() for line in f] fout = open(output_file, 'w', encoding='utf-8') f = open(results_file, encoding='utf-8') for i, line in enumerate(f): x = json.loads(line) # print(x) bids = x['block_ids'] # print(samples[i]) uf_sample_str = get_uf_sample_str(samples[i]) fout.write('{}\n{}\n'.format(uf_sample_str, samples[i]['y_str'])) # print(uf_sample_str) # print(x['y_str']) for bid in bids: fout.write('{}\n'.format(sents[bid])) fout.write('{}\n'.format(wia_labels[bid])) # print(sents[bid]) fout.write('\n') # print() # if i > 3: # break f.close() fout.close()
def train_fetel(device, gres: exputils.GlobalRes, el_entityvec: ELDirectEntityVec, train_samples_pkl, dev_samples_pkl, test_mentions_file, test_sents_file, test_noel_preds_file, type_embed_dim, context_lstm_hidden_dim, learning_rate, batch_size, n_iter, dropout, rand_per, per_penalty, use_mlp=False, pred_mlp_hdim=None, save_model_file=None, nil_rate=0.5, single_type_path=False, stack_lstm=False, concat_lstm=False, results_file=None): logging.info('result_file={}'.format(results_file)) logging.info( 'type_embed_dim={} cxt_lstm_hidden_dim={} pmlp_hdim={} nil_rate={} single_type_path={}'.format( type_embed_dim, context_lstm_hidden_dim, pred_mlp_hdim, nil_rate, single_type_path)) logging.info('rand_per={} per_pen={}'.format(rand_per, per_penalty)) logging.info('stack_lstm={} cat_lstm={}'.format(stack_lstm, concat_lstm)) if stack_lstm: model = FETELStack( device, gres.type_vocab, gres.type_id_dict, gres.embedding_layer, context_lstm_hidden_dim, type_embed_dim=type_embed_dim, dropout=dropout, use_mlp=use_mlp, mlp_hidden_dim=pred_mlp_hdim, concat_lstm=concat_lstm) else: model = None if device.type == 'cuda': model = model.cuda(device.index) train_samples = datautils.load_pickle_data(train_samples_pkl) dev_samples = datautils.load_pickle_data(dev_samples_pkl) dev_samples = anchor_samples_to_model_samples(dev_samples, gres.mention_token_id, gres.parent_type_ids_dict) lr_gamma = 0.7 eval_batch_size = 32 logging.info('{}'.format(model.__class__.__name__)) dev_true_labels_dict = {s.mention_id: [gres.type_vocab[l] for l in s.labels] for s in dev_samples} dev_entity_vecs, dev_el_sgns, dev_el_probs = __get_entity_vecs_for_samples(el_entityvec, dev_samples, None) test_samples = model_samples_from_json(gres.token_id_dict, gres.unknown_token_id, gres.mention_token_id, gres.type_id_dict, test_mentions_file, test_sents_file) test_noel_pred_results = datautils.read_pred_results_file(test_noel_preds_file) test_mentions = datautils.read_json_objs(test_mentions_file) test_entity_vecs, test_el_sgns, test_el_probs = __get_entity_vecs_for_mentions( el_entityvec, test_mentions, test_noel_pred_results, gres.n_types) test_true_labels_dict = {m['mention_id']: m['labels'] for m in test_mentions} if ( 'labels' in next(iter(test_mentions))) else None person_type_id = gres.type_id_dict.get('/person') l2_person_type_ids, person_loss_vec = None, None if person_type_id is not None: l2_person_type_ids = __get_l2_person_type_ids(gres.type_vocab) person_loss_vec = exputils.get_person_type_loss_vec( l2_person_type_ids, gres.n_types, per_penalty, model.device) dev_results_file = None n_batches = (len(train_samples) + batch_size - 1) // batch_size optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=n_batches, gamma=lr_gamma) losses = list() best_dev_acc = -1 logging.info('{} steps, {} steps per iter, lr_decay={}, start training ...'.format( n_iter * n_batches, n_batches, lr_gamma)) step = 0 n_steps = n_iter * n_batches while step < n_steps: batch_idx = step % n_batches batch_beg, batch_end = batch_idx * batch_size, min((batch_idx + 1) * batch_size, len(train_samples)) batch_samples = anchor_samples_to_model_samples( train_samples[batch_beg:batch_end], gres.mention_token_id, gres.parent_type_ids_dict) if rand_per: entity_vecs, el_sgns, el_probs = __get_entity_vecs_for_samples( el_entityvec, batch_samples, None, True, person_type_id, l2_person_type_ids, gres.type_vocab) else: entity_vecs, el_sgns, el_probs = __get_entity_vecs_for_samples(el_entityvec, batch_samples, None, True) use_entity_vecs = True model.train() (context_token_seqs, mention_token_idxs, mstrs, mstr_token_seqs, y_true ) = exputils.get_mstr_cxt_label_batch_input(model.device, gres.n_types, batch_samples) if use_entity_vecs: for i in range(entity_vecs.shape[0]): if np.random.uniform() < nil_rate: entity_vecs[i] = np.zeros(entity_vecs.shape[1], np.float32) el_probs = torch.tensor(el_probs, dtype=torch.float32, device=model.device) entity_vecs = torch.tensor(entity_vecs, dtype=torch.float32, device=model.device) else: entity_vecs = None logits = model(context_token_seqs, mention_token_idxs, mstr_token_seqs, entity_vecs, el_probs) loss = model.get_loss(y_true, logits, person_loss_vec=person_loss_vec) scheduler.step() optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 10.0, float('inf')) optimizer.step() losses.append(loss.data.cpu().numpy()) step += 1 if step % 1000 == 0: # logging.info('i={} l={:.4f}'.format(step + 1, sum(losses))) acc_v, pacc_v, _, _, dev_results = eval_fetel( gres, model, dev_samples, dev_entity_vecs, dev_el_probs, eval_batch_size, use_entity_vecs=use_entity_vecs, single_type_path=single_type_path, true_labels_dict=dev_true_labels_dict) acc_t, _, maf1, mif1, test_results = eval_fetel( gres, model, test_samples, test_entity_vecs, test_el_probs, eval_batch_size, use_entity_vecs=use_entity_vecs, single_type_path=single_type_path, true_labels_dict=test_true_labels_dict) best_tag = '*' if acc_v > best_dev_acc else '' logging.info( 'i={} l={:.4f} accv={:.4f} paccv={:.4f} acct={:.4f} maf1={:.4f} mif1={:.4f}{}'.format( step, sum(losses), acc_v, pacc_v, acc_t, maf1, mif1, best_tag)) if acc_v > best_dev_acc and save_model_file: torch.save(model.state_dict(), save_model_file) logging.info('model saved to {}'.format(save_model_file)) if dev_results_file is not None and acc_v > best_dev_acc: datautils.save_json_objs(dev_results, dev_results_file) logging.info('dev reuslts saved {}'.format(dev_results_file)) if results_file is not None and acc_v > best_dev_acc: datautils.save_json_objs(test_results, results_file) logging.info('test reuslts saved {}'.format(results_file)) if acc_v > best_dev_acc: best_dev_acc = acc_v losses = list()