def test(parser, vocab, num_buckets_test, test_batch_size, test_file, output_file): data_loader = DataLoader(test_file, num_buckets_test, vocab) record = data_loader.idx_sequence results = [None] * len(record) idx = 0 for words, tags, arcs, rels in data_loader.get_batches(batch_size=test_batch_size, shuffle=False): dy.renew_cg() outputs = parser.run(words, tags, isTrain=False) for output in outputs: sent_idx = record[idx] results[sent_idx] = output idx += 1 arcs = reduce(lambda x, y: x + y, [ list(result[0]) for result in results ]) rels = reduce(lambda x, y: x + y, [ list(result[1]) for result in results ]) idx = 0 with open(test_file) as f: with open(output_file, 'w') as fo: for line in f.readlines(): info = line.strip().split() if info: assert len(info) == 10, 'Illegal line: %s' % line info[6] = str(arcs[idx]) info[7] = vocab.id2rel(rels[idx]) fo.write('\t'.join(info) + '\n') idx += 1 else: fo.write('\n') os.system('perl eval.pl -q -b -g %s -s %s -o tmp' % (test_file, output_file)) os.system('tail -n 3 tmp > score_tmp') LAS, UAS = [float(line.strip().split()[-2]) for line in open('score_tmp').readlines()[:2]] print('LAS %.2f, UAS %.2f'%(LAS, UAS)) os.system('rm tmp score_tmp') return LAS, UAS
argparser.add_argument('--config_file', default='./config.cfg') argparser.add_argument('--model', default='BaseParser') args, extra_args = argparser.parse_known_args() config = Configurable(args.config_file, extra_args) Parser = getattr(models, args.model) vocab = Vocab(config.train_file, config.pretrained_embeddings_file, config.min_occur_count) cPickle.dump(vocab, open(config.save_vocab_path, 'w')) parser = Parser(vocab, config.word_dims, config.pret_dims, config.lemma_dims, config.flag_dims, config.tag_dims, config.dropout_emb, config.encoder_type, config.use_si_dropout, config.lstm_layers, config.lstm_hiddens, config.dropout_lstm_input, config.dropout_lstm_hidden, config.mlp_rel_size, config.dropout_mlp, config.transformer_layers, config.transformer_heads, config.transformer_hiddens, config.transformer_ffn, config.transformer_dropout, config.transformer_maxlen, config.transformer_max_timescale, config.use_lm, config.lm_path, config.lm_dims, config.lm_hidden_size, config.lm_sentences, config.use_pos, config.use_lemma, config.unified) data_loader = DataLoader(config.train_file, config.num_buckets_train, vocab) pc = parser.parameter_collection trainer = dy.AdamTrainer(pc, config.learning_rate , config.beta_1, config.beta_2, config.epsilon) # optimizer = OptimizerManager(parser, lr_scheduler_type=config.lr_scheduler_type, optim=config.optim, warmup_steps=config.warmup_steps, # eta=config.eta, patience=config.patience, clip=config.clip, # lr = config.learning_rate, beta1=config.beta_1, beta2 = config.beta_2, epsilon=config.epsilon) global_step = 0 def update_parameters(): trainer.learning_rate =config.learning_rate*config.decay**(global_step / config.decay_steps) trainer.update() epoch = 0 best_F1 = 0. history = lambda x, y : open(os.path.join(config.save_dir, 'valid_history'),'a').write('%.2f %.2f\n'%(x,y)) while global_step < config.train_iters:
argparser = argparse.ArgumentParser() argparser.add_argument('--config_file', default='../configs/default.cfg') argparser.add_argument('--model', default='BaseParser') args, extra_args = argparser.parse_known_args() config = Configurable(args.config_file, extra_args) Parser = getattr(models, args.model) vocab = Vocab(config.train_file, config.pretrained_embeddings_file, config.min_occur_count) pickle.dump(vocab, open(config.save_vocab_path, 'wb')) parser = Parser(vocab, config.word_dims, config.tag_dims, config.dropout_emb, config.lstm_layers, config.lstm_hiddens, config.dropout_lstm_input, config.dropout_lstm_hidden, config.mlp_arc_size, config.mlp_rel_size, config.dropout_mlp) data_loader = DataLoader(config.train_file, config.num_buckets_train, vocab) pc = parser.parameter_collection trainer = dy.AdamTrainer(pc, config.learning_rate, config.beta_1, config.beta_2, config.epsilon) global_step = 0 def update_parameters(): trainer.learning_rate = config.learning_rate * config.decay**( global_step / config.decay_steps) trainer.update() epoch = 0 best_UAS = 0. history = lambda x, y: open(os.path.join(config.save_dir, 'valid_history'), 'a').write('%.2f %.2f\n' % (x, y))
def test(parser, vocab, num_buckets_test, test_batch_size, unlabeled_test_file, labeled_test_file, raw_test_file, output_file): data_loader = DataLoader(unlabeled_test_file, num_buckets_test, vocab) record = data_loader.idx_sequence results = [None] * len(record) idx = 0 for words, lemmas, tags, arcs, rels in \ data_loader.get_batches(batch_size = test_batch_size, shuffle = False): dy.renew_cg() outputs = parser.run(words, lemmas, tags, arcs, isTrain=False) for output in outputs: sent_idx = record[idx] results[sent_idx] = output idx += 1 global gold_preds global gold_args global gold_rels global test_data if not gold_rels: print 'prepare test gold data' gold_rels = [] _predicate_cnt = 0 _preds_args = defaultdict(list) _end_flag = False with open(labeled_test_file) as f: for line in f: info = line.strip().split() if info: assert len(info) == 10, 'Illegal line: %s' % line if _end_flag and len(_preds_args) > 0 and int( info[6]) == 0: gold_rels.append(_preds_args) _preds_args = defaultdict(list) _end_flag = False _preds_args[int(info[6])].append(vocab.rel2id(info[7])) if info[7] != '_': if int(info[6]) == 0: gold_preds += 1 else: gold_args += 1 else: _end_flag = True if len(_preds_args) > 0: gold_rels.append(_preds_args) if not test_data: print 'prepare for writing out the prediction' with open(raw_test_file) as f: test_data = [] test_sent = [] for line in f: info = line.strip().split() if info: test_sent.append([ info[0], info[1], info[2], info[3], info[4], info[5], info[6], info[7], info[8], info[9], info[10], info[11], '_', '_' ]) elif len(test_sent) > 0: test_data.append(test_sent) test_sent = [] if len(test_sent) > 0: test_data.append(test_sent) test_sent = [] with open(output_file, 'w') as f: for test_sent, predict_sent, gold_sent in zip(test_data, results, gold_rels): for i, rel in enumerate(predict_sent[0]): if rel != vocab.NONE: test_sent[i][12] = 'Y' test_sent[i][13] = '%s.%s' % (test_sent[i][2], vocab.id2rel(rel)) for k in xrange(1, len(test_sent) + 1): if k in gold_sent and k in predict_sent: for i, (prel, grel) in enumerate( zip(predict_sent[k], gold_sent[k])): test_sent[i].append(vocab.id2rel(grel)) test_sent[i].append(vocab.id2rel(prel)) else: if k in gold_sent: for i, grel in enumerate(gold_sent[k]): test_sent[i].append(vocab.id2rel(grel)) if k in predict_sent: for i, prel in enumerate(predict_sent[k]): test_sent[i].append(vocab.id2rel(prel)) for tokens in test_sent: f.write('\t'.join(tokens)) f.write('\n') f.write('\n') predict_preds = 0. correct_preds = 0. predict_args = 0. correct_args = 0. num_correct = 0. total = 0. for psent, gsent in zip(results, gold_rels): predict_preds += len(psent) - 1 for g_pred, g_args in gsent.iteritems(): if g_pred == 0: p_args = psent.pop(g_pred) for p_rel, g_rel in zip(p_args, g_args): if g_rel != vocab.NONE and g_rel == p_rel: correct_preds += 1 else: if g_pred in psent: p_args = psent.pop(g_pred) for p_rel, g_rel in zip(p_args, g_args): total += 1 if p_rel != vocab.NONE: predict_args += 1 if g_rel != vocab.NONE and g_rel == p_rel: correct_args += 1 if g_rel == p_rel: num_correct += 1 else: for i in xrange(len(g_args)): total += 1 for p_pred, p_args in psent.iteritems(): for p_rel in p_args: if p_rel != vocab.NONE: predict_args += 1 print 'arguments: correct:%d, gold:%d, predicted:%d' % ( correct_args, gold_args, predict_args) print 'predicates: correct:%d, gold:%d, predicted:%d' % ( correct_preds, gold_preds, predict_preds) P = (correct_args + correct_preds) / (predict_args + predict_preds + 1e-13) R = (correct_args + correct_preds) / (gold_args + gold_preds + 1e-13) NP = correct_args / (predict_args + 1e-13) NR = correct_args / (gold_args + 1e-13) PP = correct_preds / (predict_preds + 1e-13) PR = correct_preds / (gold_preds + 1e-13) PF1 = 2 * PP * PR / (PP + PR + 1e-13) F1 = 2 * P * R / (P + R + 1e-13) NF1 = 2 * NP * NR / (NP + NR + 1e-13) print '\teval accurate:%.4f predict:%d golden:%d correct:%d' % \ (num_correct / total * 100, predict_args, gold_args, correct_args) print '\tP:%.4f R:%.4f F1:%.4f' % (P * 100, R * 100, F1 * 100) print '\tNP:%.4f NR:%.4f NF1:%.4f' % (NP * 100, NR * 100, NF1 * 100) print '\tcorrect predicate:%d \tgold predicate:%d' % (correct_preds, gold_preds) print '\tpredicate disambiguation PP:%.4f PR:%.4f PF1:%.4f' % ( PP * 100, PR * 100, PF1 * 100) os.system('perl ../lib/eval.pl -g %s -s %s > %s.eval' % (raw_test_file, output_file, output_file)) return NF1, F1
def test(parser, vocab, num_buckets_test, test_batch_size, pro_test_file, raw_test_file, output_file, unified = True, disambiguation_file = None, disambiguation_accuracy = None): if not unified: assert ((disambiguation_file is not None and len(disambiguation_file)>0) or disambiguation_accuracy is not None), \ 'The accuracy of predicate disambiguation shold be provied.' data_loader = DataLoader(pro_test_file, num_buckets_test, vocab) record = data_loader.idx_sequence results = [None] * len(record) idx = 0 for words, lemmas, tags, arcs, rels, syn_masks, seq_lens in \ data_loader.get_batches(batch_size = test_batch_size, shuffle = False): dy.renew_cg() outputs = parser.run(words, lemmas, tags, arcs, isTrain = False, syn_mask = syn_masks, seq_lens=seq_lens) for output in outputs: sent_idx = record[idx] results[sent_idx] = output idx += 1 rels = reduce(lambda x, y: x + y, [ list(result) for result in results ]) # global gold_rels # global test_data # global test_stat # if not gold_rels: gold_rels = [] gold_sent = [] with open(pro_test_file) as f: for line in f: info = line.strip().split() if info: assert len(info) == 11, 'Illegal line: %s' % line gold_sent.append(info[7]) else: if len(gold_sent)>0: gold_rels.append(gold_sent) gold_sent = [] if len(gold_sent)>0: gold_rels.append(gold_sent) # if not test_data: #print 'prepare for writing out the prediction' with open(raw_test_file) as f: test_data = [] test_sent = [] test_stat = [] pred_index = [] for line in f: info = line.strip().split() if info: test_sent.append([info[0], info[1], info[2], info[3], info[4], info[5], info[6], info[7], info[8], info[9], info[10], info[11], info[12], '_' if unified else info[13]]) if info[12] == 'Y': pred_index.append(int(info[0])) elif len(test_sent) > 0: test_data.append(test_sent) test_sent = [] test_stat.append(pred_index) pred_index = [] if len(test_sent) > 0: test_data.append(test_sent) test_sent = [] test_stat.append(pred_index) pred_index = [] if disambiguation_file is not None and len(disambiguation_file)>0: gold_disamb = [] gold_dsent = [] with open(disambiguation_file) as f: for line in f: info = line.strip().split() if info: gold_dsent.append(info[0]) else: if len(gold_dsent)>0: gold_disamb.append(gold_dsent) gold_dsent = [] if len(gold_dsent)>0: gold_disamb.append(gold_dsent) assert len(test_data) == len(gold_disamb) idx = 0 oidx = 0 with open(output_file, 'w') as f: for test_s, pred_index in zip(test_data, test_stat): test_sent = copy.deepcopy(test_s) for p_idx in pred_index: if unified: test_sent[p_idx - 1][13] = test_sent[p_idx - 1][2] + '.' + vocab.id2rel(results[idx][0]) else: if disambiguation_file is not None and len(disambiguation_file)>0: test_sent[p_idx - 1][13] = gold_disamb[oidx][p_idx - 1] else: if random.random() < disambiguation_accuracy: test_sent[p_idx - 1][13] = test_sent[p_idx - 1][13] else: test_sent[p_idx - 1][13] = test_sent[p_idx - 1][2] if unified: for i, p_rel in enumerate(results[idx][1:]): test_sent[i].append(vocab.id2rel(p_rel)) else: for i, p_rel in enumerate(results[idx]): test_sent[i].append(vocab.id2rel(p_rel)) idx += 1 oidx += 1 for tokens in test_sent: f.write('\t'.join(tokens)) f.write('\n') f.write('\n') predict_args = 0. correct_args = 0. gold_args = 0. correct_preds = 0. gold_preds = len(gold_rels) num_correct = 0. total = 0. idx = 0 for sent in gold_rels: for i in range(len(sent)): gold = sent[i] pred = rels[idx] if unified and i==0: if sent[i] == vocab.id2rel(pred): correct_preds += 1 else: total += 1 if vocab.id2rel(pred)!='_': predict_args += 1 if gold != '_': gold_args += 1 if gold != '_' and gold == vocab.id2rel(pred): correct_args += 1 if gold == vocab.id2rel(pred): num_correct += 1 idx += 1 correct_preds = correct_preds if unified else gold_preds * disambiguation_accuracy P = (correct_args + correct_preds) / (predict_args + gold_preds + 1e-13) R = (correct_args + correct_preds) / (gold_args + gold_preds + 1e-13) NP = correct_args / (predict_args + 1e-13) NR = correct_args / (gold_args + 1e-13) F1 = 2 * P * R / (P + R + 1e-13) NF1 = 2 * NP * NR / (NP + NR + 1e-13) print '\teval accurate:%.4f predict:%d golden:%d correct:%d' % \ (num_correct / total * 100, predict_args, gold_args, correct_args) print '\tP:%.4f R:%.4f F1:%.4f' % (P * 100, R * 100, F1 * 100) print '\tNP:%.2f NR:%.2f NF1:%.2f' % (NP * 100, NR * 100, NF1 * 100) print '\tpredicate disambiguation accurate:%.2f' % (correct_preds / gold_preds * 100) os.system('perl ../lib/eval.pl -g %s -s %s > %s.eval' % (raw_test_file, output_file, output_file)) return NF1, F1
import os from lib import Vocab, DataLoader root_path = '/home/clementine/projects/树库转换/ctb51_processed/dep_zhang_clark_conlu/' training = os.path.join(root_path, 'train.txt.3.pa.gs.tab.conllu') vocab = Vocab(training, pret_file=None, min_occur_count=2) ctb_loader = DataLoader(training, 40, vocab) for i, example in enumerate(ctb_loader.get_batches(40, shuffle=True)): print(example[0].shape, example[1].shape, example[2].shape, example[3].shape) if i == 0: break
argparser.add_argument('--model', default='BaseParser') args, extra_args = argparser.parse_known_args() config = Configurable(args.config_file, extra_args) Parser = getattr(models, args.model) vocab = Vocab(config.wsj_file, config.pretrained_embeddings_file, config.min_occur_count) cPickle.dump(vocab, open(config.save_vocab_path, 'w')) parser = Parser(vocab, config.word_dims, config.tag_dims, config.dropout_emb, config.lstm_layers, config.lstm_hiddens, config.dropout_lstm_input, config.dropout_lstm_hidden, config.mlp_arc_size, config.mlp_rel_size, config.dropout_mlp, config.filter_size, config.domain_num) wsj = DataLoader(config.wsj_file, config.num_buckets_train, vocab, isTrain=True) answer = DataLoader(config.answers_file, config.num_buckets_train, vocab, isTrain=True, len_counter=wsj.len_counter) pc = parser.parameter_collection trainer = dy.AdamTrainer(pc, config.learning_rate, config.beta_1, config.beta_2, config.epsilon) data = [] for i, item in enumerate([wsj, answer]): for words, tags, arcs, rels in item.get_batches(