Exemplo n.º 1
0
def main(_):
    train_mfcc_dir = os.path.join(FLAGS.input_data_dir, FLAGS.level, 'TRAIN',
                                  'mfcc')
    train_label_dir = os.path.join(FLAGS.input_data_dir, FLAGS.level, 'TRAIN',
                                   'label')
    test_mfcc_dir = os.path.join(FLAGS.input_data_dir, FLAGS.level, 'TEST',
                                 'mfcc')
    test_label_dir = os.path.join(FLAGS.input_data_dir, FLAGS.level, 'TEST',
                                  'label')

    savedir = os.path.join(FLAGS.exp_dir, FLAGS.level, 'save')
    resultdir = os.path.join(FLAGS.exp_dir, FLAGS.level, 'result')

    if FLAGS.is_training:
        batched_data, max_time_steps, total_n = load_batched_data(
            train_mfcc_dir, train_label_dir, FLAGS.batch_size, FLAGS.level)
    else:
        batched_data, max_time_steps, total_n = load_batched_data(
            test_mfcc_dir, test_label_dir, FLAGS.batch_size, FLAGS.level)

    hparams = {}
    hparams['level'] = FLAGS.level
    hparams['batch_size'] = FLAGS.batch_size
    hparams['partition_size'] = FLAGS.partition_size
    hparams['num_hidden'] = FLAGS.num_hidden
    hparams['feature_length'] = FLAGS.feature_length
    hparams['num_classes'] = FLAGS.num_classes
    hparams['num_proj'] = FLAGS.num_proj
    hparams['learning_rate'] = FLAGS.learning_rate
    hparams['keep_prob'] = FLAGS.keep_prob
    hparams['clip_gradient_norm'] = FLAGS.clip_gradient_norm
    hparams['use_peepholes'] = FLAGS.use_peepholes
    if FLAGS.activation == 'tanh':
        hparams['activation'] = tf.tanh
    elif FLAGS.activation == 'relu':
        hparams['activation'] = tf.nn.relu
    hparams['max_time_steps'] = max_time_steps
    with tf.Graph().as_default():
        model = DRNN(FLAGS.cell, hparams, FLAGS.is_training)
        train_writer = tf.summary.FileWriter(resultdir + '/train')
        test_writer = tf.summary.FileWriter(resultdir + '/test')
        with tf.Session(FLAGS.master) as sess:
            # restore from stored model
            if FLAGS.restore:
                ckpt = tf.train.get_checkpoint_state(savedir)
                if ckpt and ckpt.model_checkpoint_path:
                    model.saver.restore(sess, ckpt.model_checkpoint_path)
                    print('Model restored from:' + ckpt.model_checkpoint_path)
            else:
                print('Initializing')
                sess.run(model.initial_op)
            train_writer.add_graph(sess.graph)
            for epoch in range(FLAGS.num_epochs):
                ## training
                start = time.time()
                if FLAGS.is_training:
                    print('Epoch', epoch + 1, '...')
                batch_errors = np.zeros(len(batched_data))
                batched_random_idx = np.random.permutation(len(batched_data))

                for batch, batch_original_idx in enumerate(batched_random_idx):
                    batch_inputs, batch_target_sparse, batch_seq_length = batched_data[
                        batch_original_idx]
                    batch_tgt_idx, batch_tgt_vals, batch_tgt_shape = batch_target_sparse
                    feeddict = {
                        model.x: batch_inputs,
                        model.tgt_idx: batch_tgt_idx,
                        model.tgt_vals: batch_tgt_vals,
                        model.tgt_shape: batch_tgt_shape,
                        model.seq_length: batch_seq_length
                    }

                    if FLAGS.is_training and (
                        (epoch * len(batched_random_idx) + batch + 1) % 20 == 0
                            or (epoch == FLAGS.num_epochs - 1
                                and batch == len(batched_random_idx) - 1)):
                        checkpoint_path = os.path.join(savedir, 'model.ckpt')
                        model.saver.save(sess,
                                         checkpoint_path,
                                         global_step=model.global_step)
                        print('Model has been saved in {}'.format(savedir))

                    if FLAGS.level == 'cha':
                        if FLAGS.is_training:
                            _, l, pre, y, er, global_step = sess.run(
                                [
                                    model.train_op, model.loss,
                                    model.predictions, model.y,
                                    model.error_rate, model.global_step
                                ],
                                feed_dict=feeddict)
                            batch_errors[batch] = er
                            if global_step % 10 == 0:
                                log_scalar(train_writer, 'CER',
                                           er / FLAGS.batch_size, global_step)
                                print(
                                    '{} mode, global_step:{}, lr:{}, total:{}, '
                                    'batch:{}/{},epoch:{}/{},train loss={:.3f},mean train '
                                    'CER={:.3f}'.format(
                                        FLAGS.level, global_step,
                                        FLAGS.learning_rate,
                                        total_n, batch + 1,
                                        len(batched_random_idx), epoch + 1,
                                        FLAGS.num_epochs, l,
                                        er / FLAGS.batch_size))

                        elif not FLAGS.is_training:
                            l, pre, y, er, global_step = sess.run(
                                [
                                    model.loss, model.predictions, model.y,
                                    model.error_rate, model.global_step
                                ],
                                feed_dict=feeddict)
                            batch_errors[batch] = er
                            log_scalar(test_writer, 'CER',
                                       er / FLAGS.batch_size, global_step)
                            print(
                                '{} mode, global_step:{}, total:{}, batch:{}/{},test '
                                'loss={:.3f},mean test CER={:.3f}'.format(
                                    FLAGS.level, global_step, total_n,
                                    batch + 1, len(batched_random_idx), l,
                                    er / FLAGS.batch_size))

                    elif FLAGS.level == 'phn':
                        if FLAGS.is_training:
                            _, l, pre, y, global_step = sess.run(
                                [
                                    model.train_op, model.loss,
                                    model.predictions, model.y,
                                    model.global_step
                                ],
                                feed_dict=feeddict)
                            er = get_edit_distance([pre.values], [y.values],
                                                   True, FLAGS.level)
                            if global_step % 10 == 0:
                                log_scalar(train_writer, 'PER', er,
                                           global_step)
                                print(
                                    '{} mode, global_step:{}, lr:{}, total:{}, '
                                    'batch:{}/{},epoch:{}/{},train loss={:.3f},mean train '
                                    'PER={:.3f}'.format(
                                        FLAGS.level, global_step,
                                        FLAGS.learning_rate,
                                        total_n, batch + 1,
                                        len(batched_random_idx), epoch + 1,
                                        FLAGS.num_epochs, l, er))
                            batch_errors[batch] = er * len(batch_seq_length)
                        elif not FLAGS.is_training:
                            l, pre, y, global_step = sess.run(
                                [
                                    model.loss, model.predictions, model.y,
                                    model.global_step
                                ],
                                feed_dict=feeddict)
                            er = get_edit_distance([pre.values], [y.values],
                                                   True, FLAGS.level)
                            log_scalar(test_writer, 'PER', er, global_step)
                            print(
                                '{} mode, global_step:{}, total:{}, batch:{}/{},test '
                                'loss={:.3f},mean test PER={:.3f}'.format(
                                    FLAGS.level, global_step, total_n,
                                    batch + 1, len(batched_random_idx), l, er))
                            batch_errors[batch] = er * len(batch_seq_length)

                    # NOTE:
                    if er / FLAGS.batch_size == 1.0:
                        break

                    if batch % 100 == 0:
                        print('Truth:\n' +
                              output_to_sequence(y, level=FLAGS.level))
                        print('Output:\n' +
                              output_to_sequence(pre, level=FLAGS.level))

                end = time.time()
                delta_time = end - start
                print('Epoch ' + str(epoch + 1) + ' needs time:' +
                      str(delta_time) + ' s')

                if FLAGS.is_training:
                    if (epoch + 1) % 1 == 0:
                        checkpoint_path = os.path.join(savedir, 'model.ckpt')
                        model.saver.save(sess,
                                         checkpoint_path,
                                         global_step=model.global_step)
                        print('Model has been saved in {}'.format(savedir))
                    epoch_er = batch_errors.sum() / total_n
                    print('Epoch', epoch + 1, 'mean train error rate:',
                          epoch_er)

                if not FLAGS.is_training:
                    with tf.gfile.GFile(
                            os.path.join(resultdir,
                                         FLAGS.level + '_result.txt'),
                            'a') as result:
                        result.write(
                            output_to_sequence(y, level=FLAGS.level) + '\n')
                        result.write(
                            output_to_sequence(pre, level=FLAGS.level) + '\n')
                        result.write('\n')
                    epoch_er = batch_errors.sum() / total_n
                    print(' test error rate:', epoch_er)
    def run(self):
        # load data
        args_dict = self._default_configs()
        args = dotdict(args_dict)
        batchedData, maxTimeSteps, totalN = self.load_data(args,
                                                           mode=mode,
                                                           type=level)
        model = model_fn(args, maxTimeSteps)

        # count the num of params
        num_params = count_params(model, mode='trainable')
        all_num_params = count_params(model, mode='all')
        model.config['trainable params'] = num_params
        model.config['all params'] = all_num_params
        print(model.config)

        #with tf.Session(graph=model.graph) as sess:
        with tf.Session() as sess:
            # restore from stored model
            if keep == True:
                ckpt = tf.train.get_checkpoint_state(savedir)
                if ckpt and ckpt.model_checkpoint_path:
                    model.saver.restore(sess, ckpt.model_checkpoint_path)
                    print('Model restored from:' + savedir)
            else:
                print('Initializing')
                sess.run(model.initial_op)

            for epoch in range(num_epochs):
                ## training
                start = time.time()
                if mode == 'train':
                    print('Epoch', epoch + 1, '...')
                batchErrors = np.zeros(len(batchedData))
                batchRandIxs = np.random.permutation(len(batchedData))

                for batch, batchOrigI in enumerate(batchRandIxs):
                    batchInputs, batchTargetSparse, batchSeqLengths = batchedData[
                        batchOrigI]
                    batchTargetIxs, batchTargetVals, batchTargetShape = batchTargetSparse
                    feedDict = {
                        model.inputX: batchInputs,
                        model.targetIxs: batchTargetIxs,
                        model.targetVals: batchTargetVals,
                        model.targetShape: batchTargetShape,
                        model.seqLengths: batchSeqLengths
                    }

                    if level == 'cha':
                        if mode == 'train':
                            _, l, pre, y, er = sess.run([
                                model.optimizer, model.loss, model.predictions,
                                model.targetY, model.errorRate
                            ],
                                                        feed_dict=feedDict)

                            batchErrors[batch] = er
                            print(
                                '\n{} mode, total:{},batch:{}/{},epoch:{}/{},train loss={:.3f},mean train CER={:.3f}\n'
                                .format(level, totalN, batch + 1,
                                        len(batchRandIxs), epoch + 1,
                                        num_epochs, l, er / batch_size))

                        elif mode == 'test':
                            l, pre, y, er = sess.run([
                                model.loss, model.predictions, model.targetY,
                                model.errorRate
                            ],
                                                     feed_dict=feedDict)
                            batchErrors[batch] = er
                            print(
                                '\n{} mode, total:{},batch:{}/{},test loss={:.3f},mean test CER={:.3f}\n'
                                .format(level, totalN, batch + 1,
                                        len(batchRandIxs), l, er / batch_size))

                    elif level == 'phn':
                        if mode == 'train':
                            _, l, pre, y = sess.run([
                                model.optimizer, model.loss, model.predictions,
                                model.targetY
                            ],
                                                    feed_dict=feedDict)

                            er = get_edit_distance([pre.values], [y.values],
                                                   True, level)
                            print(
                                '\n{} mode, total:{},batch:{}/{},epoch:{}/{},train loss={:.3f},mean train PER={:.3f}\n'
                                .format(level, totalN, batch + 1,
                                        len(batchRandIxs), epoch + 1,
                                        num_epochs, l, er))
                            batchErrors[batch] = er * len(batchSeqLengths)
                        elif mode == 'test':
                            l, pre, y = sess.run(
                                [model.loss, model.predictions, model.targetY],
                                feed_dict=feedDict)
                            er = get_edit_distance([pre.values], [y.values],
                                                   True, level)
                            print(
                                '\n{} mode, total:{},batch:{}/{},test loss={:.3f},mean test PER={:.3f}\n'
                                .format(level, totalN, batch + 1,
                                        len(batchRandIxs), l, er))
                            batchErrors[batch] = er * len(batchSeqLengths)

                    # NOTE:
                    if er / batch_size == 1.0:
                        break

                    if batch % 30 == 0:
                        print('Truth:\n' + output_to_sequence(y, type=level))
                        print('Output:\n' +
                              output_to_sequence(pre, type=level))

                    if mode == 'train' and (
                        (epoch * len(batchRandIxs) + batch + 1) % 20 == 0 or
                        (epoch == num_epochs - 1
                         and batch == len(batchRandIxs) - 1)):
                        checkpoint_path = os.path.join(savedir, 'model.ckpt')
                        model.saver.save(sess,
                                         checkpoint_path,
                                         global_step=epoch)
                        print('Model has been saved in {}'.format(savedir))
                end = time.time()
                delta_time = end - start
                print('Epoch ' + str(epoch + 1) + ' needs time:' +
                      str(delta_time) + ' s')

                if mode == 'train':
                    if (epoch + 1) % 1 == 0:
                        checkpoint_path = os.path.join(savedir, 'model.ckpt')
                        model.saver.save(sess,
                                         checkpoint_path,
                                         global_step=epoch)
                        print('Model has been saved in {}'.format(savedir))
                    epochER = batchErrors.sum() / totalN
                    print('Epoch', epoch + 1, 'mean train error rate:',
                          epochER)
                    logging(model,
                            logfile,
                            epochER,
                            epoch,
                            delta_time,
                            mode='config')
                    logging(model,
                            logfile,
                            epochER,
                            epoch,
                            delta_time,
                            mode=mode)

                if mode == 'test':
                    with open(os.path.join(resultdir, level + '_result.txt'),
                              'a') as result:
                        result.write(output_to_sequence(y, type=level) + '\n')
                        result.write(
                            output_to_sequence(pre, type=level) + '\n')
                        result.write('\n')
                    epochER = batchErrors.sum() / totalN
                    print(' test error rate:', epochER)
                    logging(model, logfile, epochER, mode=mode)
    def run(self):
        # load data
        args_dict = self._default_configs()
        args = dotdict(args_dict)  # 创建dotdict类,类似创造自己的dict
        feature_dirs, label_dirs = get_data(datadir, level, train_dataset,
                                            dev_dataset, test_dataset, mode)

        # batchedData, maxTimeSteps, totalN = self.load_data(feature_dirs[0], label_dirs[0], mode, level)
        # model = model_fn(args, maxTimeSteps)
        # # 此两行作用不明白,删掉后不知道有什么影响

        # 记录每次epoch的
        # shuffle feature_dir and label_dir by same order
        FL_pair = list(zip(feature_dirs,
                           label_dirs))  # zip()后返回特定zip数据?,list让其变成列表
        random.shuffle(FL_pair)  # 打乱列表中元素顺序
        feature_dirs, label_dirs = zip(*FL_pair)

        for feature_dir, label_dir in zip(
                feature_dirs, label_dirs):  # zip()返回结果可用于for, 展示时用list()展出
            id_dir = feature_dirs.index(feature_dir)
            print('dir id:{}'.format(id_dir))
            batchedData, maxTimeSteps, totalN = self.load_data(
                feature_dir, label_dir, mode, level)

            model = model_fn(args, maxTimeSteps)  # 建立神经网络的图

            num_params = count_params(model, mode='trainable')
            all_num_params = count_params(model, mode='all')
            model.config['trainable params'] = num_params
            model.config['all params'] = all_num_params
            print(model.config)

            with tf.Session(graph=model.graph, config=config) as sess:
                # restore from stored model
                if keep:  # 用于重新训练 keep == True
                    ckpt = tf.train.get_checkpoint_state(savedir)
                    # Returns CheckpointState proto from the "checkpoint" file.
                    if ckpt and ckpt.model_checkpoint_path:  # The checkpoint file
                        model.saver.restore(sess, ckpt.model_checkpoint_path)
                        print('Model restored from:' + savedir)
                else:
                    print('Initializing')
                    sess.run(model.initial_op)

                for step in range(num_steps):
                    # training
                    start = time.time()
                    if mode == 'train':
                        print('step {} ...'.format(step + 1))

                    batchErrors = np.zeros(len(batchedData))
                    batchRandIxs = np.random.permutation(len(batchedData))
                    # 如果传给permutation一个矩阵,它会返回一个洗牌后的矩阵副本

                    for batch, batchOrigI in enumerate(batchRandIxs):
                        # 对于一个可迭代的(iterable)/可遍历的对象(如列表、字符串),enumerate将其组成一个索引序列,
                        # 利用它可以同时获得索引和值          这部分代码用于feed_Dict
                        batchInputs, batchTargetSparse, batchSeqLengths = batchedData[
                            batchOrigI]
                        batchTargetIxs, batchTargetVals, batchTargetShape = batchTargetSparse
                        feedDict = {
                            model.inputX: batchInputs,
                            model.targetIxs: batchTargetIxs,
                            model.targetVals: batchTargetVals,
                            model.targetShape: batchTargetShape,
                            model.seqLengths: batchSeqLengths
                        }

                        if level == 'cha':
                            if mode == 'train':
                                _, l, pre, y, er = sess.run([
                                    model.optimizer, model.loss,
                                    model.predictions, model.targetY,
                                    model.errorRate
                                ],
                                                            feed_dict=feedDict)

                                batchErrors[
                                    batch] = er  # batchError 207 batch 211

                                print(
                                    '\n{} mode, total:{},subdir:{}/{},batch:{}/{},step:{},train loss={:.3f},mean '
                                    'train CER={:.3f}, epoch: {}\n'.format(
                                        level, totalN, id_dir + 1,
                                        len(feature_dirs), batch + 1,
                                        len(batchRandIxs), step + 1, l,
                                        er / batch_size, num_epochs))

                            elif mode == 'dev':
                                l, pre, y, er = sess.run([
                                    model.loss, model.predictions,
                                    model.targetY, model.errorRate
                                ],
                                                         feed_dict=feedDict)
                                batchErrors[batch] = er
                                print(
                                    '\n{} mode, total:{},subdir:{}/{},batch:{}/{},dev loss={:.3f},'
                                    'mean dev CER={:.3f}\n'.format(
                                        level, totalN, id_dir + 1,
                                        len(feature_dirs), batch + 1,
                                        len(batchRandIxs), l, er / batch_size))

                            elif mode == 'test':
                                l, pre, y, er = sess.run([
                                    model.loss, model.predictions,
                                    model.targetY, model.errorRate
                                ],
                                                         feed_dict=feedDict)
                                batchErrors[batch] = er
                                print(
                                    '\n{} mode, total:{},subdir:{}/{},batch:{}/{},test loss={:.3f},'
                                    'mean test CER={:.3f}\n'.format(
                                        level, totalN, id_dir + 1,
                                        len(feature_dirs), batch + 1,
                                        len(batchRandIxs), l, er / batch_size))
                        elif level == 'seq2seq':
                            raise ValueError('level %s is not supported now' %
                                             str(level))

                        # NOTE: ??????for what
                        # if er / batch_size == 1.0:
                        #     break

                        if batch % 20 == 0:
                            print('Truth:\n' +
                                  output_to_sequence(y, type=level))
                            print('Output:\n' +
                                  output_to_sequence(pre, type=level))

                        if mode == 'train' and (
                            (step * len(batchRandIxs) + batch + 1) % 20 == 0 or
                            (step == num_steps - 1
                             and batch == len(batchRandIxs) - 1)):
                            # 每当算式结果是20倍数 或者 跑完一个 subdir的 batch后, 记录model
                            checkpoint_path = os.path.join(
                                savedir, 'model.ckpt')
                            model.saver.save(sess,
                                             checkpoint_path,
                                             global_step=step)
                            print('Model has been saved in {}'.format(savedir))

                    end = time.time()
                    delta_time = end - start
                    print('subdir ' + str(id_dir + 1) + ' needs time:' +
                          str(delta_time) + ' s')

                    if mode == 'train':
                        if (step + 1) % 1 == 0:
                            checkpoint_path = os.path.join(
                                savedir, 'model.ckpt')
                            model.saver.save(sess,
                                             checkpoint_path,
                                             global_step=step)
                            print('Model has been saved in {}'.format(savedir))
                        epochER = batchErrors.sum() / totalN
                        print('subdir', id_dir + 1, 'mean train error rate:',
                              epochER)  # 修改epoch成subdir
                        logging(model,
                                logfile,
                                epochER,
                                id_dir,
                                delta_time,
                                mode='config')
                        logging(model,
                                logfile,
                                epochER,
                                id_dir,
                                delta_time,
                                mode=mode)

                    if mode == 'test' or mode == 'dev':
                        with open(
                                os.path.join(resultdir, level + '_result.txt'),
                                'a') as result:
                            result.write(
                                output_to_sequence(y, type=level) + '\n')
                            result.write(
                                output_to_sequence(pre, type=level) + '\n')
                            result.write('\n')
                        epochER = batchErrors.sum() / totalN
                        print(' test error rate:', epochER)
                        logging(model, logfile, epochER, mode=mode)
    def run(self):
        # load data
        args_dict = self._default_configs()
        args = dotdict(args_dict)
        feature_dirs, label_dirs = get_data(datadir, level, train_dataset,
                                            dev_dataset, test_dataset, mode)
        batchedData, maxTimeSteps, totalN = self.load_data(
            feature_dirs[0], label_dirs[0], mode, level)
        model = model_fn(args, maxTimeSteps)
        FL_pair = list(zip(feature_dirs, label_dirs))
        random.shuffle(FL_pair)
        feature_dirs, label_dirs = zip(*FL_pair)
        print("Feature dirs:", feature_dirs)
        for feature_dir, label_dir in zip(feature_dirs, label_dirs):
            id_dir = feature_dirs.index(feature_dir)
            print('dir id:{}'.format(id_dir))
            batchedData, maxTimeSteps, totalN = self.load_data(
                feature_dir, label_dir, mode, level)
            model = model_fn(args, maxTimeSteps)
            num_params = count_params(model, mode='trainable')
            all_num_params = count_params(model, mode='all')
            model.config['trainable params'] = num_params
            model.config['all params'] = all_num_params
            print(model.config)
            with tf.Session(graph=model.graph) as sess:
                # restore from stored model
                if keep == True:
                    ckpt = tf.train.get_checkpoint_state(savedir)
                    if ckpt and ckpt.model_checkpoint_path:
                        model.saver.restore(sess, ckpt.model_checkpoint_path)
                        print('Model restored from:' + savedir)
                else:
                    print('Initializing')
                    sess.run(model.initial_op)
                total_cont = 0
                for epoch in range(num_epochs):
                    ## training
                    start = time.time()
                    if mode == 'train':
                        print('Epoch {} ...'.format(epoch + 1))
                    batchErrors = np.zeros(len(batchedData))
                    batchRandIxs = np.random.permutation(len(batchedData))
                    for batch, batchOrigI in enumerate(batchRandIxs):
                        batchInputs, batchTargetSparse, batchSeqLengths = batchedData[
                            batchOrigI]
                        batchTargetIxs, batchTargetVals, batchTargetShape = batchTargetSparse
                        feedDict = {
                            model.inputX: batchInputs,
                            model.targetIxs: batchTargetIxs,
                            model.targetVals: batchTargetVals,
                            model.targetShape: batchTargetShape,
                            model.seqLengths: batchSeqLengths
                        }

                        _, l, pre, y, er = sess.run([
                            model.optimizer, model.loss, model.predictions,
                            model.targetY, model.errorRate
                        ],
                                                    feed_dict=feedDict)
                        batchErrors[batch] = er
                        print(
                            '\n{} mode, total:{},subdir:{}/{},batch:{}/{},epoch:{}/{},train loss={:.3f},mean train CER={:.3f}\n'
                            .format(level, totalN, id_dir + 1,
                                    len(feature_dirs), batch + 1,
                                    len(batchRandIxs), epoch + 1, num_epochs,
                                    l, er / batch_size))
                        total_cont += 1
                        if batch % 20 == 0:
                            print('Truth:\n' +
                                  output_to_sequence(y, type=level))
                            print('Output:\n' +
                                  output_to_sequence(pre, type=level))
                            checkpoint_path = os.path.join(
                                savedir, 'model.ckpt')
                            model.saver.save(sess,
                                             checkpoint_path,
                                             global_step=total_cont)
                            print('Model has been saved in {}'.format(savedir))

                    end = time.time()
                    delta_time = end - start
                    print('Epoch ' + str(epoch + 1) + ' needs time:' +
                          str(delta_time) + ' s')
Exemplo n.º 5
0
def main(_):
    print('%s mode...' % str(FLAGS.mode))
    savedir = os.path.join(FLAGS.exp_dir, FLAGS.level, 'save')
    resultdir = os.path.join(FLAGS.exp_dir, FLAGS.level, 'result')
    check_path_exists([savedir, resultdir])
    # load data
    hparams = {}
    hparams['level'] = FLAGS.level
    hparams['batch_size'] = FLAGS.batch_size
    hparams['partition_size'] = FLAGS.partition_size
    hparams['num_hidden'] = FLAGS.num_hidden
    hparams['feature_length'] = FLAGS.feature_length
    hparams['num_classes'] = FLAGS.num_classes
    hparams['num_proj'] = FLAGS.num_proj
    hparams['learning_rate'] = FLAGS.learning_rate
    hparams['keep_prob'] = FLAGS.keep_prob
    hparams['clip_gradient_norm'] = FLAGS.clip_gradient_norm
    hparams['use_peepholes'] = FLAGS.use_peepholes
    if FLAGS.activation == 'tanh':
        hparams['activation'] = tf.tanh
    elif FLAGS.activation == 'relu':
        hparams['activation'] = tf.nn.relu
    feature_dirs, label_dirs = get_data(FLAGS.input_data_dir, FLAGS.level,
                                        FLAGS.train_dataset, FLAGS.dev_dataset,
                                        FLAGS.test_dataset, FLAGS.mode)
    batched_data, max_time_steps, total_n = load_batched_data(
        feature_dirs[0], label_dirs[0], FLAGS.batch_size, FLAGS.level)
    hparams['max_time_steps'] = max_time_steps
    ## shuffle feature_dir and label_dir by same order
    FL_pair = list(zip(feature_dirs, label_dirs))
    random.shuffle(FL_pair)
    feature_dirs, label_dirs = zip(*FL_pair)
    train_writer = tf.summary.FileWriter(resultdir + '/train')
    test_writer = tf.summary.FileWriter(resultdir + '/test')

    for feature_dir, label_dir in zip(feature_dirs, label_dirs):
        id_dir = feature_dirs.index(feature_dir)
        print('dir id:{}'.format(id_dir))
        batched_data, max_time_steps, total_n = load_batched_data(
            feature_dir, label_dir, FLAGS.batch_size, FLAGS.level)
        hparams['max_time_steps'] = max_time_steps
        model = DRNN(FLAGS.cell, hparams, FLAGS.mode == 'train')

        with tf.Session(FLAGS.master) as sess:
            # restore from stored model
            if FLAGS.restore:
                ckpt = tf.train.get_checkpoint_state(savedir)
                if ckpt and ckpt.model_checkpoint_path:
                    model.saver.restore(sess, ckpt.model_checkpoint_path)
                    print('Model restored from:' + savedir)
            else:
                print('Initializing')
                sess.run(model.initial_op)

            for epoch in range(FLAGS.num_epochs):
                ## training
                start = time.time()
                if FLAGS.mode == 'train':
                    print('Epoch {} ...'.format(epoch + 1))

                batch_errors = np.zeros(len(batched_data))
                batched_random_idx = np.random.permutation(len(batched_data))

                for batch, batch_original_idx in enumerate(batched_random_idx):
                    batch_inputs, batch_target_sparse, batch_seq_length = batched_data[
                        batch_original_idx]
                    batch_tgt_idx, batch_tgt_vals, batch_tgt_shape = batch_target_sparse
                    feedDict = {
                        model.x: batch_inputs,
                        model.tgt_idx: batch_tgt_idx,
                        model.tgt_vals: batch_tgt_vals,
                        model.tgt_shape: batch_tgt_shape,
                        model.seq_length: batch_seq_length
                    }

                    if FLAGS.level == 'cha':
                        if FLAGS.mode == 'train':
                            _, l, pre, y, er = sess.run([
                                model.train_op, model.loss, model.predictions,
                                model.y, model.error_rate
                            ],
                                                        feed_dict=feedDict)

                            batch_errors[batch] = er
                            print(
                                '\n{} mode, total:{},subdir:{}/{},batch:{}/{},epoch:{}/{},train loss={:.3f},mean train CER={:.3f}\n'
                                .format(FLAGS.level, total_n, id_dir + 1,
                                        len(feature_dirs), batch + 1,
                                        len(batched_random_idx), epoch + 1,
                                        FLAGS.num_epochs, l,
                                        er / FLAGS.batch_size))

                        elif FLAGS.mode == 'dev':
                            l, pre, y, er = sess.run([
                                model.loss, model.predictions, model.y,
                                model.error_rate
                            ],
                                                     feed_dict=feedDict)
                            batch_errors[batch] = er
                            print(
                                '\n{} mode, total:{},subdir:{}/{},batch:{}/{},dev loss={:.3f},mean dev CER={:.3f}\n'
                                .format(FLAGS.level, total_n, id_dir + 1,
                                        len(feature_dirs), batch + 1,
                                        len(batched_random_idx), l,
                                        er / FLAGS.batch_size))

                        elif FLAGS.mode == 'test':
                            l, pre, y, er = sess.run([
                                model.loss, model.predictions, model.y,
                                model.error_rate
                            ],
                                                     feed_dict=feedDict)
                            batch_errors[batch] = er
                            print(
                                '\n{} mode, total:{},subdir:{}/{},batch:{}/{},test loss={:.3f},mean test CER={:.3f}\n'
                                .format(FLAGS.level, total_n, id_dir + 1,
                                        len(feature_dirs), batch + 1,
                                        len(batched_random_idx), l,
                                        er / FLAGS.batch_size))
                    elif FLAGS.level == 'seq2seq':
                        raise ValueError('level %s is not supported now' %
                                         str(FLAGS.level))

                    # NOTE:
                    if er / FLAGS.batch_size == 1.0:
                        break

                    if batch % 20 == 0:
                        print('Truth:\n' +
                              output_to_sequence(y, level=FLAGS.level))
                        print('Output:\n' +
                              output_to_sequence(pre, level=FLAGS.level))

                    if FLAGS.mode == 'train' and (
                        (epoch * len(batched_random_idx) + batch + 1) % 20 == 0
                            or (epoch == FLAGS.num_epochs - 1
                                and batch == len(batched_random_idx) - 1)):
                        checkpoint_path = os.path.join(savedir, 'model.ckpt')
                        model.saver.save(sess,
                                         checkpoint_path,
                                         global_step=epoch)
                        print('Model has been saved in {}'.format(savedir))

                end = time.time()
                delta_time = end - start
                print('Epoch ' + str(epoch + 1) + ' needs time:' +
                      str(delta_time) + ' s')

                if FLAGS.mode == 'train':
                    if (epoch + 1) % 1 == 0:
                        checkpoint_path = os.path.join(savedir, 'model.ckpt')
                        model.saver.save(sess,
                                         checkpoint_path,
                                         global_step=epoch)
                        print('Model has been saved in {}'.format(savedir))
                    epoch_er = batch_errors.sum() / total_n
                    print('Epoch', epoch + 1, 'mean train error rate:',
                          epoch_er)

                if FLAGS.mode == 'test' or FLAGS.mode == 'dev':
                    with open(
                            os.path.join(resultdir,
                                         FLAGS.level + '_result.txt'),
                            'a') as result:
                        result.write(
                            output_to_sequence(y, level=FLAGS.level) + '\n')
                        result.write(
                            output_to_sequence(pre, level=FLAGS.level) + '\n')
                        result.write('\n')
                    epoch_er = batch_errors.sum() / total_n
                    print(' test error rate:', epoch_er)