def run_ctc(log_dir, args):
    checkpoint_path = os.path.join(log_dir, 'model.ckpt')
    #input_path = os.path.join(args.base_dir, args.input)
    log('Checkpoint path: %s' % checkpoint_path)
    #log('Loading training data from: %s' % input_path)
    log('Using model: %s' % args.model)
    #log(hparams_debug_string())

    # Set up DataFeeder:
    #dataset = DataFeeder(input_path)

    # Build the model
    graph = tf.Graph()
    with graph.as_default():
        # e.g: log filter bank or MFCC features
        # Has size [batch_size, max_step_size, num_features], but the
        # batch_size and max_step_size can vary along each step
        inputs = tf.placeholder(tf.float32,
                                [None, None, num_features])  #batch size = 1
        #inputs = tf.placeholder(tf.float32, [None,num_features])
        # Here we use sparse_placeholder that will generate a
        # SparseTensor required by ctc_loss op.
        targets = tf.sparse_placeholder(tf.int32)

        # 1d array of size [batch_size]
        seq_len = tf.placeholder(tf.int32, [None])

        # Defining the cell
        # Can be:
        #   tf.nn.rnn_cell.RNNCell
        #   tf.nn.rnn_cell.GRUCell
        cell = tf.contrib.rnn.LSTMCell(num_hidden, state_is_tuple=True)

        # Stacking rnn cells
        stack = tf.contrib.rnn.MultiRNNCell([cell] * num_layers,
                                            state_is_tuple=True)

        # The second output is the last state and we will no use that
        outputs, _ = tf.nn.dynamic_rnn(stack,
                                       inputs,
                                       seq_len,
                                       dtype=tf.float32)

        shape = tf.shape(inputs)
        batch_s, max_time_steps = shape[0], shape[1]

        # Reshaping to apply the same weights over the timesteps
        outputs = tf.reshape(outputs, [-1, num_hidden])

        # Truncated normal with mean 0 and stdev=0.1
        # Tip: Try another initialization
        # see https://www.tensorflow.org/versions/r0.9/api_docs/python/contrib.layers.html#initializers
        W = tf.Variable(
            tf.truncated_normal([num_hidden, num_classes], stddev=0.1))
        # Zero initialization
        # Tip: Is tf.zeros_initializer the same?
        b = tf.Variable(tf.constant(0., shape=[num_classes]))

        # Doing the affine projection
        logits = tf.matmul(outputs, W) + b

        # Reshaping back to the original shape
        logits = tf.reshape(logits, [batch_s, -1, num_classes])

        # Time major
        logits = tf.transpose(logits, (1, 0, 2))

        loss = tf.nn.ctc_loss(targets, logits, seq_len)
        cost = tf.reduce_mean(loss)

        optimizer = tf.train.AdamOptimizer().minimize(cost)
        # optimizer = tf.train.MomentumOptimizer(learning_rate=0.01, momentum=0.9).minimize(cost)
        #optimizer = tf.train.MomentumOptimizer(learning_rate=0.005, momentum=0.9).minimize(cost)

        # Option 2: tf.contrib.ctc.ctc_beam_search_decoder
        # (it's slower but you'll get better results)
        decoded, log_prob = tf.nn.ctc_greedy_decoder(logits, seq_len)

        # Inaccuracy: label error rate
        ler = tf.reduce_mean(
            tf.edit_distance(tf.cast(decoded[0], tf.int32), targets))

        stats = add_stats(cost, ler)

        saver = tf.train.Saver(max_to_keep=5, keep_checkpoint_every_n_hours=2)

    # Bookkeeping:
    time_window = ValueWindow(100)
    train_cost_window = ValueWindow(100)
    train_ler_window = ValueWindow(100)
    val_cost_window = ValueWindow(100)
    val_ler_window = ValueWindow(100)

    # Run!
    with tf.Session(graph=graph) as sess:
        summary_writer = tf.summary.FileWriter(log_dir, sess.graph)
        sess.run(tf.global_variables_initializer())

        if args.restore_step:
            # Restore from a checkpoint if the user requested it.
            restore_path = '%s-%d' % (checkpoint_path, args.restore_step)
            saver.restore(sess, restore_path)
            log('Resuming from checkpoint: %s' % (restore_path, ))
        else:
            log('Starting new training run')

        for index_val in range(0, 1):
            index_val = index_val + 1
            sample_file = args.sample_file
            label_file = args.label_file
            val_inputs, val_targets, val_seq_len, val_original = next_infer_batch(
                sample_file, label_file)
            #print(val_inputs)
            print(val_targets)
            val_feed = {
                inputs: val_inputs,
                targets: val_targets,
                seq_len: val_seq_len
            }
            val_cost, val_ler = sess.run([cost, ler], feed_dict=val_feed)
            val_cost_window.append(val_cost)
            val_ler_window.append(val_ler)

            # Decoding
            d = sess.run(decoded[0], feed_dict=val_feed)
            # Replacing blank label to none
            str_decoded = ''
            for x in np.asarray(d[1]):
                if x in layer_int_to_name_map:
                    str_decoded = str_decoded + layer_int_to_name_map[x] + ' '
                else:
                    print("x=%d MAJOR ERROR? OUT OF PREDICTION SCOPE" % x)

            print('for Sample %s' % sample_file)
            print('Original val: %s' % val_original)  # TODO
            print('Decoded val: %s' % str_decoded)

            message = "avg_train_cost = {:.3f}, avg_train_ler = {:.3f}, " \
                      "val_cost = {:.3f}, val_ler = {:.3f}, " \
                      "avg_val_cost = {:.3f}, avg_val_ler = {:.3f}"
            log(
                message.format(train_cost_window.average,
                               train_ler_window.average, val_cost, val_ler,
                               val_cost_window.average,
                               val_ler_window.average))
Beispiel #2
0
def train(log_dir, args, trans_ckpt_dir=None):
    commit = get_git_commit() if args.git else 'None'
    checkpoint_path = os.path.join(log_dir, 'model.ckpt')
    if trans_ckpt_dir != None:
        trans_checkpoint_path = os.path.join(trans_ckpt_dir, 'model.ckpt')

    input_path = os.path.join(args.base_dir, args.input)
    log('Checkpoint path: %s' % trans_checkpoint_path)
    log('Loading training data from: %s' % input_path)
    log('Using model: %s' % args.model)
    log(hparams_debug_string())

    # Set up DataFeeder:
    coord = tf.train.Coordinator()
    with tf.variable_scope('datafeeder') as scope:
        feeder = DataFeeder(coord, input_path, hparams)

    # Set up model:
    global_step = tf.Variable(0, name='global_step', trainable=False)
    with tf.variable_scope('model') as scope:
        model = create_model(args.model, hparams)
        model.initialize(feeder.inputs, feeder.input_lengths,
                         feeder.mel_targets, feeder.linear_targets)
        model.add_loss()
        model.add_optimizer(global_step)
        stats = add_stats(model)

    # Bookkeeping:
    step = 0
    time_window = ValueWindow(100)
    loss_window = ValueWindow(100)
    saver = tf.train.Saver(max_to_keep=5, keep_checkpoint_every_n_hours=2)

    # Train!
    with tf.Session() as sess:
        try:
            summary_writer = tf.summary.FileWriter(log_dir, sess.graph)
            sess.run(tf.global_variables_initializer())

            if args.restore_step:
                # Restore from a checkpoint if the user requested it.
                restore_path = '%s-%d' % (trans_checkpoint_path,
                                          args.restore_step)
                saver.restore(sess, restore_path)
                log('Resuming from checkpoint: %s at commit: %s' %
                    (restore_path, commit),
                    slack=True)
            else:
                log('Starting new training run at commit: %s' % commit,
                    slack=True)

            feeder.start_in_session(sess)

            while not coord.should_stop():
                start_time = time.time()
                step, loss, opt = sess.run(
                    [global_step, model.loss, model.optimize])
                time_window.append(time.time() - start_time)
                loss_window.append(loss)
                message = 'Step %-7d [%.03f sec/step, loss=%.05f, avg_loss=%.05f]' % (
                    step, time_window.average, loss, loss_window.average)
                log(message, slack=(step % args.checkpoint_interval == 0))

                if loss > 100 or math.isnan(loss):
                    log('Loss exploded to %.05f at step %d!' % (loss, step),
                        slack=True)
                    raise Exception('Loss Exploded')

                if step % args.summary_interval == 0:
                    log('Writing summary at step: %d' % step)
                    summary_writer.add_summary(sess.run(stats), step)

                if step % args.checkpoint_interval == 0:
                    log('Saving checkpoint to: %s-%d' %
                        (checkpoint_path, step))
                    saver.save(sess, checkpoint_path, global_step=step)
                    log('Saving audio and alignment...')
                    input_seq, spectrogram, alignment = sess.run([
                        model.inputs[0], model.linear_outputs[0],
                        model.alignments[0]
                    ])
                    waveform = audio.inv_spectrogram(spectrogram.T)
                    audio.save_wav(
                        waveform,
                        os.path.join(log_dir, 'step-%d-audio.wav' % step))
                    plot.plot_alignment(
                        alignment,
                        os.path.join(log_dir, 'step-%d-align.png' % step),
                        info='%s, %s, %s, step=%d, loss=%.5f' %
                        (args.model, commit, time_string(), step, loss))
                    log('Input: %s' % sequence_to_text(input_seq))

        except Exception as e:
            log('Exiting due to exception: %s' % e, slack=True)
            traceback.print_exc()
            coord.request_stop(e)
Beispiel #3
0
def train(log_dir, args):
  commit = get_git_commit() if args.git else 'None'
  checkpoint_path = os.path.join(log_dir, 'model.ckpt')
  input_path = os.path.join(args.base_dir, args.input)
  log('Checkpoint path: %s' % checkpoint_path)
  log('Loading training data from: %s' % input_path)
  log('Using model: %s' % args.model)
  log(hparams_debug_string())

  # Set up DataFeeder:
  coord = tf.train.Coordinator()
  with tf.variable_scope('datafeeder') as scope:
    feeder = DataFeeder(coord, input_path, hparams)

  # Set up model:
  global_step = tf.Variable(0, name='global_step', trainable=False)
  with tf.variable_scope('model') as scope:
    model = create_model(args.model, hparams)
    model.initialize(feeder.inputs, feeder.input_lengths, feeder.mel_targets, feeder.linear_targets)
    model.add_loss()
    model.add_optimizer(global_step)
    stats = add_stats(model)

  # Bookkeeping:
  step = 0
  time_window = ValueWindow(100)
  loss_window = ValueWindow(100)
  saver = tf.train.Saver(max_to_keep=5, keep_checkpoint_every_n_hours=2)

  # Train!
  with tf.Session() as sess:
    try:
      summary_writer = tf.summary.FileWriter(log_dir, sess.graph)
      sess.run(tf.global_variables_initializer())

      if args.restore_step:
        # Restore from a checkpoint if the user requested it.
        restore_path = '%s-%d' % (checkpoint_path, args.restore_step)
        saver.restore(sess, restore_path)
        log('Resuming from checkpoint: %s at commit: %s' % (restore_path, commit), slack=True)
      else:
        log('Starting new training run at commit: %s' % commit, slack=True)

      feeder.start_in_session(sess)

      while not coord.should_stop():
        start_time = time.time()
        step, loss, opt = sess.run([global_step, model.loss, model.optimize])
        time_window.append(time.time() - start_time)
        loss_window.append(loss)
        message = 'Step %-7d [%.03f sec/step, loss=%.05f, avg_loss=%.05f]' % (
          step, time_window.average, loss, loss_window.average)
        log(message, slack=(step % args.checkpoint_interval == 0))

        if loss > 100 or math.isnan(loss):
          log('Loss exploded to %.05f at step %d!' % (loss, step), slack=True)
          raise Exception('Loss Exploded')

        if step % args.summary_interval == 0:
          log('Writing summary at step: %d' % step)
          summary_writer.add_summary(sess.run(stats), step)

        if step % args.checkpoint_interval == 0:
          log('Saving checkpoint to: %s-%d' % (checkpoint_path, step))
          saver.save(sess, checkpoint_path, global_step=step)
          log('Saving audio and alignment...')
          input_seq, spectrogram, alignment = sess.run([
            model.inputs[0], model.linear_outputs[0], model.alignments[0]])
          waveform = audio.inv_spectrogram(spectrogram.T)
          audio.save_wav(waveform, os.path.join(log_dir, 'step-%d-audio.wav' % step))
          plot.plot_alignment(alignment, os.path.join(log_dir, 'step-%d-align.png' % step),
            info='%s, %s, %s, step=%d, loss=%.5f' % (args.model, commit, time_string(), step, loss))
          log('Input: %s' % sequence_to_text(input_seq))

    except Exception as e:
      log('Exiting due to exception: %s' % e, slack=True)
      traceback.print_exc()
      coord.request_stop(e)
Beispiel #4
0
def train(log_dir, args):
    commit = get_git_commit() if args.git else 'None'
    checkpoint_path = os.path.join(log_dir, 'model.ckpt')
    input_path = os.path.join(args.base_dir, args.input)
    log('Checkpoint path: %s' % checkpoint_path)
    log('Loading training data from: %s' % input_path)
    log('Using model: %s' % args.model)
    log(hparams_debug_string())

    ps_hosts = args.ps_hosts.split(",")
    worker_hosts = args.worker_hosts.split(",")
    cluster = tf.train.ClusterSpec({"ps": ps_hosts, "worker": worker_hosts})
    server = tf.train.Server(cluster,
                             job_name=args.job_name,
                             task_index=args.task_index)

    # Block further graph execution if current node is parameter server
    if args.job_name == "ps":
        server.join()

    with tf.device(
            tf.train.replica_device_setter(
                worker_device="/job:worker/task:%d" % args.task_index,
                cluster=cluster)):

        # Set up DataFeeder:
        coord = tf.train.Coordinator()
        with tf.variable_scope('datafeeder') as scope:
            feeder = DataFeeder(coord, input_path, hparams)

        # Set up model:
        global_step = tf.Variable(0, name='global_step', trainable=False)
        with tf.variable_scope('model') as scope:
            model = create_model(args.model, hparams)
            model.initialize(feeder.inputs, feeder.input_lengths,
                             feeder.mel_targets, feeder.linear_targets)
            model.add_loss()
            model.add_optimizer(global_step)
            stats = add_stats(model)

        # Bookkeeping:
        step = 0
        time_window = ValueWindow(100)
        loss_window = ValueWindow(100)
        saver = tf.train.Saver(max_to_keep=5,
                               keep_checkpoint_every_n_hours=2,
                               sharded=True)

        hooks = [tf.train.StopAtStepHook(last_step=1000000)]
        # Train!
        # Monitored... automatycznie wznawia z checkpointu.
        is_chief = (args.task_index == 0)
        init_op = tf.global_variables_initializer()
        sv = tf.train.Supervisor(is_chief=(args.task_index == 0),
                                 logdir="train_logs",
                                 init_op=init_op,
                                 summary_op=stats,
                                 saver=saver,
                                 save_model_secs=600)
        with sv.managed_session(server.target) as sess:
            try:

                summary_writer = tf.summary.FileWriter(log_dir, sess.graph)
                sess.run(init_op)

                if args.restore_step and is_chief:
                    # Restore from a checkpoint if the user requested it.
                    restore_path = '%s-%d' % (checkpoint_path,
                                              args.restore_step)
                    saver.restore(sess, restore_path)
                    log('Resuming from checkpoint: %s at commit: %s' %
                        (restore_path, commit),
                        slack=True)
                else:
                    log('Starting new training run at commit: %s' % commit,
                        slack=True)

                feeder.start_in_session(sess)

                while not coord.should_stop():
                    start_time = time.time()
                    step, loss, opt = sess.run(
                        [global_step, model.loss, model.optimize])
                    time_window.append(time.time() - start_time)
                    loss_window.append(loss)
                    message = 'Step %-7d [%.03f sec/step, loss=%.05f, avg_loss=%.05f]' % (
                        step, time_window.average, loss, loss_window.average)
                    log(message, slack=(step % args.checkpoint_interval == 0))

                    if loss > 100 or math.isnan(loss):
                        log('Loss exploded to %.05f at step %d!' %
                            (loss, step),
                            slack=True)
                        raise Exception('Loss Exploded')

                    if step % args.summary_interval == 0:
                        log('Writing summary at step: %d' % step)
                        summary_writer.add_summary(sess.run(stats), step)

                    if step % args.checkpoint_interval == 0 and is_chief:
                        log('Saving checkpoint to: %s-%d' %
                            (checkpoint_path, step))
                        saver.save(sess, checkpoint_path, global_step=step)
                        log('Saving audio and alignment...')
                        input_seq, spectrogram, alignment = sess.run([
                            model.inputs[0], model.linear_outputs[0],
                            model.alignments[0]
                        ])
                        waveform = audio.inv_spectrogram(spectrogram.T)
                        audio.save_wav(
                            waveform,
                            os.path.join(log_dir, 'step-%d-audio.wav' % step))
                        plot.plot_alignment(
                            alignment,
                            os.path.join(log_dir, 'step-%d-align.png' % step),
                            info='%s, %s, %s, step=%d, loss=%.5f' %
                            (args.model, commit, time_string(), step, loss))
                        log('Input: %s' % sequence_to_text(input_seq))

            except Exception as e:
                log('Exiting due to exception: %s' % e, slack=True)
                traceback.print_exc()
                coord.request_stop(e)
Beispiel #5
0
def train(log_dir, args):
    commit = get_git_commit() if args.git else 'None'
    checkpoint_path = os.path.join(log_dir, 'model.ckpt')
    input_path = os.path.join(args.base_dir, args.input)
    log('Checkpoint path: %s' % checkpoint_path)
    log('Loading training data from: %s' % input_path)
    log('Using model: %s' % args.model)
    log(hparams_debug_string())

    # Set up DataFeeder:
    coord = tf.train.Coordinator()
    with tf.variable_scope('datafeeder') as scope:
        feeder = DataFeeder(coord, input_path, hparams)

    # Set up model:
    global_step = tf.Variable(0, name='global_step', trainable=False)
    with tf.variable_scope('model') as scope:
        model = SkipNet(hparams)
        model.initialize(feeder.txt_A, feeder.txt_A_lenth, feeder.txt_B, feeder.txt_B_lenth, \
           feeder.mel_targets, feeder.image_targets)
        model.add_loss()
        model.add_optimizer(global_step)
        stats = add_stats(model)

    # Bookkeeping:
    step = 0
    time_window = ValueWindow(100)
    loss_window = ValueWindow(100)
    saver = tf.train.Saver(max_to_keep=5, keep_checkpoint_every_n_hours=2)

    # Train!
    with tf.Session() as sess:
        try:
            sess.run(tf.global_variables_initializer())

            if args.restore_step:
                # Restore from a checkpoint if the user requested it.
                restore_path = '%s-%d' % (checkpoint_path, args.restore_step)
                saver.restore(sess, restore_path)
                log('Resuming from checkpoint: %s at commit: %s' %
                    (restore_path, commit),
                    slack=True)
            else:
                log('Starting new training run at commit: %s' % commit,
                    slack=True)

            feeder.start_in_session(sess)

            while not coord.should_stop():
                start_time = time.time()
                feed_dict = {
                    self.txt_targets_A: feeder.txt_A,
                    self.txt_lenth_A: feeder.txt_A_lenth,
                    self.txt_targets_B: feeder.txt_B,
                    self.txt_lenth_B: feeder.txt_B_lenth,
                    self.mel_targets: feeder.mel_targets,
                    self.image_targets: feeder.image_targets
                }

                # iter 1: dataset A : image - text pairs
                step, img_loss, txt_B_loss, d_loss, g_loss, opt1, opt2 =\
                    sess.run(feed_dict, [global_step, model.recon_img_loss, model.recon_txt_loss_A,\
                    model.domain_d_loss, model.domain_g_loss, model.optimize_recon, model.optimize_domain])

                # iter 2: dataset B: speech-text pairs
                step, speech_loss, txt_A_loss, opt1, opt2 =\
                    sess.run(feed_dict, [global_step, model.recon_speech_loss, model.recon_txt_loss_B,\
                    model.domain_d_loss, model.domain_g_loss, model.optimize_recon, model.optimize_domain])

                time_window.append(time.time() - start_time)
                loss_window.append(loss)
                message = 'Step %-7d [%.03f sec/step, loss=%.05f, avg_loss=%.05f]' % (
                    step, time_window.average, loss, loss_window.average)
                log(message, slack=(step % args.checkpoint_interval == 0))

                if loss > 100 or math.isnan(loss):
                    log('Loss exploded to %.05f at step %d!' % (loss, step),
                        slack=True)
                    raise Exception('Loss Exploded')

                if step % args.checkpoint_interval == 0:
                    log('Saving checkpoint to: %s-%d' %
                        (checkpoint_path, step))
                    saver.save(sess, checkpoint_path, global_step=step)

        except Exception as e:
            log('Exiting due to exception: %s' % e, slack=True)
            traceback.print_exc()
            coord.request_stop(e)
Beispiel #6
0
def train(log_dir, pretrain_log_dir, args):
    checkpoint_path = os.path.join(log_dir, 'model.ckpt')
    input_path = os.path.join(args.base_dir, args.input)
    log('Checkpoint path: %s' % checkpoint_path)
    log('Loading training data from: %s' % input_path)
    log('Using model: %s' % args.model)
    log(hparams_debug_string())

    # Set up DataFeeder:
    ### input_path: linear, mel, frame_num, ppgs
    coord = tf.train.Coordinator()
    with tf.variable_scope('datafeeder') as scope:
        feeder = DataFeeder(coord, input_path, hparams)

    # Set up model:
    global_step = tf.Variable(0, name='global_step', trainable=False)
    with tf.variable_scope('model') as scope:
        model = create_model(args.model, hparams)
        model.initialize(input_lengths=feeder.input_lengths,
                         mel_targets=feeder.mel_targets,
                         linear_targets=feeder.linear_targets,
                         ppgs=feeder.ppgs,
                         speakers=feeder.speakers)
        model.add_loss()
        model.add_optimizer(global_step)
        stats = add_stats(model)

    # Bookkeeping:
    step = 0
    time_window = ValueWindow(100)
    loss_window = ValueWindow(100)
    acc_window = ValueWindow(100)
    saver = tf.train.Saver(max_to_keep=5, keep_checkpoint_every_n_hours=2)

    # Train!
    with tf.Session() as sess:
        try:
            summary_writer = tf.summary.FileWriter(log_dir, sess.graph)
            sess.run(tf.global_variables_initializer())

            try:
                if pretrain_log_dir != None:
                    checkpoint_state = tf.train.get_checkpoint_state(
                        pretrain_log_dir)
                else:
                    checkpoint_state = tf.train.get_checkpoint_state(log_dir)
                if (checkpoint_state
                        and checkpoint_state.model_checkpoint_path):
                    log('Loading checkpoint {}'.format(
                        checkpoint_state.model_checkpoint_path),
                        slack=True)
                    saver.restore(sess, checkpoint_state.model_checkpoint_path)
                else:
                    log('No model to load at {}'.format(log_dir), slack=True)
                    saver.save(sess, checkpoint_path, global_step=global_step)
            except tf.errors.OutOfRangeError as e:
                log('Cannot restore checkpoint: {}'.format(e), slack=True)

            feeder.start_in_session(sess)

            while not coord.should_stop():
                start_time = time.time()

                ### how to run training
                if args.model == 'tacotron':
                    step, loss, opt = sess.run(
                        [global_step, model.loss, model.optimize])
                    time_window.append(time.time() - start_time)
                    loss_window.append(loss)
                    message = 'Step %-7d [%.03f sec/step, loss=%.05f, avg_loss=%.05f]' % (
                        step, time_window.average, loss, loss_window.average)
                    log(message, slack=(step % args.checkpoint_interval == 0))
                elif args.model == 'nnet1':
                    step, loss, opt, ppgs, logits = sess.run([
                        global_step, model.loss, model.optimize, model.ppgs,
                        model.logits
                    ])
                    ## cal acc
                    ppgs = np.argmax(ppgs, axis=-1)  # (N, 201, )
                    logits = np.argmax(logits, axis=-1)  # (N, 201, )
                    num_hits = np.sum(np.equal(ppgs, logits))
                    num_targets = np.shape(ppgs)[0] * np.shape(ppgs)[1]
                    acc = num_hits / num_targets
                    ## summerize
                    time_window.append(time.time() - start_time)
                    loss_window.append(loss)
                    acc_window.append(acc)
                    message = 'Step %-7d [%.03f sec/step, loss=%.05f, avg_loss=%.05f, acc=%.05f, avg_acc=%.05f]' % (
                        step, time_window.average, loss, loss_window.average,
                        acc, acc_window.average)
                    log(message, slack=(step % args.checkpoint_interval == 0))
                else:
                    print('input error!!')
                    assert 1 == 0

                ### save model and logs
                if loss > 100 or math.isnan(loss):
                    log('Loss exploded to %.05f at step %d!' % (loss, step),
                        slack=True)
                    raise Exception('Loss Exploded')

                if step % args.summary_interval == 0:
                    log('Writing summary at step: %d' % step)
                    summary_writer.add_summary(sess.run(stats), step)

                if step % args.checkpoint_interval == 0:
                    log('Saving checkpoint to: %s-%d' %
                        (checkpoint_path, step))
                    saver.save(sess, checkpoint_path, global_step=step)

        except Exception as e:
            log('Exiting due to exception: %s' % e, slack=True)
            traceback.print_exc()
            coord.request_stop(e)
Beispiel #7
0
def train(log_dir, args):
    checkpoint_path = os.path.join(log_dir, 'model.ckpt')
    log('Checkpoint path: %s' % checkpoint_path)
    log('Using model: %s' % args.model)
    log(hparams_debug_string())

    sequence_to_text = sequence_to_text2

    with tf.Graph().as_default(), tf.device('/cpu:0'):
        # Multi-GPU settings
        GPUs_id = eval(args.GPUs_id)
        num_GPU = len(GPUs_id)
        hparams.num_GPU = num_GPU
        models = []

        # Set up DataFeeder:
        coord = tf.train.Coordinator()

        if args.data_type == 'tfrecord':
            with open('./train_data_dict.json', 'r') as f:
                train_data_dict = json.load(f)
            train_data = args.train_data.split(',')
            file_list = []
            pattern = '[.]*\\_id\\_num\\_([0-9]+)[.]+'
            id_num = 0
            for item in train_data:
                file_list.append(train_data_dict[item])
                id_num += int(re.findall(pattern, train_data_dict[item])[0])
            log('train data:%s' % args.train_data)

            feeder = DataFeeder_tfrecord(hparams, file_list)
            inputs, input_lengths, linear_targets, mel_targets, n_frames, wavs, identities = feeder._get_batch_input(
            )

        elif args.data_type == 'npy':
            with open('./train_npy_data_dict.json', 'r') as f:
                train_data_dict = json.load(f)
            train_data = args.train_data.split(',')
            file_list = []
            pattern = '[.]*\\_id\\_num\\_([0-9]+)[.]+'
            id_num = 0
            for item in train_data:
                file_list.append(train_data_dict[item])
                id_num += int(re.findall(pattern, train_data_dict[item])[0])
            log('train data:%s' % args.train_data)

            feeder = DataFeeder_npy(hparams, file_list, coord)
            inputs = feeder.inputs
            input_lengths = feeder.input_lengths
            mel_targets = feeder.mel_targets
            linear_targets = feeder.linear_targets
            wavs = feeder.wavs
            identities = feeder.identities

        else:
            raise ('not spificied the input data type')

        # Set up model:
        global_step = tf.Variable(0, name='global_step', trainable=False)
        with tf.variable_scope('model') as scope:
            for i, GPU_id in enumerate(GPUs_id):
                with tf.device('/gpu:%d' % GPU_id):
                    with tf.name_scope('GPU_%d' % GPU_id):
                        models.append(None)
                        models[i] = create_model(args.model, hparams)
                        models[i].initialize(inputs=inputs,
                                             input_lengths=input_lengths,
                                             mel_targets=mel_targets,
                                             linear_targets=linear_targets,
                                             identities=identities,
                                             id_num=id_num)
                        models[i].add_loss()
                        models[i].add_optimizer(global_step)
                        stats = add_stats(models[i])

        # Bookkeeping:
        step = 0
        time_window = ValueWindow(250)
        loss_window = ValueWindow(1000)
        saver = tf.train.Saver(max_to_keep=5, keep_checkpoint_every_n_hours=8)
        # Train!
        config = tf.ConfigProto(log_device_placement=False,
                                allow_soft_placement=True)
        config.gpu_options.allow_growth = True
        with tf.Session(config=config) as sess:
            try:
                summary_writer = tf.summary.FileWriter(log_dir, sess.graph)
                sess.run(tf.global_variables_initializer())
                if args.restore_step:
                    # Restore from a checkpoint if the user requested it.
                    restore_path = '%s-%d' % (checkpoint_path,
                                              args.restore_step)
                    saver.restore(sess, restore_path)
                    log('Resuming from checkpoint: %s' % restore_path)
                else:
                    log('Starting new training run')

                if args.data_type == 'tfrecord':
                    tf.train.start_queue_runners(sess=sess, coord=coord)
                    feeder.start_threads(sess=sess, coord=coord)
                elif args.data_type == 'npy':
                    feeder.start_in_session(sess)

                while not coord.should_stop():
                    start_time = time.time()

                    step, loss, opt, loss_regularity = sess.run([
                        global_step,
                        models[0].loss,
                        models[0].optimize,
                        models[0].loss_regularity,
                    ])

                    time_window.append(time.time() - start_time)
                    loss_window.append(loss)
                    message = 'Step %-7d [%.03f avg_sec/step,  loss=%.05f,  avg_loss=%.05f,  lossw=%.05f]' % (
                        step, time_window.average, loss, loss_window.average,
                        loss_regularity)
                    log(message)

                    # if the gradient seems to explode, then restore to the previous step
                    if loss > 2 * loss_window.average or math.isnan(loss):
                        log('recover to the previous checkpoint')
                        restore_step = int(
                            (step - 10) / args.checkpoint_interval
                        ) * args.checkpoint_interval
                        restore_path = '%s-%d' % (checkpoint_path,
                                                  restore_step)
                        saver.restore(sess, restore_path)
                        continue

                    if step % args.summary_interval == 0:
                        log('Writing summary at step: %d' % step)
                        summary_writer.add_summary(sess.run(stats), step)

                    if step % args.checkpoint_interval == 0:
                        crrt_dir = os.path.join(log_dir, str(step))
                        os.makedirs(crrt_dir, exist_ok=True)

                        log('Saving checkpoint to: %s-%d' %
                            (checkpoint_path, step))
                        saver.save(sess, checkpoint_path, global_step=step)
                        log('Saving audio and alignment...')
                        input_seq, spectrogram, alignment, wav_original, melspectogram, spec_original, mel_original, \
                        identity2 = sess.run([models[0].inputs[0], models[0].linear_outputs[0], models[0].alignments[0],
                                              wavs[0],models[0].mel_outputs[0], linear_targets[0], mel_targets[0],
                                              identities[0]])
                        waveform = audio.inv_spectrogram(spectrogram.T)
                        audio.save_wav(
                            waveform,
                            os.path.join(crrt_dir, 'step-%d-audio.wav' % step))
                        audio.save_wav(
                            wav_original,
                            os.path.join(
                                crrt_dir, 'step-%d-audio-original-%d.wav' %
                                (step, identity2)))
                        np.save(os.path.join(crrt_dir, 'spec.npy'),
                                spectrogram,
                                allow_pickle=False)
                        np.save(os.path.join(crrt_dir, 'melspectogram.npy'),
                                melspectogram,
                                allow_pickle=False)
                        np.save(os.path.join(crrt_dir, 'spec_original.npy'),
                                spec_original,
                                allow_pickle=False)
                        np.save(os.path.join(crrt_dir, 'mel_original.npy'),
                                mel_original,
                                allow_pickle=False)
                        plot.plot_alignment(
                            alignment,
                            os.path.join(crrt_dir, 'step-%d-align.png' % step),
                            info='%s, %s, step=%d, loss=%.5f' %
                            (args.model, time_string(), step, loss))

                        #提取alignment, 看看对其效果如何
                        transition_params = []
                        for i in range(alignment.shape[0]):
                            transition_params.append([])
                            for j in range(alignment.shape[0]):
                                if i == j or j - i == 1:
                                    transition_params[-1].append(500)
                                else:
                                    transition_params[-1].append(0.0)
                        alignment[0][0] = 100000
                        alignment2 = np.argmax(alignment, axis=0)
                        alignment3 = tf.contrib.crf.viterbi_decode(
                            alignment.T, transition_params)
                        alignment4 = np.zeros(alignment.shape)
                        for i, item in enumerate(alignment3[0]):
                            alignment4[item, i] = 1
                        plot.plot_alignment(
                            alignment4,
                            os.path.join(crrt_dir,
                                         'step-%d-align2.png' % step),
                            info='%s, %s, step=%d, loss=%.5f' %
                            (args.model, time_string(), step, loss))

                        crrt = 0
                        sample_crrt = 0
                        sample_last = 0
                        for i, item in enumerate(alignment3[0]):
                            if item == crrt:
                                sample_crrt += hparams.sample_rate * hparams.frame_shift_ms * hparams.outputs_per_step\
                                               / 1000
                            if not item == crrt:
                                crrt += 1
                                sample_crrt = int(sample_crrt)
                                sample_last = int(sample_last)
                                wav_crrt = waveform[:sample_crrt]
                                wav_crrt2 = waveform[sample_last:sample_crrt]
                                audio.save_wav(
                                    wav_crrt,
                                    os.path.join(crrt_dir, '%d.wav' % crrt))
                                audio.save_wav(
                                    wav_crrt2,
                                    os.path.join(crrt_dir, '%d-2.wav' % crrt))
                                sample_last = sample_crrt
                                sample_crrt += hparams.sample_rate * hparams.frame_shift_ms * hparams.outputs_per_step \
                                               / 1000

                        input_seq2 = []
                        input_seq3 = []
                        for item in alignment2:
                            input_seq2.append(input_seq[item])
                        for item in alignment3[0]:
                            input_seq3.append(input_seq[item])

                        #output alignment
                        path_align1 = os.path.join(crrt_dir,
                                                   'step-%d-align1.txt' % step)
                        path_align2 = os.path.join(crrt_dir,
                                                   'step-%d-align2.txt' % step)
                        path_align3 = os.path.join(crrt_dir,
                                                   'step-%d-align3.txt' % step)
                        path_seq1 = os.path.join(crrt_dir,
                                                 'step-%d-input1.txt' % step)
                        path_seq2 = os.path.join(crrt_dir,
                                                 'step-%d-input2.txt' % step)
                        path_seq3 = os.path.join(crrt_dir,
                                                 'step-%d-input3.txt' % step)
                        with open(path_align1, 'w') as f:
                            for row in alignment:
                                for item in row:
                                    f.write('%.3f' % item)
                                    f.write('\t')
                                f.write('\n')
                        with open(path_align2, 'w') as f:
                            for item in alignment2:
                                f.write('%.3f' % item)
                                f.write('\t')
                        with open(path_align3, 'w') as f:
                            for item in alignment3[0]:
                                f.write('%.3f' % item)
                                f.write('\t')
                        with open(path_seq1, 'w') as f:
                            f.write(sequence_to_text(input_seq))
                        with open(path_seq2, 'w') as f:
                            f.write(sequence_to_text(input_seq2))
                        with open(path_seq3, 'w') as f:
                            f.write(sequence_to_text(input_seq3))
                        log('Input: %s' % sequence_to_text(input_seq))
                        log('Input: %s' % str(input_seq))

            except Exception as e:
                log('Exiting due to exception: %s' % e)
                traceback.print_exc()
                coord.request_stop(e)
Beispiel #8
0
def train(log_dir, args):
    commit = get_git_commit() if args.git else 'None'
    checkpoint_path = os.path.join(log_dir, 'model.ckpt')
    ## input path is lists of both postive path and negtiva path
    input_path_pos = os.path.join(args.base_dir, args.input_pos)
    input_path_neg = os.path.join(args.base_dir, args.input_neg)

    log('Checkpoint path: %s' % checkpoint_path)
    log('Loading positive training data from: %s' % input_path_pos)
    log('Loading negative training data from: %s' % input_path_neg)
    log('Using model: %s' % args.model)
    log(hparams_debug_string())

    # Set up DataFeeder:
    coord = tf.train.Coordinator()
    with tf.variable_scope('datafeeder') as scope:
        feeder = DataFeeder(coord, input_path_pos, input_path_neg, hparams)

    # Set up model:
    global_step = tf.Variable(0, name='global_step', trainable=False)
    with tf.variable_scope('model') as scope:
        model = create_model(args.model, hparams)
        model.initialize(feeder.inputs_pos, feeder.input_lengths_pos,
                         feeder.mel_targets_pos, feeder.linear_targets_pos,
                         feeder.mel_targets_neg, feeder.linear_targets_neg,
                         feeder.labels_pos, feeder.labels_neg)
        model.add_loss()
        model.add_optimizer(global_step)

    # Bookkeeping:
    step = 0
    time_window = ValueWindow(100)
    loss_window = ValueWindow(100)
    saver = tf.train.Saver(max_to_keep=5, keep_checkpoint_every_n_hours=2)

    # Train!
    with tf.Session() as sess:
        try:
            #summary_writer = tf.summary.FileWriter(log_dir, sess.graph)
            sess.run(tf.global_variables_initializer())

            if args.restore_step:
                # Restore from a checkpoint if the user requested it.
                restore_path = '%s-%d' % (checkpoint_path, args.restore_step)
                saver.restore(sess, restore_path)
                log('Resuming from checkpoint: %s at commit: %s' %
                    (restore_path, commit),
                    slack=True)
            else:
                log('Starting new training run at commit: %s' % commit,
                    slack=True)

            feeder.start_in_session(sess)

            while not coord.should_stop():
                start_time = time.time()
                # train d
                sess.run(model.d_optimize)
                # train g
                step, rec_loss, style_loss, d_loss, g_loss, _ = sess.run([
                    global_step, model.rec_loss, model.style_loss,
                    model.d_loss, model.g_loss, model.g_optimize
                ])
                time_window.append(time.time() - start_time)
                message = 'Step %-7d [%.03f sec/step, rec_loss=%.05f, style_loss=%.05f, d_loss=%.05f, g_loss=%.05f]' % (
                    step, time_window.average, rec_loss, style_loss, d_loss,
                    g_loss)
                log(message, slack=(step % args.checkpoint_interval == 0))

                if step % args.checkpoint_interval == 0:
                    log('Saving checkpoint to: %s-%d' %
                        (checkpoint_path, step))
                    saver.save(sess, checkpoint_path, global_step=step)
                    log('Saving audio and alignment...')
                    input_seq, spectrogram_pos, spectrogram_neg, alignment_pos, alignment_neg = sess.run(
                        [
                            model.inputs[0], model.linear_outputs_pos[0],
                            model.linear_outputs_neg[0],
                            model.alignments_pos[0], model.alignments_neg[0]
                        ])

                    waveform_pos = audio.inv_spectrogram(spectrogram_pos.T)
                    waveform_neg = audio.inv_spectrogram(spectrogram_neg.T)
                    audio.save_wav(
                        waveform_pos,
                        os.path.join(log_dir, 'step-%d-audio_pos.wav' % step))
                    audio.save_wav(
                        waveform_neg,
                        os.path.join(log_dir, 'step-%d-audio_neg.wav' % step))
                    plot.plot_alignment(
                        alignment_pos,
                        os.path.join(log_dir, 'step-%d-align_pos.png' % step),
                        info='%s, %s, %s, step=%d, loss=%.5f' %
                        (args.model, commit, time_string(), step, rec_loss))
                    plot.plot_alignment(
                        alignment_neg,
                        os.path.join(log_dir, 'step-%d-align_neg.png' % step),
                        info='%s, %s, %s, step=%d, loss=%.5f' %
                        (args.model, commit, time_string(), step, rec_loss))
                    log('Input: %s' % sequence_to_text(input_seq))

        except Exception as e:
            log('Exiting due to exception: %s' % e, slack=True)
            traceback.print_exc()
            coord.request_stop(e)
Beispiel #9
0
def train(log_dir, args, hparams, input_path):
	save_dir = os.path.join(log_dir, 'wave_pretrained/')
	eval_dir = os.path.join(log_dir, 'eval-dir')
	audio_dir = os.path.join(log_dir, 'wavs')
	plot_dir = os.path.join(log_dir, 'plots')
	wav_dir = os.path.join(log_dir, 'wavs')
	eval_audio_dir = os.path.join(eval_dir, 'wavs')
	eval_plot_dir = os.path.join(eval_dir, 'plots')
	checkpoint_path = os.path.join(save_dir, 'wavenet_model.ckpt')
	input_path = os.path.join(args.base_dir, input_path)
	os.makedirs(save_dir, exist_ok=True)
	os.makedirs(wav_dir, exist_ok=True)
	os.makedirs(audio_dir, exist_ok=True)
	os.makedirs(plot_dir, exist_ok=True)
	os.makedirs(eval_audio_dir, exist_ok=True)
	os.makedirs(eval_plot_dir, exist_ok=True)

	log('Checkpoint_path: {}'.format(checkpoint_path))
	log('Loading training data from: {}'.format(input_path))
	log('Using model: {}'.format(args.model))
	log(hparams_debug_string())

	#Start by setting a seed for repeatability
	tf.set_random_seed(hparams.wavenet_random_seed)

	#Set up data feeder
	coord = tf.train.Coordinator()
	with tf.variable_scope('datafeeder') as scope:
		feeder = Feeder(coord, input_path, args.base_dir, hparams)

	#Set up model
	global_step = tf.Variable(0, name='global_step', trainable=False)
	model, stats = model_train_mode(args, feeder, hparams, global_step)
	eval_model = model_test_mode(args, feeder, hparams, global_step)

	#book keeping
	step = 0
	time_window = ValueWindow(100)
	loss_window = ValueWindow(100)
	sh_saver = create_shadow_saver(model, global_step)

	log('Wavenet training set to a maximum of {} steps'.format(args.wavenet_train_steps))

	#Memory allocation on the memory
	config = tf.ConfigProto()
	config.gpu_options.allow_growth = True

	#Train
	with tf.Session(config=config) as sess:
		try:
			summary_writer = tf.summary.FileWriter(log_dir, sess.graph)
			sess.run(tf.global_variables_initializer())

			#saved model restoring
			if args.restore_step:
				#Restore saved model if the user requested it, default = True
				try:
					checkpoint_state = tf.train.get_checkpoint_state(save_dir)
				except tf.errors.OutOfRangeError as e:
					log('Cannot restore checkpoint: {}'.format(e))

			if (checkpoint_state and checkpoint_state.model_checkpoint_path):
				log('Loading checkpoint {}'.format(checkpoint_state.model_checkpoint_path))
				load_averaged_model(sess, sh_saver, checkpoint_state.model_checkpoint_path)

			else:
				if not args.restore_step:
					log('Starting new training!')
				else:
					log('No model to load at {}'.format(save_dir))

			#initializing feeder
			feeder.start_threads(sess)

			#Training loop
			while not coord.should_stop() and step < args.wavenet_train_steps:
				start_time = time.time()
				step, y_hat, loss, opt = sess.run([global_step, model.y_hat, model.loss, model.optimize])
				time_window.append(time.time() - start_time)
				loss_window.append(loss)

				message = 'Step {:7d} [{:.3f} sec/step, loss={:.5f}, avg_loss={:.5f}]'.format(
					step, time_window.average, loss, loss_window.average)
				log(message, end='\r')

				if loss > 100 or np.isnan(loss):
					log('Loss exploded to {:.5f} at step {}'.format(loss, step))
					raise Exception('Loss exploded')

				if step % args.summary_interval == 0:
					log('\nWriting summary at step {}'.format(step))
					summary_writer.add_summary(sess.run(stats), step)

				if step % args.checkpoint_interval == 0 or step == args.wavenet_train_steps:
					save_log(sess, step, model, plot_dir, audio_dir, hparams=hparams)
					save_checkpoint(sess, sh_saver, checkpoint_path, global_step)

				if step % args.eval_interval == 0:
					log('\nEvaluating at step {}'.format(step))
					eval_step(sess, step, eval_model, eval_plot_dir, eval_audio_dir, summary_writer=summary_writer , hparams=model._hparams)

			log('Wavenet training complete after {} global steps'.format(args.wavenet_train_steps))
			return save_dir

		except Exception as e:
			log('Exiting due to Exception: {}'.format(e))
Beispiel #10
0
def gst_train(log_dir, args):
	commit = get_git_commit() if args.git else 'None'
	save_dir = os.path.join(log_dir, 'gst_pretrained/')
	checkpoint_path = os.path.join(save_dir, 'gst_model.ckpt')
	input_path = os.path.join(args.base_dir, args.gst_input)
	plot_dir = os.path.join(log_dir, 'plots')
	wav_dir = os.path.join(log_dir, 'wavs')
	mel_dir = os.path.join(log_dir, 'mel-spectrograms')
	eval_dir = os.path.join(log_dir, 'eval-dir')
	eval_plot_dir = os.path.join(eval_dir, 'plots')
	eval_wav_dir = os.path.join(eval_dir, 'wavs')
	os.makedirs(eval_dir, exist_ok=True)
	os.makedirs(plot_dir, exist_ok=True)
	os.makedirs(wav_dir, exist_ok=True)
	os.makedirs(mel_dir, exist_ok=True)
	os.makedirs(eval_plot_dir, exist_ok=True)
	os.makedirs(eval_wav_dir, exist_ok=True)

	log('Checkpoint path: %s' % checkpoint_path)
	log('Loading training data from: %s' % input_path)
	log(hparams_debug_string())

	#Start by setting a seed for repeatability
	tf.set_random_seed(hparams.random_seed)
	
	# Set up DataFeeder:
	coord = tf.train.Coordinator()
	with tf.variable_scope('datafeeder') as scope:
		feeder = DataFeeder(coord, input_path, hparams)

	# Set up model:
	global_step = tf.Variable(0, name='global_step', trainable=False)
	model, stats = model_train_mode(args, feeder, hparams, global_step)
	eval_model = model_test_mode(args, feeder, hparams, global_step)
		
	# Bookkeeping:
	step = 0
	time_window = ValueWindow(100)
	loss_window = ValueWindow(100)
	saver = tf.train.Saver(max_to_keep=5, keep_checkpoint_every_n_hours=2)

	#Memory allocation on the GPU as needed
	config = tf.ConfigProto()
	config.gpu_options.allow_growth = True


	# Train!
	with tf.Session(config=config) as sess:
		try:
			summary_writer = tf.summary.FileWriter(log_dir, sess.graph)
			sess.run(tf.global_variables_initializer())
			checkpoint_state = False
			#saved model restoring
			if args.restore_step:
				#Restore saved model if the user requested it, Default = True.
				try:
					checkpoint_state = tf.train.get_checkpoint_state(save_dir)
				except tf.errors.OutOfRangeError as e:
					log('Cannot restore checkpoint: {}'.format(e))

			if (checkpoint_state and checkpoint_state.model_checkpoint_path):
				log('Loading checkpoint {}'.format(checkpoint_state.model_checkpoint_path))
				saver.restore(sess, checkpoint_state.model_checkpoint_path)

			else:
				if not args.restore_step:
					log('Starting new training!')
				else:
					log('No model to load at {}'.format(save_dir))

			feeder.start_in_session(sess)

			while not coord.should_stop() and step < args.gst_train_steps:
				start_time = time.time()
				step, loss, opt = sess.run([global_step, model.loss, model.optimize])
				time_window.append(time.time() - start_time)
				loss_window.append(loss)
				message = 'Step %-7d [%.03f sec/step, loss=%.05f, avg_loss=%.05f]' % (
					step, time_window.average, loss, loss_window.average)
				log(message, slack=(step % args.checkpoint_interval == 0))

				if loss > 100 or math.isnan(loss):
					log('Loss exploded to %.05f at step %d!' % (loss, step), slack=True)
					raise Exception('Loss Exploded')

				if step % args.summary_interval == 0:
					log('Writing summary at step: %d' % step)
					summary_writer.add_summary(sess.run(stats), step)

				if step % args.eval_interval == 0:
					#Run eval and save eval stats
					log('\nRunning evaluation at step {}'.format(step))

					eval_losses = []
					linear_losses = []

					#TODO: FIX TO ENCOMPASS MORE LOSS
					for i in tqdm(range(feeder.test_steps)):
									eloss, linear_loss, mel_p, mel_t, t_len, align, lin_p = sess.run([eval_model.loss, eval_model.linear_loss, 
										eval_model.mel_outputs[0], eval_model.mel_targets[0], eval_model.targets_lengths[0], eval_model.alignments[0], 
										eval_model.linear_outputs[0]])
									eval_losses.append(eloss)
									linear_losses.append(linear_loss)


					eval_loss = sum(eval_losses) / len(eval_losses)
					linear_loss = sum(linear_losses) / len(linear_losses)

					wav = audio.inv_linear_spectrogram(lin_p.T)
					audio.save_wav(wav, os.path.join(eval_wav_dir, 'step-{}-eval-waveform-linear.wav'.format(step)))
					log('Saving eval log to {}..'.format(eval_dir))
					#Save some log to monitor model improvement on same unseen sequence

					wav = audio.inv_mel_spectrogram(mel_p.T)
					audio.save_wav(wav, os.path.join(eval_wav_dir, 'step-{}-eval-waveform-mel.wav'.format(step)))

					plot.plot_alignment(align, os.path.join(eval_plot_dir, 'step-{}-eval-align.png'.format(step)),
									info='{}, {}, step={}, loss={:.5f}'.format(args.model, time_string(), step, eval_loss),
									max_len=t_len // hparams.outputs_per_step)
					plot.plot_spectrogram(mel_p, os.path.join(eval_plot_dir, 'step-{}-eval-mel-spectrogram.png'.format(step)),
									info='{}, {}, step={}, loss={:.5}'.format(args.model, time_string(), step, eval_loss), target_spectrogram=mel_t,
									)

					log('Eval loss for global step {}: {:.3f}'.format(step, eval_loss))
					log('Writing eval summary!')
					add_eval_stats(summary_writer, step, linear_loss, eval_loss)      

				if step % args.checkpoint_interval == 0 or step == args.gst_train_steps:
					log('Saving checkpoint to: %s-%d' % (checkpoint_path, step))
					saver.save(sess, checkpoint_path, global_step=step)
					log('Saving audio and alignment...')
					input_seq, mel_pred, alignment, target, target_len = sess.run([model.inputs[0],
						model.mel_outputs[0],
						model.alignments[0],
						model.mel_targets[0],
						model.targets_lengths[0],
						])
					
					#save predicted mel spectrogram to disk (debug)
					mel_filename = 'mel-prediction-step-{}.npy'.format(step)
					np.save(os.path.join(mel_dir, mel_filename), mel_pred.T, allow_pickle=False)

					#save griffin lim inverted wav for debug (mel -> wav)
					wav = audio.inv_mel_spectrogram(mel_pred.T)
					audio.save_wav(wav, os.path.join(wav_dir, 'step-{}-wave-from-mel.wav'.format(step)))

					#save alignment plot to disk (control purposes)
					plot.plot_alignment(alignment, os.path.join(plot_dir, 'step-{}-align.png'.format(step)),
									info='{}, {}, step={}, loss={:.5f}'.format(args.model, time_string(), step, loss),
									max_len=target_len // hparams.outputs_per_step)
					#save real and predicted mel-spectrogram plot to disk (control purposes)
					plot.plot_spectrogram(mel_pred, os.path.join(plot_dir, 'step-{}-mel-spectrogram.png'.format(step)),
									info='{}, {}, step={}, loss={:.5}'.format(args.model, time_string(), step, loss), target_spectrogram=target,
									max_len=target_len)
					log('Input at step {}: {}'.format(step, sequence_to_text(input_seq)))

			log('GST Taco training complete after {} global steps!'.format(args.gst_train_steps))
			return save_dir

		
		except Exception as e:
			log('Exiting due to exception: %s' % e, slack=True)
			traceback.print_exc()
			coord.request_stop(e)
Beispiel #11
0
def train(log_dir, args):
    commit = get_git_commit() if args.git else 'None'
    checkpoint_path = os.path.join(log_dir, 'model.ckpt')
    input_path = os.path.join(args.base_dir, args.input)
    log('Checkpoint path: %s' % checkpoint_path)
    log('Loading training data from: %s' % input_path)
    log('Using model: %s' % args.model)
    log(hparams_debug_string())

    # Set up DataFeeder:
    coord = tf.train.Coordinator()
    with tf.variable_scope('datafeeder') as _:
        feeder = DataFeeder(coord, input_path, hparams)

    # Set up model:
    global_step = tf.Variable(0, name='global_step', trainable=False)
    with tf.variable_scope('model') as _:
        model = create_model(args.model, hparams)
        model.initialize(feeder.inputs, args.vgg19_pretrained_model,
                         feeder.mel_targets, feeder.linear_targets)
        model.add_loss()
        model.add_optimizer(global_step)
        stats = add_stats(model)

    # Bookkeeping:
    time_window = ValueWindow()
    loss_window = ValueWindow()
    saver = tf.train.Saver(keep_checkpoint_every_n_hours=2)

    # Train!
    with tf.Session() as sess:
        try:
            train_start_time = time.time()
            summary_writer = tf.summary.FileWriter(log_dir, sess.graph)
            sess.run(tf.global_variables_initializer())

            if args.restore_step:
                # Restore from a checkpoint if the user requested it.
                restore_path = '%s-%d' % (checkpoint_path, args.restore_step)
                checkpoint_saver = tf.train.import_meta_graph(
                    '%s.%s' % (restore_path, 'meta'))
                checkpoint_saver.restore(sess, restore_path)
                log('Resuming from checkpoint: %s at commit: %s' %
                    (restore_path, commit))
            else:
                log('Starting new training run at commit: %s' % commit)

            feeder.start_in_session(sess)

            while not coord.should_stop():
                start_time = time.time()
                step, loss, opt = sess.run(
                    [global_step, model.loss, model.optimize])
                time_window.append(time.time() - start_time)
                loss_window.append(loss)
                message = 'Step %-7d [%.03f sec/step, loss=%.05f, avg_loss=%.05f]' % (
                    step, time_window.average, loss, loss_window.average)
                log(message, slack=(step % args.summary_interval == 0))

                if loss > 100 or math.isnan(loss):
                    log('Loss exploded to %.05f at step %d!' % (loss, step))
                    raise Exception('Loss Exploded')

                if step % args.summary_interval == 0:
                    log('Writing summary at step: %d' % step)
                    summary_writer.add_summary(sess.run(stats), step)

                if step % args.checkpoint_interval == 0:
                    log('Saving checkpoint to: %s-%d' %
                        (checkpoint_path, step))
                    saver.save(sess, checkpoint_path, global_step=step)
                    log('Saving audio...')
                    _, spectrogram, _ = sess.run([
                        model.inputs[0], model.linear_outputs[0],
                        model.alignments[0]
                    ])
                    waveform = audio.inv_spectrogram(spectrogram.T)
                    audio_path = os.path.join(log_dir,
                                              'step-%d-audio.wav' % step)
                    audio.save_wav(waveform, audio_path)

                    infolog.upload_to_slack(audio_path, step)

                    time_so_far = time.time() - train_start_time
                    hrs, rest = divmod(time_so_far, 3600)
                    min, secs = divmod(rest, 60)
                    log('{:.0f} hrs, {:.0f}mins and {:.1f}sec since the training process began'
                        .format(hrs, min, secs))

                if asked_to_stop(step):
                    coord.request_stop()

        except Exception as e:
            log('@channel: Exiting due to exception: %s' % e)
            traceback.print_exc()
            coord.request_stop(e)
Beispiel #12
0
def train(log_dir, args):
    commit = get_git_commit() if args.git else 'None'
    checkpoint_path = os.path.join(log_dir, 'model.ckpt')
    input_path = os.path.join(args.base_dir, args.input)
    log('Checkpoint path: %s' % checkpoint_path)
    log('Loading training data from: %s' % input_path)
    log('Using model: %s' % args.model)
    log(hparams_debug_string())

    # graph
    with tf.Graph().as_default(), tf.device('/cpu:0'):

        #new attributes of hparams
        #hparams.num_GPU = len(GPUs_id)
        #hparams.datasets = eval(args.datasets)
        hparams.datasets = eval(args.datasets)
        hparams.prenet_layer1 = args.prenet_layer1
        hparams.prenet_layer2 = args.prenet_layer2
        hparams.gru_size = args.gru_size
        hparams.attention_size = args.attention_size
        hparams.rnn_size = args.rnn_size
        hparams.enable_fv1 = args.enable_fv1
        hparams.enable_fv2 = args.enable_fv2

        if args.batch_size:
            hparams.batch_size = args.batch_size

        # Multi-GPU settings
        GPUs_id = eval(args.GPUs_id)
        hparams.num_GPU = len(GPUs_id)
        tower_grads = []
        tower_loss = []
        models = []

        global_step = tf.Variable(-1, name='global_step', trainable=False)
        if hparams.decay_learning_rate:
            learning_rate = _learning_rate_decay(hparams.initial_learning_rate,
                                                 global_step, hparams.num_GPU)
        else:
            learning_rate = tf.convert_to_tensor(hparams.initial_learning_rate)
        # Set up DataFeeder:
        coord = tf.train.Coordinator()
        with tf.variable_scope('datafeeder') as scope:
            input_path = os.path.join(args.base_dir, args.input)
            feeder = DataFeeder(coord, input_path, hparams)
            inputs = feeder.inputs
            inputs = tf.split(inputs, hparams.num_GPU, 0)
            input_lengths = feeder.input_lengths
            input_lengths = tf.split(input_lengths, hparams.num_GPU, 0)
            mel_targets = feeder.mel_targets
            mel_targets = tf.split(mel_targets, hparams.num_GPU, 0)
            linear_targets = feeder.linear_targets
            linear_targets = tf.split(linear_targets, hparams.num_GPU, 0)

        # Set up model:
        with tf.variable_scope('model') as scope:
            optimizer = tf.train.AdamOptimizer(learning_rate,
                                               hparams.adam_beta1,
                                               hparams.adam_beta2)
            for i, GPU_id in enumerate(GPUs_id):
                with tf.device('/gpu:%d' % GPU_id):
                    with tf.name_scope('GPU_%d' % GPU_id):

                        if hparams.enable_fv1 or hparams.enable_fv2:
                            net = ResCNN(data=mel_targets[i],
                                         batch_size=hparams.batch_size,
                                         hyparam=hparams)
                            net.inference()

                            voice_print_feature = tf.reduce_mean(
                                net.features, 0)
                        else:
                            voice_print_feature = None

                        models.append(None)
                        models[i] = create_model(args.model, hparams)
                        models[i].initialize(
                            inputs=inputs[i],
                            input_lengths=input_lengths[i],
                            mel_targets=mel_targets[i],
                            linear_targets=linear_targets[i],
                            voice_print_feature=voice_print_feature)
                        models[i].add_loss()
                        """L2 weight decay loss."""
                        if args.weight_decay > 0:
                            costs = []
                            for var in tf.trainable_variables():
                                #if var.op.name.find(r'DW') > 0:
                                costs.append(tf.nn.l2_loss(var))
                                # tf.summary.histogram(var.op.name, var)
                            weight_decay = tf.cast(args.weight_decay,
                                                   tf.float32)
                            cost = models[i].loss
                            models[i].loss += tf.multiply(
                                weight_decay, tf.add_n(costs))
                            cost_pure_wd = tf.multiply(weight_decay,
                                                       tf.add_n(costs))
                        else:
                            cost = models[i].loss
                            cost_pure_wd = tf.constant([0])

                        tower_loss.append(models[i].loss)

                        tf.get_variable_scope().reuse_variables()
                        models[i].add_optimizer(global_step, optimizer)

                        tower_grads.append(models[i].gradients)

            # calculate average gradient
            gradients = average_gradients(tower_grads)

            stats = add_stats(models[0], gradients, learning_rate)
            time.sleep(10)

        # apply average gradient

        with tf.control_dependencies(tf.get_collection(
                tf.GraphKeys.UPDATE_OPS)):
            apply_gradient_op = optimizer.apply_gradients(
                gradients, global_step=global_step)

        # Bookkeeping:
        step = 0
        time_window = ValueWindow(100)
        loss_window = ValueWindow(100)
        saver = tf.train.Saver(max_to_keep=5, keep_checkpoint_every_n_hours=2)

        # Train!
        config = tf.ConfigProto(log_device_placement=False,
                                allow_soft_placement=True)
        config.gpu_options.allow_growth = True
        with tf.Session(config=config) as sess:
            try:
                summary_writer = tf.summary.FileWriter(log_dir, sess.graph)
                sess.run(tf.global_variables_initializer())

                if args.restore_step:
                    # Restore from a checkpoint if the user requested it.
                    restore_path = '%s-%d' % (checkpoint_path,
                                              args.restore_step)
                    saver.restore(sess, restore_path)
                    log('Resuming from checkpoint: %s at commit: %s' %
                        (restore_path, commit),
                        slack=True)
                else:
                    log('Starting new training run at commit: %s' % commit,
                        slack=True)

                feeder.start_in_session(sess)

                while not coord.should_stop():
                    start_time = time.time()
                    model = models[0]

                    step, loss, opt, loss_wd, loss_pure_wd = sess.run([
                        global_step, cost, apply_gradient_op, model.loss,
                        cost_pure_wd
                    ])
                    feeder._batch_in_queue -= 1
                    log('feed._batch_in_queue: %s' %
                        str(feeder._batch_in_queue),
                        slack=True)

                    time_window.append(time.time() - start_time)
                    loss_window.append(loss)
                    message = 'Step %-7d [%.03f sec/step, loss=%.05f, avg_loss=%.05f, loss_wd=%.05f, loss_pure_wd=%.05f]' % (
                        step, time_window.average, loss, loss_window.average,
                        loss_wd, loss_pure_wd)
                    log(message, slack=(step % args.checkpoint_interval == 0))

                    #if the gradient seems to explode, then restore to the previous step
                    if loss > 2 * loss_window.average or math.isnan(loss):
                        log('recover to the previous checkpoint')
                        #tf.reset_default_graph()
                        restore_step = int(
                            (step - 10) / args.checkpoint_interval
                        ) * args.checkpoint_interval
                        restore_path = '%s-%d' % (checkpoint_path,
                                                  restore_step)
                        saver.restore(sess, restore_path)
                        continue

                    if loss > 100 or math.isnan(loss):
                        log('Loss exploded to %.05f at step %d!' %
                            (loss, step),
                            slack=True)
                        raise Exception('Loss Exploded')

                    try:
                        if step % args.summary_interval == 0:
                            log('Writing summary at step: %d' % step)
                            summary_writer.add_summary(sess.run(stats), step)
                    except:
                        pass

                    if step % args.checkpoint_interval == 0:
                        log('Saving checkpoint to: %s-%d' %
                            (checkpoint_path, step))
                        saver.save(sess, checkpoint_path, global_step=step)
                        log('Saving audio and alignment...')
                        input_seq, spectrogram, alignment = sess.run([
                            model.inputs[0], model.linear_outputs[0],
                            model.alignments[0]
                        ])
                        waveform = audio.inv_spectrogram(spectrogram.T)
                        audio.save_wav(
                            waveform,
                            os.path.join(log_dir, 'step-%d-audio.wav' % step))
                        plot.plot_alignment(
                            alignment,
                            os.path.join(log_dir, 'step-%d-align.png' % step),
                            info='%s, %s, %s, step=%d, loss=%.5f' %
                            (args.model, commit, time_string(), step, loss))
                        log('Input: %s' % sequence_to_text(input_seq))

            except Exception as e:
                log('Exiting due to exception: %s' % e, slack=True)
                traceback.print_exc()
                coord.request_stop(e)
Beispiel #13
0
def train(log_dir, args):
    commit = get_git_commit() if args.git else 'None'
    checkpoint_path = os.path.join(log_dir, 'model.ckpt')
    DATA_PATH = {'bznsyp': "BZNSYP", 'ljspeech': "LJSpeech-1.1"}[args.dataset]
    input_path = os.path.join(args.base_dir, 'DATA', DATA_PATH, 'training',
                              'train.txt')
    log('Checkpoint path: %s' % checkpoint_path)
    log('Loading training data from: %s' % input_path)
    log('Using model: %s' % args.model)
    log(hparams_debug_string())

    # Set up DataFeeder:
    coord = tf.train.Coordinator()
    with tf.variable_scope('datafeeder') as scope:
        feeder = DataFeeder(coord, input_path, hparams)

    # Set up model:
    global_step = tf.Variable(0, name='global_step', trainable=False)
    with tf.variable_scope('model') as scope:
        model = create_model(args.model, hparams)
        model.initialize(feeder.inputs, feeder.input_lengths,
                         feeder.lpc_targets, feeder.stop_token_targets)
        model.add_loss()
        model.add_optimizer(global_step)
        stats = add_stats(model)

    # Bookkeeping:
    step = 0
    time_window = ValueWindow(100)
    loss_window = ValueWindow(100)
    saver = tf.train.Saver(max_to_keep=999, keep_checkpoint_every_n_hours=2)

    # Train!
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    with tf.Session(config=config) as sess:
        try:
            summary_writer = tf.summary.FileWriter(log_dir, sess.graph)
            sess.run(tf.global_variables_initializer())

            if args.restore_step:
                # Restore from a checkpoint if the user requested it.
                checkpoint_state = tf.train.get_checkpoint_state(log_dir)
                # restore_path = '%s-%d' % (checkpoint_path, args.restore_step)
                if checkpoint_state is not None:
                    saver.restore(sess, checkpoint_state.model_checkpoint_path)
                    log('Resuming from checkpoint: %s at commit: %s' %
                        (checkpoint_state.model_checkpoint_path, commit),
                        slack=True)
            else:
                log('Starting new training run at commit: %s' % commit,
                    slack=True)

            if args.restore_decoder:
                models = [
                    f for f in os.listdir('pretrain') if f.find('.meta') != -1
                ]
                decoder_ckpt_path = os.path.join(
                    'pretrain', models[0].replace('.meta', ''))

                global_vars = tf.global_variables()
                var_list = []
                valid_scope = [
                    'model/inference/decoder', 'model/inference/post_cbhg',
                    'model/inference/dense', 'model/inference/memory_layer'
                ]
                for v in global_vars:
                    if v.name.find('attention') != -1:
                        continue
                    if v.name.find('Attention') != -1:
                        continue
                    for scope in valid_scope:
                        if v.name.startswith(scope):
                            var_list.append(v)
                decoder_saver = tf.train.Saver(var_list)
                decoder_saver.restore(sess, decoder_ckpt_path)
                print('restore pretrained decoder ...')

            feeder.start_in_session(sess)

            while not coord.should_stop():
                start_time = time.time()
                step, loss, opt = sess.run(
                    [global_step, model.loss, model.optimize])
                time_window.append(time.time() - start_time)
                loss_window.append(loss)
                message = 'Step %-7d [%.03f sec/step, loss=%.05f, avg_loss=%.05f]' % (
                    step, time_window.average, loss, loss_window.average)
                log(message, slack=(step % args.checkpoint_interval == 0))

                if loss > 100 or math.isnan(loss):
                    log('Loss exploded to %.05f at step %d!' % (loss, step),
                        slack=True)
                    raise Exception('Loss Exploded')

                if step % args.summary_interval == 0:
                    log('Writing summary at step: %d' % step)
                    summary_writer.add_summary(sess.run(stats), step)

                if step % args.checkpoint_interval == 0:
                    log('Saving checkpoint to: %s-%d' %
                        (checkpoint_path, step))
                    saver.save(sess, checkpoint_path, global_step=step)
                    log('Saving audio and alignment...')
                    input_seq, lpc_targets, alignment = sess.run([
                        model.inputs[0], model.lpc_outputs[0],
                        model.alignments[0]
                    ])
                    plot.plot_alignment(
                        alignment,
                        os.path.join(log_dir, 'step-%d-align.png' % step),
                        info='%s, %s, %s, step=%d, loss=%.5f' %
                        (args.model, commit, time_string(), step, loss))
                    np.save(os.path.join(log_dir, 'step-%d-lpc.npy' % step),
                            lpc_targets)
                    log('Input: %s' % sequence_to_text(input_seq))

        except Exception as e:
            log('Exiting due to exception: %s' % e, slack=True)
            traceback.print_exc()
            coord.request_stop(e)
Beispiel #14
0
def train(log_dir, args):
  checkpoint_path = os.path.join(log_dir, 'model.ckpt')
  input_path = os.path.join(args.base_dir, args.input)

  coord = tf.train.Coordinator()
  with tf.variable_scope('datafeeder') as scope:
    feeder = DataFeeder(coord, input_path, hparams)

  global_step = tf.Variable(0, name='global_step', trainable=False)
  with tf.variable_scope('model') as scope:
    model = create_model(args.model, hparams)
    model.initialize(feeder.inputs, feeder.input_lengths, feeder.mel_targets, feeder.linear_targets)
    model.add_loss()
    model.add_optimizer(global_step)
    stats = add_stats(model)

  # Bookkeeping
  time_window = ValueWindow(100)
  loss_window = ValueWindow(100)
  saver = tf.train.Saver(max_to_keep=5, keep_checkpoint_every_n_hours=2)

  # Train
  with tf.Session() as sess:
    try:
      summary_writer = tf.summary.FileWriter(log_dir, sess.graph)
      sess.run(tf.global_variables_initializer())

      if args.restore_step:
        restore_path = '%s-%d' % (checkpoint_path, args.restore_step)
        saver.restore(sess, restore_path)
        log('Resuming from checkpoint: %s' % (restore_path), slack=True)

      feeder.start_in_session(sess)

      while not coord.should_stop():
        start_time = time.time()
        step, loss, opt, mel_loss, linear_loss = \
          sess.run([global_step, model.loss, model.optimize, model.mel_loss, model.linear_loss])
        time_window.append(time.time() - start_time)
        loss_window.append(loss)
        message = 'Step %-7d [%.03f sec/step, loss=%.05f, avg_loss=%.05f, mel_loss=%.5f, linear_loss=%.5f]' % (
          step, time_window.average, loss, loss_window.average, mel_loss, linear_loss)
        log(message, slack=(step % args.checkpoint_interval == 0))

        if step % args.summary_interval == 0:
          log('Writing summary at step: %d' % step)
          summary_writer.add_summary(sess.run(stats), step)

        if step % args.checkpoint_interval == 0:
          log('Saving checkpoint to: %s-%d' % (checkpoint_path, step))
          saver.save(sess, checkpoint_path, global_step=step)
          log('Saving audio and alignment...')
          input_seq, spectrogram, alignment = sess.run([
            model.inputs[0], model.linear_outputs[0], model.alignments[0]])
          waveform = audio.inv_spectrogram(spectrogram.T)
          audio.save_wav(waveform, os.path.join(log_dir, 'step-%d-audio.wav' % step))
          input_seq = sequence_to_text(input_seq)
          plot.plot_alignment(alignment, os.path.join(log_dir, 'step-%d-align.png' % step), input_seq,
            info='%s, step=%d, loss=%.5f' % (args.model, step, loss), istrain=1)
          log('Input: %s' % input_seq)

    except Exception as e:
      log('Exiting due to exception: %s' % e, slack=True)
      traceback.print_exc()
      coord.request_stop(e)