def test(parser, vocab, num_buckets_test, test_batch_size, test_file, output_file): data_loader = DataLoader(test_file, num_buckets_test, vocab) record = data_loader.idx_sequence results = [None] * len(record) idx = 0 for words, tags, arcs, rels in data_loader.get_batches(batch_size=test_batch_size, shuffle=False): dy.renew_cg() outputs = parser.run(words, tags, isTrain=False) for output in outputs: sent_idx = record[idx] results[sent_idx] = output idx += 1 arcs = reduce(lambda x, y: x + y, [ list(result[0]) for result in results ]) rels = reduce(lambda x, y: x + y, [ list(result[1]) for result in results ]) idx = 0 with open(test_file) as f: with open(output_file, 'w') as fo: for line in f.readlines(): info = line.strip().split() if info: assert len(info) == 10, 'Illegal line: %s' % line info[6] = str(arcs[idx]) info[7] = vocab.id2rel(rels[idx]) fo.write('\t'.join(info) + '\n') idx += 1 else: fo.write('\n') os.system('perl eval.pl -q -b -g %s -s %s -o tmp' % (test_file, output_file)) os.system('tail -n 3 tmp > score_tmp') LAS, UAS = [float(line.strip().split()[-2]) for line in open('score_tmp').readlines()[:2]] print('LAS %.2f, UAS %.2f'%(LAS, UAS)) os.system('rm tmp score_tmp') return LAS, UAS
# eta=config.eta, patience=config.patience, clip=config.clip, # lr = config.learning_rate, beta1=config.beta_1, beta2 = config.beta_2, epsilon=config.epsilon) global_step = 0 def update_parameters(): trainer.learning_rate =config.learning_rate*config.decay**(global_step / config.decay_steps) trainer.update() epoch = 0 best_F1 = 0. history = lambda x, y : open(os.path.join(config.save_dir, 'valid_history'),'a').write('%.2f %.2f\n'%(x,y)) while global_step < config.train_iters: print time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()), '\nStart training epoch #%d'%(epoch, ) epoch += 1 for words, lemmas, tags, arcs, rels, syn_masks, seq_lens in \ data_loader.get_batches(batch_size = config.train_batch_size, shuffle = True): num = int(words.shape[1]/2) words_ = [words[:,:num], words[:,num:]] lemmas_ = [lemmas[:,:num], lemmas[:,num:]] tags_ = [tags[:,:num], tags[:,num:]] arcs_ = [arcs[:,:num], arcs[:,num:]] rels_ = [rels[:,:num], rels[:,num:]] syn_masks_ = [syn_masks[:,:num], syn_masks[:,num:]] seq_lens_ = [seq_lens[:num], seq_lens[num:]] for step in xrange(2): dy.renew_cg() rel_accuracy, loss = parser.run(words_[step], lemmas_[step], tags_[step], arcs_[step], rels_[step], syn_mask = syn_masks_[step], seq_lens = seq_lens_[step]) loss = loss * 0.5 loss_value = loss.scalar_value() loss.backward() sys.stdout.write("Step #%d: Acc: rel %.2f, loss %.3f\r\r" %
def update_parameters(): trainer.learning_rate = config.learning_rate * config.decay**( global_step / config.decay_steps) trainer.update() epoch = 0 best_UAS = 0. history = lambda x, y: open(os.path.join(config.save_dir, 'valid_history'), 'a').write('%.2f %.2f\n' % (x, y)) while global_step < config.train_iters: print time.strftime( "%Y-%m-%d %H:%M:%S", time.localtime()), '\nStart training epoch #%d' % (epoch, ) epoch += 1 for words, tags, arcs, rels in data_loader.get_batches( batch_size=config.train_batch_size, shuffle=True): dy.renew_cg() arc_accuracy, rel_accuracy, overall_accuracy, loss = parser.run( words, tags, arcs, rels) loss_value = loss.scalar_value() loss.backward() update_parameters() sys.stdout.write( "Step #%d: Acc: arc %.2f, rel %.2f, overall %.2f, loss %.3f\r\r" % (global_step, arc_accuracy, rel_accuracy, overall_accuracy, loss_value)) sys.stdout.flush() global_step += 1 if global_step % config.validate_every == 0: print '\nTest on development set'
def test(parser, vocab, num_buckets_test, test_batch_size, unlabeled_test_file, labeled_test_file, raw_test_file, output_file): data_loader = DataLoader(unlabeled_test_file, num_buckets_test, vocab) record = data_loader.idx_sequence results = [None] * len(record) idx = 0 for words, lemmas, tags, arcs, rels in \ data_loader.get_batches(batch_size = test_batch_size, shuffle = False): dy.renew_cg() outputs = parser.run(words, lemmas, tags, arcs, isTrain=False) for output in outputs: sent_idx = record[idx] results[sent_idx] = output idx += 1 global gold_preds global gold_args global gold_rels global test_data if not gold_rels: print 'prepare test gold data' gold_rels = [] _predicate_cnt = 0 _preds_args = defaultdict(list) _end_flag = False with open(labeled_test_file) as f: for line in f: info = line.strip().split() if info: assert len(info) == 10, 'Illegal line: %s' % line if _end_flag and len(_preds_args) > 0 and int( info[6]) == 0: gold_rels.append(_preds_args) _preds_args = defaultdict(list) _end_flag = False _preds_args[int(info[6])].append(vocab.rel2id(info[7])) if info[7] != '_': if int(info[6]) == 0: gold_preds += 1 else: gold_args += 1 else: _end_flag = True if len(_preds_args) > 0: gold_rels.append(_preds_args) if not test_data: print 'prepare for writing out the prediction' with open(raw_test_file) as f: test_data = [] test_sent = [] for line in f: info = line.strip().split() if info: test_sent.append([ info[0], info[1], info[2], info[3], info[4], info[5], info[6], info[7], info[8], info[9], info[10], info[11], '_', '_' ]) elif len(test_sent) > 0: test_data.append(test_sent) test_sent = [] if len(test_sent) > 0: test_data.append(test_sent) test_sent = [] with open(output_file, 'w') as f: for test_sent, predict_sent, gold_sent in zip(test_data, results, gold_rels): for i, rel in enumerate(predict_sent[0]): if rel != vocab.NONE: test_sent[i][12] = 'Y' test_sent[i][13] = '%s.%s' % (test_sent[i][2], vocab.id2rel(rel)) for k in xrange(1, len(test_sent) + 1): if k in gold_sent and k in predict_sent: for i, (prel, grel) in enumerate( zip(predict_sent[k], gold_sent[k])): test_sent[i].append(vocab.id2rel(grel)) test_sent[i].append(vocab.id2rel(prel)) else: if k in gold_sent: for i, grel in enumerate(gold_sent[k]): test_sent[i].append(vocab.id2rel(grel)) if k in predict_sent: for i, prel in enumerate(predict_sent[k]): test_sent[i].append(vocab.id2rel(prel)) for tokens in test_sent: f.write('\t'.join(tokens)) f.write('\n') f.write('\n') predict_preds = 0. correct_preds = 0. predict_args = 0. correct_args = 0. num_correct = 0. total = 0. for psent, gsent in zip(results, gold_rels): predict_preds += len(psent) - 1 for g_pred, g_args in gsent.iteritems(): if g_pred == 0: p_args = psent.pop(g_pred) for p_rel, g_rel in zip(p_args, g_args): if g_rel != vocab.NONE and g_rel == p_rel: correct_preds += 1 else: if g_pred in psent: p_args = psent.pop(g_pred) for p_rel, g_rel in zip(p_args, g_args): total += 1 if p_rel != vocab.NONE: predict_args += 1 if g_rel != vocab.NONE and g_rel == p_rel: correct_args += 1 if g_rel == p_rel: num_correct += 1 else: for i in xrange(len(g_args)): total += 1 for p_pred, p_args in psent.iteritems(): for p_rel in p_args: if p_rel != vocab.NONE: predict_args += 1 print 'arguments: correct:%d, gold:%d, predicted:%d' % ( correct_args, gold_args, predict_args) print 'predicates: correct:%d, gold:%d, predicted:%d' % ( correct_preds, gold_preds, predict_preds) P = (correct_args + correct_preds) / (predict_args + predict_preds + 1e-13) R = (correct_args + correct_preds) / (gold_args + gold_preds + 1e-13) NP = correct_args / (predict_args + 1e-13) NR = correct_args / (gold_args + 1e-13) PP = correct_preds / (predict_preds + 1e-13) PR = correct_preds / (gold_preds + 1e-13) PF1 = 2 * PP * PR / (PP + PR + 1e-13) F1 = 2 * P * R / (P + R + 1e-13) NF1 = 2 * NP * NR / (NP + NR + 1e-13) print '\teval accurate:%.4f predict:%d golden:%d correct:%d' % \ (num_correct / total * 100, predict_args, gold_args, correct_args) print '\tP:%.4f R:%.4f F1:%.4f' % (P * 100, R * 100, F1 * 100) print '\tNP:%.4f NR:%.4f NF1:%.4f' % (NP * 100, NR * 100, NF1 * 100) print '\tcorrect predicate:%d \tgold predicate:%d' % (correct_preds, gold_preds) print '\tpredicate disambiguation PP:%.4f PR:%.4f PF1:%.4f' % ( PP * 100, PR * 100, PF1 * 100) os.system('perl ../lib/eval.pl -g %s -s %s > %s.eval' % (raw_test_file, output_file, output_file)) return NF1, F1
def test(parser, vocab, num_buckets_test, test_batch_size, pro_test_file, raw_test_file, output_file, unified = True, disambiguation_file = None, disambiguation_accuracy = None): if not unified: assert ((disambiguation_file is not None and len(disambiguation_file)>0) or disambiguation_accuracy is not None), \ 'The accuracy of predicate disambiguation shold be provied.' data_loader = DataLoader(pro_test_file, num_buckets_test, vocab) record = data_loader.idx_sequence results = [None] * len(record) idx = 0 for words, lemmas, tags, arcs, rels, syn_masks, seq_lens in \ data_loader.get_batches(batch_size = test_batch_size, shuffle = False): dy.renew_cg() outputs = parser.run(words, lemmas, tags, arcs, isTrain = False, syn_mask = syn_masks, seq_lens=seq_lens) for output in outputs: sent_idx = record[idx] results[sent_idx] = output idx += 1 rels = reduce(lambda x, y: x + y, [ list(result) for result in results ]) # global gold_rels # global test_data # global test_stat # if not gold_rels: gold_rels = [] gold_sent = [] with open(pro_test_file) as f: for line in f: info = line.strip().split() if info: assert len(info) == 11, 'Illegal line: %s' % line gold_sent.append(info[7]) else: if len(gold_sent)>0: gold_rels.append(gold_sent) gold_sent = [] if len(gold_sent)>0: gold_rels.append(gold_sent) # if not test_data: #print 'prepare for writing out the prediction' with open(raw_test_file) as f: test_data = [] test_sent = [] test_stat = [] pred_index = [] for line in f: info = line.strip().split() if info: test_sent.append([info[0], info[1], info[2], info[3], info[4], info[5], info[6], info[7], info[8], info[9], info[10], info[11], info[12], '_' if unified else info[13]]) if info[12] == 'Y': pred_index.append(int(info[0])) elif len(test_sent) > 0: test_data.append(test_sent) test_sent = [] test_stat.append(pred_index) pred_index = [] if len(test_sent) > 0: test_data.append(test_sent) test_sent = [] test_stat.append(pred_index) pred_index = [] if disambiguation_file is not None and len(disambiguation_file)>0: gold_disamb = [] gold_dsent = [] with open(disambiguation_file) as f: for line in f: info = line.strip().split() if info: gold_dsent.append(info[0]) else: if len(gold_dsent)>0: gold_disamb.append(gold_dsent) gold_dsent = [] if len(gold_dsent)>0: gold_disamb.append(gold_dsent) assert len(test_data) == len(gold_disamb) idx = 0 oidx = 0 with open(output_file, 'w') as f: for test_s, pred_index in zip(test_data, test_stat): test_sent = copy.deepcopy(test_s) for p_idx in pred_index: if unified: test_sent[p_idx - 1][13] = test_sent[p_idx - 1][2] + '.' + vocab.id2rel(results[idx][0]) else: if disambiguation_file is not None and len(disambiguation_file)>0: test_sent[p_idx - 1][13] = gold_disamb[oidx][p_idx - 1] else: if random.random() < disambiguation_accuracy: test_sent[p_idx - 1][13] = test_sent[p_idx - 1][13] else: test_sent[p_idx - 1][13] = test_sent[p_idx - 1][2] if unified: for i, p_rel in enumerate(results[idx][1:]): test_sent[i].append(vocab.id2rel(p_rel)) else: for i, p_rel in enumerate(results[idx]): test_sent[i].append(vocab.id2rel(p_rel)) idx += 1 oidx += 1 for tokens in test_sent: f.write('\t'.join(tokens)) f.write('\n') f.write('\n') predict_args = 0. correct_args = 0. gold_args = 0. correct_preds = 0. gold_preds = len(gold_rels) num_correct = 0. total = 0. idx = 0 for sent in gold_rels: for i in range(len(sent)): gold = sent[i] pred = rels[idx] if unified and i==0: if sent[i] == vocab.id2rel(pred): correct_preds += 1 else: total += 1 if vocab.id2rel(pred)!='_': predict_args += 1 if gold != '_': gold_args += 1 if gold != '_' and gold == vocab.id2rel(pred): correct_args += 1 if gold == vocab.id2rel(pred): num_correct += 1 idx += 1 correct_preds = correct_preds if unified else gold_preds * disambiguation_accuracy P = (correct_args + correct_preds) / (predict_args + gold_preds + 1e-13) R = (correct_args + correct_preds) / (gold_args + gold_preds + 1e-13) NP = correct_args / (predict_args + 1e-13) NR = correct_args / (gold_args + 1e-13) F1 = 2 * P * R / (P + R + 1e-13) NF1 = 2 * NP * NR / (NP + NR + 1e-13) print '\teval accurate:%.4f predict:%d golden:%d correct:%d' % \ (num_correct / total * 100, predict_args, gold_args, correct_args) print '\tP:%.4f R:%.4f F1:%.4f' % (P * 100, R * 100, F1 * 100) print '\tNP:%.2f NR:%.2f NF1:%.2f' % (NP * 100, NR * 100, NF1 * 100) print '\tpredicate disambiguation accurate:%.2f' % (correct_preds / gold_preds * 100) os.system('perl ../lib/eval.pl -g %s -s %s > %s.eval' % (raw_test_file, output_file, output_file)) return NF1, F1
import os from lib import Vocab, DataLoader root_path = '/home/clementine/projects/树库转换/ctb51_processed/dep_zhang_clark_conlu/' training = os.path.join(root_path, 'train.txt.3.pa.gs.tab.conllu') vocab = Vocab(training, pret_file=None, min_occur_count=2) ctb_loader = DataLoader(training, 40, vocab) for i, example in enumerate(ctb_loader.get_batches(40, shuffle=True)): print(example[0].shape, example[1].shape, example[2].shape, example[3].shape) if i == 0: break