コード例 #1
0
 def test_gpt2_bpe_tokenize(self):
     with testing_utils.capture_output():
         opt = Opt({'dict_tokenizer': 'gpt2', 'datapath': './data'})
         agent = DictionaryAgent(opt)
     self.assertEqual(
         # grinning face emoji
         agent.gpt2_tokenize(u'Hello, ParlAI! \U0001f600'),
         [
             'Hello',
             ',',
             r'\xc4\xa0Par',
             'l',
             'AI',
             '!',
             r'\xc4\xa0\xc3\xb0\xc5\x81\xc4\xba',
             r'\xc4\xa2',
         ],
     )
     self.assertEqual(
         agent.vec2txt(agent.tok2ind[w] for w in [
             'Hello',
             ',',
             r'\xc4\xa0Par',
             'l',
             'AI',
             '!',
             r'\xc4\xa0\xc3\xb0\xc5\x81\xc4\xba',
             r'\xc4\xa2',
         ]),
         # grinning face emoji
         u'Hello, ParlAI! \U0001f600',
     )
コード例 #2
0
    def test_opt(self):
        opt = {'x': 0}
        opt = Opt(opt)
        opt['x'] += 1
        opt['x'] = 10
        history = opt.history['x']
        self.assertEqual(history[0][1], 1, 'History not set properly')
        self.assertEqual(history[1][1], 10, 'History not set properly')

        opt_copy = deepcopy(opt)
        history = opt_copy.history['x']
        self.assertEqual(history[0][1], 1, 'Deepcopy history not set properly')
        self.assertEqual(history[1][1], 10,
                         'Deepcopy history not set properly')
コード例 #3
0
    def _process_args_to_opts(self,
                              args_that_override: Optional[List[str]] = None):
        self.opt = Opt(vars(self.args))

        # custom post-parsing
        self.opt['parlai_home'] = self.parlai_home
        self.opt = self._infer_datapath(self.opt)

        # set all arguments specified in command line as overridable
        option_strings_dict = {}
        store_true = []
        store_false = []
        for group in self._action_groups:
            for a in group._group_actions:
                if hasattr(a, 'option_strings'):
                    for option in a.option_strings:
                        option_strings_dict[option] = a.dest
                        if '_StoreTrueAction' in str(type(a)):
                            store_true.append(option)
                        elif '_StoreFalseAction' in str(type(a)):
                            store_false.append(option)

        if args_that_override is None:
            args_that_override = _sys.argv[1:]

        for i in range(len(args_that_override)):
            if args_that_override[i] in option_strings_dict:
                if args_that_override[i] in store_true:
                    self.overridable[option_strings_dict[
                        args_that_override[i]]] = True
                elif args_that_override[i] in store_false:
                    self.overridable[option_strings_dict[
                        args_that_override[i]]] = False
                elif (i < len(args_that_override) - 1
                      and args_that_override[i + 1][:1] != '-'):
                    key = option_strings_dict[args_that_override[i]]
                    self.overridable[key] = self.opt[key]
        self.opt['override'] = self.overridable

        # load opts if a file is provided.
        if self.opt.get('init_opt', None) is not None:
            self._load_opts(self.opt)

        # map filenames that start with 'zoo:' to point to the model zoo dir
        if self.opt.get('model_file') is not None:
            self.opt['model_file'] = modelzoo_path(self.opt.get('datapath'),
                                                   self.opt['model_file'])
        if self.opt['override'].get('model_file') is not None:
            # also check override
            self.opt['override']['model_file'] = modelzoo_path(
                self.opt.get('datapath'), self.opt['override']['model_file'])
        if self.opt.get('dict_file') is not None:
            self.opt['dict_file'] = modelzoo_path(self.opt.get('datapath'),
                                                  self.opt['dict_file'])
        if self.opt['override'].get('dict_file') is not None:
            # also check override
            self.opt['override']['dict_file'] = modelzoo_path(
                self.opt.get('datapath'), self.opt['override']['dict_file'])

        # add start time of an experiment
        self.opt['starttime'] = datetime.datetime.today().strftime(
            '%b%d_%H-%M')
コード例 #4
0
        self.id = 'query_agent'
        self.opt = opt
        self.searchInput = search_input
    
    def observe(self, observation):
        pass

    def act(self):
        reply = Message()
        reply['id'] = self.getID()
        reply['text'] = self.searchInput
        return reply

# setup parlai environment
# parlai_parser = ParlaiParser(True, True, 'Parser For MedTrialConnect Server')
parlai_opt = Opt()
parlai_opt['model_file'] = MODEL_FILE_PATH
parlai_opt['task'] = None
retriever_agent = create_agent(parlai_opt, requireModelExists=True)

# setup flask app
app = Flask(__name__)

@app.route(URL, methods=['GET'])
def search():
    query_agent = QueryAgent(parlai_opt, search_input=request.args.get('query'))
    world = MultiAgentDialogWorld(parlai_opt, [query_agent, retriever_agent])
    world.parley()
    retriever_output = world.acts[-1]
    candidates = retriever_output['candidates']
    candidate_scores = retriever_output['candidate_scores']