def make_strings(cls, num, unique=True, exclusive=None, embedded=False, multiple=1, verbose=False, interleave=True): # Check input if exclusive is None: exclusive = [] elif not isinstance(exclusive, list): raise TypeError('!! exclusive must be a list of Reber strings') # Make strings reber_list = [] long_token = None for i in range(num): if interleave: long_token = 'T' if long_token in ('P', None) else 'P' while True: string = ReberGrammar(embedded, multiple=multiple, specification=long_token) if unique and string in reber_list: continue if string in exclusive: continue reber_list.append(string) break if verbose: console.clear_line() console.print_progress(i + 1, num) if verbose: console.clear_line() # Return a list of Reber string return reber_list
def recover_progress(self, start_time=None): # Print progress bar if self.th.progress_bar and self.th.round_length is not None: assert isinstance(self._training_set, TFRData) progress = self.th.round_progress assert progress is not None console.print_progress(progress=progress, start_time=start_time)
def separate_verified_data(data_dir, csv_path, to_path, verbose=True, vname='audio_train_verified', uvname='audio_train_unverified'): # Check path check_path(data_dir, create_path=False) check_path(csv_path, create_path=False) # Read raw csv file and split raw_csv = pd.read_csv(csv_path) verified_csv = raw_csv.loc[raw_csv[MANUALLY_VERIFIED] == 1] unverified_csv = raw_csv.loc[raw_csv[MANUALLY_VERIFIED] == 0] # Copy verified data for df, dn in ((verified_csv, vname), (unverified_csv, uvname)): check_path(to_path, dn, create_path=True) console.show_status('Generating {} data ... '.format(dn)) num_files = len(df) for i, file_name in enumerate(df[FNAME]): from_file = os.path.join(data_dir, file_name) to_file = os.path.join(to_path, dn, file_name) sh.copyfile(from_file, to_file) if verbose: console.print_progress(i, num_files) file_name = to_path + dn + '.csv' df.to_csv(file_name, index=False) console.show_status('CSV file saved to {}'.format(file_name))
def _print_progress(self, epi, start_time, steps, **kwargs): """Use a awkward way to avoid IDE warning :(""" console.clear_line() console.show_status( 'Episode {} [{} total] {} steps, Time elapsed = {:.2f} sec'.format( epi, self.counter, steps, time.time() - start_time)) console.print_progress(epi, kwargs.get('total'))
def _inter_cut(self, content, prompt='>>', start_time=None): # Show content console.show_status(content, symbol=prompt) # Print progress bar if self.th.progress_bar and self.th.round_length is not None: assert isinstance(self._training_set, TFRData) progress = self.th.round_progress assert progress is not None console.print_progress(progress=progress, start_time=start_time)
def _inter_cut(self, content, start_time=None): # If run on the cloud, do not show progress bar if not FLAGS.progress_bar: console.show_status(content) return console.clear_line() console.show_status(content) console.print_progress(progress=self._training_set.progress, start_time=start_time)
def _snapshot(self, progress): if self._snapshot_function is None: return filename = 'train_{}_episode'.format(self.counter) fullname = "{}/{}".format(self.snapshot_dir, filename) self._snapshot_function(fullname) console.clear_line() console.write_line("[Snapshot] snapshot saved to {}".format(filename)) console.print_progress(progress=progress)
def _generate_meta(self, data_dir): console.show_status('Scanning data directory ...') file_list = self._get_tfd_list(data_dir) # Scan directory num_files = len(file_list) for i, file_name in enumerate(file_list): data_set = self._load_data_set(file_name) self.files[os.path.basename(file_name)] = data_set.structure console.print_progress(i + 1, num_files) del data_set
def _check_data_files(self, data_dir): file_list = self._get_tfd_list(data_dir) console.show_status('Integrity checking ...') for i, f in enumerate(file_list): if os.path.basename(f) not in self.files.keys(): raise AssertionError( '!! Can not find {} in metadata'.format(f)) console.print_progress(i, len(file_list)) if len(self.files) != len(file_list): raise AssertionError( '!! {} files are expected but only {} are found'.format( len(self.files), len(file_list)))
def _print_progress(self, epc, start_time, info_dict, **kwargs): # Generate loss string loss_strings = ['{} = {:.3f}'.format(k, info_dict[k]) for k in info_dict.keys()] loss_string = ', '.join(loss_strings) total_epoch = self._counter / self._training_set.batches_per_epoch if FLAGS.progress_bar: console.clear_line() console.show_status( 'Epoch {} [{:.1f} Total] {}'.format(epc + 1, total_epoch, loss_string)) if FLAGS.progress_bar: console.print_progress(progress=self._training_set.progress, start_time=start_time)
def _synthesize(cls, size, L, N, fixed_length, verbose=False): features, targets = [], [] for i in range(size): x, y = engine(L, N, fixed_length) features.append(x) targets.append(y) if verbose: console.clear_line() console.print_progress(i + 1, size) # Wrap data into a SequenceSet data_set = SequenceSet(features, summ_dict={'targets': targets}, n_to_one=True, name='TemporalOrder') return data_set
def _init_big_data(self, csv_path, lb_sheet_path): if csv_path is None or lb_sheet_path is None: return self.with_labels = True self.properties[self.DATA_INFO] = pd.read_csv(csv_path) self.properties[self.LABEL_INDEX] = pd.read_csv(lb_sheet_path) # Generate groups self.properties[self.GROUPS] = collections.OrderedDict() console.show_status('Generating group information ...') for i, label in enumerate(self.label_index[du.LABEL]): file_list = list(self.data_info.loc[self.data_info[du.LABEL] == label][du.FNAME]) self.groups[label] = sorted(file_list, key=self._get_length, reverse=True) console.print_progress(i, len(self.label_index)) console.show_status('Group population:') console.pprint(self.group_population)
def _training_match(self, agent, rounds, progress, rate_thresh): # TODO: inference graph is hard to build under this frame => compromise if self._opponent is None: return assert isinstance(agent, FMDPAgent) console.clear_line() title = 'Training match with {}'.format(self._opponent.player_name) rate = self.compete(agent, rounds, self._opponent, title=title) if rate >= rate_thresh and isinstance(self._opponent, TDPlayer): # Find an stronger opponent self._opponent._load() self._opponent.player_name = 'Shadow_{}'.format( self._opponent.counter) console.show_status('Opponent updated') console.print_progress(progress=progress)
def convert_to_tframe_files(data_dir, fs, csv_path=None, label_path=None, to_dir=None, verbose=True): # Check paths check_path(data_dir, create_path=False) if csv_path is not None: check_path(csv_path, create_path=False) if label_path is not None: check_path(label_path, create_path=False) if to_dir is None: to_dir = os.path.join( os.path.dirname(data_dir), 'tfd_{}_{}Hz'.format(os.path.basename(data_dir), fs)) check_path(to_dir, create_path=True) # Generate file list srcs, dsts = [], [] for f in os.listdir(data_dir): file_path = os.path.join(data_dir, f) if not os.path.isfile(file_path) or f[-4:] != '.wav': continue srcs.append(file_path) dsts.append(os.path.join(to_dir, f)) # Load csv file and label sheet if necessary csv, lb_sheet = None, None if csv_path is not None and label_path is not None: csv = pd.read_csv(csv_path) lb_sheet = pd.read_csv(label_path) # Wrap each .wav file in data_dir into a SignalSet num_files = len(srcs) console.show_status('Converting {} ...'.format(data_dir)) for i, src, dst in zip(range(num_files), srcs, dsts): s = wav_to_signal(src, fs) ss = SignalSet(s, name=os.path.basename(src)) if csv is not None: fname = os.path.basename(src) ss.data_dict[pedia.labels] = [_get_one_hot(fname, csv, lb_sheet)] ss.save(dst) if verbose: console.print_progress(i + 1, num_files) # Show status console.show_status('Data saved to {}'.format(to_dir))
def make_strings(cls, num, unique=True, exclusive=None, embedded=False, verbose=False): # Check input if exclusive is None: exclusive = [] elif not isinstance(exclusive, list): raise TypeError('!! exclusive must be a list of Reber strings') # Make strings reber_list = [] for i in range(num): while True: string = ReberGrammar(embedded) if unique and string in reber_list: continue if string in exclusive: continue reber_list.append(string) break if verbose: console.clear_line() console.print_progress(i + 1, num) if verbose: console.clear_line() # Return a list of Reber string return reber_list
def merge_to_signal_set(self, save_as=None): signals = [] onehot_labels = [] file_names = [] groups = [] # Put each signal together with its one-hot label into lists for wav_names in self.groups.values(): indices = [] for wav_name in wav_names: file_path = os.path.join(self.data_dir, wav_name + '.tfds') # Load signal set from disk data_set = self._load_data_set(file_path) assert isinstance(data_set, SignalSet) assert len(data_set.signals) == 1 and len( data_set[pedia.labels]) == 1 # Append signal and label to corresponding list signals.append(data_set.signals[0]) onehot_labels.append(data_set[pedia.labels][0]) file_names.append(wav_name) # Print progress console.print_progress(len(signals), len(self.files)) indices.append(len(signals) - 1) groups.append(indices) # Wrap data_dict and properties data_dict, properties = {}, {} data_dict[pedia.labels] = onehot_labels properties[self.LABEL_INDEX] = self.label_index properties[self.DATA_INFO] = self.data_info properties[self.WAV_NAMES] = file_names properties[self.GROUPS] = groups # Save and return gpat_signal_set = GPATSignalSet(signals, data_dict=data_dict, **properties) if save_as is not None: gpat_signal_set.save(save_as) return gpat_signal_set
def down_sample(data_dir, sample_rate, to_path=None, verbose=True): """ Common sample frequency list: 8000 Hz - fs for telephone 11025 Hz - 22050 Hz - fs for radio 32000 Hz - fs for miniDV 44100 Hz - fs for CD """ # Check data directory check_path(data_dir, create_path=False) # Check to_path if to_path is None: to_path = os.path.join( os.path.dirname(data_dir), '{}_{}Hz'.format(os.path.basename(data_dir), sample_rate)) check_path(to_path, create_path=True) # Generate file list srcs, dsts = [], [] for f in os.listdir(data_dir): file_path = os.path.join(data_dir, f) if not os.path.isfile(file_path) or f[-4:] != '.wav': continue srcs.append(file_path) dsts.append(os.path.join(to_path, f)) # Down sample each .wav file in data_dir num_files = len(srcs) console.show_status('Down sampling ...') for i, src, dst in zip(range(num_files), srcs, dsts): data, _ = librosa.core.load(src, sample_rate, res_type='kaiser_fast') librosa.output.write_wav(dst, data, sample_rate) if verbose: console.print_progress(i + 1, num_files) # Show status console.show_status('Data saved to {}'.format(to_path))
def show(self, index): console.print_progress(index, self._total, start_time=self._start_time)
def _progress(count, block_size, total_size): console.clear_line() console.print_progress(count * block_size, total_size, start_time)