def testEncodeDecode(self):
        x = np.linspace(-1, 1, 1000).astype(np.float32)
        channels = 256

        # Test whether decoded signal is roughly equal to
        # what was encoded before
        with self.test_session() as sess:
            encoded = mu_law_encode(x, channels)
            x1 = sess.run(mu_law_decode(encoded, channels))

        self.assertAllClose(x, x1, rtol=1e-1, atol=0.05)

        # Make sure that re-encoding leaves the waveform invariant
        with self.test_session() as sess:
            encoded = mu_law_encode(x1, channels)
            x2 = sess.run(mu_law_decode(encoded, channels))

        self.assertAllClose(x1, x2)
    def testDecodeZeros(self):
        np.random.seed(40)

        channels = 128
        number_of_samples = 100
        x = np.zeros(number_of_samples)
        y = manual_mu_law_encode(x, channels)
        decoded_manual = manual_mu_law_decode(y, channels)

        with self.test_session() as sess:
            decode = sess.run(mu_law_decode(y, channels))

        self.assertAllEqual(decoded_manual, decode)
    def testDecodeUniformRandomNoise(self):
        np.random.seed(40)

        channels = 128
        number_of_samples = 512
        x = np.random.uniform(-1, 1, number_of_samples)
        y = manual_mu_law_encode(x, channels)
        decoded_manual = manual_mu_law_decode(y, channels)

        with self.test_session() as sess:
            decode = sess.run(mu_law_decode(y, channels))

        self.assertAllEqual(decoded_manual, decode)
    def testDecodeUniformRandomNoise(self):
        np.random.seed(1944)  # For repeatability of test.

        channels = 256
        number_of_samples = 10
        x = np.random.uniform(-1, 1, number_of_samples).astype(np.float32)
        y = manual_mu_law_encode(x, channels)
        manual_decode = manual_mu_law_decode(y, channels)

        with self.test_session() as sess:
            decode = sess.run(mu_law_decode(y, channels))

        self.assertAllEqual(manual_decode, decode)
    def testDecodeRamp(self):
        np.random.seed(40)

        channels = 128
        number_of_samples = 512
        number_of_steps = 2.0 / number_of_samples
        x = np.arange(-1.0, 1.0, number_of_steps)
        y = manual_mu_law_encode(x, channels)
        decoded_manual = manual_mu_law_decode(y, channels)

        with self.test_session() as sess:
            decode = sess.run(mu_law_decode(y, channels))

        self.assertAllEqual(decoded_manual, decode)
    def testDecodeRandomConstant(self):
        np.random.seed(40)

        channels = 128
        number_of_samples = 512
        x = np.zeros(number_of_samples)
        x.fill(np.random.uniform(-1, 1))
        y = manual_mu_law_encode(x, channels)
        decoded_manual = manual_mu_law_decode(y, channels)

        with self.test_session() as sess:
            decode = sess.run(mu_law_decode(y, channels))

        self.assertAllEqual(decoded_manual, decode)
Ejemplo n.º 7
0
def main():
    args = get_arguments()
    logdir = os.path.join(args.logdir, 'train', str(datetime.now()))
    with open(args.wavenet_params, 'r') as config_file:
        wavenet_params = json.load(config_file)

    sess = tf.Session()

    net = WaveNet(
        batch_size=1,
        dilations=wavenet_params['dilations'],
        filter_width=wavenet_params['filter_width'],
        residual_channels=wavenet_params['residual_channels'],
        dilation_channels=wavenet_params['dilation_channels'],
        quantization_channels=wavenet_params['quantization_channels'],
        skip_channels=wavenet_params['skip_channels'],
        use_biases=wavenet_params['use_biases'],
        fast_generation=args.fast_generation)

    samples = tf.placeholder(tf.int32)

    next_sample = net.predict_proba(samples)

    if args.fast_generation:
        sess.run(tf.initialize_all_variables())
        sess.run(net.init_ops)

    variables_to_restore = {
        var.name[:-2]: var
        for var in tf.all_variables()
        if not ('state_buffer' in var.name or 'pointer' in var.name)
    }
    saver = tf.train.Saver(variables_to_restore)

    print('Restoring model from {}'.format(args.checkpoint))
    saver.restore(sess, args.checkpoint)

    decode = mu_law_decode(samples, wavenet_params['quantization_channels'])

    quantization_channels = wavenet_params['quantization_channels']
    if args.wav_seed:
        seed = create_seed(args.wav_seed, wavenet_params['sample_rate'],
                           quantization_channels)
        waveform = sess.run(seed).tolist()
    else:
        waveform = np.random.randint(quantization_channels,
                                     size=(1, )).tolist()

    for step in range(args.samples):
        if args.fast_generation:
            window = waveform[-1]
            outputs = [next_sample]
            outputs.extend(net.push_ops)
        else:
            if len(waveform) > args.window:
                window = waveform[-args.window:]
            else:
                window = waveform
            outputs = [next_sample]

        prediction = sess.run(outputs, feed_dict={samples: window})[0]
        sample = np.random.choice(np.arange(quantization_channels),
                                  p=prediction)
        waveform.append(sample)
        print('Sample {:3<d}/{:3<d}: {}'.format(step + 1, args.samples,
                                                sample))

        if (args.wav_out_path and args.save_every
                and (step + 1) % args.save_every == 0):

            out = sess.run(decode, feed_dict={samples: waveform})
            write_wav(out, wavenet_params['sample_rate'], args.wav_out_path)

    datestring = str(datetime.now()).replace(' ', 'T')
    writer = tf.train.SummaryWriter(
        os.path.join(logdir, 'generation', datestring))
    tf.audio_summary('generated', decode, wavenet_params['sample_rate'])
    summaries = tf.merge_all_summaries()

    summary_out = sess.run(summaries,
                           feed_dict={samples: np.reshape(waveform, [-1, 1])})
    writer.add_summary(summary_out)

    if args.wav_out_path:
        out = sess.run(decode, feed_dict={samples: waveform})
        write_wav(out, wavenet_params['sample_rate'], args.wav_out_path)

    print('Finished generating. The result can be viewed in TensorBoard.')
    def testDecodeNegativeDilation(self):
        channels = 10
        y = [0, 255, 243, 31, 156, 229, 0, 235, 202, 18]

        with self.test_session() as sess:
            self.assertRaises(TypeError, sess.run(mu_law_decode(y, channels)))