def test(parser, vocab, num_buckets_test, test_batch_size, test_file, output_file):
    data_loader = DataLoader(test_file, num_buckets_test, vocab)
    record = data_loader.idx_sequence
    results = [None] * len(record)
    idx = 0
    for words, tags, arcs, rels in data_loader.get_batches(batch_size=test_batch_size, shuffle=False):
        dy.renew_cg()
        outputs = parser.run(words, tags, isTrain=False)
        for output in outputs:
            sent_idx = record[idx]
            results[sent_idx] = output
            idx += 1

    arcs = reduce(lambda x, y: x + y, [ list(result[0]) for result in results ])
    rels = reduce(lambda x, y: x + y, [ list(result[1]) for result in results ])
    idx = 0
    with open(test_file) as f:
        with open(output_file, 'w') as fo:
            for line in f.readlines():
                info = line.strip().split()
                if info:
                    assert len(info) == 10, 'Illegal line: %s' % line
                    info[6] = str(arcs[idx])
                    info[7] = vocab.id2rel(rels[idx])
                    fo.write('\t'.join(info) + '\n')
                    idx += 1
                else:
                    fo.write('\n')

    os.system('perl eval.pl -q -b -g %s -s %s -o tmp' % (test_file, output_file))
    os.system('tail -n 3 tmp > score_tmp')
    LAS, UAS = [float(line.strip().split()[-2]) for line in open('score_tmp').readlines()[:2]]
    print('LAS %.2f, UAS %.2f'%(LAS, UAS))
    os.system('rm tmp score_tmp')
    return LAS, UAS
Exemple #2
0
	argparser.add_argument('--config_file', default='./config.cfg')
	argparser.add_argument('--model', default='BaseParser')
	args, extra_args = argparser.parse_known_args()
	config = Configurable(args.config_file, extra_args)
	Parser = getattr(models, args.model)

	vocab = Vocab(config.train_file, config.pretrained_embeddings_file, config.min_occur_count)
	cPickle.dump(vocab, open(config.save_vocab_path, 'w'))
	parser = Parser(vocab, config.word_dims, config.pret_dims, config.lemma_dims, config.flag_dims, config.tag_dims, config.dropout_emb, 
					config.encoder_type, config.use_si_dropout,
					config.lstm_layers, config.lstm_hiddens, config.dropout_lstm_input, config.dropout_lstm_hidden, config.mlp_rel_size, config.dropout_mlp, 
					config.transformer_layers, config.transformer_heads, config.transformer_hiddens, config.transformer_ffn, config.transformer_dropout, config.transformer_maxlen, config.transformer_max_timescale,
					config.use_lm, config.lm_path, config.lm_dims, config.lm_hidden_size, config.lm_sentences,
					config.use_pos, config.use_lemma,
					config.unified)
	data_loader = DataLoader(config.train_file, config.num_buckets_train, vocab)
	pc = parser.parameter_collection
	trainer = dy.AdamTrainer(pc, config.learning_rate , config.beta_1, config.beta_2, config.epsilon)
	# optimizer = OptimizerManager(parser, lr_scheduler_type=config.lr_scheduler_type, optim=config.optim, warmup_steps=config.warmup_steps, 
	# 							eta=config.eta, patience=config.patience, clip=config.clip,
	# 							lr = config.learning_rate, beta1=config.beta_1, beta2 = config.beta_2, epsilon=config.epsilon)
	
	global_step = 0
	def update_parameters():
		trainer.learning_rate =config.learning_rate*config.decay**(global_step / config.decay_steps)
		trainer.update()

	epoch = 0
	best_F1 = 0.
	history = lambda x, y : open(os.path.join(config.save_dir, 'valid_history'),'a').write('%.2f %.2f\n'%(x,y))
	while global_step < config.train_iters:
Exemple #3
0
    argparser = argparse.ArgumentParser()
    argparser.add_argument('--config_file', default='../configs/default.cfg')
    argparser.add_argument('--model', default='BaseParser')
    args, extra_args = argparser.parse_known_args()
    config = Configurable(args.config_file, extra_args)
    Parser = getattr(models, args.model)

    vocab = Vocab(config.train_file, config.pretrained_embeddings_file,
                  config.min_occur_count)
    pickle.dump(vocab, open(config.save_vocab_path, 'wb'))
    parser = Parser(vocab, config.word_dims, config.tag_dims,
                    config.dropout_emb, config.lstm_layers,
                    config.lstm_hiddens, config.dropout_lstm_input,
                    config.dropout_lstm_hidden, config.mlp_arc_size,
                    config.mlp_rel_size, config.dropout_mlp)
    data_loader = DataLoader(config.train_file, config.num_buckets_train,
                             vocab)
    pc = parser.parameter_collection
    trainer = dy.AdamTrainer(pc, config.learning_rate, config.beta_1,
                             config.beta_2, config.epsilon)

    global_step = 0

    def update_parameters():
        trainer.learning_rate = config.learning_rate * config.decay**(
            global_step / config.decay_steps)
        trainer.update()

    epoch = 0
    best_UAS = 0.
    history = lambda x, y: open(os.path.join(config.save_dir, 'valid_history'),
                                'a').write('%.2f %.2f\n' % (x, y))
Exemple #4
0
def test(parser, vocab, num_buckets_test, test_batch_size, unlabeled_test_file,
         labeled_test_file, raw_test_file, output_file):
    data_loader = DataLoader(unlabeled_test_file, num_buckets_test, vocab)
    record = data_loader.idx_sequence
    results = [None] * len(record)
    idx = 0
    for words, lemmas, tags, arcs, rels in \
      data_loader.get_batches(batch_size = test_batch_size, shuffle = False):
        dy.renew_cg()
        outputs = parser.run(words, lemmas, tags, arcs, isTrain=False)
        for output in outputs:
            sent_idx = record[idx]
            results[sent_idx] = output
            idx += 1

    global gold_preds
    global gold_args
    global gold_rels
    global test_data

    if not gold_rels:
        print 'prepare test gold data'
        gold_rels = []
        _predicate_cnt = 0
        _preds_args = defaultdict(list)
        _end_flag = False

        with open(labeled_test_file) as f:
            for line in f:
                info = line.strip().split()
                if info:
                    assert len(info) == 10, 'Illegal line: %s' % line
                    if _end_flag and len(_preds_args) > 0 and int(
                            info[6]) == 0:
                        gold_rels.append(_preds_args)
                        _preds_args = defaultdict(list)
                        _end_flag = False
                    _preds_args[int(info[6])].append(vocab.rel2id(info[7]))
                    if info[7] != '_':
                        if int(info[6]) == 0:
                            gold_preds += 1
                        else:
                            gold_args += 1
                else:
                    _end_flag = True
            if len(_preds_args) > 0:
                gold_rels.append(_preds_args)

    if not test_data:
        print 'prepare for writing out the prediction'
        with open(raw_test_file) as f:
            test_data = []
            test_sent = []
            for line in f:
                info = line.strip().split()
                if info:
                    test_sent.append([
                        info[0], info[1], info[2], info[3], info[4], info[5],
                        info[6], info[7], info[8], info[9], info[10], info[11],
                        '_', '_'
                    ])
                elif len(test_sent) > 0:
                    test_data.append(test_sent)
                    test_sent = []

            if len(test_sent) > 0:
                test_data.append(test_sent)
                test_sent = []

    with open(output_file, 'w') as f:
        for test_sent, predict_sent, gold_sent in zip(test_data, results,
                                                      gold_rels):
            for i, rel in enumerate(predict_sent[0]):
                if rel != vocab.NONE:
                    test_sent[i][12] = 'Y'
                    test_sent[i][13] = '%s.%s' % (test_sent[i][2],
                                                  vocab.id2rel(rel))

            for k in xrange(1, len(test_sent) + 1):
                if k in gold_sent and k in predict_sent:
                    for i, (prel, grel) in enumerate(
                            zip(predict_sent[k], gold_sent[k])):
                        test_sent[i].append(vocab.id2rel(grel))
                        test_sent[i].append(vocab.id2rel(prel))
                else:
                    if k in gold_sent:
                        for i, grel in enumerate(gold_sent[k]):
                            test_sent[i].append(vocab.id2rel(grel))
                    if k in predict_sent:
                        for i, prel in enumerate(predict_sent[k]):
                            test_sent[i].append(vocab.id2rel(prel))

            for tokens in test_sent:
                f.write('\t'.join(tokens))
                f.write('\n')
            f.write('\n')

    predict_preds = 0.
    correct_preds = 0.
    predict_args = 0.
    correct_args = 0.
    num_correct = 0.
    total = 0.
    for psent, gsent in zip(results, gold_rels):
        predict_preds += len(psent) - 1
        for g_pred, g_args in gsent.iteritems():
            if g_pred == 0:
                p_args = psent.pop(g_pred)
                for p_rel, g_rel in zip(p_args, g_args):
                    if g_rel != vocab.NONE and g_rel == p_rel:
                        correct_preds += 1
            else:
                if g_pred in psent:
                    p_args = psent.pop(g_pred)
                    for p_rel, g_rel in zip(p_args, g_args):
                        total += 1
                        if p_rel != vocab.NONE:
                            predict_args += 1
                        if g_rel != vocab.NONE and g_rel == p_rel:
                            correct_args += 1
                        if g_rel == p_rel:
                            num_correct += 1
                else:
                    for i in xrange(len(g_args)):
                        total += 1

        for p_pred, p_args in psent.iteritems():
            for p_rel in p_args:
                if p_rel != vocab.NONE:
                    predict_args += 1

    print 'arguments: correct:%d, gold:%d, predicted:%d' % (
        correct_args, gold_args, predict_args)
    print 'predicates: correct:%d, gold:%d, predicted:%d' % (
        correct_preds, gold_preds, predict_preds)

    P = (correct_args + correct_preds) / (predict_args + predict_preds + 1e-13)
    R = (correct_args + correct_preds) / (gold_args + gold_preds + 1e-13)
    NP = correct_args / (predict_args + 1e-13)
    NR = correct_args / (gold_args + 1e-13)
    PP = correct_preds / (predict_preds + 1e-13)
    PR = correct_preds / (gold_preds + 1e-13)
    PF1 = 2 * PP * PR / (PP + PR + 1e-13)
    F1 = 2 * P * R / (P + R + 1e-13)
    NF1 = 2 * NP * NR / (NP + NR + 1e-13)

    print '\teval accurate:%.4f predict:%d golden:%d correct:%d' % \
       (num_correct / total * 100, predict_args, gold_args, correct_args)
    print '\tP:%.4f R:%.4f F1:%.4f' % (P * 100, R * 100, F1 * 100)
    print '\tNP:%.4f NR:%.4f NF1:%.4f' % (NP * 100, NR * 100, NF1 * 100)
    print '\tcorrect predicate:%d \tgold predicate:%d' % (correct_preds,
                                                          gold_preds)
    print '\tpredicate disambiguation PP:%.4f PR:%.4f PF1:%.4f' % (
        PP * 100, PR * 100, PF1 * 100)
    os.system('perl ../lib/eval.pl -g %s -s %s > %s.eval' %
              (raw_test_file, output_file, output_file))
    return NF1, F1
Exemple #5
0
def test(parser, vocab, num_buckets_test, test_batch_size, pro_test_file, raw_test_file, output_file,
	unified = True, disambiguation_file = None, disambiguation_accuracy = None):
	if not unified:
		assert ((disambiguation_file is not None and len(disambiguation_file)>0) or disambiguation_accuracy is not None), \
				'The accuracy of predicate disambiguation shold be provied.'
	data_loader = DataLoader(pro_test_file, num_buckets_test, vocab)
	record = data_loader.idx_sequence
	results = [None] * len(record)
	idx = 0
	for words, lemmas, tags, arcs, rels, syn_masks, seq_lens in \
			data_loader.get_batches(batch_size = test_batch_size, shuffle = False):
		dy.renew_cg()
		outputs = parser.run(words, lemmas, tags, arcs, isTrain = False, syn_mask = syn_masks, seq_lens=seq_lens)
		for output in outputs:
			sent_idx = record[idx]
			results[sent_idx] = output
			idx += 1
	rels = reduce(lambda x, y: x + y, [ list(result) for result in results ])
	
	# global gold_rels
	# global test_data
	# global test_stat

	# if not gold_rels:
	gold_rels = []
	gold_sent = []
	with open(pro_test_file) as f:
		for line in f:
			info = line.strip().split()
			if info:
				assert len(info) == 11, 'Illegal line: %s' % line
				gold_sent.append(info[7])
			else:
				if len(gold_sent)>0:
					gold_rels.append(gold_sent)
					gold_sent = []
		if len(gold_sent)>0:
			gold_rels.append(gold_sent)

	
	# if not test_data:
	#print 'prepare for writing out the prediction'
	with open(raw_test_file) as f:
		test_data = []
		test_sent = []
		test_stat = []
		pred_index = []
		for line in f:
			info = line.strip().split()
			if info:
				test_sent.append([info[0], info[1], info[2], info[3], info[4], info[5], 
								info[6], info[7], info[8], info[9], info[10], info[11], 
								info[12], '_' if unified else info[13]])
				if info[12] == 'Y':
					pred_index.append(int(info[0]))
			elif len(test_sent) > 0:
				test_data.append(test_sent)
				test_sent = []
				test_stat.append(pred_index)
				pred_index = []
		if len(test_sent) > 0:
			test_data.append(test_sent)
			test_sent = []
			test_stat.append(pred_index)
			pred_index = []


	if disambiguation_file is not None and len(disambiguation_file)>0:
		gold_disamb = []
		gold_dsent = []
		with open(disambiguation_file) as f:
			for line in f:
				info = line.strip().split()
				if info:
					gold_dsent.append(info[0])
				else:
					if len(gold_dsent)>0:
						gold_disamb.append(gold_dsent)
						gold_dsent = []
			if len(gold_dsent)>0:
				gold_disamb.append(gold_dsent)

		assert len(test_data) == len(gold_disamb)

	idx = 0
	oidx = 0
	with open(output_file, 'w') as f:
		for test_s, pred_index in zip(test_data, test_stat):
			test_sent = copy.deepcopy(test_s)
			for p_idx in pred_index:
				if unified:
					test_sent[p_idx - 1][13] = test_sent[p_idx - 1][2] + '.' + vocab.id2rel(results[idx][0])
				else:
					if disambiguation_file is not None and len(disambiguation_file)>0:
						test_sent[p_idx - 1][13] = gold_disamb[oidx][p_idx - 1]
					else:
						if random.random() < disambiguation_accuracy:
							test_sent[p_idx - 1][13] = test_sent[p_idx - 1][13]
						else:
							test_sent[p_idx - 1][13] = test_sent[p_idx - 1][2]
				if unified:
					for i, p_rel in enumerate(results[idx][1:]):	
						test_sent[i].append(vocab.id2rel(p_rel))
				else:
					for i, p_rel in enumerate(results[idx]):	
						test_sent[i].append(vocab.id2rel(p_rel))
				idx += 1
			oidx += 1
			for tokens in test_sent:
				f.write('\t'.join(tokens))
				f.write('\n')
			f.write('\n')
	
	predict_args = 0.
	correct_args = 0.
	gold_args = 0.
	correct_preds = 0.
	gold_preds = len(gold_rels)
	num_correct = 0.
	total = 0.
	idx = 0
	for sent in gold_rels:
		
		for i in range(len(sent)):
			gold = sent[i]
			pred = rels[idx]
			if unified and i==0:
				if sent[i] == vocab.id2rel(pred):
					correct_preds += 1
			else:
				total += 1
				if vocab.id2rel(pred)!='_':
					predict_args += 1
				if gold != '_':
					gold_args += 1
				if gold != '_' and gold == vocab.id2rel(pred):
					correct_args += 1
				if gold == vocab.id2rel(pred):
					num_correct += 1
			idx += 1
	
	correct_preds = correct_preds if unified else gold_preds * disambiguation_accuracy
	P = (correct_args + correct_preds) / (predict_args + gold_preds + 1e-13)
	R = (correct_args + correct_preds) / (gold_args + gold_preds + 1e-13)
	NP = correct_args / (predict_args + 1e-13)
	NR = correct_args / (gold_args + 1e-13)	   
	F1 = 2 * P * R / (P + R + 1e-13)
	NF1 = 2 * NP * NR / (NP + NR + 1e-13)
	print '\teval accurate:%.4f predict:%d golden:%d correct:%d' % \
				(num_correct / total * 100, predict_args, gold_args, correct_args)
	print '\tP:%.4f R:%.4f F1:%.4f' % (P * 100, R * 100, F1 * 100)
	print '\tNP:%.2f NR:%.2f NF1:%.2f' % (NP * 100, NR * 100, NF1 * 100)
	print '\tpredicate disambiguation accurate:%.2f' % (correct_preds / gold_preds * 100)
	os.system('perl ../lib/eval.pl -g %s -s %s > %s.eval' % (raw_test_file, output_file, output_file))
	return NF1, F1
Exemple #6
0
import os

from lib import Vocab, DataLoader

root_path = '/home/clementine/projects/树库转换/ctb51_processed/dep_zhang_clark_conlu/'
training = os.path.join(root_path, 'train.txt.3.pa.gs.tab.conllu')
vocab = Vocab(training, pret_file=None, min_occur_count=2)

ctb_loader = DataLoader(training, 40, vocab)
for i, example in enumerate(ctb_loader.get_batches(40, shuffle=True)):
    print(example[0].shape, example[1].shape, example[2].shape,
          example[3].shape)
    if i == 0:
        break
    argparser.add_argument('--model', default='BaseParser')
    args, extra_args = argparser.parse_known_args()
    config = Configurable(args.config_file, extra_args)
    Parser = getattr(models, args.model)

    vocab = Vocab(config.wsj_file, config.pretrained_embeddings_file,
                  config.min_occur_count)
    cPickle.dump(vocab, open(config.save_vocab_path, 'w'))
    parser = Parser(vocab, config.word_dims, config.tag_dims,
                    config.dropout_emb, config.lstm_layers,
                    config.lstm_hiddens, config.dropout_lstm_input,
                    config.dropout_lstm_hidden, config.mlp_arc_size,
                    config.mlp_rel_size, config.dropout_mlp,
                    config.filter_size, config.domain_num)
    wsj = DataLoader(config.wsj_file,
                     config.num_buckets_train,
                     vocab,
                     isTrain=True)
    answer = DataLoader(config.answers_file,
                        config.num_buckets_train,
                        vocab,
                        isTrain=True,
                        len_counter=wsj.len_counter)
    pc = parser.parameter_collection

    trainer = dy.AdamTrainer(pc, config.learning_rate, config.beta_1,
                             config.beta_2, config.epsilon)

    data = []

    for i, item in enumerate([wsj, answer]):
        for words, tags, arcs, rels in item.get_batches(