Exemplo n.º 1
0
    def __init__(self, config):

        print('The model is built for training:', config['IS_TRAIN'])

        self.train_mode = 0

        self.rl_enable = config['RL_ENABLE']
        self.bleu_enable = config['BLEU_RL_ENABLE']
        self.learning_rate = tf.Variable(config['LR'],
                                         dtype=tf.float32,
                                         name='model_learning_rate',
                                         trainable=False)
        self.word_embedding_learning_rate = tf.Variable(
            config['WE_LR'],
            dtype=tf.float32,
            name='model_we_learning_rate',
            trainable=False)
        self.encoder_learning_rate = tf.Variable(
            config['ENCODER_LR'],
            dtype=tf.float32,
            name='model_enc_learning_rate',
            trainable=False)
        self.decoder_learning_rate = tf.Variable(
            config['DECODER_LR'],
            dtype=tf.float32,
            name='model_dec_learning_rate',
            trainable=False)
        if config['SPLIT_LR']:

            def tmp_func():
                self.word_embedding_learning_rate.assign(
                    self.word_embedding_learning_rate * config['LR_DECAY'])
                self.encoder_learning_rate.assign(self.encoder_learning_rate *
                                                  config['LR_DECAY'])
                self.decoder_learning_rate.assign(self.decoder_learning_rate *
                                                  config['LR_DECAY'])

            self.lr_decay_op = tmp_func()
        else:
            self.lr_decay_op = self.learning_rate.assign(self.learning_rate *
                                                         config['LR_DECAY'])
        self.lr_reset_op = self.learning_rate.assign(config['LR'])

        if config['OPTIMIZER'] == 'Adam':
            self.optimizer = tf.train.AdamOptimizer
        elif config['OPTIMIZER'] == 'GD':
            self.optimizer = tf.train.GradientDescentOptimizer
        else:
            raise Exception("Wrong optimizer name...")

        self.global_step = tf.Variable(config['GLOBAL_STEP'],
                                       dtype=tf.int32,
                                       name='model_global_step',
                                       trainable=False)
        self.batch_size = config['BATCH_SIZE']
        self.input_size_1 = config['INPUT_VOCAB_SIZE']
        self.input_size_2 = config['OUTPUT_VOCAB_SIZE']
        self.output_size_1 = config['INPUT_VOCAB_SIZE']
        self.output_size_2 = config['OUTPUT_VOCAB_SIZE']
        self.encoder_hidden_size = config['ENCODER_HIDDEN_SIZE']
        self.decoder_hidden_size = config['DECODER_HIDDEN_SIZE']
        self.embedding_size = config['WORD_EMBEDDING_SIZE']

        self.encoder_inputs = tf.placeholder(dtype=tf.int32,
                                             shape=(None, self.batch_size),
                                             name='encoder_inputs')
        self.encoder_inputs_length = tf.placeholder(
            dtype=tf.int32,
            shape=(self.batch_size, ),
            name='encoder_inputs_length')
        self.encoder_inputs_mask = tf.placeholder(dtype=tf.float32,
                                                  shape=(self.batch_size,
                                                         None),
                                                  name='encoder_inputs_mask')
        self.decoder_inputs = tf.placeholder(dtype=tf.int32,
                                             shape=(None, self.batch_size),
                                             name='decoder_inputs')
        self.decoder_inputs_length = tf.placeholder(
            dtype=tf.int32,
            shape=(self.batch_size, ),
            name='decoder_inputs_length')
        self.decoder_inputs_mask = tf.placeholder(dtype=tf.float32,
                                                  shape=(self.batch_size,
                                                         None),
                                                  name='decoder_inputs_mask')
        self.decoder_targets = tf.placeholder(dtype=tf.int32,
                                              shape=(None, self.batch_size),
                                              name='decoder_targets')
        self.decoder_targets_length = tf.placeholder(
            dtype=tf.int32,
            shape=(self.batch_size, ),
            name='decoder_targets_length')
        self.decoder_targets_mask = tf.placeholder(dtype=tf.float32,
                                                   shape=(self.batch_size,
                                                          None),
                                                   name='decoder_targets_mask')

        with tf.variable_scope("DynamicEncoder_1") as scope:
            self.input_word_embedding_matrix_1 = modelInitWordEmbedding(
                self.input_size_1, self.embedding_size, name='we_input_1')
            self.encoder_inputs_embedded_1 = modelGetWordEmbedding(
                self.input_word_embedding_matrix_1, self.encoder_inputs)

            self.encoder_cell_1 = modelInitRNNCells(self.encoder_hidden_size,
                                                    config['ENCODER_LAYERS'],
                                                    config['CELL'],
                                                    config['INPUT_DROPOUT'],
                                                    config['OUTPUT_DROPOUT'])
            if config['BIDIRECTIONAL_ENCODER']:
                self.encoder_outputs_1, self.encoder_state_1 = modelInitBidirectionalEncoder(
                    self.encoder_cell_1,
                    self.encoder_inputs_embedded_1,
                    self.encoder_inputs_length,
                    encoder_type='stack')
            else:
                self.encoder_outputs_1, self.encoder_state_1 = modelInitUndirectionalEncoder(
                    self.encoder_cell_1, self.encoder_inputs_embedded_1,
                    self.encoder_inputs_length)

            if config['USE_BS'] and not config['IS_TRAIN']:
                self.encoder_state_1 = seq2seq.tile_batch(
                    self.encoder_state_1, config['BEAM_WIDTH'])
                self.encoder_outputs_1 = tf.transpose(
                    seq2seq.tile_batch(
                        tf.transpose(self.encoder_outputs_1, [1, 0, 2]),
                        config['BEAM_WIDTH']), [1, 0, 2])

            # print('Encoder Trainable Variables')
            self.encoder_1_variables = scope.trainable_variables()
            # print(self.encoder_variables)

        with tf.variable_scope("DynamicEncoder_2") as scope:
            self.input_word_embedding_matrix_2 = modelInitWordEmbedding(
                self.input_size_2, self.embedding_size, name='we_input_2')
            self.encoder_inputs_embedded_2 = modelGetWordEmbedding(
                self.input_word_embedding_matrix_2, self.encoder_inputs)

            self.encoder_cell_2 = modelInitRNNCells(self.encoder_hidden_size,
                                                    config['ENCODER_LAYERS'],
                                                    config['CELL'],
                                                    config['INPUT_DROPOUT'],
                                                    config['OUTPUT_DROPOUT'])
            if config['BIDIRECTIONAL_ENCODER']:
                self.encoder_outputs_2, self.encoder_state_2 = modelInitBidirectionalEncoder(
                    self.encoder_cell_2,
                    self.encoder_inputs_embedded_2,
                    self.encoder_inputs_length,
                    encoder_type='stack')
            else:
                self.encoder_outputs_2, self.encoder_state_2 = modelInitUndirectionalEncoder(
                    self.encoder_cell_2, self.encoder_inputs_embedded_2,
                    self.encoder_inputs_length)

            if config['USE_BS'] and not config['IS_TRAIN']:
                self.encoder_state_2 = seq2seq.tile_batch(
                    self.encoder_state_2, config['BEAM_WIDTH'])
                self.encoder_outputs_2 = tf.transpose(
                    seq2seq.tile_batch(
                        tf.transpose(self.encoder_outputs_2, [1, 0, 2]),
                        config['BEAM_WIDTH']), [1, 0, 2])

            # print('Encoder Trainable Variables')
            self.encoder_2_variables = scope.trainable_variables()
            # print(self.encoder_variables)

        self.encoder_inputs_length_att = self.encoder_inputs_length
        if config['USE_BS'] and not config['IS_TRAIN']:
            self.encoder_inputs_length_att = seq2seq.tile_batch(
                self.encoder_inputs_length_att, config['BEAM_WIDTH'])

        self.encoder_outputs_list = [
            self.encoder_outputs_1, self.encoder_outputs_2
        ]
        self.encoder_state_list = [self.encoder_state_1, self.encoder_state_2]
        ru = False
        self.train_loss = []
        self.train_loss_rl = []
        self.final_loss = []
        self.eval_loss = []
        self.infer_outputs_all = []
        for mode in range(6):
            with tf.variable_scope(tf.get_variable_scope(),
                                   reuse=True if ru else None):
                ru = True
                self.encoder_outputs = self.encoder_outputs_list[mode % 2]
                self.encoder_state = self.encoder_state_list[mode % 2]

                with tf.variable_scope("DynamicDecoder_1") as scope:
                    self.output_word_embedding_matrix_1 = modelInitWordEmbedding(
                        self.output_size_1,
                        self.embedding_size,
                        name='we_output_1')
                    self.decoder_inputs_embedded_1 = modelGetWordEmbedding(
                        self.output_word_embedding_matrix_1,
                        self.decoder_inputs)

                    self.decoder_cell_1 = modelInitRNNCells(
                        self.decoder_hidden_size, config['DECODER_LAYERS'],
                        config['CELL'], config['INPUT_DROPOUT'],
                        config['OUTPUT_DROPOUT'])
                    if config['ATTENTION_DECODER']:
                        self.decoder_cell_1 = modelInitAttentionDecoderCell(
                            self.decoder_cell_1,
                            self.decoder_hidden_size,
                            self.encoder_outputs,
                            self.encoder_inputs_length_att,
                            att_type=config['ATTENTION_MECHANISE'],
                            wrapper_type='whole')
                    else:
                        self.decoder_cell_1 = modelInitRNNDecoderCell(
                            self.decoder_cell)

                    initial_state_1 = None

                    if config['USE_BS'] and not config['IS_TRAIN']:
                        initial_state_1 = self.decoder_cell_1.zero_state(
                            batch_size=self.batch_size * config['BEAM_WIDTH'],
                            dtype=tf.float32)
                        if config['ATTENTION_DECODER']:
                            cat_state = tuple(
                                [self.encoder_state] +
                                list(initial_state_1.cell_state)[:-1])
                            initial_state_1.clone(cell_state=cat_state)
                        else:
                            initial_state_1 = tuple([self.encoder_state] +
                                                    list(initial_state_1[:-1]))
                    else:
                        initial_state_1 = self.decoder_cell_1.zero_state(
                            batch_size=self.batch_size, dtype=tf.float32)

                        if config['ATTENTION_DECODER']:
                            cat_state = tuple(
                                [self.encoder_state] +
                                list(initial_state_1.cell_state)[:-1])
                            initial_state_1.clone(cell_state=cat_state)
                        else:
                            initial_state_1 = tuple([self.encoder_state] +
                                                    list(initial_state_1[:-1]))

                    self.output_projection_layer_1 = layers_core.Dense(
                        self.output_size_1, use_bias=False)
                    if config['IS_TRAIN']:
                        self.train_outputs_1 = modelInitDecoderForTrain(
                            self.decoder_cell_1,
                            self.decoder_inputs_embedded_1,
                            self.decoder_inputs_length, initial_state_1,
                            self.output_projection_layer_1)
                        self.blind_outputs_1 = modelInitDecoderForBlindTrain(
                            self.decoder_cell_1,
                            self.decoder_inputs_embedded_1,
                            self.decoder_inputs_length,
                            self.output_word_embedding_matrix_1,
                            initial_state_1, self.output_projection_layer_1)
                    if config['USE_BS'] and not config['IS_TRAIN']:
                        self.infer_outputs_1 = modelInitDecoderForBSInfer(
                            self.decoder_cell_1, self.decoder_inputs[0],
                            self.output_word_embedding_matrix_1,
                            config['BEAM_WIDTH'], config['ID_END_1'],
                            config['MAX_OUT_LEN'], initial_state_1,
                            self.output_projection_layer_1)
                    else:
                        self.infer_outputs_1 = modelInitDecoderForGreedyInfer(
                            self.decoder_cell_1, self.decoder_inputs[0],
                            self.output_word_embedding_matrix_1,
                            config['ID_END_1'], config['MAX_OUT_LEN'],
                            initial_state_1, self.output_projection_layer_1)

                with tf.variable_scope("DynamicDecoder_2") as scope:
                    self.output_word_embedding_matrix_2 = modelInitWordEmbedding(
                        self.output_size_2,
                        self.embedding_size,
                        name='we_output_2')
                    self.decoder_inputs_embedded_2 = modelGetWordEmbedding(
                        self.output_word_embedding_matrix_2,
                        self.decoder_inputs)

                    self.decoder_cell_2 = modelInitRNNCells(
                        self.decoder_hidden_size, config['DECODER_LAYERS'],
                        config['CELL'], config['INPUT_DROPOUT'],
                        config['OUTPUT_DROPOUT'])
                    if config['ATTENTION_DECODER']:
                        self.decoder_cell_2 = modelInitAttentionDecoderCell(
                            self.decoder_cell_2,
                            self.decoder_hidden_size,
                            self.encoder_outputs,
                            self.encoder_inputs_length_att,
                            att_type=config['ATTENTION_MECHANISE'],
                            wrapper_type='whole')
                    else:
                        self.decoder_cell_2 = modelInitRNNDecoderCell(
                            self.decoder_cell)

                    initial_state_2 = None

                    if config['USE_BS'] and not config['IS_TRAIN']:
                        initial_state_2 = self.decoder_cell_2.zero_state(
                            batch_size=self.batch_size * config['BEAM_WIDTH'],
                            dtype=tf.float32)
                        if config['ATTENTION_DECODER']:
                            cat_state = tuple(
                                [self.encoder_state] +
                                list(initial_state_2.cell_state)[:-1])
                            initial_state_2.clone(cell_state=cat_state)
                        else:
                            initial_state_2 = tuple([self.encoder_state] +
                                                    list(initial_state_2[:-1]))
                    else:
                        initial_state_2 = self.decoder_cell_2.zero_state(
                            batch_size=self.batch_size, dtype=tf.float32)

                        if config['ATTENTION_DECODER']:
                            cat_state = tuple(
                                [self.encoder_state] +
                                list(initial_state_2.cell_state)[:-1])
                            initial_state_2.clone(cell_state=cat_state)
                        else:
                            initial_state_2 = tuple([self.encoder_state] +
                                                    list(initial_state_2[:-1]))

                    self.output_projection_layer_2 = layers_core.Dense(
                        self.output_size_2, use_bias=False)
                    if config['IS_TRAIN']:
                        self.train_outputs_2 = modelInitDecoderForTrain(
                            self.decoder_cell_2,
                            self.decoder_inputs_embedded_2,
                            self.decoder_inputs_length, initial_state_2,
                            self.output_projection_layer_2)
                        self.blind_outputs_2 = modelInitDecoderForBlindTrain(
                            self.decoder_cell_2,
                            self.decoder_inputs_embedded_2,
                            self.decoder_inputs_length,
                            self.output_word_embedding_matrix_2,
                            initial_state_2, self.output_projection_layer_2)
                    if config['USE_BS'] and not config['IS_TRAIN']:
                        self.infer_outputs_2 = modelInitDecoderForBSInfer(
                            self.decoder_cell_2, self.decoder_inputs[0],
                            self.output_word_embedding_matrix_2,
                            config['BEAM_WIDTH'], config['ID_END_2'],
                            config['MAX_OUT_LEN'], initial_state_2,
                            self.output_projection_layer_2)
                    else:
                        self.infer_outputs_2 = modelInitDecoderForGreedyInfer(
                            self.decoder_cell_2, self.decoder_inputs[0],
                            self.output_word_embedding_matrix_2,
                            config['ID_END_2'], config['MAX_OUT_LEN'],
                            initial_state_2, self.output_projection_layer_2)

                if config['IS_TRAIN']:
                    outputs2use = None
                    if mode in [1, 2, 5]:
                        self.train_loss.append(
                            seq2seq.sequence_loss(
                                logits=self.train_outputs_1,
                                targets=tf.transpose(self.decoder_targets,
                                                     perm=[1, 0]),
                                weights=self.decoder_targets_mask))
                        outputs2use = self.train_outputs_1
                        if mode == 5:
                            self.train_loss[-1] = tf.constant(0.0)
                            outputs2use = self.blind_outputs_1
                        self.rewards = tf.py_func(LMScore, [
                            outputs2use,
                            tf.constant(config['LM_MODEL_X'], dtype=tf.string),
                            tf.constant(config['SRC_DICT'], dtype=tf.string)
                        ], tf.float32)
                    else:
                        self.train_loss.append(
                            seq2seq.sequence_loss(
                                logits=self.train_outputs_2,
                                targets=tf.transpose(self.decoder_targets,
                                                     perm=[1, 0]),
                                weights=self.decoder_targets_mask))
                        outputs2use = self.train_outputs_2
                        if mode == 4:
                            self.train_loss[-1] = tf.constant(0.0)
                            outputs2use = self.blind_outputs_2
                        self.rewards = tf.py_func(LMScore, [
                            outputs2use,
                            tf.constant(config['LM_MODEL_Y'], dtype=tf.string),
                            tf.constant(config['DST_DICT'], dtype=tf.string)
                        ], tf.float32)
                        self.rewards.set_shape(outputs2use.get_shape())

                    self.train_loss_rl.append(
                        rlloss.sequence_loss_rl(
                            logits=outputs2use,
                            rewards=self.rewards,
                            weights=self.decoder_targets_mask))
                    # if mode in [0,1,2,3]:
                    #     self.train_loss_rl[-1]=tf.constant(0.0)
                    # if mode in [0,1]:
                    #     self.train_loss_rl[-1]=tf.constant(0.0)
                    # if mode == 2:
                    #     self.train_loss_rl[-1]=tf.constant(0.0)
                    # self.train_loss_rl[-1] *= 0.1

                    self.eval_loss.append(
                        seq2seq.sequence_loss(
                            logits=outputs2use,
                            targets=tf.transpose(self.decoder_targets,
                                                 perm=[1, 0]),
                            weights=self.decoder_targets_mask))

                    self.final_loss.append(self.train_loss[-1] +
                                           self.train_loss_rl[-1])
                if mode in [1, 2, 5]:
                    self.infer_outputs_all.append(self.infer_outputs_1)
                else:
                    self.infer_outputs_all.append(self.infer_outputs_2)

                # print('Decoder Trainable Variables')
                self.decoder_variables = scope.trainable_variables()
                # print(self.decoder_variables)

        print('All Trainable Variables:')
        self.all_trainable_variables = tf.trainable_variables()
        print(self.all_trainable_variables)
        if config['IS_TRAIN']:
            self.train_op = []
            for mode in range(6):
                self.train_op.append(
                    updateBP(self.final_loss[mode], [self.learning_rate],
                             [self.all_trainable_variables],
                             self.optimizer,
                             norm=config['CLIP_NORM']))
        self.saver = initSaver(tf.global_variables(), config['MAX_TO_KEEP'])
Exemplo n.º 2
0
    def __init__(self, config):

        print('The model is built for training:', config['IS_TRAIN'])

        self.train_mode = 0

        self.learning_rate = tf.Variable(config['LR'],
                                         dtype=tf.float32,
                                         name='model_learning_rate',
                                         trainable=False)
        self.lr_decay_op = self.learning_rate.assign(self.learning_rate *
                                                     config['LR_DECAY'])
        self.lr_reset_op = self.learning_rate.assign(config['LR'])

        if config['OPTIMIZER'] == 'Adam':
            self.optimizer = tf.train.AdamOptimizer
        elif config['OPTIMIZER'] == 'GD':
            self.optimizer = tf.train.GradientDescentOptimizer
        else:
            raise Exception("Wrong optimizer name...")

        self.global_step = tf.Variable(config['GLOBAL_STEP'],
                                       dtype=tf.int32,
                                       name='model_global_step',
                                       trainable=False)
        self.batch_size = config['BATCH_SIZE']
        self.max_len = config['MAX_OUT_LEN']
        self.input_sizes = config['MODELS_INPUT_VOCAB_SIZES']
        self.output_sizes = config['MODELS_OUTPUT_VOCAB_SIZES']
        self.encoder_hidden_size = config['ENCODER_HIDDEN_SIZE']
        self.decoder_hidden_size = config['DECODER_HIDDEN_SIZE']
        self.embedding_size = config['WORD_EMBEDDING_SIZE']

        self.encoder_inputs = tf.placeholder(dtype=tf.int32,
                                             shape=(None, self.batch_size),
                                             name='encoder_inputs')
        self.encoder_inputs_length = tf.placeholder(
            dtype=tf.int32,
            shape=(self.batch_size, ),
            name='encoder_inputs_length')
        self.encoder_inputs_mask = tf.placeholder(dtype=tf.float32,
                                                  shape=(self.batch_size,
                                                         None),
                                                  name='encoder_inputs_mask')
        self.decoder_inputs = tf.placeholder(dtype=tf.int32,
                                             shape=(None, self.batch_size),
                                             name='decoder_inputs')
        self.decoder_inputs_length = tf.placeholder(
            dtype=tf.int32,
            shape=(self.batch_size, ),
            name='decoder_inputs_length')
        self.decoder_inputs_mask = tf.placeholder(dtype=tf.float32,
                                                  shape=(self.batch_size,
                                                         None),
                                                  name='decoder_inputs_mask')
        self.decoder_targets = tf.placeholder(dtype=tf.int32,
                                              shape=(None, self.batch_size),
                                              name='decoder_targets')
        self.decoder_targets_length = tf.placeholder(
            dtype=tf.int32,
            shape=(self.batch_size, ),
            name='decoder_targets_length')
        self.decoder_targets_mask = tf.placeholder(dtype=tf.float32,
                                                   shape=(self.batch_size,
                                                          None),
                                                   name='decoder_targets_mask')

        self.maps_g2l_src = []
        self.maps_g2l_tgt = []
        r_global_dict_src, global_dict_src = loadDict(config['SRC_DICT'])
        r_global_dict_tgt, global_dict_tgt = loadDict(config['DST_DICT'])

        def make_maps(dx, dy):
            ret_index = []
            ret_weights = []
            dcnt = {}
            for i in range(len(dx)):
                word = dx[i]
                if word in [
                        '<ASV>', '<BBE>', '<DARBY>', '<DRA>', '<WEB>', '<YLT>',
                        '<AMP>', '<CJB>', '<CSB>', '<ERV>', '<ESV>', '<KJ21>',
                        '<MEV>', '<NCV>', '<NIV>', '<NOG>'
                ]:
                    for w in [
                            '<ASV>', '<BBE>', '<DARBY>', '<DRA>', '<WEB>',
                            '<YLT>', '<AMP>', '<CJB>', '<CSB>', '<ERV>',
                            '<ESV>', '<KJ21>', '<MEV>', '<NCV>', '<NIV>',
                            '<NOG>'
                    ]:
                        if w in dy:
                            word = w
                            print(i, word, dy[word])
                if word not in dy:
                    word = '<UNK>'

                ret_index.append(dy[word])

                if word not in dcnt:
                    dcnt[dy[word]] = 0
                dcnt[dy[word]] += 1
            for index in ret_index:
                ret_weights.append(1.0 / dcnt[index])
            return [ret_index, ret_weights]

        for model_no in range(len(config['MODEL_PREFIX'])):
            local_dict, _ = loadDict('auto-train-cc/' +
                                     config['MODEL_PREFIX'][model_no] +
                                     '/all.in.dict')
            self.maps_g2l_src.append(make_maps(global_dict_src, local_dict))
            local_dict, _ = loadDict('auto-train-cc/' +
                                     config['MODEL_PREFIX'][model_no] +
                                     '/all.out.dict')
            self.maps_g2l_tgt.append(make_maps(global_dict_tgt, local_dict))

        self.encoder_real_inputs = []
        self.decoder_real_inputs = []
        for model_no in range(len(config['MODEL_PREFIX'])):
            self.encoder_real_inputs.append(
                tf.gather(self.maps_g2l_src[model_no][0], self.encoder_inputs))
            self.decoder_real_inputs.append(
                tf.gather(self.maps_g2l_tgt[model_no][0], self.decoder_inputs))

        if config['CORENET'] == "FULL":
            self.maps_g2l_src.insert(
                0, make_maps(global_dict_src, r_global_dict_src))
            self.maps_g2l_tgt.insert(
                0, make_maps(global_dict_tgt, r_global_dict_tgt))

        self.input_word_embedding_matrixs = []
        self.encoder_inputs_embeddeds = []
        self.encoder_cells = []
        self.encoder_final_outputs = []
        self.encoder_final_states = []
        self.encoder_inputs_length_atts = []

        self.output_word_embedding_matrixs = []
        self.decoder_inputs_embeddeds = []
        self.decoder_cells = []
        self.decoder_initial_states = []
        self.output_projection_layers = []
        self.decoders = []

        self.encoder_variables = []
        self.decoder_variables = []

        self.decoders_outputs = []

        for model_no in range(len(config['MODEL_PREFIX'])):
            model_prefix = "asv-%s-02-" % config['MODEL_PREFIX'][model_no]

            with tf.variable_scope(model_prefix + "DynamicEncoder") as scope:
                self.input_word_embedding_matrixs.append(
                    modelInitWordEmbedding(self.input_sizes[model_no],
                                           self.embedding_size,
                                           name='we_input'))
                self.encoder_inputs_embeddeds.append(
                    modelGetWordEmbedding(
                        self.input_word_embedding_matrixs[model_no],
                        self.encoder_real_inputs[model_no]))

                self.encoder_cells.append(
                    modelInitRNNCells(self.encoder_hidden_size,
                                      config['ENCODER_LAYERS'], config['CELL'],
                                      config['INPUT_DROPOUT'],
                                      config['OUTPUT_DROPOUT']))

                if config['BIDIRECTIONAL_ENCODER']:
                    encoder_outputs, encoder_state = modelInitBidirectionalEncoder(
                        self.encoder_cells[model_no],
                        self.encoder_inputs_embeddeds[model_no],
                        self.encoder_inputs_length,
                        encoder_type='stack')
                    self.encoder_final_outputs.append(encoder_outputs)
                    self.encoder_final_states.append(encoder_state)
                else:
                    encoder_outputs, encoder_state = modelInitUndirectionalEncoder(
                        self.encoder_cells[model_no],
                        self.encoder_inputs_embeddeds[model_no],
                        self.encoder_inputs_length)
                    self.encoder_final_outputs.append(encoder_outputs)
                    self.encoder_final_states.append(encoder_state)

                if config['USE_BS'] and not config['IS_TRAIN']:
                    self.encoder_final_states[model_no] = seq2seq.tile_batch(
                        self.encoder_final_states[model_no],
                        config['BEAM_WIDTH'])
                    self.encoder_final_outputs[model_no] = tf.transpose(
                        seq2seq.tile_batch(
                            tf.transpose(self.encoder_final_outputs[model_no],
                                         [1, 0, 2]), config['BEAM_WIDTH']),
                        [1, 0, 2])

                self.encoder_variables.append(scope.trainable_variables())

            self.encoder_inputs_length_att = self.encoder_inputs_length
            if config['USE_BS'] and not config['IS_TRAIN']:
                self.encoder_inputs_length_att = seq2seq.tile_batch(
                    self.encoder_inputs_length_att, config['BEAM_WIDTH'])

            with tf.variable_scope(model_prefix + "DynamicDecoder") as scope:
                self.output_word_embedding_matrixs.append(
                    modelInitWordEmbedding(self.output_sizes[model_no],
                                           self.embedding_size,
                                           name='we_output'))
                self.decoder_inputs_embeddeds.append(
                    modelGetWordEmbedding(
                        self.output_word_embedding_matrixs[model_no],
                        self.decoder_real_inputs[model_no]))

                self.decoder_cells.append(
                    modelInitRNNCells(self.decoder_hidden_size,
                                      config['DECODER_LAYERS'], config['CELL'],
                                      config['INPUT_DROPOUT'],
                                      config['OUTPUT_DROPOUT']))
                if config['ATTENTION_DECODER']:
                    self.decoder_cells[
                        model_no] = modelInitAttentionDecoderCell(
                            self.decoder_cells[model_no],
                            self.decoder_hidden_size,
                            self.encoder_final_outputs[model_no],
                            self.encoder_inputs_length_att,
                            att_type=config['ATTENTION_MECHANISE'],
                            wrapper_type='whole')
                else:
                    self.decoder_cells[model_no] = modelInitRNNDecoderCell(
                        self.decoder_cells[model_no])

                initial_state = None

                if config['USE_BS'] and not config['IS_TRAIN']:
                    initial_state = self.decoder_cells[model_no].zero_state(
                        batch_size=self.batch_size * config['BEAM_WIDTH'],
                        dtype=tf.float32)
                    if config['ATTENTION_DECODER']:
                        cat_state = tuple(
                            [self.encoder_final_states[model_no]] +
                            list(initial_state.cell_state)[:-1])
                        initial_state.clone(cell_state=cat_state)
                    else:
                        initial_state = tuple(
                            [self.encoder_final_states[model_no]] +
                            list(initial_state[:-1]))
                else:
                    initial_state = self.decoder_cells[model_no].zero_state(
                        batch_size=self.batch_size, dtype=tf.float32)

                    if config['ATTENTION_DECODER']:
                        cat_state = tuple(
                            [self.encoder_final_states[model_no]] +
                            list(initial_state.cell_state)[:-1])
                        initial_state.clone(cell_state=cat_state)
                    else:
                        initial_state = tuple(
                            [self.encoder_final_states[model_no]] +
                            list(initial_state[:-1]))

                self.decoder_initial_states.append(initial_state)
                self.output_projection_layers.append(
                    layers_core.Dense(self.output_sizes[model_no],
                                      use_bias=False,
                                      name=model_prefix + 'Opl'))

                decoder_tmp, _ = modelInitPretrainedDecoder(
                    self.decoder_cells[model_no],
                    self.decoder_inputs_embeddeds[model_no],
                    self.decoder_inputs_length,
                    self.decoder_initial_states[model_no],
                    self.output_projection_layers[model_no])
                self.decoders.append(decoder_tmp)
                # self.decoders_outputs.append(output_tmp)

                self.decoder_variables.append(scope.trainable_variables())

        # print('Encoder Trainable Variables')
        # print(self.encoder_variables)
        # print('Decoder Trainable Variables')
        # print(self.decoder_variables)

        with tf.variable_scope("Core") as scope:
            if config['CORENET'] == 'FULL':
                self.input_word_embedding_matrixs.insert(
                    0,
                    modelInitWordEmbedding(config['INPUT_VOCAB_SIZE'],
                                           self.embedding_size,
                                           name='we_input'))
                self.encoder_inputs_embeddeds.insert(
                    0,
                    modelGetWordEmbedding(self.input_word_embedding_matrixs[0],
                                          self.encoder_inputs))

                self.encoder_cells.insert(
                    0,
                    modelInitRNNCells(self.encoder_hidden_size,
                                      config['ENCODER_LAYERS'], config['CELL'],
                                      config['INPUT_DROPOUT'],
                                      config['OUTPUT_DROPOUT']))

                if config['BIDIRECTIONAL_ENCODER']:
                    encoder_outputs, encoder_state = modelInitBidirectionalEncoder(
                        self.encoder_cells[0],
                        self.encoder_inputs_embeddeds[0],
                        self.encoder_inputs_length,
                        encoder_type='stack')
                    self.encoder_final_outputs.insert(0, encoder_outputs)
                    self.encoder_final_states.insert(0, encoder_state)
                else:
                    encoder_outputs, encoder_state = modelInitUndirectionalEncoder(
                        self.encoder_cells[0],
                        self.encoder_inputs_embeddeds[0],
                        self.encoder_inputs_length)
                    self.encoder_final_outputs.insert(0, encoder_outputs)
                    self.encoder_final_states.insert(0, encoder_state)
                self.encoder_inputs_length_att = self.encoder_inputs_length
                self.output_word_embedding_matrixs.insert(
                    0,
                    modelInitWordEmbedding(config['OUTPUT_VOCAB_SIZE'],
                                           self.embedding_size,
                                           name='we_output'))
                self.decoder_inputs_embeddeds.insert(
                    0,
                    modelGetWordEmbedding(
                        self.output_word_embedding_matrixs[0],
                        self.decoder_inputs))

                self.decoder_cells.insert(
                    0,
                    modelInitRNNCells(self.decoder_hidden_size,
                                      config['DECODER_LAYERS'], config['CELL'],
                                      config['INPUT_DROPOUT'],
                                      config['OUTPUT_DROPOUT']))
                if config['ATTENTION_DECODER']:
                    self.decoder_cells[0] = modelInitAttentionDecoderCell(
                        self.decoder_cells[0],
                        self.decoder_hidden_size,
                        self.encoder_final_outputs[0],
                        self.encoder_inputs_length_att,
                        att_type=config['ATTENTION_MECHANISE'],
                        wrapper_type='whole')
                else:
                    self.decoder_cells[0] = modelInitRNNDecoderCell(
                        self.decoder_cells[0])

                initial_state = self.decoder_cells[0].zero_state(
                    batch_size=self.batch_size, dtype=tf.float32)

                if config['ATTENTION_DECODER']:
                    cat_state = tuple([self.encoder_final_states[0]] +
                                      list(initial_state.cell_state)[:-1])
                    initial_state.clone(cell_state=cat_state)
                else:
                    initial_state = tuple([self.encoder_final_states[0]] +
                                          list(initial_state[:-1]))

                self.decoder_initial_states.insert(0, initial_state)
                self.output_projection_layers.insert(
                    0,
                    layers_core.Dense(config['OUTPUT_VOCAB_SIZE'],
                                      use_bias=False,
                                      name='Opl'))

                decoder_tmp, _ = modelInitPretrainedDecoder(
                    self.decoder_cells[0], self.decoder_inputs_embeddeds[0],
                    self.decoder_inputs_length, self.decoder_initial_states[0],
                    self.output_projection_layers[0])
                self.decoders.insert(0, decoder_tmp)
                # self.decoders_outputs.append(output_tmp)

            self.ma_policy = tf.get_variable(name='ma_policy',
                                             shape=[
                                                 config['OUTPUT_VOCAB_SIZE'],
                                                 len(config['MODEL_PREFIX'])
                                             ],
                                             dtype=tf.float32)
            print('Core Trainable Variables')
            self.core_variables = scope.trainable_variables()
            print(self.core_variables)

        self.outputs, self.output_weights = modelInitMultiDecodersForTrain(
            self.decoders,
            self.ma_policy,
            self.maps_g2l_tgt,
            self.output_word_embedding_matrixs,
            decode_func=dynamic_multi_decode,
            policy_mode=config['CORENET'])
        if config['IS_TRAIN']:
            self.train_outputs = self.outputs.rnn_output
        if config['USE_BS'] and not config['IS_TRAIN']:
            # self.infer_outputs = modelInitMultiDecodersForBSInfer(self.decoder_cells, [self.decoder_inputs[0]]*len(config[MODEL_PREFIX]), self.output_word_embedding_matrixs, config['BEAM_WIDTH'], config['IDS_END'], config['MAX_OUT_LEN'], self.decoder_initial_states, self.output_projection_layers, self.ma_policy, self.maps_g2l, decode_func=dynamic_multi_decode)
            pass
        else:
            self.infer_outputs = self.outputs.rnn_output

        if config['IS_TRAIN']:
            self.train_loss = seq2seq.sequence_loss(
                logits=self.train_outputs,
                targets=tf.transpose(self.decoder_targets, perm=[1, 0]),
                weights=self.decoder_targets_mask)

            self.rewards = tf.py_func(LMScore, [
                self.train_outputs,
                tf.constant(config['LM_MODEL_Y'], dtype=tf.string),
                tf.constant(config['DST_DICT'], dtype=tf.string)
            ], tf.float32)

            self.train_loss_rl = rlloss.sequence_loss_rl(
                logits=self.train_outputs,
                rewards=self.rewards,
                weights=self.decoder_targets_mask)

            self.eval_loss = seq2seq.sequence_loss(
                logits=self.train_outputs,
                targets=tf.transpose(self.decoder_targets, perm=[1, 0]),
                weights=self.decoder_targets_mask)

            self.final_loss = self.train_loss  #+config['RL_RATIO']*self.train_loss_rl

        print('All Trainable Variables:')
        self.all_trainable_variables = self.core_variables
        self.preload_variables = [
            i + j
            for i, j in zip(self.encoder_variables, self.decoder_variables)
        ]
        print(self.all_trainable_variables)
        if config['IS_TRAIN']:
            self.train_op = updateBP(self.final_loss, [self.learning_rate],
                                     [self.all_trainable_variables],
                                     self.optimizer,
                                     norm=config['CLIP_NORM'])
            # self.train_op = tf.constant(0.0)
        self.saver = initSaver(tf.global_variables(), config['MAX_TO_KEEP'])
        self.preload_savers = []
        for mno in range(len(config['MODEL_PREFIX'])):
            self.preload_savers.append(initSaver(self.preload_variables[mno]))
Exemplo n.º 3
0
    def __init__(self, config):

        print('The model is built for training:', config['IS_TRAIN'])

        self.rl_enable = config['RL_ENABLE']
        self.bleu_enable = config['BLEU_RL_ENABLE']
        self.learning_rate = tf.Variable(config['LR'],
                                         dtype=tf.float32,
                                         name='model_learning_rate',
                                         trainable=False)
        self.word_embedding_learning_rate = tf.Variable(
            config['WE_LR'],
            dtype=tf.float32,
            name='model_we_learning_rate',
            trainable=False)
        self.encoder_learning_rate = tf.Variable(
            config['ENCODER_LR'],
            dtype=tf.float32,
            name='model_enc_learning_rate',
            trainable=False)
        self.decoder_learning_rate = tf.Variable(
            config['DECODER_LR'],
            dtype=tf.float32,
            name='model_dec_learning_rate',
            trainable=False)
        if config['SPLIT_LR']:

            def tmp_func():
                self.word_embedding_learning_rate.assign(
                    self.word_embedding_learning_rate * config['LR_DECAY'])
                self.encoder_learning_rate.assign(self.encoder_learning_rate *
                                                  config['LR_DECAY'])
                self.decoder_learning_rate.assign(self.decoder_learning_rate *
                                                  config['LR_DECAY'])

            self.lr_decay_op = tmp_func()
        else:
            self.lr_decay_op = self.learning_rate.assign(self.learning_rate *
                                                         config['LR_DECAY'])

        if config['OPTIMIZER'] == 'Adam':
            self.optimizer = tf.train.AdamOptimizer
        elif config['OPTIMIZER'] == 'GD':
            self.optimizer = tf.train.GradientDescentOptimizer
        else:
            raise Exception("Wrong optimizer name...")

        self.global_step = tf.Variable(config['GLOBAL_STEP'],
                                       dtype=tf.int32,
                                       name='model_global_step',
                                       trainable=False)
        self.batch_size = config['BATCH_SIZE']
        self.input_size = config['INPUT_VOCAB_SIZE']
        self.output_size = config['OUTPUT_VOCAB_SIZE']
        self.encoder_hidden_size = config['ENCODER_HIDDEN_SIZE']
        self.decoder_hidden_size = config['DECODER_HIDDEN_SIZE']
        self.embedding_size = config['WORD_EMBEDDING_SIZE']

        self.encoder_inputs = tf.placeholder(dtype=tf.int32,
                                             shape=(None, self.batch_size),
                                             name='encoder_inputs')
        self.encoder_inputs_length = tf.placeholder(
            dtype=tf.int32,
            shape=(self.batch_size, ),
            name='encoder_inputs_length')
        self.encoder_inputs_mask = tf.placeholder(dtype=tf.float32,
                                                  shape=(self.batch_size,
                                                         None),
                                                  name='encoder_inputs_mask')
        self.decoder_inputs = tf.placeholder(dtype=tf.int32,
                                             shape=(None, self.batch_size),
                                             name='decoder_inputs')
        self.decoder_inputs_length = tf.placeholder(
            dtype=tf.int32,
            shape=(self.batch_size, ),
            name='decoder_inputs_length')
        self.decoder_inputs_mask = tf.placeholder(dtype=tf.float32,
                                                  shape=(self.batch_size,
                                                         None),
                                                  name='decoder_inputs_mask')
        self.decoder_targets = tf.placeholder(dtype=tf.int32,
                                              shape=(None, self.batch_size),
                                              name='decoder_targets')
        self.decoder_targets_length = tf.placeholder(
            dtype=tf.int32,
            shape=(self.batch_size, ),
            name='decoder_targets_length')
        self.decoder_targets_mask = tf.placeholder(dtype=tf.float32,
                                                   shape=(self.batch_size,
                                                          None),
                                                   name='decoder_targets_mask')
        self.rewards = tf.placeholder(dtype=tf.float32,
                                      shape=(self.batch_size, None,
                                             self.output_size),
                                      name='decoder_targets_mask')

        with tf.variable_scope("InputWordEmbedding") as scope:

            self.input_word_embedding_matrix = modelInitWordEmbedding(
                self.input_size, self.embedding_size, name='we_input')

            self.encoder_inputs_embedded = modelGetWordEmbedding(
                self.input_word_embedding_matrix, self.encoder_inputs)

            # print('Embedding Trainable Variables')
            self.input_embedding_variables = scope.trainable_variables()
            # print(self.embedding_variables)

        with tf.variable_scope("OutputWordEmbedding") as scope:

            self.output_word_embedding_matrix = modelInitWordEmbedding(
                self.output_size, self.embedding_size, name='we_output')
            self.decoder_inputs_embedded = modelGetWordEmbedding(
                self.output_word_embedding_matrix, self.decoder_inputs)

            # print('Embedding Trainable Variables')
            self.output_embedding_variables = scope.trainable_variables()
            # print(self.embedding_variables)

        with tf.variable_scope("DynamicEncoder") as scope:
            self.encoder_cell = modelInitRNNCells(self.encoder_hidden_size,
                                                  config['ENCODER_LAYERS'],
                                                  config['CELL'],
                                                  config['INPUT_DROPOUT'],
                                                  config['OUTPUT_DROPOUT'])
            if config['BIDIRECTIONAL_ENCODER']:
                if config['VAE_ENCODER']:
                    self.encoder_outputs, self.encoder_state, self.vae_loss = modelInitVAEBidirectionalEncoder(
                        self.encoder_cell,
                        self.encoder_inputs_embedded,
                        self.encoder_inputs_length,
                        encoder_type='stack')
                else:
                    self.encoder_outputs, self.encoder_state = modelInitBidirectionalEncoder(
                        self.encoder_cell,
                        self.encoder_inputs_embedded,
                        self.encoder_inputs_length,
                        encoder_type='stack')
            else:
                if config['VAE_ENCODER']:
                    self.encoder_outputs, self.encoder_state, self.vae_loss = modelInitVAEUndirectionalEncoder(
                        self.encoder_cell, self.encoder_inputs_embedded,
                        self.encoder_inputs_length)
                else:
                    self.encoder_outputs, self.encoder_state = modelInitUndirectionalEncoder(
                        self.encoder_cell, self.encoder_inputs_embedded,
                        self.encoder_inputs_length)
            o_1, o_2 = tf.split(self.encoder_outputs, 2, 1)
            euclidean_dis = tf.reduce_mean(tf.square(o_1 - o_2), 2)
            self.sae_loss = tf.reduce_mean(euclidean_dis)

            if config['USE_BS'] and not config['IS_TRAIN']:
                self.encoder_state = seq2seq.tile_batch(
                    self.encoder_state, config['BEAM_WIDTH'])
                self.encoder_outputs = tf.transpose(
                    seq2seq.tile_batch(
                        tf.transpose(self.encoder_outputs, [1, 0, 2]),
                        config['BEAM_WIDTH']), [1, 0, 2])
                self.encoder_inputs_length_att = seq2seq.tile_batch(
                    self.encoder_inputs_length, config['BEAM_WIDTH'])
            else:
                self.encoder_inputs_length_att = self.encoder_inputs_length

            # print('Encoder Trainable Variables')
            self.encoder_variables = scope.trainable_variables()
            # print(self.encoder_variables)

        with tf.variable_scope("DynamicDecoder") as scope:
            self.decoder_cell = modelInitRNNCells(self.decoder_hidden_size,
                                                  config['DECODER_LAYERS'],
                                                  config['CELL'],
                                                  config['INPUT_DROPOUT'],
                                                  config['OUTPUT_DROPOUT'])
            if config['ATTENTION_DECODER']:
                self.decoder_cell = modelInitAttentionDecoderCell(
                    self.decoder_cell,
                    self.decoder_hidden_size,
                    self.encoder_outputs,
                    self.encoder_inputs_length_att,
                    att_type=config['ATTENTION_MECHANISE'],
                    wrapper_type='whole')

            initial_state = None

            if config['USE_BS'] and not config['IS_TRAIN']:
                initial_state = self.decoder_cell.zero_state(
                    batch_size=self.batch_size * config['BEAM_WIDTH'],
                    dtype=tf.float32)
                if config['ATTENTION_DECODER']:
                    cat_state = tuple([self.encoder_state] +
                                      list(initial_state.cell_state)[:-1])
                    initial_state.clone(cell_state=cat_state)
                else:
                    initial_state = tuple([self.encoder_state] +
                                          list(initial_state[:-1]))
            else:
                initial_state = self.decoder_cell.zero_state(
                    batch_size=self.batch_size, dtype=tf.float32)

                if config['ATTENTION_DECODER']:
                    cat_state = tuple([self.encoder_state] +
                                      list(initial_state.cell_state)[:-1])
                    initial_state.clone(cell_state=cat_state)
                else:
                    initial_state = tuple([self.encoder_state] +
                                          list(initial_state[:-1]))

            self.output_projection_layer = layers_core.Dense(self.output_size,
                                                             use_bias=False)
            if config['IS_TRAIN']:
                self.train_outputs = modelInitDecoderForTrain(
                    self.decoder_cell, self.decoder_inputs_embedded,
                    self.decoder_inputs_length, initial_state,
                    self.output_projection_layer)
            if config['USE_BS'] and not config['IS_TRAIN']:
                self.infer_outputs = modelInitDecoderForBSInfer(
                    self.decoder_cell, self.decoder_inputs[0],
                    self.output_word_embedding_matrix, config['BEAM_WIDTH'],
                    config['ID_END'], config['MAX_OUT_LEN'], initial_state,
                    self.output_projection_layer)
            else:
                self.infer_outputs = modelInitDecoderForGreedyInfer(
                    self.decoder_cell, self.decoder_inputs[0],
                    self.output_word_embedding_matrix, config['ID_END'],
                    config['MAX_OUT_LEN'], initial_state,
                    self.output_projection_layer)

            if config['IS_TRAIN']:
                self.train_loss = seq2seq.sequence_loss(
                    logits=self.train_outputs,
                    targets=tf.transpose(self.decoder_targets, perm=[1, 0]),
                    weights=self.decoder_targets_mask)
                if config['VAE_ENCODER'] and config['PRE_ENCODER'] == None:
                    self.train_loss += 0.01 * self.vae_loss
                if config['SAE_ENCODER'] and config['PRE_ENCODER'] == None:
                    self.train_loss += self.sae_loss
                self.rewards = tf.py_func(contentPenalty, [
                    tf.transpose(self.encoder_inputs, perm=[1, 0]),
                    self.train_outputs,
                    tf.constant(config['SRC_DICT'], dtype=tf.string),
                    tf.constant(config['DST_DICT'], dtype=tf.string),
                    tf.transpose(self.decoder_targets, perm=[1, 0])
                ], tf.float32)
                self.rewards.set_shape(self.train_outputs.get_shape())
                if config['RL_ENABLE']:
                    self.train_loss_rl = rlloss.sequence_loss_rl(
                        logits=self.train_outputs,
                        rewards=self.rewards,
                        weights=self.decoder_targets_mask)
                else:
                    self.train_loss_rl = tf.constant(0.0)
                self.rewards_bleu = tf.py_func(bleuPenalty, [
                    tf.transpose(self.encoder_inputs, perm=[1, 0]),
                    self.train_outputs,
                    tf.constant(config['SRC_DICT'], dtype=tf.string),
                    tf.constant(config['DST_DICT'], dtype=tf.string),
                    tf.constant(config['HYP_FILE_PATH'], dtype=tf.string),
                    tf.constant(config['REF_FILE_PATH_FORMAT'],
                                dtype=tf.string)
                ], tf.float32)
                self.rewards_bleu.set_shape(self.train_outputs.get_shape())
                if config['BLEU_RL_ENABLE']:
                    self.train_loss_rl_bleu = rlloss.sequence_loss_rl(
                        logits=self.train_outputs,
                        rewards=self.rewards_bleu,
                        weights=self.decoder_targets_mask) / 2
                else:
                    self.train_loss_rl_bleu = tf.constant(0.0)
                self.eval_loss = seq2seq.sequence_loss(
                    logits=self.train_outputs,
                    targets=tf.transpose(self.decoder_targets, perm=[1, 0]),
                    weights=self.decoder_targets_mask)

                if config['TRAIN_ON_EACH_STEP']:
                    self.final_loss = self.train_loss
                    if config['RL_ENABLE']:
                        self.final_loss = self.final_loss + self.train_loss_rl
                    if config['BLEU_RL_ENABLE']:
                        self.final_loss = self.final_loss + self.train_loss_rl_bleu
                else:
                    self.final_loss = self.eval_loss

            # print('Decoder Trainable Variables')
            self.decoder_variables = scope.trainable_variables()
            # print(self.decoder_variables)

        print('All Trainable Variables:')
        if config['PRE_ENCODER']:
            self.all_trainable_variables = list(
                set(tf.trainable_variables()).difference(
                    set(self.encoder_variables)).difference(
                        set(self.input_embedding_variables)))
        else:
            self.all_trainable_variables = tf.trainable_variables()
        print(self.all_trainable_variables)
        if config['IS_TRAIN']:
            if config['SPLIT_LR']:
                self.train_op = updateBP(self.final_loss, [
                    self.word_embedding_learning_rate,
                    self.encoder_learning_rate, self.decoder_learning_rate
                ], [
                    self.embedding_variables, self.encoder_variables,
                    self.decoder_variables
                ],
                                         self.optimizer,
                                         norm=config['CLIP_NORM'])
            else:
                self.train_op = updateBP(self.final_loss, [self.learning_rate],
                                         [self.all_trainable_variables],
                                         self.optimizer,
                                         norm=config['CLIP_NORM'])
        self.saver = initSaver(tf.global_variables(), config['MAX_TO_KEEP'])
        self.encoder_saver = initSaver(
            self.encoder_variables + self.input_embedding_variables,
            config['MAX_TO_KEEP'])