Пример #1
0
    def test_build(self):
        params = load_parameters()
        params['DATASET_STORE_PATH'] = './'
        params['REBUILD_DATASET'] = True
        dataset = build_dataset(params)
        params['INPUT_VOCABULARY_SIZE'] = \
            dataset.vocabulary_len[params['INPUTS_IDS_DATASET'][0]]
        params['OUTPUT_VOCABULARY_SIZE'] = \
            dataset.vocabulary_len[params['OUTPUTS_IDS_DATASET'][0]]
        for encoder_rnn_type in ['LSTM', 'GRU']:
            for decoder_rnn_type in [
                    'LSTM', 'GRU', 'ConditionalLSTM', 'ConditionalGRU'
            ]:
                params['ENCODER_RNN_TYPE'] = encoder_rnn_type
                params['DECODER_RNN_TYPE'] = decoder_rnn_type
                for n_layers in range(2):
                    params['N_LAYERS_DECODER'] = n_layers
                    params['N_LAYERS_ENCODER'] = n_layers
                    nmt_model = \
                        TranslationModel(params,
                                         model_type=params['MODEL_TYPE'],
                                         verbose=params['VERBOSE'],
                                         model_name=params['MODEL_NAME'],
                                         vocabularies=dataset.vocabulary,
                                         store_path=params['STORE_PATH'],
                                         clear_dirs=False)
                    self.assertIsInstance(nmt_model, Model_Wrapper)

        # Check Inputs
        inputMapping = dict()
        for i, id_in in enumerate(params['INPUTS_IDS_DATASET']):
            pos_source = dataset.ids_inputs.index(id_in)
            id_dest = nmt_model.ids_inputs[i]
            inputMapping[id_dest] = pos_source
        nmt_model.setInputsMapping(inputMapping)
        outputMapping = dict()
        for i, id_out in enumerate(params['OUTPUTS_IDS_DATASET']):
            pos_target = dataset.ids_outputs.index(id_out)
            id_dest = nmt_model.ids_outputs[i]
            outputMapping[id_dest] = pos_target
        nmt_model.setOutputsMapping(outputMapping)
        return True
Пример #2
0
            def test_train():
                params = load_parameters()
                params['REBUILD_DATASET'] = True
                params['DATASET_STORE_PATH'] = './'
                dataset = build_dataset(params)
                params['INPUT_VOCABULARY_SIZE'] = \
                    dataset.vocabulary_len[params['INPUTS_IDS_DATASET'][0]]
                params['OUTPUT_VOCABULARY_SIZE'] = \
                    dataset.vocabulary_len[params['OUTPUTS_IDS_DATASET'][0]]

                params['SOURCE_TEXT_EMBEDDING_SIZE'] = 2
                params['TARGET_TEXT_EMBEDDING_SIZE'] = 2
                params['ENCODER_HIDDEN_SIZE'] = 2
                params['DECODER_HIDDEN_SIZE'] = 2
                params['ATTENTION_SIZE'] = 2
                params['SKIP_VECTORS_HIDDEN_SIZE'] = 2
                params['DEEP_OUTPUT_LAYERS'] = [('linear', 2)]
                params['STORE_PATH'] = './'
                nmt_model = \
                    TranslationModel(params,
                                     model_type=params['MODEL_TYPE'],
                                     verbose=params['VERBOSE'],
                                     model_name=params['MODEL_NAME'],
                                     vocabularies=dataset.vocabulary,
                                     store_path=params['STORE_PATH'],
                                     clear_dirs=False)

                # Check Inputs
                inputMapping = dict()
                for i, id_in in enumerate(params['INPUTS_IDS_DATASET']):
                    pos_source = dataset.ids_inputs.index(id_in)
                    id_dest = nmt_model.ids_inputs[i]
                    inputMapping[id_dest] = pos_source
                nmt_model.setInputsMapping(inputMapping)
                outputMapping = dict()
                for i, id_out in enumerate(params['OUTPUTS_IDS_DATASET']):
                    pos_target = dataset.ids_outputs.index(id_out)
                    id_dest = nmt_model.ids_outputs[i]
                    outputMapping[id_dest] = pos_target
                nmt_model.setOutputsMapping(outputMapping)
                callbacks = buildCallbacks(params, nmt_model, dataset)
                training_params = {
                    'n_epochs': 1,
                    'batch_size': 50,
                    'homogeneous_batches': False,
                    'maxlen': 10,
                    'joint_batches': params['JOINT_BATCHES'],
                    'lr_decay': params['LR_DECAY'],
                    'lr_gamma': params['LR_GAMMA'],
                    'epochs_for_save': 1,
                    'verbose': params['VERBOSE'],
                    'eval_on_sets': params['EVAL_ON_SETS_KERAS'],
                    'n_parallel_loaders': params['PARALLEL_LOADERS'],
                    'extra_callbacks': callbacks,
                    'reload_epoch': 0,
                    'epoch_offset': 0,
                    'data_augmentation': False,
                    'patience': 1,  # early stopping parameters
                    'metric_check': 'Bleu_4',
                    'eval_on_epochs': True,
                    'each_n_epochs': 1,
                    'start_eval_on_epoch': 0
                }
                nmt_model.trainNet(dataset, training_params)
                return True
Пример #3
0
def train_model(params, load_dataset=None):
    """
    Training function. Sets the training parameters from params. Build or loads the model and launches the training.
    :param params: Dictionary of network hyperparameters.
    :param load_dataset: Load dataset from file or build it from the parameters.
    :return: None
    """
    check_params(params)

    if params['RELOAD'] > 0:
        logging.info('Resuming training.')
        # Load data
        if load_dataset is None:
            if params['REBUILD_DATASET']:
                logging.info('Rebuilding dataset.')
                dataset = build_dataset(params)
            else:
                logging.info('Updating dataset.')
                dataset = loadDataset(params['DATASET_STORE_PATH'] +
                                      '/Dataset_' + params['DATASET_NAME'] +
                                      '_' + params['SRC_LAN'] +
                                      params['TRG_LAN'] + '.pkl')
                params['EPOCH_OFFSET'] = params['RELOAD'] if params['RELOAD_EPOCH'] else \
                    int(params['RELOAD'] * params['BATCH_SIZE'] / dataset.len_train)
                for split, filename in params['TEXT_FILES'].iteritems():
                    dataset = update_dataset_from_file(
                        dataset,
                        params['DATA_ROOT_PATH'] + '/' + filename +
                        params['SRC_LAN'],
                        params,
                        splits=list([split]),
                        output_text_filename=params['DATA_ROOT_PATH'] + '/' +
                        filename + params['TRG_LAN'],
                        remove_outputs=False,
                        compute_state_below=True,
                        recompute_references=True)
                    dataset.name = params['DATASET_NAME'] + '_' + params[
                        'SRC_LAN'] + params['TRG_LAN']
                saveDataset(dataset, params['DATASET_STORE_PATH'])

        else:
            logging.info('Reloading and using dataset.')
            dataset = loadDataset(load_dataset)
    else:
        # Load data
        if load_dataset is None:
            dataset = build_dataset(params)
        else:
            dataset = loadDataset(load_dataset)

    params['INPUT_VOCABULARY_SIZE'] = dataset.vocabulary_len[
        params['INPUTS_IDS_DATASET'][0]]
    params['OUTPUT_VOCABULARY_SIZE'] = dataset.vocabulary_len[
        params['OUTPUTS_IDS_DATASET'][0]]

    # Build model
    set_optimizer = True if params['RELOAD'] == 0 else False
    clear_dirs = True if params['RELOAD'] == 0 else False

    # build new model
    nmt_model = TranslationModel(params,
                                 model_type=params['MODEL_TYPE'],
                                 verbose=params['VERBOSE'],
                                 model_name=params['MODEL_NAME'],
                                 vocabularies=dataset.vocabulary,
                                 store_path=params['STORE_PATH'],
                                 set_optimizer=set_optimizer,
                                 clear_dirs=clear_dirs)

    # Define the inputs and outputs mapping from our Dataset instance to our model
    inputMapping = dict()
    for i, id_in in enumerate(params['INPUTS_IDS_DATASET']):
        pos_source = dataset.ids_inputs.index(id_in)
        id_dest = nmt_model.ids_inputs[i]
        inputMapping[id_dest] = pos_source
    nmt_model.setInputsMapping(inputMapping)

    outputMapping = dict()
    for i, id_out in enumerate(params['OUTPUTS_IDS_DATASET']):
        pos_target = dataset.ids_outputs.index(id_out)
        id_dest = nmt_model.ids_outputs[i]
        outputMapping[id_dest] = pos_target
    nmt_model.setOutputsMapping(outputMapping)

    if params['RELOAD'] > 0:
        nmt_model = updateModel(nmt_model,
                                params['STORE_PATH'],
                                params['RELOAD'],
                                reload_epoch=params['RELOAD_EPOCH'])
        nmt_model.setParams(params)
        nmt_model.setOptimizer()
        if params.get('EPOCH_OFFSET') is None:
            params['EPOCH_OFFSET'] = params['RELOAD'] if params['RELOAD_EPOCH'] else \
                int(params['RELOAD'] * params['BATCH_SIZE'] / dataset.len_train)

    # Store configuration as pkl
    dict2pkl(params, params['STORE_PATH'] + '/config')

    # Callbacks
    callbacks = buildCallbacks(params, nmt_model, dataset)

    # Training
    total_start_time = timer()

    logger.debug('Starting training!')
    training_params = {
        'n_epochs':
        params['MAX_EPOCH'],
        'batch_size':
        params['BATCH_SIZE'],
        'homogeneous_batches':
        params['HOMOGENEOUS_BATCHES'],
        'maxlen':
        params['MAX_OUTPUT_TEXT_LEN'],
        'joint_batches':
        params['JOINT_BATCHES'],
        'lr_decay':
        params.get('LR_DECAY', None),  # LR decay parameters
        'reduce_each_epochs':
        params.get('LR_REDUCE_EACH_EPOCHS', True),
        'start_reduction_on_epoch':
        params.get('LR_START_REDUCTION_ON_EPOCH', 0),
        'lr_gamma':
        params.get('LR_GAMMA', 0.9),
        'lr_reducer_type':
        params.get('LR_REDUCER_TYPE', 'linear'),
        'lr_reducer_exp_base':
        params.get('LR_REDUCER_EXP_BASE', 0),
        'lr_half_life':
        params.get('LR_HALF_LIFE', 50000),
        'epochs_for_save':
        params['EPOCHS_FOR_SAVE'],
        'verbose':
        params['VERBOSE'],
        'eval_on_sets':
        params['EVAL_ON_SETS_KERAS'],
        'n_parallel_loaders':
        params['PARALLEL_LOADERS'],
        'extra_callbacks':
        callbacks,
        'reload_epoch':
        params['RELOAD'],
        'epoch_offset':
        params.get('EPOCH_OFFSET', 0),
        'data_augmentation':
        params['DATA_AUGMENTATION'],
        'patience':
        params.get('PATIENCE', 0),  # early stopping parameters
        'metric_check':
        params.get('STOP_METRIC', None)
        if params.get('EARLY_STOP', False) else None,
        'eval_on_epochs':
        params.get('EVAL_EACH_EPOCHS', True),
        'each_n_epochs':
        params.get('EVAL_EACH', 1),
        'start_eval_on_epoch':
        params.get('START_EVAL_ON_EPOCH', 0),
        'tensorboard':
        params.get('TENSORBOARD', False),
        'tensorboard_params': {
            'log_dir':
            params.get('LOG_DIR', 'tensorboard_logs'),
            'histogram_freq':
            params.get('HISTOGRAM_FREQ', 0),
            'batch_size':
            params.get('TENSORBOARD_BATCH_SIZE', params['BATCH_SIZE']),
            'write_graph':
            params.get('WRITE_GRAPH', True),
            'write_grads':
            params.get('WRITE_GRADS', False),
            'write_images':
            params.get('WRITE_IMAGES', False),
            'embeddings_freq':
            params.get('EMBEDDINGS_FREQ', 0),
            'embeddings_layer_names':
            params.get('EMBEDDINGS_LAYER_NAMES', None),
            'embeddings_metadata':
            params.get('EMBEDDINGS_METADATA', None),
            'label_word_embeddings_with_vocab':
            params.get('LABEL_WORD_EMBEDDINGS_WITH_VOCAB', False),
            'word_embeddings_labels':
            params.get('WORD_EMBEDDINGS_LABELS', None),
        }
    }
    nmt_model.trainNet(dataset, training_params)

    total_end_time = timer()
    time_difference = total_end_time - total_start_time
    logging.info('In total is {0:.2f}s = {1:.2f}m'.format(
        time_difference, time_difference / 60.0))
Пример #4
0
print(ds)

nmt_model = TranslationModel(params, 
    model_type='GroundHogModel',
    weights_path='trained_models/1024_Base/epoch_10_init.h5',
    model_name='1024_Trained_w2v_Base',
    vocabularies=ds.vocabulary,
    store_path='trained_models/1024_Trained_w2v_Base/',
    verbose=True)

inputMapping = dict()
for i, id_in in enumerate(params['INPUTS_IDS_DATASET']):
    pos_source = ds.ids_inputs.index(id_in)
    id_dest = nmt_model.ids_inputs[i]
    inputMapping[id_dest] = pos_source

nmt_model.setInputsMapping(inputMapping)
outputMapping = dict()
for i, id_out in enumerate(params['OUTPUTS_IDS_DATASET']):
    pos_target = ds.ids_outputs.index(id_out)
    id_dest = nmt_model.ids_outputs[i]
    outputMapping[id_dest] = pos_target
nmt_model.setOutputsMapping(outputMapping)

training_params = {'n_epochs': 12, 'batch_size': 20,'maxlen': 30, 'epochs_for_save': 1, 'verbose': 1, 'eval_on_sets': [], 'reload_epoch': 10, 'epoch_offset': 10}

print(ds)

nmt_model.trainNet(ds, training_params)
Пример #5
0
def train_model(params, load_dataset=None):
    """
    Training function. Sets the training parameters from params. Build or loads the model and launches the training.
    :param params: Dictionary of network hyperparameters.
    :param load_dataset: Load dataset from file or build it from the parameters.
    :return: None
    """

    if params['RELOAD'] > 0:
        logging.info('Resuming training.')

    check_params(params)

    # Load data
    if load_dataset is None:
        dataset = build_dataset(params)
    else:
        dataset = loadDataset(load_dataset)

    params['INPUT_VOCABULARY_SIZE'] = dataset.vocabulary_len[
        params['INPUTS_IDS_DATASET'][0]]
    params['OUTPUT_VOCABULARY_SIZE'] = dataset.vocabulary_len[
        params['OUTPUTS_IDS_DATASET'][0]]

    # Build model
    if params['RELOAD'] == 0:  # build new model
        nmt_model = TranslationModel(params,
                                     model_type=params['MODEL_TYPE'],
                                     verbose=params['VERBOSE'],
                                     model_name=params['MODEL_NAME'],
                                     vocabularies=dataset.vocabulary,
                                     store_path=params['STORE_PATH'])
        dict2pkl(params, params['STORE_PATH'] + '/config')

        # Define the inputs and outputs mapping from our Dataset instance to our model
        inputMapping = dict()
        for i, id_in in enumerate(params['INPUTS_IDS_DATASET']):
            pos_source = dataset.ids_inputs.index(id_in)
            id_dest = nmt_model.ids_inputs[i]
            inputMapping[id_dest] = pos_source
        nmt_model.setInputsMapping(inputMapping)

        outputMapping = dict()
        for i, id_out in enumerate(params['OUTPUTS_IDS_DATASET']):
            pos_target = dataset.ids_outputs.index(id_out)
            id_dest = nmt_model.ids_outputs[i]
            outputMapping[id_dest] = pos_target
        nmt_model.setOutputsMapping(outputMapping)

    else:  # resume from previously trained model
        nmt_model = TranslationModel(params,
                                     model_type=params['MODEL_TYPE'],
                                     verbose=params['VERBOSE'],
                                     model_name=params['MODEL_NAME'],
                                     vocabularies=dataset.vocabulary,
                                     store_path=params['STORE_PATH'],
                                     set_optimizer=False,
                                     clear_dirs=False)

        # Define the inputs and outputs mapping from our Dataset instance to our model
        inputMapping = dict()
        for i, id_in in enumerate(params['INPUTS_IDS_DATASET']):
            pos_source = dataset.ids_inputs.index(id_in)
            id_dest = nmt_model.ids_inputs[i]
            inputMapping[id_dest] = pos_source
        nmt_model.setInputsMapping(inputMapping)

        outputMapping = dict()
        for i, id_out in enumerate(params['OUTPUTS_IDS_DATASET']):
            pos_target = dataset.ids_outputs.index(id_out)
            id_dest = nmt_model.ids_outputs[i]
            outputMapping[id_dest] = pos_target

        nmt_model.setOutputsMapping(outputMapping)
        nmt_model = updateModel(nmt_model,
                                params['STORE_PATH'],
                                params['RELOAD'],
                                reload_epoch=params['RELOAD_EPOCH'])
        nmt_model.setParams(params)
        nmt_model.setOptimizer()
        params['EPOCH_OFFSET'] = params['RELOAD'] if params['RELOAD_EPOCH'] else \
            int(params['RELOAD'] * params['BATCH_SIZE'] / dataset.len_train)

    # Callbacks
    callbacks = buildCallbacks(params, nmt_model, dataset)

    # Training
    total_start_time = timer()

    logger.debug('Starting training!')
    training_params = {
        'n_epochs':
        params['MAX_EPOCH'],
        'batch_size':
        params['BATCH_SIZE'],
        'homogeneous_batches':
        params['HOMOGENEOUS_BATCHES'],
        'maxlen':
        params['MAX_OUTPUT_TEXT_LEN'],
        'joint_batches':
        params['JOINT_BATCHES'],
        'lr_decay':
        params['LR_DECAY'],
        'lr_gamma':
        params['LR_GAMMA'],
        'epochs_for_save':
        params['EPOCHS_FOR_SAVE'],
        'verbose':
        params['VERBOSE'],
        'eval_on_sets':
        params['EVAL_ON_SETS_KERAS'],
        'n_parallel_loaders':
        params['PARALLEL_LOADERS'],
        'extra_callbacks':
        callbacks,
        'reload_epoch':
        params['RELOAD'],
        'epoch_offset':
        params.get('EPOCH_OFFSET', 0),
        'data_augmentation':
        params['DATA_AUGMENTATION'],
        'patience':
        params.get('PATIENCE', 0),  # early stopping parameters
        'metric_check':
        params.get('STOP_METRIC', None)
        if params.get('EARLY_STOP', False) else None,
        'eval_on_epochs':
        params.get('EVAL_EACH_EPOCHS', True),
        'each_n_epochs':
        params.get('EVAL_EACH', 1),
        'start_eval_on_epoch':
        params.get('START_EVAL_ON_EPOCH', 0)
    }
    nmt_model.trainNet(dataset, training_params)

    total_end_time = timer()
    time_difference = total_end_time - total_start_time
    logging.info('In total is {0:.2f}s = {1:.2f}m'.format(
        time_difference, time_difference / 60.0))
def train_model(params,
                weights_dict,
                load_dataset=None,
                trainable_pred=True,
                trainable_est=True,
                weights_path=None):
    """
    Training function. Sets the training parameters from params. Build or loads the model and launches the training.
    :param params: Dictionary of network hyperparameters.
    :param load_dataset: Load dataset from file or build it from the parameters.
    :return: None
    """
    check_params(params)

    if params['RELOAD'] > 0:
        logging.info('Resuming training.')
        # Load data
        if load_dataset is None:
            if params['REBUILD_DATASET']:
                logging.info('Rebuilding dataset.')

                pred_vocab = params.get('PRED_VOCAB', None)
                if pred_vocab is not None:
                    dataset_voc = loadDataset(params['PRED_VOCAB'])
                    dataset = build_dataset(params, dataset_voc.vocabulary,
                                            dataset_voc.vocabulary_len)
                else:
                    dataset = build_dataset(params)
            else:
                logging.info('Updating dataset.')
                dataset = loadDataset(params['DATASET_STORE_PATH'] +
                                      '/Dataset_' + params['DATASET_NAME'] +
                                      '_' + params['SRC_LAN'] +
                                      params['TRG_LAN'] + '.pkl')

                for split, filename in params['TEXT_FILES'].iteritems():
                    dataset = update_dataset_from_file(
                        dataset,
                        params['DATA_ROOT_PATH'] + '/' + filename +
                        params['SRC_LAN'],
                        params,
                        splits=list([split]),
                        output_text_filename=params['DATA_ROOT_PATH'] + '/' +
                        filename + params['TRG_LAN'],
                        remove_outputs=False,
                        compute_state_below=True,
                        recompute_references=True)
                    dataset.name = params['DATASET_NAME'] + '_' + params[
                        'SRC_LAN'] + params['TRG_LAN']
                saveDataset(dataset, params['DATASET_STORE_PATH'])

        else:
            logging.info('Reloading and using dataset.')
            dataset = loadDataset(load_dataset)
    else:
        # Load data
        if load_dataset is None:
            pred_vocab = params.get('PRED_VOCAB', None)
            if pred_vocab is not None:
                dataset_voc = loadDataset(params['PRED_VOCAB'])
                # for the testing pharse handle model vocab differences
                #dataset_voc.vocabulary['target_text'] = dataset_voc.vocabulary['target']
                #dataset_voc.vocabulary_len['target_text'] = dataset_voc.vocabulary_len['target']
                dataset = build_dataset(params, dataset_voc.vocabulary,
                                        dataset_voc.vocabulary_len)
            else:
                dataset = build_dataset(params)
        else:
            dataset = loadDataset(load_dataset)

    params['INPUT_VOCABULARY_SIZE'] = dataset.vocabulary_len[
        params['INPUTS_IDS_DATASET'][0]]
    #params['OUTPUT_VOCABULARY_SIZE'] = dataset.vocabulary_len[params['OUTPUTS_IDS_DATASET_FULL'][0]]
    params['OUTPUT_VOCABULARY_SIZE'] = dataset.vocabulary_len['target_text']

    # Build model
    if params['RELOAD'] == 0:  # build new model
        nmt_model = TranslationModel(params,
                                     model_type=params['MODEL_TYPE'],
                                     verbose=params['VERBOSE'],
                                     model_name=params['MODEL_NAME'],
                                     vocabularies=dataset.vocabulary,
                                     store_path=params['STORE_PATH'],
                                     trainable_pred=trainable_pred,
                                     trainable_est=trainable_est,
                                     clear_dirs=True,
                                     weights_path=weights_path)

        # Define the inputs and outputs mapping from our Dataset instance to our model
        inputMapping = dict()
        for i, id_in in enumerate(params['INPUTS_IDS_DATASET']):
            pos_source = dataset.ids_inputs.index(id_in)
            id_dest = nmt_model.ids_inputs[i]
            inputMapping[id_dest] = pos_source
        nmt_model.setInputsMapping(inputMapping)

        outputMapping = dict()
        for i, id_out in enumerate(params['OUTPUTS_IDS_DATASET']):
            pos_target = dataset.ids_outputs.index(id_out)
            id_dest = nmt_model.ids_outputs[i]
            outputMapping[id_dest] = pos_target
        nmt_model.setOutputsMapping(outputMapping)

    else:  # resume from previously trained model
        nmt_model = TranslationModel(params,
                                     model_type=params['MODEL_TYPE'],
                                     verbose=params['VERBOSE'],
                                     model_name=params['MODEL_NAME'],
                                     vocabularies=dataset.vocabulary,
                                     store_path=params['STORE_PATH'],
                                     set_optimizer=False,
                                     trainable_pred=trainable_pred,
                                     trainable_est=trainable_est,
                                     weights_path=weights_path)

        # Define the inputs and outputs mapping from our Dataset instance to our model
        inputMapping = dict()
        for i, id_in in enumerate(params['INPUTS_IDS_DATASET']):
            pos_source = dataset.ids_inputs.index(id_in)
            id_dest = nmt_model.ids_inputs[i]
            inputMapping[id_dest] = pos_source
        nmt_model.setInputsMapping(inputMapping)

        outputMapping = dict()
        for i, id_out in enumerate(params['OUTPUTS_IDS_DATASET']):
            pos_target = dataset.ids_outputs.index(id_out)
            id_dest = nmt_model.ids_outputs[i]
            outputMapping[id_dest] = pos_target

        nmt_model.setOutputsMapping(outputMapping)
        nmt_model = updateModel(nmt_model,
                                params['STORE_PATH'],
                                params['RELOAD'],
                                reload_epoch=params['RELOAD_EPOCH'])
        nmt_model.setParams(params)
        nmt_model.setOptimizer()
        params['EPOCH_OFFSET'] = params['RELOAD'] if params['RELOAD_EPOCH'] else \
            int(params['RELOAD'] * params['BATCH_SIZE'] / dataset.len_train)

    # Store configuration as pkl
    dict2pkl(params, params['STORE_PATH'] + '/config')

    # Callbacks
    callbacks = buildCallbacks(params, nmt_model, dataset)

    # Training
    total_start_time = timer()

    logger.debug('Starting training!')
    training_params = {
        'n_epochs':
        params['MAX_EPOCH'],
        'batch_size':
        params['BATCH_SIZE'],
        'homogeneous_batches':
        params['HOMOGENEOUS_BATCHES'],
        'maxlen':
        params['MAX_OUTPUT_TEXT_LEN'],
        'joint_batches':
        params['JOINT_BATCHES'],
        'lr_decay':
        params.get('LR_DECAY', None),  # LR decay parameters
        'reduce_each_epochs':
        params.get('LR_REDUCE_EACH_EPOCHS', True),
        'start_reduction_on_epoch':
        params.get('LR_START_REDUCTION_ON_EPOCH', 0),
        'lr_gamma':
        params.get('LR_GAMMA', 0.9),
        'lr_reducer_type':
        params.get('LR_REDUCER_TYPE', 'linear'),
        'lr_reducer_exp_base':
        params.get('LR_REDUCER_EXP_BASE', 0),
        'lr_half_life':
        params.get('LR_HALF_LIFE', 50000),
        'epochs_for_save':
        params['EPOCHS_FOR_SAVE'],
        'verbose':
        params['VERBOSE'],
        'eval_on_sets':
        params['EVAL_ON_SETS_KERAS'],
        'n_parallel_loaders':
        params['PARALLEL_LOADERS'],
        'extra_callbacks':
        callbacks,
        'reload_epoch':
        params['RELOAD'],
        'epoch_offset':
        params.get('EPOCH_OFFSET', 0),
        'data_augmentation':
        params['DATA_AUGMENTATION'],
        'patience':
        params.get('PATIENCE', 0),  # early stopping parameters
        'metric_check':
        params.get('STOP_METRIC', None)
        if params.get('EARLY_STOP', False) else None,
        'eval_on_epochs':
        params.get('EVAL_EACH_EPOCHS', True),
        'each_n_epochs':
        params.get('EVAL_EACH', 1),
        'start_eval_on_epoch':
        params.get('START_EVAL_ON_EPOCH', 0)
    }
    if weights_dict is not None:
        for layer in nmt_model.model.layers:
            if layer.name in weights_dict:
                layer.set_weights(weights_dict[layer.name])

    nmt_model.trainNet(dataset, training_params)

    if weights_dict is not None:
        for layer in nmt_model.model.layers:
            weights_dict[layer.name] = layer.get_weights()

    total_end_time = timer()
    time_difference = total_end_time - total_start_time
    logging.info('In total is {0:.2f}s = {1:.2f}m'.format(
        time_difference, time_difference / 60.0))
def apply_NMT_model(params, load_dataset=None):
    """
    Sample from a previously trained model.

    :param params: Dictionary of network hyperparameters.
    :param load_dataset: Load dataset from file or build it from the parameters.
    :return: None
    """
    pred_vocab = params.get('PRED_VOCAB', None)
    if pred_vocab is not None:
        dataset_voc = loadDataset(params['PRED_VOCAB'])
        dataset = build_dataset(params, dataset_voc.vocabulary,
                                dataset_voc.vocabulary_len)
    else:
        dataset = build_dataset(params)
    # Load data
    #if load_dataset is None:
    #    dataset = build_dataset(params)
    #else:
    #    dataset = loadDataset(load_dataset)
    #params['INPUT_VOCABULARY_SIZE'] = dataset.vocabulary_len[params['INPUTS_IDS_DATASET'][0]]
    #params['OUTPUT_VOCABULARY_SIZE'] = dataset.vocabulary_len[params['OUTPUTS_IDS_DATASET'][0]]
    #vocab_y = dataset.vocabulary[params['INPUTS_IDS_DATASET'][1]]['idx2words']
    params['INPUT_VOCABULARY_SIZE'] = dataset.vocabulary_len[
        params['INPUTS_IDS_DATASET'][0]]
    params['OUTPUT_VOCABULARY_SIZE'] = dataset.vocabulary_len['target_text']

    # Load model
    #nmt_model = loadModel(params['STORE_PATH'], params['RELOAD'], reload_epoch=params['RELOAD_EPOCH'])
    nmt_model = TranslationModel(params,
                                 model_type=params['MODEL_TYPE'],
                                 verbose=params['VERBOSE'],
                                 model_name=params['MODEL_NAME'],
                                 set_optimizer=False,
                                 vocabularies=dataset.vocabulary,
                                 store_path=params['STORE_PATH'],
                                 trainable_pred=True,
                                 trainable_est=True,
                                 weights_path=None)
    nmt_model = updateModel(nmt_model,
                            params['STORE_PATH'],
                            params['RELOAD'],
                            reload_epoch=params['RELOAD_EPOCH'])
    nmt_model.setParams(params)
    nmt_model.setOptimizer()

    inputMapping = dict()
    for i, id_in in enumerate(params['INPUTS_IDS_DATASET']):
        pos_source = dataset.ids_inputs.index(id_in)
        id_dest = nmt_model.ids_inputs[i]
        inputMapping[id_dest] = pos_source
    nmt_model.setInputsMapping(inputMapping)

    outputMapping = dict()
    for i, id_out in enumerate(params['OUTPUTS_IDS_DATASET']):
        pos_target = dataset.ids_outputs.index(id_out)
        id_dest = nmt_model.ids_outputs[i]
        outputMapping[id_dest] = pos_target
    nmt_model.setOutputsMapping(outputMapping)
    nmt_model.setOptimizer()

    for s in params["EVAL_ON_SETS"]:
        # Evaluate training
        extra_vars = {
            'language': params.get('TRG_LAN', 'en'),
            'n_parallel_loaders': params['PARALLEL_LOADERS'],
            'tokenize_f': eval('dataset.' + params['TOKENIZATION_METHOD']),
            'detokenize_f': eval('dataset.' + params['DETOKENIZATION_METHOD']),
            'apply_detokenization': params['APPLY_DETOKENIZATION'],
            'tokenize_hypotheses': params['TOKENIZE_HYPOTHESES'],
            'tokenize_references': params['TOKENIZE_REFERENCES']
        }
        #vocab = dataset.vocabulary[params['OUTPUTS_IDS_DATASET'][0]]['idx2words']
        #vocab = dataset.vocabulary[params['INPUTS_IDS_DATASET'][1]]['idx2words']
        extra_vars[s] = dict()
        if not params.get('NO_REF', False):
            extra_vars[s]['references'] = dataset.extra_variables[s][
                params['OUTPUTS_IDS_DATASET'][0]]
        #input_text_id = None
        #vocab_src = None
        input_text_id = params['INPUTS_IDS_DATASET'][0]
        vocab_x = dataset.vocabulary[input_text_id]['idx2words']
        vocab_y = dataset.vocabulary[params['INPUTS_IDS_DATASET']
                                     [1]]['idx2words']

        if params['BEAM_SEARCH']:
            extra_vars['beam_size'] = params.get('BEAM_SIZE', 6)
            extra_vars['state_below_index'] = params.get(
                'BEAM_SEARCH_COND_INPUT', -1)
            extra_vars['maxlen'] = params.get('MAX_OUTPUT_TEXT_LEN_TEST', 30)
            extra_vars['optimized_search'] = params.get(
                'OPTIMIZED_SEARCH', True)
            extra_vars['model_inputs'] = params['INPUTS_IDS_MODEL']
            extra_vars['model_outputs'] = params['OUTPUTS_IDS_MODEL']
            extra_vars['dataset_inputs'] = params['INPUTS_IDS_DATASET']
            extra_vars['dataset_outputs'] = params['OUTPUTS_IDS_DATASET']
            extra_vars['normalize_probs'] = params.get('NORMALIZE_SAMPLING',
                                                       False)
            extra_vars['search_pruning'] = params.get('SEARCH_PRUNING', False)
            extra_vars['alpha_factor'] = params.get('ALPHA_FACTOR', 1.0)
            extra_vars['coverage_penalty'] = params.get(
                'COVERAGE_PENALTY', False)
            extra_vars['length_penalty'] = params.get('LENGTH_PENALTY', False)
            extra_vars['length_norm_factor'] = params.get(
                'LENGTH_NORM_FACTOR', 0.0)
            extra_vars['coverage_norm_factor'] = params.get(
                'COVERAGE_NORM_FACTOR', 0.0)
            extra_vars['pos_unk'] = params['POS_UNK']
            extra_vars['output_max_length_depending_on_x'] = params.get(
                'MAXLEN_GIVEN_X', True)
            extra_vars['output_max_length_depending_on_x_factor'] = params.get(
                'MAXLEN_GIVEN_X_FACTOR', 3)
            extra_vars['output_min_length_depending_on_x'] = params.get(
                'MINLEN_GIVEN_X', True)
            extra_vars['output_min_length_depending_on_x_factor'] = params.get(
                'MINLEN_GIVEN_X_FACTOR', 2)

            if params['POS_UNK']:
                extra_vars['heuristic'] = params['HEURISTIC']
                input_text_id = params['INPUTS_IDS_DATASET'][0]
                vocab_src = dataset.vocabulary[input_text_id]['idx2words']
                if params['HEURISTIC'] > 0:
                    extra_vars['mapping'] = dataset.mapping

        callback_metric = PrintPerformanceMetricOnEpochEndOrEachNUpdates(
            nmt_model,
            dataset,
            gt_id=params['OUTPUTS_IDS_DATASET'][0],
            metric_name=params['METRICS'],
            set_name=params['EVAL_ON_SETS'],
            batch_size=params['BATCH_SIZE'],
            each_n_epochs=params['EVAL_EACH'],
            extra_vars=extra_vars,
            reload_epoch=params['RELOAD'],
            is_text=True,
            input_text_id=input_text_id,
            save_path=nmt_model.model_path,
            index2word_y=vocab_y,
            index2word_x=vocab_x,
            sampling_type=params['SAMPLING'],
            beam_search=params['BEAM_SEARCH'],
            start_eval_on_epoch=params['START_EVAL_ON_EPOCH'],
            write_samples=True,
            write_type=params['SAMPLING_SAVE_MODE'],
            eval_on_epochs=params['EVAL_EACH_EPOCHS'],
            save_each_evaluation=False,
            verbose=params['VERBOSE'],
            no_ref=params['NO_REF'])

        callback_metric.evaluate(
            params['RELOAD'],
            counter_name='epoch' if params['EVAL_EACH_EPOCHS'] else 'update')