Esempio n. 1
0
def main():
    args = get_arguments()

    try:
        directories = validate_directories(args)
    except ValueError as e:
        print("Some arguments are wrong:")
        print(str(e))
        return

    logdir = directories['logdir']
    restore_from = directories['restore_from']

    # Even if we restored the model, we will treat it as new training
    # if the trained model is written into an arbitrary location.
    is_overwritten_training = logdir != restore_from

    with open(args.wavenet_params, 'r') as f:
        wavenet_params = json.load(f)

    # Create coordinator.
    coord = tf.train.Coordinator()

    # Load raw waveform from VCTK corpus.
    with tf.name_scope('create_inputs'):
        # Allow silence trimming to be skipped by specifying a threshold near
        # zero.
        silence_threshold = args.silence_threshold if args.silence_threshold > \
                                                      EPSILON else None
        gc_enabled = args.gc_channels is not None
        reader = AudioReader(
            args.data_dir,
            coord,
            sample_rate=wavenet_params['sample_rate'],
            gc_enabled=gc_enabled,
            receptive_field=WaveNetModel.calculate_receptive_field(
                wavenet_params["filter_width"], wavenet_params["dilations"],
                wavenet_params["scalar_input"],
                wavenet_params["initial_filter_width"]),
            sample_size=args.sample_size,
            silence_threshold=silence_threshold)
        audio_batch = reader.dequeue(args.batch_size)
        if gc_enabled:
            gc_id_batch = reader.dequeue_gc(args.batch_size)
        else:
            gc_id_batch = None

    # Create network.
    net = WaveNetModel(
        batch_size=args.batch_size,
        dilations=wavenet_params["dilations"],
        filter_width=wavenet_params["filter_width"],
        residual_channels=wavenet_params["residual_channels"],
        dilation_channels=wavenet_params["dilation_channels"],
        skip_channels=wavenet_params["skip_channels"],
        quantization_channels=wavenet_params["quantization_channels"],
        use_biases=wavenet_params["use_biases"],
        scalar_input=wavenet_params["scalar_input"],
        initial_filter_width=wavenet_params["initial_filter_width"],
        histograms=args.histograms,
        global_condition_channels=args.gc_channels,
        global_condition_cardinality=reader.gc_category_cardinality)

    if args.l2_regularization_strength == 0:
        args.l2_regularization_strength = None
    loss = net.loss(input_batch=audio_batch,
                    global_condition_batch=gc_id_batch,
                    l2_regularization_strength=args.l2_regularization_strength)
    optimizer = optimizer_factory[args.optimizer](
        learning_rate=args.learning_rate, momentum=args.momentum)
    trainable = tf.trainable_variables()
    optim = optimizer.minimize(loss, var_list=trainable)

    # Set up logging for TensorBoard.
    writer = tf.summary.FileWriter(logdir)
    writer.add_graph(tf.get_default_graph())
    run_metadata = tf.RunMetadata()
    summaries = tf.summary.merge_all()

    # Set up session

    sess = tf.Session(config=tf.ConfigProto(log_device_placement=False))

    init = tf.global_variables_initializer()
    sess.run(init)
    #sess = tf_debug.LocalCLIDebugWrapperSession(sess, thread_name_filter="MainThread$", dump_root="C:\\MProjects\\WaveNet\\tensorflow-wavenet-master\\debugDump")

    # run --node_name_filter wavenet_1/loss/Reshape_1:0 -- (36352, 256)
    # run --node_name_filter (.*loss.*)|(.*encode.*)
    # pt -a tensorName > C:/Users/russkov.alexander/Desktop/WaveNet/tensorflow-wavenet-master/myDebugInfo/file.txt
    #encoded_input = Tensor("wavenet_1/encode/ToInt32:0", shape=(1, ?, 1), dtype=int32)  -- (1, 59901, 1)
    #encoded = Tensor("wavenet_1/one_hot_encode/Reshape:0", shape=(1, ?, 256), dtype=float32) -- (1, 59901, 256)

    #https: // www.tensorflow.org / guide / debugger  # frequently_asked_questions
    #Q: The model I am debugging is very large. The data dumped by tfdbg fills up the free space of my disk. What can I do?
    #https: // github.com / tensorflow / tensorflow / issues / 8753
    #sess = tf_debug.TensorBoardDebugWrapperSession(sess, "RUSSKOV-NB-W10:6064", send_traceback_and_source_code=False)

    # Saver for storing checkpoints of the model.
    saver = tf.train.Saver(var_list=tf.trainable_variables(),
                           max_to_keep=args.max_checkpoints)

    try:
        saved_global_step = load(saver, sess, restore_from)
        if is_overwritten_training or saved_global_step is None:
            # The first training step will be saved_global_step + 1,
            # therefore we put -1 here for new or overwritten trainings.
            saved_global_step = -1

    except:
        print("Something went wrong while restoring checkpoint. "
              "We will terminate training to avoid accidentally overwriting "
              "the previous model.")
        raise

    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    reader.start_threads(sess)

    step = None
    last_saved_step = saved_global_step
    try:
        for step in range(saved_global_step + 1, args.num_steps):
            start_time = time.time()
            if args.store_metadata and step % 50 == 0:
                # Slow run that stores extra information for debugging.
                print('Storing metadata')
                run_options = tf.RunOptions(
                    trace_level=tf.RunOptions.FULL_TRACE)
                summary, loss_value, _ = sess.run([summaries, loss, optim],
                                                  options=run_options,
                                                  run_metadata=run_metadata)
                writer.add_summary(summary, step)
                writer.add_run_metadata(run_metadata,
                                        'step_{:04d}'.format(step))
                tl = timeline.Timeline(run_metadata.step_stats)
                timeline_path = os.path.join(logdir, 'timeline.trace')
                with open(timeline_path, 'w') as f:
                    f.write(tl.generate_chrome_trace_format(show_memory=True))
            else:
                summary, loss_value, _ = sess.run([summaries, loss, optim])
                writer.add_summary(summary, step)

            duration = time.time() - start_time
            print('step {:d} - loss = {:.3f}, ({:.3f} sec/step)'.format(
                step, loss_value, duration))

            if step % args.checkpoint_every == 0:
                save(saver, sess, logdir, step)
                last_saved_step = step

    except KeyboardInterrupt:
        # Introduce a line break after ^C is displayed so save message
        # is on its own line.
        print()
    finally:
        if step > last_saved_step:
            save(saver, sess, logdir, step)
        coord.request_stop()
        coord.join(threads)
Esempio n. 2
0
def main():
    args = get_arguments()

    try:
        directories = validate_directories(args)
    except ValueError as e:
        print("Some arguments are wrong:")
        print(str(e))
        return

    logdir = directories['logdir']
    restore_from = directories['restore_from']

    # Even if we restored the model, we will treat it as new training
    # if the trained model is written into an arbitrary location.
    is_overwritten_training = logdir != restore_from

    with open(args.wavenet_params, 'r') as f:
        wavenet_params = json.load(f)

    # Create coordinator.
    coord = tf.train.Coordinator()

    with tf.name_scope('create_inputs'):
        # Allow silence trimming to be skipped by specifying a threshold near
        # zero.
        silence_threshold = args.silence_threshold if args.silence_threshold > \
                                                      EPSILON else None
        gc_enabled = args.gc_channels is not None
        reader = AudioReader(
            audio_dir=args.data_dir,
            coord=coord,
            sample_rate=wavenet_params["sample_rate"],
            gc_enabled=gc_enabled,
            receptive_field=WaveNetModel.calculate_receptive_field(
                wavenet_params["filter_width"], wavenet_params["dilations"],
                wavenet_params["scalar_input"],
                wavenet_params["initial_filter_width"]),
            sample_size=args.sample_size,
            mfsc_dim=wavenet_params["MFSC_channels"],
            ap_dim=wavenet_params["AP_channels"],
            F0_dim=wavenet_params["F0_channels"],
            phone_dim=wavenet_params["phones_channels"],
            phone_pos_dim=wavenet_params["phone_pos_channels"],
            silence_threshold=silence_threshold)

        ap_batch, lc_batch = reader.dequeue(args.batch_size)
        # print ("mfsc_batch_shape:", mfsc_batch.get_shape().as_list())
        if gc_enabled:
            gc_id_batch = reader.dequeue_gc(args.batch_size)
        else:
            gc_id_batch = None

    # Create network.
    net = WaveNetModel(
        batch_size=args.batch_size,
        dilations=wavenet_params["dilations"],
        filter_width=wavenet_params["filter_width"],
        residual_channels=wavenet_params["residual_channels"],
        dilation_channels=wavenet_params["dilation_channels"],
        skip_channels=wavenet_params["skip_channels"],
        use_biases=wavenet_params["use_biases"],
        scalar_input=wavenet_params["scalar_input"],
        initial_filter_width=wavenet_params["initial_filter_width"],
        histograms=args.histograms,
        global_condition_channels=args.gc_channels,
        global_condition_cardinality=reader.gc_category_cardinality,
        MFSC_channels=wavenet_params["MFSC_channels"],
        F0_channels=wavenet_params["F0_channels"],
        phone_channels=wavenet_params["phones_channels"],
        phone_pos_channels=wavenet_params["phone_pos_channels"])

    if args.l2_regularization_strength == 0:
        args.l2_regularization_strength = None
    # pdb.set_trace()
    loss = net.loss(
        input_batch=
        ap_batch,  # audio_batch shape: [receptive_filed + sample_size, 1]
        lc_batch=lc_batch,
        global_condition_batch=gc_id_batch,  # gc_id_batch shape: scalar
        l2_regularization_strength=args.l2_regularization_strength)
    optimizer = optimizer_factory[args.optimizer](
        learning_rate=args.learning_rate, momentum=args.momentum)
    trainable = tf.trainable_variables()
    optim = optimizer.minimize(loss, var_list=trainable)

    # Set up logging for TensorBoard.
    writer = tf.summary.FileWriter(logdir)
    writer.add_graph(tf.get_default_graph())
    run_metadata = tf.RunMetadata()
    summaries = tf.summary.merge_all()

    # Set up session
    sess = tf.Session(config=tf.ConfigProto(log_device_placement=False))
    init = tf.global_variables_initializer()
    sess.run(init)

    # Saver for storing checkpoints of the model.
    saver = tf.train.Saver(var_list=tf.trainable_variables(),
                           max_to_keep=args.max_checkpoints)

    try:
        saved_global_step = load(saver, sess, restore_from)
        if is_overwritten_training or saved_global_step is None:
            # The first training step will be saved_global_step + 1,
            # therefore we put -1 here for new or overwritten trainings.
            saved_global_step = -1

    except:
        print("Something went wrong while restoring checkpoint. "
              "We will terminate training to avoid accidentally overwriting "
              "the previous model.")
        raise

    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    reader.start_threads(sess)
    print("========================================")
    print(
        "Total number of parameteres for mfsc model:",
        np.sum([
            np.prod(v.get_shape().as_list()) for v in tf.trainable_variables()
        ]))
    # pdb.set_trace()
    step = None
    last_saved_step = saved_global_step
    try:
        for step in range(saved_global_step + 1, args.num_steps):
            start_time = time.time()

            if args.store_metadata and step % 50 == 0:
                # Slow run that stores extra information for debugging.
                print('Storing metadata')
                run_options = tf.RunOptions(
                    trace_level=tf.RunOptions.FULL_TRACE)
                summary, loss_value, _ = sess.run([summaries, loss, optim],
                                                  options=run_options,
                                                  run_metadata=run_metadata)
                writer.add_summary(summary, step)
                writer.add_run_metadata(run_metadata,
                                        'step_{:04d}'.format(step))
                tl = timeline.Timeline(run_metadata.step_stats)
                timeline_path = os.path.join(logdir, 'timeline.trace')
                with open(timeline_path, 'w') as f:
                    f.write(tl.generate_chrome_trace_format(show_memory=True))
            else:
                summary, loss_value, _ = sess.run([summaries, loss, optim])
                writer.add_summary(summary, step)

            duration = time.time() - start_time
            if step % 10 == 0:
                print('step {:d} - loss = {:.3f}, ({:.3f} sec/step)'.format(
                    step, loss_value, duration))
            if step % args.checkpoint_every == 0:
                save(saver, sess, logdir, step)
                last_saved_step = step

    except KeyboardInterrupt:
        # Introduce a line break after ^C is displayed so save message
        # is on its own line.
        print()
    # finally:
    if step > last_saved_step:
        save(saver, sess, logdir, step)
    coord.request_stop()
    coord.join(threads)
def main():
    args = get_arguments()

    try:
        directories = validate_directories(args)
    except ValueError as e:
        print("Some arguments are wrong:")
        print(str(e))
        return

    logdir = directories['logdir']
    restore_from = directories['restore_from']

    # Even if we restored the model, we will treat it as new training
    # if the trained model is written into an arbitrary location.
    is_overwritten_training = logdir != restore_from

    with open(args.wavenet_params, 'r') as f:
        wavenet_params = json.load(f)

    # Create coordinator.
    coord = tf.train.Coordinator()

    # Load raw waveform from VCTK corpus.
    with tf.name_scope('create_inputs'):
        gc_enabled = args.gc_channels is not None
        reader = AudioReader(
            args.data_dir,
            coord,
            sample_rate=wavenet_params['sample_rate'],
            gc_enabled=gc_enabled,
            max_samples=get_max_samples(args.data_dir,
                                        wavenet_params['sample_rate']),
            receptive_field=WaveNetModel.calculate_receptive_field(
                wavenet_params["filter_width"], wavenet_params["dilations"],
                wavenet_params["scalar_input"],
                wavenet_params["initial_filter_width"]),
            sample_size=args.sample_size,
            silence_threshold=args.silence_threshold
            if args.silence_threshold > EPSILON else None)
        audio_batch = reader.dequeue(args.batch_size)
        if gc_enabled:
            gc_id_batch = reader.dequeue_gc(args.batch_size)
        else:
            gc_id_batch = None

    # Create network.
    net = WaveNetModel(
        batch_size=args.batch_size,
        dilations=wavenet_params["dilations"],
        filter_width=wavenet_params["filter_width"],
        residual_channels=wavenet_params["residual_channels"],
        dilation_channels=wavenet_params["dilation_channels"],
        skip_channels=wavenet_params["skip_channels"],
        quantization_channels=wavenet_params["quantization_channels"],
        use_biases=wavenet_params["use_biases"],
        scalar_input=wavenet_params["scalar_input"],
        initial_filter_width=wavenet_params["initial_filter_width"],
        histograms=args.histograms,
        global_condition_channels=args.gc_channels,
        global_condition_cardinality=reader.gc_category_cardinality)

    if args.l2_regularization_strength == 0:
        args.l2_regularization_strength = None
    loss = net.loss(input_batch=audio_batch,
                    global_condition_batch=gc_id_batch,
                    l2_regularization_strength=args.l2_regularization_strength)
    learning_rate_placeholder = tf.placeholder(tf.float32, [])
    optimizer = tf.train.RMSPropOptimizer(
        learning_rate=learning_rate_placeholder, momentum=args.momentum)
    train_op = optimizer.minimize(loss)

    # Set up logging for TensorBoard.
    writer = tf.summary.FileWriter(logdir)
    writer.add_graph(tf.get_default_graph())
    run_metadata = tf.RunMetadata()
    summaries = tf.summary.merge_all()

    # Set up session
    sess = tf.Session(config=tf.ConfigProto(log_device_placement=False))
    init = tf.global_variables_initializer()
    sess.run(init)

    # Saver for storing checkpoints of the model.
    saver = tf.train.Saver(var_list=tf.trainable_variables(),
                           max_to_keep=args.max_checkpoints)

    try:
        saved_global_step = load(saver, sess, restore_from)
        if is_overwritten_training or saved_global_step is None:
            # The first training step will be saved_global_step + 1,
            # therefore we put -1 here for new or overwritten trainings.
            saved_global_step = -1

    except:
        print("Something went wrong while restoring checkpoint. "
              "We will terminate training to avoid accidentally overwriting "
              "the previous model.")
        raise

    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    reader.start_threads(sess)

    step = None
    loss_value = None
    update = 0
    last_saved_step = saved_global_step
    learning_rate = args.learning_rate
    print('learning_rate {:f})'.format(learning_rate))
    try:
        for step in range(saved_global_step + 1, args.num_steps):
            start_time = time.time()

            if args.store_metadata and step % 50 == 0:
                # Slow run that stores extra information for debugging.
                print('Storing metadata')
                run_options = tf.RunOptions(
                    trace_level=tf.RunOptions.FULL_TRACE)
                summary, loss_value, _ = sess.run(
                    [summaries, loss, train_op],
                    feed_dict={learning_rate_placeholder: learning_rate},
                    options=run_options,
                    run_metadata=run_metadata)
                writer.add_summary(summary, step)
                writer.add_run_metadata(run_metadata,
                                        'step_{:04d}'.format(step))
                tl = timeline.Timeline(run_metadata.step_stats)
                timeline_path = os.path.join(logdir, 'timeline.trace')
                with open(timeline_path, 'w') as f:
                    f.write(tl.generate_chrome_trace_format(show_memory=True))
            else:
                summary, loss_value, _ = sess.run(
                    [summaries, loss, train_op],
                    feed_dict={learning_rate_placeholder: learning_rate})
                writer.add_summary(summary, step)

            if 1.5 >= loss_value > 0.5 and update == 0:
                learning_rate = learning_rate * 0.1
                update += 1
                print('learning_rate {:f})'.format(learning_rate))
            elif loss_value <= 0.5 and update == 1:
                learning_rate = learning_rate * 0.1
                update += 1
                print('learning_rate {:f})'.format(learning_rate))

            duration = time.time() - start_time
            print('step {:d} - loss = {:.3f}, ({:.3f} sec/step)'.format(
                step, loss_value, duration))

            if step % args.checkpoint_every == 0:
                save(saver, sess, logdir, step)
                last_saved_step = step

    except KeyboardInterrupt:
        # Introduce a line break after ^C is displayed so save message
        # is on its own line.
        print()
    finally:
        if step > last_saved_step:
            save(saver, sess, logdir, step)
        coord.request_stop()
        coord.join(threads)
Esempio n. 4
0
def main():
    args = get_arguments()

    try:
        directories = validate_directories(args)
    except ValueError as e:
        print("Some arguments are wrong:")
        print(str(e))
        return

    logdir = directories['logdir']
    restore_from = directories['restore_from']

    # Even if we restored the model, we will treat it as new training
    # if the trained model is written into an arbitrary location.
    is_overwritten_training = logdir != restore_from

    with open(args.wavenet_params, 'r') as f:
        wavenet_params = json.load(f)

    # Create coordinator.
    coord = tf.train.Coordinator()

    # Load raw waveform from VCTK corpus.
    with tf.name_scope('create_inputs'):
        # Allow silence trimming to be skipped by specifying a threshold near
        # zero.
        silence_threshold = args.silence_threshold if args.silence_threshold > \
                                                      EPSILON else None
        gc_enabled = args.gc_channels is not None
        reader = AudioReader(
            args.data_dir,
            coord,
            sample_rate=wavenet_params['sample_rate'],
            gc_enabled=gc_enabled,
            receptive_field=WaveNetModel.calculate_receptive_field(
                wavenet_params["filter_width"], wavenet_params["dilations"],
                wavenet_params["scalar_input"],
                wavenet_params["initial_filter_width"]),
            sample_size=args.sample_size,
            silence_threshold=silence_threshold)
        audio_batch = reader.dequeue(args.batch_size)
        if gc_enabled:
            gc_id_batch = reader.dequeue_gc(args.batch_size)
        else:
            gc_id_batch = None

    # Create network.
    net = WaveNetModel(
        batch_size=args.batch_size,
        dilations=wavenet_params["dilations"],
        filter_width=wavenet_params["filter_width"],
        residual_channels=wavenet_params["residual_channels"],
        dilation_channels=wavenet_params["dilation_channels"],
        skip_channels=wavenet_params["skip_channels"],
        quantization_channels=wavenet_params["quantization_channels"],
        use_biases=wavenet_params["use_biases"],
        scalar_input=wavenet_params["scalar_input"],
        initial_filter_width=wavenet_params["initial_filter_width"],
        histograms=args.histograms,
        global_condition_channels=args.gc_channels,
        global_condition_cardinality=reader.gc_category_cardinality)

    if args.l2_regularization_strength == 0:
        args.l2_regularization_strength = None
    #aleix
    loss, global_condition_batch, gc_embedding, conv_filter, conv_filter0, conv_filter1, conv_gate, \
    embedding_table, weights_gc_filter, input_batch = net.loss(input_batch=audio_batch,
                    global_condition_batch=gc_id_batch,
                    l2_regularization_strength=args.l2_regularization_strength)
    optimizer = optimizer_factory[args.optimizer](
        learning_rate=args.learning_rate, momentum=args.momentum)
    trainable = tf.trainable_variables()
    optim = optimizer.minimize(loss, var_list=trainable)

    # Set up logging for TensorBoard.
    writer = tf.summary.FileWriter(logdir)
    writer.add_graph(tf.get_default_graph())
    run_metadata = tf.RunMetadata()
    summaries = tf.summary.merge_all()

    # Set up session
    sess = tf.Session(config=tf.ConfigProto(log_device_placement=False))
    init = tf.global_variables_initializer()
    sess.run(init)

    # Saver for storing checkpoints of the model.
    saver = tf.train.Saver(var_list=tf.trainable_variables(),
                           max_to_keep=args.max_checkpoints)

    try:
        saved_global_step = load(saver, sess, restore_from)
        if is_overwritten_training or saved_global_step is None:
            # The first training step will be saved_global_step + 1,
            # therefore we put -1 here for new or overwritten trainings.
            saved_global_step = -1

    except:
        print("Something went wrong while restoring checkpoint. "
              "We will terminate training to avoid accidentally overwriting "
              "the previous model.")
        raise

    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    reader.start_threads(sess)

    step = None
    last_saved_step = saved_global_step
    loss_plot = []  #store loss function (aleix)
    try:
        for step in range(saved_global_step + 1, args.num_steps):
            start_time = time.time()
            if args.store_metadata and step % 50 == 0:
                # Slow run that stores extra information for debugging.
                print('Storing metadata')
                run_options = tf.RunOptions(
                    trace_level=tf.RunOptions.FULL_TRACE)
                summary, loss_value, _ = sess.run([summaries, loss, optim],
                                                  options=run_options,
                                                  run_metadata=run_metadata)
                writer.add_summary(summary, step)
                writer.add_run_metadata(run_metadata,
                                        'step_{:04d}'.format(step))
                tl = timeline.Timeline(run_metadata.step_stats)
                timeline_path = os.path.join(logdir, 'timeline.trace')
                with open(timeline_path, 'w') as f:
                    f.write(tl.generate_chrome_trace_format(show_memory=True))
            else:
                #aleix
                summary, loss_value, global_condition_batch0, gc_embedding0, conv_filter_end, conv_filter0_0, \
                conv_filter0_1, conv_gate0,embedding_table0, weights_gc_filter0,input_batch0, _ = sess.run([
                    summaries, loss, global_condition_batch, gc_embedding, conv_filter, conv_filter0, conv_filter1,
                    conv_gate, embedding_table, weights_gc_filter, input_batch, optim])
                #print('global_condition_batch:')
                #print(global_condition_batch0)
                #print(global_condition_batch0.shape)
                #print()
                #print('gc_embedding')
                #print(gc_embedding0)
                #print(gc_embedding0.shape)
                #print()
                #print('conv_filter')
                #print(conv_filter_end)
                #print(conv_filter_end.shape)
                #print()
                #print('conv_filter0')
                #print(conv_filter0_0)
                #print(conv_filter0_0.shape)
                #print()
                #print('conv_filter1')
                #print(conv_filter0_1)
                #print(conv_filter0_1.shape)
                #print()
                #print('conv_gate')
                #print(conv_gate0)
                #print(conv_gate0.shape)
                #print()
                #print('embedding_table')
                #print(embedding_table0)
                #print(embedding_table0.shape)
                #print(target_output00)
                #print(target_output00.shape)
                #print(target_output10)
                #print(target_output10.shape)
                #print()
                #print('weights_gc_filter')
                #print(weights_gc_filter0)
                #print(weights_gc_filter.shape)
                #print(input_batch0.shape)
                writer.add_summary(summary, step)

            duration = time.time() - start_time
            print('step {:d} - loss = {:.3f}, ({:.3f} sec/step)'.format(
                step, loss_value, duration))
            loss_plot.append(loss_value)
            if step % args.checkpoint_every == 0:
                save(saver, sess, logdir, step)
                last_saved_step = step
        plt.figure(1)  #store loss function (aleix)
        plt.plot(loss_plot)
        #plt.show()
        plt.savefig(os.path.join(args.data_dir, 'loss.png'))
        print()
        print('Loss .plot saved')
        file00 = open(os.path.join(args.data_dir, 'loss.txt'), 'w')
        for item in loss_plot:
            file00.write("%s\n" % item)
        file00.close()
        print('Loss .txt saved')
        print()
    except KeyboardInterrupt:
        plt.figure(1)  #store loss function (aleix)
        plt.plot(loss_plot)
        plt.savefig(os.path.join(args.data_dir, 'loss.png'))
        print()
        print('Loss plot saved')
        file00 = open(os.path.join(args.data_dir, 'loss.txt'), 'w')
        for item in loss_plot:
            file00.write("%s\n" % item)
        file00.close()
        print('Loss .txt saved')
        print()
        #plt.show()

        # Introduce a line break after ^C is displayed so save message
        # is on its own line.
        print()
    finally:
        if step > last_saved_step:
            save(saver, sess, logdir, step)
        coord.request_stop()
        coord.join(threads)
Esempio n. 5
0
def main():
    args = get_arguments()

    try:
        directories = validate_directories(args)
    except ValueError as e:
        print("Some arguments are wrong:")
        print(str(e))
        return

    logdir = directories['logdir']
    logdir_root = directories['logdir_root']
    restore_from = directories['restore_from']

    # Even if we restored the model, we will treat it as new training
    # if the trained model is written into an arbitrary location.
    is_overwritten_training = logdir != restore_from

    with open(args.wavenet_params, 'r') as f:
        wavenet_params = json.load(f)

    with tf.device("/cpu:0"):
        # Create coordinator.
        coord = tf.train.Coordinator()

        # Load raw waveform from VCTK corpus.
        with tf.name_scope('create_inputs'):
            # Allow silence trimming to be skipped by specifying a threshold near
            # zero.
            silence_threshold = args.silence_threshold if args.silence_threshold > \
                                                          EPSILON else None
            gc_enabled = args.gc_channels is not None
            reader = AudioReader(
                args.data_dir,
                coord,
                sample_rate=wavenet_params['sample_rate'],
                gc_enabled=gc_enabled,
                sample_size=args.sample_size,
                silence_threshold=silence_threshold)

        # Create network.
        net = WaveNetModel(
            batch_size=args.batch_size,
            dilations=wavenet_params["dilations"],
            filter_width=wavenet_params["filter_width"],
            residual_channels=wavenet_params["residual_channels"],
            dilation_channels=wavenet_params["dilation_channels"],
            skip_channels=wavenet_params["skip_channels"],
            quantization_channels=wavenet_params["quantization_channels"],
            use_biases=wavenet_params["use_biases"],
            scalar_input=wavenet_params["scalar_input"],
            initial_filter_width=wavenet_params["initial_filter_width"],
            histograms=args.histograms,
            global_condition_channels=args.gc_channels,
            global_condition_cardinality=reader.gc_category_cardinality)

        if args.l2_regularization_strength == 0:
            args.l2_regularization_strength = None

        global_step = tf.get_variable("global_step", [], initializer=tf.constant_initializer(0), trainable=False)

        optimizer = optimizer_factory[args.optimizer](
            learning_rate=args.learning_rate,
            momentum=args.momentum)

        tower_grads = []
        tower_losses = []
        with tf.variable_scope(tf.get_variable_scope()):
            for i in range(args.gpu_nums):
                with tf.device("/gpu:%d" % i), tf.name_scope("tower_%d" % i) as scope:
                    audio_batch = reader.dequeue(args.batch_size)
                    if gc_enabled:
                        gc_id_batch = reader.dequeue_gc(args.batch_size)
                    else:
                        gc_id_batch = None

                    loss = net.loss(input_batch=audio_batch,
                                    global_condition_batch=gc_id_batch,
                                    l2_regularization_strength=args.l2_regularization_strength)
                    tower_losses.append(loss)

                    trainable = tf.trainable_variables()
                    grads = optimizer.compute_gradients(loss, var_list=trainable)
                    tower_grads.append(grads)

                    summaries = tf.get_collection(tf.GraphKeys.SUMMARIES, scope)
                    tf.get_variable_scope().reuse_variables()

        # calculate the mean of each gradient. Synchronization point across all towers
        grads = average_gradients(tower_grads)
        train_ops = optimizer.apply_gradients(grads, global_step=global_step)

        # calculate the mean loss
        loss = tf.reduce_mean(tower_losses)

        # Set up logging for TensorBoard.
        writer = tf.summary.FileWriter(logdir)
        writer.add_graph(tf.get_default_graph())
        run_metadata = tf.RunMetadata()
        summaries_ops = tf.summary.merge(summaries)

        # Set up session
        sess = tf.Session(config=tf.ConfigProto(log_device_placement=False, allow_soft_placement=True))
        init = tf.global_variables_initializer()
        sess.run(init)

        # Saver for storing checkpoints of the model.
        saver = tf.train.Saver(var_list=tf.trainable_variables())

        try:
            saved_global_step = load(saver, sess, restore_from)
            if is_overwritten_training or saved_global_step is None:
                # The first training step will be saved_global_step + 1,
                # therefore we put -1 here for new or overwritten trainings.
                saved_global_step = -1

        except:
            print("Something went wrong while restoring checkpoint. "
                  "We will terminate training to avoid accidentally overwriting "
                  "the previous model.")
            raise

        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        reader.start_threads(sess)

        step = None
        try:
            last_saved_step = saved_global_step
            for step in range(saved_global_step + 1, args.num_steps):
                start_time = time.time()
                if args.store_metadata and step % 50 == 0:
                    # Slow run that stores extra information for debugging.
                    print('Storing metadata')
                    run_options = tf.RunOptions(
                        trace_level=tf.RunOptions.FULL_TRACE)
                    summary, loss_value, _ = sess.run(
                        [summaries_ops, loss, train_ops],
                        options=run_options,
                        run_metadata=run_metadata)
                    writer.add_summary(summary, step)
                    writer.add_run_metadata(run_metadata,
                                            'step_{:04d}'.format(step))
                    tl = timeline.Timeline(run_metadata.step_stats)
                    timeline_path = os.path.join(logdir, 'timeline.trace')
                    with open(timeline_path, 'w') as f:
                        f.write(tl.generate_chrome_trace_format(show_memory=True))
                else:
                    summary, loss_value, _ = sess.run([summaries_ops, loss, train_ops])
                    writer.add_summary(summary, step)

                duration = time.time() - start_time
                print('step {:d} - loss = {:.3f}, ({:.3f} sec/step)'
                      .format(step, loss_value, duration))

                if step % args.checkpoint_every == 0:
                    save(saver, sess, logdir, step)
                    last_saved_step = step

        except KeyboardInterrupt:
            # Introduce a line break after ^C is displayed so save message
            # is on its own line.
            print()
        finally:
            if step > last_saved_step:
                save(saver, sess, logdir, step)
            coord.request_stop()
            coord.join(threads)
Esempio n. 6
0
def main():
    args = get_arguments()

    try:
        directories = validate_directories(args)
    except ValueError as e:
        print("Some arguments are wrong:")
        print(str(e))
        return

    logdir = directories['logdir']
    restore_from = directories['restore_from']

    # Even if we restored the model, we will treat it as new training
    # if the trained model is written into an arbitrary location.
    is_overwritten_training = logdir != restore_from

    with open(args.wavenet_params, 'r') as f:
        wavenet_params = json.load(f)

    # Create coordinator.
    coord = tf.train.Coordinator()

    # Load raw waveform from VCTK corpus.
    with tf.name_scope('create_inputs'):
        # Allow silence trimming to be skipped by specifying a threshold near
        # zero.
        silence_threshold = args.silence_threshold if args.silence_threshold > \
                                                      EPSILON else None
        gc_enabled = args.gc_channels is not None
        reader = AudioReader(
            args.data_dir,
            coord,
            sample_rate=wavenet_params['sample_rate'],
            gc_enabled=gc_enabled,
            receptive_field=WaveNetModel.calculate_receptive_field(wavenet_params["filter_width"],
                                                                   wavenet_params["dilations"],
                                                                   wavenet_params["scalar_input"],
                                                                   wavenet_params["initial_filter_width"]),
            sample_size=args.sample_size,
            silence_threshold=silence_threshold)
        audio_batch = reader.dequeue(args.batch_size)
        if gc_enabled:
            gc_id_batch = reader.dequeue_gc(args.batch_size)
        else:
            gc_id_batch = None

    # Create network.
    net = WaveNetModel(
        batch_size=args.batch_size,
        dilations=wavenet_params["dilations"],
        filter_width=wavenet_params["filter_width"],
        residual_channels=wavenet_params["residual_channels"],
        dilation_channels=wavenet_params["dilation_channels"],
        skip_channels=wavenet_params["skip_channels"],
        quantization_channels=wavenet_params["quantization_channels"],
        use_biases=wavenet_params["use_biases"],
        scalar_input=wavenet_params["scalar_input"],
        initial_filter_width=wavenet_params["initial_filter_width"],
        histograms=args.histograms,
        global_condition_channels=args.gc_channels,
        global_condition_cardinality=reader.gc_category_cardinality)

    if args.l2_regularization_strength == 0:
        args.l2_regularization_strength = None
    loss = net.loss(input_batch=audio_batch,
                    global_condition_batch=gc_id_batch,
                    l2_regularization_strength=args.l2_regularization_strength)
    optimizer = optimizer_factory[args.optimizer](
                    learning_rate=args.learning_rate,
                    momentum=args.momentum)
    trainable = tf.trainable_variables()
    optim = optimizer.minimize(loss, var_list=trainable)

    # Set up logging for TensorBoard.
    writer = tf.summary.FileWriter(logdir)
    writer.add_graph(tf.get_default_graph())
    run_metadata = tf.RunMetadata()
    summaries = tf.summary.merge_all()

    # Set up session
    sess = tf.Session(config=tf.ConfigProto(log_device_placement=False))
    init = tf.global_variables_initializer()
    sess.run(init)

    # Saver for storing checkpoints of the model.
    saver = tf.train.Saver(var_list=tf.trainable_variables(), max_to_keep=args.max_checkpoints)

    try:
        saved_global_step = load(saver, sess, restore_from)
        if is_overwritten_training or saved_global_step is None:
            # The first training step will be saved_global_step + 1,
            # therefore we put -1 here for new or overwritten trainings.
            saved_global_step = -1

    except:
        print("Something went wrong while restoring checkpoint. "
              "We will terminate training to avoid accidentally overwriting "
              "the previous model.")
        raise

    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    reader.start_threads(sess)

    step = None
    last_saved_step = saved_global_step
    try:
        for step in range(saved_global_step + 1, args.num_steps):
            start_time = time.time()
            if args.store_metadata and step % 50 == 0:
                # Slow run that stores extra information for debugging.
                print('Storing metadata')
                run_options = tf.RunOptions(
                    trace_level=tf.RunOptions.FULL_TRACE)
                summary, loss_value, _ = sess.run(
                    [summaries, loss, optim],
                    options=run_options,
                    run_metadata=run_metadata)
                writer.add_summary(summary, step)
                writer.add_run_metadata(run_metadata,
                                        'step_{:04d}'.format(step))
                tl = timeline.Timeline(run_metadata.step_stats)
                timeline_path = os.path.join(logdir, 'timeline.trace')
                with open(timeline_path, 'w') as f:
                    f.write(tl.generate_chrome_trace_format(show_memory=True))
            else:
                summary, loss_value, _ = sess.run([summaries, loss, optim])
                writer.add_summary(summary, step)

            duration = time.time() - start_time
            print('step {:d} - loss = {:.3f}, ({:.3f} sec/step)'
                  .format(step, loss_value, duration))

            if step % args.checkpoint_every == 0:
                save(saver, sess, logdir, step)
                last_saved_step = step

    except KeyboardInterrupt:
        # Introduce a line break after ^C is displayed so save message
        # is on its own line.
        print()
    finally:
        if step > last_saved_step:
            save(saver, sess, logdir, step)
        coord.request_stop()
        coord.join(threads)
Esempio n. 7
0
def main():
    args = get_arguments()

    try:
        directories = validate_directories(args)
    except ValueError as e:
        print("Some arguments are wrong:")
        print(str(e))
        return

    logdir = directories['logdir']
    restore_from = directories['restore_from']

    # Even if we restored the model, we will treat it as new training
    # if the trained model is written into an arbitrary location.
    is_overwritten_training = logdir != restore_from

    with open(args.wavenet_params, 'r') as f:
        wavenet_params = json.load(f)

    # Create coordinator.
    coord = tf.train.Coordinator()

    # Load raw waveform from VCTK corpus.
    with tf.name_scope('create_inputs'):
        # Allow silence trimming to be skipped by specifying a threshold near
        # zero.
        silence_threshold = args.silence_threshold if args.silence_threshold > \
                                                      EPSILON else None
        gc_enabled = args.gc_channels is not None
        reader = AudioReader(
            args.data_dir,
            coord,
            sample_rate=wavenet_params['sample_rate'],
            gc_enabled=gc_enabled,
            receptive_field=WaveNetModel.calculate_receptive_field(
                wavenet_params["filter_width"], wavenet_params["dilations"],
                wavenet_params["scalar_input"],
                wavenet_params["initial_filter_width"]),
            sample_size=args.sample_size,
            silence_threshold=silence_threshold,
            normalize_peak=args.normalize_peak,
            queue_size=32 * max(args.num_gpus, 1))
        if gc_enabled:
            gc_id_batch = reader.dequeue_gc(args.batch_size)
        else:
            gc_id_batch = None

    if args.l2_regularization_strength == 0:
        args.l2_regularization_strength = None

    if args.num_gpus <= 1:
        print("Falling back to single computation unit.")
        audio_batch = reader.dequeue(args.batch_size)
        net = make_model(args, wavenet_params, reader)
        loss = net.loss(
            input_batch=audio_batch,
            global_condition_batch=gc_id_batch,
            l2_regularization_strength=args.l2_regularization_strength)
        optimizer = optimizer_factory[args.optimizer](
            learning_rate=args.learning_rate, momentum=args.momentum)
        trainable = tf.trainable_variables()
        gradients = optimizer.compute_gradients(loss, var_list=trainable)
        for gradient, variable in gradients:
            if gradient is not None:
                tf.summary.scalar(variable.name + '/gradient',
                                  tf.norm(gradient))
        optim = optimizer.apply_gradients(gradients)
    else:
        print("Using {} GPUs for compuation.".format(args.num_gpus))
        with tf.device('/gpu:0'), tf.name_scope('tower_0'):
            optimizer = optimizer_factory[args.optimizer](
                learning_rate=args.learning_rate, momentum=args.momentum)
        losses = []
        gradients = []
        with tf.variable_scope(tf.get_variable_scope()) as scope:
            for i in range(args.num_gpus):
                with tf.device('/gpu:%d' % i), tf.name_scope('tower_%d' % i):
                    audio_batch = reader.dequeue(args.batch_size)
                    net = make_model(args, wavenet_params, reader, i)
                    loss = net.loss(input_batch=audio_batch,
                                    global_condition_batch=gc_id_batch,
                                    l2_regularization_strength=args.
                                    l2_regularization_strength)
                    trainable = tf.trainable_variables()
                    gradient = optimizer.compute_gradients(loss,
                                                           var_list=trainable)
                    losses.append(loss)
                    gradients.append(gradient)
                    scope.reuse_variables()

        with tf.device('/gpu:0'), tf.name_scope('tower_0'):
            loss = tf.reduce_mean(losses)
            tf.summary.scalar('mean_total_loss', loss)
            average_gradients = []
            for grouped_gradients in zip(*gradients):
                expanded_gradients = []
                for gradient, _ in grouped_gradients:
                    if gradient is not None:
                        expanded_gradients.append(tf.expand_dims(gradient, 0))

                # Since all GPUs share the same variable we can just the the one from gpu:0
                _, variable = grouped_gradients[0]
                if len(expanded_gradients) == 0:
                    print('No gradient for %s' % variable.name)
                    average_gradients.append((None, variable))
                    continue

                merged_gradients = tf.concat(expanded_gradients, 0)
                average_gradient = tf.reduce_mean(merged_gradients, 0)
                average_gradients.append((average_gradient, variable))

                tf.summary.scalar(variable.name + '/gradient',
                                  tf.norm(average_gradient))
            optim = optimizer.apply_gradients(average_gradients)

    # Set up logging for TensorBoard.
    writer = tf.summary.FileWriter(logdir)
    writer.add_graph(tf.get_default_graph())
    run_metadata = tf.RunMetadata()
    summaries = tf.summary.merge_all()

    # Set up session
    config = tf.ConfigProto(log_device_placement=False,
                            allow_soft_placement=True)
    # Workaround for avoiding allocating memory on all GPUs due to tensorflow#8021
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    init = tf.global_variables_initializer()
    sess.run(init)

    # Saver for storing checkpoints of the model.
    saver = tf.train.Saver(var_list=tf.trainable_variables(),
                           max_to_keep=args.max_checkpoints)

    try:
        saved_global_step = load(saver, sess, restore_from)
        if is_overwritten_training or saved_global_step is None:
            # The first training step will be saved_global_step + 1,
            # therefore we put -1 here for new or overwritten trainings.
            saved_global_step = -1

    except:
        print("Something went wrong while restoring checkpoint. "
              "We will terminate training to avoid accidentally overwriting "
              "the previous model.")
        raise

    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    reader.start_threads(sess)

    step = None
    last_saved_step = saved_global_step
    try:
        for step in range(saved_global_step + 1, args.num_steps):
            start_time = time.time()
            if args.store_metadata and step % 50 == 0:
                # Slow run that stores extra information for debugging.
                print('Storing metadata')
                run_options = tf.RunOptions(
                    trace_level=tf.RunOptions.FULL_TRACE)
                summary, loss_value, _ = sess.run([summaries, loss, optim],
                                                  options=run_options,
                                                  run_metadata=run_metadata)
                writer.add_summary(summary, step)
                writer.add_run_metadata(run_metadata,
                                        'step_{:04d}'.format(step))
                tl = timeline.Timeline(run_metadata.step_stats)
                timeline_path = os.path.join(logdir, 'timeline.trace')
                with open(timeline_path, 'w') as f:
                    f.write(tl.generate_chrome_trace_format(show_memory=True))
            else:
                summary, loss_value, _ = sess.run([summaries, loss, optim])
                writer.add_summary(summary, step)

            duration = time.time() - start_time
            print('step {:d} - loss = {:.3f}, ({:.3f} sec/step)'.format(
                step, loss_value, duration))

            if step % args.checkpoint_every == 0:
                save(saver, sess, logdir, step)
                last_saved_step = step

    except KeyboardInterrupt:
        # Introduce a line break after ^C is displayed so save message
        # is on its own line.
        print()
    finally:
        if step is not None and step > last_saved_step:
            save(saver, sess, logdir, step)
        coord.request_stop()
        coord.join(threads)