steps_per_epoch = len(dataset_flow.input_tensor) // BATCH_SIZE
timer = Timer()
previous_loss = 1e5

for epoch in range(EPOCHS):
    timer.start()
    enc_hidden = encoder.initialize_hidden_state()
    total_loss = 0

    print('Something', end='\r')
    for (batch, (inp, targ)) in enumerate(dataset.take(steps_per_epoch)):
        print("Epoch {} batch [{}/{}]".format(epoch + 1, batch + 1,
                                              steps_per_epoch),
              end='\r',
              flush=True)
        batch_loss = train_step(inp, targ, enc_hidden)
        total_loss += batch_loss

    total_loss = total_loss / steps_per_epoch

    if total_loss < previous_loss:
        print_str = 'Loss improved from {:.4f} to {:.4f} Saving model to file...'.format(
            previous_loss, total_loss)
        print(print_str)
        encoder.save_weights('model_checkpoint/encoder.h5')
        decoder.save_weights('model_checkpoint/decoder.h5')
    previous_loss = total_loss

    print('Epoch {} Loss {:.4f} : {:.0f} s'.format(epoch + 1, total_loss,
                                                   timer.stop()))
for epoch in range(EPOCHS):
    start = time.time()

    enc_hidden = encoder.initialize_hidden_state()
    total_loss = 0

    for (batch, (inp, targ)) in enumerate(dataset.take(steps_per_epoch)):
        batch_loss = train_step(inp, targ, enc_hidden)
        total_loss += batch_loss

        if batch % 100 == 0:
            print('Epoch {} Batch {} Loss {:.4f}'.format(epoch + 1,
                                                         batch,
                                                         batch_loss.numpy()))
    # saving (checkpoint) the model every 2 epochs
    if (epoch + 1) % 2 == 0:
        checkpoint.save(file_prefix=checkpoint_prefix)

    print('Epoch {} Loss {:.4f}'.format(epoch + 1,
                                        total_loss / steps_per_epoch))
    print('Time taken for 1 epoch {} sec\n'.format(time.time() - start))

# saving encoder and decoder weights
encoder.save_weights('./encoder', save_format='tf')
decoder.save_weights('./decoder', save_format='tf')

checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))

# final call to translation model
image_to_txt()