Пример #1
0
 def load_model(self, checkpoint_path: str = None, verbose=True):
     model = self.get_model()
     self.compile_model(model)
     ckpt = tf.train.Checkpoint(net=model)
     manager = tf.train.CheckpointManager(ckpt,
                                          self.weights_dir,
                                          max_to_keep=None)
     if checkpoint_path:
         ckpt.restore(checkpoint_path)
         if verbose:
             print(
                 f'restored weights from {checkpoint_path} at step {model.step}'
             )
     else:
         if manager.latest_checkpoint is None:
             print(
                 f'WARNING: could not find weights file. Trying to load from \n {self.weights_dir}.'
             )
             print(
                 'Edit data_config.yaml to point at the right log directory.'
             )
         ckpt.restore(manager.latest_checkpoint)
         if verbose:
             print(
                 f'restored weights from {manager.latest_checkpoint} at step {model.step}'
             )
     decoder_prenet_dropout = piecewise_linear_schedule(
         model.step, self.config['decoder_prenet_dropout_schedule'])
     reduction_factor = None
     if self.model_kind == 'autoregressive':
         reduction_factor = reduction_schedule(
             model.step, self.config['reduction_factor_schedule'])
     model.set_constants(reduction_factor=reduction_factor,
                         decoder_prenet_dropout=decoder_prenet_dropout)
     return model
        f'\nresuming training from step {model.step} ({manager.latest_checkpoint})'
    )
else:
    print(f'\nstarting training from scratch')

if config['debug'] is True:
    print('\nWARNING: DEBUG is set to True. Training in eager mode.')
# main event
print('\nTRAINING')
losses = []
_ = train_dataset.next_batch()
t = trange(model.step, config['max_steps'], leave=True)
for _ in t:
    t.set_description(f'step {model.step}')
    mel, phonemes, stop = train_dataset.next_batch()
    decoder_prenet_dropout = piecewise_linear_schedule(
        model.step, config['decoder_prenet_dropout_schedule'])
    learning_rate = piecewise_linear_schedule(model.step,
                                              config['learning_rate_schedule'])
    reduction_factor = reduction_schedule(model.step,
                                          config['reduction_factor_schedule'])
    drop_n_heads = tf.cast(
        reduction_schedule(model.step, config['head_drop_schedule']), tf.int32)
    t.display(f'reduction factor {reduction_factor}', pos=10)
    model.set_constants(decoder_prenet_dropout=decoder_prenet_dropout,
                        learning_rate=learning_rate,
                        reduction_factor=reduction_factor,
                        drop_n_heads=drop_n_heads)
    output = model.train_step(inp=phonemes, tar=mel, stop_prob=stop)
    losses.append(float(output['loss']))

    t.display(f'step loss: {losses[-1]}', pos=1)
Пример #3
0
print('\nTRAINING')
losses = []
texts = []
for text_file in config_dict['text_prediction']:
    with open(text_file, 'r') as file:
        text = file.readlines()
    texts.append(text)

all_files = len(set(
    train_data_handler.metadata_reader.filenames))  # without duplicates
all_durations = {}
t = trange(model.step, config_dict['max_steps'], leave=True)
for _ in t:
    t.set_description(f'step {model.step}')
    phonemes, mel, durations, spk_emb, fname = train_dataset.next_batch()
    learning_rate = piecewise_linear_schedule(
        model.step, config_dict['learning_rate_schedule'])
    model.set_constants(learning_rate=learning_rate)

    spk_emb += jitter(512)[:spk_emb.shape[0], :]
    spk_emb /= np.linalg.norm(spk_emb, axis=1, keepdims=True)

    output = model.train_step(input_sequence=phonemes,
                              target_sequence=mel,
                              target_durations=durations,
                              spk_emb=spk_emb)
    losses.append(float(output['loss']))

    predicted_durations = dict(zip(fname.numpy(), output['duration'].numpy()))
    all_durations.update(predicted_durations)
    if len(all_durations) >= all_files:  # all the dataset has been processed
        display_predicted_symbol_duration_distributions(all_durations)