Esempio n. 1
0
parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, default='config/session_paths.yaml')
parser.add_argument('--skip_phonemes', action='store_true')
parser.add_argument('--skip_mels', action='store_true')
parser.add_argument('--skip_speakers', action='store_true')

args = parser.parse_args()
for arg in vars(args):
    print('{}: {}'.format(arg, getattr(args, arg)))

cm = Config(args.config, asr=True)
cm.create_remove_dirs()
metadatareader = DataReader.from_config(cm, kind='original')
summary_manager = SummaryManager(model=None,
                                 log_dir=cm.log_dir / 'data_preprocessing',
                                 config=cm.config,
                                 default_writer='data_preprocessing')
print(f'\nFound {len(metadatareader.filenames)} audio files.')
audio = Audio(config=cm.config)

if not args.skip_mels:

    def process_file(tuples):
        len_dict = {}
        spk_file_dict = {}
        remove_files = []
        for idx in trange(len(tuples), desc=''):
            file_name, fullpath, data_type, spk_name, _ = tuples[idx]
            _, trim_type = cm.data_type[data_type]
            try:
                y, sr = audio.load_file(
                                                kind='valid')

train_dataset = train_data_handler.get_dataset(
    bucket_batch_sizes=config['bucket_batch_sizes'],
    bucket_boundaries=config['bucket_boundaries'],
    shuffle=True)
valid_dataset = valid_data_handler.get_dataset(
    bucket_batch_sizes=[6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 1],
    bucket_boundaries=config['bucket_boundaries'],
    shuffle=False,
    drop_remainder=True)

# create logger and checkpointer and restore latest model

summary_manager = SummaryManager(model=model,
                                 log_dir=config_manager.log_dir,
                                 config=config)
checkpoint = tf.train.Checkpoint(step=tf.Variable(1),
                                 optimizer=model.optimizer,
                                 net=model)
manager = tf.train.CheckpointManager(
    checkpoint,
    str(config_manager.weights_dir),
    max_to_keep=config['keep_n_weights'],
    keep_checkpoint_every_n_hours=config['keep_checkpoint_every_n_hours'])
manager_training = tf.train.CheckpointManager(checkpoint,
                                              str(config_manager.weights_dir /
                                                  'latest'),
                                              max_to_keep=1,
                                              checkpoint_name='latest')
Esempio n. 3
0
config.print_config()

model = config.get_model()
config.compile_model(model)

data_handler = ASRDataset.from_config(config,
                                      tokenizer=model.text_pipeline.tokenizer,
                                      kind='valid')
dataset = data_handler.get_dataset(
    bucket_batch_sizes=config_dict['bucket_batch_sizes'],
    bucket_boundaries=config_dict['bucket_boundaries'],
    shuffle=False)

# create logger and checkpointer and restore latest model
summary_manager = SummaryManager(model=model,
                                 log_dir=config.log_dir,
                                 config=config_dict)
checkpoint = tf.train.Checkpoint(step=tf.Variable(1),
                                 optimizer=model.optimizer,
                                 net=model)
manager = tf.train.CheckpointManager(
    checkpoint,
    config.weights_dir,
    max_to_keep=config_dict['keep_n_weights'],
    keep_checkpoint_every_n_hours=config_dict['keep_checkpoint_every_n_hours'])
manager_training = tf.train.CheckpointManager(checkpoint,
                                              str(config.weights_dir /
                                                  'latest'),
                                              max_to_keep=1,
                                              checkpoint_name='latest')
Esempio n. 4
0
valid_data_handler = TTSDataset.from_config(config,
                                            preprocessor=data_prep,
                                            kind='valid')
train_dataset = train_data_handler.get_dataset(
    bucket_batch_sizes=config_dict['bucket_batch_sizes'],
    bucket_boundaries=config_dict['bucket_boundaries'],
    shuffle=True)
valid_dataset = valid_data_handler.get_dataset(
    bucket_batch_sizes=config_dict['val_bucket_batch_size'],
    bucket_boundaries=config_dict['bucket_boundaries'],
    shuffle=False,
    drop_remainder=True)

# create logger and checkpointer and restore latest model
summary_manager = SummaryManager(model=model,
                                 log_dir=config.log_dir,
                                 config=config_dict)
checkpoint = tf.train.Checkpoint(step=tf.Variable(1),
                                 optimizer=model.optimizer,
                                 net=model)
manager = tf.train.CheckpointManager(
    checkpoint,
    config.weights_dir,
    max_to_keep=config_dict['keep_n_weights'],
    keep_checkpoint_every_n_hours=config_dict['keep_checkpoint_every_n_hours'])
manager_training = tf.train.CheckpointManager(checkpoint,
                                              str(config.weights_dir /
                                                  'latest'),
                                              max_to_keep=1,
                                              checkpoint_name='latest')
Esempio n. 5
0
np.random.seed(42)

parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, required=True)
parser.add_argument('--skip_phonemes', action='store_true')
parser.add_argument('--skip_mels', action='store_true')

args = parser.parse_args()
for arg in vars(args):
    print('{}: {}'.format(arg, getattr(args, arg)))

cm = Config(args.config, model_kind='autoregressive')
cm.create_remove_dirs()
metadatareader = DataReader.from_config(cm, kind='original', scan_wavs=True)
summary_manager = SummaryManager(model=None,
                                 log_dir=cm.log_dir / 'data_preprocessing',
                                 config=cm.config,
                                 default_writer='data_preprocessing')
if not args.skip_mels:

    def process_wav(wav_path: Path):
        file_name = wav_path.stem
        y, sr = audio.load_wav(str(wav_path))
        mel = audio.mel_spectrogram(y)
        assert mel.shape[1] == audio.config['mel_channels'], len(
            mel.shape) == 2
        mel_path = (cm.mel_dir / file_name).with_suffix('.npy')
        np.save(mel_path, mel)
        return (file_name, mel.shape[0])

    print(
        f"Creating mels from all wavs found in {metadatareader.data_directory}"
Esempio n. 6
0
valid_data_handler = TextMelDurDataset.from_config(config,
                                                   preprocessor=data_prep,
                                                   kind='valid')
train_dataset = train_data_handler.get_dataset(
    bucket_batch_sizes=config_dict['bucket_batch_sizes'],
    bucket_boundaries=config_dict['bucket_boundaries'],
    shuffle=True)
valid_dataset = valid_data_handler.get_dataset(
    bucket_batch_sizes=[6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 1],
    bucket_boundaries=config_dict['bucket_boundaries'],
    shuffle=False,
    drop_remainder=True)

# create logger and checkpointer and restore latest model
summary_manager = SummaryManager(model=model,
                                 log_dir=config.log_dir,
                                 config=config_dict)
checkpoint = tf.train.Checkpoint(step=tf.Variable(1),
                                 optimizer=model.optimizer,
                                 net=model)
manager = tf.train.CheckpointManager(
    checkpoint,
    config.weights_dir,
    max_to_keep=config_dict['keep_n_weights'],
    keep_checkpoint_every_n_hours=config_dict['keep_checkpoint_every_n_hours'])
manager_training = tf.train.CheckpointManager(checkpoint,
                                              str(config.weights_dir /
                                                  'latest'),
                                              max_to_keep=1,
                                              checkpoint_name='latest')
Esempio n. 7
0
train_data_handler = ASRDataset.from_config(config,
                                            tokenizer=model.text_pipeline.tokenizer,
                                            kind='train')
valid_data_handler = ASRDataset.from_config(config,
                                            tokenizer=model.text_pipeline.tokenizer,
                                            kind='valid')
train_dataset = train_data_handler.get_dataset(bucket_batch_sizes=config_dict['bucket_batch_sizes'],
                                               bucket_boundaries=config_dict['bucket_boundaries'],
                                               shuffle=True)
valid_dataset = valid_data_handler.get_dataset(bucket_batch_sizes=config_dict['val_bucket_batch_size'],
                                               bucket_boundaries=config_dict['bucket_boundaries'],
                                               shuffle=False,
                                               drop_remainder=True)

# create logger and checkpointer and restore latest model
summary_manager = SummaryManager(model=model, log_dir=config.log_dir, config=config_dict)
checkpoint = tf.train.Checkpoint(step=tf.Variable(1),
                                 optimizer=model.optimizer,
                                 net=model)
manager = tf.train.CheckpointManager(checkpoint, config.weights_dir,
                                     max_to_keep=config_dict['keep_n_weights'],
                                     keep_checkpoint_every_n_hours=config_dict['keep_checkpoint_every_n_hours'])
manager_training = tf.train.CheckpointManager(checkpoint, str(config.weights_dir / 'latest'),
                                              max_to_keep=1, checkpoint_name='latest')

checkpoint.restore(manager_training.latest_checkpoint)
if manager_training.latest_checkpoint:
    print(f'\nresuming training from step {model.step} ({manager_training.latest_checkpoint})')
else:
    print(f'\nstarting training from scratch')
Esempio n. 8
0
n_dense = int(config_manager.config['decoder_dense_blocks'])
n_convs = int(n_layers - n_dense)
if args.extract_layer > 0:
    if n_convs > 0:
        last_layer_key = f'Decoder_ConvBlock{args.extract_layer}_CrossAttention'
    else:
        last_layer_key = f'Decoder_DenseBlock{args.extract_layer}_CrossAttention'
else:
    if n_convs > 0:
        last_layer_key = f'Decoder_ConvBlock{n_convs}_CrossAttention'
    else:
        last_layer_key = f'Decoder_DenseBlock{n_dense}_CrossAttention'
print(f'Extracting attention from layer {last_layer_key}')

summary_manager = SummaryManager(model=model,
                                 log_dir=config_manager.log_dir / writer_tag,
                                 config=config,
                                 default_writer=writer_tag)
all_durations = np.array([])
new_alignments = []
iterator = tqdm(enumerate(dataset.all_batches()))
step = 0
for c, (mel_batch, text_batch, stop_batch, file_name_batch) in iterator:
    iterator.set_description(f'Processing dataset')
    outputs = model.val_step(inp=text_batch,
                             tar=mel_batch,
                             stop_prob=stop_batch)
    attention_values = outputs['decoder_attention'][last_layer_key].numpy()
    text = text_batch.numpy()

    if args.use_GT:
        mel = mel_batch.numpy()