예제 #1
0
def eval_model(opt, skip_valid=False, skip_test=False):
    """
    Run through an evaluation loop.

    :param opt:
        Any non-default options you wish to set.
    :param bool skip_valid:
        If true skips the valid evaluation, and the second return value will be None.
    :param bool skip_test:
        If true skips the test evaluation, and the third return value will be None.

    :return: (stdout, valid_results, test_results)
    :rtype: (str, dict, dict)

    If model_file is not in opt, then this helper will create a temporary directory
    to store the model files, and clean up afterwards. You can keep the directory
    by disabling autocleanup
    """
    import parlai.scripts.eval_model as ems

    parser = ems.setup_args()
    parser.set_params(**opt)
    parser.set_params(log_every_n_secs=10)
    popt = parser.parse_args(print_args=False)

    if popt.get('model_file') and not popt.get('dict_file'):
        popt['dict_file'] = popt['model_file'] + '.dict'

    with capture_output() as output:
        popt['datatype'] = 'valid'
        valid = None if skip_valid else ems.eval_model(popt)
        popt['datatype'] = 'test'
        test = None if skip_test else ems.eval_model(popt)

    return (output.getvalue(), valid, test)
예제 #2
0
def eval_model(opt):
    """
    Runs through an evaluation loop.

    :return: (stdout, stderr, valid_results, test_results)
    :rtype: (str, str, dict, dict)

    If model_file is not in opt, then this helper will create a temporary directory
    to store the model files, and clean up afterwards. You can keep the directory
    by disabling autocleanup
    """

    import parlai.scripts.eval_model as ems
    parser = ems.setup_args()
    parser.set_params(**opt)
    popt = parser.parse_args(print_args=False)

    if popt.get('model_file') and not popt.get('dict_file'):
        popt['dict_file'] = popt['model_file'] + '.dict'

    with capture_output() as output:
        popt['datatype'] = 'valid'
        valid = ems.eval_model(popt)
        popt['datatype'] = 'test'
        test = ems.eval_model(popt)

    return (
        output.getvalue(),
        valid,
        test,
    )
    def test_hogwild_eval(self):
        """Test eval with numthreads > 1 and batchsize in [1,2,3]."""
        parser = setup_args()
        NUM_EXS = 500
        parser.set_defaults(
            task='tasks.repeat:RepeatTeacher:{}'.format(NUM_EXS),
            model='repeat_label',
            datatype='valid',
            num_examples=-1,
            display_examples=False,
        )

        old_out = sys.stdout
        output = display_output()
        try:
            sys.stdout = output
            for nt in [2, 5, 10]:
                parser.set_defaults(numthreads=nt)
                for bs in [1, 2, 3]:
                    parser.set_defaults(batchsize=bs)
                    parser.set_defaults(batch_sort=(bs % 2 == 0))
                    report = eval_model(parser, printargs=False)
                    self.assertEqual(report['total'], NUM_EXS)
        finally:
            # restore sys.stdout
            sys.stdout = old_out
예제 #4
0
def main():
    parser = eval_model.setup_args()
    parser.add_distributed_training_args()
    parser.add_argument('--port', type=int, default=61337, help='TCP port number')
    opt = parser.parse_args(print_args=(os.environ['SLURM_PROCID'] == '0'))

    with distributed_utils.slurm_distributed_context(opt) as opt:
        return eval_model.eval_model(opt)
예제 #5
0
def multiprocess_eval(
    rank, opt, port=61337, rank_offset=0, gpu=None, hostname='localhost'
):
    """
    Run a multiprocessing evaluation.

    Invoked by launch_and_eval, not instantiated directly.
    """
    with distributed_utils.distributed_context(
        rank, opt, port, rank_offset, gpu, hostname
    ) as opt:
        return eval_model.eval_model(opt)
예제 #6
0
def multiprocess_eval(
    rank, opt, port=61337, rank_offset=0, gpu=None, hostname='localhost'
):
    """
    Run a multiprocessing evaluation.

    Invoked by launch_and_eval, not instantiated directly.
    """
    init_method = f'tcp://{hostname}:{port}'
    with distributed_utils.distributed_context(
        rank, opt, rank_offset, gpu, init_method=init_method
    ) as opt:
        opt['multiprocessing'] = True
        return eval_model.eval_model(opt)
예제 #7
0
# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""Evaluate a pre-trained retriever-reader model on open squad.
"""
from parlai.scripts.eval_model import setup_args, eval_model

if __name__ == '__main__':
    parser = setup_args()
    parser.set_params(
        task='squad:opensquad',
        model='retriever_reader',
        retriever_model_file='models:wikipedia_full/tfidf_retriever/model',
        reader_model_file='models:drqa/squad/model',
    )
    opt = parser.parse_args(print_args=False)
    eval_model(opt, print_parser=parser)
def eval_hits(opt, print_parser):
    report = eval_model(opt, print_parser)
    print("============================")
    print("FINAL Hits@1: " + str(report['hits@1']))
예제 #9
0
from parlai.scripts.eval_model import eval_model
from projects.personachat.persona_seq2seq import PersonachatSeqseqAgentBasic
'''Evaluate pre-trained model trained for hits@1 metric
Generative model trained on personachat using persona 'self'
Run from ParlAI directory
'''

if __name__ == '__main__':
    parser = ParlaiParser(add_model_args=True)
    parser.add_argument('-n', '--num-examples', default=100000000)
    parser.add_argument('-d', '--display-examples', type='bool', default=False)
    parser.add_argument('-ltim', '--log-every-n-secs', type=float, default=2)
    PersonachatSeqseqAgentBasic.add_cmdline_args(parser)
    parser.set_defaults(
        dict_file='models:personachat/seq2seq_personachat/fulldict.dict',
        rank_candidates=True,
        task='personachat:self',
        model=
        'projects.personachat.persona_seq2seq:PersonachatSeqseqAgentBasic',
        model_file=
        'models:personachat/seq2seq_personachat/seq2seq_no_dropout0.2_lstm_1024_1e-3',
        datatype='test')

    opt = parser.parse_args()
    opt['model_type'] = 'seq2seq_personachat'  # for builder
    # build all profile memory models
    fnames = ['seq2seq_no_dropout0.2_lstm_1024_1e-3', 'fulldict.dict']
    download_models(opt, fnames, 'personachat')

    eval_model(parser)
예제 #10
0
def eval_f1(opt, print_parser):
    report = eval_model(opt, print_parser)
    print('============================')
    print('FINAL F1: ' + str(report['f1']))
    return report
예제 #11
0
Results on unseen test set (run with flag
`-t wizard_of_wikipedia:WizardDialogKnowledge:topic_split`):
Hits@1/100: 68.96
"""

if __name__ == '__main__':
    parser = ParlaiParser(add_model_args=True)
    parser.add_argument('-n', '--num-examples', default=100000000)
    parser.add_argument('-d', '--display-examples', type='bool', default=False)
    parser.add_argument('-ltim', '--log-every-n-secs', type=float, default=2)
    WizardTransformerRankerAgent.add_cmdline_args(parser, partial_opt=None)
    parser.set_params(
        task='wizard_of_wikipedia',
        model='projects:wizard_of_wikipedia:wizard_transformer_ranker',
        model_file=
        'models:wizard_of_wikipedia/full_dialogue_retrieval_model/model',
        datatype='test',
        n_heads=6,
        ffn_size=1200,
        embeddings_scale=False,
        delimiter=' __SOC__ ',
        n_positions=1000,
        legacy=True,
    )

    opt = parser.parse_args()
    download(opt['datapath'])  # download pretrained retrieval model

    eval_model(opt)
예제 #12
0
def eval_f1(opt, print_parser):
    report = eval_model(opt, print_parser)
    print("============================")
    print("FINAL F1: " + str(report['f1']))
예제 #13
0
def eval_f1(opt, print_parser):
    report = eval_model(opt, print_parser)
    print('============================')
    print('Final F1@1: {}, BLEU:  {}'.format(report['f1'], report['bleu']))
예제 #14
0
def eval_hits(opt, print_parser):
    report = eval_model(opt, print_parser)
    print('============================')
    print('FINAL Hits@1: ' + str(report['hits@1']))
    return report
예제 #15
0
<<<<<<< HEAD
    popt = parser.parse_args([], print_args=False)
=======
    popt = parser.parse_args(print_args=False)
>>>>>>> 4f6b99642d60aff1a41b9eae8bd2ccd9e40ebba4
>>>>>>> origin/master
=======
    popt = parser.parse_args(print_args=False)
>>>>>>> 4f6b99642d60aff1a41b9eae8bd2ccd9e40ebba4
>>>>>>> ef574cebef2a8d5aa38b73176b1e71a919d6670f

    if popt.get('model_file') and not popt.get('dict_file'):
        popt['dict_file'] = popt['model_file'] + '.dict'

    popt['datatype'] = 'valid' if valid_datatype is None else valid_datatype
    valid = None if skip_valid else ems.eval_model(popt)
    popt['datatype'] = 'test'
    test = None if skip_test else ems.eval_model(popt)

    return valid, test


def display_data(opt):
    """
    Run through a display data run.

    :return: (stdout_train, stdout_valid, stdout_test)
    :rtype: (str, str, str)
    """
    import parlai.scripts.display_data as dd
예제 #16
0
 def run(self):
     with distributed_utils.slurm_distributed_context(self.opt) as opt:
         return eval_model.eval_model(opt)