def build_dataset_for_inference(self, src_tokens, src_lengths, constraints=None): return AsrDataset( src_tokens, src_lengths, dictionary=self.target_dictionary, constraints=constraints, )
def _asr_dataset_helper( self, all_in_memory=False, ordered_prefetch=False, has_utt2num_frames=False, ): if not all_in_memory: src_dataset = FeatScpCachedDataset( utt_ids=self.feats_utt_ids, rxfiles=self.rxfiles, utt2num_frames=self.utt2num_frames if has_utt2num_frames else None, ordered_prefetch=ordered_prefetch, cache_size=self.cache_size, ) else: src_dataset = FeatScpInMemoryDataset( utt_ids=self.feats_utt_ids, rxfiles=self.rxfiles, utt2num_frames=self.utt2num_frames if has_utt2num_frames else None, ) tgt_dataset = AsrTextDataset( utt_ids=self.text_utt_ids, token_text=self.token_text, dictionary=self.dictionary, ) dataset = AsrDataset( src_dataset, src_dataset.sizes, tgt_dataset, tgt_dataset.sizes, self.dictionary, left_pad_source=False, left_pad_target=False, max_source_positions=1000, max_target_positions=200, ) # assume one is a subset of the other expected_dataset_size = min(self.num_audios, self.num_transripts) self.assertEqual(len(dataset.src), expected_dataset_size) self.assertEqual(len(dataset.tgt), expected_dataset_size) indices = list(range(expected_dataset_size)) batch_sampler = [] for i in range(0, expected_dataset_size, self.batch_size): batch_sampler.append(indices[i:i+self.batch_size]) if not all_in_memory: dataset.prefetch(indices) dataloader = torch.utils.data.DataLoader( dataset, collate_fn=dataset.collater, batch_sampler=batch_sampler, ) for i, batch in enumerate(iter(dataloader)): bsz = batch["nsentences"] self.assertEqual(bsz, len(batch_sampler[i])) src_frames = batch["net_input"]["src_tokens"] src_lengths = batch["net_input"]["src_lengths"] tgt_tokens = self.dictionary.string( batch["target"], extra_symbols_to_ignore={self.dictionary.pad()} ).split("\n") tgt_tokens = [line.split(" ") for line in tgt_tokens] self.assertEqual(bsz, src_frames.size(0)) self.assertEqual(bsz, src_lengths.numel()) self.assertEqual(bsz, len(tgt_tokens)) for j, utt_id in enumerate(batch["utt_id"]): self.assertTensorEqual( torch.from_numpy(self.expected_feats[utt_id]).float(), src_frames[j, :src_lengths[j], :] ) self.assertEqual( self.expected_text[utt_id], tgt_tokens[j], )
def build_dataset_for_inference(self, src_tokens, src_lengths): return AsrDataset(src_tokens, src_lengths)
def get_asr_dataset_from_json( data_path, split, tgt_dict, combine, upsample_primary, max_source_positions, max_target_positions, seed=1, specaugment_config=None, ): """ Parse data json and create dataset. See espresso/tools/asr_prep_json.py which pack json from raw files Json example: { "011c0202": { "feat": "fbank/raw_fbank_pitch_train_si284.1.ark:54819", "token_text": "T H E <space> H O T E L", "utt2num_frames": "693", }, "011c0203": { ... } } """ src_datasets = [] tgt_datasets = [] for k in itertools.count(): split_k = split + (str(k) if k > 0 else "") data_json_path = os.path.join(data_path, "{}.json".format(split_k)) if not os.path.isfile(data_json_path): if k > 0: break else: raise FileNotFoundError( "Dataset not found: {}".format(data_json_path)) with open(data_json_path, "rb") as f: loaded_json = json.load(f, object_pairs_hook=OrderedDict) utt_ids, feats, token_text, utt2num_frames = [], [], [], [] for utt_id, val in loaded_json.items(): utt_ids.append(utt_id) feats.append(val["feat"]) if "token_text" in val: token_text.append(val["token_text"]) if "utt2num_frames" in val: utt2num_frames.append(int(val["utt2num_frames"])) assert len(utt2num_frames) == 0 or len(utt_ids) == len(utt2num_frames) src_datasets.append( FeatScpCachedDataset( utt_ids, feats, utt2num_frames=utt2num_frames, seed=seed, specaugment_config=specaugment_config if split == "train" else None, ordered_prefetch=True, )) if len(token_text) > 0: assert len(utt_ids) == len(token_text) assert tgt_dict is not None tgt_datasets.append(AsrTextDataset(utt_ids, token_text, tgt_dict)) logger.info("{} {} examples".format(data_json_path, len(src_datasets[-1]))) if not combine: break assert len(src_datasets) == len(tgt_datasets) or len(tgt_datasets) == 0 feat_dim = src_datasets[0].feat_dim if len(src_datasets) == 1: src_dataset = src_datasets[0] tgt_dataset = tgt_datasets[0] if len(tgt_datasets) > 0 else None else: for i in range(1, len(src_datasets)): assert feat_dim == src_datasets[i].feat_dim, \ "feature dimension does not match across multiple json files" sample_ratios = [1] * len(src_datasets) sample_ratios[0] = upsample_primary src_dataset = ConcatDataset(src_datasets, sample_ratios) if len(tgt_datasets) > 0: tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios) else: tgt_dataset = None tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None return AsrDataset( src_dataset, src_dataset.sizes, tgt_dataset, tgt_dataset_sizes, tgt_dict, left_pad_source=False, left_pad_target=False, max_source_positions=max_source_positions, max_target_positions=max_target_positions, )
def get_asr_dataset_from_json( data_path, split, tgt_dict, combine, upsample_primary=1, num_buckets=0, shuffle=True, pad_to_multiple=1, seed=1, global_cmvn_stats_path=None, specaugment_config=None, ): """ Parse data json and create dataset. See espresso/tools/asr_prep_json.py which pack json from raw files Json example: { "011c0202": { "feat": "fbank/raw_fbank_pitch_train_si284.1.ark:54819" or "wave": "/export/corpora5/LDC/LDC93S6B/11-1.1/wsj0/si_tr_s/011/011c0202.wv1" or "command": "sph2pipe -f wav /export/corpora5/LDC/LDC93S6B/11-1.1/wsj0/si_tr_s/011/011c0202.wv1 |", "text": "THE HOTEL", "utt2num_frames": "693", }, "011c0203": { ... } } """ src_datasets = [] tgt_datasets = [] for k in itertools.count(): split_k = split + (str(k) if k > 0 else "") data_json_path = os.path.join(data_path, "{}.json".format(split_k)) if not os.path.isfile(data_json_path): if k > 0: break else: raise FileNotFoundError( "Dataset not found: {}".format(data_json_path) ) with open(data_json_path, "rb") as f: loaded_json = json.load(f, object_pairs_hook=OrderedDict) utt_ids, audios, texts, utt2num_frames = [], [], [], [] for utt_id, val in loaded_json.items(): utt_ids.append(utt_id) if "feat" in val: audio = val["feat"] elif "wave" in val: audio = val["wave"] elif "command" in val: audio = val["command"] else: raise KeyError( f"'feat', 'wave' or 'command' should be present as a field for the entry {utt_id} in {data_json_path}" ) audios.append(audio) if "text" in val: texts.append(val["text"]) if "utt2num_frames" in val: utt2num_frames.append(int(val["utt2num_frames"])) assert len(utt2num_frames) == 0 or len(utt_ids) == len(utt2num_frames) if "feat" in next(iter(loaded_json.items())): extra_kwargs = {} else: extra_kwargs = {"feat_dim": 80, "feature_type": "fbank"} if global_cmvn_stats_path is not None: feature_transforms_config = { "transforms": ["global_cmvn"], "global_cmvn": {"stats_npz_path": global_cmvn_stats_path} } extra_kwargs["feature_transforms_config"] = feature_transforms_config src_datasets.append(AudioFeatDataset( utt_ids, audios, utt2num_frames=utt2num_frames, seed=seed, specaugment_config=specaugment_config if split == "train" else None, **extra_kwargs )) if len(texts) > 0: assert len(utt_ids) == len(texts) assert tgt_dict is not None tgt_datasets.append(AsrTextDataset(utt_ids, texts, tgt_dict)) logger.info("{} {} examples".format(data_json_path, len(src_datasets[-1]))) if not combine: break assert len(src_datasets) == len(tgt_datasets) or len(tgt_datasets) == 0 feat_dim = src_datasets[0].feat_dim if len(src_datasets) == 1: src_dataset = src_datasets[0] tgt_dataset = tgt_datasets[0] if len(tgt_datasets) > 0 else None else: for i in range(1, len(src_datasets)): assert ( feat_dim == src_datasets[i].feat_dim ), "feature dimension does not match across multiple json files" sample_ratios = [1] * len(src_datasets) sample_ratios[0] = upsample_primary src_dataset = ConcatDataset(src_datasets, sample_ratios) if len(tgt_datasets) > 0: tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios) else: tgt_dataset = None tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None return AsrDataset( src_dataset, src_dataset.sizes, tgt_dataset, tgt_dataset_sizes, tgt_dict, left_pad_source=False, left_pad_target=False, num_buckets=num_buckets, shuffle=shuffle, pad_to_multiple=pad_to_multiple, )