class TestGeneration(tf.test.TestCase):

    def setUp(self):
        self.net = WaveNetModel(batch_size=1,
                                dilations=[1, 2, 4, 8, 16, 32, 64, 128, 256],
                                filter_width=2,
                                residual_channels=16,
                                dilation_channels=16,
                                quantization_channels=128,
                                skip_channels=32)

    def testGenerateSimple(self):
        '''Generate a few samples using the naive method and
        perform sanity checks on the output.'''
        waveform = tf.placeholder(tf.int32)
        np.random.seed(0)
        data = np.random.randint(128, size=1000)
        proba = self.net.predict_proba(waveform)

        with self.test_session() as sess:
            sess.run(tf.initialize_all_variables())
            proba = sess.run(proba, feed_dict={waveform: data})

        self.assertAllEqual(proba.shape, [128])
        self.assertTrue(np.all((proba >= 0) & (proba <= (128 - 1))))

    def testGenerateFast(self):
        '''Generate a few samples using the fast method and
        perform sanity checks on the output.'''
        waveform = tf.placeholder(tf.int32)
        np.random.seed(0)
        data = np.random.randint(128)
        proba = self.net.predict_proba_incremental(waveform)

        with self.test_session() as sess:
            sess.run(tf.initialize_all_variables())
            sess.run(self.net.init_ops)
            proba = sess.run(proba, feed_dict={waveform: data})

        self.assertAllEqual(proba.shape, [128])
        self.assertTrue(np.all((proba >= 0) & (proba <= (128 - 1))))

    def testCompareSimpleFast(self):
        waveform = tf.placeholder(tf.int32)
        np.random.seed(0)
        data = np.random.randint(128, size=1)
        proba = self.net.predict_proba(waveform)
        proba_fast = self.net.predict_proba_incremental(waveform)
        with self.test_session() as sess:
            sess.run(tf.initialize_all_variables())
            sess.run(self.net.init_ops)
            # Prime the incremental generation with all samples
            # except the last one
            for x in data[:-1]:
                proba_fast_ = sess.run(
                    [proba_fast, self.net.push_ops],
                    feed_dict={waveform: x})

            # Get the last sample from the incremental generator
            proba_fast_ = sess.run(
                proba_fast,
                feed_dict={waveform: data[-1]})
            # Get the sample from the simple generator
            proba_ = sess.run(proba, feed_dict={waveform: data})
            self.assertAllClose(proba_, proba_fast_)
Beispiel #2
0
def main():
    args = get_arguments()
    started_datestring = "{0:%Y-%m-%dT%H-%M-%S}".format(datetime.now())
    logdir = os.path.join(args.logdir, 'generate', started_datestring)
    with open(args.wavenet_params, 'r') as config_file:
        wavenet_params = json.load(config_file)

    sess = tf.Session()

    net = WaveNetModel(
        batch_size=1,
        dilations=wavenet_params['dilations'],
        filter_width=wavenet_params['filter_width'],
        residual_channels=wavenet_params['residual_channels'],
        dilation_channels=wavenet_params['dilation_channels'],
        quantization_channels=wavenet_params['quantization_channels'],
        skip_channels=wavenet_params['skip_channels'],
        use_biases=wavenet_params['use_biases'],
        scalar_input=wavenet_params['scalar_input'],
        initial_filter_width=wavenet_params['initial_filter_width'])

    samples = tf.placeholder(tf.int32)  #待初始化的张量占位符
    input_ID = tf.placeholder(tf.int32)
    startime_fastgeration = time.clock()
    if args.fast_generation:
        print("#########using_fast_generation")
        next_sample = net.predict_proba_incremental(samples, input_ID)

        #print next_sample

    else:
        next_sample = net.predict_proba(samples, input_ID)

    if args.fast_generation:
        sess.run(tf.initialize_all_variables())
        sess.run(net.init_ops)
    endtime_fastgernation = time.clock()
    #print ('fast_generation time {}'.format(endtime_fastgernation - endtime_fastgernation))
    time_of_fast = endtime_fastgernation - startime_fastgeration  #1

    start_vari_saver = time.clock()  #变量save
    variables_to_restore = {
        var.name[:-2]: var
        for var in tf.all_variables()
        if not ('state_buffer' in var.name or 'pointer' in var.name)
    }
    saver = tf.train.Saver(variables_to_restore)
    end_vari_saver = time.clock()
    print('variables_to_restore{}'.format(end_vari_saver - start_vari_saver))

    starttime_restore = time.clock()  #恢复从checkpoint
    print('Restoring model from {}'.format(args.checkpoint))
    saver.restore(sess, args.checkpoint)
    endtime_restore = time.clock()
    #print ('restore model time{}'.format(endtime_restore - endtime_restore))
    time_of_restore = endtime_restore - starttime_restore  #2
    print('%%%%%%%%%%%%{}'.format(time_of_restore))
    #return 0
    decode = mu_law_decode(
        samples, wavenet_params['quantization_channels'])  #namescope(encode)

    quantization_channels = wavenet_params['quantization_channels']
    time_of_seed = 0
    if args.wav_seed:
        start_using_create_seed = time.clock()
        print('#######using_create_seed')
        seed = create_seed(args.wav_seed, wavenet_params['sample_rate'],
                           quantization_channels)
        waveform = sess.run(seed).tolist()  #
        end_using_create_seed = time.clock()
        time_of_seed = end_using_create_seed - start_using_create_seed
        #print ('using create_seed time{}'.format(end_using_create_seed - start_using_create_seed))
    else:
        print('#######not_using_create_seed')
        waveform = np.random.randint(quantization_channels,
                                     size=(1, )).tolist()
    predict_of_fast_seed = 0
    if args.fast_generation and args.wav_seed:
        starttime_fast_and_seed = time.clock()
        # When using the incremental generation, we need to
        # feed in all priming samples one by one before starting the
        # actual generation.
        # TODO This could be done much more efficiently by passing the waveform
        # to the incremental generator as an optional argument, which would be
        # used to fill the queues initially.
        outputs = [next_sample]
        outputs.extend(net.push_ops)  #push_ops是一个列表

        print('Priming generation...')
        for i, x in enumerate(waveform[:-1]):
            if i % 1600 == 0:
                print('Priming sample {}'.format(i), end='\r')
                sys.stdout.flush()
            sess.run(outputs,
                     feed_dict={
                         samples: x,
                         input_ID: args.ID_generation
                     })
        print('Done.')
        endtime_fast_seed = time.clock()
        #print('fast_generation and create_seed time{}'.format(endtime_fast_seed - starttime_fast_and_seed))
        predict_of_fast_seed = predict_of_fast_seed + (endtime_fast_seed -
                                                       starttime_fast_and_seed)

    #return 0
    last_sample_timestamp = datetime.now()
    predict = 0
    index_begin_generate = 0 if (False
                                 == args.fast_generation) else len(waveform)
    startime_total_predict_sample = time.clock()
    for step in range(args.samples):
        if args.fast_generation:
            outputs = [next_sample]
            outputs.extend(net.push_ops)
            window = waveform[-1]
        else:
            if len(waveform) > args.window:  #波形大于窗口?
                window = waveform[-args.window:]
            else:
                window = waveform
            outputs = [next_sample]

        # Run the WaveNet to predict the next sample.
        starttime_run_net_predict = time.clock()
        prediction = sess.run(outputs,
                              feed_dict={
                                  samples: window,
                                  input_ID: args.ID_generation
                              })[0]
        endtime_run_net_predict = time.clock()
        print('run net to predict samples per step{}'.format(
            endtime_run_net_predict - starttime_run_net_predict))
        predict = predict + (endtime_run_net_predict -
                             starttime_run_net_predict)

        # Scale prediction distribution using temperature.尺度化预测!!

        np.seterr(divide='ignore')  #某种数据规约尺度化
        scaled_prediction = np.log(prediction) / args.temperature
        scaled_prediction = scaled_prediction - np.logaddexp.reduce(
            scaled_prediction)
        scaled_prediction = np.exp(scaled_prediction)
        np.seterr(divide='warn')

        # Prediction distribution at temperature=1.0 should be unchanged after scaling.
        if args.temperature == 1.0:
            np.testing.assert_allclose(
                prediction,
                scaled_prediction,
                atol=1e-5,
                err_msg=
                'Prediction scaling at temperature=1.0 is not working as intended.'
            )

        sample = np.random.choice(  #以概率返回某层通道采样
            np.arange(quantization_channels),
            p=scaled_prediction)
        waveform.append(sample)

        # Show progress only once per second.
        current_sample_timestamp = datetime.now()
        time_since_print = current_sample_timestamp - last_sample_timestamp
        if time_since_print.total_seconds() > 1.:  #以1??
            print('Sample {:3<d}/{:3<d}'.format(step + 1, args.samples),
                  end='\r')
            sys.stdout.flush()
            last_sample_timestamp = current_sample_timestamp

        # If we have partial writing, save the result so far.
        if (args.wav_out_path and args.save_every
                and (step + 1) % args.save_every == 0):
            print(
                '$$$$$$$$$$If we have partial writing, save the result so far')
            out = sess.run(decode, feed_dict={samples:
                                              waveform})  #有输入要求的tensor
            write_wav(out, wavenet_params['sample_rate'], args.wav_out_path)
    endtime_total_predicttime = time.clock()
    print('total predic time{}'.format(endtime_total_predicttime -
                                       startime_total_predict_sample))
    print('run net predict time{}'.format(predict))

    # Introduce a newline to clear the carriage return from the progress.
    print()

    # Save the result as an audio summary.
    '''
    datestring = str(datetime.now()).replace(' ', 'T')
    writer = tf.train.SummaryWriter(logdir)
    tf.audio_summary('generated', decode, wavenet_params['sample_rate'])
    summaries = tf.merge_all_summaries()
    summary_out = sess.run(summaries,
                           feed_dict={samples: np.reshape(waveform[index_begin_generate:], [-1, 1]), input_ID: args.ID_generation})
    writer.add_summary(summary_out)
    '''

    # Save the result as a wav file.
    if args.wav_out_path:
        start_save_wav_time = time.clock()
        out = sess.run(decode,
                       feed_dict={samples: waveform[index_begin_generate:]})

        write_wav(out, wavenet_params['sample_rate'], args.wav_out_path)
        end_save_wave_time = time.clock()
        print('write wave time{}'.format(end_save_wave_time -
                                         start_save_wav_time))

    print('time_of_fast_initops{}'.format(time_of_fast))
    print('time_of_restore'.format(time_of_restore))
    print('time_of_fast_and_seed{}'.format(predict_of_fast_seed))
    print('time_of_seed'.format(time_of_seed))
    print('Finished generating. The result can be viewed in TensorBoard.')
Beispiel #3
0
def main():
    args = get_arguments()
    started_datestring = "{0:%Y-%m-%dT%H-%M-%S}".format(datetime.now())
    logdir = os.path.join(args.logdir, 'generate', started_datestring)
    with open(args.wavenet_params, 'r') as config_file:
        wavenet_params = json.load(config_file)

    labels = tf.placeholder(tf.float32)

    data_dir = DATA_DIRECTORY
    file_list = FILE_LIST
    label_dir = data_dir + 'binary_label_norm/'
    audio_dir = data_dir + 'wav/'
    label_dim = 425
    n_out = 1

    iterator = audio_reader.load_generic_audio_label(file_list, audio_dir,
                                                     label_dir, label_dim)
    audio_test, labels_test, filename = iterator.next()
    n_samples_read = audio_test.shape[0]
    labels_test = labels_test.reshape(
        (1, labels_test.shape[0], labels_test.shape[1]))

    sess = tf.Session()

    net = WaveNetModel(
        batch_size=1,
        dilations=wavenet_params['dilations'],
        filter_width=wavenet_params['filter_width'],
        label_dim=label_dim,
        residual_channels=wavenet_params['residual_channels'],
        dilation_channels=wavenet_params['dilation_channels'],
        skip_channels=wavenet_params['skip_channels'],
        quantization_channels=wavenet_params['quantization_channels'],
        use_biases=wavenet_params['use_biases'],
        scalar_input=wavenet_params['scalar_input'],
        initial_filter_width=wavenet_params['initial_filter_width'],
        histograms=False)

    samples = tf.placeholder(tf.int32)

    if args.fast_generation:
        next_sample = net.predict_proba_incremental(samples, labels)
    else:
        next_sample = net.predict_proba(samples, labels)

    if args.fast_generation:
        sess.run(tf.initialize_all_variables())
        sess.run(net.init_ops)

    variables_to_restore = {
        var.name[:-2]: var
        for var in tf.all_variables()
        if not ('state_buffer' in var.name or 'pointer' in var.name)
    }
    saver = tf.train.Saver(variables_to_restore)

    print('Restoring model from {}'.format(args.checkpoint))
    saver.restore(sess, args.checkpoint)

    decode = mu_law_decode(samples, wavenet_params['quantization_channels'])

    quantization_channels = wavenet_params['quantization_channels']
    if args.wav_seed:
        seed = create_seed(args.wav_seed, wavenet_params['sample_rate'],
                           quantization_channels)
        waveform = sess.run(seed).tolist()
    else:
        waveform = np.random.randint(quantization_channels,
                                     size=(1, )).tolist()

    if args.fast_generation and args.wav_seed:
        # When using the incremental generation, we need to
        # feed in all priming samples one by one before starting the
        # actual generation.
        # TODO This could be done much more efficiently by passing the waveform
        # to the incremental generator as an optional argument, which would be
        # used to fill the queues initially.
        outputs = [next_sample]
        outputs.extend(net.push_ops)

        print('Priming generation...')
        for i, x in enumerate(waveform[-args.window:-1]):
            if i % 100 == 0:
                print('Priming sample {}'.format(i))
            sess.run(outputs,
                     feed_dict={
                         samples: x,
                         labels: labels_test[:, i:i + 1, :]
                     })
        print('Done.')

    last_sample_timestamp = datetime.now()
    for step in range(n_samples_read):
        if args.fast_generation:
            outputs = [next_sample]
            outputs.extend(net.push_ops)
            window = waveform[-1]
            labels_window = labels_test[:, step:step + 1, :]
        else:
            if len(waveform) > args.window:
                window = waveform[-args.window:]
            else:
                window = waveform
            outputs = [next_sample]
            labels_window = labels_test[:, step:step + min(
                len(window), args.window
            ), :]  # Here there might be a problem with out of index error.

        # Run the WaveNet to predict the next sample.
        #if (step%100 == 0):
        #    print('step = ', step, ' , ')
        prediction = sess.run(outputs,
                              feed_dict={
                                  samples: window,
                                  labels: labels_window
                              })[0]

        # Scale prediction distribution using temperature.
        np.seterr(divide='ignore')
        scaled_prediction = np.log(prediction) / args.temperature
        scaled_prediction = scaled_prediction - np.logaddexp.reduce(
            scaled_prediction)
        scaled_prediction = np.exp(scaled_prediction)
        np.seterr(divide='warn')

        # Prediction distribution at temperature=1.0 should be unchanged after scaling.
        if args.temperature == 1.0:
            np.testing.assert_allclose(
                prediction,
                scaled_prediction,
                atol=1e-5,
                err_msg=
                'Prediction scaling at temperature=1.0 is not working as intended.'
            )

        sample = np.random.choice(np.arange(quantization_channels),
                                  p=scaled_prediction)
        waveform.append(sample)

        # Show progress only once per second.
        current_sample_timestamp = datetime.now()
        time_since_print = current_sample_timestamp - last_sample_timestamp
        if time_since_print.total_seconds() > 1.:
            print('Sample {:3<d}/{:3<d}'.format(step + 1, args.samples),
                  end='\r')
            last_sample_timestamp = current_sample_timestamp

        # If we have partial writing, save the result so far.
        if (args.wav_out_path and args.save_every
                and (step + 1) % args.save_every == 0):
            out = sess.run(decode, feed_dict={samples: waveform})
            write_wav(out, wavenet_params['sample_rate'], args.wav_out_path)

    # Introduce a newline to clear the carriage return from the progress.
    print()

    # Save the result as an audio summary.
    datestring = str(datetime.now()).replace(' ', 'T')
    writer = tf.train.SummaryWriter(logdir)
    tf.audio_summary('generated', decode, wavenet_params['sample_rate'])
    summaries = tf.merge_all_summaries()
    summary_out = sess.run(summaries,
                           feed_dict={samples: np.reshape(waveform, [-1, 1])})
    writer.add_summary(summary_out)

    # Save the result as a wav file.
    if args.wav_out_path:
        out = sess.run(decode, feed_dict={samples: waveform})
        write_wav(out, wavenet_params['sample_rate'], args.wav_out_path)

    print('Finished generating. The result can be viewed in TensorBoard.')
Beispiel #4
0
def main():
    args = get_arguments()
    started_datestring = "{0:%Y-%m-%dT%H-%M-%S}".format(datetime.now())
    #logdir is where logging file is saved. different from where generated mat is saved.
    logdir = os.path.join(args.logdir, 'generate', started_datestring)
    with open(args.wavenet_params, 'r') as config_file:
        wavenet_params = json.load(config_file)

    sess = tf.Session()

    input_channels = wavenet_params['input_channels']
    output_channels = wavenet_params['output_channels']
    gt, cut_index = create_seed(args.motion_seed, args.window)
    if np.isnan(np.sum(gt)):
        print('nan detected')
        raise ValueError('NAN detected in seed file')
    seed = tf.constant(gt)

    net = WaveNetModel(batch_size=1,
                       dilations=wavenet_params['dilations'],
                       filter_width=wavenet_params['filter_width'],
                       residual_channels=wavenet_params['residual_channels'],
                       dilation_channels=wavenet_params['dilation_channels'],
                       skip_channels=wavenet_params['skip_channels'],
                       use_biases=wavenet_params['use_biases'],
                       input_channels=input_channels,
                       output_channels=output_channels,
                       global_condition_channels=args.gc_channels)

    samples = tf.placeholder(dtype=tf.float32)

    next_sample = net.predict_proba_incremental(samples)
    sess.run(tf.initialize_all_variables())
    sess.run(net.init_ops)
    #TODO: run init_ops only once

    variables_to_restore = {
        var.name[:-2]: var
        for var in tf.all_variables()
        if not ('state_buffer' in var.name or 'pointer' in var.name)
    }
    saver = tf.train.Saver(variables_to_restore)

    print('Restoring model from {}'.format(args.checkpoint))
    saver.restore(sess, args.checkpoint)
    if args.motion_seed:
        pass
    else:
        raise ValueError('motion seed not specified!')

    # seed: T x 42 tensor
    # tolist() converts a tf tensor to a list
    gt_list = sess.run(seed).tolist()
    motion = gt_list[:cut_index]
    #motion: list of generated data (along with seed)
    # When using the incremental generation, we need to
    # feed in all priming samples one by one before starting the
    # actual generation.
    # TODO This could be done much more efficiently by passing the waveform
    # to the incremental generator as an optional argument, which would be
    # used to fill the queues initially.
    outputs = [next_sample]
    outputs.extend(net.push_ops)
    #TODO: question: everytime runs next_sample <- predict_proba_incremental(samples), will the q be reinitialized? or just use the queue with elements inserted before?
    print('Priming generation...')
    #for i, x in enumerate(motion[-net.receptive_field: -1]):
    for i, x in enumerate(motion[-net.receptive_field:-2]):
        if i % 10 == 0:
            print('Priming sample {}'.format(i))
        sess.run(outputs,
                 feed_dict={samples: np.reshape(x, (1, input_channels))})
    print('Done.')
    #TODO: check how next_sample <- net.predict_proba_incremental(samples) works. sample is of size 1 x input_channels.
    #TODO: then check if motion[-1] is fed into network twice.
    last_sample_timestamp = datetime.now()
    for step in range(args.samples):
        outputs = [next_sample]
        outputs.extend(net.push_ops)
        window = motion[-1]

        #TODO: why motion[-1] fed into network twice?
        # Run the WaveNet to predict the next sample.
        prediction = sess.run(
            outputs,
            feed_dict={samples: np.reshape(window, (1, input_channels))})[0]
        #TODO: next_input = np.concatenate((prediction, gt(4:9)), axis=1). motion.append(next_input)
        motion.append(
            np.concatenate(
                (np.reshape(prediction, (1, output_channels)),
                 np.reshape(gt_list[cut_index + step][output_channels:],
                            (1, input_channels - output_channels))),
                axis=1))
        # Show progress only once per second.
        current_sample_timestamp = datetime.now()
        time_since_print = current_sample_timestamp - last_sample_timestamp
        if time_since_print.total_seconds() > 1.:
            print('Sample {:3<d}/{:3<d}'.format(step + 1, args.samples),
                  end='\r')
            last_sample_timestamp = current_sample_timestamp

    print()

    # save result in .mat file
    if args.skeleton_out_path:
        #TODO: save according to Hanbyul rules
        # outdir = os.path.join('logdir','skeleton_generate', os.path.basename(os.path.dirname(args.checkpoint)) + os.path.basename(args.checkpoint)+'window'+str(args.window)+'sample'+str(args.samples))
        outdir_base = os.path.join(
            args.skeleton_out_path,
            os.path.basename(os.path.dirname(args.checkpoint)))
        if not os.path.exists(outdir_base):
            os.makedirs(outdir_base)
        scene_name = os.path.basename(os.path.dirname(args.motion_seed))
        scene_dir = os.path.join(outdir_base, scene_name)
        if not os.path.exists(scene_dir):
            os.makedirs(scene_dir)
        filedir = os.path.join(scene_dir, os.path.basename(args.motion_seed))

        motion_array = np.array(motion)
        np.savetxt(filedir, motion_array[:, :output_channels], delimiter=',')
        #sio.savemat(filedir, {'sequence_gt': gt, 'sequence_predict': motion[:, :output_channels], 'global_T': global_T, 'global_Theta': global_Theta,
        #                      'startFrames': startFrames, 'datatype': foldername, 'testdataPath: '})
        print(len(motion))
        print('generated filedir:{0}'.format(filedir))
    print('Finished generating. The result can be viewed in Matlab.')
def main():
    args = get_arguments()
    started_datestring = "{0:%Y-%m-%dT%H-%M-%S}".format(datetime.now())
    logdir = os.path.join(args.logdir, 'generate', started_datestring)
    with open(args.wavenet_params, 'r') as config_file:
        wavenet_params = json.load(config_file)

    sess = tf.Session()

    net = WaveNetModel(
        batch_size=1,
        dilations=wavenet_params['dilations'],
        filter_width=wavenet_params['filter_width'],
        residual_channels=wavenet_params['residual_channels'],
        dilation_channels=wavenet_params['dilation_channels'],
        quantization_channels=wavenet_params['quantization_channels'],
        skip_channels=wavenet_params['skip_channels'],
        use_biases=wavenet_params['use_biases'],
        scalar_input=wavenet_params['scalar_input'],
        initial_filter_width=wavenet_params['initial_filter_width'])

    samples = tf.placeholder(tf.int32)

    if args.fast_generation:
        next_sample = net.predict_proba_incremental(samples)
    else:
        next_sample = net.predict_proba(samples)

    if args.fast_generation:
        sess.run(tf.initialize_all_variables())
        sess.run(net.init_ops)

    variables_to_restore = {
        var.name[:-2]: var for var in tf.all_variables()
        if not ('state_buffer' in var.name or 'pointer' in var.name)}
    saver = tf.train.Saver(variables_to_restore)

    print('Restoring model from {}'.format(args.checkpoint))
    saver.restore(sess, args.checkpoint)

    decode = mu_law_decode(samples, wavenet_params['quantization_channels'])

    quantization_channels = wavenet_params['quantization_channels']
    if args.wav_seed:
        seed = create_seed(args.wav_seed,
                           wavenet_params['sample_rate'],
                           quantization_channels)
        waveform = sess.run(seed).tolist()
    else:
        waveform = np.random.randint(quantization_channels, size=(1,)).tolist()

    if args.fast_generation and args.wav_seed:
        # When using the incremental generation, we need to
        # feed in all priming samples one by one before starting the
        # actual generation.
        # TODO This could be done much more efficiently by passing the waveform
        # to the incremental generator as an optional argument, which would be
        # used to fill the queues initially.
        outputs = [next_sample]
        outputs.extend(net.push_ops)

        print('Priming generation...')
        for i, x in enumerate(waveform[:-(args.window + 1)]):
            if i % 100 == 0:
                print('Priming sample {}'.format(i))
            sess.run(outputs, feed_dict={samples: x})
        print('Done.')

    last_sample_timestamp = datetime.now()
    for step in range(args.samples):
        if args.fast_generation:
            outputs = [next_sample]
            outputs.extend(net.push_ops)
            window = waveform[-1]
        else:
            if len(waveform) > args.window:
                window = waveform[-args.window:]
            else:
                window = waveform
            outputs = [next_sample]

        # Run the WaveNet to predict the next sample.
        prediction = sess.run(outputs, feed_dict={samples: window})[0]
        sample = np.random.choice(
            np.arange(quantization_channels), p=prediction)
        waveform.append(sample)

        # Show progress only once per second.
        current_sample_timestamp = datetime.now()
        time_since_print = current_sample_timestamp - last_sample_timestamp
        if time_since_print.total_seconds() > 1.:
            print('Sample {:3<d}/{:3<d}'.format(step + 1, args.samples),
                  end='\r')
            last_sample_timestamp = current_sample_timestamp

        # If we have partial writing, save the result so far.
        if (args.wav_out_path and args.save_every and
                (step + 1) % args.save_every == 0):
            out = sess.run(decode, feed_dict={samples: waveform})
            write_wav(out, wavenet_params['sample_rate'], args.wav_out_path)

    # Introduce a newline to clear the carriage return from the progress.
    print()

    # Save the result as an audio summary.
    datestring = str(datetime.now()).replace(' ', 'T')
    writer = tf.train.SummaryWriter(logdir)
    tf.audio_summary('generated', decode, wavenet_params['sample_rate'])
    summaries = tf.merge_all_summaries()
    summary_out = sess.run(summaries,
                           feed_dict={samples: np.reshape(waveform, [-1, 1])})
    writer.add_summary(summary_out)

    # Save the result as a wav file.
    if args.wav_out_path:
        out = sess.run(decode, feed_dict={samples: waveform})
        write_wav(out, wavenet_params['sample_rate'], args.wav_out_path)

    print('Finished generating. The result can be viewed in TensorBoard.')
def main():
    args = get_arguments()
    started_datestring = "{0:%Y-%m-%dT%H-%M-%S}".format(datetime.now())
    logdir = os.path.join(args.logdir, 'generate', started_datestring)
    with open(args.wavenet_params, 'r') as config_file:
        wavenet_params = json.load(config_file)

    lc_enabled = args.lc_channels is not None
    lc_channels = args.lc_channels
    lc_duration = args.lc_duration

    sess = tf.Session()

    net = WaveNetModel(
        batch_size=1,
        dilations=wavenet_params['dilations'],
        filter_width=wavenet_params['filter_width'],
        residual_channels=wavenet_params['residual_channels'],
        dilation_channels=wavenet_params['dilation_channels'],
        quantization_channels=wavenet_params['quantization_channels'],
        skip_channels=wavenet_params['skip_channels'],
        use_biases=wavenet_params['use_biases'],
        scalar_input=wavenet_params['scalar_input'],
        initial_filter_width=wavenet_params['initial_filter_width'],
        global_condition_channels=args.gc_channels,
        global_condition_cardinality=args.gc_cardinality,
        local_condition_channels=args.lc_channels)

    samples = tf.placeholder(tf.int32)
    lc_piece = tf.placeholder(tf.float32, [1, lc_channels])
    local_condition = load_lc(os.path.join(os.getcwd(), args.lc_path))

    args.samples = args.lc_duration * len(local_condition)

    if args.fast_generation:
        next_sample = net.predict_proba_incremental(samples, args.gc_id,
                                                    lc_piece)
    else:
        next_sample = net.predict_proba(samples, args.gc_id, lc_piece)

    if args.fast_generation:
        sess.run(tf.global_variables_initializer())
        sess.run(net.init_ops)

    variables_to_restore = {
        var.name[:-2]: var
        for var in tf.global_variables()
        if not ('state_buffer' in var.name or 'pointer' in var.name)
    }
    saver = tf.train.Saver(variables_to_restore)

    print('Restoring model from {}'.format(args.checkpoint))
    saver.restore(sess, args.checkpoint)

    decode = mu_law_decode(samples, wavenet_params['quantization_channels'])

    quantization_channels = wavenet_params['quantization_channels']
    if args.wav_seed:
        seed = create_seed(args.wav_seed, wavenet_params['sample_rate'],
                           quantization_channels, net.receptive_field)
        waveform = sess.run(seed).tolist()
    else:
        # Silence with a single random sample at the end.
        waveform = [quantization_channels / 2] * (net.receptive_field - 1)
        waveform.append(np.random.randint(quantization_channels))

    if args.fast_generation and args.wav_seed:
        # When using the incremental generation, we need to
        # feed in all priming samples one by one before starting the
        # actual generation.
        # TODO This could be done much more efficiently by passing the waveform
        # to the incremental generator as an optional argument, which would be
        # used to fill the queues initially.
        outputs = [next_sample]
        outputs.extend(net.push_ops)

        print('Priming generation...')
        for i, x in enumerate(waveform[-net.receptive_field:-1]):
            if i % 100 == 0:
                print('Priming sample {}'.format(i))
            sess.run(outputs, feed_dict={samples: x})
        print('Done.')

    last_sample_timestamp = datetime.now()
    for step in range(args.samples):
        if args.fast_generation:
            outputs = [next_sample]
            outputs.extend(net.push_ops)
            window = waveform[-1]
        else:
            if len(waveform) > net.receptive_field:
                window = waveform[-net.receptive_field:]
            else:
                window = waveform
            outputs = [next_sample]

        if lc_enabled:
            if (step % lc_duration == 0):
                lc_window = local_condition[:1, :]
                local_condition[1:, :]
            prediction = sess.run(outputs,
                                  feed_dict={
                                      samples: window,
                                      lc_piece: lc_window
                                  })[0]
        else:
            # Run the WaveNet to predict the next sample.
            prediction = sess.run(outputs, feed_dict={samples: window})[0]

        # Scale prediction distribution using temperature.
        np.seterr(divide='ignore')
        scaled_prediction = np.log(prediction) / args.temperature
        scaled_prediction = (scaled_prediction -
                             np.logaddexp.reduce(scaled_prediction))
        scaled_prediction = np.exp(scaled_prediction)
        np.seterr(divide='warn')

        # Prediction distribution at temperature=1.0 should be unchanged after
        # scaling.
        if args.temperature == 1.0:
            np.testing.assert_allclose(
                prediction,
                scaled_prediction,
                atol=1e-5,
                err_msg='Prediction scaling at temperature=1.0 '
                'is not working as intended.')

        sample = np.random.choice(np.arange(quantization_channels),
                                  p=scaled_prediction)
        waveform.append(sample)

        # Show progress only once per second.
        current_sample_timestamp = datetime.now()
        time_since_print = current_sample_timestamp - last_sample_timestamp
        if time_since_print.total_seconds() > 1.:
            print('Sample {:3<d}/{:3<d}'.format(step + 1, args.samples),
                  end='\r')
            last_sample_timestamp = current_sample_timestamp

        # If we have partial writing, save the result so far.
        if (args.wav_out_path and args.save_every
                and (step + 1) % args.save_every == 0):
            out = sess.run(decode, feed_dict={samples: waveform})
            write_wav(out, wavenet_params['sample_rate'], args.wav_out_path)

    # Introduce a newline to clear the carriage return from the progress.
    print()

    # Save the result as an audio summary.
    datestring = str(datetime.now()).replace(' ', 'T')
    writer = tf.summary.FileWriter(logdir)
    tf.summary.audio('generated', decode, wavenet_params['sample_rate'])
    summaries = tf.summary.merge_all()
    summary_out = sess.run(summaries,
                           feed_dict={samples: np.reshape(waveform, [-1, 1])})
    writer.add_summary(summary_out)

    # Save the result as a wav file.
    if args.wav_out_path:
        out = sess.run(decode, feed_dict={samples: waveform})
        write_wav(out, wavenet_params['sample_rate'], args.wav_out_path)

    print('Finished generating. The result can be viewed in TensorBoard.')
Beispiel #7
0
def main():
    midi_dims = 88
    args = get_arguments()
    started_datestring = "{0:%Y-%m-%dT%H-%M-%S}".format(datetime.now())
    logdir = os.path.join(args.logdir, 'generate', started_datestring)
    checkpoint = tf.train.latest_checkpoint(args.resdir)
    print('checkpoint: ', checkpoint)
    wavenet_params_fname = args.resdir + 'wavenet_params.json'
    print('wavenet params fname', wavenet_params_fname)
    with open(wavenet_params_fname, 'r') as config_file:
        wavenet_params = json.load(config_file)
        wavenet_params['midi_dims'] = midi_dims

    sess = tf.Session()

    net = WaveNetModel(
        batch_size=1,
        dilations=wavenet_params['dilations'],
        filter_width=wavenet_params['filter_width'],
        residual_channels=wavenet_params['residual_channels'],
        dilation_channels=wavenet_params['dilation_channels'],
        skip_channels=wavenet_params['skip_channels'],
        use_biases=wavenet_params['use_biases'],
        scalar_input=wavenet_params['scalar_input'],
        midi_dims=wavenet_params['midi_dims'],
        initial_filter_width=wavenet_params['initial_filter_width'],
        global_condition_channels=args.gc_channels,
        global_condition_cardinality=args.gc_cardinality)

    samples = tf.placeholder(tf.float32, shape=(None, midi_dims))

    if args.fast_generation:
        next_sample = net.predict_proba_incremental(samples, args.gc_id)
    else:
        next_sample = net.predict_proba(samples, args.gc_id)

    if args.fast_generation:
        sess.run(tf.global_variables_initializer())
        sess.run(net.init_ops)

    variables_to_restore = {
        var.name[:-2]: var
        for var in tf.global_variables()
        if not ('state_buffer' in var.name or 'pointer' in var.name)
    }
    saver = tf.train.Saver(variables_to_restore)
    print('vars to restore')
    for key in variables_to_restore.keys():
        print(key, variables_to_restore[key])

    print('Restoring model from {}'.format(checkpoint))
    saver.restore(sess, checkpoint)

    if args.wav_seed:
        seed = create_seed(args.wav_seed, wavenet_params['sample_rate'],
                           midi_dims, net.receptive_field)
        waveform = sess.run(seed).tolist()
    else:
        # Silence with a single random sample at the end.
        #waveform = [quantization_channels / 2] * (net.receptive_field - 1)
        #waveform.append(np.random.randint(quantization_channels))
        random_note = np.zeros((1, midi_dims))
        random_note[0, np.random.randint(0, midi_dims - 1)] = 1.0
        waveform = np.concatenate((np.zeros(
            (net.receptive_field - 1, midi_dims)), random_note),
                                  axis=0)

    if args.fast_generation and args.wav_seed:
        print('fast gen')
        # When using the incremental generation, we need to
        # feed in all priming samples one by one before starting the
        # actual generation.
        # TODO This could be done much more efficiently by passing the waveform
        # to the incremental generator as an optional argument, which would be
        # used to fill the queues initially.
        outputs = [next_sample]
        outputs.extend(net.push_ops)

        print('Priming generation...')
        for i, x in enumerate(waveform[-net.receptive_field:-1]):
            if i % 100 == 0:
                print('Priming sample {}'.format(i))
            sess.run(outputs, feed_dict={samples: x})
        print('Done.')

    print('receptive field is %d' % net.receptive_field)
    last_sample_timestamp = datetime.now()
    for step in range(args.samples):
        if args.fast_generation:
            outputs = [next_sample]
            outputs.extend(net.push_ops)
            window = np.expand_dims(waveform[-1, :], 0)
        else:
            if len(waveform) > net.receptive_field:
                window = waveform[-net.receptive_field:]
            else:
                window = waveform
            outputs = [next_sample]

        print(step, 'wave shape', waveform.shape, 'window shape', window.shape)
        # Run the WaveNet to predict the next sample.
        prediction = sess.run(outputs, feed_dict={samples: window})[0]

        # Scale prediction distribution using temperature.
        #np.seterr(divide='ignore')
        #scaled_prediction = np.log(prediction) / args.temperature
        #scaled_prediction = (scaled_prediction -
        #                     np.logaddexp.reduce(scaled_prediction))
        #scaled_prediction = np.exp(scaled_prediction)
        #np.seterr(divide='warn')

        # Prediction distribution at temperature=1.0 should be unchanged after
        # scaling.
        #if args.temperature == 1.0:
        #    np.testing.assert_allclose(
        #            prediction, scaled_prediction, atol=1e-5,
        #           err_msg='Prediction scaling at temperature=1.0 '
        #                    'is not working as intended.')
        sample = 1 * (prediction > 0.5)
        print('num notes', np.count_nonzero(sample))
        #sample = np.random.choice(
        #    np.arange(quantization_channels), p=scaled_prediction)
        waveform = np.concatenate((waveform, np.expand_dims(sample, 0)),
                                  axis=0)

        # Show progress only once per second.
        current_sample_timestamp = datetime.now()
        time_since_print = current_sample_timestamp - last_sample_timestamp
        if time_since_print.total_seconds() > 1.:
            print('Sample {:3<d}/{:3<d}'.format(step + 1, args.samples),
                  end='\r')
            last_sample_timestamp = current_sample_timestamp

    # Introduce a newline to clear the carriage return from the progress.
    print()

    # Save the result as an audio summary.
    datestring = str(datetime.now()).replace(' ', 'T')
    #writer = tf.summary.FileWriter(logdir)
    #tf.summary.audio('generated', decode, wavenet_params['sample_rate'])
    #summaries = tf.summary.merge_all()
    #print('waveform', waveform);
    #summary_out = sess.run(summaries, feed_dict={samples: waveform})
    #writer.add_summary(summary_out)

    # Save the result as a wav file.
    if args.wav_out_path is None:
        args.wav_out_path = args.resdir

    #out = sess.run(decode, feed_dict={samples: waveform})
    print(args.wav_out_path)
    filename = args.wav_out_path + ('sample_%d.mid' % int(args.gen_num))
    midiwrite(filename, waveform)
Beispiel #8
0
def main():
    args = get_arguments()
    started_datestring = "{0:%Y-%m-%dT%H-%M-%S}".format(datetime.now())
    logdir = os.path.join(args.logdir, 'generate', started_datestring)
    with open(args.wavenet_params, 'r') as config_file:
        wavenet_params = json.load(config_file)

    sess = tf.Session()

    net = WaveNetModel(
        batch_size=1,
        dilations=wavenet_params["dilations"],
        filter_width=wavenet_params["filter_width"],
        residual_channels=wavenet_params["residual_channels"],
        dilation_channels=wavenet_params["dilation_channels"],
        skip_channels=wavenet_params["skip_channels"],
        use_biases=wavenet_params["use_biases"],
        scalar_input=wavenet_params["scalar_input"],
        initial_filter_width=wavenet_params["initial_filter_width"],
        histograms=False,
        global_condition_channels=args.gc_channels,
        global_condition_cardinality=args.gc_cardinality,
        MFSC_channels=wavenet_params["MFSC_channels"],
        AP_channels=wavenet_params["AP_channels"],
        F0_channels=wavenet_params["F0_channels"],
        phone_channels=wavenet_params["phones_channels"],
        phone_pos_channels=wavenet_params["phone_pos_channels"])

    samples = tf.placeholder(tf.float32)
    lc = tf.placeholder(tf.float32)

    AP_channels = wavenet_params["AP_channels"]
    MFSC_channels = wavenet_params["MFSC_channels"]

    if args.fast_generation:
        next_sample = net.predict_proba_incremental(
            samples, args.gc_id)  ########### understand shape of next_sample
    else:
        outputs = net.predict_proba(samples, AP_channels, lc, args.gc_id)
        outputs = tf.reshape(outputs, [1, AP_channels])

    if args.fast_generation:
        sess.run(tf.global_variables_initializer())
        sess.run(net.init_ops)

    variables_to_restore = {
        var.name[:-2]: var
        for var in tf.global_variables()
        if not ('state_buffer' in var.name or 'pointer' in var.name)
    }
    saver = tf.train.Saver(variables_to_restore)

    print('Restoring model from {}'.format(args.checkpoint))
    saver.restore(sess, args.checkpoint)

    if args.wav_seed:
        # seed = create_seed(args.wav_seed,
        #                    wavenet_params['sample_rate'],
        #                    quantization_channels,
        #                    net.receptive_field)
        # waveform = sess.run(seed).tolist()
        pass
    else:

        # Silence with a single random sample at the end.
        waveform = np.zeros(
            (net.receptive_field - 1, AP_channels + MFSC_channels))
        waveform = np.append(waveform,
                             np.random.randn(1, AP_channels + MFSC_channels),
                             axis=0)

        lc_array, mfsc_array = load_lc(
            scramble=True)  # clip_id:[003, 004, 007, 010, 012 ...]
        lc_array = np.pad(lc_array, ((net.receptive_field, 0), (0, 0)),
                          'constant',
                          constant_values=((0, 0), (0, 0)))
        mfsc_array = np.pad(mfsc_array, ((net.receptive_field, 0), (0, 0)),
                            'constant',
                            constant_values=((0, 0), (0, 0)))

    if args.fast_generation and args.wav_seed:
        # When using the incremental generation, we need to
        # feed in all priming samples one by one before starting the
        # actual generation.
        # TODO This could be done much more efficiently by passing the waveform
        # to the incremental generator as an optional argument, which would be
        # used to fill the queues initially.
        outputs = [next_sample]
        outputs.extend(
            net.push_ops
        )  ################# understand net.push_ops, understand shape of outputs

        print('Priming generation...')
        for i, x in enumerate(waveform[-net.receptive_field:-1]):
            if i % 100 == 0:
                print('Priming sample {}'.format(i))
            sess.run(outputs, feed_dict={samples: x})
        print('Done.')

    last_sample_timestamp = datetime.now()
    for step in range(args.samples):
        if args.fast_generation:
            outputs = [next_sample]
            outputs.extend(net.push_ops)
            window = waveform[-1]
        else:

            if len(waveform) > net.receptive_field:
                window = waveform[-net.receptive_field:, :]
            else:
                window = waveform

        # Run the WaveNet to predict the next sample.
        window = window.reshape(1, window.shape[-2], window.shape[-1])

        prediction = sess.run(outputs,
                              feed_dict={
                                  samples:
                                  window,
                                  lc:
                                  lc_array[step + 1:step + 1 +
                                           net.receptive_field, :].reshape(
                                               1, net.receptive_field, -1)
                              })

        prediction = np.concatenate(
            (prediction, mfsc_array[step + net.receptive_field].reshape(1,
                                                                        -1)),
            axis=-1)
        waveform = np.append(waveform, prediction, axis=0)

        # Show progress only once per second.
        current_sample_timestamp = datetime.now()
        time_since_print = current_sample_timestamp - last_sample_timestamp
        if time_since_print.total_seconds() > 1.:
            print('Frame {:3<d}/{:3<d}'.format(step + 1, args.samples),
                  end='\r')
            last_sample_timestamp = current_sample_timestamp

        # If we have partial writing, save the result so far.
        if (args.wav_out_path and args.save_every
                and (step + 1) % args.save_every == 0):
            np.save(args.wav_out_path, waveform)

    # Introduce a newline to clear the carriage return from the progress.
    print()

    if args.wav_out_path:

        np.save(
            args.wav_out_path,
            waveform[:, :4])  # only take the first four columns, which are aps

    print('Finished generating. The result was saved as .npy file.')
Beispiel #9
0
def main():
    args = get_arguments()
    started_datestring = "{0:%Y-%m-%dT%H-%M-%S}".format(datetime.now())
    #logdir is where logging file is saved. different from where generated mat is saved.
    logdir = os.path.join(args.logdir, 'generate', started_datestring)
    with open(args.wavenet_params, 'r') as config_file:
        wavenet_params = json.load(config_file)

    sess = tf.Session()

    #skeleton_channels = wavenet_params['skeleton_channels']
    input_channels = wavenet_params['input_channels']
    output_channels = wavenet_params['output_channels']
    #gt, cut_index, gc_id = create_seed(os.path.join(args.motion_seed, os.path.basename(args.motion_seed)), args.window)
    gt, cut_index = create_seed(
        os.path.join(args.motion_seed, os.path.basename(args.motion_seed)),
        args.window)
    if np.isnan(np.sum(gt)):
        print('nan detected')
        raise ValueError('NAN detected in seed file')
    #if skeleton_channels == 45 or skeleton_channels == 42:
    #    seed = tf.constant(gt[:cut_index, 45 - skeleton_channels:])
    #else:
    #    seed = tf.constant(gt[:cut_index, :])
    seed = tf.constant(gt)

    net = WaveNetModel(batch_size=1,
                       dilations=wavenet_params['dilations'],
                       filter_width=wavenet_params['filter_width'],
                       residual_channels=wavenet_params['residual_channels'],
                       dilation_channels=wavenet_params['dilation_channels'],
                       skip_channels=wavenet_params['skip_channels'],
                       use_biases=wavenet_params['use_biases'],
                       input_channels=input_channels,
                       output_channels=output_channels,
                       global_condition_channels=args.gc_channels)

    samples = tf.placeholder(dtype=tf.float32)
    #todo: Q: how does samples represent T x 42 data? does predict_proba_incremental memorize? A: samples can store multiple frames. T x 42 dim

    if args.fast_generation:
        #next_sample = net.predict_proba_incremental(samples, args.gc_id)
        #next_sample = net.predict_proba_incremental(samples, gc_id)
        next_sample = net.predict_proba_incremental(samples)
    else:
        #next_sample = net.predict_proba(samples, args.gc_id)
        #next_sample = net.predict_proba(samples, gc_id)
        next_sample = net.predict_proba_incremental(samples)

    if args.fast_generation:
        sess.run(tf.initialize_all_variables())
        sess.run(net.init_ops)

    variables_to_restore = {
        var.name[:-2]: var
        for var in tf.all_variables()
        if not ('state_buffer' in var.name or 'pointer' in var.name)
    }
    saver = tf.train.Saver(variables_to_restore)

    print('Restoring model from {}'.format(args.checkpoint))
    saver.restore(sess, args.checkpoint)
    if args.motion_seed:
        pass
    else:
        raise ValueError('motion seed not specified!')

    # seed: T x 42 tensor
    # tolist() converts a tf tensor to a list
    gt_list = sess.run(seed).tolist()
    motion = gt_list[:cut_index]
    #motion[i]: ith frame, list of 42 features
    if args.fast_generation and args.motion_seed:
        # When using the incremental generation, we need to
        # feed in all priming samples one by one before starting the
        # actual generation.
        # TODO This could be done much more efficiently by passing the waveform
        # to the incremental generator as an optional argument, which would be
        # used to fill the queues initially.
        outputs = [next_sample]
        outputs.extend(net.push_ops)

        print('Priming generation...')
        for i, x in enumerate(motion[-net.receptive_field:-1]):
            if i % 10 == 0:
                print('Priming sample {}'.format(i))
            sess.run(outputs,
                     feed_dict={samples: np.reshape(x, (1, input_channels))})
        print('Done.')

    last_sample_timestamp = datetime.now()
    for step in range(args.samples):
        if args.fast_generation:
            outputs = [next_sample]
            outputs.extend(net.push_ops)
            window = motion[-1]
        else:
            if len(motion) > net.receptive_field:
                window = motion[-net.receptive_field:]
            else:
                window = motion
            outputs = [next_sample]

        #TODO: why motion[-1] fed into network twice?
        # Run the WaveNet to predict the next sample.
        prediction = sess.run(
            outputs,
            feed_dict={samples: np.reshape(window, (1, input_channels))})[0]
        # prediction = sess.run(outputs, feed_dict={samples: window})[0]
        #TODO: next_input = np.concatenate((prediction, gt(4:9)), axis=1). motion.append(next_input)
        motion.append(
            np.concatenate(
                (prediction, gt_list[cut_index + step][input_channels:]),
                axis=1))
        # Show progress only once per second.
        current_sample_timestamp = datetime.now()
        time_since_print = current_sample_timestamp - last_sample_timestamp
        if time_since_print.total_seconds() > 1.:
            print('Sample {:3<d}/{:3<d}'.format(step + 1, args.samples),
                  end='\r')
            last_sample_timestamp = current_sample_timestamp

    print()

    # save result in .mat file
    if args.skeleton_out_path:
        #TODO: save according to Hanbyul rules
        # outdir = os.path.join('logdir','skeleton_generate', os.path.basename(os.path.dirname(args.checkpoint)) + os.path.basename(args.checkpoint)+'window'+str(args.window)+'sample'+str(args.samples))
        outdir = os.path.join(
            args.skeleton_out_path,
            os.path.basename(os.path.dirname(args.checkpoint)))
        if not os.path.exists(outdir):
            os.makedirs(outdir)
        filedir = os.path.join(
            outdir,
            str(os.path.basename(args.motion_seed)) + '.mat')
        # filedir = os.path.join(outdir, (sub+args.skeleton_out_path))
        sio.savemat(filedir, {'wavenet_predict': motion, 'gt': gt})
        # out = sess.run(decode, feed_dict={samples: motion})
        # todo: write skeleton writer
        # write_skeleton(motion, args.wav_out_path)
        print(len(motion))
        print('generated filedir:{0}'.format(filedir))
    print('Finished generating. The result can be viewed in Matlab.')
Beispiel #10
0
def main():
    args = get_arguments()
    started_datestring = "{0:%Y-%m-%dT%H-%M-%S}".format(datetime.now())
    logdir = os.path.join(args.logdir, 'generate', started_datestring)
    
    if args.wavenet_params.startswith('wavenet_params/'):
        with open(args.wavenet_params, 'r') as config_file:
            wavenet_params = json.load(config_file)
    elif args.wavenet_params.startswith('wavenet_params'):
        with open('wavenet_params/'+args.wavenet_params, 'r') as config_file:
            wavenet_params = json.load(config_file)
    else:
        with open('wavenet_params/wavenet_params_'+args.wavenet_params, 'r') as config_file:
            wavenet_params = json.load(config_file)    

    sess = tf.Session()

    net = WaveNetModel(
        batch_size=1,
        dilations=wavenet_params['dilations'],
        filter_width=wavenet_params['filter_width'],
        residual_channels=wavenet_params['residual_channels'],
        dilation_channels=wavenet_params['dilation_channels'],
        quantization_channels=wavenet_params['quantization_channels'],
        skip_channels=wavenet_params['skip_channels'],
        use_biases=wavenet_params['use_biases'],
        scalar_input=wavenet_params['scalar_input'],
        initial_filter_width=wavenet_params['initial_filter_width'],
        global_condition_channels=args.gc_channels,
        global_condition_cardinality=args.gc_cardinality)

    samples = tf.placeholder(tf.int32)

    if args.fast_generation:
        next_sample = net.predict_proba_incremental(samples, args.gc_id)
    else:
        next_sample = net.predict_proba(samples, args.gc_id)

    if args.fast_generation:
        sess.run(tf.global_variables_initializer())
        sess.run(net.init_ops)

    variables_to_restore = {
        var.name[:-2]: var for var in tf.global_variables()
        if not ('state_buffer' in var.name or 'pointer' in var.name)}
    saver = tf.train.Saver(variables_to_restore)

    print('Restoring model from {}'.format(args.checkpoint))
    saver.restore(sess, args.checkpoint)

    decode = mu_law_decode(samples, wavenet_params['quantization_channels'])

    quantization_channels = wavenet_params['quantization_channels']
    if args.wav_seed:
        seed = create_seed(args.wav_seed,
                           wavenet_params['sample_rate'],
                           quantization_channels,
                           net.receptive_field,
                           args.silence_threshold)
        waveform = sess.run(seed).tolist()
    else:
        # Silence with a single random sample at the end.
        waveform = [quantization_channels / 2] * (net.receptive_field - 1)
        waveform.append(np.random.randint(quantization_channels))

    if args.fast_generation and args.wav_seed:
        # When using the incremental generation, we need to
        # feed in all priming samples one by one before starting the
        # actual generation.
        # TODO This could be done much more efficiently by passing the waveform
        # to the incremental generator as an optional argument, which would be
        # used to fill the queues initially.
        outputs = [next_sample]
        outputs.extend(net.push_ops)

        print('Priming generation...')
        for i, x in enumerate(waveform[-net.receptive_field: -1]):
            if i % 100 == 0:
                print('Priming sample {}'.format(i))
            sess.run(outputs, feed_dict={samples: x})
        print('Done.')

    last_sample_timestamp = datetime.now()

        #frequency to period change
    if args.tfrequency is not None:
        if args.tperiod is not None:
            raise ValueError("Frequency and Period both assigned. Assign only one of them.")
        else:
            PERIOD = dynamic.frequency_to_period(args.tfrequency)
    else:
        if args.tperiod is not None:
            PERIOD = args.tperiod
        else:
            PERIOD = 1

    # generate an array of temperature for each step when called "dynamic" in tempurature-change variable
    if args.temperature_change == "dynamic" and args.tform is not None:
        temp_array = dynamic.generate_value(0, wavenet_params['sample_rate'], args.tform, args.tmin, args.tmax, PERIOD, args.samples, args.tphaseshift)


    for step in range(args.samples):
        if args.fast_generation:
            outputs = [next_sample]
            outputs.extend(net.push_ops)
            window = waveform[-1]
        else:
            if len(waveform) > net.receptive_field:
                window = waveform[-net.receptive_field:]
            else:
                window = waveform
            outputs = [next_sample]

        # Run the WaveNet to predict the next sample.
        prediction = sess.run(outputs, feed_dict={samples: window})[0]

        # Scale prediction distribution using temperature.
        np.seterr(divide='ignore')

        # temperature change
        if args.temperature_change == None: #static
            _temp_temperature = args.temperature
        elif args.temperature_change == "dynamic":
            if args.tform == None: #random
                if step % int(args.samples/5) == 0:
                    _temp_temperature = args.temperature * np.random.rand()
            else:
                _temp_temperature = temp_array[1][step]  
        else:
                raise Exception("wrong temperature_change value")

        scaled_prediction = np.log(prediction) / _temp_temperature
        scaled_prediction = (scaled_prediction -
                             np.logaddexp.reduce(scaled_prediction))
        scaled_prediction = np.exp(scaled_prediction)
        np.seterr(divide='warn')

        # Prediction distribution at temperature=1.0 should be unchanged after
        # scaling.
        if args.temperature == 1.0 and args.temperature_change == None:
            np.testing.assert_allclose(
                    prediction, scaled_prediction, atol=1e-5,
                    err_msg = 'Prediction scaling at temperature=1.0 is not working as intended.')

        sample = np.random.choice(
            np.arange(quantization_channels), p=scaled_prediction)

        
        # Show progress only once per second.
        current_sample_timestamp = datetime.now()
        time_since_print = current_sample_timestamp - last_sample_timestamp
        if time_since_print.total_seconds() > 1.:
            print('Sample {:3<d}/{:3<d}, temperature {:3<f}'.format(step + 1, args.samples, _temp_temperature),
                  end='\r')
            last_sample_timestamp = current_sample_timestamp

        # If we have partial writing, save the result so far.
        if (args.wav_out_path and args.save_every and
                (step + 1) % args.save_every == 0):
            out = sess.run(decode, feed_dict={samples: waveform})
            write_wav(out, wavenet_params['sample_rate'], args.wav_out_path)

    # Introduce a newline to clear the carriage return from the progress.
    print()

    # Save the result as an audio summary.
    datestring = str(datetime.now()).replace(' ', 'T')
    writer = tf.summary.FileWriter(logdir)
    tf.summary.audio('generated', decode, wavenet_params['sample_rate'])
    summaries = tf.summary.merge_all()
    summary_out = sess.run(summaries,
                           feed_dict={samples: np.reshape(waveform, [-1, 1])})
    writer.add_summary(summary_out)

    # Save the result as a wav file.
    if args.wav_out_path:
        out = sess.run(decode, feed_dict={samples: waveform})
        write_wav(out, wavenet_params['sample_rate'], args.wav_out_path)

    print('Finished generating. The result can be viewed in TensorBoard.')
Beispiel #11
0
def main():
    args = get_arguments()
    started_datestring = "{0:%Y-%m-%dT%H-%M-%S}".format(datetime.now())
    logdir = os.path.join(args.logdir, 'generate', started_datestring)
    with open(args.wavenet_params, 'r') as config_file:
        wavenet_params = json.load(config_file)

    sess = tf.Session()

    net = WaveNetModel(
        batch_size=1,
        dilations=wavenet_params['dilations'],
        filter_width=wavenet_params['filter_width'],
        residual_channels=wavenet_params['residual_channels'],
        dilation_channels=wavenet_params['dilation_channels'],
        quantization_channels=wavenet_params['quantization_channels'],
        skip_channels=wavenet_params['skip_channels'],
        use_biases=wavenet_params['use_biases'],
        scalar_input=wavenet_params['scalar_input'],
        initial_filter_width=wavenet_params['initial_filter_width'],
        global_condition_channels=args.gc_channels,
        global_condition_cardinality=args.gc_cardinality)

    gi_sampler = get_generator_input_sampler()

    # White noise generator params
    white_mean = 0
    white_sigma = 1
    white_length = 20234

    white_noise = gi_sampler(white_mean, white_sigma, white_length)

    loss = net.loss(input_batch=tf.convert_to_tensor(white_noise,
                                                     dtype=np.float32),
                    name='generator')

    samples = tf.placeholder(tf.int32)

    if args.fast_generation:
        next_sample = net.predict_proba_incremental(samples, args.gc_id)
    else:
        next_sample = net.predict_proba(samples, args.gc_id)

    if args.fast_generation:
        sess.run(tf.global_variables_initializer())
        sess.run(net.init_ops)

    decode = mu_law_decode(samples, wavenet_params['quantization_channels'])

    quantization_channels = wavenet_params['quantization_channels']
    '''
    # Silence with a single random sample at the end.
    waveform = [quantization_channels / 2] * (net.receptive_field - 1)
    waveform.append(np.random.randint(quantization_channels))
    '''
    waveform = [0]

    last_sample_timestamp = datetime.now()
    for step in range(args.samples):
        if args.fast_generation:
            outputs = [next_sample]
            outputs.extend(net.push_ops)
            window = waveform[-1]
        else:
            if len(waveform) > net.receptive_field:
                window = waveform[-net.receptive_field:]
            else:
                window = waveform
            outputs = [next_sample]

        # Run the WaveNet to predict the next sample.
        prediction = sess.run(outputs, feed_dict={samples: window})[0]

        # Scale prediction distribution using temperature.
        np.seterr(divide='ignore')
        scaled_prediction = np.log(prediction) / args.temperature
        scaled_prediction = (scaled_prediction -
                             np.logaddexp.reduce(scaled_prediction))
        scaled_prediction = np.exp(scaled_prediction)
        np.seterr(divide='warn')

        # Prediction distribution at temperature=1.0 should be unchanged after
        # scaling.
        if args.temperature == 1.0:
            np.testing.assert_allclose(
                prediction,
                scaled_prediction,
                atol=1e-5,
                err_msg='Prediction scaling at temperature=1.0 '
                'is not working as intended.')

        sample = np.random.choice(np.arange(quantization_channels),
                                  p=scaled_prediction)
        waveform.append(sample)

    # Introduce a newline to clear the carriage return from the progress.
    print()
    del waveform[0]
    print(waveform)
    print()

    print('Finished generating. The result can be viewed in TensorBoard.')
Beispiel #12
0
def main():
    args = get_arguments()
    started_datestring = "{0:%Y-%m-%dT%H-%M-%S}".format(datetime.now())
    logdir = os.path.join(args.logdir, 'generate', started_datestring)
    with open(args.wavenet_params, 'r') as config_file:
        wavenet_params = json.load(config_file)

    if (args.using_magna):
        styles = wavenet_params['styles']
        header = get_data(N_CLASSES)[0]
        conditions = np.zeros(N_CLASSES)

        for style in styles:
            conditions[header.index(style)] = 1
        cardinality = N_CLASSES
    else:
        conditions = args.gc_id
        cardinality = args.gc_cardinality

    sess = tf.Session()
    net = WaveNetModel(
        batch_size=1,
        dilations=wavenet_params['dilations'],
        filter_width=wavenet_params['filter_width'],
        residual_channels=wavenet_params['residual_channels'],
        dilation_channels=wavenet_params['dilation_channels'],
        quantization_channels=wavenet_params['quantization_channels'],
        skip_channels=wavenet_params['skip_channels'],
        use_biases=wavenet_params['use_biases'],
        scalar_input=wavenet_params['scalar_input'],
        initial_filter_width=wavenet_params['initial_filter_width'],
        global_condition_channels=args.gc_channels,
        global_condition_cardinality=cardinality,
        glove_channels=args.glove_channels,
        residual_postproc=wavenet_params["residual_postproc"],
        use_magna=args.using_magna)

    samples = tf.placeholder(tf.int32)

    if args.text is not None:
        nlp = spacy.load('en', vectors='en_glove_cc_300_1m_vectors')
        word_vec = nlp(u'%s' % args.text).vector
    else:
        word_vec = None

    if args.fast_generation:
        next_sample = net.predict_proba_incremental(
            samples, conditions, word_vec)
    else:
        next_sample = net.predict_proba(samples, conditions, word_vec)

    if args.fast_generation:
        sess.run(tf.initialize_all_variables())
        sess.run(net.init_ops)

    variables_to_restore = {
        var.name[:-2]: var for var in tf.all_variables()
        if not ('state_buffer' in var.name or 'pointer' in var.name)}
    saver = tf.train.Saver(variables_to_restore)

    print('Restoring model from {}'.format(args.checkpoint))
    saver.restore(sess, args.checkpoint)

    decode = mu_law_decode(samples, wavenet_params['quantization_channels'])

    quantization_channels = wavenet_params['quantization_channels']
    if args.wav_seed:
        seed = create_seed(args.wav_seed,
                           wavenet_params['sample_rate'],
                           quantization_channels)
        waveform = sess.run(seed).tolist()
    else:
        waveform = np.random.randint(quantization_channels, size=(1,)).tolist()

    if args.fast_generation and args.wav_seed:
        # When using the incremental generation, we need to
        # feed in all priming samples one by one before starting the
        # actual generation.
        # TODO This could be done much more efficiently by passing the waveform
        # to the incremental generator as an optional argument, which would be
        # used to fill the queues initially.
        outputs = [next_sample]
        outputs.extend(net.push_ops)

        print('Priming generation...')
        for i, x in enumerate(waveform[:-(args.window + 1)]):
            if i % 100 == 0:
                print('Priming sample {}'.format(i))
            sess.run(outputs, feed_dict={samples: x})
        print('Done.')

    last_sample_timestamp = datetime.now()
    for step in range(args.samples):
        if args.fast_generation:
            outputs = [next_sample]
            outputs.extend(net.push_ops)
            window = waveform[-1]
        else:
            if len(waveform) > args.window:
                window = waveform[-args.window:]
            else:
                window = waveform
            outputs = [next_sample]

        # Run the WaveNet to predict the next sample.
        prediction = sess.run(outputs, feed_dict={samples: window})[0]

        # Scale prediction distribution using temperature.
        np.seterr(divide='ignore')
        scaled_prediction = np.log(prediction) / args.temperature
        scaled_prediction = scaled_prediction - \
            np.logaddexp.reduce(scaled_prediction)
        scaled_prediction = np.exp(scaled_prediction)
        np.seterr(divide='warn')

        # Prediction distribution at temperature=1.0 should be unchanged after
        # scaling.
        if args.temperature == 1.0:
            np.testing.assert_allclose(
                prediction, scaled_prediction, atol=1e-5,
                err_msg='Prediction scaling at temperature=1.0 is not working as intended.')

        sample = np.random.choice(
            np.arange(quantization_channels), p=scaled_prediction)
        waveform.append(sample)

        # Show progress only once per second.
        current_sample_timestamp = datetime.now()
        time_since_print = current_sample_timestamp - last_sample_timestamp
        if time_since_print.total_seconds() > 1.:
            print('Sample {:3<d}/{:3<d}'.format(step + 1, args.samples),
                  end='\r')
            last_sample_timestamp = current_sample_timestamp

        # If we have partial writing, save the result so far.
        if (args.wav_out_path and args.save_every and
                (step + 1) % args.save_every == 0):
            out = sess.run(decode, feed_dict={samples: waveform})
            write_wav(out, wavenet_params['sample_rate'], args.wav_out_path)

    # Introduce a newline to clear the carriage return from the progress.
    print()

    # Save the result as an audio summary.
    datestring = str(datetime.now()).replace(' ', 'T')
    writer = tf.train.SummaryWriter(logdir)
    tf.audio_summary('generated', decode, wavenet_params['sample_rate'])
    summaries = tf.merge_all_summaries()
    summary_out = sess.run(summaries,
                           feed_dict={samples: np.reshape(waveform, [-1, 1])})
    writer.add_summary(summary_out)

    # Save the result as a wav file.
    if args.wav_out_path:
        out = sess.run(decode, feed_dict={samples: waveform})
        write_wav(out, wavenet_params['sample_rate'], args.wav_out_path)

    print('Finished generating. The result can be viewed in TensorBoard.')
Beispiel #13
0
def main():

    with tf.Graph().as_default():
        coord = tf.train.Coordinator()
        sess = tf.Session()

        batch_size = 1
        hidden1_units = 5202
        hidden2_units = 2601
        hidden3_units = 1300
        hidden4_units = 650
        hidden5_units = 325
        max_training_steps = 1

        global_step = tf.Variable(0, name='global_step', trainable=False)
        initial_training_learning_rate = 3e-2
        training_learning_rate = tf.train.exponential_decay(
            initial_training_learning_rate,
            global_step,
            100,
            0.9,
            staircase=True)

        inputs_placeholder, labels_placeholder = placeholder_inputs(batch_size)

        logits = ffnn.inference(inputs_placeholder, hidden1_units,
                                hidden2_units, hidden3_units, hidden4_units,
                                hidden5_units)
        loss = ffnn.loss(logits, labels_placeholder)
        train_op = ffnn.training(loss, training_learning_rate, global_step)
        eval_correct = ffnn.evaluation(logits, labels_placeholder)

        summary = tf.summary.merge_all()
        init = tf.global_variables_initializer()
        saver = tf.train.Saver()

        summary_writer = tf.summary.FileWriter('./logdir', sess.graph)

        sess.run(init)

        args = get_arguments()

        # Load parameters from wavenet params json file
        with open(args.wavenet_params, 'r') as f:
            wavenet_params = json.load(f)

        quantization_channels = wavenet_params['quantization_channels']

        if args.restore_from != None:
            restore_from = args.restore_from
            print("Restoring from: ")
            print(restore_from)

        else:
            restore_from = ""

        try:
            saved_global_step = load(saver, sess, restore_from)
            if saved_global_step is None:
                # The first training step will be saved_global_step + 1,
                # therefore we put -1 here for new or overwritten trainings.
                saved_global_step = -1
            else:
                counter = saved_global_step % label_batch_size

        except:
            print(
                "Something went wrong while restoring checkpoint. "
                "We will terminate training to avoid accidentally overwriting "
                "the previous model.")
            raise

        # TODO: Find a more robust way to find different data sets

        # Training data
        directory = './sampleTrue'
        reader = AudioReader(directory,
                             coord,
                             sample_rate=16000,
                             gc_enabled=False,
                             receptive_field=5117,
                             sample_size=15117,
                             silence_threshold=0.05)
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        reader.start_threads(sess)

        directory = './sampleFalse'
        reader2 = AudioReader(directory,
                              coord,
                              sample_rate=16000,
                              gc_enabled=False,
                              receptive_field=5117,
                              sample_size=15117,
                              silence_threshold=0.05)
        threads2 = tf.train.start_queue_runners(sess=sess, coord=coord)
        reader2.start_threads(sess)

        total_loss = 0
        for step in range(saved_global_step + 1, max_training_steps):
            start_time = time.time()

            batch_data = []
            label_data = []

            if (step % 100 == 0):
                print('Current learning rate: %6f' %
                      (sess.run(training_learning_rate)))

            for b in range(batch_size):
                label = randint(0, 1)

                if label == 1:
                    data = sess.run(reader.dequeue(1))
                    while (len(data[0]) < ffnn.INPUT_SIZE):
                        data = sess.run(reader.dequeue(1))
                else:
                    data = sess.run(reader2.dequeue(1))
                    while (len(data[0]) < ffnn.INPUT_SIZE):
                        data = sess.run(reader2.dequeue(1))
                data = np.array(data[0])

                cut = []
                for i in range(ffnn.INPUT_SIZE):
                    cut.append(data[i])

                data = cut

                # processing
                samples = process(data, quantization_channels, 1)

                batch_data.append(samples)
                label_data.append(label)

            feed_dict = fill_feed_dict(batch_data, label_data,
                                       inputs_placeholder, labels_placeholder)

            _, loss_value = sess.run([train_op, loss], feed_dict=feed_dict)

            duration = time.time() - start_time
            total_loss = total_loss + loss_value

            print('Step %d: loss = %.7f (%.3f sec)' %
                  (step, loss_value, duration))
            '''
            if step % 100 == 0 or (step + 1) == max_training_steps:
                average = total_loss / (step + 1)
                print('Cumulative average loss: %6f' % (average))
                # TODO: Update train script to add data to new directory
                checkpoint_file = os.path.join('./logdir/init-train/', 'model.ckpt')
                print("Generating checkpoint file...")
                saver.save(sess, checkpoint_file, global_step=step)
            '''

        # Lambda for white noise sampler
        gi_sampler = get_generator_input_sampler()

        # Intialize generator WaveNet
        G = WaveNetModel(
            batch_size=1,
            dilations=wavenet_params["dilations"],
            filter_width=wavenet_params["filter_width"],
            residual_channels=wavenet_params["residual_channels"],
            dilation_channels=wavenet_params["dilation_channels"],
            skip_channels=wavenet_params["skip_channels"],
            quantization_channels=wavenet_params["quantization_channels"],
            use_biases=wavenet_params["use_biases"],
            initial_filter_width=wavenet_params["initial_filter_width"])

        # White noise generator params
        white_mean = 0
        white_sigma = 1
        white_length = ffnn.INPUT_SIZE

        white_noise = gi_sampler(white_mean, white_sigma, white_length)
        white_noise = process(white_noise, quantization_channels, 1)
        white_noise_t = tf.convert_to_tensor(white_noise)

        # initialize generator
        w_loss, w_prediction = G.loss(input_batch=white_noise_t,
                                      name='generator')

        G_variables = tf.trainable_variables(scope='wavenet')
        optimizer = optimizer_factory[args.optimizer](learning_rate=3e-2,
                                                      momentum=args.momentum)
        optim = optimizer.minimize(w_loss, var_list=G_variables)

        init = tf.global_variables_initializer()
        sess.run(init)

        print(sess.run(tf.shape(w_prediction)))

        # main GAN training loop
        for step in range(NUM_EPOCHS):
            batch_data = []
            label_data = []

            # train D on real
            for d_index in range(batch_size):
                data = sess.run(reader.dequeue(1))
                data = data[0]

                d_real_data = process(data, quantization_channels, 1)

                batch_data.append(d_real_data)
                label_data.append(1)

            feed_dict = fill_feed_dict(batch_data, label_data,
                                       inputs_placeholder, labels_placeholder)

            _, d_real_loss = sess.run([train_op, loss], feed_dict=feed_dict)

            print("Real loss")
            print(d_real_loss)

            batch_data = []
            label_data = []

            # train D on fake
            for d_index in range(batch_size):
                samples = tf.placeholder(tf.int32)

                if args.fast_generation:
                    next_sample = G.predict_proba_incremental(
                        samples, args.gc_id)
                else:
                    next_sample = G.predict_proba(samples, args.gc_id)

                if args.fast_generation:
                    sess.run(tf.global_variables_initializer())
                    sess.run(G.init_ops)

                waveform = [0]

                for step in range(ffnn.INPUT_SIZE):
                    if args.fast_generation:
                        outputs = [next_sample]
                        outputs.extend(G.push_ops)
                        window = waveform[-1]
                    else:
                        if len(waveform) > G.receptive_field:
                            window = waveform[-G.receptive_field:]
                        else:
                            window = waveform
                        outputs = [next_sample]

                    # Run the WaveNet to predict the next sample.
                    prediction = sess.run(outputs, feed_dict={samples:
                                                              window})[0]

                    # Scale prediction distribution using temperature.
                    np.seterr(divide='ignore')
                    scaled_prediction = np.log(prediction) / 1
                    scaled_prediction = (
                        scaled_prediction -
                        np.logaddexp.reduce(scaled_prediction))
                    scaled_prediction = np.exp(scaled_prediction)
                    np.seterr(divide='warn')

                    sample = np.random.choice(np.arange(quantization_channels),
                                              p=scaled_prediction)
                    waveform.append(sample)

                del waveform[0]

                d_fake_data = process(waveform, quantization_channels, 0)

                batch_data.append(d_fake_data)
                label_data.append(0)

            feed_dict = fill_feed_dict(batch_data, label_data,
                                       inputs_placeholder, labels_placeholder)

            _, d_fake_loss = sess.run([train_op, loss], feed_dict=feed_dict)

            print("Fake loss")
            print(d_fake_loss)

            batch_data = []
            label_data = []

            # train G, but don't train D
            for g_index in range(batch_size):
                samples = tf.placeholder(tf.int32)

                if args.fast_generation:
                    next_sample = G.predict_proba_incremental(
                        samples, args.gc_id)
                else:
                    next_sample = G.predict_proba(samples, args.gc_id)

                if args.fast_generation:
                    sess.run(tf.global_variables_initializer())
                    sess.run(G.init_ops)

                waveform = [0]

                for step in range(ffnn.INPUT_SIZE):
                    if args.fast_generation:
                        outputs = [next_sample]
                        outputs.extend(G.push_ops)
                        window = waveform[-1]
                    else:
                        if len(waveform) > G.receptive_field:
                            window = waveform[-G.receptive_field:]
                        else:
                            window = waveform
                        outputs = [next_sample]

                    # Run the WaveNet to predict the next sample.
                    prediction = sess.run(outputs, feed_dict={samples:
                                                              window})[0]

                    # Scale prediction distribution using temperature.
                    np.seterr(divide='ignore')
                    scaled_prediction = np.log(prediction) / 1
                    scaled_prediction = (
                        scaled_prediction -
                        np.logaddexp.reduce(scaled_prediction))
                    scaled_prediction = np.exp(scaled_prediction)
                    np.seterr(divide='warn')

                    sample = np.random.choice(np.arange(quantization_channels),
                                              p=scaled_prediction)
                    waveform.append(sample)

                del waveform[0]

                g_data = process(waveform, quantization_channels, 0)

                batch_data.append(g_data)
                label_data.append(1)

            feed_dict = fill_feed_dict(batch_data, label_data,
                                       inputs_placeholder, labels_placeholder)

            _, g_loss = sess.run([optim, loss], feed_dict=feed_dict)

            print("Generator loss")
            print(g_loss)
        '''
def main():

    with tf.device(
            '/cpu:0'):  # cpu가 더 빠르다. gpu로 설정하면 Error. tf.device 없이 하면 더 느려진다.
        config = get_arguments()
        started_datestring = "{0:%Y-%m-%dT%H-%M-%S}".format(datetime.now())
        logdir = os.path.join(config.logdir, 'generate', started_datestring)
        print('logdir0-------------' + logdir)

        if not os.path.exists(logdir):
            os.makedirs(logdir)

        load_hparams(hparams, config.checkpoint_dir)

        sess = tf.Session()
        scalar_input = hparams.scalar_input
        net = WaveNetModel(
            batch_size=config.batch_size,
            dilations=hparams.dilations,
            filter_width=hparams.filter_width,
            residual_channels=hparams.residual_channels,
            dilation_channels=hparams.dilation_channels,
            quantization_channels=hparams.quantization_channels,
            out_channels=hparams.out_channels,
            skip_channels=hparams.skip_channels,
            use_biases=hparams.use_biases,
            scalar_input=hparams.scalar_input,
            global_condition_channels=hparams.gc_channels,
            global_condition_cardinality=config.gc_cardinality,
            local_condition_channels=hparams.num_mels,
            upsample_factor=hparams.upsample_factor,
            legacy=hparams.legacy,
            residual_legacy=hparams.residual_legacy,
            train_mode=False
        )  # train 단계에서는 global_condition_cardinality를 AudioReader에서 파악했지만, 여기서는 넣어주어야 함

        if scalar_input:
            samples = tf.placeholder(tf.float32, shape=[net.batch_size, None])
        else:
            samples = tf.placeholder(
                tf.int32, shape=[net.batch_size, None]
            )  # samples: mu_law_encode로 변환된 것. one-hot으로 변환되기 전. (batch_size, 길이)

        # local condition이 (N,T,num_mels) 여야 하지만, 길이 1까지로 들어가야하기 때무넹, (N,1,num_mels) --> squeeze하면 (N,num_mels)
        upsampled_local_condition = tf.placeholder(
            tf.float32, shape=[net.batch_size, hparams.num_mels])

        next_sample = net.predict_proba_incremental(
            samples, upsampled_local_condition, [config.gc_id] * net.batch_size
        )  # Fast Wavenet Generation Algorithm-1611.09482 algorithm 적용

        # making local condition data. placeholder - upsampled_local_condition 넣어줄 upsampled local condition data를 만들어 보자.
        print('logdir0-------------' + logdir)
        mel_input = np.load(config.mel)
        sample_size = mel_input.shape[0] * hparams.hop_size
        mel_input = np.tile(mel_input, (config.batch_size, 1, 1))
        with tf.variable_scope('wavenet', reuse=tf.AUTO_REUSE):
            upsampled_local_condition_data = net.create_upsample(
                mel_input, upsample_type=hparams.upsample_type)

        var_list = [
            var for var in tf.global_variables() if 'queue' not in var.name
        ]
        saver = tf.train.Saver(var_list)
        print('Restoring model from {}'.format(config.checkpoint_dir))

        load(saver, sess, config.checkpoint_dir)
        init_op = tf.group(tf.initialize_all_variables(),
                           net.queue_initializer)

        sess.run(init_op)  # 이 부분이 없으면, checkpoint에서 복원된 값들이 들어 있다.

        quantization_channels = hparams.quantization_channels
        if config.wav_seed:
            # wav_seed의 길이가 receptive_field보다 작으면, padding이라도 해야 되는 거 아닌가? 그냥 짧으면 짧은 대로 return함  --> 그래서 너무 짧으면 error
            seed = create_seed(config.wav_seed, hparams.sample_rate,
                               quantization_channels, net.receptive_field,
                               scalar_input)  # --> mu_law encode 된 것.
            if scalar_input:
                waveform = seed.tolist()
            else:
                waveform = sess.run(
                    seed).tolist()  # [116, 114, 120, 121, 127, ...]

            print('Priming generation...')
            for i, x in enumerate(waveform[-net.receptive_field:-1]
                                  ):  # 제일 마지막 1개는 아래의 for loop의 첫 loop에서 넣어준다.
                if i % 100 == 0:
                    print('Priming sample {}/{}'.format(
                        i, net.receptive_field),
                          end='\r')
                sess.run(next_sample,
                         feed_dict={
                             samples:
                             np.array([x] * net.batch_size).reshape(
                                 net.batch_size, 1),
                             upsampled_local_condition:
                             np.zeros([net.batch_size, hparams.num_mels])
                         })
            print('Done.')
            waveform = np.array([waveform[-net.receptive_field:]] *
                                net.batch_size)
        else:
            # Silence with a single random sample at the end.
            if scalar_input:
                waveform = [0.0] * (net.receptive_field - 1)
                waveform = np.array(waveform * net.batch_size).reshape(
                    net.batch_size, -1)
                waveform = np.concatenate(
                    [
                        waveform, 2 * np.random.rand(net.batch_size).reshape(
                            net.batch_size, -1) - 1
                    ],
                    axis=-1)  # -1~1사이의 random number를 만들어 끝에 붙힌다.
                # wavefor: shape(batch_size,net.receptive_field )
            else:
                waveform = [quantization_channels / 2] * (
                    net.receptive_field - 1
                )  # 필요한 receptive_field 크기보다 1개 작게 만든 후, 아래에서 random하게 1개를 덧붙힌다.
                waveform = np.array(waveform * net.batch_size).reshape(
                    net.batch_size, -1)
                waveform = np.concatenate(
                    [
                        waveform,
                        np.random.randint(quantization_channels,
                                          size=net.batch_size).reshape(
                                              net.batch_size, -1)
                    ],
                    axis=-1)  # one hot 변환 전. (batch_size, 5117)

        start_time = time.time()
        upsampled_local_condition_data = sess.run(
            upsampled_local_condition_data)
        last_sample_timestamp = datetime.now()
        for step in range(sample_size):  # 원하는 길이를 구하기 위해 loop sample_size

            window = waveform[:,
                              -1:]  # 제일 끝에 있는 1개만 samples에 넣어 준다.  window: shape(N,1)

            # Run the WaveNet to predict the next sample.

            # fast가 아닌경우. window: [128.0, 128.0, ..., 128.0, 178, 185]
            # fast인 경우, window는 숫자 1개.
            prediction = sess.run(
                next_sample,
                feed_dict={
                    samples:
                    window,
                    upsampled_local_condition:
                    upsampled_local_condition_data[:, step, :]
                }
            )  # samples는 mu law encoding된 것. 계산 과정에서 one hot으로 변환된다.  --> (batch_size,256)

            if scalar_input:
                sample = prediction  # logistic distribution으로부터 sampling 되었기 때문에, randomness가 있다.
            else:
                # Scale prediction distribution using temperature.
                # 다음 과정은 config.temperature==1이면 각 원소를 합으로 나누어주는 것에 불과. 이미 softmax를 적용한 겂이므로, 합이 1이된다. 그래서 값의 변화가 없다.
                # config.temperature가 1이 아니며, 각 원소의 log취한 값을 나눈 후, 합이 1이 되도록 rescaling하는 것이 된다.
                np.seterr(divide='ignore')
                scaled_prediction = np.log(
                    prediction
                ) / config.temperature  # config.temperature인 경우는 값의 변화가 없다.
                scaled_prediction = (
                    scaled_prediction - np.logaddexp.reduce(
                        scaled_prediction, axis=-1, keepdims=True)
                )  # np.log(np.sum(np.exp(scaled_prediction)))
                scaled_prediction = np.exp(scaled_prediction)
                np.seterr(divide='warn')

                # Prediction distribution at temperature=1.0 should be unchanged after
                # scaling.
                if config.temperature == 1.0:
                    np.testing.assert_allclose(
                        prediction,
                        scaled_prediction,
                        atol=1e-5,
                        err_msg=
                        'Prediction scaling at temperature=1.0 is not working as intended.'
                    )

                # argmax로 선택하지 않기 때문에, 같은 입력이 들어가도 달라질 수 있다.
                sample = [[
                    np.random.choice(np.arange(quantization_channels), p=p)
                ] for p in scaled_prediction]  # choose one sample per batch

            waveform = np.concatenate([waveform, sample],
                                      axis=-1)  #window.shape: (N,1)

            # Show progress only once per second.
            current_sample_timestamp = datetime.now()
            time_since_print = current_sample_timestamp - last_sample_timestamp
            if time_since_print.total_seconds() > 1.:
                duration = time.time() - start_time
                print('Sample {:3<d}/{:3<d}, ({:.3f} sec/step)'.format(
                    step + 1, sample_size, duration),
                      end='\r')
                last_sample_timestamp = current_sample_timestamp

        # Introduce a newline to clear the carriage return from the progress.
        print()

        # Save the result as a wav file.
        if hparams.input_type == 'raw':
            out = waveform[:, net.receptive_field:]
        elif hparams.input_type == 'mulaw':
            decode = mu_law_decode(samples,
                                   quantization_channels,
                                   quantization=False)
            out = sess.run(
                decode, feed_dict={samples: waveform[:, net.receptive_field:]})
        else:  # 'mulaw-quantize'
            decode = mu_law_decode(samples,
                                   quantization_channels,
                                   quantization=True)
            out = sess.run(
                decode, feed_dict={samples: waveform[:, net.receptive_field:]})

        # save wav

        for i in range(net.batch_size):
            config.wav_out_path = logdir + '/test-{}.wav'.format(i)
            mel_path = config.wav_out_path.replace(".wav", ".png")

            gen_mel_spectrogram = audio.melspectrogram(out[i], hparams).astype(
                np.float32).T
            audio.save_wav(out[i], config.wav_out_path,
                           hparams.sample_rate)  # save_wav 내에서 out[i]의 값이 바뀐다.

            plot.plot_spectrogram(gen_mel_spectrogram,
                                  mel_path,
                                  title='generated mel spectrogram',
                                  target_spectrogram=mel_input[i])
        print('Finished generating.')
Beispiel #15
0
def main():
    args = get_arguments()
    started_datestring = "{0:%Y-%m-%dT%H-%M-%S}".format(datetime.now())
    logdir = os.path.join(args.logdir, 'generate', started_datestring)
    with open(args.wavenet_params, 'r') as config_file:
        wavenet_params = json.load(config_file)

    sess = tf.Session()

    net = WaveNetModel(
        batch_size=1,
        dilations=wavenet_params['dilations'],
        filter_width=wavenet_params['filter_width'],
        residual_channels=wavenet_params['residual_channels'],
        dilation_channels=wavenet_params['dilation_channels'],
        quantization_channels=wavenet_params['quantization_channels'],
        skip_channels=wavenet_params['skip_channels'],
        use_biases=wavenet_params['use_biases'])

    samples = tf.placeholder(tf.int32)

    if args.fast_generation:
        next_sample = net.predict_proba_incremental(samples)
    else:
        next_sample = net.predict_proba(samples)

    if args.fast_generation:
        sess.run(tf.initialize_all_variables())
        sess.run(net.init_ops)

    variables_to_restore = {
        var.name[:-2]: var
        for var in tf.all_variables()
        if not ('state_buffer' in var.name or 'pointer' in var.name)
    }
    saver = tf.train.Saver(variables_to_restore)

    print('Restoring model from {}'.format(args.checkpoint))
    saver.restore(sess, args.checkpoint)

    # decode = samples
    # Creating a copy of samples
    decode = tf.identity(samples)

    quantization_channels = wavenet_params['quantization_channels']

    with open(fname) as f:
        content = f.readlines()
    waveform = [float(x.replace('\r', '').strip()) for x in content]

    dataframe = pd.read_csv(fname, engine='python')
    dataset = dataframe.values

    # split into train and test sets
    train_size = int(len(dataset) * SPLIT)
    test_size = len(dataset) - train_size

    testX, testY = create_test_dataset(dataset, train_size)

    print(str(len(testX)) + " steps to go...")

    testP = []

    # train, test = waveform[0:train_size,:], waveform[train_size:len(dataset),:]
    # waveform = [144.048970239,143.889691754,143.68922135,143.644718903,143.698498762,143.710396703,143.756327831,143.843187531,143.975287002,143.811912129]

    last_sample_timestamp = datetime.now()
    for step in range(len(testX)):
        # for step in range(2):
        # if len(waveform) > args.window:
        #     window = waveform[-args.window:]
        # else:

        window = testX[step]
        outputs = [next_sample]

        # Run the WaveNet to predict the next sample.
        prediction = sess.run(outputs, feed_dict={samples: window})[0]
        # sample = np.random.choice(np.arange(quantization_channels), p=prediction)
        sample = np.arange(quantization_channels)[np.argmax(prediction)]
        # waveform.append(sample)
        testP.append(sample)
        print(step, sample)

    # Introduce a newline to clear the carriage return from the progress.
    testPredict = np.array(testP).reshape((-1, 1))

    # Save the result as a wav file.
    # if args.text_out_path:
    #     out = sess.run(decode, feed_dict={samples: waveform})
    #     write_text(out, args.text_out_path)
    testScore = math.sqrt(mean_squared_error(testY, testPredict[:, 0]))
    print('Test Score: %.2f RMSE' % (testScore))

    testPredictPlot = np.empty_like(dataset, dtype=float)
    testPredictPlot[:, :] = np.nan
    testPredictPlot[train_size + 1:len(dataset) - 1, :] = testPredict

    # plot baseline and predictions
    plt.plot(dataset)
    plt.plot(testPredictPlot)
    plt.show()

    print('Finished generating.')
Beispiel #16
0
def main():
    args = get_arguments()
    started_datestring = "{0:%Y-%m-%dT%H-%M-%S}".format(datetime.now())
    logdir = os.path.join(args.logdir, 'generate', started_datestring)
    with open(args.wavenet_params, 'r') as config_file:
        wavenet_params = json.load(config_file)

    sess = tf.Session()

    net = WaveNetModel(
        batch_size=1,
        dilations=wavenet_params['dilations'],
        filter_width=wavenet_params['filter_width'],
        residual_channels=wavenet_params['residual_channels'],
        dilation_channels=wavenet_params['dilation_channels'],
        quantization_channels=wavenet_params['quantization_channels'],
        skip_channels=wavenet_params['skip_channels'],
        use_biases=wavenet_params['use_biases'])

    samples = tf.placeholder(tf.int32)

    if args.fast_generation:
        next_sample = net.predict_proba_incremental(samples)
    else:
        next_sample = net.predict_proba(samples)

    if args.fast_generation:
        sess.run(tf.initialize_all_variables())
        sess.run(net.init_ops)

    variables_to_restore = {
        var.name[:-2]: var for var in tf.all_variables()
        if not ('state_buffer' in var.name or 'pointer' in var.name)}
    saver = tf.train.Saver(variables_to_restore)

    print('Restoring model from {}'.format(args.checkpoint))
    saver.restore(sess, args.checkpoint)

    decode = mu_law_decode(samples, wavenet_params['quantization_channels'])

    quantization_channels = wavenet_params['quantization_channels']
    if args.wav_seed:
        seed = create_seed(args.wav_seed,
                           wavenet_params['sample_rate'],
                           quantization_channels)
        waveform = sess.run(seed).tolist()
    else:
        waveform = np.random.randint(quantization_channels, size=(1,)).tolist()

    if args.fast_generation and args.wav_seed:
        # When using the incremental generation, we need to
        # feed in all priming samples one by one before starting the
        # actual generation.
        # TODO This could be done much more efficiently by passing the waveform
        # to the incremental generator as an optional argument, which would be
        # used to fill the queues initially.
        outputs = [next_sample]
        outputs.extend(net.push_ops)

        print('Priming generation...')
        for i, x in enumerate(waveform[:-(args.window + 1)]):
            if i % 100 == 0:
                print('Priming sample {}'.format(i))
            sess.run(outputs, feed_dict={samples: x})
        print('Done.')

    last_sample_timestamp = datetime.now()
    for step in range(args.samples):
        if args.fast_generation:
            outputs = [next_sample]
            outputs.extend(net.push_ops)
            window = waveform[-1]
        else:
            if len(waveform) > args.window:
                window = waveform[-args.window:]
            else:
                window = waveform
            outputs = [next_sample]

        # Run the WaveNet to predict the next sample.
        prediction = sess.run(outputs, feed_dict={samples: window})[0]
        sample = np.random.choice(
            np.arange(quantization_channels), p=prediction)
        waveform.append(sample)

        # Show progress only once per second.
        current_sample_timestamp = datetime.now()
        time_since_print = current_sample_timestamp - last_sample_timestamp
        if time_since_print.total_seconds() > 1.:
            print('Sample {:3<d}/{:3<d}'.format(step + 1, args.samples),
                  end='\r')
            last_sample_timestamp = current_sample_timestamp

        # If we have partial writing, save the result so far.
        if (args.wav_out_path and args.save_every and
                (step + 1) % args.save_every == 0):
            out = sess.run(decode, feed_dict={samples: waveform})
            write_wav(out, wavenet_params['sample_rate'], args.wav_out_path)

    # Introduce a newline to clear the carriage return from the progress.
    print()

    # Save the result as an audio summary.
    datestring = str(datetime.now()).replace(' ', 'T')
    writer = tf.train.SummaryWriter(logdir)
    tf.audio_summary('generated', decode, wavenet_params['sample_rate'])
    summaries = tf.merge_all_summaries()
    summary_out = sess.run(summaries,
                           feed_dict={samples: np.reshape(waveform, [-1, 1])})
    writer.add_summary(summary_out)

    # Save the result as a wav file.
    if args.wav_out_path:
        out = sess.run(decode, feed_dict={samples: waveform})
        write_wav(out, wavenet_params['sample_rate'], args.wav_out_path)

    print('Finished generating. The result can be viewed in TensorBoard.')
Beispiel #17
0
def main():
    args = get_arguments()
    started_datestring = "{0:%Y-%m-%dT%H-%M-%S}".format(datetime.now())
    logdir = os.path.join(args.logdir, 'generate', started_datestring)
    with open(args.wavenet_params, 'r') as config_file:
        wavenet_params = json.load(config_file)

    sess = tf.Session()

    net = WaveNetModel(
        batch_size=1,
        dilations=wavenet_params["dilations"],
        filter_width=wavenet_params["filter_width"],
        residual_channels=wavenet_params["residual_channels"],
        dilation_channels=wavenet_params["dilation_channels"],
        skip_channels=wavenet_params["skip_channels"],
        use_biases=wavenet_params["use_biases"],
        scalar_input=wavenet_params["scalar_input"],
        initial_filter_width=wavenet_params["initial_filter_width"],
        histograms=False,
        global_condition_channels=args.gc_channels,
        global_condition_cardinality=args.gc_cardinality,
        MFSC_channels=wavenet_params["MFSC_channels"],
        AP_channels=wavenet_params["AP_channels"],
        F0_channels=wavenet_params["F0_channels"],
        phone_channels=wavenet_params["phones_channels"],
        phone_pos_channels=wavenet_params["phone_pos_channels"])

    samples = tf.placeholder(tf.float32)
    lc = tf.placeholder(tf.float32)
    
    AP_channels = wavenet_params["AP_channels"]
    MFSC_channels = wavenet_params["MFSC_channels"]

    if args.fast_generation:
        next_sample = net.predict_proba_incremental(samples,  args.gc_id)         ########### understand shape of next_sample
    else:
        # samples = tf.reshape(samples, [1, samples.shape[-2], samples.shape[-1]])
        # lc = tf.reshape(lc, [1, lc.shape[-2], lc.shape[-1]])
        # mvn1, mvn2 , mvn3 ,mvn4, w1, w2, w3, w4 = net.predict_proba(samples, lc, args.gc_id)
        outputs = net.predict_proba(samples, AP_channels, lc, args.gc_id)
        outputs = tf.reshape(outputs, [1, AP_channels])

    if args.fast_generation:
        sess.run(tf.global_variables_initializer())
        sess.run(net.init_ops)

    variables_to_restore = {
        var.name[:-2]: var for var in tf.global_variables()
        if not ('state_buffer' in var.name or 'pointer' in var.name)}
    saver = tf.train.Saver(variables_to_restore)

    print('Restoring model from {}'.format(args.checkpoint))
    saver.restore(sess, args.checkpoint)

    # decode = mu_law_decode(samples, wavenet_params['quantization_channels'])      

    # quantization_channels = wavenet_params['quantization_channels']
    # MFSC_channels = wavenet_params["MFSC_channels"]
    if args.wav_seed:
        # seed = create_seed(args.wav_seed,
        #                    wavenet_params['sample_rate'],
        #                    quantization_channels,
        #                    net.receptive_field)
        # waveform = sess.run(seed).tolist()
        pass
    else:

        # Silence with a single random sample at the end.
        # waveform = [quantization_channels / 2] * (net.receptive_field - 1)
        # waveform.append(np.random.randint(quantization_channels))

        waveform = np.zeros((net.receptive_field - 1, AP_channels + MFSC_channels))
        waveform = np.append(waveform, np.random.randn(1, AP_channels + MFSC_channels), axis=0)

        lc_array, mfsc_array = load_lc(scramble=True) # clip_id:[003, 004, 007, 010, 012 ...]
        # print ("before pading:", lc_array.shape)
        # lc_array = lc_array.reshape(1, lc_array.shape[-2], lc_array.shape[-1])
        lc_array = np.pad(lc_array, ((net.receptive_field, 0), (0, 0)), 'constant', constant_values=((0, 0),(0,0)))
        mfsc_array = np.pad(mfsc_array, ((net.receptive_field, 0), (0, 0)), 'constant', constant_values=((0, 0),(0,0)))
        # print ("after pading:", lc_array.shape)
    if args.fast_generation and args.wav_seed:
        # When using the incremental generation, we need to
        # feed in all priming samples one by one before starting the
        # actual generation.
        # TODO This could be done much more efficiently by passing the waveform
        # to the incremental generator as an optional argument, which would be
        # used to fill the queues initially.
        outputs = [next_sample]
        outputs.extend(net.push_ops)                                   ################# understand net.push_ops, understand shape of outputs

        print('Priming generation...')
        for i, x in enumerate(waveform[-net.receptive_field: -1]):
            if i % 100 == 0:
                print('Priming sample {}'.format(i))
            sess.run(outputs, feed_dict={samples: x})
        print('Done.')

    last_sample_timestamp = datetime.now()
    for step in range(args.samples):
        if args.fast_generation:
            outputs = [next_sample]
            outputs.extend(net.push_ops)
            window = waveform[-1]
        else:

            if len(waveform) > net.receptive_field:
                window = waveform[-net.receptive_field:, :]
            else:
                window = waveform
            # outputs = [next_sample]
            ############################################## TODO ############################################
            # Modified code to get one output of the GMM
            # outputs = estimate_output(mvn1, mvn2, mvn3, mvn4, w1, w2, w3, w4, MFSC_channels)
            # outputs = tf.reshape(outputs, [1,60])
            
            # outputs = w1 * mvn1.sample([1]) + w2 * mvn2.sample([1]) + w3 * mvn3.sample([1]) + w4 * mvn4.sample([1])

        # Run the WaveNet to predict the next sample.
        window = window.reshape(1, window.shape[-2], window.shape[-1])
        # print ("lc_batch_shape:", lc_array[step+1: step+1+net.receptive_field, :].reshape(1, net.receptive_field, -1).shape)
        prediction = sess.run(outputs, feed_dict={samples: window, 
                                                lc: lc_array[step+1: step+1+net.receptive_field, :].reshape(1, net.receptive_field, -1)})
        # print(prediction.shape)
        prediction = np.concatenate((prediction, mfsc_array[step + net.receptive_field].reshape(1,-1)), axis=-1)
        waveform = np.append(waveform, prediction, axis=0)

        ########### temp control is in model.py currently ############

        # # Scale prediction distribution using temperature.
        # np.seterr(divide='ignore')
        # scaled_prediction = np.log(prediction) / args.temperature           ############# understand this
        # scaled_prediction = (scaled_prediction -
        #                      np.logaddexp.reduce(scaled_prediction))
        # scaled_prediction = np.exp(scaled_prediction)
        # np.seterr(divide='warn')

        # # Prediction distribution at temperature=1.0 should be unchanged after
        # # scaling.
        # if args.temperature == 1.0:
        #     np.testing.assert_allclose(
        #             prediction, scaled_prediction, atol=1e-5,
        #             err_msg='Prediction scaling at temperature=1.0 '
        #                     'is not working as intended.')

        # sample = np.random.choice(                                      ############ replace with sampling multivariate gaussian
        #     np.arange(quantization_channels), p=scaled_prediction)
        # waveform.append(sample)

        # Show progress only once per second.
        current_sample_timestamp = datetime.now()
        time_since_print = current_sample_timestamp - last_sample_timestamp
        if time_since_print.total_seconds() > 1.:
            print('Frame {:3<d}/{:3<d}'.format(step + 1, args.samples),
                  end='\r')
            last_sample_timestamp = current_sample_timestamp

        # If we have partial writing, save the result so far.
        if (args.wav_out_path and args.save_every and
                (step + 1) % args.save_every == 0):
            # out = sess.run(decode, feed_dict={samples: waveform})
            # write_wav(out, wavenet_params['sample_rate'], args.wav_out_path)
            np.save(args.wav_out_path, waveform)

    # Introduce a newline to clear the carriage return from the progress.
    print()

    # Save the result as an audio summary.
    # datestring = str(datetime.now()).replace(' ', 'T')
    # writer = tf.summary.FileWriter(logdir)
    # tf.summary.audio('generated', decode, wavenet_params['sample_rate'])
    # summaries = tf.summary.merge_all()
    # summary_out = sess.run(summaries,
    #                        feed_dict={samples: np.reshape(waveform, [-1, 1])})
    # writer.add_summary(summary_out)

    # Save the result as a numpy file.
    if args.wav_out_path:
        # out = sess.run(decode, feed_dict={samples: waveform})
        # write_wav(out, wavenet_params['sample_rate'], args.wav_out_path)
        np.save(args.wav_out_path, waveform[:,:4])    # only take the first four columns, which are aps

    print('Finished generating. The result can be viewed in TensorBoard.')
Beispiel #18
0
class TestGeneration(tf.test.TestCase):
    def setUp(self):
        self.net = WaveNetModel(batch_size=1,
                                dilations=[1, 2, 4, 8, 16, 32, 64, 128, 256],
                                filter_width=2,
                                residual_channels=16,
                                dilation_channels=16,
                                quantization_channels=128,
                                skip_channels=32)

    def testGenerateSimple(self):
        '''Generate a few samples using the naive method and
        perform sanity checks on the output.'''
        waveform = tf.placeholder(tf.int32)
        np.random.seed(0)
        data = np.random.randint(128, size=1000)
        proba = self.net.predict_proba(waveform)

        with self.test_session() as sess:
            sess.run(tf.global_variables_initializer())
            proba = sess.run(proba, feed_dict={waveform: data})

        self.assertAllEqual(proba.shape, [128])
        self.assertTrue(np.all((proba >= 0) & (proba <= (128 - 1))))

    def testGenerateFast(self):
        '''Generate a few samples using the fast method and
        perform sanity checks on the output.'''
        waveform = tf.placeholder(tf.int32)
        np.random.seed(0)
        data = np.random.randint(128)
        proba = self.net.predict_proba_incremental(waveform)

        with self.test_session() as sess:
            sess.run(tf.global_variables_initializer())
            sess.run(self.net.init_ops)
            proba = sess.run(proba, feed_dict={waveform: data})

        self.assertAllEqual(proba.shape, [128])
        self.assertTrue(np.all((proba >= 0) & (proba <= (128 - 1))))

    def testCompareSimpleFast(self):
        waveform = tf.placeholder(tf.int32)
        np.random.seed(0)
        data = np.random.randint(128, size=1000)
        proba = self.net.predict_proba(waveform)
        proba_fast = self.net.predict_proba_incremental(waveform)
        with self.test_session() as sess:
            sess.run(tf.global_variables_initializer())
            sess.run(self.net.init_ops)
            # Prime the incremental generation with all samples
            # except the last one
            for x in data[:-1]:
                proba_fast_ = sess.run([proba_fast, self.net.push_ops],
                                       feed_dict={waveform: x})

            # Get the last sample from the incremental generator
            proba_fast_ = sess.run(proba_fast, feed_dict={waveform: data[-1]})
            # Get the sample from the simple generator
            proba_ = sess.run(proba, feed_dict={waveform: data})
            self.assertAllClose(proba_, proba_fast_)
Beispiel #19
0
def main(checkpoint=None):

    #print("\n\nGenerating.\nPlease wait.\n\n")

    title_BOOL = True
    #title=""

    args = get_arguments()

    # create folder in which to put txt files of generated poems
    directory = "GENERATED/" + args.start_time
    if not os.path.exists(directory):
        os.makedirs(directory)

    started_datestring = "{0:%Y-%m-%dT%H-%M-%S}".format(datetime.now())
    logdir = os.path.join(args.logdir, 'generate', started_datestring)
    with open(args.wavenet_params, 'r') as config_file:
        wavenet_params = json.load(config_file)

    sess = tf.Session()

    net = WaveNetModel(
        batch_size=1,
        dilations=wavenet_params['dilations'],
        filter_width=wavenet_params['filter_width'],
        residual_channels=wavenet_params['residual_channels'],
        dilation_channels=wavenet_params['dilation_channels'],
        quantization_channels=wavenet_params['quantization_channels'],
        skip_channels=wavenet_params['skip_channels'],
        use_biases=wavenet_params['use_biases'])

    samples = tf.placeholder(tf.int32)

    if args.fast_generation:
        next_sample = net.predict_proba_incremental(samples)
    else:
        next_sample = net.predict_proba(samples)

    if args.fast_generation:
        sess.run(tf.initialize_all_variables())
        sess.run(net.init_ops)

    variables_to_restore = {
        var.name[:-2]: var
        for var in tf.all_variables()
        if not ('state_buffer' in var.name or 'pointer' in var.name)
    }
    saver = tf.train.Saver(variables_to_restore)

    powr = int((len(wavenet_params['dilations']) / 2) - 1)
    md = ''.join(
        args.checkpoint.split("-")[-1:]
    )  #map(str.lstrip("[").rstrip("]").strip(",")  , args.checkpoint.split("-")[-1:])

    #STORAGE
    words = "\n\n"

    if checkpoint == None:
        intro = """DIR: {}\tMODEL: {}\t\tLOSS: {}\ndilations: {}\t\t\t\tfilter_width: {}\t\tresidual_channels: {}\ndilation_channels: {}\t\t\tskip_channels: {}\tquantization_channels: {}\n_______________________________________________________________________________________________""".format(
            args.checkpoint.split("/")[-2], md, args.loss, "2^" + str(powr),
            wavenet_params['filter_width'],
            wavenet_params['residual_channels'],
            wavenet_params['dilation_channels'],
            wavenet_params['skip_channels'],
            wavenet_params['quantization_channels'])

        saver.restore(sess, args.checkpoint)
    else:
        print('Restoring model from PARAMETER {}'.format(checkpoint))
        saver.restore(sess, args.checkpoint)

    decode = samples

    quantization_channels = wavenet_params['quantization_channels']
    waveform = [32.]

    last_sample_timestamp = datetime.now()
    limit = args.samples - 1

    print("")

    for step in range(args.samples):

        # COUNTDOWN
        #print(step,args.samples,int(args.samples)-int(step), end="\r")
        #print("")
        print("Generating:", step, "/", args.samples, end="\r")

        if args.fast_generation:
            outputs = [next_sample]
            outputs.extend(net.push_ops)
            window = waveform[-1]
        else:
            if len(waveform) > args.window:
                window = waveform[-args.window:]
            else:
                window = waveform
            outputs = [next_sample]

        # Run the WaveNet to predict the next sample.
        prediction = sess.run(outputs, feed_dict={samples: window})[0]
        sample = np.random.choice(np.arange(quantization_channels),
                                  p=prediction)
        waveform.append(sample)

        # CAPITALIZE TITLE
        if title_BOOL:
            # STORAGE
            words += chr(sample).title()

            #check for newline
            if sample == 10:
                #print("GOT IT___________")
                title_BOOL = False
                words += "\n\n"
        else:
            # STORAGE
            words += chr(sample)

        #TYPEWRITER
        #sys.stdout.write(words[-1])

        if args.text_out_path == None:
            args.text_out_path = "GENERATED/{}/{}_DIR-{}_Model-{}_Loss-{}_Chars-{}.txt".format(
                args.start_time,
                datetime.strftime(datetime.now(), '%Y-%m-%d_%H-%M'),
                args.checkpoint.split("/")[-2],
                args.checkpoint.split("-")[-1], args.loss, args.samples)

        # If we have partial writing, save the result so far.
        if (args.text_out_path and args.save_every
                and (step + 1) % args.save_every == 0):
            out = sess.run(decode, feed_dict={samples: waveform})
            #write_text(out, args.text_out_path,intro,words)
            #print (step, end="\r")

    # Introduce a newline to clear the carriage return from the progress.
    #print()
    ml = "Model: {}  |  Loss: {}  |  {}".format(
        args.checkpoint.split("-")[-1], args.loss,
        args.checkpoint.split("/")[-2])

    # Save the result as a wav file.
    if args.text_out_path:
        out = sess.run(decode, feed_dict={samples: waveform})
        print("                                          ", end="\r")
        write_text(out, args.text_out_path, intro, words, ml)
Beispiel #20
0
def main(checkpoint=None):

    title_BOOL = True
    title = ""

    args = get_arguments()
    started_datestring = "{0:%Y-%m-%dT%H-%M-%S}".format(datetime.now())
    logdir = os.path.join(args.logdir, 'generate', started_datestring)
    with open(args.wavenet_params, 'r') as config_file:
        wavenet_params = json.load(config_file)

    sess = tf.Session()

    net = WaveNetModel(
        batch_size=1,
        dilations=wavenet_params['dilations'],
        filter_width=wavenet_params['filter_width'],
        residual_channels=wavenet_params['residual_channels'],
        dilation_channels=wavenet_params['dilation_channels'],
        quantization_channels=wavenet_params['quantization_channels'],
        skip_channels=wavenet_params['skip_channels'],
        use_biases=wavenet_params['use_biases'])

    samples = tf.placeholder(tf.int32)

    if args.fast_generation:
        next_sample = net.predict_proba_incremental(samples)
    else:
        next_sample = net.predict_proba(samples)

    if args.fast_generation:
        sess.run(tf.initialize_all_variables())
        sess.run(net.init_ops)

    variables_to_restore = {
        var.name[:-2]: var
        for var in tf.all_variables()
        if not ('state_buffer' in var.name or 'pointer' in var.name)
    }
    saver = tf.train.Saver(variables_to_restore)

    powr = int((len(wavenet_params['dilations']) / 2) - 1)
    md = args.checkpoint.split(
        "-"
    )[-1:]  #map(str.lstrip("[").rstrip("]").strip(",")  , args.checkpoint.split("-")[-1:])
    if checkpoint == None:
        intro = """\n________________________________________________________________________________________________\n\n\t\t~~~~~~******~~~~~~\n________________________________________________________________________________________________\n\n\tDIR: {}\n\tMODEL: {}\n\tLOSS: {}\n\n
dilations={}\t        filter_width={}\t                residual_channels={}
dilation_channels={}\tquantization_channels={}\tskip_channels={}\n________________________________________________________________________________________________\n\n""".format(
            args.checkpoint.split("/")[-2], md, args.loss, "2^" + str(powr),
            wavenet_params['filter_width'],
            wavenet_params['residual_channels'],
            wavenet_params['dilation_channels'],
            wavenet_params['quantization_channels'],
            wavenet_params['skip_channels'])
        print(intro)
        saver.restore(sess, args.checkpoint)
    else:
        print('Restoring model from PARAMETER {}'.format(checkpoint))
        saver.restore(sess, args.checkpoint)

    decode = samples

    quantization_channels = wavenet_params['quantization_channels']
    waveform = [32.]

    last_sample_timestamp = datetime.now()
    for step in range(args.samples):
        if args.fast_generation:
            outputs = [next_sample]
            outputs.extend(net.push_ops)
            window = waveform[-1]
        else:
            if len(waveform) > args.window:
                window = waveform[-args.window:]
            else:
                window = waveform
            outputs = [next_sample]

        # Run the WaveNet to predict the next sample.
        prediction = sess.run(outputs, feed_dict={samples: window})[0]
        sample = np.random.choice(np.arange(quantization_channels),
                                  p=prediction)
        waveform.append(sample)

        # CAPITALIZE TITLE
        if title_BOOL:
            # SHOW character by character in terminal
            sys.stdout.write(chr(sample).capitalize())
            title += chr(sample).capitalize()
            #check for newline
            if sample == 10:
                title_BOOL = False
                sys.stdout.write("\n\n")
        else:
            # SHOW character by character in terminal
            sys.stdout.write(chr(sample))

        if args.text_out_path == None:
            args.text_out_path = "GENERATED/{}_DIR-{}_Model-{}_Loss-{}_Chars-{}.txt".format(
                datetime.strftime(datetime.now(), '%Y-%m-%d_%H:%M'),
                args.checkpoint.split("/")[-2],
                args.checkpoint.split("-")[-1], args.loss, args.samples)

        # If we have partial writing, save the result so far.
        if (args.text_out_path and args.save_every
                and (step + 1) % args.save_every == 0):
            out = sess.run(decode, feed_dict={samples: waveform})
            write_text(out, args.text_out_path, intro)

    # Introduce a newline to clear the carriage return from the progress.
    print()

    # Save the result as a wav file.
    if args.text_out_path:
        out = sess.run(decode, feed_dict={samples: waveform})
        write_text(out, args.text_out_path, intro)
Beispiel #21
0
def main():
    config = get_arguments()
    started_datestring = "{0:%Y-%m-%dT%H-%M-%S}".format(datetime.now())
    logdir = os.path.join(config.logdir, 'generate', started_datestring)

    if not os.path.exists(logdir):
        os.makedirs(logdir)

    load_hparams(hparams, config.checkpoint_dir)

    with tf.device('/cpu:0'):

        sess = tf.Session()
        scalar_input = hparams.scalar_input
        net = WaveNetModel(
            batch_size=BATCH_SIZE,
            dilations=hparams.dilations,
            filter_width=hparams.filter_width,
            residual_channels=hparams.residual_channels,
            dilation_channels=hparams.dilation_channels,
            quantization_channels=hparams.quantization_channels,
            out_channels=hparams.out_channels,
            skip_channels=hparams.skip_channels,
            use_biases=hparams.use_biases,
            scalar_input=scalar_input,
            initial_filter_width=hparams.initial_filter_width,
            global_condition_channels=hparams.gc_channels,
            global_condition_cardinality=config.gc_cardinality,
            train_mode=False
        )  # train 단계에서는 global_condition_cardinality를 AudioReader에서 파악했지만, 여기서는 넣어주어야 함

        if scalar_input:
            samples = tf.placeholder(tf.float32, shape=[net.batch_size, None])
        else:
            samples = tf.placeholder(
                tf.int32, shape=[net.batch_size, None]
            )  # samples: mu_law_encode로 변환된 것. one-hot으로 변환되기 전. (batch_size, 길이)

        next_sample = net.predict_proba_incremental(
            samples, [config.gc_id] * net.batch_size
        )  # Fast Wavenet Generation Algorithm-1611.09482 algorithm 적용

        var_list = [
            var for var in tf.global_variables() if 'queue' not in var.name
        ]
        saver = tf.train.Saver(var_list)
        print('Restoring model from {}'.format(config.checkpoint_dir))

        load(saver, sess, config.checkpoint_dir)

        sess.run(
            net.queue_initializer)  # 이 부분이 없으면, checkpoint에서 복원된 값들이 들어 있다.

        quantization_channels = hparams.quantization_channels
        if config.wav_seed:
            # wav_seed의 길이가 receptive_field보다 작으면, padding이라도 해야 되는 거 아닌가? 그냥 짧으면 짧은 대로 return함  --> 그래서 너무 짧으면 error
            seed = create_seed(config.wav_seed, hparams.sample_rate,
                               quantization_channels, net.receptive_field,
                               scalar_input)  # --> mu_law encode 된 것.
            if scalar_input:
                waveform = seed.tolist()
            else:
                waveform = sess.run(
                    seed).tolist()  # [116, 114, 120, 121, 127, ...]

            print('Priming generation...')
            for i, x in enumerate(waveform[-net.receptive_field:-1]
                                  ):  # 제일 마지막 1개는 아래의 for loop의 첫 loop에서 넣어준다.
                if i % 100 == 0:
                    print('Priming sample {}'.format(i))
                sess.run(next_sample,
                         feed_dict={
                             samples:
                             np.array([x] * net.batch_size).reshape(
                                 net.batch_size, 1)
                         })
            print('Done.')
            waveform = np.array([waveform[-net.receptive_field:]] *
                                net.batch_size)
        else:
            # Silence with a single random sample at the end.
            if scalar_input:
                waveform = [0.0] * (net.receptive_field - 1)
                waveform = np.array(waveform * net.batch_size).reshape(
                    net.batch_size, -1)
                waveform = np.concatenate(
                    [
                        waveform, 2 * np.random.rand(net.batch_size).reshape(
                            net.batch_size, -1) - 1
                    ],
                    axis=-1)  # -1~1사이의 random number를 만들어 끝에 붙힌다.
            else:
                waveform = [quantization_channels / 2] * (
                    net.receptive_field - 1
                )  # 필요한 receptive_field 크기보다 1개 작게 만든 후, 아래에서 random하게 1개를 덧붙힌다.
                waveform = np.array(waveform * net.batch_size).reshape(
                    net.batch_size, -1)
                waveform = np.concatenate(
                    [
                        waveform,
                        np.random.randint(quantization_channels,
                                          size=net.batch_size).reshape(
                                              net.batch_size, -1)
                    ],
                    axis=-1)  # one hot 변환 전. (batch_size, 5117)

        last_sample_timestamp = datetime.now()
        for step in range(config.samples):  # 원하는 길이를 구하기 위해 loop

            window = waveform[:, -1:]  # 제일 끝에 있는 1개만 samples에 넣어 준다.

            # Run the WaveNet to predict the next sample.

            # fast가 아닌경우. window: [128.0, 128.0, ..., 128.0, 178, 185]
            # fast인 경우, window는 숫자 1개.
            prediction = sess.run(
                next_sample, feed_dict={samples: window}
            )  # samples는 mu law encoding된 것. 계산 과정에서 one hot으로 변환된다.  --> (batch_size,256)

            if scalar_input:
                sample = prediction
            else:
                # Scale prediction distribution using temperature.
                # 다음 과정은 config.temperature==1이면 각 원소를 합으로 나누어주는 것에 불과. 이미 softmax를 적용한 겂이므로, 합이 1이된다. 그래서 값의 변화가 없다.
                # config.temperature가 1이 아니며, 각 원소의 log취한 값을 나눈 후, 합이 1이 되도록 rescaling하는 것이 된다.
                np.seterr(divide='ignore')
                scaled_prediction = np.log(
                    prediction
                ) / config.temperature  # config.temperature인 경우는 값의 변화가 없다.
                scaled_prediction = (
                    scaled_prediction - np.logaddexp.reduce(
                        scaled_prediction, axis=-1, keepdims=True)
                )  # np.log(np.sum(np.exp(scaled_prediction)))
                scaled_prediction = np.exp(scaled_prediction)
                np.seterr(divide='warn')

                # Prediction distribution at temperature=1.0 should be unchanged after
                # scaling.
                if config.temperature == 1.0:
                    np.testing.assert_allclose(
                        prediction,
                        scaled_prediction,
                        atol=1e-5,
                        err_msg=
                        'Prediction scaling at temperature=1.0 is not working as intended.'
                    )

                sample = [[
                    np.random.choice(np.arange(quantization_channels), p=p)
                ] for p in scaled_prediction]  # choose one sample per batch

            waveform = np.concatenate([waveform, sample], axis=-1)

            # Show progress only once per second.
            current_sample_timestamp = datetime.now()
            time_since_print = current_sample_timestamp - last_sample_timestamp
            if time_since_print.total_seconds() > 1.:
                print('Sample {:3<d}/{:3<d}'.format(step + 1, config.samples),
                      end='\r')
                last_sample_timestamp = current_sample_timestamp

        # Introduce a newline to clear the carriage return from the progress.
        print()

        # Save the result as a wav file.
        if scalar_input:
            out = waveform
        else:
            decode = mu_law_decode(samples, quantization_channels)
            out = sess.run(decode, feed_dict={samples: waveform})
        for i in range(net.batch_size):
            config.wav_out_path = logdir + '/test-{}.wav'.format(i)
            write_wav(out[i], hparams.sample_rate, config.wav_out_path)

        print('Finished generating.')
def main():
    args = get_arguments()
    started_datestring = "{0:%Y-%m-%dT%H-%M-%S}".format(datetime.now())
    logdir = os.path.join(args.logdir, 'generate', started_datestring)
    with open(args.wavenet_params, 'r') as config_file:
        wavenet_params = json.load(config_file)

    sess = tf.Session()

    net = WaveNetModel(
        batch_size=1,
        dilations=wavenet_params['dilations'],
        filter_width=wavenet_params['filter_width'],
        residual_channels=wavenet_params['residual_channels'],
        dilation_channels=wavenet_params['dilation_channels'],
        quantization_channels=wavenet_params['quantization_channels'],
        skip_channels=wavenet_params['skip_channels'],
        use_biases=wavenet_params['use_biases'])

    samples = tf.placeholder(tf.int32)

    if args.fast_generation:
        next_sample = net.predict_proba_incremental(samples)
    else:
        next_sample = net.predict_proba(samples)

    if args.fast_generation:
        sess.run(tf.initialize_all_variables())
        sess.run(net.init_ops)

    variables_to_restore = {
        var.name[:-2]: var
        for var in tf.all_variables()
        if not ('state_buffer' in var.name or 'pointer' in var.name)
    }
    saver = tf.train.Saver(variables_to_restore)

    print('Restoring model from {}'.format(args.checkpoint))
    saver.restore(sess, args.checkpoint)

    decode = samples

    quantization_channels = wavenet_params['quantization_channels']
    waveform = [32.]

    last_sample_timestamp = datetime.now()
    for step in range(args.samples):
        if args.fast_generation:
            outputs = [next_sample]
            outputs.extend(net.push_ops)
            window = waveform[-1]
        else:
            if len(waveform) > args.window:
                window = waveform[-args.window:]
            else:
                window = waveform
            outputs = [next_sample]

        # Run the WaveNet to predict the next sample.
        prediction = sess.run(outputs, feed_dict={samples: window})[0]
        sample = np.random.choice(np.arange(quantization_channels),
                                  p=prediction)
        waveform.append(sample)

        # Show progress only once per second.
        current_sample_timestamp = datetime.now()
        time_since_print = current_sample_timestamp - last_sample_timestamp
        if time_since_print.total_seconds() > 1.:
            print('Sample {:3<d}/{:3<d}'.format(step + 1, args.samples),
                  end='\r')
            last_sample_timestamp = current_sample_timestamp

        # If we have partial writing, save the result so far.
        if (args.text_out_path and args.save_every
                and (step + 1) % args.save_every == 0):
            out = sess.run(decode, feed_dict={samples: waveform})
            write_text(out, args.text_out_path)

    # Introduce a newline to clear the carriage return from the progress.
    print()

    # Save the result as a wav file.
    if args.text_out_path:
        out = sess.run(decode, feed_dict={samples: waveform})
        write_text(out, args.text_out_path)

    print('Finished generating.')
Beispiel #23
0
def main():
    args = get_arguments()
    started_datestring = "{0:%Y-%m-%dT%H-%M-%S}".format(datetime.now())
    logdir = os.path.join(args.logdir, 'generate', started_datestring)
    with open(args.wavenet_params, 'r') as config_file:
        wavenet_params = json.load(config_file)

    sess = tf.Session(
        # config=tf.ConfigProto(device_count={'GPU': 0})
    )

    # Build the WaveNet model
    net = WaveNetModel(
        batch_size=1,
        dilations=wavenet_params['dilations'],
        filter_width=wavenet_params['filter_width'],
        residual_channels=wavenet_params['residual_channels'],
        quantization_channels=wavenet_params['quantization_channels'],
        skip_channels=wavenet_params['skip_channels'],
        gc_channels=args.gc_channels,
        gc_cardinality=args.gc_cardinality,
        lc_channels=args.lc_channels)

    # Create placeholders
    # Default to fast generation
    samples = tf.placeholder(tf.int32)
    lc = tf.placeholder(tf.float32) if args.lc_embedding else None
    gc = args.gc_id or None

    # TODO: right now we pre-calculated lc embeddings of the same length
    # as the audio we'd like to generate so they're naturally algined.
    # Add function to load a length of `args.n_samples` of embeddings
    # from pre-calculated (full-length) embeddings.
    if args.lc_embedding is not None:
        lc_embedding = load_lc_embedding(args.lc_embedding)
        lc_embedding = tf.convert_to_tensor(lc_embedding)
        lc_embedding = tf.reshape(lc_embedding, [1, -1, args.lc_channels])
        lc_embedding = net._enc_upsampling_conv(lc_embedding, args.n_samples)
        lc_embedding = tf.reshape(lc_embedding, [-1, args.lc_channels])

    next_sample = net.predict_proba_incremental(samples, gc, lc)

    sess.run(tf.global_variables_initializer())
    sess.run(net.init_ops)
    # Group the ops we need to run
    output_ops = [next_sample]
    output_ops.extend(net.push_ops)
    # Convert mu-law encoded samples back to (-1, 1) of R
    QUANTIZATION_CHANNELS = wavenet_params['quantization_channels']
    decode = mu_law_decode(samples, QUANTIZATION_CHANNELS)

    variables_to_restore = {
        var.name[:-2]: var
        for var in tf.global_variables()
        if not ('state_buffer' in var.name or 'pointer' in var.name)
    }
    saver = tf.train.Saver(variables_to_restore)

    ckpt = tf.train.get_checkpoint_state(args.checkpoint)
    if ckpt and ckpt.model_checkpoint_path:
        saver.restore(sess, ckpt.model_checkpoint_path)
        print('Restoring model from {}'.format(ckpt.model_checkpoint_path))

    if args.wav_seed:
        seed = create_seed(args.wav_seed, wavenet_params['sample_rate'],
                           QUANTIZATION_CHANNELS, net.receptive_field)
        waveform = sess.run(seed).tolist()
    else:
        # Silence with a single random sample at the end.
        waveform = [0] * (net.receptive_field - 1)
        waveform.append(
            np.random.randint(-QUANTIZATION_CHANNELS // 2,
                              QUANTIZATION_CHANNELS // 2))

    if args.lc_embedding is not None:
        lc_embedding = sess.run(lc_embedding)

    if args.wav_seed:
        # When using the incremental generation, we need to
        # feed in all priming samples one by one before starting the
        # actual generation.
        # TODO This could be done much more efficiently by passing the waveform
        # to the incremental generator as an optional argument, which would be
        # used to fill the queues initially.

        print('Priming generation...')
        for i, x in enumerate(waveform[-net.receptive_field:-1]):
            if i % 100 == 0:
                print('Priming sample {}'.format(i))
            lc_ = lc_embedding[i, :]
            sess.run(output_ops, feed_dict={samples: x, lc: lc_})
        print('Done.')

    last_sample_timestamp = datetime.now()
    lc_ = None
    import sys
    for step in range(args.n_samples):
        if step % 1000 == 0:
            print("Generating {} of {}.".format(step, args.n_samples))
            sys.stdout.flush()

        window = waveform[-1]

        if args.lc_embedding is not None:
            lc_ = lc_embedding[step, :]

        # Run the WaveNet to predict the next sample.
        feed_dict = {samples: window}
        if lc_ is not None:
            feed_dict[lc] = lc_
        results = sess.run(output_ops, feed_dict=feed_dict)

        pred = results[0]

        # Scale prediction distribution using temperature.
        np.seterr(divide='ignore')
        scaled_prediction = np.log(pred) / args.temperature
        scaled_prediction = (scaled_prediction -
                             np.logaddexp.reduce(scaled_prediction))
        scaled_prediction = np.exp(scaled_prediction)
        np.seterr(divide='warn')

        # Prediction distribution at temperature=1.0 should be unchanged
        # after scaling.
        if args.temperature == 1.0:
            np.testing.assert_allclose(
                pred,
                scaled_prediction,
                atol=1e-5,
                err_msg='Prediction scaling at temperature=1.0 '
                'is not working as intended.')

        sample = np.random.choice(np.arange(-QUANTIZATION_CHANNELS // 2,
                                            QUANTIZATION_CHANNELS // 2),
                                  p=scaled_prediction)
        waveform.append(sample)

        # If we have partial writing, save the result so far.
        if (args.wav_out_path and args.save_every
                and (step + 1) % args.save_every == 0):
            out = sess.run(decode, feed_dict={samples: waveform})
            write_wav(out, wavenet_params['sample_rate'], args.wav_out_path)

    # Introduce a newline to clear the carriage return from the progress.
    print()

    # Save the result as an audio summary.
    datestring = str(datetime.now()).replace(' ', 'T')
    writer = tf.summary.FileWriter(logdir)
    tf.summary.audio('generated', decode, wavenet_params['sample_rate'])
    summaries = tf.summary.merge_all()
    summary_out = sess.run(summaries,
                           feed_dict={samples: np.reshape(waveform, [-1, 1])})
    writer.add_summary(summary_out)

    # Save the result as a wav file.
    if args.wav_out_path:
        out = sess.run(decode, feed_dict={samples: waveform})
        write_wav(out, wavenet_params['sample_rate'], args.wav_out_path)

    print('Finished generating. The result can be viewed in TensorBoard.')