def compute_objectives(self, predictions, targets, stage="train"): if stage == "train": p_ctc, p_seq, wav_lens = predictions else: p_ctc, p_seq, wav_lens, hyps = predictions ids, phns, phn_lens = targets phns, phn_lens = phns.to(params.device), phn_lens.to(params.device) # Add phn_lens by one for eos token abs_length = torch.round(phn_lens * phns.shape[1]) # Append eos token at the end of the label sequences phns_with_eos = append_eos_token( phns, length=abs_length, eos_index=params.eos_index ) # convert to speechbrain-style relative length rel_length = (abs_length + 1) / phns.shape[1] loss_ctc = params.ctc_cost(p_ctc, phns, wav_lens, phn_lens) loss_seq = params.seq_cost(p_seq, phns_with_eos, length=rel_length) loss = params.ctc_weight * loss_ctc + (1 - params.ctc_weight) * loss_seq stats = {} if stage != "train": ind2lab = params.train_loader.label_dict["phn"]["index2lab"] sequence = convert_index_to_lab(hyps, ind2lab) phns = undo_padding(phns, phn_lens) phns = convert_index_to_lab(phns, ind2lab) per_stats = edit_distance.wer_details_for_batch( ids, phns, sequence, compute_alignments=True ) stats["PER"] = per_stats return loss, stats
def append( self, ids, predict, target, predict_len=None, target_len=None, ind2lab=None, ): """Add stats to the relevant containers. * See MetricStats.append() Arguments --------- ids : list List of ids corresponding to utterances. predict : torch.tensor A predicted output, for comparison with the target output target : torch.tensor The correct reference output, for comparison with the prediction. predict_len : torch.tensor The predictions relative lengths, used to undo padding if there is padding present in the predictions. target_len : torch.tensor The target outputs' relative lengths, used to undo padding if there is padding present in the target. ind2lab : callable Callable that maps from indices to labels, operating on batches, for writing alignments. """ self.ids.extend(ids) if predict_len is not None: predict = undo_padding(predict, predict_len) if target_len is not None: target = undo_padding(target, target_len) if ind2lab is not None: predict = ind2lab(predict) target = ind2lab(target) if self.merge_tokens: predict = merge_char(predict) target = merge_char(target) if self.split_tokens: predict = split_word(predict) target = split_word(target) scores = edit_distance.wer_details_for_batch(ids, target, predict, True) self.scores.extend(scores)
def _check_coverage_from_bpe(self, list_csv_files=[]): """Logging the accuracy of the BPE model to recover words from the training text. Arguments --------- csv_list_to_check : list, List of the csv file which is used for checking the accuracy of recovering words from the tokenizer. """ for csv_file in list_csv_files: if os.path.isfile(os.path.abspath(csv_file)): logger.info( "==== Accuracy checking for recovering text from tokenizer ===" ) fcsv_file = open(csv_file, "r") reader = csv.reader(fcsv_file) headers = next(reader, None) if self.csv_read not in headers: raise ValueError(self.csv_read + " must exist in:" + csv_file) index_label = headers.index(self.csv_read) wrong_recover_list = [] for row in reader: row = row[index_label] if self.char_format_input: (row, ) = merge_char([row.split()]) row = " ".join(row) row = row.split("\n")[0] encoded_id = self.sp.encode_as_ids(row) decode_text = self.sp.decode_ids(encoded_id) (details, ) = edit_distance.wer_details_for_batch( ["utt1"], [row.split(" ")], [decode_text.split(" ")], compute_alignments=True, ) if details["WER"] > 0: for align in details["alignment"]: if align[0] != "=" and align[1] is not None: if align[1] not in wrong_recover_list: wrong_recover_list.append(align[1]) fcsv_file.close() logger.info("recover words from: " + csv_file) if len(wrong_recover_list) > 0: logger.warn("Wrong recover words: " + str(len(wrong_recover_list))) logger.warn("Tokenizer vocab size: " + str(self.sp.vocab_size())) logger.warn("accuracy recovering words: " + str(1 - float(len(wrong_recover_list)) / self.sp.vocab_size())) else: logger.info("Wrong recover words: 0") logger.warning("accuracy recovering words: " + str(1.0)) else: logger.info("No accuracy recover checking for" + csv_file)
def compute_forward_tea(self, x, y, init_params=False): ids, wavs, wav_lens = x ids, phns, phn_lens = y wavs, wav_lens = wavs.to(params.device), wav_lens.to(params.device) phns, phn_lens = phns.to(params.device), phn_lens.to(params.device) if hasattr(params, "augmentation"): wavs = params.augmentation(wavs, wav_lens, init_params) feats = params.compute_features(wavs, init_params) feats = params.normalize(feats, wav_lens) apply_softmax = torch.nn.Softmax(dim=-1) ind2lab = params.train_loader.label_dict["phn"]["index2lab"] phns_decode = undo_padding(phns, phn_lens) phns_decode = convert_index_to_lab(phns_decode, ind2lab) # run inference to each teacher model tea_dict_list = [] for num in range(params.num_tea): tea_dict = {} self.tea_modules_list[num].eval() with torch.no_grad(): x_tea = tea_enc_list[num](feats, init_params=init_params) ctc_logits_tea = tea_ctc_lin_list[num](x_tea, init_params) # output layer for ctc log-probabilities p_ctc_tea = params.log_softmax(ctc_logits_tea / params.T) # Prepend bos token at the beginning y_in_tea = prepend_bos_token(phns, bos_index=params.bos_index) e_in_tea = tea_emb_list[num](y_in_tea, init_params=init_params) h_tea, _ = tea_dec_list[num]( e_in_tea, x_tea, wav_lens, init_params ) # output layer for seq2seq log-probabilities seq_logits_tea = tea_seq_lin_list[num](h_tea, init_params) p_seq_tea = apply_softmax(seq_logits_tea / params.T) # WER from output layer of CTC sequence_ctc = ctc_greedy_decode( p_ctc_tea, wav_lens, blank_id=params.blank_index ) sequence_ctc = convert_index_to_lab(sequence_ctc, ind2lab) per_stats_ctc = edit_distance.wer_details_for_batch( ids, phns_decode, sequence_ctc, compute_alignments=False ) wer_ctc_tea = [] for item in per_stats_ctc: wer_ctc_tea.append(item["WER"]) wer_ctc_tea = exclude_wer(wer_ctc_tea) wer_ctc_tea = np.expand_dims(wer_ctc_tea, axis=0) # WER from output layer of CE _, predictions = p_seq_tea.max(dim=-1) hyps = batch_filter_seq2seq_output( predictions, eos_id=params.eos_index ) sequence_ce = convert_index_to_lab(hyps, ind2lab) per_stats_ce = edit_distance.wer_details_for_batch( ids, phns_decode, sequence_ce, compute_alignments=False ) wer_tea = [] for item in per_stats_ce: wer_tea.append(item["WER"]) wer_tea = exclude_wer(wer_tea) wer_tea = np.expand_dims(wer_tea, axis=0) # save the variables into dict tea_dict["p_ctc_tea"] = p_ctc_tea.cpu().numpy() tea_dict["p_seq_tea"] = p_seq_tea.cpu().numpy() tea_dict["wer_ctc_tea"] = wer_ctc_tea tea_dict["wer_tea"] = wer_tea tea_dict_list.append(tea_dict) return tea_dict_list
def compute_objectives(self, predictions, targets, data_dict, batch_id, stage="train"): if stage == "train": p_ctc, p_seq, wav_lens = predictions else: p_ctc, p_seq, wav_lens, hyps = predictions ids, phns, phn_lens = targets phns, phn_lens = phns.to(params.device), phn_lens.to(params.device) # Add phn_lens by one for eos token abs_length = torch.round(phn_lens * phns.shape[1]) # Append eos token at the end of the label sequences phns_with_eos = append_eos_token(phns, length=abs_length, eos_index=params.eos_index) # convert to speechbrain-style relative length rel_length = (abs_length + 1) / phns.shape[1] # normal supervised training loss_ctc_nor = params.ctc_cost(p_ctc, phns, wav_lens, phn_lens) loss_seq_nor = params.seq_cost(p_seq, phns_with_eos, length=rel_length) # load teacher inference results item_tea_list = [None, None, None, None] for tea_num in range(params.num_tea): for i in range(4): item_tea = data_dict[str(batch_id)][tea_name[tea_num]][ tea_keys[i]][()] if tea_keys[i].startswith("wer"): item_tea = torch.tensor(item_tea) else: item_tea = torch.from_numpy(item_tea) item_tea = item_tea.to(params.device) item_tea = torch.unsqueeze(item_tea, 0) if tea_num == 0: item_tea_list[i] = item_tea else: item_tea_list[i] = torch.cat([item_tea_list[i], item_tea], 0) p_ctc_tea = item_tea_list[0] p_seq_tea = item_tea_list[1] wer_ctc_tea = item_tea_list[2] wer_tea = item_tea_list[3] # Stategy "average": average losses of teachers when doing distillation. # Stategy "best": choosing the best teacher based on WER. # Stategy "weighted": assigning weights to teachers based on WER. if params.strategy == "best": # tea_ce for kd wer_scores, indx = torch.min(wer_tea, dim=0) indx = list(indx.cpu().numpy()) # select the best teacher for each sentence tea_seq2seq_pout = None for stn_indx, tea_indx in enumerate(indx): s2s_one = p_seq_tea[tea_indx][stn_indx] s2s_one = torch.unsqueeze(s2s_one, 0) if stn_indx == 0: tea_seq2seq_pout = s2s_one else: tea_seq2seq_pout = torch.cat([tea_seq2seq_pout, s2s_one], 0) apply_softmax = torch.nn.Softmax(dim=0) if params.strategy == "best" or params.strategy == "weighted": # mean wer for ctc tea_wer_ctc_mean = wer_ctc_tea.mean(1) tea_acc_main = 100 - tea_wer_ctc_mean # normalise weights via Softmax function tea_acc_softmax = apply_softmax(tea_acc_main) if params.strategy == "weighted": # mean wer for ce tea_wer_mean = wer_tea.mean(1) tea_acc_ce_main = 100 - tea_wer_mean # normalise weights via Softmax function tea_acc_ce_softmax = apply_softmax(tea_acc_ce_main) # kd loss ctc_loss_list = None ce_loss_list = None for tea_num in range(params.num_tea): # ctc p_ctc_tea_one = p_ctc_tea[tea_num] # calculate CTC distillation loss of one teacher loss_ctc_one = params.ctc_cost_kd(p_ctc, p_ctc_tea_one, wav_lens) loss_ctc_one = torch.unsqueeze(loss_ctc_one, 0) if tea_num == 0: ctc_loss_list = loss_ctc_one else: ctc_loss_list = torch.cat([ctc_loss_list, loss_ctc_one]) # ce p_seq_tea_one = p_seq_tea[tea_num] # calculate CE distillation loss of one teacher loss_seq_one = params.seq_cost_kd(p_seq, p_seq_tea_one, rel_length) loss_seq_one = torch.unsqueeze(loss_seq_one, 0) if tea_num == 0: ce_loss_list = loss_seq_one else: ce_loss_list = torch.cat([ce_loss_list, loss_seq_one]) # kd loss if params.strategy == "average": # get average value of losses from all teachers (CTC and CE loss) ctc_loss_kd = ctc_loss_list.mean(0) seq2seq_loss_kd = ce_loss_list.mean(0) else: # assign weights to different teachers (CTC loss) ctc_loss_kd = (tea_acc_softmax * ctc_loss_list).sum(0) if params.strategy == "best": # only use the best teacher to compute CE loss seq2seq_loss_kd = params.seq_cost_kd(p_seq, tea_seq2seq_pout, rel_length) if params.strategy == "weighted": # assign weights to different teachers (CE loss) seq2seq_loss_kd = (tea_acc_ce_softmax * ce_loss_list).sum(0) # total loss # combine normal supervised training loss_ctc = (params.Temperature * params.Temperature * params.alpha * ctc_loss_kd + (1 - params.alpha) * loss_ctc_nor) loss_seq = (params.Temperature * params.Temperature * params.alpha * seq2seq_loss_kd + (1 - params.alpha) * loss_seq_nor) loss = params.ctc_weight * loss_ctc + (1 - params.ctc_weight) * loss_seq stats = {} if stage != "train": ind2lab = params.train_loader.label_dict["phn"]["index2lab"] sequence = convert_index_to_lab(hyps, ind2lab) phns = undo_padding(phns, phn_lens) phns = convert_index_to_lab(phns, ind2lab) per_stats = edit_distance.wer_details_for_batch( ids, phns, sequence, compute_alignments=True) stats["PER"] = per_stats return loss, stats