def __init__( self, audio_list, audio_length_threshold=None, audio_load_fn=sf.read, return_filename=False, allow_cache=False, ): """Initialize dataset. Args: audio_list (str): Filename of the list of audio files. audio_load_fn (func): Function to load audio file. audio_length_threshold (int): Threshold to remove short audio files. return_filename (bool): Whether to return the filename with arrays. allow_cache (bool): Whether to allow cache of the loaded files. """ # load audio and world file list audio_files = read_txt(audio_list) # filter by threshold if audio_length_threshold is not None: audio_lengths = [audio_load_fn(f).shape[0] for f in audio_files] idxs = [ idx for idx in range(len(audio_files)) if audio_lengths[idx] > audio_length_threshold ] if len(audio_files) != len(idxs): logging.warning( f"some files are filtered by audio length threshold " f"({len(audio_files)} -> {len(idxs)}).") audio_files = [audio_files[idx] for idx in idxs] # assert the number of files assert len( audio_files) != 0, f"Not found any audio files in ${audio_list}." self.audio_files = audio_files self.audio_load_fn = audio_load_fn self.return_filename = return_filename self.allow_cache = allow_cache if allow_cache: # NOTE(kan-bayashi): Manager is need to share memory in dataloader with num_workers > 0 self.manager = Manager() self.caches = self.manager.list() self.caches += [() for _ in range(len(audio_files))]
def __init__( self, stats, audio_list, world_list, audio_load_fn=sf.read, world_load_fn=lambda x: read_hdf5(x, "world"), hop_size=110, audio_length_threshold=None, world_length_threshold=None, return_filename=False, allow_cache=False, mean_path="/world/mean", scale_path="/world/scale", ): """Initialize dataset. Args: stats (str): Filename of the statistic hdf5 file. audio_list (str): Filename of the list of audio files. world_list (str): Filename of the list of world feature files. audio_load_fn (func): Function to load audio file. world_load_fn (func): Function to load world feature file. hop_size (int): Hope size of world feature audio_length_threshold (int): Threshold to remove short audio files. world_length_threshold (int): Threshold to remove short world feature files. return_filename (bool): Whether to return the filename with arrays. allow_cache (bool): Whether to allow cache of the loaded files. mean_path (str): The data path (channel) of the mean in the statistic hdf5 file. scale_path (str): The data path (channel) of the scale in the statistic hdf5 file. """ # load audio and world file list audio_files = read_txt(audio_list) world_files = read_txt(world_list) # check filename assert check_filename(audio_files, world_files) # filter by threshold if audio_length_threshold is not None: audio_lengths = [audio_load_fn(f).shape[0] for f in audio_files] idxs = [ idx for idx in range(len(audio_files)) if audio_lengths[idx] > audio_length_threshold ] if len(audio_files) != len(idxs): logging.warning( f"Some files are filtered by audio length threshold " f"({len(audio_files)} -> {len(idxs)}).") audio_files = [audio_files[idx] for idx in idxs] world_files = [world_files[idx] for idx in idxs] if world_length_threshold is not None: world_lengths = [world_load_fn(f).shape[0] for f in world_files] idxs = [ idx for idx in range(len(world_files)) if world_lengths[idx] > world_length_threshold ] if len(world_files) != len(idxs): logging.warning( f"Some files are filtered by world length threshold " f"({len(world_files)} -> {len(idxs)}).") audio_files = [audio_files[idx] for idx in idxs] world_files = [world_files[idx] for idx in idxs] # assert the number of files assert len( audio_files) != 0, f"Not found any audio files in ${audio_list}." assert len(audio_files) == len(world_files), \ f"Number of audio and world files are different ({len(audio_files)} vs {len(world_files)})." self.audio_files = audio_files self.world_files = world_files self.audio_load_fn = audio_load_fn self.world_load_fn = world_load_fn self.return_filename = return_filename self.allow_cache = allow_cache self.hop_size = hop_size if allow_cache: # NOTE(kan-bayashi): Manager is need to share memory in dataloader with num_workers > 0 self.manager = Manager() self.caches = self.manager.list() self.caches += [() for _ in range(len(audio_files))] # define feature pre-processing funtion scaler = StandardScaler() scaler.mean_ = read_hdf5(stats, mean_path) scaler.scale_ = read_hdf5(stats, scale_path) # for version 0.23.0, this information is needed scaler.n_features_in_ = scaler.mean_.shape[0] self.feat_transform = lambda x: scaler.transform(x)