def __init__(self,
                 root_dir,
                 audio_query="*.h5",
                 mel_query="*.h5",
                 audio_load_fn=lambda x: read_hdf5(x, "wave"),
                 mel_load_fn=lambda x: read_hdf5(x, "feats"),
                 audio_length_threshold=None,
                 mel_length_threshold=None,
                 return_filename=False,
                 ):
        """Initialize dataset.

        Args:
            root_dir (str): Root directory including dumped files.
            audio_query (str): Query to find audio files in root_dir.
            mel_query (str): Query to find feature files in root_dir.
            audio_load_fn (func): Function to load audio file.
            mel_load_fn (func): Function to load feature file.
            audio_length_threshold (int): Threshold to remove short audio files.
            mel_length_threshold (int): Threshold to remove short feature files.
            return_filename (bool): Whether to return the filename with arrays.

        """
        # find all of audio and mel files
        audio_files = sorted(find_files(root_dir, audio_query))
        mel_files = sorted(find_files(root_dir, mel_query))

        # filter by threshold
        if audio_length_threshold is not None:
            audio_lengths = [audio_load_fn(f).shape[0] for f in audio_files]
            idxs = [idx for idx in range(len(audio_files)) if audio_lengths[idx] > audio_length_threshold]
            if len(audio_files) != len(idxs):
                logging.info(f"some files are filtered by audio length threshold "
                             f"({len(audio_files)} -> {len(idxs)}).")
            audio_files = [audio_files[idx] for idx in idxs]
            mel_files = [mel_files[idx] for idx in idxs]
        if mel_length_threshold is not None:
            mel_lengths = [mel_load_fn(f).shape[0] for f in mel_files]
            idxs = [idx for idx in range(len(mel_files)) if mel_lengths[idx] > mel_length_threshold]
            if len(mel_files) != len(idxs):
                logging.info(f"some files are filtered by mel length threshold "
                             f"({len(mel_files)} -> {len(idxs)}).")
            audio_files = [audio_files[idx] for idx in idxs]
            mel_files = [mel_files[idx] for idx in idxs]

        # assert the number of files
        assert len(audio_files) != 0, f"Not found any audio files in ${root_dir}."
        assert len(audio_files) == len(mel_files), \
            f"Number of audio and mel files are different ({len(audio_files)} vs {len(mel_files)})."

        self.audio_files = audio_files
        self.mel_files = mel_files
        self.audio_load_fn = audio_load_fn
        self.mel_load_fn = mel_load_fn
        self.return_filename = return_filename
Ejemplo n.º 2
0
    def register_stats(self, stats):
        """Register stats for de-normalization as buffer.

        Args:
            stats (str): Path of statistics file (".npy" or ".h5").

        """
        assert stats.endswith(".h5") or stats.endswith(".npy")
        if stats.endswith(".h5"):
            mean = read_hdf5(stats, "mean").reshape(-1)
            scale = read_hdf5(stats, "scale").reshape(-1)
        else:
            mean = np.load(stats)[0].reshape(-1)
            scale = np.load(stats)[1].reshape(-1)
        self.register_buffer("mean", torch.from_numpy(mean).float())
        self.register_buffer("scale", torch.from_numpy(scale).float())
        logging.info("Successfully registered stats as buffer.")
    def __init__(
        self,
        stats,
        audio_list,
        world_list,
        audio_load_fn=sf.read,
        world_load_fn=lambda x: read_hdf5(x, "world"),
        hop_size=110,
        audio_length_threshold=None,
        world_length_threshold=None,
        return_filename=False,
        allow_cache=False,
        mean_path="/world/mean",
        scale_path="/world/scale",
    ):
        """Initialize dataset.

        Args:
            stats (str): Filename of the statistic hdf5 file.
            audio_list (str): Filename of the list of audio files.
            world_list (str): Filename of the list of world feature files.
            audio_load_fn (func): Function to load audio file.
            world_load_fn (func): Function to load world feature file.
            hop_size (int): Hope size of world feature
            audio_length_threshold (int): Threshold to remove short audio files.
            world_length_threshold (int): Threshold to remove short world feature files.
            return_filename (bool): Whether to return the filename with arrays.
            allow_cache (bool): Whether to allow cache of the loaded files.
            mean_path (str): The data path (channel) of the mean in the statistic hdf5 file.
            scale_path (str): The data path (channel) of the scale in the statistic hdf5 file.

        """
        # load audio and world file list
        audio_files = read_txt(audio_list)
        world_files = read_txt(world_list)
        # check filename
        assert check_filename(audio_files, world_files)

        # filter by threshold
        if audio_length_threshold is not None:
            audio_lengths = [audio_load_fn(f).shape[0] for f in audio_files]
            idxs = [
                idx for idx in range(len(audio_files))
                if audio_lengths[idx] > audio_length_threshold
            ]
            if len(audio_files) != len(idxs):
                logging.warning(
                    f"Some files are filtered by audio length threshold "
                    f"({len(audio_files)} -> {len(idxs)}).")
            audio_files = [audio_files[idx] for idx in idxs]
            world_files = [world_files[idx] for idx in idxs]
        if world_length_threshold is not None:
            world_lengths = [world_load_fn(f).shape[0] for f in world_files]
            idxs = [
                idx for idx in range(len(world_files))
                if world_lengths[idx] > world_length_threshold
            ]
            if len(world_files) != len(idxs):
                logging.warning(
                    f"Some files are filtered by world length threshold "
                    f"({len(world_files)} -> {len(idxs)}).")
            audio_files = [audio_files[idx] for idx in idxs]
            world_files = [world_files[idx] for idx in idxs]

        # assert the number of files
        assert len(
            audio_files) != 0, f"Not found any audio files in ${audio_list}."
        assert len(audio_files) == len(world_files), \
            f"Number of audio and world files are different ({len(audio_files)} vs {len(world_files)})."

        self.audio_files = audio_files
        self.world_files = world_files
        self.audio_load_fn = audio_load_fn
        self.world_load_fn = world_load_fn
        self.return_filename = return_filename
        self.allow_cache = allow_cache
        self.hop_size = hop_size
        if allow_cache:
            # NOTE(kan-bayashi): Manager is need to share memory in dataloader with num_workers > 0
            self.manager = Manager()
            self.caches = self.manager.list()
            self.caches += [() for _ in range(len(audio_files))]
        # define feature pre-processing funtion
        scaler = StandardScaler()
        scaler.mean_ = read_hdf5(stats, mean_path)
        scaler.scale_ = read_hdf5(stats, scale_path)
        # for version 0.23.0, this information is needed
        scaler.n_features_in_ = scaler.mean_.shape[0]
        self.feat_transform = lambda x: scaler.transform(x)
Ejemplo n.º 4
0
def main():
    """Run preprocessing process."""
    parser = argparse.ArgumentParser(
        description="Compute mean and variance of dumped raw features.")
    parser.add_argument("--rootdir",
                        default=None,
                        type=str,
                        required=True,
                        help="Direcotry including feature files.")
    parser.add_argument("--dumpdir",
                        default=None,
                        type=str,
                        help="Direcotry to save statistics.")
    parser.add_argument("--config",
                        default="hparam.yml",
                        type=str,
                        required=True,
                        help="Yaml format configuration file.")
    parser.add_argument("--verbose",
                        type=int,
                        default=1,
                        help="logging level (higher is more logging)")
    args = parser.parse_args()

    # set logger
    if args.verbose > 1:
        logging.basicConfig(
            level=logging.DEBUG,
            format=
            "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")
    elif args.verbose > 0:
        logging.basicConfig(
            level=logging.INFO,
            format=
            "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")
    else:
        logging.basicConfig(
            level=logging.WARN,
            format=
            "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")
        logging.warning('skip DEBUG/INFO messages')

    # load config
    with open(args.config) as f:
        config = yaml.load(f, Loader=yaml.Loader)
    config.update(vars(args))

    # check direcotry existence
    if args.dumpdir is None:
        args.dumpdir = os.path.dirname(args.rootdir)
    if not os.path.exists(args.dumpdir):
        os.makedirs(args.dumpdir)

    # get dataset
    if config["format"] == "hdf5":
        mel_query = "*.h5"
        mel_load_fn = lambda x: read_hdf5(x, "feats")  # NOQA
    elif config["format"] == "npy":
        mel_query = "*-feats.npy"
        mel_load_fn = np.load
    else:
        raise ValueError("support only hdf5 or npy format.")
    dataset = MelDataset(args.rootdir,
                         mel_query=mel_query,
                         mel_load_fn=mel_load_fn)
    logging.info(f"the number of files = {len(dataset)}.")

    # calculate statistics
    scaler = StandardScaler()
    for mel in tqdm(dataset):
        scaler.partial_fit(mel)

    if config["format"] == "hdf5":
        write_hdf5(os.path.join(args.dumpdir, "stats.h5"), "mean",
                   scaler.mean_.astype(np.float32))
        write_hdf5(os.path.join(args.dumpdir, "stats.h5"), "scale",
                   scaler.scale_.astype(np.float32))
    else:
        stats = np.stack([scaler.mean_, scaler.scale_], axis=0)
        np.save(os.path.join(args.dumpdir, "stats.npy"),
                stats.astype(np.float32),
                allow_pickle=False)
Ejemplo n.º 5
0
def main():
    """Run preprocessing process."""
    parser = argparse.ArgumentParser(
        description=
        "Normalize dumped raw features (See detail in parallel_wavegan/bin/normalize.py)."
    )
    parser.add_argument(
        "--rootdir",
        type=str,
        required=True,
        help="directory including feature files to be normalized.")
    parser.add_argument("--dumpdir",
                        type=str,
                        required=True,
                        help="directory to dump normalized feature files.")
    parser.add_argument("--stats",
                        type=str,
                        required=True,
                        help="statistics file.")
    parser.add_argument("--config",
                        type=str,
                        required=True,
                        help="yaml format configuration file.")
    parser.add_argument("--n_jobs",
                        type=int,
                        default=16,
                        help="number of parallel jobs. (default=16)")
    parser.add_argument(
        "--verbose",
        type=int,
        default=1,
        help="logging level. higher is more logging. (default=1)")
    args = parser.parse_args()

    # set logger
    if args.verbose > 1:
        logging.basicConfig(
            level=logging.DEBUG,
            format=
            "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")
    elif args.verbose > 0:
        logging.basicConfig(
            level=logging.INFO,
            format=
            "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")
    else:
        logging.basicConfig(
            level=logging.WARN,
            format=
            "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")
        logging.warning('Skip DEBUG/INFO messages')

    # load config
    with open(args.config) as f:
        config = yaml.load(f, Loader=yaml.Loader)
    config.update(vars(args))

    # check directory existence or mkdir new one
    if not os.path.exists(args.dumpdir):
        os.makedirs(args.dumpdir)

    # get dataset
    if config["format"] == "hdf5":
        audio_query, mel_query = "*.h5", "*.h5"
        audio_load_fn = lambda x: read_hdf5(x, "wave")  # NOQA
        mel_load_fn = lambda x: read_hdf5(x, "feats")  # NOQA
    elif config["format"] == "npy":
        audio_query, mel_query = "*-wave.npy", "*-feats.npy"
        audio_load_fn = np.load
        mel_load_fn = np.load
    else:
        raise ValueError("support only hdf5 or npy format.")
    dataset = AudioMelDataset(
        root_dir=args.rootdir,
        audio_query=audio_query,
        mel_query=mel_query,
        audio_load_fn=audio_load_fn,
        mel_load_fn=mel_load_fn,
        return_filename=True,
    )
    logging.info(f"The number of files = {len(dataset)}.")

    # restore scaler
    scaler = StandardScaler()
    if config["format"] == "hdf5":
        scaler.mean_ = read_hdf5(args.stats, "mean")
        scaler.scale_ = read_hdf5(args.stats, "scale")
    elif config["format"] == "npy":
        scaler.mean_ = np.load(args.stats)[0]
        scaler.scale_ = np.load(args.stats)[1]
    else:
        raise ValueError("support only hdf5 or npy format.")

    def _process_single_file(data):
        # parse inputs for each audio
        audio_name, mel_name, audio, mel = data

        # normalize
        """Scale features of X according to feature_range.
        mel *= self.scale_
        mel += self.min_ """
        mel = scaler.transform(mel)

        # save
        if config["format"] == "hdf5":
            write_hdf5(
                os.path.join(args.dumpdir, f"{os.path.basename(audio_name)}"),
                "wave", audio.astype(np.float32))
            write_hdf5(
                os.path.join(args.dumpdir, f"{os.path.basename(mel_name)}"),
                "feats", mel.astype(np.float32))
        elif config["format"] == "npy":
            np.save(os.path.join(args.dumpdir,
                                 f"{os.path.basename(audio_name)}"),
                    audio.astype(np.float32),
                    allow_pickle=False)
            np.save(os.path.join(args.dumpdir,
                                 f"{os.path.basename(mel_name)}"),
                    mel.astype(np.float32),
                    allow_pickle=False)
        else:
            raise ValueError("support only hdf5 or npy format.")

    # process in parallel
    """delayed => Decorator used to capture the arguments of a function."""
    Parallel(n_jobs=args.n_jobs, verbose=args.verbose)(
        [delayed(_process_single_file)(data) for data in tqdm(dataset)])
Ejemplo n.º 6
0
 def audio_load_fn(x):
     return read_hdf5(x, "wave")  # NOQA
Ejemplo n.º 7
0
def main():
    """Run decoding process."""
    parser = argparse.ArgumentParser(
        description="Decode dumped features with trained Parallel WaveGAN Generator.")
    parser.add_argument("--scp", default=None, type=str,
                        help="Kaldi-style feats.scp file.")
    parser.add_argument("--dumpdir", default=None, type=str,
                        help="Directory including feature files.")
    parser.add_argument("--outdir", default=None, type=str, required=True,
                        help="Direcotry to save generated speech.")
    parser.add_argument("--checkpoint", default=None, type=str, required=True,
                        help="Checkpoint file.")
    parser.add_argument("--config", default=None, type=str,
                        help="Yaml format configuration file.")
    parser.add_argument("--verbose", type=int, default=1,
                        help="logging level (higher is more logging)")
    args = parser.parse_args()

    # set logger
    if args.verbose > 1:
        logging.basicConfig(
            level=logging.DEBUG, format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")
    elif args.verbose > 0:
        logging.basicConfig(
            level=logging.INFO, format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")
    else:
        logging.basicConfig(
            level=logging.WARN, format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")
        logging.warning("skip DEBUG/INFO messages")

    # check direcotry existence
    if not os.path.exists(args.outdir):
        os.makedirs(args.outdir)

    # load config
    if args.config is None:
        dirname = os.path.dirname(args.checkpoint)
        args.config = os.path.join(dirname, "config.yml")
    with open(args.config) as f:
        config = yaml.load(f, Loader=yaml.Loader)
    config.update(vars(args))

    # check arguments
    if (args.scp is not None and args.dumpdir is not None) or \
            (args.scp is None and args.dumpdir is None):
        raise ValueError("Please specify either dumpdir or scp.")

    # get dataset
    if args.scp is None:
        if config["format"] == "hdf5":
            mel_query = "*.h5"
            mel_load_fn = lambda x: read_hdf5(x, "feats")  # NOQA
        elif config["format"] == "npy":
            mel_query = "*-feats.npy"
            mel_load_fn = np.load
        else:
            raise ValueError("support only hdf5 or npy format.")
        dataset = MelDataset(
            args.dumpdir,
            mel_query=mel_query,
            mel_load_fn=mel_load_fn,
            return_filename=True)
        logging.info(f"the number of features to be decoded = {len(dataset)}.")
    else:
        dataset = kaldiio.ReadHelper(f"scp:{args.scp}")
        logging.info(f"the feature loaded from {args.scp}.")

    # setup
    if torch.cuda.is_available():
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")
    model = ParallelWaveGANGenerator(**config["generator_params"])
    model.load_state_dict(torch.load(args.checkpoint, map_location="cpu")["model"]["generator"])
    model.remove_weight_norm()
    model = model.eval().to(device)
    logging.info(f"loaded model parameters from {args.checkpoint}.")

    # start generation
    pad_size = (config["generator_params"]["aux_context_window"],
                config["generator_params"]["aux_context_window"])
    total_rtf = 0.0
    with torch.no_grad(), tqdm(dataset, desc="[decode]") as pbar:
        for idx, (feat_path, c) in enumerate(pbar, 1):
            # generate each utterance
            z = torch.randn(1, 1, c.shape[0] * config["hop_size"]).to(device)
            c = np.pad(c, (pad_size, (0, 0)), "edge")
            c = torch.FloatTensor(c).unsqueeze(0).transpose(2, 1).to(device)
            start = time.time()
            y = model(z, c).view(-1).cpu().numpy()
            rtf = (time.time() - start) / (len(y) / config["sampling_rate"])
            pbar.set_postfix({"RTF": rtf})
            total_rtf += rtf

            # save as PCM 16 bit wav file
            utt_id = os.path.splitext(os.path.basename(feat_path))[0]
            sf.write(os.path.join(config["outdir"], f"{utt_id}_gen.wav"),
                     y, config["sampling_rate"], "PCM_16")

    # report average RTF
    logging.info(f"finished generation of {idx} utterances (RTF = {total_rtf / idx:.03f}).")
Ejemplo n.º 8
0
    def _wav_to_melgan_spec(self, wav, sample_rate, introduce_noise = False, wav_path = None):
        """Convert wav file to a mel spectrogram using the methods used by the melgan model (e.g. 24khz when using the default dict),
        this is different from the normal AutoVC mel-spectrograms conversion methods and would thus have different results.

        This method should probably be avoided when calculating speech embeddings, as the speaker encoder is trained on 16khz data with the normal spectrogram format


        Args:
            wav (numpy array): audio data either 1-d (mono) or 2-d (stereo)
            sample_rate (int): the sampling rate of the .wav (sf.read[1])
            wav_path (str): Path to original wav file
            note that these two variables can be loaded using: 
                wavfile, sample_rate = sf.read(os.path.join(input_dir, speaker, fileName))

        Returns:
            np.array: Mel spectrogram (converted using melgan spec)
        """
        print("Converting using wav to melgan!")
        if self.melgan_config["trim_silence"]:
            wav, _ = librosa.effects.trim(wav,
                                            top_db=self.melgan_config["trim_threshold_in_db"],
                                            frame_length=self.melgan_config["trim_frame_size"],
                                            hop_length=self.melgan_config["trim_hop_size"])

        if introduce_noise:
            log.error(f"Introduce_noise is set tot {introduce_noise}, however, this is not implemented. Exiting...")
            exit(0)

        if sample_rate != self.melgan_config["sampling_rate"]: #Resampling
            wav = librosa.resample(wav, sample_rate, self.melgan_config["sampling_rate"])
            print(f"Wav file with sr {sample_rate} != {self.melgan_config['sampling_rate']}, Now resampling to {self.melgan_config['sampling_rate']}")

        mel = self.logmelfilterbank( #Create mel spectrogram using the melGAN settings
                        wav,  
                        sampling_rate=self.melgan_config["sampling_rate"],
                        hop_size=self.melgan_config["hop_size"],
                        fft_size=self.melgan_config["fft_size"],
                        win_length=self.melgan_config["win_length"],
                        window=self.melgan_config["window"],
                        num_mels=self.melgan_config["num_mels"],
                        fmin=self.melgan_config["fmin"],
                        fmax=self.melgan_config["fmax"])
        
        # make sure the audio length and feature length are matched
        wav = np.pad(wav, (0, self.melgan_config["fft_size"]), mode="reflect")
        wav = wav[:len(mel) * self.melgan_config["hop_size"]]
        assert len(mel) * self.melgan_config["hop_size"] == len(wav)

        #================================================Normalization=========================================================
        # restore scaler
        scaler = StandardScaler()
        if self.melgan_config["format"] == "hdf5":
            scaler.mean_ = read_hdf5(self.melgan_stats_path, "mean")
            scaler.scale_ = read_hdf5(self.melgan_stats_path, "scale")
        else:
            raise ValueError("support only hdf5 (and normally npy - but not now) format.... cannot load in scaler mean/scale, exiting")
         
        # from version 0.23.0, this information is needed
        scaler.n_features_in_ = scaler.mean_.shape[0]
        mel = scaler.transform(mel)
        return mel
Ejemplo n.º 9
0
def main():
    """Run training process."""
    parser = argparse.ArgumentParser(description=(
        "Train Parallel WaveGAN (See detail in parallel_wavegan/bin/train.py)."
    ))
    parser.add_argument(
        "--train-wav-scp",
        default=None,
        type=str,
        help=("kaldi-style wav.scp file for training. "
              "you need to specify either train-*-scp or train-dumpdir."),
    )
    parser.add_argument(
        "--train-feats-scp",
        default=None,
        type=str,
        help=("kaldi-style feats.scp file for training. "
              "you need to specify either train-*-scp or train-dumpdir."),
    )
    parser.add_argument(
        "--train-segments",
        default=None,
        type=str,
        help="kaldi-style segments file for training.",
    )
    parser.add_argument(
        "--train-dumpdir",
        default=None,
        type=str,
        help=("directory including training data. "
              "you need to specify either train-*-scp or train-dumpdir."),
    )
    parser.add_argument(
        "--dev-wav-scp",
        default=None,
        type=str,
        help=("kaldi-style wav.scp file for validation. "
              "you need to specify either dev-*-scp or dev-dumpdir."),
    )
    parser.add_argument(
        "--dev-feats-scp",
        default=None,
        type=str,
        help=("kaldi-style feats.scp file for vaidation. "
              "you need to specify either dev-*-scp or dev-dumpdir."),
    )
    parser.add_argument(
        "--dev-segments",
        default=None,
        type=str,
        help="kaldi-style segments file for validation.",
    )
    parser.add_argument(
        "--dev-dumpdir",
        default=None,
        type=str,
        help=("directory including development data. "
              "you need to specify either dev-*-scp or dev-dumpdir."),
    )
    parser.add_argument(
        "--outdir",
        type=str,
        required=True,
        help="directory to save checkpoints.",
    )
    parser.add_argument(
        "--config",
        type=str,
        required=True,
        help="yaml format configuration file.",
    )
    parser.add_argument(
        "--pretrain",
        default="",
        type=str,
        nargs="?",
        help='checkpoint file path to load pretrained params. (default="")',
    )
    parser.add_argument(
        "--resume",
        default="",
        type=str,
        nargs="?",
        help='checkpoint file path to resume training. (default="")',
    )
    parser.add_argument(
        "--verbose",
        type=int,
        default=1,
        help="logging level. higher is more logging. (default=1)",
    )
    parser.add_argument(
        "--rank",
        "--local_rank",
        default=0,
        type=int,
        help="rank for distributed training. no need to explictly specify.",
    )
    args = parser.parse_args()

    args.distributed = False
    if not torch.cuda.is_available():
        device = torch.device("cpu")
    else:
        device = torch.device("cuda")
        # effective when using fixed size inputs
        # see https://discuss.pytorch.org/t/what-does-torch-backends-cudnn-benchmark-do/5936
        torch.backends.cudnn.benchmark = True
        torch.cuda.set_device(args.rank)
        # setup for distributed training
        # see example: https://github.com/NVIDIA/apex/tree/master/examples/simple/distributed
        if "WORLD_SIZE" in os.environ:
            args.world_size = int(os.environ["WORLD_SIZE"])
            args.distributed = args.world_size > 1
        if args.distributed:
            torch.distributed.init_process_group(backend="nccl",
                                                 init_method="env://")

    # suppress logging for distributed training
    if args.rank != 0:
        sys.stdout = open(os.devnull, "w")

    # set logger
    if args.verbose > 1:
        logging.basicConfig(
            level=logging.DEBUG,
            stream=sys.stdout,
            format=
            "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
        )
    elif args.verbose > 0:
        logging.basicConfig(
            level=logging.INFO,
            stream=sys.stdout,
            format=
            "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
        )
    else:
        logging.basicConfig(
            level=logging.WARN,
            stream=sys.stdout,
            format=
            "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
        )
        logging.warning("Skip DEBUG/INFO messages")

    # check directory existence
    if not os.path.exists(args.outdir):
        os.makedirs(args.outdir)

    # check arguments
    if (args.train_feats_scp is not None and args.train_dumpdir
            is not None) or (args.train_feats_scp is None
                             and args.train_dumpdir is None):
        raise ValueError(
            "Please specify either --train-dumpdir or --train-*-scp.")
    if (args.dev_feats_scp is not None and args.dev_dumpdir is not None) or (
            args.dev_feats_scp is None and args.dev_dumpdir is None):
        raise ValueError("Please specify either --dev-dumpdir or --dev-*-scp.")

    # load and save config
    with open(args.config) as f:
        config = yaml.load(f, Loader=yaml.Loader)
    config.update(vars(args))
    config["version"] = parallel_wavegan.__version__  # add version info
    with open(os.path.join(args.outdir, "config.yml"), "w") as f:
        yaml.dump(config, f, Dumper=yaml.Dumper)
    for key, value in config.items():
        logging.info(f"{key} = {value}")

    # get dataset
    if config["remove_short_samples"]:
        mel_length_threshold = config["batch_max_steps"] // config[
            "hop_size"] + 2 * config["generator_params"].get(
                "aux_context_window", 0)
    else:
        mel_length_threshold = None
    if args.train_wav_scp is None or args.dev_wav_scp is None:
        if config["format"] == "hdf5":
            audio_query, mel_query = "*.h5", "*.h5"
            audio_load_fn = lambda x: read_hdf5(x, "wave")  # NOQA
            mel_load_fn = lambda x: read_hdf5(x, "feats")  # NOQA
        elif config["format"] == "npy":
            audio_query, mel_query = "*-wave.npy", "*-feats.npy"
            audio_load_fn = np.load
            mel_load_fn = np.load
        else:
            raise ValueError("support only hdf5 or npy format.")
    if args.train_dumpdir is not None:
        train_dataset = AudioMelDataset(
            root_dir=args.train_dumpdir,
            audio_query=audio_query,
            mel_query=mel_query,
            audio_load_fn=audio_load_fn,
            mel_load_fn=mel_load_fn,
            mel_length_threshold=mel_length_threshold,
            allow_cache=config.get("allow_cache", False),  # keep compatibility
        )
    else:
        train_dataset = AudioMelSCPDataset(
            wav_scp=args.train_wav_scp,
            feats_scp=args.train_feats_scp,
            segments=args.train_segments,
            mel_length_threshold=mel_length_threshold,
            allow_cache=config.get("allow_cache", False),  # keep compatibility
        )
    logging.info(f"The number of training files = {len(train_dataset)}.")
    if args.dev_dumpdir is not None:
        dev_dataset = AudioMelDataset(
            root_dir=args.dev_dumpdir,
            audio_query=audio_query,
            mel_query=mel_query,
            audio_load_fn=audio_load_fn,
            mel_load_fn=mel_load_fn,
            mel_length_threshold=mel_length_threshold,
            allow_cache=config.get("allow_cache", False),  # keep compatibility
        )
    else:
        dev_dataset = AudioMelSCPDataset(
            wav_scp=args.dev_wav_scp,
            feats_scp=args.dev_feats_scp,
            segments=args.dev_segments,
            mel_length_threshold=mel_length_threshold,
            allow_cache=config.get("allow_cache", False),  # keep compatibility
        )
    logging.info(f"The number of development files = {len(dev_dataset)}.")
    dataset = {
        "train": train_dataset,
        "dev": dev_dataset,
    }

    # get data loader
    collater = Collater(
        batch_max_steps=config["batch_max_steps"],
        hop_size=config["hop_size"],
        # keep compatibility
        aux_context_window=config["generator_params"].get(
            "aux_context_window", 0),
        # keep compatibility
        use_noise_input=config.get("generator_type",
                                   "ParallelWaveGANGenerator")
        in ["ParallelWaveGANGenerator"],
    )
    sampler = {"train": None, "dev": None}
    if args.distributed:
        # setup sampler for distributed training
        from torch.utils.data.distributed import DistributedSampler

        sampler["train"] = DistributedSampler(
            dataset=dataset["train"],
            num_replicas=args.world_size,
            rank=args.rank,
            shuffle=True,
        )
        sampler["dev"] = DistributedSampler(
            dataset=dataset["dev"],
            num_replicas=args.world_size,
            rank=args.rank,
            shuffle=False,
        )
    data_loader = {
        "train":
        DataLoader(
            dataset=dataset["train"],
            shuffle=False if args.distributed else True,
            collate_fn=collater,
            batch_size=config["batch_size"],
            num_workers=config["num_workers"],
            sampler=sampler["train"],
            pin_memory=config["pin_memory"],
        ),
        "dev":
        DataLoader(
            dataset=dataset["dev"],
            shuffle=False if args.distributed else True,
            collate_fn=collater,
            batch_size=config["batch_size"],
            num_workers=config["num_workers"],
            sampler=sampler["dev"],
            pin_memory=config["pin_memory"],
        ),
    }

    # define models
    generator_class = getattr(
        parallel_wavegan.models,
        # keep compatibility
        config.get("generator_type", "ParallelWaveGANGenerator"),
    )
    discriminator_class = getattr(
        parallel_wavegan.models,
        # keep compatibility
        config.get("discriminator_type", "ParallelWaveGANDiscriminator"),
    )
    model = {
        "generator":
        generator_class(**config["generator_params"], ).to(device),
        "discriminator":
        discriminator_class(**config["discriminator_params"], ).to(device),
    }

    # define criterions
    criterion = {
        "gen_adv":
        GeneratorAdversarialLoss(
            # keep compatibility
            **config.get("generator_adv_loss_params", {})).to(device),
        "dis_adv":
        DiscriminatorAdversarialLoss(
            # keep compatibility
            **config.get("discriminator_adv_loss_params", {})).to(device),
    }
    if config.get("use_stft_loss", True):  # keep compatibility
        config["use_stft_loss"] = True
        criterion["stft"] = MultiResolutionSTFTLoss(
            **config["stft_loss_params"], ).to(device)
    if config.get("use_subband_stft_loss", False):  # keep compatibility
        assert config["generator_params"]["out_channels"] > 1
        criterion["sub_stft"] = MultiResolutionSTFTLoss(
            **config["subband_stft_loss_params"], ).to(device)
    else:
        config["use_subband_stft_loss"] = False
    if config.get("use_feat_match_loss", False):  # keep compatibility
        criterion["feat_match"] = FeatureMatchLoss(
            # keep compatibility
            **config.get("feat_match_loss_params", {}), ).to(device)
    else:
        config["use_feat_match_loss"] = False
    if config.get("use_mel_loss", False):  # keep compatibility
        if config.get("mel_loss_params", None) is None:
            criterion["mel"] = MelSpectrogramLoss(
                fs=config["sampling_rate"],
                fft_size=config["fft_size"],
                hop_size=config["hop_size"],
                win_length=config["win_length"],
                window=config["window"],
                num_mels=config["num_mels"],
                fmin=config["fmin"],
                fmax=config["fmax"],
            ).to(device)
        else:
            criterion["mel"] = MelSpectrogramLoss(**config["mel_loss_params"],
                                                  ).to(device)
    else:
        config["use_mel_loss"] = False

    # define special module for subband processing
    if config["generator_params"]["out_channels"] > 1:
        criterion["pqmf"] = PQMF(
            subbands=config["generator_params"]["out_channels"],
            # keep compatibility
            **config.get("pqmf_params", {}),
        ).to(device)

    # define optimizers and schedulers
    generator_optimizer_class = getattr(
        parallel_wavegan.optimizers,
        # keep compatibility
        config.get("generator_optimizer_type", "RAdam"),
    )
    discriminator_optimizer_class = getattr(
        parallel_wavegan.optimizers,
        # keep compatibility
        config.get("discriminator_optimizer_type", "RAdam"),
    )
    optimizer = {
        "generator":
        generator_optimizer_class(
            model["generator"].parameters(),
            **config["generator_optimizer_params"],
        ),
        "discriminator":
        discriminator_optimizer_class(
            model["discriminator"].parameters(),
            **config["discriminator_optimizer_params"],
        ),
    }
    generator_scheduler_class = getattr(
        torch.optim.lr_scheduler,
        # keep compatibility
        config.get("generator_scheduler_type", "StepLR"),
    )
    discriminator_scheduler_class = getattr(
        torch.optim.lr_scheduler,
        # keep compatibility
        config.get("discriminator_scheduler_type", "StepLR"),
    )
    scheduler = {
        "generator":
        generator_scheduler_class(
            optimizer=optimizer["generator"],
            **config["generator_scheduler_params"],
        ),
        "discriminator":
        discriminator_scheduler_class(
            optimizer=optimizer["discriminator"],
            **config["discriminator_scheduler_params"],
        ),
    }
    if args.distributed:
        # wrap model for distributed training
        try:
            from apex.parallel import DistributedDataParallel
        except ImportError:
            raise ImportError(
                "apex is not installed. please check https://github.com/NVIDIA/apex."
            )
        model["generator"] = DistributedDataParallel(model["generator"])
        model["discriminator"] = DistributedDataParallel(
            model["discriminator"])

    # show settings
    logging.info(model["generator"])
    logging.info(model["discriminator"])
    logging.info(optimizer["generator"])
    logging.info(optimizer["discriminator"])
    logging.info(scheduler["generator"])
    logging.info(scheduler["discriminator"])
    for criterion_ in criterion.values():
        logging.info(criterion_)

    # define trainer
    trainer = Trainer(
        steps=0,
        epochs=0,
        data_loader=data_loader,
        sampler=sampler,
        model=model,
        criterion=criterion,
        optimizer=optimizer,
        scheduler=scheduler,
        config=config,
        device=device,
    )

    # load pretrained parameters from checkpoint
    if len(args.pretrain) != 0:
        trainer.load_checkpoint(args.pretrain, load_only_params=True)
        logging.info(f"Successfully load parameters from {args.pretrain}.")

    # resume from checkpoint
    if len(args.resume) != 0:
        trainer.load_checkpoint(args.resume)
        logging.info(f"Successfully resumed from {args.resume}.")

    # run training loop
    try:
        trainer.run()
    finally:
        trainer.save_checkpoint(
            os.path.join(config["outdir"],
                         f"checkpoint-{trainer.steps}steps.pkl"))
        logging.info(f"Successfully saved checkpoint @ {trainer.steps}steps.")
Ejemplo n.º 10
0
 def mel_load_fn(x):
     return read_hdf5(x, "feats")  # NOQA
Ejemplo n.º 11
0
#                                sampling_rate=sampling_rate,
#                                hop_size=hop_size,
#                                fft_size=config["fft_size"],
#                                win_length=config["win_length"],
#                                window=config["window"],
#                                num_mels=config["num_mels"],
#                                fmin=config["fmin"],
#                                fmax=config["fmax"])
# eps=1e-10

# mel_basis = librosa.filters.mel(sr, n_fft, config["num_mels"], config["fmin"], config["fmax"])

# mel_out = np.log10(np.maximum(eps, np.dot(lin_out, mel_basis.T)))

# Normalize melgan mel spect
scaler = StandardScaler()
if config["format"] == "hdf5":
    scaler.mean_ = read_hdf5(stats, "mean")
    scaler.scale_ = read_hdf5(stats, "scale")
elif config["format"] == "npy":
    scaler.mean_ = np.load(stats)[0]
    scaler.scale_ = np.load(stats)[1]

mel_out = scaler.transform(mel_out)

# Process by melgan
audio = melgan(melgan_model, device, mel_out)

# audio = librosa.griffinlim(lin_out, win_length=win_length, hop_length=hop_length)

sf.write("output/test.wav", audio, 16000)
Ejemplo n.º 12
0
def main():
    """Run preprocessing process."""
    parser = argparse.ArgumentParser(
        description="Compute mean and variance of dumped raw features "
                    "(See detail in parallel_wavegan/bin/compute_statistics.py).")
    parser.add_argument("--feats-scp", "--scp", default=None, type=str,
                        help="kaldi-style feats.scp file. "
                             "you need to specify either feats-scp or rootdir.")
    parser.add_argument("--rootdir", type=str, required=True,
                        help="directory including feature files. "
                             "you need to specify either feats-scp or rootdir.")
    parser.add_argument("--config", type=str, required=True,
                        help="yaml format configuration file.")
    parser.add_argument("--dumpdir", default=None, type=str,
                        help="directory to save statistics. if not provided, "
                             "stats will be saved in the above root directory. (default=None)")
    parser.add_argument("--ftype", default='mel', type=str,
                        help="feature type")
    parser.add_argument("--verbose", type=int, default=1,
                        help="logging level. higher is more logging. (default=1)")

    # runtime mode
    args = parser.parse_args()

    # interactive mode
    # args = argparse.ArgumentParser()
    # args.feats_scp = None
    # args.config = 'egs/so_emo_female/voc1/conf/multi_band_melgan.v2.yaml'
    # args.verbose = 1
    # args.ftype = 'spec'

    # args.rootdir = '/data/evs/VCTK/VCTK-wgan/spec'
    # args.rootdir = '/data/evs/Arctic/spec'

    # args.dumpdir = os.path.join(args.rootdir, "")


    # set logger
    if args.verbose > 1:
        logging.basicConfig(
            level=logging.DEBUG, format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")
    elif args.verbose > 0:
        logging.basicConfig(
            level=logging.INFO, format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")
    else:
        logging.basicConfig(
            level=logging.WARN, format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")
        logging.warning('Skip DEBUG/INFO messages')

    # load config
    with open(args.config) as f:
        config = yaml.load(f, Loader=yaml.Loader)
    config.update(vars(args))

    # check arguments
    if (args.feats_scp is not None and args.rootdir is not None) or \
            (args.feats_scp is None and args.rootdir is None):
        raise ValueError("Please specify either --rootdir or --feats-scp.")

    # check directory existence
    if args.dumpdir is None:
        args.dumpdir = os.path.dirname(args.rootdir)
    if not os.path.exists(args.dumpdir):
        os.makedirs(args.dumpdir)

    # get dataset
    if args.feats_scp is None:
        if config["format"] == "hdf5":
            mel_query = "*.h5"
            mel_load_fn = lambda x: read_hdf5(x, "feats")  # NOQA
        elif config["format"] == "npy":
            mel_query = "*.mel.npy"
            mel_load_fn = np.load
            spc_query = "*.spec.npy"
            spc_load_fn = np.load
        else:
            raise ValueError("support only hdf5 or npy format.")
        dataset1 = MelDataset(
            args.rootdir,
            mel_query=mel_query,
            mel_load_fn=mel_load_fn)
        dataset2 = SpcDataset(
            args.rootdir,
            spc_query=spc_query,
            spc_load_fn=spc_load_fn)
    else:
        dataset = MelSCPDataset(args.feats_scp)
    logging.info(f"The number of files in mel dataset = {len(dataset1)}.")
    logging.info(f"The number of files in spc dataset = {len(dataset2)}.")

    # calculate statistics
    scaler = StandardScaler()
    if args.ftype == 'mel':
        for mel in tqdm(dataset1):
            scaler.partial_fit(mel)
    elif args.ftype == 'spec':
        for spc in tqdm(dataset2):
            scaler.partial_fit(spc)

    if config["format"] == "hdf5":
        write_hdf5(os.path.join(args.dumpdir, "{}_mean_std.h5".format(args.ftype)),
                   "mean", scaler.mean_.astype(np.float32))
        write_hdf5(os.path.join(args.dumpdir, "{}_mean_std.h5".format(args.ftype)),
                   "scale", scaler.scale_.astype(np.float32))
    else:
        stats = np.stack([scaler.mean_, scaler.scale_], axis=0)
        np.save(os.path.join(args.dumpdir, "{}_mean_std.npy".format(args.ftype)),
                stats.astype(np.float32), allow_pickle=False)
Ejemplo n.º 13
0
# assert len(mel) * config["hop_size"] == len(audio) #TODO: not sure

# apply global gain
if config["global_gain_scale"] > 0.0:
    audio *= config["global_gain_scale"]
if np.abs(audio).max() >= 1.0:
    logging.warn(
        f"Loaded audio file causes clipping. "
        f"it is better to re-consider global gain scale. Now exiting.")
    exit(0)

#================================================Normalization=========================================================
# restore scaler
scaler = StandardScaler()
if config["format"] == "hdf5":
    scaler.mean_ = read_hdf5(".\\vocoders\\melgan\\stats.h5", "mean")
    scaler.scale_ = read_hdf5(".\\vocoders\\melgan\\stats.h5", "scale")
# elif config["format"] == "npy":
# scaler.mean_ = np.load(args.stats)[0]
# scaler.scale_ = np.load(args.stats)[1]
else:
    raise ValueError(
        "support only hdf5 (and normally npy - but not now) format.")
# from version 0.23.0, this information is needed
scaler.n_features_in_ = scaler.mean_.shape[0]
mel = scaler.transform(mel)

# plt.imshow(mel)
# plt.show()

#==============================================Put it through network==================================================
Ejemplo n.º 14
0
def main():
    """The main function that runs training process."""
    # initialize the argument parser
    parser = argparse.ArgumentParser(description="Train Parallel WaveGAN.") # just a description of the job that the parser is used to support.

    # Add arguments to the parser
        #first is name of the argument
        #default: The value produced if the argument is absent from the command line
        #type: The type to which the command-line argument should be converted
        #help: hint that appears when the user doesnot know what is this argument [-h]
        #required: Whether or not the command-line option may be omitted (optionals only)
        #nargs:The number of command-line arguments that should be consumed
            # "?" One argument will be consumed from the command line if possible, and produced as a single item. If no command-line argument is present, 
            # the value from default will be produced. 
            # Note that for optional arguments, there is an additional case - the option string is present but not followed by a command-line argument. In this case the value from const will be produced.

    parser.add_argument("--train-wav-scp", default=None, type=str,
                        help="kaldi-style wav.scp file for training. "
                             "you need to specify either train-*-scp or train-dumpdir.")

    parser.add_argument("--train-feats-scp", default=None, type=str,
                        help="kaldi-style feats.scp file for training. "
                             "you need to specify either train-*-scp or train-dumpdir.")

    parser.add_argument("--train-segments", default=None, type=str,
                        help="kaldi-style segments file for training.")

    parser.add_argument("--train-dumpdir", default=None, type=str,
                        help="directory including training data. "
                             "you need to specify either train-*-scp or train-dumpdir.")

    parser.add_argument("--dev-wav-scp", default=None, type=str,
                        help="kaldi-style wav.scp file for validation. "
                             "you need to specify either dev-*-scp or dev-dumpdir.")

    parser.add_argument("--dev-feats-scp", default=None, type=str,
                        help="kaldi-style feats.scp file for vaidation. "
                             "you need to specify either dev-*-scp or dev-dumpdir.")

    parser.add_argument("--dev-segments", default=None, type=str,
                        help="kaldi-style segments file for validation.")

    parser.add_argument("--dev-dumpdir", default=None, type=str,
                        help="directory including development data. "
                             "you need to specify either dev-*-scp or dev-dumpdir.")

    parser.add_argument("--outdir", type=str, required=True,
                        help="directory to save checkpoints.")

    parser.add_argument("--config", type=str, required=True,
                        help="yaml format configuration file.")

    parser.add_argument("--pretrain", default="", type=str, nargs="?",
                        help="checkpoint file path to load pretrained params. (default=\"\")")

    parser.add_argument("--resume", default="", type=str, nargs="?",
                        help="checkpoint file path to resume training. (default=\"\")")

    parser.add_argument("--verbose", type=int, default=1,
                        help="logging level. higher is more logging. (default=1)")

    parser.add_argument("--rank", "--local_rank", default=0, type=int,
                        help="rank for distributed training. no need to explictly specify.")

    # parse all the input arguments 
    args = parser.parse_args()
    args.distributed = False

    if not torch.cuda.is_available(): #if gpu is not available 
        device = torch.device("cpu") #train on cpu
    else: #GPU
        device = torch.device("cuda")#train on gpu
        torch.backends.cudnn.benchmark = True # effective when using fixed size inputs (no conditional layers or layers inside loops),benchmark mode in cudnn,faster runtime
        torch.cuda.set_device(args.rank) # sets the default GPU for distributed training
        if "WORLD_SIZE" in os.environ:#determine max number of parallel processes (distributed)
            args.world_size = int(os.environ["WORLD_SIZE"]) #get the world size from the os
            args.distributed = args.world_size > 1 #set distributed if woldsize > 1 
        if args.distributed: 
            torch.distributed.init_process_group(backend="nccl", init_method="env://") #Use the NCCL backend for distributed GPU training (Rule of thumb)
                #NCCL:since it currently provides the best distributed GPU training performance, especially for multiprocess single-node or multi-node distributed training

    # suppress logging for distributed training
    if args.rank != 0: #if process is not p0
        sys.stdout = open(os.devnull, "w")#DEVNULL is Special value that can be used as the stdin, stdout or stderr argument to

    # set logger
    if args.verbose > 1: #if level of logging is heigher then 1
        logging.basicConfig( #configure the logging
            level=logging.DEBUG, stream=sys.stdout, #heigh logging level,detailed information, typically of interest only when diagnosing problems.
            format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s") #format includes Time,module,line#,level,and message.
    elif args.verbose > 0:#if level of logging is between 0,1
        logging.basicConfig(#configure the logging
            level=logging.INFO, stream=sys.stdout,#moderate logging level,Confirmation that things are working as expected.
            format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")#format includes Time,module,line#,level,and message.
    else:#if level of logging is 0
        logging.basicConfig(#configure the logging
            level=logging.WARN, stream=sys.stdout,#low logging level,An indication that something unexpected happened, or indicative of some problem in the near future (e.g. ‘disk space low’).
            format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")#format includes Time,module,line#,level,and message.
        logging.warning("Skip DEBUG/INFO messages")#tell the user that he will skip logging DEBUG/INFO messages by choosing this level.

    # check directory existence
    if not os.path.exists(args.outdir): #directory to save checkpoints
        os.makedirs(args.outdir)

    # check arguments
    if (args.train_feats_scp is not None and args.train_dumpdir is not None) or \ 
            (args.train_feats_scp is None and args.train_dumpdir is None):
            # if the user chooses both training data files (examples) or
            # the user doesnot choose any training data file
        raise ValueError("Please specify either --train-dumpdir or --train-*-scp.") #raise an error to tell the user to choose one training file
    if (args.dev_feats_scp is not None and args.dev_dumpdir is not None) or \
            (args.dev_feats_scp is None and args.dev_dumpdir is None):
            # if the user chooses both validatation data files (examples) or
            # the user doesnot choose any validatation data file
        raise ValueError("Please specify either --dev-dumpdir or --dev-*-scp.") #raise an error to tell the user to choose one validation data file

    # load config
    with open(args.config) as f:#open configuration file (yaml format)
        config = yaml.load(f, Loader=yaml.Loader) #load configuration file (yaml format to python object)
    # update config
    config.update(vars(args))#update arguments in configuration file
    config["version"] = parallel_wavegan.__version__  # add parallel wavegan version info
    # save config
    with open(os.path.join(args.outdir, "config.yml"), "w") as f:#open outdir/config.yml
        yaml.dump(config, f, Dumper=yaml.Dumper) #dump function accepts a Python object and produces a YAML document.
    # add config info to the high level logger.
    for key, value in config.items():
        logging.info(f"{key} = {value}")

    # get dataset
    if config["remove_short_samples"]:#if configuration tells to remove short samples from training.
        mel_length_threshold = config["batch_max_steps"] // config["hop_size"] + \
            2 * config["generator_params"].get("aux_context_window", 0)#th of length = floor(batch_max_steps/hop_size) + 2 * (generator_params.aux_context_window)
    else:
        mel_length_threshold = None # No th.
    if args.train_wav_scp is None or args.dev_wav_scp is None: #if at least one of training or evaluating datasets = None
        if config["format"] == "hdf5":# format of data = hdf5
            audio_query, mel_query = "*.h5", "*.h5" # audio and text queries = "...".h5
            #lambda example:
            #x = lambda a, b: a * b
            #x(5, 6)-->x(a=5,b=6)=a*b=5*6=30
            audio_load_fn = lambda x: read_hdf5(x, "wave")  # The function to load data,NOQA
            mel_load_fn = lambda x: read_hdf5(x, "feats")  # The function to load data,NOQA
        elif config["format"] == "npy":# format of data = npy
            audio_query, mel_query = "*-wave.npy", "*-feats.npy" #audio query = "..."-wave.npy and text query = "..."-feats.h5
            audio_load_fn = np.load#The function to load data.
            mel_load_fn = np.load#The function to load data.
        else:#if any other data format
            raise ValueError("support only hdf5 or npy format.") #raise error to tell the user the data format is not supported.

    if args.train_dumpdir is not None: # if training ds is not None
        train_dataset = AudioMelDataset( # define the training dataset
            root_dir=args.train_dumpdir,#the directory of ds.
            audio_query=audio_query,#audio query according to format above.
            mel_query=mel_query,#mel query according to format above.
            audio_load_fn=audio_load_fn,#load the function that loads the audio data according to format above.
            mel_load_fn=mel_load_fn,#load the function that loads the mel data according to format above.
            mel_length_threshold=mel_length_threshold,#th to remove short samples -calculated above-.
            allow_cache=config.get("allow_cache", False),  # keep compatibility.
        )
    else:# if training ds is None
        train_dataset = AudioMelSCPDataset(# define the training dataset
            wav_scp=args.train_wav_scp,
            feats_scp=args.train_feats_scp,
            segments=args.train_segments, #segments of dataset
            mel_length_threshold=mel_length_threshold,#th to remove short samples -calculated above-.
            allow_cache=config.get("allow_cache", False),  # keep compatibility
        )
    logging.info(f"The number of training files = {len(train_dataset)}.") # add length of trainning data set to the logger.
    if args.dev_dumpdir is not None: #if evaluating ds is not None
        dev_dataset = AudioMelDataset( # define the evaluating dataset
            root_dir=args.dev_dumpdir,#the directory of ds.
            audio_query=audio_query,#audio query according to format above.
            mel_query=mel_query,#mel query according to format above.
            audio_load_fn=audio_load_fn,#load the function that loads the audio data according to format above.
            mel_load_fn=mel_load_fn,#load the function that loads the mel data according to format above.
            mel_length_threshold=mel_length_threshold,#th to remove short samples -calculated above-.
            allow_cache=config.get("allow_cache", False),  # keep compatibility
        )
    else:# if evaluating ds is None
        dev_dataset = AudioMelSCPDataset(
            wav_scp=args.dev_wav_scp,
            feats_scp=args.dev_feats_scp,
            segments=args.dev_segments,#segments of dataset
            mel_length_threshold=mel_length_threshold,#th to remove short samples -calculated above-.
            allow_cache=config.get("allow_cache", False),  # keep compatibility
        )
    logging.info(f"The number of development files = {len(dev_dataset)}.") # add length of evaluating data set to the logger.
    dataset = {
        "train": train_dataset,
        "dev": dev_dataset,
    } #define the whole dataset used which is divided into training and evaluating datasets
    # get data loader
    collater = Collater(
        batch_max_steps=config["batch_max_steps"],
        hop_size=config["hop_size"],
        # keep compatibility
        aux_context_window=config["generator_params"].get("aux_context_window", 0),
        # keep compatibility
        use_noise_input=config.get(
            "generator_type", "ParallelWaveGANGenerator") != "MelGANGenerator",
    )
    train_sampler, dev_sampler = None, None
    if args.distributed:
        # setup sampler for distributed training
        from torch.utils.data.distributed import DistributedSampler
        train_sampler = DistributedSampler(
            dataset=dataset["train"],
            num_replicas=args.world_size,
            rank=args.rank,
            shuffle=True,
        )
        dev_sampler = DistributedSampler(
            dataset=dataset["dev"],
            num_replicas=args.world_size,
            rank=args.rank,
            shuffle=False,
        )
    data_loader = {
        "train": DataLoader(
            dataset=dataset["train"],
            shuffle=False if args.distributed else True,
            collate_fn=collater,
            batch_size=config["batch_size"],
            num_workers=config["num_workers"],
            sampler=train_sampler,
            pin_memory=config["pin_memory"],
        ),
        "dev": DataLoader(
            dataset=dataset["dev"],
            shuffle=False if args.distributed else True,
            collate_fn=collater,
            batch_size=config["batch_size"],
            num_workers=config["num_workers"],
            sampler=dev_sampler,
            pin_memory=config["pin_memory"],
        ),
    }

    # define models and optimizers
    generator_class = getattr(
        parallel_wavegan.models,
        # keep compatibility
        config.get("generator_type", "ParallelWaveGANGenerator"),
    )
    discriminator_class = getattr(
        parallel_wavegan.models,
        # keep compatibility
        config.get("discriminator_type", "ParallelWaveGANDiscriminator"),
    )
    model = {
        "generator": generator_class(
            **config["generator_params"]).to(device),
        "discriminator": discriminator_class(
            **config["discriminator_params"]).to(device),
    }
    criterion = {
        "stft": MultiResolutionSTFTLoss(
            **config["stft_loss_params"]).to(device),
        "mse": torch.nn.MSELoss().to(device),
    }
    if config.get("use_feat_match_loss", False):  # keep compatibility
        criterion["l1"] = torch.nn.L1Loss().to(device)
    optimizer = {
        "generator": RAdam(
            model["generator"].parameters(),
            **config["generator_optimizer_params"]),
        "discriminator": RAdam(
            model["discriminator"].parameters(),
            **config["discriminator_optimizer_params"]),
    }
    scheduler = {
        "generator": torch.optim.lr_scheduler.StepLR(
            optimizer=optimizer["generator"],
            **config["generator_scheduler_params"]),
        "discriminator": torch.optim.lr_scheduler.StepLR(
            optimizer=optimizer["discriminator"],
            **config["discriminator_scheduler_params"]),
    }
    if args.distributed:
        # wrap model for distributed training
        try:
            from apex.parallel import DistributedDataParallel
        except ImportError:
            raise ImportError("apex is not installed. please check https://github.com/NVIDIA/apex.")
        model["generator"] = DistributedDataParallel(model["generator"])
        model["discriminator"] = DistributedDataParallel(model["discriminator"])
    logging.info(model["generator"])
    logging.info(model["discriminator"])

    # define trainer
    trainer = Trainer(
        steps=0,
        epochs=0,
        data_loader=data_loader,
        model=model,
        criterion=criterion,
        optimizer=optimizer,
        scheduler=scheduler,
        config=config,
        device=device,
    )

    # load pretrained parameters from checkpoint
    if len(args.pretrain) != 0:
        trainer.load_checkpoint(args.pretrain, load_only_params=True)
        logging.info(f"Successfully load parameters from {args.pretrain}.")

    # resume from checkpoint
    if len(args.resume) != 0:
        trainer.load_checkpoint(args.resume)
        logging.info(f"Successfully resumed from {args.resume}.")

    # run training loop
    try:
        trainer.run()
    except KeyboardInterrupt:
        trainer.save_checkpoint(
            os.path.join(config["outdir"], f"checkpoint-{trainer.steps}steps.pkl"))
        logging.info(f"Successfully saved checkpoint @ {trainer.steps}steps.")
Ejemplo n.º 15
0
def main():
    """Run preprocessing process."""
    parser = argparse.ArgumentParser(
        description="Compute mean and variance of dumped raw features "
        "(See detail in parallel_wavegan/bin/compute_statistics.py).")
    parser.add_argument(
        "--feats-scp",
        "--scp",
        default=None,
        type=str,
        help="kaldi-style feats.scp file. "
        "you need to specify either feats-scp or rootdir.",
    )
    parser.add_argument(
        "--rootdir",
        type=str,
        help="directory including feature files. "
        "you need to specify either feats-scp or rootdir.",
    )
    parser.add_argument(
        "--config",
        type=str,
        required=True,
        help="yaml format configuration file.",
    )
    parser.add_argument(
        "--dumpdir",
        default=None,
        type=str,
        required=True,
        help="directory to save statistics. if not provided, "
        "stats will be saved in the above root directory. (default=None)",
    )
    parser.add_argument(
        "--verbose",
        type=int,
        default=1,
        help="logging level. higher is more logging. (default=1)",
    )
    args = parser.parse_args()

    # set logger
    if args.verbose > 1:
        logging.basicConfig(
            level=logging.DEBUG,
            format=
            "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
        )
    elif args.verbose > 0:
        logging.basicConfig(
            level=logging.INFO,
            format=
            "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
        )
    else:
        logging.basicConfig(
            level=logging.WARN,
            format=
            "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
        )
        logging.warning("Skip DEBUG/INFO messages")

    # load config
    with open(args.config) as f:
        config = yaml.load(f, Loader=yaml.Loader)
    config.update(vars(args))

    # check arguments
    if (args.feats_scp is not None
            and args.rootdir is not None) or (args.feats_scp is None
                                              and args.rootdir is None):
        raise ValueError("Please specify either --rootdir or --feats-scp.")

    # check directory existence
    if not os.path.exists(args.dumpdir):
        os.makedirs(args.dumpdir)

    # get dataset
    if args.feats_scp is None:
        if config["format"] == "hdf5":
            mel_query = "*.h5"
            mel_load_fn = lambda x: read_hdf5(x, "feats")  # NOQA
        elif config["format"] == "npy":
            mel_query = "*-feats.npy"
            mel_load_fn = np.load
        else:
            raise ValueError("support only hdf5 or npy format.")
        dataset = MelDataset(args.rootdir,
                             mel_query=mel_query,
                             mel_load_fn=mel_load_fn)
    else:
        dataset = MelSCPDataset(args.feats_scp)
    logging.info(f"The number of files = {len(dataset)}.")

    # calculate statistics
    scaler = StandardScaler()
    for mel in tqdm(dataset):
        scaler.partial_fit(mel)

    if config["format"] == "hdf5":
        write_hdf5(
            os.path.join(args.dumpdir, "stats.h5"),
            "mean",
            scaler.mean_.astype(np.float32),
        )
        write_hdf5(
            os.path.join(args.dumpdir, "stats.h5"),
            "scale",
            scaler.scale_.astype(np.float32),
        )
    else:
        stats = np.stack([scaler.mean_, scaler.scale_], axis=0)
        np.save(
            os.path.join(args.dumpdir, "stats.npy"),
            stats.astype(np.float32),
            allow_pickle=False,
        )
Ejemplo n.º 16
0
    def __init__(
        self,
        root_dir,
        audio_query="*.h5",
        mel_query="*.h5",
        audio_load_fn=lambda x: read_hdf5(x, "wave"),
        mel_load_fn=lambda x: read_hdf5(x, "feats"),
        audio_length_threshold=None,
        mel_length_threshold=None,
        return_utt_id=False,
        allow_cache=False,
    ):
        """Initialize dataset.

        Args:
            root_dir (str): Root directory including dumped files.
            audio_query (str): Query to find audio files in root_dir.
            mel_query (str): Query to find feature files in root_dir.
            audio_load_fn (func): Function to load audio file.
            mel_load_fn (func): Function to load feature file.
            audio_length_threshold (int): Threshold to remove short audio files.
            mel_length_threshold (int): Threshold to remove short feature files.
            return_utt_id (bool): Whether to return the utterance id with arrays.
            allow_cache (bool): Whether to allow cache of the loaded files.

        """
        # find all of audio and mel files
        audio_files = sorted(find_files(root_dir, audio_query))
        mel_files = sorted(find_files(root_dir, mel_query))

        # filter by threshold
        if audio_length_threshold is not None:
            audio_lengths = [audio_load_fn(f).shape[0] for f in audio_files]
            idxs = [
                idx for idx in range(len(audio_files))
                if audio_lengths[idx] > audio_length_threshold
            ]
            if len(audio_files) != len(idxs):
                logging.warning(
                    f"Some files are filtered by audio length threshold "
                    f"({len(audio_files)} -> {len(idxs)}).")
            audio_files = [audio_files[idx] for idx in idxs]
            mel_files = [mel_files[idx] for idx in idxs]
        if mel_length_threshold is not None:
            mel_lengths = [mel_load_fn(f).shape[0] for f in mel_files]
            idxs = [
                idx for idx in range(len(mel_files))
                if mel_lengths[idx] > mel_length_threshold
            ]
            if len(mel_files) != len(idxs):
                logging.warning(
                    f"Some files are filtered by mel length threshold "
                    f"({len(mel_files)} -> {len(idxs)}).")
            audio_files = [audio_files[idx] for idx in idxs]
            mel_files = [mel_files[idx] for idx in idxs]

        # assert the number of files
        assert len(
            audio_files) != 0, f"Not found any audio files in ${root_dir}."
        assert len(audio_files) == len(mel_files), \
            f"Number of audio and mel files are different ({len(audio_files)} vs {len(mel_files)})."

        self.audio_files = audio_files
        self.audio_load_fn = audio_load_fn
        self.mel_load_fn = mel_load_fn
        self.mel_files = mel_files
        if ".npy" in audio_query:
            self.utt_ids = [
                os.path.basename(f).replace("-wave.npy", "")
                for f in audio_files
            ]
        else:
            self.utt_ids = [
                os.path.splitext(os.path.basename(f))[0] for f in audio_files
            ]
        self.return_utt_id = return_utt_id
        self.allow_cache = allow_cache
        if allow_cache:
            # NOTE(kan-bayashi): Manager is need to share memory in dataloader with num_workers > 0
            self.manager = Manager()
            self.caches = self.manager.list()
            self.caches += [() for _ in range(len(audio_files))]
Ejemplo n.º 17
0
def main():
    """Run preprocessing process."""
    parser = argparse.ArgumentParser(
        description="Normalize dumped raw features (See detail in parallel_wavegan/bin/normalize.py)."
    )
    parser.add_argument(
        "--rootdir",
        default=None,
        type=str,
        help="directory including feature files to be normalized. "
        "you need to specify either *-scp or rootdir.",
    )
    parser.add_argument(
        "--wav-scp",
        default=None,
        type=str,
        help="kaldi-style wav.scp file. "
        "you need to specify either *-scp or rootdir.",
    )
    parser.add_argument(
        "--feats-scp",
        default=None,
        type=str,
        help="kaldi-style feats.scp file. "
        "you need to specify either *-scp or rootdir.",
    )
    parser.add_argument(
        "--segments",
        default=None,
        type=str,
        help="kaldi-style segments file.",
    )
    parser.add_argument(
        "--dumpdir",
        type=str,
        required=True,
        help="directory to dump normalized feature files.",
    )
    parser.add_argument(
        "--stats",
        type=str,
        required=True,
        help="statistics file.",
    )
    parser.add_argument(
        "--skip-wav-copy",
        default=False,
        action="store_true",
        help="whether to skip the copy of wav files.",
    )
    parser.add_argument(
        "--config", type=str, required=True, help="yaml format configuration file."
    )
    parser.add_argument(
        "--verbose",
        type=int,
        default=1,
        help="logging level. higher is more logging. (default=1)",
    )
    args = parser.parse_args()

    # set logger
    if args.verbose > 1:
        logging.basicConfig(
            level=logging.DEBUG,
            format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
        )
    elif args.verbose > 0:
        logging.basicConfig(
            level=logging.INFO,
            format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
        )
    else:
        logging.basicConfig(
            level=logging.WARN,
            format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
        )
        logging.warning("Skip DEBUG/INFO messages")

    # load config
    with open(args.config) as f:
        config = yaml.load(f, Loader=yaml.Loader)
    config.update(vars(args))

    # check arguments
    if (args.feats_scp is not None and args.rootdir is not None) or (
        args.feats_scp is None and args.rootdir is None
    ):
        raise ValueError("Please specify either --rootdir or --feats-scp.")

    # check directory existence
    if not os.path.exists(args.dumpdir):
        os.makedirs(args.dumpdir)

    # get dataset
    if args.rootdir is not None:
        if config["format"] == "hdf5":
            audio_query, mel_query = "*.h5", "*.h5"
            audio_load_fn = lambda x: read_hdf5(x, "wave")  # NOQA
            mel_load_fn = lambda x: read_hdf5(x, "feats")  # NOQA
        elif config["format"] == "npy":
            audio_query, mel_query = "*-wave.npy", "*-feats.npy"
            audio_load_fn = np.load
            mel_load_fn = np.load
        else:
            raise ValueError("support only hdf5 or npy format.")
        if not args.skip_wav_copy:
            dataset = AudioMelDataset(
                root_dir=args.rootdir,
                audio_query=audio_query,
                mel_query=mel_query,
                audio_load_fn=audio_load_fn,
                mel_load_fn=mel_load_fn,
                return_utt_id=True,
            )
        else:
            dataset = MelDataset(
                root_dir=args.rootdir,
                mel_query=mel_query,
                mel_load_fn=mel_load_fn,
                return_utt_id=True,
            )
    else:
        if not args.skip_wav_copy:
            dataset = AudioMelSCPDataset(
                wav_scp=args.wav_scp,
                feats_scp=args.feats_scp,
                segments=args.segments,
                return_utt_id=True,
            )
        else:
            dataset = MelSCPDataset(
                feats_scp=args.feats_scp,
                return_utt_id=True,
            )
    logging.info(f"The number of files = {len(dataset)}.")

    # restore scaler
    scaler = StandardScaler()
    if config["format"] == "hdf5":
        scaler.mean_ = read_hdf5(args.stats, "mean")
        scaler.scale_ = read_hdf5(args.stats, "scale")
    elif config["format"] == "npy":
        scaler.mean_ = np.load(args.stats)[0]
        scaler.scale_ = np.load(args.stats)[1]
    else:
        raise ValueError("support only hdf5 or npy format.")
    # from version 0.23.0, this information is needed
    scaler.n_features_in_ = scaler.mean_.shape[0]

    # process each file
    for items in tqdm(dataset):
        if not args.skip_wav_copy:
            utt_id, audio, mel = items
        else:
            utt_id, mel = items

        # normalize
        mel = scaler.transform(mel)

        # save
        if config["format"] == "hdf5":
            write_hdf5(
                os.path.join(args.dumpdir, f"{utt_id}.h5"),
                "feats",
                mel.astype(np.float32),
            )
            if not args.skip_wav_copy:
                write_hdf5(
                    os.path.join(args.dumpdir, f"{utt_id}.h5"),
                    "wave",
                    audio.astype(np.float32),
                )
        elif config["format"] == "npy":
            np.save(
                os.path.join(args.dumpdir, f"{utt_id}-feats.npy"),
                mel.astype(np.float32),
                allow_pickle=False,
            )
            if not args.skip_wav_copy:
                np.save(
                    os.path.join(args.dumpdir, f"{utt_id}-wave.npy"),
                    audio.astype(np.float32),
                    allow_pickle=False,
                )
        else:
            raise ValueError("support only hdf5 or npy format.")
Ejemplo n.º 18
0
audio = audio[:len(mel) * config["hop_size"]]
assert len(mel) * config["hop_size"] == len(audio)

# apply global gain
if config["global_gain_scale"] > 0.0:
    audio *= config["global_gain_scale"]
if np.abs(audio).max() >= 1.0:
    logging.warn(f"Loaded audio file causes clipping. "
                    f"it is better to re-consider global gain scale. Now exiting.")
    exit(0)

#================================================Normalization=========================================================
# restore scaler
scaler = StandardScaler()
if config["format"] == "hdf5": 
    scaler.mean_ = read_hdf5("./parallel_wavegan/stats.h5", "mean")
    scaler.scale_ = read_hdf5("./parallel_wavegan/stats.h5", "scale")
# elif config["format"] == "npy":
    # scaler.mean_ = np.load(args.stats)[0]
    # scaler.scale_ = np.load(args.stats)[1]
else:
    raise ValueError("support only hdf5 (and normally npy - but not now) format.")
# from version 0.23.0, this information is needed
scaler.n_features_in_ = scaler.mean_.shape[0]
mel = scaler.transform(mel)

# plt.imshow(mel)
# plt.show()

#==============================================Put it through network==================================================
# converter.output_to_wav([[mel]])
Ejemplo n.º 19
0
def make_wav(args):

    import argparse
    import logging
    import os
    import time

    import numpy as np
    import soundfile as sf
    import torch
    import yaml

    from tqdm import tqdm

    import parallel_wavegan.models

    from parallel_wavegan.datasets import MelDataset
    from parallel_wavegan.datasets import MelSCPDataset
    from parallel_wavegan.utils import read_hdf5
    """Run decoding process."""

    # set logger
    if args.verbose > 1:
        logging.basicConfig(
            level=logging.DEBUG,
            format=
            "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")
    elif args.verbose > 0:
        logging.basicConfig(
            level=logging.INFO,
            format=
            "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")
    else:
        logging.basicConfig(
            level=logging.WARN,
            format=
            "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")
        logging.warning("Skip DEBUG/INFO messages")

    # check directory existence
    if not os.path.exists(args.outdir):
        os.makedirs(args.outdir)

    # load config
    if args.config is None:
        dirname = os.path.dirname(args.checkpoint)
        args.config = os.path.join(dirname, "config.yml")
    with open(args.config) as f:
        config = yaml.load(f, Loader=yaml.Loader)
    config.update(vars(args))

    # check arguments
    if (args.feats_scp is not None and args.dumpdir is not None) or \
            (args.feats_scp is None and args.dumpdir is None):
        raise ValueError("Please specify either --dumpdir or --feats-scp.")

    # get dataset
    if args.dumpdir is not None:
        if config["format"] == "hdf5":
            mel_query = "*.h5"
            mel_load_fn = lambda x: read_hdf5(x, "feats")  # NOQA
        elif config["format"] == "npy":
            mel_query = "*-feats.npy"
            mel_load_fn = np.load
        else:
            raise ValueError("support only hdf5 or npy format.")
        dataset = MelDataset(
            args.dumpdir,
            mel_query=mel_query,
            mel_load_fn=mel_load_fn,
            return_utt_id=True,
        )
    else:
        dataset = MelSCPDataset(
            feats_scp=args.feats_scp,
            return_utt_id=True,
        )

    logging.info(f"The number of features to be decoded = {len(dataset)}.")

    # setup
    if torch.cuda.is_available():
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")
    model_class = getattr(
        parallel_wavegan.models,
        config.get("generator_type", "ParallelWaveGANGenerator"))
    model = model_class(**config["generator_params"])
    model.load_state_dict(
        torch.load(args.checkpoint, map_location="cpu")["model"]["generator"])
    logging.info(f"Loaded model parameters from {args.checkpoint}.")
    model.remove_weight_norm()
    model = model.eval().to(device)
    use_noise_input = not isinstance(model,
                                     parallel_wavegan.models.MelGANGenerator)
    pad_fn = torch.nn.ReplicationPad1d(config["generator_params"].get(
        "aux_context_window", 0))

    # start generation
    total_rtf = 0.0
    with torch.no_grad(), tqdm(dataset, desc="[decode]") as pbar:
        for idx, (utt_id, c) in enumerate(pbar, 1):
            x = ()
            #c = c.T
            if use_noise_input:
                z = torch.randn(1, 1,
                                np.shape(c)[2] * config["hop_size"]).to(device)
                x += (z, )
            print(c.shape)
            c = torch.from_numpy(c)
            c = c.type(torch.cuda.FloatTensor).to(device)
            c = pad_fn(c)
            x += (c, )

            # setup input
            #---------------------------------------------------------------------
            '''
            x = ()
            print(c.shape)
            if use_noise_input:
                print('len(c).shape: ', len(c))
                z = torch.randn(1, 1, np.shape(c)[2] * config["hop_size"]).to(device)
                x += (z,)
            c = c.type(torch.cuda.FloatTensor).to(device)
            c = pad_fn(c)
            x += (c,)
            #c = pad_fn(torch.from_numpy(c).unsqueeze(0).transpose(2, 1)).to(device)
            '''
            #---------------------------------------------------------
            '''
            import pickle
            x_ = ()
            with open('test2.pickle', 'rb') as f:
                c_ = pickle.load(f)
            print(c_.shape)
            if use_noise_input:
                #print('c_.shape : ', np.shape(c_)[2])
                z = torch.randn(1, 1, np.shape(c_)[2] * config["hop_size"]).to(device)
                x_ += (z,)
            c_ = c_.type(torch.cuda.FloatTensor).to(device)
            c_ = pad_fn(c_)
            x_ += (c_,)
            '''
            #---------------------------------------------------------
            # generate
            start = time.time()
            y = model(*x).view(-1).cpu().numpy()
            rtf = (time.time() - start) / (len(y) / config["sampling_rate"])
            pbar.set_postfix({"RTF": rtf})
            total_rtf += rtf

            # save as PCM 16 bit wav file
            sf.write(os.path.join(config["outdir"], f"{utt_id}_gen.wav"), y,
                     config["sampling_rate"], "PCM_16")

    # report average RTF
    logging.info(
        f"Finished generation of {idx} utterances (RTF = {total_rtf / idx:.03f})."
    )
Ejemplo n.º 20
0
def main():
    """Run training process."""
    parser = argparse.ArgumentParser(
        description=
        "Train Parallel WaveGAN (See detail in parallel_wavegan/bin/train.py)."
    )
    parser.add_argument("--train-dumpdir",
                        type=str,
                        required=True,
                        help="directory including trainning data.")
    parser.add_argument("--dev-dumpdir",
                        type=str,
                        required=True,
                        help="directory including development data.")
    parser.add_argument("--outdir",
                        type=str,
                        required=True,
                        help="directory to save checkpoints.")
    parser.add_argument("--config",
                        type=str,
                        required=True,
                        help="yaml format configuration file.")
    parser.add_argument(
        "--resume",
        default="",
        type=str,
        nargs="?",
        help="checkpoint file path to resume training. (default=\"\")")
    parser.add_argument(
        "--verbose",
        type=int,
        default=1,
        help="logging level. higher is more logging. (default=1)")
    args = parser.parse_args()

    # set logger
    if args.verbose > 1:
        logging.basicConfig(
            level=logging.DEBUG,
            format=
            "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")
    elif args.verbose > 0:
        logging.basicConfig(
            level=logging.INFO,
            format=
            "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")
    else:
        logging.basicConfig(
            level=logging.WARN,
            format=
            "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")
        logging.warning('skip DEBUG/INFO messages')

    # check directory existence
    if not os.path.exists(args.outdir):
        os.makedirs(args.outdir)

    # load and save config
    with open(args.config) as f:
        config = yaml.load(f, Loader=yaml.Loader)
    config.update(vars(args))
    with open(os.path.join(args.outdir, "config.yml"), "w") as f:
        yaml.dump(config, f, Dumper=yaml.Dumper)
    for key, value in config.items():
        logging.info(f"{key} = {value}")

    # get dataset
    if config["remove_short_samples"]:
        mel_length_threshold = config["batch_max_steps"] // config["hop_size"] + \
            2 * config["generator_params"]["aux_context_window"]
    else:
        mel_length_threshold = None
    if config["format"] == "hdf5":
        audio_query, mel_query = "*.h5", "*.h5"
        audio_load_fn = lambda x: read_hdf5(x, "wave")  # NOQA
        mel_load_fn = lambda x: read_hdf5(x, "feats")  # NOQA
    elif config["format"] == "npy":
        audio_query, mel_query = "*-wave.npy", "*-feats.npy"
        audio_load_fn = np.load
        mel_load_fn = np.load
    else:
        raise ValueError("support only hdf5 or npy format.")
    dataset = {
        "train":
        AudioMelDataset(
            root_dir=args.train_dumpdir,
            audio_query=audio_query,
            mel_query=mel_query,
            audio_load_fn=audio_load_fn,
            mel_load_fn=mel_load_fn,
            mel_length_threshold=mel_length_threshold,
            allow_cache=config.get("allow_cache", False),  # keep compatibilty
        ),
        "dev":
        AudioMelDataset(
            root_dir=args.dev_dumpdir,
            audio_query=audio_query,
            mel_query=mel_query,
            audio_load_fn=audio_load_fn,
            mel_load_fn=mel_load_fn,
            mel_length_threshold=mel_length_threshold,
            allow_cache=config.get("allow_cache", False),  # keep compatibilty
        ),
    }

    # get data loader
    if torch.cuda.is_available():
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")
    collater = Collater(
        batch_max_steps=config["batch_max_steps"],
        hop_size=config["hop_size"],
        aux_context_window=config["generator_params"]["aux_context_window"],
    )
    data_loader = {
        "train":
        DataLoader(dataset=dataset["train"],
                   shuffle=True,
                   collate_fn=collater,
                   batch_size=config["batch_size"],
                   num_workers=config["num_workers"],
                   pin_memory=config["pin_memory"]),
        "dev":
        DataLoader(dataset=dataset["dev"],
                   shuffle=True,
                   collate_fn=collater,
                   batch_size=config["batch_size"],
                   num_workers=config["num_workers"],
                   pin_memory=config["pin_memory"]),
    }

    # define models and optimizers
    model = {
        "generator":
        ParallelWaveGANGenerator(**config["generator_params"]).to(device),
        "discriminator":
        ParallelWaveGANDiscriminator(
            **config["discriminator_params"]).to(device),
    }
    criterion = {
        "stft":
        MultiResolutionSTFTLoss(**config["stft_loss_params"]).to(device),
        "mse": torch.nn.MSELoss().to(device),
    }
    optimizer = {
        "generator":
        RAdam(model["generator"].parameters(),
              **config["generator_optimizer_params"]),
        "discriminator":
        RAdam(model["discriminator"].parameters(),
              **config["discriminator_optimizer_params"]),
    }
    scheduler = {
        "generator":
        torch.optim.lr_scheduler.StepLR(
            optimizer=optimizer["generator"],
            **config["generator_scheduler_params"]),
        "discriminator":
        torch.optim.lr_scheduler.StepLR(
            optimizer=optimizer["discriminator"],
            **config["discriminator_scheduler_params"]),
    }
    logging.info(model["generator"])
    logging.info(model["discriminator"])

    # define trainer
    trainer = Trainer(
        steps=0,
        epochs=0,
        data_loader=data_loader,
        model=model,
        criterion=criterion,
        optimizer=optimizer,
        scheduler=scheduler,
        config=config,
        device=device,
    )

    # resume from checkpoint
    if len(args.resume) != 0:
        trainer.load_checkpoint(args.resume)
        logging.info(f"resumed from {args.resume}.")

    # run training loop
    try:
        trainer.run()
    finally:
        trainer.save_checkpoint(
            os.path.join(config["outdir"],
                         f"checkpoint-{trainer.steps}steps.pkl"))
        logging.info(f"successfully saved checkpoint @ {trainer.steps}steps.")
Ejemplo n.º 21
0
def main():
    """Run decoding process."""
    parser = argparse.ArgumentParser(
        description=
        "Decode dumped features with trained Parallel WaveGAN Generator "
        "(See detail in parallel_wavegan/bin/decode.py).")
    parser.add_argument("--feats-scp",
                        "--scp",
                        default=None,
                        type=str,
                        help="kaldi-style feats.scp file. "
                        "you need to specify either feats-scp or dumpdir.")
    parser.add_argument("--dumpdir",
                        default=None,
                        type=str,
                        help="directory including feature files. "
                        "you need to specify either feats-scp or dumpdir.")
    parser.add_argument("--outdir",
                        type=str,
                        required=True,
                        help="directory to save generated speech.")
    parser.add_argument("--checkpoint",
                        type=str,
                        required=True,
                        help="checkpoint file to be loaded.")
    parser.add_argument(
        "--config",
        default=None,
        type=str,
        help="yaml format configuration file. if not explicitly provided, "
        "it will be searched in the checkpoint directory. (default=None)")
    parser.add_argument(
        "--verbose",
        type=int,
        default=1,
        help="logging level. higher is more logging. (default=1)")
    args = parser.parse_args()

    # set logger
    if args.verbose > 1:
        logging.basicConfig(
            level=logging.DEBUG,
            format=
            "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")
    elif args.verbose > 0:
        logging.basicConfig(
            level=logging.INFO,
            format=
            "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")
    else:
        logging.basicConfig(
            level=logging.WARN,
            format=
            "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")
        logging.warning("Skip DEBUG/INFO messages")

    # check directory existence
    if not os.path.exists(args.outdir):
        os.makedirs(args.outdir)

    # load config
    if args.config is None:
        dirname = os.path.dirname(args.checkpoint)
        args.config = os.path.join(dirname, "config.yml")
    with open(args.config) as f:
        config = yaml.load(f, Loader=yaml.Loader)
    config.update(vars(args))

    # check arguments
    if (args.feats_scp is not None and args.dumpdir is not None) or \
            (args.feats_scp is None and args.dumpdir is None):
        raise ValueError("Please specify either --dumpdir or --feats-scp.")

    # get dataset
    if args.dumpdir is not None:
        if config["format"] == "hdf5":
            mel_query = "*.h5"
            mel_load_fn = lambda x: read_hdf5(x, "feats")  # NOQA
        elif config["format"] == "npy":
            mel_query = "*-feats.npy"
            mel_load_fn = np.load
        else:
            raise ValueError("Support only hdf5 or npy format.")
        dataset = MelDataset(
            args.dumpdir,
            mel_query=mel_query,
            mel_load_fn=mel_load_fn,
            return_utt_id=True,
        )
    else:
        dataset = MelSCPDataset(
            feats_scp=args.feats_scp,
            return_utt_id=True,
        )
    logging.info(f"The number of features to be decoded = {len(dataset)}.")

    # setup model
    if torch.cuda.is_available():
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")
    model = load_model(args.checkpoint, config)
    logging.info(f"Loaded model parameters from {args.checkpoint}.")
    model.remove_weight_norm()
    model = model.eval().to(device)

    # start generation
    total_rtf = 0.0
    with torch.no_grad(), tqdm(dataset, desc="[decode]") as pbar:
        for idx, (utt_id, c) in enumerate(pbar, 1):
            # generate
            c = torch.tensor(c, dtype=torch.float).to(device)
            start = time.time()
            y = model.inference(c).view(-1)
            rtf = (time.time() - start) / (len(y) / config["sampling_rate"])
            pbar.set_postfix({"RTF": rtf})
            total_rtf += rtf

            # save as PCM 16 bit wav file
            sf.write(os.path.join(config["outdir"], f"{utt_id}_gen.wav"),
                     y.cpu().numpy(), config["sampling_rate"], "PCM_16")

    # report average RTF
    logging.info(
        f"Finished generation of {idx} utterances (RTF = {total_rtf / idx:.03f})."
    )
Ejemplo n.º 22
0
 def mel_load_fn(x):
     return read_hdf5(x, feat_query)  # NOQA
Ejemplo n.º 23
0
def main():
    """Run preprocessing process."""
    parser = argparse.ArgumentParser(
        description="Normalize dumped raw features (See detail in parallel_wavegan/bin/normalize.py).")
    parser.add_argument("--rootdir", default=None, type=str,
                        help="directory including feature files to be normalized. "
                             "you need to specify either *-scp or rootdir.")
    parser.add_argument("--wav-scp", default=None, type=str,
                        help="kaldi-style wav.scp file. "
                             "you need to specify either *-scp or rootdir.")
    parser.add_argument("--feats-scp", default=None, type=str,
                        help="kaldi-style feats.scp file. "
                             "you need to specify either *-scp or rootdir.")
    parser.add_argument("--segments", default=None, type=str,
                        help="kaldi-style segments file.")
    parser.add_argument("--dumpdir", type=str, required=True,
                        help="directory to dump normalized feature files.")
    parser.add_argument("--stats", type=str, required=True,
                        help="statistics file.")
    parser.add_argument("--skip-wav-copy", default=False, action="store_true",
                        help="whether to skip the copy of wav files.")
    parser.add_argument("--config", type=str, required=True,
                        help="yaml format configuration file.")
    parser.add_argument("--ftype", default='mel', type=str,
                        help="feature type")
    parser.add_argument("--verbose", type=int, default=1,
                        help="logging level. higher is more logging. (default=1)")

    # runtime mode
    args = parser.parse_args()

    # interactive mode
    # args = argparse.ArgumentParser()
    # args.wav_scp = None
    # args.feats_scp = None
    # args.segment = None
    # args.dumpdir = ""
    # args.skip_wav_copy = True
    # args.config = 'egs/so_emo_female/voc1/conf/multi_band_melgan.v2.yaml'
    # args.ftype = 'spec'
    # args.verbose = 1

    # args.rootdir = '/data/evs/VCTK/VCTK-wgan/spec'
    # args.stats = '/data/evs/VCTK/VCTK-wgan/spec/mel_mean_std.npy'

    # args.rootdir = '/data/evs/Arctic/spec'
    # args.stats = '/data/evs/Arctic/spec/spec_mean_std.npy'

    # set logger
    if args.verbose > 1:
        logging.basicConfig(
            level=logging.DEBUG, format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")
    elif args.verbose > 0:
        logging.basicConfig(
            level=logging.INFO, format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")
    else:
        logging.basicConfig(
            level=logging.WARN, format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")
        logging.warning('Skip DEBUG/INFO messages')

    # load config
    with open(args.config) as f:
        config = yaml.load(f, Loader=yaml.Loader)
    config.update(vars(args))

    # check arguments
    if (args.feats_scp is not None and args.rootdir is not None) or \
            (args.feats_scp is None and args.rootdir is None):
        raise ValueError("Please specify either --rootdir or --feats-scp.")

    # check directory existence
    if args.dumpdir != "":
        if not os.path.exists(args.dumpdir):
            os.makedirs(args.dumpdir, exist_ok=True)

    # get dataset
    if args.rootdir is not None:
        if config["format"] == "hdf5":
            audio_query, mel_query = "*.h5", "*.h5"
            audio_load_fn = lambda x: read_hdf5(x, "wave")  # NOQA
            mel_load_fn = lambda x: read_hdf5(x, "feats")  # NOQA
        elif config["format"] == "npy":
            audio_query, mel_query, spc_query = "*.wav.npy", "*.mel.npy", "*.spec.npy"
            audio_load_fn = np.load
            mel_load_fn = np.load
            spc_load_fn = np.load
        else:
            raise ValueError("support only hdf5 or npy format.")
        if not args.skip_wav_copy:
            dataset = AudioMelDataset(
                root_dir=args.rootdir,
                audio_query=audio_query,
                mel_query=mel_query,
                audio_load_fn=audio_load_fn,
                mel_load_fn=mel_load_fn,
                return_utt_id=True,
            )
        else:
            dataset1 = MelDatasetNew(
                root_dir=args.rootdir,
                mel_query=mel_query,
                mel_load_fn=mel_load_fn,
                return_utt_id=True,
            )
            dataset2 = SpcDatasetNew(
                root_dir=args.rootdir,
                spc_query=spc_query,
                spc_load_fn=spc_load_fn,
                return_utt_id=True,
            )
    else:
        if not args.skip_wav_copy:
            dataset = AudioMelSCPDataset(
                wav_scp=args.wav_scp,
                feats_scp=args.feats_scp,
                segments=args.segments,
                return_utt_id=True,
            )
        else:
            dataset = MelSCPDataset(
                feats_scp=args.feats_scp,
                return_utt_id=True,
            )
    logging.info(f"The number of files in mel dataset = {len(dataset1)}.")
    logging.info(f"The number of files in spc dataset = {len(dataset2)}.")

    # restore scaler
    scaler = StandardScaler()
    if config["format"] == "hdf5":
        scaler.mean_ = read_hdf5(args.stats, "mean")
        scaler.scale_ = read_hdf5(args.stats, "scale")
    elif config["format"] == "npy":
        scaler.mean_ = np.load(args.stats)[0]
        scaler.scale_ = np.load(args.stats)[1]
    else:
        raise ValueError("support only hdf5 or npy format.")
    # from version 0.23.0, this information is needed
    scaler.n_features_in_ = scaler.mean_.shape[0]

    # process each file
    if args.ftype == 'mel':
      dataset = dataset1
    elif args.ftype == 'spec':
      dataset = dataset2

    for items in tqdm(dataset):
        if not args.skip_wav_copy:
            utt_id, audio, feat = items
        else:
            utt_id, feat, feat_file = items

        # normalize
        feat = scaler.transform(feat)
        # feat = (feat - scaler.mean_) / scaler.scale_ # this is identical to scaler.transform(feat)

        # save
        if config["format"] == "hdf5":
            write_hdf5(os.path.join(args.dumpdir, f"{utt_id}.h5"),
                       "feats", feat.astype(np.float32))
            if not args.skip_wav_copy:
                write_hdf5(os.path.join(args.dumpdir, f"{utt_id}.h5"),
                           "wave", audio.astype(np.float32))
        elif config["format"] == "npy":
            if args.dumpdir == "":
                feat_file = feat_file.replace('.npy', '')

                np.save((feat_file + "-norm.npy"),
                        feat.astype(np.float32), allow_pickle=False)
                if not args.skip_wav_copy:
                    print("Please include --skip_wav_copy in arguments")

            else:
                np.save(os.path.join(args.dumpdir, f"{utt_id}.npy"),
                        feat.astype(np.float32), allow_pickle=False)
                if not args.skip_wav_copy:
                    np.save(os.path.join(args.dumpdir, f"{utt_id}.wav.npy"),
                            audio.astype(np.float32), allow_pickle=False)
        else:
            raise ValueError("support only hdf5 or npy format.")
Ejemplo n.º 24
0
def main():
    """Run training process."""
    parser = argparse.ArgumentParser(
        description=
        "Train Parallel WaveGAN (See detail in parallel_wavegan/bin/train.py)."
    )
    parser.add_argument("--train-dumpdir",
                        type=str,
                        required=True,
                        help="directory including training data.")
    parser.add_argument("--dev-dumpdir",
                        type=str,
                        required=True,
                        help="directory including development data.")
    parser.add_argument("--outdir",
                        type=str,
                        required=True,
                        help="directory to save checkpoints.")
    parser.add_argument("--config",
                        type=str,
                        required=True,
                        help="yaml format configuration file.")
    parser.add_argument(
        "--resume",
        default="",
        type=str,
        nargs="?",
        help="checkpoint file path to resume training. (default=\"\")")
    parser.add_argument(
        "--verbose",
        type=int,
        default=1,
        help="logging level. higher is more logging. (default=1)")
    parser.add_argument(
        "--rank",
        "--local_rank",
        default=0,
        type=int,
        help="rank for distributed training. no need to explictly specify.")
    args = parser.parse_args()

    args.distributed = False
    if not torch.cuda.is_available():
        device = torch.device("cpu")
    else:
        device = torch.device("cuda")
        # effective when using fixed size inputs
        # see https://discuss.pytorch.org/t/what-does-torch-backends-cudnn-benchmark-do/5936
        torch.backends.cudnn.benchmark = True
        torch.cuda.set_device(args.rank)
        # setup for distributed training
        # see example: https://github.com/NVIDIA/apex/tree/master/examples/simple/distributed
        if "WORLD_SIZE" in os.environ:
            args.world_size = int(os.environ["WORLD_SIZE"])
            args.distributed = args.world_size > 1
        if args.distributed:
            torch.distributed.init_process_group(backend="nccl",
                                                 init_method="env://")

    # suppress logging for distributed training
    if args.rank != 0:
        sys.stdout = open(os.devnull, "w")

    # set logger
    if args.verbose > 1:
        logging.basicConfig(
            level=logging.DEBUG,
            stream=sys.stdout,
            format=
            "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")
    elif args.verbose > 0:
        logging.basicConfig(
            level=logging.INFO,
            stream=sys.stdout,
            format=
            "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")
    else:
        logging.basicConfig(
            level=logging.WARN,
            stream=sys.stdout,
            format=
            "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")
        logging.warning("Skip DEBUG/INFO messages")

    # check directory existence
    if not os.path.exists(args.outdir):
        os.makedirs(args.outdir)

    # load and save config
    with open(args.config) as f:
        config = yaml.load(f, Loader=yaml.Loader)
    config.update(vars(args))
    config["version"] = parallel_wavegan.__version__  # add version info
    with open(os.path.join(args.outdir, "config.yml"), "w") as f:
        yaml.dump(config, f, Dumper=yaml.Dumper)
    for key, value in config.items():
        logging.info(f"{key} = {value}")

    # get dataset
    if config["remove_short_samples"]:
        mel_length_threshold = config["batch_max_steps"] // config["hop_size"] + \
            2 * config["generator_params"].get("aux_context_window", 0)
    else:
        mel_length_threshold = None
    if config["format"] == "hdf5":
        audio_query, mel_query = "*.h5", "*.h5"
        audio_load_fn = lambda x: read_hdf5(x, "wave")  # NOQA
        mel_load_fn = lambda x: read_hdf5(x, "feats")  # NOQA
    elif config["format"] == "npy":
        audio_query, mel_query = "*-wave.npy", "*-feats.npy"
        audio_load_fn = np.load
        mel_load_fn = np.load
    else:
        raise ValueError("support only hdf5 or npy format.")
    dataset = {
        "train":
        AudioMelDataset(
            root_dir=args.train_dumpdir,
            audio_query=audio_query,
            mel_query=mel_query,
            audio_load_fn=audio_load_fn,
            mel_load_fn=mel_load_fn,
            mel_length_threshold=mel_length_threshold,
            allow_cache=config.get("allow_cache", False),  # keep compatibility
        ),
        "dev":
        AudioMelDataset(
            root_dir=args.dev_dumpdir,
            audio_query=audio_query,
            mel_query=mel_query,
            audio_load_fn=audio_load_fn,
            mel_load_fn=mel_load_fn,
            mel_length_threshold=mel_length_threshold,
            allow_cache=config.get("allow_cache", False),  # keep compatibility
        ),
    }

    # get data loader
    collater = Collater(
        batch_max_steps=config["batch_max_steps"],
        hop_size=config["hop_size"],
        # keep compatibility
        aux_context_window=config["generator_params"].get(
            "aux_context_window", 0),
        # keep compatibility
        use_noise_input=config.get(
            "generator_type", "ParallelWaveGANGenerator") != "MelGANGenerator",
    )
    train_sampler, dev_sampler = None, None
    if args.distributed:
        # setup sampler for distributed training
        from torch.utils.data.distributed import DistributedSampler
        train_sampler = DistributedSampler(
            dataset=dataset["train"],
            num_replicas=args.world_size,
            rank=args.rank,
            shuffle=True,
        )
        dev_sampler = DistributedSampler(
            dataset=dataset["dev"],
            num_replicas=args.world_size,
            rank=args.rank,
            shuffle=False,
        )
    data_loader = {
        "train":
        DataLoader(
            dataset=dataset["train"],
            shuffle=False if args.distributed else True,
            collate_fn=collater,
            batch_size=config["batch_size"],
            num_workers=config["num_workers"],
            sampler=train_sampler,
            pin_memory=config["pin_memory"],
        ),
        "dev":
        DataLoader(
            dataset=dataset["dev"],
            shuffle=False if args.distributed else True,
            collate_fn=collater,
            batch_size=config["batch_size"],
            num_workers=config["num_workers"],
            sampler=dev_sampler,
            pin_memory=config["pin_memory"],
        ),
    }

    # define models and optimizers
    generator_class = getattr(
        parallel_wavegan.models,
        # keep compatibility
        config.get("generator_type", "ParallelWaveGANGenerator"),
    )
    discriminator_class = getattr(
        parallel_wavegan.models,
        # keep compatibility
        config.get("discriminator_type", "ParallelWaveGANDiscriminator"),
    )
    model = {
        "generator":
        generator_class(**config["generator_params"]).to(device),
        "discriminator":
        discriminator_class(**config["discriminator_params"]).to(device),
    }
    criterion = {
        "stft":
        MultiResolutionSTFTLoss(**config["stft_loss_params"]).to(device),
        "mse": torch.nn.MSELoss().to(device),
    }
    if config.get("use_feat_match_loss", False):  # keep compatibility
        criterion["l1"] = torch.nn.L1Loss().to(device)
    optimizer = {
        "generator":
        RAdam(model["generator"].parameters(),
              **config["generator_optimizer_params"]),
        "discriminator":
        RAdam(model["discriminator"].parameters(),
              **config["discriminator_optimizer_params"]),
    }
    scheduler = {
        "generator":
        torch.optim.lr_scheduler.StepLR(
            optimizer=optimizer["generator"],
            **config["generator_scheduler_params"]),
        "discriminator":
        torch.optim.lr_scheduler.StepLR(
            optimizer=optimizer["discriminator"],
            **config["discriminator_scheduler_params"]),
    }
    if args.distributed:
        # wrap model for distributed training
        try:
            from apex.parallel import DistributedDataParallel
        except ImportError:
            raise ImportError(
                "apex is not installed. please check https://github.com/NVIDIA/apex."
            )
        model["generator"] = DistributedDataParallel(model["generator"])
        model["discriminator"] = DistributedDataParallel(
            model["discriminator"])
    logging.info(model["generator"])
    logging.info(model["discriminator"])

    # define trainer
    trainer = Trainer(
        steps=0,
        epochs=0,
        data_loader=data_loader,
        model=model,
        criterion=criterion,
        optimizer=optimizer,
        scheduler=scheduler,
        config=config,
        device=device,
    )

    # resume from checkpoint
    if len(args.resume) != 0:
        trainer.load_checkpoint(args.resume)
        logging.info(f"Successfully resumed from {args.resume}.")

    # run training loop
    try:
        trainer.run()
    except KeyboardInterrupt:
        trainer.save_checkpoint(
            os.path.join(config["outdir"],
                         f"checkpoint-{trainer.steps}steps.pkl"))
        logging.info(f"Successfully saved checkpoint @ {trainer.steps}steps.")
Ejemplo n.º 25
0
def main():
    """Run decoding process."""
    parser = argparse.ArgumentParser(
        description=
        "Decode dumped features with trained Parallel WaveGAN Generator "
        "(See detail in parallel_wavegan/bin/decode.py).")
    parser.add_argument("--world_test",
                        required=True,
                        type=str,
                        help="list or directory of testing aux feat files")
    parser.add_argument("--stats",
                        required=True,
                        type=str,
                        help="hdf5 file including statistics")
    parser.add_argument("--indir",
                        required=True,
                        type=str,
                        help="directory of input feature files")
    parser.add_argument("--outdir",
                        required=True,
                        type=str,
                        help="directory to save generated samples")
    parser.add_argument("--checkpoint",
                        type=str,
                        required=True,
                        help="checkpoint file to be loaded.")
    parser.add_argument(
        "--config",
        default=None,
        type=str,
        help="yaml format configuration file. if not explicitly provided, "
        "it will be searched in the checkpoint directory. (default=None)")
    parser.add_argument("--feat_path1",
                        default="world",
                        type=str,
                        help="default feature path(channel) of hdf5 files.")
    parser.add_argument("--feat_path2",
                        default=None,
                        type=str,
                        help="second feature path(channel) of hdf5 files.")
    parser.add_argument(
        "--verbose",
        type=int,
        default=1,
        help="logging level. higher is more logging. (default=1)")
    parser.add_argument("--seed", default=100, type=int, help="seed number")
    args = parser.parse_args()

    # set logger
    if args.verbose > 1:
        logging.basicConfig(
            level=logging.DEBUG,
            format=
            "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")
    elif args.verbose > 0:
        logging.basicConfig(
            level=logging.INFO,
            format=
            "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")
    else:
        logging.basicConfig(
            level=logging.WARN,
            format=
            "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")
        logging.warning("Skip DEBUG/INFO messages")

    # fix seed
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    os.environ['PYTHONHASHSEED'] = str(args.seed)

    # check directory existence
    if not os.path.isdir(os.path.dirname(args.outdir)):
        os.makedirs(os.path.dirname(args.outdir))

    # load config
    if args.config is None:
        dirname = os.path.dirname(args.checkpoint)
        args.config = os.path.join(dirname, "config.yml")
    with open(args.config) as f:
        config = yaml.load(f, Loader=yaml.Loader)
    config.update(vars(args))

    world_load_fn = lambda x: read_hdf5(
        x, hdf5_path1=args.feat_path1, hdf5_path2=args.feat_path2)

    dataset = WorldDataset(
        stats=args.stats,
        world_list=args.world_test,
        world_load_fn=world_load_fn,
        return_filename=True,
        mean_path=config.get("mean_path", "/world/mean"),
        scale_path=config.get("scale_path", "/world/scale"),
    )  #
    logging.info(f"The number of features to be decoded = {len(dataset)}.")

    # setup
    if torch.cuda.is_available():
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")
    model_class = getattr(
        parallel_wavegan.models,
        config.get("generator_type", "ParallelWaveGANGenerator"))
    model = model_class(**config["generator_params"])
    model.load_state_dict(
        torch.load(args.checkpoint, map_location="cpu")["model"]["generator"])
    logging.info(f"Loaded model parameters from {args.checkpoint}.")
    model.remove_weight_norm()
    model = model.eval().to(device)
    use_noise_input = not isinstance(model,
                                     parallel_wavegan.models.MelGANGenerator)
    pad_fn = torch.nn.ReplicationPad1d(config["generator_params"].get(
        "aux_context_window", 0))

    # start generation
    total_rtf = 0.0
    with torch.no_grad(), tqdm(dataset, desc="[decode]") as pbar:
        for idx, (feat_path, c) in enumerate(pbar, 1):
            # setup input
            c = pad_fn(torch.FloatTensor(c).unsqueeze(0).transpose(
                2, 1)).to(device)
            x = (c, )
            if use_noise_input:
                z_size = (1, 1, (c.size(2) - sum(pad_fn.padding)) *
                          config["hop_size"])
                z = torch.randn(z_size).to(device)
                x = (z, ) + x

            # generate
            start = time.time()
            y = model(*x).view(-1).cpu().numpy()
            rtf = (time.time() - start) / (len(y) / config["sampling_rate"])
            pbar.set_postfix({"RTF": rtf})
            total_rtf += rtf

            # save as PCM 16 bit wav file
            feat_path = os.path.splitext(feat_path)[0]
            feat_path = feat_path.replace(args.indir, args.outdir)
            wav_filename = "%s.wav" % feat_path
            if not os.path.exists(os.path.dirname(wav_filename)):
                os.makedirs(os.path.dirname(wav_filename))
            sf.write(wav_filename, y, config["sampling_rate"], "PCM_16")

    # report average RTF
    logging.info(
        f"Finished generation of {idx} utterances (RTF = {total_rtf / idx:.03f})."
    )