コード例 #1
0
ファイル: eval_model.py プロジェクト: jderiu/ParlAI
def setup_args(parser=None):
    if parser is None:
        parser = ParlaiParser(True, True, 'Evaluate a model')
    # Get command line arguments
    parser.add_argument(
        '-rf',
        '--report-filename',
        type=str,
        default='',
        help='Saves a json file of the evaluation report either as an '
        'extension to the model-file (if begins with a ".") or a whole '
        'file path. Set to the empty string to not save at all.',
    )
    parser.add_argument(
        '--save-world-logs',
        type='bool',
        default=False,
        help='Saves a jsonl file containing all of the task examples and '
        'model replies. Must also specify --report-filename.',
    )
    parser.add_argument(
        '--save-format',
        type=str,
        default='conversations',
        choices=['conversations', 'parlai'],
    )
    parser.add_argument('-ne', '--num-examples', type=int, default=-1)
    parser.add_argument('-d', '--display-examples', type='bool', default=False)
    parser.add_argument('-ltim', '--log-every-n-secs', type=float, default=10)
    parser.add_argument(
        '-mcs',
        '--metrics',
        type=str,
        default='default',
        help='list of metrics to show/compute, e.g. all, default,'
        'or give a list split by , like '
        'ppl,f1,accuracy,hits@1,rouge,bleu'
        'the rouge metrics will be computed as rouge-1, rouge-2 and rouge-l',
    )
    parser.add_argument(
        '-micro',
        '--aggregate-micro',
        type='bool',
        default=False,
        help='Report micro-averaged metrics instead of macro averaged metrics.',
        recommended=False,
    )
    WorldLogger.add_cmdline_args(parser)
    TensorboardLogger.add_cmdline_args(parser)
    parser.set_params(datatype='valid')
    return parser
コード例 #2
0
ファイル: utils.py プロジェクト: simplecoka/cortx
def get_context_generator(
        override_opt: Optional[Dict[str, Any]] = None) -> ContextGenerator:
    """
    Return an object to return BlendedSkillTalk-style context info (personas, etc.).
    """
    argparser = ParlaiParser(False, False)
    argparser.add_parlai_data_path()
    if override_opt is not None:
        argparser.set_params(**override_opt)
    opt = argparser.parse_args([])
    context_generator = ContextGenerator(opt, datatype='test', seed=0)
    # We pull from the test set so that the model can't regurgitate
    # memorized conversations
    return context_generator
コード例 #3
0
ファイル: safety.py プロジェクト: scy6500/openchat
    def _create_safety_model(self, custom_model_file, device):
        from parlai.core.params import ParlaiParser

        parser = ParlaiParser(False, False)
        TransformerClassifierAgent.add_cmdline_args(parser, partial_opt=None)
        parser.set_params(
            model='transformer/classifier',
            model_file=custom_model_file,
            print_scores=True,
            data_parallel=False,
        )
        safety_opt = parser.parse_args([])
        safety_opt["override"]["no_cuda"] = False if "cuda" in device else True
        return create_agent(safety_opt, requireModelExists=True)
コード例 #4
0
ファイル: test_torch_agent.py プロジェクト: rhamnett/ParlAI
def get_agent(**kwargs):
    r"""
    Return opt-initialized agent.

    :param kwargs: any kwargs you want to set using parser.set_params(\*\*kwargs)
    """
    if 'no_cuda' not in kwargs:
        kwargs['no_cuda'] = True
    from parlai.core.params import ParlaiParser

    parser = ParlaiParser()
    MockTorchAgent.add_cmdline_args(parser)
    parser.set_params(**kwargs)
    opt = parser.parse_args([], print_args=False)
    return MockTorchAgent(opt)
コード例 #5
0
ファイル: test_torch_agent.py プロジェクト: XinnuoXu/DRank
    def test_maintain_dialog_history(self):
        try:
            from parlai.core.torch_agent import TorchAgent
        except ImportError as e:
            if 'pytorch' in e.msg:
                print(
                    'Skipping TestTorchAgent.test_maintain_dialog_history, no pytorch.'
                )
                return

        from parlai.core.params import ParlaiParser
        parser = ParlaiParser()
        TorchAgent.add_cmdline_args(parser)
        parser.set_params(no_cuda=True, truncate=5)
        opt = parser.parse_args(print_args=False)
        mdict = MockDict()

        shared = {'opt': opt, 'dict': mdict}
        agent = TorchAgent(opt, shared)

        observation = {
            "text": "What is a painting?",
            "labels": ["Paint on a canvas."],
            "episode_done": False
        }

        agent.maintain_dialog_history(observation)

        self.assertTrue('dialog' in agent.history,
                        "Failed initializing self.history.")
        self.assertTrue('episode_done' in agent.history,
                        "Failed initializing self.history.")
        self.assertTrue('labels' in agent.history,
                        "Failed initializing self.history.")
        self.assertTrue(
            list(agent.history['dialog']) == [7, 8, 9],
            "Failed adding vectorized text to dialog.")
        self.assertTrue(not agent.history['episode_done'],
                        "Failed to properly store episode_done field.")
        self.assertTrue(agent.history['labels'] == observation['labels'],
                        "Failed saving labels.")

        observation['text_vec'] = agent.maintain_dialog_history(observation)
        print(agent.history['dialog'])
        self.assertTrue(
            list(agent.history['dialog']) == [8, 9, 7, 8, 9],
            "Failed adding vectorized text to dialog.")
コード例 #6
0
ファイル: worlds.py プロジェクト: convobox/ParlAI
 def _set_up_knowledge_agent(self,
                             add_token_knowledge: bool = False,
                             shared=None) -> None:
     """
     Set up knowledge agent for knowledge retrieval generated from WoW project.
     """
     parser = ParlaiParser(False, False)
     KnowledgeRetrieverAgent.add_cmdline_args(parser)
     parser.set_params(
         model='projects:wizard_of_wikipedia:knowledge_retriever',
         add_token_knowledge=add_token_knowledge,
     )
     knowledge_opt = parser.parse_args([])
     if shared:
         self.knowledge_agent = KnowledgeRetrieverAgent(
             knowledge_opt, shared.get('knowledge_retriever', None))
     else:
         self.knowledge_agent = KnowledgeRetrieverAgent(knowledge_opt)
コード例 #7
0
def setup_args(parser=None):
    if parser is None:
        parser = ParlaiParser(True, True,
                              'compute statistics from model predictions')
    DictionaryAgent.add_cmdline_args(parser)

    # These defaults can be overriden by both .opt file and user's command line flags
    parser.add_argument('-ne', '--num-examples', type=int, default=-1)
    parser.add_argument('-ltim', '--log-every-n-secs', type=float, default=2)
    parser.add_argument(
        '-ed',
        '--external-dict',
        type=str,
        default=None,
        help='External dictionary for stat computation',
    )
    parser.add_argument(
        '-fb',
        '--freq-bins',
        type=str,
        default='0,100,1000,10000',
        help='Bins boundaries for rare words stat',
    )
    parser.add_argument(
        '-gr',
        '--gold-response',
        type=bool,
        default=False,
        help='Compute stats for gold response',
    )

    # These settings override .opt file but not user's command line flags
    parser.set_params(
        datatype='valid',
        task='projects.controllable_dialogue.tasks.agents',
        model=
        'projects.controllable_dialogue.controllable_seq2seq.controllable_seq2seq:ControllableSeq2seqAgent',  # noqa: E501
        batchsize=64,
        beam_size=20,
        beam_min_n_best=10,
        use_reply='model',
    )
    TensorboardLogger.add_cmdline_args(parser)
    return parser
コード例 #8
0
ファイル: utils.py プロジェクト: sagar-spkt/ParlAI
def get_context_generator(
    override_opt: Optional[Dict[str, Any]] = None,
    task: Optional[str] = 'blended_skill_talk',
    **kwargs,
) -> ContextGenerator:
    """
    Return an object to return BlendedSkillTalk-style context info (personas, etc.).
    """
    argparser = ParlaiParser(False, False)
    argparser.add_parlai_data_path()
    if override_opt is not None:
        argparser.set_params(**override_opt)
    opt = argparser.parse_args([])
    task_module = load_task_module(task)
    context_generator_class = getattr(task_module, 'ContextGenerator', None)
    context_generator = context_generator_class(opt, datatype='test', seed=0, **kwargs)
    # We pull from the test set so that the model can't regurgitate
    # memorized conversations
    return context_generator
コード例 #9
0
    def test_resize_embeddings(self):
        # train original model
        with testing_utils.tempdir() as tmpdir:
            model_file = os.path.join(tmpdir, 'model_file')
            _, _ = testing_utils.train_model(
                dict(
                    model='transformer/generator',
                    task='integration_tests:short_fixed',
                    n_layers=1,
                    n_encoder_layers=2,
                    n_decoder_layers=4,
                    num_epochs=1,
                    dict_tokenizer='bytelevelbpe',
                    bpe_vocab=DEFAULT_BYTELEVEL_BPE_VOCAB,
                    bpe_merge=DEFAULT_BYTELEVEL_BPE_MERGE,
                    bpe_add_prefix_space=False,
                    model_file=model_file,
                    save_after_valid=True,
                )
            )

            # now create agent with special tokens
            parser = ParlaiParser()
            parser.set_params(
                model='transformer/generator',
                task='integration_tests:short_fixed',
                n_layers=1,
                n_encoder_layers=2,
                n_decoder_layers=4,
                dict_tokenizer='bytelevelbpe',
                bpe_vocab=DEFAULT_BYTELEVEL_BPE_VOCAB,
                bpe_merge=DEFAULT_BYTELEVEL_BPE_MERGE,
                bpe_add_prefix_space=False,
                model_file=model_file,
                save_after_valid=True,
                special_tok_lst='PARTY,PARROT',
            )
            opt = parser.parse_args([])
            agent = create_agent(opt)
            # assert that the embeddings were resized
            assert agent.resized_embeddings
            # assert model has special tokens
            self.assertEqual(agent.special_toks, ['PARTY', 'PARROT'])
コード例 #10
0
    def get_bot_agents(args: DictConfig,
                       model_opts: Dict[str, str],
                       no_cuda=False) -> Dict[str, dict]:
        """
        Return shared bot agents.

        Pass in model opts with the `model_opts` arg, where `model_opts` is a dictionary
        whose keys are model names and whose values are strings that specify model
        params (i.e. `--model image_seq2seq`).
        """

        # Set up overrides
        model_overrides = {
            'model_parallel': args.blueprint.task_model_parallel
        }
        if no_cuda:
            # If we load many models at once, we have to keep it on CPU
            model_overrides['no_cuda'] = no_cuda
        else:
            logging.warning(
                'WARNING: MTurk task has no_cuda FALSE. Models will run on GPU. Will '
                'not work if loading many models at once.')

        # Convert opt strings to Opt objects
        parser = ParlaiParser(True, True)
        parser.set_params(**model_overrides)
        processed_opts = {}
        for name, opt_string in model_opts.items():
            processed_opts[name] = parser.parse_args(opt_string.split())

        # Load and share all model agents
        logging.info(
            f'Got {len(list(processed_opts.keys()))} models: {processed_opts.keys()}.'
        )
        shared_bot_agents = {}
        for model_name, model_opt in processed_opts.items():
            logging.info('\n\n--------------------------------')
            logging.info(f'model_name: {model_name}, opt_dict: {model_opt}')
            model_agent = create_agent(model_opt, requireModelExists=True)
            shared_bot_agents[model_name] = model_agent.share()
        return shared_bot_agents
コード例 #11
0
 def add_cmdline_args(cls,
                      parser: ParlaiParser,
                      partial_opt: Optional[Opt] = None) -> ParlaiParser:
     super().add_cmdline_args(parser, partial_opt)
     group = parser.add_argument_group(
         'Gender Multiclass Interactive World')
     group.add_argument(
         '--self-threshold',
         type=float,
         default=0.52,
         help='Threshold for choosing unknown for self',
     )
     group.add_argument(
         '--partner-threshold',
         type=float,
         default=0.52,
         help='Threshold for choosing unknown for self',
     )
     parser.set_params(
         single_turn=True,  # this is a single turn task currently
         eval_candidates='inline',
         return_cand_scores=True,
     )
     return parser
コード例 #12
0
ファイル: test_wizard.py プロジェクト: skybirdhe/ParlAI
    def test_knowledge_retriever(self):
        from parlai.core.params import ParlaiParser

        parser = ParlaiParser(False, False)
        KnowledgeRetrieverAgent.add_cmdline_args(parser)
        parser.set_params(
            model='projects:wizard_of_wikipedia:knowledge_retriever',
            add_token_knowledge=True,
        )
        knowledge_opt = parser.parse_args([], print_args=False)
        knowledge_agent = create_agent(knowledge_opt)

        knowledge_agent.observe(
            {
                'text': 'what do you think of mountain dew?',
                'chosen_topic': 'Mountain Dew',
                'episode_done': False,
            }
        )

        knowledge_act = knowledge_agent.act()

        title = knowledge_act['title']
        self.assertEqual(title, 'Mountain Dew', 'Did not save chosen topic correctly')

        knowledge = knowledge_act['text']
        self.assertIn(
            TOKEN_KNOWLEDGE, knowledge, 'Knowledge token was not inserted correctly'
        )

        checked_sentence = knowledge_act['checked_sentence']
        self.assertEqual(
            checked_sentence,
            'Mountain Dew (stylized as Mtn Dew) is a carbonated soft drink brand produced and owned by PepsiCo.',
            'Did not correctly choose the checked sentence',
        )
コード例 #13
0
# All rights reserved.
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. An additional grant
# of patent rights can be found in the PATENTS file in the same directory.
"""Interact with a pre-trained model.
This seq2seq model was trained on convai2:self.
"""

from parlai.core.build_data import download_models
from parlai.core.params import ParlaiParser
from parlai.scripts.interactive import interactive

if __name__ == '__main__':
    parser = ParlaiParser(add_model_args=True)
    parser.set_params(
        model='legacy:seq2seq:0',
        model_file='models:convai2/seq2seq/convai2_self_seq2seq_model',
        dict_file='models:convai2/seq2seq/convai2_self_seq2seq_model.dict',
        dict_lower=True,
        batchsize=1,
    )
    opt = parser.parse_args()
    if (opt.get('model_file', '')
            .find('convai2/seq2seq/convai2_self_seq2seq_model') != -1):
        opt['model_type'] = 'seq2seq'
        fnames = ['convai2_self_seq2seq_model.tgz',
                  'convai2_self_seq2seq_model.dict',
                  'convai2_self_seq2seq_model.opt']
        download_models(opt, fnames, 'convai2', version='v3.0')
    interactive(opt)
コード例 #14
0
Results on unseen test set (run with flag
`-t wizard_of_wikipedia:WizardDialogKnowledge:topic_split`):
Hits@1/100: 68.96
"""

if __name__ == '__main__':
    parser = ParlaiParser(add_model_args=True)
    parser.add_argument('-n', '--num-examples', default=100000000)
    parser.add_argument('-d', '--display-examples', type='bool', default=False)
    parser.add_argument('-ltim', '--log-every-n-secs', type=float, default=2)
    WizardTransformerRankerAgent.add_cmdline_args(parser, partial_opt=None)
    parser.set_params(
        task='wizard_of_wikipedia',
        model='projects:wizard_of_wikipedia:wizard_transformer_ranker',
        model_file=
        'models:wizard_of_wikipedia/full_dialogue_retrieval_model/model',
        datatype='test',
        n_heads=6,
        ffn_size=1200,
        embeddings_scale=False,
        delimiter=' __SOC__ ',
        n_positions=1000,
        legacy=True,
    )

    opt = parser.parse_args()
    download(opt['datapath'])  # download pretrained retrieval model

    eval_model(opt)
コード例 #15
0
ファイル: eval_model.py プロジェクト: skywalker023/ParlAI
def setup_args(parser=None):
    if parser is None:
        parser = ParlaiParser(True, True, 'Evaluate a model')
    # Get command line arguments
    parser.add_argument(
        '-rf',
        '--report-filename',
        type=str,
        default='',
        help='Saves a json file of the evaluation report either as an '
        'extension to the model-file (if begins with a ".") or a whole '
        'file path. Set to the empty string to not save at all.',
    )
    parser.add_argument(
        '--world-logs',
        type=str,
        default='',
        help='Saves a jsonl file of the world logs.'
        'Set to the empty string to not save at all.',
    )
    parser.add_argument(
        '--save-format',
        type=str,
        default='conversations',
        choices=['conversations', 'parlai'],
    )
    parser.add_argument(
        '--area-under-curve-digits',
        '-auc',
        type=int,
        default=-1,
        help=
        'a positive number indicates to calculate the area under the roc curve and it also determines how many decimal digits of the predictions to keep (higher numbers->more precise); also used to determine whether or not to calculate the AUC metric',
    )
    parser.add_argument(
        '--area-under-curve-class',
        '-auclass',
        type=str,
        default=None,
        nargs='*',
        help='the name(s) of the class to calculate the auc for',
    )
    parser.add_argument('-ne', '--num-examples', type=int, default=-1)
    parser.add_argument('-d', '--display-examples', type='bool', default=False)
    parser.add_argument('-ltim', '--log-every-n-secs', type=float, default=10)
    parser.add_argument(
        '-mcs',
        '--metrics',
        type=str,
        default='default',
        help='list of metrics to show/compute, e.g. all, default,'
        'or give a list split by , like '
        'ppl,f1,accuracy,hits@1,rouge,bleu'
        'the rouge metrics will be computed as rouge-1, rouge-2 and rouge-l',
    )
    parser.add_argument(
        '-micro',
        '--aggregate-micro',
        type='bool',
        default=False,
        help='Report micro-averaged metrics instead of macro averaged metrics.',
        recommended=False,
    )
    WorldLogger.add_cmdline_args(parser, partial_opt=None)
    TensorboardLogger.add_cmdline_args(parser, partial_opt=None)
    parser.set_params(datatype='valid')
    return parser
コード例 #16
0
def setup_args(parser=None):
    if parser is None:
        parser = ParlaiParser()
    parser.set_params(
        model='parlai.agents.local_human.local_human:LocalHumanAgent')
    return parser
コード例 #17
0
 def add_cmdline_args(cls,
                      parser: ParlaiParser,
                      partial_opt=None) -> ParlaiParser:
     super().add_cmdline_args(parser, partial_opt)
     parser.set_params(prepend_gold_knowledge=True)
     return parser
コード例 #18
0
from parlai.scripts.interactive import interactive
from parlai.agents.language_model.language_model import LanguageModelAgent
'''Interact with pre-trained model
Language model trained on Opensubtitles 2018 dataset
Run from ParlAI directory
'''

if __name__ == '__main__':
    parser = ParlaiParser(add_model_args=True)
    parser.add_argument('-d', '--display-examples', type='bool', default=False)
    LanguageModelAgent.add_cmdline_args(parser)
    parser.set_params(
        dict_file=
        'models:personachat/language_model/languagemodel_esz512_hid1024_nl2.pt.dict',
        sampling_mode=True,
        task='parlai.agents.local_human.local_human:LocalHumanAgent',
        model='language_model',
        model_file=
        'models:personachat/language_model/languagemodel_esz512_hid1024_nl2.pt'
    )

    opt = parser.parse_args()
    opt['model_type'] = 'language_model'  # for builder
    # build all profile memory models
    fnames = [
        'languagemodel_esz512_hid1024_nl2.pt',
        'languagemodel_esz512_hid1024_nl2.pt.opt',
        'languagemodel_esz512_hid1024_nl2.pt.dict'
    ]
    download_models(opt, fnames, 'personachat', version='v3.0')
コード例 #19
0
#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
Interact with a pre-trained model. Key-Value Memory Net model trained on personachat
using persona 'self'.

[Note: no persona in this example code is actually given to the model.]
"""

from parlai.core.build_data import download_models
from parlai.core.params import ParlaiParser
from parlai.scripts.interactive import interactive

if __name__ == '__main__':
    parser = ParlaiParser(add_model_args=True)
    parser.add_argument('-d', '--display-examples', type='bool', default=False)
    parser.set_params(
        model='projects.personachat.kvmemnn.kvmemnn:KvmemnnAgent',
        model_file='models:convai2/kvmemnn/model',
        interactive_mode=True,
    )
    opt = parser.parse_args()
    # build all profile memory models
    fnames = ['kvmemnn.tgz']
    opt['model_type'] = 'kvmemnn'  # for builder
    download_models(opt, fnames, 'convai2')
    interactive(opt)
コード例 #20
0
ファイル: kvmemnn_interactive.py プロジェクト: xlrshop/Parl
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. An additional grant
# of patent rights can be found in the PATENTS file in the same directory.
from parlai.core.build_data import download_models
from parlai.core.params import ParlaiParser
from parlai.scripts.interactive import interactive
'''Interact with pre-trained model
Key-Value Memory Net model trained on personachat using persona 'self'
[Note: no persona in this example code is actually given to the model.]
'''

if __name__ == '__main__':
    parser = ParlaiParser(add_model_args=True)
    parser.add_argument('-d', '--display-examples', type='bool', default=False)
    parser.set_params(
        task='parlai.agents.local_human.local_human:LocalHumanAgent',
        model='projects.personachat.kvmemnn.kvmemnn:KvmemnnAgent',
        model_file=
        'models:personachat/kvmemnn/kvmemnn/persona-self_rephraseTrn-True_rephraseTst-False_lr-0.1_esz-500_margin-0.1_tfidf-False_shareEmb-True_hops1_lins0_model',
        interactive_mode=True,
    )
    opt = parser.parse_args()
    # build all profile memory models
    fnames = ['kvmemnn.tgz']
    opt['model_type'] = 'kvmemnn'  # for builder
    download_models(opt, fnames, 'personachat')
    interactive(opt)
コード例 #21
0
 def add_cmdline_args(cls,
                      parser: ParlaiParser,
                      partial_opt: Optional[Opt] = None) -> ParlaiParser:
     """
     Add command line arguments.
     """
     Transformer.add_common_cmdline_args(parser)
     agent = parser.add_argument_group('TransresnetModel arguments')
     agent.add_argument(
         '--truncate',
         type=int,
         default=32,
         help='Max amount of tokens allowed in a text sequence',
     )
     agent.add_argument(
         '--image-features-dim',
         type=int,
         default=2048,
         help='dimensionality of image features',
     )
     agent.add_argument(
         '--embedding-type',
         type=str,
         default=None,
         choices=[None, 'fasttext_cc'],
         help='Specify if using pretrained embeddings',
     )
     agent.add_argument(
         '--load-encoder-from',
         type=str,
         default=None,
         help='Specify if using a pretrained transformer encoder',
     )
     agent.add_argument(
         '--hidden-dim',
         type=int,
         default=300,
         help='Hidden dimesionality of personality and image encoder',
     )
     agent.add_argument(
         '--num-layers-all',
         type=int,
         default=-1,
         help='If >= 1, number of layers for both the text '
         'and image encoders.',
     )
     agent.add_argument(
         '--num-layers-text-encoder',
         type=int,
         default=1,
         help='Number of layers for the text encoder',
     )
     agent.add_argument(
         '--num-layers-image-encoder',
         type=int,
         default=1,
         help='Number of layers for the image encoder',
     )
     agent.add_argument('--no-cuda',
                        dest='no_cuda',
                        action='store_true',
                        help='If True, perform ops on CPU only')
     agent.add_argument(
         '--learningrate',
         type=float,
         default=0.0005,
         help='learning rate for optimizer',
     )
     agent.add_argument(
         '--additional-layer-dropout',
         type=float,
         default=0.2,
         help='dropout for additional linear layer',
     )
     parser.set_params(ffn_size=1200,
                       attention_dropout=0.2,
                       relu_dropout=0.2,
                       n_positions=1000)
     return parser
コード例 #22
0
ファイル: run.py プロジェクト: advi1012/ParlAITest
def run_task(override_opt: Optional[dict] = None):
    """
    This task consists of an MTurk worker talking to a model and MTurker also evaluates
    each utterance of the bot for various buckets (see constants).
    """

    config_folder = os.path.join(os.path.dirname(os.path.realpath(__file__)),
                                 'task_config')
    argparser = ParlaiParser(False, False)
    argparser.add_parlai_data_path()
    default_task_folder = os.path.join(argparser.parlai_home, 'data',
                                       'turn_annotations')
    argparser.add_mturk_args()
    argparser.add_argument('-num_t',
                           '--num_turns',
                           default=6,
                           type=int,
                           help='minimum number of turns')
    argparser.add_argument(
        '--conversations-needed',
        dest='conversations_needed_string',
        default=None,
        type=str,
        help=
        'Number of convos needed for each model. For example: "modelA:50,modelB:20"',
    )
    argparser.add_argument(
        '--task-model-parallel',
        default=True,
        type=bool,
        help='Whether to load models to be used with model_parallel True.',
    )
    argparser.add_argument(
        '--auto-approve-delay',
        dest='auto_approve_delay',
        type=int,
        default=3600 * 24 * 5,
        help='how long to wait for auto approval',
    )
    argparser.add_argument(
        '--max-resp-time',
        type=int,
        default=180,
        help='time limit for entering a dialog message',
    )
    argparser.add_argument(
        '--max-onboard-time',
        type=int,
        default=300,
        help='time limit accepting onboarding',
    )
    argparser.add_argument(
        '--base-save-folder',
        default=default_task_folder,
        type=str,
        help='base folder for saving all crowdsourcing results',
    )
    argparser.add_argument(
        '--base-model-folder',
        default=None,
        type=str,
        help='base folder for loading model files from',
    )
    argparser.add_argument(
        '--onboard-worker-answer-folder',
        default=os.path.join(default_task_folder, 'onboard_answers'),
        type=str,
        help=
        'base folder for saving all worker answer results during onboarding',
    )
    argparser.add_argument(
        '--worker-blocklist-paths',
        default=None,
        type=str,
        help=
        'Path(s) to a list of IDs of workers to soft-block, separated by newlines. Use commas to indicate multiple lists',
    )
    argparser.add_argument(
        '--check-acceptability',
        default=False,
        type=bool,
        help=
        "Check worker's responses against several metrics of acceptability",
    )
    argparser.add_argument('--include-persona',
                           default=False,
                           type=bool,
                           help="Show persona to the bot")
    argparser.add_argument(
        '--conversation-start-mode',
        default='hi',
        type=str,
        choices=['hi', 'bst'],
        help=
        'Whether to show "Hi!" or two previous utterances (as in BlendedSkillTalk) at the beginning of the conversation',
    )
    argparser.add_argument(
        '--context-seed',
        default=None,
        type=int,
        help="Set seed for pulling the context info (for testing)",
    )
    argparser.add_argument(
        '--hit-config-path',
        default=os.path.join(config_folder, 'hit_config.json'),
        type=str,
        help=
        'Path to file of parameters describing how MTurk will describe the HIT to the workers',
    )
    argparser.add_argument(
        '--task-description-path',
        default=os.path.join(config_folder, 'task_description.html'),
        type=str,
        help='Path to file of HTML to show on the task-description page',
    )
    argparser.add_argument(
        '--left-pane-text-path',
        default=os.path.join(config_folder, 'left_pane_text.html'),
        type=str,
        help=
        'Path to file of HTML to show on the left-hand pane of the chat window',
    )
    argparser.add_argument(
        '--annotations-intro',
        default=
        'Does this comment from your partner have any of the following attributes? (Check all that apply)',
        type=str,
        help='Text shown to worker before they fill out annotation form',
    )
    argparser.add_argument(
        '--annotations-config-path',
        default=os.path.join(config_folder, 'annotations_config.json'),
        type=str,
        help='Path to JSON of annotation categories',
    )
    argparser.add_argument(
        '--onboard-task-data-path',
        default=os.path.join(config_folder, 'onboard_task_data.json'),
        type=str,
        help='Path to JSON containing settings for running onboarding',
    )
    argparser.add_argument(
        '--final-rating-question',
        default='Please rate your partner on a scale of 1-5.',
        type=str,
        help='Text to show when asking worker to make their final rating',
    )

    # NOTE: you have to set all three of these opts to enforce the MTurk core
    # param max_hits_per_worker.
    #  - Without unique_qual_name, MTurkManager creates different qualification
    #    for each run (so a worker could do N hits per run) Also, the
    #    worker has to get to N HITs in at least one run or they won't be given
    #    the qualification.
    #  - allowed_conversations is like max concurrent conversations
    #    allowed_conversations needs to be 1 or the actual max would be N +
    #    allowed_conversations. Worker gets notified via frontend message that
    #    they aren't eligible (second description screen), UNLESS the frontend
    #    overwrites that functionality.
    # There's also still a race condition where the worker might be able to open
    # 1 extra task
    argparser.set_defaults(
        unique_qual_name='turn_annotations_max_submissions',
        max_hits_per_worker=10,
        allowed_conversations=3,
    )

    if override_opt is not None:
        argparser.set_params(**override_opt)
        opt = argparser.parse_args([])
    else:
        opt = argparser.parse_args()
    directory_path = os.path.dirname(os.path.abspath(__file__))
    opt['task'] = os.path.basename(directory_path)

    # Set the number of conversations needed
    if opt.get('conversations_needed_string') is not None:
        parts = opt['conversations_needed_string'].split(',')
        conversations_needed = {}
        for part in parts:
            model_name, num_string = part.split(':')
            conversations_needed[model_name] = int(num_string)
        opt['conversations_needed'] = conversations_needed

    # Read in workers to soft-block
    if opt.get('worker_blocklist_paths') is not None:
        blocklist_paths = opt['worker_blocklist_paths'].split(',')
        worker_blocklist = set()
        for path in blocklist_paths:
            with open(path) as f:
                worker_blocklist |= set(f.read().strip().split('\n'))
        opt['worker_blocklist'] = worker_blocklist

    # Read in and define text shown to users
    if opt.get('hit_config') is None:
        with open(opt['hit_config_path']) as f:
            opt['hit_config'] = json.load(f)
        opt.update(opt['hit_config'])
        # Add all of the settings in hit_config into the base opt
    if opt.get('task_description') is None:
        with open(opt['task_description_path']) as f:
            opt['task_description'] = f.readlines()
    if opt.get('left_pane_text') is None:
        with open(opt['left_pane_text_path']) as f:
            opt['left_pane_text'] = f.readlines()
    if opt.get('annotations_config') is None:
        with open(opt['annotations_config_path']) as f:
            opt['annotations_config'] = json.load(f)
    if opt.get('onboard_task_data') is None:
        with open(opt['onboard_task_data_path']) as f:
            opt['onboard_task_data'] = json.load(f)

    # Limits the number of models that can generate at once
    max_concurrent_responses = 1
    semaphore = threading.Semaphore(max_concurrent_responses)

    run_statistics = copy.deepcopy(opt['conversations_needed'])
    run_statistics = {r: 0 for (r, v) in run_statistics.items()}
    onboard_statistics = {}

    save_folder = 'sandbox' if opt['is_sandbox'] else 'live'
    opt['save_folder'] = os.path.join(opt['base_save_folder'], save_folder,
                                      time.strftime("%Y_%m_%d"))
    os.makedirs(opt['save_folder'], exist_ok=True)

    print(
        f'Going to start collecting {opt["num_conversations"]} conversations, max_hits_per_worker: {opt["max_hits_per_worker"]}, reward: {opt["reward"]}, is_sandbox: {opt["is_sandbox"]}.'
    )

    # Create the models before it launches Heroku backend b/c takes a while
    models_needed = list(opt['conversations_needed'].keys())
    active_models = [
        m for m in models_needed if opt['conversations_needed'][m] > 0
    ]
    shared_bot_agents = TurkLikeAgent.get_bot_agents(opt, active_models)

    mturk_agent_ids = [AGENT_0]
    mturk_manager = MTurkManager(opt=opt, mturk_agent_ids=mturk_agent_ids)
    mturk_manager.setup_server(task_directory_path=directory_path)

    if opt['include_persona'] or opt['conversation_start_mode'] == 'bst':
        context_generator = ContextGenerator(opt, datatype='test', seed=0)
        # We pull from the test set so that the model can't regurgitate
        # memorized conversations
    else:
        context_generator = None

    try:
        mturk_manager.start_new_run()
        mturk_manager.create_hits()

        if not opt['is_sandbox']:
            # Soft-block all chosen workers
            if len(opt['worker_blocklist']) > 0:
                print(
                    f"About to soft-block {len(opt['worker_blocklist'])} workers."
                )
                for w in set(opt['worker_blocklist']):
                    try:
                        print('Soft Blocking {}\n'.format(w))
                        mturk_manager.soft_block_worker(w)
                    except Exception as e:
                        print(f'Did not soft block worker {w}: {e}')
                    time.sleep(0.1)
            else:
                print(
                    'WARNING: We are in live mode, but a list of workers to soft-block '
                    'has not been passed in.')

        def run_onboard(worker):
            world = TurnAnnotationsOnboardWorld(opt, worker)
            status = world.parley()
            if status not in onboard_statistics:
                onboard_statistics[status] = 0
            onboard_statistics[status] += 1
            print(
                f'After onboard world parley. About to shutdown onboard world for {worker.worker_id}, status was: {status}. Total onboard statistics for this run are: {onboard_statistics}.'
            )
            world.shutdown()

        mturk_manager.set_onboard_function(onboard_function=run_onboard)
        mturk_manager.ready_to_accept_workers()

        def check_worker_eligibility(worker):
            return True

        def assign_worker_roles(workers):
            workers[0].id = mturk_agent_ids[0]

        def run_conversation(mturk_manager, opt, workers):
            remaining_counts_needed = [
                (m, c - run_statistics[m])
                for (m, c) in opt['conversations_needed'].items()
            ]
            remaining_counts_needed.sort(reverse=True, key=lambda x: x[1])
            model_name = remaining_counts_needed[0][0]
            print(
                f'Remaining conversation counts needed: {remaining_counts_needed}'
            )

            # Get a bot and add it to the list of "workers"
            print(f'Choosing the "{model_name}" model for the bot.')
            agent = create_agent_from_shared(shared_bot_agents[model_name])
            bot_worker = TurkLikeAgent(
                opt,
                model_name=model_name,
                model_agent=agent,
                num_turns=opt['num_turns'],
                semaphore=semaphore,
            )
            workers_including_bot = workers + [bot_worker]

            assert len(workers_including_bot) == 2

            # Get context: personas, previous utterances, etc.
            if context_generator is not None:
                context_info = context_generator.get_context()
            else:
                context_info = None

            conv_idx = mturk_manager.conversation_index
            world = TurnAnnotationsChatWorld(
                opt=opt,
                agents=workers_including_bot,
                num_turns=opt['num_turns'],
                max_resp_time=opt['max_resp_time'],
                tag='conversation t_{}'.format(conv_idx),
                context_info=context_info,
            )
            while not world.episode_done():
                print('About to parley')
                world.parley()
            model_nickname, worker_is_unacceptable, convo_finished = world.save_data(
            )
            if worker_is_unacceptable:
                print(f'Soft-blocking worker {workers[0].worker_id}')
                mturk_manager.soft_block_worker(workers[0].worker_id)
                time.sleep(0.1)
            if not worker_is_unacceptable and convo_finished:
                run_statistics[model_nickname] += 1

            world.shutdown()
            world.review_work()

        mturk_manager.start_task(
            eligibility_function=check_worker_eligibility,
            assign_role_function=assign_worker_roles,
            task_function=run_conversation,
        )

    except BaseException:
        raise
    finally:
        mturk_manager.expire_all_unassigned_hits()
        mturk_manager.shutdown()
コード例 #23
0
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. An additional grant
# of patent rights can be found in the PATENTS file in the same directory.
"""Interact with a pre-trained model.
Key-Value Memory Net model trained on personachat using persona 'self'
[Note: no persona in this example code is actually given to the model.]
"""

from parlai.core.build_data import download_models
from parlai.core.params import ParlaiParser
from parlai.scripts.interactive import interactive

if __name__ == '__main__':
    parser = ParlaiParser(add_model_args=True)
    parser.add_argument('-d', '--display-examples', type='bool', default=False)
    parser.set_params(
        model='projects.personachat.kvmemnn.arms:ArmsAgent',
        model_file='/tmp/persona_self_original.checkpoint',
        #model='projects.personachat.kvmemnn.kvmemnn:KvmemnnAgent',
        #model_file='models:convai2/kvmemnn/model',
        interactive_mode=True,
    )
    opt = parser.parse_args()
    # build all profile memory models
    fnames = ['kvmemnn.tgz']
    opt['model_type'] = 'kvmemnn'  # for builder
    download_models(opt, fnames, 'convai2')
    interactive(opt)
コード例 #24
0
ファイル: interactive.py プロジェクト: elnaaz/ParlAI
# LICENSE file in the root directory of this source tree. An additional grant
# of patent rights can be found in the PATENTS file in the same directory.
"""Interact with a pre-trained model.
This transformer model was trained on convai2:self.
"""

from parlai.core.build_data import download_models, modelzoo_path
from parlai.core.params import ParlaiParser
from parlai.scripts.interactive import interactive

if __name__ == '__main__':
    parser = ParlaiParser(add_model_args=True)
    parser.set_params(
        model='transformer',
        model_file='models:convai2/transformer/convai2_self_transformer_model',
        dict_file=
        'models:convai2/transformer/convai2_self_transformer_model.dict',
        dict_lower=True,
    )
    opt = parser.parse_args()
    if opt.get('model_file', '').startswith(
            modelzoo_path(opt.get('datapath'), "models:convai2")):
        opt['model_type'] = 'transformer'
        fnames = [
            'convai2_self_transformer_model.tgz',
            'convai2_self_transformer_model.dict',
            'convai2_self_transformer_model.opt'
        ]
        download_models(opt, fnames, 'convai2', version='v3.0')
    interactive(opt)
コード例 #25
0
    def get_parlai_opt(self) -> Opt:
        """
        Parser for converting fairseq argument to ParlAI opt.

        :return opt:
            opt parsed by ParlAI Parser
        """
        # assume encoder/decoder symetrical except for number of layers
        state = self.state
        fairseq_args = state['args'].__dict__

        transformer_common_config = {}

        # 1. Map transformer params
        for each in TRANSFORMER_PARAMETER_MAPPING:
            transformer_common_config[TRANSFORMER_PARAMETER_MAPPING[
                each]] = fairseq_args[f'encoder_{each}']
        # 2. Map dropout
        for each in TRANSFORMER_DROPOUT:
            transformer_common_config[each] = fairseq_args[each]

        if 'activation_dropout' in fairseq_args:
            transformer_common_config['relu_dropout'] = fairseq_args[
                'activation_dropout']
        else:
            transformer_common_config['relu_dropout'] = fairseq_args[
                'relu_dropout']

        # 3. Map other options
        transformer_common_config.update({
            'model':
            self.opt['model'],
            # number of layers
            'n_encoder_layers':
            fairseq_args['encoder_layers'],
            'n_decoder_layers':
            fairseq_args['decoder_layers'],
            # tokenization args
            'dict_tokenizer':
            self.opt['tokenizer'],
            'bpe_vocab':
            self.opt['vocab'],
            'bpe_merge':
            self.opt['merge'],
            'n_positions':
            fairseq_args['max_source_positions'],
        })

        # 4. Embedding scale
        if 'encoder_embed_scale' in fairseq_args:
            transformer_common_config['embeddings_scale'] = (
                fairseq_args['encoder_embed_scale'] != 1.0)
        else:
            transformer_common_config[
                'embeddings_scale'] = not fairseq_args['no_scale_embedding']

        # 5. Determine variant
        if fairseq_args['encoder_normalize_before']:
            transformer_common_config['variant'] = 'prelayernorm'
        elif fairseq_args['layernorm_embedding']:
            transformer_common_config['variant'] = 'bart'
        else:
            transformer_common_config['variant'] = 'aiayn'

        if self.opt['add_prefix_space']:
            transformer_common_config['bpe_add_prefix_space'] = True
        parser = ParlaiParser()
        parser.set_params(**transformer_common_config)
        opt = parser.parse_args([])

        # 6. Augment opt with additional ParlAI options
        opt['fp16'] = self.opt['fp16']
        opt['activation'] = self.opt['activation']
        opt['delimiter'] = self.opt['delimiter']
        opt['history_add_global_end_token'] = self.opt[
            'history_add_global_end_token']
        # Makes model fp16 ready for fine-tuning, means 4 extra padding tokens.
        opt['force_fp16_tokens'] = True
        opt['converting'] = True

        return opt
コード例 #26
0
ファイル: test_torch_agent.py プロジェクト: XinnuoXu/DRank
    def test_vectorize(self):
        """
        Make sure that the vectorize function is actually adding a new field.
        """
        try:
            from parlai.core.torch_agent import TorchAgent
        except ImportError as e:
            if 'pytorch' in e.msg:
                print('Skipping TestTorchAgent.test_vectorize, no pytorch.')
                return

        from parlai.core.params import ParlaiParser
        parser = ParlaiParser()
        TorchAgent.add_cmdline_args(parser)
        parser.set_params(no_cuda=True)
        opt = parser.parse_args(print_args=False)
        mdict = MockDict()

        shared = {'opt': opt, 'dict': mdict}
        agent = TorchAgent(opt, shared)
        observation = {}
        observation["text"] = "What does the dog do?"
        observation["labels"] = ["The dog jumps over the cat."]

        # add start and end
        obs_vec = agent.vectorize(observation, add_start=True, add_end=True)
        self.assertTrue(
            'text_vec' in obs_vec,
            "Field 'text_vec' missing from vectorized observation")
        self.assertTrue(obs_vec['text_vec'].numpy().tolist() == [7, 8, 9],
                        "Vectorized text is incorrect.")
        self.assertTrue(
            'labels_vec' in obs_vec,
            "Field 'labels_vec' missing from vectorized observation")
        self.assertTrue(
            obs_vec['labels_vec'].numpy().tolist() == [
                mdict.START_IDX, 7, 8, 9, mdict.END_IDX
            ], "Vectorized label is incorrect.")
        # no start, add end
        obs_vec = agent.vectorize(observation, add_start=False, add_end=True)
        self.assertTrue(
            obs_vec['labels_vec'].numpy().tolist() == [7, 8, 9, mdict.END_IDX],
            "Vectorized label is incorrect.")
        # add start, no end
        obs_vec = agent.vectorize(observation, add_start=True, add_end=False)
        self.assertTrue(
            obs_vec['labels_vec'].numpy().tolist() == [
                mdict.START_IDX, 7, 8, 9
            ], "Vectorized label is incorrect.")
        # no start, no end
        obs_vec = agent.vectorize(observation, add_start=False, add_end=False)
        self.assertTrue(obs_vec['labels_vec'].numpy().tolist() == [7, 8, 9],
                        "Vectorized label is incorrect.")

        observation = {}
        observation["text"] = "What does the dog do?"
        observation["eval_labels"] = ["The dog jumps over the cat."]

        # eval_labels
        obs_vec = agent.vectorize(observation)
        self.assertTrue(
            'eval_labels_vec' in obs_vec,
            "Field \'eval_labels_vec\' missing from vectorized observation")
        self.assertTrue(
            obs_vec['eval_labels_vec'].numpy().tolist() == [
                mdict.START_IDX, 7, 8, 9, mdict.END_IDX
            ], "Vectorized label is incorrect.")
        # truncate
        obs_vec = agent.vectorize(observation, truncate=2)
        self.assertTrue(
            'eval_labels_vec' in obs_vec,
            "Field \'eval_labels_vec\' missing from vectorized observation")
        self.assertTrue(
            obs_vec['eval_labels_vec'].numpy().tolist() == [
                mdict.START_IDX, 7
            ], "Vectorized label is incorrect: " +
            str(obs_vec['eval_labels_vec']))

        # truncate
        obs_vec = agent.vectorize(observation, truncate=10)
        self.assertTrue(
            'eval_labels_vec' in obs_vec,
            "Field \'eval_labels_vec\' missing from vectorized observation")
        self.assertTrue(
            obs_vec['eval_labels_vec'].numpy().tolist() == [
                mdict.START_IDX, 7, 8, 9, mdict.END_IDX
            ], "Vectorized label is incorrect.")
コード例 #27
0
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. An additional grant
# of patent rights can be found in the PATENTS file in the same directory.
"""Interact with a pre-trained model.
This seq2seq model was trained on convai2:self.
"""

from parlai.core.build_data import download_models
from parlai.core.params import ParlaiParser
from parlai.scripts.interactive import interactive

if __name__ == '__main__':
    parser = ParlaiParser(add_model_args=True)
    parser.set_params(
        model='seq2seq',
        model_file='models:convai2/seq2seq/convai2_self_seq2seq_model',
        dict_file='models:convai2/seq2seq/convai2_self_seq2seq_model.dict',
        dict_lower=True,
    )
    opt = parser.parse_args()
    if opt.get('model_file', '').startswith('models:convai2'):
        opt['model_type'] = 'seq2seq'
        fnames = [
            'convai2_self_seq2seq_model.tgz',
            'convai2_self_seq2seq_model.dict', 'convai2_self_seq2seq_model.opt'
        ]
        download_models(opt, fnames, 'convai2', version='v3.0')
    interactive(opt)
コード例 #28
0
# of patent rights can be found in the PATENTS file in the same directory.
"""Interact with a pre-trained model.
This transformer model was trained on convai2:self.
"""

from parlai.core.build_data import download_models, modelzoo_path
from parlai.core.params import ParlaiParser
from parlai.scripts.interactive import interactive

if __name__ == '__main__':
    parser = ParlaiParser(add_model_args=True)
    parser.set_params(
        model='transformer',
        #model_file='models:convai2/transformer/convai2_self_transformer_model',
        #dict_file='models:convai2/transformer/convai2_self_transformer_model.dict',
        model_file=
        './checkpoints/convai2_transformer_volta_[l=4,h=2,dw=256,dm=256,di=2048,dk=64,dv=64,src_tgt_share=False,tgt_prj=False,smooth=False]',
        dict_file=
        './checkpoints/convai2_transformer_volta_[l=4,h=2,dw=256,dm=256,di=2048,dk=64,dv=64,src_tgt_share=False,tgt_prj=False,smooth=False].dict',
        dict_lower=True,
    )
    opt = parser.parse_args()
    if opt.get('model_file', '').startswith(
            modelzoo_path(opt.get('datapath'), "models:convai2")):
        opt['model_type'] = 'transformer'
        #fnames = ['convai2_self_transformer_model.tgz',
        #          'convai2_self_transformer_model.dict',
        #          'convai2_self_transformer_model.opt']
        fnames = [
            'convai2_transformer_volta_[l=4,h=2,dw=256,dm=256,di=2048,dk=64,dv=64,src_tgt_share=False,tgt_prj=False,smooth=False].tgz',
            'convai2_transformer_volta_[l=4,h=2,dw=256,dm=256,di=2048,dk=64,dv=64,src_tgt_share=False,tgt_prj=False,smooth=False].dict'
            'convai2_transformer_volta_[l=4,h=2,dw=256,dm=256,di=2048,dk=64,dv=64,src_tgt_share=False,tgt_prj=False,smooth=False].opt'
コード例 #29
0
    def __init__(self, task_run: "TaskRun", args: "DictConfig",
                 shared_state: "SharedTaskState"):
        # Set the number of conversations needed
        conversations_needed_string = args.blueprint.conversations_needed_string
        conversations_needed = {}
        parts = conversations_needed_string.split(',')
        for part in parts:
            model_name, num_string = part.split(':')
            conversations_needed[model_name] = int(num_string)
        self.conversations_needed = conversations_needed
        shared_state.conversations_needed = conversations_needed
        args.blueprint.num_conversations = sum(conversations_needed.values())

        # Default conversation initialization
        super().__init__(task_run, args=args, shared_state=shared_state)
        random.seed(self.args.blueprint.random_seed)
        np.random.seed(self.args.blueprint.random_seed)

        # Load task configuration data beyond the task desscription, as the super does that
        left_pane_path = os.path.expanduser(args.blueprint.left_pane_text_path)
        with open(left_pane_path, "r") as left_pane_file:
            self.left_pane_text = left_pane_file.read()
        annotations_config_path = os.path.expanduser(
            args.blueprint.annotations_config_path)
        with open(annotations_config_path, "r") as annotations_config_file:
            self.annotations_config = annotations_config_file.read()
        onboard_task_data_path = os.path.expanduser(
            args.blueprint.onboard_task_data_path)
        with open(onboard_task_data_path, "r") as onboard_task_data_file:
            self.onboard_task_data = json.load(onboard_task_data_file)

        run_statistics = {r: 0 for (r, v) in self.conversations_needed.items()}
        shared_state.run_statistics = run_statistics

        # Initialize models
        models_needed = list(conversations_needed.keys())
        active_models = [
            m for m in models_needed if conversations_needed[m] > 0
        ]
        shared_bot_agents = TurkLikeAgent.get_bot_agents(args, active_models)
        shared_state.shared_models = shared_bot_agents

        # Context need parlai options
        argparser = ParlaiParser(False, False)
        argparser.add_parlai_data_path()
        if len(args.blueprint.override_opt) > 0:
            argparser.set_params(**args.blueprint.override_opt)
        opt = argparser.parse_args([])

        if (args.blueprint.include_persona
                or args.blueprint.conversation_start_mode == 'bst'):
            context_generator = ContextGenerator(opt, datatype='test', seed=0)
            # We pull from the test set so that the model can't regurgitate
            # memorized conversations
        else:
            context_generator = None
        shared_state.context_generator = context_generator

        # Limits the number of models that can generate at once
        max_concurrent_responses = 1
        semaphore = Semaphore(max_concurrent_responses)

        # Lock for editing run statistics between threads
        statistics_condition = Condition()

        # Move shared state into the world and onboarding opts, such that these
        # can be used by the worlds
        shared_state.onboarding_world_opt.update({
            'onboard_statistics':
            shared_state.onboard_statistics,
            'statistics_condition':
            statistics_condition,
            'max_onboard_time':
            args.blueprint.max_onboard_time,
            'onboard_task_data':
            self.onboard_task_data,
            'onboarding_qualification':
            args.blueprint.onboarding_qualification,
        })
        shared_state.world_opt.update({
            'annotations_config':
            self.annotations_config,
            'block_qualification':
            args.blueprint.block_qualification,
            'conversations_needed':
            conversations_needed,
            'run_statistics':
            shared_state.run_statistics,
            'context_generator':
            context_generator,
            'semaphore':
            semaphore,
            'shared_bot_agents':
            shared_bot_agents,
            'num_turns':
            args.blueprint.num_turns,
            'max_resp_time':
            args.blueprint.max_resp_time,
            'is_sandbox':
            args.provider.requester_name == 'MOCK_REQUESTER',
            'statistics_condition':
            statistics_condition,
            'check_acceptability':
            args.blueprint.check_acceptability,
            'include_persona':
            args.blueprint.include_persona,
            'conversation_start_mode':
            args.blueprint.conversation_start_mode,
            'chat_data_folder':
            args.blueprint.chat_data_folder,
        })
コード例 #30
0
ファイル: test_torch_agent.py プロジェクト: XinnuoXu/DRank
    def test_map_unmap(self):
        try:
            from parlai.core.torch_agent import TorchAgent, Output
        except ImportError as e:
            if 'pytorch' in e.msg:
                print('Skipping TestTorchAgent.test_map_unmap, no pytorch.')
                return

        observations = []
        observations.append({
            "text": "What is a painting?",
            "labels": ["Paint on a canvas."]
        })
        observations.append({})
        observations.append({})
        observations.append({
            "text": "What is a painting?",
            "labels": ["Paint on a canvas."]
        })
        observations.append({})
        observations.append({})

        from parlai.core.params import ParlaiParser
        parser = ParlaiParser()
        TorchAgent.add_cmdline_args(parser)
        parser.set_params(no_cuda=True)
        opt = parser.parse_args(print_args=False)
        mdict = MockDict()

        shared = {'opt': opt, 'dict': mdict}
        agent = TorchAgent(opt, shared)

        vec_observations = [agent.vectorize(obs) for obs in observations]

        batch = agent.batchify(vec_observations)

        self.assertTrue(batch.text_vec is not None,
                        "Missing 'text_vecs' field.")
        self.assertTrue(
            batch.text_vec.numpy().tolist() == [[7, 8, 9], [7, 8, 9]],
            "Incorrectly vectorized text field of obs_batch.")
        self.assertTrue(batch.label_vec is not None,
                        "Missing 'label_vec' field.")
        self.assertTrue(
            batch.label_vec.numpy().tolist() == [[
                mdict.START_IDX, 7, 8, 9, mdict.END_IDX
            ], [mdict.START_IDX, 7, 8, 9, mdict.END_IDX]],
            "Incorrectly vectorized text field of obs_batch.")
        self.assertTrue(
            batch.labels == ["Paint on a canvas.", "Paint on a canvas."],
            "Doesn't return correct labels: " + str(batch.labels))
        true_i = [0, 3]
        self.assertTrue(
            all(batch.valid_indices[i] == true_i[i] for i in range(2)),
            "Returns incorrect indices of valid observations.")

        observations = []
        observations.append({
            "text": "What is a painting?",
            "eval_labels": ["Paint on a canvas."]
        })
        observations.append({})
        observations.append({})
        observations.append({
            "text": "What is a painting?",
            "eval_labels": ["Paint on a canvas."]
        })
        observations.append({})
        observations.append({})

        vec_observations = [agent.vectorize(obs) for obs in observations]

        batch = agent.batchify(vec_observations)

        self.assertTrue(batch.label_vec is not None,
                        "Missing \'eval_label_vec\' field.")
        self.assertTrue(
            batch.label_vec.numpy().tolist() == [[
                mdict.START_IDX, 7, 8, 9, mdict.END_IDX
            ], [mdict.START_IDX, 7, 8, 9, mdict.END_IDX]],
            "Incorrectly vectorized text field of obs_batch.")

        batch_reply = [{} for i in range(6)]
        predictions = ["Oil on a canvas.", "Oil on a canvas."]
        output = Output(predictions, None)
        expected_unmapped = batch_reply.copy()
        expected_unmapped[0]["text"] = "Oil on a canvas."
        expected_unmapped[3]["text"] = "Oil on a canvas."
        self.assertTrue(
            agent.match_batch(batch_reply, batch.valid_indices,
                              output) == expected_unmapped,
            "Unmapped predictions do not match expected results.")