예제 #1
0
파일: parser.py 프로젝트: vin-ivar/uuparser
def run(om, options, i):

    if options.multiling:
        outdir = options.outdir
    else:
        cur_treebank = om.languages[i]
        outdir = cur_treebank.outdir

    if options.shared_task:
        outdir = options.shared_task_outdir

    if not options.predict:  # training

        print 'Preparing vocab'
        if options.multiling:
            path_is_dir = True,
            words, w2i, pos, cpos, rels, langs, ch = utils.vocab(om.languages,\
                                                                 path_is_dir,
                                                                 options.shareWordLookup,\
                                                                 options.shareCharLookup)

        else:
            words, w2i, pos, cpos, rels, langs, ch = utils.vocab(
                cur_treebank.trainfile)

        paramsfile = os.path.join(outdir, options.params)
        with open(paramsfile, 'w') as paramsfp:
            print 'Saving params to ' + paramsfile
            pickle.dump((words, w2i, pos, rels, cpos, langs, options, ch),
                        paramsfp)
            print 'Finished collecting vocab'

        print 'Initializing blstm arc hybrid:'
        parser = ArcHybridLSTM(words, pos, rels, cpos, langs, w2i, ch, options)
        if options.continueModel is not None:
            parser.Load(options.continueModel)

        for epoch in xrange(options.first_epoch,
                            options.first_epoch + options.epochs):

            print 'Starting epoch ' + str(epoch)

            if options.multiling:
                traindata = list(
                    utils.read_conll_dir(om.languages, "train",
                                         options.max_sentences))
            else:
                traindata = list(
                    utils.read_conll(cur_treebank.trainfile,
                                     cur_treebank.iso_id,
                                     options.max_sentences))

            parser.Train(traindata)
            print 'Finished epoch ' + str(epoch)

            model_file = os.path.join(outdir, options.model + str(epoch))
            parser.Save(model_file)

            if options.pred_dev:  # use the model to predict on dev data

                if options.multiling:
                    pred_langs = [
                        lang for lang in om.languages if lang.pred_dev
                    ]  # languages which have dev data on which to predict
                    for lang in pred_langs:
                        lang.outfilename = os.path.join(
                            lang.outdir, 'dev_epoch_' + str(epoch) + '.conllu')
                        print "Predicting on dev data for " + lang.name
                    devdata = utils.read_conll_dir(pred_langs, "dev")
                    pred = list(parser.Predict(devdata))
                    if len(pred) > 0:
                        utils.write_conll_multiling(pred, pred_langs)
                    else:
                        print "Warning: prediction empty"
                    if options.pred_eval:
                        for lang in pred_langs:
                            print "Evaluating dev prediction for " + lang.name
                            utils.evaluate(lang.dev_gold, lang.outfilename,
                                           om.conllu)
                else:  # monolingual case
                    if cur_treebank.pred_dev:
                        print "Predicting on dev data for " + cur_treebank.name
                        devdata = utils.read_conll(cur_treebank.devfile,
                                                   cur_treebank.iso_id)
                        cur_treebank.outfilename = os.path.join(
                            outdir, 'dev_epoch_' + str(epoch) +
                            ('.conll' if not om.conllu else '.conllu'))
                        pred = list(parser.Predict(devdata))
                        utils.write_conll(cur_treebank.outfilename, pred)
                        if options.pred_eval:
                            print "Evaluating dev prediction for " + cur_treebank.name
                            score = utils.evaluate(cur_treebank.dev_gold,
                                                   cur_treebank.outfilename,
                                                   om.conllu)
                            if options.model_selection:
                                if score > cur_treebank.dev_best[1]:
                                    cur_treebank.dev_best = [epoch, score]

            if epoch == options.epochs:  # at the last epoch choose which model to copy to barchybrid.model
                if not options.model_selection:
                    best_epoch = options.epochs  # take the final epoch if model selection off completely (for example multilingual case)
                else:
                    best_epoch = cur_treebank.dev_best[
                        0]  # will be final epoch by default if model selection not on for this treebank
                    if cur_treebank.model_selection:
                        print "Best dev score of " + str(
                            cur_treebank.dev_best[1]
                        ) + " found at epoch " + str(cur_treebank.dev_best[0])

                bestmodel_file = os.path.join(
                    outdir, "barchybrid.model" + str(best_epoch))
                model_file = os.path.join(outdir, "barchybrid.model")
                print "Copying " + bestmodel_file + " to " + model_file
                copyfile(bestmodel_file, model_file)

    else:  #if predict - so

        if options.multiling:
            modeldir = options.modeldir
        else:
            modeldir = om.languages[i].modeldir

        params = os.path.join(modeldir, options.params)
        print 'Reading params from ' + params
        with open(params, 'r') as paramsfp:
            words, w2i, pos, rels, cpos, langs, stored_opt, ch = pickle.load(
                paramsfp)

            parser = ArcHybridLSTM(words, pos, rels, cpos, langs, w2i, ch,
                                   stored_opt)
            model = os.path.join(modeldir, options.model)
            parser.Load(model)

            if options.multiling:
                testdata = utils.read_conll_dir(om.languages, "test")
            else:
                testdata = utils.read_conll(cur_treebank.testfile,
                                            cur_treebank.iso_id)

            ts = time.time()

            if options.multiling:
                for l in om.languages:
                    l.outfilename = os.path.join(outdir, l.outfilename)
                pred = list(parser.Predict(testdata))
                utils.write_conll_multiling(pred, om.languages)
            else:
                if cur_treebank.outfilename:
                    cur_treebank.outfilename = os.path.join(
                        outdir, cur_treebank.outfilename)
                else:
                    cur_treebank.outfilename = os.path.join(
                        outdir,
                        'out' + ('.conll' if not om.conllu else '.conllu'))
                utils.write_conll(cur_treebank.outfilename,
                                  parser.Predict(testdata))

            te = time.time()

            if options.pred_eval:
                if options.multiling:
                    for l in om.languages:
                        print "Evaluating on " + l.name
                        score = utils.evaluate(l.test_gold, l.outfilename,
                                               om.conllu)
                        print "Obtained LAS F1 score of %.2f on %s" % (score,
                                                                       l.name)
                else:
                    print "Evaluating on " + cur_treebank.name
                    score = utils.evaluate(cur_treebank.test_gold,
                                           cur_treebank.outfilename, om.conllu)
                    print "Obtained LAS F1 score of %.2f on %s" % (
                        score, cur_treebank.name)

            print 'Finished predicting'
예제 #2
0
def run(om,options,i):

    if options.multiling:
        outdir = options.outdir
    else:
        cur_treebank = om.languages[i]
        outdir = cur_treebank.outdir

    if options.shared_task:
        outdir = options.shared_task_outdir

    if not options.predict: # training

        fineTune = False
        start_from = 1
        if options.continueModel is None:
            continueTraining = False
        else:
            continueTraining = True
            trainedModel = options.continueModel
            if options.fineTune:
                fineTune = True
            else:
                start_from = options.first_epoch - 1

        if not continueTraining:
            print 'Preparing vocab'
            if options.multiling:
                path_is_dir=True,
                words, w2i, pos, cpos, rels, langs, ch = utils.vocab(om.languages,\
                                                                     path_is_dir,
                                                                     options.shareWordLookup,\
                                                                     options.shareCharLookup)

            else:
                words, w2i, pos, cpos, rels, langs, ch = utils.vocab(cur_treebank.trainfile)

            paramsfile = os.path.join(outdir, options.params)
            with open(paramsfile, 'w') as paramsfp:
                print 'Saving params to ' + paramsfile
                pickle.dump((words, w2i, pos, rels, cpos, langs,
                             options, ch), paramsfp)
                print 'Finished collecting vocab'
        else:
            paramsfile = os.path.join(outdir, options.params)
            with open(paramsfile, 'rb') as paramsfp:
                print 'Load params from ' + paramsfile
                words, w2i, pos, rels, cpos, langs, options, ch = pickle.load(paramsfp)
                print 'Finished loading vocab'

        max_epochs = options.first_epoch + options.epochs
        print 'Initializing blstm arc hybrid:'
        parser = ArcHybridLSTM(words, pos, rels, cpos, langs, w2i,
                               ch, options)

        if continueTraining:
            if not fineTune: 
                # continue training only, not doing fine tuning
                options.first_epoch = start_from + 1
                max_epochs = options.epochs
            else:
                # fine tune model
                options.first_epoch = options.epochs + 1
                max_epochs = options.first_epoch + 15
                print 'Fine tune model for another', max_epochs - options.first_epoch, 'epochs'

            parser.Load(trainedModel)
            

        best_multi_las = -1
        best_multi_epoch = 0
        
        if continueTraining:
            train_stats = codecs.open(os.path.join(outdir, 'train.stats'), 'a', encoding='utf-8')
        else:
            train_stats = codecs.open(os.path.join(outdir, 'train.stats'), 'w', encoding='utf-8')
                
        for epoch in xrange(options.first_epoch, max_epochs + 1):

            print 'Starting epoch ' + str(epoch)

            if options.multiling:
                traindata = list(utils.read_conll_dir(om.languages, "train", options.max_sentences))
            else:
                traindata = list(utils.read_conll(cur_treebank.trainfile, cur_treebank.iso_id,options.max_sentences))

            parser.Train(traindata)
            train_stats.write(unicode('Epoch ' + str(epoch) + '\n'))
            print 'Finished epoch ' + str(epoch)

            model_file = os.path.join(outdir, options.model + '.tmp')
            parser.Save(model_file)

            if options.pred_dev: # use the model to predict on dev data
                if options.multiling:
                    pred_langs = [lang for lang in om.languages if lang.pred_dev] # languages which have dev data on which to predict
                    for lang in pred_langs:
                        lang.outfilename = os.path.join(lang.outdir, 'dev_epoch_' + str(epoch) + '.conllu')
                        print "Predicting on dev data for " + lang.name
                    devdata = utils.read_conll_dir(pred_langs,"dev")
                    pred = list(parser.Predict(devdata))

                    if len(pred)>0:
                        utils.write_conll_multiling(pred,pred_langs)
                    else:
                        print "Warning: prediction empty"
                    
                    if options.pred_eval:
                        total_las = 0
                        for lang in pred_langs:
                            print "Evaluating dev prediction for " + lang.name
                            las_score = utils.evaluate(lang.dev_gold, lang.outfilename,om.conllu)
                            total_las += las_score
                            train_stats.write(unicode('Dev LAS ' + lang.name + ': ' + str(las_score) + '\n'))
                        if options.model_selection:
                            if total_las > best_multi_las:
                                best_multi_las = total_las
                                best_multi_epoch = epoch 

                else: # monolingual case
                    if cur_treebank.pred_dev:
                        print "Predicting on dev data for " + cur_treebank.name
                        devdata = utils.read_conll(cur_treebank.devfile, cur_treebank.iso_id)
                        cur_treebank.outfilename = os.path.join(outdir, 'dev_epoch_' + str(epoch) + ('.conll' if not om.conllu else '.conllu'))
                        pred = list(parser.Predict(devdata))
                        utils.write_conll(cur_treebank.outfilename, pred)
                        if options.pred_eval:
                            print "Evaluating dev prediction for " + cur_treebank.name
                            las_score = utils.evaluate(cur_treebank.dev_gold, cur_treebank.outfilename, om.conllu)
                            if options.model_selection:
                                if las_score > cur_treebank.dev_best[1]:
                                    cur_treebank.dev_best = [epoch, las_score]
                                    train_stats.write(unicode('Dev LAS ' + cur_treebank.name + ': ' + str(las_score) + '\n'))
                                    

            if epoch == max_epochs: # at the last epoch choose which model to copy to barchybrid.model
                if not options.model_selection:
                    best_epoch = options.epochs # take the final epoch if model selection off completely (for example multilingual case)
                else:
                    if options.multiling:
                        best_epoch = best_multi_epoch
                    else:
                        best_epoch = cur_treebank.dev_best[0] # will be final epoch by default if model selection not on for this treebank
                        if cur_treebank.model_selection:
                            print "Best dev score of " + str(cur_treebank.dev_best[1]) + " found at epoch " + str(cur_treebank.dev_best[0])

                bestmodel_file = os.path.join(outdir,"barchybrid.model.tmp")
                model_file = os.path.join(outdir,"barchybrid.model")
                if fineTune:
                    model_file = os.path.join(outdir,"barchybrid.tuned.model")
                print "Best epoch: " + str(best_epoch)
                print "Copying " + bestmodel_file + " to " + model_file
                copyfile(bestmodel_file,model_file)

        train_stats.close()

    else: #if predict - so

        # import pdb;pdb.set_trace()
        eval_type = options.evaltype
        print "Eval type: ", eval_type
        if eval_type == "train":
            if options.multiling:
                for l in om.languages:
                    l.test_gold = l.test_gold.replace('test', 'train')
            else:
                cur_treebank.testfile = cur_treebank.trainfile
                cur_treebank.test_gold = cur_treebank.trainfile

        elif eval_type == "dev":
            if options.multiling:
                for l in om.languages:
                    l.test_gold = l.test_gold.replace('test', 'dev')
            else:
                cur_treebank.testfile = cur_treebank.devfile
                cur_treebank.test_gold = cur_treebank.devfile

        if options.multiling:
            modeldir = options.modeldir
            if options.fineTune:
                prefix = [os.path.join(outdir, os.path.basename(l.test_gold) + '-tuned') for l in om.languages] 
            else:
                prefix = [os.path.join(outdir, os.path.basename(l.test_gold)) for l in om.languages] 
        else:
            modeldir = om.languages[i].modeldir
            if options.fineTune:
                prefix = os.path.join(outdir, os.path.basename(cur_treebank.testfile)) + '-tuned'
            else:
                prefix = os.path.join(outdir, os.path.basename(cur_treebank.testfile))

        if not options.extract_vectors:
            prefix = None


        params = os.path.join(modeldir, options.params)
        print 'Reading params from ' + params
        with open(params, 'r') as paramsfp:
            words, w2i, pos, rels, cpos, langs, stored_opt, ch = pickle.load(paramsfp)

            parser = ArcHybridLSTM(words, pos, rels, cpos, langs, w2i,
                               ch, stored_opt)

            if options.fineTune:
                options.model = options.model.replace('.model', '.tuned.model')
            model = os.path.join(modeldir, options.model)
            parser.Load(model)

            if options.multiling:
                testdata = utils.read_conll_dir(om.languages, eval_type)
            else:
                testdata = utils.read_conll(cur_treebank.testfile, cur_treebank.iso_id)

            ts = time.time()

            if options.multiling:
                for l in om.languages:
                    l.outfilename = os.path.join(outdir, eval_type + "-" + l.outfilename)
                pred = list(parser.Predict(testdata, prefix))
                utils.write_conll_multiling(pred,om.languages)
            else:
                if cur_treebank.outfilename:
                    cur_treebank.outfilename = os.path.join(outdir, eval_type + "-" + cur_treebank.outfilename)
                else:
                    cur_treebank.outfilename = os.path.join(outdir, 'out' + ('.conll' if not om.conllu else '.conllu'))
                utils.write_conll(cur_treebank.outfilename, parser.Predict(testdata, prefix))

            te = time.time()

            if options.pred_eval:
                if options.multiling:
                    for l in om.languages:
                        print "Evaluating on " + l.name
                        score = utils.evaluate(l.test_gold, l.outfilename, om.conllu)
                        print "Obtained LAS F1 score of %.2f on %s" %(score, l.name)
                else:
                    print "Evaluating on " + cur_treebank.name
                    score = utils.evaluate(cur_treebank.test_gold, cur_treebank.outfilename, om.conllu)
                    print "Obtained LAS F1 score of %.2f on %s" %(score,cur_treebank.name)

            print 'Finished predicting'
예제 #3
0
def run(om, options, i):

    if options.multiling:
        outdir = options.outdir
    else:
        cur_treebank = om.languages[i]
        outdir = cur_treebank.outdir

    if options.shared_task:
        outdir = options.shared_task_outdir

    if not options.predict:  # training

        print 'Preparing vocab'
        if options.multiling:
            words, w2i, pos, cpos, rels, langs, ch = utils.vocab(
                om.languages, path_is_dir=True)

        else:
            words, w2i, pos, cpos, rels, langs, ch = utils.vocab(
                cur_treebank.trainfile)

        paramsfile = os.path.join(outdir, options.params)
        with open(paramsfile, 'w') as paramsfp:
            print 'Saving params to ' + paramsfile
            pickle.dump((words, w2i, pos, rels, cpos, langs, options, ch),
                        paramsfp)
            print 'Finished collecting vocab'

        print 'Initializing blstm arc hybrid:'
        parser = ArcHybridLSTM(words, pos, rels, cpos, langs, w2i, ch, options)

        durations = []
        for epoch in xrange(options.first_epoch,
                            options.first_epoch + options.epochs):

            print 'Starting epoch ' + str(epoch)
            start_time = time.time()

            if options.multiling:
                traindata = list(
                    utils.read_conll_dir(om.languages, "train",
                                         options.max_sentences))
            else:
                traindata = list(
                    utils.read_conll(cur_treebank.trainfile,
                                     cur_treebank.iso_id,
                                     options.max_sentences))

            parser.Train(traindata)
            print 'Finished epoch ' + str(epoch)

            if not options.overwrite_model:
                model_file = os.path.join(outdir, options.model + str(epoch))
                parser.Save(model_file)

            if options.pred_dev:  # use the model to predict on dev data

                if options.multiling:
                    pred_langs = [
                        lang for lang in om.languages if lang.pred_dev
                    ]  # languages which have dev data on which to predict
                    for lang in pred_langs:
                        lang.outfilename = os.path.join(
                            lang.outdir, 'dev_epoch_' + str(epoch) + '.conllu')
                        print "Predicting on dev data for " + lang.name
                    devdata = utils.read_conll_dir(pred_langs, "dev")
                    pred = list(parser.Predict(devdata))
                    if len(pred) > 0:
                        utils.write_conll_multiling(pred, pred_langs)
                    else:
                        print "Warning: prediction empty"
                    if options.pred_eval:
                        for lang in pred_langs:
                            print "Evaluating dev prediction for " + lang.name
                            utils.evaluate(lang.dev_gold, lang.outfilename,
                                           om.conllu)
                else:  # monolingual case
                    if cur_treebank.pred_dev:
                        print "Predicting on dev data for " + cur_treebank.name
                        devdata = utils.read_conll(cur_treebank.devfile,
                                                   cur_treebank.iso_id)
                        cur_treebank.outfilename = os.path.join(
                            outdir, 'dev_epoch_' + str(epoch) +
                            ('.conll' if not om.conllu else '.conllu'))
                        pred = list(parser.Predict(devdata))
                        utils.write_conll(cur_treebank.outfilename, pred)
                        if options.pred_eval:
                            print "Evaluating dev prediction for " + cur_treebank.name
                            score = utils.evaluate(cur_treebank.dev_gold,
                                                   cur_treebank.outfilename,
                                                   om.conllu)
                            if options.model_selection:
                                if score > cur_treebank.dev_best[1]:
                                    cur_treebank.dev_best = [epoch, score]
                                if options.overwrite_model:
                                    print "Overwriting model due to higher dev score"
                                    model_file = os.path.join(
                                        cur_treebank.outdir, options.model)
                                    parser.Save(model_file)

            if options.deadline:
                # keep track of duration of training+eval
                now = time.time()
                duration = now - start_time
                durations.append(duration)
                # estimate when next epoch will finish
                last_five_durations = durations[-5:]
                eta = time.time() + max(last_five_durations)
                print 'Deadline in %.1f seconds' % (options.deadline - now)
                print 'ETA of next epoch in %.1f seconds' % (eta - now)
                # does it exceed the deadline?
                exceeds_deadline = eta > options.deadline
            else:
                # no deadline
                exceeds_deadline = False

            if exceeds_deadline or epoch == options.epochs:
                # at the last epoch copy the best model to barchybrid.model
                if not options.model_selection:
                    # model selection off completely (for example multilingual case)
                    # --> take the final epoch, i.e. the current epoch
                    best_epoch = epoch
                else:
                    best_epoch = cur_treebank.dev_best[
                        0]  # will be final epoch by default if model selection not on for this treebank
                    if cur_treebank.model_selection:
                        print "Best dev score of " + str(
                            cur_treebank.dev_best[1]
                        ) + " found at epoch " + str(cur_treebank.dev_best[0])

                if not options.overwrite_model:
                    bestmodel_file = os.path.join(
                        outdir, "barchybrid.model" + str(best_epoch))
                    model_file = os.path.join(outdir, "barchybrid.model")
                    print "Copying " + bestmodel_file + " to " + model_file
                    copyfile(bestmodel_file, model_file)

            if exceeds_deadline and epoch < options.epochs:
                print 'Leaving epoch loop early to avoid exceeding deadline'
                break

            if exceeds_deadline and epoch < options.epochs:
                print 'Leaving epoch loop early to avoid exceeding deadline'
                break

    else:  #if predict - so

        if options.multiling:
            modeldir = options.modeldir
        else:
            modeldir = om.languages[i].modeldir

        params = os.path.join(modeldir, options.params)
        print 'Reading params from ' + params
        with open(params, 'r') as paramsfp:
            words, w2i, pos, rels, cpos, langs, stored_opt, ch = pickle.load(
                paramsfp)

            parser = ArcHybridLSTM(words, pos, rels, cpos, langs, w2i, ch,
                                   stored_opt)
            model = os.path.join(modeldir, options.model)
            parser.Load(model)

            if options.multiling:
                testdata = utils.read_conll_dir(om.languages, "test")
            else:
                testdata = utils.read_conll(cur_treebank.testfile,
                                            cur_treebank.iso_id)

            ts = time.time()

            if options.multiling:
                for l in om.languages:
                    l.outfilename = os.path.join(outdir, l.outfilename)
                pred = list(parser.Predict(testdata))
                utils.write_conll_multiling(pred, om.languages)
            else:
                if cur_treebank.outfilename:
                    cur_treebank.outfilename = os.path.join(
                        outdir, cur_treebank.outfilename)
                else:
                    cur_treebank.outfilename = os.path.join(
                        outdir,
                        'out' + ('.conll' if not om.conllu else '.conllu'))
                utils.write_conll(cur_treebank.outfilename,
                                  parser.Predict(testdata))

            te = time.time()

            if options.pred_eval:
                if options.multiling:
                    for l in om.languages:
                        print "Evaluating on " + l.name
                        score = utils.evaluate(l.test_gold, l.outfilename,
                                               om.conllu)
                        print "Obtained LAS F1 score of %.2f on %s" % (score,
                                                                       l.name)
                else:
                    print "Evaluating on " + cur_treebank.name
                    score = utils.evaluate(cur_treebank.test_gold,
                                           cur_treebank.outfilename, om.conllu)
                    print "Obtained LAS F1 score of %.2f on %s" % (
                        score, cur_treebank.name)

            print 'Finished predicting'
예제 #4
0
def train_parser(options,
                 sentences_train=None,
                 sentences_dev=None,
                 sentences_test=None):
    current_path = os.path.dirname(__file__)
    set_proc_name(options.title)
    if not (options.rlFlag or options.rlMostFlag or options.headFlag):
        print(
            'You must use either --userlmost or --userl or --usehead (you can use multiple)'
        )
        sys.exit()

    if not sentences_train:
        sentences_train = get_sentences(options.conll_train)
    if not sentences_dev:
        sentences_dev = get_sentences(options.conll_dev) \
            if options.conll_dev is not None else None
    if not sentences_test:
        sentences_test = get_sentences(options.conll_test) \
            if options.conll_test is not None else None

    print('Preparing vocab')
    words, w2i, pos, rels = tree_utils.vocab(sentences_train)
    if not os.path.exists(options.output):
        os.mkdir(options.output)
    with open(os.path.join(options.output, options.params), 'wb') as paramsfp:
        pickle.dump((words, w2i, pos, rels, options), paramsfp)
    print('Finished collecting vocab')
    print('Initializing blstm arc hybrid:')
    parser = ArcHybridLSTM(words, pos, rels, w2i, options)
    for epoch in range(options.epochs):
        print('Starting epoch', epoch)
        parser.Train(sentences_train)

        def predict(sentences, gold_file, output_file):

            with open(output_file, "w") as f:
                result = parser.Predict(sentences)
                for i in result:
                    f.write(i.to_string())

            eval_script = os.path.join(
                current_path, "utils/evaluation_script/conll17_ud_eval.py")
            weight_file = os.path.join(current_path,
                                       "utils/evaluation_script/weights.clas")
            eval_process = sh.python(eval_script,
                                     "-v",
                                     "-w",
                                     weight_file,
                                     gold_file,
                                     output_file,
                                     _out=output_file + '.txt')
            eval_process.wait()
            sh.cat(output_file + '.txt', _out=sys.stdout)

            print('Finished predicting {}'.format(gold_file))

        if sentences_dev:
            dev_output = os.path.join(
                options.output, 'dev_epoch_' + str(epoch + 1) + '.conllu')
            predict(sentences_dev, options.conll_dev, dev_output)

        if sentences_test:
            test_output = os.path.join(
                options.output, 'test_epoch_' + str(epoch + 1) + '.conllu')
            predict(sentences_test, options.conll_test, test_output)

        for i in range(epoch + 1 - options.max_model):
            filename = os.path.join(options.output, options.model + str(i))
            if os.path.exists(filename):
                os.remove(filename)
        parser.Save(
            os.path.join(options.output, options.model + str(epoch + 1)))
예제 #5
0
def run(experiment, options):

    if not options.predict:  # training

        paramsfile = os.path.join(experiment.outdir, options.params)

        if not options.continueTraining:
            print 'Preparing vocab'
            vocab = utils.get_vocab(experiment.treebanks, "train")
            print 'Finished collecting vocab'

            with open(paramsfile, 'w') as paramsfp:
                print 'Saving params to ' + paramsfile
                pickle.dump((vocab, options), paramsfp)

                print 'Initializing blstm arc hybrid:'
                parser = ArcHybridLSTM(vocab, options)
        else:  #continue
            if options.continueParams:
                paramsfile = options.continueParams
            with open(paramsfile, 'r') as paramsfp:
                stored_vocab, stored_options = pickle.load(paramsfp)
                print 'Initializing blstm arc hybrid:'
                parser = ArcHybridLSTM(stored_vocab, stored_options)

            parser.Load(options.continueModel)

        dev_best = [options.epochs, -1.0]  # best epoch, best score

        for epoch in xrange(options.first_epoch, options.epochs + 1):

            print 'Starting epoch ' + str(epoch)
            traindata = list(
                utils.read_conll_dir(experiment.treebanks, "train",
                                     options.max_sentences))
            parser.Train(traindata, options)
            print 'Finished epoch ' + str(epoch)

            model_file = os.path.join(experiment.outdir,
                                      options.model + str(epoch))
            parser.Save(model_file)

            if options.pred_dev:  # use the model to predict on dev data

                # not all treebanks necessarily have dev data
                pred_treebanks = [
                    treebank for treebank in experiment.treebanks
                    if treebank.pred_dev
                ]
                if pred_treebanks:
                    for treebank in pred_treebanks:
                        treebank.outfilename = os.path.join(
                            treebank.outdir,
                            'dev_epoch_' + str(epoch) + '.conllu')
                        print "Predicting on dev data for " + treebank.name
                    pred = list(parser.Predict(pred_treebanks, "dev", options))
                    utils.write_conll_multiling(pred, pred_treebanks)

                    if options.pred_eval:  # evaluate the prediction against gold data
                        mean_score = 0.0
                        for treebank in pred_treebanks:
                            score = utils.evaluate(treebank.dev_gold,
                                                   treebank.outfilename,
                                                   options.conllu)
                            print "Dev score %.2f at epoch %i for %s" % (
                                score, epoch, treebank.name)
                            mean_score += score
                        if len(pred_treebanks) > 1:  # multiling case
                            mean_score = mean_score / len(pred_treebanks)
                            print "Mean dev score %.2f at epoch %i" % (
                                mean_score, epoch)
                        if options.model_selection:
                            if mean_score > dev_best[1]:
                                dev_best = [epoch, mean_score
                                            ]  # update best dev score
                            # hack to print the word "mean" if the dev score is an average
                            mean_string = "mean " if len(
                                pred_treebanks) > 1 else ""
                            print "Best %sdev score %.2f at epoch %i" % (
                                mean_string, dev_best[1], dev_best[0])

            # at the last epoch choose which model to copy to barchybrid.model
            if epoch == options.epochs:
                bestmodel_file = os.path.join(
                    experiment.outdir, "barchybrid.model" + str(dev_best[0]))
                model_file = os.path.join(experiment.outdir,
                                          "barchybrid.model")
                print "Copying " + bestmodel_file + " to " + model_file
                copyfile(bestmodel_file, model_file)
                best_dev_file = os.path.join(experiment.outdir,
                                             "best_dev_epoch.txt")
                with open(best_dev_file, 'w') as fh:
                    print "Writing best scores to: " + best_dev_file
                    if len(experiment.treebanks) == 1:
                        fh.write("Best dev score %s at epoch %i\n" %
                                 (dev_best[1], dev_best[0]))
                    else:
                        fh.write("Best mean dev score %s at epoch %i\n" %
                                 (dev_best[1], dev_best[0]))

    else:  #if predict - so

        params = os.path.join(experiment.modeldir, options.params)
        print 'Reading params from ' + params
        with open(params, 'r') as paramsfp:
            stored_vocab, stored_opt = pickle.load(paramsfp)

            # we need to update/add certain options based on new user input
            utils.fix_stored_options(stored_opt, options)

            parser = ArcHybridLSTM(stored_vocab, stored_opt)
            model = os.path.join(experiment.modeldir, options.model)
            parser.Load(model)

            ts = time.time()

            for treebank in experiment.treebanks:
                if options.predict_all_epochs:  # name outfile after epoch number in model file
                    try:
                        m = re.search('(\d+)$', options.model)
                        epoch = m.group(1)
                        treebank.outfilename = 'dev_epoch_%s.conllu' % epoch
                    except AttributeError:
                        raise Exception(
                            "No epoch number found in model file (e.g. barchybrid.model22)"
                        )
                if not treebank.outfilename:
                    treebank.outfilename = 'out' + (
                        '.conll' if not options.conllu else '.conllu')
                treebank.outfilename = os.path.join(treebank.outdir,
                                                    treebank.outfilename)

            pred = list(
                parser.Predict(experiment.treebanks, "test", stored_opt))
            utils.write_conll_multiling(pred, experiment.treebanks)

            te = time.time()

            if options.pred_eval:
                for treebank in experiment.treebanks:
                    print "Evaluating on " + treebank.name
                    score = utils.evaluate(treebank.test_gold,
                                           treebank.outfilename,
                                           options.conllu)
                    print "Obtained LAS F1 score of %.2f on %s" % (
                        score, treebank.name)

            print 'Finished predicting'
예제 #6
0
def run(om, options, i):
    outdir = options.output
    if options.multi_monoling:
        cur_treebank = om.languages[i]
        outdir = cur_treebank.outdir
        modelDir = cur_treebank.modelDir
    else:
        outdir = options.output
        modelDir = om.languages[i].modelDir

    if options.shared_task:
        outdir = options.shared_task_outdir

    if not options.include:
        cur_treebank = om.treebank

    if not options.predictFlag:

        print 'Preparing vocab'
        if options.multiling:
            words, w2i, pos, cpos, rels, langs, ch = utils.vocab(
                om.languages, path_is_dir=True)

        else:
            words, w2i, pos, cpos, rels, langs, ch = utils.vocab(
                cur_treebank.trainfile)

        with open(os.path.join(outdir, options.params), 'w') as paramsfp:
            pickle.dump((words, w2i, pos, rels, cpos, langs, options, ch),
                        paramsfp)
            print 'Finished collecting vocab'

        print 'Initializing blstm arc hybrid:'
        parser = ArcHybridLSTM(words, pos, rels, cpos, langs, w2i, ch, options)

        for epoch in xrange(options.first_epoch - 1,
                            options.first_epoch - 1 + options.epochs):
            if options.multiling:
                traindata = list(
                    utils.read_conll_dir(om.languages, "train",
                                         options.drop_proj, options.maxCorpus))
                devdata = enumerate(utils.read_conll_dir(om.languages, "dev"))

            else:
                conllFP = open(cur_treebank.trainfile, 'r')
                traindata = list(
                    utils.read_conll(conllFP, options.drop_proj,
                                     cur_treebank.iso_id))
                if os.path.exists(cur_treebank.devfile):
                    conllFP = open(cur_treebank.devfile, 'r')
                    devdata = enumerate(
                        utils.read_conll(conllFP, False, cur_treebank.iso_id))
                else:
                    tot_sen = len(traindata)
                    #take a bit less than 5% of train sentences for dev
                    if tot_sen > 1000:
                        import random
                        random.shuffle(traindata)
                        dev_len = int(0.05 * tot_sen)
                        #gen object * 2
                        devdata, dev_gold = itertools.tee(traindata[:dev_len])
                        devdata = enumerate(devdata)
                        dev_gold_f = os.path.join(outdir,
                                                  'dev_gold' + '.conllu')
                        utils.write_conll(dev_gold_f, dev_gold)
                        cur_treebank.dev_gold = dev_gold_f
                        traindata = traindata[dev_len:]
                    else:
                        devdata = None

            print 'Starting epoch', epoch
            parser.Train(traindata)

            if options.multiling:
                for l in om.languages:
                    l.outfilename = os.path.join(
                        l.outdir, 'dev_epoch_' + str(epoch + 1) + '.conllu')
                pred = list(parser.Predict(devdata))
                if len(pred) > 0:
                    utils.write_conll_multiling(pred, om.languages)
            else:
                cur_treebank.outfilename = os.path.join(
                    outdir, 'dev_epoch_' + str(epoch + 1) +
                    ('.conll' if not om.conllu else '.conllu'))
                if devdata:
                    pred = list(parser.Predict(devdata))
                    utils.write_conll(cur_treebank.outfilename, pred)

            if options.multiling:
                for l in om.languages:
                    utils.evaluate(l.dev_gold, l.outfilename, om.conllu)
            else:
                utils.evaluate(cur_treebank.dev_gold, cur_treebank.outfilename,
                               om.conllu)

            print 'Finished predicting dev'
            parser.Save(os.path.join(outdir, options.model + str(epoch + 1)))

    else:  #if predict - so
        params = os.path.join(modelDir, options.params)
        with open(params, 'r') as paramsfp:
            words, w2i, pos, rels, cpos, langs, stored_opt, ch = pickle.load(
                paramsfp)

            parser = ArcHybridLSTM(words, pos, rels, cpos, langs, w2i, ch,
                                   stored_opt)
            model = os.path.join(modelDir, options.model)
            parser.Load(model)

            if options.multiling:
                testdata = enumerate(utils.read_conll_dir(
                    om.languages, "test"))

            if not options.multiling:
                conllFP = open(cur_treebank.testfile, 'r')
                testdata = enumerate(
                    utils.read_conll(conllFP, False, cur_treebank.iso_id))

            ts = time.time()

            if options.multiling:
                for l in om.languages:
                    l.outfilename = os.path.join(outdir, l.outfilename)
                pred = list(parser.Predict(testdata))
                utils.write_conll_multiling(pred, om.languages)
            else:
                cur_treebank.outfilename = os.path.join(
                    outdir, cur_treebank.outfilename)
                utils.write_conll(cur_treebank.outfilename,
                                  parser.Predict(testdata))

            te = time.time()

            if options.predEval:
                if options.multiling:
                    for l in om.languages:
                        utils.evaluate(l.test_gold, l.outfilename, om.conllu)
                else:
                    utils.evaluate(cur_treebank.test_gold,
                                   cur_treebank.outfilename, om.conllu)

            print 'Finished predicting test', te - ts