def do_train(sess, args):
    # set CPU as the default device for the graph. Some of the operations will be moved to GPU later.
    with tf.device('/cpu:0'):

        # Images and labels placeholders
        # images_ph = tf.placeholder(tf.float32, shape=(None,) + tuple(args.processed_size), name='input')
        images_ph = tf.placeholder(tf.float32, shape=None, name='input')
        labels_ph = tf.placeholder(tf.int32, shape=None, name='label')
        max_seq_len_ph = tf.placeholder(tf.int32,
                                        shape=None,
                                        name='max_seq_len')
        label_length_batch_ph = tf.placeholder(tf.int32,
                                               shape=None,
                                               name='label_length_batch')

        # a placeholder for determining if we train or validate the network. This placeholder will be used to set dropout rates and batchnorm paramaters.
        is_training_ph = tf.placeholder(tf.bool, name='is_training')

        # epoch number
        # 值得一看
        epoch_number = tf.get_variable(
            'epoch_number', [],
            dtype=tf.int32,
            initializer=tf.constant_initializer(0),
            trainable=False,
            collections=[tf.GraphKeys.GLOBAL_VARIABLES, SAVE_VARIABLES])
        global_step = tf.get_variable(
            'global_step', [],
            dtype=tf.int32,
            initializer=tf.constant_initializer(0),
            trainable=False,
            collections=[tf.GraphKeys.GLOBAL_VARIABLES, SAVE_VARIABLES])

        # Weight Decay policy
        wd = utils.get_policy(args.WD_policy, args.WD_details)

        # Learning rate decay policy (if needed)
        # lr = utils.get_policy(args.LR_policy, args.LR_details)
        # TODO: 可能有问题
        lr = 0.0001

        # Create an optimizer that performs gradient descent.
        optimizer = utils.get_optimizer(args.optimizer, lr)

        # Create a pipeline to read data from disk
        # a placeholder for setting the input pipeline batch size. This is employed to ensure that we feed each validation example only once to the network.
        # Because we only use 1 GPU for validation, the validation batch size should not be more than 512.
        batch_size_tf = tf.placeholder_with_default(min(512, args.batch_size),
                                                    shape=())

        # A data loader pipeline to read training images and their labels
        train_loader = Loader(args.train_info, args.delimiter, args.raw_size,
                              args.processed_size, True,
                              args.chunked_batch_size, args.num_prefetch,
                              args.num_threads, args.path_prefix, args.shuffle)
        # The loader returns images, their labels, and their paths
        # images, labels, info = train_loader.load()
        mfcc_feat_batch, label_batch, feat_shape_batch, seq_len_batch, max_seq_len, label_length_batch = train_loader.load(
        )

        # build the computational graph using the provided configuration.
        dnn_model = model(images_ph,
                          labels_ph,
                          utils.loss,
                          optimizer,
                          wd,
                          args.architecture,
                          args.depth,
                          args.num_chars,
                          args.num_classes,
                          is_training_ph,
                          max_seq_len_ph,
                          label_length_batch_ph,
                          args.transfer_mode,
                          num_gpus=args.num_gpus)

        # If validation data are provided, we create an input pipeline to load the validation data
        if args.run_validation:
            val_loader = Loader(args.val_info, args.delimiter, args.raw_size,
                                args.processed_size, False, batch_size_tf,
                                args.num_prefetch, args.num_threads,
                                args.path_prefix)
            # TODO: uncomment
            # val_images, val_labels, val_info = val_loader.load()

        # Get training operations to run from the deep learning model
        train_ops = dnn_model.train_ops()

        # Build an initialization operation to run below.
        init = tf.group(tf.global_variables_initializer(),
                        tf.local_variables_initializer())
        sess.run(init)

        if args.retrain_from is not None:
            dnn_model.load(sess, args.retrain_from)

        # Set the start epoch number
        start_epoch = sess.run(epoch_number + 1)

        # Start the queue runners.
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        # Setup a summary writer
        summary_writer = tf.summary.FileWriter(args.log_dir, sess.graph)

        # The main training loop
        for epoch in range(start_epoch, start_epoch + args.num_epochs):

            # update epoch_number
            sess.run(epoch_number.assign(epoch))

            print("Epoch %d started" % (epoch))
            # Trainig batches
            for step in range(args.num_batches):
                sess.run(global_step.assign(step + epoch * args.num_batches))
                # train the network on a batch of data (It also measures time)
                start_time = time.time()

                # load a batch from input pipeline
                # img, lbl = sess.run([images, labels], options=args.run_options, run_metadata=args.run_metadata)
                mfcc_feat_batch, label_batch, feat_shape_batch, seq_len_batch, max_seq_len, label_length_batch \
                    = sess.run([mfcc_feat_batch, label_batch, feat_shape_batch, seq_len_batch, max_seq_len, label_length_batch],
                               options=args.run_options, run_metadata=args.run_metadata)

                # train on the loaded batch of data
                _, loss_value, top1_accuracy, topn_accuracy = \
                    sess.run(train_ops,
                             feed_dict={images_ph: mfcc_feat_batch, labels_ph: label_batch,
                                        max_seq_len_ph: max_seq_len, label_length_batch_ph: label_length_batch,
                                        is_training_ph: True},
                             options=args.run_options, run_metadata=args.run_metadata)
                duration = time.time() - start_time

                # Check for errors
                assert not np.isnan(
                    loss_value), 'Model diverged with loss = NaN'

                # Logging every ten batches and writing tensorboard summaries every hundred batches
                if step % 10 == 0:
                    num_examples_per_step = args.chunked_batch_size * args.num_gpus
                    examples_per_sec = num_examples_per_step / duration
                    sec_per_batch = duration / args.num_gpus

                    # Log
                    format_str = (
                        '%s: epoch %d, step %d, loss = %.2f, Top-1 = %.2f Top-'
                        + str(args.top_n) +
                        ' = %.2f (%.1f examples/sec; %.3f sec/batch)')
                    print(format_str %
                          (datetime.now(), epoch, step, loss_value,
                           top1_accuracy, topn_accuracy, examples_per_sec,
                           sec_per_batch))
                    sys.stdout.flush()

                if step % 100 == 0:
                    summary_str = sess.run(tf.summary.merge_all(),
                                           feed_dict={
                                               images_ph: mfcc_feat_batch,
                                               labels_ph: label_batch,
                                               max_seq_len_ph: max_seq_len,
                                               label_length_batch_ph:
                                               label_length_batch,
                                               is_training_ph: True
                                           })
                    summary_writer.add_summary(summary_str,
                                               args.num_batches * epoch + step)
                    # TODO:这里好像有bug
                    # if args.log_debug_info:
                    #     summary_writer.add_run_metadata(run_metadata, 'epoch%d step%d' % (epoch, step))

            # Save the model checkpoint periodically after each training epoch
            checkpoint_path = os.path.join(args.log_dir, args.snapshot_prefix)
            dnn_model.save(sess, checkpoint_path, global_step=epoch)

            print("Epoch %d ended. a checkpoint saved at %s" %
                  (epoch, args.log_dir))
            sys.stdout.flush()
            # if validation data are provided, evaluate accuracy on the validation set after the end of each epoch
            if args.run_validation:

                print("Evaluating on validation set")
                """

                true_predictions_count = 0  # Counts the number of correct predictions
                true_topn_predictions_count = 0  # Counts the number of top-n correct predictions
                total_loss = 0.0  # measures cross entropy loss
                all_count = 0  # Count the total number of examples

                # The validation loop
                for step in range(args.num_val_batches):
                    # Load a batch of data
                    val_img, val_lbl = sess.run([val_images, val_labels], feed_dict={
                        batch_size_tf: args.num_val_samples % min(512, args.batch_size)} if step == args.num_val_batches - 1 else None,
                                                options=args.run_options, run_metadata=args.run_metadata)

                    # validate the network on the loaded batch
                    val_loss, top1_predictions, topn_predictions = sess.run([train_ops[1], train_ops[2], train_ops[3]],
                                                                            feed_dict={images_ph: val_img, labels_ph: val_lbl,
                                                                                       is_training_ph: False},
                                                                            options=args.run_options, run_metadata=args.run_metadata)

                    all_count += val_lbl.shape[0]
                    true_predictions_count += int(round(val_lbl.shape[0] * top1_predictions))
                    true_topn_predictions_count += int(round(val_lbl.shape[0] * topn_predictions))
                    total_loss += val_loss * val_lbl.shape[0]
                    if step % 10 == 0:
                        print("Validation step %d of %d" % (step, args.num_val_batches))
                        sys.stdout.flush()

                print("Total number of validation examples %d, Loss %.2f, Top-1 Accuracy %.2f, Top-%d Accuracy %.2f" %
                      (all_count, total_loss / all_count, true_predictions_count / all_count, args.top_n,
                       true_topn_predictions_count / all_count))
                sys.stdout.flush()
                """

        coord.request_stop()
        coord.join(threads)
        sess.close()
def do_evaluate(sess, args):
    with tf.device('/cpu:0'):
        # Images and labels placeholders
        images_ph = tf.placeholder(tf.float32,
                                   shape=(None, ) + tuple(args.processed_size),
                                   name='input')
        labels_ph = tf.placeholder(tf.int32, shape=(None), name='label')

        # a placeholder for determining if we train or validate the network. This placeholder will be used to set dropout rates and batchnorm paramaters.
        is_training_ph = tf.placeholder(tf.bool, name='is_training')

        # build a deep learning model using the provided configuration
        dnn_model = model(images_ph, labels_ph, utils.loss, None, 0.0,
                          args.architecture, args.depth, args.num_chars,
                          args.num_classes, is_training_ph, args.transfer_mode)

        # creating an input pipeline to read data from disk
        # a placeholder for setting the input pipeline batch size. This is employed to ensure that we feed each validation example only once to the network.
        batch_size_tf = tf.placeholder_with_default(args.batch_size, shape=())

        # a data loader pipeline to read test data
        val_loader = Loader(args.val_info,
                            args.delimiter,
                            args.raw_size,
                            args.processed_size,
                            False,
                            batch_size_tf,
                            args.num_prefetch,
                            args.num_threads,
                            args.path_prefix,
                            inference_only=args.inference_only)

        # if we want to do inference only (i.e. no label is provided) we only load images and their paths
        if not args.inference_only:
            val_images, val_labels, val_info = val_loader.load()
        else:
            val_images, val_info = val_loader.load()

        # get evaluation operations from the dnn model
        eval_ops = dnn_model.evaluate_ops(args.inference_only)

        # Build an initialization operation to run below.
        init = tf.group(tf.global_variables_initializer(),
                        tf.local_variables_initializer())
        sess.run(init)

        # Load pretrained parameters from disk
        dnn_model.load(sess, args.log_dir)

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        # evaluation
        if not args.inference_only:
            true_predictions_count = 0  # Counts the number of correct predictions
            true_topn_predictions_count = 0  # Counts the number of correct top-n predictions
            total_loss = 0.0  # Measures cross entropy loss
            all_count = 0  # Counts the total number of examples

            # Open an output file to write predictions
            out_file = open(args.save_predictions, 'w')
            predictions_format_str = ('%d, %s, %d, %s, %s\n')
            for step in range(args.num_val_batches):
                # Load a batch of data
                val_img, val_lbl, val_inf = sess.run(
                    [val_images, val_labels, val_info],
                    feed_dict={
                        batch_size_tf: args.num_val_samples % args.batch_size
                    } if step == args.num_val_batches - 1 else None)

                # Evaluate the network on the loaded batch
                val_loss, top1_predictions, topn_predictions, topnguesses, topnconf = sess.run(
                    eval_ops,
                    feed_dict={
                        images_ph: val_img,
                        labels_ph: val_lbl,
                        is_training_ph: False
                    },
                    options=args.run_options,
                    run_metadata=args.run_metadata)

                true_predictions_count += np.sum(top1_predictions)
                true_topn_predictions_count += np.sum(topn_predictions)
                all_count += top1_predictions.shape[0]
                total_loss += val_loss * val_lbl.shape[0]
                print(
                    'Batch Number: %d, Top-1 Hit: %d, Top-%d Hit: %d, Loss %.2f, Top-1 Accuracy: %.3f, Top-%d Accuracy: %.3f'
                    % (step, true_predictions_count, args.top_n,
                       true_topn_predictions_count, total_loss / all_count,
                       true_predictions_count / all_count, args.top_n,
                       true_topn_predictions_count / all_count))

                # log results into an output file
                for i in range(0, val_inf.shape[0]):
                    out_file.write(
                        predictions_format_str %
                        (step * args.batch_size + i + 1, str(
                            val_inf[i]).encode('utf-8'), val_lbl[i], ', '.join(
                                '%d' % item
                                for item in topnguesses[i]), ', '.join(
                                    '%.4f' % item for item in topnconf[i])))
                    out_file.flush()

            out_file.close()
        # inference
        else:

            # Open an output file to write predictions
            out_file = open(args.save_predictions, 'w')
            predictions_format_str = ('%d, %s, %s, %s\n')

            for step in range(args.num_val_batches):
                # Load a batch of data
                val_img, val_inf = sess.run(
                    [val_images, val_info],
                    feed_dict={
                        batch_size_tf: args.num_val_samples % args.batch_size
                    } if step == args.num_val_batches - 1 else None)

                # Run the network on the loaded batch
                topnguesses, topnconf = sess.run(
                    eval_ops,
                    feed_dict={
                        images_ph: val_img,
                        is_training_ph: False
                    },
                    options=args.run_options,
                    run_metadata=args.run_metadata)
                print('Batch Number: %d of %d is done' %
                      (step, args.num_val_batches))

                # Log to an output file
                for i in range(0, val_inf.shape[0]):
                    out_file.write(predictions_format_str %
                                   (step * args.batch_size + i + 1,
                                    str(val_inf[i]).encode('utf-8'), ', '.join(
                                        '%d' % item
                                        for item in topnguesses[i]), ', '.join(
                                            '%.4f' % item
                                            for item in topnconf[i])))
                    out_file.flush()

            out_file.close()

        coord.request_stop()
        coord.join(threads)
        sess.close()