def test(parser, vocab, num_buckets_test, test_batch_size, test_file, output_file):
    data_loader = DataLoader(test_file, num_buckets_test, vocab)
    record = data_loader.idx_sequence
    results = [None] * len(record)
    idx = 0
    for words, tags, arcs, rels in data_loader.get_batches(batch_size=test_batch_size, shuffle=False):
        dy.renew_cg()
        outputs = parser.run(words, tags, isTrain=False)
        for output in outputs:
            sent_idx = record[idx]
            results[sent_idx] = output
            idx += 1

    arcs = reduce(lambda x, y: x + y, [ list(result[0]) for result in results ])
    rels = reduce(lambda x, y: x + y, [ list(result[1]) for result in results ])
    idx = 0
    with open(test_file) as f:
        with open(output_file, 'w') as fo:
            for line in f.readlines():
                info = line.strip().split()
                if info:
                    assert len(info) == 10, 'Illegal line: %s' % line
                    info[6] = str(arcs[idx])
                    info[7] = vocab.id2rel(rels[idx])
                    fo.write('\t'.join(info) + '\n')
                    idx += 1
                else:
                    fo.write('\n')

    os.system('perl eval.pl -q -b -g %s -s %s -o tmp' % (test_file, output_file))
    os.system('tail -n 3 tmp > score_tmp')
    LAS, UAS = [float(line.strip().split()[-2]) for line in open('score_tmp').readlines()[:2]]
    print('LAS %.2f, UAS %.2f'%(LAS, UAS))
    os.system('rm tmp score_tmp')
    return LAS, UAS
예제 #2
0
	# 							eta=config.eta, patience=config.patience, clip=config.clip,
	# 							lr = config.learning_rate, beta1=config.beta_1, beta2 = config.beta_2, epsilon=config.epsilon)
	
	global_step = 0
	def update_parameters():
		trainer.learning_rate =config.learning_rate*config.decay**(global_step / config.decay_steps)
		trainer.update()

	epoch = 0
	best_F1 = 0.
	history = lambda x, y : open(os.path.join(config.save_dir, 'valid_history'),'a').write('%.2f %.2f\n'%(x,y))
	while global_step < config.train_iters:
		print time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()), '\nStart training epoch #%d'%(epoch, )
		epoch += 1
		for words, lemmas, tags, arcs, rels, syn_masks, seq_lens in \
				data_loader.get_batches(batch_size = config.train_batch_size, shuffle = True):
			num = int(words.shape[1]/2)
			words_ = [words[:,:num], words[:,num:]]
			lemmas_ = [lemmas[:,:num], lemmas[:,num:]]
			tags_ = [tags[:,:num], tags[:,num:]]
			arcs_ = [arcs[:,:num], arcs[:,num:]]
			rels_ = [rels[:,:num], rels[:,num:]]
			syn_masks_ = [syn_masks[:,:num], syn_masks[:,num:]]
			seq_lens_ = [seq_lens[:num], seq_lens[num:]]
			for step in xrange(2):
				dy.renew_cg()
				rel_accuracy, loss = parser.run(words_[step], lemmas_[step], tags_[step], arcs_[step], rels_[step], syn_mask = syn_masks_[step], seq_lens = seq_lens_[step])
				loss = loss * 0.5
				loss_value = loss.scalar_value()
				loss.backward()
				sys.stdout.write("Step #%d: Acc: rel %.2f, loss %.3f\r\r" % 
예제 #3
0
    def update_parameters():
        trainer.learning_rate = config.learning_rate * config.decay**(
            global_step / config.decay_steps)
        trainer.update()

    epoch = 0
    best_UAS = 0.
    history = lambda x, y: open(os.path.join(config.save_dir, 'valid_history'),
                                'a').write('%.2f %.2f\n' % (x, y))
    while global_step < config.train_iters:
        print time.strftime(
            "%Y-%m-%d %H:%M:%S",
            time.localtime()), '\nStart training epoch #%d' % (epoch, )
        epoch += 1
        for words, tags, arcs, rels in data_loader.get_batches(
                batch_size=config.train_batch_size, shuffle=True):
            dy.renew_cg()
            arc_accuracy, rel_accuracy, overall_accuracy, loss = parser.run(
                words, tags, arcs, rels)
            loss_value = loss.scalar_value()
            loss.backward()
            update_parameters()
            sys.stdout.write(
                "Step #%d: Acc: arc %.2f, rel %.2f, overall %.2f, loss %.3f\r\r"
                % (global_step, arc_accuracy, rel_accuracy, overall_accuracy,
                   loss_value))
            sys.stdout.flush()

            global_step += 1
            if global_step % config.validate_every == 0:
                print '\nTest on development set'
예제 #4
0
def test(parser, vocab, num_buckets_test, test_batch_size, unlabeled_test_file,
         labeled_test_file, raw_test_file, output_file):
    data_loader = DataLoader(unlabeled_test_file, num_buckets_test, vocab)
    record = data_loader.idx_sequence
    results = [None] * len(record)
    idx = 0
    for words, lemmas, tags, arcs, rels in \
      data_loader.get_batches(batch_size = test_batch_size, shuffle = False):
        dy.renew_cg()
        outputs = parser.run(words, lemmas, tags, arcs, isTrain=False)
        for output in outputs:
            sent_idx = record[idx]
            results[sent_idx] = output
            idx += 1

    global gold_preds
    global gold_args
    global gold_rels
    global test_data

    if not gold_rels:
        print 'prepare test gold data'
        gold_rels = []
        _predicate_cnt = 0
        _preds_args = defaultdict(list)
        _end_flag = False

        with open(labeled_test_file) as f:
            for line in f:
                info = line.strip().split()
                if info:
                    assert len(info) == 10, 'Illegal line: %s' % line
                    if _end_flag and len(_preds_args) > 0 and int(
                            info[6]) == 0:
                        gold_rels.append(_preds_args)
                        _preds_args = defaultdict(list)
                        _end_flag = False
                    _preds_args[int(info[6])].append(vocab.rel2id(info[7]))
                    if info[7] != '_':
                        if int(info[6]) == 0:
                            gold_preds += 1
                        else:
                            gold_args += 1
                else:
                    _end_flag = True
            if len(_preds_args) > 0:
                gold_rels.append(_preds_args)

    if not test_data:
        print 'prepare for writing out the prediction'
        with open(raw_test_file) as f:
            test_data = []
            test_sent = []
            for line in f:
                info = line.strip().split()
                if info:
                    test_sent.append([
                        info[0], info[1], info[2], info[3], info[4], info[5],
                        info[6], info[7], info[8], info[9], info[10], info[11],
                        '_', '_'
                    ])
                elif len(test_sent) > 0:
                    test_data.append(test_sent)
                    test_sent = []

            if len(test_sent) > 0:
                test_data.append(test_sent)
                test_sent = []

    with open(output_file, 'w') as f:
        for test_sent, predict_sent, gold_sent in zip(test_data, results,
                                                      gold_rels):
            for i, rel in enumerate(predict_sent[0]):
                if rel != vocab.NONE:
                    test_sent[i][12] = 'Y'
                    test_sent[i][13] = '%s.%s' % (test_sent[i][2],
                                                  vocab.id2rel(rel))

            for k in xrange(1, len(test_sent) + 1):
                if k in gold_sent and k in predict_sent:
                    for i, (prel, grel) in enumerate(
                            zip(predict_sent[k], gold_sent[k])):
                        test_sent[i].append(vocab.id2rel(grel))
                        test_sent[i].append(vocab.id2rel(prel))
                else:
                    if k in gold_sent:
                        for i, grel in enumerate(gold_sent[k]):
                            test_sent[i].append(vocab.id2rel(grel))
                    if k in predict_sent:
                        for i, prel in enumerate(predict_sent[k]):
                            test_sent[i].append(vocab.id2rel(prel))

            for tokens in test_sent:
                f.write('\t'.join(tokens))
                f.write('\n')
            f.write('\n')

    predict_preds = 0.
    correct_preds = 0.
    predict_args = 0.
    correct_args = 0.
    num_correct = 0.
    total = 0.
    for psent, gsent in zip(results, gold_rels):
        predict_preds += len(psent) - 1
        for g_pred, g_args in gsent.iteritems():
            if g_pred == 0:
                p_args = psent.pop(g_pred)
                for p_rel, g_rel in zip(p_args, g_args):
                    if g_rel != vocab.NONE and g_rel == p_rel:
                        correct_preds += 1
            else:
                if g_pred in psent:
                    p_args = psent.pop(g_pred)
                    for p_rel, g_rel in zip(p_args, g_args):
                        total += 1
                        if p_rel != vocab.NONE:
                            predict_args += 1
                        if g_rel != vocab.NONE and g_rel == p_rel:
                            correct_args += 1
                        if g_rel == p_rel:
                            num_correct += 1
                else:
                    for i in xrange(len(g_args)):
                        total += 1

        for p_pred, p_args in psent.iteritems():
            for p_rel in p_args:
                if p_rel != vocab.NONE:
                    predict_args += 1

    print 'arguments: correct:%d, gold:%d, predicted:%d' % (
        correct_args, gold_args, predict_args)
    print 'predicates: correct:%d, gold:%d, predicted:%d' % (
        correct_preds, gold_preds, predict_preds)

    P = (correct_args + correct_preds) / (predict_args + predict_preds + 1e-13)
    R = (correct_args + correct_preds) / (gold_args + gold_preds + 1e-13)
    NP = correct_args / (predict_args + 1e-13)
    NR = correct_args / (gold_args + 1e-13)
    PP = correct_preds / (predict_preds + 1e-13)
    PR = correct_preds / (gold_preds + 1e-13)
    PF1 = 2 * PP * PR / (PP + PR + 1e-13)
    F1 = 2 * P * R / (P + R + 1e-13)
    NF1 = 2 * NP * NR / (NP + NR + 1e-13)

    print '\teval accurate:%.4f predict:%d golden:%d correct:%d' % \
       (num_correct / total * 100, predict_args, gold_args, correct_args)
    print '\tP:%.4f R:%.4f F1:%.4f' % (P * 100, R * 100, F1 * 100)
    print '\tNP:%.4f NR:%.4f NF1:%.4f' % (NP * 100, NR * 100, NF1 * 100)
    print '\tcorrect predicate:%d \tgold predicate:%d' % (correct_preds,
                                                          gold_preds)
    print '\tpredicate disambiguation PP:%.4f PR:%.4f PF1:%.4f' % (
        PP * 100, PR * 100, PF1 * 100)
    os.system('perl ../lib/eval.pl -g %s -s %s > %s.eval' %
              (raw_test_file, output_file, output_file))
    return NF1, F1
예제 #5
0
def test(parser, vocab, num_buckets_test, test_batch_size, pro_test_file, raw_test_file, output_file,
	unified = True, disambiguation_file = None, disambiguation_accuracy = None):
	if not unified:
		assert ((disambiguation_file is not None and len(disambiguation_file)>0) or disambiguation_accuracy is not None), \
				'The accuracy of predicate disambiguation shold be provied.'
	data_loader = DataLoader(pro_test_file, num_buckets_test, vocab)
	record = data_loader.idx_sequence
	results = [None] * len(record)
	idx = 0
	for words, lemmas, tags, arcs, rels, syn_masks, seq_lens in \
			data_loader.get_batches(batch_size = test_batch_size, shuffle = False):
		dy.renew_cg()
		outputs = parser.run(words, lemmas, tags, arcs, isTrain = False, syn_mask = syn_masks, seq_lens=seq_lens)
		for output in outputs:
			sent_idx = record[idx]
			results[sent_idx] = output
			idx += 1
	rels = reduce(lambda x, y: x + y, [ list(result) for result in results ])
	
	# global gold_rels
	# global test_data
	# global test_stat

	# if not gold_rels:
	gold_rels = []
	gold_sent = []
	with open(pro_test_file) as f:
		for line in f:
			info = line.strip().split()
			if info:
				assert len(info) == 11, 'Illegal line: %s' % line
				gold_sent.append(info[7])
			else:
				if len(gold_sent)>0:
					gold_rels.append(gold_sent)
					gold_sent = []
		if len(gold_sent)>0:
			gold_rels.append(gold_sent)

	
	# if not test_data:
	#print 'prepare for writing out the prediction'
	with open(raw_test_file) as f:
		test_data = []
		test_sent = []
		test_stat = []
		pred_index = []
		for line in f:
			info = line.strip().split()
			if info:
				test_sent.append([info[0], info[1], info[2], info[3], info[4], info[5], 
								info[6], info[7], info[8], info[9], info[10], info[11], 
								info[12], '_' if unified else info[13]])
				if info[12] == 'Y':
					pred_index.append(int(info[0]))
			elif len(test_sent) > 0:
				test_data.append(test_sent)
				test_sent = []
				test_stat.append(pred_index)
				pred_index = []
		if len(test_sent) > 0:
			test_data.append(test_sent)
			test_sent = []
			test_stat.append(pred_index)
			pred_index = []


	if disambiguation_file is not None and len(disambiguation_file)>0:
		gold_disamb = []
		gold_dsent = []
		with open(disambiguation_file) as f:
			for line in f:
				info = line.strip().split()
				if info:
					gold_dsent.append(info[0])
				else:
					if len(gold_dsent)>0:
						gold_disamb.append(gold_dsent)
						gold_dsent = []
			if len(gold_dsent)>0:
				gold_disamb.append(gold_dsent)

		assert len(test_data) == len(gold_disamb)

	idx = 0
	oidx = 0
	with open(output_file, 'w') as f:
		for test_s, pred_index in zip(test_data, test_stat):
			test_sent = copy.deepcopy(test_s)
			for p_idx in pred_index:
				if unified:
					test_sent[p_idx - 1][13] = test_sent[p_idx - 1][2] + '.' + vocab.id2rel(results[idx][0])
				else:
					if disambiguation_file is not None and len(disambiguation_file)>0:
						test_sent[p_idx - 1][13] = gold_disamb[oidx][p_idx - 1]
					else:
						if random.random() < disambiguation_accuracy:
							test_sent[p_idx - 1][13] = test_sent[p_idx - 1][13]
						else:
							test_sent[p_idx - 1][13] = test_sent[p_idx - 1][2]
				if unified:
					for i, p_rel in enumerate(results[idx][1:]):	
						test_sent[i].append(vocab.id2rel(p_rel))
				else:
					for i, p_rel in enumerate(results[idx]):	
						test_sent[i].append(vocab.id2rel(p_rel))
				idx += 1
			oidx += 1
			for tokens in test_sent:
				f.write('\t'.join(tokens))
				f.write('\n')
			f.write('\n')
	
	predict_args = 0.
	correct_args = 0.
	gold_args = 0.
	correct_preds = 0.
	gold_preds = len(gold_rels)
	num_correct = 0.
	total = 0.
	idx = 0
	for sent in gold_rels:
		
		for i in range(len(sent)):
			gold = sent[i]
			pred = rels[idx]
			if unified and i==0:
				if sent[i] == vocab.id2rel(pred):
					correct_preds += 1
			else:
				total += 1
				if vocab.id2rel(pred)!='_':
					predict_args += 1
				if gold != '_':
					gold_args += 1
				if gold != '_' and gold == vocab.id2rel(pred):
					correct_args += 1
				if gold == vocab.id2rel(pred):
					num_correct += 1
			idx += 1
	
	correct_preds = correct_preds if unified else gold_preds * disambiguation_accuracy
	P = (correct_args + correct_preds) / (predict_args + gold_preds + 1e-13)
	R = (correct_args + correct_preds) / (gold_args + gold_preds + 1e-13)
	NP = correct_args / (predict_args + 1e-13)
	NR = correct_args / (gold_args + 1e-13)	   
	F1 = 2 * P * R / (P + R + 1e-13)
	NF1 = 2 * NP * NR / (NP + NR + 1e-13)
	print '\teval accurate:%.4f predict:%d golden:%d correct:%d' % \
				(num_correct / total * 100, predict_args, gold_args, correct_args)
	print '\tP:%.4f R:%.4f F1:%.4f' % (P * 100, R * 100, F1 * 100)
	print '\tNP:%.2f NR:%.2f NF1:%.2f' % (NP * 100, NR * 100, NF1 * 100)
	print '\tpredicate disambiguation accurate:%.2f' % (correct_preds / gold_preds * 100)
	os.system('perl ../lib/eval.pl -g %s -s %s > %s.eval' % (raw_test_file, output_file, output_file))
	return NF1, F1
예제 #6
0
import os

from lib import Vocab, DataLoader

root_path = '/home/clementine/projects/树库转换/ctb51_processed/dep_zhang_clark_conlu/'
training = os.path.join(root_path, 'train.txt.3.pa.gs.tab.conllu')
vocab = Vocab(training, pret_file=None, min_occur_count=2)

ctb_loader = DataLoader(training, 40, vocab)
for i, example in enumerate(ctb_loader.get_batches(40, shuffle=True)):
    print(example[0].shape, example[1].shape, example[2].shape,
          example[3].shape)
    if i == 0:
        break