def main():
    args = get_arguments()
    logdir = os.path.join(args.logdir, 'train', str(datetime.now()))
    with open(args.wavenet_params, 'r') as config_file:
        wavenet_params = json.load(config_file)

    sess = tf.Session()

    net = WaveNet(
        batch_size=1,
        dilations=wavenet_params['dilations'],
        filter_width=wavenet_params['filter_width'],
        residual_channels=wavenet_params['residual_channels'],
        dilation_channels=wavenet_params['dilation_channels'],
        quantization_channels=wavenet_params['quantization_channels'],
        use_biases=wavenet_params['use_biases'])

    samples = tf.placeholder(tf.int32)

    next_sample = net.predict_proba(samples)

    saver = tf.train.Saver()
    print('Restoring model from {}'.format(args.checkpoint))
    saver.restore(sess, args.checkpoint)

    decode = net.decode(samples)

    quantization_steps = wavenet_params['quantization_steps']
    waveform = np.random.randint(quantization_steps, size=(1, )).tolist()
    for step in range(args.samples):
        if len(waveform) > args.window:
            window = waveform[-args.window:]
        else:
            window = waveform
        prediction = sess.run(next_sample, feed_dict={samples: window})
        sample = np.random.choice(np.arange(quantization_steps), p=prediction)
        waveform.append(sample)
        print('Sample {:3<d}/{:3<d}: {}'.format(step + 1, args.samples,
                                                sample))
        if (args.wav_out_path and args.save_every
                and (step + 1) % args.save_every == 0):

            out = sess.run(decode, feed_dict={samples: waveform})
            write_wav(out, wavenet_params['sample_rate'], args.wav_out_path)

    datestring = str(datetime.now()).replace(' ', 'T')
    writer = tf.train.SummaryWriter(
        os.path.join(logdir, 'generation', datestring))
    tf.audio_summary('generated', decode, wavenet_params['sample_rate'])
    summaries = tf.merge_all_summaries()

    summary_out = sess.run(summaries,
                           feed_dict={samples: np.reshape(waveform, [-1, 1])})
    writer.add_summary(summary_out)

    if args.wav_out_path:
        out = sess.run(decode, feed_dict={samples: waveform})
        write_wav(out, wavenet_params['sample_rate'], args.wav_out_path)

    print('Finished generating. The result can be viewed in TensorBoard.')
Exemple #2
0
  def testSummaries(self):
    with self.cached_session() as s:
      var = tf.Variable([1, 2, 3], dtype=tf.float32)
      s.run(tf.initialize_all_variables())
      x, y = np.meshgrid(np.linspace(-10, 10, 256), np.linspace(-10, 10, 256))
      image = np.sin(x**2 + y**2) / np.sqrt(x**2 + y**2) * .5 + .5
      image = image[None, :, :, None]

      # make a dummy sound
      freq = 440  # A = 440Hz
      sampling_frequency = 11000
      audio = np.sin(2 * np.pi * np.linspace(0, 1, sampling_frequency) * freq)
      audio = audio[None, :, None]
      test_dir = tempfile.mkdtemp()
      # test summaries
      writer = tf.train.SummaryWriter(test_dir)
      summaries = [
          tf.scalar_summary("scalar_var", var[0]),
          tf.scalar_summary("scalar_reduce_var", tf.reduce_sum(var)),
          tf.histogram_summary("var_histogram", var),
          tf.image_summary("sin_image", image),
          tf.audio_summary("sin_wave", audio, sampling_frequency),
      ]
      run_summaries = s.run(summaries)
      writer.add_summary(s.run(tf.merge_summary(inputs=run_summaries)))
      # This is redundant, but we want to be able to rewrite the command
      writer.add_summary(s.run(tf.merge_all_summaries()))
      writer.close()
      shutil.rmtree(test_dir)
def WriteAudioSeries(writer, tag, n_audio=1):
    """Write a few dummy audio clips to writer."""
    step = 0
    session = tf.Session()

    min_frequency_hz = 440
    max_frequency_hz = 880
    sample_rate = 4000
    duration_frames = sample_rate * 0.5  # 0.5 seconds.
    frequencies_per_run = 1
    num_channels = 2

    p = tf.placeholder("float32", (frequencies_per_run, duration_frames, num_channels))
    s = tf.audio_summary(tag, p, sample_rate)

    for _ in xrange(n_audio):
        # Generate a different frequency for each channel to show stereo works.
        frequencies = np.random.random_integers(
            min_frequency_hz, max_frequency_hz, size=(frequencies_per_run, num_channels)
        )
        tiled_frequencies = np.tile(frequencies, (1, duration_frames))
        tiled_increments = np.tile(np.arange(0, duration_frames), (num_channels, 1)).T.reshape(
            1, duration_frames * num_channels
        )
        tones = np.sin(2.0 * np.pi * tiled_frequencies * tiled_increments / sample_rate)
        tones = tones.reshape(frequencies_per_run, duration_frames, num_channels)

        summ = session.run(s, feed_dict={p: tones})
        writer.add_summary(summ, step)
        step += 20
    session.close()
def WriteAudioSeries(writer, tag, n_audio=1):
  """Write a few dummy audio clips to writer."""
  step = 0
  session = tf.Session()

  min_frequency_hz = 440
  max_frequency_hz = 880
  sample_rate = 4000
  duration_frames = sample_rate * 0.5  # 0.5 seconds.
  frequencies_per_run = 1
  num_channels = 2

  p = tf.placeholder("float32", (frequencies_per_run, duration_frames,
                                 num_channels))
  s = tf.audio_summary(tag, p, sample_rate)

  for _ in xrange(n_audio):
    # Generate a different frequency for each channel to show stereo works.
    frequencies = np.random.random_integers(
        min_frequency_hz, max_frequency_hz,
        size=(frequencies_per_run, num_channels))
    tiled_frequencies = np.tile(frequencies, (1, duration_frames))
    tiled_increments = np.tile(
        np.arange(0, duration_frames), (num_channels, 1)).T.reshape(
            1, duration_frames * num_channels)
    tones = np.sin(2.0 * np.pi * tiled_frequencies * tiled_increments /
                   sample_rate)
    tones = tones.reshape(frequencies_per_run, duration_frames, num_channels)

    summ = session.run(s, feed_dict={p: tones})
    writer.add_summary(summ, step)
    step += 20
  session.close()
  def testAudioSummary(self):
    np.random.seed(7)
    with self.test_session() as sess:
      num_frames = 7
      for channels in 1, 2, 5, 8:
        shape = (4, num_frames, channels)
        # Generate random audio in the range [-1.0, 1.0).
        const = 2.0 * np.random.random(shape) - 1.0

        # Summarize
        sample_rate = 8000
        summ = tf.audio_summary("snd",
                                const,
                                max_outputs=3,
                                sample_rate=sample_rate)
        value = sess.run(summ)
        self.assertEqual([], summ.get_shape())
        audio_summ = self._AsSummary(value)

        # Check the rest of the proto
        self._CheckProto(audio_summ, sample_rate, channels, num_frames)
  def testAudioSummary(self):
    np.random.seed(7)
    with self.test_session() as sess:
      num_frames = 7
      for channels in 1, 2, 5, 8:
        shape = (4, num_frames, channels)
        # Generate random audio in the range [-1.0, 1.0).
        const = 2.0 * np.random.random(shape) - 1.0

        # Summarize
        sample_rate = 8000
        summ = tf.audio_summary("snd",
                                const,
                                max_outputs=3,
                                sample_rate=sample_rate)
        value = sess.run(summ)
        self.assertEqual([], summ.get_shape())
        audio_summ = self._AsSummary(value)

        # Check the rest of the proto
        self._CheckProto(audio_summ, sample_rate, channels, num_frames)
def main():
    args = get_arguments()
    started_datestring = "{0:%Y-%m-%dT%H-%M-%S}".format(datetime.now())
    logdir = os.path.join(args.logdir, 'generate', started_datestring)
    with open(args.wavenet_params, 'r') as config_file:
        wavenet_params = json.load(config_file)

    sess = tf.Session()

    net = WaveNetModel(
        batch_size=1,
        dilations=wavenet_params['dilations'],
        filter_width=wavenet_params['filter_width'],
        residual_channels=wavenet_params['residual_channels'],
        dilation_channels=wavenet_params['dilation_channels'],
        quantization_channels=wavenet_params['quantization_channels'],
        skip_channels=wavenet_params['skip_channels'],
        use_biases=wavenet_params['use_biases'],
        scalar_input=wavenet_params['scalar_input'],
        initial_filter_width=wavenet_params['initial_filter_width'])

    samples = tf.placeholder(tf.int32)

    if args.fast_generation:
        next_sample = net.predict_proba_incremental(samples)
    else:
        next_sample = net.predict_proba(samples)

    if args.fast_generation:
        sess.run(tf.initialize_all_variables())
        sess.run(net.init_ops)

    variables_to_restore = {
        var.name[:-2]: var for var in tf.all_variables()
        if not ('state_buffer' in var.name or 'pointer' in var.name)}
    saver = tf.train.Saver(variables_to_restore)

    print('Restoring model from {}'.format(args.checkpoint))
    saver.restore(sess, args.checkpoint)

    decode = mu_law_decode(samples, wavenet_params['quantization_channels'])

    quantization_channels = wavenet_params['quantization_channels']
    if args.wav_seed:
        seed = create_seed(args.wav_seed,
                           wavenet_params['sample_rate'],
                           quantization_channels)
        waveform = sess.run(seed).tolist()
    else:
        waveform = np.random.randint(quantization_channels, size=(1,)).tolist()

    if args.fast_generation and args.wav_seed:
        # When using the incremental generation, we need to
        # feed in all priming samples one by one before starting the
        # actual generation.
        # TODO This could be done much more efficiently by passing the waveform
        # to the incremental generator as an optional argument, which would be
        # used to fill the queues initially.
        outputs = [next_sample]
        outputs.extend(net.push_ops)

        print('Priming generation...')
        for i, x in enumerate(waveform[:-(args.window + 1)]):
            if i % 100 == 0:
                print('Priming sample {}'.format(i))
            sess.run(outputs, feed_dict={samples: x})
        print('Done.')

    last_sample_timestamp = datetime.now()
    for step in range(args.samples):
        if args.fast_generation:
            outputs = [next_sample]
            outputs.extend(net.push_ops)
            window = waveform[-1]
        else:
            if len(waveform) > args.window:
                window = waveform[-args.window:]
            else:
                window = waveform
            outputs = [next_sample]

        # Run the WaveNet to predict the next sample.
        prediction = sess.run(outputs, feed_dict={samples: window})[0]
        sample = np.random.choice(
            np.arange(quantization_channels), p=prediction)
        waveform.append(sample)

        # Show progress only once per second.
        current_sample_timestamp = datetime.now()
        time_since_print = current_sample_timestamp - last_sample_timestamp
        if time_since_print.total_seconds() > 1.:
            print('Sample {:3<d}/{:3<d}'.format(step + 1, args.samples),
                  end='\r')
            last_sample_timestamp = current_sample_timestamp

        # If we have partial writing, save the result so far.
        if (args.wav_out_path and args.save_every and
                (step + 1) % args.save_every == 0):
            out = sess.run(decode, feed_dict={samples: waveform})
            write_wav(out, wavenet_params['sample_rate'], args.wav_out_path)

    # Introduce a newline to clear the carriage return from the progress.
    print()

    # Save the result as an audio summary.
    datestring = str(datetime.now()).replace(' ', 'T')
    writer = tf.train.SummaryWriter(logdir)
    tf.audio_summary('generated', decode, wavenet_params['sample_rate'])
    summaries = tf.merge_all_summaries()
    summary_out = sess.run(summaries,
                           feed_dict={samples: np.reshape(waveform, [-1, 1])})
    writer.add_summary(summary_out)

    # Save the result as a wav file.
    if args.wav_out_path:
        out = sess.run(decode, feed_dict={samples: waveform})
        write_wav(out, wavenet_params['sample_rate'], args.wav_out_path)

    print('Finished generating. The result can be viewed in TensorBoard.')
Exemple #8
0
def main():
    args = get_arguments()
    started_datestring = "{0:%Y-%m-%dT%H-%M-%S}".format(datetime.now())
    logdir = os.path.join(args.logdir, 'generate', started_datestring)
    with open(args.wavenet_params, 'r') as config_file:
        wavenet_params = json.load(config_file)

    sess = tf.Session()

    net = WaveNetModel(
        batch_size=1,
        dilations=wavenet_params['dilations'],
        filter_width=wavenet_params['filter_width'],
        residual_channels=wavenet_params['residual_channels'],
        dilation_channels=wavenet_params['dilation_channels'],
        quantization_channels=wavenet_params['quantization_channels'],
        skip_channels=wavenet_params['skip_channels'],
        use_biases=wavenet_params['use_biases'])

    samples = tf.placeholder(tf.int32)

    if args.fast_generation:
        next_sample = net.predict_proba_incremental(samples)
    else:
        next_sample = net.predict_proba(samples)

    if args.fast_generation:
        sess.run(tf.initialize_all_variables())
        sess.run(net.init_ops)

    variables_to_restore = {
        var.name[:-2]: var for var in tf.all_variables()
        if not ('state_buffer' in var.name or 'pointer' in var.name)}
    saver = tf.train.Saver(variables_to_restore)

    print('Restoring model from {}'.format(args.checkpoint))
    saver.restore(sess, args.checkpoint)

    decode = mu_law_decode(samples, wavenet_params['quantization_channels'])

    quantization_channels = wavenet_params['quantization_channels']
    if args.wav_seed:
        seed = create_seed(args.wav_seed,
                           wavenet_params['sample_rate'],
                           quantization_channels)
        waveform = sess.run(seed).tolist()
    else:
        waveform = np.random.randint(quantization_channels, size=(1,)).tolist()

    if args.fast_generation and args.wav_seed:
        # When using the incremental generation, we need to
        # feed in all priming samples one by one before starting the
        # actual generation.
        # TODO This could be done much more efficiently by passing the waveform
        # to the incremental generator as an optional argument, which would be
        # used to fill the queues initially.
        outputs = [next_sample]
        outputs.extend(net.push_ops)

        print('Priming generation...')
        for i, x in enumerate(waveform[:-(args.window + 1)]):
            if i % 100 == 0:
                print('Priming sample {}'.format(i))
            sess.run(outputs, feed_dict={samples: x})
        print('Done.')

    last_sample_timestamp = datetime.now()
    for step in range(args.samples):
        if args.fast_generation:
            outputs = [next_sample]
            outputs.extend(net.push_ops)
            window = waveform[-1]
        else:
            if len(waveform) > args.window:
                window = waveform[-args.window:]
            else:
                window = waveform
            outputs = [next_sample]

        # Run the WaveNet to predict the next sample.
        prediction = sess.run(outputs, feed_dict={samples: window})[0]
        sample = np.random.choice(
            np.arange(quantization_channels), p=prediction)
        waveform.append(sample)

        # Show progress only once per second.
        current_sample_timestamp = datetime.now()
        time_since_print = current_sample_timestamp - last_sample_timestamp
        if time_since_print.total_seconds() > 1.:
            print('Sample {:3<d}/{:3<d}'.format(step + 1, args.samples),
                  end='\r')
            last_sample_timestamp = current_sample_timestamp

        # If we have partial writing, save the result so far.
        if (args.wav_out_path and args.save_every and
                (step + 1) % args.save_every == 0):
            out = sess.run(decode, feed_dict={samples: waveform})
            write_wav(out, wavenet_params['sample_rate'], args.wav_out_path)

    # Introduce a newline to clear the carriage return from the progress.
    print()

    # Save the result as an audio summary.
    datestring = str(datetime.now()).replace(' ', 'T')
    writer = tf.train.SummaryWriter(logdir)
    tf.audio_summary('generated', decode, wavenet_params['sample_rate'])
    summaries = tf.merge_all_summaries()
    summary_out = sess.run(summaries,
                           feed_dict={samples: np.reshape(waveform, [-1, 1])})
    writer.add_summary(summary_out)

    # Save the result as a wav file.
    if args.wav_out_path:
        out = sess.run(decode, feed_dict={samples: waveform})
        write_wav(out, wavenet_params['sample_rate'], args.wav_out_path)

    print('Finished generating. The result can be viewed in TensorBoard.')
Exemple #9
0
def audio_summary(name, x, sampling_rate=16e3):
    try:
        summ = tf.summary.audio(name, x, sampling_rate)
    except AttributeError:
        summ = tf.audio_summary(name, x, sampling_rate)
    return summ
def main():
    args = get_arguments()
    started_datestring = "{0:%Y-%m-%dT%H-%M-%S}".format(datetime.now())
    logdir = os.path.join(args.logdir, 'generate', started_datestring)
    with open(args.wavenet_params, 'r') as config_file:
        wavenet_params = json.load(config_file)

    sess = tf.Session()

    net = WaveNetModel(
        batch_size=1,
        dilations=wavenet_params['dilations'],
        filter_width=wavenet_params['filter_width'],
        residual_channels=wavenet_params['residual_channels'],
        dilation_channels=wavenet_params['dilation_channels'],
        quantization_channels=wavenet_params['quantization_channels'],
        skip_channels=wavenet_params['skip_channels'],
        use_biases=wavenet_params['use_biases'],
        scalar_input=wavenet_params['scalar_input'],
        initial_filter_width=wavenet_params['initial_filter_width'])

    samples = tf.placeholder(tf.int32)

    next_sample = net.predict_proba_all(samples)

    # if args.fast_generation:
    #     sess.run(tf.initialize_all_variables())
    #     sess.run(net.init_ops)

    variables_to_restore = {
        var.name[:-2]: var for var in tf.all_variables()
        if not ('state_buffer' in var.name or 'pointer' in var.name)}
    saver = tf.train.Saver(variables_to_restore)

    print('Restoring model from {}'.format(args.checkpoint))
    saver.restore(sess, args.checkpoint)

    decode = mu_law_decode(samples, wavenet_params['quantization_channels'])

    quantization_channels = wavenet_params['quantization_channels']
    seed = create_seed(args.wav_seed,
                       wavenet_params['sample_rate'],
                       quantization_channels)
    input_waveform = sess.run(seed).tolist()
    waveform = []
    print('waveform seed length from {}'.format(len(input_waveform)))
    print('samples {}'.format(args.samples))
    last_sample_timestamp = datetime.now()
    for slide_start in range(0, len(input_waveform), args.step_length):
        if slide_start + args.samples >= len(input_waveform):
            break
        input_audio_window = input_waveform[slide_start:slide_start + args.samples]

        outputs = [next_sample]
        # Run the WaveNet to predict the next sample.
        all_prediction = sess.run(outputs, feed_dict={samples: input_audio_window})[0]
        all_prediction = np.asarray(all_prediction)
        output_waveform = get_all_output_from_predictions(all_prediction, net.quantization_channels)

        if len(waveform) > 0:
            overlap_waveform = waveform[slide_start:len(waveform)]
            output_overlap_waveform = output_waveform[:-args.step_length]
            print(len(overlap_waveform), len(output_overlap_waveform), len(waveform))
            result = np.divide(np.add(output_overlap_waveform, overlap_waveform), 2.0)
            waveform[slide_start:len(waveform)] = result
            waveform.extend(output_waveform[-args.step_length:])

        else:
            waveform = output_waveform

        # Show progress only once per second.
        current_sample_timestamp = datetime.now()
        time_since_print = current_sample_timestamp - last_sample_timestamp
        if time_since_print.total_seconds() > 1.:
            print('Sample {:3<d}/{:3<d}'.format(slide_start + 1, args.samples),
                  end='\r')
            last_sample_timestamp = current_sample_timestamp

        # If we have partial writing, save the result so far.
        if (args.wav_out_path and args.save_every and
                        (slide_start + 1) % args.save_every == 0):
            out = sess.run(decode, feed_dict={samples: waveform})
            write_wav(out, wavenet_params['sample_rate'], args.wav_out_path)
            print("current step is {}".format(slide_start))

    # Introduce a newline to clear the carriage return from the progress.
    print()

    # Save the result as an audio summary.
    datestring = str(datetime.now()).replace(' ', 'T')
    writer = tf.train.SummaryWriter(logdir)
    tf.audio_summary('generated', decode, wavenet_params['sample_rate'])
    summaries = tf.merge_all_summaries()
    summary_out = sess.run(summaries,
                           feed_dict={samples: np.reshape(waveform, [-1, 1])})
    writer.add_summary(summary_out)

    # Save the result as a wav file.
    if args.wav_out_path:
        out = sess.run(decode, feed_dict={samples: waveform})
        print("The error between expected and actual is {}".format(mse_with_output(out, OUTPUT_FILE, wavenet_params['sample_rate'])))
        write_wav(out, wavenet_params['sample_rate'], args.wav_out_path)

    print('Finished generating. The result can be viewed in TensorBoard.')
Exemple #11
0
    def __init__(self,
                 model,
                 log_dir,
                 histogram_freq=0,
                 image_freq=0,
                 audio_freq=0,
                 write_graph=False):
        super(Callback, self).__init__()
        if K._BACKEND != 'tensorflow':
            raise Exception('TensorBoardBatch callback only works '
                            'with the TensorFlow backend.')
        import tensorflow as tf
        self.tf = tf
        import keras.backend.tensorflow_backend as KTF
        self.KTF = KTF

        self.log_dir = log_dir
        self.histogram_freq = histogram_freq
        self.image_freq = image_freq
        self.audio_freq = audio_freq
        self.histograms = None
        self.images = None
        self.write_graph = write_graph
        self.iter = 0
        self.scalars = []
        self.images = []
        self.audios = []

        self.model = model
        self.sess = KTF.get_session()

        if self.histogram_freq != 0:
            layers = self.model.layers
            for layer in layers:
                if hasattr(layer, 'name'):
                    layer_name = layer.name
                else:
                    layer_name = layer

                if hasattr(layer, 'W'):
                    name = '{}_W'.format(layer_name)
                    tf.histogram_summary(name,
                                         layer.W,
                                         collections=['histograms'])
                if hasattr(layer, 'b'):
                    name = '{}_b'.format(layer_name)
                    tf.histogram_summary(name,
                                         layer.b,
                                         collections=['histograms'])
                if hasattr(layer, 'output'):
                    name = '{}_out'.format(layer_name)
                    tf.histogram_summary(name,
                                         layer.output,
                                         collections=['histograms'])

        if self.image_freq != 0:
            tf.image_summary('input',
                             self.model.input,
                             max_images=2,
                             collections=['images'])
            tf.image_summary('output',
                             self.model.output,
                             max_images=2,
                             collections=['images'])

        if self.audio_freq != 0:
            tf.audio_summary('input',
                             self.model.input,
                             max_outputs=1,
                             collections=['audios'])
            tf.audio_summary('output',
                             self.model.output,
                             max_outputs=1,
                             collections=['audios'])

        if self.write_graph:
            if self.tf.__version__ >= '0.8.0':
                self.writer = self.tf.train.SummaryWriter(
                    self.log_dir, self.sess.graph)
            else:
                self.writer = self.tf.train.SummaryWriter(
                    self.log_dir, self.sess.graph_def)
        else:
            self.writer = self.tf.train.SummaryWriter(self.log_dir)
Exemple #12
0
    def build_model(self, dataset):
        # if self.y_dim:
        #     self.y= tf.placeholder(tf.float32, [self.batch_size, self.y_dim], name='y')
        #G
        if dataset == 'wav':
            # self.audio_samples = tf.placeholder(tf.float32, [self.batch_size] + [self.output_length, 1],
            #                         name='real_samples')
#            self.images = tf.placeholder(tf.float32, [self.batch_size] + [self.output_length, 1],
#                                    name='real_images')
            # self.gen_audio_samples= tf.placeholder(tf.float32, [self.batch_size] + [self.output_length, 1],
            #                             name='gen_audio_samples')
            self.coord = tf.train.Coordinator()
            self.reader = self.load_wav(self.coord)
            audio_batch = self.reader.dequeue(self.batch_size)
            #import IPython; IPython.embed()

        self.z = tf.placeholder(tf.float32, [None, self.z_dim],
                                name='z')

        self.z_sum = tf.histogram_summary("z", self.z)

        #G deprecated, this only applies for mnist
        # if self.y_dim:
        #     self.G = self.generator(self.z, self.y)
        #     self.D, self.D_logits = self.discriminator(self.images, self.y, reuse=False)

        #     self.sampler = self.sampler(self.z, self.y)
        #     self.D_, self.D_logits = self.discriminator(self.G, self.y, reuse=True)
        # else:
        self.G = self.generator(self.z)
        self.D, self.D_logits = self.discriminator(audio_batch, include_fourier=self.use_fourier)

        self.sampler = self.sampler(self.z)
        self.D_, self.D_logits_ = self.discriminator(self.G, reuse=True, include_fourier=self.use_fourier)

        
        #import IPython; IPython.embed()
        self.d_sum = tf.histogram_summary("d", self.D)
        self.d__sum = tf.histogram_summary("d_", self.D_)
        #G need to check sample rate
        self.G_sum = tf.audio_summary("G", self.G, sample_rate=self.audio_params['sample_rate'])

        self.d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(self.D_logits, tf.ones_like(self.D)))
        self.d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(self.D_logits_, tf.zeros_like(self.D_)))
        self.g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(self.D_logits_, tf.ones_like(self.D_)))
        self.d_loss = self.d_loss_real + self.d_loss_fake

        #G experiments with losses
        #self.max_like_g_loss = tf.reduce_mean(-tf.exp(self.D_logits_)/2.)
        #self.g_loss = self.max_like_g_loss
        #import IPython; IPython.embed()
        #self.g_loss = self.g_loss - self.d_loss

        self.d_loss_real_sum = tf.scalar_summary("d_loss_real", self.d_loss_real)
        self.d_loss_fake_sum = tf.scalar_summary("d_loss_fake", self.d_loss_fake)

        self.g_loss_sum = tf.scalar_summary("g_loss", self.g_loss)
        self.d_loss_sum = tf.scalar_summary("d_loss", self.d_loss)

        t_vars = tf.trainable_variables()

        self.d_vars = [var for var in t_vars if 'd_' in var.name]
        self.g_vars = [var for var in t_vars if 'g_' in var.name]

        self.saver = tf.train.Saver()
def main():
    args = get_arguments()
    logdir = os.path.join(args.logdir, 'train', str(datetime.now()))
    with open(args.wavenet_params, 'r') as config_file:
        wavenet_params = json.load(config_file)

    sess = tf.Session()

    net = WaveNet(
        batch_size=1,
        dilations=wavenet_params['dilations'],
        filter_width=wavenet_params['filter_width'],
        residual_channels=wavenet_params['residual_channels'],
        dilation_channels=wavenet_params['dilation_channels'],
        quantization_channels=wavenet_params['quantization_channels'],
        skip_channels=wavenet_params['skip_channels'],
        use_biases=wavenet_params['use_biases'],
        fast_generation=args.fast_generation)

    samples = tf.placeholder(tf.int32)

    next_sample = net.predict_proba(samples)

    if args.fast_generation:
        sess.run(tf.initialize_all_variables())
        sess.run(net.init_ops)

    variables_to_restore = {
        var.name[:-2]: var
        for var in tf.all_variables()
        if not ('state_buffer' in var.name or 'pointer' in var.name)
    }
    saver = tf.train.Saver(variables_to_restore)

    print('Restoring model from {}'.format(args.checkpoint))
    saver.restore(sess, args.checkpoint)

    decode = mu_law_decode(samples, wavenet_params['quantization_channels'])

    quantization_channels = wavenet_params['quantization_channels']
    if args.wav_seed:
        seed = create_seed(args.wav_seed, wavenet_params['sample_rate'],
                           quantization_channels)
        waveform = sess.run(seed).tolist()
    else:
        waveform = np.random.randint(quantization_channels,
                                     size=(1, )).tolist()

    for step in range(args.samples):
        if args.fast_generation:
            window = waveform[-1]
            outputs = [next_sample]
            outputs.extend(net.push_ops)
        else:
            if len(waveform) > args.window:
                window = waveform[-args.window:]
            else:
                window = waveform
            outputs = [next_sample]

        prediction = sess.run(outputs, feed_dict={samples: window})[0]
        sample = np.random.choice(np.arange(quantization_channels),
                                  p=prediction)
        waveform.append(sample)
        print('Sample {:3<d}/{:3<d}: {}'.format(step + 1, args.samples,
                                                sample))

        if (args.wav_out_path and args.save_every
                and (step + 1) % args.save_every == 0):

            out = sess.run(decode, feed_dict={samples: waveform})
            write_wav(out, wavenet_params['sample_rate'], args.wav_out_path)

    datestring = str(datetime.now()).replace(' ', 'T')
    writer = tf.train.SummaryWriter(
        os.path.join(logdir, 'generation', datestring))
    tf.audio_summary('generated', decode, wavenet_params['sample_rate'])
    summaries = tf.merge_all_summaries()

    summary_out = sess.run(summaries,
                           feed_dict={samples: np.reshape(waveform, [-1, 1])})
    writer.add_summary(summary_out)

    if args.wav_out_path:
        out = sess.run(decode, feed_dict={samples: waveform})
        write_wav(out, wavenet_params['sample_rate'], args.wav_out_path)

    print('Finished generating. The result can be viewed in TensorBoard.')
def main():
    args = get_arguments()
    started_datestring = "{0:%Y-%m-%dT%H-%M-%S}".format(datetime.now())
    logdir = os.path.join(args.logdir, 'generate', started_datestring)
    with open(args.wavenet_params, 'r') as config_file:
        wavenet_params = json.load(config_file)

    sess = tf.Session()

    net = WaveNetModel(
        batch_size=1,
        dilations=wavenet_params['dilations'],
        filter_width=wavenet_params['filter_width'],
        residual_channels=wavenet_params['residual_channels'],
        dilation_channels=wavenet_params['dilation_channels'],
        quantization_channels=wavenet_params['quantization_channels'],
        skip_channels=wavenet_params['skip_channels'],
        use_biases=wavenet_params['use_biases'],
        scalar_input=wavenet_params['scalar_input'],
        initial_filter_width=wavenet_params['initial_filter_width'],
        global_condition_channels=args.gc_channels,
        global_condition_cardinality=args.gc_cardinality)

    samples = tf.placeholder(tf.int32)

    if args.fast_generation:
        next_sample = net.predict_proba_incremental(samples, args.gc_id)
    else:
        next_sample = net.predict_proba(samples, args.gc_id)

    if args.fast_generation:
        sess.run(tf.initialize_all_variables())
        sess.run(net.init_ops)

    variables_to_restore = {
        var.name[:-2]: var
        for var in tf.all_variables()
        if not ('state_buffer' in var.name or 'pointer' in var.name)
    }
    saver = tf.train.Saver(variables_to_restore)

    print('Restoring model from {}'.format(args.checkpoint))
    saver.restore(sess, args.checkpoint)

    decode = mu_law_decode(samples, wavenet_params['quantization_channels'])

    quantization_channels = wavenet_params['quantization_channels']
    if args.wav_seed:
        seed = create_seed(args.wav_seed, wavenet_params['sample_rate'],
                           quantization_channels, net.receptive_field)
        waveform = sess.run(seed).tolist()
    elif args.wav_replicate:
        seed = create_seed(args.wav_replicate,
                           wavenet_params['sample_rate'],
                           quantization_channels,
                           net.receptive_field,
                           cut=False)
        orig_waveform = sess.run(seed).tolist()
        waveform = []
    else:
        # Silence with a single random sample at the end.
        waveform = [quantization_channels / 2] * (net.receptive_field - 1)
        waveform.append(np.random.randint(quantization_channels))

    if args.fast_generation and args.wav_seed:
        # When using the incremental generation, we need to
        # feed in all priming samples one by one before starting the
        # actual generation.
        # TODO This could be done much more efficiently by passing the waveform
        # to the incremental generator as an optional argument, which would be
        # used to fill the queues initially.
        outputs = [next_sample]
        outputs.extend(net.push_ops)

        print('Priming generation...')
        for i, x in enumerate(waveform[-net.receptive_field:-1]):
            if i % 100 == 0:
                print('Priming sample {}'.format(i))
            sess.run(outputs, feed_dict={samples: x})
        print('Done.')

    last_sample_timestamp = datetime.now()
    for step in range(
            min(args.samples,
                len(orig_waveform) - net.receptive_field + 1)):
        if args.wav_replicate:
            # if net.receptive_field < step:
            #     window = orig_waveform[:step]
            # else:
            #     window = orig_waveform[step - net.receptive_field:step]
            # print("receptive_field = {}".format(net.receptive_field))
            # print("length = {}".format(len(orig_waveform)))
            window = orig_waveform[step:step + net.receptive_field]
            outputs = [next_sample]
        else:
            if args.fast_generation:
                outputs = [next_sample]
                outputs.extend(net.push_ops)
                window = waveform[-1]
            else:
                if len(waveform) > net.receptive_field:
                    window = waveform[-net.receptive_field:]
                else:
                    window = waveform
                # print("waveform_length = {}".format(len(waveform)))
                outputs = [next_sample]

        # Run the WaveNet to predict the next sample.
        prediction = sess.run(outputs, feed_dict={samples: window})[0]

        # Scale prediction distribution using temperature.
        np.seterr(divide='ignore')
        scaled_prediction = np.log(prediction) / args.temperature
        scaled_prediction = (scaled_prediction -
                             np.logaddexp.reduce(scaled_prediction))
        scaled_prediction = np.exp(scaled_prediction)
        np.seterr(divide='warn')

        # Prediction distribution at temperature=1.0 should be unchanged after
        # scaling.
        if args.temperature == 1.0:
            np.testing.assert_allclose(
                prediction,
                scaled_prediction,
                atol=1e-5,
                err_msg='Prediction scaling at temperature=1.0 '
                'is not working as intended.')

        sample = np.random.choice(np.arange(quantization_channels),
                                  p=scaled_prediction)
        waveform.append(sample)

        # Show progress only once per second.
        current_sample_timestamp = datetime.now()
        time_since_print = current_sample_timestamp - last_sample_timestamp
        if time_since_print.total_seconds() > 1.:
            print('Sample {:3<d}/{:3<d}'.format(step + 1, args.samples),
                  end='\r')
            last_sample_timestamp = current_sample_timestamp

        # If we have partial writing, save the result so far.
        if (args.wav_out_path and args.save_every
                and (step + 1) % args.save_every == 0):
            out = sess.run(decode, feed_dict={samples: waveform})
            write_wav(out, wavenet_params['sample_rate'], args.wav_out_path)

    # Introduce a newline to clear the carriage return from the progress.
    print()

    # Save the result as an audio summary.
    datestring = str(datetime.now()).replace(' ', 'T')
    writer = tf.train.SummaryWriter(logdir)
    tf.audio_summary('generated', decode, wavenet_params['sample_rate'])
    summaries = tf.merge_all_summaries()
    summary_out = sess.run(summaries,
                           feed_dict={samples: np.reshape(waveform, [-1, 1])})
    writer.add_summary(summary_out)

    # Save the result as a wav file.
    if args.wav_out_path:
        out = sess.run(decode, feed_dict={samples: waveform})
        write_wav(out, wavenet_params['sample_rate'], args.wav_out_path)

    print('Finished generating. The result can be viewed in TensorBoard.')
Exemple #15
0
def audio_summary(name, x, sampling_rate=16e3):
    try:
        summ = tf.summary.audio(name, x, sampling_rate)
    except AttributeError:
        summ = tf.audio_summary(name, x, sampling_rate)
    return summ
 def audio_summary(self, *args, **kwargs):
     summary = tf.audio_summary(*args, **kwargs)
     self._summaries.append(summary)
     return summary
def main():
    args = get_arguments()
    started_datestring = "{0:%Y-%m-%dT%H-%M-%S}".format(datetime.now())
    logdir = os.path.join(args.logdir, 'generate', started_datestring)
    with open(args.wavenet_params, 'r') as config_file:
        wavenet_params = json.load(config_file)

    labels = tf.placeholder(tf.float32)

    data_dir = DATA_DIRECTORY
    file_list = FILE_LIST
    label_dir = data_dir + 'binary_label_norm/'
    audio_dir = data_dir + 'wav/'
    label_dim = 425
    n_out = 1

    iterator = audio_reader.load_generic_audio_label(file_list, audio_dir,
                                                     label_dir, label_dim)
    audio_test, labels_test, filename = iterator.next()
    n_samples_read = audio_test.shape[0]
    labels_test = labels_test.reshape(
        (1, labels_test.shape[0], labels_test.shape[1]))

    sess = tf.Session()

    net = WaveNetModel(
        batch_size=1,
        dilations=wavenet_params['dilations'],
        filter_width=wavenet_params['filter_width'],
        label_dim=label_dim,
        residual_channels=wavenet_params['residual_channels'],
        dilation_channels=wavenet_params['dilation_channels'],
        skip_channels=wavenet_params['skip_channels'],
        quantization_channels=wavenet_params['quantization_channels'],
        use_biases=wavenet_params['use_biases'],
        scalar_input=wavenet_params['scalar_input'],
        initial_filter_width=wavenet_params['initial_filter_width'],
        histograms=False)

    samples = tf.placeholder(tf.int32)

    if args.fast_generation:
        next_sample = net.predict_proba_incremental(samples, labels)
    else:
        next_sample = net.predict_proba(samples, labels)

    if args.fast_generation:
        sess.run(tf.initialize_all_variables())
        sess.run(net.init_ops)

    variables_to_restore = {
        var.name[:-2]: var
        for var in tf.all_variables()
        if not ('state_buffer' in var.name or 'pointer' in var.name)
    }
    saver = tf.train.Saver(variables_to_restore)

    print('Restoring model from {}'.format(args.checkpoint))
    saver.restore(sess, args.checkpoint)

    decode = mu_law_decode(samples, wavenet_params['quantization_channels'])

    quantization_channels = wavenet_params['quantization_channels']
    if args.wav_seed:
        seed = create_seed(args.wav_seed, wavenet_params['sample_rate'],
                           quantization_channels)
        waveform = sess.run(seed).tolist()
    else:
        waveform = np.random.randint(quantization_channels,
                                     size=(1, )).tolist()

    if args.fast_generation and args.wav_seed:
        # When using the incremental generation, we need to
        # feed in all priming samples one by one before starting the
        # actual generation.
        # TODO This could be done much more efficiently by passing the waveform
        # to the incremental generator as an optional argument, which would be
        # used to fill the queues initially.
        outputs = [next_sample]
        outputs.extend(net.push_ops)

        print('Priming generation...')
        for i, x in enumerate(waveform[-args.window:-1]):
            if i % 100 == 0:
                print('Priming sample {}'.format(i))
            sess.run(outputs,
                     feed_dict={
                         samples: x,
                         labels: labels_test[:, i:i + 1, :]
                     })
        print('Done.')

    last_sample_timestamp = datetime.now()
    for step in range(n_samples_read):
        if args.fast_generation:
            outputs = [next_sample]
            outputs.extend(net.push_ops)
            window = waveform[-1]
            labels_window = labels_test[:, step:step + 1, :]
        else:
            if len(waveform) > args.window:
                window = waveform[-args.window:]
            else:
                window = waveform
            outputs = [next_sample]
            labels_window = labels_test[:, step:step + min(
                len(window), args.window
            ), :]  # Here there might be a problem with out of index error.

        # Run the WaveNet to predict the next sample.
        #if (step%100 == 0):
        #    print('step = ', step, ' , ')
        prediction = sess.run(outputs,
                              feed_dict={
                                  samples: window,
                                  labels: labels_window
                              })[0]

        # Scale prediction distribution using temperature.
        np.seterr(divide='ignore')
        scaled_prediction = np.log(prediction) / args.temperature
        scaled_prediction = scaled_prediction - np.logaddexp.reduce(
            scaled_prediction)
        scaled_prediction = np.exp(scaled_prediction)
        np.seterr(divide='warn')

        # Prediction distribution at temperature=1.0 should be unchanged after scaling.
        if args.temperature == 1.0:
            np.testing.assert_allclose(
                prediction,
                scaled_prediction,
                atol=1e-5,
                err_msg=
                'Prediction scaling at temperature=1.0 is not working as intended.'
            )

        sample = np.random.choice(np.arange(quantization_channels),
                                  p=scaled_prediction)
        waveform.append(sample)

        # Show progress only once per second.
        current_sample_timestamp = datetime.now()
        time_since_print = current_sample_timestamp - last_sample_timestamp
        if time_since_print.total_seconds() > 1.:
            print('Sample {:3<d}/{:3<d}'.format(step + 1, args.samples),
                  end='\r')
            last_sample_timestamp = current_sample_timestamp

        # If we have partial writing, save the result so far.
        if (args.wav_out_path and args.save_every
                and (step + 1) % args.save_every == 0):
            out = sess.run(decode, feed_dict={samples: waveform})
            write_wav(out, wavenet_params['sample_rate'], args.wav_out_path)

    # Introduce a newline to clear the carriage return from the progress.
    print()

    # Save the result as an audio summary.
    datestring = str(datetime.now()).replace(' ', 'T')
    writer = tf.train.SummaryWriter(logdir)
    tf.audio_summary('generated', decode, wavenet_params['sample_rate'])
    summaries = tf.merge_all_summaries()
    summary_out = sess.run(summaries,
                           feed_dict={samples: np.reshape(waveform, [-1, 1])})
    writer.add_summary(summary_out)

    # Save the result as a wav file.
    if args.wav_out_path:
        out = sess.run(decode, feed_dict={samples: waveform})
        write_wav(out, wavenet_params['sample_rate'], args.wav_out_path)

    print('Finished generating. The result can be viewed in TensorBoard.')