Example #1
0
    def track_top_models(self, mel_loss, gen_wav, model):
        """ Keeps track of top k models and saves them according to their current rank """
        for j, (l, g, m, m_n) in enumerate(self.top_k_models):
            print(f'{j} {l} {m} {m_n}')
        if len(self.top_k_models) < self.train_cfg[
                'keep_top_k'] or mel_loss < self.top_k_models[-1][0]:
            m_step = model.get_step()
            model_name = f'model_loss{mel_loss:#0.5}_step{m_step}_weights.pyt'
            self.top_k_models.append(
                (mel_loss, gen_wav, model.get_step(), model_name))
            self.top_k_models.sort(key=lambda t: t[0])
            self.top_k_models = self.top_k_models[:self.
                                                  train_cfg['keep_top_k']]
            model.save(self.paths.voc_top_k / model_name)
            all_models = get_files(self.paths.voc_top_k, extension='pyt')
            top_k_names = {m[-1] for m in self.top_k_models}
            for model_file in all_models:
                if model_file.name not in top_k_names:
                    print(f'removing {model_file}')
                    os.remove(model_file)
            pickle_binary(self.top_k_models,
                          self.paths.voc_top_k / 'top_k.pkl')

            for i, (mel_loss, g_wav, m_step,
                    m_name) in enumerate(self.top_k_models, 1):
                self.writer.add_audio(tag=f'Top_K_Models/generated_top_{i}',
                                      snd_tensor=g_wav,
                                      global_step=m_step,
                                      sample_rate=self.dsp.sample_rate)
def create_align_features(
    model: Tacotron,
    train_set: DataLoader,
    val_set: DataLoader,
    save_path_alg: Path,
    #   save_path_pitch: Path
):
    assert model.r == 1, f'Reduction factor of tacotron must be 1 for creating alignment features! ' \
                         f'Reduction factor was: {model.r}'
    model.eval()
    device = next(
        model.parameters()).device  # use same device as model parameters
    if val_set is not None:
        iters = len(val_set) + len(train_set)
        dataset = itertools.chain(train_set, val_set)
    else:
        # print('here')
        iters = len(train_set)
        # print(iters)
        dataset = itertools.chain(train_set)

    att_score_dict = {}

    if hp.extract_durations_with_dijkstra:
        print('Extracting durations using dijkstra...')

        dur_extraction_func = extract_durations_with_dijkstra

    else:
        print('Extracting durations using attention peak counts...')
        dur_extraction_func = extract_durations_per_count
    # for i in dataset:
    # print(i)
    for i, (x, mels, ids, x_lens, mel_lens) in enumerate(dataset, 1):
        x, mels = x.to(device), mels.to(device)
        # print(x)
        # print(mels)
        with torch.no_grad():
            _, _, att_batch = model(x, mels)
        align_score, sharp_score = attention_score(att_batch, mel_lens, r=1)
        att_batch = np_now(att_batch)
        seq, att, mel_len, item_id = x[0], att_batch[0], mel_lens[0], ids[0]
        align_score, sharp_score = float(align_score[0]), float(sharp_score[0])
        att_score_dict[item_id] = (align_score, sharp_score)
        durs = dur_extraction_func(seq, att, mel_len)
        if np.sum(durs) != mel_len:
            print(
                f'WARNINNG: Sum of durations did not match mel length for item {item_id}!'
            )
        np.save(str(save_path_alg / f'{item_id}.npy'),
                durs,
                allow_pickle=False)
        bar = progbar(i, iters)
        msg = f'{bar} {i}/{iters} Batches '
        stream(msg)
    pickle_binary(att_score_dict, paths.data / 'att_score_dict.pkl')
Example #3
0
    pool = Pool(processes=n_workers)
    dataset = []
    cleaned_texts = []
    for i, (item_id, length, cleaned_text) in enumerate(
            pool.imap_unordered(process_wav, wav_files), 1):
        if item_id in text_dict:
            dataset += [(item_id, length)]
            cleaned_texts += [(item_id, cleaned_text)]
        bar = progbar(i, len(wav_files))
        message = f'{bar} {i}/{len(wav_files)} '
        stream(message)

    random = Random(hp.seed)
    random.shuffle(dataset)
    train_dataset = dataset[hp.n_val:]
    val_dataset = dataset[:hp.n_val]
    # sort val dataset longest to shortest
    val_dataset.sort(key=lambda d: -d[1])

    for id, text in cleaned_texts:
        text_dict[id] = text

    pickle_binary(text_dict, paths.data / 'text_dict.pkl')
    pickle_binary(train_dataset, paths.data / 'train_dataset.pkl')
    pickle_binary(val_dataset, paths.data / 'val_dataset.pkl')

    print(
        '\n\nCompleted. Ready to run "python train_tacotron.py" or "python train_wavernn.py". \n'
    )