def batch_cer_accuracy(preds, labels, label_lengths): pred_sentences = get_most_probable(preds) labels_list = labels.tolist() idx = 0 cer = [] for i, length in enumerate(label_lengths.cpu().tolist()): pred_sentence = pred_sentences[i] gt_sentence = sequence_to_string(labels_list[idx:idx + length]) cer.append(cerCalc(pred_sentence, gt_sentence)) idx += length return np.sum(cer)
def validate_update_function(engine, batch): img, labels, label_lengths, image_lengths = batch y_pred = model(img.to(device)) if np.random.rand() > 0.99: pred_sentences = get_most_probable(y_pred) labels_list = labels.tolist() idx = 0 for i, length in enumerate(label_lengths.cpu().tolist()): pred_sentence = pred_sentences[i] gt_sentence = sequence_to_string(labels_list[idx:idx + length]) idx += length print(f"Pred sentence: {pred_sentence}, GT: {gt_sentence}") return (y_pred, labels, label_lengths)
if not data_enc: return self.__getitem__(np.random.choice(range(len(self)))) data = msgpack.unpackb(data_enc, object_hook=m.decode, raw=False) img = data['img'] label = data['label'] if self.transform is not None: img = self.transform(img, self.epoch) if self.target_transform is not None: label = self.target_transform(label) return (img, label) if __name__ == "__main__": from utils.config import lmdb_root_path from datasets.librispeech import sequence_to_string lmdb_commonvoice_root_path = "lmdb-databases-common_voice" lmdb_airtel_root_path = "lmdb-databases-airtel" trainCleanPath = os.path.join(lmdb_root_path, 'train-labelled') trainOtherPath = os.path.join(lmdb_root_path, 'train-unlabelled') trainCommonVoicePath = os.path.join(lmdb_commonvoice_root_path, 'train-labelled-en') testAirtelPath = os.path.join(lmdb_airtel_root_path, 'test-labelled-en') roots = [trainCleanPath, trainOtherPath, trainCommonVoicePath] dataset = lmdbMultiDataset(roots=[testAirtelPath]) print( sequence_to_string(dataset[np.random.choice( len(dataset))][1].tolist()))