예제 #1
0
    def training_step(self, batch, batch_idx):
        trainParams = {"spaceIx": args["CHAR_TO_INDEX"][" "],
                       "eosIx": args["CHAR_TO_INDEX"]["<EOS>"]}

        inputBatch, targetBatch, inputLenBatch, targetLenBatch = batch
        inputBatch, targetBatch = inputBatch.float(), targetBatch.int()
        inputLenBatch, targetLenBatch = inputLenBatch.int(), targetLenBatch.int()

        outputBatch = self.model(inputBatch)
        with torch.backends.cudnn.flags(enabled=False):
            loss = self.loss_fn(outputBatch, targetBatch, inputLenBatch, targetLenBatch)

        trainingLoss = loss
        predictionBatch, predictionLenBatch = ctc_greedy_decode(outputBatch.detach(),
                                                                inputLenBatch,
                                                                trainParams["eosIx"])
        trainingCER = compute_cer(predictionBatch,
                                  targetBatch,
                                  predictionLenBatch,
                                  targetLenBatch)
        trainingWER = compute_wer(predictionBatch,
                                  targetBatch,
                                  predictionLenBatch,
                                  targetLenBatch,
                                  trainParams["spaceIx"])

        self.log('train_loss', trainingLoss, prog_bar=True)
        self.log('train_wer', trainingWER, prog_bar=True)
        self.log('train_cer', trainingCER, prog_bar=True)
        return trainingLoss
예제 #2
0
    def validation_step(self, batch, batch_idx):
        evalParams = {"decodeScheme": "greedy",
                      "spaceIx": args["CHAR_TO_INDEX"][" "],
                      "eosIx": args["CHAR_TO_INDEX"]["<EOS>"]}

        inputBatch, targetBatch, inputLenBatch, targetLenBatch = batch
        inputBatch, targetBatch = inputBatch.float(), targetBatch.int()
        inputLenBatch, targetLenBatch = inputLenBatch.int(), targetLenBatch.int()

        outputBatch = self.model(inputBatch)
        with torch.backends.cudnn.flags(enabled=False):
            loss = self.loss_fn(outputBatch, targetBatch, inputLenBatch, targetLenBatch)

        evalLoss = loss
        if evalParams["decodeScheme"] == "greedy":
            predictionBatch, predictionLenBatch = ctc_greedy_decode(outputBatch,
                                                                    inputLenBatch,
                                                                    evalParams["eosIx"])
        elif evalParams["decodeScheme"] == "search":
            predictionBatch, predictionLenBatch = ctc_search_decode(outputBatch,
                                                                    inputLenBatch,
                                                                    evalParams["beamSearchParams"],
                                                                    evalParams["spaceIx"],
                                                                    evalParams["eosIx"],
                                                                    evalParams["lm"])
        else:
            print("Invalid Decode Scheme")
            exit()

        evalCER = compute_cer(predictionBatch,
                              targetBatch,
                              predictionLenBatch,
                              targetLenBatch)
        evalWER = compute_wer(predictionBatch,
                              targetBatch,
                              predictionLenBatch,
                              targetLenBatch,
                              evalParams["spaceIx"])

        self.log('val_loss', evalLoss, prog_bar=True)
        self.log('val_wer', evalWER, prog_bar=True)
        self.log('val_cer', evalCER, prog_bar=True)
        return evalLoss
예제 #3
0
def compute_cer_checker():
    preds = ["SOMETIN'  ' NEDSS~", "   ALRIT ~", "CHEK DON~", "EXACT SAME~"]
    trgts = ["SOMETHING NEEDS~", "ALRIGHT~", "CHECK DONE~", "EXACT SAME~"]
    predLens = [18, 10, 9, 11]
    trgtLens = [16, 8, 11, 11]

    predIxs = list()
    for n in range(len(preds)):
        predIx = list()
        for i in range(len(preds[n])):
            char = preds[n][i]
            if char == "~":
                ix = args["CHAR_TO_INDEX"]["<EOS>"]
            else:
                ix = args["CHAR_TO_INDEX"][char]
            predIx.append(ix)
        predIxs.extend(predIx)

    trgtIxs = list()
    for n in range(len(trgts)):
        trgtIx = list()
        for i in range(len(trgts[n])):
            char = trgts[n][i]
            if char == "~":
                ix = args["CHAR_TO_INDEX"]["<EOS>"]
            else:
                ix = args["CHAR_TO_INDEX"][char]
            trgtIx.append(ix)
        trgtIxs.extend(trgtIx)

    predictionBatch = torch.tensor(predIxs)
    targetBatch = torch.tensor(trgtIxs)
    predictionLenBatch = torch.tensor(predLens)
    targetLenBatch = torch.tensor(trgtLens)

    print(
        compute_cer(predictionBatch, targetBatch, predictionLenBatch,
                    targetLenBatch))
    return
예제 #4
0
                #obtaining the prediction using CTC deocder
                if args["TEST_DEMO_DECODING"] == "greedy":
                    predictionBatch, predictionLenBatch = ctc_greedy_decode(outputBatch, inputLenBatch, args["CHAR_TO_INDEX"]["<EOS>"])

                elif args["TEST_DEMO_DECODING"] == "search":
                    beamSearchParams = {"beamWidth":args["BEAM_WIDTH"], "alpha":args["LM_WEIGHT_ALPHA"], "beta":args["LENGTH_PENALTY_BETA"],
                                        "threshProb":args["THRESH_PROBABILITY"]}
                    predictionBatch, predictionLenBatch = ctc_search_decode(outputBatch, inputLenBatch, beamSearchParams,
                                                                            args["CHAR_TO_INDEX"][" "], args["CHAR_TO_INDEX"]["<EOS>"], lm)

                else:
                    print("Invalid Decode Scheme")
                    exit()

                #computing CER and WER
                cer = compute_cer(predictionBatch, targetBatch, predictionLenBatch, targetLenBatch)
                wer = compute_wer(predictionBatch, targetBatch, predictionLenBatch, targetLenBatch, args["CHAR_TO_INDEX"][" "])

                #converting character indices back to characters
                pred = predictionBatch[:][:-1]
                trgt = targetBatch[:][:-1]
                pred = "".join([args["INDEX_TO_CHAR"][ix] for ix in pred.tolist()])
                trgt = "".join([args["INDEX_TO_CHAR"][ix] for ix in trgt.tolist()])

                #printing the predictions
                print("File: %s" %(file))
                print("Prediction: %s" %(pred))
                print("Target: %s" %(trgt))
                print("CER: %.3f  WER: %.3f" %(cer, wer))
                print("\n")