コード例 #1
0
    def __init__(
        self,
        audio_list,
        audio_length_threshold=None,
        audio_load_fn=sf.read,
        return_filename=False,
        allow_cache=False,
    ):
        """Initialize dataset.

        Args:
            audio_list (str): Filename of the list of audio files.
            audio_load_fn (func): Function to load audio file.
            audio_length_threshold (int): Threshold to remove short audio files.
            return_filename (bool): Whether to return the filename with arrays.
            allow_cache (bool): Whether to allow cache of the loaded files.

        """
        # load audio and world file list
        audio_files = read_txt(audio_list)

        # filter by threshold
        if audio_length_threshold is not None:
            audio_lengths = [audio_load_fn(f).shape[0] for f in audio_files]
            idxs = [
                idx for idx in range(len(audio_files))
                if audio_lengths[idx] > audio_length_threshold
            ]
            if len(audio_files) != len(idxs):
                logging.warning(
                    f"some files are filtered by audio length threshold "
                    f"({len(audio_files)} -> {len(idxs)}).")
            audio_files = [audio_files[idx] for idx in idxs]

        # assert the number of files
        assert len(
            audio_files) != 0, f"Not found any audio files in ${audio_list}."

        self.audio_files = audio_files
        self.audio_load_fn = audio_load_fn
        self.return_filename = return_filename
        self.allow_cache = allow_cache
        if allow_cache:
            # NOTE(kan-bayashi): Manager is need to share memory in dataloader with num_workers > 0
            self.manager = Manager()
            self.caches = self.manager.list()
            self.caches += [() for _ in range(len(audio_files))]
コード例 #2
0
    def __init__(
        self,
        stats,
        audio_list,
        world_list,
        audio_load_fn=sf.read,
        world_load_fn=lambda x: read_hdf5(x, "world"),
        hop_size=110,
        audio_length_threshold=None,
        world_length_threshold=None,
        return_filename=False,
        allow_cache=False,
        mean_path="/world/mean",
        scale_path="/world/scale",
    ):
        """Initialize dataset.

        Args:
            stats (str): Filename of the statistic hdf5 file.
            audio_list (str): Filename of the list of audio files.
            world_list (str): Filename of the list of world feature files.
            audio_load_fn (func): Function to load audio file.
            world_load_fn (func): Function to load world feature file.
            hop_size (int): Hope size of world feature
            audio_length_threshold (int): Threshold to remove short audio files.
            world_length_threshold (int): Threshold to remove short world feature files.
            return_filename (bool): Whether to return the filename with arrays.
            allow_cache (bool): Whether to allow cache of the loaded files.
            mean_path (str): The data path (channel) of the mean in the statistic hdf5 file.
            scale_path (str): The data path (channel) of the scale in the statistic hdf5 file.

        """
        # load audio and world file list
        audio_files = read_txt(audio_list)
        world_files = read_txt(world_list)
        # check filename
        assert check_filename(audio_files, world_files)

        # filter by threshold
        if audio_length_threshold is not None:
            audio_lengths = [audio_load_fn(f).shape[0] for f in audio_files]
            idxs = [
                idx for idx in range(len(audio_files))
                if audio_lengths[idx] > audio_length_threshold
            ]
            if len(audio_files) != len(idxs):
                logging.warning(
                    f"Some files are filtered by audio length threshold "
                    f"({len(audio_files)} -> {len(idxs)}).")
            audio_files = [audio_files[idx] for idx in idxs]
            world_files = [world_files[idx] for idx in idxs]
        if world_length_threshold is not None:
            world_lengths = [world_load_fn(f).shape[0] for f in world_files]
            idxs = [
                idx for idx in range(len(world_files))
                if world_lengths[idx] > world_length_threshold
            ]
            if len(world_files) != len(idxs):
                logging.warning(
                    f"Some files are filtered by world length threshold "
                    f"({len(world_files)} -> {len(idxs)}).")
            audio_files = [audio_files[idx] for idx in idxs]
            world_files = [world_files[idx] for idx in idxs]

        # assert the number of files
        assert len(
            audio_files) != 0, f"Not found any audio files in ${audio_list}."
        assert len(audio_files) == len(world_files), \
            f"Number of audio and world files are different ({len(audio_files)} vs {len(world_files)})."

        self.audio_files = audio_files
        self.world_files = world_files
        self.audio_load_fn = audio_load_fn
        self.world_load_fn = world_load_fn
        self.return_filename = return_filename
        self.allow_cache = allow_cache
        self.hop_size = hop_size
        if allow_cache:
            # NOTE(kan-bayashi): Manager is need to share memory in dataloader with num_workers > 0
            self.manager = Manager()
            self.caches = self.manager.list()
            self.caches += [() for _ in range(len(audio_files))]
        # define feature pre-processing funtion
        scaler = StandardScaler()
        scaler.mean_ = read_hdf5(stats, mean_path)
        scaler.scale_ = read_hdf5(stats, scale_path)
        # for version 0.23.0, this information is needed
        scaler.n_features_in_ = scaler.mean_.shape[0]
        self.feat_transform = lambda x: scaler.transform(x)