コード例 #1
0
ファイル: train_mod.py プロジェクト: Cli98/attention-modular
def train():
    # args is a global variable in this task
    BATCH_SIZE = args.batch_size
    EPOCH = args.epoch
    EMBED_DIM = args.embeddingDim
    MAXLEN = args.maxLen
    NUM_UNITS = args.units
    LEARNING_RATE = args.learning_rate
    DROPOUT = args.dropout
    METHOD = args.method
    GPUNUM = args.gpuNum
    CKPT = args.checkpoint
    LIMIT = args.limit
    start_word = "<s>"
    end_word = "</s>"
    #Here, tokenizer saves all info to split data.
    #Itself is not a part of data.
    source_vocabulary_data = load_vocabulary_data(filepath=r'./data/vocab.50K.en', limit=None)
    target_vocabulary_data = load_vocabulary_data(filepath=r'./data/vocab.50K.de', limit=None)
    buffer_size = BATCH_SIZE*100

    #train_source_tensor, train_source_tokenizer, train_target_tensor, train_target_tokenizer = \
    #    load_data(pad_length = MAXLEN, limit=LIMIT)

    #train_source_tensor, val_source_tensor, train_target_tensor, val_target_tensor = \
    #    train_test_split(train_source_tensor, train_target_tensor, random_state=2019)

    vocab_source_size = len(source_vocabulary_data)
    print("vocab_input_size: ",vocab_source_size)
    vocab_target_size = len(target_vocabulary_data)
    print("vocab_target_size: ", vocab_target_size)

    optimizer = tf.compat.v1.train.GradientDescentOptimizer(learning_rate=0.001)
    # dataset = generator_input_fn(train_source_tensor, train_target_tensor, buffer_size, EPOCH, BATCH_SIZE, MAXLEN)
    apply_loss = tf.losses.SparseCategoricalCrossentropy(from_logits=True)
    encoder = Encoder(vocab_source_size, EMBED_DIM, NUM_UNITS, dropout_rate = DROPOUT, batch_size=BATCH_SIZE)
    decoder = Decoder(vocab_target_size, EMBED_DIM, NUM_UNITS, batch_size= BATCH_SIZE, method= None, dropout_rate=DROPOUT)
    # set up checkpoint
    if not os.path.exists(CKPT):
        os.makedirs(CKPT)
    else:
        print("Warning: current Checkpoint dir already exist! ",
              "\nPlease consider to choose a new dir to save your checkpoint!")
    checkpoint = tf.train.Checkpoint(optimizer = optimizer, encoder=encoder,
                                     decoder=decoder)
    checkpoint_prefix = os.path.join(CKPT, "ckpt")


    def train_wrapper(source, target, train_target_tokenizer):
        # with tf.GradientTape(watch_accessed_variables=False) as tape:
        with tf.GradientTape() as tape:
            # source_out, source_state, source_trainable_var, tape = encoder(source, encoder_state, vocab_source_size,
            #                                                          EMBED_DIM, NUM_UNITS, activation="tanh",
            #                                                          dropout_rate = DROPOUT)
            source_out, source_state = encoder(source, encoder_state, activation="tanh")

            initial = tf.expand_dims([train_target_tokenizer.word_index[start_word]] * BATCH_SIZE, 1)
            attention_state = tf.zeros((BATCH_SIZE, 1, EMBED_DIM))
            # cur_total_loss is a sum of loss for current steps, namely batch loss
            cur_total_loss, cur_total_plex, cur_loss = 0, 0, 0
            for i in range(1, target.shape[1]):
                output_state, source_state, attention_state = decoder(initial, source_state, source_out, attention_state)
                # TODO: check for the case where target is 0
                cur_loss = apply_loss(target[:, i], output_state)
                perplex = tf.nn.sparse_softmax_cross_entropy_with_logits(target[:, i], output_state)
                # 0 should be the padding value in target.
                # I assumed that there should not be 0 value in target
                # for safety reason, we apply this mask to final loss
                # Mask is a array contains binary value(0 or 1)
                mask = tf.math.logical_not(tf.math.equal(target[:, i], 0))
                mask = tf.cast(mask, dtype=cur_loss.dtype)
                cur_loss *= mask
                perplex *= mask
                cur_total_loss += tf.reduce_mean(cur_loss)
                cur_total_plex += tf.reduce_mean(perplex)
                initial = tf.expand_dims(target[:, i], 1)
                # print(cur_loss)
                # print(cur_total_loss)
        batch_loss = cur_total_loss / target.shape[1]
        batch_perplex = cur_total_plex / target.shape[1]
        ## debug
        variables = encoder.trainable_variables + decoder.trainable_variables
        # print("check variable: ", len(variables))
        #variables = encoder.trainable_variables
        # print("check var:", len(variables), variables[12:])
        gradients = tape.gradient(cur_total_loss, variables)
        # print("check gradient: ", len(gradients))
        # g_e = [type(ele) for ele in gradients if not isinstance(ele, tf.IndexedSlices)]
        # sum_g = [ele.numpy().sum() for ele in gradients if not isinstance(ele, tf.IndexedSlices)]

        # print(len(gradients), len(sum_g))
        clipped_gradients, _ = tf.clip_by_global_norm(gradients, 5)
        optimizer.apply_gradients(zip(clipped_gradients, variables))
        return batch_loss, batch_perplex
    # print(len(train_source_tensor),BATCH_SIZE,training_steps,LIMIT)
    for epoch in range(EPOCH):
        per_epoch_loss, per_epoch_plex = 0, 0
        start = time.time()
        encoder_hidden = encoder.initialize_hidden_state()
        encoder_ceil = encoder.initialize_cell_state()
        encoder_state = [[encoder_hidden, encoder_ceil], [encoder_hidden, encoder_ceil],
                         [encoder_hidden, encoder_ceil], [encoder_hidden, encoder_ceil]]
        # TODO : Double check to make sure all re-initialization is performed
        gen = \
            nmt_generator_from_data(source_data_path=r"./data/train.en", target_data_path=r'./data/train.de',
                         source_vocabulary_data=source_vocabulary_data, target_vocabulary_data=target_vocabulary_data, pad_length=90, batch_size=BATCH_SIZE,
                         limit=args.limit)
        for idx, data in enumerate(gen):
            train_source_tensor, train_source_tokenizer, train_target_tensor, train_target_tokenizer = data
            source, target = tf.convert_to_tensor(train_source_tensor), tf.convert_to_tensor(train_target_tensor)
            #print(type(source))
            #print("The shape of source is: ",source.shape,train_source_tensor.shape)
            #print(source[0])
            #print(source[1])
            cur_total_loss, batch_perplex = train_wrapper(source, target, train_target_tokenizer)
            per_epoch_loss += cur_total_loss
            per_epoch_plex += tf.exp(batch_perplex).numpy()
            if idx % 10 == 0:
                # print("current step is: "+str(tf.compat.v1.train.get_global_step()))
                # print(dir(optimizer))
                with open("checkpoint/logger.txt","a") as filelogger:
                    print("current learning rate is:"+str(optimizer._learning_rate),file=filelogger)
                    print('Epoch {}/{} Batch {} Loss {:.4f} perplex {:.4f}'.format(epoch + 1,
                                                                       EPOCH,
                                                                       idx + 10,
                                                                       cur_total_loss.numpy(),
                                                                        tf.exp(batch_perplex).numpy()),
                          file=filelogger)
                # tf.print(step)
                # print(dir(step))
                # print(int(step))

        with open("checkpoint/logger.txt", "a") as filelogger:
            print('Epoch {}/{} Total Loss per epoch {:.4f} - {} sec'.format(epoch + 1,
                                                                        EPOCH,
                                                                        per_epoch_loss / (idx+1.0),
                                                                        time.time() - start),file=filelogger)
            print("Epoch perplex: ",str(per_epoch_plex),file=filelogger)
        # TODO: for evaluation add bleu score
        if epoch % 1 == 0:
            print('Saving checkpoint for each epochs')
            checkpoint.save(file_prefix=checkpoint_prefix)

        if epoch == 3:
            optimizer._learning_rate = LEARNING_RATE / 10.0
コード例 #2
0
def test(args: Namespace):
    cfg = json.load(open(args.config_path, 'r', encoding='UTF-8'))

    batch_size = 1  # for predicting one sentence.

    encoder = Encoder(cfg['vocab_input_size'], cfg['embedding_dim'],
                      cfg['units'], batch_size, 0)
    decoder = Decoder(cfg['vocab_target_size'], cfg['embedding_dim'],
                      cfg['units'], cfg['method'], batch_size, 0)
    optimizer = select_optimizer(cfg['optimizer'], cfg['learning_rate'])

    ckpt = tf.train.Checkpoint(optimizer=optimizer,
                               encoder=encoder,
                               decoder=decoder)
    manager = tf.train.CheckpointManager(ckpt,
                                         cfg['checkpoint_dir'],
                                         max_to_keep=3)
    ckpt.restore(manager.latest_checkpoint)

    while True:
        sentence = input(
            'Input Sentence or If you want to quit, type Enter Key : ')

        if sentence == '': break

        sentence = re.sub(r"(\.\.\.|[?.!,¿])", r" \1 ", sentence)
        sentence = re.sub(r'[" "]+', " ", sentence)

        sentence = '<s> ' + sentence.lower().strip() + ' </s>'

        input_vocab = load_vocab('./data/', 'en')
        target_vocab = load_vocab('./data/', 'de')

        input_lang_tokenizer = tf.keras.preprocessing.text.Tokenizer(
            filters='', oov_token='<unk>')
        input_lang_tokenizer.word_index = input_vocab

        target_lang_tokenizer = tf.keras.preprocessing.text.Tokenizer(
            filters='', oov_token='<unk>')
        target_lang_tokenizer.word_index = target_vocab

        convert_vocab(input_lang_tokenizer, input_vocab)
        convert_vocab(target_lang_tokenizer, target_vocab)

        inputs = [
            input_lang_tokenizer.word_index[i]
            if i in input_lang_tokenizer.word_index else
            input_lang_tokenizer.word_index['<unk>']
            for i in sentence.split(' ')
        ]
        inputs = tf.keras.preprocessing.sequence.pad_sequences(
            [inputs], maxlen=cfg['max_len_input'], padding='post')

        inputs = tf.convert_to_tensor(inputs)

        result = ''

        enc_hidden = encoder.initialize_hidden_state()
        enc_cell = encoder.initialize_cell_state()
        enc_state = [[enc_hidden, enc_cell], [enc_hidden, enc_cell],
                     [enc_hidden, enc_cell], [enc_hidden, enc_cell]]

        enc_output, enc_hidden = encoder(inputs, enc_state)

        dec_hidden = enc_hidden
        #dec_input = tf.expand_dims([target_lang_tokenizer.word_index['<eos>']], 0)
        dec_input = tf.expand_dims([target_lang_tokenizer.word_index['<s>']],
                                   1)

        print('dec_input:', dec_input)

        h_t = tf.zeros((batch_size, 1, cfg['embedding_dim']))

        for t in range(int(cfg['max_len_target'])):
            predictions, dec_hidden, h_t = decoder(dec_input, dec_hidden,
                                                   enc_output, h_t)

            # predeictions shape == (1, 50002)

            predicted_id = tf.argmax(predictions[0]).numpy()
            print('predicted_id', predicted_id)

            result += target_lang_tokenizer.index_word[predicted_id] + ' '

            if target_lang_tokenizer.index_word[predicted_id] == '</s>':
                print('Early stopping')
                break

            dec_input = tf.expand_dims([predicted_id], 1)
            print('dec_input:', dec_input)

        print('<s> ' + result)
        print(sentence)
        sys.stdout.flush()
コード例 #3
0
ファイル: test.py プロジェクト: Cli98/attention-modular
def test():
    #Dropout rate is 0 in test setup?
    BATCH_SIZE = args.batch_size
    EMBED_DIM = args.embeddingDim
    MAXLEN = args.maxLen
    NUM_UNITS = args.units
    CKPT = args.checkpoint
    LEARNING_RATE = args.learning_rate
    EPOCH = 1
    DROPOUT = 0
    start_word = "<s>"
    end_word = "</s>"

    test_source_tensor, test_source_tokenizer, test_target_tensor, test_target_tokenizer = \
        load_test_data(test_translate_from=r"./data/newstest2015.en", test_translate_to=r'./data/newstest2015.de',
           vocab_from=r'./data/vocab.50K.en', vocab_to=r'./data/vocab.50K.de',
           pad_length=90, limit=args.limit)
    #test_source_tensor, test_source_tokenizer, test_target_tensor, test_target_tokenizer = \
    #    load_data(pad_length = MAXLEN, limit=None)
    print(len(test_source_tensor))
    vocab_source_size = len(test_source_tokenizer.word_index) + 1
    print("vocab_input_size: ", vocab_source_size)
    vocab_target_size = len(test_target_tokenizer.word_index) + 1
    print("vocab_target_size: ", vocab_target_size)
    optimizer = tf.optimizers.Adam(learning_rate=LEARNING_RATE)
    buffer_size = len(test_source_tensor)

    test_steps = len(test_source_tensor) // BATCH_SIZE
    #print(test_source_tensor[0])
    #print("Type: ",type(test_source_tensor))
    #for ele in test_target_tensor:
    #    print(type(ele))
    #print("dir of test_target__token: ",dir(test_target_tokenizer))
    #print(test_target_tokenizer.index_word)
    dataset = train_input_fn(test_source_tensor, test_target_tensor,
                             buffer_size, EPOCH, BATCH_SIZE)

    # apply_loss = tf.losses.SparseCategoricalCrossentropy(from_logits=True)
    encoder = Encoder(vocab_source_size,
                      EMBED_DIM,
                      NUM_UNITS,
                      dropout_rate=DROPOUT,
                      batch_size=BATCH_SIZE)
    decoder = Decoder(vocab_target_size,
                      EMBED_DIM,
                      NUM_UNITS,
                      batch_size=BATCH_SIZE,
                      method=None,
                      dropout_rate=DROPOUT)
    optimizer = tf.compat.v1.train.GradientDescentOptimizer(
        learning_rate=0.001)
    ckpt = tf.train.Checkpoint(optimizer=optimizer,
                               encoder=encoder,
                               decoder=decoder)
    manager = tf.train.CheckpointManager(ckpt, args.checkpoint, max_to_keep=10)
    ckpt.restore(manager.latest_checkpoint)
    per_epoch_loss, per_epoch_plex = 0, 0

    def test_wrapper(source, target):
        # source_out, source_state, source_trainable_var, tape = encoder(source, encoder_state, vocab_source_size,
        #                                                          EMBED_DIM, NUM_UNITS, activation="tanh",
        #                                                          dropout_rate = DROPOUT)
        result = ""
        source_out, source_state = encoder(source,
                                           encoder_state,
                                           activation="tanh")

        initial = tf.expand_dims(
            [test_target_tokenizer.word_index[start_word]] * BATCH_SIZE, 1)
        attention_state = tf.zeros((BATCH_SIZE, 1, EMBED_DIM))
        apply_loss = tf.losses.SparseCategoricalCrossentropy(from_logits=True)
        # cur_total_loss is a sum of loss for current steps, namely batch loss
        cur_total_loss, cur_total_plex, cur_loss = 0, 0, 0
        #print("word_index: ", test_target_tokenizer.word_index,len(test_target_tokenizer.word_index))
        #print("index_word: ",test_target_tokenizer.index_word)
        for i in range(target.shape[1]):
            output_state, source_state, attention_state = decoder(
                initial, source_state, source_out, attention_state)
            # TODO: check for the case where target is 0
            # 0 should be the padding value in target.
            # I assumed that there should not be 0 value in target
            # for safety reason, we apply this mask to final loss
            # Mask is a array contains binary value(0 or 1)
            #print(output_state.numpy().shape)
            #print(output_state[0].numpy())
            cur_loss = apply_loss(target[:, i], output_state)
            perplex = tf.nn.sparse_softmax_cross_entropy_with_logits(
                target[:, i], output_state)
            current_ind = tf.argmax(output_state[0])
            mask = tf.math.logical_not(tf.math.equal(target[:, i], 0))
            mask = tf.cast(mask, dtype=cur_loss.dtype)
            cur_loss *= mask
            perplex *= mask
            cur_total_loss += tf.reduce_mean(cur_loss)
            cur_total_plex += tf.reduce_mean(perplex)
            #tf.print("check current id: ",current_ind)
            #tf.print(test_target_tokenizer.index_word[29])
            #print(current_ind.numpy())
            if current_ind.numpy() == 0:
                # 0 is for pad value, we don't need to record it
                continue
            result += test_target_tokenizer.index_word[current_ind.numpy()]
            if test_target_tokenizer.index_word[
                    current_ind.numpy()] == end_word:
                break
            initial = tf.expand_dims(target[:, i], 1)
        batch_loss = cur_total_loss / target.shape[1]
        batch_perplex = cur_total_plex / target.shape[1]
        return batch_loss, batch_perplex, result

    encoder_hidden = encoder.initialize_hidden_state()
    encoder_ceil = encoder.initialize_cell_state()
    encoder_state = [[encoder_hidden, encoder_ceil],
                     [encoder_hidden, encoder_ceil],
                     [encoder_hidden, encoder_ceil],
                     [encoder_hidden, encoder_ceil]]
    # TODO : Double check to make sure all re-initialization is performed
    result_by_batch = []
    for idx, data in tqdm(enumerate(dataset.take(test_steps)),
                          total=args.limit):
        source, target = data
        batch_loss, batch_perplex, result = test_wrapper(source, target)
        with open("checkpoint/test_logger.txt", "a") as filelogger:
            print("The validation loss in batch " + str(idx) + " is : ",
                  str(batch_loss.numpy() / (idx + 1.0)),
                  file=filelogger)
            print("The validation perplex in batch " + str(idx) + " is : ",
                  str(batch_perplex.numpy() / (idx + 1.0)),
                  file=filelogger)
        per_epoch_loss += batch_loss
        per_epoch_plex += batch_perplex
        assert type(result) == str
        result_by_batch.append(result)
        #if idx>=3:
        #    break
    with open("checkpoint/test_logger.txt", "a") as filelogger:
        print("The validation loss is : ",
              str(per_epoch_loss.numpy() / (idx + 1.0)),
              file=filelogger)
        print("The validation perplex is: ",
              str(tf.exp(per_epoch_plex).numpy() / (idx + 1.0)),
              file=filelogger)
    return test_target_tokenizer, result_by_batch
コード例 #4
0
def train(args: Namespace):
    input_tensor, target_tensor, input_lang_tokenizer, target_lang_tokenizer = load_dataset(
        './data/', args.max_len, limit_size=None)

    max_len_input = len(input_tensor[0])
    max_len_target = len(target_tensor[0])

    print('max len of each seq:', max_len_input, ',', max_len_target)

    input_tensor_train, input_tensor_val, target_tensor_train, target_tensor_val = train_test_split(
        input_tensor, target_tensor, test_size=args.dev_split)

    # init hyperparameter
    EPOCHS = args.epoch
    batch_size = args.batch_size
    steps_per_epoch = len(input_tensor_train) // batch_size
    embedding_dim = args.embedding_dim
    units = args.units
    vocab_input_size = len(input_lang_tokenizer.word_index) + 1
    vocab_target_size = len(target_lang_tokenizer.word_index) + 1
    BUFFER_SIZE = len(input_tensor_train)
    learning_rate = args.learning_rate

    setattr(args, 'max_len_input', max_len_input)
    setattr(args, 'max_len_target', max_len_target)

    setattr(args, 'steps_per_epoch', steps_per_epoch)
    setattr(args, 'vocab_input_size', vocab_input_size)
    setattr(args, 'vocab_target_size', vocab_target_size)
    setattr(args, 'BUFFER_SIZE', BUFFER_SIZE)

    dataset = tf.data.Dataset.from_tensor_slices(
        (input_tensor_train, target_tensor_train)).shuffle(BUFFER_SIZE)
    dataset = dataset.batch(batch_size)

    print('dataset shape (batch_size, max_len):', dataset)

    encoder = Encoder(vocab_input_size, embedding_dim, units, batch_size,
                      args.dropout)
    decoder = Decoder(vocab_target_size, embedding_dim, units, args.method,
                      batch_size, args.dropout)

    optimizer = select_optimizer(args.optimizer, args.learning_rate)

    loss_object = tf.losses.SparseCategoricalCrossentropy(from_logits=True,
                                                          reduction='none')

    @tf.function
    def train_step(_input, _target, enc_state):
        loss = 0

        with tf.GradientTape() as tape:
            enc_output, enc_state = encoder(_input, enc_state)

            dec_hidden = enc_state

            dec_input = tf.expand_dims(
                [target_lang_tokenizer.word_index['<s>']] * batch_size, 1)

            # First input feeding definition
            h_t = tf.zeros((batch_size, 1, embedding_dim))

            for idx in range(1, _target.shape[1]):
                # idx means target character index.
                predictions, dec_hidden, h_t = decoder(dec_input, dec_hidden,
                                                       enc_output, h_t)

                #tf.print(tf.argmax(predictions, axis=1))

                loss += loss_function(loss_object, _target[:, idx],
                                      predictions)

                dec_input = tf.expand_dims(_target[:, idx], 1)

        batch_loss = (loss / int(_target.shape[1]))

        variables = encoder.trainable_variables + decoder.trainable_variables

        gradients = tape.gradient(loss, variables)

        optimizer.apply_gradients(zip(gradients, variables))

        return batch_loss

    # Setting checkpoint
    now_time = dt.datetime.now().strftime("%m%d%H%M")
    checkpoint_dir = './training_checkpoints/' + now_time
    setattr(args, 'checkpoint_dir', checkpoint_dir)
    checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
    checkpoint = tf.train.Checkpoint(optimizer=optimizer,
                                     encoder=encoder,
                                     decoder=decoder)

    os.makedirs(checkpoint_dir, exist_ok=True)

    # saving information of the model
    with open('{}/config.json'.format(checkpoint_dir), 'w',
              encoding='UTF-8') as fout:
        json.dump(vars(args), fout, indent=2, sort_keys=True)

    min_total_loss = 1000

    for epoch in range(EPOCHS):
        start = time.time()

        enc_hidden = encoder.initialize_hidden_state()
        enc_cell = encoder.initialize_cell_state()
        enc_state = [[enc_hidden, enc_cell], [enc_hidden, enc_cell],
                     [enc_hidden, enc_cell], [enc_hidden, enc_cell]]

        total_loss = 0

        for (batch, (_input,
                     _target)) in enumerate(dataset.take(steps_per_epoch)):
            batch_loss = train_step(_input, _target, enc_state)
            total_loss += batch_loss

            if batch % 10 == 0:
                print('Epoch {}/{} Batch {}/{} Loss {:.4f}'.format(
                    epoch + 1, EPOCHS, batch + 10, steps_per_epoch,
                    batch_loss.numpy()))

        print('Epoch {}/{} Total Loss per epoch {:.4f} - {} sec'.format(
            epoch + 1, EPOCHS, total_loss / steps_per_epoch,
            time.time() - start))

        # saving checkpoint
        if min_total_loss > total_loss / steps_per_epoch:
            print('Saving checkpoint...')
            min_total_loss = total_loss / steps_per_epoch
            checkpoint.save(file_prefix=checkpoint_prefix)

        print('\n')
コード例 #5
0
ファイル: train.py プロジェクト: Cli98/attention-modular
def train():
    # args is a global variable in this task
    BATCH_SIZE = args.batch_size
    EPOCH = args.epoch
    EMBED_DIM = args.embeddingDim
    MAXLEN = args.maxLen
    NUM_UNITS = args.units
    LEARNING_RATE = args.learning_rate
    DROPOUT = args.dropout
    METHOD = args.method
    GPUNUM = args.gpuNum
    CKPT = args.checkpoint
    LIMIT = args.limit
    start_word = "<s>"
    end_word = "</s>"
    #Here, tokenizer saves all info to split data.
    #Itself is not a part of data.
    train_source_tensor, train_source_tokenizer, train_target_tensor, train_target_tokenizer = \
        load_data(pad_length = MAXLEN, limit=LIMIT)
    buffer_size = len(train_source_tensor)
    train_source_tensor, val_source_tensor, train_target_tensor, val_target_tensor = \
        train_test_split(train_source_tensor, train_target_tensor, random_state=2019)

    #TODO: check if we need target tokenizer
    training_steps = len(train_source_tensor) // BATCH_SIZE
    vocab_source_size = len(train_source_tokenizer.word_index) + 1
    print("vocab_input_size: ", vocab_source_size)
    vocab_target_size = len(train_target_tokenizer.word_index) + 1
    print("vocab_target_size: ", vocab_target_size)

    step = tf.Variable(0, trainable=False)
    # boundaries = [100, 200]
    # values = [1.0, 0.5, 0.1]
    # boundaries = [30, 40]
    # values = [1.0, 0.5, 0.0]
    # learning_rate_fn = tf.compat.v1.train.piecewise_constant(step,
    #     boundaries, values)
    # optimizer = tf.optimizers.SGD(learning_rate=learning_rate_fn(step))
    optimizer = tf.compat.v1.train.GradientDescentOptimizer(
        learning_rate=0.001)
    # set up checkpoint
    if not os.path.exists(CKPT):
        os.makedirs(CKPT)
    else:
        print(
            "Warning: current Checkpoint dir already exist! ",
            "\nPlease consider to choose a new dir to save your checkpoint!")
    checkpoint = tf.train.Checkpoint(optimzier=optimizer)
    checkpoint_prefix = os.path.join(CKPT, "ckpt")

    dataset = train_input_fn(train_source_tensor, train_target_tensor,
                             buffer_size, EPOCH, BATCH_SIZE)
    apply_loss = tf.losses.SparseCategoricalCrossentropy(from_logits=True)
    encoder = Encoder(vocab_source_size,
                      EMBED_DIM,
                      NUM_UNITS,
                      dropout_rate=DROPOUT,
                      batch_size=BATCH_SIZE)
    decoder = Decoder(vocab_target_size,
                      EMBED_DIM,
                      NUM_UNITS,
                      batch_size=BATCH_SIZE,
                      method=None,
                      dropout_rate=DROPOUT)

    def train_wrapper(source, target):
        # with tf.GradientTape(watch_accessed_variables=False) as tape:
        with tf.GradientTape() as tape:
            # source_out, source_state, source_trainable_var, tape = encoder(source, encoder_state, vocab_source_size,
            #                                                          EMBED_DIM, NUM_UNITS, activation="tanh",
            #                                                          dropout_rate = DROPOUT)
            source_out, source_state = encoder(source,
                                               encoder_state,
                                               activation="tanh")

            initial = tf.expand_dims(
                [train_target_tokenizer.word_index[start_word]] * BATCH_SIZE,
                1)
            attention_state = tf.zeros((BATCH_SIZE, 1, EMBED_DIM))
            # cur_total_loss is a sum of loss for current steps, namely batch loss
            cur_total_loss, cur_loss = 0, 0
            for i in range(1, target.shape[1]):
                output_state, source_state, attention_state = decoder(
                    initial, source_state, source_out, attention_state)
                # TODO: check for the case where target is 0
                cur_loss = apply_loss(target[:, i], output_state)
                # 0 should be the padding value in target.
                # I assumed that there should not be 0 value in target
                # for safety reason, we apply this mask to final loss
                # Mask is a array contains binary value(0 or 1)
                mask = tf.math.logical_not(tf.math.equal(target[:, i], 0))
                mask = tf.cast(mask, dtype=cur_loss.dtype)
                cur_loss *= mask
                cur_total_loss += tf.reduce_mean(cur_loss)
                initial = tf.expand_dims(target[:, i], 1)
                # print(cur_loss)
                # print(cur_total_loss)
        batch_loss = cur_total_loss / target.shape[1]
        ## debug
        variables = encoder.trainable_variables + decoder.trainable_variables
        # print("check variable: ", len(variables))
        #variables = encoder.trainable_variables
        # print("check var:", len(variables), variables[12:])
        gradients = tape.gradient(cur_total_loss, variables)
        # print("check gradient: ", len(gradients))
        # g_e = [type(ele) for ele in gradients if not isinstance(ele, tf.IndexedSlices)]
        # sum_g = [ele.numpy().sum() for ele in gradients if not isinstance(ele, tf.IndexedSlices)]

        # print(len(gradients), len(sum_g))
        optimizer.apply_gradients(zip(gradients, variables), global_step=step)
        return batch_loss

    # print(len(train_source_tensor),BATCH_SIZE,training_steps,LIMIT)
    for epoch in range(EPOCH):
        per_epoch_loss = 0
        start = time.time()
        encoder_hidden = encoder.initialize_hidden_state()
        encoder_ceil = encoder.initialize_cell_state()
        encoder_state = [[encoder_hidden, encoder_ceil],
                         [encoder_hidden, encoder_ceil],
                         [encoder_hidden, encoder_ceil],
                         [encoder_hidden, encoder_ceil]]
        # TODO : Double check to make sure all re-initialization is performed
        for idx, data in enumerate(dataset.take(training_steps)):

            source, target = data
            cur_total_loss = train_wrapper(source, target)
            per_epoch_loss += cur_total_loss
            if idx % 10 == 0:
                # print("current step is: "+str(tf.compat.v1.train.get_global_step()))
                # print(dir(optimizer))
                print("current learning rate is:" +
                      str(optimizer._learning_rate))
                print('Epoch {}/{} Batch {}/{} Loss {:.4f}'.format(
                    epoch + 1, EPOCH, idx + 10, training_steps,
                    cur_total_loss.numpy()))
                # tf.print(step)
                # print(dir(step))
                # print(int(step))
            if step >= 5:
                optimizer._learning_rate /= 2.0

        print('Epoch {}/{} Total Loss per epoch {:.4f} - {} sec'.format(
            epoch + 1, EPOCH, per_epoch_loss / training_steps,
            time.time() - start))
        # TODO: for evaluation add bleu score
        if epoch % 10 == 0:
            print('Saving checkpoint for each 10 epochs')
            checkpoint.save(file_prefix=checkpoint_prefix)