예제 #1
0
def build_dict(opt):
    opt = copy.deepcopy(opt)
    opt['batchsize'] = 1
    dictionary = SimpleDictionaryAgent(opt)

    # We use the train set to build the dictionary.
    logger.info('[ Building dictionary... ]')
    opt['datatype'] = 'train:ordered'
    world = create_task(opt, dictionary)
    for _ in world:
        world.parley()

    dictionary.sort()
    logger.info('[ Dictionary built. ]')
    logger.info('[ Num words = %d ]' % len(dictionary))
    return dictionary
예제 #2
0
def build_dict(opt):
    opt = copy.deepcopy(opt)
    opt['batchsize'] = 1
    dictionary = SimpleDictionaryAgent(opt)

    # We use the train set to build the dictionary.
    logger.info('[ Building word dictionary... ]')
    opt['datatype'] = 'train:ordered'
    world = create_task(opt, dictionary)
    for _ in world:
        world.parley()

    if(opt['vocab_size'] > 0):
        nKeep=opt['vocab_size']
        dictionary.sort_and_keep(nKeep)
    else:
        dictionary.sort()
        opt['vocab_size'] = len(dictionary)+1

    logger.info('[ Dictionary built (full size). ]')
    logger.info('[ Num words = %d ]' % len(dictionary))
    return dictionary
예제 #3
0
파일: train.py 프로젝트: zgsxwsdxg/ParlAI
        iteration += 1


if __name__ == '__main__':
    # Get command line arguments
    argparser = ParlaiParser()
    argparser.add_arg(
        '--train_interval', type=int, default=1000,
        help='Validate after every N train updates',
    )
    argparser.add_arg(
        '--patience', type=int, default=10,
        help='Number of intervals to continue without improvement'
    )
    SimpleDictionaryAgent.add_cmdline_args(argparser)
    DocReaderAgent.add_cmdline_args(argparser)
    opt = argparser.parse_args()

    # Set logging
    logger = logging.getLogger('DrQA')
    logger.setLevel(logging.INFO)
    fmt = logging.Formatter('%(asctime)s: %(message)s', '%m/%d/%Y %I:%M:%S %p')
    console = logging.StreamHandler()
    console.setFormatter(fmt)
    logger.addHandler(console)
    if 'log_file' in opt:
        logfile = logging.FileHandler(opt['log_file'], 'w')
        logfile.setFormatter(fmt)
        logger.addHandler(logfile)
    logger.info('[ COMMAND: %s ]' % ' '.join(sys.argv))