def append( self, ids, predict, target, predict_len=None, target_len=None, ind2lab=None, ): """Add stats to the relevant containers. * See MetricStats.append() Arguments --------- ids : list List of ids corresponding to utterances. predict : torch.tensor A predicted output, for comparison with the target output target : torch.tensor The correct reference output, for comparison with the prediction. predict_len : torch.tensor The predictions relative lengths, used to undo padding if there is padding present in the predictions. target_len : torch.tensor The target outputs' relative lengths, used to undo padding if there is padding present in the target. ind2lab : callable Callable that maps from indices to labels, operating on batches, for writing alignments. """ self.ids.extend(ids) if predict_len is not None: predict = undo_padding(predict, predict_len) if target_len is not None: target = undo_padding(target, target_len) if ind2lab is not None: predict = ind2lab(predict) target = ind2lab(target) if self.merge_tokens: predict = merge_char(predict) target = merge_char(target) if self.split_tokens: predict = split_word(predict) target = split_word(target) scores = edit_distance.wer_details_for_batch(ids, target, predict, True) self.scores.extend(scores)
def compute_objectives(self, predictions, batch, stage): """Computes the loss (CTC) given predictions and targets.""" p_ctc, wav_lens = predictions ids = batch.id tokens_eos, tokens_eos_lens = batch.tokens_eos tokens, tokens_lens = batch.tokens loss = self.hparams.ctc_cost(p_ctc, tokens, wav_lens, tokens_lens) if stage != sb.Stage.TRAIN: # Decode token terms to words sequence = sb.decoders.ctc_greedy_decode( p_ctc, wav_lens, blank_id=self.hparams.blank_index) predicted_words = self.tokenizer(sequence, task="decode_from_list") # Convert indices to words target_words = undo_padding(tokens, tokens_lens) target_words = self.tokenizer(target_words, task="decode_from_list") self.wer_metric.append(ids, predicted_words, target_words) self.cer_metric.append(ids, predicted_words, target_words) return loss
def compute_objectives(self, predictions, targets, stage="train"): if stage == "train": p_ctc, p_seq, wav_lens = predictions else: p_ctc, p_seq, wav_lens, hyps = predictions ids, phns, phn_lens = targets phns, phn_lens = phns.to(params.device), phn_lens.to(params.device) # Add phn_lens by one for eos token abs_length = torch.round(phn_lens * phns.shape[1]) # Append eos token at the end of the label sequences phns_with_eos = append_eos_token( phns, length=abs_length, eos_index=params.eos_index ) # convert to speechbrain-style relative length rel_length = (abs_length + 1) / phns.shape[1] loss_ctc = params.ctc_cost(p_ctc, phns, wav_lens, phn_lens) loss_seq = params.seq_cost(p_seq, phns_with_eos, length=rel_length) loss = params.ctc_weight * loss_ctc + (1 - params.ctc_weight) * loss_seq stats = {} if stage != "train": ind2lab = params.train_loader.label_dict["phn"]["index2lab"] sequence = convert_index_to_lab(hyps, ind2lab) phns = undo_padding(phns, phn_lens) phns = convert_index_to_lab(phns, ind2lab) per_stats = edit_distance.wer_details_for_batch( ids, phns, sequence, compute_alignments=True ) stats["PER"] = per_stats return loss, stats
def compute_objectives(self, predictions, targets, stage): """Computes the loss (NLL) given predictions and targets.""" if (stage == sb.Stage.TRAIN and self.batch_count % show_results_every != 0): p_seq, decoded_transcript_lens = predictions else: p_seq, decoded_transcript_lens, predicted_tokens = predictions ids, target_semantics, target_semantics_lens = targets target_tokens, target_token_lens = self.hparams.tokenizer( target_semantics, target_semantics_lens, self.hparams.ind2lab, task="encode", ) target_tokens = target_tokens.to(self.device) target_token_lens = target_token_lens.to(self.device) # Add char_lens by one for eos token abs_length = torch.round(target_token_lens * target_tokens.shape[1]) # Append eos token at the end of the label sequences target_tokens_with_eos = sb.dataio.dataio.append_eos_token( target_tokens, length=abs_length, eos_index=self.hparams.eos_index) # Convert to speechbrain-style relative length rel_length = (abs_length + 1) / target_tokens_with_eos.shape[1] loss_seq = self.hparams.seq_cost(p_seq, target_tokens_with_eos, length=rel_length) # (No ctc loss) loss = loss_seq if (stage != sb.Stage.TRAIN or self.batch_count % show_results_every == 0): # Decode token terms to words predicted_semantics = self.hparams.tokenizer( predicted_tokens, task="decode_from_list") # Convert indices to words target_semantics = undo_padding(target_semantics, target_semantics_lens) target_semantics = sb.dataio.dataio.convert_index_to_lab( target_semantics, self.hparams.ind2lab) for i in range(len(target_semantics)): print(" ".join(predicted_semantics[i]).replace("|", ",")) print(" ".join(target_semantics[i]).replace("|", ",")) print("") if stage != sb.Stage.TRAIN: self.wer_metric.append(ids, predicted_semantics, target_semantics) self.cer_metric.append(ids, predicted_semantics, target_semantics) return loss
def compute_objectives(self, predictions, batch, stage): """Computes the loss (CTC+NLL) given predictions and targets.""" current_epoch = self.hparams.epoch_counter.current if stage == sb.Stage.TRAIN: if current_epoch <= self.hparams.number_of_ctc_epochs: p_ctc, p_seq, wav_lens = predictions else: p_seq, wav_lens = predictions else: p_seq, wav_lens, predicted_tokens = predictions ids = batch.id tokens_eos, tokens_eos_lens = batch.tokens_eos tokens, tokens_lens = batch.tokens loss_seq = self.hparams.seq_cost( p_seq, tokens_eos, length=tokens_eos_lens ) # Add ctc loss if necessary if ( stage == sb.Stage.TRAIN and current_epoch <= self.hparams.number_of_ctc_epochs ): loss_ctc = self.hparams.ctc_cost( p_ctc, tokens, wav_lens, tokens_lens ) loss = self.hparams.ctc_weight * loss_ctc loss += (1 - self.hparams.ctc_weight) * loss_seq else: loss = loss_seq if stage != sb.Stage.TRAIN: # Decode token terms to words predicted_words = self.tokenizer( predicted_tokens, task="decode_from_list" ) # Convert indices to words target_words = undo_padding(tokens, tokens_lens) target_words = self.tokenizer(target_words, task="decode_from_list") self.wer_metric.append(ids, predicted_words, target_words) self.cer_metric.append(ids, predicted_words, target_words) return loss
def compute_objectives(self, predictions, batch, stage): """Computes the loss (CTC+NLL) given predictions and targets.""" current_epoch = self.hparams.epoch_counter.current wav_lens = predictions["wav_lens"] tokens, tokens_lens = batch.tokens tokens_eos, tokens_eos_lens = batch.tokens_eos if hasattr(self.modules, "env_corrupt") and stage == sb.Stage.TRAIN: tokens = torch.cat([tokens, tokens], dim=0) tokens_lens = torch.cat([tokens_lens, tokens_lens]) tokens_eos = torch.cat([tokens_eos, tokens_eos], dim=0) tokens_eos_lens = torch.cat([tokens_eos_lens, tokens_eos_lens]) loss = self.hparams.seq_cost( predictions["p_seq"], tokens_eos, length=tokens_eos_lens ) # Add ctc loss if necessary if ( stage == sb.Stage.TRAIN and current_epoch <= self.hparams.number_of_ctc_epochs ): loss_ctc = self.hparams.ctc_cost( predictions["p_ctc"], tokens, wav_lens, tokens_lens ) loss *= 1 - self.hparams.ctc_weight loss += self.hparams.ctc_weight * loss_ctc if stage != sb.Stage.TRAIN: # Decode token terms to words predicted_words = self.tokenizer( predictions["p_tokens"], task="decode_from_list" ) # Convert indices to words target_words = undo_padding(tokens, tokens_lens) target_words = self.tokenizer(target_words, task="decode_from_list") self.wer_metric.append(batch.id, predicted_words, target_words) self.cer_metric.append(batch.id, predicted_words, target_words) return loss
def compute_objectives(self, predictions, batch, stage): """Computes the loss (CTC+NLL) given predictions and targets.""" ( p_ctc, p_seq, wav_lens, predicted_tokens, ) = predictions ids = batch.id tokens_eos, tokens_eos_lens = batch.tokens_eos tokens, tokens_lens = batch.tokens loss_seq = self.hparams.seq_cost(p_seq, tokens_eos, length=tokens_eos_lens) loss_ctc = self.hparams.ctc_cost(p_ctc, tokens, wav_lens, tokens_lens) loss = (self.hparams.ctc_weight * loss_ctc + (1 - self.hparams.ctc_weight) * loss_seq) if stage != sb.Stage.TRAIN: current_epoch = self.hparams.epoch_counter.current valid_search_interval = self.hparams.valid_search_interval if current_epoch % valid_search_interval == 0 or ( stage == sb.Stage.TEST): # Decode token terms to words predicted_words = self.tokenizer(predicted_tokens, task="decode_from_list") # Convert indices to words target_words = undo_padding(tokens, tokens_lens) target_words = self.tokenizer(target_words, task="decode_from_list") self.wer_metric.append(ids, predicted_words, target_words) self.cer_metric.append(ids, predicted_words, target_words) # compute the accuracy of the one-step-forward prediction self.acc_metric.append(p_seq, tokens_eos, tokens_eos_lens) return loss
def compute_objectives(self, predictions, targets, stage): """Computes the loss (NLL) given predictions and targets.""" if ( stage == sb.Stage.TRAIN and self.batch_count % show_results_every != 0 ): p_seq, decoded_transcript_lens = predictions else: p_seq, decoded_transcript_lens, predicted_tokens = predictions ids, target_semantics, target_semantics_lens = targets target_tokens, target_token_lens = self.hparams.tokenizer( target_semantics, target_semantics_lens, self.hparams.ind2lab, task="encode", ) target_tokens = target_tokens.to(self.device) target_token_lens = target_token_lens.to(self.device) # Add char_lens by one for eos token abs_length = torch.round(target_token_lens * target_tokens.shape[1]) # Append eos token at the end of the label sequences target_tokens_with_eos = sb.dataio.dataio.append_eos_token( target_tokens, length=abs_length, eos_index=self.hparams.eos_index ) # Convert to speechbrain-style relative length rel_length = (abs_length + 1) / target_tokens_with_eos.shape[1] loss_seq = self.hparams.seq_cost( p_seq, target_tokens_with_eos, length=rel_length ) # (No ctc loss) loss = loss_seq if ( stage != sb.Stage.TRAIN or self.batch_count % show_results_every == 0 ): # Decode token terms to words predicted_semantics = self.hparams.tokenizer( predicted_tokens, task="decode_from_list" ) # Convert indices to words target_semantics = undo_padding( target_semantics, target_semantics_lens ) target_semantics = sb.dataio.dataio.convert_index_to_lab( target_semantics, self.hparams.ind2lab ) for i in range(len(target_semantics)): print(" ".join(predicted_semantics[i]).replace("|", ",")) print(" ".join(target_semantics[i]).replace("|", ",")) print("") if stage != sb.Stage.TRAIN: self.wer_metric.append( ids, predicted_semantics, target_semantics ) self.cer_metric.append( ids, predicted_semantics, target_semantics ) if stage == sb.Stage.TEST: # write to "predictions.jsonl" with jsonlines.open( hparams["output_folder"] + "/predictions.jsonl", mode="a" ) as writer: for i in range(len(predicted_semantics)): try: dict = ast.literal_eval( " ".join(predicted_semantics[i]).replace( "|", "," ) ) except SyntaxError: # need this if the output is not a valid dictionary dict = { "scenario": "none", "action": "none", "entities": [], } dict["file"] = id_to_file[ids[i]] writer.write(dict) return loss
def compute_objectives(self, predictions, batch, stage): """Compute possibly several loss terms: enhance, mimic, ctc, seq""" # Do not augment targets clean_wavs, clean_feats, lens = self.prepare_feats(batch.clean_sig, augment=False) loss = 0 # Compute enhancement loss if self.hparams.enhance_weight > 0: enhance_loss = self.hparams.enhance_loss(predictions["feats"], clean_feats, lens) loss += self.hparams.enhance_weight * enhance_loss if stage != sb.Stage.TRAIN: self.enh_metrics.append(batch.id, predictions["feats"], clean_feats, lens) self.stoi_metrics.append( ids=batch.id, predict=predictions["wavs"], target=clean_wavs, lengths=lens, ) self.pesq_metrics.append( ids=batch.id, predict=predictions["wavs"], target=clean_wavs, lengths=lens, ) # Compute mimic loss if self.hparams.mimic_weight > 0: clean_embed = self.modules.src_embedding.CNN(clean_feats) enh_embed = self.modules.src_embedding.CNN(predictions["feats"]) mimic_loss = self.hparams.mimic_loss(enh_embed, clean_embed, lens) loss += self.hparams.mimic_weight * mimic_loss if stage != sb.Stage.TRAIN: self.mimic_metrics.append(batch.id, enh_embed, clean_embed, lens) # Compute hard ASR loss if self.hparams.ctc_weight > 0 and ( not hasattr(self.hparams, "ctc_epochs") or self.hparams.epoch_counter.current < self.hparams.ctc_epochs): tokens, token_lens = self.prepare_targets(batch.tokens) ctc_loss = self.hparams.ctc_loss(predictions["ctc_pout"], tokens, lens, token_lens) loss += self.hparams.ctc_weight * ctc_loss if stage != sb.Stage.TRAIN and self.hparams.seq_weight == 0: predict = sb.decoders.ctc_greedy_decode( predictions["ctc_pout"], lens, blank_id=-1) self.err_rate_metrics.append( ids=batch.id, predict=predict, target=tokens, target_len=token_lens, ind2lab=self.hparams.ind2lab, ) # Compute nll loss for seq2seq model if self.hparams.seq_weight > 0: tokens, token_lens = self.prepare_targets(batch.tokens_eos) seq_loss = self.hparams.seq_loss(predictions["seq_pout"], tokens, token_lens) loss += self.hparams.seq_weight * seq_loss if stage != sb.Stage.TRAIN and self.hparams.target_type == "wrd": pred_words = self.tokenizer(predictions["hyps"], task="decode_from_list") target_words = self.tokenizer(undo_padding(*batch.tokens), task="decode_from_list") self.err_rate_metrics.append(batch.id, pred_words, target_words) elif stage != sb.Stage.TRAIN: self.err_rate_metrics.append( ids=batch.id, predict=predictions["hyps"], target=tokens, target_len=token_lens, ind2lab=self.tokenizer.decode_ndim, ) return loss
def compute_objectives(self, predictions, batch, stage): """Computes the loss (Transducer+(CTC+NLL)) given predictions and targets.""" ids = batch.id current_epoch = self.hparams.epoch_counter.current tokens, token_lens = batch.tokens tokens_eos, token_eos_lens = batch.tokens_eos if stage == sb.Stage.TRAIN: if len(predictions) == 4: p_ctc, p_ce, p_transducer, wav_lens = predictions CTC_loss = self.hparams.ctc_cost( p_ctc, tokens, wav_lens, token_lens ) CE_loss = self.hparams.ce_cost( p_ce, tokens_eos, length=token_eos_lens ) loss_transducer = self.hparams.transducer_cost( p_transducer, tokens, wav_lens, token_lens ) loss = ( self.hparams.ctc_weight * CTC_loss + self.hparams.ce_weight * CE_loss + (1 - (self.hparams.ctc_weight + self.hparams.ce_weight)) * loss_transducer ) elif len(predictions) == 3: # one of the 2 heads (CTC or CE) is still computed # CTC alive if current_epoch <= self.hparams.number_of_ctc_epochs: p_ctc, p_transducer, wav_lens = predictions CTC_loss = self.hparams.ctc_cost( p_ctc, tokens, wav_lens, token_lens ) loss_transducer = self.hparams.transducer_cost( p_transducer, tokens, wav_lens, token_lens ) loss = ( self.hparams.ctc_weight * CTC_loss + (1 - self.hparams.ctc_weight) * loss_transducer ) # CE for decoder alive else: p_ce, p_transducer, wav_lens = predictions CE_loss = self.hparams.ce_cost( p_ce, tokens_eos, length=token_eos_lens ) loss_transducer = self.hparams.transducer_cost( p_transducer, tokens, wav_lens, token_lens ) loss = ( self.hparams.ce_weight * CE_loss + (1 - self.hparams.ctc_weight) * loss_transducer ) else: p_transducer, wav_lens = predictions loss = self.hparams.transducer_cost( p_transducer, tokens, wav_lens, token_lens ) else: p_transducer, wav_lens, predicted_tokens = predictions loss = self.hparams.transducer_cost( p_transducer, tokens, wav_lens, token_lens ) if stage != sb.Stage.TRAIN: # Decode token terms to words predicted_words = self.tokenizer( predicted_tokens, task="decode_from_list" ) # Convert indices to words target_words = undo_padding(tokens, token_lens) target_words = self.tokenizer(target_words, task="decode_from_list") self.wer_metric.append(ids, predicted_words, target_words) self.cer_metric.append(ids, predicted_words, target_words) return loss
def compute_objectives(self, predictions, batch, stage): """Compute possibly several loss terms: enhance, mimic, ctc, seq""" # Do not augment targets clean_wavs, clean_feats, lens = self.prepare_feats(batch.clean_sig, augment=False) loss = 0 # Compute enhancement loss if self.hparams.enhance_weight > 0: enhance_loss = self.hparams.enhance_loss(predictions["feats"], clean_feats, lens) loss += self.hparams.enhance_weight * enhance_loss if stage != sb.Stage.TRAIN: self.enh_metrics.append(batch.id, predictions["feats"], clean_feats, lens) self.stoi_metrics.append( ids=batch.id, predict=predictions["wavs"], target=clean_wavs, lengths=lens, ) self.pesq_metrics.append( ids=batch.id, predict=predictions["wavs"], target=clean_wavs, lengths=lens, ) if hasattr(self.hparams, "enh_dir"): abs_lens = lens * predictions["wavs"].size(1) for i, uid in enumerate(batch.id): length = int(abs_lens[i]) wav = predictions["wavs"][i, :length].unsqueeze(0) path = os.path.join(self.hparams.enh_dir, uid + ".wav") torchaudio.save(path, wav.cpu(), sample_rate=16000) # Compute mimic loss if self.hparams.mimic_weight > 0: clean_embed = self.modules.src_embedding.CNN(clean_feats) enh_embed = self.modules.src_embedding.CNN(predictions["feats"]) mimic_loss = self.hparams.mimic_loss(enh_embed, clean_embed, lens) loss += self.hparams.mimic_weight * mimic_loss if stage != sb.Stage.TRAIN: self.mimic_metrics.append(batch.id, enh_embed, clean_embed, lens) # Compute hard ASR loss if self.hparams.ctc_weight > 0 and ( not hasattr(self.hparams, "ctc_epochs") or self.hparams.epoch_counter.current < self.hparams.ctc_epochs): tokens, token_lens = self.prepare_targets(batch.tokens) ctc_loss = self.hparams.ctc_loss(predictions["ctc_pout"], tokens, lens, token_lens) loss += self.hparams.ctc_weight * ctc_loss if stage != sb.Stage.TRAIN and self.hparams.seq_weight == 0: predict = sb.decoders.ctc_greedy_decode( predictions["ctc_pout"], lens, blank_id=-1) self.err_rate_metrics.append( ids=batch.id, predict=predict, target=tokens, target_len=token_lens, ind2lab=self.hparams.ind2lab, ) # Compute nll loss for seq2seq model if self.hparams.seq_weight > 0: tokens, token_lens = self.prepare_targets(batch.tokens_eos) seq_loss = self.hparams.seq_loss(predictions["seq_pout"], tokens, token_lens) loss += self.hparams.seq_weight * seq_loss if stage != sb.Stage.TRAIN: if hasattr(self.hparams, "asr_pretrained"): pred_words = [ self.token_encoder.decode_ids(token_seq) for token_seq in predictions["hyps"] ] target_words = [ self.token_encoder.decode_ids(token_seq) for token_seq in undo_padding(*batch.tokens) ] self.err_rate_metrics.append(batch.id, pred_words, target_words) else: self.err_rate_metrics.append( ids=batch.id, predict=predictions["hyps"], target=tokens, target_len=token_lens, ind2lab=self.token_encoder.decode_ndim, ) return loss
def compute_forward_tea(self, x, y, init_params=False): ids, wavs, wav_lens = x ids, phns, phn_lens = y wavs, wav_lens = wavs.to(params.device), wav_lens.to(params.device) phns, phn_lens = phns.to(params.device), phn_lens.to(params.device) if hasattr(params, "augmentation"): wavs = params.augmentation(wavs, wav_lens, init_params) feats = params.compute_features(wavs, init_params) feats = params.normalize(feats, wav_lens) apply_softmax = torch.nn.Softmax(dim=-1) ind2lab = params.train_loader.label_dict["phn"]["index2lab"] phns_decode = undo_padding(phns, phn_lens) phns_decode = convert_index_to_lab(phns_decode, ind2lab) # run inference to each teacher model tea_dict_list = [] for num in range(params.num_tea): tea_dict = {} self.tea_modules_list[num].eval() with torch.no_grad(): x_tea = tea_enc_list[num](feats, init_params=init_params) ctc_logits_tea = tea_ctc_lin_list[num](x_tea, init_params) # output layer for ctc log-probabilities p_ctc_tea = params.log_softmax(ctc_logits_tea / params.T) # Prepend bos token at the beginning y_in_tea = prepend_bos_token(phns, bos_index=params.bos_index) e_in_tea = tea_emb_list[num](y_in_tea, init_params=init_params) h_tea, _ = tea_dec_list[num]( e_in_tea, x_tea, wav_lens, init_params ) # output layer for seq2seq log-probabilities seq_logits_tea = tea_seq_lin_list[num](h_tea, init_params) p_seq_tea = apply_softmax(seq_logits_tea / params.T) # WER from output layer of CTC sequence_ctc = ctc_greedy_decode( p_ctc_tea, wav_lens, blank_id=params.blank_index ) sequence_ctc = convert_index_to_lab(sequence_ctc, ind2lab) per_stats_ctc = edit_distance.wer_details_for_batch( ids, phns_decode, sequence_ctc, compute_alignments=False ) wer_ctc_tea = [] for item in per_stats_ctc: wer_ctc_tea.append(item["WER"]) wer_ctc_tea = exclude_wer(wer_ctc_tea) wer_ctc_tea = np.expand_dims(wer_ctc_tea, axis=0) # WER from output layer of CE _, predictions = p_seq_tea.max(dim=-1) hyps = batch_filter_seq2seq_output( predictions, eos_id=params.eos_index ) sequence_ce = convert_index_to_lab(hyps, ind2lab) per_stats_ce = edit_distance.wer_details_for_batch( ids, phns_decode, sequence_ce, compute_alignments=False ) wer_tea = [] for item in per_stats_ce: wer_tea.append(item["WER"]) wer_tea = exclude_wer(wer_tea) wer_tea = np.expand_dims(wer_tea, axis=0) # save the variables into dict tea_dict["p_ctc_tea"] = p_ctc_tea.cpu().numpy() tea_dict["p_seq_tea"] = p_seq_tea.cpu().numpy() tea_dict["wer_ctc_tea"] = wer_ctc_tea tea_dict["wer_tea"] = wer_tea tea_dict_list.append(tea_dict) return tea_dict_list
def compute_objectives(self, predictions, targets, data_dict, batch_id, stage="train"): if stage == "train": p_ctc, p_seq, wav_lens = predictions else: p_ctc, p_seq, wav_lens, hyps = predictions ids, phns, phn_lens = targets phns, phn_lens = phns.to(params.device), phn_lens.to(params.device) # Add phn_lens by one for eos token abs_length = torch.round(phn_lens * phns.shape[1]) # Append eos token at the end of the label sequences phns_with_eos = append_eos_token(phns, length=abs_length, eos_index=params.eos_index) # convert to speechbrain-style relative length rel_length = (abs_length + 1) / phns.shape[1] # normal supervised training loss_ctc_nor = params.ctc_cost(p_ctc, phns, wav_lens, phn_lens) loss_seq_nor = params.seq_cost(p_seq, phns_with_eos, length=rel_length) # load teacher inference results item_tea_list = [None, None, None, None] for tea_num in range(params.num_tea): for i in range(4): item_tea = data_dict[str(batch_id)][tea_name[tea_num]][ tea_keys[i]][()] if tea_keys[i].startswith("wer"): item_tea = torch.tensor(item_tea) else: item_tea = torch.from_numpy(item_tea) item_tea = item_tea.to(params.device) item_tea = torch.unsqueeze(item_tea, 0) if tea_num == 0: item_tea_list[i] = item_tea else: item_tea_list[i] = torch.cat([item_tea_list[i], item_tea], 0) p_ctc_tea = item_tea_list[0] p_seq_tea = item_tea_list[1] wer_ctc_tea = item_tea_list[2] wer_tea = item_tea_list[3] # Stategy "average": average losses of teachers when doing distillation. # Stategy "best": choosing the best teacher based on WER. # Stategy "weighted": assigning weights to teachers based on WER. if params.strategy == "best": # tea_ce for kd wer_scores, indx = torch.min(wer_tea, dim=0) indx = list(indx.cpu().numpy()) # select the best teacher for each sentence tea_seq2seq_pout = None for stn_indx, tea_indx in enumerate(indx): s2s_one = p_seq_tea[tea_indx][stn_indx] s2s_one = torch.unsqueeze(s2s_one, 0) if stn_indx == 0: tea_seq2seq_pout = s2s_one else: tea_seq2seq_pout = torch.cat([tea_seq2seq_pout, s2s_one], 0) apply_softmax = torch.nn.Softmax(dim=0) if params.strategy == "best" or params.strategy == "weighted": # mean wer for ctc tea_wer_ctc_mean = wer_ctc_tea.mean(1) tea_acc_main = 100 - tea_wer_ctc_mean # normalise weights via Softmax function tea_acc_softmax = apply_softmax(tea_acc_main) if params.strategy == "weighted": # mean wer for ce tea_wer_mean = wer_tea.mean(1) tea_acc_ce_main = 100 - tea_wer_mean # normalise weights via Softmax function tea_acc_ce_softmax = apply_softmax(tea_acc_ce_main) # kd loss ctc_loss_list = None ce_loss_list = None for tea_num in range(params.num_tea): # ctc p_ctc_tea_one = p_ctc_tea[tea_num] # calculate CTC distillation loss of one teacher loss_ctc_one = params.ctc_cost_kd(p_ctc, p_ctc_tea_one, wav_lens) loss_ctc_one = torch.unsqueeze(loss_ctc_one, 0) if tea_num == 0: ctc_loss_list = loss_ctc_one else: ctc_loss_list = torch.cat([ctc_loss_list, loss_ctc_one]) # ce p_seq_tea_one = p_seq_tea[tea_num] # calculate CE distillation loss of one teacher loss_seq_one = params.seq_cost_kd(p_seq, p_seq_tea_one, rel_length) loss_seq_one = torch.unsqueeze(loss_seq_one, 0) if tea_num == 0: ce_loss_list = loss_seq_one else: ce_loss_list = torch.cat([ce_loss_list, loss_seq_one]) # kd loss if params.strategy == "average": # get average value of losses from all teachers (CTC and CE loss) ctc_loss_kd = ctc_loss_list.mean(0) seq2seq_loss_kd = ce_loss_list.mean(0) else: # assign weights to different teachers (CTC loss) ctc_loss_kd = (tea_acc_softmax * ctc_loss_list).sum(0) if params.strategy == "best": # only use the best teacher to compute CE loss seq2seq_loss_kd = params.seq_cost_kd(p_seq, tea_seq2seq_pout, rel_length) if params.strategy == "weighted": # assign weights to different teachers (CE loss) seq2seq_loss_kd = (tea_acc_ce_softmax * ce_loss_list).sum(0) # total loss # combine normal supervised training loss_ctc = (params.Temperature * params.Temperature * params.alpha * ctc_loss_kd + (1 - params.alpha) * loss_ctc_nor) loss_seq = (params.Temperature * params.Temperature * params.alpha * seq2seq_loss_kd + (1 - params.alpha) * loss_seq_nor) loss = params.ctc_weight * loss_ctc + (1 - params.ctc_weight) * loss_seq stats = {} if stage != "train": ind2lab = params.train_loader.label_dict["phn"]["index2lab"] sequence = convert_index_to_lab(hyps, ind2lab) phns = undo_padding(phns, phn_lens) phns = convert_index_to_lab(phns, ind2lab) per_stats = edit_distance.wer_details_for_batch( ids, phns, sequence, compute_alignments=True) stats["PER"] = per_stats return loss, stats