Beispiel #1
0
def main(
    metadata_path,
    source_dir,
    target_dir,
    output_dir,
    root,
    batch_size,
    reload,
    reload_dir,
):
    """Main function"""

    # import Inferencer module

    inferencer = Inferencer(root)
    device = inferencer.device
    sample_rate = inferencer.sample_rate
    print(f"[INFO]: Inferencer is loaded from {root}.")

    metadata = json.load(open(metadata_path))
    print(f"[INFO]: Metadata list is loaded from {metadata_path}.")

    output_dir = Path(output_dir) / Path(root).stem / \
        f"{metadata['source_corpus']}2{metadata['target_corpus']}"
    output_dir.mkdir(parents=True, exist_ok=True)

    if reload:
        metadata, conv_mels = reload_from_numpy(device, metadata, reload_dir)
    else:
        metadata, conv_mels = conversion(inferencer, device, root, metadata,
                                         source_dir, target_dir, output_dir)

    waveforms = []
    max_memory_use = conv_mels[0].size(0) * batch_size

    with torch.no_grad():
        pbar = tqdm(total=metadata["n_samples"])
        left = 0
        while (left < metadata["n_samples"]):
            batch_size = max_memory_use // conv_mels[left].size(0) - 1
            right = left + min(batch_size, metadata["n_samples"] - left)
            waveforms.extend(
                inferencer.spectrogram2waveform(conv_mels[left:right]))
            pbar.update(batch_size)
            left += batch_size
        pbar.close()

    for pair, waveform in tqdm(zip(metadata["pairs"], waveforms)):
        waveform = waveform.detach().cpu().numpy()

        prefix = Path(pair["src_utt"]).stem
        postfix = Path(pair["tgt_utts"][0]).stem
        file_path = output_dir / f"{prefix}_to_{postfix}.wav"
        pair["converted"] = f"{prefix}_to_{postfix}.wav"

        if Path(root).stem == "BLOW":
            wavfile.write(file_path, sample_rate, waveform)
        else:
            sf.write(file_path, waveform, sample_rate)

    metadata_output_path = output_dir / "metadata.json"
    json.dump(metadata, metadata_output_path.open("w"), indent=2)