def __init__( self, batch_size: int, key_file: str, drop_last: bool = False, utt2category_file: str = None, ): assert check_argument_types() assert batch_size > 0 self.batch_size = batch_size self.key_file = key_file self.drop_last = drop_last # utt2shape: # uttA <anything is o.k> # uttB <anything is o.k> utt2any = read_2column_text(key_file) if len(utt2any) == 0: logging.warning(f"{key_file} is empty") # In this case the, the first column in only used keys = list(utt2any) if len(keys) == 0: raise RuntimeError(f"0 lines found: {key_file}") category2utt = {} if utt2category_file is not None: utt2category = read_2column_text(utt2category_file) if set(utt2category) != set(keys): raise RuntimeError( f"keys are mismatched between {utt2category_file} != {key_file}" ) for k, v in utt2category.items(): category2utt.setdefault(v, []).append(k) else: category2utt["default_category"] = keys self.batch_list = [] for d, v in category2utt.items(): category_keys = v # Apply max(, 1) to avoid 0-batches N = max(len(category_keys) // batch_size, 1) if not self.drop_last: # Split keys evenly as possible as. Note that If N != 1, # the these batches always have size of batch_size at minimum. cur_batch_list = [ category_keys[i * len(keys) // N:(i + 1) * len(keys) // N] for i in range(N) ] else: cur_batch_list = [ tuple(category_keys[i * batch_size:(i + 1) * batch_size]) for i in range(N) ] self.batch_list.extend(cur_batch_list)
def test_read_2column_text(tmp_path: Path): p = tmp_path / "dummy.scp" with p.open("w") as f: f.write("abc /some/path/a.wav\n") f.write("def /some/path/b.wav\n") d = read_2column_text(p) assert d == {"abc": "/some/path/a.wav", "def": "/some/path/b.wav"}
def __init__(self, batch_size: int, key_file: str, drop_last: bool = False): assert check_argument_types() assert batch_size > 0 self.batch_size = batch_size self.key_file = key_file self.drop_last = drop_last # utt2shape: # uttA <anything is o.k> # uttB <anything is o.k> utt2any = read_2column_text(key_file) if len(utt2any) == 0: logging.warning(f"{key_file} is empty") # In this case the, the first column in only used keys = list(utt2any) if len(keys) == 0: raise RuntimeError(f"0 lines found: {key_file}") # Apply max(, 1) to avoid 0-batches N = max(len(keys) // batch_size, 1) if not self.drop_last: # Split keys evenly as possible as. Note that If N != 1, # the these batches always have size of batch_size at minimum. self.batch_list = [ keys[i * len(keys) // N:(i + 1) * len(keys) // N] for i in range(N) ] else: self.batch_list = [ tuple(keys[i * batch_size:(i + 1) * batch_size]) for i in range(N) ]
def __init__( self, fname, dtype=np.int16, always_2d: bool = False, normalize: bool = False, ): assert check_argument_types() self.fname = fname self.dtype = dtype self.always_2d = always_2d self.normalize = normalize self.data = read_2column_text(fname)
def __init__( self, batch_size: int, shape_files: Union[Tuple[str, ...], List[str]], fold_lengths: Sequence[int], min_batch_size: int = 1, sort_in_batch: str = "descending", sort_batch: str = "ascending", drop_last: bool = False, utt2category_file: str = None, ): assert check_argument_types() assert batch_size > 0 if sort_batch != "ascending" and sort_batch != "descending": raise ValueError( f"sort_batch must be ascending or descending: {sort_batch}") if sort_in_batch != "descending" and sort_in_batch != "ascending": raise ValueError( f"sort_in_batch must be ascending or descending: {sort_in_batch}" ) self.batch_size = batch_size self.shape_files = shape_files self.sort_in_batch = sort_in_batch self.sort_batch = sort_batch self.drop_last = drop_last # utt2shape: (Length, ...) # uttA 100,... # uttB 201,... utt2shapes = [ load_num_sequence_text(s, loader_type="csv_int") for s in shape_files ] first_utt2shape = utt2shapes[0] for s, d in zip(shape_files, utt2shapes): if set(d) != set(first_utt2shape): raise RuntimeError( f"keys are mismatched between {s} != {shape_files[0]}") # Sort samples in ascending order # (shape order should be like (Length, Dim)) keys = sorted(first_utt2shape, key=lambda k: first_utt2shape[k][0]) if len(keys) == 0: raise RuntimeError(f"0 lines found: {shape_files[0]}") category2utt = {} if utt2category_file is not None: utt2category = read_2column_text(utt2category_file) if set(utt2category) != set(first_utt2shape): raise RuntimeError("keys are mismatched between " f"{utt2category_file} != {shape_files[0]}") for k in keys: category2utt.setdefault(utt2category[k], []).append(k) else: category2utt["default_category"] = keys self.batch_list = [] for d, v in category2utt.items(): category_keys = v # Decide batch-sizes start = 0 batch_sizes = [] while True: k = category_keys[start] factor = max( int(d[k][0] / m) for d, m in zip(utt2shapes, fold_lengths)) bs = max(min_batch_size, int(batch_size / (1 + factor))) if self.drop_last and start + bs > len(category_keys): # This if-block avoids 0-batches if len(self.batch_list) > 0: break bs = min(len(category_keys) - start, bs) batch_sizes.append(bs) start += bs if start >= len(category_keys): break if len(batch_sizes) == 0: # Maybe we can't reach here raise RuntimeError("0 batches") # If the last batch-size is smaller than minimum batch_size, # the samples are redistributed to the other mini-batches if len(batch_sizes) > 1 and batch_sizes[-1] < min_batch_size: for i in range(batch_sizes.pop(-1)): batch_sizes[-(i % len(batch_sizes)) - 2] += 1 if not self.drop_last: # Bug check assert sum(batch_sizes) == len( category_keys ), f"{sum(batch_sizes)} != {len(category_keys)}" # Set mini-batch cur_batch_list = [] start = 0 for bs in batch_sizes: assert len(category_keys) >= start + bs, "Bug" minibatch_keys = category_keys[start:start + bs] start += bs if sort_in_batch == "descending": minibatch_keys.reverse() elif sort_in_batch == "ascending": # Key are already sorted in ascending pass else: raise ValueError("sort_in_batch must be ascending or " f"descending: {sort_in_batch}") cur_batch_list.append(tuple(minibatch_keys)) if sort_batch == "ascending": pass elif sort_batch == "descending": cur_batch_list.reverse() else: raise ValueError( f"sort_batch must be ascending or descending: {sort_batch}" ) self.batch_list.extend(cur_batch_list)
def main(): logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s" logging.basicConfig(level=logging.INFO, format=logfmt) logging.info(get_commandline_args()) parser = argparse.ArgumentParser( description='Create waves list from "wav.scp"', formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) parser.add_argument("scp") parser.add_argument("outdir") parser.add_argument( "--name", default="wav", help="Specify the prefix word of output file name " 'such as "wav.scp"', ) parser.add_argument("--segments", default=None) parser.add_argument( "--fs", type=humanfriendly_or_none, default=None, help="If the sampling rate specified, " "Change the sampling rate.", ) parser.add_argument("--audio-format", default="wav") group = parser.add_mutually_exclusive_group() group.add_argument("--ref-channels", default=None, type=str2int_tuple) group.add_argument("--utt2ref-channels", default=None, type=str) args = parser.parse_args() out_num_samples = Path(args.outdir) / f"utt2num_samples" if args.ref_channels is not None: def utt2ref_channels(x) -> Tuple[int, ...]: return args.ref_channels elif args.utt2ref_channels is not None: utt2ref_channels_dict = read_2column_text(args.utt2ref_channels) def utt2ref_channels(x, d=utt2ref_channels_dict) -> Tuple[int, ...]: chs_str = d[x] return tuple(map(int, chs_str.split())) else: utt2ref_channels = None Path(args.outdir).mkdir(parents=True, exist_ok=True) out_wavscp = Path(args.outdir) / f"{args.name}.scp" if args.segments is not None: # Note: kaldiio supports only wav-pcm-int16le file. loader = kaldiio.load_scp_sequential(args.scp, segments=args.segments) if args.audio_format.endswith("ark"): fark = open(Path(args.outdir) / f"data_{args.name}.ark", "wb") fscp = out_wavscp.open("w") else: writer = SoundScpWriter( args.outdir, out_wavscp, format=args.audio_format, ) with out_num_samples.open("w") as fnum_samples: for uttid, (rate, wave) in tqdm(loader): # wave: (Time,) or (Time, Nmic) if wave.ndim == 2 and utt2ref_channels is not None: wave = wave[:, utt2ref_channels(uttid)] if args.fs is not None and args.fs != rate: # FIXME(kamo): To use sox? wave = resampy.resample(wave.astype(np.float64), rate, args.fs, axis=0) wave = wave.astype(np.int16) rate = args.fs if args.audio_format.endswith("ark"): if "flac" in args.audio_format: suf = "flac" elif "wav" in args.audio_format: suf = "wav" else: raise RuntimeError("wav.ark or flac") # NOTE(kamo): Using extended ark format style here. # This format is incompatible with Kaldi kaldiio.save_ark( fark, {uttid: (wave, rate)}, scp=fscp, append=True, write_function=f"soundfile_{suf}", ) else: writer[uttid] = rate, wave fnum_samples.write(f"{uttid} {len(wave)}\n") else: if args.audio_format.endswith("ark"): fark = open(Path(args.outdir) / f"data_{args.name}.ark", "wb") else: wavdir = Path(args.outdir) / f"data_{args.name}" wavdir.mkdir(parents=True, exist_ok=True) with Path(args.scp).open("r") as fscp, out_wavscp.open( "w") as fout, out_num_samples.open("w") as fnum_samples: for line in tqdm(fscp): uttid, wavpath = line.strip().split(None, 1) if wavpath.endswith("|"): # Streaming input e.g. cat a.wav | with kaldiio.open_like_kaldi(wavpath, "rb") as f: with BytesIO(f.read()) as g: wave, rate = soundfile.read(g, dtype=np.int16) if wave.ndim == 2 and utt2ref_channels is not None: wave = wave[:, utt2ref_channels(uttid)] if args.fs is not None and args.fs != rate: # FIXME(kamo): To use sox? wave = resampy.resample(wave.astype(np.float64), rate, args.fs, axis=0) wave = wave.astype(np.int16) rate = args.fs if args.audio_format.endswith("ark"): if "flac" in args.audio_format: suf = "flac" elif "wav" in args.audio_format: suf = "wav" else: raise RuntimeError("wav.ark or flac") # NOTE(kamo): Using extended ark format style here. # This format is incompatible with Kaldi kaldiio.save_ark( fark, {uttid: (wave, rate)}, scp=fout, append=True, write_function=f"soundfile_{suf}", ) else: owavpath = str(wavdir / f"{uttid}.{args.audio_format}") soundfile.write(owavpath, wave, rate) fout.write(f"{uttid} {owavpath}\n") else: wave, rate = soundfile.read(wavpath, dtype=np.int16) if wave.ndim == 2 and utt2ref_channels is not None: wave = wave[:, utt2ref_channels(uttid)] save_asis = False elif args.audio_format.endswith("ark"): save_asis = False elif Path(wavpath).suffix == "." + args.audio_format and ( args.fs is None or args.fs == rate): save_asis = True else: save_asis = False if save_asis: # Neither --segments nor --fs are specified and # the line doesn't end with "|", # i.e. not using unix-pipe, # only in this case, # just using the original file as is. fout.write(f"{uttid} {wavpath}\n") else: if args.fs is not None and args.fs != rate: # FIXME(kamo): To use sox? wave = resampy.resample(wave.astype(np.float64), rate, args.fs, axis=0) wave = wave.astype(np.int16) rate = args.fs if args.audio_format.endswith("ark"): if "flac" in args.audio_format: suf = "flac" elif "wav" in args.audio_format: suf = "wav" else: raise RuntimeError("wav.ark or flac") # NOTE(kamo): Using extended ark format style here. # This format is not supported in Kaldi. kaldiio.save_ark( fark, {uttid: (wave, rate)}, scp=fout, append=True, write_function=f"soundfile_{suf}", ) else: owavpath = str(wavdir / f"{uttid}.{args.audio_format}") soundfile.write(owavpath, wave, rate) fout.write(f"{uttid} {owavpath}\n") fnum_samples.write(f"{uttid} {len(wave)}\n")
def __init__(self, fname: Union[Path, str]): assert check_argument_types() self.fname = Path(fname) self.data = read_2column_text(fname)