コード例 #1
0
ファイル: test.py プロジェクト: szha/DeepBiaffineParserMXNet
def test(parser,
         vocab,
         num_buckets_test,
         test_batch_size,
         test_file,
         output_file,
         debug=False):
    data_loader = DataLoader(test_file, num_buckets_test, vocab)
    record = data_loader.idx_sequence
    results = [None] * len(record)
    idx = 0
    for words, tags, arcs, rels in data_loader.get_batches(
            batch_size=test_batch_size, shuffle=False):
        outputs = parser.run(words, tags, is_train=False)
        for output in outputs:
            sent_idx = record[idx]
            results[sent_idx] = output
            idx += 1

    arcs = reduce(lambda x, y: x + y, [list(result[0]) for result in results])
    rels = reduce(lambda x, y: x + y, [list(result[1]) for result in results])
    idx = 0
    with open(test_file) as f:
        if debug:
            f = f.readlines()[:1000]
        with open(output_file, 'w') as fo:
            for line in f:
                info = line.strip().split()
                if info:
                    assert len(info) == 10, 'Illegal line: %s' % line
                    info[6] = str(arcs[idx])
                    info[7] = vocab.id2rel(rels[idx])
                    fo.write('\t'.join(info) + '\n')
                    idx += 1
                else:
                    fo.write('\n')

    os.system('perl run/eval.pl -q -b -g %s -s %s -o tmp' %
              (test_file, output_file))
    os.system('tail -n 3 tmp > score_tmp')
    LAS, UAS = [
        float(line.strip().split()[-2])
        for line in open('score_tmp').readlines()[:2]
    ]
    print('LAS %.2f, UAS %.2f' % (LAS, UAS))
    os.system('rm tmp score_tmp')
    return LAS, UAS
コード例 #2
0
ファイル: train.py プロジェクト: szha/DeepBiaffineParserMXNet
                                 vocab)
        # trainer = dy.AdamTrainer(pc, config.learning_rate, config.beta_1, config.beta_2, config.epsilon)
        trainer = gluon.Trainer(parser.collect_params(), 'adam',
                                {'learning_rate': config.learning_rate})

        global_step = 0
        epoch = 0
        best_UAS = 0.
        history = lambda x, y: open(
            os.path.join(config.save_dir, 'valid_history'), 'a').write(
                '%.2f %.2f\n' % (x, y))
        while global_step < config.train_iters:
            print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()),
                  ' Start training epoch #%d' % (epoch, ))
            epoch += 1
            for words, tags, arcs, rels in data_loader.get_batches(
                    batch_size=config.train_batch_size, shuffle=True):
                with autograd.record():
                    arc_accuracy, rel_accuracy, overall_accuracy, loss = parser.run(
                        words, tags, arcs, rels)
                    loss = loss * 0.5
                    loss_value = loss.asscalar()
                    print(
                        "Step #%d: Acc: arc %.2f, rel %.2f, overall %.2f, loss %.3f\r\r"
                        % (global_step, arc_accuracy, rel_accuracy,
                           overall_accuracy, loss_value))
                    # trainer.set_learning_rate(config.learning_rate * config.decay ** (global_step / config.decay_steps))
                loss.backward()
                trainer.step(config.train_batch_size)
                global_step += 1
                if global_step % config.validate_every == 0:
                    print('\nTest on development set')