def decode_wer(pred, truth, id_dict, space=0, emp_t=1, is_list = True, verbose = 2): decoded_pred = [] wer = 0 n=0 pr = 0 if is_list: for j in range(len(pred)): pred_batch = pred[j] truth_batch = truth[j] for i in range(len(pred_batch)): clean_pred, text_pred = decoder(pred_batch[i], id_dict, emp_token=emp_t) text_truth = ans_decoder(truth_batch[i], id_dict, space=space) if len(text_truth) != 0 and text_truth != ' ': wer_1 = jiwer.wer(text_truth, text_pred) wer += wer_1 if pr < verbose: print('text pred:', text_pred, ', text true:', text_truth, ', wer:', wer_1) pr +=1 decoded_pred.append(text_pred) n+=1 else: for i in range(len(pred)): clean_pred, text_pred = decoder(pred[i], id_dict, emp_token=emp_t) text_truth = ans_decoder(truth[i], id_dict, space=space) if len(text_truth) != 0 and text_truth != ' ': wer_1 = jiwer.wer(text_truth, text_pred) wer += wer_1 if pr < verbose: print('text pred:', text_pred, 'text true:', text_truth) pr +=1 decoded_pred.append(text_pred) n+=1 return 100*(wer/n)
def compute_wer(data_dir, split): overall_ref = "" overall_hyp = "" filename = os.path.join(data_dir, split + '_aligned.tsv') df_orig = pd.read_csv(filename, sep="\t") df_orig['orig_id'] = df_orig.apply( lambda row: '{}_{}_{}'.format(row.filenum, row.true_speaker, str(row.turn_id).zfill(4)), axis=1) filename = os.path.join(data_dir, split + '_asr.tsv') df_asr = pd.read_csv(filename, sep="\t") df_asr['asr_hyp'] = df_asr.sent_id.apply(lambda x: int(x.split('-')[-1])) dfs = [] for orig_id, df_sent in df_asr.groupby('orig_id'): this_df = df_sent[df_sent.asr_hyp == df_sent.asr_hyp.min()] dfs.append(this_df) df_asr = pd.concat(dfs).reset_index() wer_dict = {} orig_ids = set(df_orig.orig_id) for orig_id in orig_ids: turn_orig = df_orig[df_orig.orig_id == orig_id] turn_orig = turn_orig[turn_orig.da_token != "<MISSED>"] turn_asr = df_asr[df_asr.orig_id == orig_id] aref = turn_orig.da_token.tolist() ref = ' '.join(aref).replace(" '", "'") ahyp = turn_asr.da_token.tolist() hyp = ' '.join(ahyp) overall_ref += ref + " " overall_hyp += hyp + " " this_wer = jiwer.wer(hyp, ref) wer_dict[orig_id] = this_wer overall_wer = jiwer.wer(overall_ref, overall_hyp) return wer_dict, overall_wer
def compute_wer_council(): dev_gt, test_gt = load_council_ground_truth() dev_pred, test_pred = load_council_prediction_no_lm() dev_pred_lm, test_pred_lm = load_council_prediction_with_lm() print('WER, dev') print(wer(dev_gt, dev_pred), 'no lm') print(wer(dev_gt, dev_pred_lm), 'with lm') print('-----') print('WER, test:') print(wer(test_gt, test_pred), 'no lm') print(wer(test_gt, test_pred_lm), 'with lm')
def compare_asr(s_wav, t_wav): try: gt = asr(s_wav) recog = asr(t_wav) err_result = wer(gt, recog), wer(' '.join([c for c in gt if c != ' ']), ' '.join([c for c in recog if c != ' '])) except sr.UnknownValueError: err_result = [1., 1.] except: err_result = [-1., -1.] return err_result
def test_standardize(self): ground_truth = "he's my neminis" hypothesis = "he is my <unk> [laughter]" x = jiwer.wer(ground_truth, hypothesis, standardize=True) # is equivalent to ground_truth = "he is my neminis" hypothesis = "he is my" y = jiwer.wer(ground_truth, hypothesis) self.assertEqual(x, y)
def test_words_to_filter(self): ground_truth = "yhe about that bug" hypothesis = "yeah about that bug" x = jiwer.wer(ground_truth, hypothesis, words_to_filter=["yhe", "yeah"]) # is equivalent to ground_truth = "about that bug" hypothesis = "about that bug" y = jiwer.wer(ground_truth, hypothesis) self.assertEqual(x, y)
def probability_analysis(sentences, predictions_sentences, selected_sentences, x): ''' ANALYSIS OF PROBABILITY SENTENCE AGAINST MODEL PREDICTION AND GROUND TRUTH outputs ground truth, model prediction sentence, probability sentence and WER of model sentence and probability sentence ''' import jiwer from jiwer import wer class color: PURPLE = '\033[95m' CYAN = '\033[96m' DARKCYAN = '\033[36m' BLUE = '\033[94m' GREEN = '\033[92m' YELLOW = '\033[93m' RED = '\033[91m' BOLD = '\033[1m' UNDERLINE = '\033[4m' END = '\033[0m' print(color.BOLD + 'Ground truth: ' + color.END) print(sentences[x]) print('') print(color.BOLD + 'Model prediction: ' + color.END) print(predictions_sentences[x]) print( "-----------------------------------------------------------------------------" ) print(color.BOLD + 'WER model: ' + color.END, wer(sentences[x], predictions_sentences[x])) print( "-----------------------------------------------------------------------------" ) print('') print(color.BOLD + '50% treshold: ' + color.END) print(selected_sentences[x]) print( "-----------------------------------------------------------------------------" ) print(color.BOLD + 'WER 50% treshold: ' + color.END, wer(sentences[x], selected_sentences[x])) print( "-----------------------------------------------------------------------------" )
def batch_metrics_asr(refs, hyps): score_lists = {"LWER": [], "LER": [], "SER": [], "NSER": [], "DAER": []} for ref_labels, hyp_labels in zip(refs, hyps): this_metrics = instance_metrics_asr(ref_labels, hyp_labels) for k, v in this_metrics.items(): score_lists[k].append(v) flattened_refs = [label for ref in refs for label in ref] flattened_hyps = [label for hyp in hyps for label in hyp] flat_ref_short = [x for x in flattened_refs if x != "I"] flat_hyp_short = [x for x in flattened_hyps if x != "I"] lwer = jiwer.wer(flat_ref_short, flat_hyp_short) ler = jiwer.wer(flattened_refs, flattened_hyps) t_ids = [i for i, t in enumerate(flattened_refs) if "E" in t] r_ids = [i for i, r in enumerate(flattened_hyps) if "E" in r] s = 0 for t in t_ids: s += min([abs(r - t) for r in r_ids]) for r in r_ids: s += min([abs(r - t) for t in t_ids]) ser = s / 2 / len(flattened_refs) nser = abs(len(t_ids) - len(r_ids)) / len(t_ids) new_ref = [] new_hyp = [] offset = 0 for i in t_ids: new_ref += [flattened_refs[i]] * (i - offset + 1) offset = i + 1 offset = 0 for i in r_ids: new_hyp += [flattened_hyps[i]] * (i - offset + 1) offset = i + 1 daer = jiwer.wer(new_ref, new_hyp) return { "Macro LWER": np.mean(score_lists["LWER"]), "Micro LWER": lwer, "Macro LER": np.mean(score_lists["LER"]), "Micro LER": ler, "Macro SER": np.mean(score_lists["SER"]), "Micro SER": ser, "Macro NSER": np.mean(score_lists["NSER"]), "Micro NSER": nser, "Macro DAER": np.mean(score_lists["DAER"]), "Micro DAER": daer, }
def wer_evaluation(ostt, asr): """ STEPS: 1- conver ostt and asr to a string. 2- run preprocessing over them 3- run wer over them. :param ostt: the list of OSt(OStt) sentences :param asr: the list of asr sentences :return: Return a WER score """ #----------convert ostt to a string ostt_string = '' detokenize = MosesDetokenizer().detokenize for i in ostt: ostt_string += ' ' ostt_string += detokenize(i) #----------convert asr to a string asr_string = '' for i in asr: asr_string += ' ' asr_string += detokenize(i) #--------preprocessing asr_string = text_preprocessing(asr_string) ostt_string = text_preprocessing(ostt_string) #-------run wer return wer(ostt_string, asr_string)
def best_phrase(ground_truth, hyp): # Finds the phrase in the hypothesis string that best matches the ground truth. # In the case of ties, returns the shortest match ground_truth_len = len(ground_truth.split()) try: hyp_tokens = hyp.split() except: print(hyp) raise opt_phrase = "" opt_phrase_len = math.inf opt_wer = math.inf for start_pos in range(len(hyp_tokens)): for end_pos in range(start_pos, start_pos + ground_truth_len + 1): this_phrase = " ".join(hyp_tokens[start_pos:end_pos]) this_phrase_len = len(this_phrase.split()) this_wer = wer(ground_truth, this_phrase) if (this_wer < opt_wer) or (this_wer == opt_wer and this_phrase_len < opt_phrase_len): opt_phrase = this_phrase opt_phrase_len = this_phrase_len opt_wer = this_wer return (opt_phrase, opt_wer)
def eval(self): logger.info('now evaluate!') self._model.eval() wer_score = 0. acc_score = 0. f1_score = 0. corrected_sent_cnt = 0. score_failure_cnt = 0 for step, text in tqdm(enumerate(self._dataset), desc='evaluation steps', total=len(self._dataset)): if self.limit_len is not None: text = text[:self.limit_len] try: unspaced_text = unspacing(text.strip()) tokenized_text = text_to_list(unspaced_text) input_batch = torch.Tensor( [self._input_vocab.to_indices(tokenized_text)]).long() _, tag_seq = self._model(input_batch) labeled_tag_seq = self._tag_vocab.to_tokens( tag_seq[0].tolist()) pred_text = segment_word_by_tags(unspaced_text, labeled_tag_seq) wer_score += jiwer.wer(text.strip(), pred_text.strip()) if text.split() == pred_text.split(): corrected_sent_cnt += 1 _, labels = labelize(text, bi_tags_only=True) labels = [ch for ch in labels] labeled_tag_seq = ' '.join(labeled_tag_seq).replace( 'E', 'I').replace('S', 'B').replace('<pad>', 'I').split() acc_score += acc(labeled_tag_seq, labels) f1_score += f1(labeled_tag_seq, labels, labels=['B', 'I']) except Exception as e: score_failure_cnt += 1 logger.warning( "Error message while calculating wer score: {}".format(e)) logger.info( 'wer score failure {} times'.format(score_failure_cnt)) raise ValueError() else: wer_score = wer_score / (step + 1 - score_failure_cnt) corrected_sent_cnt = corrected_sent_cnt / (step + 1 - score_failure_cnt) acc_score = acc_score / (step + 1 - score_failure_cnt) f1_score = f1_score / (step + 1 - score_failure_cnt) self._wer_score = wer_score self._corrected_sent_cnt = corrected_sent_cnt self._acc_score = acc_score self._f1_score = f1_score logger.info('evaluation done!')
def test_files(sound_dir: str, stt: SpeechToTextV1): """ :param csv_path: :param stt: :return: """ gold_sents = [] hypo_sents = [] utt_errors = 0 utt_compares = 0 for recording_path, gold_transcript in _gather_sound_files(sound_dir): recognized_transcript = recognize_audio(stt, recording_path)[0][0].strip() gold_sents.append(gold_transcript) hypo_sents.append(recognized_transcript) if _str_clean(gold_transcript) != _str_clean(recognized_transcript): utt_errors += 1 utt_compares += 1 wer = jiwer.wer(gold_sents, hypo_sents, hypothesis_transform=_str_clean, truth_transform=_str_clean) print(f'WER: {wer}\nSER: {utt_errors/utt_compares}')
def evaluate_step(self, batch): xs, ys, xlen, ylen = [x.to(device) for x in batch] ys = ys[:, :ylen.max()].contiguous() outputs = self.frontend(xs) xs = outputs.permute(0, 2, 1) max_length = xlen.max() xs_shape = xs.shape[1] xlen_ = torch.floor(xlen.float() / ( max_length.item() / xs_shape )).int() xlen = xlen_ xs = xs[:, :xlen.max()].contiguous() xs = xs[:, :xlen.max()] loss = self.model(xs, ys, xlen, ylen) if FLAGS.multi_gpu: loss = loss.mean() if FLAGS.multi_gpu: ys_hat, nll = self.model.module.greedy_decode(xs, xlen) else: ys_hat, nll = self.model.greedy_decode(xs, xlen) pred_seq = self.tokenizer.decode_plus(ys_hat) true_seq = self.tokenizer.decode_plus(ys.cpu().numpy()) wer = jiwer.wer(true_seq, pred_seq) return loss.item(), wer, pred_seq, true_seq
def upload_and_record(db, sid, path): import WavInfo import cloud_speech from jiwer import wer print('begin to get reading') unit, reading_len, content = db.get_reading_content(sid) print('finish get reading') print('begin to upload google cloud') destination_blob_name = 'unit ' + str(unit) + '/' + sid + '.wav' cloud_speech.upload_blob("speech_to_text_class", path, destination_blob_name) print('begin to transcribe file') gcs_url = "gs://speech_to_text_class/" + destination_blob_name transcript, confidence = cloud_speech.transcribe_gcs(gcs_url) print('finis transcribe') print('Start calculating reading speed and word error rate') reading_time = WavInfo.get_wav_time(path) reading_speed = int(reading_len / (reading_time / 60.0)) word_error_rate = (1 - wer(content.encode('utf-8'), str(transcript))) * 100 word_error_rate = float('%.4f' % word_error_rate) db.record_reading(sid, unit, transcript, reading_speed, word_error_rate, confidence)
def run_batch(real_word_batch, word_batch, pron_batch, avg_per, avg_wer, n): word_batch, pron_batch = self.prepare_batches( word_batch, pron_batch) max_pron_len = int(pron_batch.shape[0] * 1.2) pron_batch = torch.LongTensor([ [self.phoneme_start_idx] * word_batch.shape[1] ]).to(self.device) for _ in range(max_pron_len): y = self.G2P_model(word_batch, pron_batch, device=self.device) pron_batch = torch.cat( (pron_batch, torch.argmax(y[-1, :, :], dim=-1).unsqueeze(0)), dim=0) for i, word in enumerate(real_word_batch): real_pronunciation = self.lexicon[word] pronunciation = [] for x in pron_batch[1:, i]: x = x.item() if x not in self.index_mapping or x == self.phoneme_pad_idx: break #Since pad_idx is <SIL> it will be in index_mapping pronunciation.append(self.index_mapping[x]) avg_per += (wer(real_pronunciation, pronunciation) - avg_per) / (n + 1) avg_wer += (int(real_pronunciation != pronunciation) - avg_wer) / (n + 1) n += 1 return avg_per, avg_wer, n
def wer(ref_string, string): ref_speech_content = extract_speech_content(ref_string) speech_content = extract_speech_content(string) return jiwer.wer(ref_speech_content, speech_content, truth_transform=JIWER_TRANSFORM, hypothesis_transform=JIWER_TRANSFORM)
def on_test_message(self, bus, message): structure = message.get_structure() if structure and structure.get_name( ) == "deepspeech" and structure.get_value("intermediate") == False: self.recognised_text += structure.get_value("text") + "\n" self.training_progress.set_fraction( (self.testing_sample + 1) / self.sample_id) if self.testing_sample < self.sample_id - 1: self.test_pipeline.set_state(Gst.State.NULL) self.test_sample(self.testing_sample + 1) else: self.test_text = jiwer.RemovePunctuation()( self.test_text).lower().encode("ascii", "ignore").decode() print("Expected:", self.test_text) print("Got:", self.recognised_text) accuracy = 100 - jiwer.wer( self.test_text.replace("\n", " "), self.recognised_text.replace("\n", " ").replace("'", "")) * 100 if self.pretraining: pretraining_accuracy_label = self.builder.get_object( "pretraining_accuracy_label") pretraining_accuracy_label.set_text("%.2f%%" % accuracy) self.pretraining = False self.training = True status_label = self.builder.get_object("status_label") status_label.set_text("Training...") self.training_progress.set_fraction(0) if self.posttraining: posttraining_accuracy_label = self.builder.get_object( "posttraining_accuracy_label") posttraining_accuracy_label.set_text("%.2f%%" % accuracy) spinner = self.builder.get_object("spinner") spinner.set_active = False self.posttraining = False
def read_and_tokenize(self): ''' Reads and tokenizes source and target sentences ''' source_sentences = [] target_sentences = [] tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') with open(self.data_path, 'r+') as f: for line in f: line = line.split(',') if self.encoding_type == 'word': source_tokens = word_tokenize(line[0]) target_tokens = word_tokenize(line[1]) elif self.encoding_type == 'char': source_tokens = self.char_tokenize(line[0]) target_tokens = self.char_tokenize(line[1]) elif self.encoding_type == 'bpe': source_tokens = tokenizer.tokenize(line[0]) target_tokens = tokenizer.tokenize(line[1]) # TODO: add a custom BPE tokenizer else: raise NotImplementedError( 'Dataset only supports character-based (char), word-based (word) or bpe encoding ' ) if self.filter_threshold: if wer(target_tokens, source_tokens) > self.filter_threshold: continue source_sentences.append(source_tokens) target_sentences.append(['<s>'] + target_tokens + ['</s>']) return source_sentences, target_sentences
def compute_cer(predictions, references, concatenate_texts=False): if concatenate_texts: return jiwer.wer( references, predictions, truth_transform=cer_transform, hypothesis_transform=cer_transform, ) incorrect = 0 total = 0 for prediction, reference in zip(predictions, references): measures = jiwer.compute_measures( reference, prediction, truth_transform=cer_transform, hypothesis_transform=cer_transform, ) incorrect += measures["substitutions"] + measures[ "deletions"] + measures["insertions"] total += measures["substitutions"] + measures["deletions"] + measures[ "hits"] return incorrect / total
def WER_by_mwersegmenter(ostt, asr, SLTev_home, temp_folder): """ Calculating WER score without Mosses-tokenizer and with preprocessing For each segment obtained from the mwersegmenter segmentation, the WER score is calculated and finally, the average is taken. :param ostt: the list of OSt(OStt) sentences :param asr: the list of asr sentences :param SLTev_home: path oof the SLTev files (/path/to/mwerSegmenter) :param temp_folder: name of tem folder that created by UUID :return: Return a WER score """ segments, mWERQuality = segmentation_by_mwersegmenter( ostt, asr, SLTev_home, temp_folder) detokenize = MosesDetokenizer().detokenize asr = segments[:] # ------------------convert to text and preprocessing and run wer wer_scores = list() for i in range(len(ostt)): ostt_text = detokenize(ostt[i]) ostt_text = text_preprocessing(ostt_text) asr_text = detokenize(asr[i]) asr_text = text_preprocessing(asr_text) score = wer(ostt_text, asr_text) wer_scores.append(score) return sum(wer_scores) / len(wer_scores)
def evaluate(valid_loader, model, criterion, device, id_letters): model.eval() loss = 0 test_wer = [] with torch.no_grad(): for i, (melspec, tokens, target_len, padded_len) in enumerate(valid_loader): melspec, tokens = melspec.to(device), tokens.to(device) outputs = model(melspec.unsqueeze(1).transpose(2, 3)) outputs = F.log_softmax(outputs, dim=2) loss += criterion(outputs.transpose(0, 1), tokens, padded_len, target_len).item() decoded_preds, decoded_targets = greedy_decoder( outputs, tokens, target_len, id_letters) for j in range(len(decoded_preds)): test_wer.append(jiwer.wer(decoded_targets[j], decoded_preds[j])) loss /= len(valid_loader) #wandb.log({"Validation loss": loss}) avg_wer = sum(test_wer) / len(test_wer) #wandb.log({"WER": avg_wer}) print('Validation: Average loss: {:.4f}, Average WER: {:.4f}\n'.format( loss, avg_wer))
def student_conversation(self, student_conversation_log): import json from jiwer import wer datetime = 'NOW()' conversation = json.loads(student_conversation_log) sql_sentence = '' sid = str(conversation['student'][0]) character = str(conversation['character'][0]) for i in range(len(conversation['student'])): stu_say = conversation['student_say'][i].replace("'", r"\'") chr_say = conversation['character_say'][i].replace("'", r"\'") word_error_rate = wer(str(chr_say), str(stu_say)) sql_sentence += "('{}', '{}', '{}', '{}', {}, {})," \ .format(sid, character, stu_say, chr_say, word_error_rate, datetime) sql_sentence = "INSERT INTO conversation_record(stu_id, character_name, stu_say, character_say , " + \ "word_error_rate, datetime) VALUES {};".format(sql_sentence[:-1]) print(sql_sentence) try: self.cursor.execute(sql_sentence) self.db.commit() except NameError, MySQLdb.OperationalError: self.operation_error() self.cursor.execute(sql_sentence) self.db.commit()
def _calc_metrics(self, ground_truth, hypothesis): transformation = jiwer.Compose([ jiwer.ToLowerCase(), jiwer.RemoveMultipleSpaces(), jiwer.RemoveWhiteSpace(replace_by_space=" "), jiwer.SentencesToListOfWords(word_delimiter=" ") ]) mer = jiwer.mer(ground_truth, hypothesis, truth_transform=transformation, hypothesis_transform=transformation) wer = jiwer.wer(ground_truth, hypothesis, truth_transform=transformation, hypothesis_transform=transformation) wil = jiwer.wil(ground_truth, hypothesis, truth_transform=transformation, hypothesis_transform=transformation) wip = jiwer.wip(ground_truth, hypothesis, truth_transform=transformation, hypothesis_transform=transformation) return mer, wer, wil, wip
def add_scores_into_csv(excel_file, actual_col, generated_col, arabic=True): # Read the excel file excel = pd.read_csv(excel_file, encoding="ISO-8859-1") excel_values = excel.values sum = 0 for index in range(0, len(excel_values)): actual = str(excel.at[index, actual_col]).replace(".", "").replace( "?", "").replace("!", "").lower() generated = str(excel.at[index, generated_col]).replace(".", "").replace( "?", "").replace("!", "").lower() # actual = actual.replace(" a ", " ") # actual = actual.replace(" an ", " ") # generated = generated.replace(" a ", " ") # generated = generated.replace(" an ", " ") score = get_similarity_score(actual, generated) word_er = wer(actual, generated, standardize=True) sum += word_er word_er = 1 if word_er > 1 else word_er excel.at[index, "word_error_rate"] = 1 - word_er excel.at[index, "scores"] = score print("actual: ", actual) print("generated: ", generated) print(1 - word_er) print((sum / len(excel_values)) * 100) excel.to_csv(excel_file)
def computeER(chunkId=None): sql = "select * from benchmark_deepspeech where is_verified = true" if chunkId != None: sql = "select * from benchmark_deepspeech where id = "+str(chunkId) try: con = psycopg2.connect("host='192.168.0.102' dbname='sales' user='******' password='******'") cur = con.cursor() cur.execute(sql) cur2 = con.cursor() while True: row = cur.fetchone() if row == None: break d = {"id":row[0], "audio_url": row[3], "ds_trans": row[6], "real_trans": row[7], "is_verified": row[5], "wer":row[9], "cer":row[8]} if len(row[7]): cer = bd.cer(d["real_trans"],d["ds_trans"]) weri= wer(d["real_trans"],d["ds_trans"]) updateSql = "update benchmark_deepspeech set cer = "+str(cer)+", wer = "+str(weri)+" where id = "+str(d["id"]) cur2.execute(updateSql) print(updateSql) con.commit() else: continue except psycopg2.DatabaseError as e: if con: con.rollback() print(e) sys.exit(1) finally: if con: con.close()
def wer(clean_speech, denoised_speech): """ computes the word error rate(WER) score for 1 single data point """ def _transcription(clean_speech, denoised_speech): # transcribe clean audio input_values = wer_tokenizer(clean_speech, return_tensors="pt").input_values; logits = wer_model(input_values).logits; predicted_ids = torch.argmax(logits, dim=-1); transcript_clean = wer_tokenizer.batch_decode(predicted_ids)[0]; # transcribe input_values = wer_tokenizer(denoised_speech, return_tensors="pt").input_values; logits = wer_model(input_values).logits; predicted_ids = torch.argmax(logits, dim=-1); transcript_estimate = wer_tokenizer.batch_decode(predicted_ids)[0]; return [transcript_clean, transcript_estimate] transcript = _transcription(clean_speech, denoised_speech); try: #if no words are predicted wer_val = jiwer.wer(transcript[0], transcript[1]) except ValueError: wer_val = None return wer_val
def train(net, train_loader, optimizer, criterion_ctc, criterion_cnn, epoch, writer): net.train() running_loss = 0.0 running_wer = 0.0 for batch_idx, batch in enumerate(train_loader): inputs = batch['videos'].cuda() targets = batch['annotations'].permute(1, 0).contiguous().cuda() input_lens = batch['video_lens'].cuda() target_lens = batch['anno_lens'].cuda() n, t, c, h, w = inputs.size() optimizer.zero_grad() raw_logits, concat_logits, part_logits, top_n_prob, outs = net(inputs) cnn_targets = outs.max(-1)[1].permute(1, 0).contiguous().view(-1).data loss_raw = criterion_cnn(raw_logits, cnn_targets) loss_concat = criterion_cnn(concat_logits, cnn_targets) # ?简化? loss_partcls = criterion_cnn( part_logits.view(n * t * PROPOSAL_NUM, -1), cnn_targets.unsqueeze(1).repeat(1, PROPOSAL_NUM).view(-1)) # ?简化 part_targets = list_loss( part_logits.view(n * t * PROPOSAL_NUM, -1), cnn_targets.unsqueeze(1).repeat(1, PROPOSAL_NUM).view(-1)).view( n * t, PROPOSAL_NUM) loss_rank = ranking_loss(top_n_prob, part_targets) loss_ctc = criterion_ctc(outs, targets, input_lens, target_lens) loss_total = loss_raw + loss_concat + loss_partcls + loss_rank + loss_ctc loss_total.backward() # ignore batch that lead gradient exploration flag = False for name, param in net.named_parameters(): if param.grad != None and torch.isnan(param.grad).any(): flag = True break if flag: print(batch_idx) continue optimizer.step() outs = outs.max(-1)[1].permute(1, 0).contiguous().view(-1) outs = ' '.join( [TRG.vocab.itos[i] for i, _ in groupby(outs) if i != VocabSize]) targets = ' '.join([TRG.vocab.itos[i] for i in targets.view(-1)]) running_wer += wer(targets, outs, standardize=True) running_loss += loss_total.item() N = len(train_loader) // 10 if batch_idx % N == N - 1: writer.add_scalar('train loss', running_loss / N, epoch * len(train_loader) + batch_idx) writer.add_scalar('train wer', running_wer / N, epoch * len(train_loader) + batch_idx) running_loss = 0.0 running_wer = 0.0
def val(net, val_loader, criterion, epoch, writer, device): net.eval() epoch_wer = 0.0 epoch_loss = 0.0 with torch.no_grad(): for batch_idx, batch in enumerate(val_loader): inputs = batch['videos'].to(device) targets = batch['annotations'].permute(1,0).contiguous().to(device) input_lens = batch['video_lens'].to(device) target_lens = batch['annotation_lens'].to(device) outs = net(inputs) loss = criterion(outs, targets, input_lens, target_lens) outs = outs.max(-1)[1].permute(1,0).contiguous().view(-1) outs = ' '.join([TRG.vocab.itos[k] for k, _ in groupby(outs) if k != VocabSize]) targets = targets.view(-1) targets = ' '.join([TRG.vocab.itos[k] for k in targets]) epoch_wer += wer(targets, outs, standardize=True) epoch_loss += loss.item() epoch_wer /= len(val_loader) epoch_loss /= len(val_loader) if writer: writer.add_scalar('val wer', epoch_wer, epoch) writer.add_scalar('val loss', epoch_loss, epoch) return epoch_loss, epoch_wer
def test(opt): _, _, test_set = utils.get_dataset(opt) test_loader = DataLoader(test_set, batch_size=opt.batch_size, shuffle=False, num_workers=opt.num_workers, drop_last=False) encode, middle, decode = utils.get_model(opt) checkpoint = torch.load(opt.load, map_location='cpu') encode.load_state_dict(checkpoint['encode']) middle.load_state_dict(checkpoint['mid_net']) decode.load_state_dict(checkpoint['decode']) encode, middle, decode = encode.eval(), middle.eval(), decode.eval() if opt.gpu: encode, middle, decode = encode.cuda(), middle.cuda(), decode.cuda() assert not opt.gpus wer_list = [] person_list = [] with torch.no_grad(): for step, pack in enumerate(tqdm(test_loader)): v = pack[0] align = pack[1] text = pack[3] items = pack[4] if opt.gpu: v = v.cuda() embeddings = encode(v) embeddings, enc_mask = middle(embeddings) res = torch.zeros(embeddings.size(0), embeddings.size(1)).cuda() input_align = torch.zeros(embeddings.size(0), embeddings.size(1)).cuda() input_align[:, 0] = test_set.character_dict['*'] for i in range(embeddings.size(1)): digits = decode.forward_infer(embeddings, input_align, enc_mask) pred = torch.argmax(digits, -1) res[:, i] = pred[:, i] if i < embeddings.size(1) - 1: input_align[:, i + 1] = pred[:, i] pred = list( map( lambda x: ''.join([test_set.idx_dict[i.item()] for i in x]).replace('^', ''), res)) wer_list.extend([wer(p, t) for p, t in zip(pred, text)]) person_list += [x.split('/')[-2] for x in items] wer_list = np.array(wer_list) person_list = np.array(person_list) print('overall wer:{:.4f}'.format(np.mean(wer_list))) for person in list(set(person_list)): print('{} wer:{:.4f}'.format(person, np.mean(wer_list[person_list == person])))
def build_vocab_wer(self): self.wer_dict = {} if not hasattr(self, 'vocab'): self.load() for word1 in self.vocab: self.wer_dict[word1] = {} for word2 in self.vocab: self.wer_dict[word1][word2] = wer(str(word1), str(word2))