Example #1
0
def batch_cer_accuracy(preds, labels, label_lengths):
    pred_sentences = get_most_probable(preds)
    labels_list = labels.tolist()
    idx = 0
    cer = []
    for i, length in enumerate(label_lengths.cpu().tolist()):
        pred_sentence = pred_sentences[i]
        gt_sentence = sequence_to_string(labels_list[idx:idx + length])
        cer.append(cerCalc(pred_sentence, gt_sentence))
        idx += length
    return np.sum(cer)
Example #2
0
 def validate_update_function(engine, batch):
     img, labels, label_lengths, image_lengths = batch
     y_pred = model(img.to(device))
     if np.random.rand() > 0.99:
         pred_sentences = get_most_probable(y_pred)
         labels_list = labels.tolist()
         idx = 0
         for i, length in enumerate(label_lengths.cpu().tolist()):
             pred_sentence = pred_sentences[i]
             gt_sentence = sequence_to_string(labels_list[idx:idx + length])
             idx += length
             print(f"Pred sentence: {pred_sentence}, GT: {gt_sentence}")
     return (y_pred, labels, label_lengths)
Example #3
0
            if not data_enc:
                return self.__getitem__(np.random.choice(range(len(self))))
            data = msgpack.unpackb(data_enc, object_hook=m.decode, raw=False)

            img = data['img']
            label = data['label']

            if self.transform is not None:
                img = self.transform(img, self.epoch)

            if self.target_transform is not None:
                label = self.target_transform(label)

            return (img, label)


if __name__ == "__main__":
    from utils.config import lmdb_root_path
    from datasets.librispeech import sequence_to_string
    lmdb_commonvoice_root_path = "lmdb-databases-common_voice"
    lmdb_airtel_root_path = "lmdb-databases-airtel"
    trainCleanPath = os.path.join(lmdb_root_path, 'train-labelled')
    trainOtherPath = os.path.join(lmdb_root_path, 'train-unlabelled')
    trainCommonVoicePath = os.path.join(lmdb_commonvoice_root_path,
                                        'train-labelled-en')
    testAirtelPath = os.path.join(lmdb_airtel_root_path, 'test-labelled-en')
    roots = [trainCleanPath, trainOtherPath, trainCommonVoicePath]
    dataset = lmdbMultiDataset(roots=[testAirtelPath])
    print(
        sequence_to_string(dataset[np.random.choice(
            len(dataset))][1].tolist()))