Esempio n. 1
0
    def epoch(self, subset, batch_size, shuffle=False):
        # TODO add a validation set ?
        if subset not in self.subset:
            raise KeyError(
                'Unknown subset "%s", valid options are %s' %
                (subset, list(self.subset.keys())))
        signals_li, phonemes_li, texts_li = self.subset[subset]
        tot_size = len(signals_li)
        batch_size_ = batch_size
        assert tot_size == len(phonemes_li)
        assert tot_size == len(texts_li)
        if shuffle:
            idx_li = np.random.permutation(tot_size)
        else:
            idx_li = np.arange(tot_size)
        for i in range(0, tot_size-batch_size_, batch_size_):
            signals_batch_li = [signals_li[j] for j in idx_li[i:i+batch_size_]]
            texts_batch_li = [texts_li[j] for j in idx_li[i:i+batch_size_]]
            sig_len = max(map(len, signals_batch_li))
            txt_len = max(map(len, texts_batch_li))
            signals_batch = np.stack(
                [utils.random_zeropad(s, sig_len-len(s), axis=-2)
                    for s in signals_batch_li])
            text_indices = np.empty(
                (reduce(int.__add__, map(len, texts_batch_li)), 2),
                dtype=hparams.INTX)
            text_values = np.concatenate(texts_batch_li)

            idx = 0
            for j, t in enumerate(texts_batch_li):
                l = len(t)
                text_indices[idx:idx+l, 0] = j
                text_indices[idx:idx+l, 1] = np.arange(l)
                idx += l

            text_shape = (batch_size_, txt_len)
            yield signals_batch, (text_indices, text_values, text_shape)

        if tot_size % batch_size_:
            signals_batch_li = [signals_li[j] for j in idx_li[-batch_size_:]]
            texts_batch_li = [texts_li[j] for j in idx_li[-batch_size_:]]
            sig_len = len(signals_li[-1])
            txt_len = max(map(len, texts_li[-batch_size_:]))
            signals_batch = np.stack(
                [np.pad(s, ((0, sig_len-len(s)), (0, 0)), mode='constant')
                    for s in signals_batch_li])
            text_indices = np.empty(
                (reduce(int.__add__, map(len, texts_batch_li)), 2), dtype=hparams.INTX)
            text_values = np.concatenate(texts_batch_li)

            idx = 0
            for i, t in enumerate(texts_batch_li):
                l = len(t)
                text_indices[idx:idx+l, 0] = i
                text_indices[idx:idx+l, 1] = np.arange(l)
                idx += l

            text_shape = (batch_size_, txt_len)
            yield signals_batch, (text_indices, text_values, text_shape)
Esempio n. 2
0
 def epoch(self, subset, batch_size, shuffle=False):
     dataset = self.subset[subset]
     handle = dataset.open()
     dset_size = self.h5file.attrs['split'][
         dict(train=0, valid=1, test=2)[subset]][3]
     indices = np.arange(
         ((dset_size + batch_size - 1) // batch_size)*batch_size)
     indices %= dset_size
     if shuffle:
         np.random.shuffle(indices)
     req_itor = SequentialScheme(
         examples=indices, batch_size=batch_size).get_request_iterator()
     for req in req_itor:
         data_pt = dataset.get_data(handle, req)
         max_len = max(map(len, data_pt[0]))
         spectra_li = [utils.random_zeropad(
             x, max_len - len(x), axis=-2)
             for x in data_pt[0]]
         spectra = np.stack(spectra_li)
         yield (spectra,)
     dataset.close(handle)
Esempio n. 3
0
def main():
    global g_args, g_model, g_dataset
    parser = argparse.ArgumentParser()
    parser.add_argument('-n',
                        '--name',
                        default='UnnamedExperiment',
                        help='name of experiment, affects checkpoint saves')
    parser.add_argument(
        '-m',
        '--mode',
        default='train',
        help='Mode, "train", "valid", "test", "demo" or "interactive"')
    parser.add_argument('-i',
                        '--input-pfile',
                        help='path to input model parameter file')
    parser.add_argument('-o',
                        '--output-pfile',
                        help='path to output model parameters file')
    parser.add_argument('-ne',
                        '--num-epoch',
                        type=int,
                        default=10,
                        help='number of training epoch')
    parser.add_argument('--no-save-on-epoch',
                        action='store_true',
                        help="don't save parameter after each epoch")
    parser.add_argument('--no-valid-on-epoch',
                        action='store_true',
                        help="don't sweep validation set after training epoch")
    parser.add_argument('-if',
                        '--input-file',
                        help='input WAV file for "demo" mode')
    parser.add_argument(
        '-ds',
        '--dataset',
        help='choose dataset to use, overrides hparams.DATASET_TYPE')
    parser.add_argument('-lr',
                        '--learn-rate',
                        help='Learn rate, overrides hparams.LR')
    parser.add_argument(
        '-tl',
        '--train-length',
        help='segment length during training, overrides hparams.MAX_TRAIN_LEN')
    parser.add_argument('-bs',
                        '--batch-size',
                        help='set batch size, overrides hparams.BATCH_SIZE')
    g_args = parser.parse_args()

    # TODO manage device

    # Do override from arguments
    if g_args.learn_rate is not None:
        hparams.LR = float(g_args.learn_rate)
        assert hparams.LR >= 0.
    if g_args.train_length is not None:
        hparams.MAX_TRAIN_LEN = int(g_args.train_length)
        assert hparams.MAX_TRAIN_LEN >= 2
    if g_args.dataset is not None:
        hparams.DATASET_TYPE = g_args.dataset
    if g_args.batch_size is not None:
        hparams.BATCH_SIZE = int(g_args.batch_size)
        assert hparams.BATCH_SIZE > 0

    stdout.write('Preparing dataset "%s" ... ' % hparams.DATASET_TYPE)
    stdout.flush()
    g_dataset = hparams.get_dataset()()
    g_dataset.install_and_load()
    stdout.write('done\n')
    stdout.flush()

    print('Encoder type: "%s"' % hparams.ENCODER_TYPE)
    print('Separator type: "%s"' % hparams.SEPARATOR_TYPE)
    print('Training estimator type: "%s"' % hparams.TRAIN_ESTIMATOR_METHOD)
    print('Inference estimator type: "%s"' % hparams.INFER_ESTIMATOR_METHOD)

    stdout.write('Building model ... ')
    stdout.flush()
    g_model = Model(name=g_args.name)
    if g_args.mode in ['demo', 'debug']:
        hparams.BATCH_SIZE = 1
        print(
            '\n  Warning: setting hparams.BATCH_SIZE to 1 for "demo" mode'
            '\n... ',
            end='')
        if g_args.mode == 'debug':
            hparams.DEBUG = True
    g_model.build()
    stdout.write('done\n')

    g_model.reset()
    if g_args.input_pfile is not None:
        stdout.write('Loading paramters from %s ... ' % g_args.input_pfile)
        g_model.load_params(g_args.input_pfile)
        stdout.write('done\n')
    stdout.flush()

    if g_args.mode == 'interactive':
        print('Now in interactive mode, you should run this with python -i')
        return
    elif g_args.mode == 'train':
        g_model.train(n_epoch=g_args.num_epoch, dataset=g_dataset)
        if g_args.output_pfile is not None:
            stdout.write('Saving parameters into %s ... ' %
                         g_args.output_pfile)
            stdout.flush()
            g_model.save_params(g_args.output_pfile)
            stdout.write('done\n')
            stdout.flush()
    elif g_args.mode == 'test':
        g_model.test(g_dataset)
    elif g_args.mode == 'valid':
        g_model.test(g_dataset, 'valid', 'Valid')
    elif g_args.mode == 'demo':
        # prepare data point
        colors = np.asarray([
            hsv_to_rgb(h, .95, .98)
            for h in np.arange(hparams.MAX_N_SIGNAL, dtype=np.float32) /
            hparams.MAX_N_SIGNAL
        ])
        if g_args.input_file is None:
            filename = 'demo.wav'
            for src_signals in g_dataset.epoch('test', hparams.MAX_N_SIGNAL):
                break
            max_len = max(map(len, src_signals[0]))
            max_len += (-max_len) % hparams.LENGTH_ALIGN
            src_signals_li = [
                utils.random_zeropad(x, max_len - len(x), axis=-2)
                for x in src_signals[0]
            ]
            src_signals = np.stack(src_signals_li)
            raw_mixture = np.sum(src_signals, axis=0)
            save_wavfile(filename, raw_mixture)
            true_mixture = np.log1p(np.abs(src_signals))
            true_mixture = -np.einsum('nwh,nc->whc', true_mixture, colors)
            true_mixture /= np.min(true_mixture)
        else:
            filename = g_args.input_file
            raw_mixture = load_wavfile(g_args.input_file)
            true_mixture = np.log1p(np.abs(raw_mixture))

        # run with inference mode and save results
        data_pt = (np.expand_dims(raw_mixture, 0), )
        result = g_sess.run(
            g_model.infer_fetches,
            dict(
                zip(g_model.infer_feed_keys,
                    data_pt + (hparams.DROPOUT_KEEP_PROB, ))))
        signals = result['signals'][0]
        filename, fileext = os.path.splitext(filename)
        for i, s in enumerate(signals):
            save_wavfile(filename + ('_separated_%d' % (i + 1)) + fileext, s)

        # visualize result
        if 'DISPLAY' not in os.environ:
            print('Warning: no display found, not generating plot')
            return

        import matplotlib.pyplot as plt
        signals = np.log1p(np.abs(signals))
        signals = -np.einsum('nwh,nc->nwhc', signals, colors)
        signals /= np.min(signals)
        for i, s in enumerate(signals):
            plt.subplot(1, len(signals) + 2, i + 1)
            plt.imshow(np.log1p(np.abs(s)))
        fake_mixture = 0.9 * np.sum(signals, axis=0)
        # fake_mixture /= np.max(fake_mixture)
        plt.subplot(1, len(signals) + 2, len(signals) + 1)
        plt.imshow(fake_mixture)
        plt.subplot(1, len(signals) + 2, len(signals) + 2)
        plt.imshow(true_mixture)
        plt.show()
    elif g_args.mode == 'debug':
        import matplotlib.pyplot as plt
        for input_ in g_dataset.epoch('test',
                                      hparams.MAX_N_SIGNAL,
                                      shuffle=True):
            break
        max_len = max(map(len, input_[0]))
        max_len += (-max_len) % hparams.LENGTH_ALIGN
        input_li = [
            utils.random_zeropad(x, max_len - len(x), axis=-2)
            for x in input_[0]
        ]
        input_ = np.expand_dims(np.stack(input_li), 0)
        data_pt = (input_, )
        debug_data = g_sess.run(
            g_model.debug_fetches,
            dict(zip(g_model.debug_feed_keys, data_pt + (1., ))))
        debug_data['input'] = input_
        scipy.io.savemat('debug/debug_data.mat', debug_data)
        print('Debug data written to debug/debug_data.mat')
    else:
        raise ValueError('Unknown mode "%s"' % g_args.mode)