示例#1
0
def evaluate(args):
    # file paths
    system_pred_file = args['output_file']
    gold_file = args['gold_file']
    model_file = args['save_dir'] + '/' + args['save_name'] if args['save_name'] is not None \
            else '{}/{}_mwt_expander.pt'.format(args['save_dir'], args['shorthand'])

    # load model
    use_cuda = args['cuda'] and not args['cpu']
    trainer = Trainer(model_file=model_file, use_cuda=use_cuda)
    loaded_args, vocab = trainer.args, trainer.vocab

    for k in args:
        if k.endswith('_dir') or k.endswith('_file') or k in ['shorthand']:
            loaded_args[k] = args[k]
    logger.debug('max_dec_len: %d' % loaded_args['max_dec_len'])

    # load data
    logger.debug("Loading data with batch size {}...".format(
        args['batch_size']))
    doc = Document(CoNLL.conll2dict(input_file=args['eval_file']))
    batch = DataLoader(doc,
                       args['batch_size'],
                       loaded_args,
                       vocab=vocab,
                       evaluation=True)

    if len(batch) > 0:
        dict_preds = trainer.predict_dict(
            batch.doc.get_mwt_expansions(evaluation=True))
        # decide trainer type and run eval
        if loaded_args['dict_only']:
            preds = dict_preds
        else:
            logger.info("Running the seq2seq model...")
            preds = []
            for i, b in enumerate(batch):
                preds += trainer.predict(b)

            if loaded_args.get('ensemble_dict', False):
                preds = trainer.ensemble(
                    batch.doc.get_mwt_expansions(evaluation=True), preds)
    else:
        # skip eval if dev data does not exist
        preds = []

    # write to file and score
    doc = copy.deepcopy(batch.doc)
    doc.set_mwt_expansions(preds)
    CoNLL.dict2conll(doc.to_dict(), system_pred_file)

    if gold_file is not None:
        _, _, score = scorer.score(system_pred_file, gold_file)

        logger.info("MWT expansion score: {} {:.2f}".format(
            args['shorthand'], score * 100))
def test_dict_to_doc_and_doc_to_dict():
    doc = Document(DICT)
    dicts = doc.to_dict()
    dicts_tupleid = []
    for sentence in dicts:
        items = []
        for item in sentence:
            item['id'] = item['id'] if isinstance(item['id'], tuple) else (item['id'], )
            items.append(item)
        dicts_tupleid.append(items)
    assert dicts_tupleid == DICT
def test_dict_to_doc_and_doc_to_dict():
    """
    Test the conversion from raw dict to Document and back
    This code path will first turn start_char|end_char into start_char & end_char fields in the Document
    That version to a dict will have separate fields for each of those
    Finally, the conversion from that dict to a list of conll entries should convert that back to misc
    """
    doc = Document(DICT)
    dicts = doc.to_dict()
    dicts_tupleid = []
    for sentence in dicts:
        items = []
        for item in sentence:
            item['id'] = item['id'] if isinstance(item['id'], tuple) else (item['id'], )
            items.append(item)
        dicts_tupleid.append(items)
    conll = CoNLL.convert_dict(DICT)
    assert conll == CONLL