def Predict(self, treebanks, datasplit, options): reached_max_swap = 0 char_map = {} if options.char_map_file: char_map_fh = open(options.char_map_file,encoding='utf-8') char_map = json.loads(char_map_fh.read()) # should probably use a namedtuple in get_vocab to make this prettier _, test_words, test_chars, _, _, _, test_treebanks, test_langs = utils.get_vocab(treebanks,datasplit,char_map) # get external embeddings for the set of words and chars in the # test vocab but not in the training vocab test_embeddings = defaultdict(lambda: {}) if options.word_emb_size > 0 and options.ext_word_emb_file: new_test_words = \ set(test_words) - self.feature_extractor.words.keys() logger.debug(f"Number of OOV word types at test time: {len(new_test_words)} (out of {len(test_words)})") if len(new_test_words) > 0: # no point loading embeddings if there are no words to look for for lang in test_langs: embeddings = utils.get_external_embeddings( options, emb_file=options.ext_word_emb_file, lang=lang, words=new_test_words ) test_embeddings["words"].update(embeddings) if len(test_langs) > 1 and test_embeddings["words"]: logger.debug( "External embeddings found for {0} words (out of {1})".format( len(test_embeddings["words"]), len(new_test_words), ), ) if options.char_emb_size > 0: new_test_chars = \ set(test_chars) - self.feature_extractor.chars.keys() logger.debug( f"Number of OOV char types at test time: {len(new_test_chars)} (out of {len(test_chars)})" ) if len(new_test_chars) > 0: for lang in test_langs: embeddings = utils.get_external_embeddings( options, emb_file=options.ext_char_emb_file, lang=lang, words=new_test_chars, chars=True ) test_embeddings["chars"].update(embeddings) if len(test_langs) > 1 and test_embeddings["chars"]: logger.debug( "External embeddings found for {0} chars (out of {1})".format( len(test_embeddings["chars"]), len(new_test_chars), ), ) data = utils.read_conll_dir(treebanks,datasplit,char_map=char_map) pbar = tqdm.tqdm( data, desc="Parsing", unit="sentences", mininterval=1.0, leave=False, disable=options.quiet, ) for iSentence, osentence in enumerate(pbar,1): sentence = deepcopy(osentence) reached_swap_for_i_sentence = False max_swap = 2*len(sentence) iSwap = 0 self.feature_extractor.Init(options) conll_sentence = [entry for entry in sentence if isinstance(entry, utils.ConllEntry)] conll_sentence = conll_sentence[1:] + [conll_sentence[0]] self.feature_extractor.getWordEmbeddings(conll_sentence, False, options, test_embeddings) stack = ParseForest([]) buf = ParseForest(conll_sentence) hoffset = 1 if self.headFlag else 0 for root in conll_sentence: root.lstms = [root.vec] if self.headFlag else [] root.lstms += [root.vec for _ in range(self.nnvecs - hoffset)] root.relation = root.relation if root.relation in self.irels else 'runk' while not (len(buf) == 1 and len(stack) == 0): scores = self.__evaluate(stack, buf, False) best = max(chain(*(scores if iSwap < max_swap else scores[:3] )), key = itemgetter(2) ) if iSwap == max_swap and not reached_swap_for_i_sentence: reached_max_swap += 1 reached_swap_for_i_sentence = True logger.debug(f"reached max swap in {reached_max_swap:d} out of {iSentence:d} sentences") self.apply_transition(best,stack,buf,hoffset) if best[1] == SWAP: iSwap += 1 dy.renew_cg() #keep in memory the information we need, not all the vectors oconll_sentence = [entry for entry in osentence if isinstance(entry, utils.ConllEntry)] oconll_sentence = oconll_sentence[1:] + [oconll_sentence[0]] for tok_o, tok in zip(oconll_sentence, conll_sentence): tok_o.pred_relation = tok.pred_relation tok_o.pred_parent_id = tok.pred_parent_id yield osentence
def run(experiment, options): if options.graph_based: from uuparser.mstlstm import MSTParserLSTM as Parser logger.info('Working with a graph-based parser') else: from uuparser.arc_hybrid import ArcHybridLSTM as Parser logger.info('Working with a transition-based parser') if not options.predict: # training paramsfile = os.path.join(experiment.outdir, options.params) if not options.continueTraining: logger.debug('Preparing vocab') vocab = utils.get_vocab(experiment.treebanks, "train") logger.debug('Finished collecting vocab') with open(paramsfile, 'wb') as paramsfp: logger.info(f'Saving params to {paramsfile}') pickle.dump((vocab, options), paramsfp) logger.debug('Initializing the model') parser = Parser(vocab, options) else: #continue if options.continueParams: paramsfile = options.continueParams with open(paramsfile, 'r') as paramsfp: stored_vocab, stored_options = pickle.load(paramsfp) logger.debug('Initializing the model:') parser = Parser(stored_vocab, stored_options) parser.Load(options.continueModel) dev_best = [options.epochs, -1.0] # best epoch, best score for epoch in range(options.first_epoch, options.epochs + 1): logger.info(f'Starting epoch {epoch}') traindata = list( utils.read_conll_dir(experiment.treebanks, "train", options.max_sentences)) parser.Train(traindata, options) logger.info(f'Finished epoch {epoch}') model_file = os.path.join(experiment.outdir, options.model + str(epoch)) parser.Save(model_file) if options.pred_dev: # use the model to predict on dev data # not all treebanks necessarily have dev data pred_treebanks = [ treebank for treebank in experiment.treebanks if treebank.pred_dev ] if pred_treebanks: for treebank in pred_treebanks: treebank.outfilename = os.path.join( treebank.outdir, 'dev_epoch_' + str(epoch) + '.conllu') logger.info( f"Predicting on dev data for {treebank.name}") pred = list(parser.Predict(pred_treebanks, "dev", options)) utils.write_conll_multiling(pred, pred_treebanks) if options.pred_eval: # evaluate the prediction against gold data mean_score = 0.0 for treebank in pred_treebanks: score = utils.evaluate(treebank.dev_gold, treebank.outfilename, options.conllu) logger.info( f"Dev score {score:.2f} at epoch {epoch:d} for {treebank.name}" ) mean_score += score if len(pred_treebanks) > 1: # multiling case mean_score = mean_score / len(pred_treebanks) logger.info( f"Mean dev score {mean_score:.2f} at epoch {epoch:d}" ) if options.model_selection: if mean_score > dev_best[1]: dev_best = [epoch, mean_score ] # update best dev score # hack to printthe word "mean" if the dev score is an average mean_string = "mean " if len( pred_treebanks) > 1 else "" logger.info( f"Best {mean_string}dev score {dev_best[1]:.2f} at epoch {dev_best[0]:d}" ) # at the last epoch choose which model to copy to barchybrid.model if epoch == options.epochs: bestmodel_file = os.path.join( experiment.outdir, "barchybrid.model" + str(dev_best[0])) model_file = os.path.join(experiment.outdir, "barchybrid.model") logger.info(f"Copying {bestmodel_file} to {model_file}") copyfile(bestmodel_file, model_file) best_dev_file = os.path.join(experiment.outdir, "best_dev_epoch.txt") with open(best_dev_file, 'w') as fh: logger.info(f"Writing best scores to: {best_dev_file}") if len(experiment.treebanks) == 1: fh.write( f"Best dev score {dev_best[1]} at epoch {dev_best[0]:d}\n" ) else: fh.write( f"Best mean dev score {dev_best[1]} at epoch {dev_best[0]:d}\n" ) else: #if predict - so params = os.path.join(experiment.modeldir, options.params) logger.info(f'Reading params from {params}') with open(params, 'rb') as paramsfp: stored_vocab, stored_opt = pickle.load(paramsfp) # we need to update/add certain options based on new user input utils.fix_stored_options(stored_opt, options) parser = Parser(stored_vocab, stored_opt) model = os.path.join(experiment.modeldir, options.model) parser.Load(model) ts = time.time() for treebank in experiment.treebanks: if options.predict_all_epochs: # name outfile after epoch number in model file try: m = re.search('(\d+)$', options.model) epoch = m.group(1) treebank.outfilename = f'dev_epoch_{epoch}.conllu' except AttributeError: raise Exception( "No epoch number found in model file (e.g. barchybrid.model22)" ) if not treebank.outfilename: treebank.outfilename = 'out' + ( '.conll' if not options.conllu else '.conllu') treebank.outfilename = os.path.join(treebank.outdir, treebank.outfilename) pred = list( parser.Predict(experiment.treebanks, "test", stored_opt)) utils.write_conll_multiling(pred, experiment.treebanks) te = time.time() if options.pred_eval: for treebank in experiment.treebanks: logger.debug(f"Evaluating on {treebank.name}") score = utils.evaluate(treebank.test_gold, treebank.outfilename, options.conllu) logger.info( f"Obtained LAS F1 score of {score:.2f} on {treebank.name}" ) logger.debug('Finished predicting')
def Predict(self, treebanks, datasplit, options): char_map = {} if options.char_map_file: char_map_fh = open(options.char_map_file, encoding='utf-8') char_map = json.loads(char_map_fh.read()) # should probably use a namedtuple in get_vocab to make this prettier _, test_words, test_chars, _, _, _, test_treebanks, test_langs = utils.get_vocab( treebanks, datasplit, char_map) # get external embeddings for the set of words and chars in the # test vocab but not in the training vocab test_embeddings = defaultdict(lambda: {}) if options.word_emb_size > 0 and options.ext_word_emb_file: new_test_words = \ set(test_words) - self.feature_extractor.words.keys() print("Number of OOV word types at test time: %i (out of %i)" % (len(new_test_words), len(test_words))) if len(new_test_words) > 0: # no point loading embeddings if there are no words to look for for lang in test_langs: embeddings = utils.get_external_embeddings( options, emb_file=options.ext_word_emb_file, lang=lang, words=new_test_words) test_embeddings["words"].update(embeddings) if len(test_langs) > 1 and test_embeddings["words"]: print("External embeddings found for %i words "\ "(out of %i)" % \ (len(test_embeddings["words"]), len(new_test_words))) if options.char_emb_size > 0: new_test_chars = \ set(test_chars) - self.feature_extractor.chars.keys() print("Number of OOV char types at test time: %i (out of %i)" % (len(new_test_chars), len(test_chars))) if len(new_test_chars) > 0: for lang in test_langs: embeddings = utils.get_external_embeddings( options, emb_file=options.ext_char_emb_file, lang=lang, words=new_test_chars, chars=True) test_embeddings["chars"].update(embeddings) if len(test_langs) > 1 and test_embeddings["chars"]: print("External embeddings found for %i chars "\ "(out of %i)" % \ (len(test_embeddings["chars"]), len(new_test_chars))) data = utils.read_conll_dir(treebanks, datasplit, char_map=char_map) for iSentence, osentence in enumerate(data, 1): sentence = deepcopy(osentence) self.feature_extractor.Init(options) conll_sentence = [ entry for entry in sentence if isinstance(entry, utils.ConllEntry) ] self.feature_extractor.getWordEmbeddings(conll_sentence, False, options, test_embeddings) scores, exprs = self.__evaluate(conll_sentence, True) if self.proj: heads = decoder.parse_proj(scores) #LATTICE solution to multiple roots # see https://github.com/jujbob/multilingual-bist-parser/blob/master/bist-parser/bmstparser/src/mstlstm.py ## ADD for handling multi-roots problem rootHead = [head for head in heads if head == 0] if len(rootHead) != 1: print( "it has multi-root, changing it for heading first root for other roots" ) rootHead = [ seq for seq, head in enumerate(heads) if head == 0 ] for seq in rootHead[1:]: heads[seq] = rootHead[0] ## finish to multi-roots else: heads = chuliu_edmonds_one_root(scores.T) for entry, head in zip(conll_sentence, heads): entry.pred_parent_id = head entry.pred_relation = '_' if self.labelsFlag: for modifier, head in enumerate(heads[1:]): scores, exprs = self.__evaluateLabel( conll_sentence, head, modifier + 1) conll_sentence[ modifier + 1].pred_relation = self.feature_extractor.irels[max( enumerate(scores), key=itemgetter(1))[0]] dy.renew_cg() #keep in memory the information we need, not all the vectors oconll_sentence = [ entry for entry in osentence if isinstance(entry, utils.ConllEntry) ] for tok_o, tok in zip(oconll_sentence, conll_sentence): tok_o.pred_relation = tok.pred_relation tok_o.pred_parent_id = tok.pred_parent_id yield osentence