def get_asr_dataset_from_json( data_path, split, tgt_dict, combine, upsample_primary, max_source_positions, max_target_positions, seed=1, specaugment_config=None, ): """ Parse data json and create dataset. See espresso/tools/asr_prep_json.py which pack json from raw files Json example: { "011c0202": { "feat": "fbank/raw_fbank_pitch_train_si284.1.ark:54819", "token_text": "T H E <space> H O T E L", "utt2num_frames": "693", }, "011c0203": { ... } } """ src_datasets = [] tgt_datasets = [] for k in itertools.count(): split_k = split + (str(k) if k > 0 else "") data_json_path = os.path.join(data_path, "{}.json".format(split_k)) if not os.path.isfile(data_json_path): if k > 0: break else: raise FileNotFoundError( "Dataset not found: {}".format(data_json_path)) with open(data_json_path, "rb") as f: loaded_json = json.load(f, object_pairs_hook=OrderedDict) utt_ids, feats, token_text, utt2num_frames = [], [], [], [] for utt_id, val in loaded_json.items(): utt_ids.append(utt_id) feats.append(val["feat"]) if "token_text" in val: token_text.append(val["token_text"]) if "utt2num_frames" in val: utt2num_frames.append(int(val["utt2num_frames"])) assert len(utt2num_frames) == 0 or len(utt_ids) == len(utt2num_frames) src_datasets.append( FeatScpCachedDataset( utt_ids, feats, utt2num_frames=utt2num_frames, seed=seed, specaugment_config=specaugment_config if split == "train" else None, ordered_prefetch=True, )) if len(token_text) > 0: assert len(utt_ids) == len(token_text) assert tgt_dict is not None tgt_datasets.append(AsrTextDataset(utt_ids, token_text, tgt_dict)) logger.info("{} {} examples".format(data_json_path, len(src_datasets[-1]))) if not combine: break assert len(src_datasets) == len(tgt_datasets) or len(tgt_datasets) == 0 feat_dim = src_datasets[0].feat_dim if len(src_datasets) == 1: src_dataset = src_datasets[0] tgt_dataset = tgt_datasets[0] if len(tgt_datasets) > 0 else None else: for i in range(1, len(src_datasets)): assert feat_dim == src_datasets[i].feat_dim, \ "feature dimension does not match across multiple json files" sample_ratios = [1] * len(src_datasets) sample_ratios[0] = upsample_primary src_dataset = ConcatDataset(src_datasets, sample_ratios) if len(tgt_datasets) > 0: tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios) else: tgt_dataset = None tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None return AsrDataset( src_dataset, src_dataset.sizes, tgt_dataset, tgt_dataset_sizes, tgt_dict, left_pad_source=False, left_pad_target=False, max_source_positions=max_source_positions, max_target_positions=max_target_positions, )
def _asr_dataset_helper( self, all_in_memory=False, ordered_prefetch=False, has_utt2num_frames=False, ): if not all_in_memory: src_dataset = FeatScpCachedDataset( utt_ids=self.feats_utt_ids, rxfiles=self.rxfiles, utt2num_frames=self.utt2num_frames if has_utt2num_frames else None, ordered_prefetch=ordered_prefetch, cache_size=self.cache_size, ) else: src_dataset = FeatScpInMemoryDataset( utt_ids=self.feats_utt_ids, rxfiles=self.rxfiles, utt2num_frames=self.utt2num_frames if has_utt2num_frames else None, ) tgt_dataset = AsrTextDataset( utt_ids=self.text_utt_ids, token_text=self.token_text, dictionary=self.dictionary, ) dataset = AsrDataset( src_dataset, src_dataset.sizes, tgt_dataset, tgt_dataset.sizes, self.dictionary, left_pad_source=False, left_pad_target=False, max_source_positions=1000, max_target_positions=200, ) # assume one is a subset of the other expected_dataset_size = min(self.num_audios, self.num_transripts) self.assertEqual(len(dataset.src), expected_dataset_size) self.assertEqual(len(dataset.tgt), expected_dataset_size) indices = list(range(expected_dataset_size)) batch_sampler = [] for i in range(0, expected_dataset_size, self.batch_size): batch_sampler.append(indices[i:i+self.batch_size]) if not all_in_memory: dataset.prefetch(indices) dataloader = torch.utils.data.DataLoader( dataset, collate_fn=dataset.collater, batch_sampler=batch_sampler, ) for i, batch in enumerate(iter(dataloader)): bsz = batch["nsentences"] self.assertEqual(bsz, len(batch_sampler[i])) src_frames = batch["net_input"]["src_tokens"] src_lengths = batch["net_input"]["src_lengths"] tgt_tokens = self.dictionary.string( batch["target"], extra_symbols_to_ignore={self.dictionary.pad()} ).split("\n") tgt_tokens = [line.split(" ") for line in tgt_tokens] self.assertEqual(bsz, src_frames.size(0)) self.assertEqual(bsz, src_lengths.numel()) self.assertEqual(bsz, len(tgt_tokens)) for j, utt_id in enumerate(batch["utt_id"]): self.assertTensorEqual( torch.from_numpy(self.expected_feats[utt_id]).float(), src_frames[j, :src_lengths[j], :] ) self.assertEqual( self.expected_text[utt_id], tgt_tokens[j], )
def get_asr_dataset_from_json( data_path, split, dictionary, combine, upsample_primary, num_buckets=0, shuffle=True, lf_mmi=True, seed=1, specaugment_config=None, chunk_width=None, chunk_left_context=None, chunk_right_context=None, label_delay=0, ): """ Parse data json and create dataset. See espresso/tools/asr_prep_json.py which pack json from raw files Json example: { "011c0202": { "feat": "data/train_si284_spe2e_hires/data/raw_mfcc_train_si284_spe2e_hires.1.ark:24847", "numerator_fst": "exp/chain/e2e_bichar_tree_tied1a/fst.1.ark:6704", "alignment": "exp/tri3/ali.ark:8769", "text": "THE HOTELi OPERATOR'S EMBASSY", "utt2num_frames": "693", }, "011c0203": { ... } } """ src_datasets = [] tgt_datasets = [] text_datasets = [] for k in itertools.count(): split_k = split + (str(k) if k > 0 else "") data_json_path = os.path.join(data_path, "{}.json".format(split_k)) if not os.path.isfile(data_json_path): if k > 0: break else: raise FileNotFoundError( "Dataset not found: {}".format(data_json_path)) with open(data_json_path, "rb") as f: loaded_json = json.load(f, object_pairs_hook=OrderedDict) utt_ids, feats, numerator_fsts, alignments, text, utt2num_frames = [], [], [], [], [], [] for utt_id, val in loaded_json.items(): utt_ids.append(utt_id) feats.append(val["feat"]) if "numerator_fst" in val: numerator_fsts.append(val["numerator_fst"]) if "alignment" in val: alignments.append(val["alignment"]) if "text" in val: text.append(val["text"]) if "utt2num_frames" in val: utt2num_frames.append(int(val["utt2num_frames"])) assert len(utt2num_frames) == 0 or len(utt_ids) == len(utt2num_frames) src_datasets.append( FeatScpCachedDataset( utt_ids, feats, utt2num_frames=utt2num_frames, seed=seed, specaugment_config=specaugment_config if split == "train" else None, ordered_prefetch=True, )) if lf_mmi: if len(numerator_fsts) > 0: assert len(utt_ids) == len(numerator_fsts) tgt_datasets.append( NumeratorGraphDataset(utt_ids, numerator_fsts)) else: # cross-entropy if len(alignments) > 0: assert len(utt_ids) == len(alignments) tgt_datasets.append( AliScpCachedDataset(utt_ids, alignments, utt2num_frames=utt2num_frames, ordered_prefetch=True)) if len(text) > 0: assert len(utt_ids) == len(text) text_datasets.append( AsrTextDataset(utt_ids, text, dictionary, append_eos=False)) logger.info("{} {} examples".format(data_json_path, len(src_datasets[-1]))) if not combine: break assert len(src_datasets) == len(tgt_datasets) or len(tgt_datasets) == 0 assert len(src_datasets) == len(text_datasets) or len(text_datasets) == 0 feat_dim = src_datasets[0].feat_dim if len(src_datasets) == 1: src_dataset = src_datasets[0] tgt_dataset = tgt_datasets[0] if len(tgt_datasets) > 0 else None text_dataset = text_datasets[0] if len(text_datasets) > 0 else None else: for i in range(1, len(src_datasets)): assert feat_dim == src_datasets[i].feat_dim, \ "feature dimension does not match across multiple json files" sample_ratios = [1] * len(src_datasets) sample_ratios[0] = upsample_primary src_dataset = ConcatDataset(src_datasets, sample_ratios) if len(tgt_datasets) > 0: tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios) else: tgt_dataset = None if len(text_datasets) > 0: text_dataset = ConcatDataset(text_datasets, sample_ratios) else: text_dataset = None tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None if lf_mmi: return AsrChainDataset( src_dataset, src_dataset.sizes, tgt_dataset, tgt_dataset_sizes, text=text_dataset, num_buckets=num_buckets, shuffle=shuffle, ) else: return AsrXentDataset( src_dataset, src_dataset.sizes, tgt_dataset, tgt_dataset_sizes, text=text_dataset, num_buckets=num_buckets, shuffle=shuffle, seed=seed, chunk_width=chunk_width, chunk_left_context=chunk_left_context, chunk_right_context=chunk_right_context, label_delay=label_delay, random_chunking=(split == "train" and chunk_width is not None), )