test_list = [data_prep(s) for s in val_samples]
train_dataset = Dataset(samples=train_samples,
                        preprocessor=data_prep,
                        batch_size=config['batch_size'],
                        mel_channels=config['mel_channels'],
                        shuffle=True)
val_dataset = Dataset(samples=val_samples,
                      preprocessor=data_prep,
                      batch_size=config['batch_size'],
                      mel_channels=config['mel_channels'],
                      shuffle=False)

# create logger and checkpointer and restore latest model

summary_manager = SummaryManager(model=model,
                                 log_dir=config_manager.log_dir,
                                 config=config)
checkpoint = tf.train.Checkpoint(step=tf.Variable(1),
                                 optimizer=model.optimizer,
                                 net=model)
manager = tf.train.CheckpointManager(
    checkpoint,
    str(config_manager.weights_dir),
    max_to_keep=config['keep_n_weights'],
    keep_checkpoint_every_n_hours=config['keep_checkpoint_every_n_hours'])
checkpoint.restore(manager.latest_checkpoint)
if manager.latest_checkpoint:
    print(
        f'\nresuming training from step {model.step} ({manager.latest_checkpoint})'
    )
else:
Beispiel #2
0
            with open(str(train_predictions_dir / f'{c}_batch_prediction.npy'),
                      'wb') as file:
                pickle.dump(batch, file)
        train_batches.append(batch)
else:
    val_batches = [
        batch_file for batch_file in val_predictions_dir.iterdir()
        if batch_file.suffix == '.npy'
    ]
    train_batches = [
        batch_file for batch_file in train_predictions_dir.iterdir()
        if batch_file.suffix == '.npy'
    ]

summary_manager = SummaryManager(model=model,
                                 log_dir=config_manager.log_dir / writer_tag,
                                 config=config,
                                 default_writer=writer_tag)

# TODO: not clean, here val/train _batches can be either the actual batches, or the file names.
iterator = tqdm(enumerate(val_batches))
all_val_durations = np.array([])
new_alignments = []
total_val_samples = 0
for c, batch_file in iterator:
    iterator.set_description(f'Extracting validation alignments')
    if not running_predictions:
        val_mel, val_text, val_alignments = np.load(str(batch_file),
                                                    allow_pickle=True)
    else:
        val_mel, val_text, val_alignments = batch_file
def train():
    print_config()
    model = get_model(train=True)
    lr = FLAGS['learning_rate']
    model._compile(optimizer=tf.keras.optimizers.Adam(
        lr, beta_1=0.9, beta_2=0.98, epsilon=1e-9))

    if not os.path.isfile(FLAGS['learning_rate_path']):
        try:
            with open(FLAGS['learning_rate_path'], 'w+') as f:
                f.write(str(FLAGS['learning_rate']))
        except:
            print("Could not create learning rate backup %s" %
                  FLAGS['learning_rate_path'])
            traceback.print_exc()
            sys.exit(1)

    if not os.path.isfile(FLAGS['epoch_path']):
        try:
            with open(FLAGS['epoch_path'], 'w+') as f:
                f.write(str(0))
        except:
            print("Could not create epoch count file %s" % FLAGS['epoch_path'])
            traceback.print_exc()
            sys.exit(1)

    try:
        if not os.path.exists(FLAGS['weights_dir']):
            os.makedirs(FLAGS['weights_dir'])
    except:
        print("Could not weights folder %s that contains model "
              "directories for trained weights " % FLAGS['weights_dir'])
        traceback.print_exc()
        sys.exit(1)

    checkpoint = tf.train.Checkpoint(step=tf.Variable(0),
                                     optimizer=model.optimizer,
                                     net=model)
    manager = tf.train.CheckpointManager(
        checkpoint,
        str(FLAGS['weights_dir']),  #### remaining  #### ##done##
        max_to_keep=FLAGS['keep_n_weights'],
        keep_checkpoint_every_n_hours=FLAGS['keep_checkpoint_every_n_hours'])
    summary_manager = SummaryManager(model=model,
                                     log_dir=FLAGS['train_dir'],
                                     config=FLAGS)
    checkpoint.restore(manager.latest_checkpoint)
    if manager.latest_checkpoint:
        print(
            f'\nresuming training from step {model.step} ({manager.latest_checkpoint})'
        )
        try:
            with open(FLAGS['learning_rate_path'], 'r') as f:
                lr = float(f.readlines()[0])
                print("Successfully loaded learning rate.")
        except:
            print("Could not load learning rate file %s" %
                  FLAGS['learning_rate_path'])
            traceback.print_exc()
            sys.exit(1)
    else:
        print(f'\nStarting training from scratch ...')
    if FLAGS['debug'] is True:
        print('\nWARNING: DEBUG is set to True. Training in eager mode.')

    print('\nTRAINING...')

    train_data = data_utils.read_and_bucket_data(
        os.path.join(FLAGS['data_dir'], "train.pkl"))
    dev_set = data_utils.read_and_bucket_data(
        os.path.join(FLAGS['data_dir'], "dev.pkl"))

    train_dataset = Dataset(train_data[0],
                            FLAGS['batch_size'],
                            isTraining=True,
                            bucket_id=None,
                            drop_remainder=True)

    temp = []

    val_wer_window = []
    window_size = 3
    validation_improvement_threshold = FLAGS[
        'valid_thresh']  # per_threshold for validation improvement

    step_time, loss = 0.0, 0.0
    #previous_losses = []
    steps_done = model.step
    val_losses = []

    if steps_done > 0:  ## remaining## ##done##
        # The model saved would have wer and per better than 1.0
        best_wer, _ = calc_levenshtein_loss(model,
                                            dev_set)  ## remaining## ##done##
    else:
        best_wer = 1.0

    # _ = train_dataset.next_batch()
    epoch_id = train_dataset.epoch  #remaining#
    t = trange(
        model.step, FLAGS['max_steps'], leave=True
    )  ## implement model.epoch #### replace 3 with model.epoch# ##done##
    c = epoch_id
    steps = 0
    for _ in t:
        #current_temp = subprocess.check_output(['nvidia-smi','--query-gpu=temperature.gpu','--format=csv,noheader'])
        t.set_description(f'Step {model.step}')
        #batch_data = data_utils.batch_bucketed_data(train_data, FLAGS['batch_size'])
        #for batch in tqdm.tqdm(batch_data):
        #start_time = time.time()
        encoder_inputs, seq_len, decoder_inputs, seq_len_target = train_dataset.next_batch(
        )  #model.get_batch(batch) ## to be implemented ## ## done ##
        model_out = model.train_step(encoder_inputs, seq_len, decoder_inputs,
                                     seq_len_target)  ## remaining ## ##done##
        step_loss = model_out['loss']
        loss += step_loss
        steps += 1
        t.display(f'epoch : {train_dataset.epoch}', pos=2)
        if model.step % FLAGS['train_images_plotting_frequency'] == 0:
            summary_manager.display_attention_heads(model_out,
                                                    tag='TrainAttentionHeads')
        #model.increment_epoch()
        if c + 1 == train_dataset.epoch:  #change in epoch
            c = train_dataset.epoch
            loss /= steps
            summary_manager.display_scalar(tag='Meta/epoch',
                                           scalar_value=c,
                                           plot_all=True)
            summary_manager.display_loss(loss, tag='Train', plot_all=True)
            #summary_manager.display_loss(loss, tag='Validation', plot)

            perplexity = np.exp(loss) if loss < 300 else float('inf')
            t.display("Epoch %d"
                      " perplexity %.4f" % (train_dataset.epoch, perplexity),
                      pos=3)

            steps = 0
            loss = 0
            # Calculate validation result
            val_wer, val_per, val_loss = calc_levenshtein_loss(
                model,
                dev_set,
                summary_manager=summary_manager,
                step=model.step)
            val_losses.append(val_per)
            summary_manager.display_loss(val_loss,
                                         tag='Validation-loss',
                                         plot_all=True)
            summary_manager.display_loss(perplexity,
                                         tag='Validation-perplexity',
                                         plot_all=True)
            summary_manager.display_loss(val_per,
                                         tag='Validation-per',
                                         plot_all=True)
            summary_manager.display_loss(val_wer,
                                         tag='Validation-wer',
                                         plot_all=True)
            summary_manager.display_scalar(tag='Meta/learning_rate',
                                           scalar_value=model.optimizer.lr,
                                           plot_all=True)

            #validation_improvement_threshold
            t.display("Validation WER: %.5f, PER: %.5f" % (val_wer, val_per),
                      pos=4)
            if len(val_losses) >= 50:
                global_avg = sum(val_losses[-50:]) / 50.0
                last_10_avg = sum(val_losses[-10:]) / 10.0
                if global_avg - last_10_avg < validation_improvement_threshold:
                    lr *= 0.2
                    t.display("Learning rate updated.", pos=5)
                    model.set_constants(learning_rate=lr)
                    with open(FLAGS['learning_rate_path'], 'w') as f:
                        f.write(str(lr))
            # Validation WER is a moving window, we add the new entry and pop the oldest one
            val_wer_window.append(val_wer)  ## confirm from this paper
            if len(val_wer_window) > window_size:
                val_wer_window.pop(0)
                avg_wer = sum(val_wer_window) / float(len(val_wer_window))
                t.display("Average Validation WER %.5f" % (avg_wer), pos=6)
                # The best model is decided based on average validation WER to remove noisy cases of one off validation success
                if best_wer > avg_wer:  ## saving criteria is different ## #done
                    # Save the best model
                    best_wer = avg_wer
                    t.display("Saving Updated Model", pos=7)
                    save_path = manager.save()

        print()