Beispiel #1
0
def train_model(params, load_dataset=None):
    """
    Training function. Sets the training parameters from params. Build or loads the model and launches the training.
    :param params: Dictionary of network hyperparameters.
    :param load_dataset: Load dataset from file or build it from the parameters.
    :return: None
    """
    check_params(params)

    if params['RELOAD'] > 0:
        logging.info('Resuming training.')
        # Load data
        if load_dataset is None:
            if params['REBUILD_DATASET']:
                logging.info('Rebuilding dataset.')
                dataset = build_dataset(params)
            else:
                logging.info('Updating dataset.')
                dataset = loadDataset(params['DATASET_STORE_PATH'] +
                                      '/Dataset_' + params['DATASET_NAME'] +
                                      '_' + params['SRC_LAN'] +
                                      params['TRG_LAN'] + '.pkl')
                params['EPOCH_OFFSET'] = params['RELOAD'] if params['RELOAD_EPOCH'] else \
                    int(params['RELOAD'] * params['BATCH_SIZE'] / dataset.len_train)
                for split, filename in params['TEXT_FILES'].iteritems():
                    dataset = update_dataset_from_file(
                        dataset,
                        params['DATA_ROOT_PATH'] + '/' + filename +
                        params['SRC_LAN'],
                        params,
                        splits=list([split]),
                        output_text_filename=params['DATA_ROOT_PATH'] + '/' +
                        filename + params['TRG_LAN'],
                        remove_outputs=False,
                        compute_state_below=True,
                        recompute_references=True)
                    dataset.name = params['DATASET_NAME'] + '_' + params[
                        'SRC_LAN'] + params['TRG_LAN']
                saveDataset(dataset, params['DATASET_STORE_PATH'])

        else:
            logging.info('Reloading and using dataset.')
            dataset = loadDataset(load_dataset)
    else:
        # Load data
        if load_dataset is None:
            dataset = build_dataset(params)
        else:
            dataset = loadDataset(load_dataset)

    params['INPUT_VOCABULARY_SIZE'] = dataset.vocabulary_len[
        params['INPUTS_IDS_DATASET'][0]]
    params['OUTPUT_VOCABULARY_SIZE'] = dataset.vocabulary_len[
        params['OUTPUTS_IDS_DATASET'][0]]

    # Build model
    set_optimizer = True if params['RELOAD'] == 0 else False
    clear_dirs = True if params['RELOAD'] == 0 else False

    # build new model
    nmt_model = TranslationModel(params,
                                 model_type=params['MODEL_TYPE'],
                                 verbose=params['VERBOSE'],
                                 model_name=params['MODEL_NAME'],
                                 vocabularies=dataset.vocabulary,
                                 store_path=params['STORE_PATH'],
                                 set_optimizer=set_optimizer,
                                 clear_dirs=clear_dirs)

    # Define the inputs and outputs mapping from our Dataset instance to our model
    inputMapping = dict()
    for i, id_in in enumerate(params['INPUTS_IDS_DATASET']):
        pos_source = dataset.ids_inputs.index(id_in)
        id_dest = nmt_model.ids_inputs[i]
        inputMapping[id_dest] = pos_source
    nmt_model.setInputsMapping(inputMapping)

    outputMapping = dict()
    for i, id_out in enumerate(params['OUTPUTS_IDS_DATASET']):
        pos_target = dataset.ids_outputs.index(id_out)
        id_dest = nmt_model.ids_outputs[i]
        outputMapping[id_dest] = pos_target
    nmt_model.setOutputsMapping(outputMapping)

    if params['RELOAD'] > 0:
        nmt_model = updateModel(nmt_model,
                                params['STORE_PATH'],
                                params['RELOAD'],
                                reload_epoch=params['RELOAD_EPOCH'])
        nmt_model.setParams(params)
        nmt_model.setOptimizer()
        if params.get('EPOCH_OFFSET') is None:
            params['EPOCH_OFFSET'] = params['RELOAD'] if params['RELOAD_EPOCH'] else \
                int(params['RELOAD'] * params['BATCH_SIZE'] / dataset.len_train)

    # Store configuration as pkl
    dict2pkl(params, params['STORE_PATH'] + '/config')

    # Callbacks
    callbacks = buildCallbacks(params, nmt_model, dataset)

    # Training
    total_start_time = timer()

    logger.debug('Starting training!')
    training_params = {
        'n_epochs':
        params['MAX_EPOCH'],
        'batch_size':
        params['BATCH_SIZE'],
        'homogeneous_batches':
        params['HOMOGENEOUS_BATCHES'],
        'maxlen':
        params['MAX_OUTPUT_TEXT_LEN'],
        'joint_batches':
        params['JOINT_BATCHES'],
        'lr_decay':
        params.get('LR_DECAY', None),  # LR decay parameters
        'reduce_each_epochs':
        params.get('LR_REDUCE_EACH_EPOCHS', True),
        'start_reduction_on_epoch':
        params.get('LR_START_REDUCTION_ON_EPOCH', 0),
        'lr_gamma':
        params.get('LR_GAMMA', 0.9),
        'lr_reducer_type':
        params.get('LR_REDUCER_TYPE', 'linear'),
        'lr_reducer_exp_base':
        params.get('LR_REDUCER_EXP_BASE', 0),
        'lr_half_life':
        params.get('LR_HALF_LIFE', 50000),
        'epochs_for_save':
        params['EPOCHS_FOR_SAVE'],
        'verbose':
        params['VERBOSE'],
        'eval_on_sets':
        params['EVAL_ON_SETS_KERAS'],
        'n_parallel_loaders':
        params['PARALLEL_LOADERS'],
        'extra_callbacks':
        callbacks,
        'reload_epoch':
        params['RELOAD'],
        'epoch_offset':
        params.get('EPOCH_OFFSET', 0),
        'data_augmentation':
        params['DATA_AUGMENTATION'],
        'patience':
        params.get('PATIENCE', 0),  # early stopping parameters
        'metric_check':
        params.get('STOP_METRIC', None)
        if params.get('EARLY_STOP', False) else None,
        'eval_on_epochs':
        params.get('EVAL_EACH_EPOCHS', True),
        'each_n_epochs':
        params.get('EVAL_EACH', 1),
        'start_eval_on_epoch':
        params.get('START_EVAL_ON_EPOCH', 0),
        'tensorboard':
        params.get('TENSORBOARD', False),
        'tensorboard_params': {
            'log_dir':
            params.get('LOG_DIR', 'tensorboard_logs'),
            'histogram_freq':
            params.get('HISTOGRAM_FREQ', 0),
            'batch_size':
            params.get('TENSORBOARD_BATCH_SIZE', params['BATCH_SIZE']),
            'write_graph':
            params.get('WRITE_GRAPH', True),
            'write_grads':
            params.get('WRITE_GRADS', False),
            'write_images':
            params.get('WRITE_IMAGES', False),
            'embeddings_freq':
            params.get('EMBEDDINGS_FREQ', 0),
            'embeddings_layer_names':
            params.get('EMBEDDINGS_LAYER_NAMES', None),
            'embeddings_metadata':
            params.get('EMBEDDINGS_METADATA', None),
            'label_word_embeddings_with_vocab':
            params.get('LABEL_WORD_EMBEDDINGS_WITH_VOCAB', False),
            'word_embeddings_labels':
            params.get('WORD_EMBEDDINGS_LABELS', None),
        }
    }
    nmt_model.trainNet(dataset, training_params)

    total_end_time = timer()
    time_difference = total_end_time - total_start_time
    logging.info('In total is {0:.2f}s = {1:.2f}m'.format(
        time_difference, time_difference / 60.0))
def train_model(params,
                weights_dict,
                load_dataset=None,
                trainable_pred=True,
                trainable_est=True,
                weights_path=None):
    """
    Training function. Sets the training parameters from params. Build or loads the model and launches the training.
    :param params: Dictionary of network hyperparameters.
    :param load_dataset: Load dataset from file or build it from the parameters.
    :return: None
    """
    check_params(params)

    if params['RELOAD'] > 0:
        logging.info('Resuming training.')
        # Load data
        if load_dataset is None:
            if params['REBUILD_DATASET']:
                logging.info('Rebuilding dataset.')

                pred_vocab = params.get('PRED_VOCAB', None)
                if pred_vocab is not None:
                    dataset_voc = loadDataset(params['PRED_VOCAB'])
                    dataset = build_dataset(params, dataset_voc.vocabulary,
                                            dataset_voc.vocabulary_len)
                else:
                    dataset = build_dataset(params)
            else:
                logging.info('Updating dataset.')
                dataset = loadDataset(params['DATASET_STORE_PATH'] +
                                      '/Dataset_' + params['DATASET_NAME'] +
                                      '_' + params['SRC_LAN'] +
                                      params['TRG_LAN'] + '.pkl')

                for split, filename in params['TEXT_FILES'].iteritems():
                    dataset = update_dataset_from_file(
                        dataset,
                        params['DATA_ROOT_PATH'] + '/' + filename +
                        params['SRC_LAN'],
                        params,
                        splits=list([split]),
                        output_text_filename=params['DATA_ROOT_PATH'] + '/' +
                        filename + params['TRG_LAN'],
                        remove_outputs=False,
                        compute_state_below=True,
                        recompute_references=True)
                    dataset.name = params['DATASET_NAME'] + '_' + params[
                        'SRC_LAN'] + params['TRG_LAN']
                saveDataset(dataset, params['DATASET_STORE_PATH'])

        else:
            logging.info('Reloading and using dataset.')
            dataset = loadDataset(load_dataset)
    else:
        # Load data
        if load_dataset is None:
            pred_vocab = params.get('PRED_VOCAB', None)
            if pred_vocab is not None:
                dataset_voc = loadDataset(params['PRED_VOCAB'])
                # for the testing pharse handle model vocab differences
                #dataset_voc.vocabulary['target_text'] = dataset_voc.vocabulary['target']
                #dataset_voc.vocabulary_len['target_text'] = dataset_voc.vocabulary_len['target']
                dataset = build_dataset(params, dataset_voc.vocabulary,
                                        dataset_voc.vocabulary_len)
            else:
                dataset = build_dataset(params)
        else:
            dataset = loadDataset(load_dataset)

    params['INPUT_VOCABULARY_SIZE'] = dataset.vocabulary_len[
        params['INPUTS_IDS_DATASET'][0]]
    #params['OUTPUT_VOCABULARY_SIZE'] = dataset.vocabulary_len[params['OUTPUTS_IDS_DATASET_FULL'][0]]
    params['OUTPUT_VOCABULARY_SIZE'] = dataset.vocabulary_len['target_text']

    # Build model
    if params['RELOAD'] == 0:  # build new model
        nmt_model = TranslationModel(params,
                                     model_type=params['MODEL_TYPE'],
                                     verbose=params['VERBOSE'],
                                     model_name=params['MODEL_NAME'],
                                     vocabularies=dataset.vocabulary,
                                     store_path=params['STORE_PATH'],
                                     trainable_pred=trainable_pred,
                                     trainable_est=trainable_est,
                                     clear_dirs=True,
                                     weights_path=weights_path)

        # Define the inputs and outputs mapping from our Dataset instance to our model
        inputMapping = dict()
        for i, id_in in enumerate(params['INPUTS_IDS_DATASET']):
            pos_source = dataset.ids_inputs.index(id_in)
            id_dest = nmt_model.ids_inputs[i]
            inputMapping[id_dest] = pos_source
        nmt_model.setInputsMapping(inputMapping)

        outputMapping = dict()
        for i, id_out in enumerate(params['OUTPUTS_IDS_DATASET']):
            pos_target = dataset.ids_outputs.index(id_out)
            id_dest = nmt_model.ids_outputs[i]
            outputMapping[id_dest] = pos_target
        nmt_model.setOutputsMapping(outputMapping)

    else:  # resume from previously trained model
        nmt_model = TranslationModel(params,
                                     model_type=params['MODEL_TYPE'],
                                     verbose=params['VERBOSE'],
                                     model_name=params['MODEL_NAME'],
                                     vocabularies=dataset.vocabulary,
                                     store_path=params['STORE_PATH'],
                                     set_optimizer=False,
                                     trainable_pred=trainable_pred,
                                     trainable_est=trainable_est,
                                     weights_path=weights_path)

        # Define the inputs and outputs mapping from our Dataset instance to our model
        inputMapping = dict()
        for i, id_in in enumerate(params['INPUTS_IDS_DATASET']):
            pos_source = dataset.ids_inputs.index(id_in)
            id_dest = nmt_model.ids_inputs[i]
            inputMapping[id_dest] = pos_source
        nmt_model.setInputsMapping(inputMapping)

        outputMapping = dict()
        for i, id_out in enumerate(params['OUTPUTS_IDS_DATASET']):
            pos_target = dataset.ids_outputs.index(id_out)
            id_dest = nmt_model.ids_outputs[i]
            outputMapping[id_dest] = pos_target

        nmt_model.setOutputsMapping(outputMapping)
        nmt_model = updateModel(nmt_model,
                                params['STORE_PATH'],
                                params['RELOAD'],
                                reload_epoch=params['RELOAD_EPOCH'])
        nmt_model.setParams(params)
        nmt_model.setOptimizer()
        params['EPOCH_OFFSET'] = params['RELOAD'] if params['RELOAD_EPOCH'] else \
            int(params['RELOAD'] * params['BATCH_SIZE'] / dataset.len_train)

    # Store configuration as pkl
    dict2pkl(params, params['STORE_PATH'] + '/config')

    # Callbacks
    callbacks = buildCallbacks(params, nmt_model, dataset)

    # Training
    total_start_time = timer()

    logger.debug('Starting training!')
    training_params = {
        'n_epochs':
        params['MAX_EPOCH'],
        'batch_size':
        params['BATCH_SIZE'],
        'homogeneous_batches':
        params['HOMOGENEOUS_BATCHES'],
        'maxlen':
        params['MAX_OUTPUT_TEXT_LEN'],
        'joint_batches':
        params['JOINT_BATCHES'],
        'lr_decay':
        params.get('LR_DECAY', None),  # LR decay parameters
        'reduce_each_epochs':
        params.get('LR_REDUCE_EACH_EPOCHS', True),
        'start_reduction_on_epoch':
        params.get('LR_START_REDUCTION_ON_EPOCH', 0),
        'lr_gamma':
        params.get('LR_GAMMA', 0.9),
        'lr_reducer_type':
        params.get('LR_REDUCER_TYPE', 'linear'),
        'lr_reducer_exp_base':
        params.get('LR_REDUCER_EXP_BASE', 0),
        'lr_half_life':
        params.get('LR_HALF_LIFE', 50000),
        'epochs_for_save':
        params['EPOCHS_FOR_SAVE'],
        'verbose':
        params['VERBOSE'],
        'eval_on_sets':
        params['EVAL_ON_SETS_KERAS'],
        'n_parallel_loaders':
        params['PARALLEL_LOADERS'],
        'extra_callbacks':
        callbacks,
        'reload_epoch':
        params['RELOAD'],
        'epoch_offset':
        params.get('EPOCH_OFFSET', 0),
        'data_augmentation':
        params['DATA_AUGMENTATION'],
        'patience':
        params.get('PATIENCE', 0),  # early stopping parameters
        'metric_check':
        params.get('STOP_METRIC', None)
        if params.get('EARLY_STOP', False) else None,
        'eval_on_epochs':
        params.get('EVAL_EACH_EPOCHS', True),
        'each_n_epochs':
        params.get('EVAL_EACH', 1),
        'start_eval_on_epoch':
        params.get('START_EVAL_ON_EPOCH', 0)
    }
    if weights_dict is not None:
        for layer in nmt_model.model.layers:
            if layer.name in weights_dict:
                layer.set_weights(weights_dict[layer.name])

    nmt_model.trainNet(dataset, training_params)

    if weights_dict is not None:
        for layer in nmt_model.model.layers:
            weights_dict[layer.name] = layer.get_weights()

    total_end_time = timer()
    time_difference = total_end_time - total_start_time
    logging.info('In total is {0:.2f}s = {1:.2f}m'.format(
        time_difference, time_difference / 60.0))
Beispiel #3
0
def train_model(params, load_dataset=None):
    """
    Training function. Sets the training parameters from params. Build or loads the model and launches the training.
    :param params: Dictionary of network hyperparameters.
    :param load_dataset: Load dataset from file or build it from the parameters.
    :return: None
    """

    if params['RELOAD'] > 0:
        logging.info('Resuming training.')

    check_params(params)

    # Load data
    if load_dataset is None:
        dataset = build_dataset(params)
    else:
        dataset = loadDataset(load_dataset)

    params['INPUT_VOCABULARY_SIZE'] = dataset.vocabulary_len[
        params['INPUTS_IDS_DATASET'][0]]
    params['OUTPUT_VOCABULARY_SIZE'] = dataset.vocabulary_len[
        params['OUTPUTS_IDS_DATASET'][0]]

    # Build model
    if params['RELOAD'] == 0:  # build new model
        nmt_model = TranslationModel(params,
                                     model_type=params['MODEL_TYPE'],
                                     verbose=params['VERBOSE'],
                                     model_name=params['MODEL_NAME'],
                                     vocabularies=dataset.vocabulary,
                                     store_path=params['STORE_PATH'])
        dict2pkl(params, params['STORE_PATH'] + '/config')

        # Define the inputs and outputs mapping from our Dataset instance to our model
        inputMapping = dict()
        for i, id_in in enumerate(params['INPUTS_IDS_DATASET']):
            pos_source = dataset.ids_inputs.index(id_in)
            id_dest = nmt_model.ids_inputs[i]
            inputMapping[id_dest] = pos_source
        nmt_model.setInputsMapping(inputMapping)

        outputMapping = dict()
        for i, id_out in enumerate(params['OUTPUTS_IDS_DATASET']):
            pos_target = dataset.ids_outputs.index(id_out)
            id_dest = nmt_model.ids_outputs[i]
            outputMapping[id_dest] = pos_target
        nmt_model.setOutputsMapping(outputMapping)

    else:  # resume from previously trained model
        nmt_model = TranslationModel(params,
                                     model_type=params['MODEL_TYPE'],
                                     verbose=params['VERBOSE'],
                                     model_name=params['MODEL_NAME'],
                                     vocabularies=dataset.vocabulary,
                                     store_path=params['STORE_PATH'],
                                     set_optimizer=False,
                                     clear_dirs=False)

        # Define the inputs and outputs mapping from our Dataset instance to our model
        inputMapping = dict()
        for i, id_in in enumerate(params['INPUTS_IDS_DATASET']):
            pos_source = dataset.ids_inputs.index(id_in)
            id_dest = nmt_model.ids_inputs[i]
            inputMapping[id_dest] = pos_source
        nmt_model.setInputsMapping(inputMapping)

        outputMapping = dict()
        for i, id_out in enumerate(params['OUTPUTS_IDS_DATASET']):
            pos_target = dataset.ids_outputs.index(id_out)
            id_dest = nmt_model.ids_outputs[i]
            outputMapping[id_dest] = pos_target

        nmt_model.setOutputsMapping(outputMapping)
        nmt_model = updateModel(nmt_model,
                                params['STORE_PATH'],
                                params['RELOAD'],
                                reload_epoch=params['RELOAD_EPOCH'])
        nmt_model.setParams(params)
        nmt_model.setOptimizer()
        params['EPOCH_OFFSET'] = params['RELOAD'] if params['RELOAD_EPOCH'] else \
            int(params['RELOAD'] * params['BATCH_SIZE'] / dataset.len_train)

    # Callbacks
    callbacks = buildCallbacks(params, nmt_model, dataset)

    # Training
    total_start_time = timer()

    logger.debug('Starting training!')
    training_params = {
        'n_epochs':
        params['MAX_EPOCH'],
        'batch_size':
        params['BATCH_SIZE'],
        'homogeneous_batches':
        params['HOMOGENEOUS_BATCHES'],
        'maxlen':
        params['MAX_OUTPUT_TEXT_LEN'],
        'joint_batches':
        params['JOINT_BATCHES'],
        'lr_decay':
        params['LR_DECAY'],
        'lr_gamma':
        params['LR_GAMMA'],
        'epochs_for_save':
        params['EPOCHS_FOR_SAVE'],
        'verbose':
        params['VERBOSE'],
        'eval_on_sets':
        params['EVAL_ON_SETS_KERAS'],
        'n_parallel_loaders':
        params['PARALLEL_LOADERS'],
        'extra_callbacks':
        callbacks,
        'reload_epoch':
        params['RELOAD'],
        'epoch_offset':
        params.get('EPOCH_OFFSET', 0),
        'data_augmentation':
        params['DATA_AUGMENTATION'],
        'patience':
        params.get('PATIENCE', 0),  # early stopping parameters
        'metric_check':
        params.get('STOP_METRIC', None)
        if params.get('EARLY_STOP', False) else None,
        'eval_on_epochs':
        params.get('EVAL_EACH_EPOCHS', True),
        'each_n_epochs':
        params.get('EVAL_EACH', 1),
        'start_eval_on_epoch':
        params.get('START_EVAL_ON_EPOCH', 0)
    }
    nmt_model.trainNet(dataset, training_params)

    total_end_time = timer()
    time_difference = total_end_time - total_start_time
    logging.info('In total is {0:.2f}s = {1:.2f}m'.format(
        time_difference, time_difference / 60.0))
def apply_NMT_model(params, load_dataset=None):
    """
    Sample from a previously trained model.

    :param params: Dictionary of network hyperparameters.
    :param load_dataset: Load dataset from file or build it from the parameters.
    :return: None
    """
    pred_vocab = params.get('PRED_VOCAB', None)
    if pred_vocab is not None:
        dataset_voc = loadDataset(params['PRED_VOCAB'])
        dataset = build_dataset(params, dataset_voc.vocabulary,
                                dataset_voc.vocabulary_len)
    else:
        dataset = build_dataset(params)
    # Load data
    #if load_dataset is None:
    #    dataset = build_dataset(params)
    #else:
    #    dataset = loadDataset(load_dataset)
    #params['INPUT_VOCABULARY_SIZE'] = dataset.vocabulary_len[params['INPUTS_IDS_DATASET'][0]]
    #params['OUTPUT_VOCABULARY_SIZE'] = dataset.vocabulary_len[params['OUTPUTS_IDS_DATASET'][0]]
    #vocab_y = dataset.vocabulary[params['INPUTS_IDS_DATASET'][1]]['idx2words']
    params['INPUT_VOCABULARY_SIZE'] = dataset.vocabulary_len[
        params['INPUTS_IDS_DATASET'][0]]
    params['OUTPUT_VOCABULARY_SIZE'] = dataset.vocabulary_len['target_text']

    # Load model
    #nmt_model = loadModel(params['STORE_PATH'], params['RELOAD'], reload_epoch=params['RELOAD_EPOCH'])
    nmt_model = TranslationModel(params,
                                 model_type=params['MODEL_TYPE'],
                                 verbose=params['VERBOSE'],
                                 model_name=params['MODEL_NAME'],
                                 set_optimizer=False,
                                 vocabularies=dataset.vocabulary,
                                 store_path=params['STORE_PATH'],
                                 trainable_pred=True,
                                 trainable_est=True,
                                 weights_path=None)
    nmt_model = updateModel(nmt_model,
                            params['STORE_PATH'],
                            params['RELOAD'],
                            reload_epoch=params['RELOAD_EPOCH'])
    nmt_model.setParams(params)
    nmt_model.setOptimizer()

    inputMapping = dict()
    for i, id_in in enumerate(params['INPUTS_IDS_DATASET']):
        pos_source = dataset.ids_inputs.index(id_in)
        id_dest = nmt_model.ids_inputs[i]
        inputMapping[id_dest] = pos_source
    nmt_model.setInputsMapping(inputMapping)

    outputMapping = dict()
    for i, id_out in enumerate(params['OUTPUTS_IDS_DATASET']):
        pos_target = dataset.ids_outputs.index(id_out)
        id_dest = nmt_model.ids_outputs[i]
        outputMapping[id_dest] = pos_target
    nmt_model.setOutputsMapping(outputMapping)
    nmt_model.setOptimizer()

    for s in params["EVAL_ON_SETS"]:
        # Evaluate training
        extra_vars = {
            'language': params.get('TRG_LAN', 'en'),
            'n_parallel_loaders': params['PARALLEL_LOADERS'],
            'tokenize_f': eval('dataset.' + params['TOKENIZATION_METHOD']),
            'detokenize_f': eval('dataset.' + params['DETOKENIZATION_METHOD']),
            'apply_detokenization': params['APPLY_DETOKENIZATION'],
            'tokenize_hypotheses': params['TOKENIZE_HYPOTHESES'],
            'tokenize_references': params['TOKENIZE_REFERENCES']
        }
        #vocab = dataset.vocabulary[params['OUTPUTS_IDS_DATASET'][0]]['idx2words']
        #vocab = dataset.vocabulary[params['INPUTS_IDS_DATASET'][1]]['idx2words']
        extra_vars[s] = dict()
        if not params.get('NO_REF', False):
            extra_vars[s]['references'] = dataset.extra_variables[s][
                params['OUTPUTS_IDS_DATASET'][0]]
        #input_text_id = None
        #vocab_src = None
        input_text_id = params['INPUTS_IDS_DATASET'][0]
        vocab_x = dataset.vocabulary[input_text_id]['idx2words']
        vocab_y = dataset.vocabulary[params['INPUTS_IDS_DATASET']
                                     [1]]['idx2words']

        if params['BEAM_SEARCH']:
            extra_vars['beam_size'] = params.get('BEAM_SIZE', 6)
            extra_vars['state_below_index'] = params.get(
                'BEAM_SEARCH_COND_INPUT', -1)
            extra_vars['maxlen'] = params.get('MAX_OUTPUT_TEXT_LEN_TEST', 30)
            extra_vars['optimized_search'] = params.get(
                'OPTIMIZED_SEARCH', True)
            extra_vars['model_inputs'] = params['INPUTS_IDS_MODEL']
            extra_vars['model_outputs'] = params['OUTPUTS_IDS_MODEL']
            extra_vars['dataset_inputs'] = params['INPUTS_IDS_DATASET']
            extra_vars['dataset_outputs'] = params['OUTPUTS_IDS_DATASET']
            extra_vars['normalize_probs'] = params.get('NORMALIZE_SAMPLING',
                                                       False)
            extra_vars['search_pruning'] = params.get('SEARCH_PRUNING', False)
            extra_vars['alpha_factor'] = params.get('ALPHA_FACTOR', 1.0)
            extra_vars['coverage_penalty'] = params.get(
                'COVERAGE_PENALTY', False)
            extra_vars['length_penalty'] = params.get('LENGTH_PENALTY', False)
            extra_vars['length_norm_factor'] = params.get(
                'LENGTH_NORM_FACTOR', 0.0)
            extra_vars['coverage_norm_factor'] = params.get(
                'COVERAGE_NORM_FACTOR', 0.0)
            extra_vars['pos_unk'] = params['POS_UNK']
            extra_vars['output_max_length_depending_on_x'] = params.get(
                'MAXLEN_GIVEN_X', True)
            extra_vars['output_max_length_depending_on_x_factor'] = params.get(
                'MAXLEN_GIVEN_X_FACTOR', 3)
            extra_vars['output_min_length_depending_on_x'] = params.get(
                'MINLEN_GIVEN_X', True)
            extra_vars['output_min_length_depending_on_x_factor'] = params.get(
                'MINLEN_GIVEN_X_FACTOR', 2)

            if params['POS_UNK']:
                extra_vars['heuristic'] = params['HEURISTIC']
                input_text_id = params['INPUTS_IDS_DATASET'][0]
                vocab_src = dataset.vocabulary[input_text_id]['idx2words']
                if params['HEURISTIC'] > 0:
                    extra_vars['mapping'] = dataset.mapping

        callback_metric = PrintPerformanceMetricOnEpochEndOrEachNUpdates(
            nmt_model,
            dataset,
            gt_id=params['OUTPUTS_IDS_DATASET'][0],
            metric_name=params['METRICS'],
            set_name=params['EVAL_ON_SETS'],
            batch_size=params['BATCH_SIZE'],
            each_n_epochs=params['EVAL_EACH'],
            extra_vars=extra_vars,
            reload_epoch=params['RELOAD'],
            is_text=True,
            input_text_id=input_text_id,
            save_path=nmt_model.model_path,
            index2word_y=vocab_y,
            index2word_x=vocab_x,
            sampling_type=params['SAMPLING'],
            beam_search=params['BEAM_SEARCH'],
            start_eval_on_epoch=params['START_EVAL_ON_EPOCH'],
            write_samples=True,
            write_type=params['SAMPLING_SAVE_MODE'],
            eval_on_epochs=params['EVAL_EACH_EPOCHS'],
            save_each_evaluation=False,
            verbose=params['VERBOSE'],
            no_ref=params['NO_REF'])

        callback_metric.evaluate(
            params['RELOAD'],
            counter_name='epoch' if params['EVAL_EACH_EPOCHS'] else 'update')