def main( metadata_path, source_dir, target_dir, output_dir, root, batch_size, reload, reload_dir, ): """Main function""" # import Inferencer module inferencer = Inferencer(root) device = inferencer.device sample_rate = inferencer.sample_rate print(f"[INFO]: Inferencer is loaded from {root}.") metadata = json.load(open(metadata_path)) print(f"[INFO]: Metadata list is loaded from {metadata_path}.") output_dir = Path(output_dir) / Path(root).stem / \ f"{metadata['source_corpus']}2{metadata['target_corpus']}" output_dir.mkdir(parents=True, exist_ok=True) if reload: metadata, conv_mels = reload_from_numpy(device, metadata, reload_dir) else: metadata, conv_mels = conversion(inferencer, device, root, metadata, source_dir, target_dir, output_dir) waveforms = [] max_memory_use = conv_mels[0].size(0) * batch_size with torch.no_grad(): pbar = tqdm(total=metadata["n_samples"]) left = 0 while (left < metadata["n_samples"]): batch_size = max_memory_use // conv_mels[left].size(0) - 1 right = left + min(batch_size, metadata["n_samples"] - left) waveforms.extend( inferencer.spectrogram2waveform(conv_mels[left:right])) pbar.update(batch_size) left += batch_size pbar.close() for pair, waveform in tqdm(zip(metadata["pairs"], waveforms)): waveform = waveform.detach().cpu().numpy() prefix = Path(pair["src_utt"]).stem postfix = Path(pair["tgt_utts"][0]).stem file_path = output_dir / f"{prefix}_to_{postfix}.wav" pair["converted"] = f"{prefix}_to_{postfix}.wav" if Path(root).stem == "BLOW": wavfile.write(file_path, sample_rate, waveform) else: sf.write(file_path, waveform, sample_rate) metadata_output_path = output_dir / "metadata.json" json.dump(metadata, metadata_output_path.open("w"), indent=2)