def train_model(params, load_dataset=None): """ Training function. Sets the training parameters from params. Build or loads the model and launches the training. :param params: Dictionary of network hyperparameters. :param load_dataset: Load dataset from file or build it from the parameters. :return: None """ check_params(params) if params['RELOAD'] > 0: logging.info('Resuming training.') # Load data if load_dataset is None: if params['REBUILD_DATASET']: logging.info('Rebuilding dataset.') dataset = build_dataset(params) else: logging.info('Updating dataset.') dataset = loadDataset(params['DATASET_STORE_PATH'] + '/Dataset_' + params['DATASET_NAME'] + '_' + params['SRC_LAN'] + params['TRG_LAN'] + '.pkl') params['EPOCH_OFFSET'] = params['RELOAD'] if params['RELOAD_EPOCH'] else \ int(params['RELOAD'] * params['BATCH_SIZE'] / dataset.len_train) for split, filename in params['TEXT_FILES'].iteritems(): dataset = update_dataset_from_file( dataset, params['DATA_ROOT_PATH'] + '/' + filename + params['SRC_LAN'], params, splits=list([split]), output_text_filename=params['DATA_ROOT_PATH'] + '/' + filename + params['TRG_LAN'], remove_outputs=False, compute_state_below=True, recompute_references=True) dataset.name = params['DATASET_NAME'] + '_' + params[ 'SRC_LAN'] + params['TRG_LAN'] saveDataset(dataset, params['DATASET_STORE_PATH']) else: logging.info('Reloading and using dataset.') dataset = loadDataset(load_dataset) else: # Load data if load_dataset is None: dataset = build_dataset(params) else: dataset = loadDataset(load_dataset) params['INPUT_VOCABULARY_SIZE'] = dataset.vocabulary_len[ params['INPUTS_IDS_DATASET'][0]] params['OUTPUT_VOCABULARY_SIZE'] = dataset.vocabulary_len[ params['OUTPUTS_IDS_DATASET'][0]] # Build model set_optimizer = True if params['RELOAD'] == 0 else False clear_dirs = True if params['RELOAD'] == 0 else False # build new model nmt_model = TranslationModel(params, model_type=params['MODEL_TYPE'], verbose=params['VERBOSE'], model_name=params['MODEL_NAME'], vocabularies=dataset.vocabulary, store_path=params['STORE_PATH'], set_optimizer=set_optimizer, clear_dirs=clear_dirs) # Define the inputs and outputs mapping from our Dataset instance to our model inputMapping = dict() for i, id_in in enumerate(params['INPUTS_IDS_DATASET']): pos_source = dataset.ids_inputs.index(id_in) id_dest = nmt_model.ids_inputs[i] inputMapping[id_dest] = pos_source nmt_model.setInputsMapping(inputMapping) outputMapping = dict() for i, id_out in enumerate(params['OUTPUTS_IDS_DATASET']): pos_target = dataset.ids_outputs.index(id_out) id_dest = nmt_model.ids_outputs[i] outputMapping[id_dest] = pos_target nmt_model.setOutputsMapping(outputMapping) if params['RELOAD'] > 0: nmt_model = updateModel(nmt_model, params['STORE_PATH'], params['RELOAD'], reload_epoch=params['RELOAD_EPOCH']) nmt_model.setParams(params) nmt_model.setOptimizer() if params.get('EPOCH_OFFSET') is None: params['EPOCH_OFFSET'] = params['RELOAD'] if params['RELOAD_EPOCH'] else \ int(params['RELOAD'] * params['BATCH_SIZE'] / dataset.len_train) # Store configuration as pkl dict2pkl(params, params['STORE_PATH'] + '/config') # Callbacks callbacks = buildCallbacks(params, nmt_model, dataset) # Training total_start_time = timer() logger.debug('Starting training!') training_params = { 'n_epochs': params['MAX_EPOCH'], 'batch_size': params['BATCH_SIZE'], 'homogeneous_batches': params['HOMOGENEOUS_BATCHES'], 'maxlen': params['MAX_OUTPUT_TEXT_LEN'], 'joint_batches': params['JOINT_BATCHES'], 'lr_decay': params.get('LR_DECAY', None), # LR decay parameters 'reduce_each_epochs': params.get('LR_REDUCE_EACH_EPOCHS', True), 'start_reduction_on_epoch': params.get('LR_START_REDUCTION_ON_EPOCH', 0), 'lr_gamma': params.get('LR_GAMMA', 0.9), 'lr_reducer_type': params.get('LR_REDUCER_TYPE', 'linear'), 'lr_reducer_exp_base': params.get('LR_REDUCER_EXP_BASE', 0), 'lr_half_life': params.get('LR_HALF_LIFE', 50000), 'epochs_for_save': params['EPOCHS_FOR_SAVE'], 'verbose': params['VERBOSE'], 'eval_on_sets': params['EVAL_ON_SETS_KERAS'], 'n_parallel_loaders': params['PARALLEL_LOADERS'], 'extra_callbacks': callbacks, 'reload_epoch': params['RELOAD'], 'epoch_offset': params.get('EPOCH_OFFSET', 0), 'data_augmentation': params['DATA_AUGMENTATION'], 'patience': params.get('PATIENCE', 0), # early stopping parameters 'metric_check': params.get('STOP_METRIC', None) if params.get('EARLY_STOP', False) else None, 'eval_on_epochs': params.get('EVAL_EACH_EPOCHS', True), 'each_n_epochs': params.get('EVAL_EACH', 1), 'start_eval_on_epoch': params.get('START_EVAL_ON_EPOCH', 0), 'tensorboard': params.get('TENSORBOARD', False), 'tensorboard_params': { 'log_dir': params.get('LOG_DIR', 'tensorboard_logs'), 'histogram_freq': params.get('HISTOGRAM_FREQ', 0), 'batch_size': params.get('TENSORBOARD_BATCH_SIZE', params['BATCH_SIZE']), 'write_graph': params.get('WRITE_GRAPH', True), 'write_grads': params.get('WRITE_GRADS', False), 'write_images': params.get('WRITE_IMAGES', False), 'embeddings_freq': params.get('EMBEDDINGS_FREQ', 0), 'embeddings_layer_names': params.get('EMBEDDINGS_LAYER_NAMES', None), 'embeddings_metadata': params.get('EMBEDDINGS_METADATA', None), 'label_word_embeddings_with_vocab': params.get('LABEL_WORD_EMBEDDINGS_WITH_VOCAB', False), 'word_embeddings_labels': params.get('WORD_EMBEDDINGS_LABELS', None), } } nmt_model.trainNet(dataset, training_params) total_end_time = timer() time_difference = total_end_time - total_start_time logging.info('In total is {0:.2f}s = {1:.2f}m'.format( time_difference, time_difference / 60.0))
def train_model(params, weights_dict, load_dataset=None, trainable_pred=True, trainable_est=True, weights_path=None): """ Training function. Sets the training parameters from params. Build or loads the model and launches the training. :param params: Dictionary of network hyperparameters. :param load_dataset: Load dataset from file or build it from the parameters. :return: None """ check_params(params) if params['RELOAD'] > 0: logging.info('Resuming training.') # Load data if load_dataset is None: if params['REBUILD_DATASET']: logging.info('Rebuilding dataset.') pred_vocab = params.get('PRED_VOCAB', None) if pred_vocab is not None: dataset_voc = loadDataset(params['PRED_VOCAB']) dataset = build_dataset(params, dataset_voc.vocabulary, dataset_voc.vocabulary_len) else: dataset = build_dataset(params) else: logging.info('Updating dataset.') dataset = loadDataset(params['DATASET_STORE_PATH'] + '/Dataset_' + params['DATASET_NAME'] + '_' + params['SRC_LAN'] + params['TRG_LAN'] + '.pkl') for split, filename in params['TEXT_FILES'].iteritems(): dataset = update_dataset_from_file( dataset, params['DATA_ROOT_PATH'] + '/' + filename + params['SRC_LAN'], params, splits=list([split]), output_text_filename=params['DATA_ROOT_PATH'] + '/' + filename + params['TRG_LAN'], remove_outputs=False, compute_state_below=True, recompute_references=True) dataset.name = params['DATASET_NAME'] + '_' + params[ 'SRC_LAN'] + params['TRG_LAN'] saveDataset(dataset, params['DATASET_STORE_PATH']) else: logging.info('Reloading and using dataset.') dataset = loadDataset(load_dataset) else: # Load data if load_dataset is None: pred_vocab = params.get('PRED_VOCAB', None) if pred_vocab is not None: dataset_voc = loadDataset(params['PRED_VOCAB']) # for the testing pharse handle model vocab differences #dataset_voc.vocabulary['target_text'] = dataset_voc.vocabulary['target'] #dataset_voc.vocabulary_len['target_text'] = dataset_voc.vocabulary_len['target'] dataset = build_dataset(params, dataset_voc.vocabulary, dataset_voc.vocabulary_len) else: dataset = build_dataset(params) else: dataset = loadDataset(load_dataset) params['INPUT_VOCABULARY_SIZE'] = dataset.vocabulary_len[ params['INPUTS_IDS_DATASET'][0]] #params['OUTPUT_VOCABULARY_SIZE'] = dataset.vocabulary_len[params['OUTPUTS_IDS_DATASET_FULL'][0]] params['OUTPUT_VOCABULARY_SIZE'] = dataset.vocabulary_len['target_text'] # Build model if params['RELOAD'] == 0: # build new model nmt_model = TranslationModel(params, model_type=params['MODEL_TYPE'], verbose=params['VERBOSE'], model_name=params['MODEL_NAME'], vocabularies=dataset.vocabulary, store_path=params['STORE_PATH'], trainable_pred=trainable_pred, trainable_est=trainable_est, clear_dirs=True, weights_path=weights_path) # Define the inputs and outputs mapping from our Dataset instance to our model inputMapping = dict() for i, id_in in enumerate(params['INPUTS_IDS_DATASET']): pos_source = dataset.ids_inputs.index(id_in) id_dest = nmt_model.ids_inputs[i] inputMapping[id_dest] = pos_source nmt_model.setInputsMapping(inputMapping) outputMapping = dict() for i, id_out in enumerate(params['OUTPUTS_IDS_DATASET']): pos_target = dataset.ids_outputs.index(id_out) id_dest = nmt_model.ids_outputs[i] outputMapping[id_dest] = pos_target nmt_model.setOutputsMapping(outputMapping) else: # resume from previously trained model nmt_model = TranslationModel(params, model_type=params['MODEL_TYPE'], verbose=params['VERBOSE'], model_name=params['MODEL_NAME'], vocabularies=dataset.vocabulary, store_path=params['STORE_PATH'], set_optimizer=False, trainable_pred=trainable_pred, trainable_est=trainable_est, weights_path=weights_path) # Define the inputs and outputs mapping from our Dataset instance to our model inputMapping = dict() for i, id_in in enumerate(params['INPUTS_IDS_DATASET']): pos_source = dataset.ids_inputs.index(id_in) id_dest = nmt_model.ids_inputs[i] inputMapping[id_dest] = pos_source nmt_model.setInputsMapping(inputMapping) outputMapping = dict() for i, id_out in enumerate(params['OUTPUTS_IDS_DATASET']): pos_target = dataset.ids_outputs.index(id_out) id_dest = nmt_model.ids_outputs[i] outputMapping[id_dest] = pos_target nmt_model.setOutputsMapping(outputMapping) nmt_model = updateModel(nmt_model, params['STORE_PATH'], params['RELOAD'], reload_epoch=params['RELOAD_EPOCH']) nmt_model.setParams(params) nmt_model.setOptimizer() params['EPOCH_OFFSET'] = params['RELOAD'] if params['RELOAD_EPOCH'] else \ int(params['RELOAD'] * params['BATCH_SIZE'] / dataset.len_train) # Store configuration as pkl dict2pkl(params, params['STORE_PATH'] + '/config') # Callbacks callbacks = buildCallbacks(params, nmt_model, dataset) # Training total_start_time = timer() logger.debug('Starting training!') training_params = { 'n_epochs': params['MAX_EPOCH'], 'batch_size': params['BATCH_SIZE'], 'homogeneous_batches': params['HOMOGENEOUS_BATCHES'], 'maxlen': params['MAX_OUTPUT_TEXT_LEN'], 'joint_batches': params['JOINT_BATCHES'], 'lr_decay': params.get('LR_DECAY', None), # LR decay parameters 'reduce_each_epochs': params.get('LR_REDUCE_EACH_EPOCHS', True), 'start_reduction_on_epoch': params.get('LR_START_REDUCTION_ON_EPOCH', 0), 'lr_gamma': params.get('LR_GAMMA', 0.9), 'lr_reducer_type': params.get('LR_REDUCER_TYPE', 'linear'), 'lr_reducer_exp_base': params.get('LR_REDUCER_EXP_BASE', 0), 'lr_half_life': params.get('LR_HALF_LIFE', 50000), 'epochs_for_save': params['EPOCHS_FOR_SAVE'], 'verbose': params['VERBOSE'], 'eval_on_sets': params['EVAL_ON_SETS_KERAS'], 'n_parallel_loaders': params['PARALLEL_LOADERS'], 'extra_callbacks': callbacks, 'reload_epoch': params['RELOAD'], 'epoch_offset': params.get('EPOCH_OFFSET', 0), 'data_augmentation': params['DATA_AUGMENTATION'], 'patience': params.get('PATIENCE', 0), # early stopping parameters 'metric_check': params.get('STOP_METRIC', None) if params.get('EARLY_STOP', False) else None, 'eval_on_epochs': params.get('EVAL_EACH_EPOCHS', True), 'each_n_epochs': params.get('EVAL_EACH', 1), 'start_eval_on_epoch': params.get('START_EVAL_ON_EPOCH', 0) } if weights_dict is not None: for layer in nmt_model.model.layers: if layer.name in weights_dict: layer.set_weights(weights_dict[layer.name]) nmt_model.trainNet(dataset, training_params) if weights_dict is not None: for layer in nmt_model.model.layers: weights_dict[layer.name] = layer.get_weights() total_end_time = timer() time_difference = total_end_time - total_start_time logging.info('In total is {0:.2f}s = {1:.2f}m'.format( time_difference, time_difference / 60.0))
def train_model(params, load_dataset=None): """ Training function. Sets the training parameters from params. Build or loads the model and launches the training. :param params: Dictionary of network hyperparameters. :param load_dataset: Load dataset from file or build it from the parameters. :return: None """ if params['RELOAD'] > 0: logging.info('Resuming training.') check_params(params) # Load data if load_dataset is None: dataset = build_dataset(params) else: dataset = loadDataset(load_dataset) params['INPUT_VOCABULARY_SIZE'] = dataset.vocabulary_len[ params['INPUTS_IDS_DATASET'][0]] params['OUTPUT_VOCABULARY_SIZE'] = dataset.vocabulary_len[ params['OUTPUTS_IDS_DATASET'][0]] # Build model if params['RELOAD'] == 0: # build new model nmt_model = TranslationModel(params, model_type=params['MODEL_TYPE'], verbose=params['VERBOSE'], model_name=params['MODEL_NAME'], vocabularies=dataset.vocabulary, store_path=params['STORE_PATH']) dict2pkl(params, params['STORE_PATH'] + '/config') # Define the inputs and outputs mapping from our Dataset instance to our model inputMapping = dict() for i, id_in in enumerate(params['INPUTS_IDS_DATASET']): pos_source = dataset.ids_inputs.index(id_in) id_dest = nmt_model.ids_inputs[i] inputMapping[id_dest] = pos_source nmt_model.setInputsMapping(inputMapping) outputMapping = dict() for i, id_out in enumerate(params['OUTPUTS_IDS_DATASET']): pos_target = dataset.ids_outputs.index(id_out) id_dest = nmt_model.ids_outputs[i] outputMapping[id_dest] = pos_target nmt_model.setOutputsMapping(outputMapping) else: # resume from previously trained model nmt_model = TranslationModel(params, model_type=params['MODEL_TYPE'], verbose=params['VERBOSE'], model_name=params['MODEL_NAME'], vocabularies=dataset.vocabulary, store_path=params['STORE_PATH'], set_optimizer=False, clear_dirs=False) # Define the inputs and outputs mapping from our Dataset instance to our model inputMapping = dict() for i, id_in in enumerate(params['INPUTS_IDS_DATASET']): pos_source = dataset.ids_inputs.index(id_in) id_dest = nmt_model.ids_inputs[i] inputMapping[id_dest] = pos_source nmt_model.setInputsMapping(inputMapping) outputMapping = dict() for i, id_out in enumerate(params['OUTPUTS_IDS_DATASET']): pos_target = dataset.ids_outputs.index(id_out) id_dest = nmt_model.ids_outputs[i] outputMapping[id_dest] = pos_target nmt_model.setOutputsMapping(outputMapping) nmt_model = updateModel(nmt_model, params['STORE_PATH'], params['RELOAD'], reload_epoch=params['RELOAD_EPOCH']) nmt_model.setParams(params) nmt_model.setOptimizer() params['EPOCH_OFFSET'] = params['RELOAD'] if params['RELOAD_EPOCH'] else \ int(params['RELOAD'] * params['BATCH_SIZE'] / dataset.len_train) # Callbacks callbacks = buildCallbacks(params, nmt_model, dataset) # Training total_start_time = timer() logger.debug('Starting training!') training_params = { 'n_epochs': params['MAX_EPOCH'], 'batch_size': params['BATCH_SIZE'], 'homogeneous_batches': params['HOMOGENEOUS_BATCHES'], 'maxlen': params['MAX_OUTPUT_TEXT_LEN'], 'joint_batches': params['JOINT_BATCHES'], 'lr_decay': params['LR_DECAY'], 'lr_gamma': params['LR_GAMMA'], 'epochs_for_save': params['EPOCHS_FOR_SAVE'], 'verbose': params['VERBOSE'], 'eval_on_sets': params['EVAL_ON_SETS_KERAS'], 'n_parallel_loaders': params['PARALLEL_LOADERS'], 'extra_callbacks': callbacks, 'reload_epoch': params['RELOAD'], 'epoch_offset': params.get('EPOCH_OFFSET', 0), 'data_augmentation': params['DATA_AUGMENTATION'], 'patience': params.get('PATIENCE', 0), # early stopping parameters 'metric_check': params.get('STOP_METRIC', None) if params.get('EARLY_STOP', False) else None, 'eval_on_epochs': params.get('EVAL_EACH_EPOCHS', True), 'each_n_epochs': params.get('EVAL_EACH', 1), 'start_eval_on_epoch': params.get('START_EVAL_ON_EPOCH', 0) } nmt_model.trainNet(dataset, training_params) total_end_time = timer() time_difference = total_end_time - total_start_time logging.info('In total is {0:.2f}s = {1:.2f}m'.format( time_difference, time_difference / 60.0))
def apply_NMT_model(params, load_dataset=None): """ Sample from a previously trained model. :param params: Dictionary of network hyperparameters. :param load_dataset: Load dataset from file or build it from the parameters. :return: None """ pred_vocab = params.get('PRED_VOCAB', None) if pred_vocab is not None: dataset_voc = loadDataset(params['PRED_VOCAB']) dataset = build_dataset(params, dataset_voc.vocabulary, dataset_voc.vocabulary_len) else: dataset = build_dataset(params) # Load data #if load_dataset is None: # dataset = build_dataset(params) #else: # dataset = loadDataset(load_dataset) #params['INPUT_VOCABULARY_SIZE'] = dataset.vocabulary_len[params['INPUTS_IDS_DATASET'][0]] #params['OUTPUT_VOCABULARY_SIZE'] = dataset.vocabulary_len[params['OUTPUTS_IDS_DATASET'][0]] #vocab_y = dataset.vocabulary[params['INPUTS_IDS_DATASET'][1]]['idx2words'] params['INPUT_VOCABULARY_SIZE'] = dataset.vocabulary_len[ params['INPUTS_IDS_DATASET'][0]] params['OUTPUT_VOCABULARY_SIZE'] = dataset.vocabulary_len['target_text'] # Load model #nmt_model = loadModel(params['STORE_PATH'], params['RELOAD'], reload_epoch=params['RELOAD_EPOCH']) nmt_model = TranslationModel(params, model_type=params['MODEL_TYPE'], verbose=params['VERBOSE'], model_name=params['MODEL_NAME'], set_optimizer=False, vocabularies=dataset.vocabulary, store_path=params['STORE_PATH'], trainable_pred=True, trainable_est=True, weights_path=None) nmt_model = updateModel(nmt_model, params['STORE_PATH'], params['RELOAD'], reload_epoch=params['RELOAD_EPOCH']) nmt_model.setParams(params) nmt_model.setOptimizer() inputMapping = dict() for i, id_in in enumerate(params['INPUTS_IDS_DATASET']): pos_source = dataset.ids_inputs.index(id_in) id_dest = nmt_model.ids_inputs[i] inputMapping[id_dest] = pos_source nmt_model.setInputsMapping(inputMapping) outputMapping = dict() for i, id_out in enumerate(params['OUTPUTS_IDS_DATASET']): pos_target = dataset.ids_outputs.index(id_out) id_dest = nmt_model.ids_outputs[i] outputMapping[id_dest] = pos_target nmt_model.setOutputsMapping(outputMapping) nmt_model.setOptimizer() for s in params["EVAL_ON_SETS"]: # Evaluate training extra_vars = { 'language': params.get('TRG_LAN', 'en'), 'n_parallel_loaders': params['PARALLEL_LOADERS'], 'tokenize_f': eval('dataset.' + params['TOKENIZATION_METHOD']), 'detokenize_f': eval('dataset.' + params['DETOKENIZATION_METHOD']), 'apply_detokenization': params['APPLY_DETOKENIZATION'], 'tokenize_hypotheses': params['TOKENIZE_HYPOTHESES'], 'tokenize_references': params['TOKENIZE_REFERENCES'] } #vocab = dataset.vocabulary[params['OUTPUTS_IDS_DATASET'][0]]['idx2words'] #vocab = dataset.vocabulary[params['INPUTS_IDS_DATASET'][1]]['idx2words'] extra_vars[s] = dict() if not params.get('NO_REF', False): extra_vars[s]['references'] = dataset.extra_variables[s][ params['OUTPUTS_IDS_DATASET'][0]] #input_text_id = None #vocab_src = None input_text_id = params['INPUTS_IDS_DATASET'][0] vocab_x = dataset.vocabulary[input_text_id]['idx2words'] vocab_y = dataset.vocabulary[params['INPUTS_IDS_DATASET'] [1]]['idx2words'] if params['BEAM_SEARCH']: extra_vars['beam_size'] = params.get('BEAM_SIZE', 6) extra_vars['state_below_index'] = params.get( 'BEAM_SEARCH_COND_INPUT', -1) extra_vars['maxlen'] = params.get('MAX_OUTPUT_TEXT_LEN_TEST', 30) extra_vars['optimized_search'] = params.get( 'OPTIMIZED_SEARCH', True) extra_vars['model_inputs'] = params['INPUTS_IDS_MODEL'] extra_vars['model_outputs'] = params['OUTPUTS_IDS_MODEL'] extra_vars['dataset_inputs'] = params['INPUTS_IDS_DATASET'] extra_vars['dataset_outputs'] = params['OUTPUTS_IDS_DATASET'] extra_vars['normalize_probs'] = params.get('NORMALIZE_SAMPLING', False) extra_vars['search_pruning'] = params.get('SEARCH_PRUNING', False) extra_vars['alpha_factor'] = params.get('ALPHA_FACTOR', 1.0) extra_vars['coverage_penalty'] = params.get( 'COVERAGE_PENALTY', False) extra_vars['length_penalty'] = params.get('LENGTH_PENALTY', False) extra_vars['length_norm_factor'] = params.get( 'LENGTH_NORM_FACTOR', 0.0) extra_vars['coverage_norm_factor'] = params.get( 'COVERAGE_NORM_FACTOR', 0.0) extra_vars['pos_unk'] = params['POS_UNK'] extra_vars['output_max_length_depending_on_x'] = params.get( 'MAXLEN_GIVEN_X', True) extra_vars['output_max_length_depending_on_x_factor'] = params.get( 'MAXLEN_GIVEN_X_FACTOR', 3) extra_vars['output_min_length_depending_on_x'] = params.get( 'MINLEN_GIVEN_X', True) extra_vars['output_min_length_depending_on_x_factor'] = params.get( 'MINLEN_GIVEN_X_FACTOR', 2) if params['POS_UNK']: extra_vars['heuristic'] = params['HEURISTIC'] input_text_id = params['INPUTS_IDS_DATASET'][0] vocab_src = dataset.vocabulary[input_text_id]['idx2words'] if params['HEURISTIC'] > 0: extra_vars['mapping'] = dataset.mapping callback_metric = PrintPerformanceMetricOnEpochEndOrEachNUpdates( nmt_model, dataset, gt_id=params['OUTPUTS_IDS_DATASET'][0], metric_name=params['METRICS'], set_name=params['EVAL_ON_SETS'], batch_size=params['BATCH_SIZE'], each_n_epochs=params['EVAL_EACH'], extra_vars=extra_vars, reload_epoch=params['RELOAD'], is_text=True, input_text_id=input_text_id, save_path=nmt_model.model_path, index2word_y=vocab_y, index2word_x=vocab_x, sampling_type=params['SAMPLING'], beam_search=params['BEAM_SEARCH'], start_eval_on_epoch=params['START_EVAL_ON_EPOCH'], write_samples=True, write_type=params['SAMPLING_SAVE_MODE'], eval_on_epochs=params['EVAL_EACH_EPOCHS'], save_each_evaluation=False, verbose=params['VERBOSE'], no_ref=params['NO_REF']) callback_metric.evaluate( params['RELOAD'], counter_name='epoch' if params['EVAL_EACH_EPOCHS'] else 'update')