def filter_datasets_by_stt(data_paths: List[str], metadata_path: str, stt_path: str, save_path: str) -> None: """This function takes in a list of datasets paths `data_paths`, combines them, and filters out the examples where the transcript does not equal the transcript from Google's speech-to-text API saved at `stt_path`. The filtered dataset (which is a filtered superset of all of the datasets in `data_paths`) is written to `save_path`. Args: data_paths: list of datasets to combine and fitler to output metadata_path: path to speak metadata tsv file stt_path: path to speech-to-text saved output from `stt_on_datasets` function save_path: path where filtered examples will be saved """ data_dict = combine_sort_datasets(data_paths) metadata = get_record_ids_map(metadata_path, has_url=True) stt_data = read_data_json(stt_path) filtered_data = list() count = {"total": 0, "filtered": 0} for datum in stt_data: audio_id = path_to_id(datum['audio_path']) spk_trans = process_text(metadata[audio_id]['target_sentence']) ggl_trans = process_text(datum['transcript']) count['total'] += 1 if spk_trans == ggl_trans: count['filtered'] += 1 filtered_data.append(data_dict[datum['audio_path']]) write_data_json(filtered_data, save_path) print(f"number of total de-duplicated examples: {count['total']}") print(f"number of filtered examples: {count['filtered']}")
def main(data_path:str)-> None: """ Prints various stats like the total audio length of the input dataset in `data_path` Args data_path (str): path to the dataset """ dataset = read_data_json(data_path) # iterate through the dataset durations = list() path_counter = dict() data_disk_prefix = "/mnt/disks/data_disk/home" for elem in dataset: durations.append(elem['duration']) path = elem['audio'] assert path.startswith(data_disk_prefix), f"path {path} is not a data disk path" path_counter[path] = path_counter.get(path, 0) + 1 # print out the total time total_sec = sum(durations) total_hr = round(total_sec / 3600, 2) data_name = os.path.basename(data_path) print(f"total duration for {data_name}: {total_hr} hrs:") # check if any paths occured more than once dup_paths = {path: count for path, count in path_counter.items() if count > 1} print(f"duplicated paths: {dup_paths}")
def download_sample_from_GCP(): """This function downloads a random sample of audio from processed data in Google Cloud. """ import subprocess sample_size = 100 src_dataset_path = "/home/dzubke/awni_speech/data/speak_test_data/eval/eval2/eval2_data_2020-12-05.json" vm_string = "dzubke@phoneme-3:" full_dataset_string = vm_string + src_dataset_path dst_dir = "/Users/dustin/CS/work/speak/src/data/speak_eval/eval2/v1/" cmd_base = ["gcloud", "compute", "scp"] # ensure the dst directory exists os.makedirs(dst_dir, exist_ok=True) # copy the datatset from the VM subprocess.run(cmd_base + [full_dataset_string, dst_dir]) print(f"copied dataset {full_dataset_string} to {dst_dir}") dataset = read_data_json( os.path.join(dst_dir, os.path.basename(src_dataset_path))) data_subset = random.sample(dataset, sample_size) dst_audio_dir = os.path.join(dst_dir, "audio") os.makedirs(dst_audio_dir, exist_ok=True) print("destination audio dir: ", dst_audio_dir) for xmpl in data_subset: audio_path = xmpl['audio'] full_audio_string = vm_string + audio_path dst_audio_path = os.path.join(dst_audio_dir, os.path.basename(audio_path)) subprocess.run(cmd_base + [full_audio_string, dst_audio_path])
def __init__(self, data_json, preproc, batch_size): """ this code sorts the samples in data based on the length of the transcript lables and the audio sample duration. It does this by creating a number of buckets and sorting the samples into different buckets based on the length of the labels. It then sorts the buckets based on the duration of the audio sample. """ data = read_data_json(data_json) #loads the data_json into a list self.preproc = preproc # assign the preproc object bucket_diff = 4 # number of different buckets max_len = max(len(x['text']) for x in data) # max number of phoneme labels in data num_buckets = max_len // bucket_diff # the number of buckets buckets = [[] for _ in range(num_buckets) ] # creating an empy list for the buckets for sample in data: bucket_id = min( len(sample['text']) // bucket_diff, num_buckets - 1) buckets[bucket_id].append(sample) sort_fn = lambda x: (round(x['duration'], 1), len(x['text'])) for bucket in buckets: bucket.sort(key=sort_fn) # unpack the data in the buckets into a list data = [sample for bucket in buckets for sample in bucket] self.data = data print(f"in AudioDataset: length of data: {len(data)}")
def dataset_duration(data_path: str) -> float: """Returns the total time (in hours) of the input data path. Args: data_path: path to dataset Returns: (float): total duration of the dataset in hours """ data = read_data_json(data_path) total_duration_s = sum([xmpl['duration'] for xmpl in data]) return round(total_duration_s / 3600, 3)
def get_dataset_ids(dataset_path: str) -> Set[str]: """This function reads a dataset path and returns a set of the record ID's in that dataset. The record ID's mainly correspond to recordings from the speak dataset. For other datsets, this function will return the filename without the extension. Args: dataset_path (str): path to the dataset Returns: Set[str]: a set of the record ID's """ # dataset is a list of dictionaries with the audio path as the value of the 'audio' key. dataset = read_data_json(dataset_path) return set([path_to_id(xmpl['audio']) for xmpl in dataset])
def output_dict_from_json(json_dataset_path: str) -> dict: """This function returns a formatted output dict using a json dataset path """ output_dict = {} # dictionary containing the printed outputs dataset = read_data_json(json_dataset_path) for xmpl in dataset: record_id = path_to_id(xmpl['audio']) output_dict[record_id] = { "header": { "labels": " ".join(xmpl['text']) }, "reference_phones": xmpl['text'] # used for PER calculation } return output_dict
def main(config: dict): data_cfg = config.get('data') log_cfg = config.get('logger') preproc_cfg = config.get('preproc') # create logger logger = logging.getLogger("sig_aug") logger.setLevel(logging.DEBUG) # create file handler which logs even debug messages fh = logging.FileHandler(log_cfg["log_file"]) fh.setLevel(logging.DEBUG) formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s', "%Y-%m-%d %H:%M:%S") fh.setFormatter(formatter) logger.addHandler(fh) logger.info(f"config:\n{config}") dataset = read_data_json(data_cfg['data_set']) audio_list = [example['audio'] for example in dataset] audio_subset = random.sample(audio_list, data_cfg['num_examples']) for audio_path in audio_subset: aug_audio_data, samp_rate = apply_augmentation(audio_path, preproc_cfg, logger) os.makedirs(os.path.join(data_cfg['save_dir'], "aug"), exist_ok=True) os.makedirs(os.path.join(data_cfg['save_dir'], "org"), exist_ok=True) # save the augmented file basename = os.path.basename(audio_path) save_path = os.path.join(data_cfg['save_dir'], "aug", basename) array_to_wave(save_path, aug_audio_data, samp_rate) # copy the original audio file for comparison save_org_path = os.path.join(data_cfg['save_dir'], "org", basename) shutil.copyfile(audio_path, save_org_path) if preproc_cfg['play_audio']: print(f"sample rate: {sr}") print(f"Saved to: {save_path}") print("Playing original audio...") os_play(audio_path) print("Playing augmented audio...") os_play(save_path)
def combine_sort_datasets(data_paths: List[str]) -> Dict[str, dict]: """combines all examples in the datasets in `data_paths` and creates a de-duplicated, sorted dict mapping audio_paths to examples. """ data_dict = dict() total_xmpls = 0 for data_path in data_paths: dataset = read_data_json(data_path) total_xmpls += len(dataset) for xmpl in dataset: if xmpl['audio'] not in data_dict: data_dict[xmpl['audio']] = xmpl else: # checks that same entry in different datasest have same labels assert data_dict[xmpl['audio']]['text'] == xmpl['text'], \ "same entry in different dataset differ in phonemes labels" print(f"number of total examples: {total_xmpls}") return OrderedDict( (audio, xmpl) for audio, xmpl in sorted(data_dict.items()))
def get_id_sets(dataset_paths: List[str]) -> Dict[str, Tuple[set, int]]: """ This function returns a dictionary with the dataset-name as the keys and a set of record-ids as the values Args: dataset_paths (List[str]): a list of dataset paths (str) Returns: Dict[str, set]: a dict with the set of ids as values """ data_dict = dict() for data_path in dataset_paths: # _extract_id on the data path will return the dataset name data_name = _extract_id(data_path) dataset = read_data_json(data_path) # set comprehension what extracts the record-id from each audiopath in the dataset id_set = {_extract_id(xmpl['audio']) for xmpl in dataset} data_dict.update({data_name: (id_set, len(dataset))}) return data_dict
def assess_from_json(eval_phn_path, ds_json_path): ds_preds = read_data_json(ds_json_path) rec_to_eval_phns = read_eval_file(eval_phn_path) for xmpl in ds_preds: ref_phns = xmpl['label'] hyp_phns = xmpl['prediction'] edit_ops = get_editops(hyp_phns, ref_phns) rec_id = path_to_id(xmpl['filename']) rec_id, has_mispro, eval_phns = rec_to_eval_phns[rec_id] for eval_phn in eval_phns: print(f"record id: {rec_id}") print(f"evaluation phone: {eval_phn}") print(f"has mispro: {bool(has_mispro)}") print_editops(edit_ops, hyp_phns, ref_phns) mispro_detected = check_mispro(edit_ops, hyp_phns, ref_phns, eval_phn) print(f"mispro detected?: {mispro_detected}") print(f"detector is correct?: {has_mispro == mispro_detected}") print('\n\n')
def review_audio(data_path: str, gcsfuse_dir: str, restart_path: str = None): """ Args: data_path: path to dataset gcsfuse_dir: path to directory connected to gsc bucket restart_path: path of example that the script will start from """ dataset = read_data_json(data_path) # start from restart_path if restart_path is not None: restart_id = path_to_id(restart_path) restart_idx = 0 for i, xmpl in enumerate(dataset): if path_to_id(xmpl['audio']) == restart_id: restart_idx = i break dataset = dataset[restart_idx:] for xmpl in dataset: next_recording = False while not next_recording: print('\n\n') print(xmpl['text']) play_fn(xmpl['audio'], gcsfuse_dir) print("(f) next rec, (j) play again, (p) print full entry") action = input() if action == 'f': next_recording = True elif action == 'j': pass elif action == 'p': print(xmpl) else: print("invalid entry")
def match_filename(label:list, dataset_json:str, return_order=False) -> str: """ returns the filename in dataset_json that matches the phonemes in label """ dataset = read_data_json(dataset_json) matches = [] for i, sample in enumerate(dataset): if sample['text'] == label: matches.append(sample["audio"]) order = i if len(matches) > 1: print(f"multiple matches found {matches} for label {label}") print("Would you like to continue? (y/n)") response = input() if response.lower() == "n": raise AssertionError assert len(matches) >0, f"no matches found for {label}" if return_order: output = (matches[0], order) else: output = matches[0] return output
def dataset_stats(dataset_path:str)->None: """This function prints a variety of stats (like mean and std-dev) for the input dataset Args: dataset_path (str): path to the dataset """ dataset = read_data_json(dataset_path) data_features = { "target_len": [len(xmpl['text']) for xmpl in dataset], "audio_dur": [xmpl['duration'] for xmpl in dataset] } stat_functions = { "mean": np.mean, "stddev": np.std, } print(f"stats for dataset: {os.path.basename(dataset_path)}") for data_name, data in data_features.items(): for stat_name, stat_fn in stat_functions.items(): print(f"\t{stat_name} of {data_name} is: {round(stat_fn(data), 3)}") print()
def get_train_test_ids() -> Set[str]: """ This function returns a set of ids for records that are included in the speak training and test sets. The paths to the training and test sets are hardcoded to the paths on the cloud VM's. Returns: Set[str]: a set of record_ids for the training and test recordings """ # train_data_trim_2020-09-22.json is the entire 7M recordings in the full speak training set train_test_paths = [ "/home/dzubke/awni_speech/data/speak_train/train_data_trim_2020-09-22.json", "/home/dzubke/awni_speech/data/speak_test_data/2020-05-27/speak-test_2020-05-27.json", "/home/dzubke/awni_speech/data/speak_test_data/2019-11-29/speak-test_2019-11-29.json" ] datasets = [read_data_json(path) for path in train_test_paths] train_test_ids = set() # loop throug the datasets and add the id's output from `path_to_id` to the set for dataset in datasets: train_test_ids.update( [path_to_id(datum['audio']) for datum in dataset]) return train_test_ids
def download_dataset(self): """ This method loops through the firestore document database using paginated queries based on the document id. It filters out documents where `target != guess` if `self.target_eq_guess` is True and saves the audio file and target text into separate files. """ PROJECT_ID = 'speak-v2-2a1f1' QUERY_LIMIT = 2000 # max size of query SAMPLES_PER_QUERY = 200 # number of random samples downloaded per query AUDIO_EXT = '.m4a' # extension of downloaded audio audio_dir = os.path.join(self.output_dir, "audio") os.makedirs(audio_dir, exist_ok=True) # verify and set the credientials CREDENTIAL_PATH = "/home/dzubke/awni_speech/speak-v2-2a1f1-d8fc553a3437.json" assert os.path.exists( CREDENTIAL_PATH ), "Credential file does not exist or is in the wrong location." # set the enviroment variable that `firebase_admin.credentials` will use os.putenv("GOOGLE_APPLICATION_CREDENTIALS", CREDENTIAL_PATH) # initialize the credentials and firebase db client cred = credentials.ApplicationDefault() firebase_admin.initialize_app(cred, {'projectId': PROJECT_ID}) db = firestore.client() # create the data-label path and initialize the tsv headers date = datetime.date.today().isoformat() self.data_label_path = os.path.join(self.output_dir, "eval2-v4_data_" + date + ".tsv") self.metadata_path = os.path.join( self.output_dir, "eval2-v4_metadata_" + date + ".json") # re-calculate the constraints in the `config` as integer counts based on the `dataset_size` self.constraints = { name: int(self.constraints[name] * self.num_examples) for name in self.constraints.keys() } # constraint_names will help to ensure the dict keys created later are consistent. constraint_names = list(self.constraints.keys()) print("constraints: ", self.constraints) # id_counter keeps track of the counts for each speaker, lesson, and line ids id_counter = {name: dict() for name in constraint_names} # create a mapping from record_id to lesson, line, and speaker ids disjoint_ids_map = get_record_ids_map(metadata_path, constraint_names) # create a dict of sets of all the ids in the disjoint datasets that will not # be included in the filtered dataset disjoint_id_sets = {name: set() for name in self.disjoint_id_names} for disj_dataset_path in self.disjoint_datasets: disj_dataset = read_data_json(disj_dataset_path) # extracts the record_ids from the excluded datasets record_ids = [ path_to_id(example['audio']) for example in disj_dataset ] # loop through each record id for record_id in record_ids: # loop through each id_name and update the disjoint_id_sets for disjoint_id_name, disjoint_id_set in disjoint_id_sets.items( ): disjoint_id_set.add( disjoint_ids_map[record_id][disjoint_id_name]) # creating a data range from `self.days_from_today` in the correct format now = datetime.datetime.utcnow() day_delta = datetime.timedelta(days=self.days_from_today) day_range = now - day_delta day_range = day_range.isoformat("T") + "Z" with open(self.data_label_path, 'w', newline='\n') as tsv_file: tsv_writer = csv.writer(tsv_file, delimiter='\t') header = [ "id", "target", "guess", "lessonId", "target_sentence", "lineId", "uid", "redWords_score", "date" ] tsv_writer.writerow(header) # create the first query based on the constant QUERY_LIMIT rec_ref = db.collection(u'recordings') # this is the final record_id that was downloaded from the speak training set speak_train_last_id = 'SR9TIlF8bSWApZa1tqEBIHOQs5z1-1583920255' next_query = rec_ref\ .order_by(u'id')\ .start_after({u'id': speak_train_last_id})\ .limit(QUERY_LIMIT)\ # loop through the queries until the example_count is at least the num_examples example_count = 0 # get the ids from the training and testsets to ensure the downloaded set is disjoint train_test_set = self.get_train_test_ids() while example_count < self.num_examples: print(f"another loop with {example_count} examples written") # convert the generator to a list to retrieve the last doc_id docs = list( map(lambda x: self._doc_trim_to_dict(x), next_query.stream())) try: # this time will be used to start the next query last_id = docs[-1]['id'] # if the docs list is empty, there are no new documents # and an IndexError will be raised and break the while loop except IndexError: print("Exiting while loop") break # selects a random sample of `SAMPLES_PER_QUERY` from the total queries #docs = random.sample(docs, SAMPLES_PER_QUERY) for doc in docs: # if num_examples is reached, break if example_count >= self.num_examples: break target = process_text(doc['info']['target']) # check that the speaker, target-sentence, and record_Id are disjoint if doc['user']['uid'] not in disjoint_id_sets['speaker']\ and target not in disjoint_id_sets['target_sentence']\ and doc['id'] not in train_test_set: # set `self.target_eq_guess` to True in `init` if you want ## to filter by `target`==`guess` if self.target_eq_guess: # process the target and guess and remove apostrophe's for comparison guess = process_text(doc['result']['guess']) # removing apostrophes for comparison target_no_apostrophe = target.replace("'", "") guess_no_apostrophe = guess.replace("'", "") # if targ != guess, skip the record if target_no_apostrophe != guess_no_apostrophe: continue # if `True` constraints on the records downloaded will be checked if self.check_constraints: # create a mapping to feed into `check_update_constraints` record_ids_map = { doc['id']: { 'lesson': doc['info']['lessonId'], 'target_sentence': target, # using processed target as id 'speaker': doc['user']['uid'] } } pass_constraint = check_update_contraints( doc['id'], record_ids_map, id_counter, self.constraints) # if the record doesn't pass the constraints, continue to the next record if not pass_constraint: continue # save the audio file from the link in the document audio_url = doc['result']['audioDownloadUrl'] audio_path = os.path.join(audio_dir, doc['id'] + AUDIO_EXT) # convert the downloaded file to .wav format # usually, this conversion done in the preprocessing step # but some eval sets don't need PER labels, and so this removes the need to # preprocess the evalset. base, raw_ext = os.path.splitext(audio_path) # use the `.wv` extension if the original file is a `.wav` wav_path = base + os.path.extsep + "wav" # if the wave file doesn't exist, convert to wav if not os.path.exists(wav_path): try: to_wave(audio_path, wav_path) except subprocess.CalledProcessError: # if the file can't be converted, skip the file by continuing logging.info( f"Process Error converting file: {audio_path}" ) continue # save the target in a tsv row # tsv header: "id", "target", "guess", "lessonId", "target_id", "lineId", "uid", "date" tsv_row = [ doc['id'], doc['info']['target'], doc['result']['guess'], doc['info']['lessonId'], target, # using this to replace lineId doc['info']['lineId'], doc['user']['uid'], doc['result']['score'], doc['info']['date'] ] tsv_writer.writerow(tsv_row) # save all the metadata in a separate file #with open(self.metadata_path, 'a') as jsonfile: # json.dump(doc, jsonfile) # jsonfile.write("\n") example_count += 1 # create the next query starting after the last_id next_query = (rec_ref\ .order_by(u'id')\ .start_after({u'id': last_id})\ .limit(QUERY_LIMIT) )
def stt_on_sample(data_path: str, metadata_path: str, save_path: str, stt_provider: str = 'ibm') -> None: """Pulls a random sample of audio files from `data_path` and calls a speech-to-text API to get transcript predictions. The STT output is formated and written to `save_path` along with the files's transcript from `metadata_path`. Args: data_path: path to training json metadata_path: path to metadata tsv containing transcript save_path: path where output txt will be saved stt_provider: name of company providing STT model """ random.seed(0) SAMPLE_SIZE = 100 data = read_data_json(data_path) data_sample = random.choices(data, k=SAMPLE_SIZE) print(f"sampling {len(data_sample)} samples from {data_path}") # mapping from audio_id to transcript metadata = get_record_ids_map(metadata_path, has_url=True) client = get_stt_client(stt_provider) preds_with_two_trans = set() match_trans_entries = list() # output list for matching transcripts diff_trans_entries = list() # output list for non-matching transcripts for datum in data_sample: audio_path = datum['audio'] audio_id = path_to_id(audio_path) id_plus_dir = os.path.join(*audio_path.split('/')[-2:]) data = read_data_json(data_path) data_sample = random.choices(data, k=SAMPLE_SIZE) print(f"sampling {len(data_sample)} samples from {data_path}") # mapping from audio_id to transcript metadata = get_record_ids_map(metadata_path, has_url=True) client = get_stt_client(stt_provider) preds_with_two_trans = set() match_trans_entries = list() # output list for matching transcripts diff_trans_entries = list() # output list for non-matching transcripts for datum in data_sample: audio_path = datum['audio'] audio_id = path_to_id(audio_path) id_plus_dir = os.path.join(*audio_path.split('/')[-2:]) response = get_stt_response(audio_path, client, stt_provider) resp_dict = format_response_dict(audio_path, response, stt_provider) ggl_trans = process_text(resp_dict['transcript']) apl_trans = process_text(metadata[audio_id]['target_sentence']) out_txt = format_txt_from_dict(resp_dict, apl_trans, id_plus_dir) if apl_trans == ggl_trans: match_trans_entries.append(out_txt) else: diff_trans_entries.append(out_txt) with open(save_path, 'w') as fid: for entries in [diff_trans_entries, match_trans_entries]: fid.write("-" * 10 + '\n') for entry in entries: fid.write(entry + '\n\n')
def __init__(self, dataset_path: str, subset_size: int): self.dataset_path = dataset_path self.data_json = read_data_json(dataset_path) self.subset = random.sample(self.data_json, k=subset_size)
def assess_speak_train(dataset_paths: List[str], metadata_path:str, out_dir:str, use_json:bool=True)->None: """This function creates counts of the speaker, lesson, and line ids in a speak training dataset Args: dataset_path (str): path to speak training.json dataset metadata_path (str): path to tsv file that contains speaker, line, and lesson ids out_dir (str): directory where plots and txt files will be saved use_json (bool): if true, the data will be read from a training.json file Returns: None """ def _increment_key(in_dict, key): in_dict[key] = in_dict.get(key, 0) + 1 # this will read the data from a metadata.tsv file if not use_json: # count dictionaries for the lesssons, lines, and users (speakers) lesson_dict, line_dict, user_dict, target_dict = {}, {}, {}, {} # create count_dicts for each with open(metadata_path, 'r') as tsv_file: tsv_reader = csv.reader(tsv_file, delimiter='\t') header = next(tsv_reader) print(header) for row in tsv_reader: _increment_key(lesson_dict, row[2]) _increment_key(line_dict, row[3]) _increment_key(user_dict, row[4]) _increment_key(target_dict, process_text(row[1])) # put the labels and count_dicts in list of the for-loop constraint_names = ['lesson', 'line', 'speaker', 'target_sent'] counter = { "lesson": lesson_dict, "line": line_dict, "speaker": user_dict, "target_sent": target_dict } # reading from a training.json file supported by a metadata.tsv file if use_json: # create mapping from record_id to speaker, line, and lesson ids rec_ids_map = dict() constraint_names = ['lesson', 'line', 'speaker', 'target_sent'] counter = {name: dict() for name in constraint_names} with open(metadata_path, 'r') as tsv_file: tsv_reader = csv.reader(tsv_file, delimiter='\t') # header: id, text, lessonId, lineId, uid(speaker_id), date header = next(tsv_reader) rec_ids_map = dict() for row in tsv_reader: rec_ids_map[row[0]]= { constraint_names[0]: row[2], # lesson constraint_names[1]: row[3], # line constraint_names[2]: row[4], # speaker constraint_names[3]: process_text(row[1]), # target-sentence "date": row[6] # date } total_date_counter = dict() # `unq_date_sets` keep track of the unique ids unq_date_counter = {name: dict() for name in constraint_names} # iterate through the datasets for dataset_path in dataset_paths: dataset = read_data_json(dataset_path) print(f"dataset {path_to_id(dataset_path)} size is: {len(dataset)}") # iterate through the exmaples in the dataset for xmpl in dataset: rec_id = path_to_id(xmpl['audio']) date = rec_ids_map[rec_id]['date'] # date has format 2020-09-10T04:24:03.073Z, so splitting # and joining by '-' using the first two element will be `2020-09` yyyy_mm_date = '-'.join(date.split('-')[:2]) _increment_key(total_date_counter, yyyy_mm_date) # iterate through the constraints and update the id counters for name in constraint_names: constraint_id = rec_ids_map[rec_id][name] _increment_key(counter[name], constraint_id) update_unq_date_counter( unq_date_counter, name, constraint_id, yyyy_mm_date ) # create the plots fig, axs = plt.subplots(1,len(constraint_names)) fig.set_size_inches(8, 6) # plot and calculate stats of the count_dicts for ax, name in zip(axs, constraint_names): plot_count(ax, counter[name], name) print(f"{name} stats") print_stats(counter[name]) print() # ensures the directory of `out_dir` exists os.makedirs(out_dir, exist_ok=dir) out_path = os.path.join(out_dir, os.path.basename(out_dir)) print("out_path: ", out_path) plt.savefig(out_path + "_count_plot.png") plt.close() # plot the total_date histogram fig, ax = plt.subplots(1,1) dates = sorted(total_date_counter.keys()) date_counts = [total_date_counter[date] for date in dates] ax.plot(range(len(date_counts)), date_counts) plt.xticks(range(len(date_counts)), dates, rotation=60) #ax.set_title(label) #ax.set_xlabel(f"unique {label}") #ax.set_ylabel(f"utterance per {label}") #ax.xaxis.set_major_formatter(tick.FuncFormatter(reformat_large_tick_values)); ax.yaxis.set_major_formatter(tick.FuncFormatter(reformat_large_tick_values)); plt.tight_layout() plt.savefig(out_path + "_date_count.png") plt.close() # plot the unique ids for name in constraint_names: fig, ax = plt.subplots(1,1) date_counts = [] dates = sorted(unq_date_counter[name].keys()) total_count = sum([unq_date_counter[name][date]['count'] for date in dates]) cumulative_count = 0 for date in dates: cumulative_count += unq_date_counter[name][date]['count'] date_counts.append(round(cumulative_count/total_count, 2)) ax.plot(range(len(date_counts)), date_counts) plt.xticks(range(len(date_counts)), dates, rotation=60) ax.set_title(name) ax.set_xlabel(f"Date") ax.set_ylabel(f"% of total unique ID's") #ax.xaxis.set_major_formatter(tick.FuncFormatter(reformat_large_tick_values)); #ax.yaxis.set_major_formatter(tick.FuncFormatter(reformat_large_tick_values)); plt.tight_layout() plt.savefig(out_path + f"_unq_cum_date_{name}.png") plt.close() # sort the lesson_ids and line_ids and write to txt file for name in counter: sorted_ids = sorted(list(counter[name].keys())) with open(f"{out_path}_{name}.txt", 'w') as fid: for ids in sorted_ids: fid.write(ids+"\n")
def filter_speak_train(config: dict) -> None: """ This script filters the dataset in `full_json_path` and write the new dataset to `filter_json_path`. The constraints on the filtered dataset are: - utterances per speaker, lesson, and line cannot exceed the decimal values as a fraction of the `dataset_size`. Older config files have an absolute value on the `max_speaker_count` - the utterances are not also included in the datasets specified in `excluded_datasets` Config contents: full_json_path (str): path to the source json file that that the output will filter from metadata_path (str): path to the tsv file that includes metadata on each recording, like the speaker_id filter_json_path (str): path to the filtered, written json file dataset_size (int): number of utterances included in the output dataset constraints (dict): dict of constraints on the number of utterances per speaker, lesson, and line expressed as decimal fractions of the total dataset. disjoint_datasets (Dict[Tuple[str],str]): dict whose keys are a tuple of the ids that will be disjoint and whose values are the datasets paths whose examples will be disjiont from the output Returns: None, only files written. """ # unpacking the config # TODO, only unpack what is necessary full_json_path = config['full_json_path'] metadata_path = config['metadata_tsv_path'] filter_json_path = config['filter_json_path'] dataset_size = config['dataset_size'] # re-calculate the constraints as integer counts based on the `dataset_size` constraints = { name: int(value * dataset_size) for name, value in config['constraints'].items() } print("constraints: ", constraints) # read and shuffle the full dataset and convert to iterator to save memory full_dataset = read_data_json(full_json_path) random.shuffle(full_dataset) full_dataset = iter(full_dataset) # get the mapping from record_id to other ids (like speaker, lesson, line) for each example record_ids_map = get_record_ids_map(metadata_path, list(constraints.keys())) # create a defaultdict with set values for each disjoint-id name disjoint_id_sets = get_disjoint_sets(config['disjoint_datasets']) print("all disjoint names: ", disjoint_id_sets.keys()) # id_counter keeps track of the counts for each speaker, lesson, and line ids id_counter = {name: dict() for name in constraints} examples_written = 0 # loop until the number of examples in dataset_size has been written with open(filter_json_path, 'w') as fid: while examples_written < dataset_size: if examples_written != 0 and examples_written % config[ 'print_modulus'] == 0: print(f"{examples_written} examples written") try: example = next(full_dataset) except StopIteration: print(f"Stop encountered {examples_written} examples written") break record_id = path_to_id(example['audio']) # check if the ids associated with the record_id are not included in the disjoint_datasets pass_filter = check_disjoint_filter(record_id, disjoint_id_sets, record_ids_map) if pass_filter: # check if the record_id pass the speaker, line, lesson constraints pass_constraint = check_update_contraints( record_id, record_ids_map, id_counter, constraints) if pass_constraint: # if you don't want to use distribution filtering, the example always passes if not config['dist_filter']['use']: pass_distribution_filter = True else: # creates a filter based on the params in `dist_filter` pass_distribution_filter = check_distribution_filter( example, config['dist_filter']) if pass_distribution_filter: json.dump(example, fid) fid.write("\n") # increment counters examples_written += 1
def __init__(self, data_json: list, preproc_cfg: dict, logger=None, max_samples: int = 1000, start_and_end=False): """ Builds a preprocessor from a dataset. Arguments: data_json (string): A file containing a json representation of each example per line. preproc_json: A json file defining the preprocessing with attributes preprocessor: "log_spec" or "mfcc" to determine the type of preprocessing window_size: the size of the window in the spectrogram transform step_size: the size of the step in the spectrogram transform max_samples (int): The maximum number of examples to be used in computing summary statistics. start_and_end (bool): Include start and end tokens in labels. """ # if true, data augmentation will be applied self.train_status = True assert preproc_cfg['preprocessor'] in ['log_spectrogram', 'log_mel', 'mfcc'], \ f"preprocessor name: {preproc_cfg['preprocessor']} is unacceptable" self.preprocessor = preproc_cfg['preprocessor'] self.window_size = preproc_cfg['window_size'] self.step_size = preproc_cfg['step_size'] self.use_feature_normalize = preproc_cfg['use_feature_normalize'] self.augment_from_normal = preproc_cfg.get('augment_from_normal', False) self.tempo_gain_pitch_perturb = preproc_cfg['tempo_gain_pitch_perturb'] self.tempo_gain_pitch_prob = preproc_cfg.get('tempo_gain_pitch_prob', 1.0) self.tempo_range = preproc_cfg['tempo_range'] self.gain_range = preproc_cfg['gain_range'] self.pitch_range = preproc_cfg['pitch_range'] self.synthetic_gaussian_noise = preproc_cfg.get( 'synthetic_gaussian_noise', False) self.gauss_noise_prob = preproc_cfg.get('gauss_noise_prob', 1.0) self.gauss_snr_db_range = preproc_cfg.get( 'gauss_snr_db_range', preproc_cfg.get('signal_to_noise_range_db')) self.background_noise = preproc_cfg.get( 'background_noise', preproc_cfg.get('inject_noise')) self.noise_dir = preproc_cfg.get('background_noise_dir', preproc_cfg.get('noise_directory')) self.background_noise_prob = preproc_cfg.get( 'background_noise_prob', preproc_cfg.get('noise_prob')) self.background_noise_range = preproc_cfg.get( 'background_noise_range', preproc_cfg.get('noise_levels')) self.spec_augment = preproc_cfg.get( 'spec_augment', preproc_cfg.get('use_spec_augment')) self.spec_augment_prob = preproc_cfg.get('spec_augment_prob', 1.0) self.spec_augment_policy = preproc_cfg['spec_augment_policy'] # Compute data mean, std from sample data = read_data_json(data_json) audio_files = [sample['audio'] for sample in data] random.shuffle(audio_files) self.mean, self.std = compute_mean_std( audio_files[:max_samples], self.preprocessor, window_size=self.window_size, step_size=self.step_size, use_feature_normalize=self.use_feature_normalize) self._input_dim = self.mean.shape[0] self.use_log = (logger is not None) self.logger = logger # Make char map chars = sorted( list(set(label for datum in data for label in datum['text']))) if start_and_end: # START must be last so it can easily be # excluded in the output classes of a model. chars.extend([self.END, self.START]) self.start_and_end = start_and_end assert preproc_cfg['blank_idx'] in ['first', 'last'], \ f"blank_idx: {preproc_cfg['blank_idx']} must be either 'first' or 'last'" # if the blank_idx is 'first' then the int_to_char must start at 1 as 0 is already reserved ## for the blank if preproc_cfg['blank_idx'] == 'first': start_idx = 1 else: # if the blank_idx is 'last', then the int_to_char can start at 0 start_idx = 0 self.int_to_char = dict(enumerate( chars, start_idx)) # start at 1 so zero can be blank for native loss self.char_to_int = {v: k for k, v in self.int_to_char.items()}