Пример #1
0
def build_dict(opt, skip_if_built=False):
    if isinstance(opt, ParlaiParser):
        print('[ Deprecated Warning: should be passed opt not Parser ]')
        opt = opt.parse_args()
    if not opt.get('dict_file'):
        print('Tried to build dictionary but `--dict-file` is not set. Set ' +
              'this param so the dictionary can be saved.')
        return
    if skip_if_built and os.path.isfile(opt['dict_file']):
        # Dictionary already built, skip all loading or setup
        print("[ dictionary already built .]")
        return None

    if is_distributed():
        raise ValueError(
            'Dictionaries should be pre-built before distributed train.')

    if opt.get('dict_class'):
        # Custom dictionary class
        dictionary = str2class(opt['dict_class'])(opt)
    else:
        # Default dictionary class
        dictionary = DictionaryAgent(opt)

    if os.path.isfile(opt['dict_file']):
        # Dictionary already built, return loaded dictionary agent
        print("[ dictionary already built .]")
        return dictionary

    ordered_opt = copy.deepcopy(opt)
    cnt = 0
    # we use train set to build dictionary

    ordered_opt['numthreads'] = 1
    ordered_opt['batchsize'] = 1
    # Set this to none so that image features are not calculated when Teacher is
    # instantiated while building the dict
    ordered_opt['image_mode'] = 'no_image_model'
    ordered_opt['pytorch_teacher_batch_sort'] = False
    if ordered_opt['task'] == 'pytorch_teacher' or not ordered_opt['task']:
        pytorch_teacher_task = ordered_opt.get('pytorch_teacher_task', '')
        if pytorch_teacher_task != '':
            ordered_opt['task'] = pytorch_teacher_task

    datatypes = ['train:ordered:stream']
    if opt.get('dict_include_valid'):
        datatypes.append('valid:stream')
    if opt.get('dict_include_test'):
        datatypes.append('test:stream')
    cnt = 0
    for dt in datatypes:
        ordered_opt['datatype'] = dt
        world_dict = create_task(ordered_opt, dictionary)
        # pass examples to dictionary
        print('[ running dictionary over data.. ]')
        log_time = TimeLogger()
        total = world_dict.num_examples()
        if opt['dict_maxexs'] >= 0:
            total = min(total, opt['dict_maxexs'])

        log_every_n_secs = opt.get('log_every_n_secs', None)
        if log_every_n_secs:
            pbar = tqdm.tqdm(total=total,
                             desc='Building dictionary',
                             unit='ex',
                             unit_scale=True)
        else:
            pbar = None
        while not world_dict.epoch_done():
            cnt += 1
            if cnt > opt['dict_maxexs'] and opt['dict_maxexs'] >= 0:
                print('Processed {} exs, moving on.'.format(
                    opt['dict_maxexs']))
                # don't wait too long...
                break
            world_dict.parley()
            if pbar:
                pbar.update(1)
        if pbar:
            pbar.close()

    dictionary.save(opt['dict_file'], sort=True)
    print('[ dictionary built with {} tokens in {}s ]'.format(
        len(dictionary), round(log_time.total_time(), 2)))
    return dictionary
Пример #2
0
def build_dict(cfg):
    if cfg.build_dict.skip_if_build and os.path.isfile(cfg.dict.file):
        # Dictionary already built, skip all loading or setup
        log.info("dictionary already built")
        return None

    if is_distributed():
        raise ValueError(
            'Dictionaries should be pre-built before distributed train.')

    dictionary = hydra.utils.instantiate(cfg.dict)

    if os.path.isfile(cfg.dict.file):
        # Dictionary already built, return loaded dictionary agent
        log.info("dictionary already built")
        return dictionary

    # TODO: override
    # ordered_opt['numthreads'] = 1
    # ordered_opt['batchsize'] = 1

    # Set this to none so that image features are not calculated when Teacher is
    # instantiated while building the dict
    # TODO: change 'none' to 'no_image_model'
    # ordered_opt['image_mode'] = 'none'

    # ordered_opt['pytorch_teacher_batch_sort'] = False
    # TODO: how to check if task if ??? (not set)
    # if cfg.teacher.task['task'] == 'pytorch_teacher' or cfg.teacher.task is None:
    #     pytorch_teacher_task = ordered_opt.get('pytorch_teacher_task', '')
    #     if pytorch_teacher_task != '':
    #         ordered_opt['task'] = pytorch_teacher_task

    datatypes = ['train:ordered:stream']
    if cfg.build_dict.include_valid:
        datatypes.append('valid:stream')
    if cfg.build_dict.include_test:
        datatypes.append('test:stream')

    cnt = 0
    for dt in datatypes:
        world_dict = create_task(cfg, dictionary)
        # pass examples to dictionary
        log.info('running dictionary over data..')
        log_time = TimeLogger()
        total = world_dict.num_examples()
        if opt['dict_maxexs'] >= 0:
            total = min(total, opt['dict_maxexs'])

        log_every_n_secs = opt.get('log_every_n_secs', None)
        if log_every_n_secs:
            pbar = tqdm.tqdm(total=total,
                             desc='Building dictionary',
                             unit='ex',
                             unit_scale=True)
        else:
            pbar = None
        while not world_dict.epoch_done():
            cnt += 1
            if cnt > cfg.build_dict.maxexs and cfg.build_dict.maxexs >= 0:
                log.info('Processed {} exs, moving on.'.format(
                    cfg.build_dict.maxexs))
                # don't wait too long...
                break
            world_dict.parley()
            if pbar:
                pbar.update(1)
        if pbar:
            pbar.close()

    dictionary.save(cfg.dict.file, sort=True)
    log.info('dictionary built with {} tokens in {}s'.format(
        len(dictionary), round(log_time.total_time(), 2)))
    return dictionary