t.display(f'reduction factor {reduction_factor}', pos=10)
    model.set_constants(decoder_prenet_dropout=decoder_prenet_dropout,
                        learning_rate=learning_rate,
                        reduction_factor=reduction_factor,
                        drop_n_heads=drop_n_heads)
    output = model.train_step(inp=phonemes, tar=mel, stop_prob=stop)
    losses.append(float(output['loss']))

    t.display(f'step loss: {losses[-1]}', pos=1)
    for pos, n_steps in enumerate(config['n_steps_avg_losses']):
        if len(losses) > n_steps:
            t.display(
                f'{n_steps}-steps average loss: {sum(losses[-n_steps:]) / n_steps}',
                pos=pos + 2)

    summary_manager.display_loss(output, tag='Train')
    summary_manager.display_scalar(tag='Meta/decoder_prenet_dropout',
                                   scalar_value=model.decoder_prenet.rate)
    summary_manager.display_scalar(tag='Meta/learning_rate',
                                   scalar_value=model.optimizer.lr)
    summary_manager.display_scalar(tag='Meta/reduction_factor',
                                   scalar_value=model.r)
    summary_manager.display_scalar(tag='Meta/drop_n_heads',
                                   scalar_value=model.drop_n_heads)
    if model.step % config['train_images_plotting_frequency'] == 0:
        summary_manager.display_attention_heads(output,
                                                tag='TrainAttentionHeads')
        summary_manager.display_mel(mel=output['mel_linear'][0],
                                    tag=f'Train/linear_mel_out')
        summary_manager.display_mel(mel=output['final_output'][0],
                                    tag=f'Train/predicted_mel')
def train():
    print_config()
    model = get_model(train=True)
    lr = FLAGS['learning_rate']
    model._compile(optimizer=tf.keras.optimizers.Adam(
        lr, beta_1=0.9, beta_2=0.98, epsilon=1e-9))

    if not os.path.isfile(FLAGS['learning_rate_path']):
        try:
            with open(FLAGS['learning_rate_path'], 'w+') as f:
                f.write(str(FLAGS['learning_rate']))
        except:
            print("Could not create learning rate backup %s" %
                  FLAGS['learning_rate_path'])
            traceback.print_exc()
            sys.exit(1)

    if not os.path.isfile(FLAGS['epoch_path']):
        try:
            with open(FLAGS['epoch_path'], 'w+') as f:
                f.write(str(0))
        except:
            print("Could not create epoch count file %s" % FLAGS['epoch_path'])
            traceback.print_exc()
            sys.exit(1)

    try:
        if not os.path.exists(FLAGS['weights_dir']):
            os.makedirs(FLAGS['weights_dir'])
    except:
        print("Could not weights folder %s that contains model "
              "directories for trained weights " % FLAGS['weights_dir'])
        traceback.print_exc()
        sys.exit(1)

    checkpoint = tf.train.Checkpoint(step=tf.Variable(0),
                                     optimizer=model.optimizer,
                                     net=model)
    manager = tf.train.CheckpointManager(
        checkpoint,
        str(FLAGS['weights_dir']),  #### remaining  #### ##done##
        max_to_keep=FLAGS['keep_n_weights'],
        keep_checkpoint_every_n_hours=FLAGS['keep_checkpoint_every_n_hours'])
    summary_manager = SummaryManager(model=model,
                                     log_dir=FLAGS['train_dir'],
                                     config=FLAGS)
    checkpoint.restore(manager.latest_checkpoint)
    if manager.latest_checkpoint:
        print(
            f'\nresuming training from step {model.step} ({manager.latest_checkpoint})'
        )
        try:
            with open(FLAGS['learning_rate_path'], 'r') as f:
                lr = float(f.readlines()[0])
                print("Successfully loaded learning rate.")
        except:
            print("Could not load learning rate file %s" %
                  FLAGS['learning_rate_path'])
            traceback.print_exc()
            sys.exit(1)
    else:
        print(f'\nStarting training from scratch ...')
    if FLAGS['debug'] is True:
        print('\nWARNING: DEBUG is set to True. Training in eager mode.')

    print('\nTRAINING...')

    train_data = data_utils.read_and_bucket_data(
        os.path.join(FLAGS['data_dir'], "train.pkl"))
    dev_set = data_utils.read_and_bucket_data(
        os.path.join(FLAGS['data_dir'], "dev.pkl"))

    train_dataset = Dataset(train_data[0],
                            FLAGS['batch_size'],
                            isTraining=True,
                            bucket_id=None,
                            drop_remainder=True)

    temp = []

    val_wer_window = []
    window_size = 3
    validation_improvement_threshold = FLAGS[
        'valid_thresh']  # per_threshold for validation improvement

    step_time, loss = 0.0, 0.0
    #previous_losses = []
    steps_done = model.step
    val_losses = []

    if steps_done > 0:  ## remaining## ##done##
        # The model saved would have wer and per better than 1.0
        best_wer, _ = calc_levenshtein_loss(model,
                                            dev_set)  ## remaining## ##done##
    else:
        best_wer = 1.0

    # _ = train_dataset.next_batch()
    epoch_id = train_dataset.epoch  #remaining#
    t = trange(
        model.step, FLAGS['max_steps'], leave=True
    )  ## implement model.epoch #### replace 3 with model.epoch# ##done##
    c = epoch_id
    steps = 0
    for _ in t:
        #current_temp = subprocess.check_output(['nvidia-smi','--query-gpu=temperature.gpu','--format=csv,noheader'])
        t.set_description(f'Step {model.step}')
        #batch_data = data_utils.batch_bucketed_data(train_data, FLAGS['batch_size'])
        #for batch in tqdm.tqdm(batch_data):
        #start_time = time.time()
        encoder_inputs, seq_len, decoder_inputs, seq_len_target = train_dataset.next_batch(
        )  #model.get_batch(batch) ## to be implemented ## ## done ##
        model_out = model.train_step(encoder_inputs, seq_len, decoder_inputs,
                                     seq_len_target)  ## remaining ## ##done##
        step_loss = model_out['loss']
        loss += step_loss
        steps += 1
        t.display(f'epoch : {train_dataset.epoch}', pos=2)
        if model.step % FLAGS['train_images_plotting_frequency'] == 0:
            summary_manager.display_attention_heads(model_out,
                                                    tag='TrainAttentionHeads')
        #model.increment_epoch()
        if c + 1 == train_dataset.epoch:  #change in epoch
            c = train_dataset.epoch
            loss /= steps
            summary_manager.display_scalar(tag='Meta/epoch',
                                           scalar_value=c,
                                           plot_all=True)
            summary_manager.display_loss(loss, tag='Train', plot_all=True)
            #summary_manager.display_loss(loss, tag='Validation', plot)

            perplexity = np.exp(loss) if loss < 300 else float('inf')
            t.display("Epoch %d"
                      " perplexity %.4f" % (train_dataset.epoch, perplexity),
                      pos=3)

            steps = 0
            loss = 0
            # Calculate validation result
            val_wer, val_per, val_loss = calc_levenshtein_loss(
                model,
                dev_set,
                summary_manager=summary_manager,
                step=model.step)
            val_losses.append(val_per)
            summary_manager.display_loss(val_loss,
                                         tag='Validation-loss',
                                         plot_all=True)
            summary_manager.display_loss(perplexity,
                                         tag='Validation-perplexity',
                                         plot_all=True)
            summary_manager.display_loss(val_per,
                                         tag='Validation-per',
                                         plot_all=True)
            summary_manager.display_loss(val_wer,
                                         tag='Validation-wer',
                                         plot_all=True)
            summary_manager.display_scalar(tag='Meta/learning_rate',
                                           scalar_value=model.optimizer.lr,
                                           plot_all=True)

            #validation_improvement_threshold
            t.display("Validation WER: %.5f, PER: %.5f" % (val_wer, val_per),
                      pos=4)
            if len(val_losses) >= 50:
                global_avg = sum(val_losses[-50:]) / 50.0
                last_10_avg = sum(val_losses[-10:]) / 10.0
                if global_avg - last_10_avg < validation_improvement_threshold:
                    lr *= 0.2
                    t.display("Learning rate updated.", pos=5)
                    model.set_constants(learning_rate=lr)
                    with open(FLAGS['learning_rate_path'], 'w') as f:
                        f.write(str(lr))
            # Validation WER is a moving window, we add the new entry and pop the oldest one
            val_wer_window.append(val_wer)  ## confirm from this paper
            if len(val_wer_window) > window_size:
                val_wer_window.pop(0)
                avg_wer = sum(val_wer_window) / float(len(val_wer_window))
                t.display("Average Validation WER %.5f" % (avg_wer), pos=6)
                # The best model is decided based on average validation WER to remove noisy cases of one off validation success
                if best_wer > avg_wer:  ## saving criteria is different ## #done
                    # Save the best model
                    best_wer = avg_wer
                    t.display("Saving Updated Model", pos=7)
                    save_path = manager.save()

        print()