def __init__( self, manifest_path, sample_rate, max_sample_size=None, min_sample_size=0, shuffle=True, pad=False, normalize=False, num_buckets=0, compute_mask_indices=False, text_compression_level=TextCompressionLevel.none, **mask_compute_kwargs, ): super().__init__( sample_rate=sample_rate, max_sample_size=max_sample_size, min_sample_size=min_sample_size, shuffle=shuffle, pad=pad, normalize=normalize, compute_mask_indices=compute_mask_indices, **mask_compute_kwargs, ) self.text_compressor = TextCompressor(level=text_compression_level) skipped = 0 self.fnames = [] sizes = [] self.skipped_indices = set() with open(manifest_path, "r") as f: self.root_dir = f.readline().strip() for i, line in enumerate(f): items = line.strip().split("\t") assert len(items) == 2, line sz = int(items[1]) if min_sample_size is not None and sz < min_sample_size: skipped += 1 self.skipped_indices.add(i) continue self.fnames.append(self.text_compressor.compress(items[0])) sizes.append(sz) logger.info(f"loaded {len(self.fnames)}, skipped {skipped} samples") self.sizes = np.array(sizes, dtype=np.int64) try: import pyarrow self.fnames = pyarrow.array(self.fnames) except: logger.debug( "Could not create a pyarrow array. Please install pyarrow for better performance" ) pass self.set_bucket_info(num_buckets)
def __init__(self, dataset, labels, pad, eos, batch_targets, process_label=None, label_len_fn=None, add_to_input=False, text_compression_level=TextCompressionLevel.none): super().__init__(dataset) self.labels = labels self.batch_targets = batch_targets self.pad = pad self.eos = eos self.process_label = process_label self.label_len_fn = label_len_fn self.add_to_input = add_to_input self.text_compressor = TextCompressor(level=text_compression_level)
def load_dataset( self, split: str, task_cfg: AudioFinetuningConfig = None, **kwargs ): super().load_dataset(split, task_cfg, **kwargs) task_cfg = task_cfg or self.cfg assert task_cfg.labels is not None text_compression_level = getattr( TextCompressionLevel, str(self.cfg.text_compression_level) ) data_path = self.cfg.data label_path = os.path.join(data_path, f"{split}.{task_cfg.labels}") skipped_indices = getattr(self.datasets[split], "skipped_indices", set()) text_compressor = TextCompressor(level=text_compression_level) with open(label_path, "r") as f: labels = [ text_compressor.compress(l) for i, l in enumerate(f) if i not in skipped_indices ] assert len(labels) == len(self.datasets[split]), ( f"labels length ({len(labels)}) and dataset length " f"({len(self.datasets[split])}) do not match" ) process_label = LabelEncoder(self.target_dictionary) self.datasets[split] = AddTargetDataset( self.datasets[split], labels, pad=self.target_dictionary.pad(), eos=self.target_dictionary.eos(), batch_targets=True, process_label=process_label, label_len_fn=label_len_fn, add_to_input=task_cfg.get("autoregressive", False), text_compression_level=text_compression_level, )
class AddTargetDataset(BaseWrapperDataset): def __init__( self, dataset, labels, pad, eos, batch_targets, process_label=None, label_len_fn=None, add_to_input=False, text_compression_level=TextCompressionLevel.none, ): super().__init__(dataset) self.labels = labels self.batch_targets = batch_targets self.pad = pad self.eos = eos self.process_label = process_label self.label_len_fn = label_len_fn self.add_to_input = add_to_input self.text_compressor = TextCompressor(level=text_compression_level) def get_label(self, index, process_fn=None): lbl = self.labels[index] lbl = self.text_compressor.decompress(lbl) return lbl if process_fn is None else process_fn(lbl) def __getitem__(self, index): item = self.dataset[index] item["label"] = self.get_label(index, process_fn=self.process_label) return item def size(self, index): sz = self.dataset.size(index) own_sz = self.label_len_fn(self.get_label(index)) return sz, own_sz def collater(self, samples): collated = self.dataset.collater(samples) if len(collated) == 0: return collated indices = set(collated["id"].tolist()) target = [s["label"] for s in samples if s["id"] in indices] if self.add_to_input: eos = torch.LongTensor([self.eos]) prev_output_tokens = [torch.cat([eos, t], axis=-1) for t in target] target = [torch.cat([t, eos], axis=-1) for t in target] collated["net_input"]["prev_output_tokens"] = prev_output_tokens if self.batch_targets: collated["target_lengths"] = torch.LongTensor([len(t) for t in target]) target = data_utils.collate_tokens(target, pad_idx=self.pad, left_pad=False) collated["ntokens"] = collated["target_lengths"].sum().item() collated["net_input"]["prev_output_tokens"] = data_utils.collate_tokens( collated["net_input"]["prev_output_tokens"], pad_idx=self.pad, left_pad=False, ) else: collated["ntokens"] = sum([len(t) for t in target]) collated["target"] = target return collated def filter_indices_by_size(self, indices, max_sizes): indices, ignored = data_utils._filter_by_size_dynamic( indices, self.size, max_sizes ) return indices, ignored
class FileAudioDataset(RawAudioDataset): def __init__( self, manifest_path, sample_rate, max_sample_size=None, min_sample_size=0, shuffle=True, pad=False, normalize=False, num_buckets=0, compute_mask_indices=False, text_compression_level=TextCompressionLevel.none, **mask_compute_kwargs, ): super().__init__( sample_rate=sample_rate, max_sample_size=max_sample_size, min_sample_size=min_sample_size, shuffle=shuffle, pad=pad, normalize=normalize, compute_mask_indices=compute_mask_indices, **mask_compute_kwargs, ) self.text_compressor = TextCompressor(level=text_compression_level) skipped = 0 self.fnames = [] sizes = [] self.skipped_indices = set() with open(manifest_path, "r") as f: self.root_dir = f.readline().strip() for i, line in enumerate(f): items = line.strip().split("\t") assert len(items) == 2, line sz = int(items[1]) if min_sample_size is not None and sz < min_sample_size: skipped += 1 self.skipped_indices.add(i) continue self.fnames.append(self.text_compressor.compress(items[0])) sizes.append(sz) logger.info(f"loaded {len(self.fnames)}, skipped {skipped} samples") self.sizes = np.array(sizes, dtype=np.int64) try: import pyarrow self.fnames = pyarrow.array(self.fnames) except: logger.debug( "Could not create a pyarrow array. Please install pyarrow for better performance" ) pass self.set_bucket_info(num_buckets) def __getitem__(self, index): import soundfile as sf fn = self.fnames[index] fn = fn if isinstance(self.fnames, list) else fn.as_py() fn = self.text_compressor.decompress(fn) path_or_fp = os.path.join(self.root_dir, fn) _path, slice_ptr = parse_path(path_or_fp) if len(slice_ptr) == 2: byte_data = read_from_stored_zip(_path, slice_ptr[0], slice_ptr[1]) assert is_sf_audio_data(byte_data) path_or_fp = io.BytesIO(byte_data) wav, curr_sample_rate = sf.read(path_or_fp, dtype="float32") feats = torch.from_numpy(wav).float() feats = self.postprocess(feats, curr_sample_rate) return {"id": index, "source": feats}