class TestNet(tf.test.TestCase):
    def setUp(self):
        self.net = WaveNetModel(
            batch_size=1,
            dilations=[1, 2, 4, 8, 16, 32, 64, 1, 2, 4, 8, 16, 32, 64],
            filter_width=2,
            residual_channels=32,
            dilation_channels=32,
            quantization_channels=256,
            skip_channels=32)
        self.optimizer_type = 'sgd'
        self.learning_rate = 0.02

    # Train a net on a short clip of 3 sine waves superimposed
    # (an e-flat chord).
    #
    # Presumably it can overfit to such a simple signal. This test serves
    # as a smoke test where we just check that it runs end-to-end during
    # training, and learns this waveform.

    def testEndToEndTraining(self):
        audio = MakeSineWaves()
        np.random.seed(42)

        audio_tensor = tf.convert_to_tensor(audio, dtype=tf.float32)
        loss = self.net.loss(audio_tensor)
        optimizer = optimizer_factory[self.optimizer_type](
            learning_rate=self.learning_rate, momentum=MOMENTUM)
        trainable = tf.trainable_variables()
        optim = optimizer.minimize(loss, var_list=trainable)
        init = tf.initialize_all_variables()

        max_allowed_loss = 0.1
        loss_val = max_allowed_loss
        initial_loss = None
        with self.test_session() as sess:
            sess.run(init)
            initial_loss = sess.run(loss)
            for i in range(TRAIN_ITERATIONS):
                loss_val, _ = sess.run([loss, optim])
                # if i % 10 == 0 or i == TRAIN_ITERATIONS-1:
                #    print("i: %d loss: %f" % (i, loss_val))

        # Sanity check the initial loss was larger.
        self.assertGreater(initial_loss, max_allowed_loss)

        # Loss after training should be small.
        self.assertLess(loss_val, max_allowed_loss)

        # Loss should be at least two orders of magnitude better
        # than before training.
        self.assertLess(loss_val / initial_loss, 0.01)
Beispiel #2
0
class TestNet(tf.test.TestCase):

    def setUp(self):
        self.net = WaveNetModel(batch_size=1,
                                dilations=[1, 2, 4, 8, 16, 32, 64, 128, 256,
                                           1, 2, 4, 8, 16, 32, 64, 128, 256],
                                filter_width=2,
                                residual_channels=16,
                                dilation_channels=16,
                                quantization_channels=256,
                                skip_channels=32)

    # Train a net on a short clip of 3 sine waves superimposed
    # (an e-flat chord).
    #
    # Presumably it can overfit to such a simple signal. This test serves
    # as a smoke test where we just check that it runs end-to-end during
    # training, and learns this waveform.

    def testEndToEndTraining(self):
        audio = MakeSineWaves()
        np.random.seed(42)

        audio_tensor = tf.convert_to_tensor(audio, dtype=tf.float32)
        loss = self.net.loss(audio_tensor)
        optimizer = tf.train.AdamOptimizer(learning_rate=0.02)
        trainable = tf.trainable_variables()
        optim = optimizer.minimize(loss, var_list=trainable)
        init = tf.initialize_all_variables()

        max_allowed_loss = 0.1
        loss_val = max_allowed_loss
        initial_loss = None
        with self.test_session() as sess:
            sess.run(init)
            initial_loss = sess.run(loss)
            for i in range(50):
                loss_val, _ = sess.run([loss, optim])
                # print("i: %d loss: %f" % (i, loss_val))

        # Sanity check the initial loss was larger.
        self.assertGreater(initial_loss, max_allowed_loss)

        # Loss after training should be small.
        self.assertLess(loss_val, max_allowed_loss)

        # Loss should be at least two orders of magnitude better
        # than before training.
        self.assertLess(loss_val / initial_loss, 0.01)
Beispiel #3
0
def make_net(args, wavenet_params, audio_batch, reuse_variables):
    # Create network.
    net = WaveNetModel(
        batch_size=args.batch_size,
        dilations=wavenet_params["dilations"],
        filter_width=wavenet_params["filter_width"],
        residual_channels=wavenet_params["residual_channels"],
        dilation_channels=wavenet_params["dilation_channels"],
        skip_channels=wavenet_params["skip_channels"],
        quantization_channels=wavenet_params["quantization_channels"],
        use_biases=wavenet_params["use_biases"],
        scalar_input=wavenet_params["scalar_input"],
        initial_filter_width=wavenet_params["initial_filter_width"],
        reuse_variables=reuse_variables,
        histograms=args.histograms)
    if args.l2_regularization_strength == 0:
        args.l2_regularization_strength = None
    loss = net.loss(audio_batch, args.l2_regularization_strength)
    optimizer = optimizer_factory[args.optimizer](
        learning_rate=args.learning_rate, momentum=args.momentum)
    trainable = tf.trainable_variables()
    return loss, optimizer, trainable
Beispiel #4
0
def main():
    args = get_arguments()

    try:
        directories = validate_directories(args)
    except ValueError as e:
        print("Some arguments are wrong:")
        print(str(e))
        return

    logdir = directories['logdir']
    restore_from = directories['restore_from']

    # Even if we restored the model, we will treat it as new training
    # if the trained model is written into an arbitrary location.
    is_overwritten_training = logdir != restore_from

    with open(args.wavenet_params, 'r') as f:
        wavenet_params = json.load(f)

    # Create coordinator.
    coord = tf.train.Coordinator()

    # Load raw waveform from VCTK corpus.
    with tf.name_scope('create_inputs'):
        # Allow silence trimming to be skipped by specifying a threshold near
        # zero.
        silence_threshold = args.silence_threshold if args.silence_threshold > \
                                                      EPSILON else None
        gc_enabled = args.gc_channels is not None
        reader = AudioReader(
            args.data_dir,
            coord,
            sample_rate=wavenet_params['sample_rate'],
            gc_enabled=gc_enabled,
            receptive_field=WaveNetModel.calculate_receptive_field(
                wavenet_params["filter_width"], wavenet_params["dilations"],
                wavenet_params["scalar_input"],
                wavenet_params["initial_filter_width"]),
            sample_size=args.sample_size,
            silence_threshold=silence_threshold)
        audio_batch = reader.dequeue(args.batch_size)
        if gc_enabled:
            gc_id_batch = reader.dequeue_gc(args.batch_size)
        else:
            gc_id_batch = None

    # Create network.
    net = WaveNetModel(
        batch_size=args.batch_size,
        dilations=wavenet_params["dilations"],
        filter_width=wavenet_params["filter_width"],
        residual_channels=wavenet_params["residual_channels"],
        dilation_channels=wavenet_params["dilation_channels"],
        skip_channels=wavenet_params["skip_channels"],
        quantization_channels=wavenet_params["quantization_channels"],
        use_biases=wavenet_params["use_biases"],
        scalar_input=wavenet_params["scalar_input"],
        initial_filter_width=wavenet_params["initial_filter_width"],
        histograms=args.histograms,
        global_condition_channels=args.gc_channels,
        global_condition_cardinality=reader.gc_category_cardinality)

    if args.l2_regularization_strength == 0:
        args.l2_regularization_strength = None
    loss = net.loss(input_batch=audio_batch,
                    global_condition_batch=gc_id_batch,
                    l2_regularization_strength=args.l2_regularization_strength)
    optimizer = optimizer_factory[args.optimizer](
        learning_rate=args.learning_rate, momentum=args.momentum)
    trainable = tf.trainable_variables()
    optim = optimizer.minimize(loss, var_list=trainable)

    # Set up logging for TensorBoard.
    writer = tf.summary.FileWriter(logdir)
    writer.add_graph(tf.get_default_graph())
    run_metadata = tf.RunMetadata()
    summaries = tf.summary.merge_all()

    # Set up session

    sess = tf.Session(config=tf.ConfigProto(log_device_placement=False))

    init = tf.global_variables_initializer()
    sess.run(init)
    #sess = tf_debug.LocalCLIDebugWrapperSession(sess, thread_name_filter="MainThread$", dump_root="C:\\MProjects\\WaveNet\\tensorflow-wavenet-master\\debugDump")

    # run --node_name_filter wavenet_1/loss/Reshape_1:0 -- (36352, 256)
    # run --node_name_filter (.*loss.*)|(.*encode.*)
    # pt -a tensorName > C:/Users/russkov.alexander/Desktop/WaveNet/tensorflow-wavenet-master/myDebugInfo/file.txt
    #encoded_input = Tensor("wavenet_1/encode/ToInt32:0", shape=(1, ?, 1), dtype=int32)  -- (1, 59901, 1)
    #encoded = Tensor("wavenet_1/one_hot_encode/Reshape:0", shape=(1, ?, 256), dtype=float32) -- (1, 59901, 256)

    #https: // www.tensorflow.org / guide / debugger  # frequently_asked_questions
    #Q: The model I am debugging is very large. The data dumped by tfdbg fills up the free space of my disk. What can I do?
    #https: // github.com / tensorflow / tensorflow / issues / 8753
    #sess = tf_debug.TensorBoardDebugWrapperSession(sess, "RUSSKOV-NB-W10:6064", send_traceback_and_source_code=False)

    # Saver for storing checkpoints of the model.
    saver = tf.train.Saver(var_list=tf.trainable_variables(),
                           max_to_keep=args.max_checkpoints)

    try:
        saved_global_step = load(saver, sess, restore_from)
        if is_overwritten_training or saved_global_step is None:
            # The first training step will be saved_global_step + 1,
            # therefore we put -1 here for new or overwritten trainings.
            saved_global_step = -1

    except:
        print("Something went wrong while restoring checkpoint. "
              "We will terminate training to avoid accidentally overwriting "
              "the previous model.")
        raise

    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    reader.start_threads(sess)

    step = None
    last_saved_step = saved_global_step
    try:
        for step in range(saved_global_step + 1, args.num_steps):
            start_time = time.time()
            if args.store_metadata and step % 50 == 0:
                # Slow run that stores extra information for debugging.
                print('Storing metadata')
                run_options = tf.RunOptions(
                    trace_level=tf.RunOptions.FULL_TRACE)
                summary, loss_value, _ = sess.run([summaries, loss, optim],
                                                  options=run_options,
                                                  run_metadata=run_metadata)
                writer.add_summary(summary, step)
                writer.add_run_metadata(run_metadata,
                                        'step_{:04d}'.format(step))
                tl = timeline.Timeline(run_metadata.step_stats)
                timeline_path = os.path.join(logdir, 'timeline.trace')
                with open(timeline_path, 'w') as f:
                    f.write(tl.generate_chrome_trace_format(show_memory=True))
            else:
                summary, loss_value, _ = sess.run([summaries, loss, optim])
                writer.add_summary(summary, step)

            duration = time.time() - start_time
            print('step {:d} - loss = {:.3f}, ({:.3f} sec/step)'.format(
                step, loss_value, duration))

            if step % args.checkpoint_every == 0:
                save(saver, sess, logdir, step)
                last_saved_step = step

    except KeyboardInterrupt:
        # Introduce a line break after ^C is displayed so save message
        # is on its own line.
        print()
    finally:
        if step > last_saved_step:
            save(saver, sess, logdir, step)
        coord.request_stop()
        coord.join(threads)
def main():
    args = get_arguments()

    try:
        directories = validate_directories(args)
    except ValueError as e:
        print("Some arguments are wrong:")
        print(str(e))
        return

    logdir = directories['logdir']
    restore_from = directories['restore_from']

    # Even if we restored the model, we will treat it as new training
    # if the trained model is written into an arbitrary location.
    is_overwritten_training = logdir != restore_from

    with open(args.wavenet_params, 'r') as f:
        wavenet_params = json.load(f)

    # Create coordinator.
    coord = tf.train.Coordinator()

    # Load raw waveform from VCTK corpus.
    with tf.name_scope('create_inputs'):
        gc_enabled = args.gc_channels is not None
        reader = AudioReader(
            args.data_dir,
            coord,
            sample_rate=wavenet_params['sample_rate'],
            gc_enabled=gc_enabled,
            max_samples=get_max_samples(args.data_dir,
                                        wavenet_params['sample_rate']),
            receptive_field=WaveNetModel.calculate_receptive_field(
                wavenet_params["filter_width"], wavenet_params["dilations"],
                wavenet_params["scalar_input"],
                wavenet_params["initial_filter_width"]),
            sample_size=args.sample_size,
            silence_threshold=args.silence_threshold
            if args.silence_threshold > EPSILON else None)
        audio_batch = reader.dequeue(args.batch_size)
        if gc_enabled:
            gc_id_batch = reader.dequeue_gc(args.batch_size)
        else:
            gc_id_batch = None

    # Create network.
    net = WaveNetModel(
        batch_size=args.batch_size,
        dilations=wavenet_params["dilations"],
        filter_width=wavenet_params["filter_width"],
        residual_channels=wavenet_params["residual_channels"],
        dilation_channels=wavenet_params["dilation_channels"],
        skip_channels=wavenet_params["skip_channels"],
        quantization_channels=wavenet_params["quantization_channels"],
        use_biases=wavenet_params["use_biases"],
        scalar_input=wavenet_params["scalar_input"],
        initial_filter_width=wavenet_params["initial_filter_width"],
        histograms=args.histograms,
        global_condition_channels=args.gc_channels,
        global_condition_cardinality=reader.gc_category_cardinality)

    if args.l2_regularization_strength == 0:
        args.l2_regularization_strength = None
    loss = net.loss(input_batch=audio_batch,
                    global_condition_batch=gc_id_batch,
                    l2_regularization_strength=args.l2_regularization_strength)
    learning_rate_placeholder = tf.placeholder(tf.float32, [])
    optimizer = tf.train.RMSPropOptimizer(
        learning_rate=learning_rate_placeholder, momentum=args.momentum)
    train_op = optimizer.minimize(loss)

    # Set up logging for TensorBoard.
    writer = tf.summary.FileWriter(logdir)
    writer.add_graph(tf.get_default_graph())
    run_metadata = tf.RunMetadata()
    summaries = tf.summary.merge_all()

    # Set up session
    sess = tf.Session(config=tf.ConfigProto(log_device_placement=False))
    init = tf.global_variables_initializer()
    sess.run(init)

    # Saver for storing checkpoints of the model.
    saver = tf.train.Saver(var_list=tf.trainable_variables(),
                           max_to_keep=args.max_checkpoints)

    try:
        saved_global_step = load(saver, sess, restore_from)
        if is_overwritten_training or saved_global_step is None:
            # The first training step will be saved_global_step + 1,
            # therefore we put -1 here for new or overwritten trainings.
            saved_global_step = -1

    except:
        print("Something went wrong while restoring checkpoint. "
              "We will terminate training to avoid accidentally overwriting "
              "the previous model.")
        raise

    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    reader.start_threads(sess)

    step = None
    loss_value = None
    update = 0
    last_saved_step = saved_global_step
    learning_rate = args.learning_rate
    print('learning_rate {:f})'.format(learning_rate))
    try:
        for step in range(saved_global_step + 1, args.num_steps):
            start_time = time.time()

            if args.store_metadata and step % 50 == 0:
                # Slow run that stores extra information for debugging.
                print('Storing metadata')
                run_options = tf.RunOptions(
                    trace_level=tf.RunOptions.FULL_TRACE)
                summary, loss_value, _ = sess.run(
                    [summaries, loss, train_op],
                    feed_dict={learning_rate_placeholder: learning_rate},
                    options=run_options,
                    run_metadata=run_metadata)
                writer.add_summary(summary, step)
                writer.add_run_metadata(run_metadata,
                                        'step_{:04d}'.format(step))
                tl = timeline.Timeline(run_metadata.step_stats)
                timeline_path = os.path.join(logdir, 'timeline.trace')
                with open(timeline_path, 'w') as f:
                    f.write(tl.generate_chrome_trace_format(show_memory=True))
            else:
                summary, loss_value, _ = sess.run(
                    [summaries, loss, train_op],
                    feed_dict={learning_rate_placeholder: learning_rate})
                writer.add_summary(summary, step)

            if 1.5 >= loss_value > 0.5 and update == 0:
                learning_rate = learning_rate * 0.1
                update += 1
                print('learning_rate {:f})'.format(learning_rate))
            elif loss_value <= 0.5 and update == 1:
                learning_rate = learning_rate * 0.1
                update += 1
                print('learning_rate {:f})'.format(learning_rate))

            duration = time.time() - start_time
            print('step {:d} - loss = {:.3f}, ({:.3f} sec/step)'.format(
                step, loss_value, duration))

            if step % args.checkpoint_every == 0:
                save(saver, sess, logdir, step)
                last_saved_step = step

    except KeyboardInterrupt:
        # Introduce a line break after ^C is displayed so save message
        # is on its own line.
        print()
    finally:
        if step > last_saved_step:
            save(saver, sess, logdir, step)
        coord.request_stop()
        coord.join(threads)
Beispiel #6
0
def main():
    args = get_arguments()
    data_dir = 'midi-Corpus/' + args.data_set + '/'
    logdir = data_dir + 'max_dilation=%d_reps=%d/' % (args.max_dilation_pow,
                                                      args.expansion_reps)
    print('*************************************************')
    print(logdir)
    print('*************************************************')
    sys.stdout.flush()
    restore_from = logdir
    if not os.path.exists(logdir):
        os.makedirs(logdir)

    # Even if we restored the model, we will treat it as new training
    # if the trained model is written into an arbitrary location.
    is_overwritten_training = logdir != restore_from

    wavenet_params = loadParams(args.max_dilation_pow, args.expansion_reps,
                                args.dil_chan, args.res_chan, args.skip_chan)

    with open(logdir + 'wavenet_params.json', 'w') as outfile:
        json.dump(wavenet_params, outfile)

    # Create coordinator.
    coord = tf.train.Coordinator()

    # Load raw waveform from VCTK corpus.
    with tf.name_scope('create_inputs'):
        # Allow silence trimming to be skipped by specifying a threshold near
        # zero.
        gc_enabled = False
        # data queue for the training set
        train_dir = data_dir + 'train/'
        train_reader = MidiReader(
            train_dir,
            coord,
            sample_rate=wavenet_params['sample_rate'],
            gc_enabled=gc_enabled,
            receptive_field=WaveNetModel.calculate_receptive_field(
                wavenet_params["filter_width"], wavenet_params["dilations"],
                wavenet_params["scalar_input"],
                wavenet_params["initial_filter_width"]),
            sample_size=args.sample_size)
        train_batch = train_reader.dequeue(args.batch_size)
        # data queue for the validation set
        #valid_dir = data_dir + 'valid/';
        #valid_reader = MidiReader(
        #    valid_dir,
        #    coord,
        #    sample_rate=wavenet_params['sample_rate'],
        #    gc_enabled=gc_enabled,
        #    receptive_field=WaveNetModel.calculate_receptive_field(wavenet_params["filter_width"],
        #                                                           wavenet_params["dilations"],
        #                                                           wavenet_params["scalar_input"],
        #                                                           wavenet_params["initial_filter_width"]),
        #    sample_size=args.sample_size)
        #valid_batch = valid_reader.dequeue(args.batch_size)
        if gc_enabled:
            gc_id_batch = reader.dequeue_gc(args.batch_size)
        else:
            gc_id_batch = None

    # Create network.
    net = WaveNetModel(
        batch_size=BATCH_SIZE,
        dilations=wavenet_params["dilations"],
        filter_width=wavenet_params["filter_width"],
        residual_channels=wavenet_params["residual_channels"],
        dilation_channels=wavenet_params["dilation_channels"],
        skip_channels=wavenet_params["skip_channels"],
        use_biases=wavenet_params["use_biases"],
        scalar_input=wavenet_params["scalar_input"],
        initial_filter_width=wavenet_params["initial_filter_width"],
        histograms=False,
        global_condition_channels=None,
        global_condition_cardinality=train_reader.gc_category_cardinality)
    if args.l2_regularization_strength == 0:
        args.l2_regularization_strength = None
    print('constructing training loss')
    sys.stdout.flush()
    train_loss, recon_loss, latent_loss, target_output, prediction, mu_enc, layers = net.loss(
        input_batch=train_batch,
        global_condition_batch=gc_id_batch,
        l2_regularization_strength=args.l2_regularization_strength)
    print('constructing validation loss')
    sys.stdout.flush()
    #valid_loss, target_output, prediction = net.loss(input_batch=valid_batch,
    #                global_condition_batch=gc_id_batch,
    #                l2_regularization_strength=args.l2_regularization_strength)

    print('making optimizer')
    sys.stdout.flush()
    optimizer = optimizer_factory['adam'](learning_rate=args.learning_rate,
                                          momentum=args.momentum)
    trainable = tf.trainable_variables()
    optim = optimizer.minimize(train_loss, var_list=trainable)

    print('setting up tensorboard')
    sys.stdout.flush()
    # Set up logging for TensorBoard.
    writer = tf.summary.FileWriter(logdir)
    writer.add_graph(tf.get_default_graph())
    run_metadata = tf.RunMetadata()
    summaries = tf.summary.merge_all()

    valid_input = tf.placeholder(dtype=tf.float32, shape=(1, None, 88))
    valid_loss, valid_recon_loss, valid_latent_loss, valid_target_output, valid_prediction, valid_mu, valid_enc_layers = net.loss(
        input_batch=valid_input,
        global_condition_batch=gc_id_batch,
        l2_regularization_strength=args.l2_regularization_strength)

    # Set up session
    sess = tf.Session(config=tf.ConfigProto(log_device_placement=False))
    init = tf.global_variables_initializer()
    sess.run(init)

    print('saver')
    sys.stdout.flush()
    # Saver for storing checkpoints of the model.
    saver = tf.train.Saver(var_list=tf.trainable_variables(), max_to_keep=5)

    try:
        saved_global_step = load(saver, sess, restore_from)
        if is_overwritten_training or saved_global_step is None:
            # The first training step will be saved_global_step + 1,
            # therefore we put -1 here for new or overwritten trainings.
            saved_global_step = -1

    except:
        print("Something went wrong while restoring checkpoint. "
              "We will terminate training to avoid accidentally overwriting "
              "the previous model.")
        raise

    print('thread stuff')
    sys.stdout.flush()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    train_reader.start_threads(sess)

    step = None
    last_saved_step = saved_global_step

    # load validation data
    validation_audio = load_all_audio(data_dir + 'valid/')
    num_valid_files = len(validation_audio)
    valid_loss_values = np.zeros((int(np.ceil(args.num_steps / 500)), ))
    vl_ind = 0

    valid_losses_step = np.zeros((num_valid_files, ))
    audio_0 = np.expand_dims(validation_audio[2], 0)
    print('audio 0', audio_0.shape)
    print(audio_0)
    mu_enc_0, enc_layers_0 = sess.run([valid_mu, valid_enc_layers],
                                      {valid_input: audio_0})
    print('layer shapes')
    for layer in enc_layers_0:
        print(layer.shape)
    print('mu 0', mu_enc_0.shape)
    print(mu_enc_0)
    valid_loss_0 = sess.run(valid_loss, {valid_input: audio_0})
    print('valid_loss_0', valid_loss_0)
    #print('validation loss 0', valid_losses_step_0);

    print('optimization time')
    sys.stdout.flush()
    min_valid_loss = 1e10
    try:
        for step in range(saved_global_step + 1, args.num_steps):
            print('step', step)
            sys.stdout.flush()
            start_time = time.time()
            if args.store_metadata and step % 500 == 0:
                # Slow run that stores extra information for debugging.
                print('Storing metadata')
                sys.stdout.flush()
                run_options = tf.RunOptions(
                    trace_level=tf.RunOptions.FULL_TRACE)
                print('mu comp')
                sys.stdout.flush()
                _mu_enc = sess.run(mu_enc,
                                   options=run_options,
                                   run_metadata=run_metadata)
                print(_mu_enc.shape)
                summary, loss_value, _ = sess.run(
                    [summaries, train_loss, optim],
                    options=run_options,
                    run_metadata=run_metadata)
                print('writing summary')
                sys.stdout.flush()
                writer.add_summary(summary, step)
                writer.add_run_metadata(run_metadata,
                                        'step_{:04d}'.format(step))
                valid_losses_step = np.zeros((num_valid_files, ))
                for i in range(num_valid_files):
                    audio_i = np.expand_dims(validation_audio[i], 0)
                    valid_losses_step[i] = sess.run(valid_loss,
                                                    {valid_input: audio_i})
                valid_loss_value_step = np.mean(valid_losses_step)
                valid_loss_values[vl_ind] = valid_loss_value_step
                np.savez(logdir + 'validation.npz',
                         validation_loss=valid_loss_values)
                vl_ind += 1
                tl = timeline.Timeline(run_metadata.step_stats)
                timeline_path = os.path.join(logdir, 'timeline.trace')
                with open(timeline_path, 'w') as f:
                    f.write(tl.generate_chrome_trace_format(show_memory=True))

                if (valid_loss_value_step < min_valid_loss):
                    print('new min!')
                    if (not np.isnan(valid_loss_value_step)):
                        print('saving')
                        min_valid_loss = valid_loss_value_step
                        save(saver, sess, logdir, step)
                        last_saved_step = step
                    else:
                        print('ignoring model bc of nan')
            else:
                _rec_ls, _lat_ls, _tot_ls, _pred = sess.run(
                    [recon_loss, latent_loss, train_loss, prediction])
                print('recon', _rec_ls, 'latent', _lat_ls, 'total', _tot_ls,
                      'max pred', np.max(_pred), 'min pred', np.min(_pred))
                summary, loss_value, _ = sess.run(
                    [summaries, train_loss, optim])
                writer.add_summary(summary, step)

            duration = time.time() - start_time
            print('step {:d} - loss = {:.3f}, ({:.3f} sec/step)'.format(
                step, loss_value, duration))
            sys.stdout.flush()

    except KeyboardInterrupt:
        #    # Introduce a line break after ^C is displayed so save message
        #    # is on its own line.
        print()
    finally:
        #    if step > last_saved_step:
        #        save(saver, sess, logdir, step)
        coord.request_stop()
        coord.join(threads)
Beispiel #7
0
def run(target,
        is_chief,
        train_steps,
        job_dir,
        train_files,
        reader_config,
        batch_size,
        learning_rate,
        residual_channels,
        dilation_channels,
        skip_channels,
        dilations,
        use_biases,
        gc_channels,
        lc_channels,
        filter_width,
        sample_size,
        initial_filter_width,
        l2_regularization_strength,
        momentum,
        optimizer):

    # Run the training and evaluation graph.

    # If the server is chief which is `master`
    # In between graph replication Chief is one node in
    # the cluster with extra responsibility and by default
    # is worker task zero. We have assigned master as the chief.
    #
    # See https://youtu.be/la_M6bCV91M?t=1203 for details on
    # distributed TensorFlow and motivation about chief.
    # TODO: hooks
    hooks = []

    # Create a new graph and specify that as default
    with tf.Graph().as_default():
        # Placement of ops on devices using replica device setter
        # which automatically places the parameters on the `ps` server
        # and the `ops` on the workers
        #
        # See:
        # https://www.tensorflow.org/api_docs/python/tf/train/replica_device_setter
        with tf.device(tf.train.replica_device_setter()):

            with open(reader_config) as json_file:
                reader_config = json.load(json_file)

            # Reader
            receptive_field_size = WaveNetModel.calculate_receptive_field(filter_width,
                                                                          dilations,
                                                                          False,
                                                                          initial_filter_width)

            reader = CsvReader(
                train_files,
                batch_size=batch_size,
                receptive_field=receptive_field_size,
                sample_size=sample_size,
                config=reader_config
            )

            # Create network.
            net = WaveNetModel(
                batch_size=batch_size,
                dilations=dilations,
                filter_width=filter_width,
                residual_channels=residual_channels,
                dilation_channels=dilation_channels,
                skip_channels=skip_channels,
                quantization_channels=reader.data_dim,
                use_biases=use_biases,
                scalar_input=False,
                initial_filter_width=initial_filter_width,
                histograms=False,
                global_channels=gc_channels,
                local_channels=lc_channels)

            global_step_tensor = tf.contrib.framework.get_or_create_global_step()

            if l2_regularization_strength == 0:
                l2_regularization_strength = None

            loss = net.loss(input_batch=reader.data_batch,
                            global_condition=reader.gc_batch,
                            local_condition=reader.lc_batch,
                            l2_regularization_strength=l2_regularization_strength)

            optimizer = optimizer_factory[optimizer](learning_rate=learning_rate, momentum=momentum)

            trainable = tf.trainable_variables()

            train_op = optimizer.minimize(loss, var_list=trainable, global_step=global_step_tensor)

            # Add Generation operator to graph for later use in generate.py
            tf.add_to_collection("config", tf.constant(reader.data_dim, name='data_dim'))
            tf.add_to_collection("config", tf.constant(receptive_field_size, name='receptive_field_size'))
            tf.add_to_collection("config", tf.constant(sample_size, name='sample_size'))

            samples = tf.placeholder(tf.float32, shape=(receptive_field_size, reader.data_dim), name="samples")
            gc = tf.placeholder(tf.int32, shape=(receptive_field_size), name="gc")
            lc = tf.placeholder(tf.int32, shape=(receptive_field_size), name="lc")  # TODO set to one

            gc = tf.one_hot(gc, gc_channels)
            lc = tf.one_hot(lc, lc_channels / 1)  # TODO set to one...

            tf.add_to_collection("predict_proba", net.predict_proba(samples, gc, lc))

            # TODO: Implement fast generation
            """
            if filter_width <= 2:
                samples_fast = tf.placeholder(tf.float32, shape=(1, reader.data_dim), name="samples_fast")
                gc_fast = tf.placeholder(tf.int32, shape=(1), name="gc_fast")
                lc_fast = tf.placeholder(tf.int32, shape=(1), name="lc_fast")

                gc_fast = tf.one_hot(gc_fast, gc_channels)
                lc_fast = tf.one_hot(lc_fast, lc_channels)

                tf.add_to_collection("predict_proba_incremental", net.predict_proba_incremental(samples_fast, gc_fast, lc_fast))
                tf.add_to_collection("push_ops", net.push_ops)
            """

        # Creates a MonitoredSession for training
        # MonitoredSession is a Session-like object that handles
        # initialization, recovery and hooks
        # https://www.tensorflow.org/api_docs/python/tf/train/MonitoredTrainingSession
        with tf.train.MonitoredTrainingSession(master=target,
                                               is_chief=is_chief,
                                               checkpoint_dir=job_dir,
                                               hooks=hooks,
                                               save_checkpoint_secs=120,
                                               save_summaries_steps=20) as session:  # TODO: SUMMARIES HERE

            # Global step to keep track of global number of steps particularly in
            # distributed setting
            step = global_step_tensor.eval(session=session)
            # Run the training graph which returns the step number as tracked by
            # the global step tensor.
            # When train epochs is reached, session.should_stop() will be true.
            try:
                while (train_steps is None or
                       step < train_steps) and not session.should_stop():

                    step, _, loss_val = session.run([global_step_tensor, train_op, loss])
                    print("step %d loss %.4f" % (step, loss_val), end='\r')
                    sys.stdout.flush()

                    # For debugging
                    # dat, gc, lc = session.run([reader.data_batch, reader.gc_batch, reader.lc_batch])
                    # print(colored(str(dat.shape), 'red', 'on_grey'))
                    # for field in dat:
                    #     print(colored(str(field), 'red'))
                    # print(colored(str(lc.shape), 'red', 'on_grey'))
                    # for i in lc[0, -1, :]:
                    #     print("%1d" % i, end='')
                    #     sys.stdout.flush()
                    # print(colored(str(gc.shape), 'red', 'on_grey'))
                    # for i in gc[0, -1, :]:
                    #     print("%1d" % i, end='')
                    #     sys.stdout.flush()
                    # for field in lc:
                    #     print(colored(str(field), 'blue'))
                    # print(colored(str(gc.shape), 'red', 'on_grey'))
                    # for field in gc:
                    #     print(colored(str(field), 'green'))

            except KeyboardInterrupt:
                pass
Beispiel #8
0
def main():
    args = get_arguments()

    try:
        directories = validate_directories(args)
    except ValueError as e:
        print("Some arguments are wrong:")
        print(str(e))
        return

    logdir = directories['logdir']
    logdir_root = directories['logdir_root']
    restore_from = directories['restore_from']

    # Even if we restored the model, we will treat it as new training
    # if the trained model is written into an arbitrary location.
    is_overwritten_training = logdir != restore_from

    with open(args.wavenet_params, 'r') as f:
        wavenet_params = json.load(f)

    # Create coordinator.
    coord = tf.train.Coordinator()

    # Load raw waveform from VCTK corpus.
    with tf.name_scope('create_inputs'):
        # Allow silence trimming to be skipped by specifying a threshold near
        # zero.
        silence_threshold = args.silence_threshold if args.silence_threshold > \
                                                      EPSILON else None
        reader = AudioReader(
            args.data_dir,
            coord,
            sample_rate=wavenet_params['sample_rate'],
            sample_size=args.sample_size,
            silence_threshold=args.silence_threshold)
        audio_batch = reader.dequeue(args.batch_size)

    # Create network.
    net = WaveNetModel(
        batch_size=args.batch_size,
        dilations=wavenet_params["dilations"],
        filter_width=wavenet_params["filter_width"],
        residual_channels=wavenet_params["residual_channels"],
        dilation_channels=wavenet_params["dilation_channels"],
        skip_channels=wavenet_params["skip_channels"],
        quantization_channels=wavenet_params["quantization_channels"],
        use_biases=wavenet_params["use_biases"])
    if args.l2_regularization_strength == 0:
        args.l2_regularization_strength = None
    loss = net.loss(audio_batch, args.l2_regularization_strength)
    optimizer = tf.train.AdamOptimizer(learning_rate=args.learning_rate)
    trainable = tf.trainable_variables()
    optim = optimizer.minimize(loss, var_list=trainable)

    # Set up logging for TensorBoard.
    writer = tf.train.SummaryWriter(logdir)
    writer.add_graph(tf.get_default_graph())
    run_metadata = tf.RunMetadata()
    summaries = tf.merge_all_summaries()

    # Set up session
    sess = tf.Session(config=tf.ConfigProto(log_device_placement=False))
    init = tf.initialize_all_variables()
    sess.run(init)

    # Saver for storing checkpoints of the model.
    saver = tf.train.Saver()

    try:
        saved_global_step = load(saver, sess, restore_from)
        if is_overwritten_training or saved_global_step is None:
            # The first training step will be saved_global_step + 1,
            # therefore we put -1 here for new or overwritten trainings.
            saved_global_step = -1

    except:
        print("Something went wrong while restoring checkpoint. "
              "We will terminate training to avoid accidentally overwriting "
              "the previous model.")
        raise

    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    reader.start_threads(sess)

    try:
        last_saved_step = saved_global_step
        for step in range(saved_global_step + 1, args.num_steps):
            start_time = time.time()
            if args.store_metadata and step % 50 == 0:
                # Slow run that stores extra information for debugging.
                print('Storing metadata')
                run_options = tf.RunOptions(
                    trace_level=tf.RunOptions.FULL_TRACE)
                summary, loss_value, _ = sess.run(
                    [summaries, loss, optim],
                    options=run_options,
                    run_metadata=run_metadata)
                writer.add_summary(summary, step)
                writer.add_run_metadata(run_metadata,
                                        'step_{:04d}'.format(step))
                tl = timeline.Timeline(run_metadata.step_stats)
                timeline_path = os.path.join(logdir, 'timeline.trace')
                with open(timeline_path, 'w') as f:
                    f.write(tl.generate_chrome_trace_format(show_memory=True))
            else:
                summary, loss_value, _ = sess.run([summaries, loss, optim])
                writer.add_summary(summary, step)

            duration = time.time() - start_time
            print('step {:d} - loss = {:.3f}, ({:.3f} sec/step)'
                  .format(step, loss_value, duration))

            if step % args.checkpoint_every == 0:
                save(saver, sess, logdir, step)
                last_saved_step = step

    except KeyboardInterrupt:
        # Introduce a line break after ^C is displayed so save message
        # is on its own line.
        print()
    finally:
        if step > last_saved_step:
            save(saver, sess, logdir, step)
        coord.request_stop()
        coord.join(threads)
Beispiel #9
0
class TestGeneration(tf.test.TestCase):

    def testGenerateSimple(self):
        # Reader config
        with open(TEST_DATA + "/config.json") as json_file:
            self.reader_config = json.load(json_file)

        # Initialize the reader
        receptive_field_size = WaveNetModel.calculate_receptive_field(2, LAYERS, False, 8)

        self.reader = CsvReader(
            [TEST_DATA + "/test.dat", TEST_DATA + "/test.emo", TEST_DATA + "/test.pho"],
            batch_size=1,
            receptive_field=receptive_field_size,
            sample_size=SAMPLE_SIZE,
            config=self.reader_config
        )

        # WaveNet model
        self.net = WaveNetModel(batch_size=1,
                                dilations=LAYERS,
                                filter_width=2,
                                residual_channels=8,
                                dilation_channels=8,
                                skip_channels=8,
                                quantization_channels=2,
                                use_biases=True,
                                scalar_input=False,
                                initial_filter_width=8,
                                histograms=False,
                                global_channels=GC_CHANNELS,
                                local_channels=LC_CHANNELS)

        loss = self.net.loss(input_batch=self.reader.data_batch,
                             global_condition=self.reader.gc_batch,
                             local_condition=self.reader.lc_batch,
                             l2_regularization_strength=L2)

        optimizer = optimizer_factory['adam'](learning_rate=0.003, momentum=0.9)
        trainable = tf.trainable_variables()
        train_op = optimizer.minimize(loss, var_list=trainable)

        samples = tf.placeholder(tf.float32, shape=(receptive_field_size, self.reader.data_dim), name="samples")
        gc = tf.placeholder(tf.int32, shape=(receptive_field_size), name="gc")
        lc = tf.placeholder(tf.int32, shape=(receptive_field_size), name="lc")

        gc = tf.one_hot(gc, GC_CHANNELS)
        lc = tf.one_hot(lc, LC_CHANNELS)

        predict = self.net.predict_proba(samples, gc, lc)

        '''does nothing'''
        with self.test_session() as session:
            session.run([
                tf.local_variables_initializer(),
                tf.global_variables_initializer(),
                tf.tables_initializer(),
            ])
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=session, coord=coord)

            for ITER in range(1):

                for i in range(1000):
                    _, loss_val = session.run([train_op, loss])
                    print("step %d loss %.4f" % (i, loss_val), end='\r')
                    sys.stdout.flush()
                print()

                data_samples = np.random.random((receptive_field_size, self.reader.data_dim))
                gc_samples = np.zeros((receptive_field_size))
                lc_samples = np.zeros((receptive_field_size))

                output = []

                for EMO in range(3):
                    for PHO in range(3):
                        for _ in range(100):
                            prediction = session.run(predict, feed_dict={'samples:0': data_samples, 'gc:0': gc_samples, 'lc:0': lc_samples})
                            data_samples = data_samples[1:, :]
                            data_samples = np.append(data_samples, prediction, axis=0)

                            gc_samples = gc_samples[1:]
                            gc_samples = np.append(gc_samples, [EMO], axis=0)
                            lc_samples = lc_samples[1:]
                            lc_samples = np.append(lc_samples, [PHO], axis=0)

                            output.append(prediction[0])

                output = np.array(output)
                print("ITER %d" % ITER)
                plt.imsave("./test/SINE_test_%d.png" % ITER, np.kron(output[:, :], np.ones([1, 500])), vmin=0.0, vmax=1.0)
class TestNet(tf.test.TestCase):
    def setUp(self):
        self.net = WaveNetModel(
            batch_size=1,
            dilations=[1, 2, 4, 8, 16, 32, 64, 1, 2, 4, 8, 16, 32, 64],
            filter_width=2,
            residual_channels=32,
            dilation_channels=32,
            quantization_channels=256,
            use_biases=True,
            skip_channels=32)
        self.optimizer_type = 'sgd'
        self.learning_rate = 0.02
        self.generate = True
        self.momentum = MOMENTUM

    # Train a net on a short clip of 3 sine waves superimposed
    # (an e-flat chord).
    #
    # Presumably it can overfit to such a simple signal. This test serves
    # as a smoke test where we just check that it runs end-to-end during
    # training, and learns this waveform.

    def testEndToEndTraining(self):
        audio, output_audio = make_sine_waves()
        np.random.seed(42)
        librosa.output.write_wav('sine_train.wav', audio, int(SAMPLE_RATE_HZ))
        librosa.output.write_wav('sine_expected_answered.wav', output_audio,
                                 int(SAMPLE_RATE_HZ))
        # if self.generate:
        #
        #    power_spectrum = np.abs(np.fft.fft(audio))**2
        #    freqs = np.fft.fftfreq(audio.size, SAMPLE_PERIOD_SECS)
        #    indices = np.argsort(freqs)
        #    indices = [index for index in indices if freqs[index] >= 0 and
        #                                             freqs[index] <= 500.0]
        #    plt.plot(freqs[indices], power_spectrum[indices])
        #    plt.show()
        run_metadata = tf.RunMetadata()

        audio_tensor = tf.convert_to_tensor(audio, dtype=tf.float32)
        output_audio_tensor = tf.convert_to_tensor(output_audio,
                                                   dtype=tf.float32)
        loss = self.net.loss(audio_tensor, output_audio_tensor)
        optimizer = optimizer_factory[self.optimizer_type](
            learning_rate=self.learning_rate, momentum=self.momentum)
        trainable = tf.trainable_variables()
        optim = optimizer.minimize(loss, var_list=trainable)
        init = tf.initialize_all_variables()
        run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
        generated_waveform = None
        max_allowed_loss = 0.1
        loss_val = max_allowed_loss
        initial_loss = None
        with self.test_session() as sess:
            sess.run(init)
            initial_loss = sess.run(loss)
            for i in range(TRAIN_ITERATIONS):
                loss_val, _ = sess.run([loss, optim],
                                       run_metadata=run_metadata)
                if i % 10 == 0:
                    print("i: %d loss: %f" % (i, loss_val))
                    tl = timeline.Timeline(run_metadata.step_stats)
                    timeline_path = os.path.join('.', 'timeline.trace')
                    # with open(timeline_path, 'w') as f:
                    #     f.write(tl.generate_chrome_trace_format(show_memory=True))

            # Sanity check the initial loss was larger.
            # self.assertGreater(initial_loss, max_allowed_loss)

            # Loss after training should be small.
            # self.assertLess(loss_val, max_allowed_loss)

            # Loss should be at least two orders of magnitude better
            # than before training.
            # self.assertLess(loss_val / initial_loss, 0.01)

            # saver = tf.train.Saver(var_list=tf.trainable_variables())
            # saver.save(sess, '/tmp/sine_test_model.ckpt', global_step=i)
            if self.generate:
                # Check non-incremental generation
                generated_waveform = generate_waveform(sess,
                                                       self.net,
                                                       False,
                                                       wav_seed=True)
                check_waveform(self.assertGreater, generated_waveform)
class TestMoveNet(tf.test.TestCase):
    def generate_waveform(self, sess):
        samples = tf.placeholder(tf.int32)
        next_sample_probs = self.net.predict_proba_all(samples)
        operations = [next_sample_probs]

        waveform = []
        seed = create_seed("sine_train.wav",
                           SAMPLE_RATE_HZ,
                           QUANTIZATION_CHANNELS,
                           window_size=WINDOW_SIZE,
                           silence_threshold=0)
        input_waveform = sess.run(seed).tolist()
        decode = mu_law_decode(samples, QUANTIZATION_CHANNELS)
        slide_windows = 256
        for slide_start in range(0, len(input_waveform), slide_windows):
            if slide_start + slide_windows >= len(input_waveform):
                break
            input_audio_window = input_waveform[slide_start:slide_start +
                                                slide_windows]

            # Run the WaveNet to predict the next sample.
            all_prediction = sess.run(operations,
                                      feed_dict={samples:
                                                 input_audio_window})[0]
            all_prediction = np.asarray(all_prediction)
            output_waveform = get_all_output_from_predictions(all_prediction)
            print("Prediction {}".format(output_waveform))
            waveform.extend(output_waveform)

        waveform = np.array(waveform[:])
        decoded_waveform = sess.run(decode, feed_dict={samples: waveform})
        return decoded_waveform

    def setUp(self):
        self.net = WaveNetModel(
            batch_size=1,
            dilations=[1, 2, 4, 8, 16, 32, 64, 1, 2, 4, 8, 16, 32, 64],
            filter_width=2,
            residual_channels=32,
            dilation_channels=32,
            quantization_channels=256,
            use_biases=True,
            skip_channels=32)
        self.optimizer_type = 'sgd'
        self.learning_rate = 0.02
        self.generate = True
        self.momentum = MOMENTUM

    def testEndToEndTraining(self):
        audio, output_audio = make_sine_waves()
        np.random.seed(42)
        librosa.output.write_wav('sine_train.wav', audio, int(SAMPLE_RATE_HZ))
        librosa.output.write_wav('sine_expected_answered.wav', output_audio,
                                 int(SAMPLE_RATE_HZ))

        input_samples = tf.placeholder(tf.float32)
        output_samples = tf.placeholder(tf.float32)

        loss = self.net.loss(input_samples, output_samples)
        optimizer = optimizer_factory[self.optimizer_type](
            learning_rate=self.learning_rate, momentum=self.momentum)
        trainable = tf.trainable_variables()
        optim = optimizer.minimize(loss, var_list=trainable)
        init = tf.initialize_all_variables()

        generated_waveform = None
        max_allowed_loss = 0.1
        slide_windows = 256
        slide_start = 0
        with self.test_session() as sess:
            sess.run(init)
            for i in range(TRAIN_ITERATIONS):
                if slide_start + slide_windows >= min(len(audio),
                                                      len(output_audio)):
                    slide_start = 0
                    print("slide from beginning...")
                input_audio_window = audio[slide_start:slide_start +
                                           slide_windows]
                output_audio_window = output_audio[slide_start:slide_start +
                                                   slide_windows]
                slide_start += 1
                loss_val, _ = sess.run(
                    [loss, optim],
                    feed_dict={
                        input_samples: input_audio_window,
                        output_samples: output_audio_window
                    })
                if i % 10 == 0:
                    print("i: %d loss: %f" % (i, loss_val))
            # saver.save(sess, '/tmp/sine_test_model.ckpt', global_step=i)
            if self.generate:
                # Check non-incremental generation
                generated_waveform = self.generate_waveform(sess)
                check_waveform(self.assertGreater, generated_waveform)
Beispiel #12
0
def main():
    args = get_arguments()

    try:
        directories = validate_directories(args)
    except ValueError as e:
        print("Some arguments are wrong:")
        print(str(e))
        return
    logdir = directories['logdir']
    logdir_root = directories['logdir_root']
    restore_from = directories['restore_from']

    # Even if we restored the model, we will treat it as new training
    # if the trained model is written into an arbitrary location.
    is_overwritten_training = logdir != restore_from

    with open(args.wavenet_params, 'r') as f:
        wavenet_params = json.load(f)

    # Create coordinator.

    coord = tf.train.Coordinator()

    sess = tf.Session(config=tf.ConfigProto(log_device_placement=False))
    # Load raw waveform from VCTK corpus.
    with tf.name_scope('create_inputs'):
        gc_enabled = args.gc_channels is not None
        reader = SkeletonReader(
            args.data_dir,
            coord,
            gc_enabled=gc_enabled,
            receptive_field=WaveNetModel.calculate_receptive_field(
                wavenet_params["filter_width"], wavenet_params["dilations"]),
            input_channels=wavenet_params["input_channels"],
            sample_size=args.sample_size)
        #print ('batch_size:{0}'.format(args.batch_size))
        #skeleton_batch: batch_sizes x (receptive field+sample_size) x skeleton_channels
        skeleton_batch = reader.dequeue(args.batch_size)
        #print('skeleton_batch shape:')
        print_node = tf.Print(skeleton_batch, [tf.size(skeleton_batch)])
        if gc_enabled:
            gc_id_batch = reader.dequeue_gc(args.batch_size)
        else:
            gc_id_batch = None
        #sess.run(print_node)
    # Create network.
    net = WaveNetModel(batch_size=args.batch_size,
                       dilations=wavenet_params["dilations"],
                       filter_width=wavenet_params["filter_width"],
                       residual_channels=wavenet_params["residual_channels"],
                       input_channels=wavenet_params["input_channels"],
                       output_channels=wavenet_params["output_channels"],
                       dilation_channels=wavenet_params["dilation_channels"],
                       skip_channels=wavenet_params["skip_channels"],
                       use_biases=wavenet_params["use_biases"],
                       histograms=args.histograms,
                       global_condition_channels=args.gc_channels)

    if args.l2_regularization_strength == 0:
        args.l2_regularization_strength = None
    loss = net.loss(input_batch=skeleton_batch,
                    global_condition_batch=gc_id_batch,
                    l2_regularization_strength=args.l2_regularization_strength)
    optimizer = optimizer_factory[args.optimizer](
        learning_rate=args.learning_rate,
        momentum=args.momentum,
        epsilon=args.epsilon)
    trainable = tf.trainable_variables()
    total_parameters = 0
    for variable in trainable:
        shape = variable.get_shape()
        variable_parameters = 1
        for dim in shape:
            variable_parameters *= dim.value
        total_parameters += variable_parameters
    print('total number of parameters: {0}'.format(total_parameters))
    optim = optimizer.minimize(loss, var_list=trainable)

    # Set up logging for TensorBoard.
    writer = tf.train.SummaryWriter(logdir)
    writer.add_graph(tf.get_default_graph())
    run_metadata = tf.RunMetadata()
    summaries = tf.merge_all_summaries()

    # Set up session
    #sess = tf.Session(config=tf.ConfigProto(log_device_placement=False))
    init = tf.initialize_all_variables()
    sess.run(init)

    # Saver for storing checkpoints of the model.
    saver = tf.train.Saver(var_list=tf.trainable_variables())

    try:
        saved_global_step = load(saver, sess, restore_from)
        if is_overwritten_training or saved_global_step is None:
            # The first training step will be saved_global_step + 1,
            # therefore we put -1 here for new or overwritten trainings.
            saved_global_step = -1

    except:
        print("Something went wrong while restoring checkpoint. "
              "We will terminate training to avoid accidentally overwriting "
              "the previous model.")
        raise

    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    reader.start_threads(sess)

    step = None
    try:
        last_saved_step = saved_global_step
        for step in range(saved_global_step + 1, args.num_steps):
            start_time = time.time()
            if args.store_metadata and step % 50 == 0:
                # Slow run that stores extra information for debugging.
                print('Storing metadata')
                run_options = tf.RunOptions(
                    trace_level=tf.RunOptions.FULL_TRACE)
                summary, loss_value, _ = sess.run([summaries, loss, optim],
                                                  options=run_options,
                                                  run_metadata=run_metadata)
                writer.add_summary(summary, step)
                writer.add_run_metadata(run_metadata,
                                        'step_{:04d}'.format(step))
                tl = timeline.Timeline(run_metadata.step_stats)
                timeline_path = os.path.join(logdir, 'timeline.trace')
                with open(timeline_path, 'w') as f:
                    f.write(tl.generate_chrome_trace_format(show_memory=True))
            else:
                summary, loss_value, _ = sess.run([summaries, loss, optim])
                writer.add_summary(summary, step)

            duration = time.time() - start_time
            print('step {:d} - loss = {:.3f}, ({:.3f} sec/step)'.format(
                step, loss_value, duration))

            if step % args.checkpoint_every == 0:
                save(saver, sess, logdir, step)
                last_saved_step = step

    except KeyboardInterrupt:
        # Introduce a line break after ^C is displayed so save message
        # is on its own line.
        print()
    finally:
        if step > last_saved_step:
            save(saver, sess, logdir, step)
        coord.request_stop()
        coord.join(threads)
Beispiel #13
0
def main():
    args = get_arguments()

    try:
        directories = validate_directories(args)
    except ValueError as e:
        print("Some arguments are wrong:")
        print(str(e))
        return

    logdir = directories['logdir']
    restore_from = directories['restore_from']

    # Even if we restored the model, we will treat it as new training
    # if the trained model is written into an arbitrary location.
    is_overwritten_training = logdir != restore_from

    with open(args.wavenet_params, 'r') as f:
        wavenet_params = json.load(f)

    # Read TFRecords and create network.
    tf.reset_default_graph()

    data_train = get_tfrecord(name='train',
                              sample_size=args.sample_size,
                              batch_size=args.batch_size,
                              seed=None,
                              repeat=None,
                              data_path=args.data_path)
    data_test = get_tfrecord(name='test',
                             sample_size=args.sample_size,
                             batch_size=args.batch_size,
                             seed=None,
                             repeat=None,
                             data_path=args.data_path)

    train_itr = data_train.make_one_shot_iterator()
    test_itr = data_test.make_one_shot_iterator()

    train_batch, train_label = train_itr.get_next()
    test_batch, test_label = test_itr.get_next()

    train_batch = tf.reshape(train_batch, [-1, train_batch.shape[1], 1])
    test_batch = tf.reshape(test_batch, [-1, test_batch.shape[1], 1])

    # Create network.
    net = WaveNetModel(sample_size=args.sample_size,
                       batch_size=args.batch_size,
                       dilations=wavenet_params["dilations"],
                       filter_width=wavenet_params["filter_width"],
                       residual_channels=wavenet_params["residual_channels"],
                       dilation_channels=wavenet_params["dilation_channels"],
                       skip_channels=wavenet_params["skip_channels"],
                       histograms=args.histograms)

    train_loss = net.loss(train_batch, train_label)
    test_loss = net.loss(test_batch, test_label)

    # Optimizer
    # Temporarily set to momentum optimizer
    optimizer = tf.train.MomentumOptimizer(learning_rate=args.learning_rate,
                                           momentum=args.momentum)
    trainable = tf.trainable_variables()
    optim = optimizer.minimize(train_loss, var_list=trainable)

    # Accuracy of test data
    pred_test = net.predict_proba(test_batch, audio_only=True)
    equals = tf.equal(tf.squeeze(test_label), tf.round(pred_test))
    acc = tf.reduce_mean(tf.cast(equals, tf.float32))

    # Set up logging for TensorBoard.
    writer = tf.summary.FileWriter(logdir)
    writer.add_graph(tf.get_default_graph())
    run_metadata = tf.RunMetadata()
    summaries = tf.summary.merge_all()

    # Set up session
    sess = tf.Session(config=tf.ConfigProto(log_device_placement=False))
    init = tf.global_variables_initializer()
    init2 = tf.local_variables_initializer()
    sess.run([init, init2])

    # Saver for storing checkpoints of the model.
    saver = tf.train.Saver(var_list=tf.trainable_variables(),
                           max_to_keep=args.max_checkpoints)

    try:
        saved_global_step = load(saver, sess, restore_from)
        if is_overwritten_training or saved_global_step is None:
            # The first training step will be saved_global_step + 1,
            # therefore we put -1 here for new or overwritten trainings.
            saved_global_step = -1
    except:
        print("Something went wrong while restoring checkpoint. "
              "We will terminate training to avoid accidentally overwriting "
              "the previous model.")
        raise

    step = None
    last_saved_step = saved_global_step
    try:
        for step in range(saved_global_step + 1, args.num_steps):
            start_time = time.time()
            if step == saved_global_step + 1:
                run_options = tf.RunOptions(
                    trace_level=tf.RunOptions.FULL_TRACE)
            if args.store_metadata and step % 50 == 0:
                # Slow run that stores extra information for debugging.
                print('Storing metadata')
                run_options = tf.RunOptions(
                    trace_level=tf.RunOptions.FULL_TRACE)
                summary_, train_loss_, test_loss_, acc_, _ = sess.run(
                    [summaries, train_loss, test_loss, acc, optim],
                    options=run_options,
                    run_metadata=run_metadata)
                writer.add_summary(summary_, step)
                writer.add_run_metadata(run_metadata,
                                        'step_{:04d}'.format(step))
                tl = timeline.Timeline(run_metadata.step_stats)
                timeline_path = os.path.join(logdir, 'timeline.trace')
                with open(timeline_path, 'w') as f:
                    f.write(tl.generate_chrome_trace_format(show_memory=True))
            else:
                summary_, train_loss_, test_loss_, acc_, _ = sess.run(
                    [summaries, train_loss, test_loss, acc, optim],
                    options=run_options,
                    run_metadata=run_metadata)
                writer.add_summary(summary_, step)

            duration = time.time() - start_time
            print("step {:d}:  trainloss = {:.3f}, "
                  "testloss = {:.3f}, acc = {:.3f}, ({:.3f} sec/step)".format(
                      step, train_loss_, test_loss_, acc_, duration))

            if step % args.checkpoint_every == 0:
                save(saver, sess, logdir, step)
                last_saved_step = step

    except KeyboardInterrupt:
        # Introduce a line break after ^C is displayed so save message
        # is on its own line.
        print()
    finally:
        if step and step > last_saved_step:
            save(saver, sess, logdir, step)
        elif not step:
            print("No training performed during session.")
        else:
            pass
Beispiel #14
0
def main():
    args = get_arguments()

    with open(args.model_params, 'r') as f:
        model_params = json.load(f)

    with open(args.training_params, 'r') as f:
        train_params = json.load(f)

    try:
        directories = validate_directories(args)
    except ValueError as e:
        print('Some arguments are wrong:')
        print(str(e))
        return

    logdir = directories['logdir']
    restore_from = directories['restore_from']

    # Even if we restored the model, we will treat it as new training
    # if the trained model is written into an arbitrary location.
    is_overwritten_training = logdir != restore_from

    receptive_field = WaveNetModel.calculate_receptive_field(
        model_params['filter_width'],
        model_params['dilations'],
        model_params['initial_filter_width'])
    # Save arguments and model params into file
    save_run_config(args, receptive_field, STARTED_DATESTRING, logdir)

    # Create coordinator.
    coord = tf.train.Coordinator()

    # Create data loader.
    with tf.name_scope('create_inputs'):
        reader = WavMidReader(data_dir=args.data_dir_train,
                              coord=coord,
                              audio_sample_rate=model_params['audio_sr'],
                              receptive_field=receptive_field,
                              velocity=args.velocity,
                              sample_size=args.sample_size,
                              queues_size=(10, 10*args.batch_size))
        data_batch = reader.dequeue(args.batch_size)

    # Create model.
    net = WaveNetModel(
        batch_size=args.batch_size,
        dilations=model_params['dilations'],
        filter_width=model_params['filter_width'],
        residual_channels=model_params['residual_channels'],
        dilation_channels=model_params['dilation_channels'],
        skip_channels=model_params['skip_channels'],
        output_channels=model_params['output_channels'],
        use_biases=model_params['use_biases'],
        initial_filter_width=model_params['initial_filter_width'])

    input_data = tf.placeholder(dtype=tf.float32,
                                shape=(args.batch_size, None, 1))
    input_labels = tf.placeholder(dtype=tf.float32,
                                  shape=(args.batch_size, None,
                                         model_params['output_channels']))

    loss, probs = net.loss(input_data=input_data,
                           input_labels=input_labels,
                           pos_weight=train_params['pos_weight'],
                           l2_reg_str=train_params['l2_reg_str'])
    optimizer = optimizer_factory[args.optimizer](
                    learning_rate=train_params['learning_rate'],
                    momentum=train_params['momentum'])
    trainable = tf.trainable_variables()
    optim = optimizer.minimize(loss, var_list=trainable)

    # Set up logging for TensorBoard.
    writer = tf.summary.FileWriter(logdir)
    writer.add_graph(tf.get_default_graph())
    run_metadata = tf.RunMetadata()
    summaries = tf.summary.merge_all()
    histograms = tf.summary.merge_all(key=HKEY)

    # Separate summary ops for validation, since they are
    # calculated only once per evaluation cycle.
    with tf.name_scope('validation_summaries'):

        metric_summaries = metrics_empty_dict()
        metric_value = tf.placeholder(tf.float32)
        for name in metric_summaries.keys():
            metric_summaries[name] = tf.summary.scalar(name, metric_value)

        images_buffer = tf.placeholder(tf.string)
        images_batch = tf.stack(
            [tf.image.decode_png(images_buffer[0], channels=4),
             tf.image.decode_png(images_buffer[1], channels=4),
             tf.image.decode_png(images_buffer[2], channels=4)])
        images_summary = tf.summary.image('estim', images_batch)

        audio_data = tf.placeholder(tf.float32)
        audio_summary = tf.summary.audio('input', audio_data,
                                         model_params['audio_sr'])

    # Set up session
    sess = tf.Session(config=tf.ConfigProto(log_device_placement=False))
    init = tf.global_variables_initializer()
    sess.run(init)

    # Saver for storing checkpoints of the model.
    saver = tf.train.Saver(var_list=tf.trainable_variables(),
                           max_to_keep=args.max_checkpoints)

    # Trainer for keeping best validation-performing model
    # and optional early stopping.
    trainer = Trainer(sess, logdir, train_params['early_stop_limit'], 0.999)

    try:
        saved_global_step = load(saver, sess, restore_from)
        if is_overwritten_training or saved_global_step is None:
            # The first training step will be saved_global_step + 1,
            # therefore we put -1 here for new or overwritten trainings.
            saved_global_step = -1

    except:
        print('Something went wrong while restoring checkpoint. '
              'Training will be terminated to avoid accidentally '
              'overwriting the previous model.')
        raise

    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    reader.start_threads(sess)


    step = None
    last_saved_step = saved_global_step
    try:
        for step in range(saved_global_step + 1, train_params['num_steps']):
            waveform, pianoroll = sess.run([data_batch[0], data_batch[1]])
            feed_dict = {input_data : waveform, input_labels : pianoroll}
            # Reload switches from file on each step
            with open(RUNTIME_SWITCHES, 'r') as f:
                switch = json.load(f)

            start_time = time.time()
            if switch['store_meta'] and step % switch['store_every'] == 0:
                # Slow run that stores extra information for debugging.
                print('Storing metadata')
                run_options = tf.RunOptions(
                    trace_level=tf.RunOptions.FULL_TRACE)
                summary, loss_value, _ = sess.run(
                    [summaries, loss, optim],
                    feed_dict=feed_dict,
                    options=run_options,
                    run_metadata=run_metadata)
                writer.add_summary(summary, step)
                writer.add_run_metadata(run_metadata,
                                        'step_{:04d}'.format(step))
                tl = timeline.Timeline(run_metadata.step_stats)
                timeline_path = os.path.join(logdir, 'timeline.trace')
                with open(timeline_path, 'w') as f:
                    f.write(tl.generate_chrome_trace_format(show_memory=True))
            else:
                summary, loss_value, _ = sess.run([summaries, loss, optim],
                                                  feed_dict=feed_dict)
                writer.add_summary(summary, step)

            duration = time.time() - start_time
            print('step {:d} - loss = {:.3f}, ({:.3f} sec/step)'
                  .format(step, loss_value, duration))

            if step % switch['checkpoint_every'] == 0:
                save(saver, sess, logdir, step)
                last_saved_step = step

            # Evaluate model performance on validation data
            if step % switch['evaluate_every'] == 0:
                if switch['histograms']:
                    hist_summary = sess.run(histograms)
                    writer.add_summary(hist_summary, step)
                print('evaluating...')
                stats = 0, 0, 0, 0, 0, 0
                est = np.empty([0, model_params['output_channels']])
                ref = np.empty([0, model_params['output_channels']])

                b_data, b_labels, b_cntr = (
                    np.empty((0, args.sample_size + receptive_field - 1, 1)),
                    np.empty((0, model_params['output_channels'])),
                    args.batch_size)

                # if (batch_size * sample_size > valid_data) single_pass() again
                while est.size == 0: # and ref.size == 0 and sum(stats) == 0 ...

                    for data, labels in reader.single_pass(
                        sess, args.data_dir_valid):

                        # cumulate batch
                        if b_cntr > 1:
                            b_data, b_labels, decr = cumulateBatch(
                                data, labels, b_data, b_labels)
                            b_cntr -= decr
                            continue
                        elif args.batch_size > 1:
                            b_data, b_labels, decr = cumulateBatch(
                                data, labels, b_data, b_labels)
                            if not decr:
                                continue
                            data = b_data
                            labels = b_labels
                            # reset batch cumulation variables
                            b_data, b_labels, b_cntr = (
                                np.empty((
                                    0, args.sample_size + receptive_field - 1, 1
                                )),
                                np.empty((0, model_params['output_channels'])),
                                args.batch_size)

                        predictions = sess.run(
                            probs, feed_dict={input_data : data})
                        # Aggregate sums for metrics calculation
                        stats_chunk = calc_stats(
                            predictions, labels, args.threshold)
                        stats = tuple([sum(x) for x in zip(stats, stats_chunk)])
                        est = np.append(est, predictions, axis=0)
                        ref = np.append(ref, labels, axis=0)

                metrics = calc_metrics(None, None, None, stats=stats)
                write_metrics(metrics, metric_summaries, metric_value,
                              writer, step, sess)
                trainer.check(metrics['f1_measure'])

                # Render evaluation results
                if switch['log_image'] or switch['log_sound']:
                    sub_fac = int(model_params['audio_sr']/switch['midi_sr'])
                    est = roll_subsample(est.T, sub_fac)
                    ref = roll_subsample(ref.T, sub_fac)
                if switch['log_image']:
                    write_images(est, ref, switch['midi_sr'], args.threshold,
                                 (8, 6), images_summary, images_buffer,
                                 writer, step, sess)
                if switch['log_sound']:
                    write_audio(est, ref, switch['midi_sr'],
                                model_params['audio_sr'], 0.007,
                                audio_summary, audio_data,
                                writer, step, sess)

    except KeyboardInterrupt:
        # Introduce a line break after ^C is displayed so save message
        # is on its own line.
        print()
    finally:
        if step > last_saved_step:
            save(saver, sess, logdir, step)
        coord.request_stop()
        coord.join(threads)
        flush_n_close(writer, sess)
Beispiel #15
0
def main():
    args = get_arguments()

    if (args.logdir is not None and os.path.isdir(args.logdir)):
        logdir = args.logdir
    else:
        print('Argument --logdir=\'{}\' is not (but should be) '
              'a path to valid directory.'.format(args.logdir))
        return

    with open(args.model_params, 'r') as f:
        model_params = json.load(f)
    with open(RUNTIME_SWITCHES, 'r') as f:
        switch = json.load(f)

    receptive_field = WaveNetModel.calculate_receptive_field(
        model_params['filter_width'],
        model_params['dilations'],
        model_params['initial_filter_width'])

    # Create coordinator.
    coord = tf.train.Coordinator()

    # Create data loader.
    with tf.name_scope('create_inputs'):
        reader = WavMidReader(data_dir=args.data_dir_test,
                              coord=coord,
                              audio_sample_rate=model_params['audio_sr'],
                              receptive_field=receptive_field,
                              velocity=args.velocity,
                              sample_size=args.sample_size,
                              queues_size=(100, 100*BATCH_SIZE))

    # Create model.
    net = WaveNetModel(
        batch_size=BATCH_SIZE,
        dilations=model_params['dilations'],
        filter_width=model_params['filter_width'],
        residual_channels=model_params['residual_channels'],
        dilation_channels=model_params['dilation_channels'],
        skip_channels=model_params['skip_channels'],
        output_channels=model_params['output_channels'],
        use_biases=model_params['use_biases'],
        initial_filter_width=model_params['initial_filter_width'])

    input_data = tf.placeholder(dtype=tf.float32,
                                shape=(BATCH_SIZE, None, 1))
    input_labels = tf.placeholder(dtype=tf.float32,
                                  shape=(BATCH_SIZE, None,
                                         model_params['output_channels']))

    _, probs = net.loss(input_data=input_data,
                        input_labels=input_labels,
                        pos_weight=1.0,
                        l2_reg_str=None)

    # Set up session
    sess = tf.Session(config=tf.ConfigProto(log_device_placement=False))
    init = tf.global_variables_initializer()
    sess.run(init)

    # Saver for storing checkpoints of the model.
    saver = tf.train.Saver(var_list=tf.trainable_variables())

    try:
        load(saver, sess, logdir)

    except:
        print('Something went wrong while restoring checkpoint.')
        raise

    try:
        stats = 0, 0, 0, 0, 0, 0
        est = np.empty([model_params['output_channels'], 0])
        ref = np.empty([model_params['output_channels'], 0])
        sub_fac = int(model_params['audio_sr']/switch['midi_sr'])
        for data, labels in reader.single_pass(sess,
                                               args.data_dir_test):

            predictions = sess.run(probs, feed_dict={input_data : data})
            # Aggregate sums for metrics calculation
            stats_chunk = calc_stats(predictions, labels, args.threshold)
            stats = tuple([sum(x) for x in zip(stats, stats_chunk)])
            est = np.append(est, roll_subsample(predictions.T, sub_fac), axis=1)
            ref = np.append(ref, roll_subsample(labels.T, sub_fac, b=True),
                            axis=1)

        metrics = calc_metrics(None, None, None, stats=stats)
        write_metrics(metrics, None, None, None, None, None, logdir=logdir)

        # Save subsampled data for further arbitrary evaluation
        np.save(logdir+'/est.npy', est)
        np.save(logdir+'/ref.npy', ref)

        # Render evaluation results
        figsize=(int(args.plot_scale*est.shape[1]/switch['midi_sr']),
                 int(args.plot_scale*model_params['output_channels']/12))
        if args.media:
            write_images(est, ref, switch['midi_sr'],
                         args.threshold, figsize,
                         None, None, None, 0, None,
                         noterange=(21, 109),
                         legend=args.plot_legend,
                         logdir=logdir)
            write_audio(est, ref, switch['midi_sr'],
                        model_params['audio_sr'], 0.007,
                        None, None, None, 0, None, logdir=logdir)

    except KeyboardInterrupt:
        # Introduce a line break after ^C is displayed so save message
        # is on its own line.
        print()
    finally:
        coord.request_stop()
Beispiel #16
0
def main():
    args = get_arguments()
    data_dir = 'midi-Corpus/' + args.data_set + '/'
    logdir = data_dir + 'max_dilation=%d_reps=%d/' % (args.max_dilation_pow, args.expansion_reps);
    print('*************************************************');
    print(logdir);
    print('*************************************************');
    sys.stdout.flush()
    restore_from = logdir
    if not os.path.exists(logdir):
        os.makedirs(logdir)

    # Even if we restored the model, we will treat it as new training
    # if the trained model is written into an arbitrary location.
    is_overwritten_training = logdir != restore_from

    wavenet_params = loadParams(args.max_dilation_pow, args.expansion_reps, args.dil_chan, args.res_chan, args.skip_chan);
        
    with open(logdir + 'wavenet_params.json', 'w') as outfile:
        json.dump(wavenet_params, outfile)

    # Create coordinator.
    coord = tf.train.Coordinator()

    # Load raw waveform from VCTK corpus.
    with tf.name_scope('create_inputs'):
        # Allow silence trimming to be skipped by specifying a threshold near
        # zero.
        gc_enabled = False
        # data queue for the training set
        if gc_enabled:
            gc_id_batch = reader.dequeue_gc(args.batch_size)
        else:
            gc_id_batch = None
            
    # Create network.
    net = WaveNetModel(
        batch_size=BATCH_SIZE,
        dilations=wavenet_params["dilations"],
        filter_width=wavenet_params["filter_width"],
        residual_channels=wavenet_params["residual_channels"],
        dilation_channels=wavenet_params["dilation_channels"],
        skip_channels=wavenet_params["skip_channels"],
        use_biases=wavenet_params["use_biases"],
        scalar_input=wavenet_params["scalar_input"],
        initial_filter_width=wavenet_params["initial_filter_width"],
        histograms=False,
        global_condition_channels=None)
    if args.l2_regularization_strength == 0:
        args.l2_regularization_strength = None
    print('constructing training loss');
    sys.stdout.flush()
    print('constructing validation loss');
    sys.stdout.flush()
    #valid_loss, target_output, prediction = net.loss(input_batch=valid_batch,
    #                global_condition_batch=gc_id_batch,
    #                l2_regularization_strength=args.l2_regularization_strength)

    print('making optimizer');
    sys.stdout.flush()

    print('setting up tensorboard');
    sys.stdout.flush()
    # Set up logging for TensorBoard.

    valid_input = tf.placeholder(dtype=tf.float32, shape=(1, None, 88));
    loss, recon_loss, latent_loss, target_output, prediction, mu, log_sigma_sq = net.loss(input_batch=valid_input,
                    global_condition_batch=gc_id_batch,
                    l2_regularization_strength=args.l2_regularization_strength)

    mu = tf.placeholder(dtype=tf.float32, shape=(1, None, args.res_chan));
    log_sigma_sq = tf.placeholder(dtype=tf.float32, shape=(1, None, args.res_chan));
    output_width = tf.placeholder(dtype=tf.int32, shape=());
    sample = net.sample(mu=mu, log_sigma_sq=log_sigma_sq, network_input_width=output_width);

    # Set up session
    sess = tf.Session(config=tf.ConfigProto(log_device_placement=False))
    init = tf.global_variables_initializer()
    sess.run(init)

    print('saver');
    sys.stdout.flush()
    # Saver for storing checkpoints of the model.
    saver = tf.train.Saver(var_list=tf.trainable_variables(), max_to_keep=5)

    try:
        saved_global_step = load(saver, sess, restore_from)
        if is_overwritten_training or saved_global_step is None:
            # The first training step will be saved_global_step + 1,
            # therefore we put -1 here for new or overwritten trainings.
            saved_global_step = -1

    except:
        print("Something went wrong while restoring checkpoint. "
              "We will terminate training to avoid accidentally overwriting "
              "the previous model.")
        raise

    step = None
    last_saved_step = saved_global_step
 
    # load validation data
    validation_audio = load_all_audio(data_dir + 'valid/');
    num_valid_files = len(validation_audio);

    valid_losses_step = np.zeros((num_valid_files,));
    song = validation_audio[0];
    audio_0 = np.expand_dims(song, 0);
    _loss, _rec_loss, _lat_loss, _target, _pred = sess.run([loss, recon_loss, latent_loss, target_output, prediction], {valid_input:audio_0});
 
    N = 200;
    n = N - net.receptive_field+1;
    _mu = np.zeros((1, n, args.res_chan));
    _log_sigma_sq = np.zeros((1, n, args.res_chan));
    _net_samp = sess.run(sample, {mu:_mu, log_sigma_sq:_log_sigma_sq, output_width:N});
    _net_samp = np.squeeze(_net_samp);
    print('net samp', _net_samp.shape);
    _net_samp = _net_samp > 0.5;
    _net_samp = 1*_net_samp;
    plt.figure();
    plt.imshow(_net_samp);
    plt.show();
    
    filename = args.wav_out_path + ('sample_%d.mid' % int(args.gen_num));
    midiwrite(filename, _net_samp)
Beispiel #17
0
def main():
    args = get_arguments()

    try:
        directories = validate_directories(args)
    except ValueError as e:
        print("Some arguments are wrong:")
        print(str(e))
        return

    logdir = directories['logdir']
    logdir_root = directories['logdir_root']
    restore_from = directories['restore_from']

    # Even if we restored the model, we will treat it as new training
    # if the trained model is written into an arbitrary location.
    is_overwritten_training = logdir != restore_from

    with open(args.wavenet_params, 'r') as f:
        wavenet_params = json.load(f)

    with tf.device("/cpu:0"):
        # Create coordinator.
        coord = tf.train.Coordinator()

        # Load raw waveform from VCTK corpus.
        with tf.name_scope('create_inputs'):
            # Allow silence trimming to be skipped by specifying a threshold near
            # zero.
            silence_threshold = args.silence_threshold if args.silence_threshold > \
                                                          EPSILON else None
            gc_enabled = args.gc_channels is not None
            reader = AudioReader(
                args.data_dir,
                coord,
                sample_rate=wavenet_params['sample_rate'],
                gc_enabled=gc_enabled,
                sample_size=args.sample_size,
                silence_threshold=silence_threshold)

        # Create network.
        net = WaveNetModel(
            batch_size=args.batch_size,
            dilations=wavenet_params["dilations"],
            filter_width=wavenet_params["filter_width"],
            residual_channels=wavenet_params["residual_channels"],
            dilation_channels=wavenet_params["dilation_channels"],
            skip_channels=wavenet_params["skip_channels"],
            quantization_channels=wavenet_params["quantization_channels"],
            use_biases=wavenet_params["use_biases"],
            scalar_input=wavenet_params["scalar_input"],
            initial_filter_width=wavenet_params["initial_filter_width"],
            histograms=args.histograms,
            global_condition_channels=args.gc_channels,
            global_condition_cardinality=reader.gc_category_cardinality)

        if args.l2_regularization_strength == 0:
            args.l2_regularization_strength = None

        global_step = tf.get_variable("global_step", [], initializer=tf.constant_initializer(0), trainable=False)

        optimizer = optimizer_factory[args.optimizer](
            learning_rate=args.learning_rate,
            momentum=args.momentum)

        tower_grads = []
        tower_losses = []
        with tf.variable_scope(tf.get_variable_scope()):
            for i in range(args.gpu_nums):
                with tf.device("/gpu:%d" % i), tf.name_scope("tower_%d" % i) as scope:
                    audio_batch = reader.dequeue(args.batch_size)
                    if gc_enabled:
                        gc_id_batch = reader.dequeue_gc(args.batch_size)
                    else:
                        gc_id_batch = None

                    loss = net.loss(input_batch=audio_batch,
                                    global_condition_batch=gc_id_batch,
                                    l2_regularization_strength=args.l2_regularization_strength)
                    tower_losses.append(loss)

                    trainable = tf.trainable_variables()
                    grads = optimizer.compute_gradients(loss, var_list=trainable)
                    tower_grads.append(grads)

                    summaries = tf.get_collection(tf.GraphKeys.SUMMARIES, scope)
                    tf.get_variable_scope().reuse_variables()

        # calculate the mean of each gradient. Synchronization point across all towers
        grads = average_gradients(tower_grads)
        train_ops = optimizer.apply_gradients(grads, global_step=global_step)

        # calculate the mean loss
        loss = tf.reduce_mean(tower_losses)

        # Set up logging for TensorBoard.
        writer = tf.summary.FileWriter(logdir)
        writer.add_graph(tf.get_default_graph())
        run_metadata = tf.RunMetadata()
        summaries_ops = tf.summary.merge(summaries)

        # Set up session
        sess = tf.Session(config=tf.ConfigProto(log_device_placement=False, allow_soft_placement=True))
        init = tf.global_variables_initializer()
        sess.run(init)

        # Saver for storing checkpoints of the model.
        saver = tf.train.Saver(var_list=tf.trainable_variables())

        try:
            saved_global_step = load(saver, sess, restore_from)
            if is_overwritten_training or saved_global_step is None:
                # The first training step will be saved_global_step + 1,
                # therefore we put -1 here for new or overwritten trainings.
                saved_global_step = -1

        except:
            print("Something went wrong while restoring checkpoint. "
                  "We will terminate training to avoid accidentally overwriting "
                  "the previous model.")
            raise

        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        reader.start_threads(sess)

        step = None
        try:
            last_saved_step = saved_global_step
            for step in range(saved_global_step + 1, args.num_steps):
                start_time = time.time()
                if args.store_metadata and step % 50 == 0:
                    # Slow run that stores extra information for debugging.
                    print('Storing metadata')
                    run_options = tf.RunOptions(
                        trace_level=tf.RunOptions.FULL_TRACE)
                    summary, loss_value, _ = sess.run(
                        [summaries_ops, loss, train_ops],
                        options=run_options,
                        run_metadata=run_metadata)
                    writer.add_summary(summary, step)
                    writer.add_run_metadata(run_metadata,
                                            'step_{:04d}'.format(step))
                    tl = timeline.Timeline(run_metadata.step_stats)
                    timeline_path = os.path.join(logdir, 'timeline.trace')
                    with open(timeline_path, 'w') as f:
                        f.write(tl.generate_chrome_trace_format(show_memory=True))
                else:
                    summary, loss_value, _ = sess.run([summaries_ops, loss, train_ops])
                    writer.add_summary(summary, step)

                duration = time.time() - start_time
                print('step {:d} - loss = {:.3f}, ({:.3f} sec/step)'
                      .format(step, loss_value, duration))

                if step % args.checkpoint_every == 0:
                    save(saver, sess, logdir, step)
                    last_saved_step = step

        except KeyboardInterrupt:
            # Introduce a line break after ^C is displayed so save message
            # is on its own line.
            print()
        finally:
            if step > last_saved_step:
                save(saver, sess, logdir, step)
            coord.request_stop()
            coord.join(threads)
Beispiel #18
0
def main():
    args = get_arguments()

    try:
        directories = validate_directories(args)
    except ValueError as e:
        print("Some arguments are wrong:")
        print(str(e))
        return

    logdir = directories['logdir']
    logdir_root = directories['logdir_root']
    restore_from = directories['restore_from']

    # Even if we restored the model, we will treat it as new training
    # if the trained model is written into an arbitrary location.
    is_overwritten_training = logdir != restore_from

    with open(args.wavenet_params, 'r') as f:
        wavenet_params = json.load(f)

    # Create coordinator.
    coord = tf.train.Coordinator()

    # Load raw waveform from VCTK corpus.
    with tf.name_scope('create_inputs'):
        # Allow silence trimming to be skipped by specifying a threshold near
        # zero.
        silence_threshold = args.silence_threshold if args.silence_threshold > \
                                                      EPSILON else None
        reader = AudioReader(args.data_dir,
                             coord,
                             sample_rate=wavenet_params['sample_rate'],
                             sample_size=args.sample_size,
                             silence_threshold=args.silence_threshold)
        audio_batch = reader.dequeue(args.batch_size)

    # Create network.
    net = WaveNetModel(
        batch_size=args.batch_size,
        dilations=wavenet_params["dilations"],
        filter_width=wavenet_params["filter_width"],
        residual_channels=wavenet_params["residual_channels"],
        dilation_channels=wavenet_params["dilation_channels"],
        skip_channels=wavenet_params["skip_channels"],
        quantization_channels=wavenet_params["quantization_channels"],
        use_biases=wavenet_params["use_biases"],
        scalar_input=wavenet_params["scalar_input"],
        initial_filter_width=wavenet_params["initial_filter_width"])
    if args.l2_regularization_strength == 0:
        args.l2_regularization_strength = None
    loss = net.loss(audio_batch, args.l2_regularization_strength)
    if args.optimizer == ADAM_OPTIMIZER:
        optimizer = tf.train.AdamOptimizer(learning_rate=args.learning_rate)
    elif args.optimizer == SGD_OPTIMIZER:
        optimizer = tf.train.MomentumOptimizer(
            learning_rate=args.learning_rate, momentum=args.sgd_momentum)
    else:
        # This shouldn't happen, given the choices specified in argument
        # specification.
        raise RuntimeError('Invalid optimizer option.')
    trainable = tf.trainable_variables()
    optim = optimizer.minimize(loss, var_list=trainable)

    # Set up logging for TensorBoard.
    writer = tf.train.SummaryWriter(logdir)
    writer.add_graph(tf.get_default_graph())
    run_metadata = tf.RunMetadata()
    summaries = tf.merge_all_summaries()

    # Set up session
    sess = tf.Session(config=tf.ConfigProto(log_device_placement=False))
    init = tf.initialize_all_variables()
    sess.run(init)

    # Saver for storing checkpoints of the model.
    saver = tf.train.Saver(var_list=tf.trainable_variables())

    try:
        saved_global_step = load(saver, sess, restore_from)
        if is_overwritten_training or saved_global_step is None:
            # The first training step will be saved_global_step + 1,
            # therefore we put -1 here for new or overwritten trainings.
            saved_global_step = -1

    except:
        print("Something went wrong while restoring checkpoint. "
              "We will terminate training to avoid accidentally overwriting "
              "the previous model.")
        raise

    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    reader.start_threads(sess)

    try:
        last_saved_step = saved_global_step
        for step in range(saved_global_step + 1, args.num_steps):
            start_time = time.time()
            if args.store_metadata and step % 50 == 0:
                # Slow run that stores extra information for debugging.
                print('Storing metadata')
                run_options = tf.RunOptions(
                    trace_level=tf.RunOptions.FULL_TRACE)
                summary, loss_value, _ = sess.run([summaries, loss, optim],
                                                  options=run_options,
                                                  run_metadata=run_metadata)
                writer.add_summary(summary, step)
                writer.add_run_metadata(run_metadata,
                                        'step_{:04d}'.format(step))
                tl = timeline.Timeline(run_metadata.step_stats)
                timeline_path = os.path.join(logdir, 'timeline.trace')
                with open(timeline_path, 'w') as f:
                    f.write(tl.generate_chrome_trace_format(show_memory=True))
            else:
                summary, loss_value, _ = sess.run([summaries, loss, optim])
                writer.add_summary(summary, step)

            duration = time.time() - start_time
            print('step {:d} - loss = {:.3f}, ({:.3f} sec/step)'.format(
                step, loss_value, duration))

            if step % args.checkpoint_every == 0:
                save(saver, sess, logdir, step)
                last_saved_step = step

    except KeyboardInterrupt:
        # Introduce a line break after ^C is displayed so save message
        # is on its own line.
        print()
    finally:
        if step > last_saved_step:
            save(saver, sess, logdir, step)
        coord.request_stop()
        coord.join(threads)
Beispiel #19
0
def main():
    args = get_arguments()
    data_dir = 'midi-Corpus/' + args.data_set + '/'
    logdir = data_dir + 'max_dilation=%d_reps=%d/' % (args.max_dilation_pow,
                                                      args.expansion_reps)
    print('*************************************************')
    print(logdir)
    print('*************************************************')
    sys.stdout.flush()
    restore_from = logdir
    if not os.path.exists(logdir):
        os.makedirs(logdir)

    # Even if we restored the model, we will treat it as new training
    # if the trained model is written into an arbitrary location.
    is_overwritten_training = logdir != restore_from

    wavenet_params = loadParams(args.max_dilation_pow, args.expansion_reps,
                                args.dil_chan, args.res_chan, args.skip_chan)

    with open(logdir + 'wavenet_params.json', 'w') as outfile:
        json.dump(wavenet_params, outfile)

    # Create coordinator.
    coord = tf.train.Coordinator()

    # Load raw waveform from VCTK corpus.
    with tf.name_scope('create_inputs'):
        # Allow silence trimming to be skipped by specifying a threshold near
        # zero.
        gc_enabled = False
        # data queue for the training set
        train_dir = data_dir + 'train/'
        train_reader = MidiReader(
            train_dir,
            coord,
            sample_rate=wavenet_params['sample_rate'],
            gc_enabled=gc_enabled,
            receptive_field=WaveNetModel.calculate_receptive_field(
                wavenet_params["filter_width"], wavenet_params["dilations"],
                wavenet_params["scalar_input"],
                wavenet_params["initial_filter_width"]),
            sample_size=args.sample_size)
        train_batch = train_reader.dequeue(args.batch_size)
        if gc_enabled:
            gc_id_batch = reader.dequeue_gc(args.batch_size)
        else:
            gc_id_batch = None

    # Create network.
    net = WaveNetModel(
        batch_size=BATCH_SIZE,
        dilations=wavenet_params["dilations"],
        filter_width=wavenet_params["filter_width"],
        residual_channels=wavenet_params["residual_channels"],
        dilation_channels=wavenet_params["dilation_channels"],
        skip_channels=wavenet_params["skip_channels"],
        use_biases=wavenet_params["use_biases"],
        scalar_input=wavenet_params["scalar_input"],
        initial_filter_width=wavenet_params["initial_filter_width"],
        histograms=False,
        global_condition_channels=None,
        global_condition_cardinality=train_reader.gc_category_cardinality)
    if args.l2_regularization_strength == 0:
        args.l2_regularization_strength = None
    print('constructing training loss')
    sys.stdout.flush()
    train_loss, target_output, prediction = net.loss(
        input_batch=train_batch,
        global_condition_batch=gc_id_batch,
        l2_regularization_strength=args.l2_regularization_strength)
    print('constructing validation loss')
    sys.stdout.flush()

    print('making optimizer')
    sys.stdout.flush()
    optimizer = optimizer_factory['adam'](learning_rate=args.learning_rate,
                                          momentum=args.momentum)
    trainable = tf.trainable_variables()
    optim = optimizer.minimize(train_loss, var_list=trainable)

    print('setting up tensorboard')
    sys.stdout.flush()
    # Set up logging for TensorBoard.
    writer = tf.summary.FileWriter(logdir)
    writer.add_graph(tf.get_default_graph())
    run_metadata = tf.RunMetadata()
    summaries = tf.summary.merge_all()

    test_input = tf.placeholder(dtype=tf.float32, shape=(1, None, 88))
    test_loss, test_target_output, test_prediction = net.loss(
        input_batch=test_input,
        global_condition_batch=gc_id_batch,
        l2_regularization_strength=args.l2_regularization_strength)
    # Set up session
    sess = tf.Session(config=tf.ConfigProto(log_device_placement=False))
    init = tf.global_variables_initializer()
    sess.run(init)

    print('saver')
    sys.stdout.flush()
    # Saver for storing checkpoints of the model.
    saver = tf.train.Saver(var_list=tf.trainable_variables(), max_to_keep=5)

    try:
        saved_global_step = load(saver, sess, restore_from)
        if is_overwritten_training or saved_global_step is None:
            # The first training step will be saved_global_step + 1,
            # therefore we put -1 here for new or overwritten trainings.
            saved_global_step = -1

    except:
        print("Something went wrong while restoring checkpoint. "
              "We will terminate training to avoid accidentally overwriting "
              "the previous model.")
        raise

    test_audio = load_all_audio(data_dir + 'test/')
    num_test_files = len(test_audio)
    test_losses = np.zeros((num_test_files, ))
    for i in range(num_test_files):
        test_i = np.expand_dims(test_audio[i], 0)
        test_losses[i] = sess.run(test_loss, {test_input: test_i})
    test_loss_value = np.mean(test_losses)
    np.savez(logdir + 'test.npz', test_loss=test_loss_value)
Beispiel #20
0
def main():
    args = get_arguments()

    try:
        directories = validate_directories(args)
    except ValueError as e:
        print("Some arguments are wrong:")
        print(str(e))
        return

    logdir = directories['logdir']
    restore_from = directories['restore_from']

    # Even if we restored the model, we will treat it as new training
    # if the trained model is written into an arbitrary location.
    is_overwritten_training = logdir != restore_from

    with open(args.wavenet_params, 'r') as f:
        wavenet_params = json.load(f)

    # Create coordinator.
    coord = tf.train.Coordinator()
    
    #prediction_lag = 360
    #quantization_thresholds = [0.01]
    
    # Load data from csv files
    with tf.name_scope('create_inputs'):
        reader = DataReader(
            args.data_dir,
            args.testdata_dir,
            coord,
            receptive_field = WaveNetModel.calculate_receptive_field(wavenet_params["filter_width"],
                                                                   wavenet_params["dilations"]),
            prediction_lag = wavenet_params["prediction_lag"],
            quantization_thresholds = wavenet_params["quantization_thresholds"],
            sample_size = args.sample_size)
        data_batch = reader.dequeue(args.batch_size)
    
    dropout = tf.placeholder(dtype=tf.float32, shape=None)
    
    # Create network.
    net = WaveNetModel(
        batch_size=args.batch_size,
        dilations=wavenet_params["dilations"],
        filter_width=wavenet_params["filter_width"],
        residual_channels=wavenet_params["residual_channels"],
        dilation_channels=wavenet_params["dilation_channels"],
        skip_channels=wavenet_params["skip_channels"],
        input_channels = reader.num_features(),
        output_channels = reader.num_return_categories(),
        dropout = dropout,
        use_biases=wavenet_params["use_biases"],
        histograms=args.histograms)

    if args.l2_regularization_strength == 0:
        args.l2_regularization_strength = None
        
    loss, raw_loss, correct_ratio = net.loss(input_batch=data_batch,
                    l2_regularization_strength=args.l2_regularization_strength,
                    weights = [1., 1., 1.])
                    #weights = reader.get_category_inv_weights())
    
    train_loss_summary = tf.summary.merge([
                            tf.summary.scalar('train_raw_loss', raw_loss),
                            tf.summary.scalar('train_total_loss', loss)])
    test_loss_summary = tf.summary.scalar('test_raw_loss', raw_loss)
    test_correct_ratio_sell_summary = tf.summary.scalar('correct_ratio_sell', correct_ratio[2][0])
    test_correct_ratio_mid_summary = tf.summary.scalar('correct_ratio_mid', correct_ratio[2][1])
    test_correct_ratio_buy_summary = tf.summary.scalar('correct_ratio_buy', correct_ratio[2][2])
    test_summaries = [test_loss_summary, test_correct_ratio_sell_summary,
                      test_correct_ratio_mid_summary, test_correct_ratio_buy_summary]
    test_summaries = tf.summary.merge(test_summaries)
    
    optimizer = optimizer_factory[args.optimizer](
                    learning_rate=args.learning_rate,
                    momentum=args.momentum)
    trainable = tf.trainable_variables()
    optim = optimizer.minimize(loss, var_list=trainable)

    # Set up logging for TensorBoard.
    writer = tf.summary.FileWriter(logdir)
    writer.add_graph(tf.get_default_graph())
    run_metadata = tf.RunMetadata()
    #summaries = tf.summary.merge_all()

    # Set up session
    sess = tf.Session(config=tf.ConfigProto(log_device_placement=False))
    init = tf.global_variables_initializer()
    sess.run(init)

    # Saver for storing checkpoints of the model.
    saver = tf.train.Saver(var_list=tf.trainable_variables(), max_to_keep=args.max_checkpoints)

    try:
        saved_global_step = load(saver, sess, restore_from)
        if is_overwritten_training or saved_global_step is None:
            # The first training step will be saved_global_step + 1,
            # therefore we put -1 here for new or overwritten trainings.
            saved_global_step = -1

    except:
        print("Something went wrong while restoring checkpoint. "
              "We will terminate training to avoid accidentally overwriting "
              "the previous model.")
        raise

    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    test_every = 50
    reader.start_threads(sess, test_every)

    step = None
    last_saved_step = saved_global_step
    min_test_loss = 10
    
    try:
        for step in range(saved_global_step + 1, args.num_steps):
            start_time = time.time()
            if args.store_metadata and step % 50 == 0:
                # Slow run that stores extra information for debugging.
                print('Storing metadata')
                run_options = tf.RunOptions(
                    trace_level=tf.RunOptions.FULL_TRACE)
                summary, loss_value, _ = sess.run(
                    [summaries, raw_loss, optim],
                    options=run_options,
                    run_metadata=run_metadata)
                writer.add_summary(summary, step)
                writer.add_run_metadata(run_metadata,
                                        'step_{:04d}'.format(step))
                tl = timeline.Timeline(run_metadata.step_stats)
                timeline_path = os.path.join(logdir, 'timeline.trace')
                with open(timeline_path, 'w') as f:
                    f.write(tl.generate_chrome_trace_format(show_memory=True))
            else:
                summary, loss_value, _ = sess.run([train_loss_summary, raw_loss, optim],
                    feed_dict={dropout: 1})
                writer.add_summary(summary, step)

            duration = time.time() - start_time
            print('step {:d} - loss = {:.3f}, ({:.3f} sec/step)'
                  .format(step, loss_value, duration))
            
            if step % test_every == 0:
                summary, test_loss, ratios = sess.run([test_summaries, raw_loss, correct_ratio], feed_dict={dropout: 1})
                writer.add_summary(summary, step)
                min_test_loss = min(min_test_loss, test_loss)
                print('test loss = {:.3f}, min loss = {:.3f}'.format(test_loss, min_test_loss))
                print('correct ratio = {:.3f}'.format(ratios[0]))
                print('test true categories = {}'.format(ratios[1]))
                print('correct ratios = {}'.format(ratios[2]))
                print('false positives = {}'.format(ratios[3]))
            
            if step != 0 and step % args.checkpoint_every == 0:
                #loss_test_value = sess.run(loss_test)
                #print('test loss = {:.3f}'.format(loss_test_value))
                save(saver, sess, logdir, step)
                last_saved_step = step

    except KeyboardInterrupt:
        # Introduce a line break after ^C is displayed so save message
        # is on its own line.
        print()
    finally:
        if step > last_saved_step:
            save(saver, sess, logdir, step)
        coord.request_stop()
        coord.join(threads)
Beispiel #21
0
def main():
    args = get_arguments()

    try:
        directories = validate_directories(args)
    except ValueError as e:
        print("Some arguments are wrong:")
        print(str(e))
        return

    logdir = directories['logdir']
    logdir_root = directories['logdir_root']
    restore_from = directories['restore_from']

    # Even if we restored the model, we will treat it as new training
    # if the trained model is written into an arbitrary location.
    is_overwritten_training = logdir != restore_from

    with open(args.wavenet_params, 'r') as f:
        wavenet_params = json.load(f)

    # Create coordinator.
    coord = tf.train.Coordinator()

    # Load raw waveform from VCTK corpus.
    with tf.name_scope('create_inputs'):
        # Allow silence trimming to be skipped by specifying a threshold near
        # zero.
        silence_threshold = args.silence_threshold if args.silence_threshold > \
                                                      EPSILON else None
        reader = AudioReader(args.data_dir,
                             coord,
                             sample_rate=wavenet_params['sample_rate'],
                             sample_size=args.sample_size,
                             silence_threshold=args.silence_threshold)
        #audio_batch, input_IDs = reader.dequeue(args.batch_size)#单GPu转成下面的多GPU

    # Create network.
    batch_size_single_GPU = int(1.0 * args.batch_size / args.num_gpus)
    net = WaveNetModel(
        batch_size=batch_size_single_GPU,
        dilations=wavenet_params["dilations"],
        filter_width=wavenet_params["filter_width"],
        residual_channels=wavenet_params["residual_channels"],
        dilation_channels=wavenet_params["dilation_channels"],
        skip_channels=wavenet_params["skip_channels"],
        quantization_channels=wavenet_params["quantization_channels"],
        ID_channels=wavenet_params["ID_channels"],
        use_biases=wavenet_params["use_biases"],
        scalar_input=wavenet_params["scalar_input"],  #标量输入与矢量输入?
        initial_filter_width=wavenet_params["initial_filter_width"])
    if args.l2_regularization_strength == 0:
        args.l2_regularization_strength = None

    optimizer = optimizer_factory[args.optimizer](
        learning_rate=args.learning_rate, momentum=args.momentum)
    trainable = tf.trainable_variables()

    tower_grads = []
    #for i in range(args.num_gpus):
    with tf.device('/gpu:0'):
        with tf.name_scope('losstower_0') as scope:
            audio_batch, input_IDs = reader.dequeue(batch_size_single_GPU)
            all_loss = net.loss(audio_batch, input_IDs,
                                args.l2_regularization_strength)
            loss, L1 = all_loss  #total loss
            tf.get_variable_scope().reuse_variables()
            grads_vars = optimizer.compute_gradients(loss, var_list=trainable)
            tower_grads.append(grads_vars)  #
    update_wei_op = []
    with tf.device('/cpu:0'):  ###
        for gv in tower_grads:
            app_grad = optimizer.apply_gradients(gv)
            update_wei_op.append(app_grad)

    with tf.control_dependencies(update_wei_op):
        train_op = tf.no_op()

    # Set up logging for TensorBoard.
    writer = tf.train.SummaryWriter(logdir)
    writer.add_graph(tf.get_default_graph())
    run_metadata = tf.RunMetadata()
    summaries = tf.merge_all_summaries()

    # Set up session
    sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
    init = tf.initialize_all_variables()
    sess.run(init)

    # Saver for storing checkpoints of the model.
    saver = tf.train.Saver(var_list=tf.trainable_variables())

    try:
        saved_global_step = load(saver, sess, restore_from)
        if is_overwritten_training or saved_global_step is None:
            # The first training step will be saved_global_step + 1,
            # therefore we put -1 here for new or overwritten trainings.
            saved_global_step = -1

    except:
        print("Something went wrong while restoring checkpoint. "
              "We will terminate training to avoid accidentally overwriting "
              "the previous model.")
        raise

    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    reader.start_threads(sess, N_THREADS)

    step = None
    try:
        last_saved_step = saved_global_step
        avg_loss_value = 0.0
        avg_L1_value = 0.0
        start_time = time.time()
        for step in range(saved_global_step + 1, args.num_steps):
            if args.store_metadata and step % 50 == 0:
                # Slow run that stores extra information for debugging.
                print('Storing metadata')
                run_options = tf.RunOptions(
                    trace_level=tf.RunOptions.FULL_TRACE)
                summary, all_loss_value, _ = sess.run(
                    [summaries, all_loss, train_op],
                    options=run_options,
                    run_metadata=run_metadata)
                writer.add_summary(summary, step)
                writer.add_run_metadata(run_metadata,
                                        'step_{:04d}'.format(step))
                tl = timeline.Timeline(run_metadata.step_stats)
                timeline_path = os.path.join(logdir, 'timeline.trace')
                with open(timeline_path, 'w') as f:
                    f.write(tl.generate_chrome_trace_format(show_memory=True))
            else:
                all_loss_value, _ = sess.run([all_loss, train_op])
                #writer.add_summary(summary, step)
            loss_value, L1_value = all_loss_value
            avg_loss_value += loss_value
            avg_L1_value += L1_value

            if step % args.checkloss_every == 0:
                avg_loss_value = avg_loss_value / args.checkloss_every
                avg_L1_value = avg_L1_value / args.checkloss_every
                duration = (time.time() -
                            start_time) * 1.0 / args.checkloss_every
                print(
                    'step {:d} - avg_loss = {:.3f}, avg_L1 = {:.3f}, ({:.3f} sec/step)'
                    .format(step, loss_value, L1_value, duration))
                sys.stdout.flush()
                avg_loss_value = 0.0
                avg_L1_value = 0.0
                start_time = time.time()

            if step % args.checkpoint_every == 0:
                save(saver, sess, logdir, step)
                last_saved_step = step

    except KeyboardInterrupt:
        # Introduce a line break after ^C is displayed so save message
        # is on its own line.
        print()
    finally:
        if step > last_saved_step:
            save(saver, sess, logdir, step)
        coord.request_stop()
        coord.join(threads)
Beispiel #22
0
def main():
    args = get_arguments()

    if args.isDebug in ["True", "true", "t", "1"]:
        isDebug = True
        print("Running train.py for debugging...")
    elif args.isDebug in ["False", "false", "f", "0"]:
        isDebug = False
        print("Running train.py for actual training...")
    else:
        print("--isDebug has to be True or False")
        exit()

    try:
        directories = validate_directories(args)
    except ValueError as e:
        print("Some arguments are wrong:")
        print(str(e))
        return

    logdir = directories['logdir']
    restore_from = directories['restore_from']
    print(restore_from)

    # Even if we restored the model, we will treat it as new training
    # if the trained model is written into an arbitrary location.
    is_overwritten_training = logdir != restore_from

    with open(args.wavenet_params, 'r') as f:
        wavenet_params = json.load(f)

    # Create coordinator.
    coord = tf.train.Coordinator()

    # Load raw waveform from VCTK corpus.
    with tf.name_scope('create_inputs'):
        # Allow silence trimming to be skipped by specifying a threshold near
        # zero.
        silence_threshold = args.silence_threshold if args.silence_threshold > \
                                                      EPSILON else None
        gc_enabled = args.gc_channels is not None
        lc_enabled = args.lc_channels is not None
        # reader = AudioReader(
        #     args.data_dir,
        #     coord,
        #     sample_rate=wavenet_params['sample_rate'],
        #     gc_enabled=gc_enabled,
        #     lc_enabled=lc_enabled,
        #     receptive_field=WaveNetModel.calculate_receptive_field(wavenet_params["filter_width"],
        #                                                            wavenet_params["dilations"],
        #                                                            wavenet_params["scalar_input"],
        #                                                            wavenet_params["initial_filter_width"]),
        #     sample_size=args.sample_size,
        #     silence_threshold=silence_threshold)
        # audio_batch = reader.dequeue(args.batch_size)
        # if gc_enabled:
        #     gc_id_batch = reader.dequeue_gc(args.batch_size)
        # else:
        #     gc_id_batch = None
        #
        # if lc_enabled:
        #     lc_id_batch = reader.dequeue_lc(args.batch_size)
        # else:
        #     lc_id_batch = None

    # Create network.
    net = WaveNetModel(
        batch_size=args.batch_size,
        dilations=wavenet_params["dilations"],
        filter_width=wavenet_params["filter_width"],
        residual_channels=wavenet_params["residual_channels"],
        dilation_channels=wavenet_params["dilation_channels"],
        skip_channels=wavenet_params["skip_channels"],
        quantization_channels=wavenet_params["quantization_channels"],
        use_biases=wavenet_params["use_biases"],
        scalar_input=wavenet_params["scalar_input"],
        initial_filter_width=wavenet_params["initial_filter_width"],
        histograms=args.histograms,
        global_condition_channels=args.gc_channels,
        # global_condition_cardinality=reader.gc_category_cardinality,
        local_condition_channels=args.lc_channels)

    if args.l2_regularization_strength == 0:
        args.l2_regularization_strength = None

    audio_placeholder_training = tf.placeholder(dtype=tf.float32, shape=None)
    gc_placeholder_training = tf.placeholder(
        dtype=tf.int32) if gc_enabled else None
    lc_placeholder_training = tf.placeholder(
        dtype=tf.float32, shape=(net.batch_size, None,
                                 512)) if lc_enabled else None
    loss = net.loss(input_batch=audio_placeholder_training,
                    global_condition_batch=gc_placeholder_training,
                    local_condition_batch=lc_placeholder_training,
                    l2_regularization_strength=args.l2_regularization_strength)
    optimizer = optimizer_factory[args.optimizer](
        learning_rate=args.learning_rate, momentum=args.momentum)
    trainable = tf.trainable_variables()
    optim = optimizer.minimize(loss, var_list=trainable)
    """variables for validation"""
    net.batch_size = 1
    audio_placeholder_validation = tf.placeholder(dtype=tf.float32, shape=None)
    gc_placeholder_validation = tf.placeholder(
        dtype=tf.int32) if gc_enabled else None
    lc_placeholder_validation = tf.placeholder(
        dtype=tf.float32, shape=(net.batch_size, None,
                                 512)) if lc_enabled else None
    validation = net.validation(
        input_batch=audio_placeholder_validation,
        global_condition_batch=gc_placeholder_validation,
        local_condition_batch=lc_placeholder_validation)

    # Set up logging for TensorBoard.
    writer = tf.summary.FileWriter(logdir)
    writer.add_graph(tf.get_default_graph())
    run_metadata = tf.RunMetadata()
    summaries = tf.summary.merge_all()

    # Set up session
    sess = tf.Session(config=tf.ConfigProto(log_device_placement=False))

    # if args.restore_model is not None:
    #     variables_to_restore = {
    #         var.name[:-2]: var for var in tf.global_variables()
    #         if not ('state_buffer' in var.name or 'pointer' in var.name)}
    #     saver = tf.train.Saver(variables_to_restore)
    #
    #     print('Restoring model from {}'.format(args.checkpoint))
    #     saver.restore(sess, args.checkpoint)
    #
    #     print("Restoring model done")
    # else:
    init = tf.global_variables_initializer()
    sess.run(init)

    # Saver for storing checkpoints of the model.
    saver = tf.train.Saver(var_list=tf.trainable_variables(),
                           max_to_keep=args.max_checkpoints)

    try:
        saved_global_step = load(saver, sess, restore_from)
        if is_overwritten_training or saved_global_step is None:
            # The first training step will be saved_global_step + 1,
            # therefore we put -1 here for new or overwritten trainings.
            saved_global_step = -1

    except:
        print("Something went wrong while restoring checkpoint. "
              "We will terminate training to avoid accidentally overwriting "
              "the previous model.")
        raise

    # threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    # reader.start_threads(sess)

    training_log_file = open(DATA_DIRECTORY + "training_log.txt", "w")
    validation_log_file = open(DATA_DIRECTORY + "validation_log.txt", "w")

    last_saved_step = saved_global_step

    with open('pickle/audio_lists_training_x_6.pkl', 'rb') as f1:
        audio_lists_training = pickle.load(f1)

    with open('pickle/img_vec_lists_training_x_6.pkl', 'rb') as f2:
        img_vec_lists_training = pickle.load(f2)

    with open('pickle/audio_lists_validation.pkl', 'rb') as f3:
        audio_lists_validation = pickle.load(f3)

    with open('pickle/img_vec_lists_validation.pkl', 'rb') as f4:
        img_vec_lists_validation = pickle.load(f4)

    try:
        for epoch in range(saved_global_step + 1, args.num_steps):
            start_time = time.time()
            """ training """
            num_video_frames = []
            # training_data = audio_reader.load_generic_audio_video_without_downloading(DATA_DIRECTORY, SAMPLE_RATE,
            #                                                                             reader.i2v, "training", num_video_frames)
            training_data_order = np.arange(6)
            net.batch_size = 3
            random.shuffle(training_data_order)
            print(training_data_order)

            for o in range(2):

                video_matrix = np.zeros(
                    (net.batch_size, net.receptive_field + int(16000 / 25),
                     512))
                frame_index = 1
                for index in range(len(img_vec_lists_training[0])):

                    audio = audio_lists_training[training_data_order[o *
                                                                     3]][index]
                    audio = audio.reshape(1, -1)
                    img_vec = img_vec_lists_training[training_data_order[
                        o * 3]][index]
                    img_vecs = np.repeat(img_vec, int(16000 / 25), axis=1)
                    # audio = np.pad(audio, [[net.receptive_field, 0], [0, 0]], 'constant')
                    audio1 = audio_lists_training[training_data_order[
                        o * 3 + 1]][index]
                    audio1 = audio1.reshape(1, -1)
                    img_vec1 = img_vec_lists_training[training_data_order[
                        o * 3 + 1]][index]
                    img_vecs1 = np.repeat(img_vec1, int(16000 / 25), axis=1)
                    # audio1 = np.pad(audio1, [[net.receptive_field, 0], [0, 0]], 'constant')
                    audio2 = audio_lists_training[training_data_order[
                        o * 3 + 2]][index]
                    audio2 = audio2.reshape(1, -1)
                    img_vec2 = img_vec_lists_training[training_data_order[
                        o * 3 + 2]][index]
                    img_vecs2 = np.repeat(img_vec2, int(16000 / 25), axis=1)
                    # audio2 = np.pad(audio2, [[net.receptive_field, 0], [0, 0]], 'constant')
                    audio = np.vstack((audio, audio1))
                    audio = np.vstack((audio, audio2))
                    img_vecs = np.vstack((img_vecs, img_vecs1))
                    img_vecs = np.vstack((img_vecs, img_vecs2))

                    video_matrix[:, :-int(16000 /
                                          25), :] = video_matrix[:,
                                                                 int(16000 /
                                                                     25):, :]
                    video_matrix[:, -int(16000 / 25):, :] = img_vecs
                    # print(audio.shape)
                    # print(video_matrix.shape)

                    summary, loss_value, _ = sess.run(
                        [summaries, loss, optim],
                        feed_dict={
                            audio_placeholder_training: audio,
                            lc_placeholder_training: video_matrix
                        })

                    duration = time.time() - start_time
                    if frame_index % 10 == 0:
                        print(
                            'epoch {:d}, frame_index {:d}/{:d} - loss = {:.3f}, ({:.3f} sec/epoch)'
                            .format(epoch, frame_index,
                                    len(img_vec_lists_training[0]), loss_value,
                                    duration))
                        training_log_file.write(
                            'epoch {:d}, frame_index {:d}/{:d} - loss = {:.3f}, ({:.3f} sec/epoch)\n'
                            .format(epoch, frame_index,
                                    len(img_vec_lists_training[0]), loss_value,
                                    duration))
                    frame_index += 1

                    if frame_index == 2 and isDebug:
                        break
            """validation and generation"""
            if epoch % args.generate_every == 0:
                print("calculating validation score...")
                num_video_frames = []
                # validation_data = audio_reader.load_generic_audio_video_without_downloading(DATA_DIRECTORY, SAMPLE_RATE,
                #                                                                             reader.i2v, "validation", num_video_frames)
                validation_score = 0
                # pad = np.zeros((512, net.receptive_field))
                frame_index = 1
                waveform = []
                # prediction = None

                net.batch_size = 1
                video_matrix = np.zeros(
                    (net.batch_size, net.receptive_field + int(16000 / 25),
                     512))

                for index in range(len(img_vec_lists_validation)):
                    audio = audio_lists_validation[index]
                    img_vec = img_vec_lists_validation[index]
                    video_matrix[:, :-int(16000 /
                                          25), :] = video_matrix[:,
                                                                 int(16000 /
                                                                     25):, :]
                    video_matrix[:, -int(16000 / 25):, :] = img_vec

                    # return the error and prediction at the same time
                    validation_value, prediction = sess.run(
                        validation,
                        feed_dict={
                            audio_placeholder_validation: audio,
                            lc_placeholder_validation: video_matrix
                        })

                    validation_score += validation_value

                    if prediction is not None:
                        for i in range(prediction.shape[0]):
                            # generate a sample based on the predection
                            sample = prediction2sample(
                                prediction[i, :], 1.0,
                                net.quantization_channels)
                            waveform.append(sample)

                    if frame_index % 10 == 0:
                        # show the progress

                        print('validation {:d}/{:d}'.format(
                            frame_index, len(img_vec_lists_training[0])))

                    frame_index += 1

                    if frame_index == 10 and isDebug:
                        break

                print('epoch {:d} - validation = {:.3f}'.format(
                    epoch, sum(validation_score)))
                validation_log_file.write(
                    'epoch {:d} - validation = {:.3f}\n'.format(
                        epoch, sum(validation_score)))

                if len(waveform) > 0:
                    decode = mu_law_decode(
                        audio_placeholder_validation,
                        wavenet_params['quantization_channels'])
                    out = sess.run(
                        decode,
                        feed_dict={audio_placeholder_validation: waveform})
                    write_wav(out, wavenet_params['sample_rate'],
                              DATA_DIRECTORY + "epoch_" + str(epoch) + ".wav")

            if epoch % args.checkpoint_every == 0:
                save(saver, sess, logdir, epoch)
                last_saved_step = epoch

    except KeyboardInterrupt:
        # Introduce a line break after ^C is displayed so save message
        # is on its own line.
        print()
    finally:
        validation_log_file.close()
        training_log_file.close()
        if epoch > last_saved_step:
            save(saver, sess, logdir, epoch)
Beispiel #23
0
def main():
    args = get_arguments()

    try:
        directories = validate_directories(args)
    except ValueError as e:
        print("Some arguments are wrong:")
        print(str(e))
        return

    logdir = directories['logdir']
    restore_from = directories['restore_from']

    # Even if we restored the model, we will treat it as new training
    # if the trained model is written into an arbitrary location.
    is_overwritten_training = logdir != restore_from

    with open(args.wavenet_params, 'r') as f:
        wavenet_params = json.load(f)

    # Create coordinator.
    coord = tf.train.Coordinator()

    # Load raw waveform from VCTK corpus.
    with tf.name_scope('create_inputs'):
        # Allow silence trimming to be skipped by specifying a threshold near
        # zero.
        silence_threshold = args.silence_threshold if args.silence_threshold > \
                                                      EPSILON else None
        gc_enabled = args.gc_channels is not None
        reader = AudioReader(
            args.data_dir,
            coord,
            sample_rate=wavenet_params['sample_rate'],
            gc_enabled=gc_enabled,
            receptive_field=WaveNetModel.calculate_receptive_field(
                wavenet_params["filter_width"], wavenet_params["dilations"],
                wavenet_params["scalar_input"],
                wavenet_params["initial_filter_width"]),
            sample_size=args.sample_size,
            silence_threshold=silence_threshold)
        audio_batch = reader.dequeue(args.batch_size)
        if gc_enabled:
            gc_id_batch = reader.dequeue_gc(args.batch_size)
        else:
            gc_id_batch = None

    # Create network.
    net = WaveNetModel(
        batch_size=args.batch_size,
        dilations=wavenet_params["dilations"],
        filter_width=wavenet_params["filter_width"],
        residual_channels=wavenet_params["residual_channels"],
        dilation_channels=wavenet_params["dilation_channels"],
        skip_channels=wavenet_params["skip_channels"],
        quantization_channels=wavenet_params["quantization_channels"],
        use_biases=wavenet_params["use_biases"],
        scalar_input=wavenet_params["scalar_input"],
        initial_filter_width=wavenet_params["initial_filter_width"],
        histograms=args.histograms,
        global_condition_channels=args.gc_channels,
        global_condition_cardinality=reader.gc_category_cardinality)

    if args.l2_regularization_strength == 0:
        args.l2_regularization_strength = None
    #aleix
    loss, global_condition_batch, gc_embedding, conv_filter, conv_filter0, conv_filter1, conv_gate, \
    embedding_table, weights_gc_filter, input_batch = net.loss(input_batch=audio_batch,
                    global_condition_batch=gc_id_batch,
                    l2_regularization_strength=args.l2_regularization_strength)
    optimizer = optimizer_factory[args.optimizer](
        learning_rate=args.learning_rate, momentum=args.momentum)
    trainable = tf.trainable_variables()
    optim = optimizer.minimize(loss, var_list=trainable)

    # Set up logging for TensorBoard.
    writer = tf.summary.FileWriter(logdir)
    writer.add_graph(tf.get_default_graph())
    run_metadata = tf.RunMetadata()
    summaries = tf.summary.merge_all()

    # Set up session
    sess = tf.Session(config=tf.ConfigProto(log_device_placement=False))
    init = tf.global_variables_initializer()
    sess.run(init)

    # Saver for storing checkpoints of the model.
    saver = tf.train.Saver(var_list=tf.trainable_variables(),
                           max_to_keep=args.max_checkpoints)

    try:
        saved_global_step = load(saver, sess, restore_from)
        if is_overwritten_training or saved_global_step is None:
            # The first training step will be saved_global_step + 1,
            # therefore we put -1 here for new or overwritten trainings.
            saved_global_step = -1

    except:
        print("Something went wrong while restoring checkpoint. "
              "We will terminate training to avoid accidentally overwriting "
              "the previous model.")
        raise

    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    reader.start_threads(sess)

    step = None
    last_saved_step = saved_global_step
    loss_plot = []  #store loss function (aleix)
    try:
        for step in range(saved_global_step + 1, args.num_steps):
            start_time = time.time()
            if args.store_metadata and step % 50 == 0:
                # Slow run that stores extra information for debugging.
                print('Storing metadata')
                run_options = tf.RunOptions(
                    trace_level=tf.RunOptions.FULL_TRACE)
                summary, loss_value, _ = sess.run([summaries, loss, optim],
                                                  options=run_options,
                                                  run_metadata=run_metadata)
                writer.add_summary(summary, step)
                writer.add_run_metadata(run_metadata,
                                        'step_{:04d}'.format(step))
                tl = timeline.Timeline(run_metadata.step_stats)
                timeline_path = os.path.join(logdir, 'timeline.trace')
                with open(timeline_path, 'w') as f:
                    f.write(tl.generate_chrome_trace_format(show_memory=True))
            else:
                #aleix
                summary, loss_value, global_condition_batch0, gc_embedding0, conv_filter_end, conv_filter0_0, \
                conv_filter0_1, conv_gate0,embedding_table0, weights_gc_filter0,input_batch0, _ = sess.run([
                    summaries, loss, global_condition_batch, gc_embedding, conv_filter, conv_filter0, conv_filter1,
                    conv_gate, embedding_table, weights_gc_filter, input_batch, optim])
                #print('global_condition_batch:')
                #print(global_condition_batch0)
                #print(global_condition_batch0.shape)
                #print()
                #print('gc_embedding')
                #print(gc_embedding0)
                #print(gc_embedding0.shape)
                #print()
                #print('conv_filter')
                #print(conv_filter_end)
                #print(conv_filter_end.shape)
                #print()
                #print('conv_filter0')
                #print(conv_filter0_0)
                #print(conv_filter0_0.shape)
                #print()
                #print('conv_filter1')
                #print(conv_filter0_1)
                #print(conv_filter0_1.shape)
                #print()
                #print('conv_gate')
                #print(conv_gate0)
                #print(conv_gate0.shape)
                #print()
                #print('embedding_table')
                #print(embedding_table0)
                #print(embedding_table0.shape)
                #print(target_output00)
                #print(target_output00.shape)
                #print(target_output10)
                #print(target_output10.shape)
                #print()
                #print('weights_gc_filter')
                #print(weights_gc_filter0)
                #print(weights_gc_filter.shape)
                #print(input_batch0.shape)
                writer.add_summary(summary, step)

            duration = time.time() - start_time
            print('step {:d} - loss = {:.3f}, ({:.3f} sec/step)'.format(
                step, loss_value, duration))
            loss_plot.append(loss_value)
            if step % args.checkpoint_every == 0:
                save(saver, sess, logdir, step)
                last_saved_step = step
        plt.figure(1)  #store loss function (aleix)
        plt.plot(loss_plot)
        #plt.show()
        plt.savefig(os.path.join(args.data_dir, 'loss.png'))
        print()
        print('Loss .plot saved')
        file00 = open(os.path.join(args.data_dir, 'loss.txt'), 'w')
        for item in loss_plot:
            file00.write("%s\n" % item)
        file00.close()
        print('Loss .txt saved')
        print()
    except KeyboardInterrupt:
        plt.figure(1)  #store loss function (aleix)
        plt.plot(loss_plot)
        plt.savefig(os.path.join(args.data_dir, 'loss.png'))
        print()
        print('Loss plot saved')
        file00 = open(os.path.join(args.data_dir, 'loss.txt'), 'w')
        for item in loss_plot:
            file00.write("%s\n" % item)
        file00.close()
        print('Loss .txt saved')
        print()
        #plt.show()

        # Introduce a line break after ^C is displayed so save message
        # is on its own line.
        print()
    finally:
        if step > last_saved_step:
            save(saver, sess, logdir, step)
        coord.request_stop()
        coord.join(threads)
Beispiel #24
0
class TestGeneration(tf.test.TestCase):
    def testGenerateSimple(self):
        # Reader config
        with open(TEST_DATA + "/config.json") as json_file:
            self.reader_config = json.load(json_file)

        # Initialize the reader
        receptive_field_size = WaveNetModel.calculate_receptive_field(
            2, [1, 1], False, 8)

        self.reader = CsvReader([
            TEST_DATA + "/test.dat", TEST_DATA + "/test.emo",
            TEST_DATA + "/test.pho"
        ],
                                batch_size=1,
                                receptive_field=receptive_field_size,
                                sample_size=SAMPLE_SIZE,
                                config=self.reader_config)

        # WaveNet model
        self.net = WaveNetModel(batch_size=1,
                                dilations=[1, 1],
                                filter_width=2,
                                residual_channels=8,
                                dilation_channels=8,
                                skip_channels=8,
                                quantization_channels=2,
                                use_biases=True,
                                scalar_input=False,
                                initial_filter_width=8,
                                histograms=False,
                                global_channels=GC_CHANNELS,
                                local_channels=LC_CHANNELS)

        loss = self.net.loss(input_batch=self.reader.data_batch,
                             global_condition=self.reader.gc_batch,
                             local_condition=self.reader.lc_batch,
                             l2_regularization_strength=0)

        optimizer = optimizer_factory['adam'](learning_rate=0.003,
                                              momentum=0.9)
        trainable = tf.trainable_variables()
        train_op = optimizer.minimize(loss, var_list=trainable)

        samples = tf.placeholder(tf.float32,
                                 shape=(receptive_field_size,
                                        self.reader.data_dim),
                                 name="samples")
        gc = tf.placeholder(tf.int32, shape=(receptive_field_size), name="gc")
        lc = tf.placeholder(tf.int32,
                            shape=(receptive_field_size, 4),
                            name="lc")

        gc = tf.one_hot(gc, GC_CHANNELS)
        lc = tf.one_hot(lc, int(LC_CHANNELS / 4))

        predict = self.net.predict_proba(samples, gc, lc)
        '''does nothing'''
        with self.test_session() as session:
            session.run([
                tf.local_variables_initializer(),
                tf.global_variables_initializer(),
                tf.tables_initializer(),
            ])
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=session, coord=coord)

            for i in range(5000):
                _, loss_val = session.run([train_op, loss])
                print("step %d loss %.4f" % (i, loss_val), end='\r')
                sys.stdout.flush()
            print()

            data_samples = np.random.random(
                (receptive_field_size, self.reader.data_dim))
            gc_samples = np.zeros((receptive_field_size))
            lc_samples = np.zeros((receptive_field_size, 4))

            # WITH CONDITIONING.
            error = 0.0
            i = 0.0
            for p in range(3):
                for q in range(3):
                    gc_samples[:] = p
                    lc_samples[:, :] = q
                    for _ in range(64):
                        prediction = session.run(predict,
                                                 feed_dict={
                                                     'samples:0': data_samples,
                                                     'gc:0': gc_samples,
                                                     'lc:0': lc_samples
                                                 })
                        data_samples = data_samples[1:, :]
                        data_samples = np.append(data_samples,
                                                 prediction,
                                                 axis=0)
                    print("G%d L%d - %.2f vs %.2f ERR %.2f" %
                          (p, q, i, np.average(prediction),
                           np.abs(i - np.average(prediction))))
                    error += np.abs(i - np.average(prediction))
                    data_samples = np.random.random(
                        (receptive_field_size, self.reader.data_dim))
                    i += 0.1

            print("TOTAL ERROR CONDITIONING: %.5f" % error)
            # WITHOUT CONDITIONING.

            data_samples = np.random.random(
                (receptive_field_size, self.reader.data_dim))

            errorNo = 0.0
            i = 0.0
            for p in range(3):
                for q in range(3):
                    gc_samples[:] = 0
                    lc_samples[:, :] = 0
                    for _ in range(64):
                        prediction = session.run(predict,
                                                 feed_dict={
                                                     'samples:0': data_samples,
                                                     'gc:0': gc_samples,
                                                     'lc:0': lc_samples
                                                 })
                        data_samples = data_samples[1:, :]
                        data_samples = np.append(data_samples,
                                                 prediction,
                                                 axis=0)
                    print("G%d L%d - %.2f vs %.2f ERR %.2f" %
                          (p, q, i, np.average(prediction),
                           (i - np.average(prediction))))
                    errorNo += np.abs(i - np.average(prediction))
                    data_samples = np.random.random(
                        (receptive_field_size, self.reader.data_dim))
                    i += 0.1

            print("TOTAL ERROR NO CONDITIONING: %.5f" % errorNo)
            self.assertTrue(error < 0.5)
            self.assertTrue(errorNo > 0.05)
Beispiel #25
0
        skip_channels=wavenet_params["skip_channels"],
        quantization_channels=wavenet_params["quantization_channels"],
        use_biases=wavenet_params["use_biases"],
        initial_filter_width=wavenet_params["initial_filter_width"])

    gi_sampler = get_generator_input_sampler()

    # White noise generator params
    white_mean = 0
    white_sigma = 1
    white_length = 27117

    Z = tf.placeholder(tf.float32, shape=[None, white_length], name='Z')

    # initialize generator
    _, w_prediction = G.loss(input_batch=Z, name='generator')

    theta_G = tf.trainable_variables(scope='wavenet')

    X = tf.placeholder(tf.float32, shape=[None, w1], name='X')

    init = tf.global_variables_initializer()
    sess.run(init)

    levels = []
    for i in range(quantization_channels):
        levels.append(i)

    levels_tensor = tf.reshape(tf.constant(levels, dtype=tf.float32),
                               [quantization_channels, 1])
    G_pre_stand = tf.matmul(tf.nn.softmax(w_prediction), levels_tensor)
Beispiel #26
0
class TestNet(tf.test.TestCase):
    def setUp(self):
        print('TestNet setup.')
        sys.stdout.flush()

        self.optimizer_type = 'sgd'
        self.learning_rate = 0.02
        self.generate = False
        self.momentum = MOMENTUM
        self.global_conditioning = False
        self.train_iters = TRAIN_ITERATIONS
        self.net = WaveNetModel(
            batch_size=1,
            dilations=[1, 2, 4, 8, 16, 32, 64, 1, 2, 4, 8, 16, 32, 64],
            filter_width=2,
            residual_channels=32,
            dilation_channels=32,
            quantization_channels=QUANTIZATION_CHANNELS,
            skip_channels=32,
            global_condition_channels=None,
            global_condition_cardinality=None)

    def _save_net(sess):
        saver = tf.train.Saver(var_list=tf.trainable_variables())
        saver.save(sess, os.path.join('tmp', 'test.ckpt'))

    # Train a net on a short clip of 3 sine waves superimposed
    # (an e-flat chord).
    #
    # Presumably it can overfit to such a simple signal. This test serves
    # as a smoke test where we just check that it runs end-to-end during
    # training, and learns this waveform.

    def testEndToEndTraining(self):
        def CreateTrainingFeedDict(audio, speaker_ids, audio_placeholder,
                                   gc_placeholder, i):
            speaker_index = 0
            if speaker_ids is None:
                # No global conditioning.
                feed_dict = {audio_placeholder: audio}
            else:
                feed_dict = {
                    audio_placeholder: audio,
                    gc_placeholder: speaker_ids
                }
            return feed_dict, speaker_index

        np.random.seed(42)
        audio, speaker_ids = make_sine_waves(self.global_conditioning)
        # Pad with 0s (silence) times size of the receptive field minus one,
        # because the first sample of the training data is 0 and if the network
        # learns to predict silence based on silence, it will generate only
        # silence.
        if self.global_conditioning:
            audio = np.pad(audio, ((0, 0), (self.net.receptive_field - 1, 0)),
                           'constant')
        else:
            audio = np.pad(audio, (self.net.receptive_field - 1, 0),
                           'constant')

        audio_placeholder = tf.placeholder(dtype=tf.float32)
        gc_placeholder = tf.placeholder(dtype=tf.int32)  \
            if self.global_conditioning else None

        loss = self.net.loss(input_batch=audio_placeholder,
                             global_condition_batch=gc_placeholder)
        optimizer = optimizer_factory[self.optimizer_type](
            learning_rate=self.learning_rate, momentum=self.momentum)
        trainable = tf.trainable_variables()
        optim = optimizer.minimize(loss, var_list=trainable)
        init = tf.global_variables_initializer()

        generated_waveform = None
        max_allowed_loss = 0.1
        loss_val = max_allowed_loss
        initial_loss = None
        operations = [loss, optim]
        with self.test_session() as sess:
            feed_dict, speaker_index = CreateTrainingFeedDict(
                audio, speaker_ids, audio_placeholder, gc_placeholder, 0)
            sess.run(init)
            initial_loss = sess.run(loss, feed_dict=feed_dict)
            for i in range(self.train_iters):
                feed_dict, speaker_index = CreateTrainingFeedDict(
                    audio, speaker_ids, audio_placeholder, gc_placeholder, i)
                [results] = sess.run([operations], feed_dict=feed_dict)
                if i % 100 == 0:
                    print("i: %d loss: %f" % (i, results[0]))

            loss_val = results[0]

            # Sanity check the initial loss was larger.
            self.assertGreater(initial_loss, max_allowed_loss)

            # Loss after training should be small.
            self.assertLess(loss_val, max_allowed_loss)

            # Loss should be at least two orders of magnitude better
            # than before training.
            self.assertLess(loss_val / initial_loss, 0.02)

            if self.generate:
                # self._save_net(sess)
                if self.global_conditioning:
                    # Check non-fast-generated waveform.
                    generated_waveforms, ids = generate_waveforms(
                        sess, self.net, False, speaker_ids)
                    for (waveform, id) in zip(generated_waveforms, ids):
                        check_waveform(self.assertGreater, waveform, id[0])

                    # Check fast-generated wveform.
                    # generated_waveforms, ids = generate_waveforms(sess,
                    #     self.net, True, speaker_ids)
                    # for (waveform, id) in zip(generated_waveforms, ids):
                    #     print("Checking fast wf for id{}".format(id[0]))
                    #     check_waveform( self.assertGreater, waveform, id[0])

                else:
                    # Check non-incremental generation
                    generated_waveforms, _ = generate_waveforms(
                        sess, self.net, False, None)
                    check_waveform(self.assertGreater, generated_waveforms[0],
                                   None)
                    # Check incremental generation
                    generated_waveform = generate_waveforms(
                        sess, self.net, True, None)
                    check_waveform(self.assertGreater, generated_waveforms[0],
                                   None)
class TestNet(tf.test.TestCase):
    def setUp(self):
        print('TestNet setup.')
        sys.stdout.flush()

        self.optimizer_type = 'sgd'
        self.learning_rate = 0.02
        self.generate = False
        self.momentum = MOMENTUM
        self.global_conditioning = False
        self.local_conditioning = False
        self.train_iters = TRAIN_ITERATIONS
        self.net = WaveNetModel(
            batch_size=1,
            dilations=[1, 2, 4, 8, 16, 32, 64, 1, 2, 4, 8, 16, 32, 64],
            filter_width=2,
            residual_channels=32,
            dilation_channels=32,
            quantization_channels=QUANTIZATION_CHANNELS,
            skip_channels=32,
            global_condition_channels=None,
            global_condition_cardinality=None)

    def _save_net(sess):
        saver = tf.train.Saver(var_list=tf.trainable_variables())
        saver.save(sess, os.path.join('tmp', 'test.ckpt'))

    # Train a net on a short clip of 3 sine waves superimposed
    # (an e-flat chord).
    #
    # Presumably it can overfit to such a simple signal. This test serves
    # as a smoke test where we just check that it runs end-to-end during
    # training, and learns this waveform.

    def testEndToEndTraining(self):
        def shuffle_row(audio, gc, lc):
            from copy import deepcopy
            for i in range(10):
                index1 = random.randint(0, audio.shape[0] - 1)
                index2 = random.randint(0, audio.shape[0] - 1)
                audio1 = deepcopy(audio[index1, :])
                audio2 = deepcopy(audio[index2, :])
                audio[index1, :] = audio2
                audio[index2, :] = audio1
                lc1 = deepcopy(lc[index1, :])
                lc2 = deepcopy(lc[index2, :])
                lc[index1, :] = lc2
                lc[index2, :] = lc1
                gc1 = deepcopy(gc[index1])
                gc2 = deepcopy(gc[index2])
                gc[index1] = gc2
                gc[index2] = gc1
            return audio, gc, lc

        def CreateTrainingFeedDict(audio, gc, lc, audio_placeholder,
                                   gc_placeholder, lc_placeholder, i):
            speaker_index = 0

            i = i % int(audio.shape[0] / self.net.batch_size)
            if i == 0:
                audio, gc, lc = shuffle_row(audio, gc, lc)
            _audio = audio[i * self.net.batch_size:(i + 1) *
                           self.net.batch_size]
            _gc = gc[i * self.net.batch_size:(i + 1) * self.net.batch_size]
            _lc = lc[i * self.net.batch_size:(i + 1) * self.net.batch_size]
            print("training audio length")
            print(_audio.shape)
            exit()

            if gc is None:
                # No global conditioning.
                feed_dict = {audio_placeholder: _audio}
            elif self.global_conditioning and not self.local_conditioning:
                feed_dict = {audio_placeholder: _audio, gc_placeholder: _gc}
            elif not self.global_conditioning and self.local_conditioning:
                feed_dict = {audio_placeholder: _audio, lc_placeholder: _lc}
            elif self.global_conditioning and self.local_conditioning:
                feed_dict = {
                    audio_placeholder: _audio,
                    gc_placeholder: _gc,
                    lc_placeholder: _lc
                }
            return feed_dict, speaker_index, audio, gc, lc

        np.random.seed(42)

        receptive_field = self.net.receptive_field
        audio, gc, lc, duration_lists = make_sine_waves(
            self.global_conditioning, self.local_conditioning, True)
        waveform_size = audio.shape[1]

        print("shape check 1")
        print(audio.shape)
        print(gc.shape)
        print(lc.shape)
        # Pad with 0s (silence) times size of the receptive field minus one,
        # because the first sample of the training data is 0 and if the network
        # learns to predict silence based on silence, it will generate only
        # silence.
        # if self.global_conditioning:
        #     # print(audio.shape)
        #     audio = np.pad(audio, ((0, 0), (self.net.receptive_field - 1, 0)), 'constant')
        #     # lc = np.pad(lc, ((0,0), (self.net.receptive_field - 1, 0)), 'maximum')
        #     # to set lc=0 for the initial silence
        #     lc = np.pad(lc, ((0, 0), (self.net.receptive_field - 1, 0)), 'constant')
        #     # print(audio.shape)
        #     # exit()
        # else:
        #     # print(audio.shape)
        #     audio = np.pad(audio, (self.net.receptive_field - 1, 0),
        #                    'constant')
        # print(audio.shape)
        # exit()

        audio_placeholder = tf.placeholder(dtype=tf.float32)
        gc_placeholder = tf.placeholder(dtype=tf.int32)  \
            if self.global_conditioning else None
        lc_placeholder = tf.placeholder(dtype=tf.int32) \
            if self.local_conditioning else None

        loss = self.net.loss(input_batch=audio_placeholder,
                             global_condition_batch=gc_placeholder,
                             local_condition_batch=lc_placeholder)
        self.net.batch_size = 1
        validation = self.net.loss(input_batch=audio_placeholder,
                                   global_condition_batch=gc_placeholder,
                                   local_condition_batch=lc_placeholder)
        self.net.batch_size = 3
        optimizer = optimizer_factory[self.optimizer_type](
            learning_rate=self.learning_rate, momentum=self.momentum)
        trainable = tf.trainable_variables()
        optim = optimizer.minimize(loss, var_list=trainable)
        init = tf.global_variables_initializer()

        generated_waveform = None
        max_allowed_loss = 0.1
        loss_val = max_allowed_loss
        initial_loss = None
        operations = [loss, optim]
        with self.test_session() as sess:
            # feed_dict, speaker_index, audio, gc, lc  = CreateTrainingFeedDict(
            #     audio, gc, lc, audio_placeholder, gc_placeholder, lc_placeholder, 0)
            sess.run(init)
            # print("shape check 2")
            # print(audio.shape)
            # print(lc.shape)
            # print(gc.shape)
            # print(feed_dict[audio_placeholder].shape)
            # print(feed_dict[gc_placeholder].shape)
            # print(feed_dict[lc_placeholder].shape)
            # initial_loss = sess.run(loss, feed_dict=feed_dict)

            _gc = np.zeros(3)
            """validation data"""
            lc_1 = np.full(900, 1)
            lc_2 = np.full(900, 2)
            lc_3 = np.full(900, 3)
            val_lc = np.zeros((1, 900))
            val_lc[0, :300] = lc_1[:300]
            val_lc[0, 300:600] = lc_2[300:600]
            val_lc[0, 600:] = lc_3[600:900]
            val_lc = np.pad(val_lc, ((0, 0), (receptive_field - 1, 0)),
                            'constant')
            sample_period = 1.0 / SAMPLE_RATE_HZ
            times = np.arange(0.0, SAMPLE_DURATION, sample_period)
            note1 = 0.6 * np.sin(times * 2.0 * np.pi * F1)
            note2 = 0.5 * np.sin(times * 2.0 * np.pi * F2)
            note3 = 0.4 * np.sin(times * 2.0 * np.pi * F3)
            val_audio = np.zeros((1, 900))
            val_audio[0, :300] = note1[:300]
            val_audio[0, 300:600] = note2[300:600]
            val_audio[0, 600:] = note3[600:900]
            val_audio = np.pad(val_audio, ((0, 0), (receptive_field - 1, 0)),
                               'constant')
            val_list = []
            error_list = []

            for i in range(self.train_iters):
                # for lc_index in range(3):
                #     current_audio = audio[:, int(lc_index * (waveform_size / 3)): int(
                #         (lc_index + 1) * (waveform_size / 3) + self.net.receptive_field)]
                #     # print(current_audio.shape)
                #     current_lc = lc[:, int(lc_index * (waveform_size / 3)): int(
                #         (lc_index + 1) * (waveform_size / 3) + self.net.receptive_field)]
                #
                #     [results] = sess.run([operations],
                #                          feed_dict={audio_placeholder: current_audio, lc_placeholder: current_lc,
                #                                     gc_placeholder: gc})
                self.net.batch_size = 3
                a = 0
                current_audio = audio[i % 10]
                current_lc = lc[i % 10]
                duration_list = duration_lists[i % 10]
                start_time = 0
                error_total = 0
                for duration in duration_list:
                    _audio = current_audio[:, start_time:duration +
                                           receptive_field]
                    _lc = current_lc[:, start_time:duration + receptive_field]
                    start_time = duration

                    [results] = sess.run(
                        [operations],
                        feed_dict={
                            audio_placeholder: _audio,
                            lc_placeholder: _lc,
                            gc_placeholder: _gc
                        })
                    error_total += results[0]
                # feed_dict, speaker_index, audio, gc, lc = CreateTrainingFeedDict(
                #     audio, gc, lc, audio_placeholder, gc_placeholder, lc_placeholder, i)
                # [results] = sess.run([operations], feed_dict=feed_dict)
                if i % 10 == 0:
                    print("i: %d loss: %f" % (i, results[0]))
                    error_list.append(error_total / len(duration_list))

                if i % 10 == 0:
                    self.net.batch_size = 1
                    validation_score = 0
                    for i in range(3):

                        validation_score += sess.run(
                            validation,
                            feed_dict={
                                audio_placeholder:
                                val_audio[:, 300 * i:300 * (i + 1) +
                                          receptive_field],
                                lc_placeholder:
                                val_lc[:, 300 * i:300 * (i + 1) +
                                       receptive_field],
                                gc_placeholder:
                                _gc[0]
                            })
                    val_list.append(validation_score / 3)
                    print("i: %d validation: %f" % (i, validation_score / 3))

            with open('complicated_error.pkl', 'wb') as f1:
                pickle.dump(error_list, f1)

            with open('complicated_validation.pkl', 'wb') as f2:
                pickle.dump(val_list, f2)
            loss_val = results[0]

            # Sanity check the initial loss was larger.
            # self.assertGreater(initial_loss, max_allowed_loss)

            # Loss after training should be small.
            # self.assertLess(loss_val, max_allowed_loss)

            # Loss should be at least two orders of magnitude better
            # than before training.
            # self.assertLess(loss_val / initial_loss, 0.02)

            if self.generate:
                # self._save_net(sess)
                if self.global_conditioning and not self.local_conditioning:
                    # Check non-fast-generated waveform.
                    generated_waveforms, ids = generate_waveforms(
                        # sess, self.net, True, speaker_ids)
                        sess,
                        self.net,
                        True,
                        np.array((0, )))
                    for (waveform, id) in zip(generated_waveforms, ids):
                        # check_waveform(self.assertGreater, waveform, id[0])
                        check_waveform(self.assertGreater, waveform, id)

                elif self.global_conditioning and self.local_conditioning:
                    lc_0 = np.full(int(GENERATE_SAMPLES / 3), 1)
                    lc_1 = np.full(int(GENERATE_SAMPLES / 3), 2)
                    lc_2 = np.full(int(GENERATE_SAMPLES / 3), 3)
                    lc = np.concatenate((lc_0, lc_1, lc_2))
                    lc = lc.reshape((lc.shape[0], 1))
                    print(lc.shape)
                    """ * test * """
                    test = False
                    if test:
                        # compare_logits(sess, self.net, np.array((0,)), lc)
                        logits_fast, logits_slow = check_logits(
                            sess, self.net, np.array((0, )), lc)
                        np.save("../data/logits_fast", logits_fast)
                        np.save("../data/logits_slow", logits_slow)
                        # np.save("../data/proba_fast", proba_fast)
                        # np.save("../data/proba_slow", proba_slow)
                        exit()
                    # Check non-fast-generated waveform.
                    if self.generate_two_waves:
                        generated_waveforms, ids = generate_waveforms(
                            sess, self.net, True, np.array((0, 1)), lc)
                    else:
                        generated_waveforms, ids = generate_waveforms(
                            sess, self.net, True, np.array((0, )), lc)
                    for (waveform, id) in zip(generated_waveforms, ids):
                        # check_waveform(self.assertGreater, waveform, id[0])
                        if id == 0:
                            np.save("../data/wave_fast", waveform)
                            np.save("../data/lc_fast", lc)
                            # plot_waveform(waveform)
                        else:
                            np.save("../data/wave_t", waveform)

                    generated_waveforms, ids = generate_waveforms(
                        sess, self.net, False, np.array((0, )), lc[:, 0])

                    for (waveform, id) in zip(generated_waveforms, ids):
                        # check_waveform(self.assertGreater, waveform, id[0])
                        if id == 0:
                            np.save("../data/wave_slow", waveform)
                            np.save("../data/lc_slow", lc)
                            # plot_waveform(waveform)
                        else:
                            np.save("../data/wave_t", waveform)
                            # np.save("../data/lc", lc)
                        # plot_waveform4eachLC(waveform, lc)
                        # check_waveform(self.assertGreater, waveform, id)

                    # Check fast-generated wveform.
                    # generated_waveforms, ids = generate_waveforms(sess,
                    #     self.net, True, speaker_ids)
                    # for (waveform, id) in zip(generated_waveforms, ids):
                    #     print("Checking fast wf for id{}".format(id[0]))
                    #     check_waveform( self.assertGreater, waveform, id[0])

                else:
                    # Check non-incremental generation
                    generated_waveforms, _ = generate_waveforms(
                        sess, self.net, False, None)
                    check_waveform(self.assertGreater, generated_waveforms[0],
                                   None)
                    # Check incremental generation
                    generated_waveform = generate_waveforms(
                        sess, self.net, True, None)
                    check_waveform(self.assertGreater, generated_waveforms[0],
                                   None)
Beispiel #28
0
def main():
    args = get_arguments()

    try:
        directories = validate_directories(args)
    except ValueError as e:
        print("Some arguments are wrong:")
        print(str(e))
        return

    logdir = directories['logdir']
    restore_from = directories['restore_from']

    # Even if we restored the model, we will treat it as new training
    # if the trained model is written into an arbitrary location.
    is_overwritten_training = logdir != restore_from

    with open(args.wavenet_params, 'r') as f:
        wavenet_params = json.load(f)

    # Create coordinator.
    coord = tf.train.Coordinator()

    # Load raw stock data.
    with tf.name_scope('create_inputs'):
        reader = NumberReader(args.data_dir,
                              coord,
                              sample_size=args.sample_size)
        text_batch = reader.dequeue(args.batch_size)

    # Create network.
    net = WaveNetModel(
        batch_size=args.batch_size,
        dilations=wavenet_params["dilations"],
        filter_width=wavenet_params["filter_width"],
        residual_channels=wavenet_params["residual_channels"],
        dilation_channels=wavenet_params["dilation_channels"],
        skip_channels=wavenet_params["skip_channels"],
        quantization_channels=wavenet_params["quantization_channels"],
        use_biases=wavenet_params["use_biases"])
    if args.l2_regularization_strength == 0:
        args.l2_regularization_strength = None
    loss = net.loss(text_batch, args.l2_regularization_strength)
    optimizer = tf.train.AdamOptimizer(learning_rate=args.learning_rate)
    trainable = tf.trainable_variables()
    optim = optimizer.minimize(loss, var_list=trainable)

    # Set up session
    sess = tf.Session(config=tf.ConfigProto(log_device_placement=False))
    init = tf.initialize_all_variables()
    sess.run(init)

    # Saver for storing checkpoints of the model.
    saver = tf.train.Saver(max_to_keep=CHECKPOINTS_TO_KEEP)

    try:
        saved_global_step = load(saver, sess, restore_from)
        if is_overwritten_training or saved_global_step is None:
            # The first training step will be saved_global_step + 1,
            # therefore we put -1 here for new or overwritten trainings.
            saved_global_step = -1

    except:
        print("Something went wrong while restoring checkpoint. "
              "We will terminate training to avoid accidentally overwriting "
              "the previous model.")
        raise

    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    reader.start_threads(sess)

    step = saved_global_step + 1

    try:
        last_saved_step = saved_global_step
        while step < saved_global_step + args.num_steps:
            start_time = time.time()
            loss_value, _ = sess.run([loss, optim])
            print("done step", step)
            duration = time.time() - start_time
            print('step {:d} - loss = {:.3f}, ({:.3f} sec/step)'.format(
                step, loss_value, duration))

            if step % args.checkpoint_every == 0:
                save(saver, sess, logdir, step)
                last_saved_step = step
            step += 1

    except KeyboardInterrupt:
        # Introduce a line break after ^C is displayed so save message
        # is on its own line.
        print()
    finally:
        if step > last_saved_step:
            save(saver, sess, logdir, step)
        coord.request_stop()
        coord.join(threads)
Beispiel #29
0
class TestNet(tf.test.TestCase):
    def setUp(self):
        print('TestNet setup.')
        sys.stdout.flush()

        self.optimizer_type = 'sgd'
        self.learning_rate = 0.02
        self.generate = False
        self.momentum = MOMENTUM
        self.global_conditioning = False
        self.train_iters = TRAIN_ITERATIONS
        self.net = WaveNetModel(
            batch_size=1,
            dilations=[1, 2, 4, 8, 16, 32, 64, 1, 2, 4, 8, 16, 32, 64],
            filter_width=2,
            residual_channels=32,
            dilation_channels=32,
            quantization_channels=QUANTIZATION_CHANNELS,
            skip_channels=32,
            global_condition_channels=None,
            global_condition_cardinality=None)

    def _save_net(sess):
        saver = tf.train.Saver(var_list=tf.trainable_variables())
        saver.save(sess, '\tmp\test.ckpt')

    # Train a net on a short clip of 3 sine waves superimposed
    # (an e-flat chord).
    #
    # Presumably it can overfit to such a simple signal. This test serves
    # as a smoke test where we just check that it runs end-to-end during
    # training, and learns this waveform.

    def testEndToEndTraining(self):
        def CreateTrainingFeedDict(audio, speaker_ids, audio_placeholder,
                                   gc_placeholder, i):
            speaker_index = 0
            if speaker_ids is None:
                # No global conditioning.
                feed_dict = {audio_placeholder: audio}
            else:
                feed_dict = {
                    audio_placeholder: audio,
                    gc_placeholder: speaker_ids
                }
            return feed_dict, speaker_index

        np.random.seed(42)
        audio, speaker_ids = make_sine_waves(self.global_conditioning)

        # if self.generate:
        #     if len(audio.shape) == 2:
        #       for i in range(audio.shape[0]):
        #            librosa.output.write_wav(
        #                  '/tmp/sine_train{}.wav'.format(i), audio[i,:],
        #                  SAMPLE_RATE_HZ)
        #            power_spectrum = np.abs(np.fft.fft(audio[i,:]))**2
        #            freqs = np.fft.fftfreq(audio[i,:].size,
        #                                   SAMPLE_PERIOD_SECS)
        #            indices = np.argsort(freqs)
        #            indices = [index for index in indices if
        #                         freqs[index] >= 0 and
        #                         freqs[index] <= 500.0]
        #            plt.plot(freqs[indices], power_spectrum[indices])
        #            plt.show()

        audio_placeholder = tf.placeholder(dtype=tf.float32)
        gc_placeholder = tf.placeholder(dtype=tf.int32)  \
            if self.global_conditioning else None

        loss = self.net.loss(input_batch=audio_placeholder,
                             global_condition_batch=gc_placeholder)
        optimizer = optimizer_factory[self.optimizer_type](
            learning_rate=self.learning_rate, momentum=self.momentum)
        trainable = tf.trainable_variables()
        optim = optimizer.minimize(loss, var_list=trainable)
        init = tf.initialize_all_variables()

        generated_waveform = None
        max_allowed_loss = 0.1
        loss_val = max_allowed_loss
        initial_loss = None
        operations = [loss, optim]
        with self.test_session() as sess:
            feed_dict, speaker_index = CreateTrainingFeedDict(
                audio, speaker_ids, audio_placeholder, gc_placeholder, 0)
            sess.run(init)
            initial_loss = sess.run(loss, feed_dict=feed_dict)
            for i in range(self.train_iters):
                feed_dict, speaker_index = CreateTrainingFeedDict(
                    audio, speaker_ids, audio_placeholder, gc_placeholder, i)
                [results] = sess.run([operations], feed_dict=feed_dict)
                if i % 10 == 0:
                    print("i: %d loss: %f" % (i, results[0]))

            loss_val = results[0]

            # Sanity check the initial loss was larger.
            self.assertGreater(initial_loss, max_allowed_loss)

            # Loss after training should be small.
            self.assertLess(loss_val, max_allowed_loss)

            # Loss should be at least two orders of magnitude better
            # than before training.
            self.assertLess(loss_val / initial_loss, 0.02)

            if self.generate:
                # self._save_net(sess)
                if self.global_conditioning:
                    # Check non-fast-generated waveform.
                    generated_waveforms, ids = generate_waveforms(
                        sess, self.net, False, speaker_ids)
                    for (waveform, id) in zip(generated_waveforms, ids):
                        check_waveform(self.assertGreater, waveform, id[0])

                    # Check fast-generated wveform.
                    # generated_waveforms, ids = generate_waveforms(sess,
                    #     self.net, True, speaker_ids)
                    # for (waveform, id) in zip(generated_waveforms, ids):
                    #     print("Checking fast wf for id{}".format(id[0]))
                    #     check_waveform( self.assertGreater, waveform, id[0])

                else:
                    # Check non-incremental generation
                    generated_waveforms, _ = generate_waveforms(
                        sess, self.net, False, None)
                    check_waveform(self.assertGreater, generated_waveforms[0],
                                   None)
                    if not self.net.scalar_input:
                        # Check incremental generation
                        generated_waveform = generate_waveforms(
                            sess, self.net, True, None)
                        check_waveform(self.assertGreater,
                                       generated_waveforms[0], None)
Beispiel #30
0
class TestNet(tf.test.TestCase):
    def setUp(self):
        self.net = WaveNetModel(batch_size=1,
                                dilations=[1, 2, 4, 8, 16, 32, 64,
                                           1, 2, 4, 8, 16, 32, 64],
                                filter_width=2,
                                residual_channels=32,
                                dilation_channels=32,
                                quantization_channels=256,
                                skip_channels=32)
        self.optimizer_type = 'sgd'
        self.learning_rate = 0.02
        self.generate = False
        self.momentum = MOMENTUM

    # Train a net on a short clip of 3 sine waves superimposed
    # (an e-flat chord).
    #
    # Presumably it can overfit to such a simple signal. This test serves
    # as a smoke test where we just check that it runs end-to-end during
    # training, and learns this waveform.

    def testEndToEndTraining(self):
        audio = make_sine_waves()
        np.random.seed(42)

        # if self.generate:
        #    librosa.output.write_wav('/tmp/sine_train.wav', audio,
        #                             SAMPLE_RATE_HZ)
        #    power_spectrum = np.abs(np.fft.fft(audio))**2
        #    freqs = np.fft.fftfreq(audio.size, SAMPLE_PERIOD_SECS)
        #    indices = np.argsort(freqs)
        #    indices = [index for index in indices if freqs[index] >= 0 and
        #                                             freqs[index] <= 500.0]
        #    plt.plot(freqs[indices], power_spectrum[indices])
        #    plt.show()

        audio_tensor = tf.convert_to_tensor(audio, dtype=tf.float32)
        encode_output = mu_law_encode(audio_tensor, QUANTIZATION_CHANNELS)
        loss = self.net.loss(encode_output)
        optimizer = optimizer_factory[self.optimizer_type](
                      learning_rate=self.learning_rate, momentum=self.momentum)
        trainable = tf.trainable_variables()
        optim = optimizer.minimize(loss, var_list=trainable)
        init = tf.initialize_all_variables()

        generated_waveform = None
        max_allowed_loss = 0.1
        loss_val = max_allowed_loss
        initial_loss = None
        with self.test_session() as sess:
            sess.run(init)
            initial_loss = sess.run(loss)
            for i in range(TRAIN_ITERATIONS):
                loss_val, _ = sess.run([loss, optim])
                # if i % 10 == 0:
                #     print("i: %d loss: %f" % (i, loss_val))

            # Sanity check the initial loss was larger.
            self.assertGreater(initial_loss, max_allowed_loss)

            # Loss after training should be small.
            self.assertLess(loss_val, max_allowed_loss)

            # Loss should be at least two orders of magnitude better
            # than before training.
            self.assertLess(loss_val / initial_loss, 0.01)

            # saver = tf.train.Saver(var_list=tf.trainable_variables())
            # saver.save(sess, '/tmp/sine_test_model.ckpt', global_step=i)
            if self.generate:
                # Check non-incremental generation
                generated_waveform = generate_waveform(sess, self.net, False)
                check_waveform(self.assertGreater, generated_waveform)

                # Check incremental generation
                generated_waveform = generate_waveform(sess, self.net, True)
                check_waveform(self.assertGreater, generated_waveform)
Beispiel #31
0
def main():
    # Get default and command line parameters
    args = get_arguments()

    # Get logdir
    try:
        directories = validate_directories(args)
    except ValueError as e:
        print("Some arguments are wrong:")
        print(str(e))
        return

    logdir = directories['logdir']
    logdir_init = directories['logdir_init']
    restore_from = directories['restore_from']
    restore_from_init = directories['restore_from_init']

    # Lambda for white noise sampler
    gi_sampler = get_generator_input_sampler()

    # Some TensorFlow setup variables
    sess = tf.Session()
    coord = tf.train.Coordinator()

    # White noise generation and verification

    # White noise generator params
    white_mean = 0
    white_sigma = 1
    white_length = 20234

    white_noise = gi_sampler(white_mean, white_sigma, white_length)
    if args.view_initial_white:
        plt.plot(white_noise)
        plt.ylabel('Amplitude')
        plt.xlabel('Time')
        plt.show()

    # Load parameters from wavenet params json file
    with open(args.wavenet_params, 'r') as f:
        wavenet_params = json.load(f)  

    # Initialize generator WaveNet
    G = WaveNetModel(
        batch_size=1,
        dilations=wavenet_params["dilations"],
        filter_width=wavenet_params["filter_width"],
        residual_channels=wavenet_params["residual_channels"],
        dilation_channels=wavenet_params["dilation_channels"],
        skip_channels=wavenet_params["skip_channels"],
        quantization_channels=wavenet_params["quantization_channels"],
        use_biases=wavenet_params["use_biases"],
        initial_filter_width=wavenet_params["initial_filter_width"])

    # Calculate loss for white noise input
    # loss = G.loss(input_batch=tf.convert_to_tensor(white_noise, dtype=np.float32), name='generator')
    result = G.loss(input_batch=tf.convert_to_tensor(white_noise, dtype=np.float32), name='generator')
    loss = result['loss']
    output = result['output']
    optimizer = optimizer_factory[args.optimizer](
                    learning_rate=args.learning_rate,
                    momentum=args.momentum)
    trainable = tf.trainable_variables()
    optim = optimizer.minimize(loss, var_list=trainable)

    # Saver for storing checkpoints of the model.
    saver = tf.train.Saver(var_list=tf.trainable_variables(), max_to_keep=args.max_checkpoints)

    init = tf.global_variables_initializer()
    sess.run(init)

    try:
        init_step = load(saver, sess, restore_from_init)
        if init_step is None:
            init_step = -1

    except:
        print("Something went wrong while restoring checkpoint. "
              "We will terminate training to avoid accidentally overwriting "
              "the previous model.")
        raise

    if init_step == -1:
        print('--------- Begin dummy weight setup ---------')
        start_time = time.time()
        loss_value, _, output_value = sess.run([loss, optim, output])
        duration = time.time() - start_time
        print('loss = {:.3f}, ({:.3f} sec)'.format(loss_value, duration))

    else: 
        print('---------- Loading initial weight ----------')
        print('... Done')
Beispiel #32
0
def main():
    args = get_arguments()

    # Load parameters from wavenet params json file
    with open(args.wavenet_params, 'r') as f:
        wavenet_params = json.load(f)  

    quantization_channels = wavenet_params['quantization_channels']

    with tf.Graph().as_default():
        coord = tf.train.Coordinator()
        sess = tf.Session()

        # Lambda for white noise sampler
        gi_sampler = get_generator_input_sampler()

        # Intialize generator WaveNet
        G = WaveNetModel(
            batch_size=1,
            dilations=wavenet_params["dilations"],
            filter_width=wavenet_params["filter_width"],
            residual_channels=wavenet_params["residual_channels"],
            dilation_channels=wavenet_params["dilation_channels"],
            skip_channels=wavenet_params["skip_channels"],
            quantization_channels=wavenet_params["quantization_channels"],
            use_biases=wavenet_params["use_biases"],
            initial_filter_width=wavenet_params["initial_filter_width"])

        gi_sampler = get_generator_input_sampler()

        # White noise generator params
        white_mean = 0
        white_sigma = 1
        white_length = ffnn.INPUT_SIZE

        white_noise = gi_sampler(white_mean, white_sigma, white_length)
        white_noise = process(white_noise, quantization_channels, 1)
        white_noise_t = tf.convert_to_tensor(white_noise)

        directory = './sampleTrue'
        reader = AudioReader(directory, coord, sample_rate = 16000, gc_enabled=False, receptive_field=5117, sample_size=15117, silence_threshold=0.05)
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        reader.start_threads(sess)

        audio_batch = reader.dequeue(1)

        # initialize generator
        w_loss, w_prediction = G.loss(input_batch=white_noise_t, name='generator')
        #w_loss, w_prediction = G.loss(input_batch=audio_batch, name='generator')

        G_variables = tf.trainable_variables(scope='wavenet')
        optimizer = optimizer_factory[args.optimizer](
                    learning_rate=1e-3,
                    momentum=args.momentum)
        optim = optimizer.minimize(w_loss, var_list=G_variables)

        init = tf.global_variables_initializer()
        sess.run(init)

        '''
        for step in range(300):
            loss_value, _ = sess.run([w_loss, optim])
            print('step {:d} - loss = {:.3f}'.format(step, loss_value))

        prediction = sess.run(w_prediction)
        '''

        '''
        maxs = []
        maxs_2 = []
        maxs_3 = []

        for i in range(0, 10000):
            temp = prediction[i]
            temp.sort()
            maxs_3.append(temp[253])
            maxs_2.append(temp[254])
            maxs.append(temp[255])
        
        plt.plot(maxs)
        plt.plot(maxs_2)
        plt.plot(maxs_3)
        plt.ylabel('Value')
        plt.xlabel('Sample')
        plt.savefig('logits_after.png')
        
        np.set_printoptions(threshold=np.nan)
        print(sess.run(tf.nn.softmax(w_prediction)))
        ''' 
        
        '''
Beispiel #33
0
def main():
    args = get_arguments()

    with open(args.model_params, 'r') as f:
        model_params = json.load(f)

    with open(args.training_params, 'r') as f:
        train_params = json.load(f)

    try:
        directories = validate_directories(args)
    except ValueError as e:
        print('Some arguments are wrong:')
        print(str(e))
        return

    logdir = directories['logdir']
    restore_from = directories['restore_from']

    # Even if we restored the model, we will treat it as new training
    # if the trained model is written into an arbitrary location.
    is_overwritten_training = logdir != restore_from

    receptive_field = WaveNetModel.calculate_receptive_field(
        model_params['filter_width'], model_params['dilations'],
        model_params['initial_filter_width'])
    # Save arguments and model params into file
    save_run_config(args, receptive_field, STARTED_DATESTRING, logdir)

    # Create coordinator.
    coord = tf.train.Coordinator()

    # Create data loader.
    with tf.name_scope('create_inputs'):
        reader = WavMidReader(data_dir=args.data_dir_train,
                              coord=coord,
                              audio_sample_rate=model_params['audio_sr'],
                              receptive_field=receptive_field,
                              velocity=args.velocity,
                              sample_size=args.sample_size,
                              queues_size=(10, 10 * args.batch_size))
        data_batch = reader.dequeue(args.batch_size)

    # Create model.
    net = WaveNetModel(
        batch_size=args.batch_size,
        dilations=model_params['dilations'],
        filter_width=model_params['filter_width'],
        residual_channels=model_params['residual_channels'],
        dilation_channels=model_params['dilation_channels'],
        skip_channels=model_params['skip_channels'],
        output_channels=model_params['output_channels'],
        use_biases=model_params['use_biases'],
        initial_filter_width=model_params['initial_filter_width'])

    input_data = tf.placeholder(dtype=tf.float32,
                                shape=(args.batch_size, None, 1))
    input_labels = tf.placeholder(dtype=tf.float32,
                                  shape=(args.batch_size, None,
                                         model_params['output_channels']))

    loss, probs = net.loss(input_data=input_data,
                           input_labels=input_labels,
                           pos_weight=train_params['pos_weight'],
                           l2_reg_str=train_params['l2_reg_str'])
    optimizer = optimizer_factory[args.optimizer](
        learning_rate=train_params['learning_rate'],
        momentum=train_params['momentum'])
    trainable = tf.trainable_variables()
    optim = optimizer.minimize(loss, var_list=trainable)

    # Set up logging for TensorBoard.
    writer = tf.summary.FileWriter(logdir)
    writer.add_graph(tf.get_default_graph())
    run_metadata = tf.RunMetadata()
    summaries = tf.summary.merge_all()
    histograms = tf.summary.merge_all(key=HKEY)

    # Separate summary ops for validation, since they are
    # calculated only once per evaluation cycle.
    with tf.name_scope('validation_summaries'):

        metric_summaries = metrics_empty_dict()
        metric_value = tf.placeholder(tf.float32)
        for name in metric_summaries.keys():
            metric_summaries[name] = tf.summary.scalar(name, metric_value)

        images_buffer = tf.placeholder(tf.string)
        images_batch = tf.stack([
            tf.image.decode_png(images_buffer[0], channels=4),
            tf.image.decode_png(images_buffer[1], channels=4),
            tf.image.decode_png(images_buffer[2], channels=4)
        ])
        images_summary = tf.summary.image('estim', images_batch)

        audio_data = tf.placeholder(tf.float32)
        audio_summary = tf.summary.audio('input', audio_data,
                                         model_params['audio_sr'])

    # Set up session
    sess = tf.Session(config=tf.ConfigProto(log_device_placement=False))
    init = tf.global_variables_initializer()
    sess.run(init)

    # Saver for storing checkpoints of the model.
    saver = tf.train.Saver(var_list=tf.trainable_variables(),
                           max_to_keep=args.max_checkpoints)

    # Trainer for keeping best validation-performing model
    # and optional early stopping.
    trainer = Trainer(sess, logdir, train_params['early_stop_limit'], 0.999)

    try:
        saved_global_step = load(saver, sess, restore_from)
        if is_overwritten_training or saved_global_step is None:
            # The first training step will be saved_global_step + 1,
            # therefore we put -1 here for new or overwritten trainings.
            saved_global_step = -1

    except:
        print('Something went wrong while restoring checkpoint. '
              'Training will be terminated to avoid accidentally '
              'overwriting the previous model.')
        raise

    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    reader.start_threads(sess)

    step = None
    last_saved_step = saved_global_step
    try:
        for step in range(saved_global_step + 1, train_params['num_steps']):
            waveform, pianoroll = sess.run([data_batch[0], data_batch[1]])
            feed_dict = {input_data: waveform, input_labels: pianoroll}
            # Reload switches from file on each step
            with open(RUNTIME_SWITCHES, 'r') as f:
                switch = json.load(f)

            start_time = time.time()
            if switch['store_meta'] and step % switch['store_every'] == 0:
                # Slow run that stores extra information for debugging.
                print('Storing metadata')
                run_options = tf.RunOptions(
                    trace_level=tf.RunOptions.FULL_TRACE)
                summary, loss_value, _ = sess.run([summaries, loss, optim],
                                                  feed_dict=feed_dict,
                                                  options=run_options,
                                                  run_metadata=run_metadata)
                writer.add_summary(summary, step)
                writer.add_run_metadata(run_metadata,
                                        'step_{:04d}'.format(step))
                tl = timeline.Timeline(run_metadata.step_stats)
                timeline_path = os.path.join(logdir, 'timeline.trace')
                with open(timeline_path, 'w') as f:
                    f.write(tl.generate_chrome_trace_format(show_memory=True))
            else:
                summary, loss_value, _ = sess.run([summaries, loss, optim],
                                                  feed_dict=feed_dict)
                writer.add_summary(summary, step)

            duration = time.time() - start_time
            print('step {:d} - loss = {:.3f}, ({:.3f} sec/step)'.format(
                step, loss_value, duration))

            if step % switch['checkpoint_every'] == 0:
                save(saver, sess, logdir, step)
                last_saved_step = step

            # Evaluate model performance on validation data
            if step % switch['evaluate_every'] == 0:
                if switch['histograms']:
                    hist_summary = sess.run(histograms)
                    writer.add_summary(hist_summary, step)
                print('evaluating...')
                stats = 0, 0, 0, 0, 0, 0
                est = np.empty([0, model_params['output_channels']])
                ref = np.empty([0, model_params['output_channels']])

                b_data, b_labels, b_cntr = (np.empty(
                    (0, args.sample_size + receptive_field - 1,
                     1)), np.empty((0, model_params['output_channels'])),
                                            args.batch_size)

                # if (batch_size * sample_size > valid_data) single_pass() again
                while est.size == 0:  # and ref.size == 0 and sum(stats) == 0 ...

                    for data, labels in reader.single_pass(
                            sess, args.data_dir_valid):

                        # cumulate batch
                        if b_cntr > 1:
                            b_data, b_labels, decr = cumulateBatch(
                                data, labels, b_data, b_labels)
                            b_cntr -= decr
                            continue
                        elif args.batch_size > 1:
                            b_data, b_labels, decr = cumulateBatch(
                                data, labels, b_data, b_labels)
                            if not decr:
                                continue
                            data = b_data
                            labels = b_labels
                            # reset batch cumulation variables
                            b_data, b_labels, b_cntr = (
                                np.empty(
                                    (0, args.sample_size + receptive_field - 1,
                                     1)),
                                np.empty((0, model_params['output_channels'])),
                                args.batch_size)

                        predictions = sess.run(probs,
                                               feed_dict={input_data: data})
                        # Aggregate sums for metrics calculation
                        stats_chunk = calc_stats(predictions, labels,
                                                 args.threshold)
                        stats = tuple(
                            [sum(x) for x in zip(stats, stats_chunk)])
                        est = np.append(est, predictions, axis=0)
                        ref = np.append(ref, labels, axis=0)

                metrics = calc_metrics(None, None, None, stats=stats)
                write_metrics(metrics, metric_summaries, metric_value, writer,
                              step, sess)
                trainer.check(metrics['f1_measure'])

                # Render evaluation results
                if switch['log_image'] or switch['log_sound']:
                    sub_fac = int(model_params['audio_sr'] / switch['midi_sr'])
                    est = roll_subsample(est.T, sub_fac)
                    ref = roll_subsample(ref.T, sub_fac)
                if switch['log_image']:
                    write_images(est, ref, switch['midi_sr'], args.threshold,
                                 (8, 6), images_summary, images_buffer, writer,
                                 step, sess)
                if switch['log_sound']:
                    write_audio(est, ref, switch['midi_sr'],
                                model_params['audio_sr'], 0.007, audio_summary,
                                audio_data, writer, step, sess)

    except KeyboardInterrupt:
        # Introduce a line break after ^C is displayed so save message
        # is on its own line.
        print()
    finally:
        if step > last_saved_step:
            save(saver, sess, logdir, step)
        coord.request_stop()
        coord.join(threads)
        flush_n_close(writer, sess)
Beispiel #34
0
def main():
    args = get_arguments()

    try:
        directories = validate_directories(args)
    except ValueError as e:
        print("Some arguments are wrong:")
        print(str(e))
        return

    logdir = directories['logdir']
    restore_from = directories['restore_from']

    # Even if we restored the model, we will treat it as new training
    # if the trained model is written into an arbitrary location.
    is_overwritten_training = logdir != restore_from

    with open(args.wavenet_params, 'r') as f:
        wavenet_params = json.load(f)

    # Create coordinator.
    coord = tf.train.Coordinator()

    with tf.name_scope('create_inputs'):
        # Allow silence trimming to be skipped by specifying a threshold near
        # zero.
        silence_threshold = args.silence_threshold if args.silence_threshold > \
                                                      EPSILON else None
        gc_enabled = args.gc_channels is not None
        reader = AudioReader(
            audio_dir=args.data_dir,
            coord=coord,
            sample_rate=wavenet_params["sample_rate"],
            gc_enabled=gc_enabled,
            receptive_field=WaveNetModel.calculate_receptive_field(
                wavenet_params["filter_width"], wavenet_params["dilations"],
                wavenet_params["scalar_input"],
                wavenet_params["initial_filter_width"]),
            sample_size=args.sample_size,
            mfsc_dim=wavenet_params["MFSC_channels"],
            ap_dim=wavenet_params["AP_channels"],
            F0_dim=wavenet_params["F0_channels"],
            phone_dim=wavenet_params["phones_channels"],
            phone_pos_dim=wavenet_params["phone_pos_channels"],
            silence_threshold=silence_threshold)

        ap_batch, lc_batch = reader.dequeue(args.batch_size)
        # print ("mfsc_batch_shape:", mfsc_batch.get_shape().as_list())
        if gc_enabled:
            gc_id_batch = reader.dequeue_gc(args.batch_size)
        else:
            gc_id_batch = None

    # Create network.
    net = WaveNetModel(
        batch_size=args.batch_size,
        dilations=wavenet_params["dilations"],
        filter_width=wavenet_params["filter_width"],
        residual_channels=wavenet_params["residual_channels"],
        dilation_channels=wavenet_params["dilation_channels"],
        skip_channels=wavenet_params["skip_channels"],
        use_biases=wavenet_params["use_biases"],
        scalar_input=wavenet_params["scalar_input"],
        initial_filter_width=wavenet_params["initial_filter_width"],
        histograms=args.histograms,
        global_condition_channels=args.gc_channels,
        global_condition_cardinality=reader.gc_category_cardinality,
        MFSC_channels=wavenet_params["MFSC_channels"],
        F0_channels=wavenet_params["F0_channels"],
        phone_channels=wavenet_params["phones_channels"],
        phone_pos_channels=wavenet_params["phone_pos_channels"])

    if args.l2_regularization_strength == 0:
        args.l2_regularization_strength = None
    # pdb.set_trace()
    loss = net.loss(
        input_batch=
        ap_batch,  # audio_batch shape: [receptive_filed + sample_size, 1]
        lc_batch=lc_batch,
        global_condition_batch=gc_id_batch,  # gc_id_batch shape: scalar
        l2_regularization_strength=args.l2_regularization_strength)
    optimizer = optimizer_factory[args.optimizer](
        learning_rate=args.learning_rate, momentum=args.momentum)
    trainable = tf.trainable_variables()
    optim = optimizer.minimize(loss, var_list=trainable)

    # Set up logging for TensorBoard.
    writer = tf.summary.FileWriter(logdir)
    writer.add_graph(tf.get_default_graph())
    run_metadata = tf.RunMetadata()
    summaries = tf.summary.merge_all()

    # Set up session
    sess = tf.Session(config=tf.ConfigProto(log_device_placement=False))
    init = tf.global_variables_initializer()
    sess.run(init)

    # Saver for storing checkpoints of the model.
    saver = tf.train.Saver(var_list=tf.trainable_variables(),
                           max_to_keep=args.max_checkpoints)

    try:
        saved_global_step = load(saver, sess, restore_from)
        if is_overwritten_training or saved_global_step is None:
            # The first training step will be saved_global_step + 1,
            # therefore we put -1 here for new or overwritten trainings.
            saved_global_step = -1

    except:
        print("Something went wrong while restoring checkpoint. "
              "We will terminate training to avoid accidentally overwriting "
              "the previous model.")
        raise

    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    reader.start_threads(sess)
    print("========================================")
    print(
        "Total number of parameteres for mfsc model:",
        np.sum([
            np.prod(v.get_shape().as_list()) for v in tf.trainable_variables()
        ]))
    # pdb.set_trace()
    step = None
    last_saved_step = saved_global_step
    try:
        for step in range(saved_global_step + 1, args.num_steps):
            start_time = time.time()

            if args.store_metadata and step % 50 == 0:
                # Slow run that stores extra information for debugging.
                print('Storing metadata')
                run_options = tf.RunOptions(
                    trace_level=tf.RunOptions.FULL_TRACE)
                summary, loss_value, _ = sess.run([summaries, loss, optim],
                                                  options=run_options,
                                                  run_metadata=run_metadata)
                writer.add_summary(summary, step)
                writer.add_run_metadata(run_metadata,
                                        'step_{:04d}'.format(step))
                tl = timeline.Timeline(run_metadata.step_stats)
                timeline_path = os.path.join(logdir, 'timeline.trace')
                with open(timeline_path, 'w') as f:
                    f.write(tl.generate_chrome_trace_format(show_memory=True))
            else:
                summary, loss_value, _ = sess.run([summaries, loss, optim])
                writer.add_summary(summary, step)

            duration = time.time() - start_time
            if step % 10 == 0:
                print('step {:d} - loss = {:.3f}, ({:.3f} sec/step)'.format(
                    step, loss_value, duration))
            if step % args.checkpoint_every == 0:
                save(saver, sess, logdir, step)
                last_saved_step = step

    except KeyboardInterrupt:
        # Introduce a line break after ^C is displayed so save message
        # is on its own line.
        print()
    # finally:
    if step > last_saved_step:
        save(saver, sess, logdir, step)
    coord.request_stop()
    coord.join(threads)
class TestNet(tf.test.TestCase):
    def setUp(self):
        print('TestNet setup.')
        sys.stdout.flush()

        self.optimizer_type = 'sgd'
        self.learning_rate = 0.02
        self.generate = False
        self.momentum = MOMENTUM
        self.global_conditioning = False
        self.train_iters = TRAIN_ITERATIONS
        self.net = WaveNetModel(batch_size=1,
                                dilations=[1, 2, 4, 8, 16, 32, 64,
                                           1, 2, 4, 8, 16, 32, 64],
                                filter_width=2,
                                residual_channels=32,
                                dilation_channels=32,
                                quantization_channels=QUANTIZATION_CHANNELS,
                                skip_channels=32,
                                global_condition_channels=None,
                                global_condition_cardinality=None)

    def _save_net(sess):
        saver = tf.train.Saver(var_list=tf.trainable_variables())
        saver.save(sess, os.path.join('tmp', 'test.ckpt'))

    # Train a net on a short clip of 3 sine waves superimposed
    # (an e-flat chord).
    #
    # Presumably it can overfit to such a simple signal. This test serves
    # as a smoke test where we just check that it runs end-to-end during
    # training, and learns this waveform.

    def testEndToEndTraining(self):
        def CreateTrainingFeedDict(audio, speaker_ids, audio_placeholder,
                                   gc_placeholder, i):
            speaker_index = 0
            if speaker_ids is None:
                # No global conditioning.
                feed_dict = {audio_placeholder: audio}
            else:
                feed_dict = {audio_placeholder: audio,
                             gc_placeholder: speaker_ids}
            return feed_dict, speaker_index

        np.random.seed(42)
        audio, speaker_ids = make_sine_waves(self.global_conditioning)
        # Pad with 0s (silence) times size of the receptive field minus one,
        # because the first sample of the training data is 0 and if the network
        # learns to predict silence based on silence, it will generate only
        # silence.
        if self.global_conditioning:
            audio = np.pad(audio, ((0, 0), (self.net.receptive_field - 1, 0)),
                           'constant')
        else:
            audio = np.pad(audio, (self.net.receptive_field - 1, 0),
                           'constant')

        audio_placeholder = tf.placeholder(dtype=tf.float32)
        gc_placeholder = tf.placeholder(dtype=tf.int32)  \
            if self.global_conditioning else None

        loss = self.net.loss(input_batch=audio_placeholder,
                             global_condition_batch=gc_placeholder)
        optimizer = optimizer_factory[self.optimizer_type](
                      learning_rate=self.learning_rate, momentum=self.momentum)
        trainable = tf.trainable_variables()
        optim = optimizer.minimize(loss, var_list=trainable)
        init = tf.global_variables_initializer()

        generated_waveform = None
        max_allowed_loss = 0.1
        loss_val = max_allowed_loss
        initial_loss = None
        operations = [loss, optim]
        with self.test_session() as sess:
            feed_dict, speaker_index = CreateTrainingFeedDict(
                audio, speaker_ids, audio_placeholder, gc_placeholder, 0)
            sess.run(init)
            initial_loss = sess.run(loss, feed_dict=feed_dict)
            for i in range(self.train_iters):
                feed_dict, speaker_index = CreateTrainingFeedDict(
                    audio, speaker_ids, audio_placeholder, gc_placeholder, i)
                [results] = sess.run([operations], feed_dict=feed_dict)
                if i % 100 == 0:
                    print("i: %d loss: %f" % (i, results[0]))

            loss_val = results[0]

            # Sanity check the initial loss was larger.
            self.assertGreater(initial_loss, max_allowed_loss)

            # Loss after training should be small.
            self.assertLess(loss_val, max_allowed_loss)

            # Loss should be at least two orders of magnitude better
            # than before training.
            self.assertLess(loss_val / initial_loss, 0.02)

            if self.generate:
                # self._save_net(sess)
                if self.global_conditioning:
                    # Check non-fast-generated waveform.
                    generated_waveforms, ids = generate_waveforms(
                        sess, self.net, False, speaker_ids)
                    for (waveform, id) in zip(generated_waveforms, ids):
                        check_waveform(self.assertGreater, waveform, id[0])

                    # Check fast-generated wveform.
                    # generated_waveforms, ids = generate_waveforms(sess,
                    #     self.net, True, speaker_ids)
                    # for (waveform, id) in zip(generated_waveforms, ids):
                    #     print("Checking fast wf for id{}".format(id[0]))
                    #     check_waveform( self.assertGreater, waveform, id[0])

                else:
                    # Check non-incremental generation
                    generated_waveforms, _ = generate_waveforms(
                        sess, self.net, False, None)
                    check_waveform(
                        self.assertGreater, generated_waveforms[0], None)
                    # Check incremental generation
                    generated_waveform = generate_waveforms(
                        sess, self.net, True, None)
                    check_waveform(
                        self.assertGreater, generated_waveforms[0], None)