Exemple #1
0
class TestNet(tf.test.TestCase):

    def setUp(self):
        self.net = WaveNet(batch_size=1,
                           dilations=[1, 2, 4, 8, 16, 32, 64, 128, 256,
                                      1, 2, 4, 8, 16, 32, 64, 128, 256],
                           filter_width=2,
                           residual_channels=16,
                           dilation_channels=16,
                           quantization_channels=256,
                           skip_channels=32)

    # Train a net on a short clip of 3 sine waves superimposed
    # (an e-flat chord).
    #
    # Presumably it can overfit to such a simple signal. This test serves
    # as a smoke test where we just check that it runs end-to-end during
    # training, and learns this waveform.

    def testEndToEndTraining(self):
        audio = MakeSineWaves()
        np.random.seed(42)

        audio_tensor = tf.convert_to_tensor(audio, dtype=tf.float32)
        loss = self.net.loss(audio_tensor)
        optimizer = tf.train.AdamOptimizer(learning_rate=0.02)
        trainable = tf.trainable_variables()
        optim = optimizer.minimize(loss, var_list=trainable)
        init = tf.initialize_all_variables()

        max_allowed_loss = 0.1
        loss_val = max_allowed_loss
        initial_loss = None
        with self.test_session() as sess:
            sess.run(init)
            initial_loss = sess.run(loss)
            for i in range(50):
                loss_val, _ = sess.run([loss, optim])
                # print "i: %d loss: %f" % (i, loss_val)

        # Sanity check the initial loss was larger.
        self.assertGreater(initial_loss, max_allowed_loss)

        # Loss after training should be small.
        self.assertLess(loss_val, max_allowed_loss)

        # Loss should be at least two orders of magnitude better
        # than before training.
        self.assertLess(loss_val / initial_loss, 0.01)
Exemple #2
0
def main():
    args = get_arguments()

    try:
        directories = validate_directories(args)
    except ValueError as e:
        print("Some arguments are wrong:")
        print(str(e))
        return

    logdir = directories['logdir']
    logdir_root = directories['logdir_root']
    restore_from = directories['restore_from']

    # Even if we restored the model, we will treat it as new training
    # if the trained model is written into an arbitrary location.
    is_overwritten_training = logdir != restore_from

    with open(args.wavenet_params, 'r') as f:
        wavenet_params = json.load(f)

    # Create coordinator.
    coord = tf.train.Coordinator()

    # Load raw waveform from VCTK corpus.
    with tf.name_scope('create_inputs'):
        reader = AudioReader(args.data_dir,
                             coord,
                             sample_rate=wavenet_params['sample_rate'],
                             sample_size=args.sample_size)
        audio_batch = reader.dequeue(args.batch_size)

    # Create network.
    net = WaveNet(
        batch_size=args.batch_size,
        dilations=wavenet_params["dilations"],
        filter_width=wavenet_params["filter_width"],
        residual_channels=wavenet_params["residual_channels"],
        dilation_channels=wavenet_params["dilation_channels"],
        skip_channels=wavenet_params["skip_channels"],
        quantization_channels=wavenet_params["quantization_channels"],
        use_biases=wavenet_params["use_biases"])
    loss = net.loss(audio_batch)
    optimizer = tf.train.AdamOptimizer(learning_rate=args.learning_rate)
    trainable = tf.trainable_variables()
    optim = optimizer.minimize(loss, var_list=trainable)

    # Set up logging for TensorBoard.
    writer = tf.train.SummaryWriter(logdir)
    writer.add_graph(tf.get_default_graph())
    run_metadata = tf.RunMetadata()
    summaries = tf.merge_all_summaries()

    # Set up session
    sess = tf.Session(config=tf.ConfigProto(log_device_placement=False))
    init = tf.initialize_all_variables()
    sess.run(init)

    # Saver for storing checkpoints of the model.
    saver = tf.train.Saver()

    try:
        saved_global_step = load(saver, sess, restore_from)
        if is_overwritten_training or saved_global_step is None:
            # The first training step will be saved_global_step + 1,
            # therefore we put -1 here for new or overwritten trainings.
            saved_global_step = -1

    except:
        print("Something went wrong while restoring checkpoint. "
              "We will terminate training to avoid accidentally overwriting "
              "the previous model.")
        raise

    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    reader.start_threads(sess)

    try:
        last_saved_step = saved_global_step
        for step in range(saved_global_step + 1, args.num_steps):
            start_time = time.time()
            if args.store_metadata and step % 50 == 0:
                # Slow run that stores extra information for debugging.
                print('Storing metadata')
                run_options = tf.RunOptions(
                    trace_level=tf.RunOptions.FULL_TRACE)
                summary, loss_value, _ = sess.run([summaries, loss, optim],
                                                  options=run_options,
                                                  run_metadata=run_metadata)
                writer.add_summary(summary, step)
                writer.add_run_metadata(run_metadata,
                                        'step_{:04d}'.format(step))
                tl = timeline.Timeline(run_metadata.step_stats)
                timeline_path = os.path.join(logdir, 'timeline.trace')
                with open(timeline_path, 'w') as f:
                    f.write(tl.generate_chrome_trace_format(show_memory=True))
            else:
                summary, loss_value, _ = sess.run([summaries, loss, optim])
                writer.add_summary(summary, step)

            duration = time.time() - start_time
            print('step {:d} - loss = {:.3f}, ({:.3f} sec/step)'.format(
                step, loss_value, duration))

            if step % 50 == 0:
                save(saver, sess, logdir, step)
                last_saved_step = step

    except KeyboardInterrupt:
        # Introduce a line break after ^C is displayed so save message
        # is on its own line.
        print()
    finally:
        if step > last_saved_step:
            save(saver, sess, logdir, step)
        coord.request_stop()
        coord.join(threads)
Exemple #3
0
def main():
    args = get_arguments()

    try:
        directories = validate_directories(args)
    except ValueError as e:
        print("Some arguments are wrong:")
        print(str(e))
        return

    logdir = directories['logdir']
    logdir_root = directories['logdir_root']
    restore_from = directories['restore_from']

    # Even if we restored the model, we will treat it as new training
    # if the trained model is written into arbitrary location.
    is_new_training = logdir != restore_from

    with open(args.wavenet_params, 'r') as f:
        wavenet_params = json.load(f)

    # create coordinator
    coord = tf.train.Coordinator()

    # Load raw waveform from VCTK corpus.
    with tf.name_scope('create_inputs'):
        custom_runner = CustomRunner(args, wavenet_params, coord)
        audio_batch, _ = custom_runner.get_inputs()

    # Create network.
    net = WaveNet(args.batch_size, wavenet_params["quantization_steps"],
                  wavenet_params["dilations"], wavenet_params["filter_width"],
                  wavenet_params["residual_channels"],
                  wavenet_params["dilation_channels"])
    loss = net.loss(audio_batch)
    optimizer = tf.train.AdamOptimizer(learning_rate=args.learning_rate)
    trainable = tf.trainable_variables()
    optim = optimizer.minimize(loss, var_list=trainable)

    # Set up logging for TensorBoard.
    writer = tf.train.SummaryWriter(logdir)
    writer.add_graph(tf.get_default_graph())
    run_metadata = tf.RunMetadata()
    summaries = tf.merge_all_summaries()

    # Set up session
    sess = tf.Session(config=tf.ConfigProto(log_device_placement=False))
    init = tf.initialize_all_variables()
    sess.run(init)

    # Saver for storing checkpoints of the model.
    saver = tf.train.Saver()

    try:
        saved_global_step = load(saver, sess, restore_from)
        if is_new_training or saved_global_step is None:
            # For "new" training with using pre-trained model,
            # We should ignore saved_global_step

            # The training step is start from saved_global_step + 1
            # Therefore put -1 here if the new training starts.
            saved_global_step = -1

    except:
        print("Something is wrong while restoring checkpoint. "
              "We will terminate training to avoid accidentally overwriting "
              "the previous model.")
        raise

    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    custom_runner.start_threads(sess)

    try:
        for step in range(saved_global_step + 1, args.num_steps):
            start_time = time.time()
            if args.store_metadata and step % 50 == 0:
                # Slow run that stores extra information for debugging.
                print('Storing metadata')
                run_options = tf.RunOptions(
                    trace_level=tf.RunOptions.FULL_TRACE)
                summary, loss_value, _ = sess.run([summaries, loss, optim],
                                                  options=run_options,
                                                  run_metadata=run_metadata)
                writer.add_summary(summary, step)
                writer.add_run_metadata(run_metadata,
                                        'step_{:04d}'.format(step))
                tl = timeline.Timeline(run_metadata.step_stats)
                timeline_path = os.path.join(logdir, 'timeline.trace')
                with open(timeline_path, 'w') as f:
                    f.write(tl.generate_chrome_trace_format(show_memory=True))
            else:
                summary, loss_value, _ = sess.run([summaries, loss, optim])
                writer.add_summary(summary, step)

            duration = time.time() - start_time
            print('step %d - loss = %.3f, (%.3f sec/step)' %
                  (step, loss_value, duration))

            if step % 50 == 0:
                save(saver, sess, logdir, step)

    finally:
        coord.request_stop()
        coord.join(threads)
Exemple #4
0
def main():
    args = get_arguments()
    # if checkpoint found training_step will be updated, else it's 0
    save_to, ckpt, training_step = store_and_restore(args.save_to,
                                                     args.restore_from)

    # Get wavenet object
    wavenet = WaveNet(dilations=DILATIONS,
                      residual_channels=RESIDUAL_CHANNELS,
                      dilation_channels=DILATION_CHANNELS,
                      skip_channels=SKIP_CHANNELS,
                      quantization_channels=QUANTIZATION_CHANNELS,
                      use_aux_features=args.use_aux_features,
                      n_mfcc=args.n_mfcc)

    # create dataset
    data_iterator = get_audio_data(args.data_dir,
                                   mean_sub=args.mean_sub,
                                   normalise=args.normalise,
                                   mean=args.mean_file,
                                   std=args.std_file,
                                   sample_rate=SAMPLE_RATE,
                                   hop_length=args.hop_length,
                                   n_mfcc=args.n_mfcc,
                                   n_fft=args.n_fft,
                                   receptive_field=wavenet.receptive_field,
                                   sample_size=SAMPLE_SIZE)
    data = tf.data.Dataset.from_generator(data_iterator,
                                          (tf.float32, tf.float32))
    data = data.map(lambda audio, aux: encode(audio, aux, n_mfcc=args.n_mfcc))
    data = data.prefetch(30)
    iterator = data.make_initializable_iterator()
    next_element = iterator.get_next()

    # build wavenet and return loss
    loss = wavenet.loss(input_batch=next_element[0], aux_input=next_element[1])
    optimizer = tf.train.AdamOptimizer(learning_rate=LEARNING_RATE,
                                       epsilon=1e-4)

    # add trainable variables to optimizer
    trainable = tf.trainable_variables()
    optim = optimizer.minimize(loss, var_list=trainable)

    writer = tf.summary.FileWriter(save_to)
    # save graph if first training run
    if not args.restore_from:
        writer.add_graph(tf.get_default_graph())
    summaries = tf.summary.merge_all()

    # Saver and variable initializer
    saver = tf.train.Saver(var_list=tf.trainable_variables(),
                           max_to_keep=args.max_checkpoints)
    init = tf.global_variables_initializer()
    sess = tf.Session()

    last_saved_step = training_step
    step = None
    try:
        sess.run(init)
        if ckpt:
            saver.restore(sess, ckpt)
        sess.run(iterator.initializer)

        for step in range(training_step + 1, args.num_steps):
            start_time = time.time()
            summary, loss_value, _ = sess.run([summaries, loss, optim])
            writer.add_summary(summary, step)
            duration = time.time() - start_time
            print('step {:d} - loss = {:.3f}, ({:.3f} sec/step)'.format(
                step, loss_value, duration))
            if step % CHECKPOINT_EVERY == 0:
                saver.save(sess, save_to + "model.ckpt" + str(step))
                last_saved_step = step

    except KeyboardInterrupt:
        print()
    finally:
        if step > last_saved_step:
            saver.save(sess, save_to + "/model.ckpt" + str(step))
def main():
    args = get_arguments()
    datestring = "{0:%Y-%m-%dT%H-%M-%S}".format(datetime.now())
    logdir = os.path.join(args.logdir, 'train', datestring)

    with open(args.wavenet_params, 'r') as f:
        wavenet_params = json.load(f)

    # create coordinator
    coord = tf.train.Coordinator()

    # Load raw waveform from VCTK corpus.
    with tf.name_scope('create_inputs'):
        custom_runner = CustomRunner(args, wavenet_params, coord)
        audio_batch, _ = custom_runner.get_inputs()

    # Create network.
    net = WaveNet(args.batch_size, wavenet_params["quantization_steps"],
                  wavenet_params["dilations"], wavenet_params["filter_width"],
                  wavenet_params["residual_channels"],
                  wavenet_params["dilation_channels"])
    loss = net.loss(audio_batch)
    optimizer = tf.train.AdamOptimizer(learning_rate=args.learning_rate)
    trainable = tf.trainable_variables()
    optim = optimizer.minimize(loss, var_list=trainable)

    # Set up logging for TensorBoard.
    writer = tf.train.SummaryWriter(logdir)
    writer.add_graph(tf.get_default_graph())
    run_metadata = tf.RunMetadata()
    summaries = tf.merge_all_summaries()

    # Set up session
    sess = tf.Session(config=tf.ConfigProto(log_device_placement=False))
    init = tf.initialize_all_variables()
    sess.run(init)
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    custom_runner.start_threads(sess)

    # Saver for storing checkpoints of the model.
    saver = tf.train.Saver()

    try:
        for step in range(args.num_steps):
            start_time = time.time()
            if args.store_metadata and step % 50 == 0:
                # Slow run that stores extra information for debugging.
                print('Storing metadata')
                run_options = tf.RunOptions(
                    trace_level=tf.RunOptions.FULL_TRACE)
                summary, loss_value, _ = sess.run([summaries, loss, optim],
                                                  options=run_options,
                                                  run_metadata=run_metadata)
                writer.add_summary(summary, step)
                writer.add_run_metadata(run_metadata,
                                        'step_{:04d}'.format(step))
                tl = timeline.Timeline(run_metadata.step_stats)
                timeline_path = os.path.join(logdir, 'timeline.trace')
                with open(timeline_path, 'w') as f:
                    f.write(tl.generate_chrome_trace_format(show_memory=True))
            else:
                summary, loss_value, _ = sess.run([summaries, loss, optim])
                writer.add_summary(summary, step)

            duration = time.time() - start_time
            print('step %d - loss = %.3f, (%.3f sec/step)' %
                  (step, loss_value, duration))

            if step % 50 == 0:
                checkpoint_path = os.path.join(logdir, 'model.ckpt')
                print('Storing checkpoint to {}'.format(checkpoint_path))
                saver.save(sess, checkpoint_path, global_step=step)

    finally:
        coord.request_stop()
        coord.join(threads)