Esempio n. 1
0
def train():
    """训练LSTM + CTC的语音识别系统.

  """
    data_dir = FLAGS.data_dir
    batch_size = FLAGS.batch_size
    dev_batch_size = FLAGS.dev_batch_size

    dev_config_file = path.join(data_dir, FLAGS.dev_config_file)
    train_config_file = path.join(data_dir, FLAGS.train_config_file)
    train_data_config = asr.read_data_config(train_config_file)
    dev_data_config = asr.read_data_config(dev_config_file)

    # 初始化bucket的大小, 初始化reader
    _buckets = _get_buckets(train_data_config.frame_max_length,
                            train_data_config.label_max_length)
    train_reader = asr.BucketReader(_buckets, train_data_config.feature_cols,
                                    batch_size)
    dev_reader = asr.BucketReader(_buckets, dev_data_config.feature_cols,
                                  dev_batch_size)

    # train files
    train_feature_file = os.path.join(data_dir, FLAGS.train_feature_file)
    train_feature_len_file = os.path.join(data_dir,
                                          FLAGS.train_feature_len_file)
    train_label_file = os.path.join(data_dir, FLAGS.train_label_file)
    train_label_len_file = os.path.join(data_dir, FLAGS.train_label_len_file)

    # train file open
    train_feature_fr = open(train_feature_file, "r")
    train_feature_len_fr = open(train_feature_len_file, "r")
    train_label_fr = open(train_label_file, "r")
    train_label_len_fr = open(train_label_len_file, "r")

    # dev files
    dev_feature_file = os.path.join(data_dir, FLAGS.dev_feature_file)
    dev_feature_len_file = os.path.join(data_dir, FLAGS.dev_feature_len_file)
    dev_label_file = os.path.join(data_dir, FLAGS.dev_label_file)
    dev_label_len_file = os.path.join(data_dir, FLAGS.dev_label_len_file)

    # dev file open
    dev_feature_fr = open(dev_feature_file, "r")
    dev_feature_len_fr = open(dev_feature_len_file, "r")
    dev_label_fr = open(dev_label_file, "r")
    dev_label_len_fr = open(dev_label_len_file, "r")

    # 数据的配置信息,多少个样本,一轮多少个batch
    dev_examples_num = dev_data_config.example_number
    dev_num_batches_per_epoch = int(dev_examples_num / dev_batch_size)
    train_num_examples = train_data_config.example_number
    train_num_batches_per_epoch = int(train_num_examples / batch_size)

    global_step = tf.get_variable('global_step', [],
                                  initializer=tf.constant_initializer(0),
                                  trainable=False,
                                  dtype=tf.int32)

    optimizer = tf.train.AdamOptimizer(FLAGS.initial_learning_rate,
                                       name="AdamOpt")
    # optimizer = tf.train.MomentumOptimizer(lr, 0.9)

    # placeholder
    train_seq_len = tf.placeholder(tf.int32, [batch_size],
                                   name="seq_len_placeholder")
    train_label_len = tf.placeholder(tf.int32, [batch_size],
                                     name="label_len_placeholder")
    train_feature_area = tf.placeholder(tf.float32, [None, None],
                                        name="feature_area_placeholder")
    train_label_area = tf.placeholder(
        tf.float32, [batch_size, train_data_config.label_max_length],
        name="label_area_placeholder")

    dev_seq_len = tf.placeholder(tf.int32, [dev_batch_size],
                                 name="dev_seq_len_placeholder")
    dev_label_len = tf.placeholder(tf.int32, [dev_batch_size],
                                   name="dev_label_len_placeholder")
    dev_feature_area = tf.placeholder(tf.float32, [None, None],
                                      name="dev_feature_area_placeholder")
    dev_label_area = tf.placeholder(
        tf.float32, [dev_batch_size, train_data_config.label_max_length],
        name="dev_label_area_placeholder")

    with tf.variable_scope("inference") as scope:
        train_ctc_in, train_targets, train_seq_len = asr.rnn(
            train_data_config, batch_size, train_feature_area, train_seq_len,
            train_label_area, train_label_len)
        scope.reuse_variables()
        dev_ctc_in, dev_targets, dev_seq_len = asr.rnn(
            dev_data_config, dev_batch_size, dev_feature_area, dev_seq_len,
            dev_label_area, dev_label_len)

    train_ctc_losses = tf.nn.ctc_loss(train_ctc_in, train_targets,
                                      train_seq_len)
    train_cost = tf.reduce_mean(train_ctc_losses, name="train_cost")

    # 限制梯度范围
    grads_and_vars = optimizer.compute_gradients(train_cost)
    capped_grads_and_vars = [(tf.clip_by_value(gv[0], -50.0, 50.0), gv[1])
                             for gv in grads_and_vars]
    train_op = optimizer.apply_gradients(capped_grads_and_vars,
                                         global_step=global_step)

    tf.scalar_summary("train_cost", train_cost)

    #dev
    dev_decoded, dev_log_prob = tf.nn.ctc_greedy_decoder(
        dev_ctc_in, dev_seq_len)
    dev_edit_distance = tf.edit_distance(tf.to_int32(dev_decoded[0]),
                                         dev_targets,
                                         normalize=False)
    dev_batch_error_count = tf.reduce_sum(dev_edit_distance)
    dev_batch_label_count = tf.shape(dev_targets.values)[0]

    # train
    train_decoded, train_log_prob = tf.nn.ctc_greedy_decoder(
        train_ctc_in, train_seq_len)
    train_edit_distance = tf.edit_distance(tf.to_int32(train_decoded[0]),
                                           train_targets,
                                           normalize=False)
    train_batch_error_count = tf.reduce_sum(train_edit_distance)
    train_batch_label_count = tf.shape(train_targets.values)[0]

    #init
    init = tf.global_variables_initializer()
    local_init = tf.local_variables_initializer()
    saver = tf.train.Saver()
    summary_op = tf.merge_all_summaries()

    gpu_options = tf.GPUOptions(
        per_process_gpu_memory_fraction=FLAGS.gpu_memory_fraction)

    with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as session:
        if FLAGS.reload_model == 1:
            ckpt = tf.train.get_checkpoint_state(FLAGS.model_dir)
            saver.restore(session, ckpt.model_checkpoint_path)

            global_step = int(
                ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1])

            logging.info("从%s载入模型参数, global_step = %d",
                         ckpt.model_checkpoint_path, global_step)
        else:
            logging.info("Created model with fresh parameters.")
            session.run(init)
            session.run(local_init)

        summary_writer = tf.train.SummaryWriter(FLAGS.model_dir, session.graph)

        step = 0
        epoch = 0
        while True:
            step += 1
            patch_data, bucket_id = train_reader.read_data(
                train_feature_fr, train_feature_len_fr, train_label_fr,
                train_label_len_fr)
            feature_ids, seq_len_ids, label_ids, label_len_ids = patch_data
            feed_dict = {
                train_feature_area: feature_ids,
                train_seq_len: seq_len_ids,
                train_label_area: label_ids,
                train_label_len: label_len_ids
            }
            _, loss = session.run([train_op, train_cost], feed_dict=feed_dict)

            if step % 20 == 0:
                start_time = time.time()
                patch_data, bucket_id = train_reader.read_data(
                    train_feature_fr, train_feature_len_fr, train_label_fr,
                    train_label_len_fr)
                feature_ids, seq_len_ids, label_ids, label_len_ids = patch_data
                feed_dict = {
                    train_feature_area: feature_ids,
                    train_seq_len: seq_len_ids,
                    train_label_area: label_ids,
                    train_label_len: label_len_ids
                }
                train_error_count_value, train_label_count = session.run(
                    [train_batch_error_count, train_batch_label_count],
                    feed_dict=feed_dict)
                train_acc_ratio = (train_label_count -
                                   train_error_count_value) / train_label_count
                duration = time.time() - start_time
                examples_per_sec = batch_size / duration

                logging.info(
                    'step %d, loss = %.2f (%.1f examples/sec) bucketid=%d train_acc= %.3f',
                    step, loss, examples_per_sec, bucket_id, train_acc_ratio)

            if step % train_num_batches_per_epoch == 0:
                saver.save(session,
                           FLAGS.model_dir + "model.ckpt",
                           global_step=step)
                logging.info("保存模型参数.")
                epoch += 1
                dev_error_count = 0
                dev_label_count = 0

                for batch in range(dev_num_batches_per_epoch):
                    patch_data, bucket_id = dev_reader.read_data(
                        dev_feature_fr, dev_feature_len_fr, dev_label_fr,
                        dev_label_len_fr)
                    feature_ids, seq_len_ids, label_ids, label_len_ids = patch_data
                    feed_dict = {
                        dev_feature_area: feature_ids,
                        dev_seq_len: seq_len_ids,
                        dev_label_area: label_ids,
                        dev_label_len: label_len_ids
                    }

                    dev_error_count_value, dev_label_count_value = session.run(
                        [dev_batch_error_count, dev_batch_label_count],
                        feed_dict=feed_dict)

                    dev_error_count += dev_error_count_value
                    dev_label_count += dev_label_count_value

                dev_acc_ratio = (dev_label_count -
                                 dev_error_count) / dev_label_count

                logging.info("eval: step = %d epoch = %d eval_acc = %.3f ",
                             step, epoch, dev_acc_ratio)

    train_feature_fr.close()
    train_feature_len_fr.close()
    train_label_fr.close()
    train_label_len_fr.close()

    dev_feature_fr.close()
    dev_feature_len_fr.close()
    dev_label_fr.close()
    dev_label_len_fr.close()
Esempio n. 2
0
def test():
  """单独的,测试模型效果.

  """
  data_dir = FLAGS.data_dir
  cv_batch_size = FLAGS.cv_batch_size
  cv_maxsize_file = path.join(data_dir, FLAGS.cv_maxsize_file)
  dev_data_config = asr.read_data_config(cv_maxsize_file)
  dev_data = asr.get_dev_data(dev_data_config, cv_batch_size)
  dev_examples_num = dev_data_config.example_number
  dev_num_batches_per_epoch = int(dev_examples_num / cv_batch_size)

  with tf.variable_scope("inference") as scope:
    dev_ctc_in, dev_targets, dev_seq_len = asr.rnn(dev_data, dev_data_config,
                                                   cv_batch_size)

    dev_decoded, dev_log_prob = tf.nn.ctc_greedy_decoder(dev_ctc_in,
                                                         dev_seq_len)

  edit_distance = tf.edit_distance(tf.to_int32(dev_decoded[0]), dev_targets,
                                   normalize=False)

  batch_error_count = tf.reduce_sum(edit_distance, name="batch_error_count")
  batch_label_count = tf.shape(dev_targets.values)[0]

  local_init = tf.initialize_local_variables()
  saver = tf.train.Saver()

  gpu_options = tf.GPUOptions(
    per_process_gpu_memory_fraction=FLAGS.gpu_memory_fraction)

  with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as session:

    ckpt = tf.train.get_checkpoint_state(FLAGS.model_dir)
    saver.restore(session, ckpt.model_checkpoint_path)

    global_step = int(ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1])

    logging.info("从%s载入模型参数, global_step = %d",
                 ckpt.model_checkpoint_path, global_step)

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=session, coord=coord)

    try:
      dev_error_count = 0
      dev_label_count = 0

      for batch in range(dev_num_batches_per_epoch):
        cv_error_count_value, cv_label_count = session.run(
          [batch_error_count, batch_label_count])

        dev_error_count += cv_error_count_value
        dev_label_count += cv_label_count

      dev_acc_ratio = (dev_label_count - dev_error_count) / dev_label_count

      logging.info("eval:  eval_acc = %.3f ", dev_acc_ratio)
    except tf.errors.OutOfRangeError:
      logging.info("训练完成.")
    finally:
      coord.request_stop()

    coord.join(threads)
Esempio n. 3
0
def train():
    """训练LSTM + CTC的语音识别系统.
  
  """

    ps_hosts = FLAGS.ps_hosts.split(",")
    worker_hosts = FLAGS.worker_hosts.split(",")
    cluster = tf.train.ClusterSpec({"ps": ps_hosts, "worker": worker_hosts})

    gpu_options = tf.GPUOptions(
        per_process_gpu_memory_fraction=FLAGS.gpu_memory_fraction)

    server = tf.train.Server(cluster,
                             config=tf.ConfigProto(gpu_options=gpu_options),
                             job_name=FLAGS.job_name,
                             task_index=FLAGS.task_index)

    issync = FLAGS.issync

    if FLAGS.job_name == "ps":
        server.join()
    elif FLAGS.job_name == "worker":
        data_dir = FLAGS.data_dir
        cv_maxsize_file = path.join(data_dir, FLAGS.cv_maxsize_file)
        train_maxsize_file = path.join(data_dir, FLAGS.train_maxsize_file)

        batch_size = FLAGS.batch_size
        cv_batch_size = FLAGS.cv_batch_size

        train_data_config = asr.read_data_config(train_maxsize_file)
        dev_data_config = asr.read_data_config(cv_maxsize_file)
        train_data = asr.distort_inputs(train_data_config, batch_size)
        dev_data = asr.get_dev_data(dev_data_config, cv_batch_size)

        dev_examples_num = dev_data_config.example_number
        dev_num_batches_per_epoch = int(dev_examples_num / cv_batch_size)
        train_num_examples = train_data_config.example_number
        train_num_batches_per_epoch = int(train_num_examples / batch_size)

        # 多少个step之后, 学习率下降
        decay_steps = int(train_num_batches_per_epoch *
                          FLAGS.num_epochs_per_decay)

        initial_learning_rate = FLAGS.initial_learning_rate
        learning_rate_decay_factor = FLAGS.learning_rate_decay_factor

        with tf.device(
                tf.train.replica_device_setter(
                    worker_device="/job:worker/task:%d" % FLAGS.task_index,
                    cluster=cluster)):
            global_step = tf.get_variable(
                'global_step', [],
                initializer=tf.constant_initializer(0),
                trainable=False,
                dtype=tf.int32)
            #optimizer = tf.train.GradientDescentOptimizer(initial_learning_rate)
            optimizer = tf.train.AdamOptimizer(initial_learning_rate)

            with tf.variable_scope("inference") as scope:
                ctc_input, train_targets, train_seq_len = asr.rnn(
                    train_data, train_data_config, batch_size)

                scope.reuse_variables()
                dev_ctc_in, dev_targets, dev_seq_len = asr.rnn(
                    dev_data, dev_data_config, cv_batch_size)

            example_losses = tf.nn.ctc_loss(ctc_input, train_targets,
                                            train_seq_len)
            train_cost = tf.reduce_mean(example_losses)

            if issync == 1:
                rep_op = tf.train.SyncReplicasOptimizer(
                    optimizer,
                    replicas_to_aggregate=len(worker_hosts),
                    replica_id=FLAGS.task_index,
                    total_num_replicas=len(worker_hosts),
                    use_locking=True)
                train_op = rep_op.minimize(train_cost, global_step=global_step)
                init_token_op = rep_op.get_init_tokens_op()
                chief_queue_runner = rep_op.get_chief_queue_runner()
            else:
                train_op = optimizer.minimize(train_cost,
                                              global_step=global_step)

            tf.scalar_summary("train_cost", train_cost)

            train_decoded, train_log_prob = tf.nn.ctc_greedy_decoder(
                ctc_input, train_seq_len)
            dev_decoded, dev_log_prob = tf.nn.ctc_greedy_decoder(
                dev_ctc_in, dev_seq_len)

            train_edit_distance = tf.edit_distance(tf.to_int32(
                train_decoded[0]),
                                                   train_targets,
                                                   normalize=False)
            edit_distance = tf.edit_distance(tf.to_int32(dev_decoded[0]),
                                             dev_targets,
                                             normalize=False)

            train_batch_error_count = tf.reduce_sum(train_edit_distance)
            train_batch_label_count = tf.shape(train_targets.values)[0]
            batch_error_count = tf.reduce_sum(edit_distance)
            batch_label_count = tf.shape(dev_targets.values)[0]

            init_op = tf.global_variables_initializer()
            local_init = tf.local_variables_initializer()
            saver = tf.train.Saver()
            summary_op = tf.merge_all_summaries()

            sv = tf.train.Supervisor(is_chief=(FLAGS.task_index == 0),
                                     logdir=FLAGS.model_dir,
                                     init_op=init_op,
                                     local_init_op=local_init,
                                     summary_op=summary_op,
                                     saver=saver,
                                     global_step=global_step,
                                     save_model_secs=600)
            with sv.prepare_or_wait_for_session(server.target) as sess:

                if FLAGS.task_index == 0 and issync == 1:
                    sv.start_queue_runners(sess, [chief_queue_runner])
                    sess.run(init_token_op)

                summary_writer = tf.train.SummaryWriter(
                    FLAGS.model_dir, sess.graph)

                step = 0
                valid_step = 0
                train_acc_step = 0
                epoch = 0
                while not sv.should_stop() and step < 100000000:

                    coord = tf.train.Coordinator()
                    threads = tf.train.start_queue_runners(sess=sess,
                                                           coord=coord)

                    try:
                        while not coord.should_stop():
                            train_cost_value, _, step = sess.run(
                                [train_cost, train_op, global_step])

                            if step % 100 == 0:
                                logging.info("step: %d,  loss: %f" %
                                             (step, train_cost_value))

                            if step - train_acc_step > 1000 and FLAGS.task_index == 0:
                                train_acc_step = step
                                train_error_count_value, train_label_count = sess.run(
                                    [
                                        train_batch_error_count,
                                        train_batch_label_count
                                    ])
                                train_acc_ratio = (train_label_count -
                                                   train_error_count_value
                                                   ) / train_label_count
                                logging.info(
                                    "eval: step = %d train_acc = %.3f ", step,
                                    train_acc_ratio)

                            # 当跑了steps_to_validate个step,并且是主的worker节点的时候, 评估下数据
                            # 因为是分布式的,各个节点分配了不同的step,所以不能用 % 是否等于0的方法
                            if step - valid_step > train_num_batches_per_epoch and FLAGS.task_index == 0:
                                epoch += 1

                                valid_step = step
                                dev_error_count = 0
                                dev_label_count = 0

                                for batch in range(dev_num_batches_per_epoch):
                                    cv_error_count_value, cv_label_count = sess.run(
                                        [batch_error_count, batch_label_count])

                                    dev_error_count += cv_error_count_value
                                    dev_label_count += cv_label_count

                                dev_acc_ratio = (
                                    dev_label_count -
                                    dev_error_count) / dev_label_count

                                logging.info(
                                    "epoch: %d eval: step = %d eval_acc = %.3f ",
                                    epoch, step, dev_acc_ratio)

                    except tf.errors.OutOfRangeError:
                        print("Done training after reading all data")
                    finally:
                        coord.request_stop()

                    # Wait for threads to exit
                    coord.join(threads)
Esempio n. 4
0
def train():
    """训练LSTM + CTC的语音识别系统.
  
  """

    ps_hosts = FLAGS.ps_hosts.split(",")
    worker_hosts = FLAGS.worker_hosts.split(",")
    cluster = tf.train.ClusterSpec({"ps": ps_hosts, "worker": worker_hosts})

    gpu_options = tf.GPUOptions(
        per_process_gpu_memory_fraction=FLAGS.gpu_memory_fraction)

    server = tf.train.Server(cluster,
                             config=tf.ConfigProto(gpu_options=gpu_options),
                             job_name=FLAGS.job_name,
                             task_index=FLAGS.task_index)

    issync = FLAGS.issync

    if FLAGS.job_name == "ps":
        server.join()
    elif FLAGS.job_name == "worker":
        data_dir = FLAGS.data_dir
        batch_size = FLAGS.batch_size
        dev_batch_size = FLAGS.dev_batch_size

        dev_config_file = path.join(data_dir, FLAGS.dev_config_file)
        train_config_file = path.join(data_dir, FLAGS.train_config_file)
        train_data_config = asr.read_data_config(train_config_file)
        dev_data_config = asr.read_data_config(dev_config_file)

        # 初始化bucket的大小, 初始化reader
        _buckets = _get_buckets(train_data_config.frame_max_length,
                                train_data_config.label_max_length)
        train_reader = asr.BucketReader(_buckets,
                                        train_data_config.feature_cols,
                                        batch_size)
        dev_reader = asr.BucketReader(_buckets, dev_data_config.feature_cols,
                                      dev_batch_size)

        # train files
        train_feature_file = os.path.join(data_dir, FLAGS.train_feature_file)
        train_feature_len_file = os.path.join(data_dir,
                                              FLAGS.train_feature_len_file)
        train_label_file = os.path.join(data_dir, FLAGS.train_label_file)
        train_label_len_file = os.path.join(data_dir,
                                            FLAGS.train_label_len_file)

        # train file open
        train_feature_fr = open(train_feature_file, "r")
        train_feature_len_fr = open(train_feature_len_file, "r")
        train_label_fr = open(train_label_file, "r")
        train_label_len_fr = open(train_label_len_file, "r")

        # dev files
        dev_feature_file = os.path.join(data_dir, FLAGS.dev_feature_file)
        dev_feature_len_file = os.path.join(data_dir,
                                            FLAGS.dev_feature_len_file)
        dev_label_file = os.path.join(data_dir, FLAGS.dev_label_file)
        dev_label_len_file = os.path.join(data_dir, FLAGS.dev_label_len_file)

        # dev file open
        dev_feature_fr = open(dev_feature_file, "r")
        dev_feature_len_fr = open(dev_feature_len_file, "r")
        dev_label_fr = open(dev_label_file, "r")
        dev_label_len_fr = open(dev_label_len_file, "r")

        # 数据的配置信息,多少个样本,一轮多少个batch
        dev_examples_num = dev_data_config.example_number
        dev_num_batches_per_epoch = int(dev_examples_num / dev_batch_size)
        train_num_examples = train_data_config.example_number
        train_num_batches_per_epoch = int(train_num_examples / batch_size)

        initial_learning_rate = FLAGS.initial_learning_rate

        with tf.device(
                tf.train.replica_device_setter(
                    worker_device="/job:worker/task:%d" % FLAGS.task_index,
                    cluster=cluster)):
            global_step = tf.get_variable(
                'global_step', [],
                initializer=tf.constant_initializer(0),
                trainable=False,
                dtype=tf.int32)

            #optimizer = tf.train.GradientDescentOptimizer(initial_learning_rate)
            optimizer = tf.train.AdamOptimizer(initial_learning_rate)

            # placeholder
            train_seq_len = tf.placeholder(tf.int32, [batch_size],
                                           name="seq_len_placeholder")
            train_label_len = tf.placeholder(tf.int32, [batch_size],
                                             name="label_len_placeholder")
            train_feature_area = tf.placeholder(
                tf.float32, [None, None], name="feature_area_placeholder")
            train_label_area = tf.placeholder(
                tf.float32, [batch_size, train_data_config.label_max_length],
                name="label_area_placeholder")

            dev_seq_len = tf.placeholder(tf.int32, [dev_batch_size],
                                         name="dev_seq_len_placeholder")
            dev_label_len = tf.placeholder(tf.int32, [dev_batch_size],
                                           name="dev_label_len_placeholder")
            dev_feature_area = tf.placeholder(
                tf.float32, [None, None], name="dev_feature_area_placeholder")
            dev_label_area = tf.placeholder(
                tf.float32,
                [dev_batch_size, train_data_config.label_max_length],
                name="dev_label_area_placeholder")

            with tf.variable_scope("inference") as scope:
                train_ctc_in, train_targets, train_seq_len = asr.rnn(
                    train_data_config, batch_size, train_feature_area,
                    train_seq_len, train_label_area, train_label_len)

                scope.reuse_variables()
                dev_ctc_in, dev_targets, dev_seq_len = asr.rnn(
                    dev_data_config, dev_batch_size, dev_feature_area,
                    dev_seq_len, dev_label_area, dev_label_len)

            train_ctc_losses = tf.nn.ctc_loss(train_ctc_in, train_targets,
                                              train_seq_len)
            train_cost = tf.reduce_mean(train_ctc_losses, name="train_cost")
            # 限制梯度范围
            grads_and_vars = optimizer.compute_gradients(train_cost)
            capped_grads_and_vars = [(tf.clip_by_value(gv[0], -50.0,
                                                       50.0), gv[1])
                                     for gv in grads_and_vars]

            if issync == 1:
                #同步模式计算更新梯度
                rep_op = tf.train.SyncReplicasOptimizer(
                    optimizer,
                    replicas_to_aggregate=len(worker_hosts),
                    replica_id=FLAGS.task_index,
                    total_num_replicas=len(worker_hosts),
                    use_locking=True)
                train_op = rep_op.apply_gradients(capped_grads_and_vars,
                                                  global_step=global_step)
                init_token_op = rep_op.get_init_tokens_op()
                chief_queue_runner = rep_op.get_chief_queue_runner()
            else:
                #异步模式计算更新梯度
                train_op = optimizer.apply_gradients(capped_grads_and_vars,
                                                     global_step=global_step)

            #记录loss值,显示到tensorboard上
            #tf.scalar_summary("train_cost", train_cost)

            #dev 评估
            dev_decoded, dev_log_prob = tf.nn.ctc_greedy_decoder(
                dev_ctc_in, dev_seq_len)
            dev_edit_distance = tf.edit_distance(tf.to_int32(dev_decoded[0]),
                                                 dev_targets,
                                                 normalize=False)
            dev_batch_error_count = tf.reduce_sum(dev_edit_distance)
            dev_batch_label_count = tf.shape(dev_targets.values)[0]

            # train 评估
            train_decoded, train_log_prob = tf.nn.ctc_greedy_decoder(
                train_ctc_in, train_seq_len)
            train_edit_distance = tf.edit_distance(tf.to_int32(
                train_decoded[0]),
                                                   train_targets,
                                                   normalize=False)
            train_batch_error_count = tf.reduce_sum(train_edit_distance)
            train_batch_label_count = tf.shape(train_targets.values)[0]

            # 初始化各种
            init_op = tf.global_variables_initializer()
            local_init = tf.local_variables_initializer()
            saver = tf.train.Saver()
            summary_op = tf.merge_all_summaries()

            sv = tf.train.Supervisor(is_chief=(FLAGS.task_index == 0),
                                     logdir=FLAGS.model_dir,
                                     init_op=init_op,
                                     local_init_op=local_init,
                                     summary_op=summary_op,
                                     saver=saver,
                                     global_step=global_step,
                                     save_model_secs=600)

            with sv.prepare_or_wait_for_session(server.target) as sess:
                # 如果是同步模式
                if FLAGS.task_index == 0 and issync == 1:
                    sv.start_queue_runners(sess, [chief_queue_runner])
                    sess.run(init_token_op)

                summary_writer = tf.train.SummaryWriter(
                    FLAGS.model_dir, sess.graph)

                step = 0
                valid_step = 0
                train_acc_step = 0
                epoch = 0
                while not sv.should_stop() and step < 100000000:
                    patch_data, bucket_id = train_reader.read_data(
                        train_feature_fr, train_feature_len_fr, train_label_fr,
                        train_label_len_fr)
                    feature_ids, seq_len_ids, label_ids, label_len_ids = patch_data
                    feed_dict = {
                        train_feature_area: feature_ids,
                        train_seq_len: seq_len_ids,
                        train_label_area: label_ids,
                        train_label_len: label_len_ids
                    }
                    _, loss, step = sess.run(
                        [train_op, train_cost, global_step],
                        feed_dict=feed_dict)

                    if (step -
                            train_acc_step) > 1000 and FLAGS.task_index == 0:
                        train_acc_step = step
                        patch_data, bucket_id = train_reader.read_data(
                            train_feature_fr, train_feature_len_fr,
                            train_label_fr, train_label_len_fr)
                        feature_ids, seq_len_ids, label_ids, label_len_ids = patch_data
                        feed_dict = {
                            train_feature_area: feature_ids,
                            train_seq_len: seq_len_ids,
                            train_label_area: label_ids,
                            train_label_len: label_len_ids
                        }
                        train_error_count_value, train_label_count = sess.run(
                            [train_batch_error_count, train_batch_label_count],
                            feed_dict=feed_dict)
                        train_acc_ratio = (
                            train_label_count -
                            train_error_count_value) / train_label_count
                        logging.info(
                            "eval: step = %d loss = %.3f train_acc = %.3f ",
                            step, loss, train_acc_ratio)

# 当跑了steps_to_validate个step,并且是主的worker节点的时候, 评估下数据
# 因为是分布式的,各个节点分配了不同的step,所以不能用 % 是否等于0的方法
                    if step - valid_step > train_num_batches_per_epoch and FLAGS.task_index == 0:
                        epoch += 1
                        valid_step = step
                        dev_error_count = 0
                        dev_label_count = 0

                        for batch in range(dev_num_batches_per_epoch):
                            patch_data, bucket_id = dev_reader.read_data(
                                dev_feature_fr, dev_feature_len_fr,
                                dev_label_fr, dev_label_len_fr)
                            feature_ids, seq_len_ids, label_ids, label_len_ids = patch_data
                            feed_dict = {
                                dev_feature_area: feature_ids,
                                dev_seq_len: seq_len_ids,
                                dev_label_area: label_ids,
                                dev_label_len: label_len_ids
                            }

                            dev_error_count_value, dev_label_count_value = sess.run(
                                [dev_batch_error_count, dev_batch_label_count],
                                feed_dict=feed_dict)
                            dev_error_count += dev_error_count_value
                            dev_label_count += dev_label_count_value

                        dev_acc_ratio = (dev_label_count -
                                         dev_error_count) / dev_label_count
                        logging.info(
                            "epoch: %d eval: step = %d eval_acc = %.3f ",
                            epoch, step, dev_acc_ratio)

        train_feature_fr.close()
        train_feature_len_fr.close()
        train_label_fr.close()
        train_label_len_fr.close()

        dev_feature_fr.close()
        dev_feature_len_fr.close()
        dev_label_fr.close()
        dev_label_len_fr.close()
Esempio n. 5
0
def train():
  """训练LSTM + CTC的语音识别系统.

  """
  data_dir = FLAGS.data_dir
  batch_size = FLAGS.batch_size
  cv_batch_size = FLAGS.cv_batch_size

  cv_maxsize_file = path.join(data_dir, FLAGS.cv_maxsize_file)
  train_maxsize_file = path.join(data_dir, FLAGS.train_maxsize_file)
  train_data_config = asr.read_data_config(train_maxsize_file)
  dev_data_config = asr.read_data_config(cv_maxsize_file)
  train_data = asr.distort_inputs(train_data_config, batch_size)
  dev_data = asr.get_dev_data(dev_data_config, cv_batch_size)

  dev_examples_num = dev_data_config.example_number
  dev_num_batches_per_epoch = int(dev_examples_num / cv_batch_size)
  train_num_examples = train_data_config.example_number
  train_num_batches_per_epoch = int(train_num_examples / batch_size)

  # 多少个step之后, 学习率下降
  decay_steps = int(train_num_batches_per_epoch * FLAGS.num_epochs_per_decay)

  global_step = tf.get_variable('global_step', [],
                                initializer=tf.constant_initializer(0),
                                trainable=False, dtype=tf.int32)

  lr = tf.train.exponential_decay(FLAGS.initial_learning_rate, global_step,
                                  decay_steps, FLAGS.learning_rate_decay_factor,
                                  staircase=True, name="decay_learning_rate")

  optimizer = tf.train.AdamOptimizer(lr, name="AdamOpt")
  # optimizer = tf.train.MomentumOptimizer(lr, 0.9)

  with tf.variable_scope("inference") as scope:
    ctc_input, train_targets, train_seq_len = asr.rnn(train_data,
                                                      train_data_config,
                                                      batch_size)

    scope.reuse_variables()
    dev_ctc_in, dev_targets, dev_seq_len = asr.rnn(dev_data, dev_data_config,
                                                   cv_batch_size)

  example_losses = tf.nn.ctc_loss(ctc_input, train_targets, train_seq_len)
  train_cost = tf.reduce_mean(example_losses, name="train_cost")
  grads_and_vars = optimizer.compute_gradients(train_cost)
  capped_grads_and_vars = [(tf.clip_by_value(gv[0], -50.0, 50.0), gv[1]) for gv
                           in grads_and_vars]
  train_op = optimizer.apply_gradients(capped_grads_and_vars,
                                       global_step=global_step)
  #train_op = optimizer.minimize(train_cost, global_step=global_step)

  tf.scalar_summary("train_cost", train_cost)

  dev_decoded, dev_log_prob = tf.nn.ctc_greedy_decoder(dev_ctc_in, dev_seq_len)

  edit_distance = tf.edit_distance(tf.to_int32(dev_decoded[0]), dev_targets,
                                   normalize=False)

  batch_error_count = tf.reduce_sum(edit_distance, name="batch_error_count")
  batch_label_count = tf.shape(dev_targets.values)[0]
  init = tf.global_variables_initializer()
  local_init = tf.local_variables_initializer()
  saver = tf.train.Saver()
  summary_op = tf.merge_all_summaries()

  gpu_options = tf.GPUOptions(
    per_process_gpu_memory_fraction=FLAGS.gpu_memory_fraction)

  with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as session:
    if FLAGS.reload_model == 1:
      ckpt = tf.train.get_checkpoint_state(FLAGS.model_dir)
      saver.restore(session, ckpt.model_checkpoint_path)

      global_step = int(
        ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1])

      logging.info("从%s载入模型参数, global_step = %d",
                   ckpt.model_checkpoint_path, global_step)
    else:
      logging.info("Created model with fresh parameters.")
      session.run(init)
      session.run(local_init)

    summary_writer = tf.train.SummaryWriter(FLAGS.model_dir, session.graph)

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=session, coord=coord)
    step = 0
    epoch = 0

    try:
      while not coord.should_stop():
        step += 1
        start_time = time.time()
        train_cost_value, _ = session.run([train_cost, train_op])
        duration = time.time() - start_time
        examples_per_sec = batch_size / duration
        sec_per_batch = float(duration)

        if step % 2 == 0:
          summary_str = session.run(summary_op)
          summary_writer.add_summary(summary_str, step)
        if step % 20 == 0:
          logging.info(
            'step %d, loss = %.2f (%.1f examples/sec; %.3f sec/batch)',
            step, train_cost_value, examples_per_sec, sec_per_batch)

        if step % train_num_batches_per_epoch == 0:
          saver.save(session, FLAGS.model_dir + "model.ckpt", global_step=step)
          logging.info("保存模型参数.")
          epoch += 1
          dev_error_count = 0
          dev_label_count = 0

          for batch in range(dev_num_batches_per_epoch):
            cv_error_count_value, cv_label_count = session.run(
              [batch_error_count, batch_label_count])

            dev_error_count += cv_error_count_value
            dev_label_count += cv_label_count

          dev_acc_ratio = (dev_label_count - dev_error_count) / dev_label_count

          logging.info("eval: step = %d epoch = %d eval_acc = %.3f ",
                       step, epoch, dev_acc_ratio)
    except tf.errors.OutOfRangeError:
      logging.info("训练完成.")
    finally:
      # When done, ask the threads to stop.
      coord.request_stop()

    coord.join(threads)