示例#1
0
    def training_step(self, batch, batch_idx):
        trainParams = {"spaceIx": args["CHAR_TO_INDEX"][" "],
                       "eosIx": args["CHAR_TO_INDEX"]["<EOS>"]}

        inputBatch, targetBatch, inputLenBatch, targetLenBatch = batch
        inputBatch, targetBatch = inputBatch.float(), targetBatch.int()
        inputLenBatch, targetLenBatch = inputLenBatch.int(), targetLenBatch.int()

        outputBatch = self.model(inputBatch)
        with torch.backends.cudnn.flags(enabled=False):
            loss = self.loss_fn(outputBatch, targetBatch, inputLenBatch, targetLenBatch)

        trainingLoss = loss
        predictionBatch, predictionLenBatch = ctc_greedy_decode(outputBatch.detach(),
                                                                inputLenBatch,
                                                                trainParams["eosIx"])
        trainingCER = compute_cer(predictionBatch,
                                  targetBatch,
                                  predictionLenBatch,
                                  targetLenBatch)
        trainingWER = compute_wer(predictionBatch,
                                  targetBatch,
                                  predictionLenBatch,
                                  targetLenBatch,
                                  trainParams["spaceIx"])

        self.log('train_loss', trainingLoss, prog_bar=True)
        self.log('train_wer', trainingWER, prog_bar=True)
        self.log('train_cer', trainingCER, prog_bar=True)
        return trainingLoss
示例#2
0
def compute_wer_checker():
    preds = [
        " SOMETH'NG  '  NEE DS TO BE D'NE   ABOUT IT~",
        "FUNCTION CHECKING INITIATED~", "    '   ~", "~"
    ]
    trgts = [
        "SOMETHING NEEDS TO BE DONE ABOUT IT~", "FUNCTION CHECKING INITIATED~",
        "SOME ARBIT STRING~", "ARBIT STRING~"
    ]
    predLens = [44, 28, 9, 1]
    trgtLens = [36, 28, 18, 13]

    predIxs = list()
    for n in range(len(preds)):
        predIx = list()
        for i in range(len(preds[n])):
            char = preds[n][i]
            if char == "~":
                ix = args["CHAR_TO_INDEX"]["<EOS>"]
            else:
                ix = args["CHAR_TO_INDEX"][char]
            predIx.append(ix)
        predIxs.extend(predIx)

    trgtIxs = list()
    for n in range(len(trgts)):
        trgtIx = list()
        for i in range(len(trgts[n])):
            char = trgts[n][i]
            if char == "~":
                ix = args["CHAR_TO_INDEX"]["<EOS>"]
            else:
                ix = args["CHAR_TO_INDEX"][char]
            trgtIx.append(ix)
        trgtIxs.extend(trgtIx)

    predictionBatch = torch.tensor(predIxs)
    targetBatch = torch.tensor(trgtIxs)
    predictionLenBatch = torch.tensor(predLens)
    targetLenBatch = torch.tensor(trgtLens)

    print(
        compute_wer(predictionBatch, targetBatch, predictionLenBatch,
                    targetLenBatch, args["CHAR_TO_INDEX"][" "]))
    return
示例#3
0
    def validation_step(self, batch, batch_idx):
        evalParams = {"decodeScheme": "greedy",
                      "spaceIx": args["CHAR_TO_INDEX"][" "],
                      "eosIx": args["CHAR_TO_INDEX"]["<EOS>"]}

        inputBatch, targetBatch, inputLenBatch, targetLenBatch = batch
        inputBatch, targetBatch = inputBatch.float(), targetBatch.int()
        inputLenBatch, targetLenBatch = inputLenBatch.int(), targetLenBatch.int()

        outputBatch = self.model(inputBatch)
        with torch.backends.cudnn.flags(enabled=False):
            loss = self.loss_fn(outputBatch, targetBatch, inputLenBatch, targetLenBatch)

        evalLoss = loss
        if evalParams["decodeScheme"] == "greedy":
            predictionBatch, predictionLenBatch = ctc_greedy_decode(outputBatch,
                                                                    inputLenBatch,
                                                                    evalParams["eosIx"])
        elif evalParams["decodeScheme"] == "search":
            predictionBatch, predictionLenBatch = ctc_search_decode(outputBatch,
                                                                    inputLenBatch,
                                                                    evalParams["beamSearchParams"],
                                                                    evalParams["spaceIx"],
                                                                    evalParams["eosIx"],
                                                                    evalParams["lm"])
        else:
            print("Invalid Decode Scheme")
            exit()

        evalCER = compute_cer(predictionBatch,
                              targetBatch,
                              predictionLenBatch,
                              targetLenBatch)
        evalWER = compute_wer(predictionBatch,
                              targetBatch,
                              predictionLenBatch,
                              targetLenBatch,
                              evalParams["spaceIx"])

        self.log('val_loss', evalLoss, prog_bar=True)
        self.log('val_wer', evalWER, prog_bar=True)
        self.log('val_cer', evalCER, prog_bar=True)
        return evalLoss
示例#4
0
                if args["TEST_DEMO_DECODING"] == "greedy":
                    predictionBatch, predictionLenBatch = ctc_greedy_decode(outputBatch, inputLenBatch, args["CHAR_TO_INDEX"]["<EOS>"])

                elif args["TEST_DEMO_DECODING"] == "search":
                    beamSearchParams = {"beamWidth":args["BEAM_WIDTH"], "alpha":args["LM_WEIGHT_ALPHA"], "beta":args["LENGTH_PENALTY_BETA"],
                                        "threshProb":args["THRESH_PROBABILITY"]}
                    predictionBatch, predictionLenBatch = ctc_search_decode(outputBatch, inputLenBatch, beamSearchParams,
                                                                            args["CHAR_TO_INDEX"][" "], args["CHAR_TO_INDEX"]["<EOS>"], lm)

                else:
                    print("Invalid Decode Scheme")
                    exit()

                #computing CER and WER
                cer = compute_cer(predictionBatch, targetBatch, predictionLenBatch, targetLenBatch)
                wer = compute_wer(predictionBatch, targetBatch, predictionLenBatch, targetLenBatch, args["CHAR_TO_INDEX"][" "])

                #converting character indices back to characters
                pred = predictionBatch[:][:-1]
                trgt = targetBatch[:][:-1]
                pred = "".join([args["INDEX_TO_CHAR"][ix] for ix in pred.tolist()])
                trgt = "".join([args["INDEX_TO_CHAR"][ix] for ix in trgt.tolist()])

                #printing the predictions
                print("File: %s" %(file))
                print("Prediction: %s" %(pred))
                print("Target: %s" %(trgt))
                print("CER: %.3f  WER: %.3f" %(cer, wer))
                print("\n")