Example #1
0
    def Predict(self, treebanks, datasplit, options):
        reached_max_swap = 0
        char_map = {}
        if options.char_map_file:
            char_map_fh = open(options.char_map_file,encoding='utf-8')
            char_map = json.loads(char_map_fh.read())
        # should probably use a namedtuple in get_vocab to make this prettier
        _, test_words, test_chars, _, _, _, test_treebanks, test_langs = utils.get_vocab(treebanks,datasplit,char_map)

        # get external embeddings for the set of words and chars in the
        # test vocab but not in the training vocab
        test_embeddings = defaultdict(lambda: {})
        if options.word_emb_size > 0 and options.ext_word_emb_file:
            new_test_words = \
                set(test_words) - self.feature_extractor.words.keys()

            logger.debug(f"Number of OOV word types at test time: {len(new_test_words)} (out of {len(test_words)})")

            if len(new_test_words) > 0:
                # no point loading embeddings if there are no words to look for
                for lang in test_langs:
                    embeddings = utils.get_external_embeddings(
                        options,
                        emb_file=options.ext_word_emb_file,
                        lang=lang,
                        words=new_test_words
                    )
                    test_embeddings["words"].update(embeddings)
                if len(test_langs) > 1 and test_embeddings["words"]:
                    logger.debug(
                        "External embeddings found for {0} words (out of {1})".format(
                            len(test_embeddings["words"]),
                            len(new_test_words),
                        ),
                    )

        if options.char_emb_size > 0:
            new_test_chars = \
                set(test_chars) - self.feature_extractor.chars.keys()
            logger.debug(
                f"Number of OOV char types at test time: {len(new_test_chars)} (out of {len(test_chars)})"
            )

            if len(new_test_chars) > 0:
                for lang in test_langs:
                    embeddings = utils.get_external_embeddings(
                        options,
                        emb_file=options.ext_char_emb_file,
                        lang=lang,
                        words=new_test_chars,
                        chars=True
                    )
                    test_embeddings["chars"].update(embeddings)
                if len(test_langs) > 1 and test_embeddings["chars"]:
                    logger.debug(
                        "External embeddings found for {0} chars (out of {1})".format(
                            len(test_embeddings["chars"]),
                            len(new_test_chars),
                        ),
                    )

        data = utils.read_conll_dir(treebanks,datasplit,char_map=char_map)

        pbar = tqdm.tqdm(
            data,
            desc="Parsing",
            unit="sentences",
            mininterval=1.0,
            leave=False,
            disable=options.quiet,
        )

        for iSentence, osentence in enumerate(pbar,1):
            sentence = deepcopy(osentence)
            reached_swap_for_i_sentence = False
            max_swap = 2*len(sentence)
            iSwap = 0
            self.feature_extractor.Init(options)
            conll_sentence = [entry for entry in sentence if isinstance(entry, utils.ConllEntry)]
            conll_sentence = conll_sentence[1:] + [conll_sentence[0]]
            self.feature_extractor.getWordEmbeddings(conll_sentence, False, options, test_embeddings)
            stack = ParseForest([])
            buf = ParseForest(conll_sentence)

            hoffset = 1 if self.headFlag else 0

            for root in conll_sentence:
                root.lstms = [root.vec] if self.headFlag else []
                root.lstms += [root.vec for _ in range(self.nnvecs - hoffset)]
                root.relation = root.relation if root.relation in self.irels else 'runk'


            while not (len(buf) == 1 and len(stack) == 0):
                scores = self.__evaluate(stack, buf, False)
                best = max(chain(*(scores if iSwap < max_swap else scores[:3] )), key = itemgetter(2) )
                if iSwap == max_swap and not reached_swap_for_i_sentence:
                    reached_max_swap += 1
                    reached_swap_for_i_sentence = True
                    logger.debug(f"reached max swap in {reached_max_swap:d} out of {iSentence:d} sentences")
                self.apply_transition(best,stack,buf,hoffset)
                if best[1] == SWAP:
                    iSwap += 1

            dy.renew_cg()

            #keep in memory the information we need, not all the vectors
            oconll_sentence = [entry for entry in osentence if isinstance(entry, utils.ConllEntry)]
            oconll_sentence = oconll_sentence[1:] + [oconll_sentence[0]]
            for tok_o, tok in zip(oconll_sentence, conll_sentence):
                tok_o.pred_relation = tok.pred_relation
                tok_o.pred_parent_id = tok.pred_parent_id
            yield osentence
Example #2
0
def run(experiment, options):
    if options.graph_based:
        from uuparser.mstlstm import MSTParserLSTM as Parser
        logger.info('Working with a graph-based parser')
    else:
        from uuparser.arc_hybrid import ArcHybridLSTM as Parser
        logger.info('Working with a transition-based parser')

    if not options.predict:  # training

        paramsfile = os.path.join(experiment.outdir, options.params)

        if not options.continueTraining:
            logger.debug('Preparing vocab')
            vocab = utils.get_vocab(experiment.treebanks, "train")
            logger.debug('Finished collecting vocab')

            with open(paramsfile, 'wb') as paramsfp:
                logger.info(f'Saving params to {paramsfile}')
                pickle.dump((vocab, options), paramsfp)

                logger.debug('Initializing the model')
                parser = Parser(vocab, options)
        else:  #continue
            if options.continueParams:
                paramsfile = options.continueParams
            with open(paramsfile, 'r') as paramsfp:
                stored_vocab, stored_options = pickle.load(paramsfp)
                logger.debug('Initializing the model:')
                parser = Parser(stored_vocab, stored_options)

            parser.Load(options.continueModel)

        dev_best = [options.epochs, -1.0]  # best epoch, best score

        for epoch in range(options.first_epoch, options.epochs + 1):

            logger.info(f'Starting epoch {epoch}')
            traindata = list(
                utils.read_conll_dir(experiment.treebanks, "train",
                                     options.max_sentences))
            parser.Train(traindata, options)
            logger.info(f'Finished epoch {epoch}')

            model_file = os.path.join(experiment.outdir,
                                      options.model + str(epoch))
            parser.Save(model_file)

            if options.pred_dev:  # use the model to predict on dev data

                # not all treebanks necessarily have dev data
                pred_treebanks = [
                    treebank for treebank in experiment.treebanks
                    if treebank.pred_dev
                ]
                if pred_treebanks:
                    for treebank in pred_treebanks:
                        treebank.outfilename = os.path.join(
                            treebank.outdir,
                            'dev_epoch_' + str(epoch) + '.conllu')
                        logger.info(
                            f"Predicting on dev data for {treebank.name}")
                    pred = list(parser.Predict(pred_treebanks, "dev", options))
                    utils.write_conll_multiling(pred, pred_treebanks)

                    if options.pred_eval:  # evaluate the prediction against gold data
                        mean_score = 0.0
                        for treebank in pred_treebanks:
                            score = utils.evaluate(treebank.dev_gold,
                                                   treebank.outfilename,
                                                   options.conllu)
                            logger.info(
                                f"Dev score {score:.2f} at epoch {epoch:d} for {treebank.name}"
                            )
                            mean_score += score
                        if len(pred_treebanks) > 1:  # multiling case
                            mean_score = mean_score / len(pred_treebanks)
                            logger.info(
                                f"Mean dev score {mean_score:.2f} at epoch {epoch:d}"
                            )
                        if options.model_selection:
                            if mean_score > dev_best[1]:
                                dev_best = [epoch, mean_score
                                            ]  # update best dev score
                            # hack to printthe word "mean" if the dev score is an average
                            mean_string = "mean " if len(
                                pred_treebanks) > 1 else ""
                            logger.info(
                                f"Best {mean_string}dev score {dev_best[1]:.2f} at epoch {dev_best[0]:d}"
                            )

            # at the last epoch choose which model to copy to barchybrid.model
            if epoch == options.epochs:
                bestmodel_file = os.path.join(
                    experiment.outdir, "barchybrid.model" + str(dev_best[0]))
                model_file = os.path.join(experiment.outdir,
                                          "barchybrid.model")
                logger.info(f"Copying {bestmodel_file} to {model_file}")
                copyfile(bestmodel_file, model_file)
                best_dev_file = os.path.join(experiment.outdir,
                                             "best_dev_epoch.txt")
                with open(best_dev_file, 'w') as fh:
                    logger.info(f"Writing best scores to: {best_dev_file}")
                    if len(experiment.treebanks) == 1:
                        fh.write(
                            f"Best dev score {dev_best[1]} at epoch {dev_best[0]:d}\n"
                        )
                    else:
                        fh.write(
                            f"Best mean dev score {dev_best[1]} at epoch {dev_best[0]:d}\n"
                        )

    else:  #if predict - so

        params = os.path.join(experiment.modeldir, options.params)
        logger.info(f'Reading params from {params}')
        with open(params, 'rb') as paramsfp:
            stored_vocab, stored_opt = pickle.load(paramsfp)

            # we need to update/add certain options based on new user input
            utils.fix_stored_options(stored_opt, options)

            parser = Parser(stored_vocab, stored_opt)
            model = os.path.join(experiment.modeldir, options.model)
            parser.Load(model)

            ts = time.time()

            for treebank in experiment.treebanks:
                if options.predict_all_epochs:  # name outfile after epoch number in model file
                    try:
                        m = re.search('(\d+)$', options.model)
                        epoch = m.group(1)
                        treebank.outfilename = f'dev_epoch_{epoch}.conllu'
                    except AttributeError:
                        raise Exception(
                            "No epoch number found in model file (e.g. barchybrid.model22)"
                        )
                if not treebank.outfilename:
                    treebank.outfilename = 'out' + (
                        '.conll' if not options.conllu else '.conllu')
                treebank.outfilename = os.path.join(treebank.outdir,
                                                    treebank.outfilename)

            pred = list(
                parser.Predict(experiment.treebanks, "test", stored_opt))
            utils.write_conll_multiling(pred, experiment.treebanks)

            te = time.time()

            if options.pred_eval:
                for treebank in experiment.treebanks:
                    logger.debug(f"Evaluating on {treebank.name}")
                    score = utils.evaluate(treebank.test_gold,
                                           treebank.outfilename,
                                           options.conllu)
                    logger.info(
                        f"Obtained LAS F1 score of {score:.2f} on {treebank.name}"
                    )

            logger.debug('Finished predicting')
Example #3
0
    def Predict(self, treebanks, datasplit, options):
        char_map = {}
        if options.char_map_file:
            char_map_fh = open(options.char_map_file, encoding='utf-8')
            char_map = json.loads(char_map_fh.read())
        # should probably use a namedtuple in get_vocab to make this prettier
        _, test_words, test_chars, _, _, _, test_treebanks, test_langs = utils.get_vocab(
            treebanks, datasplit, char_map)

        # get external embeddings for the set of words and chars in the
        # test vocab but not in the training vocab
        test_embeddings = defaultdict(lambda: {})
        if options.word_emb_size > 0 and options.ext_word_emb_file:
            new_test_words = \
                    set(test_words) - self.feature_extractor.words.keys()

            print("Number of OOV word types at test time: %i (out of %i)" %
                  (len(new_test_words), len(test_words)))

            if len(new_test_words) > 0:
                # no point loading embeddings if there are no words to look for
                for lang in test_langs:
                    embeddings = utils.get_external_embeddings(
                        options,
                        emb_file=options.ext_word_emb_file,
                        lang=lang,
                        words=new_test_words)
                    test_embeddings["words"].update(embeddings)
                    if len(test_langs) > 1 and test_embeddings["words"]:
                        print("External embeddings found for %i words "\
                                "(out of %i)" % \
                                (len(test_embeddings["words"]), len(new_test_words)))

        if options.char_emb_size > 0:
            new_test_chars = \
                    set(test_chars) - self.feature_extractor.chars.keys()
            print("Number of OOV char types at test time: %i (out of %i)" %
                  (len(new_test_chars), len(test_chars)))

            if len(new_test_chars) > 0:
                for lang in test_langs:
                    embeddings = utils.get_external_embeddings(
                        options,
                        emb_file=options.ext_char_emb_file,
                        lang=lang,
                        words=new_test_chars,
                        chars=True)
                    test_embeddings["chars"].update(embeddings)
                    if len(test_langs) > 1 and test_embeddings["chars"]:
                        print("External embeddings found for %i chars "\
                                "(out of %i)" % \
                                (len(test_embeddings["chars"]), len(new_test_chars)))

        data = utils.read_conll_dir(treebanks, datasplit, char_map=char_map)
        for iSentence, osentence in enumerate(data, 1):
            sentence = deepcopy(osentence)
            self.feature_extractor.Init(options)
            conll_sentence = [
                entry for entry in sentence
                if isinstance(entry, utils.ConllEntry)
            ]
            self.feature_extractor.getWordEmbeddings(conll_sentence, False,
                                                     options, test_embeddings)

            scores, exprs = self.__evaluate(conll_sentence, True)
            if self.proj:
                heads = decoder.parse_proj(scores)
                #LATTICE solution to multiple roots
                # see https://github.com/jujbob/multilingual-bist-parser/blob/master/bist-parser/bmstparser/src/mstlstm.py
                ## ADD for handling multi-roots problem
                rootHead = [head for head in heads if head == 0]
                if len(rootHead) != 1:
                    print(
                        "it has multi-root, changing it for heading first root for other roots"
                    )
                    rootHead = [
                        seq for seq, head in enumerate(heads) if head == 0
                    ]
                    for seq in rootHead[1:]:
                        heads[seq] = rootHead[0]
                ## finish to multi-roots

            else:
                heads = chuliu_edmonds_one_root(scores.T)

            for entry, head in zip(conll_sentence, heads):
                entry.pred_parent_id = head
                entry.pred_relation = '_'

            if self.labelsFlag:
                for modifier, head in enumerate(heads[1:]):
                    scores, exprs = self.__evaluateLabel(
                        conll_sentence, head, modifier + 1)
                    conll_sentence[
                        modifier +
                        1].pred_relation = self.feature_extractor.irels[max(
                            enumerate(scores), key=itemgetter(1))[0]]

            dy.renew_cg()

            #keep in memory the information we need, not all the vectors
            oconll_sentence = [
                entry for entry in osentence
                if isinstance(entry, utils.ConllEntry)
            ]
            for tok_o, tok in zip(oconll_sentence, conll_sentence):
                tok_o.pred_relation = tok.pred_relation
                tok_o.pred_parent_id = tok.pred_parent_id
            yield osentence