def get_asr_dataset_from_json(data_json_path, tgt_dict): """ Parse data json and create dataset. See scripts/asr_prep_json.py which pack json from raw files Json example: { "utts": { "4771-29403-0025": { "input": { "length_ms": 170, "path": "/tmp/file1.flac" }, "output": { "text": "HELLO \n", "token": "HE LLO", "tokenid": "4815, 861" } }, "1564-142299-0096": { ... } } """ if not os.path.isfile(data_json_path): raise FileNotFoundError("Dataset not found: {}".format(data_json_path)) with open(data_json_path, "rb") as f: data_samples = json.load(f)["utts"] assert len(data_samples) != 0 sorted_samples = sorted( data_samples.items(), key=lambda sample: int(sample[1]["input"]["length_ms"]), reverse=True, ) tgt, aud_paths, speakers, frame_sizes, ids = [[] for _ in range(5)] for i, s in enumerate(sorted_samples): try: res = [int(i) for i in s[1]["output"]["tokenid"].split(", ")] except: continue tgt.append(res) aud_paths.append(s[1]["input"]["path"]) ids.append(s) if '-' in s[0]: m = re.search("(.+?)-(.+?)-(.+?)", s[0]) else: m = re.search("(BAC[0-9]+)(S[0-9]+)(W[0-9]+)", s[0]) speakers.append(m.group(1) + "_" + m.group(2)) frame_sizes.append(s[1]["input"]["length_ms"]) print("load {} samples, dropped {} ".format( len(tgt), len(sorted_samples) - len(tgt))) # append eos tgt = [[*t, tgt_dict.eos()] for t in tgt] return AsrDataset(aud_paths, frame_sizes, tgt, tgt_dict, ids, speakers)
def get_asr_dataset_from_json(data_json_path, tgt_dict, wav2vec): """ Parse data json and create dataset. See scripts/asr_prep_json.py which pack json from raw files Json example: { "utts": { "4771-29403-0025": { "input": { "length_ms": 170, "path": "/tmp/file1.flac" }, "output": { "text": "HELLO \n", "token": "HE LLO", "tokenid": "4815, 861" } }, "1564-142299-0096": { ... } } """ if not os.path.isfile(data_json_path): raise FileNotFoundError("Dataset not found: {}".format(data_json_path)) with open(data_json_path, "rb") as f: data_samples = json.load(f)["utts"] assert len(data_samples) != 0 sorted_samples = sorted( data_samples.items(), key=lambda sample: int(sample[1]["input"]["length_ms"]), reverse=True, ) ids = [s[0] for s in sorted_samples] speakers = [] for s in sorted_samples: m = re.search("(.+?)-(.+?)-(.+?)", s[0]) speakers.append(m.group(1) + "_" + m.group(2)) tgt = [[int(i) for i in s[1]["output"]["tokenid"].split(", ")] for s in sorted_samples] # append eos tgt = [[*t, tgt_dict.eos()] for t in tgt] if wav2vec: emb_paths = [s[1]["input"]["wav2vec_path"] for s in sorted_samples] emb_num_tokens = [ s[1]["input"]["wav2vec_num_tokens"] for s in sorted_samples ] return W2VAsrDataset(emb_paths, emb_num_tokens, tgt, tgt_dict, ids, speakers) else: aud_paths = [s[1]["input"]["path"] for s in sorted_samples] frame_sizes = [s[1]["input"]["length_ms"] for s in sorted_samples] return AsrDataset(aud_paths, frame_sizes, tgt, tgt_dict, ids, speakers)