コード例 #1
0
def train_model(params, load_dataset=None):
    """
    Training function. Sets the training parameters from params. Build or loads the model and launches the training.
    :param params: Dictionary of network hyperparameters.
    :param load_dataset: Load dataset from file or build it from the parameters.
    :return: None
    """
    check_params(params)

    if params['RELOAD'] > 0:
        logging.info('Resuming training.')
        # Load data
        if load_dataset is None:
            if params['REBUILD_DATASET']:
                logging.info('Rebuilding dataset.')
                dataset = build_dataset(params)
            else:
                logging.info('Updating dataset.')
                dataset = loadDataset(params['DATASET_STORE_PATH'] +
                                      '/Dataset_' + params['DATASET_NAME'] +
                                      '_' + params['SRC_LAN'] +
                                      params['TRG_LAN'] + '.pkl')
                params['EPOCH_OFFSET'] = params['RELOAD'] if params['RELOAD_EPOCH'] else \
                    int(params['RELOAD'] * params['BATCH_SIZE'] / dataset.len_train)
                for split, filename in params['TEXT_FILES'].iteritems():
                    dataset = update_dataset_from_file(
                        dataset,
                        params['DATA_ROOT_PATH'] + '/' + filename +
                        params['SRC_LAN'],
                        params,
                        splits=list([split]),
                        output_text_filename=params['DATA_ROOT_PATH'] + '/' +
                        filename + params['TRG_LAN'],
                        remove_outputs=False,
                        compute_state_below=True,
                        recompute_references=True)
                    dataset.name = params['DATASET_NAME'] + '_' + params[
                        'SRC_LAN'] + params['TRG_LAN']
                saveDataset(dataset, params['DATASET_STORE_PATH'])

        else:
            logging.info('Reloading and using dataset.')
            dataset = loadDataset(load_dataset)
    else:
        # Load data
        if load_dataset is None:
            dataset = build_dataset(params)
        else:
            dataset = loadDataset(load_dataset)

    params['INPUT_VOCABULARY_SIZE'] = dataset.vocabulary_len[
        params['INPUTS_IDS_DATASET'][0]]
    params['OUTPUT_VOCABULARY_SIZE'] = dataset.vocabulary_len[
        params['OUTPUTS_IDS_DATASET'][0]]

    # Build model
    set_optimizer = True if params['RELOAD'] == 0 else False
    clear_dirs = True if params['RELOAD'] == 0 else False

    # build new model
    nmt_model = TranslationModel(params,
                                 model_type=params['MODEL_TYPE'],
                                 verbose=params['VERBOSE'],
                                 model_name=params['MODEL_NAME'],
                                 vocabularies=dataset.vocabulary,
                                 store_path=params['STORE_PATH'],
                                 set_optimizer=set_optimizer,
                                 clear_dirs=clear_dirs)

    # Define the inputs and outputs mapping from our Dataset instance to our model
    inputMapping = dict()
    for i, id_in in enumerate(params['INPUTS_IDS_DATASET']):
        pos_source = dataset.ids_inputs.index(id_in)
        id_dest = nmt_model.ids_inputs[i]
        inputMapping[id_dest] = pos_source
    nmt_model.setInputsMapping(inputMapping)

    outputMapping = dict()
    for i, id_out in enumerate(params['OUTPUTS_IDS_DATASET']):
        pos_target = dataset.ids_outputs.index(id_out)
        id_dest = nmt_model.ids_outputs[i]
        outputMapping[id_dest] = pos_target
    nmt_model.setOutputsMapping(outputMapping)

    if params['RELOAD'] > 0:
        nmt_model = updateModel(nmt_model,
                                params['STORE_PATH'],
                                params['RELOAD'],
                                reload_epoch=params['RELOAD_EPOCH'])
        nmt_model.setParams(params)
        nmt_model.setOptimizer()
        if params.get('EPOCH_OFFSET') is None:
            params['EPOCH_OFFSET'] = params['RELOAD'] if params['RELOAD_EPOCH'] else \
                int(params['RELOAD'] * params['BATCH_SIZE'] / dataset.len_train)

    # Store configuration as pkl
    dict2pkl(params, params['STORE_PATH'] + '/config')

    # Callbacks
    callbacks = buildCallbacks(params, nmt_model, dataset)

    # Training
    total_start_time = timer()

    logger.debug('Starting training!')
    training_params = {
        'n_epochs':
        params['MAX_EPOCH'],
        'batch_size':
        params['BATCH_SIZE'],
        'homogeneous_batches':
        params['HOMOGENEOUS_BATCHES'],
        'maxlen':
        params['MAX_OUTPUT_TEXT_LEN'],
        'joint_batches':
        params['JOINT_BATCHES'],
        'lr_decay':
        params.get('LR_DECAY', None),  # LR decay parameters
        'reduce_each_epochs':
        params.get('LR_REDUCE_EACH_EPOCHS', True),
        'start_reduction_on_epoch':
        params.get('LR_START_REDUCTION_ON_EPOCH', 0),
        'lr_gamma':
        params.get('LR_GAMMA', 0.9),
        'lr_reducer_type':
        params.get('LR_REDUCER_TYPE', 'linear'),
        'lr_reducer_exp_base':
        params.get('LR_REDUCER_EXP_BASE', 0),
        'lr_half_life':
        params.get('LR_HALF_LIFE', 50000),
        'epochs_for_save':
        params['EPOCHS_FOR_SAVE'],
        'verbose':
        params['VERBOSE'],
        'eval_on_sets':
        params['EVAL_ON_SETS_KERAS'],
        'n_parallel_loaders':
        params['PARALLEL_LOADERS'],
        'extra_callbacks':
        callbacks,
        'reload_epoch':
        params['RELOAD'],
        'epoch_offset':
        params.get('EPOCH_OFFSET', 0),
        'data_augmentation':
        params['DATA_AUGMENTATION'],
        'patience':
        params.get('PATIENCE', 0),  # early stopping parameters
        'metric_check':
        params.get('STOP_METRIC', None)
        if params.get('EARLY_STOP', False) else None,
        'eval_on_epochs':
        params.get('EVAL_EACH_EPOCHS', True),
        'each_n_epochs':
        params.get('EVAL_EACH', 1),
        'start_eval_on_epoch':
        params.get('START_EVAL_ON_EPOCH', 0),
        'tensorboard':
        params.get('TENSORBOARD', False),
        'tensorboard_params': {
            'log_dir':
            params.get('LOG_DIR', 'tensorboard_logs'),
            'histogram_freq':
            params.get('HISTOGRAM_FREQ', 0),
            'batch_size':
            params.get('TENSORBOARD_BATCH_SIZE', params['BATCH_SIZE']),
            'write_graph':
            params.get('WRITE_GRAPH', True),
            'write_grads':
            params.get('WRITE_GRADS', False),
            'write_images':
            params.get('WRITE_IMAGES', False),
            'embeddings_freq':
            params.get('EMBEDDINGS_FREQ', 0),
            'embeddings_layer_names':
            params.get('EMBEDDINGS_LAYER_NAMES', None),
            'embeddings_metadata':
            params.get('EMBEDDINGS_METADATA', None),
            'label_word_embeddings_with_vocab':
            params.get('LABEL_WORD_EMBEDDINGS_WITH_VOCAB', False),
            'word_embeddings_labels':
            params.get('WORD_EMBEDDINGS_LABELS', None),
        }
    }
    nmt_model.trainNet(dataset, training_params)

    total_end_time = timer()
    time_difference = total_end_time - total_start_time
    logging.info('In total is {0:.2f}s = {1:.2f}m'.format(
        time_difference, time_difference / 60.0))
コード例 #2
0
def main():
    args = parse_args()
    server_address = (args.address, args.port)
    httpd = HTTPServer(server_address, NMTHandler)
    logger.setLevel(args.logging_level)
    parameters = load_parameters()
    if args.config is not None:
        logger.info("Loading parameters from %s" % str(args.config))
        parameters = update_parameters(parameters, pkl2dict(args.config))

    if args.online:
        online_parameters = load_parameters_online()
        parameters = update_parameters(parameters, online_parameters)

    try:
        for arg in args.changes:
            try:
                k, v = arg.split('=')
            except ValueError:
                print(
                    'Overwritten arguments must have the form key=Value. \n Currently are: %s'
                    % str(args.changes))
                exit(1)
            try:
                parameters[k] = ast.literal_eval(v)
            except ValueError:
                parameters[k] = v
    except ValueError:
        print('Error processing arguments: (', k, ",", v, ")")
        exit(2)
    dataset = loadDataset(args.dataset)

    # For converting predictions into sentences
    # Dataset backwards compatibility
    bpe_separator = dataset.BPE_separator if hasattr(
        dataset,
        "BPE_separator") and dataset.BPE_separator is not None else '@@'
    # Build BPE tokenizer if necessary
    if 'bpe' in parameters['TOKENIZATION_METHOD'].lower():
        logger.info('Building BPE')
        if not dataset.BPE_built:
            dataset.build_bpe(parameters.get(
                'BPE_CODES_PATH',
                parameters['DATA_ROOT_PATH'] + '/training_codes.joint'),
                              separator=bpe_separator)
    # Build tokenization function
    tokenize_f = eval('dataset.' +
                      parameters.get('TOKENIZATION_METHOD', 'tokenize_bpe'))
    detokenize_function = eval(
        'dataset.' + parameters.get('DETOKENIZATION_METHOD', 'detokenize_bpe'))
    dataset.build_moses_tokenizer(language=parameters['SRC_LAN'])
    dataset.build_moses_detokenizer(language=parameters['TRG_LAN'])
    tokenize_general = dataset.tokenize_moses
    detokenize_general = dataset.detokenize_moses

    # Prediction parameters
    params_prediction = dict()
    params_prediction['max_batch_size'] = parameters.get('BATCH_SIZE', 20)
    params_prediction['n_parallel_loaders'] = parameters.get(
        'PARALLEL_LOADERS', 1)
    params_prediction['beam_size'] = parameters.get('BEAM_SIZE', 6)
    params_prediction['maxlen'] = parameters.get('MAX_OUTPUT_TEXT_LEN_TEST',
                                                 100)
    params_prediction['optimized_search'] = parameters['OPTIMIZED_SEARCH']
    params_prediction['model_inputs'] = parameters['INPUTS_IDS_MODEL']
    params_prediction['model_outputs'] = parameters['OUTPUTS_IDS_MODEL']
    params_prediction['dataset_inputs'] = parameters['INPUTS_IDS_DATASET']
    params_prediction['dataset_outputs'] = parameters['OUTPUTS_IDS_DATASET']
    params_prediction['search_pruning'] = parameters.get(
        'SEARCH_PRUNING', False)
    params_prediction['normalize_probs'] = True
    params_prediction['alpha_factor'] = parameters.get('ALPHA_FACTOR', 1.0)
    params_prediction['coverage_penalty'] = True
    params_prediction['length_penalty'] = True
    params_prediction['length_norm_factor'] = parameters.get(
        'LENGTH_NORM_FACTOR', 0.0)
    params_prediction['coverage_norm_factor'] = parameters.get(
        'COVERAGE_NORM_FACTOR', 0.0)
    params_prediction['pos_unk'] = parameters.get('POS_UNK', False)
    params_prediction['heuristic'] = parameters.get('HEURISTIC', 0)
    params_prediction['state_below_index'] = -1
    params_prediction['output_text_index'] = 0
    params_prediction['state_below_maxlen'] = -1 if parameters.get(
        'PAD_ON_BATCH', True) else parameters.get('MAX_OUTPUT_TEXT_LEN', 50)
    params_prediction['output_max_length_depending_on_x'] = parameters.get(
        'MAXLEN_GIVEN_X', True)
    params_prediction[
        'output_max_length_depending_on_x_factor'] = parameters.get(
            'MAXLEN_GIVEN_X_FACTOR', 3)
    params_prediction['output_min_length_depending_on_x'] = parameters.get(
        'MINLEN_GIVEN_X', True)
    params_prediction[
        'output_min_length_depending_on_x_factor'] = parameters.get(
            'MINLEN_GIVEN_X_FACTOR', 2)
    params_prediction['attend_on_output'] = parameters.get(
        'ATTEND_ON_OUTPUT', 'transformer' in parameters['MODEL_TYPE'].lower())

    # Manage pos_unk strategies
    if parameters['POS_UNK']:
        mapping = None if dataset.mapping == dict() else dataset.mapping
    else:
        mapping = None

    if 'transformer' in parameters['MODEL_TYPE'].lower():
        params_prediction['pos_unk'] = False
        params_prediction['coverage_penalty'] = False

    # Training parameters
    parameters_training = dict()
    if args.online:
        logger.info('Loading models from %s' % str(args.models))
        parameters_training = {  # Traning parameters
            'n_epochs': parameters['MAX_EPOCH'],
            'shuffle': False,
            'loss': parameters.get('LOSS', 'categorical_crossentropy'),
            'batch_size': parameters.get('BATCH_SIZE', 1),
            'homogeneous_batches': False,
            'optimizer': parameters.get('OPTIMIZER', 'SGD'),
            'lr': parameters.get('LR', 0.1),
            'lr_decay': parameters.get('LR_DECAY', None),
            'lr_gamma': parameters.get('LR_GAMMA', 1.),
            'epochs_for_save': -1,
            'verbose': args.verbose,
            'eval_on_sets': parameters.get('EVAL_ON_SETS_KERAS', None),
            'n_parallel_loaders': parameters['PARALLEL_LOADERS'],
            'extra_callbacks': [],  # callbacks,
            'reload_epoch': parameters['RELOAD'],
            'epoch_offset': parameters['RELOAD'],
            'data_augmentation': parameters['DATA_AUGMENTATION'],
            'patience': parameters.get('PATIENCE', 0),
            'metric_check': parameters.get('STOP_METRIC', None),
            'eval_on_epochs': parameters.get('EVAL_EACH_EPOCHS', True),
            'each_n_epochs': parameters.get('EVAL_EACH', 1),
            'start_eval_on_epoch': parameters.get('START_EVAL_ON_EPOCH', 0),
            'additional_training_settings': {
                'k': parameters.get('K', 1),
                'tau': parameters.get('TAU', 1),
                'lambda': parameters.get('LAMBDA', 0.5),
                'c': parameters.get('C', 0.5),
                'd': parameters.get('D', 0.5)
            }
        }
        model_instances = [
            TranslationModel(
                parameters,
                model_type=parameters['MODEL_TYPE'],
                verbose=parameters['VERBOSE'],
                model_name=parameters['MODEL_NAME'] + '_' + str(i),
                vocabularies=dataset.vocabulary,
                store_path=parameters['STORE_PATH'],
                set_optimizer=False) for i in range(len(args.models))
        ]
        models = [
            updateModel(model, path, -1, full_path=True)
            for (model, path) in zip(model_instances, args.models)
        ]
    else:
        models = [loadModel(m, -1, full_path=True) for m in args.models]

    for nmt_model in models:
        nmt_model.setParams(parameters)
        nmt_model.setOptimizer()

    parameters['INPUT_VOCABULARY_SIZE'] = dataset.vocabulary_len[
        parameters['INPUTS_IDS_DATASET'][0]]
    parameters['OUTPUT_VOCABULARY_SIZE'] = dataset.vocabulary_len[
        parameters['OUTPUTS_IDS_DATASET'][0]]

    # Get word2index and index2word dictionaries
    index2word_y = dataset.vocabulary[parameters['OUTPUTS_IDS_DATASET']
                                      [0]]['idx2words']
    word2index_y = dataset.vocabulary[parameters['OUTPUTS_IDS_DATASET']
                                      [0]]['words2idx']
    index2word_x = dataset.vocabulary[parameters['INPUTS_IDS_DATASET']
                                      [0]]['idx2words']
    word2index_x = dataset.vocabulary[parameters['INPUTS_IDS_DATASET']
                                      [0]]['words2idx']

    excluded_words = None
    interactive_beam_searcher = NMTSampler(models,
                                           dataset,
                                           parameters,
                                           params_prediction,
                                           parameters_training,
                                           tokenize_f,
                                           detokenize_function,
                                           tokenize_general,
                                           detokenize_general,
                                           mapping=mapping,
                                           word2index_x=word2index_x,
                                           word2index_y=word2index_y,
                                           index2word_y=index2word_y,
                                           eos_symbol=args.eos_symbol,
                                           excluded_words=excluded_words,
                                           online=args.online,
                                           verbose=args.verbose)

    httpd.sampler = interactive_beam_searcher

    logger.info('Server starting at %s' % str(server_address))
    httpd.serve_forever()
コード例 #3
0
def main():
    args = parse_args()
    server_address = ('', args.port)
    httpd = BaseHTTPServer.HTTPServer(server_address, NMTHandler)

    if args.config is None:
        logging.info("Reading parameters from config.py")
        from config import load_parameters
        params = load_parameters()
    else:
        logging.info("Loading parameters from %s" % str(args.config))
        params = pkl2dict(args.config)
    try:
        for arg in args.changes:
            try:
                k, v = arg.split('=')
            except ValueError:
                print 'Overwritten arguments must have the form key=Value. \n Currently are: %s' % str(
                    args.changes)
                exit(1)
            try:
                params[k] = ast.literal_eval(v)
            except ValueError:
                params[k] = v
    except ValueError:
        print 'Error processing arguments: (', k, ",", v, ")"
        exit(2)
    dataset = loadDataset(args.dataset)

    # For converting predictions into sentences
    # Dataset backwards compatibility
    bpe_separator = dataset.BPE_separator if hasattr(
        dataset,
        "BPE_separator") and dataset.BPE_separator is not None else '@@'
    # Build BPE tokenizer if necessary
    if 'bpe' in params['TOKENIZATION_METHOD'].lower():
        logger.info('Building BPE')
        if not dataset.BPE_built:
            dataset.build_bpe(
                params.get('BPE_CODES_PATH',
                           params['DATA_ROOT_PATH'] + '/training_codes.joint'),
                bpe_separator)
    # Build tokenization function
    tokenize_f = eval('dataset.' +
                      params.get('TOKENIZATION_METHOD', 'tokenize_none'))

    detokenize_function = eval(
        'dataset.' + params.get('DETOKENIZATION_METHOD', 'detokenize_none'))
    dataset.build_moses_tokenizer(language=params['SRC_LAN'])
    dataset.build_moses_detokenizer(language=params['TRG_LAN'])
    tokenize_general = dataset.tokenize_moses
    detokenize_general = dataset.detokenize_moses

    params_prediction = dict()
    params_prediction['max_batch_size'] = params.get('BATCH_SIZE', 20)
    params_prediction['n_parallel_loaders'] = params.get('PARALLEL_LOADERS', 1)
    params_prediction['beam_size'] = params.get('BEAM_SIZE', 6)
    params_prediction['maxlen'] = params.get('MAX_OUTPUT_TEXT_LEN_TEST', 100)
    params_prediction['optimized_search'] = params['OPTIMIZED_SEARCH']
    params_prediction['model_inputs'] = params['INPUTS_IDS_MODEL']
    params_prediction['model_outputs'] = params['OUTPUTS_IDS_MODEL']
    params_prediction['dataset_inputs'] = params['INPUTS_IDS_DATASET']
    params_prediction['dataset_outputs'] = params['OUTPUTS_IDS_DATASET']
    params_prediction['search_pruning'] = params.get('SEARCH_PRUNING', False)
    params_prediction['normalize_probs'] = params.get('NORMALIZE_SAMPLING',
                                                      False)
    params_prediction['alpha_factor'] = params.get('ALPHA_FACTOR', 1.0)
    params_prediction['coverage_penalty'] = params.get('COVERAGE_PENALTY',
                                                       False)
    params_prediction['length_penalty'] = params.get('LENGTH_PENALTY', False)
    params_prediction['length_norm_factor'] = params.get(
        'LENGTH_NORM_FACTOR', 0.0)
    params_prediction['coverage_norm_factor'] = params.get(
        'COVERAGE_NORM_FACTOR', 0.0)
    params_prediction['pos_unk'] = params.get('POS_UNK', False)
    params_prediction['heuristic'] = params.get('HEURISTIC', 0)

    params_prediction['state_below_maxlen'] = -1 if params.get('PAD_ON_BATCH', True) \
        else params.get('MAX_OUTPUT_TEXT_LEN', 50)
    params_prediction['output_max_length_depending_on_x'] = params.get(
        'MAXLEN_GIVEN_X', True)
    params_prediction['output_max_length_depending_on_x_factor'] = params.get(
        'MAXLEN_GIVEN_X_FACTOR', 3)
    params_prediction['output_min_length_depending_on_x'] = params.get(
        'MINLEN_GIVEN_X', True)
    params_prediction['output_min_length_depending_on_x_factor'] = params.get(
        'MINLEN_GIVEN_X_FACTOR', 2)
    # Manage pos_unk strategies
    if params['POS_UNK']:
        mapping = None if dataset.mapping == dict() else dataset.mapping
    else:
        mapping = None

    if args.online:
        logging.info('Loading models from %s' % str(args.models))

        model_instances = [
            TranslationModel(params,
                             model_type=params['MODEL_TYPE'],
                             verbose=params['VERBOSE'],
                             model_name=params['MODEL_NAME'] + '_' + str(i),
                             vocabularies=dataset.vocabulary,
                             store_path=params['STORE_PATH'],
                             set_optimizer=False)
            for i in range(len(args.models))
        ]
        models = [
            updateModel(model, path, -1, full_path=True)
            for (model, path) in zip(model_instances, args.models)
        ]

        # Set additional inputs to models if using a custom loss function
        params['USE_CUSTOM_LOSS'] = True if 'PAS' in params[
            'OPTIMIZER'] else False
        if params['N_BEST_OPTIMIZER']:
            logging.info('Using N-best optimizer')

        models = build_online_models(models, params)
        online_trainer = OnlineTrainer(models,
                                       dataset,
                                       None,
                                       None,
                                       params_training,
                                       verbose=args.verbose)
    else:
        models = [loadModel(m, -1, full_path=True) for m in args.models]

    params['INPUT_VOCABULARY_SIZE'] = dataset.vocabulary_len[
        params['INPUTS_IDS_DATASET'][0]]
    params['OUTPUT_VOCABULARY_SIZE'] = dataset.vocabulary_len[
        params['OUTPUTS_IDS_DATASET'][0]]

    # Get word2index and index2word dictionaries
    index2word_y = dataset.vocabulary[params['OUTPUTS_IDS_DATASET']
                                      [0]]['idx2words']
    word2index_y = dataset.vocabulary[params['OUTPUTS_IDS_DATASET']
                                      [0]]['words2idx']
    index2word_x = dataset.vocabulary[params['INPUTS_IDS_DATASET']
                                      [0]]['idx2words']
    word2index_x = dataset.vocabulary[params['INPUTS_IDS_DATASET']
                                      [0]]['words2idx']

    excluded_words = None
    interactive_beam_searcher = NMTSampler(models,
                                           dataset,
                                           params_prediction,
                                           tokenize_f,
                                           detokenize_function,
                                           tokenize_general,
                                           detokenize_general,
                                           mapping=mapping,
                                           word2index_x=word2index_x,
                                           word2index_y=word2index_y,
                                           index2word_y=index2word_y,
                                           excluded_words=excluded_words,
                                           verbose=args.verbose)

    # Compile Theano sampling function by generating a fake sample # TODO: Find a better way of doing this
    print "Compiling sampler..."
    interactive_beam_searcher.generate_sample('i')

    httpd.sampler = interactive_beam_searcher

    print 'Server starting at localhost:' + str(args.port)
    httpd.serve_forever()
コード例 #4
0
        params['OUTPUTS_IDS_DATASET'][0]]
    logger.info("<<< Using an ensemble of %d models >>>" % len(args.models))
    # Load trainable model(s)
    logging.info('Loading models from %s' % str(args.models))
    model_instances = [
        TranslationModel(params,
                         model_type=params['MODEL_TYPE'],
                         verbose=params['VERBOSE'],
                         model_name=params['MODEL_NAME'] + '_' + str(i),
                         vocabularies=dataset.vocabulary,
                         store_path=params['STORE_PATH'],
                         clear_dirs=False,
                         set_optimizer=False) for i in range(len(args.models))
    ]
    models = [
        updateModel(model, path, -1, full_path=True)
        for (model, path) in zip(model_instances, args.models)
    ]

    # Set additional inputs to models if using a custom loss function
    params['USE_CUSTOM_LOSS'] = True if 'PAS' in params['OPTIMIZER'] else False
    if params.get('N_BEST_OPTIMIZER', False):
        logging.info('Using N-best optimizer')

    models = build_online_models(models, params)
    online_trainer = OnlineTrainer(models,
                                   dataset,
                                   None,
                                   None,
                                   params_training,
                                   verbose=args.verbose)
コード例 #5
0
ファイル: main.py プロジェクト: JHnlp/nmt-keras
def train_model(params, load_dataset=None):
    """
    Training function. Sets the training parameters from params. Build or loads the model and launches the training.
    :param params: Dictionary of network hyperparameters.
    :param load_dataset: Load dataset from file or build it from the parameters.
    :return: None
    """

    if params['RELOAD'] > 0:
        logging.info('Resuming training.')

    check_params(params)

    # Load data
    if load_dataset is None:
        dataset = build_dataset(params)
    else:
        dataset = loadDataset(load_dataset)

    params['INPUT_VOCABULARY_SIZE'] = dataset.vocabulary_len[
        params['INPUTS_IDS_DATASET'][0]]
    params['OUTPUT_VOCABULARY_SIZE'] = dataset.vocabulary_len[
        params['OUTPUTS_IDS_DATASET'][0]]

    # Build model
    if params['RELOAD'] == 0:  # build new model
        nmt_model = TranslationModel(params,
                                     model_type=params['MODEL_TYPE'],
                                     verbose=params['VERBOSE'],
                                     model_name=params['MODEL_NAME'],
                                     vocabularies=dataset.vocabulary,
                                     store_path=params['STORE_PATH'])
        dict2pkl(params, params['STORE_PATH'] + '/config')

        # Define the inputs and outputs mapping from our Dataset instance to our model
        inputMapping = dict()
        for i, id_in in enumerate(params['INPUTS_IDS_DATASET']):
            pos_source = dataset.ids_inputs.index(id_in)
            id_dest = nmt_model.ids_inputs[i]
            inputMapping[id_dest] = pos_source
        nmt_model.setInputsMapping(inputMapping)

        outputMapping = dict()
        for i, id_out in enumerate(params['OUTPUTS_IDS_DATASET']):
            pos_target = dataset.ids_outputs.index(id_out)
            id_dest = nmt_model.ids_outputs[i]
            outputMapping[id_dest] = pos_target
        nmt_model.setOutputsMapping(outputMapping)

    else:  # resume from previously trained model
        nmt_model = TranslationModel(params,
                                     model_type=params['MODEL_TYPE'],
                                     verbose=params['VERBOSE'],
                                     model_name=params['MODEL_NAME'],
                                     vocabularies=dataset.vocabulary,
                                     store_path=params['STORE_PATH'],
                                     set_optimizer=False,
                                     clear_dirs=False)

        # Define the inputs and outputs mapping from our Dataset instance to our model
        inputMapping = dict()
        for i, id_in in enumerate(params['INPUTS_IDS_DATASET']):
            pos_source = dataset.ids_inputs.index(id_in)
            id_dest = nmt_model.ids_inputs[i]
            inputMapping[id_dest] = pos_source
        nmt_model.setInputsMapping(inputMapping)

        outputMapping = dict()
        for i, id_out in enumerate(params['OUTPUTS_IDS_DATASET']):
            pos_target = dataset.ids_outputs.index(id_out)
            id_dest = nmt_model.ids_outputs[i]
            outputMapping[id_dest] = pos_target

        nmt_model.setOutputsMapping(outputMapping)
        nmt_model = updateModel(nmt_model,
                                params['STORE_PATH'],
                                params['RELOAD'],
                                reload_epoch=params['RELOAD_EPOCH'])
        nmt_model.setParams(params)
        nmt_model.setOptimizer()
        params['EPOCH_OFFSET'] = params['RELOAD'] if params['RELOAD_EPOCH'] else \
            int(params['RELOAD'] * params['BATCH_SIZE'] / dataset.len_train)

    # Callbacks
    callbacks = buildCallbacks(params, nmt_model, dataset)

    # Training
    total_start_time = timer()

    logger.debug('Starting training!')
    training_params = {
        'n_epochs':
        params['MAX_EPOCH'],
        'batch_size':
        params['BATCH_SIZE'],
        'homogeneous_batches':
        params['HOMOGENEOUS_BATCHES'],
        'maxlen':
        params['MAX_OUTPUT_TEXT_LEN'],
        'joint_batches':
        params['JOINT_BATCHES'],
        'lr_decay':
        params['LR_DECAY'],
        'lr_gamma':
        params['LR_GAMMA'],
        'epochs_for_save':
        params['EPOCHS_FOR_SAVE'],
        'verbose':
        params['VERBOSE'],
        'eval_on_sets':
        params['EVAL_ON_SETS_KERAS'],
        'n_parallel_loaders':
        params['PARALLEL_LOADERS'],
        'extra_callbacks':
        callbacks,
        'reload_epoch':
        params['RELOAD'],
        'epoch_offset':
        params.get('EPOCH_OFFSET', 0),
        'data_augmentation':
        params['DATA_AUGMENTATION'],
        'patience':
        params.get('PATIENCE', 0),  # early stopping parameters
        'metric_check':
        params.get('STOP_METRIC', None)
        if params.get('EARLY_STOP', False) else None,
        'eval_on_epochs':
        params.get('EVAL_EACH_EPOCHS', True),
        'each_n_epochs':
        params.get('EVAL_EACH', 1),
        'start_eval_on_epoch':
        params.get('START_EVAL_ON_EPOCH', 0)
    }
    nmt_model.trainNet(dataset, training_params)

    total_end_time = timer()
    time_difference = total_end_time - total_start_time
    logging.info('In total is {0:.2f}s = {1:.2f}m'.format(
        time_difference, time_difference / 60.0))
コード例 #6
0
def main():
    args = parse_args()
    server_address = (args.address, args.port)
    httpd = BaseHTTPServer.HTTPServer(server_address, NMTHandler)
    logger.setLevel(args.logging_level)
    if args.config is not None:
        logger.info('Reading parameters from %s.' % args.config)
        parameters = update_parameters({}, pkl2dict(args.config))
    else:
        logger.info('Reading parameters from config.py.')
        parameters = load_parameters()

    if args.online:
        online_parameters = load_parameters_online(parameters)
        parameters = update_parameters(parameters, online_parameters)

    try:
        for arg in args.changes:
            try:
                k, v = arg.split('=')
            except ValueError:
                print(
                    'Overwritten arguments must have the form key=Value. \n Currently are: %s'
                    % str(args.changes))
                exit(1)
            try:
                parameters[k] = ast.literal_eval(v)
            except ValueError:
                parameters[k] = v
    except ValueError:
        print('Error processing arguments: (', k, ",", v, ")")
        exit(2)

    check_params(parameters)
    if args.verbose:
        logging.info("parameters = " + str(parameters))
    dataset = loadDataset(args.dataset)

    # Dataset backwards compatibility
    bpe_separator = dataset.BPE_separator if hasattr(
        dataset,
        "BPE_separator") and dataset.BPE_separator is not None else '@@'
    # Build BPE tokenizer if necessary
    if 'bpe' in parameters['TOKENIZATION_METHOD'].lower():
        logger.info('Building BPE')
        if not dataset.BPE_built:
            dataset.build_bpe(
                parameters.get(
                    'BPE_CODES_PATH',
                    parameters['DATA_ROOT_PATH'] + '/training_codes.joint'),
                bpe_separator)
    # Build tokenization function
    tokenize_f = eval('dataset.' +
                      parameters.get('TOKENIZATION_METHOD', 'tokenize_bpe'))
    detokenize_function = eval(
        'dataset.' + parameters.get('DETOKENIZATION_METHOD', 'detokenize_bpe'))
    dataset.build_moses_tokenizer(language=parameters['TRG_LAN'])
    dataset.build_moses_detokenizer(language=parameters['TRG_LAN'])
    tokenize_general = dataset.tokenize_moses
    detokenize_general = dataset.detokenize_moses

    parameters_training = dict()
    if args.online:
        logging.info('Loading models from %s' % str(args.models))
        parameters_training = {  # Traning parameters
            'n_epochs': parameters['MAX_EPOCH'],
            'shuffle': False,
            'loss': parameters.get('LOSS', 'categorical_crossentropy'),
            'batch_size': parameters.get('BATCH_SIZE', 1),
            'homogeneous_batches': False,
            'optimizer': parameters.get('OPTIMIZER', 'SGD'),
            'lr': parameters.get('LR', 0.1),
            'lr_decay': parameters.get('LR_DECAY', None),
            'lr_gamma': parameters.get('LR_GAMMA', 1.),
            'epochs_for_save': -1,
            'verbose': args.verbose,
            'eval_on_sets': parameters['EVAL_ON_SETS_KERAS'],
            'n_parallel_loaders': parameters['PARALLEL_LOADERS'],
            'extra_callbacks': [],  # callbacks,
            'reload_epoch': parameters['RELOAD'],
            'epoch_offset': parameters['RELOAD'],
            'data_augmentation': parameters['DATA_AUGMENTATION'],
            'patience': parameters.get('PATIENCE', 0),
            'metric_check': parameters.get('STOP_METRIC', None),
            'eval_on_epochs': parameters.get('EVAL_EACH_EPOCHS', True),
            'each_n_epochs': parameters.get('EVAL_EACH', 1),
            'start_eval_on_epoch': parameters.get('START_EVAL_ON_EPOCH', 0),
            'additional_training_settings': {
                'k': parameters.get('K', 1),
                'tau': parameters.get('TAU', 1),
                'lambda': parameters.get('LAMBDA', 0.5),
                'c': parameters.get('C', 0.5),
                'd': parameters.get('D', 0.5)
            }
        }
        # Load trainable model(s)
        logging.info('Loading models from %s' % str(args.models))
        model_instances = [
            Captioning_Model(
                parameters,
                model_type=parameters['MODEL_TYPE'],
                verbose=parameters['VERBOSE'],
                model_name=parameters['MODEL_NAME'] + '_' + str(i),
                vocabularies=dataset.vocabulary,
                store_path=parameters['STORE_PATH'],
                set_optimizer=False) for i in range(len(args.models))
        ]
        models = [
            updateModel(model, path, -1, full_path=True)
            for (model, path) in zip(model_instances, args.models)
        ]
        for model in models:
            model.setParams(parameters)
            model.setOptimizer()
    else:
        # Otherwise, load regular model(s)
        models = [loadModel(m, -1, full_path=True) for m in args.models]

    parameters['OUTPUT_VOCABULARY_SIZE'] = dataset.vocabulary_len[
        parameters['OUTPUTS_IDS_DATASET'][0]]

    # Get word2index and index2word dictionaries
    index2word_y = dataset.vocabulary[parameters['OUTPUTS_IDS_DATASET']
                                      [0]]['idx2words']
    word2index_y = dataset.vocabulary[parameters['OUTPUTS_IDS_DATASET']
                                      [0]]['words2idx']
    unk_id = dataset.extra_words['<unk>']

    parameters_prediction = {
        'max_batch_size':
        parameters['BATCH_SIZE'],
        'n_parallel_loaders':
        parameters['PARALLEL_LOADERS'],
        'predict_on_sets': [args.split],
        'beam_size':
        parameters['BEAM_SIZE'],
        'maxlen':
        parameters['MAX_OUTPUT_TEXT_LEN_TEST'],
        'optimized_search':
        parameters['OPTIMIZED_SEARCH'],
        'model_inputs':
        parameters['INPUTS_IDS_MODEL'],
        'model_outputs':
        parameters['OUTPUTS_IDS_MODEL'],
        'dataset_inputs':
        parameters['INPUTS_IDS_DATASET'],
        'dataset_outputs':
        parameters['OUTPUTS_IDS_DATASET'],
        'normalize_probs':
        parameters['NORMALIZE_SAMPLING'],
        'alpha_factor':
        parameters['ALPHA_FACTOR'],
        'normalize':
        parameters.get('NORMALIZATION', False),
        'normalization_type':
        parameters.get('NORMALIZATION_TYPE', None),
        'data_augmentation':
        parameters.get('DATA_AUGMENTATION', False),
        'mean_substraction':
        parameters.get('MEAN_SUBTRACTION', False),
        'wo_da_patch_type':
        parameters.get('WO_DA_PATCH_TYPE', 'whole'),
        'da_patch_type':
        parameters.get('DA_PATCH_TYPE', 'resize_and_rndcrop'),
        'da_enhance_list':
        parameters.get('DA_ENHANCE_LIST', None),
        'pos_unk':
        parameters.get('POS_UNK', None),
        'heuristic':
        parameters.get('HEURISTIC', None),
        'search_pruning':
        parameters.get('SEARCH_PRUNING', False),
        'state_below_index':
        -1,
        'output_text_index':
        0,
        'apply_tokenization':
        parameters.get('APPLY_TOKENIZATION', False),
        'tokenize_f':
        eval('dataset.' +
             parameters.get('TOKENIZATION_METHOD', 'tokenize_none')),
        'apply_detokenization':
        parameters.get('APPLY_DETOKENIZATION', True),
        'detokenize_f':
        eval('dataset.' +
             parameters.get('DETOKENIZATION_METHOD', 'detokenize_none')),
        'coverage_penalty':
        parameters.get('COVERAGE_PENALTY', False),
        'length_penalty':
        parameters.get('LENGTH_PENALTY', False),
        'length_norm_factor':
        parameters.get('LENGTH_NORM_FACTOR', 0.0),
        'coverage_norm_factor':
        parameters.get('COVERAGE_NORM_FACTOR', 0.0),
        'output_max_length_depending_on_x':
        parameters.get('MAXLEN_GIVEN_X', False),
        'output_max_length_depending_on_x_factor':
        parameters.get('MAXLEN_GIVEN_X_FACTOR', 3),
        'output_min_length_depending_on_x':
        parameters.get('MINLEN_GIVEN_X', False),
        'output_min_length_depending_on_x_factor':
        parameters.get('MINLEN_GIVEN_X_FACTOR', 2),
        'attend_on_output':
        parameters.get('ATTEND_ON_OUTPUT', 'transformer'
                       in parameters['MODEL_TYPE'].lower()),
        'n_best_optimizer':
        parameters.get('N_BEST_OPTIMIZER', False)
    }

    excluded_words = None
    interactive_beam_searcher = VideoDescSampler(models,
                                                 dataset,
                                                 parameters,
                                                 parameters_prediction,
                                                 parameters_training,
                                                 tokenize_f,
                                                 detokenize_function,
                                                 tokenize_general,
                                                 detokenize_general,
                                                 split=args.split,
                                                 word2index_y=word2index_y,
                                                 index2word_y=index2word_y,
                                                 eos_symbol=args.eos_symbol,
                                                 excluded_words=excluded_words,
                                                 unk_id=unk_id,
                                                 online=args.online,
                                                 verbose=args.verbose)

    httpd.sampler = interactive_beam_searcher

    logger.info('Server starting at %s' % str(server_address))
    httpd.serve_forever()
コード例 #7
0
def train_model(params,
                weights_dict,
                load_dataset=None,
                trainable_pred=True,
                trainable_est=True,
                weights_path=None):
    """
    Training function. Sets the training parameters from params. Build or loads the model and launches the training.
    :param params: Dictionary of network hyperparameters.
    :param load_dataset: Load dataset from file or build it from the parameters.
    :return: None
    """
    check_params(params)

    if params['RELOAD'] > 0:
        logging.info('Resuming training.')
        # Load data
        if load_dataset is None:
            if params['REBUILD_DATASET']:
                logging.info('Rebuilding dataset.')

                pred_vocab = params.get('PRED_VOCAB', None)
                if pred_vocab is not None:
                    dataset_voc = loadDataset(params['PRED_VOCAB'])
                    dataset = build_dataset(params, dataset_voc.vocabulary,
                                            dataset_voc.vocabulary_len)
                else:
                    dataset = build_dataset(params)
            else:
                logging.info('Updating dataset.')
                dataset = loadDataset(params['DATASET_STORE_PATH'] +
                                      '/Dataset_' + params['DATASET_NAME'] +
                                      '_' + params['SRC_LAN'] +
                                      params['TRG_LAN'] + '.pkl')

                for split, filename in params['TEXT_FILES'].iteritems():
                    dataset = update_dataset_from_file(
                        dataset,
                        params['DATA_ROOT_PATH'] + '/' + filename +
                        params['SRC_LAN'],
                        params,
                        splits=list([split]),
                        output_text_filename=params['DATA_ROOT_PATH'] + '/' +
                        filename + params['TRG_LAN'],
                        remove_outputs=False,
                        compute_state_below=True,
                        recompute_references=True)
                    dataset.name = params['DATASET_NAME'] + '_' + params[
                        'SRC_LAN'] + params['TRG_LAN']
                saveDataset(dataset, params['DATASET_STORE_PATH'])

        else:
            logging.info('Reloading and using dataset.')
            dataset = loadDataset(load_dataset)
    else:
        # Load data
        if load_dataset is None:
            pred_vocab = params.get('PRED_VOCAB', None)
            if pred_vocab is not None:
                dataset_voc = loadDataset(params['PRED_VOCAB'])
                # for the testing pharse handle model vocab differences
                #dataset_voc.vocabulary['target_text'] = dataset_voc.vocabulary['target']
                #dataset_voc.vocabulary_len['target_text'] = dataset_voc.vocabulary_len['target']
                dataset = build_dataset(params, dataset_voc.vocabulary,
                                        dataset_voc.vocabulary_len)
            else:
                dataset = build_dataset(params)
        else:
            dataset = loadDataset(load_dataset)

    params['INPUT_VOCABULARY_SIZE'] = dataset.vocabulary_len[
        params['INPUTS_IDS_DATASET'][0]]
    #params['OUTPUT_VOCABULARY_SIZE'] = dataset.vocabulary_len[params['OUTPUTS_IDS_DATASET_FULL'][0]]
    params['OUTPUT_VOCABULARY_SIZE'] = dataset.vocabulary_len['target_text']

    # Build model
    if params['RELOAD'] == 0:  # build new model
        nmt_model = TranslationModel(params,
                                     model_type=params['MODEL_TYPE'],
                                     verbose=params['VERBOSE'],
                                     model_name=params['MODEL_NAME'],
                                     vocabularies=dataset.vocabulary,
                                     store_path=params['STORE_PATH'],
                                     trainable_pred=trainable_pred,
                                     trainable_est=trainable_est,
                                     clear_dirs=True,
                                     weights_path=weights_path)

        # Define the inputs and outputs mapping from our Dataset instance to our model
        inputMapping = dict()
        for i, id_in in enumerate(params['INPUTS_IDS_DATASET']):
            pos_source = dataset.ids_inputs.index(id_in)
            id_dest = nmt_model.ids_inputs[i]
            inputMapping[id_dest] = pos_source
        nmt_model.setInputsMapping(inputMapping)

        outputMapping = dict()
        for i, id_out in enumerate(params['OUTPUTS_IDS_DATASET']):
            pos_target = dataset.ids_outputs.index(id_out)
            id_dest = nmt_model.ids_outputs[i]
            outputMapping[id_dest] = pos_target
        nmt_model.setOutputsMapping(outputMapping)

    else:  # resume from previously trained model
        nmt_model = TranslationModel(params,
                                     model_type=params['MODEL_TYPE'],
                                     verbose=params['VERBOSE'],
                                     model_name=params['MODEL_NAME'],
                                     vocabularies=dataset.vocabulary,
                                     store_path=params['STORE_PATH'],
                                     set_optimizer=False,
                                     trainable_pred=trainable_pred,
                                     trainable_est=trainable_est,
                                     weights_path=weights_path)

        # Define the inputs and outputs mapping from our Dataset instance to our model
        inputMapping = dict()
        for i, id_in in enumerate(params['INPUTS_IDS_DATASET']):
            pos_source = dataset.ids_inputs.index(id_in)
            id_dest = nmt_model.ids_inputs[i]
            inputMapping[id_dest] = pos_source
        nmt_model.setInputsMapping(inputMapping)

        outputMapping = dict()
        for i, id_out in enumerate(params['OUTPUTS_IDS_DATASET']):
            pos_target = dataset.ids_outputs.index(id_out)
            id_dest = nmt_model.ids_outputs[i]
            outputMapping[id_dest] = pos_target

        nmt_model.setOutputsMapping(outputMapping)
        nmt_model = updateModel(nmt_model,
                                params['STORE_PATH'],
                                params['RELOAD'],
                                reload_epoch=params['RELOAD_EPOCH'])
        nmt_model.setParams(params)
        nmt_model.setOptimizer()
        params['EPOCH_OFFSET'] = params['RELOAD'] if params['RELOAD_EPOCH'] else \
            int(params['RELOAD'] * params['BATCH_SIZE'] / dataset.len_train)

    # Store configuration as pkl
    dict2pkl(params, params['STORE_PATH'] + '/config')

    # Callbacks
    callbacks = buildCallbacks(params, nmt_model, dataset)

    # Training
    total_start_time = timer()

    logger.debug('Starting training!')
    training_params = {
        'n_epochs':
        params['MAX_EPOCH'],
        'batch_size':
        params['BATCH_SIZE'],
        'homogeneous_batches':
        params['HOMOGENEOUS_BATCHES'],
        'maxlen':
        params['MAX_OUTPUT_TEXT_LEN'],
        'joint_batches':
        params['JOINT_BATCHES'],
        'lr_decay':
        params.get('LR_DECAY', None),  # LR decay parameters
        'reduce_each_epochs':
        params.get('LR_REDUCE_EACH_EPOCHS', True),
        'start_reduction_on_epoch':
        params.get('LR_START_REDUCTION_ON_EPOCH', 0),
        'lr_gamma':
        params.get('LR_GAMMA', 0.9),
        'lr_reducer_type':
        params.get('LR_REDUCER_TYPE', 'linear'),
        'lr_reducer_exp_base':
        params.get('LR_REDUCER_EXP_BASE', 0),
        'lr_half_life':
        params.get('LR_HALF_LIFE', 50000),
        'epochs_for_save':
        params['EPOCHS_FOR_SAVE'],
        'verbose':
        params['VERBOSE'],
        'eval_on_sets':
        params['EVAL_ON_SETS_KERAS'],
        'n_parallel_loaders':
        params['PARALLEL_LOADERS'],
        'extra_callbacks':
        callbacks,
        'reload_epoch':
        params['RELOAD'],
        'epoch_offset':
        params.get('EPOCH_OFFSET', 0),
        'data_augmentation':
        params['DATA_AUGMENTATION'],
        'patience':
        params.get('PATIENCE', 0),  # early stopping parameters
        'metric_check':
        params.get('STOP_METRIC', None)
        if params.get('EARLY_STOP', False) else None,
        'eval_on_epochs':
        params.get('EVAL_EACH_EPOCHS', True),
        'each_n_epochs':
        params.get('EVAL_EACH', 1),
        'start_eval_on_epoch':
        params.get('START_EVAL_ON_EPOCH', 0)
    }
    if weights_dict is not None:
        for layer in nmt_model.model.layers:
            if layer.name in weights_dict:
                layer.set_weights(weights_dict[layer.name])

    nmt_model.trainNet(dataset, training_params)

    if weights_dict is not None:
        for layer in nmt_model.model.layers:
            weights_dict[layer.name] = layer.get_weights()

    total_end_time = timer()
    time_difference = total_end_time - total_start_time
    logging.info('In total is {0:.2f}s = {1:.2f}m'.format(
        time_difference, time_difference / 60.0))
コード例 #8
0
def apply_NMT_model(params, load_dataset=None):
    """
    Sample from a previously trained model.

    :param params: Dictionary of network hyperparameters.
    :param load_dataset: Load dataset from file or build it from the parameters.
    :return: None
    """
    pred_vocab = params.get('PRED_VOCAB', None)
    if pred_vocab is not None:
        dataset_voc = loadDataset(params['PRED_VOCAB'])
        dataset = build_dataset(params, dataset_voc.vocabulary,
                                dataset_voc.vocabulary_len)
    else:
        dataset = build_dataset(params)
    # Load data
    #if load_dataset is None:
    #    dataset = build_dataset(params)
    #else:
    #    dataset = loadDataset(load_dataset)
    #params['INPUT_VOCABULARY_SIZE'] = dataset.vocabulary_len[params['INPUTS_IDS_DATASET'][0]]
    #params['OUTPUT_VOCABULARY_SIZE'] = dataset.vocabulary_len[params['OUTPUTS_IDS_DATASET'][0]]
    #vocab_y = dataset.vocabulary[params['INPUTS_IDS_DATASET'][1]]['idx2words']
    params['INPUT_VOCABULARY_SIZE'] = dataset.vocabulary_len[
        params['INPUTS_IDS_DATASET'][0]]
    params['OUTPUT_VOCABULARY_SIZE'] = dataset.vocabulary_len['target_text']

    # Load model
    #nmt_model = loadModel(params['STORE_PATH'], params['RELOAD'], reload_epoch=params['RELOAD_EPOCH'])
    nmt_model = TranslationModel(params,
                                 model_type=params['MODEL_TYPE'],
                                 verbose=params['VERBOSE'],
                                 model_name=params['MODEL_NAME'],
                                 set_optimizer=False,
                                 vocabularies=dataset.vocabulary,
                                 store_path=params['STORE_PATH'],
                                 trainable_pred=True,
                                 trainable_est=True,
                                 weights_path=None)
    nmt_model = updateModel(nmt_model,
                            params['STORE_PATH'],
                            params['RELOAD'],
                            reload_epoch=params['RELOAD_EPOCH'])
    nmt_model.setParams(params)
    nmt_model.setOptimizer()

    inputMapping = dict()
    for i, id_in in enumerate(params['INPUTS_IDS_DATASET']):
        pos_source = dataset.ids_inputs.index(id_in)
        id_dest = nmt_model.ids_inputs[i]
        inputMapping[id_dest] = pos_source
    nmt_model.setInputsMapping(inputMapping)

    outputMapping = dict()
    for i, id_out in enumerate(params['OUTPUTS_IDS_DATASET']):
        pos_target = dataset.ids_outputs.index(id_out)
        id_dest = nmt_model.ids_outputs[i]
        outputMapping[id_dest] = pos_target
    nmt_model.setOutputsMapping(outputMapping)
    nmt_model.setOptimizer()

    for s in params["EVAL_ON_SETS"]:
        # Evaluate training
        extra_vars = {
            'language': params.get('TRG_LAN', 'en'),
            'n_parallel_loaders': params['PARALLEL_LOADERS'],
            'tokenize_f': eval('dataset.' + params['TOKENIZATION_METHOD']),
            'detokenize_f': eval('dataset.' + params['DETOKENIZATION_METHOD']),
            'apply_detokenization': params['APPLY_DETOKENIZATION'],
            'tokenize_hypotheses': params['TOKENIZE_HYPOTHESES'],
            'tokenize_references': params['TOKENIZE_REFERENCES']
        }
        #vocab = dataset.vocabulary[params['OUTPUTS_IDS_DATASET'][0]]['idx2words']
        #vocab = dataset.vocabulary[params['INPUTS_IDS_DATASET'][1]]['idx2words']
        extra_vars[s] = dict()
        if not params.get('NO_REF', False):
            extra_vars[s]['references'] = dataset.extra_variables[s][
                params['OUTPUTS_IDS_DATASET'][0]]
        #input_text_id = None
        #vocab_src = None
        input_text_id = params['INPUTS_IDS_DATASET'][0]
        vocab_x = dataset.vocabulary[input_text_id]['idx2words']
        vocab_y = dataset.vocabulary[params['INPUTS_IDS_DATASET']
                                     [1]]['idx2words']

        if params['BEAM_SEARCH']:
            extra_vars['beam_size'] = params.get('BEAM_SIZE', 6)
            extra_vars['state_below_index'] = params.get(
                'BEAM_SEARCH_COND_INPUT', -1)
            extra_vars['maxlen'] = params.get('MAX_OUTPUT_TEXT_LEN_TEST', 30)
            extra_vars['optimized_search'] = params.get(
                'OPTIMIZED_SEARCH', True)
            extra_vars['model_inputs'] = params['INPUTS_IDS_MODEL']
            extra_vars['model_outputs'] = params['OUTPUTS_IDS_MODEL']
            extra_vars['dataset_inputs'] = params['INPUTS_IDS_DATASET']
            extra_vars['dataset_outputs'] = params['OUTPUTS_IDS_DATASET']
            extra_vars['normalize_probs'] = params.get('NORMALIZE_SAMPLING',
                                                       False)
            extra_vars['search_pruning'] = params.get('SEARCH_PRUNING', False)
            extra_vars['alpha_factor'] = params.get('ALPHA_FACTOR', 1.0)
            extra_vars['coverage_penalty'] = params.get(
                'COVERAGE_PENALTY', False)
            extra_vars['length_penalty'] = params.get('LENGTH_PENALTY', False)
            extra_vars['length_norm_factor'] = params.get(
                'LENGTH_NORM_FACTOR', 0.0)
            extra_vars['coverage_norm_factor'] = params.get(
                'COVERAGE_NORM_FACTOR', 0.0)
            extra_vars['pos_unk'] = params['POS_UNK']
            extra_vars['output_max_length_depending_on_x'] = params.get(
                'MAXLEN_GIVEN_X', True)
            extra_vars['output_max_length_depending_on_x_factor'] = params.get(
                'MAXLEN_GIVEN_X_FACTOR', 3)
            extra_vars['output_min_length_depending_on_x'] = params.get(
                'MINLEN_GIVEN_X', True)
            extra_vars['output_min_length_depending_on_x_factor'] = params.get(
                'MINLEN_GIVEN_X_FACTOR', 2)

            if params['POS_UNK']:
                extra_vars['heuristic'] = params['HEURISTIC']
                input_text_id = params['INPUTS_IDS_DATASET'][0]
                vocab_src = dataset.vocabulary[input_text_id]['idx2words']
                if params['HEURISTIC'] > 0:
                    extra_vars['mapping'] = dataset.mapping

        callback_metric = PrintPerformanceMetricOnEpochEndOrEachNUpdates(
            nmt_model,
            dataset,
            gt_id=params['OUTPUTS_IDS_DATASET'][0],
            metric_name=params['METRICS'],
            set_name=params['EVAL_ON_SETS'],
            batch_size=params['BATCH_SIZE'],
            each_n_epochs=params['EVAL_EACH'],
            extra_vars=extra_vars,
            reload_epoch=params['RELOAD'],
            is_text=True,
            input_text_id=input_text_id,
            save_path=nmt_model.model_path,
            index2word_y=vocab_y,
            index2word_x=vocab_x,
            sampling_type=params['SAMPLING'],
            beam_search=params['BEAM_SEARCH'],
            start_eval_on_epoch=params['START_EVAL_ON_EPOCH'],
            write_samples=True,
            write_type=params['SAMPLING_SAVE_MODE'],
            eval_on_epochs=params['EVAL_EACH_EPOCHS'],
            save_each_evaluation=False,
            verbose=params['VERBOSE'],
            no_ref=params['NO_REF'])

        callback_metric.evaluate(
            params['RELOAD'],
            counter_name='epoch' if params['EVAL_EACH_EPOCHS'] else 'update')
コード例 #9
0
ファイル: main.py プロジェクト: lvapeab/TMA
def train_model(params):
    """
    Training function. Sets the training parameters from params. Build or loads the model and launches the training.
    :param params: Dictionary of network hyperparameters.
    :return: None
    """

    if params['RELOAD'] > 0:
        logging.info('Resuming training.')

    check_params(params)

    ########### Load data
    dataset = build_dataset(params)
    if not '-vidtext-embed' in params['DATASET_NAME']:
        params['OUTPUT_VOCABULARY_SIZE'] = dataset.vocabulary_len[
            params['OUTPUTS_IDS_DATASET'][0]]
    else:
        params['OUTPUT_VOCABULARY_SIZE'] = dataset.vocabulary_len[
            params['INPUTS_IDS_DATASET'][1]]
    ###########

    ########### Build model

    if params['MODE'] == 'finetuning':
        # video_model = loadModel(params['PRE_TRAINED_MODEL_STORE_PATHS'], params['RELOAD'])
        video_model = VideoDesc_Model(params,
                                      type=params['MODEL_TYPE'],
                                      verbose=params['VERBOSE'],
                                      model_name=params['MODEL_NAME'] +
                                      '_reloaded',
                                      vocabularies=dataset.vocabulary,
                                      store_path=params['STORE_PATH'],
                                      set_optimizer=False,
                                      clear_dirs=False)
        video_model = updateModel(video_model,
                                  params['RELOAD_PATH'],
                                  params['RELOAD'],
                                  reload_epoch=False)
        video_model.setParams(params)

        # Define the inputs and outputs mapping from our Dataset instance to our model
        inputMapping = dict()
        for i, id_in in enumerate(params['INPUTS_IDS_DATASET']):
            if len(video_model.ids_inputs) > i:
                pos_source = dataset.ids_inputs.index(id_in)
                id_dest = video_model.ids_inputs[i]
                inputMapping[id_dest] = pos_source
        video_model.setInputsMapping(inputMapping)

        outputMapping = dict()
        for i, id_out in enumerate(params['OUTPUTS_IDS_DATASET']):
            if len(video_model.ids_outputs) > i:
                pos_target = dataset.ids_outputs.index(id_out)
                id_dest = video_model.ids_outputs[i]
                outputMapping[id_dest] = pos_target
        video_model.setOutputsMapping(outputMapping)

        video_model.setOptimizer()
        params['MAX_EPOCH'] += params['RELOAD']

    else:
        if params['RELOAD'] == 0 or params[
                'LOAD_WEIGHTS_ONLY']:  # build new model
            video_model = VideoDesc_Model(params,
                                          type=params['MODEL_TYPE'],
                                          verbose=params['VERBOSE'],
                                          model_name=params['MODEL_NAME'],
                                          vocabularies=dataset.vocabulary,
                                          store_path=params['STORE_PATH'],
                                          set_optimizer=True)
            dict2pkl(params, params['STORE_PATH'] + '/config')

            # Define the inputs and outputs mapping from our Dataset instance to our model
            inputMapping = dict()
            for i, id_in in enumerate(params['INPUTS_IDS_DATASET']):
                if len(video_model.ids_inputs) > i:
                    pos_source = dataset.ids_inputs.index(id_in)
                    id_dest = video_model.ids_inputs[i]
                    inputMapping[id_dest] = pos_source
            video_model.setInputsMapping(inputMapping)

            outputMapping = dict()
            for i, id_out in enumerate(params['OUTPUTS_IDS_DATASET']):
                if len(video_model.ids_outputs) > i:
                    pos_target = dataset.ids_outputs.index(id_out)
                    id_dest = video_model.ids_outputs[i]
                    outputMapping[id_dest] = pos_target
            video_model.setOutputsMapping(outputMapping)

            # Only load weights from pre-trained model
            if params['LOAD_WEIGHTS_ONLY'] and params['RELOAD'] > 0:
                for i in range(0, len(params['RELOAD'])):
                    old_model = loadModel(
                        params['PRE_TRAINED_MODEL_STORE_PATHS'][i],
                        params['RELOAD'][i])
                    video_model = transferWeights(old_model, video_model,
                                                  params['LAYERS_MAPPING'][i])
                video_model.setOptimizer()
                params['RELOAD'] = 0
        else:  # resume from previously trained model
            video_model = loadModel(params['PRE_TRAINED_MODEL_STORE_PATHS'],
                                    params['RELOAD'])
            video_model.params['LR'] = params['LR']
            video_model.setOptimizer()

            if video_model.model_path != params['STORE_PATH']:
                video_model.setName(params['MODEL_NAME'],
                                    models_path=params['STORE_PATH'],
                                    clear_dirs=False)
    # Update optimizer either if we are loading or building a model
    video_model.params = params
    video_model.setOptimizer()
    ###########

    ########### Test model saving/loading functions
    # saveModel(video_model, params['RELOAD'])
    # video_model = loadModel(params['STORE_PATH'], params['RELOAD'])
    ###########

    ########### Callbacks
    callbacks = buildCallbacks(params, video_model, dataset)
    ###########

    ########### Training
    total_start_time = timer()

    logger.debug('Starting training!')
    training_params = {
        'n_epochs': params['MAX_EPOCH'],
        'batch_size': params['BATCH_SIZE'],
        'homogeneous_batches': params['HOMOGENEOUS_BATCHES'],
        'maxlen': params['MAX_OUTPUT_TEXT_LEN'],
        'lr_decay': params['LR_DECAY'],
        'lr_gamma': params['LR_GAMMA'],
        'epochs_for_save': params['EPOCHS_FOR_SAVE'],
        'verbose': params['VERBOSE'],
        'eval_on_sets': params['EVAL_ON_SETS_KERAS'],
        'n_parallel_loaders': params['PARALLEL_LOADERS'],
        'extra_callbacks': callbacks,
        'reload_epoch': params['RELOAD'],
        'epoch_offset': params['RELOAD'],
        'data_augmentation': params['DATA_AUGMENTATION'],
        'patience': params.get('PATIENCE', 0),  # early stopping parameters
        'metric_check': params.get('STOP_METRIC', None),
        'eval_on_epochs': params.get('EVAL_EACH_EPOCHS', True),
        'each_n_epochs': params.get('EVAL_EACH', 1),
        'start_eval_on_epoch': params.get('START_EVAL_ON_EPOCH', 0)
    }

    video_model.trainNet(dataset, training_params)

    total_end_time = timer()
    time_difference = total_end_time - total_start_time
    logging.info('In total is {0:.2f}s = {1:.2f}m'.format(
        time_difference, time_difference / 60.0))
コード例 #10
0
params['RELOAD'] = epoch_choice
# params['INPUT_VOCABULARY_SIZE'] = dataset.vocabulary_len[params['INPUTS_IDS_DATASET'][0]]
# params['OUTPUT_VOCABULARY_SIZE'] = dataset.vocabulary_len[params['OUTPUTS_IDS_DATASET'][0]]

cpu_model = TranslationModel(params,
                             model_type=params['MODEL_TYPE'],
                             verbose=True,
                             model_name=params['MODEL_NAME'],
                             vocabularies=dataset.vocabulary,
                             store_path=params['STORE_PATH'],
                             set_optimizer=True,
                             clear_dirs=True)
exit()

cpu_model = updateModel(cpu_model,
                        SRC_MODEL_PATH,
                        params['RELOAD'],
                        reload_epoch=True)

saveModel(cpu_model,
          update_num=epoch_choice,
          path=DST_MODEL_PATH,
          full_path=True)

# Define the inputs and outputs mapping from our Dataset instance to our model
inputMapping = dict()
for i, id_in in enumerate(params['INPUTS_IDS_DATASET']):
    pos_source = dataset.ids_inputs.index(id_in)
    id_dest = cpu_model.ids_inputs[i]
    inputMapping[id_dest] = pos_source
cpu_model.setInputsMapping(inputMapping)
def interactive_simulation():

    args = parse_args()
    # Update parameters
    if args.config is not None:
        logger.info('Reading parameters from %s.' % args.config)
        params = update_parameters({}, pkl2dict(args.config))
    else:
        logger.info('Reading parameters from config.py.')
        params = load_parameters()

    if args.online:
        from config_online import load_parameters as load_parameters_online
        online_parameters = load_parameters_online(params)
        params = update_parameters(params, online_parameters)

    try:
        for arg in args.changes:
            try:
                k, v = arg.split('=')
            except ValueError:
                print(
                    'Overwritten arguments must have the form key=Value. \n Currently are: %s'
                    % str(args.changes))
                exit(1)
            try:
                params[k] = ast.literal_eval(v)
            except ValueError:
                params[k] = v
    except ValueError:
        print('Error processing arguments: (', k, ",", v, ")")
        exit(2)

    check_params(params)
    if args.verbose:
        logging.info("params = " + str(params))
    dataset = loadDataset(args.dataset)
    # dataset = update_dataset_from_file(dataset, args.source, params, splits=args.splits, remove_outputs=True)
    # Dataset backwards compatibility
    bpe_separator = dataset.BPE_separator if hasattr(
        dataset,
        "BPE_separator") and dataset.BPE_separator is not None else u'@@'
    # Set tokenization method
    params[
        'TOKENIZATION_METHOD'] = 'tokenize_bpe' if args.tokenize_bpe else params.get(
            'TOKENIZATION_METHOD', 'tokenize_none')
    # Build BPE tokenizer if necessary
    if 'bpe' in params['TOKENIZATION_METHOD'].lower():
        logger.info('Building BPE')
        if not dataset.BPE_built:
            dataset.build_bpe(params.get(
                'BPE_CODES_PATH',
                params['DATA_ROOT_PATH'] + '/training_codes.joint'),
                              separator=bpe_separator)
    # Build tokenization function
    tokenize_f = eval('dataset.' +
                      params.get('TOKENIZATION_METHOD', 'tokenize_none'))

    if args.online:
        # Traning params
        params_training = {  # Traning params
            'n_epochs': params['MAX_EPOCH'],
            'shuffle': False,
            'loss': params.get('LOSS', 'categorical_crossentropy'),
            'batch_size': params.get('BATCH_SIZE', 1),
            'homogeneous_batches': False,
            'optimizer': params.get('OPTIMIZER', 'SGD'),
            'lr': params.get('LR', 0.1),
            'lr_decay': params.get('LR_DECAY', None),
            'lr_gamma': params.get('LR_GAMMA', 1.),
            'epochs_for_save': -1,
            'verbose': args.verbose,
            'eval_on_sets': params['EVAL_ON_SETS_KERAS'],
            'n_parallel_loaders': params['PARALLEL_LOADERS'],
            'extra_callbacks': [],  # callbacks,
            'reload_epoch': 0,
            'epoch_offset': 0,
            'data_augmentation': params['DATA_AUGMENTATION'],
            'patience': params.get('PATIENCE', 0),
            'metric_check': params.get('STOP_METRIC', None),
            'eval_on_epochs': params.get('EVAL_EACH_EPOCHS', True),
            'each_n_epochs': params.get('EVAL_EACH', 1),
            'start_eval_on_epoch': params.get('START_EVAL_ON_EPOCH', 0),
            'additional_training_settings': {
                'k': params.get('K', 1),
                'tau': params.get('TAU', 1),
                'lambda': params.get('LAMBDA', 0.5),
                'c': params.get('C', 0.5),
                'd': params.get('D', 0.5)
            }
        }
    else:
        params_training = dict()

    params['OUTPUT_VOCABULARY_SIZE'] = dataset.vocabulary_len[
        params['OUTPUTS_IDS_DATASET'][0]]
    logger.info("<<< Using an ensemble of %d models >>>" % len(args.models))
    if args.online:
        # Load trainable model(s)
        logging.info('Loading models from %s' % str(args.models))
        model_instances = [
            Captioning_Model(params,
                             model_type=params['MODEL_TYPE'],
                             verbose=params['VERBOSE'],
                             model_name=params['MODEL_NAME'] + '_' + str(i),
                             vocabularies=dataset.vocabulary,
                             store_path=params['STORE_PATH'],
                             clear_dirs=False,
                             set_optimizer=False)
            for i in range(len(args.models))
        ]
        models = [
            updateModel(model, path, -1, full_path=True)
            for (model, path) in zip(model_instances, args.models)
        ]

        # Set additional inputs to models if using a custom loss function
        params['USE_CUSTOM_LOSS'] = True if 'PAS' in params[
            'OPTIMIZER'] else False
        if params['N_BEST_OPTIMIZER']:
            logging.info('Using N-best optimizer')

        models = build_online_models(models, params)
        online_trainer = OnlineTrainer(models,
                                       dataset,
                                       None,
                                       None,
                                       params_training,
                                       verbose=args.verbose)
    else:
        # Otherwise, load regular model(s)
        models = [loadModel(m, -1, full_path=True) for m in args.models]

    # Load text files
    logger.info("<<< Storing corrected hypotheses into: %s >>>" %
                str(args.dest))
    ftrans = open(args.dest, 'w')
    ftrans.close()

    # Do we want to save the original sentences?
    if args.original_dest is not None:
        logger.info("<<< Storing original hypotheses into: %s >>>" %
                    str(args.original_dest))
        ftrans_ori = open(args.original_dest, 'w')
        ftrans_ori.close()

    if args.references is not None:
        ftrg = codecs.open(args.references, 'r', encoding='utf-8'
                           )  # File with post-edited (or reference) sentences.
        all_references = ftrg.read().split('\n')
        if all_references[-1] == u'':
            all_references = all_references[:-1]

    # Get word2index and index2word dictionaries
    index2word_y = dataset.vocabulary[params['OUTPUTS_IDS_DATASET']
                                      [0]]['idx2words']
    word2index_y = dataset.vocabulary[params['OUTPUTS_IDS_DATASET']
                                      [0]]['words2idx']
    unk_id = dataset.extra_words['<unk>']

    # Initialize counters
    total_errors = 0
    total_words = 0
    total_chars = 0
    total_mouse_actions = 0
    try:
        for s in args.splits:
            # Apply model predictions
            params_prediction = {
                'max_batch_size':
                params['BATCH_SIZE'],
                'n_parallel_loaders':
                params['PARALLEL_LOADERS'],
                'predict_on_sets': [s],
                'beam_size':
                params['BEAM_SIZE'],
                'maxlen':
                params['MAX_OUTPUT_TEXT_LEN_TEST'],
                'optimized_search':
                params['OPTIMIZED_SEARCH'],
                'model_inputs':
                params['INPUTS_IDS_MODEL'],
                'model_outputs':
                params['OUTPUTS_IDS_MODEL'],
                'dataset_inputs':
                params['INPUTS_IDS_DATASET'],
                'dataset_outputs':
                params['OUTPUTS_IDS_DATASET'],
                'normalize_probs':
                params.get('NORMALIZE_SAMPLING', False),
                'alpha_factor':
                params.get('ALPHA_FACTOR', 1.0),
                'normalize':
                params.get('NORMALIZATION', False),
                'normalization_type':
                params.get('NORMALIZATION_TYPE', None),
                'data_augmentation':
                params.get('DATA_AUGMENTATION', False),
                'mean_substraction':
                params.get('MEAN_SUBTRACTION', False),
                'wo_da_patch_type':
                params.get('WO_DA_PATCH_TYPE', 'whole'),
                'da_patch_type':
                params.get('DA_PATCH_TYPE', 'resize_and_rndcrop'),
                'da_enhance_list':
                params.get('DA_ENHANCE_LIST', None),
                'heuristic':
                params.get('HEURISTIC', None),
                'search_pruning':
                params.get('SEARCH_PRUNING', False),
                'state_below_index':
                -1,
                'output_text_index':
                0,
                'apply_tokenization':
                params.get('APPLY_TOKENIZATION', False),
                'tokenize_f':
                eval('dataset.' +
                     params.get('TOKENIZATION_METHOD', 'tokenize_none')),
                'apply_detokenization':
                params.get('APPLY_DETOKENIZATION', True),
                'detokenize_f':
                eval('dataset.' +
                     params.get('DETOKENIZATION_METHOD', 'detokenize_none')),
                'coverage_penalty':
                params.get('COVERAGE_PENALTY', False),
                'length_penalty':
                params.get('LENGTH_PENALTY', False),
                'length_norm_factor':
                params.get('LENGTH_NORM_FACTOR', 0.0),
                'coverage_norm_factor':
                params.get('COVERAGE_NORM_FACTOR', 0.0),
                'pos_unk':
                False,
                'state_below_maxlen':
                -1 if params.get('PAD_ON_BATCH', True) else params.get(
                    'MAX_OUTPUT_TEXT_LEN_TEST', 50),
                'output_max_length_depending_on_x':
                params.get('MAXLEN_GIVEN_X', False),
                'output_max_length_depending_on_x_factor':
                params.get('MAXLEN_GIVEN_X_FACTOR', 3),
                'output_min_length_depending_on_x':
                params.get('MINLEN_GIVEN_X', False),
                'output_min_length_depending_on_x_factor':
                params.get('MINLEN_GIVEN_X_FACTOR', 2),
                'attend_on_output':
                params.get('ATTEND_ON_OUTPUT', 'transformer'
                           in params['MODEL_TYPE'].lower()),
                'n_best_optimizer':
                params.get('N_BEST_OPTIMIZER', False)
            }

            # Build interactive sampler
            interactive_beam_searcher = InteractiveBeamSearchSampler(
                models,
                dataset,
                params_prediction,
                excluded_words=None,
                verbose=args.verbose)
            start_time = time.time()

            if args.verbose:
                logging.info("Params prediction = " + str(params_prediction))
                if args.online:
                    logging.info("Params training = " + str(params_training))
            n_samples = getattr(dataset, 'len_' + s)
            if args.references is None:
                all_references = dataset.extra_variables[s][
                    params['OUTPUTS_IDS_DATASET'][0]]

            # Start to translate the source file interactively
            for n_sample in range(n_samples):
                errors_sentence = 0
                mouse_actions_sentence = 0
                hypothesis_number = 0
                # Load data from dataset
                current_input = dataset.getX_FromIndices(
                    s, [n_sample],
                    normalization_type=params_prediction.get(
                        'normalization_type'),
                    normalization=params_prediction.get('normalize', False),
                    dataAugmentation=params_prediction.get(
                        'data_augmentation', False),
                    wo_da_patch_type=params_prediction.get(
                        'wo_da_patch_type', 'whole'),
                    da_patch_type=params_prediction.get(
                        'da_patch_type', 'resize_and_rndcrop'),
                    da_enhance_list=params_prediction.get(
                        'da_enhance_list', None))[0][0]

                # Load references
                references = all_references[n_sample]

                tokenized_references = list(map(
                    tokenize_f,
                    references)) if args.tokenize_references else references

                # Get reference as desired by the user, i.e. detokenized if necessary
                reference = list(map(params_prediction['detokenize_f'], tokenized_references)) if \
                    args.detokenize_bpe else tokenized_references

                # Detokenize line for nicer logging :)
                logger.debug(u'\n\nProcessing sample %d' % (n_sample + 1))
                logger.debug(u'Target: %s' % reference)

                # 1. Get a first hypothesis
                trans_indices, costs, alphas = interactive_beam_searcher.sample_beam_search_interactive(
                    current_input)

                # 1.2 Decode hypothesis
                hypothesis = decode_predictions_beam_search([trans_indices],
                                                            index2word_y,
                                                            pad_sequences=True,
                                                            verbose=0)[0]
                # 1.3 Store result (optional)
                hypothesis = params_prediction['detokenize_f'](hypothesis) \
                    if params_prediction.get('apply_detokenization', False) else hypothesis
                if args.original_dest is not None:
                    if params['SAMPLING_SAVE_MODE'] == 'list':
                        list2file(args.original_dest, [hypothesis],
                                  permission='a')
                    else:
                        raise Exception(
                            'Only "list" is allowed in "SAMPLING_SAVE_MODE"')
                logger.debug(u'Hypo_%d: %s' % (hypothesis_number, hypothesis))

                # 2.0 Interactive translation
                if hypothesis in tokenized_references:
                    # 2.1 If the sentence is correct, we  validate it
                    pass
                else:
                    # 2.2 Wrong hypothesis -> Interactively translate the sentence
                    correct_hypothesis = False
                    last_correct_pos = 0
                    while not correct_hypothesis:
                        # 2.2.1 Empty data structures for the next sentence
                        fixed_words_user = OrderedDict()
                        unk_words_dict = OrderedDict()
                        isle_indices = []
                        unks_in_isles = []

                        if args.prefix:
                            # 2.2.2 Compute longest common character prefix (LCCP)
                            reference_idx, next_correction_pos, validated_prefix = common_prefixes(
                                hypothesis, tokenized_references)
                        else:
                            # 2.2.2 Compute common character segments
                            #TODO
                            next_correction_pos, validated_prefix, validated_segments = common_segments(
                                hypothesis, reference)
                        reference = tokenized_references[reference_idx]
                        if next_correction_pos == len(reference):
                            correct_hypothesis = True
                            break
                        # 2.2.3 Get next correction by checking against the reference
                        next_correction = reference[next_correction_pos]

                        # 2.2.4 Tokenize the prefix properly (possibly applying BPE)
                        tokenized_validated_prefix = tokenize_f(
                            validated_prefix + next_correction)

                        # 2.2.5 Validate words
                        for pos, word in enumerate(
                                tokenized_validated_prefix.split()):
                            fixed_words_user[pos] = word2index_y.get(
                                word, unk_id)
                            if word2index_y.get(word) is None:
                                unk_words_dict[pos] = word

                        # 2.2.6 Constrain search for the last word
                        last_user_word_pos = list(fixed_words_user.keys())[-1]
                        if next_correction != u' ':
                            last_user_word = tokenized_validated_prefix.split(
                            )[-1]
                            filtered_idx2word = dict(
                                (word2index_y[candidate_word], candidate_word)
                                for candidate_word in word2index_y
                                if candidate_word[:len(last_user_word)] ==
                                last_user_word)
                            if filtered_idx2word != dict():
                                del fixed_words_user[last_user_word_pos]
                                if last_user_word_pos in unk_words_dict.keys():
                                    del unk_words_dict[last_user_word_pos]
                        else:
                            filtered_idx2word = dict()

                        logger.debug(u'"%s" to character %d.' %
                                     (next_correction, next_correction_pos))

                        # 2.2.7 Generate a hypothesis compatible with the feedback provided by the user
                        hypothesis = generate_constrained_hypothesis(
                            interactive_beam_searcher, current_input,
                            fixed_words_user, params_prediction, args,
                            isle_indices, filtered_idx2word, index2word_y,
                            None, None, None, unk_words_dict.keys(),
                            unk_words_dict.values(), unks_in_isles)
                        hypothesis_number += 1
                        hypothesis = u' '.join(
                            hypothesis)  # Hypothesis is unicode
                        hypothesis = params_prediction['detokenize_f'](hypothesis) \
                            if args.detokenize_bpe else hypothesis
                        logger.debug(u'Target: %s' % reference)
                        logger.debug(u"Hypo_%d: %s" %
                                     (hypothesis_number, hypothesis))
                        # 2.2.8 Add a keystroke
                        errors_sentence += 1
                        # 2.2.9 Add a mouse action if we moved the pointer
                        if next_correction_pos - last_correct_pos > 1:
                            mouse_actions_sentence += 1
                        last_correct_pos = next_correction_pos

                    # 2.3 Final check: The reference is a subset of the hypothesis: Cut the hypothesis
                    if len(reference) < len(hypothesis):
                        hypothesis = hypothesis[:len(reference)]
                        errors_sentence += 1
                        logger.debug(u"Cutting hypothesis")

                # 2.4 Security assertion
                assert hypothesis in references, "Error: The final hypothesis does not match with the reference! \n" \
                                                "\t Split: %s \n" \
                                                "\t Sentence: %d \n" \
                                                "\t Hypothesis: %s\n" \
                                                "\t Reference: %s" % (s, n_sample + 1,
                                                                      hypothesis,
                                                                      reference)
                # 3. Update user effort counters
                mouse_actions_sentence += 1  # This +1 is the validation action
                chars_sentence = len(hypothesis)
                total_errors += errors_sentence
                total_words += len(hypothesis.split())
                total_chars += chars_sentence
                total_mouse_actions += mouse_actions_sentence

                # 3.1 Log some info
                logger.debug(u"Final hypotesis: %s" % hypothesis)
                logger.debug(
                    u"%d errors. "
                    u"Sentence WSR: %4f. "
                    u"Sentence mouse strokes: %d "
                    u"Sentence MAR: %4f. "
                    u"Sentence MAR_c: %4f. "
                    u"Sentence KSMR: %4f. "
                    u"Accumulated (should only be considered for debugging purposes!) "
                    u"WSR: %4f. "
                    u"MAR: %4f. "
                    u"MAR_c: %4f. "
                    u"KSMR: %4f.\n\n\n\n" %
                    (errors_sentence, float(errors_sentence) / len(hypothesis),
                     mouse_actions_sentence,
                     float(mouse_actions_sentence) / len(hypothesis),
                     float(mouse_actions_sentence) / chars_sentence,
                     float(errors_sentence + mouse_actions_sentence) /
                     chars_sentence, float(total_errors) / total_words,
                     float(total_mouse_actions) / total_words,
                     float(total_mouse_actions) / total_chars,
                     float(total_errors + total_mouse_actions) / total_chars))
                # 4. If we are performing OL after each correct sample:
                if args.online:
                    # 4.1 Compute model inputs
                    # 4.1.1 Source text -> Already computed (used for the INMT process)
                    # 4.1.2 State below
                    state_below = dataset.loadText(
                        [reference],
                        vocabularies=dataset.vocabulary[
                            params['OUTPUTS_IDS_DATASET'][0]],
                        max_len=params['MAX_OUTPUT_TEXT_LEN_TEST'],
                        offset=1,
                        fill=dataset.fill_text[params['INPUTS_IDS_DATASET']
                                               [-1]],
                        pad_on_batch=dataset.pad_on_batch[
                            params['INPUTS_IDS_DATASET'][-1]],
                        words_so_far=False,
                        loading_X=True)[0]

                    # 4.1.3 Ground truth sample -> Interactively translated sentence
                    trg_seq = dataset.loadTextOneHot(
                        [reference],
                        vocabularies=dataset.vocabulary[
                            params['OUTPUTS_IDS_DATASET'][0]],
                        vocabulary_len=dataset.vocabulary_len[
                            params['OUTPUTS_IDS_DATASET'][0]],
                        max_len=params['MAX_OUTPUT_TEXT_LEN_TEST'],
                        offset=0,
                        fill=dataset.fill_text[params['OUTPUTS_IDS_DATASET']
                                               [0]],
                        pad_on_batch=dataset.pad_on_batch[
                            params['OUTPUTS_IDS_DATASET'][0]],
                        words_so_far=False,
                        sample_weights=params['SAMPLE_WEIGHTS'],
                        loading_X=False)
                    # 4.2 Train online!
                    online_trainer.train_online(
                        [np.asarray([current_input]), state_below],
                        trg_seq,
                        trg_words=[reference])
                # 5 Write correct sentences into a file
                list2file(args.dest, [hypothesis], permission='a')

                if (n_sample + 1) % 50 == 0:
                    logger.info(u"%d sentences processed" % (n_sample + 1))
                    logger.info(u"Current speed is {} per sentence".format(
                        (time.time() - start_time) / (n_sample + 1)))
                    logger.info(u"Current WSR is: %f" %
                                (float(total_errors) / total_words))
                    logger.info(u"Current MAR is: %f" %
                                (float(total_mouse_actions) / total_words))
                    logger.info(u"Current MAR_c is: %f" %
                                (float(total_mouse_actions) / total_chars))
                    logger.info(u"Current KSMR is: %f" %
                                (float(total_errors + total_mouse_actions) /
                                 total_chars))
        # 6. Final!
        # 6.1 Log some information
        print(u"Total number of errors:", total_errors)
        print(u"Total number selections", total_mouse_actions)
        print(u"WSR: %f" % (float(total_errors) / total_words))
        print(u"MAR: %f" % (float(total_mouse_actions) / total_words))
        print(u"MAR_c: %f" % (float(total_mouse_actions) / total_chars))
        print(u"KSMR: %f" %
              (float(total_errors + total_mouse_actions) / total_chars))

    except KeyboardInterrupt:
        print(u'Interrupted!')
        print(u"Total number of corrections (up to now):", total_errors)
        print(u"WSR: %f" % (float(total_errors) / total_words))
        print(u"MAR: %f" % (float(total_mouse_actions) / total_words))
        print(u"MAR_c: %f" % (float(total_mouse_actions) / total_chars))
        print(u"KSMR: %f" %
              (float(total_errors + total_mouse_actions) / total_chars))