Exemple #1
0
                if generated_utterances[g] > num_utterances:
                    continue
            except KeyError:
                generated_utterances[g] = 1

        if output_html:

            def _tqdm(x):
                return x
        else:
            _tqdm = tqdm

        # Generate
        y_hats = batch_wavegen(model,
                               c=c,
                               g=g,
                               fast=True,
                               tqdm=_tqdm,
                               writing_dir=writing_dir)
        # Save each utt.
        has_ref_file = len(ref_files) > 0
        for i, (ref, gen, length) in enumerate(zip(x, y_hats, input_lengths)):
            if has_ref_file:
                if is_mulaw_quantize(hparams.input_type):
                    # needs to be float since mulaw_inv returns in range of [-1, 1]
                    ref = ref.view(-1).float().cpu().numpy()[:length]
                else:
                    ref = ref.view(-1).cpu().numpy()[:length]
            gen = gen[:length]
            if has_ref_file:
                target_audio_path = ref_files[i]
                name = splitext(basename(target_audio_path))[0].replace(
Exemple #2
0
                if generated_utterances[g] > num_utterances:
                    continue
            except KeyError:
                generated_utterances[g] = 1

        if output_html:

            def _tqdm(x):
                return x
        else:
            _tqdm = tqdm

        # Generate
        y_hats = batch_wavegen(model,
                               c=c[:, :, 10:130],
                               g=g,
                               fast=True,
                               tqdm=_tqdm)

        # Save each utt.
        has_ref_file = len(ref_files) > 0
        for i, (ref, gen, length) in enumerate(zip(x, y_hats, input_lengths)):
            if has_ref_file:
                if is_mulaw_quantize(hparams.input_type):
                    # needs to be float since mulaw_inv returns in range of [-1, 1]
                    ref = ref.max(0)[1].view(-1).float().cpu().numpy()[:length]
                else:
                    ref = ref.view(-1).cpu().numpy()[:length]
            gen = gen[:length]
            if has_ref_file:
                target_audio_path = ref_files[i]
Exemple #3
0
            try:
                generated_utterances[g] += 1
                if generated_utterances[g] > num_utterances:
                    continue
            except KeyError:
                generated_utterances[g] = 1

        if output_html:

            def _tqdm(x):
                return x
        else:
            _tqdm = tqdm

        # Generate
        y_hats = batch_wavegen(model, c=c, g=g, fast=True, tqdm=_tqdm)

        # Save each utt.
        has_ref_file = len(ref_files) > 0
        for i, (ref, gen, length) in enumerate(zip(x, y_hats, input_lengths)):
            if has_ref_file:
                if is_mulaw_quantize(hparams.input_type):
                    # needs to be float since mulaw_inv returns in range of [-1, 1]
                    ref = ref.max(0)[1].view(-1).float().cpu().numpy()[:length]
                else:
                    ref = ref.view(-1).cpu().numpy()[:length]
            gen = gen[:length]
            if has_ref_file:
                target_audio_path = ref_files[i]
                name = splitext(basename(target_audio_path))[0].replace(
                    "-wave", "")
    from train import build_model
    from synthesis import batch_wavegen

    # Model
    model = build_model().to(device)

    # Load checkpoint
    print("Load checkpoint from {}".format(checkpoint_path))
    if use_cuda:
        checkpoint = torch.load(checkpoint_path)
    else:
        checkpoint = torch.load(checkpoint_path,
                                map_location=lambda storage, loc: storage)
    model.load_state_dict(checkpoint["state_dict"])

    # Generate
    y_hat = batch_wavegen(model,
                          c=None,
                          g=None,
                          fast=True,
                          tqdm=tqdm,
                          length=length)
    gen = y_hat[0, :length]
    gen = np.clip(gen, -1.0, 1.0)

    # Write, random name
    os.makedirs(dst_dir, exist_ok=True)
    dst_wav_path = join(dst_dir, "{}_gen.wav".format(random.randint(0, 2**16)))
    wavfile.write(dst_wav_path, hparams.sample_rate, to_int16(gen))