def loadChunkData(iterator, feat, otherArgs):
     # <feat> is KaldiArk object
     global args
     uttSpk, cmvnState, labelPdf, labelPho, toDo = otherArgs
     # use CMVN
     if args.useCMVN:
         feat = E.use_cmvn(feat, cmvnState, uttSpk)
     # Add delta
     if args.delta > 0:
         feat = E.add_delta(feat, args.delta)
     # Splice front-back n frames
     if args.splice > 0:
         feat = feat.splice(args.splice)
     # Transform to KaldiDict and sort them by frame length
     feat = feat.array.sort(by='frame')
     # Normalize
     if args.normalizeChunk:
         feat = feat.normalize()
     # Concatenate label
     datas = feat.concat([labelPdf, labelPho], axis=1)
     # cut frames
     if toDo == 'train':
         if iterator.epoch >= 4:
             datas = datas.cut(1000)
         elif iterator.epoch >= 3:
             datas = datas.cut(800)
         elif iterator.epoch >= 2:
             datas = datas.cut(400)
         elif iterator.epoch >= 1:
             datas = datas.cut(200)
         elif iterator.epoch >= 0:
             datas = datas.cut(100)
     # Transform trainable numpy data
     datas, _ = datas.merge(keepDim=True, sortFrame=True)
     return datas
 def loadChunkData(iterator, feat, otherArgs):
     # <feat> is a KaldiArk object
     global args
     uttSpk, cmvnState, labelPdf, labelPho = otherArgs
     # use CMVN
     if args.useCMVN:
         feat = E.use_cmvn(feat, cmvnState, uttSpk)
     # Add delta
     if args.delta > 0:
         feat = E.add_delta(feat, args.delta)
     # Splice front-back n frames
     if args.splice > 0:
         feat = feat.splice(args.splice)
     # Transform to KaldiDict
     feat = feat.array
     # Normalize
     if args.normalizeChunk:
         feat = feat.normalize()
     # Concatenate data-label in dimension
     datas = feat.concat([labelPdf, labelPho], axis=1)
     # Transform to trainable numpy data
     datas, _ = datas.merge(keepDim=False, sort=False)
     return datas
def decode_test(outDimPdf=1968, outDimPho=48):

    global args

    if args.preModel == '':
        raise Exception("Expected Pretrained Model.")
    elif not os.path.isfile(args.preModel):
        raise Exception("No such file:{}.".format(args.preModel))

    print("\n############## Parameters Configure ##############")

    # Show configure information and write them to file
    def configLog(message, f):
        print(message)
        f.write(message + '\n')

    f = open(args.outDir + '/configure', "w")
    configLog(
        'Start System Time:{}'.format(
            datetime.datetime.now().strftime("%Y-%m-%d %X")), f)
    configLog('Host Name:{}'.format(socket.gethostname()), f)
    configLog('Fix Random Seed:{}'.format(args.randomSeed), f)
    configLog('GPU ID:{}'.format(args.gpu), f)
    configLog('Pretrained Model:{}'.format(args.preModel), f)
    configLog('Output Folder:{}'.format(args.outDir), f)
    configLog('Use CMVN:{}'.format(args.useCMVN), f)
    configLog('Splice N Frames:{}'.format(args.splice), f)
    configLog('Add N Deltas:{}'.format(args.delta), f)
    configLog('Normalize Chunk:{}'.format(args.normalizeChunk), f)
    configLog('Normalize AMP:{}'.format(args.normalizeAMP), f)
    configLog('Decode Minimum Active:{}'.format(args.minActive), f)
    configLog('Decode Maximum Active:{}'.format(args.maxActive), f)
    configLog('Decode Maximum Memory:{}'.format(args.maxMemory), f)
    configLog('Decode Beam:{}'.format(args.beam), f)
    configLog('Decode Lattice Beam:{}'.format(args.latBeam), f)
    configLog('Decode Acoustic Weight:{}'.format(args.acwt), f)
    configLog('Decode minimum Language Weight:{}'.format(args.minLmwt), f)
    configLog('Decode maximum Language Weight:{}'.format(args.maxLmwt), f)
    f.close()

    print("\n############## Decode Test ##############")

    #------------------ STEP 1: Load Pretrained Model ------------------

    print('Load Model...')
    # Initialize model
    featDim = 40
    if args.delta > 0:
        featDim *= (args.delta + 1)
    if args.splice > 0:
        featDim *= (2 * args.splice + 1)
    model = MLP(featDim, outDimPdf, outDimPho)
    chainer.serializers.load_npz(args.preModel, model)
    if args.gpu >= 0:
        model.to_gpu(args.gpu)

    #------------------ STEP 2: Prepare Test Data ------------------

    print('Prepare decode test data...')
    # Fmllr file
    testFilePath = args.TIMITpath + '/data-fmllr-tri3/test/feats.scp'
    testFeat = E.load(testFilePath)
    # Use CMVN
    if args.useCMVN:
        testUttSpk = args.TIMITpath + '/data-fmllr-tri3/test/utt2spk'
        testCmvnState = args.TIMITpath + '/data-fmllr-tri3/test/cmvn.ark'
        testFeat = E.use_cmvn(testFeat, testCmvnState, testUttSpk)
    # Add delta
    if args.delta > 0:
        testFeat = E.add_delta(testFeat, args.delta)
    # Splice frames
    if args.splice > 0:
        testFeat = testFeat.splice(args.splice)
    # Transform to array
    testFeat = testFeat.array
    # Normalize
    if args.normalizeChunk:
        testFeat = testFeat.normalize()
    # Normalize acoustic model output
    if args.normalizeAMP:
        # Compute pdf counts in order to normalize acoustic model posterior probability.
        countFile = args.outDir + '/pdfs_counts.txt'
        # Get statistics file
        if not os.path.isfile(countFile):
            trainAliFile = args.TIMITpath + '/exp/dnn4_pretrain-dbn_dnn_ali/ali.*.gz'
            _ = E.analyze_counts(aliFile=trainAliFile, outFile=countFile)
        with open(countFile) as f:
            line = f.readline().strip().strip("[]").strip()
        # Get AMP bias value
        counts = np.array(list(map(float, line.split())), dtype=np.float32)
        normalizeBias = np.log(counts / np.sum(counts))
    else:
        normalizeBias = 0

    #------------------ STEP 3: Decode  ------------------

    temp = E.KaldiDict()
    print('Compute Test WER: Forward network', end=" " * 20 + '\r')
    with chainer.using_config('train', False), chainer.no_backprop_mode():
        for utt in testFeat.keys():
            data = cp.array(testFeat[utt], dtype=cp.float32)
            out1, out2 = model(data)
            out = F.log_softmax(out1, axis=1)
            out.to_cpu()
            temp[utt] = out.array - normalizeBias
    # Tansform KaldiDict to KaldiArk format
    print('Compute Test WER: Transform to ark', end=" " * 20 + '\r')
    amp = temp.ark
    # Decode and obtain lattice
    hmm = args.TIMITpath + '/exp/dnn4_pretrain-dbn_dnn_ali_test/final.mdl'
    hclg = args.TIMITpath + '/exp/tri3/graph/HCLG.fst'
    lexicon = args.TIMITpath + '/exp/tri3/graph/words.txt'
    print('Compute Test WER: Generate Lattice', end=" " * 20 + '\r')
    lattice = E.decode_lattice(amp, hmm, hclg, lexicon, args.minActive,
                               args.maxActive, args.maxMemory, args.beam,
                               args.latBeam, args.acwt)
    # Change language weight from 1 to 10, get the 1best words.
    print('Compute Test WER: Get 1Best', end=" " * 20 + '\r')
    outs = lattice.get_1best(lmwt=args.minLmwt,
                             maxLmwt=args.maxLmwt,
                             outFile=args.outDir + '/outRaw.txt')

    #------------------ STEP 4: Score  ------------------

    # If reference file is not existed, make it.
    phonemap = args.TIMITpath + '/conf/phones.60-48-39.map'
    outFilter = args.TIMITpath + '/local/timit_norm_trans.pl -i - -m {} -from 48 -to 39'.format(
        phonemap)
    if not os.path.isfile(args.outDir + '/test_filt.txt'):
        refText = args.TIMITpath + '/data/test/text'
        cmd = 'cat {} | {} > {}/test_filt.txt'.format(refText, outFilter,
                                                      args.outDir)
        (_, _) = E.run_shell_cmd(cmd)
    # Score WER and find the smallest one.
    print('Compute Test WER: compute WER', end=" " * 20 + '\r')
    minWER = (None, None)
    for k in range(args.minLmwt, args.maxLmwt + 1, 1):
        cmd = 'cat {} | {} > {}/tanslation_{}.txt'.format(
            outs[k], outFilter, args.outDir, k)
        (_, _) = E.run_shell_cmd(cmd)
        os.remove(outs[k])
        score = E.wer('{}/test_filt.txt'.format(args.outDir),
                      "{}/tanslation_{}.txt".format(args.outDir, k),
                      mode='all')
        if minWER[0] == None or score['WER'] < minWER[0]:
            minWER = (score['WER'], k)

    print("Best WER:{}% at {}/tanslation_{}.txt".format(
        minWER[0], args.outDir, k))
def train_model():

    global args

    print("\n############## Parameters Configure ##############")

    # Show configure information and write them to file
    def configLog(message, f):
        print(message)
        f.write(message + '\n')

    f = open(args.outDir + '/configure', "w")
    configLog(
        'Start System Time:{}'.format(
            datetime.datetime.now().strftime("%Y-%m-%d %X")), f)
    configLog('Host Name:{}'.format(socket.gethostname()), f)
    configLog('Fix Random Seed:{}'.format(args.randomSeed), f)
    configLog('Mini Batch Size:{}'.format(args.batchSize), f)
    configLog('GPU ID:{}'.format(args.gpu), f)
    configLog('Train Epochs:{}'.format(args.epoch), f)
    configLog('Output Folder:{}'.format(args.outDir), f)
    configLog('Dropout Rate:{}'.format(args.dropout), f)
    configLog('Use CMVN:{}'.format(args.useCMVN), f)
    configLog('Splice N Frames:{}'.format(args.splice), f)
    configLog('Add N Deltas:{}'.format(args.delta), f)
    configLog('Normalize Chunk:{}'.format(args.normalizeChunk), f)
    configLog('Normalize AMP:{}'.format(args.normalizeAMP), f)
    configLog('Decode Minimum Active:{}'.format(args.minActive), f)
    configLog('Decode Maximum Active:{}'.format(args.maxActive), f)
    configLog('Decode Maximum Memory:{}'.format(args.maxMemory), f)
    configLog('Decode Beam:{}'.format(args.beam), f)
    configLog('Decode Lattice Beam:{}'.format(args.latBeam), f)
    configLog('Decode Acoustic Weight:{}'.format(args.acwt), f)
    configLog('Decode minimum Language Weight:{}'.format(args.minLmwt), f)
    configLog('Decode maximum Language Weight:{}'.format(args.maxLmwt), f)
    f.close()

    print("\n############## Train DNN Acoustic Model ##############")

    #------------------------ STEP 1: Prepare Training and Validation Data -----------------------------

    print('Prepare Data Iterator...')
    # Prepare fMLLR feature files
    trainScpFile = args.TIMITpath + '/data-fmllr-tri3/train/feats.scp'
    devScpFile = args.TIMITpath + '/data-fmllr-tri3/dev/feats.scp'
    # Prepare training labels (alignment data)
    trainAliFile = args.TIMITpath + '/exp/dnn4_pretrain-dbn_dnn_ali/ali.*.gz'
    trainLabelPdf = E.get_ali(trainAliFile)
    trainLabelPho = E.get_ali(trainAliFile, returnPhone=True)
    for i in trainLabelPho.keys():
        trainLabelPho[i] = trainLabelPho[i] - 1
    # Prepare validation labels (alignment data)
    devAliFile = args.TIMITpath + '/exp/dnn4_pretrain-dbn_dnn_ali_dev/ali.*.gz'
    devLabelPdf = E.get_ali(devAliFile)
    devLabelPho = E.get_ali(devAliFile, returnPhone=True)
    for i in devLabelPho.keys():
        devLabelPho[i] = devLabelPho[i] - 1
    # prepare CMVN files
    trainUttSpk = args.TIMITpath + '/data-fmllr-tri3/train/utt2spk'
    trainCmvnState = args.TIMITpath + '/data-fmllr-tri3/train/cmvn.ark'
    devUttSpk = args.TIMITpath + '/data-fmllr-tri3/dev/utt2spk'
    devCmvnState = args.TIMITpath + '/data-fmllr-tri3/dev/cmvn.ark'

    # Now we try to make training-iterator and validation-iterator.
    # Firstly, customize a function to process feature data.
    def loadChunkData(iterator, feat, otherArgs):
        # <feat> is a KaldiArk object
        global args
        uttSpk, cmvnState, labelPdf, labelPho = otherArgs
        # use CMVN
        if args.useCMVN:
            feat = E.use_cmvn(feat, cmvnState, uttSpk)
        # Add delta
        if args.delta > 0:
            feat = E.add_delta(feat, args.delta)
        # Splice front-back n frames
        if args.splice > 0:
            feat = feat.splice(args.splice)
        # Transform to KaldiDict
        feat = feat.array
        # Normalize
        if args.normalizeChunk:
            feat = feat.normalize()
        # Concatenate data-label in dimension
        datas = feat.concat([labelPdf, labelPho], axis=1)
        # Transform to trainable numpy data
        datas, _ = datas.merge(keepDim=False, sort=False)
        return datas

    # Then get data iterators
    train = E.DataIterator(trainScpFile,
                           loadChunkData,
                           args.batchSize,
                           chunks=5,
                           shuffle=True,
                           otherArgs=(trainUttSpk, trainCmvnState,
                                      trainLabelPdf, trainLabelPho))
    print('Generate train dataset done. Chunks:{} / Batch size:{}'.format(
        train.chunks, train.batchSize))
    dev = E.DataIterator(devScpFile,
                         loadChunkData,
                         args.batchSize,
                         chunks=1,
                         shuffle=False,
                         otherArgs=(devUttSpk, devCmvnState, devLabelPdf,
                                    devLabelPho))
    print(
        'Generate validation dataset done. Chunks:{} / Batch size:{}.'.format(
            dev.chunks, dev.batchSize))

    #--------------------------------- STEP 2: Prepare Model --------------------------

    print('Prepare Model...')
    # Initialize model
    featDim = 40
    if args.delta > 0:
        featDim *= (args.delta + 1)
    if args.splice > 0:
        featDim *= (2 * args.splice + 1)
    model = MLP(featDim, trainLabelPdf.target, trainLabelPho.target)
    if args.gpu >= 0:
        model.to_gpu(args.gpu)
    # Initialize optimizer
    lr = [(0, 0.08), (10, 0.04), (15, 0.02), (17, 0.01), (19, 0.005),
          (22, 0.0025), (25, 0.001)]
    print('Learning Rate (epoch,newLR):', lr)
    optimizer = chainer.optimizers.MomentumSGD(lr[0][1], momentum=0.0)
    lr.pop(0)
    optimizer.setup(model)
    optimizer.add_hook(chainer.optimizer_hooks.WeightDecay(0.0))
    # Prepare a supporter to help handling training information.
    supporter = E.Supporter(args.outDir)

    #------------------ STEP 3: Prepare Decoding Test Data and Function ------------------

    print('Prepare decoding test data...')
    # fMLLR file of test data
    testFilePath = args.TIMITpath + '/data-fmllr-tri3/test/feats.scp'
    testFeat = E.load(testFilePath)
    # Using CMVN
    if args.useCMVN:
        testUttSpk = args.TIMITpath + '/data-fmllr-tri3/test/utt2spk'
        testCmvnState = args.TIMITpath + '/data-fmllr-tri3/test/cmvn.ark'
        testFeat = E.use_cmvn(testFeat, testCmvnState, testUttSpk)
    # Add delta
    if args.delta > 0:
        testFeat = E.add_delta(testFeat, args.delta)
    # Splice frames
    if args.splice > 0:
        testFeat = testFeat.splice(args.splice)
    # Transform to array
    testFeat = testFeat.array
    # Normalize
    if args.normalizeChunk:
        testFeat = testFeat.normalize()
    # Normalize acoustic model output
    if args.normalizeAMP:
        # Compute pdf counts in order to normalize acoustic model posterior probability.
        countFile = args.outDir + '/pdfs_counts.txt'
        # Get statistics file
        if not os.path.isfile(countFile):
            _ = E.analyze_counts(aliFile=trainAliFile, outFile=countFile)
        with open(countFile) as f:
            line = f.readline().strip().strip("[]").strip()
        # Get AMP bias value
        counts = np.array(list(map(float, line.split())), dtype=np.float32)
        normalizeBias = np.log(counts / np.sum(counts))
    else:
        normalizeBias = 0
    # Now, design a function to compute WER score
    def wer_fun(model, testFeat, normalizeBias):
        global args
        # Use decode test data to forward network
        temp = E.KaldiDict()
        print('(testing) Forward network', end=" " * 20 + '\r')
        with chainer.using_config('train', False), chainer.no_backprop_mode():
            for utt in testFeat.keys():
                data = cp.array(testFeat[utt], dtype=cp.float32)
                out1, out2 = model(data)
                out = F.log_softmax(out1, axis=1)
                out.to_cpu()
                temp[utt] = out.array - normalizeBias
        # Tansform KaldiDict to KaldiArk format
        print('(testing) Transform to ark', end=" " * 20 + '\r')
        amp = temp.ark
        # Decoding to obtain a lattice
        hmm = args.TIMITpath + '/exp/dnn4_pretrain-dbn_dnn_ali_test/final.mdl'
        hclg = args.TIMITpath + '/exp/tri3/graph/HCLG.fst'
        lexicon = args.TIMITpath + '/exp/tri3/graph/words.txt'
        print('(testing) Generate Lattice', end=" " * 20 + '\r')
        lattice = E.decode_lattice(amp, hmm, hclg, lexicon, args.minActive,
                                   args.maxActive, args.maxMemory, args.beam,
                                   args.latBeam, args.acwt)
        # Change language weight from 1 to 10, get the 1best words.
        print('(testing) Get 1-best words', end=" " * 20 + '\r')
        outs = lattice.get_1best(lmwt=args.minLmwt,
                                 maxLmwt=args.maxLmwt,
                                 outFile=args.outDir + '/outRaw.txt')
        # If reference file is not existed, make it.
        phonemap = args.TIMITpath + '/conf/phones.60-48-39.map'
        outFilter = args.TIMITpath + '/local/timit_norm_trans.pl -i - -m {} -from 48 -to 39'.format(
            phonemap)
        if not os.path.isfile(args.outDir + '/test_filt.txt'):
            refText = args.TIMITpath + '/data/test/text'
            cmd = 'cat {} | {} > {}/test_filt.txt'.format(
                refText, outFilter, args.outDir)
            (_, _) = E.run_shell_cmd(cmd)
        # Score WER and find the smallest one.
        print('(testing) Score', end=" " * 20 + '\r')
        minWER = None
        for k in range(args.minLmwt, args.maxLmwt + 1, 1):
            cmd = 'cat {} | {} > {}/test_prediction_filt.txt'.format(
                outs[k], outFilter, args.outDir)
            (_, _) = E.run_shell_cmd(cmd)
            os.remove(outs[k])
            score = E.wer('{}/test_filt.txt'.format(args.outDir),
                          "{}/test_prediction_filt.txt".format(args.outDir),
                          mode='all')
            if minWER == None or score['WER'] < minWER:
                minWER = score['WER']
        return minWER

    #-------------------------- STEP 4: Train Model ---------------------------

    # While first epoch, the epoch size is computed gradually, so the prograss information will be inaccurate.
    print('Now Start to Train')
    print(
        'Note that: The first epoch will be doing the statistics of total data size gradually.'
    )
    print(
        'Note that: We will evaluate the WER of test dataset after epoch which will cost a few seconds.'
    )

    # Preprocessing batch data which is getten from data iterator
    def convert(batch):
        batch = cp.array(batch, dtype=cp.float32)
        data = batch[:, 0:-2]
        label1 = cp.array(batch[:, -2], dtype=cp.int32)
        label2 = cp.array(batch[:, -1], dtype=cp.int32)
        return data, label1, label2

    # We will save model during training loop, so prepare a model-save function
    def saveFunc(fileName, model):
        global args
        copymodel = model.copy()
        if args.gpu >= 0:
            copymodel.to_cpu()
        chainer.serializers.save_npz(fileName, copymodel)

    # Start training loop
    for e in range(args.epoch):
        supporter.send_report({'epoch': e})
        print()
        i = 1
        usedTime = 0
        # Train
        while True:
            start = time.time()
            # Get data >> Forward network >> Loss back propagation >> Update
            batch = train.next()
            data, label1, label2 = convert(batch)
            with chainer.using_config('Train', True):
                h1, h2 = model(data)
                L1 = F.softmax_cross_entropy(h1, label1)
                L2 = F.softmax_cross_entropy(h2, label2)
                loss = L1 + L2
                acc = F.accuracy(F.softmax(h1, axis=1), label1)
            model.cleargrads()
            loss.backward()
            optimizer.update()
            # Compute time cost
            ut = time.time() - start
            usedTime += ut
            print(
                "(training) Epoch:{}>>>{}% Chunk:{}>>>{}% Iter:{} Used-time:{}s Batch-loss:{} Speed:{}iters/s"
                .format(e, int(100 * train.epochProgress), train.chunk,
                        int(100 * train.chunkProgress), i, int(usedTime),
                        "%.4f" % (float(loss.array)), "%.2f" % (1 / ut)),
                end=" " * 5 + '\r')
            i += 1
            supporter.send_report({'train_loss': loss, 'train_acc': acc})
            # If forward all data, break
            if train.isNewEpoch:
                break
        print()
        i = 1
        usedTime = 0
        # Validate
        while True:
            start = time.time()
            # Get data >> Forward network >> Score
            batch = dev.next()
            data, label1, label2 = convert(batch)
            with chainer.using_config('train',
                                      False), chainer.no_backprop_mode():
                h1, h2 = model(data)
                loss = F.softmax_cross_entropy(h1, label1)
                acc = F.accuracy(F.softmax(h1, axis=1), label1)
            # Compute time cost
            ut = time.time() - start
            usedTime += ut
            print(
                "(Validating) Epoch:{}>>>{}% Chunk:{}>>>{}% Iter:{} Used-time:{}s Batch-loss:{} Speed:{}iters/s"
                .format(e, int(100 * dev.epochProgress), dev.chunk,
                        int(100 * dev.chunkProgress), i, int(usedTime),
                        "%.4f" % (float(loss.array)), "%.2f" % (1 / ut)),
                end=" " * 5 + '\r')
            i += 1
            supporter.send_report({'dev_loss': loss, 'dev_acc': acc})
            # If forward all data, break
            if dev.isNewEpoch:
                break
        print()
        # Compute WER score
        WERscore = wer_fun(model, testFeat, normalizeBias)
        supporter.send_report({'test_wer': WERscore, 'lr': optimizer.lr})
        # Collect all information of this epoch that is reported before, and show them at display
        supporter.collect_report(plot=True)
        # Save model
        supporter.save_model(saveFunc, models={'MLP': model})
        # Change learning rate
        if len(lr) > 0 and supporter.judge('epoch', '>=', lr[0][0]):
            optimizer.lr = lr[0][1]
            lr.pop(0)

    print("DNN Acoustic Model training done.")
    print("The final model has been saved as:", supporter.finalModel)
    print('Over System Time:', datetime.datetime.now().strftime("%Y-%m-%d %X"))
def main():

    # ------------- Parse arguments from command line ----------------------
    # 1. Add a discription of this program
    args.describe("This program is used to train monophone GMM-HMM model")
    # 2. Add options
    args.add("--expDir",
             abbr="-e",
             dtype=str,
             default="exp",
             discription="The data and output path of current experiment.")
    args.add("--delta",
             abbr="-d",
             dtype=int,
             default=2,
             discription="Add n-order to feature.")
    args.add("--numIters",
             abbr="-n",
             dtype=int,
             default=40,
             discription="How many iterations to train.")
    args.add("--maxIterInc",
             abbr="-m",
             dtype=int,
             default=30,
             discription="The final iteration of increasing gaussians.")
    args.add("--realignIter",
             abbr="-r",
             dtype=int,
             default=[
                 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 12, 14, 16, 18, 20, 23, 26, 29,
                 32, 35, 38
             ],
             discription="the iteration to realign feature.")
    args.add("--order",
             abbr="-o",
             dtype=int,
             default=6,
             minV=1,
             maxV=6,
             discription="Which N-grams model to use.")
    args.add("--beam",
             abbr="-b",
             dtype=int,
             default=13,
             discription="Decode beam size.")
    args.add("--latBeam",
             abbr="-l",
             dtype=int,
             default=6,
             discription="Lattice beam size.")
    args.add("--acwt",
             abbr="-a",
             dtype=float,
             default=0.083333,
             discription="Acoustic model weight.")
    args.add(
        "--parallel",
        abbr="-p",
        dtype=int,
        default=4,
        minV=1,
        maxV=10,
        discription=
        "The number of parallel process to compute feature of train dataset.")
    args.add("--skipTrain",
             abbr="-s",
             dtype=bool,
             default=False,
             discription="If True, skip training. Do decoding only.")
    # 3. Then start to parse arguments.
    args.parse()
    # 4. Take a backup of arguments
    args.print_args()  # print arguments to display
    argsLogFile = os.path.join(args.expDir, "conf", "train_mono.args")
    args.save(argsLogFile)

    if not args.skipTrain:
        # ------------- Prepare feature for training ----------------------
        # 1. Load the feature for training (We use the index table format)
        feat = exkaldi.load_index_table(
            os.path.join(args.expDir, "mfcc", "train", "mfcc_cmvn.ark"))
        print(f"Load MFCC+CMVN feature.")
        feat = exkaldi.add_delta(feat,
                                 order=args.delta,
                                 outFile=os.path.join(args.expDir,
                                                      "train_mono",
                                                      "mfcc_cmvn_delta.ark"))
        print(f"Add {args.delta}-order deltas.")
        # 2. Load lexicon bank
        lexicons = exkaldi.load_lex(
            os.path.join(args.expDir, "dict", "lexicons.lex"))
        print(f"Restorage lexicon bank.")

        # ------------- Start training ----------------------
        # 1. Initialize a monophone HMM object
        model = exkaldi.hmm.MonophoneHMM(lexicons=lexicons, name="mono")
        model.initialize(feat=feat,
                         topoFile=os.path.join(args.expDir, "dict", "topo"))
        print(f"Initialized a monophone HMM-GMM model: {model.info}.")

        # 2. Split data for parallel training
        transcription = exkaldi.load_transcription(
            os.path.join(args.expDir, "data", "train", "text"))
        transcription = transcription.sort()
        if args.parallel > 1:
            # split feature
            feat = feat.sort(by="utt").subset(chunks=args.parallel)
            # split transcription depending on utterance IDs of each feature
            temp = []
            for f in feat:
                temp.append(transcription.subset(keys=f.utts))
            transcription = temp

        # 3. Train
        model.train(
            feat,
            transcription,
            LFile=os.path.join(args.expDir, "dict", "L.fst"),
            tempDir=os.path.join(args.expDir, "train_mono"),
            numIters=args.numIters,
            maxIterInc=args.maxIterInc,
            totgauss=1000,
            realignIter=args.realignIter,
            boostSilence=1.0,
        )
        print(model.info)
        # Save the tree
        model.tree.save(os.path.join(args.expDir, "train_mono", "tree"))
        print(f"Tree has been saved.")

        # 4. Realign with boostSilence 1.25
        print("Realign the training feature (boost silence = 1.25)")
        trainGraphFiles = exkaldi.utils.list_files(
            os.path.join(args.expDir, "train_mono", "*train_graph"))
        model.align(
            feat,
            trainGraphFile=
            trainGraphFiles,  # train graphs have been generated in the train step.
            boostSilence=1.25,  #1.5
            outFile=os.path.join(args.expDir, "train_mono", "final.ali"))
        del feat
        print("Save the new alignment done.")
        tree = model.tree

    else:
        declare.is_file(os.path.join(args.expDir, "train_mono", "final.mdl"))
        declare.is_file(os.path.join(args.expDir, "train_mono", "tree"))
        model = exkaldi.load_hmm(
            os.path.join(args.expDir, "train_mono", "final.mdl"))
        tree = exkaldi.load_tree(
            os.path.join(args.expDir, "train_mono", "tree"))

    # ------------- Compile WFST training ----------------------
    # Make a WFST decoding graph
    make_WFST_graph(
        outDir=os.path.join(args.expDir, "train_mono", "graph"),
        hmm=model,
        tree=tree,
    )

    # Decode test data
    GMM_decode_mfcc_and_score(
        outDir=os.path.join(args.expDir, "train_mono",
                            f"decode_{args.order}grams"),
        hmm=model,
        HCLGfile=os.path.join(args.expDir, "train_mono", "graph",
                              f"HCLG.{args.order}.fst"),
    )
def main():

    # ------------- Parse arguments from command line ----------------------
    # 1. Add a discription of this program
    args.discribe("This program is used to train triphone GMM-HMM model")
    # 2. Add options
    args.add("--expDir",
             abbr="-e",
             dtype=str,
             default="exp",
             discription="The data and output path of current experiment.")
    args.add("--delta",
             abbr="-d",
             dtype=int,
             default=2,
             discription="Add n-order to feature.")
    args.add("--numIters",
             abbr="-n",
             dtype=int,
             default=35,
             discription="How many iterations to train.")
    args.add("--maxIterInc",
             abbr="-m",
             dtype=int,
             default=25,
             discription="The final iteration of increasing gaussians.")
    args.add("--realignIter",
             abbr="-r",
             dtype=int,
             default=[10, 20, 30],
             discription="the iteration to realign feature.")
    args.add("--order",
             abbr="-o",
             dtype=int,
             default=6,
             discription="Which N-grams model to use.")
    args.add("--beam",
             abbr="-b",
             dtype=int,
             default=13,
             discription="Decode beam size.")
    args.add("--latBeam",
             abbr="-l",
             dtype=int,
             default=6,
             discription="Lattice beam size.")
    args.add("--acwt",
             abbr="-a",
             dtype=float,
             default=0.083333,
             discription="Acoustic model weight.")
    args.add(
        "--parallel",
        abbr="-p",
        dtype=int,
        default=4,
        minV=1,
        maxV=10,
        discription=
        "The number of parallel process to compute feature of train dataset.")
    args.add("--skipTrain",
             abbr="-s",
             dtype=bool,
             default=False,
             discription="If True, skip training. Do decoding only.")
    # 3. Then start to parse arguments.
    args.parse()
    # 4. Take a backup of arguments
    argsLogFile = os.path.join(args.expDir, "conf", "train_delta.args")
    args.save(argsLogFile)

    if not args.skipTrain:
        # ------------- Prepare feature and previous alignment for training ----------------------
        # 1. Load the feature for training
        feat = exkaldi.load_index_table(
            os.path.join(args.expDir, "mfcc", "train", "mfcc_cmvn.ark"))
        print(f"Load MFCC+CMVN feature.")
        feat = exkaldi.add_delta(feat,
                                 order=args.delta,
                                 outFile=os.path.join(args.expDir,
                                                      "train_delta",
                                                      "mfcc_cmvn_delta.ark"))
        print(f"Add {args.delta}-order deltas.")
        # 2. Load lexicon bank
        lexicons = exkaldi.load_lex(
            os.path.join(args.expDir, "dict", "lexicons.lex"))
        print(f"Restorage lexicon bank.")
        # 3. Load previous alignment
        ali = exkaldi.load_index_table(os.path.join(args.expDir, "train_mono",
                                                    "*final.ali"),
                                       useSuffix="ark")

        # -------------- Build the decision tree ------------------------
        print("Start build a tree")
        tree = exkaldi.hmm.DecisionTree(lexicons=lexicons,
                                        contextWidth=3,
                                        centralPosition=1)
        tree.train(
            feat=feat,
            hmm=os.path.join(args.expDir, "train_mono", "final.mdl"),
            ali=ali,
            topoFile=os.path.join(args.expDir, "dict", "topo"),
            numLeaves=2500,
            tempDir=os.path.join(args.expDir, "train_delta"),
        )
        print(f"Build tree done.")

        # ------------- Start training ----------------------
        # 1. Initialize a monophone HMM object
        model = exkaldi.hmm.TriphoneHMM(lexicons=lexicons, name="mono")
        model.initialize(
            tree=tree,
            topoFile=os.path.join(args.expDir, "dict", "topo"),
            treeStatsFile=os.path.join(args.expDir, "train_delta",
                                       "treeStats.acc"),
        )
        print(f"Initialized a monophone HMM-GMM model: {model.info}.")

        # 2. convert the previous alignment
        print(f"Transform the alignment")
        newAli = exkaldi.hmm.convert_alignment(
            ali=ali,
            originHmm=os.path.join("exp", "train_mono", "final.mdl"),
            targetHmm=model,
            tree=tree,
            outFile=os.path.join(args.expDir, "train_delta", "initial.ali"),
        )

        # 2. Split data for parallel training
        transcription = exkaldi.load_transcription(
            os.path.join(args.expDir, "data", "train", "text"))
        transcription = transcription.sort()
        if args.parallel > 1:
            # split feature
            feat = feat.sort(by="utt").subset(chunks=args.parallel)
            # split transcription depending on utterance IDs of each feat
            tempTrans = []
            tempAli = []
            for f in feat:
                tempTrans.append(transcription.subset(keys=f.utts))
                tempAli.append(newAli.subset(keys=f.utts))
            transcription = tempTrans
            newAli = tempAli

        # 3. Train
        print("Train the triphone model")
        model.train(
            feat,
            transcription,
            os.path.join("exp", "dict", "L.fst"),
            tree,
            tempDir=os.path.join(args.expDir, "train_delta"),
            initialAli=newAli,
            numIters=args.numIters,
            maxIterInc=args.maxIterInc,
            totgauss=15000,
            realignIter=args.realignIter,
            boostSilence=1.0,
        )
        print(model.info)
        # Save the tree
        model.tree.save(os.path.join(args.expDir, "train_delta", "tree"))
        print(f"Tree has been saved.")
        del feat

    else:
        declare.is_file(os.path.join(args.expDir, "train_delta", "final.mdl"))
        declare.is_file(os.path.join(args.expDir, "train_delta", "tree"))
        model = exkaldi.load_hmm(
            os.path.join(args.expDir, "train_delta", "final.mdl"))
        tree = exkaldi.load_tree(
            os.path.join(args.expDir, "train_delta", "tree"))

    # ------------- Compile WFST training ----------------------
    # Make a WFST decoding graph
    make_WFST_graph(
        outDir=os.path.join(args.expDir, "train_delta", "graph"),
        hmm=model,
        tree=tree,
    )
    # Decode test data
    GMM_decode_mfcc_and_score(
        outDir=os.path.join(args.expDir, "train_delta",
                            f"decode_{args.order}grams"),
        hmm=model,
        HCLGfile=os.path.join(args.expDir, "train_delta", "graph",
                              f"HCLG.{args.order}.fst"),
    )
def train_model():

    global args

    print("\n############## Parameters Configure ##############")

    # Show configure information and write them to file
    def configLog(message, f):
        print(message)
        f.write(message + '\n')

    f = open(args.outDir + '/configure', "w")
    configLog(
        'Start System Time:{}'.format(
            datetime.datetime.now().strftime("%Y-%m-%d %X")), f)
    configLog('Host Name:{}'.format(socket.gethostname()), f)
    configLog('Fix Random Seed:{}'.format(args.randomSeed), f)
    configLog('Mini Batch Size:{}'.format(args.batchSize), f)
    configLog('GPU ID:{}'.format(args.gpu), f)
    configLog('Train Epochs:{}'.format(args.epoch), f)
    configLog('Output Folder:{}'.format(args.outDir), f)
    configLog('GRU layers:{}'.format(args.layer), f)
    configLog('GRU hidden nodes:{}'.format(args.hiddenNode), f)
    configLog('GRU dropout:{}'.format(args.dropout), f)
    configLog('Use CMVN:{}'.format(args.useCMVN), f)
    configLog('Splice N Frames:{}'.format(args.splice), f)
    configLog('Add N Deltas:{}'.format(args.delta), f)
    configLog('Normalize Chunk:{}'.format(args.normalizeChunk), f)
    configLog('Normalize AMP:{}'.format(args.normalizeAMP), f)
    configLog('Decode Minimum Active:{}'.format(args.minActive), f)
    configLog('Decode Maximum Active:{}'.format(args.maxActive), f)
    configLog('Decode Maximum Memory:{}'.format(args.maxMemory), f)
    configLog('Decode Beam:{}'.format(args.beam), f)
    configLog('Decode Lattice Beam:{}'.format(args.latBeam), f)
    configLog('Decode Acoustic Weight:{}'.format(args.acwt), f)
    configLog('Decode minimum Language Weight:{}'.format(args.minLmwt), f)
    configLog('Decode maximum Language Weight:{}'.format(args.maxLmwt), f)
    f.close()

    print("\n############## Train GRU Acoustic Model ##############")

    #----------------- STEP 1: Prepare Train Data -----------------
    print('Prepare data iterator...')
    # Fmllr data file
    trainScpFile = args.TIMITpath + '/data-fmllr-tri3/train/feats.scp'
    devScpFile = args.TIMITpath + '/data-fmllr-tri3/dev/feats.scp'
    # Alignment label file
    trainAliFile = args.TIMITpath + '/exp/dnn4_pretrain-dbn_dnn_ali/ali.*.gz'
    devAliFile = args.TIMITpath + '/exp/dnn4_pretrain-dbn_dnn_ali_dev/ali.*.gz'
    # Load label
    trainLabelPdf = E.load_ali(trainAliFile)
    trainLabelPho = E.load_ali(trainAliFile, returnPhone=True)
    for i in trainLabelPho.keys():
        trainLabelPho[i] = trainLabelPho[i] - 1
    devLabelPdf = E.load_ali(devAliFile)
    devLabelPho = E.load_ali(devAliFile, returnPhone=True)
    for i in devLabelPho.keys():
        devLabelPho[i] = devLabelPho[i] - 1
    # CMVN file
    trainUttSpk = args.TIMITpath + '/data-fmllr-tri3/train/utt2spk'
    trainCmvnState = args.TIMITpath + '/data-fmllr-tri3/train/cmvn.ark'
    devUttSpk = args.TIMITpath + '/data-fmllr-tri3/dev/utt2spk'
    devCmvnState = args.TIMITpath + '/data-fmllr-tri3/dev/cmvn.ark'

    # Design a process function
    def loadChunkData(iterator, feat, otherArgs):
        # <feat> is KaldiArk object
        global args
        uttSpk, cmvnState, labelPdf, labelPho, toDo = otherArgs
        # use CMVN
        if args.useCMVN:
            feat = E.use_cmvn(feat, cmvnState, uttSpk)
        # Add delta
        if args.delta > 0:
            feat = E.add_delta(feat, args.delta)
        # Splice front-back n frames
        if args.splice > 0:
            feat = feat.splice(args.splice)
        # Transform to KaldiDict and sort them by frame length
        feat = feat.array.sort(by='frame')
        # Normalize
        if args.normalizeChunk:
            feat = feat.normalize()
        # Concatenate label
        datas = feat.concat([labelPdf, labelPho], axis=1)
        # cut frames
        if toDo == 'train':
            if iterator.epoch >= 4:
                datas = datas.cut(1000)
            elif iterator.epoch >= 3:
                datas = datas.cut(800)
            elif iterator.epoch >= 2:
                datas = datas.cut(400)
            elif iterator.epoch >= 1:
                datas = datas.cut(200)
            elif iterator.epoch >= 0:
                datas = datas.cut(100)
        # Transform trainable numpy data
        datas, _ = datas.merge(keepDim=True, sortFrame=True)
        return datas

    # Make data iterator
    train = E.DataIterator(trainScpFile,
                           loadChunkData,
                           args.batchSize,
                           chunks=5,
                           shuffle=False,
                           otherArgs=(trainUttSpk, trainCmvnState,
                                      trainLabelPdf, trainLabelPho, 'train'))
    print('Generate train dataset. Chunks:{} / Batch size:{}'.format(
        train.chunks, train.batchSize))
    dev = E.DataIterator(devScpFile,
                         loadChunkData,
                         args.batchSize,
                         chunks=1,
                         shuffle=False,
                         otherArgs=(devUttSpk, devCmvnState, devLabelPdf,
                                    devLabelPho, 'dev'))
    print('Generate validation dataset. Chunks:{} / Batch size:{}.'.format(
        dev.chunks, dev.batchSize))
    print("Done.")

    print('Prepare model...')
    featDim = 40
    if args.delta > 0:
        featDim *= (args.delta + 1)
    if args.splice > 0:
        featDim *= (2 * args.splice + 1)
    model = GRU(featDim, trainLabelPdf.target, trainLabelPho.target)
    lossfunc = nn.NLLLoss()
    if args.gpu >= 0:
        model = model.cuda(args.gpu)
        lossfunc = lossfunc.cuda(args.gpu)
    print('Generate model done.')

    print('Prepare optimizer and supporter...')
    #lr = [(0,0.5),(8,0.25),(13,0.125),(15,0.07),(17,0.035),(20,0.02),(23,0.01)]
    lr = [(0, 0.0004)]
    print('Learning Rate:', lr)
    optimizer = torch.optim.RMSprop(model.parameters(),
                                    lr=lr[0][1],
                                    alpha=0.95,
                                    eps=1e-8,
                                    weight_decay=0,
                                    momentum=0,
                                    centered=False)
    lr.pop(0)
    supporter = E.Supporter(args.outDir)
    print('Done.')

    print('Prepare test data...')
    # Fmllr file
    testFilePath = args.TIMITpath + '/data-fmllr-tri3/test/feats.scp'
    testFeat = E.load(testFilePath)
    # Use CMVN
    if args.useCMVN:
        testUttSpk = args.TIMITpath + '/data-fmllr-tri3/test/utt2spk'
        testCmvnState = args.TIMITpath + '/data-fmllr-tri3/test/cmvn.ark'
        testFeat = E.use_cmvn(testFeat, testCmvnState, testUttSpk)
    # Add delta
    if args.delta > 0:
        testFeat = E.add_delta(testFeat, args.delta)
    # Splice frames
    if args.splice > 0:
        testFeat = testFeat.splice(args.splice)
    # Transform to array
    testFeat = testFeat.array
    # Normalize
    if args.normalizeChunk:
        testFeat = testFeat.normalize()
    # Normalize acoustic model output
    if args.normalizeAMP:
        # compute pdf counts in order to normalize acoustic model posterior probability.
        countFile = args.outDir + '/pdfs_counts.txt'
        if not os.path.isfile(countFile):
            _ = E.analyze_counts(aliFile=trainAliFile, outFile=countFile)
        with open(countFile) as f:
            line = f.readline().strip().strip("[]").strip()
        counts = np.array(list(map(float, line.split())), dtype=np.float32)
        normalizeBias = np.log(counts / np.sum(counts))
    else:
        normalizeBias = 0
    print('Done.')

    print('Prepare test data decode and score function...')

    # Design a function to compute WER of test data
    def wer_fun(model, feat, normalizeBias):
        global args
        # Tranform the formate of KaldiDict feature data in order to forward network
        temp = E.KaldiDict()
        utts = feat.utts
        with torch.no_grad():
            for index, utt in enumerate(utts):
                data = torch.Tensor(feat[utt][:, np.newaxis, :])
                data = torch.autograd.Variable(data)
                if args.gpu >= 0:
                    data = data.cuda(args.gpu)
                out1, out2 = model(data, is_training=False, device=args.gpu)
                out = out1.cpu().detach().numpy() - normalizeBias
                temp[utt] = out
                print("(testing) Forward network {}/{}".format(
                    index, len(utts)),
                      end=" " * 20 + '\r')
        # Tansform KaldiDict to KaldiArk format
        print('(testing) Transform to ark', end=" " * 20 + '\r')
        amp = temp.ark
        # Decode and obtain lattice
        hmm = args.TIMITpath + '/exp/dnn4_pretrain-dbn_dnn_ali_test/final.mdl'
        hclg = args.TIMITpath + '/exp/tri3/graph/HCLG.fst'
        lexicon = args.TIMITpath + '/exp/tri3/graph/words.txt'
        print('(testing) Generate Lattice', end=" " * 20 + '\r')
        lattice = E.decode_lattice(amp, hmm, hclg, lexicon, args.minActive,
                                   args.maxActive, args.maxMemory, args.beam,
                                   args.latBeam, args.acwt)
        # Change language weight from 1 to 10, get the 1best words.
        print('(testing) Get 1-best words', end=" " * 20 + '\r')
        outs = lattice.get_1best(lmwt=args.minLmwt,
                                 maxLmwt=args.maxLmwt,
                                 outFile=args.outDir + '/outRaw')
        # If reference file is not existed, make it.
        phonemap = args.TIMITpath + '/conf/phones.60-48-39.map'
        outFilter = args.TIMITpath + '/local/timit_norm_trans.pl -i - -m {} -from 48 -to 39'.format(
            phonemap)
        if not os.path.isfile(args.outDir + '/test_filt.txt'):
            refText = args.TIMITpath + '/data/test/text'
            cmd = 'cat {} | {} > {}/test_filt.txt'.format(
                refText, outFilter, args.outDir)
            (_, _) = E.run_shell_cmd(cmd)
        # Score WER and find the smallest one.
        print('(testing) Score', end=" " * 20 + '\r')
        minWER = None
        for k in range(args.minLmwt, args.maxLmwt + 1, 1):
            cmd = 'cat {} | {} > {}/test_prediction_filt.txt'.format(
                outs[k], outFilter, args.outDir)
            (_, _) = E.run_shell_cmd(cmd)
            os.remove(outs[k])
            score = E.wer('{}/test_filt.txt'.format(args.outDir),
                          "{}/test_prediction_filt.txt".format(args.outDir),
                          mode='all')
            if minWER == None or score['WER'] < minWER:
                minWER = score['WER']
        os.remove("{}/test_prediction_filt.txt".format(args.outDir))
        return minWER

    print('Done.')

    print('Now Start to Train')
    for e in range(args.epoch):
        print()
        i = 0
        usedTime = 0
        supporter.send_report({'epoch': e})
        # Train
        model.train()
        while True:
            start = time.time()
            # Get batch data and label
            batch = train.next()
            batch, lengths = E.pad_sequence(batch, shuffle=True, pad=0)
            batch = torch.Tensor(batch)
            data, label1, label2 = batch[:, :, 0:-2], batch[:, :,
                                                            -2], batch[:, :,
                                                                       -1]
            data = torch.autograd.Variable(data)
            label1 = torch.autograd.Variable(label1).view(-1).long()
            label2 = torch.autograd.Variable(label2).view(-1).long()
            # Send to GPU if use
            if args.gpu >= 0:
                data = data.cuda(args.gpu)
                label1 = label1.cuda(args.gpu)
                label2 = label2.cuda(args.gpu)
            # Clear grad
            optimizer.zero_grad()
            # Forward model
            out1, out2 = model(data, is_training=True, device=args.gpu)
            # Loss back propagation
            loss1 = lossfunc(out1, label1)
            loss2 = lossfunc(out2, label2)
            loss = loss1 + loss2
            loss.backward()
            # Update parameter
            optimizer.step()
            # Compute accuracy
            pred = torch.max(out1, dim=1)[1]
            acc = 1 - torch.mean((pred != label1).float())
            # Record train information
            supporter.send_report({
                'train_loss': float(loss),
                'train_acc': float(acc)
            })
            ut = time.time() - start
            usedTime += ut
            batchLoss = float(loss.cpu().detach().numpy())
            print(
                "(training) Epoch:{}/{}% Chunk:{}/{}% Iter:{} Used-time:{}s Batch-loss:{} Speed:{}iters/s"
                .format(e, int(100 * train.epochProgress), train.chunk,
                        int(100 * train.chunkProgress), i, int(usedTime),
                        "%.4f" % (batchLoss), "%.2f" % (1 / ut)),
                end=" " * 5 + '\r')
            i += 1
            # If forwarded all data, break
            if train.isNewEpoch:
                break
        # Evaluation
        model.eval()
        with torch.no_grad():
            while True:
                start = time.time()
                # Get batch data and label
                batch = dev.next()
                batch, lengths = E.pad_sequence(batch, shuffle=True, pad=0)
                maxLen, bSize, _ = batch.shape
                batch = torch.Tensor(batch)
                data, label1, label2 = batch[:, :,
                                             0:-2], batch[:, :,
                                                          -2], batch[:, :, -1]
                data = torch.autograd.Variable(data)
                label1 = torch.autograd.Variable(label1).view(-1).long()
                # Send to GPU if use
                if args.gpu >= 0:
                    data = data.cuda(args.gpu)
                    label1 = label1.cuda(args.gpu)
                # Forward model
                out1, out2 = model(data, is_training=False, device=args.gpu)
                # Compute accuracy of padded label
                pred = torch.max(out1, dim=1)[1]
                acc_pad = 1 - torch.mean((pred != label1).float())
                # Compute accuracy of not padded label. This should be more correct.
                label = label1.cpu().numpy().reshape([maxLen, bSize])
                pred = pred.cpu().numpy().reshape([maxLen, bSize])
                label = E.unpack_padded_sequence(label, lengths)
                pred = E.unpack_padded_sequence(pred, lengths)
                acc_nopad = E.accuracy(pred, label)
                # Record evaluation information
                supporter.send_report({
                    'dev_acc_pad': float(acc_pad),
                    'dev_acc_nopad': acc_nopad
                })
                ut = time.time() - start
                usedTime += ut
                batchLoss = float(loss.cpu().detach().numpy())
                print(
                    "(Validating) Epoch:{}/{}% Chunk:{}/{}% Iter:{} Used-time:{}s Batch-loss:{} Speed:{}iters/s"
                    .format(e, int(100 * dev.epochProgress), dev.chunk,
                            int(100 * dev.chunkProgress), i, int(usedTime),
                            "%.4f" % (batchLoss), "%.2f" % (1 / ut)),
                    end=" " * 5 + '\r')
                i += 1
                # If forwarded all data, break
                if dev.isNewEpoch:
                    break
            print()
            # We compute WER score from 4th epoch
            if e >= 2:
                minWER = wer_fun(model, testFeat, normalizeBias)
                supporter.send_report({'test_wer': minWER})
        # one epoch is over so collect information
        supporter.collect_report(plot=True)

        # Save model
        def saveFunc(archs):
            fileName, model = archs
            torch.save(model.state_dict(), fileName)

        supporter.save_arch(saveFunc, arch={'GRU': model})
        # Change learning rate
        if len(lr) > 0 and supporter.judge('epoch', '>=', lr[0][0]):
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr[0][1]
            lr.pop(0)

    print("GRU Acoustic Model training done.")
    print("The final model has been saved as:", supporter.finalArch["GRU"])
    print('Over System Time:', datetime.datetime.now().strftime("%Y-%m-%d %X"))