コード例 #1
0
def build_dict(opt):
    if not opt.get('dict_file'):
        print('Tried to build dictionary but `--dict-file` is not set. Set ' +
              'this param so the dictionary can be saved.')
        return
    print('[ setting up dictionary. ]')
    if os.path.isfile(opt['dict_file']):
        # Dictionary already built
        print("[ dictionary already built .]")
        return
    if opt.get('dict_class'):
        # Custom dictionary class
        dictionary = str2class(opt['dict_class'])(opt)
    else:
        # Default dictionary class
        dictionary = DictionaryAgent(opt)
    ordered_opt = copy.deepcopy(opt)
    cnt = 0
    # we use train set to build dictionary
    ordered_opt['datatype'] = 'train:ordered'
    ordered_opt['numthreads'] = 1
    ordered_opt['batchsize'] = 1
    world_dict = create_task(ordered_opt, dictionary)
    # pass examples to dictionary
    for _ in world_dict:
        cnt += 1
        if cnt > opt['dict_maxexs'] and opt['dict_maxexs'] > 0:
            print('Processed {} exs, moving on.'.format(opt['dict_maxexs']))
            # don't wait too long...
            break
        world_dict.parley()
    print('[ dictionary built. ]')
    dictionary.save(opt['dict_file'], sort=True)
コード例 #2
0
ファイル: build_dict.py プロジェクト: jojonki/ParlAI
def build_dict(opt):
    if not opt.get('dict_file'):
        print('Tried to build dictionary but `--dict-file` is not set. Set ' +
              'this param so the dictionary can be saved.')
        return
    print('[ setting up dictionary. ]')
    if os.path.isfile(opt['dict_file']):
        # Dictionary already built
        print("[ dictionary already built .]")
        return
    if opt.get('dict_class'):
        # Custom dictionary class
        dictionary = str2class(opt['dict_class'])(opt)
    else:
        # Default dictionary class
        dictionary = DictionaryAgent(opt)
    ordered_opt = copy.deepcopy(opt)
    cnt = 0
    # we use train set to build dictionary
    ordered_opt['datatype'] = 'train:ordered'
    if 'stream' in opt['datatype']:
        ordered_opt['datatype'] += ':stream'
    ordered_opt['numthreads'] = 1
    ordered_opt['batchsize'] = 1
    world_dict = create_task(ordered_opt, dictionary)
    # pass examples to dictionary
    for _ in world_dict:
        cnt += 1
        if cnt > opt['dict_maxexs'] and opt['dict_maxexs'] > 0:
            print('Processed {} exs, moving on.'.format(opt['dict_maxexs']))
            # don't wait too long...
            break
        world_dict.parley()
    print('[ dictionary built. ]')
    dictionary.save(opt['dict_file'], sort=True)
コード例 #3
0
ファイル: build_dict.py プロジェクト: rapalizsolt/ParlAI
def build_dict(opt):
    if 'dict_file' not in opt:
        return
    print('[ setting up dictionary. ]')
    if os.path.isfile(opt['dict_file']):
        # Dictionary already built
        print("[ dictionary already built .]")
        return
    if 'dict_class' in opt:
        # Custom dictionary class
        name = opt['dict_class'].split(':')
        module = importlib.import_module(name[0])
        dict_class = getattr(module, name[1])
        dictionary = dict_class(opt)
    else:
        # Default dictionary class
        dictionary = DictionaryAgent(opt)
    ordered_opt = copy.deepcopy(opt)
    cnt = 0
    # we use train set to build dictionary
    ordered_opt['datatype'] = 'train:ordered'
    ordered_opt['numthreads'] = 1
    ordered_opt['batchsize'] = 1
    world_dict = create_task(ordered_opt, dictionary)
    # pass examples to dictionary
    for _ in world_dict:
        cnt += 1
        if cnt > opt['dict_maxexs'] and opt['dict_maxexs'] > 0:
            print('Processed {} exs, moving on.'.format(opt['dict_maxexs']))
            # don't wait too long...
            break
        world_dict.parley()
    print('[ dictionary built. ]')
    dictionary.save(opt['dict_file'], sort=True)
コード例 #4
0
ファイル: test_dict.py プロジェクト: Taekyung2/MichinAI
    def test_save_reload(self):
        """
        Save and reload an existing BL-BPE dictionary.
        """
        pp = ParlaiParser()
        DictionaryAgent.add_cmdline_args(pp)
        da = DictionaryAgent(
            pp.parse_args([
                '--dict-tokenizer',
                'bytelevelbpe',
                '--bpe-merge',
                DEFAULT_BYTELEVEL_BPE_MERGE,
                '--bpe-vocab',
                DEFAULT_BYTELEVEL_BPE_VOCAB,
            ]))
        # poor behavior if we failed to load
        assert da.txt2vec("hello") != []

        with testing_utils.tempdir() as tmpdir:
            newdf = os.path.join(tmpdir, "dict")
            da.save(newdf)

            # now load it
            da2 = DictionaryAgent(
                pp.parse_args(
                    ['--dict-tokenizer', 'bytelevelbpe', '--dict-file',
                     newdf]))
            assert da2.txt2vec("hello") == da.txt2vec("hello")
コード例 #5
0
ファイル: build_dict.py プロジェクト: yoichimatsuyama/ParlAI
def build_dict(opt, skip_if_built=False):
    if isinstance(opt, ParlaiParser):
        print('[ Deprecated Warning: should be passed opt not Parser ]')
        opt = opt.parse_args()
    if not opt.get('dict_file'):
        print('Tried to build dictionary but `--dict-file` is not set. Set ' +
              'this param so the dictionary can be saved.')
        return
    print('[ setting up dictionary. ]')

    if skip_if_built and os.path.isfile(opt['dict_file']):
        # Dictionary already built, skip all loading or setup
        print("[ dictionary already built .]")
        return None

    if opt.get('dict_class'):
        # Custom dictionary class
        dictionary = str2class(opt['dict_class'])(opt)
    else:
        # Default dictionary class
        dictionary = DictionaryAgent(opt)

    if os.path.isfile(opt['dict_file']):
        # Dictionary already built, return loaded dictionary agent
        print("[ dictionary already built .]")
        return dictionary

    ordered_opt = copy.deepcopy(opt)
    cnt = 0
    # we use train set to build dictionary

    ordered_opt['numthreads'] = 1
    ordered_opt['batchsize'] = 1
    ordered_opt['image_mode'] = 'none'
    if ordered_opt['task'] == 'pytorch_teacher':
        pytorch_buildteacher_task = ordered_opt.get('pytorch_buildteacher', '')
        if pytorch_buildteacher_task != '':
            ordered_opt['task'] = pytorch_buildteacher_task

    datatypes = ['train:ordered:stream']
    if opt.get('dict_include_valid'):
        datatypes.append('valid:stream')
    if opt.get('dict_include_test'):
        datatypes.append('test:stream')
    cnt = 0
    for dt in datatypes:
        ordered_opt['datatype'] = dt
        world_dict = create_task(ordered_opt, dictionary)
        # pass examples to dictionary
        while not world_dict.epoch_done():
            cnt += 1
            if cnt > opt['dict_maxexs'] and opt['dict_maxexs'] > 0:
                print('Processed {} exs, moving on.'.format(
                    opt['dict_maxexs']))
                # don't wait too long...
                break
            world_dict.parley()
    dictionary.save(opt['dict_file'], sort=True)
    print('[ dictionary built with {} tokens ]'.format(len(dictionary)))
    return dictionary
コード例 #6
0
ファイル: test_dict.py プロジェクト: bonbert81/ParlAI
 def test_byte_level_bpe_tokenize(self):
     """
     Tests a bytelevel bpe tokenizer inside ParlAI.
     """
     parser = ParlaiParser()
     parser.set_params(
         dict_tokenizer='bytelevelbpe',
         bpe_vocab=DEFAULT_BYTELEVEL_BPE_VOCAB,
         bpe_merge=DEFAULT_BYTELEVEL_BPE_MERGE,
         bpe_add_prefix_space=False,
     )
     opt = parser.parse_args([], print_args=False)
     agent = DictionaryAgent(opt)
     self.assertEqual(
         # grinning face emoji
         agent.bytelevelbpe_tokenize(u'Hello, ParlAI! \U0001f600'),
         BYTELEVEL_BPE_RESULT,
     )
     self.assertEqual(
         agent.vec2txt([agent.tok2ind[w] for w in BYTELEVEL_BPE_RESULT]),
         # grinning face emoji
         u'Hello, ParlAI! \U0001f600',
     )
     self.assertEqual(
         agent.txt2vec(u'Hello, ParlAI! \U0001f600'),
         [agent.tok2ind[w] for w in BYTELEVEL_BPE_RESULT],
     )
     vocab_size = agent.byte_level_bpe.tokenizer.get_vocab_size()
     with testing_utils.tempdir() as tmpdir:
         path = os.path.join(tmpdir, 'dict-checkpoint')
         agent.save(filename=path)
         agent.load(filename=path)
     # Test loading / saving
     self.assertEqual(vocab_size,
                      agent.byte_level_bpe.tokenizer.get_vocab_size())
     self.assertEqual(
         # grinning face emoji
         agent.bytelevelbpe_tokenize(u'Hello, ParlAI! \U0001f600'),
         BYTELEVEL_BPE_RESULT,
     )
     self.assertEqual(
         agent.vec2txt([agent.tok2ind[w] for w in BYTELEVEL_BPE_RESULT]),
         # grinning face emoji
         u'Hello, ParlAI! \U0001f600',
     )
     self.assertEqual(
         agent.txt2vec(u'Hello, ParlAI! \U0001f600'),
         [agent.tok2ind[w] for w in BYTELEVEL_BPE_RESULT],
     )
     # Test special token ids are mapped correctly:
     # 4 special tokens are added in ParlAI dict in the begining and at the
     # end for Hugging Face null token would be 0 in ParlAI dict and
     # original_vocab in Hugging Face
     assert agent.txt2vec("__null__") == [0]
     assert agent.txt2vec("__start__") == [1]
     assert agent.txt2vec("__end__") == [2]
     assert agent.txt2vec("__unk__") == [3]
コード例 #7
0
class NERDictionaryAgent(DictionaryAgent):
    """Named Entity Recognition dictionary agent"""

    @staticmethod
    def add_cmdline_args(argparser):
        """Add command line arguments"""
        group = DictionaryAgent.add_cmdline_args(argparser)
        group.add_argument(
            '--dict_class', default=class2str(NERDictionaryAgent),
            help='Sets the dictionary\'s class'
        )

    def __init__(self, opt, shared=None):
        """Initialize NER dictionary agent"""
        child_opt = copy.deepcopy(opt)
        # child_opt['model_file'] += '.labels'
        child_opt['dict_file'] = child_opt['dict_file'] + '.labels.dict'
        self.labels_dict = DictionaryAgent(child_opt, shared)
        self.char_dict = get_char_dict()
        super().__init__(opt, shared)

    def observe(self, observation):
        """Get the data from the observation"""
        observation = copy.deepcopy(observation)
        labels_observation = copy.deepcopy(observation)
        labels_observation['text'] = None
        observation['labels'] = None
        self.labels_dict.observe(labels_observation)
        return super().observe(observation)

    def act(self):
        self.labels_dict.act()
        super().act()
        return {'id': 'NERDictionary'}

    def save(self, filename=None, append=False, sort=True):
        """Save dictionary to the file

        Args:
            filename: filename of the dictionary
            append: boolean whether to append to the existing dict
            sort: boolean which determines whether to sort the dict or not

        Returns:
            None
        """
        filename = self.opt['model_file'] if filename is None else filename
        self.labels_dict.save(filename + '.labels.dict')
        return super().save(filename, append, sort)

    def tokenize(self, text, building=False):
        """Tokenize given text"""
        return text.split(' ') if text else []
コード例 #8
0
def __build_bag_of_words(opt):
    """Build a dictionary for some models.
    opt is a dictionary returned by arg_parse
    """
    if not opt['dict_build_first'] or not 'dict_file' in opt:
        return

    if opt['dict_file'] is None and opt.get('pretrained_model'):
        opt['dict_file'] = opt['pretrained_model'] + '.dict'
    if opt['dict_file'] is None and opt.get('model_file'):
        opt['dict_file'] = opt['model_file'] + '.dict'
    print("[ building dictionary first... ]")

    if not opt.get('dict_file'):
        print('Tried to build dictionary but `--dict-file` is not set. Set ' +
              'this param so the dictionary can be saved.')
        return
    print('[ setting up dictionary. ]')
    if os.path.isfile(opt['dict_file']):
        # Dictionary already built
        print("[ dictionary already built .]")
        return
    if opt.get('dict_class'):
        # Custom dictionary class
        dictionary = str2class(opt['dict_class'])(opt)
    else:
        # Default dictionary class
        dictionary = DictionaryAgent(opt)
    ordered_opt = copy.deepcopy(opt)
    cnt = 0
    # we use train set to build dictionary
    ordered_opt['datatype'] = 'train:ordered'
    if 'stream' in opt['datatype']:
        ordered_opt['datatype'] += ':stream'
    ordered_opt['numthreads'] = 1
    ordered_opt['batchsize'] = 1
    world_dict = create_task(ordered_opt, dictionary)
    # pass examples to dictionary
    for _ in world_dict:
        cnt += 1
        if cnt > opt['dict_maxexs'] and opt['dict_maxexs'] > 0:
            print('Processed {} exs, moving on.'.format(opt['dict_maxexs']))
            # don't wait too long...
            break
        world_dict.parley()
    print('[ dictionary built. ]')
    dictionary.save(opt['dict_file'], sort=True)
コード例 #9
0
def main():
    # Get command line arguments
    argparser = ParlaiParser()
    DictionaryAgent.add_cmdline_args(argparser)
    opt = argparser.parse_args()

    dictionary = DictionaryAgent(opt)

    for datatype in ['train:ordered', 'valid']:
        # we use train and valid sets to build dictionary
        opt['datatype'] = datatype
        world = create_task(opt, dictionary)

        # pass examples to dictionary
        for _ in range(len(world)):
            world.parley()

    if 'dict_savepath' in opt:
        dictionary.save(opt['dict_savepath'])
コード例 #10
0
ファイル: build_dict.py プロジェクト: zhangziliang04/ParlAI
def build_dict(opt):
    if not opt.get('dict_file'):
        print('Tried to build dictionary but `--dict-file` is not set. Set ' +
              'this param so the dictionary can be saved.')
        return
    print('[ setting up dictionary. ]')
    if os.path.isfile(opt['dict_file']):
        # Dictionary already built
        print("[ dictionary already built .]")
        return
    if opt.get('dict_class'):
        # Custom dictionary class
        dictionary = str2class(opt['dict_class'])(opt)
    else:
        # Default dictionary class
        dictionary = DictionaryAgent(opt)
    ordered_opt = copy.deepcopy(opt)
    cnt = 0
    # we use train set to build dictionary
    ordered_opt['datatype'] = 'train:ordered:stream'
    ordered_opt['numthreads'] = 1
    ordered_opt['batchsize'] = 1
    ordered_opt['image_mode'] = 'none'
    if ordered_opt['task'] == 'pytorch_teacher' and ordered_opt.get('pytorch_preprocess', False):
       pytorch_buildteacher_task = ordered_opt.get('pytorch_buildteacher', '')
       if pytorch_buildteacher_task != '':
        ordered_opt['task'] = pytorch_buildteacher_task
    world_dict = create_task(ordered_opt, dictionary)
    # pass examples to dictionary
    while not world_dict.epoch_done():
        cnt += 1
        if cnt > opt['dict_maxexs'] and opt['dict_maxexs'] > 0:
            print('Processed {} exs, moving on.'.format(opt['dict_maxexs']))
            # don't wait too long...
            break
        world_dict.parley()
    print('[ dictionary built. ]')
    dictionary.save(opt['dict_file'], sort=True)
コード例 #11
0
class NERDictionaryAgent(DictionaryAgent):
    @staticmethod
    def add_cmdline_args(argparser):
        group = DictionaryAgent.add_cmdline_args(argparser)
        group.add_argument('--dict_class',
                           default=class2str(NERDictionaryAgent),
                           help='Sets the dictionary\'s class')

    def __init__(self, opt, shared=None):
        child_opt = copy.deepcopy(opt)
        # child_opt['model_file'] += '.labels'
        child_opt['dict_file'] = child_opt['dict_file'] + '.labels.dict'
        self.labels_dict = DictionaryAgent(child_opt, shared)
        self.char_dict = get_char_dict()
        super().__init__(opt, shared)

    def observe(self, observation):
        observation = copy.deepcopy(observation)
        labels_observation = copy.deepcopy(observation)
        labels_observation['text'] = None
        observation['labels'] = None
        self.labels_dict.observe(labels_observation)
        return super().observe(observation)

    def act(self):
        self.labels_dict.act()
        super().act()
        return {'id': 'NERDictionary'}

    def save(self, filename=None, append=False, sort=True):
        filename = self.opt['model_file'] if filename is None else filename
        self.labels_dict.save(filename + '.labels.dict')
        return super().save(filename, append, sort)

    def tokenize(self, text, building=False):
        return text.split(' ') if text else []
コード例 #12
0
ファイル: build_dict.py プロジェクト: simplecoka/cortx
def build_dict(opt, skip_if_built=False):
    if isinstance(opt, ParlaiParser):
        logging.error('Should be passed opt not Parser')
        opt = opt.parse_args()
    if not opt.get('dict_file'):
        logging.error(
            'Tried to build dictionary but `--dict-file` is not set. Set '
            'this param so the dictionary can be saved.')
        return
    if skip_if_built and PathManager.exists(opt['dict_file']):
        # Dictionary already built, skip all loading or setup
        logging.debug("dictionary already built.")
        return None

    if opt.get('dict_class'):
        # Custom dictionary class
        dictionary = str2class(opt['dict_class'])(opt)
    else:
        # Default dictionary class
        dictionary = DictionaryAgent(opt)

    if PathManager.exists(
            opt['dict_file']) or (hasattr(dictionary, 'is_prebuilt')
                                  and dictionary.is_prebuilt()):
        # Dictionary already built, return loaded dictionary agent
        logging.debug("dictionary already built.")
        return dictionary

    if is_distributed():
        raise ValueError(
            'Dictionaries should be pre-built before distributed train.')

    ordered_opt = copy.deepcopy(opt)
    cnt = 0
    # we use train set to build dictionary

    ordered_opt['batchsize'] = 1
    # Set this to none so that image features are not calculated when Teacher is
    # instantiated while building the dict
    ordered_opt['image_mode'] = 'no_image_model'

    ordered_opt.log()

    datatypes = ['train:ordered:stream']
    if opt.get('dict_include_valid'):
        datatypes.append('valid:stream')
    if opt.get('dict_include_test'):
        datatypes.append('test:stream')
    cnt = 0
    for dt in datatypes:
        ordered_opt['datatype'] = dt
        world_dict = create_task(ordered_opt, dictionary)
        # pass examples to dictionary
        log_time = TimeLogger()
        total = world_dict.num_examples()
        if opt['dict_maxexs'] >= 0:
            total = min(total, opt['dict_maxexs'])

        log_every_n_secs = opt.get('log_every_n_secs', None)
        if log_every_n_secs:
            pbar = tqdm.tqdm(total=total,
                             desc='Building dictionary',
                             unit='ex',
                             unit_scale=True)
        else:
            pbar = None
        while not world_dict.epoch_done():
            cnt += 1
            if cnt > opt['dict_maxexs'] and opt['dict_maxexs'] >= 0:
                logging.info('Processed {} exs, moving on.'.format(
                    opt['dict_maxexs']))
                # don't wait too long...
                break
            world_dict.parley()
            if pbar:
                pbar.update(1)
        if pbar:
            pbar.close()

    dictionary.save(opt['dict_file'], sort=True)
    logging.info(f'dictionary built with {len(dictionary)} tokens '
                 f'in {log_time.total_time():.1f}s')
    return dictionary
コード例 #13
0
ファイル: train.py プロジェクト: daram529/ParlAI
def main():
    # Get command line arguments
    parser = ParlaiParser(add_model_args=True)
    DictionaryAgent.add_cmdline_args(parser)
    Seq2seqAgent.add_cmdline_args(parser)
    parser.add_argument('--dict-maxexs', default=100000, type=int)
    opt = parser.parse_args()

    # set model_file if none set, default is based on task name
    if not opt['model_file']:
        logdir = os.path.join(opt['parlai_home'], 'logs')
        bld.make_dir(logdir)
        task_short = opt['task'].lower()[:30]
        opt['model_file'] = os.path.join(logdir, task_short + '.model')

    #
    opt['cuda'] = not opt['no_cuda'] and torch.cuda.is_available()
    if opt['cuda']:
        print('[ Using CUDA ]')
        torch.cuda.set_device(opt['gpu'])

    # set up dictionary
    print('Setting up dictionary.')
    if '.model' in opt['model_file']:
        dict_fn = opt['model_file'].replace('.model', '.dict')
    else:
        dict_fn = opt['model_file'] + '.dict'
    if os.path.isfile(dict_fn):
        opt['dict_loadpath'] = dict_fn
    dictionary = DictionaryAgent(opt)
    ordered_opt = copy.deepcopy(opt)
    cnt = 0

    # if dictionary was not loaded, create one
    if not opt.get('dict_loadpath'):
        for datatype in ['train:ordered', 'valid']:
            # we use train and valid sets to build dictionary
            ordered_opt['datatype'] = datatype
            ordered_opt['numthreads'] = 1
            ordered_opt['batchsize'] = 1
            world_dict = create_task(ordered_opt, dictionary)

            # pass examples to dictionary
            for _ in world_dict:
                cnt += 1
                if cnt > opt['dict_maxexs'] and opt['dict_maxexs'] > 0:
                    print('Processed {} exs, moving on.'.format(
                          opt['dict_maxexs']))
                    # don't wait too long...
                    break
                world_dict.parley()
        dictionary.save(dict_fn, sort=True)

    # create agent
    agent = Seq2seqAgent(opt, {'dictionary': dictionary})

    if os.path.isfile(opt['model_file']):
        print('Loading existing model parameters from ' + opt['model_file'])
        agent.load(opt['model_file'])

    # create train and validation worlds
    opt['datatype'] = 'train'
    world_train = create_task(opt, agent)

    opt['datatype'] = 'valid'
    world_valid = create_task(opt, agent)

    # set up logging
    start = time.time()
    best_accuracy = 0
    if '.model' in opt['model_file']:
        valid_fn = opt['model_file'].replace('.model', '.validations')
        log_fn = opt['model_file'].replace('.model', '.log')
    else:
        valid_fn = opt['model_file'] + '.validations'
        log_fn = opt['model_file'] + '.log'

    # train / valid loop
    total = 0
    with open(valid_fn, 'w') as validations, open(log_fn, 'w') as log:
        while True:
            # train for a bit
            print('[ training ]')
            world_train.reset()
            for _ in range(200):
                world_train.parley()
                total += opt['batchsize']
            log.write('[ training example. ]\n')
            log.write(world_train.display() + '\n')

            # log training results
            print('[ training summary. ]')
            log.write('[ training summary. ]\n')
            report_train = world_train.report()
            report_train['cumulative_total'] = total
            print(report_train)
            log.write(str(report_train))
            log.write('\n')
            log.flush()

            # do one epoch of validation
            print('[ validating ]')
            world_valid.reset()
            for _ in world_valid:  # check valid accuracy
                world_valid.parley()
            log.write('[ validation example. ]\n')
            log.write(world_valid.display() + '\n')

            # get validation summary
            print('[ validation summary. ]')
            log.write('[ validation summary. ]\n')
            report_valid = world_valid.report()

            # update best accuracy if applicable
            annotation = ''
            if report_valid['accuracy'] > best_accuracy:
                best_accuracy = report_valid['accuracy']
                agent.save(opt['model_file'])
                annotation = '*'  # mark this validation as a best one
            curr_time = time.strftime('%Y/%m/%d %H:%M:%S', time.localtime())
            validations.write('{}: {} {}\n'.format(
                curr_time, report_valid['accuracy'], annotation))
            validations.flush()
            report_valid['best_accuracy'] = best_accuracy

            # log validation summary
            print(report_valid)
            log.write(str(report_valid))
            log.write('\n')
            log.flush()

            # break if accuracy reaches ~100%
            if report_valid['accuracy'] > 99.5:
                break

    print('finished in {} s'.format(round(time.time() - start, 2)))
コード例 #14
0
def build_dict(opt, skip_if_built=False):
    if isinstance(opt, ParlaiParser):
        print('[ Deprecated Warning: should be passed opt not Parser ]')
        opt = opt.parse_args()
    if not opt.get('dict_file'):
        print('Tried to build dictionary but `--dict-file` is not set. Set ' +
              'this param so the dictionary can be saved.')
        return
    if skip_if_built and os.path.isfile(opt['dict_file']):
        # Dictionary already built, skip all loading or setup
        print("[ dictionary already built .]")
        return None

    if is_distributed():
        raise ValueError(
            'Dictionaries should be pre-built before distributed train.')

    if opt.get('dict_class'):
        # Custom dictionary class
        dictionary = str2class(opt['dict_class'])(opt)
    else:
        # Default dictionary class
        dictionary = DictionaryAgent(opt)

    if os.path.isfile(opt['dict_file']):
        # Dictionary already built, return loaded dictionary agent
        print("[ dictionary already built .]")
        return dictionary

    ordered_opt = copy.deepcopy(opt)
    cnt = 0
    # we use train set to build dictionary

    ordered_opt['numthreads'] = 1
    ordered_opt['batchsize'] = 1
    # Set this to none so that image features are not calculated when Teacher is
    # instantiated while building the dict
    ordered_opt['image_mode'] = 'no_image_model'
    ordered_opt['pytorch_teacher_batch_sort'] = False
    if ordered_opt['task'] == 'pytorch_teacher' or not ordered_opt['task']:
        pytorch_teacher_task = ordered_opt.get('pytorch_teacher_task', '')
        if pytorch_teacher_task != '':
            ordered_opt['task'] = pytorch_teacher_task

    datatypes = ['train:ordered:stream']
    if opt.get('dict_include_valid'):
        datatypes.append('valid:stream')
    if opt.get('dict_include_test'):
        datatypes.append('test:stream')
    cnt = 0
    for dt in datatypes:
        ordered_opt['datatype'] = dt
        world_dict = create_task(ordered_opt, dictionary)
        # pass examples to dictionary
        print('[ running dictionary over data.. ]')
        log_time = TimeLogger()
        total = world_dict.num_examples()
        if opt['dict_maxexs'] >= 0:
            total = min(total, opt['dict_maxexs'])

        log_every_n_secs = opt.get('log_every_n_secs', None)
        if log_every_n_secs:
            pbar = tqdm.tqdm(total=total,
                             desc='Building dictionary',
                             unit='ex',
                             unit_scale=True)
        else:
            pbar = None
        while not world_dict.epoch_done():
            cnt += 1
            if cnt > opt['dict_maxexs'] and opt['dict_maxexs'] >= 0:
                print('Processed {} exs, moving on.'.format(
                    opt['dict_maxexs']))
                # don't wait too long...
                break
            world_dict.parley()
            if pbar:
                pbar.update(1)
        if pbar:
            pbar.close()

    dictionary.save(opt['dict_file'], sort=True)
    print('[ dictionary built with {} tokens in {}s ]'.format(
        len(dictionary), round(log_time.total_time(), 2)))
    return dictionary
コード例 #15
0
class IrBaselineAgent(Agent):
    """Information Retrieval baseline."""
    @staticmethod
    def add_cmdline_args(parser):
        """Add command line args specific to this agent."""
        parser = parser.add_argument_group('IrBaseline Arguments')
        parser.add_argument('-lp',
                            '--length_penalty',
                            type=float,
                            default=0.5,
                            help='length penalty for responses')
        parser.add_argument(
            '-hsz',
            '--history_size',
            type=int,
            default=1,
            help='number of utterances from the dialogue history to take use '
            'as the query')
        parser.add_argument('--label_candidates_file',
                            type=str,
                            default=None,
                            help='file of candidate responses to choose from')

    def __init__(self, opt, shared=None):
        """Initialize agent."""
        super().__init__(opt)
        self.id = 'IRBaselineAgent'
        self.length_penalty = float(opt['length_penalty'])
        self.dictionary = DictionaryAgent(opt)
        self.opt = opt
        self.history = []
        self.episodeDone = True
        if opt.get('label_candidates_file'):
            f = open(opt.get('label_candidates_file'))
            self.label_candidates = f.read().split('\n')

    def reset(self):
        """Reset agent properties."""
        self.observation = None
        self.history = []
        self.episodeDone = True

    def observe(self, obs):
        """Store and remember incoming observation message dict."""
        self.observation = obs
        self.dictionary.observe(obs)
        if self.episodeDone:
            self.history = []
        if 'text' in obs:
            self.history.append(obs.get('text', ''))
        self.episodeDone = obs.get('episode_done', False)
        return obs

    def act(self):
        """Generate a response to the previously seen observation(s)."""
        if self.opt.get('datatype', '').startswith('train'):
            self.dictionary.act()

        obs = self.observation
        reply = {}
        reply['id'] = self.getID()

        # Rank candidates
        cands = None
        if 'label_candidates' in obs and len(obs['label_candidates']) > 0:
            cands = obs['label_candidates']
        if hasattr(self, 'label_candidates'):
            # override label candidates with candidate file if set
            cands = self.label_candidates
        if cands:
            hist_sz = self.opt.get('history_size', 1)
            left_idx = max(0, len(self.history) - hist_sz)
            text = ' '.join(self.history[left_idx:len(self.history)])
            rep = self.build_query_representation(text)
            reply['text_candidates'] = (rank_candidates(
                rep, cands, self.length_penalty, self.dictionary))
            reply['text'] = reply['text_candidates'][0]
        else:
            reply['text'] = "I don't know."
        return reply

    def save(self, fname=None):
        """Save dictionary tokenizer if available."""
        fname = self.opt.get('model_file', None) if fname is None else fname
        if fname:
            self.dictionary.save(fname + '.dict')

    def load(self, fname):
        """Load internal dictionary."""
        self.dictionary.load(fname + '.dict')

    def build_query_representation(self, query):
        """Build representation of query, e.g. words or n-grams.

        :param query: string to represent.

        :returns: dictionary containing 'words' dictionary (token => frequency)
                  and 'norm' float (square root of the number of tokens)
        """
        rep = {}
        rep['words'] = {}
        words = [w for w in self.dictionary.tokenize(query.lower())]
        rw = rep['words']
        used = {}
        for w in words:
            if len(self.dictionary.freqs()) > 0:
                rw[w] = 1.0 / (1.0 +
                               math.log(1.0 + self.dictionary.freqs()[w]))
            else:
                if w not in stopwords:
                    rw[w] = 1
            used[w] = True
        rep['norm'] = math.sqrt(len(words))
        return rep
コード例 #16
0
class IrBaselineAgent(Agent):

    @staticmethod
    def add_cmdline_args(parser):
        DictionaryAgent.add_cmdline_args(parser)
        parser.add_argument(
            '-lp', '--length_penalty', type=float, default=0.5,
            help='length penalty for responses')
        parser.add_argument(
            '-hsz', '--history_size', type=int, default=1,
            help='number of utterances from the dialogue history to take use as the query')

    def __init__(self, opt, shared=None):
        super().__init__(opt)
        self.id = 'IRBaselineAgent'
        self.length_penalty = float(opt['length_penalty'])
        self.dictionary = DictionaryAgent(opt)
        self.opt = opt
        self.history = []
        self.episodeDone = True

    def reset(self):
        self.observation = None
        self.history = []
        self.episodeDone = True

    def observe(self, obs):
        self.observation = obs
        self.dictionary.observe(obs)
        if self.episodeDone:
            self.history = []
        if 'text' in obs:
            self.history.append(obs.get('text', ''))
        self.episodeDone = obs.get('episode_done', False)
        return obs

    def act(self):
        if self.opt.get('datatype', '').startswith('train'):
            self.dictionary.act()

        obs = self.observation
        reply = {}
        reply['id'] = self.getID()

        # Rank candidates
        if 'label_candidates' in obs and len(obs['label_candidates']) > 0:
            # text = obs['text']
            text = ' '.join(
                self.history[max(0, len(self.history) -
                                 self.opt.get('history_size', 1)):len(self.history)])
            rep = self.build_query_representation(text)
            reply['text_candidates'] = (
                rank_candidates(rep, obs['label_candidates'],
                                self.length_penalty, self.dictionary))
            reply['text'] = reply['text_candidates'][0]
        else:
            reply['text'] = "I don't know."
        return reply

    def save(self, fname=None):
        fname = self.opt.get('model_file', None) if fname is None else fname
        if fname:
            self.dictionary.save(fname + '.dict')

    def load(self, fname):
        self.dictionary.load(fname + '.dict')

    def build_query_representation(self, query):
        """ Build representation of query, e.g. words or n-grams """
        rep = {}
        rep['words'] = {}
        words = [w for w in self.dictionary.tokenize(query.lower())]
        rw = rep['words']
        used = {}
        for w in words:
            if len(self.dictionary.freqs()) > 0:
                rw[w] = 1.0 / (1.0 + math.log(1.0 + self.dictionary.freqs()[w]))
            else:
                if w not in stopwords:
                    rw[w] = 1
            used[w] = True
        rep['norm'] = math.sqrt(len(words))
        return rep
コード例 #17
0
def main():
    # Get command line arguments
    argparser = ParlaiParser()
    DictionaryAgent.add_cmdline_args(argparser)
    ParsedRemoteAgent.add_cmdline_args(argparser)
    argparser.add_argument('--num-examples', default=1000, type=int)
    argparser.add_argument('--num-its', default=100, type=int)
    argparser.add_argument('--dict-max-exs', default=10000, type=int)
    parlai_home = os.environ['PARLAI_HOME']
    if '--remote-cmd' not in sys.argv:
        if os.system('which luajit') != 0:
            raise RuntimeError(
                'Could not detect torch luajit installed: ' +
                'please install torch from http://torch.ch ' +
                'or manually set --remote-cmd for this example.')
        sys.argv.append('--remote-cmd')
        sys.argv.append('luajit {}/parlai/agents/'.format(parlai_home) +
                        'memnn_luatorch_cpu/memnn_zmq_parsed.lua')
    if '--remote-args' not in sys.argv:
        sys.argv.append('--remote-args')
        sys.argv.append('{}/examples/'.format(parlai_home) +
                        'memnn_luatorch_cpu/params_default.lua')

    opt = argparser.parse_args()

    # set up dictionary
    print('Setting up dictionary.')
    dictionary = DictionaryAgent(opt)
    if not opt.get('dict_loadpath'):
        # build dictionary since we didn't load it
        ordered_opt = copy.deepcopy(opt)
        for datatype in ['train:ordered', 'valid']:
            # we use train and valid sets to build dictionary
            ordered_opt['datatype'] = datatype
            ordered_opt['numthreads'] = 1
            world_dict = create_task(ordered_opt, dictionary)

            print('Dictionary building on {} data.'.format(datatype))
            cnt = 0
            # pass examples to dictionary
            for _ in world_dict:
                cnt += 1
                if cnt > opt['dict_max_exs'] and opt['dict_max_exs'] > 0:
                    print('Processed {} exs, moving on.'.format(
                        opt['dict_max_exs']))
                    # don't wait too long...
                    break

                world_dict.parley()

        # we need to save the dictionary to load it in memnn (sort it by freq)
        dictionary.save('/tmp/dict.txt', sort=True)

    print('Dictionary ready, moving on to training.')

    opt['datatype'] = 'train'
    agent = ParsedRemoteAgent(opt, {'dictionary': dictionary})
    world_train = create_task(opt, agent)
    opt['datatype'] = 'valid'
    world_valid = create_task(opt, agent)

    start = time.time()
    with world_train:
        for _ in range(opt['num_its']):
            print('[ training ]')
            for _ in range(opt['num_examples'] * opt.get('numthreads', 1)):
                world_train.parley()
            world_train.synchronize()

            print('[ validating ]')
            world_valid.reset()
            for _ in world_valid:  # check valid accuracy
                world_valid.parley()

            print('[ validation summary. ]')
            report_valid = world_valid.report()
            print(report_valid)
            if report_valid['accuracy'] > 0.95:
                break

        # show some example dialogs after training:
        world_valid = create_task(opt, agent)
        for _k in range(3):
            world_valid.parley()
            print(world_valid.display())

    print('finished in {} s'.format(round(time.time() - start, 2)))
コード例 #18
0
def main():
    # Get command line arguments
    argparser = ParlaiParser()
    DictionaryAgent.add_cmdline_args(argparser)
    ParsedRemoteAgent.add_cmdline_args(argparser)
    argparser.add_argument('--num_examples', default=1000, type=int)
    argparser.add_argument('--num_its', default=100, type=int)
    parlai_home = os.environ['PARLAI_HOME']
    if '--remote-cmd' not in sys.argv:
        sys.argv.append('--remote-cmd')
        sys.argv.append('luajit {}/parlai/agents/'.format(parlai_home) +
                        'memnn_luatorch_cpu/memnn_zmq_parsed.lua')
    if '--remote-args' not in sys.argv:
        sys.argv.append('--remote-args')
        sys.argv.append('{}/examples/'.format(parlai_home) +
                        'memnn_luatorch_cpu/params_default.lua')

    opt = argparser.parse_args()

    # set up dictionary
    print('Setting up dictionary.')
    dictionary = DictionaryAgent(opt)
    if not opt.get('dict_loadpath'):
        # build dictionary since we didn't load it
        ordered_opt = copy.deepcopy(opt)
        for datatype in ['train:ordered', 'valid']:
            # we use train and valid sets to build dictionary
            ordered_opt['datatype'] = datatype
            ordered_opt['numthreads'] = 1
            world_dict = create_task(ordered_opt, dictionary)
            # pass examples to dictionary
            for _ in world_dict:
                world_dict.parley()

        # we need to save the dictionary to load it in memnn (sort it by freq)
        dictionary.save('/tmp/dict.txt', sort=True)

    print('Dictionary ready, moving on to training.')

    opt['datatype'] = 'train'
    agent = ParsedRemoteAgent(opt, {'dictionary': dictionary})
    world_train = create_task(opt, agent)
    opt['datatype'] = 'valid'

    start = time.time()
    with world_train:
        for _ in range(opt['num_its']):
            print('[ training ]')
            for _ in range(opt['num_examples'] * opt.get('numthreads', 1)):
                world_train.parley()
            world_train.synchronize()

            print('[ validating ]')
            world_valid = create_task(opt, agent)
            for _ in world_valid:  # check valid accuracy
                world_valid.parley()

            print('[ validation summary. ]')
            report_valid = world_valid.report()
            print(report_valid)
            if report_valid['accuracy'] > 0.95:
                break

        # show some example dialogs after training:
        world_valid = create_task(opt, agent)
        for _k in range(3):
            world_valid.parley()
            print(world_valid.display())

    print('finished in {} s'.format(round(time.time() - start, 2)))
コード例 #19
0
ファイル: ir_baseline.py プロジェクト: rikima/ParlAI
class IrBaselineAgent(Agent):
    @staticmethod
    def add_cmdline_args(parser):
        DictionaryAgent.add_cmdline_args(parser)
        parser.add_argument('-lp',
                            '--length_penalty',
                            default=0.5,
                            help='length penalty for responses')

    def __init__(self, opt, shared=None):
        super().__init__(opt)
        self.id = 'IRBaselineAgent'
        self.length_penalty = float(opt['length_penalty'])
        self.dictionary = DictionaryAgent(opt)
        self.opt = opt

    def observe(self, obs):
        self.observation = obs
        self.dictionary.observe(obs)
        return obs

    def act(self):
        if self.opt.get('datatype', '').startswith('train'):
            self.dictionary.act()

        obs = self.observation
        reply = {}
        reply['id'] = self.getID()

        # Rank candidates
        if 'label_candidates' in obs and len(obs['label_candidates']) > 0:
            rep = self.build_query_representation(obs['text'])
            reply['text_candidates'] = (rank_candidates(
                rep, obs['label_candidates'], self.length_penalty,
                self.dictionary))
            reply['text'] = reply['text_candidates'][0]
        else:
            reply['text'] = "I don't know."
        return reply

    def save(self, fname=None):
        fname = self.opt.get('model_file', None) if fname is None else fname
        if fname:
            self.dictionary.save(fname + '.dict')

    def load(self, fname):
        self.dictionary.load(fname + '.dict')

    def build_query_representation(self, query):
        """ Build representation of query, e.g. words or n-grams """
        rep = {}
        rep['words'] = {}
        words = [w for w in self.dictionary.tokenize(query.lower())]
        rw = rep['words']
        used = {}
        for w in words:
            if len(self.dictionary.freqs()) > 0:
                rw[w] = 1.0 / (1.0 +
                               math.log(1.0 + self.dictionary.freqs()[w]))
            else:
                if w not in stopwords:
                    rw[w] = 1
            used[w] = True
        norm = len(used)
        rep['norm'] = math.sqrt(len(words))
        return rep
コード例 #20
0
class TransresnetAgent(Agent):
    """
    Model described in (https://arxiv.org/abs/1811.00945).

    A model for producing engaging captions about an image. Given an image and
    this model will attempt to predict an appropriate
    next utterance in the dialog, in the context of a given personality.

    See the paper linked above for more information.
    """
    @staticmethod
    def add_cmdline_args(argparser):
        """
        Add command line args.
        """
        arg_group = argparser.add_argument_group('Transresnet Arguments')
        TransresnetModel.add_cmdline_args(argparser)
        argparser.add_argument(
            '--freeze-patience',
            type=int,
            default=-1,
            help='How long to freeze text encoders',
        )
        argparser.add_argument(
            '--one-cand-set',
            type='bool',
            default=False,
            help='True if each example has one set of shared '
            'label candidates',
        )
        argparser.add_argument(
            '--fixed-cands-path',
            type=str,
            default=None,
            help='path to text file with candidates',
        )
        argparser.add_argument('--pretrained',
                               type='bool',
                               default=False,
                               help='True if pretrained model')
        DictionaryAgent.add_cmdline_args(argparser)
        return arg_group

    def __init__(self, opt, shared=None):
        if opt.get('numthreads', 1) > 1:
            raise RuntimeError('Warning: You cannot use multithreading with '
                               'this agent, as the current metrics do not '
                               'support sharing of lists (for median rank '
                               'calculation). Please set --numthreads to 1')
        self.metrics = {
            'hits@1/100': 0.0,
            'loss': 0.0,
            'num_samples': 0,
            'med_rank': [],
        }
        self.blank_image_features = torch.FloatTensor(
            opt.get('image_features_dim')).fill_(0)
        self.opt = opt
        self.model_file = opt['model_file']
        self.id = 'TransresnetAgent'
        self.one_cand_set = opt.get('one_cand_set', False)
        self.use_cuda = not opt['no_cuda'] and torch.cuda.is_available()
        self.fcp = None
        if opt.get('fixed_cands_path') is not None:
            self.fcp = opt['fixed_cands_path']
        self.episode_done = True

        if not shared:
            # setup dict
            self._setup_dict()
            # load the list of personalities
            self.personalities_list = self.load_personalities()
            # possibly load the model from a model file
            self._build_model()
            # load candidates if specified
            self._setup_cands()
            self.freeze_patience = self.opt['freeze_patience']
            if self.freeze_patience != -1:
                # For fine-tuning
                self.model.freeze_text_encoder()
                self.freeze_impatience = 0
                self.freeze_best_metric = 0
                self.is_frozen = True
        else:
            self.dict = shared['dict']
            self.model = shared['model']
            self.personalities_list = shared['personalities_list']
            self.fixed_cands = shared['fixed_cands']
            self.fixed_cands_enc = shared['fixed_cands_enc']

        super().__init__(opt, shared)

    def share(self):
        """
        Share appropriate attributes.
        """
        shared = super().share()
        shared['dict'] = self.dict
        shared['model'] = self.model
        shared['personalities_list'] = self.personalities_list
        shared['fixed_cands'] = self.fixed_cands
        shared['fixed_cands_enc'] = self.fixed_cands_enc
        return shared

    def _build_model(self, path=None):
        init_model_path = None
        if self.opt.get('init_model') and os.path.isfile(
                self.opt['init_model']):
            init_model_path = self.opt['init_model']
        elif self.opt.get('model_file') and os.path.isfile(
                self.opt['model_file']):
            init_model_path = self.opt['model_file']
        elif path is not None:
            init_model_path = path
        print('Creating or loading model')
        self.model = TransresnetModel(self.opt, self.personalities_list,
                                      self.dict)
        if init_model_path is not None:
            self.load(init_model_path)
        if self.use_cuda:
            self.model.cuda()

    def _setup_cands(self):
        self.fixed_cands = None
        self.fixed_cands_enc = None
        if self.fcp is not None:
            with open(self.fcp) as f:
                self.fixed_cands = [c.replace('\n', '') for c in f.readlines()]
            cands_enc_file = '{}.cands_enc'.format(self.fcp)
            print('loading saved cand encodings')
            if os.path.isfile(cands_enc_file):
                self.fixed_cands_enc = torch.load(
                    cands_enc_file, map_location=lambda cpu, _: cpu)
            else:
                print('Extracting cand encodings')
                self.model.eval()
                pbar = tqdm.tqdm(
                    total=len(self.fixed_cands),
                    unit='cand',
                    unit_scale=True,
                    desc='Extracting candidate encodings',
                )
                fixed_cands_enc = []
                for _, batch in enumerate([
                        self.fixed_cands[i:i + 50]
                        for i in range(0,
                                       len(self.fixed_cands) - 50, 50)
                ]):
                    embedding = self.model(None, None, batch)[1].detach()
                    fixed_cands_enc.append(embedding)
                    pbar.update(50)
                self.fixed_cands_enc = torch.cat(fixed_cands_enc, 0)
                torch.save(self.fixed_cands_enc, cands_enc_file)

    def load_personalities(self):
        """
        Load and return the list of personalities.
        """
        personality_path = os.path.join(
            self.opt['datapath'], 'personality_captions/personalities.txt')
        if 'yfcc_path' not in self.opt:
            self.opt['yfcc_path'] = 'temp_path'
        build(self.opt)
        del self.opt['yfcc_path']
        perss = []
        with open(personality_path) as f:
            for line in f:
                if 'Trait' not in line:
                    perss.append(line[0:-1])
        return perss

    def observe(self, observation):
        """
        Observe.
        """
        self.observation = observation
        return observation

    def act(self):
        """
        Act.
        """
        return self.batch_act([self.observation])[0]

    def train_step(self, valid_obs, image_feats, personalities):
        """
        Model train step.

        :param valid_obs:
            list of valid observations

        :param image_feats:
            list of image features, one per example

        :param personalities:
            list of personalities, one per example

        :return:
            the total loss, number of correct examples, and total number of
            examples evaluated
        """
        comments = [random.choice(v['labels']) for v in valid_obs]
        loss, num_correct, num_examples = self.model.train_batch(
            image_feats, personalities, comments)
        return loss, num_correct, num_examples

    def eval_step(self, valid_obs, image_feats, personalities):
        """
        Model eval step.

        :param valid_obs:
            list of valid observations

        :param image_feats:
            list of image features, one per example

        :param personalities:
            list of personalities, one per example

        :return:
            the total loss, number of correct examples, the total number of
            examples evaluated, the ranked position of each correct caption,
            and the ranked lists of candidates (one per example)
        """
        med_rank = None
        chosen_captions = None
        if 'label_candidates' in valid_obs[0] or self.fixed_cands is not None:
            # User provides candidates, used as negatives for evaluation
            candidates_encoded = None
            if self.fixed_cands is not None:
                candidates_encoded = self.fixed_cands_enc
                candidates = self.fixed_cands
            else:
                candidates = [v['label_candidates'] for v in valid_obs]
                if self.one_cand_set:
                    candidates_encoded = self.model(None, None,
                                                    candidates[0])[1].detach()
            chosen_captions = self.model.choose_best_caption(
                image_feats,
                personalities,
                candidates,
                candidates_encoded=candidates_encoded,
                k=-1 if self.fixed_cands is None else 100,
            )
            # calculate median ranks
            num_examples = len(chosen_captions)
            loss = -1
            if self.fixed_cands is not None:
                num_correct = 0
            else:
                comments = [v['eval_labels'] for v in valid_obs]
                med_rank = []
                for i, c_list in enumerate(chosen_captions):
                    lowest_rank = len(c_list) + 1
                    for _, c in enumerate(comments[i]):
                        lowest_rank = min(lowest_rank, c_list.index(c) + 1)
                    med_rank.append(lowest_rank)
                num_correct = sum([
                    1 if chosen_captions[i][0] in chosen_captions[i] else 0
                    for i in range(len(chosen_captions))
                ])
        else:
            comments = [random.choice(v['eval_labels']) for v in valid_obs]
            loss, num_correct, num_examples = self.model.eval_batch(
                image_feats, personalities, comments)

        return loss, num_correct, num_examples, med_rank, chosen_captions

    def batch_act(self, observations):
        """
        Act on a batch of observations.

        :param observations:
            list of observations

        :return:
            A list of acts, one for each observation
        """
        is_training = any(['labels' in obs for obs in observations])
        valid_obs, valid_indexes = self.filter_valid_obs(
            observations, is_training)
        image_feats = self.extract_image_feats(valid_obs)
        personalities = [v.get('text', '') for v in valid_obs]

        chosen_captions = None
        med_rank = None
        if is_training:
            loss, num_correct, num_examples = self.train_step(
                valid_obs, image_feats, personalities)
        else:
            loss, num_correct, num_examples, med_rank, chosen_captions = self.eval_step(
                valid_obs, image_feats, personalities)

        self.update_metrics(loss, num_correct, num_examples, med_rank)
        result = [{
            'text': 'No Response During Training'
        } for _ in range(len(observations))]
        if chosen_captions is not None:
            for i, index_obs in enumerate(valid_indexes):
                result[index_obs]['text'] = chosen_captions[i][0]
                result[index_obs]['text_candidates'] = chosen_captions[i]
        return result

    def extract_image_feats(self, obs):
        """
        Extract image features from the observations.

        :param obs:
            list of observations

        :return:
            list of image features
        """
        tmp_image_feats = [v.get('image') for v in obs]
        for i, im in enumerate(tmp_image_feats):
            try:
                # Check if given img features of form [1, <dim>, 1, 1]
                if len(im.size()) == 4:
                    tmp_image_feats[i] = im[0, :, 0, 0]
            except TypeError:  # No Image Feats Given
                tmp_image_feats[i] = self.blank_image_features
        image_feats = []
        for img in tmp_image_feats:
            image_feats.append(img.detach())
        return image_feats

    def filter_valid_obs(self, observations, is_training):
        """
        Filter out invalid observations.
        """
        label_key = 'labels' if is_training else 'eval_labels'
        valid_obs = []
        valid_indexes = []
        seen_texts = set()
        for i in range(len(observations)):
            if 'image' in observations[i]:
                if self.fixed_cands is not None:
                    valid_obs.append(observations[i])
                    valid_indexes.append(i)
                else:
                    text = observations[i][label_key][0]
                    if text not in seen_texts:
                        seen_texts.add(text)
                        valid_obs.append(observations[i])
                        valid_indexes.append(i)
        return valid_obs, valid_indexes

    def update_metrics(self, loss, num_correct, num_samples, med_rank=None):
        """
        Update Metrics.

        :param loss:
            float loss
        :param num_correct:
            number of examples for which chosen caption is correct
        :param num_samples:
            total number of examples
        :param med_rank:
            rank of correct caption for each example
        """
        self.metrics['hits@1/100'] += num_correct
        self.metrics['loss'] += loss
        self.metrics['num_samples'] += num_samples
        if med_rank:
            self.metrics['med_rank'] += med_rank

    def _setup_dict(self):
        """
        Set up the dictionary.

        The pretrained model used a separate dictionary from the standard ParlAI one.
        """
        self.dict = DictionaryAgent(self.opt)
        if self.opt.get('pretrained', False):
            new_tok2ind = {}
            new_ind2tok = {}
            for key in self.dict.tok2ind:
                val = self.dict.tok2ind[key]
                if val - 4 >= 0:
                    new_tok2ind[key] = val - 4
                    new_ind2tok[val - 4] = key
            self.dict.null_token = '<PAD>'
            self.dict.unk_token = '<UNK>'
            self.dict.tok2ind = new_tok2ind
            self.dict.ind2tok = new_ind2tok

    def receive_metrics(self, metrics_dict):
        """
        Receive the metrics from validation.

        Unfreeze text encoder weights after a certain number of rounds without improvement.

        :param metrics_dict:
            the metrics dictionary
        """
        if 'tasks' in metrics_dict:
            metrics_dict = metrics_dict['tasks']['personality_captions']
        if self.freeze_patience != -1 and self.is_frozen:
            m = metrics_dict['hits@1/100']
            if m > self.freeze_best_metric:
                self.freeze_impatience = 0
                self.freeze_best_metric = m
                print('performance not good enough to unfreeze the model.')
            else:
                self.freeze_impatience += 1
                print('Growing impatience for unfreezing')
                if self.freeze_impatience >= self.freeze_patience:
                    self.is_frozen = False
                    print('Reached impatience for fine tuning. '
                          'Reloading the best model so far.')
                    self._build_model(self.model_file)
                    if self.use_cuda:
                        self.model = self.model.cuda()
                    print('Unfreezing.')
                    self.model.unfreeze_text_encoder()
                    print('Done')

    def reset(self):
        """
        Reset metrics.
        """
        super().reset()
        self.reset_metrics()

    def reset_metrics(self):
        """
        Reset the metrics.
        """
        self.metrics['hits@1/100'] = 0.0
        self.metrics['loss'] = 0.0
        self.metrics['num_samples'] = 0.0
        if 'med_rank' in self.metrics:
            self.metrics['med_rank'] = []

    def report(self):
        """
        Report the current metrics.

        :return:
            a metrics dict
        """
        m = {}
        if self.metrics['num_samples'] > 0:
            m['hits@1/100'] = round_sigfigs(
                self.metrics['hits@1/100'] / self.metrics['num_samples'], 4)
            m['loss'] = round_sigfigs(
                self.metrics['loss'] / self.metrics['num_samples'], 4)
            if 'med_rank' in self.metrics:
                m['med_rank'] = np.median(self.metrics['med_rank'])
        return m

    def save(self, path=None):
        """
        Save the model.

        :param path:
            path for saving model
        """
        path = self.opt.get('model_file', None) if path is None else path
        self.dict.save(path + '.dict', sort=False)
        print('Saving best model')
        states = {}
        states['model'] = self.model.state_dict()
        torch.save(states, path)

        with open(path + '.opt', 'w') as handle:
            json.dump(self.opt, handle)
            handle.write('\n')

    def load(self, path):
        """
        Load a model.

        :param path:
            path from which to load model
        """
        states = torch.load(path, map_location=lambda cpu, _: cpu)
        if 'model' in states:
            self.model.load_state_dict(states['model'])
コード例 #21
0
def main():
    # Get command line arguments
    argparser = ParlaiParser()
    DictionaryAgent.add_cmdline_args(argparser)
    ParsedRemoteAgent.add_cmdline_args(argparser)
    argparser.add_argument('--num-examples', default=1000, type=int)
    argparser.add_argument('--num-its', default=100, type=int)
    argparser.add_argument('--dict-max-exs', default=10000, type=int)
    parlai_home = os.environ['PARLAI_HOME']
    if '--remote-cmd' not in sys.argv:
        if os.system('which luajit') != 0:
            raise RuntimeError('Could not detect torch luajit installed: ' +
                               'please install torch from http://torch.ch ' +
                               'or manually set --remote-cmd for this example.')
        sys.argv.append('--remote-cmd')
        sys.argv.append('luajit {}/parlai/agents/'.format(parlai_home) +
                        'memnn_luatorch_cpu/memnn_zmq_parsed.lua')
    if '--remote-args' not in sys.argv:
        sys.argv.append('--remote-args')
        sys.argv.append('{}/examples/'.format(parlai_home) +
                        'memnn_luatorch_cpu/params_default.lua')

    opt = argparser.parse_args()

    # set up dictionary
    print('Setting up dictionary.')
    dictionary = DictionaryAgent(opt)
    if not opt.get('dict_file'):
        # build dictionary since we didn't load it
        ordered_opt = copy.deepcopy(opt)
        ordered_opt['datatype'] = 'train:ordered'
        ordered_opt['numthreads'] = 1
        world_dict = create_task(ordered_opt, dictionary)

        print('Dictionary building on training data.')
        cnt = 0
        # pass examples to dictionary
        for _ in world_dict:
            cnt += 1
            if cnt > opt['dict_max_exs'] and opt['dict_max_exs'] > 0:
                print('Processed {} exs, moving on.'.format(
                      opt['dict_max_exs']))
                # don't wait too long...
                break

            world_dict.parley()

        # we need to save the dictionary to load it in memnn (sort it by freq)
        dictionary.sort()
        dictionary.save('/tmp/dict.txt', sort=True)

    print('Dictionary ready, moving on to training.')

    opt['datatype'] = 'train'
    agent = ParsedRemoteAgent(opt, {'dictionary_shared': dictionary.share()})
    world_train = create_task(opt, agent)
    opt['datatype'] = 'valid'
    world_valid = create_task(opt, agent)

    start = time.time()
    with world_train:
        for _ in range(opt['num_its']):
            print('[ training ]')
            for _ in range(opt['num_examples'] * opt.get('numthreads', 1)):
                world_train.parley()
            world_train.synchronize()

            print('[ validating ]')
            world_valid.reset()
            for _ in world_valid:  # check valid accuracy
                world_valid.parley()

            print('[ validation summary. ]')
            report_valid = world_valid.report()
            print(report_valid)
            if report_valid['accuracy'] > 0.95:
                break

        # show some example dialogs after training:
        world_valid = create_task(opt, agent)
        for _k in range(3):
            world_valid.parley()
            print(world_valid.display())

    print('finished in {} s'.format(round(time.time() - start, 2)))
コード例 #22
0
ファイル: ir_baseline.py プロジェクト: ahiroto/ParlAI
class IrBaselineAgent(Agent):

    @staticmethod
    def add_cmdline_args(parser):
        DictionaryAgent.add_cmdline_args(parser)
        parser.add_argument(
            '-lp', '--length_penalty', default=0.5,
            help='length penalty for responses')

    def __init__(self, opt, shared=None):
        super().__init__(opt)
        self.id = 'IRBaselineAgent'
        self.length_penalty = float(opt['length_penalty'])
        self.dictionary = DictionaryAgent(opt)
        self.opt = opt

    def observe(self, obs):
        self.observation = obs
        self.dictionary.observe(obs)
        return obs

    def act(self):
        if self.opt.get('datatype', '').startswith('train'):
            self.dictionary.act()

        obs = self.observation
        reply = {}
        reply['id'] = self.getID()

        # Rank candidates
        if 'label_candidates' in obs and len(obs['label_candidates']) > 0:
            rep = self.build_query_representation(obs['text'])
            reply['text_candidates'] = (
                rank_candidates(rep, obs['label_candidates'],
                                self.length_penalty, self.dictionary))
            reply['text'] = reply['text_candidates'][0]
        else:
            reply['text'] = "I don't know."
        return reply

    def save(self, fname=None):
        fname = self.opt.get('model_file', None) if fname is None else fname
        if fname:
            self.dictionary.save(fname + '.dict')

    def load(self, fname):
        self.dictionary.load(fname + '.dict')

    def build_query_representation(self, query):
        """ Build representation of query, e.g. words or n-grams """
        rep = {}
        rep['words'] = {}
        words = [w for w in self.dictionary.tokenize(query.lower())]
        rw = rep['words']
        used = {}
        for w in words:
            if len(self.dictionary.freqs()) > 0:
                rw[w] = 1.0 / (1.0 + math.log(1.0 + self.dictionary.freqs()[w]))
            else:
                if w not in stopwords:
                    rw[w] = 1
            used[w] = True
        norm = len(used)
        rep['norm'] = math.sqrt(len(words))
        return rep
コード例 #23
0
import pickle
import os
from parlai.core.dict import DictionaryAgent

path = 'dat/MovieTriples_Dataset.tar'

with open(os.path.join(path, 'Training.dict.pkl'), 'rb') as data_file:
    dictionary = pickle.load(data_file)

parlai_dict = DictionaryAgent({'vocab_size': 10004})

dictionary = sorted(dictionary, key=lambda x: x[1])
print(dictionary[:10])

for word in dictionary:
    # print(word[0])
    parlai_dict.add_to_dict([word[0]])
    parlai_dict.freq[word[0]] = word[2]
    # print(word)

# print(parlai_dict)
# parlai_dict.add_to_dict(['hello'])

parlai_dict.save('test_hred.dict', sort=True)
コード例 #24
0
ファイル: build_dict.py プロジェクト: ehosseiniasl/Enigma
def build_dict(opt, skip_if_built=False):
    if isinstance(opt, ParlaiParser):
        print('[ Deprecated Warning: should be passed opt not Parser ]')
        opt = opt.parse_args()
    if not opt.get('dict_file'):
        print('Tried to build dictionary but `--dict-file` is not set. Set ' +
              'this param so the dictionary can be saved.')
        return

    if skip_if_built and os.path.isfile(opt['dict_file']):
        # Dictionary already built, skip all loading or setup
        print("[ dictionary already built .]")
        return None

    if opt.get('dict_class'):
        # Custom dictionary class
        dictionary = str2class(opt['dict_class'])(opt)
    else:
        # Default dictionary class
        dictionary = DictionaryAgent(opt)

    if os.path.isfile(opt['dict_file']):
        # Dictionary already built, return loaded dictionary agent
        print("[ dictionary already built .]")
        return dictionary

    ordered_opt = copy.deepcopy(opt)
    cnt = 0
    # we use train set to build dictionary

    ordered_opt['numthreads'] = 1
    ordered_opt['batchsize'] = 1
    ordered_opt['image_mode'] = 'none'
    if ordered_opt['task'] == 'pytorch_teacher':
        pytorch_teacher_task = ordered_opt.get('pytorch_teacher_task', '')
        if pytorch_teacher_task != '':
            ordered_opt['task'] = pytorch_teacher_task

    datatypes = ['train:ordered:stream']
    if opt.get('dict_include_valid'):
        datatypes.append('valid:stream')
    if opt.get('dict_include_test'):
        datatypes.append('test:stream')
    cnt = 0
    for dt in datatypes:
        ordered_opt['datatype'] = dt
        world_dict = create_task(ordered_opt, dictionary)
        # pass examples to dictionary
        print('[ running dictionary over data.. ]')
        log_every_n_secs = opt.get('log_every_n_secs', -1)
        if log_every_n_secs <= 0:
            log_every_n_secs = float('inf')
        log_time = TimeLogger()
        while not world_dict.epoch_done():
            cnt += 1
            if cnt > opt['dict_maxexs'] and opt['dict_maxexs'] > 0:
                print('Processed {} exs, moving on.'.format(opt['dict_maxexs']))
                # don't wait too long...
                break
            world_dict.parley()
            if log_time.time() > log_every_n_secs:
                sys.stdout.write('\r')
                text, _log = log_time.log(cnt, max(opt.get('dict_maxexs', 0),
                                                   world_dict.num_examples()))
                sys.stdout.write(text)
                sys.stdout.flush()

    dictionary.save(opt['dict_file'], sort=True)
    print('[ dictionary built with {} tokens in {}s ]'.format(
        len(dictionary), round(log_time.total_time(), 2)))
    return dictionary