예제 #1
0
파일: train.py 프로젝트: keithito/tacotron
def train(log_dir, args):
  commit = get_git_commit() if args.git else 'None'
  checkpoint_path = os.path.join(log_dir, 'model.ckpt')
  input_path = os.path.join(args.base_dir, args.input)
  log('Checkpoint path: %s' % checkpoint_path)
  log('Loading training data from: %s' % input_path)
  log('Using model: %s' % args.model)
  log(hparams_debug_string())

  # Set up DataFeeder:
  coord = tf.train.Coordinator()
  with tf.variable_scope('datafeeder') as scope:
    feeder = DataFeeder(coord, input_path, hparams)

  # Set up model:
  global_step = tf.Variable(0, name='global_step', trainable=False)
  with tf.variable_scope('model') as scope:
    model = create_model(args.model, hparams)
    model.initialize(feeder.inputs, feeder.input_lengths, feeder.mel_targets, feeder.linear_targets)
    model.add_loss()
    model.add_optimizer(global_step)
    stats = add_stats(model)

  # Bookkeeping:
  step = 0
  time_window = ValueWindow(100)
  loss_window = ValueWindow(100)
  saver = tf.train.Saver(max_to_keep=5, keep_checkpoint_every_n_hours=2)

  # Train!
  with tf.Session() as sess:
    try:
      summary_writer = tf.summary.FileWriter(log_dir, sess.graph)
      sess.run(tf.global_variables_initializer())

      if args.restore_step:
        # Restore from a checkpoint if the user requested it.
        restore_path = '%s-%d' % (checkpoint_path, args.restore_step)
        saver.restore(sess, restore_path)
        log('Resuming from checkpoint: %s at commit: %s' % (restore_path, commit), slack=True)
      else:
        log('Starting new training run at commit: %s' % commit, slack=True)

      feeder.start_in_session(sess)

      while not coord.should_stop():
        start_time = time.time()
        step, loss, opt = sess.run([global_step, model.loss, model.optimize])
        time_window.append(time.time() - start_time)
        loss_window.append(loss)
        message = 'Step %-7d [%.03f sec/step, loss=%.05f, avg_loss=%.05f]' % (
          step, time_window.average, loss, loss_window.average)
        log(message, slack=(step % args.checkpoint_interval == 0))

        if loss > 100 or math.isnan(loss):
          log('Loss exploded to %.05f at step %d!' % (loss, step), slack=True)
          raise Exception('Loss Exploded')

        if step % args.summary_interval == 0:
          log('Writing summary at step: %d' % step)
          summary_writer.add_summary(sess.run(stats), step)

        if step % args.checkpoint_interval == 0:
          log('Saving checkpoint to: %s-%d' % (checkpoint_path, step))
          saver.save(sess, checkpoint_path, global_step=step)
          log('Saving audio and alignment...')
          input_seq, spectrogram, alignment = sess.run([
            model.inputs[0], model.linear_outputs[0], model.alignments[0]])
          waveform = audio.inv_spectrogram(spectrogram.T)
          audio.save_wav(waveform, os.path.join(log_dir, 'step-%d-audio.wav' % step))
          plot.plot_alignment(alignment, os.path.join(log_dir, 'step-%d-align.png' % step),
            info='%s, %s, %s, step=%d, loss=%.5f' % (args.model, commit, time_string(), step, loss))
          log('Input: %s' % sequence_to_text(input_seq))

    except Exception as e:
      log('Exiting due to exception: %s' % e, slack=True)
      traceback.print_exc()
      coord.request_stop(e)
예제 #2
0
def train(log_dir, args):
  commit = get_git_commit() if args.git else 'None'
  checkpoint_path = os.path.join(log_dir, 'model.ckpt')
  input_path = os.path.join(args.base_dir, args.input)
  log('Checkpoint path: %s' % checkpoint_path)
  log('Loading training data from: %s' % input_path)
  log('Using model: %s' % args.model)
  log(hparams_debug_string())

  # Set up DataFeeder:
  coord = tf.train.Coordinator()
  with tf.variable_scope('datafeeder') as scope:
    feeder = DataFeeder(coord, input_path, hparams)

  # Set up model:
  global_step = tf.Variable(0, name='global_step', trainable=False)
  with tf.variable_scope('model') as scope:
    model = create_model(args.model, hparams)
    model.initialize(feeder.inputs, feeder.input_lengths, feeder.mel_targets, feeder.linear_targets)
    model.add_loss()
    model.add_optimizer(global_step)
    stats = add_stats(model)

  # Bookkeeping:
  step = 0
  time_window = ValueWindow(100)
  loss_window = ValueWindow(100)
  saver = tf.train.Saver(max_to_keep=5, keep_checkpoint_every_n_hours=2)

  # Train!

  config = tf.ConfigProto()
  config.gpu_options.allow_growth = True
  with tf.Session(config=config) as sess:
    try:
      summary_writer = tf.summary.FileWriter(log_dir, sess.graph)
      sess.run(tf.global_variables_initializer())

      if args.restore_step:
        # Restore from a checkpoint if the user requested it.
        restore_path = '%s-%d' % (checkpoint_path, args.restore_step)
        saver.restore(sess, restore_path)
        log('Resuming from checkpoint: %s at commit: %s' % (restore_path, commit), slack=True)
      else:
        log('Starting new training run at commit: %s' % commit, slack=True)

      feeder.start_in_session(sess)

      while not coord.should_stop():
        start_time = time.time()
        step, loss, opt = sess.run([global_step, model.loss, model.optimize])
        time_window.append(time.time() - start_time)
        loss_window.append(loss)
        message = 'Step %-7d [%.03f sec/step, loss=%.05f, avg_loss=%.05f]' % (
          step, time_window.average, loss, loss_window.average)
        log(message, slack=(step % args.checkpoint_interval == 0))

        if loss > 100 or math.isnan(loss):
          log('Loss exploded to %.05f at step %d!' % (loss, step), slack=True)
          raise Exception('Loss Exploded')

        if step % args.summary_interval == 0:
          log('Writing summary at step: %d' % step)
          summary_writer.add_summary(sess.run(stats), step)

        if step % args.checkpoint_interval == 0:
          log('Saving checkpoint to: %s-%d' % (checkpoint_path, step))
          saver.save(sess, checkpoint_path, global_step=step)
          log('Saving audio and alignment...')
          input_seq, spectrogram, alignment = sess.run([
            model.inputs[0], model.linear_outputs[0], model.alignments[0]])
          waveform = audio.inv_spectrogram(spectrogram.T)
          audio.save_wav(waveform, os.path.join(log_dir, 'step-%d-audio.wav' % step))
          plot.plot_alignment(alignment, os.path.join(log_dir, 'step-%d-align.png' % step),
            info='%s, %s, %s, step=%d, loss=%.5f' % (args.model, commit, time_string(), step, loss))

    except Exception as e:
      log('Exiting due to exception: %s' % e, slack=True)
      traceback.print_exc()
      coord.request_stop(e)
예제 #3
0
def train(log_dir, config):
    config.data_paths = config.data_paths

    data_dirs = [os.path.join(data_path, "data") \
            for data_path in config.data_paths]
    num_speakers = len(data_dirs)
    config.num_test = config.num_test_per_speaker * num_speakers

    if num_speakers > 1 and hparams.model_type not in ["deepvoice", "simple"]:
        raise Exception("[!] Unkown model_type for multi-speaker: {}".format(
            config.model_type))

    commit = get_git_commit() if config.git else 'None'
    checkpoint_path = os.path.join(log_dir, 'model.ckpt')

    log(' [*] git recv-parse HEAD:\n%s' % get_git_revision_hash())
    log('=' * 50)
    #log(' [*] dit diff:\n%s' % get_git_diff())
    log('=' * 50)
    log(' [*] Checkpoint path: %s' % checkpoint_path)
    log(' [*] Loading training data from: %s' % data_dirs)
    log(' [*] Using model: %s' % config.model_dir)
    log(hparams_debug_string())

    # Set up DataFeeder:
    coord = tf.train.Coordinator()
    with tf.variable_scope('datafeeder') as scope:
        train_feeder = DataFeeder(coord,
                                  data_dirs,
                                  hparams,
                                  config,
                                  32,
                                  data_type='train',
                                  batch_size=hparams.batch_size)
        test_feeder = DataFeeder(coord,
                                 data_dirs,
                                 hparams,
                                 config,
                                 8,
                                 data_type='test',
                                 batch_size=config.num_test)

    # Set up model:
    is_randomly_initialized = config.initialize_path is None
    global_step = tf.Variable(0, name='global_step', trainable=False)
    # 학습횟수 카운팅을 위해. trainable=False는 학습에 관여하지 않음을 명시

    # 타코트론 모델 생성
    with tf.variable_scope('model') as scope:
        model = create_model(hparams)
        model.initialize(train_feeder.inputs,
                         train_feeder.input_lengths,
                         num_speakers,
                         train_feeder.speaker_id,
                         train_feeder.mel_targets,
                         train_feeder.linear_targets,
                         train_feeder.loss_coeff,
                         is_randomly_initialized=is_randomly_initialized)

        model.add_loss()
        model.add_optimizer(global_step)
        train_stats = add_stats(model, scope_name='stats')  # legacy

    with tf.variable_scope('model', reuse=True) as scope:
        test_model = create_model(hparams)
        test_model.initialize(test_feeder.inputs,
                              test_feeder.input_lengths,
                              num_speakers,
                              test_feeder.speaker_id,
                              test_feeder.mel_targets,
                              test_feeder.linear_targets,
                              test_feeder.loss_coeff,
                              rnn_decoder_test_mode=True,
                              is_randomly_initialized=is_randomly_initialized)
        test_model.add_loss()

    test_stats = add_stats(test_model, model, scope_name='test')
    test_stats = tf.summary.merge([test_stats, train_stats])

    # Bookkeeping:
    step = 0
    time_window = ValueWindow(100)
    loss_window = ValueWindow(100)
    saver = tf.train.Saver(max_to_keep=None, keep_checkpoint_every_n_hours=2)

    sess_config = tf.ConfigProto(log_device_placement=False,
                                 allow_soft_placement=True)
    sess_config.gpu_options.allow_growth = True

    # Train!
    #with tf.Session(config=sess_config) as sess:
    with tf.Session() as sess:
        try:
            summary_writer = tf.summary.FileWriter(log_dir, sess.graph)
            sess.run(tf.global_variables_initializer())

            # argument로 기존 체크포인트가 주어지는지 아닌지 부터 확인
            if config.load_path:
                # Restore from a checkpoint if the user requested it.
                restore_path = get_most_recent_checkpoint(config.model_dir)
                saver.restore(sess, restore_path)
                log('Resuming from checkpoint: %s at commit: %s' %
                    (restore_path, commit),
                    slack=True)
            elif config.initialize_path:
                restore_path = get_most_recent_checkpoint(
                    config.initialize_path)
                saver.restore(sess, restore_path)
                log('Initialized from checkpoint: %s at commit: %s' %
                    (restore_path, commit),
                    slack=True)

                # Added code for recognizing whole global steps by soundbrew
                reg_step = re.compile("model.ckpt-(\d+)")
                # reg_initial_target = re.compile("logs/(\w+)_")
                # reg_new_target = re.compile("datasets/(\w+)")

                m1 = reg_step.search(restore_path)
                # m2 = reg_initial_target.search(restore_path)
                # m3 = reg_new_target.search(config.data_paths)

                prev_global_step = m1.group(1)
                # prev_target = m2.group(1)
                # new_target = m3.group(1)

                #if new_target == prev_target:
                #    zero_step_assign = tf.assign(global_step, int(prev_global_step))
                #else:
                # zero_step_assign = tf.assign(global_step, 0)
                zero_step_assign = tf.assign(global_step,
                                             int(prev_global_step))
                sess.run(zero_step_assign)

                start_step = sess.run(global_step)
                log('=' * 50)
                log(' [*] Global step is reset to {}'. \
                        format(start_step))
                log('=' * 50)
            else:
                log('Starting new training run at commit: %s' % commit,
                    slack=True)

            start_step = sess.run(global_step)

            train_feeder.start_in_session(sess, start_step)
            test_feeder.start_in_session(sess, start_step)

            while not coord.should_stop():
                start_time = time.time()
                step, loss, opt = sess.run(
                    [global_step, model.loss_without_coeff, model.optimize],
                    feed_dict=model.get_dummy_feed_dict())

                time_window.append(time.time() - start_time)
                loss_window.append(loss)

                message = 'Step %-7d [%.03f sec/step, loss=%.05f, avg_loss=%.05f]' % (
                    step, time_window.average, loss, loss_window.average)
                log(message, slack=(step % config.checkpoint_interval == 0))

                if loss > 100 or math.isnan(loss):
                    log('Loss exploded to %.05f at step %d!' % (loss, step),
                        slack=True)
                    raise Exception('Loss Exploded')

                if step % config.summary_interval == 0:
                    log('Writing summary at step: %d' % step)

                    feed_dict = {
                        **model.get_dummy_feed_dict(),
                        **test_model.get_dummy_feed_dict()
                    }
                    summary_writer.add_summary(
                        sess.run(test_stats, feed_dict=feed_dict), step)

                if step % config.checkpoint_interval == 0:
                    log('Saving checkpoint to: %s-%d' %
                        (checkpoint_path, step))
                    saver.save(sess, checkpoint_path, global_step=step)

                if step % config.test_interval == 0:
                    log('Saving audio and alignment...')
                    num_test = config.num_test

                    fetches = [
                        model.inputs[:num_test],
                        model.linear_outputs[:num_test],
                        model.alignments[:num_test],
                        test_model.inputs[:num_test],
                        test_model.linear_outputs[:num_test],
                        test_model.alignments[:num_test],
                    ]
                    feed_dict = {
                        **model.get_dummy_feed_dict(),
                        **test_model.get_dummy_feed_dict()
                    }

                    sequences, spectrograms, alignments, \
                            test_sequences, test_spectrograms, test_alignments = \
                                    sess.run(fetches, feed_dict=feed_dict)

                    save_and_plot(sequences[:1], spectrograms[:1],
                                  alignments[:1], log_dir, step, loss, "train")
                    save_and_plot(test_sequences, test_spectrograms,
                                  test_alignments, log_dir, step, loss, "test")

        except Exception as e:
            log('Exiting due to exception: %s' % e, slack=True)
            traceback.print_exc()
            coord.request_stop(e)
예제 #4
0
def train(log_dir, args):
    commit = get_git_commit() if args.git else 'None'
    checkpoint_path = os.path.join(log_dir, 'model.ckpt')
    ## input path is lists of both postive path and negtiva path
    input_path_pos = os.path.join(args.base_dir, args.input_pos)
    input_path_neg = os.path.join(args.base_dir, args.input_neg)

    log('Checkpoint path: %s' % checkpoint_path)
    log('Loading positive training data from: %s' % input_path_pos)
    log('Loading negative training data from: %s' % input_path_neg)
    log('Using model: %s' % args.model)
    log(hparams_debug_string())

    # Set up DataFeeder:
    coord = tf.train.Coordinator()
    with tf.variable_scope('datafeeder') as scope:
        feeder = DataFeeder(coord, input_path_pos, input_path_neg, hparams)

    # Set up model:
    global_step = tf.Variable(0, name='global_step', trainable=False)
    with tf.variable_scope('model') as scope:
        model = create_model(args.model, hparams)
        model.initialize(feeder.inputs_pos, feeder.input_lengths_pos,
                         feeder.mel_targets_pos, feeder.linear_targets_pos,
                         feeder.mel_targets_neg, feeder.linear_targets_neg,
                         feeder.labels_pos, feeder.labels_neg)
        model.add_loss()
        model.add_optimizer(global_step)

    # Bookkeeping:
    step = 0
    time_window = ValueWindow(100)
    loss_window = ValueWindow(100)
    saver = tf.train.Saver(max_to_keep=5, keep_checkpoint_every_n_hours=2)

    # Train!
    with tf.Session() as sess:
        try:
            #summary_writer = tf.summary.FileWriter(log_dir, sess.graph)
            sess.run(tf.global_variables_initializer())

            if args.restore_step:
                # Restore from a checkpoint if the user requested it.
                restore_path = '%s-%d' % (checkpoint_path, args.restore_step)
                saver.restore(sess, restore_path)
                log('Resuming from checkpoint: %s at commit: %s' %
                    (restore_path, commit),
                    slack=True)
            else:
                log('Starting new training run at commit: %s' % commit,
                    slack=True)

            feeder.start_in_session(sess)

            while not coord.should_stop():
                start_time = time.time()
                # train d
                sess.run(model.d_optimize)
                # train g
                step, rec_loss, style_loss, d_loss, g_loss, _ = sess.run([
                    global_step, model.rec_loss, model.style_loss,
                    model.d_loss, model.g_loss, model.g_optimize
                ])
                time_window.append(time.time() - start_time)
                message = 'Step %-7d [%.03f sec/step, rec_loss=%.05f, style_loss=%.05f, d_loss=%.05f, g_loss=%.05f]' % (
                    step, time_window.average, rec_loss, style_loss, d_loss,
                    g_loss)
                log(message, slack=(step % args.checkpoint_interval == 0))

                if step % args.checkpoint_interval == 0:
                    log('Saving checkpoint to: %s-%d' %
                        (checkpoint_path, step))
                    saver.save(sess, checkpoint_path, global_step=step)
                    log('Saving audio and alignment...')
                    input_seq, spectrogram_pos, spectrogram_neg, alignment_pos, alignment_neg = sess.run(
                        [
                            model.inputs[0], model.linear_outputs_pos[0],
                            model.linear_outputs_neg[0],
                            model.alignments_pos[0], model.alignments_neg[0]
                        ])

                    waveform_pos = audio.inv_spectrogram(spectrogram_pos.T)
                    waveform_neg = audio.inv_spectrogram(spectrogram_neg.T)
                    audio.save_wav(
                        waveform_pos,
                        os.path.join(log_dir, 'step-%d-audio_pos.wav' % step))
                    audio.save_wav(
                        waveform_neg,
                        os.path.join(log_dir, 'step-%d-audio_neg.wav' % step))
                    plot.plot_alignment(
                        alignment_pos,
                        os.path.join(log_dir, 'step-%d-align_pos.png' % step),
                        info='%s, %s, %s, step=%d, loss=%.5f' %
                        (args.model, commit, time_string(), step, rec_loss))
                    plot.plot_alignment(
                        alignment_neg,
                        os.path.join(log_dir, 'step-%d-align_neg.png' % step),
                        info='%s, %s, %s, step=%d, loss=%.5f' %
                        (args.model, commit, time_string(), step, rec_loss))
                    log('Input: %s' % sequence_to_text(input_seq))

        except Exception as e:
            log('Exiting due to exception: %s' % e, slack=True)
            traceback.print_exc()
            coord.request_stop(e)
예제 #5
0
def train(log_dir, config):
    config.data_paths = config.data_paths
    sub_dirs =['A11', 'A12', 'A13', 'A14', 'A19', 'A2', 'A22', 'A23', 'A32', 'A33', 'A34', 'A36', 'A4', 'A6', 'A7', 'A8',
               'B11', 'B12', 'B15','B2', 'B22', 'B31', 'B32', 'B33', 'B4', 'B6', 'B7', 'B8',
               'C12', 'C13', 'C14', 'C17', 'C18', 'C19', 'C2', 'C20', 'C21', 'C22', 'C23', 'C31', 'C32', 'C4', 'C6', 'C7', 'C8',
               'D11', 'D12', 'D13', 'D21', 'D31', 'D32', 'D4', 'D6', 'D7', 'D8',]

    data_dirs = [os.path.join(config.data_paths[0], sub_dir) for sub_dir in sub_dirs]
    num_speakers = len(data_dirs)
    setattr(hparams, "num_speakers", len(data_dirs))
    config.num_test = config.num_test_per_speaker * num_speakers

    if num_speakers > 1 and hparams.model_type not in ["deepvoice", "simple"]:
        raise Exception("[!] Unkown model_type for multi-speaker: {}".format(config.model_type))

    commit = get_git_commit() if config.git else 'None'
    checkpoint_path = os.path.join(log_dir, 'model.ckpt')

    log(' [*] git recv-parse HEAD:\n%s' % get_git_revision_hash())
    log('='*50)
    #log(' [*] dit diff:\n%s' % get_git_diff())
    log('='*50)
    log(' [*] Checkpoint path: %s' % checkpoint_path)
    log(' [*] Loading training data from: %s' % data_dirs)
    log(' [*] Using model: %s' % config.model_dir)
    log(hparams_debug_string())

    # Set up DataFeeder:
    coord = tf.train.Coordinator()
    with tf.variable_scope('datafeeder') as scope:
        train_feeder = DataFeeder(
                coord, data_dirs, hparams, config, 32,
                data_type='train', batch_size=hparams.batch_size)
        test_feeder = DataFeeder(
                coord, data_dirs, hparams, config, 8,
                data_type='test', batch_size=config.num_test)

    # Set up model:
    is_randomly_initialized = config.initialize_path is None
    global_step = tf.Variable(0, name='global_step', trainable=False)

    with tf.variable_scope('model') as scope:
        model = create_model(hparams)
        model.initialize(
                train_feeder.inputs, train_feeder.input_lengths,
                num_speakers,  train_feeder.speaker_id,
                train_feeder.mel_targets, train_feeder.linear_targets,
                train_feeder.loss_coeff,
                is_randomly_initialized=is_randomly_initialized)

        model.add_loss()
        model.add_optimizer(global_step)
        train_stats = add_stats(model, scope_name='stats') # legacy

    with tf.variable_scope('model', reuse=True) as scope:
        test_model = create_model(hparams)
        test_model.initialize(
                test_feeder.inputs, test_feeder.input_lengths,
                num_speakers, test_feeder.speaker_id,
                test_feeder.mel_targets, test_feeder.linear_targets,
                test_feeder.loss_coeff, rnn_decoder_test_mode=True,
                is_randomly_initialized=is_randomly_initialized)
        test_model.add_loss()

    test_stats = add_stats(test_model, model, scope_name='test')
    test_stats = tf.summary.merge([test_stats, train_stats])

    # Bookkeeping:
    step = 0
    time_window = ValueWindow(100)
    loss_window = ValueWindow(100)
    saver = tf.train.Saver(max_to_keep=None, keep_checkpoint_every_n_hours=2)

    sess_config = tf.ConfigProto(
            log_device_placement=False,
            allow_soft_placement=True)
    sess_config.gpu_options.allow_growth=True

    # Train!
    #with tf.Session(config=sess_config) as sess:
    with tf.Session() as sess:
        try:
            summary_writer = tf.summary.FileWriter(log_dir, sess.graph)
            sess.run(tf.global_variables_initializer())

            if config.load_path:
                # Restore from a checkpoint if the user requested it.
                restore_path = get_most_recent_checkpoint(config.model_dir)
                saver.restore(sess, restore_path)
                log('Resuming from checkpoint: %s at commit: %s' % (restore_path, commit), slack=True)
            elif config.initialize_path:
                restore_path = get_most_recent_checkpoint(config.initialize_path)
                saver.restore(sess, restore_path)
                log('Initialized from checkpoint: %s at commit: %s' % (restore_path, commit), slack=True)

                zero_step_assign = tf.assign(global_step, 0)
                sess.run(zero_step_assign)

                start_step = sess.run(global_step)
                log('='*50)
                log(' [*] Global step is reset to {}'. \
                        format(start_step))
                log('='*50)
            else:
                log('Starting new training run at commit: %s' % commit, slack=True)

            start_step = sess.run(global_step)

            train_feeder.start_in_session(sess, start_step)
            test_feeder.start_in_session(sess, start_step)

            while not coord.should_stop():
                start_time = time.time()
                step, loss, opt = sess.run(
                        [global_step, model.loss_without_coeff, model.optimize],
                        feed_dict=model.get_dummy_feed_dict())

                time_window.append(time.time() - start_time)
                loss_window.append(loss)

                message = 'Step %-7d [%.03f sec/step, loss=%.05f, avg_loss=%.05f]' % (
                        step, time_window.average, loss, loss_window.average)
                log(message, slack=(step % config.checkpoint_interval == 0))

                if loss > 100 or math.isnan(loss):
                    log('Loss exploded to %.05f at step %d!' % (loss, step), slack=True)
                    raise Exception('Loss Exploded')

                if step % config.summary_interval == 0:
                    log('Writing summary at step: %d' % step)

                    feed_dict = {
                            **model.get_dummy_feed_dict(),
                            **test_model.get_dummy_feed_dict()
                    }
                    summary_writer.add_summary(sess.run(
                            test_stats, feed_dict=feed_dict), step)

                if step % config.checkpoint_interval == 0:
                    log('Saving checkpoint to: %s-%d' % (checkpoint_path, step))
                    saver.save(sess, checkpoint_path, global_step=step)

                if step % config.test_interval == 0:
                    log('Saving audio and alignment...')
                    num_test = config.num_test

                    fetches = [
                            model.inputs[:num_test],
                            model.linear_outputs[:num_test],
                            model.alignments[:num_test],
                            test_model.inputs[:num_test],
                            test_model.linear_outputs[:num_test],
                            test_model.alignments[:num_test],
                    ]
                    feed_dict = {
                            **model.get_dummy_feed_dict(),
                            **test_model.get_dummy_feed_dict()
                    }

                    sequences, spectrograms, alignments, \
                            test_sequences, test_spectrograms, test_alignments = \
                                    sess.run(fetches, feed_dict=feed_dict)

                    save_and_plot(sequences[:1], spectrograms[:1], alignments[:1],
                            log_dir, step, loss, "train")
                    save_and_plot(test_sequences, test_spectrograms, test_alignments,
                            log_dir, step, loss, "test")

        except Exception as e:
            log('Exiting due to exception: %s' % e, slack=True)
            traceback.print_exc()
            coord.request_stop(e)
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('-config', type=str, default='config/hparams.yaml')
    parser.add_argument('-load_model', type=str, default=None)
    parser.add_argument('-model_name', type=str, default='P_S_Transformer_debug',
                        help='model name')
    # parser.add_argument('-batches_per_allreduce', type=int, default=1,
    #                     help='number of batches processed locally before '
    #                          'executing allreduce across workers; it multiplies '
    #                          'total batch size.')
    parser.add_argument('-num_wokers', type=int, default=0,
                        help='how many subprocesses to use for data loading. '
                             '0 means that the data will be loaded in the main process')
    parser.add_argument('-log', type=str, default='train.log')
    opt = parser.parse_args()

    configfile = open(opt.config)
    config = AttrDict(yaml.load(configfile,Loader=yaml.FullLoader))

    log_name = opt.model_name or config.model.name
    log_folder = os.path.join(os.getcwd(),'logdir/logging',log_name)
    if not os.path.isdir(log_folder):
        os.mkdir(log_folder)
    logger = init_logger(log_folder+'/'+opt.log)

    # TODO: build dataloader
    train_datafeeder = DataFeeder(config,'debug')

    # TODO: build model or load pre-trained model
    global global_step
    global_step = 0
    learning_rate = CustomSchedule(config.model.d_model)
    # learning_rate = 0.00002
    optimizer = tf.keras.optimizers.Adam(learning_rate, beta_1=config.optimizer.beta1, beta_2=config.optimizer.beta2,
                                         epsilon=config.optimizer.epsilon)
    logger.info('config.optimizer.beta1:' + str(config.optimizer.beta1))
    logger.info('config.optimizer.beta2:' + str(config.optimizer.beta2))
    logger.info('config.optimizer.epsilon:' + str(config.optimizer.epsilon))
    # print(str(config))
    model = Speech_transformer(config=config,logger=logger)

    #Create the checkpoint path and the checkpoint manager. This will be used to save checkpoints every n epochs.
    checkpoint_path = log_folder
    ckpt = tf.train.Checkpoint(transformer=model, optimizer=optimizer)
    ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)

    # if a checkpoint exists, restore the latest checkpoint.
    if ckpt_manager.latest_checkpoint:
        ckpt.restore(ckpt_manager.latest_checkpoint)
        logger.info('Latest checkpoint restored!!')
    else:
        logger.info('Start new run')


    # define metrics and summary writer
    train_loss = tf.keras.metrics.Mean(name='train_loss')
    train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')
    # summary_writer = tf.keras.callbacks.TensorBoard(log_dir=log_folder)
    summary_writer = summary_ops_v2.create_file_writer_v2(log_folder+'/train')


    # @tf.function
    def train_step(batch_data):
        inp = batch_data['the_inputs'] # batch*time*feature
        tar = batch_data['the_labels'] # batch*time
        # inp_len = batch_data['input_length']
        # tar_len = batch_data['label_length']
        gtruth = batch_data['ground_truth']
        tar_inp = tar
        tar_real = gtruth
        # enc_padding_mask, combined_mask, dec_padding_mask = create_masks(inp[:,:,0], tar_inp)
        combined_mask = create_combined_mask(tar=tar_inp)
        with tf.GradientTape() as tape:
            predictions, _ = model(inp, tar_inp, True, None,
                                   combined_mask, None)
            # logger.info('config.train.label_smoothing_epsilon:' + str(config.train.label_smoothing_epsilon))
            loss = LableSmoothingLoss(tar_real, predictions,config.model.vocab_size,config.train.label_smoothing_epsilon)
        gradients = tape.gradient(loss, model.trainable_variables)
        clipped_gradients, _ = tf.clip_by_global_norm(gradients, 1.0)
        optimizer.apply_gradients(zip(gradients, model.trainable_variables))
        train_loss(loss)
        train_accuracy(tar_real, predictions)

    time_window = ValueWindow(100)
    loss_window = ValueWindow(100)
    acc_window = ValueWindow(100)
    logger.info('config.train.epoches:' + str(config.train.epoches))
    first_time = True
    for epoch in range(config.train.epoches):
        logger.info('start epoch '+ str(epoch))
        logger.info('total wavs: '+ str(len(train_datafeeder)))
        logger.info('batch size: ' + str(train_datafeeder.batch_size))
        logger.info('batch per epoch: ' + str(len(train_datafeeder)//train_datafeeder.batch_size))
        train_data = train_datafeeder.get_batch()
        start_time = time.time()
        train_loss.reset_states()
        train_accuracy.reset_states()

        for step in range(len(train_datafeeder)//train_datafeeder.batch_size):
            batch_data = next(train_data)
            step_time = time.time()
            train_step(batch_data)
            if first_time:
                model.summary()
                first_time=False
            time_window.append(time.time()-step_time)
            loss_window.append(train_loss.result())
            acc_window.append(train_accuracy.result())
            message = 'Step %-7d [%.03f sec/step, loss=%.05f, avg_loss=%.05f, acc=%.05f, avg_acc=%.05f]' % (
                    global_step, time_window.average, train_loss.result(), loss_window.average, train_accuracy.result(),acc_window.average)
            logger.info(message)

            if global_step % 10 == 0:
                with summary_ops_v2.always_record_summaries():
                    with summary_writer.as_default():
                        summary_ops_v2.scalar('train_loss', train_loss.result(), step=global_step)
                        summary_ops_v2.scalar('train_acc', train_accuracy.result(), step=global_step)

            global_step += 1

        ckpt_save_path = ckpt_manager.save()
        logger.info('Saving checkpoint for epoch {} at {}'.format(epoch+1, ckpt_save_path))
        logger.info('Time taken for 1 epoch: {} secs\n'.format(time.time() - start_time))