def test(parser, vocab, num_buckets_test, test_batch_size, test_file, output_file, debug=False): data_loader = DataLoader(test_file, num_buckets_test, vocab) record = data_loader.idx_sequence results = [None] * len(record) idx = 0 for words, tags, arcs, rels in data_loader.get_batches( batch_size=test_batch_size, shuffle=False): outputs = parser.run(words, tags, is_train=False) for output in outputs: sent_idx = record[idx] results[sent_idx] = output idx += 1 arcs = reduce(lambda x, y: x + y, [list(result[0]) for result in results]) rels = reduce(lambda x, y: x + y, [list(result[1]) for result in results]) idx = 0 with open(test_file) as f: if debug: f = f.readlines()[:1000] with open(output_file, 'w') as fo: for line in f: info = line.strip().split() if info: assert len(info) == 10, 'Illegal line: %s' % line info[6] = str(arcs[idx]) info[7] = vocab.id2rel(rels[idx]) fo.write('\t'.join(info) + '\n') idx += 1 else: fo.write('\n') os.system('perl run/eval.pl -q -b -g %s -s %s -o tmp' % (test_file, output_file)) os.system('tail -n 3 tmp > score_tmp') LAS, UAS = [ float(line.strip().split()[-2]) for line in open('score_tmp').readlines()[:2] ] print('LAS %.2f, UAS %.2f' % (LAS, UAS)) os.system('rm tmp score_tmp') return LAS, UAS
vocab) # trainer = dy.AdamTrainer(pc, config.learning_rate, config.beta_1, config.beta_2, config.epsilon) trainer = gluon.Trainer(parser.collect_params(), 'adam', {'learning_rate': config.learning_rate}) global_step = 0 epoch = 0 best_UAS = 0. history = lambda x, y: open( os.path.join(config.save_dir, 'valid_history'), 'a').write( '%.2f %.2f\n' % (x, y)) while global_step < config.train_iters: print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()), ' Start training epoch #%d' % (epoch, )) epoch += 1 for words, tags, arcs, rels in data_loader.get_batches( batch_size=config.train_batch_size, shuffle=True): with autograd.record(): arc_accuracy, rel_accuracy, overall_accuracy, loss = parser.run( words, tags, arcs, rels) loss = loss * 0.5 loss_value = loss.asscalar() print( "Step #%d: Acc: arc %.2f, rel %.2f, overall %.2f, loss %.3f\r\r" % (global_step, arc_accuracy, rel_accuracy, overall_accuracy, loss_value)) # trainer.set_learning_rate(config.learning_rate * config.decay ** (global_step / config.decay_steps)) loss.backward() trainer.step(config.train_batch_size) global_step += 1 if global_step % config.validate_every == 0: print('\nTest on development set')