Exemple #1
0
 def make_strings(cls,
                  num,
                  unique=True,
                  exclusive=None,
                  embedded=False,
                  multiple=1,
                  verbose=False,
                  interleave=True):
     # Check input
     if exclusive is None: exclusive = []
     elif not isinstance(exclusive, list):
         raise TypeError('!! exclusive must be a list of Reber strings')
     # Make strings
     reber_list = []
     long_token = None
     for i in range(num):
         if interleave:
             long_token = 'T' if long_token in ('P', None) else 'P'
         while True:
             string = ReberGrammar(embedded,
                                   multiple=multiple,
                                   specification=long_token)
             if unique and string in reber_list: continue
             if string in exclusive: continue
             reber_list.append(string)
             break
         if verbose:
             console.clear_line()
             console.print_progress(i + 1, num)
     if verbose: console.clear_line()
     # Return a list of Reber string
     return reber_list
Exemple #2
0
 def recover_progress(self, start_time=None):
     # Print progress bar
     if self.th.progress_bar and self.th.round_length is not None:
         assert isinstance(self._training_set, TFRData)
         progress = self.th.round_progress
         assert progress is not None
         console.print_progress(progress=progress, start_time=start_time)
Exemple #3
0
def separate_verified_data(data_dir,
                           csv_path,
                           to_path,
                           verbose=True,
                           vname='audio_train_verified',
                           uvname='audio_train_unverified'):
    # Check path
    check_path(data_dir, create_path=False)
    check_path(csv_path, create_path=False)

    # Read raw csv file and split
    raw_csv = pd.read_csv(csv_path)
    verified_csv = raw_csv.loc[raw_csv[MANUALLY_VERIFIED] == 1]
    unverified_csv = raw_csv.loc[raw_csv[MANUALLY_VERIFIED] == 0]

    # Copy verified data
    for df, dn in ((verified_csv, vname), (unverified_csv, uvname)):
        check_path(to_path, dn, create_path=True)
        console.show_status('Generating {} data ... '.format(dn))
        num_files = len(df)
        for i, file_name in enumerate(df[FNAME]):
            from_file = os.path.join(data_dir, file_name)
            to_file = os.path.join(to_path, dn, file_name)
            sh.copyfile(from_file, to_file)
            if verbose: console.print_progress(i, num_files)
        file_name = to_path + dn + '.csv'
        df.to_csv(file_name, index=False)
        console.show_status('CSV file saved to {}'.format(file_name))
Exemple #4
0
 def _print_progress(self, epi, start_time, steps, **kwargs):
     """Use a awkward way to avoid IDE warning :("""
     console.clear_line()
     console.show_status(
         'Episode {} [{} total] {} steps, Time elapsed = {:.2f} sec'.format(
             epi, self.counter, steps,
             time.time() - start_time))
     console.print_progress(epi, kwargs.get('total'))
Exemple #5
0
 def _inter_cut(self, content, prompt='>>', start_time=None):
     # Show content
     console.show_status(content, symbol=prompt)
     # Print progress bar
     if self.th.progress_bar and self.th.round_length is not None:
         assert isinstance(self._training_set, TFRData)
         progress = self.th.round_progress
         assert progress is not None
         console.print_progress(progress=progress, start_time=start_time)
Exemple #6
0
    def _inter_cut(self, content, start_time=None):
        # If run on the cloud, do not show progress bar
        if not FLAGS.progress_bar:
            console.show_status(content)
            return

        console.clear_line()
        console.show_status(content)
        console.print_progress(progress=self._training_set.progress,
                               start_time=start_time)
Exemple #7
0
    def _snapshot(self, progress):
        if self._snapshot_function is None:
            return

        filename = 'train_{}_episode'.format(self.counter)
        fullname = "{}/{}".format(self.snapshot_dir, filename)
        self._snapshot_function(fullname)

        console.clear_line()
        console.write_line("[Snapshot] snapshot saved to {}".format(filename))
        console.print_progress(progress=progress)
Exemple #8
0
    def _generate_meta(self, data_dir):
        console.show_status('Scanning data directory ...')
        file_list = self._get_tfd_list(data_dir)

        # Scan directory
        num_files = len(file_list)
        for i, file_name in enumerate(file_list):
            data_set = self._load_data_set(file_name)
            self.files[os.path.basename(file_name)] = data_set.structure
            console.print_progress(i + 1, num_files)
            del data_set
Exemple #9
0
 def _check_data_files(self, data_dir):
     file_list = self._get_tfd_list(data_dir)
     console.show_status('Integrity checking ...')
     for i, f in enumerate(file_list):
         if os.path.basename(f) not in self.files.keys():
             raise AssertionError(
                 '!! Can not find {} in metadata'.format(f))
         console.print_progress(i, len(file_list))
     if len(self.files) != len(file_list):
         raise AssertionError(
             '!! {} files are expected but only {} are found'.format(
                 len(self.files), len(file_list)))
Exemple #10
0
  def _print_progress(self, epc, start_time, info_dict, **kwargs):
    # Generate loss string
    loss_strings = ['{} = {:.3f}'.format(k, info_dict[k])
                    for k in info_dict.keys()]
    loss_string = ', '.join(loss_strings)

    total_epoch = self._counter / self._training_set.batches_per_epoch
    if FLAGS.progress_bar: console.clear_line()
    console.show_status(
      'Epoch {} [{:.1f} Total] {}'.format(epc + 1, total_epoch, loss_string))
    if FLAGS.progress_bar:
      console.print_progress(progress=self._training_set.progress,
                             start_time=start_time)
Exemple #11
0
 def _synthesize(cls, size, L, N, fixed_length, verbose=False):
     features, targets = [], []
     for i in range(size):
         x, y = engine(L, N, fixed_length)
         features.append(x)
         targets.append(y)
         if verbose:
             console.clear_line()
             console.print_progress(i + 1, size)
     # Wrap data into a SequenceSet
     data_set = SequenceSet(features,
                            summ_dict={'targets': targets},
                            n_to_one=True,
                            name='TemporalOrder')
     return data_set
Exemple #12
0
 def _init_big_data(self, csv_path, lb_sheet_path):
     if csv_path is None or lb_sheet_path is None: return
     self.with_labels = True
     self.properties[self.DATA_INFO] = pd.read_csv(csv_path)
     self.properties[self.LABEL_INDEX] = pd.read_csv(lb_sheet_path)
     # Generate groups
     self.properties[self.GROUPS] = collections.OrderedDict()
     console.show_status('Generating group information ...')
     for i, label in enumerate(self.label_index[du.LABEL]):
         file_list = list(self.data_info.loc[self.data_info[du.LABEL] ==
                                             label][du.FNAME])
         self.groups[label] = sorted(file_list,
                                     key=self._get_length,
                                     reverse=True)
         console.print_progress(i, len(self.label_index))
     console.show_status('Group population:')
     console.pprint(self.group_population)
Exemple #13
0
    def _training_match(self, agent, rounds, progress, rate_thresh):
        # TODO: inference graph is hard to build under this frame => compromise
        if self._opponent is None:
            return
        assert isinstance(agent, FMDPAgent)

        console.clear_line()
        title = 'Training match with {}'.format(self._opponent.player_name)
        rate = self.compete(agent, rounds, self._opponent, title=title)
        if rate >= rate_thresh and isinstance(self._opponent, TDPlayer):
            # Find an stronger opponent
            self._opponent._load()
            self._opponent.player_name = 'Shadow_{}'.format(
                self._opponent.counter)
            console.show_status('Opponent updated')

        console.print_progress(progress=progress)
Exemple #14
0
def convert_to_tframe_files(data_dir,
                            fs,
                            csv_path=None,
                            label_path=None,
                            to_dir=None,
                            verbose=True):
    # Check paths
    check_path(data_dir, create_path=False)
    if csv_path is not None: check_path(csv_path, create_path=False)
    if label_path is not None: check_path(label_path, create_path=False)
    if to_dir is None:
        to_dir = os.path.join(
            os.path.dirname(data_dir),
            'tfd_{}_{}Hz'.format(os.path.basename(data_dir), fs))
    check_path(to_dir, create_path=True)

    # Generate file list
    srcs, dsts = [], []
    for f in os.listdir(data_dir):
        file_path = os.path.join(data_dir, f)
        if not os.path.isfile(file_path) or f[-4:] != '.wav': continue
        srcs.append(file_path)
        dsts.append(os.path.join(to_dir, f))

    # Load csv file and label sheet if necessary
    csv, lb_sheet = None, None
    if csv_path is not None and label_path is not None:
        csv = pd.read_csv(csv_path)
        lb_sheet = pd.read_csv(label_path)

    # Wrap each .wav file in data_dir into a SignalSet
    num_files = len(srcs)
    console.show_status('Converting {} ...'.format(data_dir))
    for i, src, dst in zip(range(num_files), srcs, dsts):
        s = wav_to_signal(src, fs)
        ss = SignalSet(s, name=os.path.basename(src))
        if csv is not None:
            fname = os.path.basename(src)
            ss.data_dict[pedia.labels] = [_get_one_hot(fname, csv, lb_sheet)]

        ss.save(dst)
        if verbose: console.print_progress(i + 1, num_files)

    # Show status
    console.show_status('Data saved to {}'.format(to_dir))
Exemple #15
0
 def make_strings(cls, num, unique=True, exclusive=None, embedded=False,
                  verbose=False):
   # Check input
   if exclusive is None: exclusive = []
   elif not isinstance(exclusive, list):
     raise TypeError('!! exclusive must be a list of Reber strings')
   # Make strings
   reber_list = []
   for i in range(num):
     while True:
       string = ReberGrammar(embedded)
       if unique and string in reber_list: continue
       if string in exclusive: continue
       reber_list.append(string)
       break
     if verbose:
       console.clear_line()
       console.print_progress(i + 1, num)
   if verbose: console.clear_line()
   # Return a list of Reber string
   return reber_list
Exemple #16
0
    def merge_to_signal_set(self, save_as=None):
        signals = []
        onehot_labels = []
        file_names = []
        groups = []

        # Put each signal together with its one-hot label into lists
        for wav_names in self.groups.values():
            indices = []
            for wav_name in wav_names:
                file_path = os.path.join(self.data_dir, wav_name + '.tfds')
                # Load signal set from disk
                data_set = self._load_data_set(file_path)
                assert isinstance(data_set, SignalSet)
                assert len(data_set.signals) == 1 and len(
                    data_set[pedia.labels]) == 1
                # Append signal and label to corresponding list
                signals.append(data_set.signals[0])
                onehot_labels.append(data_set[pedia.labels][0])
                file_names.append(wav_name)
                # Print progress
                console.print_progress(len(signals), len(self.files))
                indices.append(len(signals) - 1)

            groups.append(indices)

        # Wrap data_dict and properties
        data_dict, properties = {}, {}
        data_dict[pedia.labels] = onehot_labels
        properties[self.LABEL_INDEX] = self.label_index
        properties[self.DATA_INFO] = self.data_info
        properties[self.WAV_NAMES] = file_names
        properties[self.GROUPS] = groups

        # Save and return
        gpat_signal_set = GPATSignalSet(signals,
                                        data_dict=data_dict,
                                        **properties)
        if save_as is not None: gpat_signal_set.save(save_as)
        return gpat_signal_set
Exemple #17
0
def down_sample(data_dir, sample_rate, to_path=None, verbose=True):
    """ Common sample frequency list:
   8000 Hz - fs for telephone
  11025 Hz -
  22050 Hz - fs for radio
  32000 Hz - fs for miniDV
  44100 Hz - fs for CD
  """
    # Check data directory
    check_path(data_dir, create_path=False)

    # Check to_path
    if to_path is None:
        to_path = os.path.join(
            os.path.dirname(data_dir),
            '{}_{}Hz'.format(os.path.basename(data_dir), sample_rate))
    check_path(to_path, create_path=True)

    # Generate file list
    srcs, dsts = [], []
    for f in os.listdir(data_dir):
        file_path = os.path.join(data_dir, f)
        if not os.path.isfile(file_path) or f[-4:] != '.wav': continue
        srcs.append(file_path)
        dsts.append(os.path.join(to_path, f))

    # Down sample each .wav file in data_dir
    num_files = len(srcs)
    console.show_status('Down sampling ...')
    for i, src, dst in zip(range(num_files), srcs, dsts):
        data, _ = librosa.core.load(src, sample_rate, res_type='kaiser_fast')
        librosa.output.write_wav(dst, data, sample_rate)
        if verbose: console.print_progress(i + 1, num_files)

    # Show status
    console.show_status('Data saved to {}'.format(to_path))
Exemple #18
0
 def show(self, index):
     console.print_progress(index, self._total, start_time=self._start_time)
Exemple #19
0
 def _progress(count, block_size, total_size):
     console.clear_line()
     console.print_progress(count * block_size, total_size, start_time)