示例#1
0
def transform(sentence, l1, l2, op_lang):
    if op_lang == "hi":
        dir = "training_checkpoints"
        with open("tensors1.pkl", 'rb') as f:
            example_input_batch = pickle.load(f)
            example_target_batch = pickle.load(f)
    elif op_lang == "nep":
        dir = "training_nepali"
        with open("tensors_nep.pkl", 'rb') as f:
            example_input_batch = pickle.load(f)
            example_target_batch = pickle.load(f)

    elif op_lang == "pun":
        dir = "training_punjabi"
        with open("tensors_pun.pkl", 'rb') as f:
            example_input_batch = pickle.load(f)
            example_target_batch = pickle.load(f)
    print(example_input_batch)

    global input_lang, output_lang
    input_lang = l1
    output_lang = l2
    global BATCH_SIZE, units, embedding_dim
    BATCH_SIZE = 64
    units = 512
    if op_lang == "pun":
        units = 1024
    embedding_dim = 256
    vocab_inp_size = len(input_lang.word2index) + 1
    vocab_tar_size = len(output_lang.word2index) + 1
    global encoder, decoder, optimizer
    optimizer = tf.keras.optimizers.Adam()
    loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
        from_logits=True, reduction='none')
    encoder = Encoder(vocab_inp_size, embedding_dim, units, BATCH_SIZE)
    sample_hidden = encoder.initialize_hidden_state()

    sample_output, sample_hidden = encoder(example_input_batch, sample_hidden)
    attention_layer = BahdanauAttention(10)
    attention_result, attention_weights = attention_layer(
        sample_hidden, sample_output)
    decoder = Decoder(vocab_tar_size, embedding_dim, units, BATCH_SIZE)
    sample_decoder_output, _, _ = decoder(tf.random.uniform((BATCH_SIZE, 1)),
                                          sample_hidden, sample_output)

    checkpoint = tf.train.Checkpoint(optimizer=optimizer,
                                     encoder=encoder,
                                     decoder=decoder)

    curr = pathlib.Path(__file__)
    print(encoder.summary())
    temp = os.path.join(curr.parent, "available_models")
    checkpoint_dir = os.path.join(temp, dir)
    print(checkpoint_dir)
    checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
    print(checkpoint_dir)
    print(checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir)))
    # print(checkpoint.restore(checkpoint_dir))
    print(encoder.summary())
    return translate(sentence)
import anyconfig
import numpy as np
import tensorflow as tf
from sklearn.model_selection import train_test_split

import dataset
from model import Encoder, Decoder, BahdanauAttention

if __name__ == "__main__":
    os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
    config = anyconfig.load(open("config.yaml", 'rb'))
    BATCH_SIZE = 1
    SEQ_LENGTH = config["trainer"]["sequnce_length"]
    # encoder
    encoder = Encoder(config["trainer"]["gru_units"], BATCH_SIZE)
    sample_hidden = encoder.initialize_hidden_state()
    # attention
    attention_layer = BahdanauAttention(config["trainer"]["attention_units"])
    # decoder
    decoder = Decoder(config["trainer"]["label_length"],
                      config["trainer"]["gru_units"], BATCH_SIZE)
    # reloder
    optimizer = tf.keras.optimizers.Adam()
    checkpoint_dir = config["trainer"]["checkpoint_dir"]
    checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
    checkpoint = tf.train.Checkpoint(optimizer=optimizer,
                                     encoder=encoder,
                                     decoder=decoder)
    checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))

    # test
def train():
    print("Loading training data...")
    # Get sentences to tensors
    input_tensor, target_tensor, inp_lang, targ_lang = load_dataset(
        PATH_TO_FILE, num_examples=None)

    # NNConfig
    vocab_inp_size = len(inp_lang.word_index) + 1
    vocab_targ_size = len(targ_lang.word_index) + 1
    config = NNConfig(vocab_inp_size, vocab_targ_size)

    # Save i2w file for test and translate
    save_index2word(inp_lang, "./dataset/input_dict.txt")
    save_index2word(targ_lang, "./dataset/target_dict.txt")
    save_max_length(input_tensor, target_tensor, vocab_inp_size,
                    vocab_targ_size, "./dataset/max_len.txt")

    # Setup the trainning data batch
    BUFFER_SIZE = len(input_tensor)
    steps_per_epoch = len(input_tensor) // config.BATCH_SIZE

    dataset = tf.data.Dataset.from_tensor_slices(
        (input_tensor, target_tensor)).shuffle(BUFFER_SIZE)
    dataset = dataset.batch(config.BATCH_SIZE, drop_remainder=True)

    print("Setting Seq2Seq model...")
    # Setup the NN Structure
    encoder = Encoder(config.VOCAB_INP_SIZE, config.EMBEDDING_DIM,
                      config.UNITS, config.BATCH_SIZE)
    decoder = Decoder(config.VOCAB_TARG_SIZE, config.EMBEDDING_DIM,
                      config.UNITS, config.BATCH_SIZE)
    # Setup optimizer
    optimizer = tf.keras.optimizers.Adam()
    loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
        from_logits=True, reduction='none')
    # Setup Checkpoint
    checkpoint = tf.train.Checkpoint(optimizer=optimizer,
                                     encoder=encoder,
                                     decoder=decoder)

    def loss_function(real, pred):
        mask = tf.math.logical_not(tf.math.equal(real, 0))
        loss_ = loss_object(real, pred)
        mask = tf.cast(mask, dtype=loss_.dtype)
        loss_ *= mask
        return tf.reduce_mean(loss_)

    @tf.function
    def train_step(inp, targ, enc_hidden, optimizer):
        loss = 0
        with tf.GradientTape() as tape:
            enc_output, enc_hidden = encoder(inp, enc_hidden)
            dec_hidden = enc_hidden
            dec_input = tf.expand_dims([targ_lang.word_index['<start>']] *
                                       config.BATCH_SIZE, 1)

            # Teacher forcing - feeding the target as the next input
            for t in range(1, targ.shape[1]):
                # passing enc_output to the decoder
                predictions, dec_hidden, _ = decoder(dec_input, dec_hidden,
                                                     enc_output)
                loss += loss_function(targ[:, t], predictions)
                # using teacher forcing
                dec_input = tf.expand_dims(targ[:, t], 1)

            batch_loss = (loss / int(targ.shape[1]))
            variables = encoder.trainable_variables + decoder.trainable_variables
            gradients = tape.gradient(loss, variables)
            optimizer.apply_gradients(zip(gradients, variables))
        return batch_loss

    print("Start training ...")
    for epoch in range(config.EPOCHS):

        start = time.time()
        enc_hidden = encoder.initialize_hidden_state()

        total_loss = 0

        for (batch, (inp, targ)) in enumerate(dataset.take(steps_per_epoch)):

            batch_loss = train_step(inp, targ, enc_hidden, optimizer)
            total_loss += batch_loss

            if batch % 100 == 0:
                print('Epoch {} Batch {} Loss {:.4f}'.format(
                    epoch + 1, batch, batch_loss.numpy()))

        # saving (checkpoint) the model every 2 epochs
        if (epoch + 1) % 2 == 0:
            checkpoint.save(file_prefix=SAVE_PATH)

        print('Epoch {} Loss {:.4f}'.format(epoch + 1,
                                            total_loss / steps_per_epoch))
        print('Time taken for 1 epoch {} sec\n'.format(time.time() - start))
示例#4
0
def train():
    # args is a global variable in this task
    BATCH_SIZE = args.batch_size
    EPOCH = args.epoch
    EMBED_DIM = args.embeddingDim
    MAXLEN = args.maxLen
    NUM_UNITS = args.units
    LEARNING_RATE = args.learning_rate
    DROPOUT = args.dropout
    METHOD = args.method
    GPUNUM = args.gpuNum
    CKPT = args.checkpoint
    LIMIT = args.limit
    start_word = "<s>"
    end_word = "</s>"
    #Here, tokenizer saves all info to split data.
    #Itself is not a part of data.
    source_vocabulary_data = load_vocabulary_data(filepath=r'./data/vocab.50K.en', limit=None)
    target_vocabulary_data = load_vocabulary_data(filepath=r'./data/vocab.50K.de', limit=None)
    buffer_size = BATCH_SIZE*100

    #train_source_tensor, train_source_tokenizer, train_target_tensor, train_target_tokenizer = \
    #    load_data(pad_length = MAXLEN, limit=LIMIT)

    #train_source_tensor, val_source_tensor, train_target_tensor, val_target_tensor = \
    #    train_test_split(train_source_tensor, train_target_tensor, random_state=2019)

    vocab_source_size = len(source_vocabulary_data)
    print("vocab_input_size: ",vocab_source_size)
    vocab_target_size = len(target_vocabulary_data)
    print("vocab_target_size: ", vocab_target_size)

    optimizer = tf.compat.v1.train.GradientDescentOptimizer(learning_rate=0.001)
    # dataset = generator_input_fn(train_source_tensor, train_target_tensor, buffer_size, EPOCH, BATCH_SIZE, MAXLEN)
    apply_loss = tf.losses.SparseCategoricalCrossentropy(from_logits=True)
    encoder = Encoder(vocab_source_size, EMBED_DIM, NUM_UNITS, dropout_rate = DROPOUT, batch_size=BATCH_SIZE)
    decoder = Decoder(vocab_target_size, EMBED_DIM, NUM_UNITS, batch_size= BATCH_SIZE, method= None, dropout_rate=DROPOUT)
    # set up checkpoint
    if not os.path.exists(CKPT):
        os.makedirs(CKPT)
    else:
        print("Warning: current Checkpoint dir already exist! ",
              "\nPlease consider to choose a new dir to save your checkpoint!")
    checkpoint = tf.train.Checkpoint(optimizer = optimizer, encoder=encoder,
                                     decoder=decoder)
    checkpoint_prefix = os.path.join(CKPT, "ckpt")


    def train_wrapper(source, target, train_target_tokenizer):
        # with tf.GradientTape(watch_accessed_variables=False) as tape:
        with tf.GradientTape() as tape:
            # source_out, source_state, source_trainable_var, tape = encoder(source, encoder_state, vocab_source_size,
            #                                                          EMBED_DIM, NUM_UNITS, activation="tanh",
            #                                                          dropout_rate = DROPOUT)
            source_out, source_state = encoder(source, encoder_state, activation="tanh")

            initial = tf.expand_dims([train_target_tokenizer.word_index[start_word]] * BATCH_SIZE, 1)
            attention_state = tf.zeros((BATCH_SIZE, 1, EMBED_DIM))
            # cur_total_loss is a sum of loss for current steps, namely batch loss
            cur_total_loss, cur_total_plex, cur_loss = 0, 0, 0
            for i in range(1, target.shape[1]):
                output_state, source_state, attention_state = decoder(initial, source_state, source_out, attention_state)
                # TODO: check for the case where target is 0
                cur_loss = apply_loss(target[:, i], output_state)
                perplex = tf.nn.sparse_softmax_cross_entropy_with_logits(target[:, i], output_state)
                # 0 should be the padding value in target.
                # I assumed that there should not be 0 value in target
                # for safety reason, we apply this mask to final loss
                # Mask is a array contains binary value(0 or 1)
                mask = tf.math.logical_not(tf.math.equal(target[:, i], 0))
                mask = tf.cast(mask, dtype=cur_loss.dtype)
                cur_loss *= mask
                perplex *= mask
                cur_total_loss += tf.reduce_mean(cur_loss)
                cur_total_plex += tf.reduce_mean(perplex)
                initial = tf.expand_dims(target[:, i], 1)
                # print(cur_loss)
                # print(cur_total_loss)
        batch_loss = cur_total_loss / target.shape[1]
        batch_perplex = cur_total_plex / target.shape[1]
        ## debug
        variables = encoder.trainable_variables + decoder.trainable_variables
        # print("check variable: ", len(variables))
        #variables = encoder.trainable_variables
        # print("check var:", len(variables), variables[12:])
        gradients = tape.gradient(cur_total_loss, variables)
        # print("check gradient: ", len(gradients))
        # g_e = [type(ele) for ele in gradients if not isinstance(ele, tf.IndexedSlices)]
        # sum_g = [ele.numpy().sum() for ele in gradients if not isinstance(ele, tf.IndexedSlices)]

        # print(len(gradients), len(sum_g))
        clipped_gradients, _ = tf.clip_by_global_norm(gradients, 5)
        optimizer.apply_gradients(zip(clipped_gradients, variables))
        return batch_loss, batch_perplex
    # print(len(train_source_tensor),BATCH_SIZE,training_steps,LIMIT)
    for epoch in range(EPOCH):
        per_epoch_loss, per_epoch_plex = 0, 0
        start = time.time()
        encoder_hidden = encoder.initialize_hidden_state()
        encoder_ceil = encoder.initialize_cell_state()
        encoder_state = [[encoder_hidden, encoder_ceil], [encoder_hidden, encoder_ceil],
                         [encoder_hidden, encoder_ceil], [encoder_hidden, encoder_ceil]]
        # TODO : Double check to make sure all re-initialization is performed
        gen = \
            nmt_generator_from_data(source_data_path=r"./data/train.en", target_data_path=r'./data/train.de',
                         source_vocabulary_data=source_vocabulary_data, target_vocabulary_data=target_vocabulary_data, pad_length=90, batch_size=BATCH_SIZE,
                         limit=args.limit)
        for idx, data in enumerate(gen):
            train_source_tensor, train_source_tokenizer, train_target_tensor, train_target_tokenizer = data
            source, target = tf.convert_to_tensor(train_source_tensor), tf.convert_to_tensor(train_target_tensor)
            #print(type(source))
            #print("The shape of source is: ",source.shape,train_source_tensor.shape)
            #print(source[0])
            #print(source[1])
            cur_total_loss, batch_perplex = train_wrapper(source, target, train_target_tokenizer)
            per_epoch_loss += cur_total_loss
            per_epoch_plex += tf.exp(batch_perplex).numpy()
            if idx % 10 == 0:
                # print("current step is: "+str(tf.compat.v1.train.get_global_step()))
                # print(dir(optimizer))
                with open("checkpoint/logger.txt","a") as filelogger:
                    print("current learning rate is:"+str(optimizer._learning_rate),file=filelogger)
                    print('Epoch {}/{} Batch {} Loss {:.4f} perplex {:.4f}'.format(epoch + 1,
                                                                       EPOCH,
                                                                       idx + 10,
                                                                       cur_total_loss.numpy(),
                                                                        tf.exp(batch_perplex).numpy()),
                          file=filelogger)
                # tf.print(step)
                # print(dir(step))
                # print(int(step))

        with open("checkpoint/logger.txt", "a") as filelogger:
            print('Epoch {}/{} Total Loss per epoch {:.4f} - {} sec'.format(epoch + 1,
                                                                        EPOCH,
                                                                        per_epoch_loss / (idx+1.0),
                                                                        time.time() - start),file=filelogger)
            print("Epoch perplex: ",str(per_epoch_plex),file=filelogger)
        # TODO: for evaluation add bleu score
        if epoch % 1 == 0:
            print('Saving checkpoint for each epochs')
            checkpoint.save(file_prefix=checkpoint_prefix)

        if epoch == 3:
            optimizer._learning_rate = LEARNING_RATE / 10.0
示例#5
0
def train(args: Namespace):
    '''
  Cette fonction permet l'entraînement de l'IA selon différents paramètres transmis lors de l'appel de celle-ci

  args: Renvoie tous les arguments nécessaires
  '''

    #Télécharge la dataset nommé fra-end.zip et l'extrait dans mon cas ici C:\Users\theof\.keras\datasets\ uniquement si celui-ci n'a pas déjà été téléchargé et extrait
    pathFile = tf.keras.utils.get_file(
        'fra-eng.zip',
        origin=
        'http://storage.googleapis.com/download.tensorflow.org/data/fra-eng.zip',
        extract=True)

    #Création de la première dataset
    inputTensor, inputVocab, inputSentencesSize, targetTensor, targetVocab, targetSentencesSize = loadAndCreateDataset(
        os.path.dirname(pathFile) + "/fra.txt", args.sentences_size)
    print('Datasets chargées')

    #Quelques informations utiles sont affichées dans le terminal
    statsDataPreprocessed(inputTensor, targetTensor, inputVocab, targetVocab)

    epochSize = args.epoch
    batchSize = args.batch_size
    units = args.units
    bufferSize = len(inputTensor)
    embeddingDim = args.embedding_dim

    #Création de la dataset pour l'entrainement
    inputBatch, targetBatch, dataset = createDataset(inputTensor, targetTensor,
                                                     batchSize, bufferSize)

    #On stocke la taille du vocabulaire dans une variable
    vocabInputSize = len(inputVocab.word_index) + 1
    vocabTargetSize = len(targetVocab.word_index) + 1

    #Création de l'encodeur avec certaines caractéristiques
    encoderModel = Encoder(vocabInputSize, embeddingDim, units, batchSize)
    #On met tous les hidden states à 0
    encodeHiddenLayers = encoderModel.initialize_hidden_state()
    #On récupère tous les hidden states de l'encoder
    encodedOutputLayers, encodeHiddenLayers = encoderModel(
        inputBatch, encodeHiddenLayers)

    print(
        'Encoder output shape: (batch size, sequence length, units) {}'.format(
            encodedOutputLayers.shape))
    print('Encoder Hidden state shape: (batch size, units) {}'.format(
        encodeHiddenLayers.shape))

    #Création du mécanisme d'attention avec une taille de 10 couches
    attention_layer = BahdanauAttention(10)
    attention_result, attention_weights = attention_layer(
        encodeHiddenLayers, encodedOutputLayers)

    print("Attention result shape: (batch size, units) {}".format(
        attention_result.shape))
    print(
        "Attention weights shape: (batch size, sequence_length, 1) {}".format(
            attention_weights.shape))

    #Création du décodeur avec certaines caractéristiques
    decoder = Decoder(vocabTargetSize, embeddingDim, units, batchSize)
    sampleDecoderOutput, _, _ = decoder(tf.random.uniform(
        (batchSize, 1)), encodeHiddenLayers, encodedOutputLayers)

    print('Decoder output shape: (batch size, vocab size) {}'.format(
        sampleDecoderOutput.shape))

    #Définition de l'optimizer de l'entraînement nommé Adam qui repose sur un algorithme du gradient stochastique
    optimizer = tf.keras.optimizers.Adam()
    #Objet contient une méthode qui permet le calcule de la perte de l'entropie croisée entre les entrées et les prédictions.
    lossObject = tf.keras.losses.SparseCategoricalCrossentropy(
        from_logits=True, reduction='none')

    #Ssi aucun nom de configuration n'est proposé dans les arguments
    if (args.config_name == None):
        #Le programme génère un nombre aléatoirement entre 0 et 9999
        configName = random.randint(0, 9999)
    #Sinon il copie la valeur de l'argument dans une variable
    else:
        configName = args.config_name

    #URL où sont stockés les checkpoints et le fichier de configuration de cet entraînement
    checkpointDirectory = './outputs/{}/'.format(configName)
    checkpoint_prefix = os.path.join(checkpointDirectory, "ckpt")

    #Objet qui va stocker les poids de l'entraînement selon l'optimiseur, l'encodeur et le décodeur
    checkpoint = tf.train.Checkpoint(optimizer=optimizer,
                                     encoder=encoderModel,
                                     decoder=decoder)

    #Pour le nombre d'entraînements
    for epoch in range(epochSize):
        start = time.time()
        #Initialisation à 0 des hidden states de l'encodeur
        enc_hidden = encoderModel.initialize_hidden_state()
        total_loss = 0

        #Pour chaque ligne de la dataset, on récupère les phrases ainsi que leurs positions (Elles ont le même)
        for (batch, (inp, targ)) in enumerate(dataset.take(len(inputTensor))):
            #Renvoie l'erreur
            batch_loss = trainStep(inp, targ, enc_hidden, encoderModel,
                                   decoder, targetVocab, lossObject, optimizer,
                                   batchSize)
            #On stocke l'erreur totale de l'entraînement
            total_loss += batch_loss
            #Si l'entraînement est fini
            if batch % 100 == 0:
                print('Epoch {} Batch {} Loss {:.4f}'.format(
                    epoch + 1, batch, batch_loss.numpy()))
        #Si la taille de l'entraînement est supérieur à 1
        if (epochSize > 1):
            #Le programme sauvegarde une fois sur deux le checkpoint
            if (epoch + 1) % 2 == 0:
                checkpoint.save(file_prefix=checkpoint_prefix)
        #Sinon il sauvegarde le seul entraînement (Utile pour les tests)
        else:
            checkpoint.save(file_prefix=checkpoint_prefix)
        print('Epoch {} Loss {:.4f}'.format(epoch + 1,
                                            total_loss / len(inputTensor)))
        print('Time taken for 1 epoch {} sec\n'.format(time.time() - start))

    #Création d'un objet qui va être une variable contenant toutes les informations nécessaires pour les futures prédictions
    config = {}
    config['bufferSize'] = bufferSize
    config['batchSize'] = batchSize
    config['embeddingDim'] = embeddingDim
    config['units'] = units
    config['epoch'] = epochSize
    config['vocabInputSize'] = vocabInputSize
    config['vocabTargetSize'] = vocabTargetSize
    config['inputSentencesSize'] = inputSentencesSize
    config['targetSentencesSize'] = targetSentencesSize

    #Le programme enregistre le fichier de configuration dans le répertoire indiqué par l'utilisateur en JSON
    with open('{}config.json'.format(checkpointDirectory),
              'w',
              encoding='UTF-8') as handle:
        json.dump(config, handle, indent=2, sort_keys=True)
    #Le programme enregistre le tokenizer des 2 langues, utile pour les traductions
    with open('{}inputVocab.pickle'.format(checkpointDirectory),
              'wb') as handle:
        pickle.dump(inputVocab, handle, protocol=pickle.HIGHEST_PROTOCOL)
    with open('{}targetVocab.pickle'.format(checkpointDirectory),
              'wb') as handle:
        pickle.dump(targetVocab, handle, protocol=pickle.HIGHEST_PROTOCOL)
示例#6
0
class MyQA:
    # 混合训练语料
    me_train_dir = './baiduQA_corpus/me_train.json'
    # test集
    # 一个问题配一段材料,有答案
    me_test_ann_dir = './baiduQA_corpus/me_test.ann.json'
    # 一个问题配多段材料,有或无答案
    me_test_ir_dir = './baiduQA_corpus/me_test.ir.json'
    # validation集
    # 一个问题配一段材料,有答案
    me_validation_ann_dir = './baiduQA_corpus/me_validation.ann.json '
    # 一个问题配多段材料,有或无答案
    me_validation_ir_dir = './baiduQA_corpus/me_validation.ir.json'

    # 停用词

    def __init__(self):

        current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
        train_log_dir = './logs/gradient_tape/' + current_time + '/train'
        test_log_dir = './logs/gradient_tape/' + current_time + '/test'
        self.train_summary_writer = tf.summary.create_file_writer(
            train_log_dir)
        self.test_summary_writer = tf.summary.create_file_writer(test_log_dir)

        self.m = tf.keras.metrics.SparseCategoricalAccuracy()
        # self.recall = tf.keras.metrics.Recall()
        self.recall = [0]
        # self.F1Score = 2*self.m.result()*self.recall.result()/(self.recall.result()+self.m.result())
        self.BATCH_SIZE = 128
        self.embedding_dim = 24
        self.units = 64
        # 尝试实验不同大小的数据集
        stop_word_dir = './stop_words.utf8'
        self.stop_words = self.get_stop_words(stop_word_dir) + ['']
        num_examples = 30000
        QA_dir = './QA_data.txt'
        # QA_dir = 'C:/Users/Administrator/raw_chat_corpus/qingyun-11w/qinyun-11w.csv'
        self.input_tensor, self.target_tensor, self.inp_tokenizer, self.targ_tokenizer = self.load_dataset(
            QA_dir, num_examples)
        self.num_classes = len(self.targ_tokenizer.index_word)  #目标词类别
        #初始化混淆矩阵(训练用和测试用):
        self.train_confusion_matrix = tfa.metrics.MultiLabelConfusionMatrix(
            num_classes=self.num_classes)
        self.test_confusion_matrix = tfa.metrics.MultiLabelConfusionMatrix(
            num_classes=self.num_classes)

        self.F1Score = tfa.metrics.F1Score(num_classes=len(
            self.targ_tokenizer.index_word),
                                           average="micro")
        # self.F1Score = tfa.metrics.F1Score(num_classes=self.max_length_targ, average="micro")
        # input_tensor_train, input_tensor_val, target_tensor_train, target_tensor_val = train_test_split(
        #     self.input_tensor,
        #     self.target_tensor,
        #     test_size=0.2)
        # self.load_split_dataset(input_tensor_train,target_tensor_train)
        self.vocab_inp_size = len(self.inp_tokenizer.word_index) + 1
        self.vocab_tar_size = len(self.targ_tokenizer.word_index) + 1

        # encoder初始化
        self.encoder = Encoder(self.vocab_inp_size, self.embedding_dim,
                               self.units, self.BATCH_SIZE)
        plot_model(self.encoder,
                   to_file='encoder.png',
                   show_shapes=True,
                   show_layer_names=True,
                   rankdir='TB',
                   dpi=900,
                   expand_nested=True)
        # 样本输入
        # sample_hidden = self.encoder.initialize_hidden_state()
        # sample_output, sample_hidden = self.encoder.call(self.example_input_batch, sample_hidden)
        # print('Encoder output shape: (batch size, sequence length, units) {}'.format(sample_output.shape))
        # print('Encoder Hidden state shape: (batch size, units) {}'.format(sample_hidden.shape))

        # attention初始化
        attention_layer = BahdanauAttention(10)
        # attention_result, attention_weights = attention_layer(sample_hidden, sample_output)
        plot_model(attention_layer,
                   to_file='attention_layer.png',
                   show_shapes=True,
                   show_layer_names=True,
                   rankdir='TB',
                   dpi=900,
                   expand_nested=True)

        # print("Attention result shape: (batch size, units) {}".format(attention_result.shape))
        # print("Attention weights shape: (batch_size, sequence_length, 1) {}".format(attention_weights.shape))

        # decoder初始化
        self.decoder = Decoder(self.vocab_tar_size, self.embedding_dim,
                               self.units, self.BATCH_SIZE)
        plot_model(self.decoder,
                   to_file='decoder.png',
                   show_shapes=True,
                   show_layer_names=True,
                   rankdir='TB',
                   dpi=900,
                   expand_nested=True)
        # sample_decoder_output, _, _ = self.decoder(tf.random.uniform((self.BATCH_SIZE, 1)),
        #                                       sample_hidden, sample_output)
        #
        # print('Decoder output shape: (batch_size, vocab size) {}'.format(sample_decoder_output.shape))

        # optimizer初始化
        self.optimizer = tf.keras.optimizers.Adam()
        self.loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
            from_logits=True, reduction='none')

        # checkpoint & save model as object 初始化
        self.checkpoint_dir = './training_checkpoints'
        self.checkpoint_prefix = os.path.join(self.checkpoint_dir, "ckpt")
        self.checkpoint = tf.train.Checkpoint(optimizer=self.optimizer,
                                              encoder=self.encoder,
                                              decoder=self.decoder)

    def load_word2vec(self):
        import gensim
        path = "./word2vec_model/baike_26g_news_13g_novel_229g.bin"
        self.word2vec_model = gensim.models.KeyedVectors.load_word2vec_format(
            path, limit=500000, binary=True)
        print(self.word2vec_model.similarity("西红柿", "番茄"))

    def split_load_Data(self):
        # # 采用 80 - 20 的比例切分训练集和验证集
        # input_tensor_train, input_tensor_val, target_tensor_train, target_tensor_val = train_test_split(
        #     self.input_tensor,
        #     self.target_tensor,
        #     test_size=0.2)
        # input_tensor_train, input_tensor_val, = \
        self.splited_data_group = self.Repeated_KFold(
            input=self.input_tensor, target=self.target_tensor)
        # print(self.splited_data_group)
        # target_tensor_train, target_tensor_val = self.Repeated_KFold(input=self.target_tensor)
        # self.load_split_dataset(input_tensor_train, target_tensor_train)
    def Repeated_KFold(self, input, target, n_splits=10, n_repeats=1):
        # X = np.array([[1, 2], [3, 4], [1, 2], [3, 4]])
        # y = np.array([1, 2, 3, 4])
        splited_data_group = []
        kf = RepeatedKFold(n_splits=n_splits,
                           n_repeats=n_repeats,
                           random_state=0)
        for train_index, test_index in kf.split(input):
            splited_data = {}
            input_train = np.array(
                list(map(lambda x: input[x, :], train_index)))
            splited_data['input_train'] = input_train
            input_val = np.array(list(map(lambda x: input[x, :], test_index)))
            splited_data['input_val'] = input_val
            target_train = np.array(
                list(map(lambda x: target[x, :], train_index)))
            splited_data['target_train'] = target_train
            target_val = np.array(list(map(lambda x: target[x, :],
                                           test_index)))
            splited_data['target_val'] = target_val
            splited_data_group.append(splited_data)
            # print('train_index', train_index, 'test_index', test_index)
        return splited_data_group

    def load_split_dataset(self, input_tensor, target_tensor):
        # 显示长度
        # print(len(input_tensor_train), len(target_tensor_train), len(input_tensor_val), len(target_tensor_val))
        # 24000 24000 6000 6000
        # me_train = json.load(open(self.me_train_dir))
        # print(me_train['Q_TRN_010878'])
        # print(me_train['Q_TRN_010878']['question'])
        #
        # print("Input ; index to word mapping")
        # self.convert(self.inp_tokenizer, input_tensor_train[0])
        # print()
        # print("Target ; index to word mapping")
        # self.convert(self.targ_tokenizer, target_tensor_train[0])

        # 创建一个 tf.data 数据集
        self.BUFFER_SIZE = len(input_tensor)
        self.steps_per_epoch = len(input_tensor) // self.BATCH_SIZE

        dataset = tf.data.Dataset.from_tensor_slices(
            (input_tensor, target_tensor)).shuffle(
                self.BUFFER_SIZE)  # BUFFER_SIZE代表shuffle的程度,越大shuffle越强

        dataset = dataset.batch(
            self.BATCH_SIZE, drop_remainder=True
        )  # drop_remainder=True :数据达不到batch_size时抛弃(v1.10以上)
        self.example_input_batch, self.example_target_batch = next(
            iter(dataset))  # (v1.10以上)
        print(self.example_input_batch.shape,
              self.example_target_batch.shape)  #
        return dataset
        pass

    def json_load(self, s: str):
        return json.load(open(s))

    def get_stop_words(slef, path):
        with open(path, encoding='utf8') as f:
            return [l.strip() for l in f]

    def segement(self, strs):
        sw = self.stop_words
        words = jieba.lcut(strs, cut_all=False)
        # get_TF(words)
        # words = ["<start>"]+words
        # words = words.append("<end>")
        return [x for x in words if x not in sw]

    def get_TF(self, words, topK=20):
        tf_dic = {}
        for w in words:
            tf_dic[w] = tf_dic.get(w, 0) + 1
        return sorted(tf_dic.items(), key=lambda x: x[1], reverse=True)[:topK]

    def preprocess_sentence(self, s):
        words = self.segement(s)
        words = ["<start>"] + words + ["<end>"]
        return words
        pass

    def lls_to_ls(self, lls: list):
        '''
        :param lls: List[List[Str]]
        :return: List[Str]
        '''
        ls = []
        for l in lls:
            ls = ls + l
        # ls = str(lls)
        # ls = ls.replace('[', '')
        # ls = ls.replace(']', '')
        # ls = ls.split(',')
        print(ls)
        return ls

    def create_dataset(self, path, num_examples):
        data = self.json_load(path)
        # if num_examples ==None:
        #     num_examples = len(data)
        word_pairs = []
        pair = []
        for w in tqdm(list(data.values())[:num_examples]):
            lq = ["<start> "] + self.segement(w['question']) + [" <end>"]
            pair.append(" ".join(lq))

            # print(pair[0])
            for e in w['evidences'].values():
                if e['answer'][0] != 'no_answer':
                    # print(e['answer'])
                    la = ["<start> "] + self.segement(
                        e['answer'][0]) + [" <end>"]
                    pair.append(" ".join(la))
                    break
            if e['answer'][0] == 'no_answer':
                la = ["<start> "] + e['answer'] + [" <end>"]
                pair.append(" ".join(la))
            word_pairs.append(pair)
            # print(pair[1])
            # print(pair)
            pair = []
        # q= [x[0] for x in word_pairs]
        # print(q)
        # print(q[0])
        # print(q[-1])
        # q=lls_to_ls(q)
        # print("question的词频:",str(get_TF(q)))
        return zip(*word_pairs)
        # return [x[0] for x in word_pairs],[x[1] for x in word_pairs]

        # return zip(*zip(*word_pairs))

    def create_dataset_txt(self, path, num_examples):
        import io
        lines = io.open(path).read().strip().split('\n')

        word_pairs = [[self.preprocess_sentence(w) for w in l.split(' ')]
                      for l in lines[:num_examples]]

        # return [x[0] for x in word_pairs], [x[1] for x in word_pairs]
        return zip(*word_pairs)

    def create_dataset_csv(self, path, num_examples):
        import csv
        # lines = io.open(path, encoding='gbk').read().strip().split('\n')
        with open(path, "r") as csvfile:
            reader = csv.reader(csvfile)

            # 这里不需要readlines

        word_pairs = [[self.preprocess_sentence(w) for w in l.split('|')]
                      for l in reader[:num_examples]]

        # return [x[0] for x in word_pairs], [x[1] for x in word_pairs]
        return zip(*word_pairs)

    # Q,A=create_dataset(me_train_dir,None)
    # print(Q[-1])
    # print(A[-1])
    def max_length(self, tensor):
        result = max(len(t) for t in tensor)
        print('最大长度:')
        return result

    def average_length(self, tensor):
        result = sum(len(np.trim_zeros(t)) for t in tensor) // len(tensor)
        print('平均长度:')
        print(result)
        return result

    def tokenize(self, input):
        tokenizer = tf.keras.preprocessing.text.Tokenizer(filters='',
                                                          split=" ")
        # 生成词典
        tokenizer.fit_on_texts(input)
        # 文档化作词向量(张量)
        tensor = tokenizer.texts_to_sequences(input)
        # 结尾补零

        # tensor = tf.keras.preprocessing.sequence.pad_sequences(tensor,
        #                                                        padding='post',
        #                                                        truncating='post',
        #                                                        value=0)
        return tensor, tokenizer

    def load_dataset(self, path, num_examples=None, txt=True):
        if txt:
            # inp, targ = create_dataset(path, num_examples)
            inp, targ = self.create_dataset_txt(path, num_examples)
        else:
            inp, targ = self.create_dataset(path, num_examples)
        # inp,targ = list(temp)[0], list(temp)[1]
        input_tensor, inp_tokenizer = self.tokenize(inp)
        target_tensor, targ_tokenizer = self.tokenize(targ)
        # 计算目标张量的最大长度 (max_length)
        self.max_length_inp, self.max_length_targ = self.average_length(
            input_tensor), self.average_length(target_tensor)
        input_tensor = tf.keras.preprocessing.sequence.pad_sequences(
            input_tensor,
            maxlen=self.max_length_inp,
            padding='post',
            truncating='post',
            value=0)
        target_tensor = tf.keras.preprocessing.sequence.pad_sequences(
            target_tensor,
            maxlen=self.max_length_targ,
            padding='post',
            truncating='post',
            value=0)

        print(self.max_length_inp)
        print(self.max_length_targ)
        return input_tensor, target_tensor, inp_tokenizer, targ_tokenizer

    # print(me_train['Q_TRN_010878']['evidences']['Q_TRN_010878#05'])

    #
    # q1= me_train['Q_TRN_010878']['question']
    # print(list(me_train.values())[:1])
    # print(q1)
    # # q1 = preprocess_sentence(q1)
    # seg_q1= segement(q1)
    # print(seg_q1)
    def convert(self, tokenizer, tensor):
        for t in tensor:
            if t != 0:
                print("%d ----> %s" % (t, tokenizer.index_word[t]))

    # loss
    def loss_function(self, real, pred):
        mask = tf.math.logical_not(tf.math.equal(real, 0))
        loss_ = self.loss_object(real, pred)

        mask = tf.cast(mask, dtype=loss_.dtype)
        loss_ *= mask

        return tf.reduce_mean(loss_)
        # return tf.nn.softmax_cross_entropy_with_logits(loss_)

    @tf.function
    def train_step(self, inp, targ, enc_hidden):
        loss = 0
        accruacy = 0
        recall = 0
        with tf.GradientTape() as tape:
            enc_output, enc_hidden = self.encoder(inp, enc_hidden)

            dec_hidden = enc_hidden

            dec_input = tf.expand_dims(
                [self.targ_tokenizer.word_index['<start>']] * self.BATCH_SIZE,
                1)

            # 教师强制 - 将目标词作为下一个输入
            for t in range(1, targ.shape[1]):
                actual = targ[:, t]
                end = tf.constant([0])  #掩码掩去 padding 0的影响
                mask = tf.math.not_equal(actual, end)
                reduced_actual = tf.boolean_mask(actual, mask)
                # self.m.reset_states()
                # 将编码器输出 (enc_output) 传送至解码器
                predictions, dec_hidden, _ = self.decoder(
                    dec_input, dec_hidden, enc_output)
                predicted_id = tf.argmax(predictions, 1)
                reduced_predictions = tf.boolean_mask(predictions, mask)
                reduced_predicted_id = tf.argmax(reduced_predictions, 1)
                actual = tf.cast(actual, tf.float32)

                predicted_id = tf.cast(predicted_id, tf.float32)

                loss += self.loss_function(actual, predictions)

                reduced_actual_onehot = tf.one_hot(reduced_actual,
                                                   self.num_classes)
                # reduced_actual_onehot = tf.cast(reduced_actual_onehot,tf.float32)
                reduced_predicted_id_onehot = tf.one_hot(
                    reduced_predicted_id, self.num_classes)
                # reduced_predicted_id_onehot = tf.cast(reduced_predicted_id_onehot,tf.float32)
                self.train_confusion_matrix.update_state(
                    reduced_actual_onehot, reduced_predicted_id_onehot)
                _ = self.m.update_state(reduced_actual, reduced_predictions)
                reduced_actual = tf.cast(reduced_actual, tf.float32)
                reduced_predicted_id = tf.cast(reduced_predicted_id,
                                               tf.float32)

                # 使用教师强制:
                # dec_input = tf.expand_dims(targ[:, t], 1)
                dec_input = tf.expand_dims(reduced_predicted_id, 1)
                # #不使用教师强制:
                # dec_input = tf.expand_dims(predictions[:,t], 1)
        # recall = 1
        batch_accuracy = self.m.result()  #13组 12 个词的平均准确度
        batch_loss = (loss / int(targ.shape[1])
                      )  # 批量loss= batch_size个loss求和/batch_size
        # batch_pres, batch_recalls, batch_f1s = self.cal_precision_recall_F1Score(self.train_confusion_matrix.result().numpy())
        # self.train_confusion_matrix.reset_states()

        variables = self.encoder.trainable_variables + self.decoder.trainable_variables  # 就是 W 、b、U、h、c等参数
        gradients = tape.gradient(loss, variables)  # 就是对Loss 求上述各个参数的偏导
        self.optimizer.apply_gradients(zip(gradients, variables))

        return batch_loss, batch_accuracy

    def training(self, inp_train, tar_train, inp_val, tar_val):
        # train
        self.train_dataset = self.load_split_dataset(inp_train, tar_train)
        # self.val_dataset = self.load_split_dataset(inp_val,tar_val)
        EPOCHS = 20
        # gpu_options = tf.GPUOptions(allow_growth=True)
        # sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
        for epoch in range(EPOCHS):
            start = time.time()
            self.m.reset_states()
            self.recall = []
            # self.recall.reset_states()
            self.F1Score.reset_states()
            enc_hidden = self.encoder.initialize_hidden_state()
            total_loss = 0
            total_accuracy = 0
            total_pres = 0
            total_recalls = 0
            total_f1s = 0

            test_total_loss = 0
            test_total_accuracy = 0
            test_total_pres = 0
            test_total_recalls = 0
            test_total_f1s = 0
            for (batch, (inp, targ)) in enumerate(
                    self.train_dataset.take(self.steps_per_epoch)
            ):  # 将总数据集拆成steps_per_epoch(2000)个batch
                batch_loss, batch_accuracy = self.train_step(
                    inp, targ, enc_hidden)
                # batch_pres, batch_recalls, batch_f1s = self.cal_precision_recall_F1Score(self.train_confusion_matrix.result().numpy())
                # self.train_confusion_matrix.reset_states()
                total_loss += batch_loss
                total_accuracy += batch_accuracy
                # total_pres += batch_pres
                # total_recalls += batch_recalls
                # total_f1s += batch_f1s
                with self.train_summary_writer.as_default():
                    tf.summary.scalar('loss',
                                      total_loss / self.steps_per_epoch,
                                      step=epoch)
                    tf.summary.scalar('accuracy',
                                      total_accuracy / self.steps_per_epoch,
                                      step=epoch)
                #     tf.summary.scalar('Precision', total_pres / self.steps_per_epoch, step=epoch)
                #     tf.summary.scalar('Recall', total_recalls / self.steps_per_epoch, step=epoch)
                #     tf.summary.scalar('F1Score', total_f1s / self.steps_per_epoch, step=epoch)
            # for (batch, (inp, targ)) in enumerate(self.val_dataset.take(self.steps_per_epoch)):
            #     test_batch_loss, test_batch_accuracy, test_batch_pres, test_batch_recalls, test_batch_f1s = self.evaluate(inp_vec=inp, targ_vec=targ)
            #     test_total_loss += batch_loss
            #     test_total_accuracy += batch_accuracy
            #     test_total_pres += batch_pres
            #     test_total_recalls += batch_recalls
            #     test_total_f1s += batch_f1s
            if batch % 100 == 0:
                print('Epoch {} Batch {} Loss {:.4f} & Accuracy {:.4f}'.format(
                    epoch + 1, batch, batch_loss.numpy(),
                    batch_accuracy.numpy()))
                # print('Test Epoch {} Batch {} Loss {:.4f} & Accuracy {:.4f} & Precision {:.4f} & Recall {:.4f} & F1Score {:.4f}'.format(
                #         epoch + 1,
                #         batch,
                #         test_batch_loss.numpy(),
                #         test_batch_accuracy.numpy(),
                #         test_batch_pres,
                #         test_batch_recalls,
                #         test_batch_f1s))
            # 每 2 个周期(epoch),保存(检查点)一次模型
            if (epoch + 1) % 2 == 0:
                self.checkpoint.save(file_prefix=self.checkpoint_prefix)

            print('Epoch {} Loss {:.4f} & Accuracy{:.4f}'.format(
                epoch + 1, total_loss / self.steps_per_epoch,
                total_accuracy / self.steps_per_epoch))

            # print('Test Epoch {} Loss {:.4f} & Accuracy{:.4f} & F1Score{:.4f}'.format(epoch + 1,
            #       test_total_loss / self.steps_per_epoch,
            #       test_total_accuracy / self.steps_per_epoch,
            #       test_total_pres / self.steps_per_epoch,
            #       test_total_recalls / self.steps_per_epoch,
            #       test_total_f1s / self.steps_per_epoch))
            print('Time taken for 1 epoch {} sec\n'.format(time.time() -
                                                           start))
            #重置指标
            self.m.reset_states()
            # print(self.train_confusion_matrix.result().numpy())
            pres, recalls, f1s = self.cal_precision_recall_F1Score(
                self.train_confusion_matrix.result().numpy())
            self.train_confusion_matrix.reset_states()
            with self.train_summary_writer.as_default():
                tf.summary.scalar('Precision', pres, step=epoch)
                tf.summary.scalar('Recall', recalls, step=epoch)
                tf.summary.scalar('F1Score', f1s, step=epoch)
            print('Epoch {} precision {:.4f} & recall{:.4f} & F1Score{:.4f}'.
                  format(epoch + 1, pres, recalls, f1s))
        # end train

    # 翻译

    # @tf.function
    def evaluate(self, sentence='', inp_vec=[], targ_vec=[]):
        attention_plot = np.zeros((self.max_length_targ, self.max_length_inp))
        if inp_vec == []:
            inp = self.preprocess_sentence(sentence)

            # inp = sentence.split(' ')
            print(self.inp_tokenizer)

            inputs = []
            for i in inp:
                if i in self.inp_tokenizer.word_index:
                    inputs.append(self.inp_tokenizer.word_index[i])
                else:
                    self.active_learning(i)
                    # print('请问\''+i+'\'是什么意思?')
                    # inputs.append('')
        else:
            inputs = inp_vec
        # inputs = [inp_tokenizer.word_index[i] for i in inp]#词向量
        inputs = tf.keras.preprocessing.sequence.pad_sequences(
            [inputs], maxlen=self.max_length_inp, padding='post', value=0)
        inputs = tf.convert_to_tensor(inputs)
        result_ids = [1]
        result_predictions = []
        result = ''

        hidden = [tf.zeros((1, self.units))]  # 先初始化,后面restore后覆盖这个变量
        enc_out, enc_hidden = self.encoder(inputs, hidden)  # 这里是调用call方法
        loss = 0
        dec_hidden = enc_hidden
        dec_input = tf.expand_dims([self.targ_tokenizer.word_index['<start>']],
                                   0)
        for t in range(1, self.max_length_targ):
            predictions, dec_hidden, attention_weights = self.decoder(
                dec_input, dec_hidden, enc_out)
            result_predictions.append(predictions)
            # result_predictions.append(predictions)
            # 存储注意力权重以便后面制图
            attention_weights = tf.reshape(attention_weights, (-1, ))
            attention_plot[t] = attention_weights.numpy()

            predicted_id = tf.argmax(predictions[0]).numpy(
            )  # 对于文本Gan而言 因为GAN的Discriminator需要一整个句子(具体来说是它们的embedding)作为输入,所以需要用max返回的下标(就相当于argmax操作)得到one-hot,再得到embedding,然后丢给Discriminator。
            #以下是通过target_vec 计算指标
            if targ_vec != []:
                end = tf.constant([0])  # 掩码掩去 padding 0的影响
                mask = tf.math.not_equal([targ_vec[t]], end)
                actual = tf.cast([targ_vec[t]], tf.float32)
                reduced_actual = tf.boolean_mask(actual, mask)
                mask2 = tf.math.not_equal([predicted_id], end)
                predicted_idx = tf.cast([predicted_id], tf.float32)
                reduced_predictions = tf.boolean_mask(predictions, mask2)
                reduced_predicted_idx = tf.boolean_mask(predicted_idx, mask2)
                _ = self.m.update_state(reduced_actual, reduced_predictions)
                _ = self.F1Score.update_state(reduced_actual,
                                              reduced_predicted_idx)

                reduced_actual = tf.cast(reduced_actual, tf.int32)
                reduced_predicted_idx = tf.cast(reduced_predicted_idx,
                                                tf.int32)
                # loss += self.loss_function(actual, predictions)
                reduced_actual_onehot = tf.one_hot(reduced_actual,
                                                   self.num_classes)
                # reduced_actual_onehot = tf.cast(reduced_actual_onehot,tf.float32)
                reduced_predicted_id_onehot = tf.one_hot(
                    reduced_predicted_idx, self.num_classes)
                # reduced_predicted_id_onehot = tf.cast(reduced_predicted_id_onehot,tf.float32)

                self.test_confusion_matrix.update_state(
                    reduced_actual_onehot, reduced_predicted_id_onehot)
                # n = len(self.targ_tokenizer.index_word)
                # self.recall = [0] * n
                # self.precision = [0] * n
                # for k in range(n):
                #     re = tf.keras.metrics.Recall()
                #     prec = tf.keras.metrics.Precision()
                #
                #     y_true = tf.equal(reduced_actual, k)
                #     y_pred = tf.equal(reduced_predicted_idx, k)
                #     re.update_state(y_true, y_pred)
                #     prec.update_state(y_true, y_pred)
                #     self.recall[k] = re.result().numpy()
                #     self.precision[k] = prec.result().numpy()

            # if self.targ_tokenizer.index_word[predicted_id] == '<end>' and t == 1:
            #     temp_predicts = list(predictions[0])
            #     del temp_predicts[predicted_id]
            #     predicted_id = tf.argmax(temp_predicts).numpy()
            #     result += self.targ_tokenizer.index_word[predicted_id] + ' '
            #     result_ids.append(predicted_id)

            if self.targ_tokenizer.index_word[predicted_id] == '<end>':
                result_ids.append(2)
                result_ids = tf.keras.preprocessing.sequence.pad_sequences(
                    [result_ids],
                    maxlen=self.max_length_targ,
                    padding='post',
                    value=0)
                return result, sentence, attention_plot, result_ids, result_predictions
            else:
                result += self.targ_tokenizer.index_word[predicted_id] + ' '
                result_ids.append(predicted_id)
                # result_predictions.append(predictions)
            # 预测的 ID 被输送回模型
            dec_input = tf.expand_dims([predicted_id], 0)

        result_ids = tf.keras.preprocessing.sequence.pad_sequences(
            [result_ids], maxlen=self.max_length_targ, padding='post', value=0)

        # batch_accuracy = self.m.result()  # 13组 12 个词的平均准确度
        # batch_loss = loss   # 批量loss= batch_size个loss求和/batch_size
        # batch_pres, batch_recalls, batch_f1s = self.cal_precision_recall_F1Score(self.train_confusion_matrix.result().numpy())
        return result, sentence, attention_plot, result_ids, result_predictions,

    def getYesterday(self):
        today = datetime.date.today()
        oneday = datetime.timedelta(days=1)
        yesterday = today - oneday
        return yesterday

    def Automatic_train(self):
        filename = 'new_knowledge_' + datetime.datetime.now().strftime(
            '%Y-%m-%d') + '.txt'
        self.checkpoint.restore(tf.train.latest_checkpoint(
            self.checkpoint_dir))

        new_QA_dir = './new_knowledge_' + self.getYesterday().strftime(
            '%Y-%m-%d') + '.txt'

        input_tensor, target_tensor, inp_tokenizer, targ_tokenizer = self.load_dataset(
            new_QA_dir)
        new_max_length_inp, new_max_length_tar = self.max_length(
            input_tensor), self.max_length(target_tensor)
        self.max_length_inp, self.max_length_targ = max(
            new_max_length_inp,
            self.max_length_inp), max(new_max_length_tar, self.max_length_targ)
        # new_max_length = self.max_length()
        # 采用 80 - 20 的比例切分训练集和验证集
        input_tensor_train, target_tensor_train = input_tensor, target_tensor
        dataset = tf.data.Dataset.from_tensor_slices(
            (input_tensor_train, target_tensor_train)).shuffle(
                self.BUFFER_SIZE)  # BUFFER_SIZE代表shuffle的程度,越大shuffle越强
        dataset = dataset.batch(
            self.BATCH_SIZE, drop_remainder=True
        )  # drop_remainder=True :数据达不到batch_size时抛弃(v1.10以上)
        # train
        EPOCHS = 5
        # gpu_options = tf.GPUOptions(allow_growth=True)
        # sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
        for epoch in range(EPOCHS):
            start = time.time()

            enc_hidden = self.encoder.initialize_hidden_state()
            total_loss = 0

            for (batch,
                 (inp,
                  targ)) in enumerate(dataset.take(self.steps_per_epoch)
                                      ):  # 将总数据集拆成steps_per_epoch(2000)个batch
                batch_loss = self.train_step(inp, targ, enc_hidden)
                total_loss += batch_loss

                if batch % 100 == 0:
                    print('Epoch {} Batch {} Loss {:.4f}'.format(
                        epoch + 1, batch, batch_loss.numpy()))
            # 每 2 个周期(epoch),保存(检查点)一次模型
            if (epoch + 1) % 2 == 0:
                self.checkpoint.save(file_prefix=self.checkpoint_prefix)

            print('Epoch {} Loss {:.4f}'.format(
                epoch + 1, total_loss / self.steps_per_epoch))
            print('Time taken for 1 epoch {} sec\n'.format(time.time() -
                                                           start))
        # end train

    def active_learning(self, s):
        q = '请问\'' + s + '\'是什么意思?'
        # print(q)
        answer = input(q)
        if answer == '不知道':
            return
        filename = 'new_knowledge_' + datetime.datetime.now().strftime(
            '%Y-%m-%d') + '.txt'
        with open(filename, 'a') as file_object:
            file_object.write(q + ' ' + answer + '\n')
        file_object.close()

    # 注意力权重制图函数
    def plot_attention(self, attention, sentence, predicted_sentence):
        fig = plt.figure(figsize=(10, 10))
        ax = fig.add_subplot(1, 1, 1)
        ax.matshow(attention, cmap='viridis')

        fontdict = {'fontsize': 14}

        ax.set_xticklabels([''] + sentence, fontdict=fontdict, rotation=90)
        ax.set_yticklabels([''] + predicted_sentence, fontdict=fontdict)

        ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
        ax.yaxis.set_major_locator(ticker.MultipleLocator(1))

        plt.show()

    def translate(self, sentence):
        result, sentence, attention_plot, result_ids, result_predictions = self.evaluate(
            sentence)

        print('Input: %s' % (sentence))
        print('Predicted translation: {}'.format(result))
        # sentence =preprocess_sentence(sentence)
        # result = preprocess_sentence(result)
        # attention_plot = attention_plot[:len(result.split(' ')), :len(sentence.split(' '))]
        # preprocess_sentence

        attention_plot = attention_plot[:len(self.preprocess_sentence(
            result)), :len(self.preprocess_sentence(sentence))]
        self.plot_attention(attention_plot, self.preprocess_sentence(sentence),
                            self.preprocess_sentence(result))

    def bleu(self, pred_tokens, label_tokens, k):
        len_pred, len_label = len(pred_tokens), len(label_tokens)
        score = math.exp(min(0, 1 - len_label / len_pred))
        for n in range(1, k + 1):
            num_matches, label_subs = 0, collections.defaultdict(int)
            for i in range(len_label - n + 1):
                pass
                label_subs[''.join(
                    list(
                        map(lambda x: self.targ_tokenizer.index_word[x],
                            label_tokens[i:i + n])))] += 1
            for i in range(len_pred - n + 1):
                if label_subs[''.join(
                        list(
                            map(lambda x: self.targ_tokenizer.index_word[x],
                                pred_tokens[i:i + n])))] > 0:
                    num_matches += 1
                    label_subs[''.join(
                        list(
                            map(lambda x: self.targ_tokenizer.index_word[x],
                                pred_tokens[i:i + n])))] -= 1
            score *= math.pow(num_matches / (len_pred - n + 1),
                              math.pow(0.5, n))
        return score

    def cal_precision_recall_F1Score(self, cf_matrixes):
        num_classes = len(cf_matrixes)
        file_name = "wri2te_test.txt"
        # 以写入的方式打开
        f = open(file_name, 'wb')
        # 写入内容
        total_precision = 0
        total_recall = 0
        total_F1Score = 0

        num_precison = 0
        num_recall = 0
        num_f1score = 0
        # np.array(cf_matrixes).tofile("result.txt")
        for i in range(num_classes):

            # f.write(np.array(cf_matrixes[i]).tostring().)

            # f.write('\n')
            precision = 0
            recall = 0
            f1score = 0
            # if( cf_matrixes[i, 1, 1] == 0 and cf_matrixes[i,0,1] ==0 and  )

            if (cf_matrixes[i, 1, 1] + cf_matrixes[i, 0, 1]) == 0:
                pass
            else:
                num_precison += 1
                precision = cf_matrixes[i, 1, 1] / (cf_matrixes[i, 1, 1] +
                                                    cf_matrixes[i, 0, 1])
                total_precision += precision

            if (cf_matrixes[i, 1, 1] + cf_matrixes[i, 1, 0]) == 0:
                pass
            else:
                num_recall += 1
                recall = cf_matrixes[i, 1, 1] / (cf_matrixes[i, 1, 1] +
                                                 cf_matrixes[i, 1, 0])
                total_recall += recall

            if (2 * cf_matrixes[i, 1, 1] + cf_matrixes[i, 0, 1] +
                    cf_matrixes[i, 1, 0]) == 0:
                pass
            else:
                num_f1score += 1
                f1score = 2 * cf_matrixes[i, 1, 1] / (
                    2 * cf_matrixes[i, 1, 1] + cf_matrixes[i, 0, 1] +
                    cf_matrixes[i, 1, 0])
                total_F1Score += f1score
        # print(total_precision,total_recall,total_F1Score)
        average_precision = 0
        average_recall = 0
        average_f1score = 0
        if num_precison != 0:
            average_precision = total_precision / num_precison
        if num_recall != 0:
            average_recall = total_recall / num_recall
        if num_f1score != 0:
            average_f1score = total_F1Score / num_f1score
        return average_precision, average_recall, average_f1score
示例#7
0
def test(args: Namespace):
    cfg = json.load(open(args.config_path, 'r', encoding='UTF-8'))

    batch_size = 1  # for predicting one sentence.

    encoder = Encoder(cfg['vocab_input_size'], cfg['embedding_dim'],
                      cfg['units'], batch_size, 0)
    decoder = Decoder(cfg['vocab_target_size'], cfg['embedding_dim'],
                      cfg['units'], cfg['method'], batch_size, 0)
    optimizer = select_optimizer(cfg['optimizer'], cfg['learning_rate'])

    ckpt = tf.train.Checkpoint(optimizer=optimizer,
                               encoder=encoder,
                               decoder=decoder)
    manager = tf.train.CheckpointManager(ckpt,
                                         cfg['checkpoint_dir'],
                                         max_to_keep=3)
    ckpt.restore(manager.latest_checkpoint)

    while True:
        sentence = input(
            'Input Sentence or If you want to quit, type Enter Key : ')

        if sentence == '': break

        sentence = re.sub(r"(\.\.\.|[?.!,¿])", r" \1 ", sentence)
        sentence = re.sub(r'[" "]+', " ", sentence)

        sentence = '<s> ' + sentence.lower().strip() + ' </s>'

        input_vocab = load_vocab('./data/', 'en')
        target_vocab = load_vocab('./data/', 'de')

        input_lang_tokenizer = tf.keras.preprocessing.text.Tokenizer(
            filters='', oov_token='<unk>')
        input_lang_tokenizer.word_index = input_vocab

        target_lang_tokenizer = tf.keras.preprocessing.text.Tokenizer(
            filters='', oov_token='<unk>')
        target_lang_tokenizer.word_index = target_vocab

        convert_vocab(input_lang_tokenizer, input_vocab)
        convert_vocab(target_lang_tokenizer, target_vocab)

        inputs = [
            input_lang_tokenizer.word_index[i]
            if i in input_lang_tokenizer.word_index else
            input_lang_tokenizer.word_index['<unk>']
            for i in sentence.split(' ')
        ]
        inputs = tf.keras.preprocessing.sequence.pad_sequences(
            [inputs], maxlen=cfg['max_len_input'], padding='post')

        inputs = tf.convert_to_tensor(inputs)

        result = ''

        enc_hidden = encoder.initialize_hidden_state()
        enc_cell = encoder.initialize_cell_state()
        enc_state = [[enc_hidden, enc_cell], [enc_hidden, enc_cell],
                     [enc_hidden, enc_cell], [enc_hidden, enc_cell]]

        enc_output, enc_hidden = encoder(inputs, enc_state)

        dec_hidden = enc_hidden
        #dec_input = tf.expand_dims([target_lang_tokenizer.word_index['<eos>']], 0)
        dec_input = tf.expand_dims([target_lang_tokenizer.word_index['<s>']],
                                   1)

        print('dec_input:', dec_input)

        h_t = tf.zeros((batch_size, 1, cfg['embedding_dim']))

        for t in range(int(cfg['max_len_target'])):
            predictions, dec_hidden, h_t = decoder(dec_input, dec_hidden,
                                                   enc_output, h_t)

            # predeictions shape == (1, 50002)

            predicted_id = tf.argmax(predictions[0]).numpy()
            print('predicted_id', predicted_id)

            result += target_lang_tokenizer.index_word[predicted_id] + ' '

            if target_lang_tokenizer.index_word[predicted_id] == '</s>':
                print('Early stopping')
                break

            dec_input = tf.expand_dims([predicted_id], 1)
            print('dec_input:', dec_input)

        print('<s> ' + result)
        print(sentence)
        sys.stdout.flush()
示例#8
0
def train(args: Namespace):
    input_tensor, target_tensor, input_lang_tokenizer, target_lang_tokenizer = load_dataset(
        './data/', args.max_len, limit_size=None)

    max_len_input = len(input_tensor[0])
    max_len_target = len(target_tensor[0])

    print('max len of each seq:', max_len_input, ',', max_len_target)

    input_tensor_train, input_tensor_val, target_tensor_train, target_tensor_val = train_test_split(
        input_tensor, target_tensor, test_size=args.dev_split)

    # init hyperparameter
    EPOCHS = args.epoch
    batch_size = args.batch_size
    steps_per_epoch = len(input_tensor_train) // batch_size
    embedding_dim = args.embedding_dim
    units = args.units
    vocab_input_size = len(input_lang_tokenizer.word_index) + 1
    vocab_target_size = len(target_lang_tokenizer.word_index) + 1
    BUFFER_SIZE = len(input_tensor_train)
    learning_rate = args.learning_rate

    setattr(args, 'max_len_input', max_len_input)
    setattr(args, 'max_len_target', max_len_target)

    setattr(args, 'steps_per_epoch', steps_per_epoch)
    setattr(args, 'vocab_input_size', vocab_input_size)
    setattr(args, 'vocab_target_size', vocab_target_size)
    setattr(args, 'BUFFER_SIZE', BUFFER_SIZE)

    dataset = tf.data.Dataset.from_tensor_slices(
        (input_tensor_train, target_tensor_train)).shuffle(BUFFER_SIZE)
    dataset = dataset.batch(batch_size)

    print('dataset shape (batch_size, max_len):', dataset)

    encoder = Encoder(vocab_input_size, embedding_dim, units, batch_size,
                      args.dropout)
    decoder = Decoder(vocab_target_size, embedding_dim, units, args.method,
                      batch_size, args.dropout)

    optimizer = select_optimizer(args.optimizer, args.learning_rate)

    loss_object = tf.losses.SparseCategoricalCrossentropy(from_logits=True,
                                                          reduction='none')

    @tf.function
    def train_step(_input, _target, enc_state):
        loss = 0

        with tf.GradientTape() as tape:
            enc_output, enc_state = encoder(_input, enc_state)

            dec_hidden = enc_state

            dec_input = tf.expand_dims(
                [target_lang_tokenizer.word_index['<s>']] * batch_size, 1)

            # First input feeding definition
            h_t = tf.zeros((batch_size, 1, embedding_dim))

            for idx in range(1, _target.shape[1]):
                # idx means target character index.
                predictions, dec_hidden, h_t = decoder(dec_input, dec_hidden,
                                                       enc_output, h_t)

                #tf.print(tf.argmax(predictions, axis=1))

                loss += loss_function(loss_object, _target[:, idx],
                                      predictions)

                dec_input = tf.expand_dims(_target[:, idx], 1)

        batch_loss = (loss / int(_target.shape[1]))

        variables = encoder.trainable_variables + decoder.trainable_variables

        gradients = tape.gradient(loss, variables)

        optimizer.apply_gradients(zip(gradients, variables))

        return batch_loss

    # Setting checkpoint
    now_time = dt.datetime.now().strftime("%m%d%H%M")
    checkpoint_dir = './training_checkpoints/' + now_time
    setattr(args, 'checkpoint_dir', checkpoint_dir)
    checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
    checkpoint = tf.train.Checkpoint(optimizer=optimizer,
                                     encoder=encoder,
                                     decoder=decoder)

    os.makedirs(checkpoint_dir, exist_ok=True)

    # saving information of the model
    with open('{}/config.json'.format(checkpoint_dir), 'w',
              encoding='UTF-8') as fout:
        json.dump(vars(args), fout, indent=2, sort_keys=True)

    min_total_loss = 1000

    for epoch in range(EPOCHS):
        start = time.time()

        enc_hidden = encoder.initialize_hidden_state()
        enc_cell = encoder.initialize_cell_state()
        enc_state = [[enc_hidden, enc_cell], [enc_hidden, enc_cell],
                     [enc_hidden, enc_cell], [enc_hidden, enc_cell]]

        total_loss = 0

        for (batch, (_input,
                     _target)) in enumerate(dataset.take(steps_per_epoch)):
            batch_loss = train_step(_input, _target, enc_state)
            total_loss += batch_loss

            if batch % 10 == 0:
                print('Epoch {}/{} Batch {}/{} Loss {:.4f}'.format(
                    epoch + 1, EPOCHS, batch + 10, steps_per_epoch,
                    batch_loss.numpy()))

        print('Epoch {}/{} Total Loss per epoch {:.4f} - {} sec'.format(
            epoch + 1, EPOCHS, total_loss / steps_per_epoch,
            time.time() - start))

        # saving checkpoint
        if min_total_loss > total_loss / steps_per_epoch:
            print('Saving checkpoint...')
            min_total_loss = total_loss / steps_per_epoch
            checkpoint.save(file_prefix=checkpoint_prefix)

        print('\n')
示例#9
0
def test():
    #Dropout rate is 0 in test setup?
    BATCH_SIZE = args.batch_size
    EMBED_DIM = args.embeddingDim
    MAXLEN = args.maxLen
    NUM_UNITS = args.units
    CKPT = args.checkpoint
    LEARNING_RATE = args.learning_rate
    EPOCH = 1
    DROPOUT = 0
    start_word = "<s>"
    end_word = "</s>"

    test_source_tensor, test_source_tokenizer, test_target_tensor, test_target_tokenizer = \
        load_test_data(test_translate_from=r"./data/newstest2015.en", test_translate_to=r'./data/newstest2015.de',
           vocab_from=r'./data/vocab.50K.en', vocab_to=r'./data/vocab.50K.de',
           pad_length=90, limit=args.limit)
    #test_source_tensor, test_source_tokenizer, test_target_tensor, test_target_tokenizer = \
    #    load_data(pad_length = MAXLEN, limit=None)
    print(len(test_source_tensor))
    vocab_source_size = len(test_source_tokenizer.word_index) + 1
    print("vocab_input_size: ", vocab_source_size)
    vocab_target_size = len(test_target_tokenizer.word_index) + 1
    print("vocab_target_size: ", vocab_target_size)
    optimizer = tf.optimizers.Adam(learning_rate=LEARNING_RATE)
    buffer_size = len(test_source_tensor)

    test_steps = len(test_source_tensor) // BATCH_SIZE
    #print(test_source_tensor[0])
    #print("Type: ",type(test_source_tensor))
    #for ele in test_target_tensor:
    #    print(type(ele))
    #print("dir of test_target__token: ",dir(test_target_tokenizer))
    #print(test_target_tokenizer.index_word)
    dataset = train_input_fn(test_source_tensor, test_target_tensor,
                             buffer_size, EPOCH, BATCH_SIZE)

    # apply_loss = tf.losses.SparseCategoricalCrossentropy(from_logits=True)
    encoder = Encoder(vocab_source_size,
                      EMBED_DIM,
                      NUM_UNITS,
                      dropout_rate=DROPOUT,
                      batch_size=BATCH_SIZE)
    decoder = Decoder(vocab_target_size,
                      EMBED_DIM,
                      NUM_UNITS,
                      batch_size=BATCH_SIZE,
                      method=None,
                      dropout_rate=DROPOUT)
    optimizer = tf.compat.v1.train.GradientDescentOptimizer(
        learning_rate=0.001)
    ckpt = tf.train.Checkpoint(optimizer=optimizer,
                               encoder=encoder,
                               decoder=decoder)
    manager = tf.train.CheckpointManager(ckpt, args.checkpoint, max_to_keep=10)
    ckpt.restore(manager.latest_checkpoint)
    per_epoch_loss, per_epoch_plex = 0, 0

    def test_wrapper(source, target):
        # source_out, source_state, source_trainable_var, tape = encoder(source, encoder_state, vocab_source_size,
        #                                                          EMBED_DIM, NUM_UNITS, activation="tanh",
        #                                                          dropout_rate = DROPOUT)
        result = ""
        source_out, source_state = encoder(source,
                                           encoder_state,
                                           activation="tanh")

        initial = tf.expand_dims(
            [test_target_tokenizer.word_index[start_word]] * BATCH_SIZE, 1)
        attention_state = tf.zeros((BATCH_SIZE, 1, EMBED_DIM))
        apply_loss = tf.losses.SparseCategoricalCrossentropy(from_logits=True)
        # cur_total_loss is a sum of loss for current steps, namely batch loss
        cur_total_loss, cur_total_plex, cur_loss = 0, 0, 0
        #print("word_index: ", test_target_tokenizer.word_index,len(test_target_tokenizer.word_index))
        #print("index_word: ",test_target_tokenizer.index_word)
        for i in range(target.shape[1]):
            output_state, source_state, attention_state = decoder(
                initial, source_state, source_out, attention_state)
            # TODO: check for the case where target is 0
            # 0 should be the padding value in target.
            # I assumed that there should not be 0 value in target
            # for safety reason, we apply this mask to final loss
            # Mask is a array contains binary value(0 or 1)
            #print(output_state.numpy().shape)
            #print(output_state[0].numpy())
            cur_loss = apply_loss(target[:, i], output_state)
            perplex = tf.nn.sparse_softmax_cross_entropy_with_logits(
                target[:, i], output_state)
            current_ind = tf.argmax(output_state[0])
            mask = tf.math.logical_not(tf.math.equal(target[:, i], 0))
            mask = tf.cast(mask, dtype=cur_loss.dtype)
            cur_loss *= mask
            perplex *= mask
            cur_total_loss += tf.reduce_mean(cur_loss)
            cur_total_plex += tf.reduce_mean(perplex)
            #tf.print("check current id: ",current_ind)
            #tf.print(test_target_tokenizer.index_word[29])
            #print(current_ind.numpy())
            if current_ind.numpy() == 0:
                # 0 is for pad value, we don't need to record it
                continue
            result += test_target_tokenizer.index_word[current_ind.numpy()]
            if test_target_tokenizer.index_word[
                    current_ind.numpy()] == end_word:
                break
            initial = tf.expand_dims(target[:, i], 1)
        batch_loss = cur_total_loss / target.shape[1]
        batch_perplex = cur_total_plex / target.shape[1]
        return batch_loss, batch_perplex, result

    encoder_hidden = encoder.initialize_hidden_state()
    encoder_ceil = encoder.initialize_cell_state()
    encoder_state = [[encoder_hidden, encoder_ceil],
                     [encoder_hidden, encoder_ceil],
                     [encoder_hidden, encoder_ceil],
                     [encoder_hidden, encoder_ceil]]
    # TODO : Double check to make sure all re-initialization is performed
    result_by_batch = []
    for idx, data in tqdm(enumerate(dataset.take(test_steps)),
                          total=args.limit):
        source, target = data
        batch_loss, batch_perplex, result = test_wrapper(source, target)
        with open("checkpoint/test_logger.txt", "a") as filelogger:
            print("The validation loss in batch " + str(idx) + " is : ",
                  str(batch_loss.numpy() / (idx + 1.0)),
                  file=filelogger)
            print("The validation perplex in batch " + str(idx) + " is : ",
                  str(batch_perplex.numpy() / (idx + 1.0)),
                  file=filelogger)
        per_epoch_loss += batch_loss
        per_epoch_plex += batch_perplex
        assert type(result) == str
        result_by_batch.append(result)
        #if idx>=3:
        #    break
    with open("checkpoint/test_logger.txt", "a") as filelogger:
        print("The validation loss is : ",
              str(per_epoch_loss.numpy() / (idx + 1.0)),
              file=filelogger)
        print("The validation perplex is: ",
              str(tf.exp(per_epoch_plex).numpy() / (idx + 1.0)),
              file=filelogger)
    return test_target_tokenizer, result_by_batch
示例#10
0
def train():
    # args is a global variable in this task
    BATCH_SIZE = args.batch_size
    EPOCH = args.epoch
    EMBED_DIM = args.embeddingDim
    MAXLEN = args.maxLen
    NUM_UNITS = args.units
    LEARNING_RATE = args.learning_rate
    DROPOUT = args.dropout
    METHOD = args.method
    GPUNUM = args.gpuNum
    CKPT = args.checkpoint
    LIMIT = args.limit
    start_word = "<s>"
    end_word = "</s>"
    #Here, tokenizer saves all info to split data.
    #Itself is not a part of data.
    train_source_tensor, train_source_tokenizer, train_target_tensor, train_target_tokenizer = \
        load_data(pad_length = MAXLEN, limit=LIMIT)
    buffer_size = len(train_source_tensor)
    train_source_tensor, val_source_tensor, train_target_tensor, val_target_tensor = \
        train_test_split(train_source_tensor, train_target_tensor, random_state=2019)

    #TODO: check if we need target tokenizer
    training_steps = len(train_source_tensor) // BATCH_SIZE
    vocab_source_size = len(train_source_tokenizer.word_index) + 1
    print("vocab_input_size: ", vocab_source_size)
    vocab_target_size = len(train_target_tokenizer.word_index) + 1
    print("vocab_target_size: ", vocab_target_size)

    step = tf.Variable(0, trainable=False)
    # boundaries = [100, 200]
    # values = [1.0, 0.5, 0.1]
    # boundaries = [30, 40]
    # values = [1.0, 0.5, 0.0]
    # learning_rate_fn = tf.compat.v1.train.piecewise_constant(step,
    #     boundaries, values)
    # optimizer = tf.optimizers.SGD(learning_rate=learning_rate_fn(step))
    optimizer = tf.compat.v1.train.GradientDescentOptimizer(
        learning_rate=0.001)
    # set up checkpoint
    if not os.path.exists(CKPT):
        os.makedirs(CKPT)
    else:
        print(
            "Warning: current Checkpoint dir already exist! ",
            "\nPlease consider to choose a new dir to save your checkpoint!")
    checkpoint = tf.train.Checkpoint(optimzier=optimizer)
    checkpoint_prefix = os.path.join(CKPT, "ckpt")

    dataset = train_input_fn(train_source_tensor, train_target_tensor,
                             buffer_size, EPOCH, BATCH_SIZE)
    apply_loss = tf.losses.SparseCategoricalCrossentropy(from_logits=True)
    encoder = Encoder(vocab_source_size,
                      EMBED_DIM,
                      NUM_UNITS,
                      dropout_rate=DROPOUT,
                      batch_size=BATCH_SIZE)
    decoder = Decoder(vocab_target_size,
                      EMBED_DIM,
                      NUM_UNITS,
                      batch_size=BATCH_SIZE,
                      method=None,
                      dropout_rate=DROPOUT)

    def train_wrapper(source, target):
        # with tf.GradientTape(watch_accessed_variables=False) as tape:
        with tf.GradientTape() as tape:
            # source_out, source_state, source_trainable_var, tape = encoder(source, encoder_state, vocab_source_size,
            #                                                          EMBED_DIM, NUM_UNITS, activation="tanh",
            #                                                          dropout_rate = DROPOUT)
            source_out, source_state = encoder(source,
                                               encoder_state,
                                               activation="tanh")

            initial = tf.expand_dims(
                [train_target_tokenizer.word_index[start_word]] * BATCH_SIZE,
                1)
            attention_state = tf.zeros((BATCH_SIZE, 1, EMBED_DIM))
            # cur_total_loss is a sum of loss for current steps, namely batch loss
            cur_total_loss, cur_loss = 0, 0
            for i in range(1, target.shape[1]):
                output_state, source_state, attention_state = decoder(
                    initial, source_state, source_out, attention_state)
                # TODO: check for the case where target is 0
                cur_loss = apply_loss(target[:, i], output_state)
                # 0 should be the padding value in target.
                # I assumed that there should not be 0 value in target
                # for safety reason, we apply this mask to final loss
                # Mask is a array contains binary value(0 or 1)
                mask = tf.math.logical_not(tf.math.equal(target[:, i], 0))
                mask = tf.cast(mask, dtype=cur_loss.dtype)
                cur_loss *= mask
                cur_total_loss += tf.reduce_mean(cur_loss)
                initial = tf.expand_dims(target[:, i], 1)
                # print(cur_loss)
                # print(cur_total_loss)
        batch_loss = cur_total_loss / target.shape[1]
        ## debug
        variables = encoder.trainable_variables + decoder.trainable_variables
        # print("check variable: ", len(variables))
        #variables = encoder.trainable_variables
        # print("check var:", len(variables), variables[12:])
        gradients = tape.gradient(cur_total_loss, variables)
        # print("check gradient: ", len(gradients))
        # g_e = [type(ele) for ele in gradients if not isinstance(ele, tf.IndexedSlices)]
        # sum_g = [ele.numpy().sum() for ele in gradients if not isinstance(ele, tf.IndexedSlices)]

        # print(len(gradients), len(sum_g))
        optimizer.apply_gradients(zip(gradients, variables), global_step=step)
        return batch_loss

    # print(len(train_source_tensor),BATCH_SIZE,training_steps,LIMIT)
    for epoch in range(EPOCH):
        per_epoch_loss = 0
        start = time.time()
        encoder_hidden = encoder.initialize_hidden_state()
        encoder_ceil = encoder.initialize_cell_state()
        encoder_state = [[encoder_hidden, encoder_ceil],
                         [encoder_hidden, encoder_ceil],
                         [encoder_hidden, encoder_ceil],
                         [encoder_hidden, encoder_ceil]]
        # TODO : Double check to make sure all re-initialization is performed
        for idx, data in enumerate(dataset.take(training_steps)):

            source, target = data
            cur_total_loss = train_wrapper(source, target)
            per_epoch_loss += cur_total_loss
            if idx % 10 == 0:
                # print("current step is: "+str(tf.compat.v1.train.get_global_step()))
                # print(dir(optimizer))
                print("current learning rate is:" +
                      str(optimizer._learning_rate))
                print('Epoch {}/{} Batch {}/{} Loss {:.4f}'.format(
                    epoch + 1, EPOCH, idx + 10, training_steps,
                    cur_total_loss.numpy()))
                # tf.print(step)
                # print(dir(step))
                # print(int(step))
            if step >= 5:
                optimizer._learning_rate /= 2.0

        print('Epoch {}/{} Total Loss per epoch {:.4f} - {} sec'.format(
            epoch + 1, EPOCH, per_epoch_loss / training_steps,
            time.time() - start))
        # TODO: for evaluation add bleu score
        if epoch % 10 == 0:
            print('Saving checkpoint for each 10 epochs')
            checkpoint.save(file_prefix=checkpoint_prefix)