def epoch(self, subset, batch_size, shuffle=False): # TODO add a validation set ? if subset not in self.subset: raise KeyError( 'Unknown subset "%s", valid options are %s' % (subset, list(self.subset.keys()))) signals_li, phonemes_li, texts_li = self.subset[subset] tot_size = len(signals_li) batch_size_ = batch_size assert tot_size == len(phonemes_li) assert tot_size == len(texts_li) if shuffle: idx_li = np.random.permutation(tot_size) else: idx_li = np.arange(tot_size) for i in range(0, tot_size-batch_size_, batch_size_): signals_batch_li = [signals_li[j] for j in idx_li[i:i+batch_size_]] texts_batch_li = [texts_li[j] for j in idx_li[i:i+batch_size_]] sig_len = max(map(len, signals_batch_li)) txt_len = max(map(len, texts_batch_li)) signals_batch = np.stack( [utils.random_zeropad(s, sig_len-len(s), axis=-2) for s in signals_batch_li]) text_indices = np.empty( (reduce(int.__add__, map(len, texts_batch_li)), 2), dtype=hparams.INTX) text_values = np.concatenate(texts_batch_li) idx = 0 for j, t in enumerate(texts_batch_li): l = len(t) text_indices[idx:idx+l, 0] = j text_indices[idx:idx+l, 1] = np.arange(l) idx += l text_shape = (batch_size_, txt_len) yield signals_batch, (text_indices, text_values, text_shape) if tot_size % batch_size_: signals_batch_li = [signals_li[j] for j in idx_li[-batch_size_:]] texts_batch_li = [texts_li[j] for j in idx_li[-batch_size_:]] sig_len = len(signals_li[-1]) txt_len = max(map(len, texts_li[-batch_size_:])) signals_batch = np.stack( [np.pad(s, ((0, sig_len-len(s)), (0, 0)), mode='constant') for s in signals_batch_li]) text_indices = np.empty( (reduce(int.__add__, map(len, texts_batch_li)), 2), dtype=hparams.INTX) text_values = np.concatenate(texts_batch_li) idx = 0 for i, t in enumerate(texts_batch_li): l = len(t) text_indices[idx:idx+l, 0] = i text_indices[idx:idx+l, 1] = np.arange(l) idx += l text_shape = (batch_size_, txt_len) yield signals_batch, (text_indices, text_values, text_shape)
def epoch(self, subset, batch_size, shuffle=False): dataset = self.subset[subset] handle = dataset.open() dset_size = self.h5file.attrs['split'][ dict(train=0, valid=1, test=2)[subset]][3] indices = np.arange( ((dset_size + batch_size - 1) // batch_size)*batch_size) indices %= dset_size if shuffle: np.random.shuffle(indices) req_itor = SequentialScheme( examples=indices, batch_size=batch_size).get_request_iterator() for req in req_itor: data_pt = dataset.get_data(handle, req) max_len = max(map(len, data_pt[0])) spectra_li = [utils.random_zeropad( x, max_len - len(x), axis=-2) for x in data_pt[0]] spectra = np.stack(spectra_li) yield (spectra,) dataset.close(handle)
def main(): global g_args, g_model, g_dataset parser = argparse.ArgumentParser() parser.add_argument('-n', '--name', default='UnnamedExperiment', help='name of experiment, affects checkpoint saves') parser.add_argument( '-m', '--mode', default='train', help='Mode, "train", "valid", "test", "demo" or "interactive"') parser.add_argument('-i', '--input-pfile', help='path to input model parameter file') parser.add_argument('-o', '--output-pfile', help='path to output model parameters file') parser.add_argument('-ne', '--num-epoch', type=int, default=10, help='number of training epoch') parser.add_argument('--no-save-on-epoch', action='store_true', help="don't save parameter after each epoch") parser.add_argument('--no-valid-on-epoch', action='store_true', help="don't sweep validation set after training epoch") parser.add_argument('-if', '--input-file', help='input WAV file for "demo" mode') parser.add_argument( '-ds', '--dataset', help='choose dataset to use, overrides hparams.DATASET_TYPE') parser.add_argument('-lr', '--learn-rate', help='Learn rate, overrides hparams.LR') parser.add_argument( '-tl', '--train-length', help='segment length during training, overrides hparams.MAX_TRAIN_LEN') parser.add_argument('-bs', '--batch-size', help='set batch size, overrides hparams.BATCH_SIZE') g_args = parser.parse_args() # TODO manage device # Do override from arguments if g_args.learn_rate is not None: hparams.LR = float(g_args.learn_rate) assert hparams.LR >= 0. if g_args.train_length is not None: hparams.MAX_TRAIN_LEN = int(g_args.train_length) assert hparams.MAX_TRAIN_LEN >= 2 if g_args.dataset is not None: hparams.DATASET_TYPE = g_args.dataset if g_args.batch_size is not None: hparams.BATCH_SIZE = int(g_args.batch_size) assert hparams.BATCH_SIZE > 0 stdout.write('Preparing dataset "%s" ... ' % hparams.DATASET_TYPE) stdout.flush() g_dataset = hparams.get_dataset()() g_dataset.install_and_load() stdout.write('done\n') stdout.flush() print('Encoder type: "%s"' % hparams.ENCODER_TYPE) print('Separator type: "%s"' % hparams.SEPARATOR_TYPE) print('Training estimator type: "%s"' % hparams.TRAIN_ESTIMATOR_METHOD) print('Inference estimator type: "%s"' % hparams.INFER_ESTIMATOR_METHOD) stdout.write('Building model ... ') stdout.flush() g_model = Model(name=g_args.name) if g_args.mode in ['demo', 'debug']: hparams.BATCH_SIZE = 1 print( '\n Warning: setting hparams.BATCH_SIZE to 1 for "demo" mode' '\n... ', end='') if g_args.mode == 'debug': hparams.DEBUG = True g_model.build() stdout.write('done\n') g_model.reset() if g_args.input_pfile is not None: stdout.write('Loading paramters from %s ... ' % g_args.input_pfile) g_model.load_params(g_args.input_pfile) stdout.write('done\n') stdout.flush() if g_args.mode == 'interactive': print('Now in interactive mode, you should run this with python -i') return elif g_args.mode == 'train': g_model.train(n_epoch=g_args.num_epoch, dataset=g_dataset) if g_args.output_pfile is not None: stdout.write('Saving parameters into %s ... ' % g_args.output_pfile) stdout.flush() g_model.save_params(g_args.output_pfile) stdout.write('done\n') stdout.flush() elif g_args.mode == 'test': g_model.test(g_dataset) elif g_args.mode == 'valid': g_model.test(g_dataset, 'valid', 'Valid') elif g_args.mode == 'demo': # prepare data point colors = np.asarray([ hsv_to_rgb(h, .95, .98) for h in np.arange(hparams.MAX_N_SIGNAL, dtype=np.float32) / hparams.MAX_N_SIGNAL ]) if g_args.input_file is None: filename = 'demo.wav' for src_signals in g_dataset.epoch('test', hparams.MAX_N_SIGNAL): break max_len = max(map(len, src_signals[0])) max_len += (-max_len) % hparams.LENGTH_ALIGN src_signals_li = [ utils.random_zeropad(x, max_len - len(x), axis=-2) for x in src_signals[0] ] src_signals = np.stack(src_signals_li) raw_mixture = np.sum(src_signals, axis=0) save_wavfile(filename, raw_mixture) true_mixture = np.log1p(np.abs(src_signals)) true_mixture = -np.einsum('nwh,nc->whc', true_mixture, colors) true_mixture /= np.min(true_mixture) else: filename = g_args.input_file raw_mixture = load_wavfile(g_args.input_file) true_mixture = np.log1p(np.abs(raw_mixture)) # run with inference mode and save results data_pt = (np.expand_dims(raw_mixture, 0), ) result = g_sess.run( g_model.infer_fetches, dict( zip(g_model.infer_feed_keys, data_pt + (hparams.DROPOUT_KEEP_PROB, )))) signals = result['signals'][0] filename, fileext = os.path.splitext(filename) for i, s in enumerate(signals): save_wavfile(filename + ('_separated_%d' % (i + 1)) + fileext, s) # visualize result if 'DISPLAY' not in os.environ: print('Warning: no display found, not generating plot') return import matplotlib.pyplot as plt signals = np.log1p(np.abs(signals)) signals = -np.einsum('nwh,nc->nwhc', signals, colors) signals /= np.min(signals) for i, s in enumerate(signals): plt.subplot(1, len(signals) + 2, i + 1) plt.imshow(np.log1p(np.abs(s))) fake_mixture = 0.9 * np.sum(signals, axis=0) # fake_mixture /= np.max(fake_mixture) plt.subplot(1, len(signals) + 2, len(signals) + 1) plt.imshow(fake_mixture) plt.subplot(1, len(signals) + 2, len(signals) + 2) plt.imshow(true_mixture) plt.show() elif g_args.mode == 'debug': import matplotlib.pyplot as plt for input_ in g_dataset.epoch('test', hparams.MAX_N_SIGNAL, shuffle=True): break max_len = max(map(len, input_[0])) max_len += (-max_len) % hparams.LENGTH_ALIGN input_li = [ utils.random_zeropad(x, max_len - len(x), axis=-2) for x in input_[0] ] input_ = np.expand_dims(np.stack(input_li), 0) data_pt = (input_, ) debug_data = g_sess.run( g_model.debug_fetches, dict(zip(g_model.debug_feed_keys, data_pt + (1., )))) debug_data['input'] = input_ scipy.io.savemat('debug/debug_data.mat', debug_data) print('Debug data written to debug/debug_data.mat') else: raise ValueError('Unknown mode "%s"' % g_args.mode)