def evaluate(args): # file paths system_pred_file = args['output_file'] gold_file = args['gold_file'] model_file = args['save_dir'] + '/' + args['save_name'] if args['save_name'] is not None \ else '{}/{}_mwt_expander.pt'.format(args['save_dir'], args['shorthand']) # load model use_cuda = args['cuda'] and not args['cpu'] trainer = Trainer(model_file=model_file, use_cuda=use_cuda) loaded_args, vocab = trainer.args, trainer.vocab for k in args: if k.endswith('_dir') or k.endswith('_file') or k in ['shorthand']: loaded_args[k] = args[k] logger.debug('max_dec_len: %d' % loaded_args['max_dec_len']) # load data logger.debug("Loading data with batch size {}...".format( args['batch_size'])) doc = Document(CoNLL.conll2dict(input_file=args['eval_file'])) batch = DataLoader(doc, args['batch_size'], loaded_args, vocab=vocab, evaluation=True) if len(batch) > 0: dict_preds = trainer.predict_dict( batch.doc.get_mwt_expansions(evaluation=True)) # decide trainer type and run eval if loaded_args['dict_only']: preds = dict_preds else: logger.info("Running the seq2seq model...") preds = [] for i, b in enumerate(batch): preds += trainer.predict(b) if loaded_args.get('ensemble_dict', False): preds = trainer.ensemble( batch.doc.get_mwt_expansions(evaluation=True), preds) else: # skip eval if dev data does not exist preds = [] # write to file and score doc = copy.deepcopy(batch.doc) doc.set_mwt_expansions(preds) CoNLL.dict2conll(doc.to_dict(), system_pred_file) if gold_file is not None: _, _, score = scorer.score(system_pred_file, gold_file) logger.info("MWT expansion score: {} {:.2f}".format( args['shorthand'], score * 100))
def test_dict_to_doc_and_doc_to_dict(): doc = Document(DICT) dicts = doc.to_dict() dicts_tupleid = [] for sentence in dicts: items = [] for item in sentence: item['id'] = item['id'] if isinstance(item['id'], tuple) else (item['id'], ) items.append(item) dicts_tupleid.append(items) assert dicts_tupleid == DICT
def test_dict_to_doc_and_doc_to_dict(): """ Test the conversion from raw dict to Document and back This code path will first turn start_char|end_char into start_char & end_char fields in the Document That version to a dict will have separate fields for each of those Finally, the conversion from that dict to a list of conll entries should convert that back to misc """ doc = Document(DICT) dicts = doc.to_dict() dicts_tupleid = [] for sentence in dicts: items = [] for item in sentence: item['id'] = item['id'] if isinstance(item['id'], tuple) else (item['id'], ) items.append(item) dicts_tupleid.append(items) conll = CoNLL.convert_dict(DICT) assert conll == CONLL