def generate_waveform(self, sess):
        samples = tf.placeholder(tf.int32)
        next_sample_probs = self.net.predict_proba_all(samples)
        operations = [next_sample_probs]

        waveform = []
        seed = create_seed("sine_train.wav",
                           SAMPLE_RATE_HZ,
                           QUANTIZATION_CHANNELS,
                           window_size=WINDOW_SIZE,
                           silence_threshold=0)
        input_waveform = sess.run(seed).tolist()
        decode = mu_law_decode(samples, QUANTIZATION_CHANNELS)
        slide_windows = 256
        for slide_start in range(0, len(input_waveform), slide_windows):
            if slide_start + slide_windows >= len(input_waveform):
                break
            input_audio_window = input_waveform[slide_start:slide_start +
                                                slide_windows]

            # Run the WaveNet to predict the next sample.
            all_prediction = sess.run(operations,
                                      feed_dict={samples:
                                                 input_audio_window})[0]
            all_prediction = np.asarray(all_prediction)
            output_waveform = get_all_output_from_predictions(all_prediction)
            print("Prediction {}".format(output_waveform))
            waveform.extend(output_waveform)

        waveform = np.array(waveform[:])
        decoded_waveform = sess.run(decode, feed_dict={samples: waveform})
        return decoded_waveform
示例#2
0
def generate_waveform(sess, net, fast_generation):
    samples = tf.placeholder(tf.int32)
    if fast_generation:
        next_sample_probs = net.predict_proba_incremental(samples)
        sess.run(net.init_ops)
        operations = [next_sample_probs]
        operations.extend(net.push_ops)
    else:
        next_sample_probs = net.predict_proba(samples)
        operations = [next_sample_probs]

    waveform = [128]
    decode = mu_law_decode(samples, QUANTIZATION_CHANNELS)
    for i in range(GENERATE_SAMPLES):
        if fast_generation:
            window = waveform[-1]
        else:
            if len(waveform) > 256:
                window = waveform[-256:]
            else:
                window = waveform

        # Run the WaveNet to predict the next sample.
        prediction = sess.run(operations, feed_dict={samples: window})[0]
        sample = np.random.choice(
           np.arange(QUANTIZATION_CHANNELS), p=prediction)
        waveform.append(sample)
        # print("Generated {} of {}: {}".format(i, GENERATE_SAMPLES, sample))
        # sys.stdout.flush()

    # Skip the first number of samples equal to the size of the receptive
    # field.
    waveform = np.array(waveform[256:])
    decoded_waveform = sess.run(decode, feed_dict={samples: waveform})
    return decoded_waveform
示例#3
0
def generate_waveform(sess, net, fast_generation, gc, samples_placeholder,
                      gc_placeholder, operations):
    waveform = [128] * net.receptive_field
    if fast_generation:
        for sample in waveform[:-1]:
            sess.run(operations, feed_dict={samples_placeholder: [sample]})

    for i in range(GENERATE_SAMPLES):
        if i % 100 == 0:
            print("Generating {} of {}.".format(i, GENERATE_SAMPLES))
            sys.stdout.flush()
        if fast_generation:
            window = waveform[-1]
        else:
            if len(waveform) > net.receptive_field:
                window = waveform[-net.receptive_field:]
            else:
                window = waveform

        # Run the WaveNet to predict the next sample.
        feed_dict = {samples_placeholder: window}
        if gc is not None:
            feed_dict[gc_placeholder] = gc
        results = sess.run(operations, feed_dict=feed_dict)

        sample = np.random.choice(np.arange(QC), p=results[0])
        waveform.append(sample)

    # Skip the first number of samples equal to the size of the receptive
    # field minus one.
    waveform = np.array(waveform[net.receptive_field - 1:])
    decode = mu_law_decode(samples_placeholder, QC)
    decoded_waveform = sess.run(decode,
                                feed_dict={samples_placeholder: waveform})
    return decoded_waveform
示例#4
0
def generate_waveform(sess, net, fast_generation, gc, samples_placeholder,
                      gc_placeholder, operations):
    waveform = [128]
    results = []
    for i in range(GENERATE_SAMPLES):
        if i % 100 == 0:
            print("Generating {} of {}.".format(i, GENERATE_SAMPLES))
            sys.stdout.flush()
        if fast_generation:
            window = waveform[-1]
        else:
            if len(waveform) > RECEPTIVE_FIELD:
                # Just keep the last 256 items (exceeds receptive field size)
                window = waveform[-RECEPTIVE_FIELD:]
            else:
                window = waveform

        # Run the WaveNet to predict the next sample.
        feed_dict = {samples_placeholder: window}
        if gc is not None:
            feed_dict[gc_placeholder] = gc
        results = sess.run(operations, feed_dict=feed_dict)

        sample = np.random.choice(np.arange(QUANTIZATION_CHANNELS),
                                  p=results[0])
        waveform.append(sample)

    # Skip the first number of samples equal to the size of the receptive
    # field.
    waveform = waveform[RECEPTIVE_FIELD:]
    decode = mu_law_decode(samples_placeholder, QUANTIZATION_CHANNELS)
    decoded_waveform = sess.run(decode,
                                feed_dict={samples_placeholder: waveform})
    return decoded_waveform
    def testEncodeDecode(self):
        x = np.linspace(-1, 1, 1000).astype(np.float32)
        channels = 256

        # Test whether decoded signal is roughly equal to
        # what was encoded before
        with self.test_session() as sess:
            encoded = mu_law_encode(x, channels)
            x1 = sess.run(mu_law_decode(encoded, channels))

        self.assertAllClose(x, x1, rtol=1e-1, atol=0.05)

        # Make sure that re-encoding leaves the waveform invariant
        with self.test_session() as sess:
            encoded = mu_law_encode(x1, channels)
            x2 = sess.run(mu_law_decode(encoded, channels))

        self.assertAllClose(x1, x2)
示例#6
0
    def testEncodeDecode(self):
        x = np.linspace(-1, 1, 1000).astype(np.float32)
        channels = 256

        # Test whether decoded signal is roughly equal to
        # what was encoded before
        with self.test_session() as sess:
            encoded = mu_law_encode(x, channels)
            x1 = sess.run(mu_law_decode(encoded, channels))

        self.assertAllClose(x, x1, rtol=1e-1, atol=0.05)

        # Make sure that re-encoding leaves the waveform invariant
        with self.test_session() as sess:
            encoded = mu_law_encode(x1, channels)
            x2 = sess.run(mu_law_decode(encoded, channels))

        self.assertAllClose(x1, x2)
def write_audio_not_one_hot(
        filename,
        audio,
        session,
        sample_rate=get_model_params('SAMPLE_RATE'),
        quantization_channels=get_model_params('QUANTIZATION_CHANNELS'),
        verbose=False):
    out = mu_law_decode(audio, quantization_channels)
    out_wave = session.run(out)
    write_wav(out_wave, os.path.join(get_dirs('OUTPUT'), filename),
              sample_rate, verbose)
示例#8
0
    def testDecodeUniformRandomNoise(self):
        np.random.seed(1944)  # For repeatability of test.

        channels = 256
        number_of_samples = 10
        x = np.random.uniform(-1, 1, number_of_samples).astype(np.float32)
        y = manual_mu_law_encode(x, channels)
        manual_decode = manual_mu_law_decode(y, channels)

        with self.test_session() as sess:
            decode = sess.run(mu_law_decode(y, channels))

        self.assertAllEqual(manual_decode, decode)
    def testDecodeUniformRandomNoise(self):
        np.random.seed(40)

        channels = 128
        number_of_samples = 512
        x = np.random.uniform(-1, 1, number_of_samples)
        y = manual_mu_law_encode(x, channels)
        decoded_manual = manual_mu_law_decode(y, channels)

        with self.test_session() as sess:
            decode = sess.run(mu_law_decode(y, channels))

        self.assertAllEqual(decoded_manual, decode)
示例#10
0
    def testDecodeZeros(self):
        np.random.seed(40)

        channels = 128
        number_of_samples = 100
        x = np.zeros(number_of_samples)
        y = manual_mu_law_encode(x, channels)
        decoded_manual = manual_mu_law_decode(y, channels)

        with self.test_session() as sess:
            decode = sess.run(mu_law_decode(y, channels))

        self.assertAllEqual(decoded_manual, decode)
示例#11
0
    def testDecodeZeros(self):
        np.random.seed(40)

        channels = 128
        number_of_samples = 100
        x = np.zeros(number_of_samples)
        y = manual_mu_law_encode(x, channels)
        decoded_manual = manual_mu_law_decode(y, channels)

        with self.test_session() as sess:
            decode = sess.run(mu_law_decode(y, channels))

        self.assertAllEqual(decoded_manual, decode)
示例#12
0
    def testDecodeUniformRandomNoise(self):
        np.random.seed(1944)  # For repeatability of test.

        channels = 256
        number_of_samples = 10
        x = np.random.uniform(-1, 1, number_of_samples).astype(np.float32)
        y = manual_mu_law_encode(x, channels)
        manual_decode = manual_mu_law_decode(y, channels)

        with self.test_session() as sess:
            decode = sess.run(mu_law_decode(y, channels))

        self.assertAllEqual(manual_decode, decode)
示例#13
0
    def testDecodeUniformRandomNoise(self):
        np.random.seed(40)

        channels = 128
        number_of_samples = 512
        x = np.random.uniform(-1, 1, number_of_samples)
        y = manual_mu_law_encode(x, channels)
        decoded_manual = manual_mu_law_decode(y, channels)

        with self.test_session() as sess:
            decode = sess.run(mu_law_decode(y, channels))

        self.assertAllEqual(decoded_manual, decode)
示例#14
0
    def testDecodeRamp(self):
        np.random.seed(40)

        channels = 128
        number_of_samples = 512
        number_of_steps = 2.0 / number_of_samples
        x = np.arange(-1.0, 1.0, number_of_steps)
        y = manual_mu_law_encode(x, channels)
        decoded_manual = manual_mu_law_decode(y, channels)

        with self.test_session() as sess:
            decode = sess.run(mu_law_decode(y, channels))

        self.assertAllEqual(decoded_manual, decode)
示例#15
0
    def testDecodeRandomConstant(self):
        np.random.seed(40)

        channels = 128
        number_of_samples = 512
        x = np.zeros(number_of_samples)
        x.fill(np.random.uniform(-1, 1))
        y = manual_mu_law_encode(x, channels)
        decoded_manual = manual_mu_law_decode(y, channels)

        with self.test_session() as sess:
            decode = sess.run(mu_law_decode(y, channels))

        self.assertAllEqual(decoded_manual, decode)
示例#16
0
    def testDecodeRandomConstant(self):
        np.random.seed(40)

        channels = 128
        number_of_samples = 512
        x = np.zeros(number_of_samples)
        x.fill(np.random.uniform(-1, 1))
        y = manual_mu_law_encode(x, channels)
        decoded_manual = manual_mu_law_decode(y, channels)

        with self.test_session() as sess:
            decode = sess.run(mu_law_decode(y, channels))

        self.assertAllEqual(decoded_manual, decode)
示例#17
0
    def testDecodeRamp(self):
        np.random.seed(40)

        channels = 128
        number_of_samples = 512
        number_of_steps = 2.0 / number_of_samples
        x = np.arange(-1.0, 1.0, number_of_steps)
        y = manual_mu_law_encode(x, channels)
        decoded_manual = manual_mu_law_decode(y, channels)

        with self.test_session() as sess:
            decode = sess.run(mu_law_decode(y, channels))

        self.assertAllEqual(decoded_manual, decode)
示例#18
0
    def testDecodeEncode(self):
        # generate every possible quantized level.
        x = np.array(range(QUANT_LEVELS), dtype=np.int)

        # Encoded then decode every value.
        with self.test_session() as sess:
            # Decode into floating-point scalar.
            decoded = mu_law_decode(x, QUANT_LEVELS)
            # Encode back into an integer quantization level.
            encoded = mu_law_encode(decoded, QUANT_LEVELS)
            round_tripped = sess.run(encoded)

        # decoding then encoding every level should produce what we started
        # with.
        self.assertAllEqual(x, round_tripped)
示例#19
0
    def testDecodeEncode(self):
        # generate every possible quantized level.
        x = np.array(range(QUANT_LEVELS), dtype=np.int)

        # Encoded then decode every value.
        with self.test_session() as sess:
            # Decode into floating-point scalar.
            decoded = mu_law_decode(x, QUANT_LEVELS)
            # Encode back into an integer quantization level.
            encoded = mu_law_encode(decoded, QUANT_LEVELS)
            round_tripped = sess.run(encoded)

        # decoding then encoding every level should produce what we started
        # with.
        self.assertAllEqual(x, round_tripped)
示例#20
0
    def testMinMaxRange(self):
        # Generate every possible quantized level.
        x = np.array(range(QUANT_LEVELS), dtype=np.int)

        # Decode back into float scalars.
        with self.test_session() as sess:
            # Decode into floating-point scalar.
            decoded = mu_law_decode(x, QUANT_LEVELS)
            all_scalars = sess.run(decoded)

        # Our range should be exactly [-1,1].
        max_val = np.max(all_scalars)
        min_val = np.min(all_scalars)
        EPSILON = 1e-10
        self.assertNear(max_val, 1.0, EPSILON)
        self.assertNear(min_val, -1.0, EPSILON)
示例#21
0
    def testMinMaxRange(self):
        # Generate every possible quantized level.
        x = np.array(range(QUANT_LEVELS), dtype=np.int)

        # Decode back into float scalars.
        with self.test_session() as sess:
            # Decode into floating-point scalar.
            decoded = mu_law_decode(x, QUANT_LEVELS)
            all_scalars = sess.run(decoded)

        # Our range should be exactly [-1,1].
        max_val = np.max(all_scalars)
        min_val = np.min(all_scalars)
        EPSILON = 1e-10
        self.assertNear(max_val, 1.0, EPSILON)
        self.assertNear(min_val, -1.0, EPSILON)
示例#22
0
    def testEncodeDecodeShift(self):
        x = np.linspace(-1, 1, 1000).astype(np.float32)
        with self.test_session() as sess:
            encoded = mu_law_encode(x, QUANT_LEVELS)
            decoded = mu_law_decode(encoded, QUANT_LEVELS)
            roundtripped = sess.run(decoded)

        # Detect non-unity scaling and non-zero shift in the roundtripped
        # signal by asserting that slope = 1 and y-intercept = 0 of line fit to
        # roundtripped vs x values.
        coeffs = np.polyfit(x, roundtripped, 1)
        slope = coeffs[0]
        y_intercept = coeffs[1]
        EPSILON = 1e-4
        self.assertNear(slope, 1.0, EPSILON)
        self.assertNear(y_intercept, 0.0, EPSILON)
示例#23
0
    def testEncodeDecodeShift(self):
        x = np.linspace(-1, 1, 1000).astype(np.float32)
        with self.test_session() as sess:
            encoded = mu_law_encode(x, QUANT_LEVELS)
            decoded = mu_law_decode(encoded, QUANT_LEVELS)
            roundtripped = sess.run(decoded)

        # Detect non-unity scaling and non-zero shift in the roundtripped
        # signal by asserting that slope = 1 and y-intercept = 0 of line fit to
        # roundtripped vs x values.
        coeffs = np.polyfit(x, roundtripped, 1)
        slope = coeffs[0]
        y_intercept = coeffs[1]
        EPSILON = 1e-4
        self.assertNear(slope, 1.0, EPSILON)
        self.assertNear(y_intercept, 0.0, EPSILON)
示例#24
0
def generate_waveform(sess, net, fast_generation, gc, samples_placeholder,
                      gc_placeholder, operations):
    waveform = [128] * net.receptive_field
    if fast_generation:
        for sample in waveform[:-1]:
            sess.run(operations, feed_dict={samples_placeholder: [sample]})

    for i in range(GENERATE_SAMPLES):
        if i % 100 == 0:
            print("Generating {} of {}.".format(i, GENERATE_SAMPLES))
            sys.stdout.flush()
        if fast_generation:
            window = waveform[-1]
        else:
            if len(waveform) > net.receptive_field:
                window = waveform[-net.receptive_field:]
            else:
                window = waveform

        # Run the WaveNet to predict the next sample.
        feed_dict = {samples_placeholder: window}
        if gc is not None:
            feed_dict[gc_placeholder] = gc
        results = sess.run(operations, feed_dict=feed_dict)

        sample = np.random.choice(
           np.arange(QUANTIZATION_CHANNELS), p=results[0])
        waveform.append(sample)

    # Skip the first number of samples equal to the size of the receptive
    # field minus one.
    waveform = np.array(waveform[net.receptive_field - 1:])
    decode = mu_law_decode(samples_placeholder, QUANTIZATION_CHANNELS)
    decoded_waveform = sess.run(decode,
                                feed_dict={samples_placeholder: waveform})
    return decoded_waveform
示例#25
0
    def testDecodeNegativeDilation(self):
        channels = 10
        y = [0, 255, 243, 31, 156, 229, 0, 235, 202, 18]

        with self.test_session() as sess:
            self.assertRaises(TypeError, sess.run(mu_law_decode(y, channels)))
示例#26
0
def main():
    args = get_arguments()
    started_datestring = "{0:%Y-%m-%dT%H-%M-%S}".format(datetime.now())
    logdir = os.path.join(args.logdir, 'generate', started_datestring)
    
    if args.wavenet_params.startswith('wavenet_params/'):
        with open(args.wavenet_params, 'r') as config_file:
            wavenet_params = json.load(config_file)
    elif args.wavenet_params.startswith('wavenet_params'):
        with open('wavenet_params/'+args.wavenet_params, 'r') as config_file:
            wavenet_params = json.load(config_file)
    else:
        with open('wavenet_params/wavenet_params_'+args.wavenet_params, 'r') as config_file:
            wavenet_params = json.load(config_file)    

    sess = tf.Session()

    net = WaveNetModel(
        batch_size=1,
        dilations=wavenet_params['dilations'],
        filter_width=wavenet_params['filter_width'],
        residual_channels=wavenet_params['residual_channels'],
        dilation_channels=wavenet_params['dilation_channels'],
        quantization_channels=wavenet_params['quantization_channels'],
        skip_channels=wavenet_params['skip_channels'],
        use_biases=wavenet_params['use_biases'],
        scalar_input=wavenet_params['scalar_input'],
        initial_filter_width=wavenet_params['initial_filter_width'],
        global_condition_channels=args.gc_channels,
        global_condition_cardinality=args.gc_cardinality)

    samples = tf.placeholder(tf.int32)

    if args.fast_generation:
        next_sample = net.predict_proba_incremental(samples, args.gc_id)
    else:
        next_sample = net.predict_proba(samples, args.gc_id)

    if args.fast_generation:
        sess.run(tf.global_variables_initializer())
        sess.run(net.init_ops)

    variables_to_restore = {
        var.name[:-2]: var for var in tf.global_variables()
        if not ('state_buffer' in var.name or 'pointer' in var.name)}
    saver = tf.train.Saver(variables_to_restore)

    print('Restoring model from {}'.format(args.checkpoint))
    saver.restore(sess, args.checkpoint)

    decode = mu_law_decode(samples, wavenet_params['quantization_channels'])

    quantization_channels = wavenet_params['quantization_channels']
    if args.wav_seed:
        seed = create_seed(args.wav_seed,
                           wavenet_params['sample_rate'],
                           quantization_channels,
                           net.receptive_field,
                           args.silence_threshold)
        waveform = sess.run(seed).tolist()
    else:
        # Silence with a single random sample at the end.
        waveform = [quantization_channels / 2] * (net.receptive_field - 1)
        waveform.append(np.random.randint(quantization_channels))

    if args.fast_generation and args.wav_seed:
        # When using the incremental generation, we need to
        # feed in all priming samples one by one before starting the
        # actual generation.
        # TODO This could be done much more efficiently by passing the waveform
        # to the incremental generator as an optional argument, which would be
        # used to fill the queues initially.
        outputs = [next_sample]
        outputs.extend(net.push_ops)

        print('Priming generation...')
        for i, x in enumerate(waveform[-net.receptive_field: -1]):
            if i % 100 == 0:
                print('Priming sample {}'.format(i))
            sess.run(outputs, feed_dict={samples: x})
        print('Done.')

    last_sample_timestamp = datetime.now()

        #frequency to period change
    if args.tfrequency is not None:
        if args.tperiod is not None:
            raise ValueError("Frequency and Period both assigned. Assign only one of them.")
        else:
            PERIOD = dynamic.frequency_to_period(args.tfrequency)
    else:
        if args.tperiod is not None:
            PERIOD = args.tperiod
        else:
            PERIOD = 1

    # generate an array of temperature for each step when called "dynamic" in tempurature-change variable
    if args.temperature_change == "dynamic" and args.tform is not None:
        temp_array = dynamic.generate_value(0, wavenet_params['sample_rate'], args.tform, args.tmin, args.tmax, PERIOD, args.samples, args.tphaseshift)


    for step in range(args.samples):
        if args.fast_generation:
            outputs = [next_sample]
            outputs.extend(net.push_ops)
            window = waveform[-1]
        else:
            if len(waveform) > net.receptive_field:
                window = waveform[-net.receptive_field:]
            else:
                window = waveform
            outputs = [next_sample]

        # Run the WaveNet to predict the next sample.
        prediction = sess.run(outputs, feed_dict={samples: window})[0]

        # Scale prediction distribution using temperature.
        np.seterr(divide='ignore')

        # temperature change
        if args.temperature_change == None: #static
            _temp_temperature = args.temperature
        elif args.temperature_change == "dynamic":
            if args.tform == None: #random
                if step % int(args.samples/5) == 0:
                    _temp_temperature = args.temperature * np.random.rand()
            else:
                _temp_temperature = temp_array[1][step]  
        else:
                raise Exception("wrong temperature_change value")

        scaled_prediction = np.log(prediction) / _temp_temperature
        scaled_prediction = (scaled_prediction -
                             np.logaddexp.reduce(scaled_prediction))
        scaled_prediction = np.exp(scaled_prediction)
        np.seterr(divide='warn')

        # Prediction distribution at temperature=1.0 should be unchanged after
        # scaling.
        if args.temperature == 1.0 and args.temperature_change == None:
            np.testing.assert_allclose(
                    prediction, scaled_prediction, atol=1e-5,
                    err_msg = 'Prediction scaling at temperature=1.0 is not working as intended.')

        sample = np.random.choice(
            np.arange(quantization_channels), p=scaled_prediction)

        
        # Show progress only once per second.
        current_sample_timestamp = datetime.now()
        time_since_print = current_sample_timestamp - last_sample_timestamp
        if time_since_print.total_seconds() > 1.:
            print('Sample {:3<d}/{:3<d}, temperature {:3<f}'.format(step + 1, args.samples, _temp_temperature),
                  end='\r')
            last_sample_timestamp = current_sample_timestamp

        # If we have partial writing, save the result so far.
        if (args.wav_out_path and args.save_every and
                (step + 1) % args.save_every == 0):
            out = sess.run(decode, feed_dict={samples: waveform})
            write_wav(out, wavenet_params['sample_rate'], args.wav_out_path)

    # Introduce a newline to clear the carriage return from the progress.
    print()

    # Save the result as an audio summary.
    datestring = str(datetime.now()).replace(' ', 'T')
    writer = tf.summary.FileWriter(logdir)
    tf.summary.audio('generated', decode, wavenet_params['sample_rate'])
    summaries = tf.summary.merge_all()
    summary_out = sess.run(summaries,
                           feed_dict={samples: np.reshape(waveform, [-1, 1])})
    writer.add_summary(summary_out)

    # Save the result as a wav file.
    if args.wav_out_path:
        out = sess.run(decode, feed_dict={samples: waveform})
        write_wav(out, wavenet_params['sample_rate'], args.wav_out_path)

    print('Finished generating. The result can be viewed in TensorBoard.')
示例#27
0
def generate_waveform(sess, net, fast_generation, wav_seed=False):
    samples = tf.placeholder(tf.int32)
    if fast_generation:
        next_sample_probs = net.predict_proba_incremental(samples)
        sess.run(net.init_ops)
        operations = [next_sample_probs]
        operations.extend(net.push_ops)
    else:
        next_sample_probs = net.predict_proba(samples)
        operations = [next_sample_probs]

    waveform = [128]
    if wav_seed:
        seed = create_seed("sine_train.wav",
                           SAMPLE_RATE_HZ,
                           QUANTIZATION_CHANNELS,
                           window_size=WINDOW_SIZE,
                           silence_threshold=0)
        input_waveform = sess.run(seed).tolist()
    decode = mu_law_decode(samples, QUANTIZATION_CHANNELS)
    for i in range(GENERATE_SAMPLES):
        print("=====================================================")
        if fast_generation:
            window = waveform[-1]
            if wav_seed and i < len(input_waveform):
                window = input_waveform[i]
        else:
            if len(waveform) > 256:
                window = waveform[-256:]
            else:
                window = waveform
            if wav_seed:
                if i >= len(input_waveform):
                    break
                if i - 256 >= 0:
                    f_window = input_waveform[i - 256:i]

                else:
                    f_window = input_waveform[:i]
                    print("Input {}".format(f_window))
                    # print(window)
                if len(f_window) == 0:
                    continue
                    # print(window)

        # Run the WaveNet to predict the next sample.
        all_prediction = sess.run([net.predict_proba_all(samples)],
                                  feed_dict={samples: input_waveform})[0]
        all_prediction = np.asarray(all_prediction)
        output_waveform = get_all_output_from_predictions(all_prediction)
        print("Prediction {}".format(output_waveform))
        decoded_waveform = sess.run(decode,
                                    feed_dict={samples: output_waveform})
        return decoded_waveform
        # prediction = sess.run(operations, feed_dict={samples: f_window})[0]
        # sample = np.random.choice(
        #     np.arange(QUANTIZATION_CHANNELS), p=prediction)
        # waveform.append(sample)

        # print("Generated {} of {}: {}".format(i, GENERATE_SAMPLES, sample))
        # sys.stdout.flush()

    # Skip the first number of samples equal to the size of the receptive
    # field.
    waveform = np.array(waveform[:])
    decoded_waveform = sess.run(decode, feed_dict={samples: waveform})
    return decoded_waveform
示例#28
0
def main(waveform, num_predictions):
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

    #args = get_arguments()
    started_datestring = "{0:%Y-%m-%dT%H-%M-%S}".format(datetime.now())
    #logdir = os.path.join(args.logdir, 'generate', started_datestring)
    with open(WAVENET_PARAMS, 'r') as config_file:
        wavenet_params = json.load(config_file)
    tf.reset_default_graph()
    config = tf.ConfigProto(
        device_count={'GPU': 0}  # todo this is only to generate test stuff on nonGPU
    )
    sess = tf.Session(config=config)

    # sess = tf.Session()

    net = WaveNetModel(
        batch_size=1,
        dilations=wavenet_params['dilations'],
        filter_width=wavenet_params['filter_width'],
        residual_channels=wavenet_params['residual_channels'],
        dilation_channels=wavenet_params['dilation_channels'],
        quantization_channels=wavenet_params['quantization_channels'],
        skip_channels=wavenet_params['skip_channels'],
        use_biases=wavenet_params['use_biases'],
        scalar_input=wavenet_params['scalar_input'],
        initial_filter_width=wavenet_params['initial_filter_width'],
        global_condition_channels=None,
        global_condition_cardinality=None)

    samples = tf.placeholder(tf.int32)


    # next_sample = net.predict_proba_incremental(samples, None) #fastgen
    next_sample = net.predict_proba(samples, None) #regulargen


    #sess.run(tf.global_variables_initializer())
    #sess.run(net.init_ops)

    variables_to_restore = {
        var.name[:-2]: var for var in tf.global_variables()
        if not ('state_buffer' in var.name or 'pointer' in var.name)}
    saver = tf.train.Saver(variables_to_restore)

    #print('Restoring model from {}'.format(args.checkpoint))
    #global INITIALIZED
    #if not INITIALIZED:
    #saver.restore(sess, "logdir/train/2017-06-05T19-03-11/model.ckpt-107")
    saver.restore(sess, "logdir/train/last144kwl2/model.ckpt-27050")
    #    INITIALIZED = True

    decode = mu_law_decode(samples, wavenet_params['quantization_channels'])

    quantization_channels = wavenet_params['quantization_channels']

    seed = create_seed(waveform,
                       wavenet_params['sample_rate'],
                       quantization_channels,
                       net.receptive_field)
    waveform = sess.run(seed).tolist()
    #print("priming size is " + str(net.receptive_field))  # TODO debug so I can figure out how long a sequence to feed


    # When using the incremental generation, we need to
    # feed in all priming samples one by one before starting the
    # actual generation.
    # TODO This could be done much more efficiently by passing the waveform
    # to the incremental generator as an optional argument, which would be
    # used to fill the queues initially.
    # outputs = [next_sample] begin fastgen commented section
    # outputs.extend(net.push_ops)
    #
    # # print('Priming generation...') # todo remove these prints for going into production
    # for i, x in enumerate(waveform[-net.receptive_field: -1]):
    #     # if i % 1000 == 0:
    #         # print('Priming sample {}'.format(i))
    #     sess.run(outputs, feed_dict={samples: x})

    last_sample_timestamp = datetime.now()
    for step in range(num_predictions):
        if False:
            outputs = [next_sample]
            outputs.extend(net.push_ops)
            window = waveform[-1]
        else:
            if len(waveform) > net.receptive_field:
                window = waveform[-net.receptive_field:]
            else:
                window = waveform
            outputs = [next_sample]

        # Run the WaveNet to predict the next sample.
        prediction = sess.run(outputs, feed_dict={samples: window})[0]

        # Scale prediction distribution using temperature.
        np.seterr(divide='ignore')
        scaled_prediction = np.log(prediction)
        scaled_prediction = (scaled_prediction -
                             np.logaddexp.reduce(scaled_prediction))
        scaled_prediction = np.exp(scaled_prediction)
        np.seterr(divide='warn')
        # plt.plot(range(len(scaled_prediction)), scaled_prediction) # todo debug tool, sharp peaks means operating correctly
        # plt.show()  #
        # Prediction distribution at temperature=1.0 should be unchanged after
        # scaling.
        # if args.temperature == 1.0:
        #     np.testing.assert_allclose(
        #             prediction, scaled_prediction, atol=1e-5,
        #             err_msg='Prediction scaling at temperature=1.0 '
        #                     'is not working as intended.')
        # TODO consider grabbing only top probability instead of this range
        sample = np.random.choice(np.arange(quantization_channels), p=scaled_prediction)
        waveform.append(sample)

        # Show progress only once per second.
        # current_sample_timestamp = datetime.now()
        # time_since_print = current_sample_timestamp - last_sample_timestamp
        # if time_since_print.total_seconds() > 1.:
        #     print('Sample {:3<d}/{:3<d}'.format(step + 1, args.samples),
        #           end='\r')
        #     last_sample_timestamp = current_sample_timestamp

        # If we have partial writing, save the result so far.
        # if (args.wav_out_path and args.save_every and
        #         (step + 1) % args.save_every == 0):
        #     out = sess.run(decode, feed_dict={samples: waveform})
        #     write_wav(out, wavenet_params['sample_rate'], args.wav_out_path)

    # Introduce a newline to clear the carriage return from the progress.
    # print()

    # Save the result as an audio summary.
    # datestring = str(datetime.now()).replace(' ', 'T')
    # writer = tf.summary.FileWriter(logdir)
    # tf.summary.audio('generated', decode, wavenet_params['sample_rate'])
    # summaries = tf.summary.merge_all()
    # summary_out = sess.run(summaries,
    #                        feed_dict={samples: np.reshape(waveform, [-1, 1])})
    # writer.add_summary(summary_out)

    # Save the result as a wav file.
    # if args.wav_out_path:
    #     out = sess.run(decode, feed_dict={samples: waveform})
    #     write_wav(out, wavenet_params['sample_rate'], args.wav_out_path)
    out = sess.run(decode, feed_dict={samples: waveform})
    sess.close()
    return out
示例#29
0
def generate_waveform(sess,
                      net,
                      fast_generation,
                      gc,
                      lc,
                      samples_placeholder,
                      gc_placeholder,
                      lc_placeholder,
                      operations,
                      test=False,
                      fixed_wave=None):
    if test:
        logits = []

    waveform = [128] * net.receptive_field

    _lc = np.zeros((1, net.receptive_field))

    # initial_lc = [0] * net.receptive_field if lc is not None else None
    if fast_generation:
        for sample in waveform[:-1]:
            feed_dict = {samples_placeholder: [sample]}
            feed_dict[gc_placeholder] = gc if gc is not None else None
            feed_dict[lc_placeholder] = [0] if lc is not None else None
            sess.run(operations, feed_dict)
            if fixed_wave is not None:
                np.insert(fixed_wave, 0, 0)

    for i in range(GENERATE_SAMPLES):
        if i % 100 == 0:
            print("Generating {} of {}.".format(i, GENERATE_SAMPLES))
            sys.stdout.flush()
        if fast_generation:
            window = waveform[-1]
            current_lc = lc[i] if lc is not None else None
        else:
            _lc[:, :-1] = _lc[:, 1:]
            if len(waveform) > net.receptive_field:
                window = waveform[-net.receptive_field:]
                _lc[:, -1] = lc[i] if lc is not None else None
            else:
                window = waveform
                # current_lc = initial_lc if lc is not None else None
                _lc[:, -1] = lc[i] if lc is not None else None
            current_lc = _lc
            # current_lc = current_lc.reshape((1,))
            # print("current lc")
            # print(current_lc.shape)
            # print(gc.shape)

        # Run the WaveNet to predict the next sample.
        feed_dict = {samples_placeholder: window, lc_placeholder: current_lc}
        if gc is not None:
            feed_dict[gc_placeholder] = gc

        # if lc is not None:
        #     feed_dict[lc_placeholder] = current_lc
        results = sess.run(operations, feed_dict=feed_dict)

        if test:
            logits.append(results)
        else:
            sample = np.random.choice(np.arange(QUANTIZATION_CHANNELS),
                                      p=results[0])
            waveform.append(sample)

    # Skip the first number of samples equal to the size of the receptive
    # field minus one.
    if test:
        return logits
    else:
        waveform = np.array(waveform[net.receptive_field - 1:])
        decode = mu_law_decode(samples_placeholder, QUANTIZATION_CHANNELS)
        decoded_waveform = sess.run(decode,
                                    feed_dict={samples_placeholder: waveform})
        return decoded_waveform
def main():

    with tf.device(
            '/cpu:0'):  # cpu가 더 빠르다. gpu로 설정하면 Error. tf.device 없이 하면 더 느려진다.
        config = get_arguments()
        started_datestring = "{0:%Y-%m-%dT%H-%M-%S}".format(datetime.now())
        logdir = os.path.join(config.logdir, 'generate', started_datestring)
        print('logdir0-------------' + logdir)

        if not os.path.exists(logdir):
            os.makedirs(logdir)

        load_hparams(hparams, config.checkpoint_dir)

        sess = tf.Session()
        scalar_input = hparams.scalar_input
        net = WaveNetModel(
            batch_size=config.batch_size,
            dilations=hparams.dilations,
            filter_width=hparams.filter_width,
            residual_channels=hparams.residual_channels,
            dilation_channels=hparams.dilation_channels,
            quantization_channels=hparams.quantization_channels,
            out_channels=hparams.out_channels,
            skip_channels=hparams.skip_channels,
            use_biases=hparams.use_biases,
            scalar_input=hparams.scalar_input,
            global_condition_channels=hparams.gc_channels,
            global_condition_cardinality=config.gc_cardinality,
            local_condition_channels=hparams.num_mels,
            upsample_factor=hparams.upsample_factor,
            legacy=hparams.legacy,
            residual_legacy=hparams.residual_legacy,
            train_mode=False
        )  # train 단계에서는 global_condition_cardinality를 AudioReader에서 파악했지만, 여기서는 넣어주어야 함

        if scalar_input:
            samples = tf.placeholder(tf.float32, shape=[net.batch_size, None])
        else:
            samples = tf.placeholder(
                tf.int32, shape=[net.batch_size, None]
            )  # samples: mu_law_encode로 변환된 것. one-hot으로 변환되기 전. (batch_size, 길이)

        # local condition이 (N,T,num_mels) 여야 하지만, 길이 1까지로 들어가야하기 때무넹, (N,1,num_mels) --> squeeze하면 (N,num_mels)
        upsampled_local_condition = tf.placeholder(
            tf.float32, shape=[net.batch_size, hparams.num_mels])

        next_sample = net.predict_proba_incremental(
            samples, upsampled_local_condition, [config.gc_id] * net.batch_size
        )  # Fast Wavenet Generation Algorithm-1611.09482 algorithm 적용

        # making local condition data. placeholder - upsampled_local_condition 넣어줄 upsampled local condition data를 만들어 보자.
        print('logdir0-------------' + logdir)
        mel_input = np.load(config.mel)
        sample_size = mel_input.shape[0] * hparams.hop_size
        mel_input = np.tile(mel_input, (config.batch_size, 1, 1))
        with tf.variable_scope('wavenet', reuse=tf.AUTO_REUSE):
            upsampled_local_condition_data = net.create_upsample(
                mel_input, upsample_type=hparams.upsample_type)

        var_list = [
            var for var in tf.global_variables() if 'queue' not in var.name
        ]
        saver = tf.train.Saver(var_list)
        print('Restoring model from {}'.format(config.checkpoint_dir))

        load(saver, sess, config.checkpoint_dir)
        init_op = tf.group(tf.initialize_all_variables(),
                           net.queue_initializer)

        sess.run(init_op)  # 이 부분이 없으면, checkpoint에서 복원된 값들이 들어 있다.

        quantization_channels = hparams.quantization_channels
        if config.wav_seed:
            # wav_seed의 길이가 receptive_field보다 작으면, padding이라도 해야 되는 거 아닌가? 그냥 짧으면 짧은 대로 return함  --> 그래서 너무 짧으면 error
            seed = create_seed(config.wav_seed, hparams.sample_rate,
                               quantization_channels, net.receptive_field,
                               scalar_input)  # --> mu_law encode 된 것.
            if scalar_input:
                waveform = seed.tolist()
            else:
                waveform = sess.run(
                    seed).tolist()  # [116, 114, 120, 121, 127, ...]

            print('Priming generation...')
            for i, x in enumerate(waveform[-net.receptive_field:-1]
                                  ):  # 제일 마지막 1개는 아래의 for loop의 첫 loop에서 넣어준다.
                if i % 100 == 0:
                    print('Priming sample {}/{}'.format(
                        i, net.receptive_field),
                          end='\r')
                sess.run(next_sample,
                         feed_dict={
                             samples:
                             np.array([x] * net.batch_size).reshape(
                                 net.batch_size, 1),
                             upsampled_local_condition:
                             np.zeros([net.batch_size, hparams.num_mels])
                         })
            print('Done.')
            waveform = np.array([waveform[-net.receptive_field:]] *
                                net.batch_size)
        else:
            # Silence with a single random sample at the end.
            if scalar_input:
                waveform = [0.0] * (net.receptive_field - 1)
                waveform = np.array(waveform * net.batch_size).reshape(
                    net.batch_size, -1)
                waveform = np.concatenate(
                    [
                        waveform, 2 * np.random.rand(net.batch_size).reshape(
                            net.batch_size, -1) - 1
                    ],
                    axis=-1)  # -1~1사이의 random number를 만들어 끝에 붙힌다.
                # wavefor: shape(batch_size,net.receptive_field )
            else:
                waveform = [quantization_channels / 2] * (
                    net.receptive_field - 1
                )  # 필요한 receptive_field 크기보다 1개 작게 만든 후, 아래에서 random하게 1개를 덧붙힌다.
                waveform = np.array(waveform * net.batch_size).reshape(
                    net.batch_size, -1)
                waveform = np.concatenate(
                    [
                        waveform,
                        np.random.randint(quantization_channels,
                                          size=net.batch_size).reshape(
                                              net.batch_size, -1)
                    ],
                    axis=-1)  # one hot 변환 전. (batch_size, 5117)

        start_time = time.time()
        upsampled_local_condition_data = sess.run(
            upsampled_local_condition_data)
        last_sample_timestamp = datetime.now()
        for step in range(sample_size):  # 원하는 길이를 구하기 위해 loop sample_size

            window = waveform[:,
                              -1:]  # 제일 끝에 있는 1개만 samples에 넣어 준다.  window: shape(N,1)

            # Run the WaveNet to predict the next sample.

            # fast가 아닌경우. window: [128.0, 128.0, ..., 128.0, 178, 185]
            # fast인 경우, window는 숫자 1개.
            prediction = sess.run(
                next_sample,
                feed_dict={
                    samples:
                    window,
                    upsampled_local_condition:
                    upsampled_local_condition_data[:, step, :]
                }
            )  # samples는 mu law encoding된 것. 계산 과정에서 one hot으로 변환된다.  --> (batch_size,256)

            if scalar_input:
                sample = prediction  # logistic distribution으로부터 sampling 되었기 때문에, randomness가 있다.
            else:
                # Scale prediction distribution using temperature.
                # 다음 과정은 config.temperature==1이면 각 원소를 합으로 나누어주는 것에 불과. 이미 softmax를 적용한 겂이므로, 합이 1이된다. 그래서 값의 변화가 없다.
                # config.temperature가 1이 아니며, 각 원소의 log취한 값을 나눈 후, 합이 1이 되도록 rescaling하는 것이 된다.
                np.seterr(divide='ignore')
                scaled_prediction = np.log(
                    prediction
                ) / config.temperature  # config.temperature인 경우는 값의 변화가 없다.
                scaled_prediction = (
                    scaled_prediction - np.logaddexp.reduce(
                        scaled_prediction, axis=-1, keepdims=True)
                )  # np.log(np.sum(np.exp(scaled_prediction)))
                scaled_prediction = np.exp(scaled_prediction)
                np.seterr(divide='warn')

                # Prediction distribution at temperature=1.0 should be unchanged after
                # scaling.
                if config.temperature == 1.0:
                    np.testing.assert_allclose(
                        prediction,
                        scaled_prediction,
                        atol=1e-5,
                        err_msg=
                        'Prediction scaling at temperature=1.0 is not working as intended.'
                    )

                # argmax로 선택하지 않기 때문에, 같은 입력이 들어가도 달라질 수 있다.
                sample = [[
                    np.random.choice(np.arange(quantization_channels), p=p)
                ] for p in scaled_prediction]  # choose one sample per batch

            waveform = np.concatenate([waveform, sample],
                                      axis=-1)  #window.shape: (N,1)

            # Show progress only once per second.
            current_sample_timestamp = datetime.now()
            time_since_print = current_sample_timestamp - last_sample_timestamp
            if time_since_print.total_seconds() > 1.:
                duration = time.time() - start_time
                print('Sample {:3<d}/{:3<d}, ({:.3f} sec/step)'.format(
                    step + 1, sample_size, duration),
                      end='\r')
                last_sample_timestamp = current_sample_timestamp

        # Introduce a newline to clear the carriage return from the progress.
        print()

        # Save the result as a wav file.
        if hparams.input_type == 'raw':
            out = waveform[:, net.receptive_field:]
        elif hparams.input_type == 'mulaw':
            decode = mu_law_decode(samples,
                                   quantization_channels,
                                   quantization=False)
            out = sess.run(
                decode, feed_dict={samples: waveform[:, net.receptive_field:]})
        else:  # 'mulaw-quantize'
            decode = mu_law_decode(samples,
                                   quantization_channels,
                                   quantization=True)
            out = sess.run(
                decode, feed_dict={samples: waveform[:, net.receptive_field:]})

        # save wav

        for i in range(net.batch_size):
            config.wav_out_path = logdir + '/test-{}.wav'.format(i)
            mel_path = config.wav_out_path.replace(".wav", ".png")

            gen_mel_spectrogram = audio.melspectrogram(out[i], hparams).astype(
                np.float32).T
            audio.save_wav(out[i], config.wav_out_path,
                           hparams.sample_rate)  # save_wav 내에서 out[i]의 값이 바뀐다.

            plot.plot_spectrogram(gen_mel_spectrogram,
                                  mel_path,
                                  title='generated mel spectrogram',
                                  target_spectrogram=mel_input[i])
        print('Finished generating.')
示例#31
0
quantized_oh[0][1000:1020].eval(session=sess)

# let RNN out be exact RNN input (for test)
#
# turn it back to 8 bit signal

# In[17]:

quantized_deoh = _de_one_hot(quantized_oh)
quantized_deoh[0][1000:1050].eval(session=sess)

# from 8 bit signal to real sound

# In[18]:

out = mu_law_decode(quantized_deoh,
                    quantization_channels=M_PARAMS['QUANTISATION_CHANNELS'])
out[0][1000:1050].eval(session=sess)

# evaluate real_sound from tf to numpy

# In[19]:

out_wave = sess.run(out[0])

# write into file

# In[20]:

write_wav(out_wave, M_PARAMS['SAMPLE_RATE'], wav_fname_new)
示例#32
0
    print(sess.run(D_real, feed_dict={X: batch_data, Z: sample_Z(1, 100)}))

    print("fake logit")
    print(sess.run(D_fake, feed_dict={Z: sample_Z(1, 100)}))

    print("Equal?")
    print(np.array_equal(prevA, nextA))
    prevA = nextA
    '''

    duration = time.time() - start_time

    if (it % 20 == 0):
        waveform = []
        waveform = np.reshape(
            sess.run(G_sample, feed_dict={Z: sample_Z(1, 100)}), [w1])
        print(waveform)
        name = '5-3-2simplegenerate-' + str(it) + '.wav'
        write_wav(waveform, 22000, name)

    print('Step %d: 1st D loss = %.7f, 10th G loss = %.7f (%.3f sec)' %
          (it, D_loss_curr, G_loss_curr, duration))

samples = tf.placeholder(tf.int32)
decode = mu_law_decode(samples, 256)

waveform = []
waveform = np.reshape(sess.run(G_sample, feed_dict={Z: sample_Z(1, 100)}),
                      [w1])

write_wav(waveform, 22000, '5-3-2simplegenerate.wav')
示例#33
0
def main():
    args = get_arguments()
    started_datestring = "{0:%Y-%m-%dT%H-%M-%S}".format(datetime.now())
    logdir = os.path.join(args.logdir, 'generate', started_datestring)
    with open(args.wavenet_params, 'r') as config_file:
        wavenet_params = json.load(config_file)

    sess = tf.Session()

    net = WaveNetModel(
        batch_size=1,
        dilations=wavenet_params['dilations'],
        filter_width=wavenet_params['filter_width'],
        residual_channels=wavenet_params['residual_channels'],
        dilation_channels=wavenet_params['dilation_channels'],
        quantization_channels=wavenet_params['quantization_channels'],
        skip_channels=wavenet_params['skip_channels'],
        use_biases=wavenet_params['use_biases'],
        scalar_input=wavenet_params['scalar_input'],
        initial_filter_width=wavenet_params['initial_filter_width'])

    samples = tf.placeholder(tf.int32)  #待初始化的张量占位符
    input_ID = tf.placeholder(tf.int32)
    startime_fastgeration = time.clock()
    if args.fast_generation:
        print("#########using_fast_generation")
        next_sample = net.predict_proba_incremental(samples, input_ID)

        #print next_sample

    else:
        next_sample = net.predict_proba(samples, input_ID)

    if args.fast_generation:
        sess.run(tf.initialize_all_variables())
        sess.run(net.init_ops)
    endtime_fastgernation = time.clock()
    #print ('fast_generation time {}'.format(endtime_fastgernation - endtime_fastgernation))
    time_of_fast = endtime_fastgernation - startime_fastgeration  #1

    start_vari_saver = time.clock()  #变量save
    variables_to_restore = {
        var.name[:-2]: var
        for var in tf.all_variables()
        if not ('state_buffer' in var.name or 'pointer' in var.name)
    }
    saver = tf.train.Saver(variables_to_restore)
    end_vari_saver = time.clock()
    print('variables_to_restore{}'.format(end_vari_saver - start_vari_saver))

    starttime_restore = time.clock()  #恢复从checkpoint
    print('Restoring model from {}'.format(args.checkpoint))
    saver.restore(sess, args.checkpoint)
    endtime_restore = time.clock()
    #print ('restore model time{}'.format(endtime_restore - endtime_restore))
    time_of_restore = endtime_restore - starttime_restore  #2
    print('%%%%%%%%%%%%{}'.format(time_of_restore))
    #return 0
    decode = mu_law_decode(
        samples, wavenet_params['quantization_channels'])  #namescope(encode)

    quantization_channels = wavenet_params['quantization_channels']
    time_of_seed = 0
    if args.wav_seed:
        start_using_create_seed = time.clock()
        print('#######using_create_seed')
        seed = create_seed(args.wav_seed, wavenet_params['sample_rate'],
                           quantization_channels)
        waveform = sess.run(seed).tolist()  #
        end_using_create_seed = time.clock()
        time_of_seed = end_using_create_seed - start_using_create_seed
        #print ('using create_seed time{}'.format(end_using_create_seed - start_using_create_seed))
    else:
        print('#######not_using_create_seed')
        waveform = np.random.randint(quantization_channels,
                                     size=(1, )).tolist()
    predict_of_fast_seed = 0
    if args.fast_generation and args.wav_seed:
        starttime_fast_and_seed = time.clock()
        # When using the incremental generation, we need to
        # feed in all priming samples one by one before starting the
        # actual generation.
        # TODO This could be done much more efficiently by passing the waveform
        # to the incremental generator as an optional argument, which would be
        # used to fill the queues initially.
        outputs = [next_sample]
        outputs.extend(net.push_ops)  #push_ops是一个列表

        print('Priming generation...')
        for i, x in enumerate(waveform[:-1]):
            if i % 1600 == 0:
                print('Priming sample {}'.format(i), end='\r')
                sys.stdout.flush()
            sess.run(outputs,
                     feed_dict={
                         samples: x,
                         input_ID: args.ID_generation
                     })
        print('Done.')
        endtime_fast_seed = time.clock()
        #print('fast_generation and create_seed time{}'.format(endtime_fast_seed - starttime_fast_and_seed))
        predict_of_fast_seed = predict_of_fast_seed + (endtime_fast_seed -
                                                       starttime_fast_and_seed)

    #return 0
    last_sample_timestamp = datetime.now()
    predict = 0
    index_begin_generate = 0 if (False
                                 == args.fast_generation) else len(waveform)
    startime_total_predict_sample = time.clock()
    for step in range(args.samples):
        if args.fast_generation:
            outputs = [next_sample]
            outputs.extend(net.push_ops)
            window = waveform[-1]
        else:
            if len(waveform) > args.window:  #波形大于窗口?
                window = waveform[-args.window:]
            else:
                window = waveform
            outputs = [next_sample]

        # Run the WaveNet to predict the next sample.
        starttime_run_net_predict = time.clock()
        prediction = sess.run(outputs,
                              feed_dict={
                                  samples: window,
                                  input_ID: args.ID_generation
                              })[0]
        endtime_run_net_predict = time.clock()
        print('run net to predict samples per step{}'.format(
            endtime_run_net_predict - starttime_run_net_predict))
        predict = predict + (endtime_run_net_predict -
                             starttime_run_net_predict)

        # Scale prediction distribution using temperature.尺度化预测!!

        np.seterr(divide='ignore')  #某种数据规约尺度化
        scaled_prediction = np.log(prediction) / args.temperature
        scaled_prediction = scaled_prediction - np.logaddexp.reduce(
            scaled_prediction)
        scaled_prediction = np.exp(scaled_prediction)
        np.seterr(divide='warn')

        # Prediction distribution at temperature=1.0 should be unchanged after scaling.
        if args.temperature == 1.0:
            np.testing.assert_allclose(
                prediction,
                scaled_prediction,
                atol=1e-5,
                err_msg=
                'Prediction scaling at temperature=1.0 is not working as intended.'
            )

        sample = np.random.choice(  #以概率返回某层通道采样
            np.arange(quantization_channels),
            p=scaled_prediction)
        waveform.append(sample)

        # Show progress only once per second.
        current_sample_timestamp = datetime.now()
        time_since_print = current_sample_timestamp - last_sample_timestamp
        if time_since_print.total_seconds() > 1.:  #以1??
            print('Sample {:3<d}/{:3<d}'.format(step + 1, args.samples),
                  end='\r')
            sys.stdout.flush()
            last_sample_timestamp = current_sample_timestamp

        # If we have partial writing, save the result so far.
        if (args.wav_out_path and args.save_every
                and (step + 1) % args.save_every == 0):
            print(
                '$$$$$$$$$$If we have partial writing, save the result so far')
            out = sess.run(decode, feed_dict={samples:
                                              waveform})  #有输入要求的tensor
            write_wav(out, wavenet_params['sample_rate'], args.wav_out_path)
    endtime_total_predicttime = time.clock()
    print('total predic time{}'.format(endtime_total_predicttime -
                                       startime_total_predict_sample))
    print('run net predict time{}'.format(predict))

    # Introduce a newline to clear the carriage return from the progress.
    print()

    # Save the result as an audio summary.
    '''
    datestring = str(datetime.now()).replace(' ', 'T')
    writer = tf.train.SummaryWriter(logdir)
    tf.audio_summary('generated', decode, wavenet_params['sample_rate'])
    summaries = tf.merge_all_summaries()
    summary_out = sess.run(summaries,
                           feed_dict={samples: np.reshape(waveform[index_begin_generate:], [-1, 1]), input_ID: args.ID_generation})
    writer.add_summary(summary_out)
    '''

    # Save the result as a wav file.
    if args.wav_out_path:
        start_save_wav_time = time.clock()
        out = sess.run(decode,
                       feed_dict={samples: waveform[index_begin_generate:]})

        write_wav(out, wavenet_params['sample_rate'], args.wav_out_path)
        end_save_wave_time = time.clock()
        print('write wave time{}'.format(end_save_wave_time -
                                         start_save_wav_time))

    print('time_of_fast_initops{}'.format(time_of_fast))
    print('time_of_restore'.format(time_of_restore))
    print('time_of_fast_and_seed{}'.format(predict_of_fast_seed))
    print('time_of_seed'.format(time_of_seed))
    print('Finished generating. The result can be viewed in TensorBoard.')
示例#34
0
def main():
    args = get_arguments()

    if args.isDebug in ["True", "true", "t", "1"]:
        isDebug = True
        print("Running train.py for debugging...")
    elif args.isDebug in ["False", "false", "f", "0"]:
        isDebug = False
        print("Running train.py for actual training...")
    else:
        print("--isDebug has to be True or False")
        exit()

    try:
        directories = validate_directories(args)
    except ValueError as e:
        print("Some arguments are wrong:")
        print(str(e))
        return

    logdir = directories['logdir']
    restore_from = directories['restore_from']
    print(restore_from)

    # Even if we restored the model, we will treat it as new training
    # if the trained model is written into an arbitrary location.
    is_overwritten_training = logdir != restore_from

    with open(args.wavenet_params, 'r') as f:
        wavenet_params = json.load(f)

    # Create coordinator.
    coord = tf.train.Coordinator()

    # Load raw waveform from VCTK corpus.
    with tf.name_scope('create_inputs'):
        # Allow silence trimming to be skipped by specifying a threshold near
        # zero.
        silence_threshold = args.silence_threshold if args.silence_threshold > \
                                                      EPSILON else None
        gc_enabled = args.gc_channels is not None
        lc_enabled = args.lc_channels is not None
        # reader = AudioReader(
        #     args.data_dir,
        #     coord,
        #     sample_rate=wavenet_params['sample_rate'],
        #     gc_enabled=gc_enabled,
        #     lc_enabled=lc_enabled,
        #     receptive_field=WaveNetModel.calculate_receptive_field(wavenet_params["filter_width"],
        #                                                            wavenet_params["dilations"],
        #                                                            wavenet_params["scalar_input"],
        #                                                            wavenet_params["initial_filter_width"]),
        #     sample_size=args.sample_size,
        #     silence_threshold=silence_threshold)
        # audio_batch = reader.dequeue(args.batch_size)
        # if gc_enabled:
        #     gc_id_batch = reader.dequeue_gc(args.batch_size)
        # else:
        #     gc_id_batch = None
        #
        # if lc_enabled:
        #     lc_id_batch = reader.dequeue_lc(args.batch_size)
        # else:
        #     lc_id_batch = None

    # Create network.
    net = WaveNetModel(
        batch_size=args.batch_size,
        dilations=wavenet_params["dilations"],
        filter_width=wavenet_params["filter_width"],
        residual_channels=wavenet_params["residual_channels"],
        dilation_channels=wavenet_params["dilation_channels"],
        skip_channels=wavenet_params["skip_channels"],
        quantization_channels=wavenet_params["quantization_channels"],
        use_biases=wavenet_params["use_biases"],
        scalar_input=wavenet_params["scalar_input"],
        initial_filter_width=wavenet_params["initial_filter_width"],
        histograms=args.histograms,
        global_condition_channels=args.gc_channels,
        # global_condition_cardinality=reader.gc_category_cardinality,
        local_condition_channels=args.lc_channels)

    if args.l2_regularization_strength == 0:
        args.l2_regularization_strength = None

    audio_placeholder_training = tf.placeholder(dtype=tf.float32, shape=None)
    gc_placeholder_training = tf.placeholder(
        dtype=tf.int32) if gc_enabled else None
    lc_placeholder_training = tf.placeholder(
        dtype=tf.float32, shape=(net.batch_size, None,
                                 512)) if lc_enabled else None
    loss = net.loss(input_batch=audio_placeholder_training,
                    global_condition_batch=gc_placeholder_training,
                    local_condition_batch=lc_placeholder_training,
                    l2_regularization_strength=args.l2_regularization_strength)
    optimizer = optimizer_factory[args.optimizer](
        learning_rate=args.learning_rate, momentum=args.momentum)
    trainable = tf.trainable_variables()
    optim = optimizer.minimize(loss, var_list=trainable)
    """variables for validation"""
    net.batch_size = 1
    audio_placeholder_validation = tf.placeholder(dtype=tf.float32, shape=None)
    gc_placeholder_validation = tf.placeholder(
        dtype=tf.int32) if gc_enabled else None
    lc_placeholder_validation = tf.placeholder(
        dtype=tf.float32, shape=(net.batch_size, None,
                                 512)) if lc_enabled else None
    validation = net.validation(
        input_batch=audio_placeholder_validation,
        global_condition_batch=gc_placeholder_validation,
        local_condition_batch=lc_placeholder_validation)

    # Set up logging for TensorBoard.
    writer = tf.summary.FileWriter(logdir)
    writer.add_graph(tf.get_default_graph())
    run_metadata = tf.RunMetadata()
    summaries = tf.summary.merge_all()

    # Set up session
    sess = tf.Session(config=tf.ConfigProto(log_device_placement=False))

    # if args.restore_model is not None:
    #     variables_to_restore = {
    #         var.name[:-2]: var for var in tf.global_variables()
    #         if not ('state_buffer' in var.name or 'pointer' in var.name)}
    #     saver = tf.train.Saver(variables_to_restore)
    #
    #     print('Restoring model from {}'.format(args.checkpoint))
    #     saver.restore(sess, args.checkpoint)
    #
    #     print("Restoring model done")
    # else:
    init = tf.global_variables_initializer()
    sess.run(init)

    # Saver for storing checkpoints of the model.
    saver = tf.train.Saver(var_list=tf.trainable_variables(),
                           max_to_keep=args.max_checkpoints)

    try:
        saved_global_step = load(saver, sess, restore_from)
        if is_overwritten_training or saved_global_step is None:
            # The first training step will be saved_global_step + 1,
            # therefore we put -1 here for new or overwritten trainings.
            saved_global_step = -1

    except:
        print("Something went wrong while restoring checkpoint. "
              "We will terminate training to avoid accidentally overwriting "
              "the previous model.")
        raise

    # threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    # reader.start_threads(sess)

    training_log_file = open(DATA_DIRECTORY + "training_log.txt", "w")
    validation_log_file = open(DATA_DIRECTORY + "validation_log.txt", "w")

    last_saved_step = saved_global_step

    with open('pickle/audio_lists_training_x_6.pkl', 'rb') as f1:
        audio_lists_training = pickle.load(f1)

    with open('pickle/img_vec_lists_training_x_6.pkl', 'rb') as f2:
        img_vec_lists_training = pickle.load(f2)

    with open('pickle/audio_lists_validation.pkl', 'rb') as f3:
        audio_lists_validation = pickle.load(f3)

    with open('pickle/img_vec_lists_validation.pkl', 'rb') as f4:
        img_vec_lists_validation = pickle.load(f4)

    try:
        for epoch in range(saved_global_step + 1, args.num_steps):
            start_time = time.time()
            """ training """
            num_video_frames = []
            # training_data = audio_reader.load_generic_audio_video_without_downloading(DATA_DIRECTORY, SAMPLE_RATE,
            #                                                                             reader.i2v, "training", num_video_frames)
            training_data_order = np.arange(6)
            net.batch_size = 3
            random.shuffle(training_data_order)
            print(training_data_order)

            for o in range(2):

                video_matrix = np.zeros(
                    (net.batch_size, net.receptive_field + int(16000 / 25),
                     512))
                frame_index = 1
                for index in range(len(img_vec_lists_training[0])):

                    audio = audio_lists_training[training_data_order[o *
                                                                     3]][index]
                    audio = audio.reshape(1, -1)
                    img_vec = img_vec_lists_training[training_data_order[
                        o * 3]][index]
                    img_vecs = np.repeat(img_vec, int(16000 / 25), axis=1)
                    # audio = np.pad(audio, [[net.receptive_field, 0], [0, 0]], 'constant')
                    audio1 = audio_lists_training[training_data_order[
                        o * 3 + 1]][index]
                    audio1 = audio1.reshape(1, -1)
                    img_vec1 = img_vec_lists_training[training_data_order[
                        o * 3 + 1]][index]
                    img_vecs1 = np.repeat(img_vec1, int(16000 / 25), axis=1)
                    # audio1 = np.pad(audio1, [[net.receptive_field, 0], [0, 0]], 'constant')
                    audio2 = audio_lists_training[training_data_order[
                        o * 3 + 2]][index]
                    audio2 = audio2.reshape(1, -1)
                    img_vec2 = img_vec_lists_training[training_data_order[
                        o * 3 + 2]][index]
                    img_vecs2 = np.repeat(img_vec2, int(16000 / 25), axis=1)
                    # audio2 = np.pad(audio2, [[net.receptive_field, 0], [0, 0]], 'constant')
                    audio = np.vstack((audio, audio1))
                    audio = np.vstack((audio, audio2))
                    img_vecs = np.vstack((img_vecs, img_vecs1))
                    img_vecs = np.vstack((img_vecs, img_vecs2))

                    video_matrix[:, :-int(16000 /
                                          25), :] = video_matrix[:,
                                                                 int(16000 /
                                                                     25):, :]
                    video_matrix[:, -int(16000 / 25):, :] = img_vecs
                    # print(audio.shape)
                    # print(video_matrix.shape)

                    summary, loss_value, _ = sess.run(
                        [summaries, loss, optim],
                        feed_dict={
                            audio_placeholder_training: audio,
                            lc_placeholder_training: video_matrix
                        })

                    duration = time.time() - start_time
                    if frame_index % 10 == 0:
                        print(
                            'epoch {:d}, frame_index {:d}/{:d} - loss = {:.3f}, ({:.3f} sec/epoch)'
                            .format(epoch, frame_index,
                                    len(img_vec_lists_training[0]), loss_value,
                                    duration))
                        training_log_file.write(
                            'epoch {:d}, frame_index {:d}/{:d} - loss = {:.3f}, ({:.3f} sec/epoch)\n'
                            .format(epoch, frame_index,
                                    len(img_vec_lists_training[0]), loss_value,
                                    duration))
                    frame_index += 1

                    if frame_index == 2 and isDebug:
                        break
            """validation and generation"""
            if epoch % args.generate_every == 0:
                print("calculating validation score...")
                num_video_frames = []
                # validation_data = audio_reader.load_generic_audio_video_without_downloading(DATA_DIRECTORY, SAMPLE_RATE,
                #                                                                             reader.i2v, "validation", num_video_frames)
                validation_score = 0
                # pad = np.zeros((512, net.receptive_field))
                frame_index = 1
                waveform = []
                # prediction = None

                net.batch_size = 1
                video_matrix = np.zeros(
                    (net.batch_size, net.receptive_field + int(16000 / 25),
                     512))

                for index in range(len(img_vec_lists_validation)):
                    audio = audio_lists_validation[index]
                    img_vec = img_vec_lists_validation[index]
                    video_matrix[:, :-int(16000 /
                                          25), :] = video_matrix[:,
                                                                 int(16000 /
                                                                     25):, :]
                    video_matrix[:, -int(16000 / 25):, :] = img_vec

                    # return the error and prediction at the same time
                    validation_value, prediction = sess.run(
                        validation,
                        feed_dict={
                            audio_placeholder_validation: audio,
                            lc_placeholder_validation: video_matrix
                        })

                    validation_score += validation_value

                    if prediction is not None:
                        for i in range(prediction.shape[0]):
                            # generate a sample based on the predection
                            sample = prediction2sample(
                                prediction[i, :], 1.0,
                                net.quantization_channels)
                            waveform.append(sample)

                    if frame_index % 10 == 0:
                        # show the progress

                        print('validation {:d}/{:d}'.format(
                            frame_index, len(img_vec_lists_training[0])))

                    frame_index += 1

                    if frame_index == 10 and isDebug:
                        break

                print('epoch {:d} - validation = {:.3f}'.format(
                    epoch, sum(validation_score)))
                validation_log_file.write(
                    'epoch {:d} - validation = {:.3f}\n'.format(
                        epoch, sum(validation_score)))

                if len(waveform) > 0:
                    decode = mu_law_decode(
                        audio_placeholder_validation,
                        wavenet_params['quantization_channels'])
                    out = sess.run(
                        decode,
                        feed_dict={audio_placeholder_validation: waveform})
                    write_wav(out, wavenet_params['sample_rate'],
                              DATA_DIRECTORY + "epoch_" + str(epoch) + ".wav")

            if epoch % args.checkpoint_every == 0:
                save(saver, sess, logdir, epoch)
                last_saved_step = epoch

    except KeyboardInterrupt:
        # Introduce a line break after ^C is displayed so save message
        # is on its own line.
        print()
    finally:
        validation_log_file.close()
        training_log_file.close()
        if epoch > last_saved_step:
            save(saver, sess, logdir, epoch)
示例#35
0
def main():
    args = get_arguments()
    started_datestring = "{0:%Y-%m-%dT%H-%M-%S}".format(datetime.now())
    logdir = os.path.join(args.logdir, 'generate', started_datestring)
    with open(args.wavenet_params, 'r') as config_file:
        wavenet_params = json.load(config_file)

    sess = tf.Session()

    net = WaveNetModel(
        batch_size=1,
        dilations=wavenet_params['dilations'],
        filter_width=wavenet_params['filter_width'],
        residual_channels=wavenet_params['residual_channels'],
        dilation_channels=wavenet_params['dilation_channels'],
        quantization_channels=wavenet_params['quantization_channels'],
        skip_channels=wavenet_params['skip_channels'],
        use_biases=wavenet_params['use_biases'],
        scalar_input=wavenet_params['scalar_input'],
        initial_filter_width=wavenet_params['initial_filter_width'])

    samples = tf.placeholder(tf.int32)

    next_sample = net.predict_proba_all(samples)

    # if args.fast_generation:
    #     sess.run(tf.initialize_all_variables())
    #     sess.run(net.init_ops)

    variables_to_restore = {
        var.name[:-2]: var for var in tf.all_variables()
        if not ('state_buffer' in var.name or 'pointer' in var.name)}
    saver = tf.train.Saver(variables_to_restore)

    print('Restoring model from {}'.format(args.checkpoint))
    saver.restore(sess, args.checkpoint)

    decode = mu_law_decode(samples, wavenet_params['quantization_channels'])

    quantization_channels = wavenet_params['quantization_channels']
    seed = create_seed(args.wav_seed,
                       wavenet_params['sample_rate'],
                       quantization_channels)
    input_waveform = sess.run(seed).tolist()
    waveform = []
    print('waveform seed length from {}'.format(len(input_waveform)))
    print('samples {}'.format(args.samples))
    last_sample_timestamp = datetime.now()
    for slide_start in range(0, len(input_waveform), args.step_length):
        if slide_start + args.samples >= len(input_waveform):
            break
        input_audio_window = input_waveform[slide_start:slide_start + args.samples]

        outputs = [next_sample]
        # Run the WaveNet to predict the next sample.
        all_prediction = sess.run(outputs, feed_dict={samples: input_audio_window})[0]
        all_prediction = np.asarray(all_prediction)
        output_waveform = get_all_output_from_predictions(all_prediction, net.quantization_channels)

        if len(waveform) > 0:
            overlap_waveform = waveform[slide_start:len(waveform)]
            output_overlap_waveform = output_waveform[:-args.step_length]
            print(len(overlap_waveform), len(output_overlap_waveform), len(waveform))
            result = np.divide(np.add(output_overlap_waveform, overlap_waveform), 2.0)
            waveform[slide_start:len(waveform)] = result
            waveform.extend(output_waveform[-args.step_length:])

        else:
            waveform = output_waveform

        # Show progress only once per second.
        current_sample_timestamp = datetime.now()
        time_since_print = current_sample_timestamp - last_sample_timestamp
        if time_since_print.total_seconds() > 1.:
            print('Sample {:3<d}/{:3<d}'.format(slide_start + 1, args.samples),
                  end='\r')
            last_sample_timestamp = current_sample_timestamp

        # If we have partial writing, save the result so far.
        if (args.wav_out_path and args.save_every and
                        (slide_start + 1) % args.save_every == 0):
            out = sess.run(decode, feed_dict={samples: waveform})
            write_wav(out, wavenet_params['sample_rate'], args.wav_out_path)
            print("current step is {}".format(slide_start))

    # Introduce a newline to clear the carriage return from the progress.
    print()

    # Save the result as an audio summary.
    datestring = str(datetime.now()).replace(' ', 'T')
    writer = tf.train.SummaryWriter(logdir)
    tf.audio_summary('generated', decode, wavenet_params['sample_rate'])
    summaries = tf.merge_all_summaries()
    summary_out = sess.run(summaries,
                           feed_dict={samples: np.reshape(waveform, [-1, 1])})
    writer.add_summary(summary_out)

    # Save the result as a wav file.
    if args.wav_out_path:
        out = sess.run(decode, feed_dict={samples: waveform})
        print("The error between expected and actual is {}".format(mse_with_output(out, OUTPUT_FILE, wavenet_params['sample_rate'])))
        write_wav(out, wavenet_params['sample_rate'], args.wav_out_path)

    print('Finished generating. The result can be viewed in TensorBoard.')
示例#36
0
def main():
    config = get_arguments()
    started_datestring = "{0:%Y-%m-%dT%H-%M-%S}".format(datetime.now())
    logdir = os.path.join(config.logdir, 'generate', started_datestring)

    if not os.path.exists(logdir):
        os.makedirs(logdir)

    load_hparams(hparams, config.checkpoint_dir)

    with tf.device('/cpu:0'):

        sess = tf.Session()
        scalar_input = hparams.scalar_input
        net = WaveNetModel(
            batch_size=BATCH_SIZE,
            dilations=hparams.dilations,
            filter_width=hparams.filter_width,
            residual_channels=hparams.residual_channels,
            dilation_channels=hparams.dilation_channels,
            quantization_channels=hparams.quantization_channels,
            out_channels=hparams.out_channels,
            skip_channels=hparams.skip_channels,
            use_biases=hparams.use_biases,
            scalar_input=scalar_input,
            initial_filter_width=hparams.initial_filter_width,
            global_condition_channels=hparams.gc_channels,
            global_condition_cardinality=config.gc_cardinality,
            train_mode=False
        )  # train 단계에서는 global_condition_cardinality를 AudioReader에서 파악했지만, 여기서는 넣어주어야 함

        if scalar_input:
            samples = tf.placeholder(tf.float32, shape=[net.batch_size, None])
        else:
            samples = tf.placeholder(
                tf.int32, shape=[net.batch_size, None]
            )  # samples: mu_law_encode로 변환된 것. one-hot으로 변환되기 전. (batch_size, 길이)

        next_sample = net.predict_proba_incremental(
            samples, [config.gc_id] * net.batch_size
        )  # Fast Wavenet Generation Algorithm-1611.09482 algorithm 적용

        var_list = [
            var for var in tf.global_variables() if 'queue' not in var.name
        ]
        saver = tf.train.Saver(var_list)
        print('Restoring model from {}'.format(config.checkpoint_dir))

        load(saver, sess, config.checkpoint_dir)

        sess.run(
            net.queue_initializer)  # 이 부분이 없으면, checkpoint에서 복원된 값들이 들어 있다.

        quantization_channels = hparams.quantization_channels
        if config.wav_seed:
            # wav_seed의 길이가 receptive_field보다 작으면, padding이라도 해야 되는 거 아닌가? 그냥 짧으면 짧은 대로 return함  --> 그래서 너무 짧으면 error
            seed = create_seed(config.wav_seed, hparams.sample_rate,
                               quantization_channels, net.receptive_field,
                               scalar_input)  # --> mu_law encode 된 것.
            if scalar_input:
                waveform = seed.tolist()
            else:
                waveform = sess.run(
                    seed).tolist()  # [116, 114, 120, 121, 127, ...]

            print('Priming generation...')
            for i, x in enumerate(waveform[-net.receptive_field:-1]
                                  ):  # 제일 마지막 1개는 아래의 for loop의 첫 loop에서 넣어준다.
                if i % 100 == 0:
                    print('Priming sample {}'.format(i))
                sess.run(next_sample,
                         feed_dict={
                             samples:
                             np.array([x] * net.batch_size).reshape(
                                 net.batch_size, 1)
                         })
            print('Done.')
            waveform = np.array([waveform[-net.receptive_field:]] *
                                net.batch_size)
        else:
            # Silence with a single random sample at the end.
            if scalar_input:
                waveform = [0.0] * (net.receptive_field - 1)
                waveform = np.array(waveform * net.batch_size).reshape(
                    net.batch_size, -1)
                waveform = np.concatenate(
                    [
                        waveform, 2 * np.random.rand(net.batch_size).reshape(
                            net.batch_size, -1) - 1
                    ],
                    axis=-1)  # -1~1사이의 random number를 만들어 끝에 붙힌다.
            else:
                waveform = [quantization_channels / 2] * (
                    net.receptive_field - 1
                )  # 필요한 receptive_field 크기보다 1개 작게 만든 후, 아래에서 random하게 1개를 덧붙힌다.
                waveform = np.array(waveform * net.batch_size).reshape(
                    net.batch_size, -1)
                waveform = np.concatenate(
                    [
                        waveform,
                        np.random.randint(quantization_channels,
                                          size=net.batch_size).reshape(
                                              net.batch_size, -1)
                    ],
                    axis=-1)  # one hot 변환 전. (batch_size, 5117)

        last_sample_timestamp = datetime.now()
        for step in range(config.samples):  # 원하는 길이를 구하기 위해 loop

            window = waveform[:, -1:]  # 제일 끝에 있는 1개만 samples에 넣어 준다.

            # Run the WaveNet to predict the next sample.

            # fast가 아닌경우. window: [128.0, 128.0, ..., 128.0, 178, 185]
            # fast인 경우, window는 숫자 1개.
            prediction = sess.run(
                next_sample, feed_dict={samples: window}
            )  # samples는 mu law encoding된 것. 계산 과정에서 one hot으로 변환된다.  --> (batch_size,256)

            if scalar_input:
                sample = prediction
            else:
                # Scale prediction distribution using temperature.
                # 다음 과정은 config.temperature==1이면 각 원소를 합으로 나누어주는 것에 불과. 이미 softmax를 적용한 겂이므로, 합이 1이된다. 그래서 값의 변화가 없다.
                # config.temperature가 1이 아니며, 각 원소의 log취한 값을 나눈 후, 합이 1이 되도록 rescaling하는 것이 된다.
                np.seterr(divide='ignore')
                scaled_prediction = np.log(
                    prediction
                ) / config.temperature  # config.temperature인 경우는 값의 변화가 없다.
                scaled_prediction = (
                    scaled_prediction - np.logaddexp.reduce(
                        scaled_prediction, axis=-1, keepdims=True)
                )  # np.log(np.sum(np.exp(scaled_prediction)))
                scaled_prediction = np.exp(scaled_prediction)
                np.seterr(divide='warn')

                # Prediction distribution at temperature=1.0 should be unchanged after
                # scaling.
                if config.temperature == 1.0:
                    np.testing.assert_allclose(
                        prediction,
                        scaled_prediction,
                        atol=1e-5,
                        err_msg=
                        'Prediction scaling at temperature=1.0 is not working as intended.'
                    )

                sample = [[
                    np.random.choice(np.arange(quantization_channels), p=p)
                ] for p in scaled_prediction]  # choose one sample per batch

            waveform = np.concatenate([waveform, sample], axis=-1)

            # Show progress only once per second.
            current_sample_timestamp = datetime.now()
            time_since_print = current_sample_timestamp - last_sample_timestamp
            if time_since_print.total_seconds() > 1.:
                print('Sample {:3<d}/{:3<d}'.format(step + 1, config.samples),
                      end='\r')
                last_sample_timestamp = current_sample_timestamp

        # Introduce a newline to clear the carriage return from the progress.
        print()

        # Save the result as a wav file.
        if scalar_input:
            out = waveform
        else:
            decode = mu_law_decode(samples, quantization_channels)
            out = sess.run(decode, feed_dict={samples: waveform})
        for i in range(net.batch_size):
            config.wav_out_path = logdir + '/test-{}.wav'.format(i)
            write_wav(out[i], hparams.sample_rate, config.wav_out_path)

        print('Finished generating.')
示例#37
0
def main():
    args = get_arguments()
    started_datestring = "{0:%Y-%m-%dT%H-%M-%S}".format(datetime.now())
    logdir = os.path.join(args.logdir, 'generate', started_datestring)
    with open(args.wavenet_params, 'r') as config_file:
        wavenet_params = json.load(config_file)

    sess = tf.Session()

    net = WaveNetModel(
        batch_size=1,
        dilations=wavenet_params['dilations'],
        filter_width=wavenet_params['filter_width'],
        residual_channels=wavenet_params['residual_channels'],
        dilation_channels=wavenet_params['dilation_channels'],
        quantization_channels=wavenet_params['quantization_channels'],
        skip_channels=wavenet_params['skip_channels'],
        use_biases=wavenet_params['use_biases'],
        scalar_input=wavenet_params['scalar_input'],
        initial_filter_width=wavenet_params['initial_filter_width'],
        global_condition_channels=args.gc_channels,
        global_condition_cardinality=args.gc_cardinality)

    gi_sampler = get_generator_input_sampler()

    # White noise generator params
    white_mean = 0
    white_sigma = 1
    white_length = 20234

    white_noise = gi_sampler(white_mean, white_sigma, white_length)

    loss = net.loss(input_batch=tf.convert_to_tensor(white_noise,
                                                     dtype=np.float32),
                    name='generator')

    samples = tf.placeholder(tf.int32)

    if args.fast_generation:
        next_sample = net.predict_proba_incremental(samples, args.gc_id)
    else:
        next_sample = net.predict_proba(samples, args.gc_id)

    if args.fast_generation:
        sess.run(tf.global_variables_initializer())
        sess.run(net.init_ops)

    decode = mu_law_decode(samples, wavenet_params['quantization_channels'])

    quantization_channels = wavenet_params['quantization_channels']
    '''
    # Silence with a single random sample at the end.
    waveform = [quantization_channels / 2] * (net.receptive_field - 1)
    waveform.append(np.random.randint(quantization_channels))
    '''
    waveform = [0]

    last_sample_timestamp = datetime.now()
    for step in range(args.samples):
        if args.fast_generation:
            outputs = [next_sample]
            outputs.extend(net.push_ops)
            window = waveform[-1]
        else:
            if len(waveform) > net.receptive_field:
                window = waveform[-net.receptive_field:]
            else:
                window = waveform
            outputs = [next_sample]

        # Run the WaveNet to predict the next sample.
        prediction = sess.run(outputs, feed_dict={samples: window})[0]

        # Scale prediction distribution using temperature.
        np.seterr(divide='ignore')
        scaled_prediction = np.log(prediction) / args.temperature
        scaled_prediction = (scaled_prediction -
                             np.logaddexp.reduce(scaled_prediction))
        scaled_prediction = np.exp(scaled_prediction)
        np.seterr(divide='warn')

        # Prediction distribution at temperature=1.0 should be unchanged after
        # scaling.
        if args.temperature == 1.0:
            np.testing.assert_allclose(
                prediction,
                scaled_prediction,
                atol=1e-5,
                err_msg='Prediction scaling at temperature=1.0 '
                'is not working as intended.')

        sample = np.random.choice(np.arange(quantization_channels),
                                  p=scaled_prediction)
        waveform.append(sample)

    # Introduce a newline to clear the carriage return from the progress.
    print()
    del waveform[0]
    print(waveform)
    print()

    print('Finished generating. The result can be viewed in TensorBoard.')
示例#38
0
def main():
    args = get_arguments()
    started_datestring = "{0:%Y-%m-%dT%H-%M-%S}".format(datetime.now())
    logdir = os.path.join(args.logdir, 'generate', started_datestring)
    with open(args.wavenet_params, 'r') as config_file:
        wavenet_params = json.load(config_file)

    sess = tf.Session()

    net = WaveNetModel(
        batch_size=1,
        dilations=wavenet_params['dilations'],
        filter_width=wavenet_params['filter_width'],
        residual_channels=wavenet_params['residual_channels'],
        dilation_channels=wavenet_params['dilation_channels'],
        quantization_channels=wavenet_params['quantization_channels'],
        skip_channels=wavenet_params['skip_channels'],
        use_biases=wavenet_params['use_biases'])

    samples = tf.placeholder(tf.int32)

    if args.fast_generation:
        next_sample = net.predict_proba_incremental(samples)
    else:
        next_sample = net.predict_proba(samples)

    if args.fast_generation:
        sess.run(tf.initialize_all_variables())
        sess.run(net.init_ops)

    variables_to_restore = {
        var.name[:-2]: var for var in tf.all_variables()
        if not ('state_buffer' in var.name or 'pointer' in var.name)}
    saver = tf.train.Saver(variables_to_restore)

    print('Restoring model from {}'.format(args.checkpoint))
    saver.restore(sess, args.checkpoint)

    decode = mu_law_decode(samples, wavenet_params['quantization_channels'])

    quantization_channels = wavenet_params['quantization_channels']
    if args.wav_seed:
        seed = create_seed(args.wav_seed,
                           wavenet_params['sample_rate'],
                           quantization_channels)
        waveform = sess.run(seed).tolist()
    else:
        waveform = np.random.randint(quantization_channels, size=(1,)).tolist()

    if args.fast_generation and args.wav_seed:
        # When using the incremental generation, we need to
        # feed in all priming samples one by one before starting the
        # actual generation.
        # TODO This could be done much more efficiently by passing the waveform
        # to the incremental generator as an optional argument, which would be
        # used to fill the queues initially.
        outputs = [next_sample]
        outputs.extend(net.push_ops)

        print('Priming generation...')
        for i, x in enumerate(waveform[:-(args.window + 1)]):
            if i % 100 == 0:
                print('Priming sample {}'.format(i))
            sess.run(outputs, feed_dict={samples: x})
        print('Done.')

    last_sample_timestamp = datetime.now()
    for step in range(args.samples):
        if args.fast_generation:
            outputs = [next_sample]
            outputs.extend(net.push_ops)
            window = waveform[-1]
        else:
            if len(waveform) > args.window:
                window = waveform[-args.window:]
            else:
                window = waveform
            outputs = [next_sample]

        # Run the WaveNet to predict the next sample.
        prediction = sess.run(outputs, feed_dict={samples: window})[0]
        sample = np.random.choice(
            np.arange(quantization_channels), p=prediction)
        waveform.append(sample)

        # Show progress only once per second.
        current_sample_timestamp = datetime.now()
        time_since_print = current_sample_timestamp - last_sample_timestamp
        if time_since_print.total_seconds() > 1.:
            print('Sample {:3<d}/{:3<d}'.format(step + 1, args.samples),
                  end='\r')
            last_sample_timestamp = current_sample_timestamp

        # If we have partial writing, save the result so far.
        if (args.wav_out_path and args.save_every and
                (step + 1) % args.save_every == 0):
            out = sess.run(decode, feed_dict={samples: waveform})
            write_wav(out, wavenet_params['sample_rate'], args.wav_out_path)

    # Introduce a newline to clear the carriage return from the progress.
    print()

    # Save the result as an audio summary.
    datestring = str(datetime.now()).replace(' ', 'T')
    writer = tf.train.SummaryWriter(logdir)
    tf.audio_summary('generated', decode, wavenet_params['sample_rate'])
    summaries = tf.merge_all_summaries()
    summary_out = sess.run(summaries,
                           feed_dict={samples: np.reshape(waveform, [-1, 1])})
    writer.add_summary(summary_out)

    # Save the result as a wav file.
    if args.wav_out_path:
        out = sess.run(decode, feed_dict={samples: waveform})
        write_wav(out, wavenet_params['sample_rate'], args.wav_out_path)

    print('Finished generating. The result can be viewed in TensorBoard.')
示例#39
0
def main():
    args = get_arguments()
    started_datestring = "{0:%Y-%m-%dT%H-%M-%S}".format(datetime.now())
    logdir = os.path.join(args.logdir, 'generate', started_datestring)
    with open(args.wavenet_params, 'r') as config_file:
        wavenet_params = json.load(config_file)

    sess = tf.Session()

    net = WaveNetModel(
        batch_size=1,
        dilations=wavenet_params['dilations'],
        filter_width=wavenet_params['filter_width'],
        residual_channels=wavenet_params['residual_channels'],
        dilation_channels=wavenet_params['dilation_channels'],
        quantization_channels=wavenet_params['quantization_channels'],
        skip_channels=wavenet_params['skip_channels'],
        use_biases=wavenet_params['use_biases'],
        scalar_input=wavenet_params['scalar_input'],
        initial_filter_width=wavenet_params['initial_filter_width'])

    samples = tf.placeholder(tf.int32)

    if args.fast_generation:
        next_sample = net.predict_proba_incremental(samples)
    else:
        next_sample = net.predict_proba(samples)

    if args.fast_generation:
        sess.run(tf.initialize_all_variables())
        sess.run(net.init_ops)

    variables_to_restore = {
        var.name[:-2]: var for var in tf.all_variables()
        if not ('state_buffer' in var.name or 'pointer' in var.name)}
    saver = tf.train.Saver(variables_to_restore)

    print('Restoring model from {}'.format(args.checkpoint))
    saver.restore(sess, args.checkpoint)

    decode = mu_law_decode(samples, wavenet_params['quantization_channels'])

    quantization_channels = wavenet_params['quantization_channels']
    if args.wav_seed:
        seed = create_seed(args.wav_seed,
                           wavenet_params['sample_rate'],
                           quantization_channels)
        waveform = sess.run(seed).tolist()
    else:
        waveform = np.random.randint(quantization_channels, size=(1,)).tolist()

    if args.fast_generation and args.wav_seed:
        # When using the incremental generation, we need to
        # feed in all priming samples one by one before starting the
        # actual generation.
        # TODO This could be done much more efficiently by passing the waveform
        # to the incremental generator as an optional argument, which would be
        # used to fill the queues initially.
        outputs = [next_sample]
        outputs.extend(net.push_ops)

        print('Priming generation...')
        for i, x in enumerate(waveform[:-(args.window + 1)]):
            if i % 100 == 0:
                print('Priming sample {}'.format(i))
            sess.run(outputs, feed_dict={samples: x})
        print('Done.')

    last_sample_timestamp = datetime.now()
    for step in range(args.samples):
        if args.fast_generation:
            outputs = [next_sample]
            outputs.extend(net.push_ops)
            window = waveform[-1]
        else:
            if len(waveform) > args.window:
                window = waveform[-args.window:]
            else:
                window = waveform
            outputs = [next_sample]

        # Run the WaveNet to predict the next sample.
        prediction = sess.run(outputs, feed_dict={samples: window})[0]
        sample = np.random.choice(
            np.arange(quantization_channels), p=prediction)
        waveform.append(sample)

        # Show progress only once per second.
        current_sample_timestamp = datetime.now()
        time_since_print = current_sample_timestamp - last_sample_timestamp
        if time_since_print.total_seconds() > 1.:
            print('Sample {:3<d}/{:3<d}'.format(step + 1, args.samples),
                  end='\r')
            last_sample_timestamp = current_sample_timestamp

        # If we have partial writing, save the result so far.
        if (args.wav_out_path and args.save_every and
                (step + 1) % args.save_every == 0):
            out = sess.run(decode, feed_dict={samples: waveform})
            write_wav(out, wavenet_params['sample_rate'], args.wav_out_path)

    # Introduce a newline to clear the carriage return from the progress.
    print()

    # Save the result as an audio summary.
    datestring = str(datetime.now()).replace(' ', 'T')
    writer = tf.train.SummaryWriter(logdir)
    tf.audio_summary('generated', decode, wavenet_params['sample_rate'])
    summaries = tf.merge_all_summaries()
    summary_out = sess.run(summaries,
                           feed_dict={samples: np.reshape(waveform, [-1, 1])})
    writer.add_summary(summary_out)

    # Save the result as a wav file.
    if args.wav_out_path:
        out = sess.run(decode, feed_dict={samples: waveform})
        write_wav(out, wavenet_params['sample_rate'], args.wav_out_path)

    print('Finished generating. The result can be viewed in TensorBoard.')
示例#40
0
    def testDecodeNegativeDilation(self):
        channels = 10
        y = [0, 255, 243, 31, 156, 229, 0, 235, 202, 18]

        with self.test_session() as sess:
            self.assertRaises(TypeError, sess.run(mu_law_decode(y, channels)))
示例#41
0
def main():
    args = get_arguments()
    started_datestring = "{0:%Y-%m-%dT%H-%M-%S}".format(datetime.now())
    logdir = os.path.join(args.logdir, 'generate', started_datestring)
    with open(args.wavenet_params, 'r') as config_file:
        wavenet_params = json.load(config_file)

    lc_enabled = args.lc_channels is not None
    lc_channels = args.lc_channels
    lc_duration = args.lc_duration

    sess = tf.Session()

    net = WaveNetModel(
        batch_size=1,
        dilations=wavenet_params['dilations'],
        filter_width=wavenet_params['filter_width'],
        residual_channels=wavenet_params['residual_channels'],
        dilation_channels=wavenet_params['dilation_channels'],
        quantization_channels=wavenet_params['quantization_channels'],
        skip_channels=wavenet_params['skip_channels'],
        use_biases=wavenet_params['use_biases'],
        scalar_input=wavenet_params['scalar_input'],
        initial_filter_width=wavenet_params['initial_filter_width'],
        global_condition_channels=args.gc_channels,
        global_condition_cardinality=args.gc_cardinality,
        local_condition_channels=args.lc_channels)

    samples = tf.placeholder(tf.int32)
    lc_piece = tf.placeholder(tf.float32, [1, lc_channels])
    local_condition = load_lc(os.path.join(os.getcwd(), args.lc_path))

    args.samples = args.lc_duration * len(local_condition)

    if args.fast_generation:
        next_sample = net.predict_proba_incremental(samples, args.gc_id,
                                                    lc_piece)
    else:
        next_sample = net.predict_proba(samples, args.gc_id, lc_piece)

    if args.fast_generation:
        sess.run(tf.global_variables_initializer())
        sess.run(net.init_ops)

    variables_to_restore = {
        var.name[:-2]: var
        for var in tf.global_variables()
        if not ('state_buffer' in var.name or 'pointer' in var.name)
    }
    saver = tf.train.Saver(variables_to_restore)

    print('Restoring model from {}'.format(args.checkpoint))
    saver.restore(sess, args.checkpoint)

    decode = mu_law_decode(samples, wavenet_params['quantization_channels'])

    quantization_channels = wavenet_params['quantization_channels']
    if args.wav_seed:
        seed = create_seed(args.wav_seed, wavenet_params['sample_rate'],
                           quantization_channels, net.receptive_field)
        waveform = sess.run(seed).tolist()
    else:
        # Silence with a single random sample at the end.
        waveform = [quantization_channels / 2] * (net.receptive_field - 1)
        waveform.append(np.random.randint(quantization_channels))

    if args.fast_generation and args.wav_seed:
        # When using the incremental generation, we need to
        # feed in all priming samples one by one before starting the
        # actual generation.
        # TODO This could be done much more efficiently by passing the waveform
        # to the incremental generator as an optional argument, which would be
        # used to fill the queues initially.
        outputs = [next_sample]
        outputs.extend(net.push_ops)

        print('Priming generation...')
        for i, x in enumerate(waveform[-net.receptive_field:-1]):
            if i % 100 == 0:
                print('Priming sample {}'.format(i))
            sess.run(outputs, feed_dict={samples: x})
        print('Done.')

    last_sample_timestamp = datetime.now()
    for step in range(args.samples):
        if args.fast_generation:
            outputs = [next_sample]
            outputs.extend(net.push_ops)
            window = waveform[-1]
        else:
            if len(waveform) > net.receptive_field:
                window = waveform[-net.receptive_field:]
            else:
                window = waveform
            outputs = [next_sample]

        if lc_enabled:
            if (step % lc_duration == 0):
                lc_window = local_condition[:1, :]
                local_condition[1:, :]
            prediction = sess.run(outputs,
                                  feed_dict={
                                      samples: window,
                                      lc_piece: lc_window
                                  })[0]
        else:
            # Run the WaveNet to predict the next sample.
            prediction = sess.run(outputs, feed_dict={samples: window})[0]

        # Scale prediction distribution using temperature.
        np.seterr(divide='ignore')
        scaled_prediction = np.log(prediction) / args.temperature
        scaled_prediction = (scaled_prediction -
                             np.logaddexp.reduce(scaled_prediction))
        scaled_prediction = np.exp(scaled_prediction)
        np.seterr(divide='warn')

        # Prediction distribution at temperature=1.0 should be unchanged after
        # scaling.
        if args.temperature == 1.0:
            np.testing.assert_allclose(
                prediction,
                scaled_prediction,
                atol=1e-5,
                err_msg='Prediction scaling at temperature=1.0 '
                'is not working as intended.')

        sample = np.random.choice(np.arange(quantization_channels),
                                  p=scaled_prediction)
        waveform.append(sample)

        # Show progress only once per second.
        current_sample_timestamp = datetime.now()
        time_since_print = current_sample_timestamp - last_sample_timestamp
        if time_since_print.total_seconds() > 1.:
            print('Sample {:3<d}/{:3<d}'.format(step + 1, args.samples),
                  end='\r')
            last_sample_timestamp = current_sample_timestamp

        # If we have partial writing, save the result so far.
        if (args.wav_out_path and args.save_every
                and (step + 1) % args.save_every == 0):
            out = sess.run(decode, feed_dict={samples: waveform})
            write_wav(out, wavenet_params['sample_rate'], args.wav_out_path)

    # Introduce a newline to clear the carriage return from the progress.
    print()

    # Save the result as an audio summary.
    datestring = str(datetime.now()).replace(' ', 'T')
    writer = tf.summary.FileWriter(logdir)
    tf.summary.audio('generated', decode, wavenet_params['sample_rate'])
    summaries = tf.summary.merge_all()
    summary_out = sess.run(summaries,
                           feed_dict={samples: np.reshape(waveform, [-1, 1])})
    writer.add_summary(summary_out)

    # Save the result as a wav file.
    if args.wav_out_path:
        out = sess.run(decode, feed_dict={samples: waveform})
        write_wav(out, wavenet_params['sample_rate'], args.wav_out_path)

    print('Finished generating. The result can be viewed in TensorBoard.')
示例#42
0
def eval_step(sess,logdir,step,waveform,upsampled_local_condition_data,speaker_id_data,mel_input_data,samples,speaker_id,upsampled_local_condition,next_sample,temperature=1.0):
    waveform = waveform[:,:1]
    
    sample_size = upsampled_local_condition_data.shape[1]
    last_sample_timestamp = datetime.now()
    start_time = time.time()
    for step2 in range(sample_size):  # 원하는 길이를 구하기 위해 loop sample_size
        window = waveform[:,-1:]  # 제일 끝에 있는 1개만 samples에 넣어 준다.  window: shape(N,1)
        

        prediction = sess.run(next_sample, feed_dict={samples: window,upsampled_local_condition: upsampled_local_condition_data[:,step2,:],speaker_id: speaker_id_data })


        if hparams.scalar_input:
            sample = prediction  # logistic distribution으로부터 sampling 되었기 때문에, randomness가 있다.
        else:
            # Scale prediction distribution using temperature.
            # 다음 과정은 config.temperature==1이면 각 원소를 합으로 나누어주는 것에 불과. 이미 softmax를 적용한 겂이므로, 합이 1이된다. 그래서 값의 변화가 없다.
            # config.temperature가 1이 아니며, 각 원소의 log취한 값을 나눈 후, 합이 1이 되도록 rescaling하는 것이 된다.
            np.seterr(divide='ignore')
            scaled_prediction = np.log(prediction) / temperature   # config.temperature인 경우는 값의 변화가 없다.
            scaled_prediction = (scaled_prediction - np.logaddexp.reduce(scaled_prediction,axis=-1,keepdims=True))  # np.log(np.sum(np.exp(scaled_prediction)))
            scaled_prediction = np.exp(scaled_prediction)
            np.seterr(divide='warn')
    
            # Prediction distribution at temperature=1.0 should be unchanged after
            # scaling.
            if temperature == 1.0:
                np.testing.assert_allclose( prediction, scaled_prediction, atol=1e-5, err_msg='Prediction scaling at temperature=1.0 is not working as intended.')
            
            # argmax로 선택하지 않기 때문에, 같은 입력이 들어가도 달라질 수 있다.
            sample = [[np.random.choice(np.arange(hparams.quantization_channels), p=p)] for p in scaled_prediction]  # choose one sample per batch
        
        waveform = np.concatenate([waveform,sample],axis=-1)   #window.shape: (N,1)

        # Show progress only once per second.
        current_sample_timestamp = datetime.now()
        time_since_print = current_sample_timestamp - last_sample_timestamp
        if time_since_print.total_seconds() > 1.:
            duration = time.time() - start_time
            print('Sample {:3<d}/{:3<d}, ({:.3f} sec/step)'.format(step2 + 1, sample_size, duration), end='\r')
            last_sample_timestamp = current_sample_timestamp
    
    print('\n')
    # Save the result as a wav file.    
    if hparams.input_type == 'raw':
        out = waveform[:,1:]
    elif hparams.input_type == 'mulaw':
        decode = mu_law_decode(samples, hparams.quantization_channels,quantization=False)
        out = sess.run(decode, feed_dict={samples: waveform[:,1:]})
    else:  # 'mulaw-quantize'
        decode = mu_law_decode(samples, hparams.quantization_channels,quantization=True)
        out = sess.run(decode, feed_dict={samples: waveform[:,1:]})          
        
        
    # save wav
    
    for i in range(1):
        wav_out_path= logdir + '/test-{}-{}.wav'.format(step,i)
        mel_path =  wav_out_path.replace(".wav", ".png")
        
        gen_mel_spectrogram = audio.melspectrogram(out[i], hparams).astype(np.float32).T
        audio.save_wav(out[i], wav_out_path, hparams.sample_rate)  # save_wav 내에서 out[i]의 값이 바뀐다.
        
        plot.plot_spectrogram(gen_mel_spectrogram, mel_path, title='generated mel spectrogram{}'.format(step),target_spectrogram=mel_input_data[i])