def train_model():

    global args

    print("\n############## Parameters Configure ##############")

    # Show configure information and write them to file
    def configLog(message, f):
        print(message)
        f.write(message + '\n')

    f = open(args.outDir + '/configure', "w")
    configLog(
        'Start System Time:{}'.format(
            datetime.datetime.now().strftime("%Y-%m-%d %X")), f)
    configLog('Host Name:{}'.format(socket.gethostname()), f)
    configLog('Fix Random Seed:{}'.format(args.randomSeed), f)
    configLog('Mini Batch Size:{}'.format(args.batchSize), f)
    configLog('GPU ID:{}'.format(args.gpu), f)
    configLog('Train Epochs:{}'.format(args.epoch), f)
    configLog('Output Folder:{}'.format(args.outDir), f)
    configLog('Dropout Rate:{}'.format(args.dropout), f)
    configLog('Use CMVN:{}'.format(args.useCMVN), f)
    configLog('Splice N Frames:{}'.format(args.splice), f)
    configLog('Add N Deltas:{}'.format(args.delta), f)
    configLog('Normalize Chunk:{}'.format(args.normalizeChunk), f)
    configLog('Normalize AMP:{}'.format(args.normalizeAMP), f)
    configLog('Decode Minimum Active:{}'.format(args.minActive), f)
    configLog('Decode Maximum Active:{}'.format(args.maxActive), f)
    configLog('Decode Maximum Memory:{}'.format(args.maxMemory), f)
    configLog('Decode Beam:{}'.format(args.beam), f)
    configLog('Decode Lattice Beam:{}'.format(args.latBeam), f)
    configLog('Decode Acoustic Weight:{}'.format(args.acwt), f)
    configLog('Decode minimum Language Weight:{}'.format(args.minLmwt), f)
    configLog('Decode maximum Language Weight:{}'.format(args.maxLmwt), f)
    f.close()

    print("\n############## Train DNN Acoustic Model ##############")

    #------------------------ STEP 1: Prepare Training and Validation Data -----------------------------

    print('Prepare Data Iterator...')
    # Prepare fMLLR feature files
    trainScpFile = args.TIMITpath + '/data-fmllr-tri3/train/feats.scp'
    devScpFile = args.TIMITpath + '/data-fmllr-tri3/dev/feats.scp'
    # Prepare training labels (alignment data)
    trainAliFile = args.TIMITpath + '/exp/dnn4_pretrain-dbn_dnn_ali/ali.*.gz'
    trainLabelPdf = E.get_ali(trainAliFile)
    trainLabelPho = E.get_ali(trainAliFile, returnPhone=True)
    for i in trainLabelPho.keys():
        trainLabelPho[i] = trainLabelPho[i] - 1
    # Prepare validation labels (alignment data)
    devAliFile = args.TIMITpath + '/exp/dnn4_pretrain-dbn_dnn_ali_dev/ali.*.gz'
    devLabelPdf = E.get_ali(devAliFile)
    devLabelPho = E.get_ali(devAliFile, returnPhone=True)
    for i in devLabelPho.keys():
        devLabelPho[i] = devLabelPho[i] - 1
    # prepare CMVN files
    trainUttSpk = args.TIMITpath + '/data-fmllr-tri3/train/utt2spk'
    trainCmvnState = args.TIMITpath + '/data-fmllr-tri3/train/cmvn.ark'
    devUttSpk = args.TIMITpath + '/data-fmllr-tri3/dev/utt2spk'
    devCmvnState = args.TIMITpath + '/data-fmllr-tri3/dev/cmvn.ark'

    # Now we try to make training-iterator and validation-iterator.
    # Firstly, customize a function to process feature data.
    def loadChunkData(iterator, feat, otherArgs):
        # <feat> is a KaldiArk object
        global args
        uttSpk, cmvnState, labelPdf, labelPho = otherArgs
        # use CMVN
        if args.useCMVN:
            feat = E.use_cmvn(feat, cmvnState, uttSpk)
        # Add delta
        if args.delta > 0:
            feat = E.add_delta(feat, args.delta)
        # Splice front-back n frames
        if args.splice > 0:
            feat = feat.splice(args.splice)
        # Transform to KaldiDict
        feat = feat.array
        # Normalize
        if args.normalizeChunk:
            feat = feat.normalize()
        # Concatenate data-label in dimension
        datas = feat.concat([labelPdf, labelPho], axis=1)
        # Transform to trainable numpy data
        datas, _ = datas.merge(keepDim=False, sort=False)
        return datas

    # Then get data iterators
    train = E.DataIterator(trainScpFile,
                           loadChunkData,
                           args.batchSize,
                           chunks=5,
                           shuffle=True,
                           otherArgs=(trainUttSpk, trainCmvnState,
                                      trainLabelPdf, trainLabelPho))
    print('Generate train dataset done. Chunks:{} / Batch size:{}'.format(
        train.chunks, train.batchSize))
    dev = E.DataIterator(devScpFile,
                         loadChunkData,
                         args.batchSize,
                         chunks=1,
                         shuffle=False,
                         otherArgs=(devUttSpk, devCmvnState, devLabelPdf,
                                    devLabelPho))
    print(
        'Generate validation dataset done. Chunks:{} / Batch size:{}.'.format(
            dev.chunks, dev.batchSize))

    #--------------------------------- STEP 2: Prepare Model --------------------------

    print('Prepare Model...')
    # Initialize model
    featDim = 40
    if args.delta > 0:
        featDim *= (args.delta + 1)
    if args.splice > 0:
        featDim *= (2 * args.splice + 1)
    model = MLP(featDim, trainLabelPdf.target, trainLabelPho.target)
    if args.gpu >= 0:
        model.to_gpu(args.gpu)
    # Initialize optimizer
    lr = [(0, 0.08), (10, 0.04), (15, 0.02), (17, 0.01), (19, 0.005),
          (22, 0.0025), (25, 0.001)]
    print('Learning Rate (epoch,newLR):', lr)
    optimizer = chainer.optimizers.MomentumSGD(lr[0][1], momentum=0.0)
    lr.pop(0)
    optimizer.setup(model)
    optimizer.add_hook(chainer.optimizer_hooks.WeightDecay(0.0))
    # Prepare a supporter to help handling training information.
    supporter = E.Supporter(args.outDir)

    #------------------ STEP 3: Prepare Decoding Test Data and Function ------------------

    print('Prepare decoding test data...')
    # fMLLR file of test data
    testFilePath = args.TIMITpath + '/data-fmllr-tri3/test/feats.scp'
    testFeat = E.load(testFilePath)
    # Using CMVN
    if args.useCMVN:
        testUttSpk = args.TIMITpath + '/data-fmllr-tri3/test/utt2spk'
        testCmvnState = args.TIMITpath + '/data-fmllr-tri3/test/cmvn.ark'
        testFeat = E.use_cmvn(testFeat, testCmvnState, testUttSpk)
    # Add delta
    if args.delta > 0:
        testFeat = E.add_delta(testFeat, args.delta)
    # Splice frames
    if args.splice > 0:
        testFeat = testFeat.splice(args.splice)
    # Transform to array
    testFeat = testFeat.array
    # Normalize
    if args.normalizeChunk:
        testFeat = testFeat.normalize()
    # Normalize acoustic model output
    if args.normalizeAMP:
        # Compute pdf counts in order to normalize acoustic model posterior probability.
        countFile = args.outDir + '/pdfs_counts.txt'
        # Get statistics file
        if not os.path.isfile(countFile):
            _ = E.analyze_counts(aliFile=trainAliFile, outFile=countFile)
        with open(countFile) as f:
            line = f.readline().strip().strip("[]").strip()
        # Get AMP bias value
        counts = np.array(list(map(float, line.split())), dtype=np.float32)
        normalizeBias = np.log(counts / np.sum(counts))
    else:
        normalizeBias = 0
    # Now, design a function to compute WER score
    def wer_fun(model, testFeat, normalizeBias):
        global args
        # Use decode test data to forward network
        temp = E.KaldiDict()
        print('(testing) Forward network', end=" " * 20 + '\r')
        with chainer.using_config('train', False), chainer.no_backprop_mode():
            for utt in testFeat.keys():
                data = cp.array(testFeat[utt], dtype=cp.float32)
                out1, out2 = model(data)
                out = F.log_softmax(out1, axis=1)
                out.to_cpu()
                temp[utt] = out.array - normalizeBias
        # Tansform KaldiDict to KaldiArk format
        print('(testing) Transform to ark', end=" " * 20 + '\r')
        amp = temp.ark
        # Decoding to obtain a lattice
        hmm = args.TIMITpath + '/exp/dnn4_pretrain-dbn_dnn_ali_test/final.mdl'
        hclg = args.TIMITpath + '/exp/tri3/graph/HCLG.fst'
        lexicon = args.TIMITpath + '/exp/tri3/graph/words.txt'
        print('(testing) Generate Lattice', end=" " * 20 + '\r')
        lattice = E.decode_lattice(amp, hmm, hclg, lexicon, args.minActive,
                                   args.maxActive, args.maxMemory, args.beam,
                                   args.latBeam, args.acwt)
        # Change language weight from 1 to 10, get the 1best words.
        print('(testing) Get 1-best words', end=" " * 20 + '\r')
        outs = lattice.get_1best(lmwt=args.minLmwt,
                                 maxLmwt=args.maxLmwt,
                                 outFile=args.outDir + '/outRaw.txt')
        # If reference file is not existed, make it.
        phonemap = args.TIMITpath + '/conf/phones.60-48-39.map'
        outFilter = args.TIMITpath + '/local/timit_norm_trans.pl -i - -m {} -from 48 -to 39'.format(
            phonemap)
        if not os.path.isfile(args.outDir + '/test_filt.txt'):
            refText = args.TIMITpath + '/data/test/text'
            cmd = 'cat {} | {} > {}/test_filt.txt'.format(
                refText, outFilter, args.outDir)
            (_, _) = E.run_shell_cmd(cmd)
        # Score WER and find the smallest one.
        print('(testing) Score', end=" " * 20 + '\r')
        minWER = None
        for k in range(args.minLmwt, args.maxLmwt + 1, 1):
            cmd = 'cat {} | {} > {}/test_prediction_filt.txt'.format(
                outs[k], outFilter, args.outDir)
            (_, _) = E.run_shell_cmd(cmd)
            os.remove(outs[k])
            score = E.wer('{}/test_filt.txt'.format(args.outDir),
                          "{}/test_prediction_filt.txt".format(args.outDir),
                          mode='all')
            if minWER == None or score['WER'] < minWER:
                minWER = score['WER']
        return minWER

    #-------------------------- STEP 4: Train Model ---------------------------

    # While first epoch, the epoch size is computed gradually, so the prograss information will be inaccurate.
    print('Now Start to Train')
    print(
        'Note that: The first epoch will be doing the statistics of total data size gradually.'
    )
    print(
        'Note that: We will evaluate the WER of test dataset after epoch which will cost a few seconds.'
    )

    # Preprocessing batch data which is getten from data iterator
    def convert(batch):
        batch = cp.array(batch, dtype=cp.float32)
        data = batch[:, 0:-2]
        label1 = cp.array(batch[:, -2], dtype=cp.int32)
        label2 = cp.array(batch[:, -1], dtype=cp.int32)
        return data, label1, label2

    # We will save model during training loop, so prepare a model-save function
    def saveFunc(fileName, model):
        global args
        copymodel = model.copy()
        if args.gpu >= 0:
            copymodel.to_cpu()
        chainer.serializers.save_npz(fileName, copymodel)

    # Start training loop
    for e in range(args.epoch):
        supporter.send_report({'epoch': e})
        print()
        i = 1
        usedTime = 0
        # Train
        while True:
            start = time.time()
            # Get data >> Forward network >> Loss back propagation >> Update
            batch = train.next()
            data, label1, label2 = convert(batch)
            with chainer.using_config('Train', True):
                h1, h2 = model(data)
                L1 = F.softmax_cross_entropy(h1, label1)
                L2 = F.softmax_cross_entropy(h2, label2)
                loss = L1 + L2
                acc = F.accuracy(F.softmax(h1, axis=1), label1)
            model.cleargrads()
            loss.backward()
            optimizer.update()
            # Compute time cost
            ut = time.time() - start
            usedTime += ut
            print(
                "(training) Epoch:{}>>>{}% Chunk:{}>>>{}% Iter:{} Used-time:{}s Batch-loss:{} Speed:{}iters/s"
                .format(e, int(100 * train.epochProgress), train.chunk,
                        int(100 * train.chunkProgress), i, int(usedTime),
                        "%.4f" % (float(loss.array)), "%.2f" % (1 / ut)),
                end=" " * 5 + '\r')
            i += 1
            supporter.send_report({'train_loss': loss, 'train_acc': acc})
            # If forward all data, break
            if train.isNewEpoch:
                break
        print()
        i = 1
        usedTime = 0
        # Validate
        while True:
            start = time.time()
            # Get data >> Forward network >> Score
            batch = dev.next()
            data, label1, label2 = convert(batch)
            with chainer.using_config('train',
                                      False), chainer.no_backprop_mode():
                h1, h2 = model(data)
                loss = F.softmax_cross_entropy(h1, label1)
                acc = F.accuracy(F.softmax(h1, axis=1), label1)
            # Compute time cost
            ut = time.time() - start
            usedTime += ut
            print(
                "(Validating) Epoch:{}>>>{}% Chunk:{}>>>{}% Iter:{} Used-time:{}s Batch-loss:{} Speed:{}iters/s"
                .format(e, int(100 * dev.epochProgress), dev.chunk,
                        int(100 * dev.chunkProgress), i, int(usedTime),
                        "%.4f" % (float(loss.array)), "%.2f" % (1 / ut)),
                end=" " * 5 + '\r')
            i += 1
            supporter.send_report({'dev_loss': loss, 'dev_acc': acc})
            # If forward all data, break
            if dev.isNewEpoch:
                break
        print()
        # Compute WER score
        WERscore = wer_fun(model, testFeat, normalizeBias)
        supporter.send_report({'test_wer': WERscore, 'lr': optimizer.lr})
        # Collect all information of this epoch that is reported before, and show them at display
        supporter.collect_report(plot=True)
        # Save model
        supporter.save_model(saveFunc, models={'MLP': model})
        # Change learning rate
        if len(lr) > 0 and supporter.judge('epoch', '>=', lr[0][0]):
            optimizer.lr = lr[0][1]
            lr.pop(0)

    print("DNN Acoustic Model training done.")
    print("The final model has been saved as:", supporter.finalModel)
    print('Over System Time:', datetime.datetime.now().strftime("%Y-%m-%d %X"))
def decode_test(outDimPdf=1968, outDimPho=48):

    global args

    if args.preModel == '':
        raise Exception("Expected Pretrained Model.")
    elif not os.path.isfile(args.preModel):
        raise Exception("No such file:{}.".format(args.preModel))

    print("\n############## Parameters Configure ##############")

    # Show configure information and write them to file
    def configLog(message, f):
        print(message)
        f.write(message + '\n')

    f = open(args.outDir + '/configure', "w")
    configLog(
        'Start System Time:{}'.format(
            datetime.datetime.now().strftime("%Y-%m-%d %X")), f)
    configLog('Host Name:{}'.format(socket.gethostname()), f)
    configLog('Fix Random Seed:{}'.format(args.randomSeed), f)
    configLog('GPU ID:{}'.format(args.gpu), f)
    configLog('Pretrained Model:{}'.format(args.preModel), f)
    configLog('Output Folder:{}'.format(args.outDir), f)
    configLog('Use CMVN:{}'.format(args.useCMVN), f)
    configLog('Splice N Frames:{}'.format(args.splice), f)
    configLog('Add N Deltas:{}'.format(args.delta), f)
    configLog('Normalize Chunk:{}'.format(args.normalizeChunk), f)
    configLog('Normalize AMP:{}'.format(args.normalizeAMP), f)
    configLog('Decode Minimum Active:{}'.format(args.minActive), f)
    configLog('Decode Maximum Active:{}'.format(args.maxActive), f)
    configLog('Decode Maximum Memory:{}'.format(args.maxMemory), f)
    configLog('Decode Beam:{}'.format(args.beam), f)
    configLog('Decode Lattice Beam:{}'.format(args.latBeam), f)
    configLog('Decode Acoustic Weight:{}'.format(args.acwt), f)
    configLog('Decode minimum Language Weight:{}'.format(args.minLmwt), f)
    configLog('Decode maximum Language Weight:{}'.format(args.maxLmwt), f)
    f.close()

    print("\n############## Decode Test ##############")

    #------------------ STEP 1: Load Pretrained Model ------------------

    print('Load Model...')
    # Initialize model
    featDim = 40
    if args.delta > 0:
        featDim *= (args.delta + 1)
    if args.splice > 0:
        featDim *= (2 * args.splice + 1)
    model = MLP(featDim, outDimPdf, outDimPho)
    chainer.serializers.load_npz(args.preModel, model)
    if args.gpu >= 0:
        model.to_gpu(args.gpu)

    #------------------ STEP 2: Prepare Test Data ------------------

    print('Prepare decode test data...')
    # Fmllr file
    testFilePath = args.TIMITpath + '/data-fmllr-tri3/test/feats.scp'
    testFeat = E.load(testFilePath)
    # Use CMVN
    if args.useCMVN:
        testUttSpk = args.TIMITpath + '/data-fmllr-tri3/test/utt2spk'
        testCmvnState = args.TIMITpath + '/data-fmllr-tri3/test/cmvn.ark'
        testFeat = E.use_cmvn(testFeat, testCmvnState, testUttSpk)
    # Add delta
    if args.delta > 0:
        testFeat = E.add_delta(testFeat, args.delta)
    # Splice frames
    if args.splice > 0:
        testFeat = testFeat.splice(args.splice)
    # Transform to array
    testFeat = testFeat.array
    # Normalize
    if args.normalizeChunk:
        testFeat = testFeat.normalize()
    # Normalize acoustic model output
    if args.normalizeAMP:
        # Compute pdf counts in order to normalize acoustic model posterior probability.
        countFile = args.outDir + '/pdfs_counts.txt'
        # Get statistics file
        if not os.path.isfile(countFile):
            trainAliFile = args.TIMITpath + '/exp/dnn4_pretrain-dbn_dnn_ali/ali.*.gz'
            _ = E.analyze_counts(aliFile=trainAliFile, outFile=countFile)
        with open(countFile) as f:
            line = f.readline().strip().strip("[]").strip()
        # Get AMP bias value
        counts = np.array(list(map(float, line.split())), dtype=np.float32)
        normalizeBias = np.log(counts / np.sum(counts))
    else:
        normalizeBias = 0

    #------------------ STEP 3: Decode  ------------------

    temp = E.KaldiDict()
    print('Compute Test WER: Forward network', end=" " * 20 + '\r')
    with chainer.using_config('train', False), chainer.no_backprop_mode():
        for utt in testFeat.keys():
            data = cp.array(testFeat[utt], dtype=cp.float32)
            out1, out2 = model(data)
            out = F.log_softmax(out1, axis=1)
            out.to_cpu()
            temp[utt] = out.array - normalizeBias
    # Tansform KaldiDict to KaldiArk format
    print('Compute Test WER: Transform to ark', end=" " * 20 + '\r')
    amp = temp.ark
    # Decode and obtain lattice
    hmm = args.TIMITpath + '/exp/dnn4_pretrain-dbn_dnn_ali_test/final.mdl'
    hclg = args.TIMITpath + '/exp/tri3/graph/HCLG.fst'
    lexicon = args.TIMITpath + '/exp/tri3/graph/words.txt'
    print('Compute Test WER: Generate Lattice', end=" " * 20 + '\r')
    lattice = E.decode_lattice(amp, hmm, hclg, lexicon, args.minActive,
                               args.maxActive, args.maxMemory, args.beam,
                               args.latBeam, args.acwt)
    # Change language weight from 1 to 10, get the 1best words.
    print('Compute Test WER: Get 1Best', end=" " * 20 + '\r')
    outs = lattice.get_1best(lmwt=args.minLmwt,
                             maxLmwt=args.maxLmwt,
                             outFile=args.outDir + '/outRaw.txt')

    #------------------ STEP 4: Score  ------------------

    # If reference file is not existed, make it.
    phonemap = args.TIMITpath + '/conf/phones.60-48-39.map'
    outFilter = args.TIMITpath + '/local/timit_norm_trans.pl -i - -m {} -from 48 -to 39'.format(
        phonemap)
    if not os.path.isfile(args.outDir + '/test_filt.txt'):
        refText = args.TIMITpath + '/data/test/text'
        cmd = 'cat {} | {} > {}/test_filt.txt'.format(refText, outFilter,
                                                      args.outDir)
        (_, _) = E.run_shell_cmd(cmd)
    # Score WER and find the smallest one.
    print('Compute Test WER: compute WER', end=" " * 20 + '\r')
    minWER = (None, None)
    for k in range(args.minLmwt, args.maxLmwt + 1, 1):
        cmd = 'cat {} | {} > {}/tanslation_{}.txt'.format(
            outs[k], outFilter, args.outDir, k)
        (_, _) = E.run_shell_cmd(cmd)
        os.remove(outs[k])
        score = E.wer('{}/test_filt.txt'.format(args.outDir),
                      "{}/tanslation_{}.txt".format(args.outDir, k),
                      mode='all')
        if minWER[0] == None or score['WER'] < minWER[0]:
            minWER = (score['WER'], k)

    print("Best WER:{}% at {}/tanslation_{}.txt".format(
        minWER[0], args.outDir, k))
def train_model():

    global args

    print("\n############## Parameters Configure ##############")

    # Show configure information and write them to file
    def configLog(message, f):
        print(message)
        f.write(message + '\n')

    f = open(args.outDir + '/configure', "w")
    configLog(
        'Start System Time:{}'.format(
            datetime.datetime.now().strftime("%Y-%m-%d %X")), f)
    configLog('Host Name:{}'.format(socket.gethostname()), f)
    configLog('Fix Random Seed:{}'.format(args.randomSeed), f)
    configLog('Mini Batch Size:{}'.format(args.batchSize), f)
    configLog('GPU ID:{}'.format(args.gpu), f)
    configLog('Train Epochs:{}'.format(args.epoch), f)
    configLog('Output Folder:{}'.format(args.outDir), f)
    configLog('GRU layers:{}'.format(args.layer), f)
    configLog('GRU hidden nodes:{}'.format(args.hiddenNode), f)
    configLog('GRU dropout:{}'.format(args.dropout), f)
    configLog('Use CMVN:{}'.format(args.useCMVN), f)
    configLog('Splice N Frames:{}'.format(args.splice), f)
    configLog('Add N Deltas:{}'.format(args.delta), f)
    configLog('Normalize Chunk:{}'.format(args.normalizeChunk), f)
    configLog('Normalize AMP:{}'.format(args.normalizeAMP), f)
    configLog('Decode Minimum Active:{}'.format(args.minActive), f)
    configLog('Decode Maximum Active:{}'.format(args.maxActive), f)
    configLog('Decode Maximum Memory:{}'.format(args.maxMemory), f)
    configLog('Decode Beam:{}'.format(args.beam), f)
    configLog('Decode Lattice Beam:{}'.format(args.latBeam), f)
    configLog('Decode Acoustic Weight:{}'.format(args.acwt), f)
    configLog('Decode minimum Language Weight:{}'.format(args.minLmwt), f)
    configLog('Decode maximum Language Weight:{}'.format(args.maxLmwt), f)
    f.close()

    print("\n############## Train GRU Acoustic Model ##############")

    #----------------- STEP 1: Prepare Train Data -----------------
    print('Prepare data iterator...')
    # Fmllr data file
    trainScpFile = args.TIMITpath + '/data-fmllr-tri3/train/feats.scp'
    devScpFile = args.TIMITpath + '/data-fmllr-tri3/dev/feats.scp'
    # Alignment label file
    trainAliFile = args.TIMITpath + '/exp/dnn4_pretrain-dbn_dnn_ali/ali.*.gz'
    devAliFile = args.TIMITpath + '/exp/dnn4_pretrain-dbn_dnn_ali_dev/ali.*.gz'
    # Load label
    trainLabelPdf = E.load_ali(trainAliFile)
    trainLabelPho = E.load_ali(trainAliFile, returnPhone=True)
    for i in trainLabelPho.keys():
        trainLabelPho[i] = trainLabelPho[i] - 1
    devLabelPdf = E.load_ali(devAliFile)
    devLabelPho = E.load_ali(devAliFile, returnPhone=True)
    for i in devLabelPho.keys():
        devLabelPho[i] = devLabelPho[i] - 1
    # CMVN file
    trainUttSpk = args.TIMITpath + '/data-fmllr-tri3/train/utt2spk'
    trainCmvnState = args.TIMITpath + '/data-fmllr-tri3/train/cmvn.ark'
    devUttSpk = args.TIMITpath + '/data-fmllr-tri3/dev/utt2spk'
    devCmvnState = args.TIMITpath + '/data-fmllr-tri3/dev/cmvn.ark'

    # Design a process function
    def loadChunkData(iterator, feat, otherArgs):
        # <feat> is KaldiArk object
        global args
        uttSpk, cmvnState, labelPdf, labelPho, toDo = otherArgs
        # use CMVN
        if args.useCMVN:
            feat = E.use_cmvn(feat, cmvnState, uttSpk)
        # Add delta
        if args.delta > 0:
            feat = E.add_delta(feat, args.delta)
        # Splice front-back n frames
        if args.splice > 0:
            feat = feat.splice(args.splice)
        # Transform to KaldiDict and sort them by frame length
        feat = feat.array.sort(by='frame')
        # Normalize
        if args.normalizeChunk:
            feat = feat.normalize()
        # Concatenate label
        datas = feat.concat([labelPdf, labelPho], axis=1)
        # cut frames
        if toDo == 'train':
            if iterator.epoch >= 4:
                datas = datas.cut(1000)
            elif iterator.epoch >= 3:
                datas = datas.cut(800)
            elif iterator.epoch >= 2:
                datas = datas.cut(400)
            elif iterator.epoch >= 1:
                datas = datas.cut(200)
            elif iterator.epoch >= 0:
                datas = datas.cut(100)
        # Transform trainable numpy data
        datas, _ = datas.merge(keepDim=True, sortFrame=True)
        return datas

    # Make data iterator
    train = E.DataIterator(trainScpFile,
                           loadChunkData,
                           args.batchSize,
                           chunks=5,
                           shuffle=False,
                           otherArgs=(trainUttSpk, trainCmvnState,
                                      trainLabelPdf, trainLabelPho, 'train'))
    print('Generate train dataset. Chunks:{} / Batch size:{}'.format(
        train.chunks, train.batchSize))
    dev = E.DataIterator(devScpFile,
                         loadChunkData,
                         args.batchSize,
                         chunks=1,
                         shuffle=False,
                         otherArgs=(devUttSpk, devCmvnState, devLabelPdf,
                                    devLabelPho, 'dev'))
    print('Generate validation dataset. Chunks:{} / Batch size:{}.'.format(
        dev.chunks, dev.batchSize))
    print("Done.")

    print('Prepare model...')
    featDim = 40
    if args.delta > 0:
        featDim *= (args.delta + 1)
    if args.splice > 0:
        featDim *= (2 * args.splice + 1)
    model = GRU(featDim, trainLabelPdf.target, trainLabelPho.target)
    lossfunc = nn.NLLLoss()
    if args.gpu >= 0:
        model = model.cuda(args.gpu)
        lossfunc = lossfunc.cuda(args.gpu)
    print('Generate model done.')

    print('Prepare optimizer and supporter...')
    #lr = [(0,0.5),(8,0.25),(13,0.125),(15,0.07),(17,0.035),(20,0.02),(23,0.01)]
    lr = [(0, 0.0004)]
    print('Learning Rate:', lr)
    optimizer = torch.optim.RMSprop(model.parameters(),
                                    lr=lr[0][1],
                                    alpha=0.95,
                                    eps=1e-8,
                                    weight_decay=0,
                                    momentum=0,
                                    centered=False)
    lr.pop(0)
    supporter = E.Supporter(args.outDir)
    print('Done.')

    print('Prepare test data...')
    # Fmllr file
    testFilePath = args.TIMITpath + '/data-fmllr-tri3/test/feats.scp'
    testFeat = E.load(testFilePath)
    # Use CMVN
    if args.useCMVN:
        testUttSpk = args.TIMITpath + '/data-fmllr-tri3/test/utt2spk'
        testCmvnState = args.TIMITpath + '/data-fmllr-tri3/test/cmvn.ark'
        testFeat = E.use_cmvn(testFeat, testCmvnState, testUttSpk)
    # Add delta
    if args.delta > 0:
        testFeat = E.add_delta(testFeat, args.delta)
    # Splice frames
    if args.splice > 0:
        testFeat = testFeat.splice(args.splice)
    # Transform to array
    testFeat = testFeat.array
    # Normalize
    if args.normalizeChunk:
        testFeat = testFeat.normalize()
    # Normalize acoustic model output
    if args.normalizeAMP:
        # compute pdf counts in order to normalize acoustic model posterior probability.
        countFile = args.outDir + '/pdfs_counts.txt'
        if not os.path.isfile(countFile):
            _ = E.analyze_counts(aliFile=trainAliFile, outFile=countFile)
        with open(countFile) as f:
            line = f.readline().strip().strip("[]").strip()
        counts = np.array(list(map(float, line.split())), dtype=np.float32)
        normalizeBias = np.log(counts / np.sum(counts))
    else:
        normalizeBias = 0
    print('Done.')

    print('Prepare test data decode and score function...')

    # Design a function to compute WER of test data
    def wer_fun(model, feat, normalizeBias):
        global args
        # Tranform the formate of KaldiDict feature data in order to forward network
        temp = E.KaldiDict()
        utts = feat.utts
        with torch.no_grad():
            for index, utt in enumerate(utts):
                data = torch.Tensor(feat[utt][:, np.newaxis, :])
                data = torch.autograd.Variable(data)
                if args.gpu >= 0:
                    data = data.cuda(args.gpu)
                out1, out2 = model(data, is_training=False, device=args.gpu)
                out = out1.cpu().detach().numpy() - normalizeBias
                temp[utt] = out
                print("(testing) Forward network {}/{}".format(
                    index, len(utts)),
                      end=" " * 20 + '\r')
        # Tansform KaldiDict to KaldiArk format
        print('(testing) Transform to ark', end=" " * 20 + '\r')
        amp = temp.ark
        # Decode and obtain lattice
        hmm = args.TIMITpath + '/exp/dnn4_pretrain-dbn_dnn_ali_test/final.mdl'
        hclg = args.TIMITpath + '/exp/tri3/graph/HCLG.fst'
        lexicon = args.TIMITpath + '/exp/tri3/graph/words.txt'
        print('(testing) Generate Lattice', end=" " * 20 + '\r')
        lattice = E.decode_lattice(amp, hmm, hclg, lexicon, args.minActive,
                                   args.maxActive, args.maxMemory, args.beam,
                                   args.latBeam, args.acwt)
        # Change language weight from 1 to 10, get the 1best words.
        print('(testing) Get 1-best words', end=" " * 20 + '\r')
        outs = lattice.get_1best(lmwt=args.minLmwt,
                                 maxLmwt=args.maxLmwt,
                                 outFile=args.outDir + '/outRaw')
        # If reference file is not existed, make it.
        phonemap = args.TIMITpath + '/conf/phones.60-48-39.map'
        outFilter = args.TIMITpath + '/local/timit_norm_trans.pl -i - -m {} -from 48 -to 39'.format(
            phonemap)
        if not os.path.isfile(args.outDir + '/test_filt.txt'):
            refText = args.TIMITpath + '/data/test/text'
            cmd = 'cat {} | {} > {}/test_filt.txt'.format(
                refText, outFilter, args.outDir)
            (_, _) = E.run_shell_cmd(cmd)
        # Score WER and find the smallest one.
        print('(testing) Score', end=" " * 20 + '\r')
        minWER = None
        for k in range(args.minLmwt, args.maxLmwt + 1, 1):
            cmd = 'cat {} | {} > {}/test_prediction_filt.txt'.format(
                outs[k], outFilter, args.outDir)
            (_, _) = E.run_shell_cmd(cmd)
            os.remove(outs[k])
            score = E.wer('{}/test_filt.txt'.format(args.outDir),
                          "{}/test_prediction_filt.txt".format(args.outDir),
                          mode='all')
            if minWER == None or score['WER'] < minWER:
                minWER = score['WER']
        os.remove("{}/test_prediction_filt.txt".format(args.outDir))
        return minWER

    print('Done.')

    print('Now Start to Train')
    for e in range(args.epoch):
        print()
        i = 0
        usedTime = 0
        supporter.send_report({'epoch': e})
        # Train
        model.train()
        while True:
            start = time.time()
            # Get batch data and label
            batch = train.next()
            batch, lengths = E.pad_sequence(batch, shuffle=True, pad=0)
            batch = torch.Tensor(batch)
            data, label1, label2 = batch[:, :, 0:-2], batch[:, :,
                                                            -2], batch[:, :,
                                                                       -1]
            data = torch.autograd.Variable(data)
            label1 = torch.autograd.Variable(label1).view(-1).long()
            label2 = torch.autograd.Variable(label2).view(-1).long()
            # Send to GPU if use
            if args.gpu >= 0:
                data = data.cuda(args.gpu)
                label1 = label1.cuda(args.gpu)
                label2 = label2.cuda(args.gpu)
            # Clear grad
            optimizer.zero_grad()
            # Forward model
            out1, out2 = model(data, is_training=True, device=args.gpu)
            # Loss back propagation
            loss1 = lossfunc(out1, label1)
            loss2 = lossfunc(out2, label2)
            loss = loss1 + loss2
            loss.backward()
            # Update parameter
            optimizer.step()
            # Compute accuracy
            pred = torch.max(out1, dim=1)[1]
            acc = 1 - torch.mean((pred != label1).float())
            # Record train information
            supporter.send_report({
                'train_loss': float(loss),
                'train_acc': float(acc)
            })
            ut = time.time() - start
            usedTime += ut
            batchLoss = float(loss.cpu().detach().numpy())
            print(
                "(training) Epoch:{}/{}% Chunk:{}/{}% Iter:{} Used-time:{}s Batch-loss:{} Speed:{}iters/s"
                .format(e, int(100 * train.epochProgress), train.chunk,
                        int(100 * train.chunkProgress), i, int(usedTime),
                        "%.4f" % (batchLoss), "%.2f" % (1 / ut)),
                end=" " * 5 + '\r')
            i += 1
            # If forwarded all data, break
            if train.isNewEpoch:
                break
        # Evaluation
        model.eval()
        with torch.no_grad():
            while True:
                start = time.time()
                # Get batch data and label
                batch = dev.next()
                batch, lengths = E.pad_sequence(batch, shuffle=True, pad=0)
                maxLen, bSize, _ = batch.shape
                batch = torch.Tensor(batch)
                data, label1, label2 = batch[:, :,
                                             0:-2], batch[:, :,
                                                          -2], batch[:, :, -1]
                data = torch.autograd.Variable(data)
                label1 = torch.autograd.Variable(label1).view(-1).long()
                # Send to GPU if use
                if args.gpu >= 0:
                    data = data.cuda(args.gpu)
                    label1 = label1.cuda(args.gpu)
                # Forward model
                out1, out2 = model(data, is_training=False, device=args.gpu)
                # Compute accuracy of padded label
                pred = torch.max(out1, dim=1)[1]
                acc_pad = 1 - torch.mean((pred != label1).float())
                # Compute accuracy of not padded label. This should be more correct.
                label = label1.cpu().numpy().reshape([maxLen, bSize])
                pred = pred.cpu().numpy().reshape([maxLen, bSize])
                label = E.unpack_padded_sequence(label, lengths)
                pred = E.unpack_padded_sequence(pred, lengths)
                acc_nopad = E.accuracy(pred, label)
                # Record evaluation information
                supporter.send_report({
                    'dev_acc_pad': float(acc_pad),
                    'dev_acc_nopad': acc_nopad
                })
                ut = time.time() - start
                usedTime += ut
                batchLoss = float(loss.cpu().detach().numpy())
                print(
                    "(Validating) Epoch:{}/{}% Chunk:{}/{}% Iter:{} Used-time:{}s Batch-loss:{} Speed:{}iters/s"
                    .format(e, int(100 * dev.epochProgress), dev.chunk,
                            int(100 * dev.chunkProgress), i, int(usedTime),
                            "%.4f" % (batchLoss), "%.2f" % (1 / ut)),
                    end=" " * 5 + '\r')
                i += 1
                # If forwarded all data, break
                if dev.isNewEpoch:
                    break
            print()
            # We compute WER score from 4th epoch
            if e >= 2:
                minWER = wer_fun(model, testFeat, normalizeBias)
                supporter.send_report({'test_wer': minWER})
        # one epoch is over so collect information
        supporter.collect_report(plot=True)

        # Save model
        def saveFunc(archs):
            fileName, model = archs
            torch.save(model.state_dict(), fileName)

        supporter.save_arch(saveFunc, arch={'GRU': model})
        # Change learning rate
        if len(lr) > 0 and supporter.judge('epoch', '>=', lr[0][0]):
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr[0][1]
            lr.pop(0)

    print("GRU Acoustic Model training done.")
    print("The final model has been saved as:", supporter.finalArch["GRU"])
    print('Over System Time:', datetime.datetime.now().strftime("%Y-%m-%d %X"))