def monitor_asr_train_progress(tensors: list, labels: list): """ Takes output of greedy ctc decoder and performs ctc decoding algorithm to remove duplicates and special symbol. Prints wer and prediction examples to screen Args: tensors: A list of 3 tensors (predictions, targets, target_lengths) labels: A list of labels Returns: word error rate """ references = [] labels_map = dict([(i, labels[i]) for i in range(len(labels))]) with torch.no_grad(): targets_cpu_tensor = tensors[1].long().cpu() tgt_lenths_cpu_tensor = tensors[2].long().cpu() # iterate over batch for ind in range(targets_cpu_tensor.shape[0]): tgt_len = tgt_lenths_cpu_tensor[ind].item() target = targets_cpu_tensor[ind][:tgt_len].numpy().tolist() reference = ''.join([labels_map[c] for c in target]) references.append(reference) hypotheses = __ctc_decoder_predictions_tensor(tensors[0], labels=labels) tag = "training_batch_WER" wer, _, _ = word_error_rate(hypotheses, references) print_once('{0}: {1}'.format(tag, wer)) print_once('Prediction: {0}'.format(hypotheses[0])) print_once('Reference: {0}'.format(references[0])) return wer
def process_evaluation_epoch(global_vars: dict, tag=None): """ Processes results from each worker at the end of evaluation and combine to final result Args: global_vars: dictionary containing information of entire evaluation Return: wer: final word error rate loss: final loss """ hypotheses = global_vars['predictions'] references = global_vars['transcripts'] wer, scores, num_words = word_error_rate(hypotheses=hypotheses, references=references) return wer
def compute_batch_wer(logits, targets, target_lengths, preproc, reduction='mean'): logits = logits.cpu().argmax(dim=-1).long() targets = targets.cpu().long() mask = make_mask_from_lengths(target_lengths, max(target_lengths)) logits.masked_fill_(mask, 0) hypotheses = [preproc.decode_ids(logits[i]) for i in range(len(logits))] references = [preproc.decode_ids(targets[i]) for i in range(len(targets))] wer = [word_error_rate(h, r) for h, r in zip(hypotheses, references)] wer = torch.FloatTensor(wer) if reduction is None: return wer elif reduction == 'sum': return wer.sum() elif reduction == 'mean': return wer.mean()
def process_evaluation_epoch2(global_vars: dict, tag=None): """ Processes results from each worker at the end of evaluation and combine to final result Args: global_vars: dictionary containing information of entire evaluation Return: wer: final word error rate loss: final loss cer """ if 'EvalLoss' in global_vars: eloss = torch.mean(torch.stack(global_vars['EvalLoss'])).item() else: eloss = None hypotheses = global_vars['predictions'] references = global_vars['transcripts'] wer, scores, num_words = word_error_rate(hypotheses=hypotheses, references=references) #lnw add for cer cer = get_distance(ref_labels=references, hyp_labels=hypotheses) multi_gpu = torch.distributed.is_initialized() if multi_gpu: if eloss is not None: eloss /= torch.distributed.get_world_size() eloss_tensor = torch.tensor(eloss).cuda() dist.all_reduce(eloss_tensor) eloss = eloss_tensor.item() del eloss_tensor scores_tensor = torch.tensor(scores).cuda() dist.all_reduce(scores_tensor) scores = scores_tensor.item() del scores_tensor num_words_tensor = torch.tensor(num_words).cuda() dist.all_reduce(num_words_tensor) num_words = num_words_tensor.item() del num_words_tensor wer = scores *1.0/num_words return wer, eloss, cer
last_request = False predictions = [] while not last_request: batch_audio_samples = [] batch_filenames = [] for idx in range(FLAGS.batch_size): filename = filenames[audio_idx] print("Reading audio file: ", filename) audio = AudioSegment.from_file(filename, offset=0, duration=FLAGS.fixed_size).samples if FLAGS.fixed_size: audio = np.resize(audio, FLAGS.fixed_size) audio_idx = (audio_idx + 1) % len(filenames) if audio_idx == 0: last_request = True batch_audio_samples.append(audio) batch_filenames.append(filename) predictions += speech_client.recognize(batch_audio_samples, batch_filenames) if transcripts: predictions = [x for l in predictions for x in l] from metrics import word_error_rate wer, scores, num_words = word_error_rate(predictions, transcripts) print(wer)