def compute_objectives(self, predictions, targets, stage="train"):
        if stage == "train":
            p_ctc, p_seq, wav_lens = predictions
        else:
            p_ctc, p_seq, wav_lens, hyps = predictions

        ids, phns, phn_lens = targets
        phns, phn_lens = phns.to(params.device), phn_lens.to(params.device)

        # Add phn_lens by one for eos token
        abs_length = torch.round(phn_lens * phns.shape[1])

        # Append eos token at the end of the label sequences
        phns_with_eos = append_eos_token(
            phns, length=abs_length, eos_index=params.eos_index
        )

        # convert to speechbrain-style relative length
        rel_length = (abs_length + 1) / phns.shape[1]

        loss_ctc = params.ctc_cost(p_ctc, phns, wav_lens, phn_lens)
        loss_seq = params.seq_cost(p_seq, phns_with_eos, length=rel_length)
        loss = params.ctc_weight * loss_ctc + (1 - params.ctc_weight) * loss_seq

        stats = {}
        if stage != "train":
            ind2lab = params.train_loader.label_dict["phn"]["index2lab"]
            sequence = convert_index_to_lab(hyps, ind2lab)
            phns = undo_padding(phns, phn_lens)
            phns = convert_index_to_lab(phns, ind2lab)
            per_stats = edit_distance.wer_details_for_batch(
                ids, phns, sequence, compute_alignments=True
            )
            stats["PER"] = per_stats
        return loss, stats
示例#2
0
    def append(
        self,
        ids,
        predict,
        target,
        predict_len=None,
        target_len=None,
        ind2lab=None,
    ):
        """Add stats to the relevant containers.

        * See MetricStats.append()

        Arguments
        ---------
        ids : list
            List of ids corresponding to utterances.
        predict : torch.tensor
            A predicted output, for comparison with the target output
        target : torch.tensor
            The correct reference output, for comparison with the prediction.
        predict_len : torch.tensor
            The predictions relative lengths, used to undo padding if
            there is padding present in the predictions.
        target_len : torch.tensor
            The target outputs' relative lengths, used to undo padding if
            there is padding present in the target.
        ind2lab : callable
            Callable that maps from indices to labels, operating on batches,
            for writing alignments.
        """
        self.ids.extend(ids)

        if predict_len is not None:
            predict = undo_padding(predict, predict_len)

        if target_len is not None:
            target = undo_padding(target, target_len)

        if ind2lab is not None:
            predict = ind2lab(predict)
            target = ind2lab(target)

        if self.merge_tokens:
            predict = merge_char(predict)
            target = merge_char(target)

        if self.split_tokens:
            predict = split_word(predict)
            target = split_word(target)

        scores = edit_distance.wer_details_for_batch(ids, target, predict,
                                                     True)

        self.scores.extend(scores)
示例#3
0
    def _check_coverage_from_bpe(self, list_csv_files=[]):
        """Logging the accuracy of the BPE model to recover words from the training text.

        Arguments
        ---------
        csv_list_to_check : list,
            List of the csv file which is used for checking the accuracy of recovering words from the tokenizer.
        """
        for csv_file in list_csv_files:
            if os.path.isfile(os.path.abspath(csv_file)):
                logger.info(
                    "==== Accuracy checking for recovering text from tokenizer ==="
                )
                fcsv_file = open(csv_file, "r")
                reader = csv.reader(fcsv_file)
                headers = next(reader, None)
                if self.csv_read not in headers:
                    raise ValueError(self.csv_read + " must exist in:" +
                                     csv_file)
                index_label = headers.index(self.csv_read)
                wrong_recover_list = []
                for row in reader:
                    row = row[index_label]
                    if self.char_format_input:
                        (row, ) = merge_char([row.split()])
                        row = " ".join(row)
                    row = row.split("\n")[0]
                    encoded_id = self.sp.encode_as_ids(row)
                    decode_text = self.sp.decode_ids(encoded_id)
                    (details, ) = edit_distance.wer_details_for_batch(
                        ["utt1"],
                        [row.split(" ")],
                        [decode_text.split(" ")],
                        compute_alignments=True,
                    )
                    if details["WER"] > 0:
                        for align in details["alignment"]:
                            if align[0] != "=" and align[1] is not None:
                                if align[1] not in wrong_recover_list:
                                    wrong_recover_list.append(align[1])
                fcsv_file.close()
                logger.info("recover words from: " + csv_file)
                if len(wrong_recover_list) > 0:
                    logger.warn("Wrong recover words: " +
                                str(len(wrong_recover_list)))
                    logger.warn("Tokenizer vocab size: " +
                                str(self.sp.vocab_size()))
                    logger.warn("accuracy recovering words: " +
                                str(1 - float(len(wrong_recover_list)) /
                                    self.sp.vocab_size()))
                else:
                    logger.info("Wrong recover words: 0")
                    logger.warning("accuracy recovering words: " + str(1.0))
            else:
                logger.info("No accuracy recover checking for" + csv_file)
示例#4
0
    def compute_forward_tea(self, x, y, init_params=False):
        ids, wavs, wav_lens = x
        ids, phns, phn_lens = y

        wavs, wav_lens = wavs.to(params.device), wav_lens.to(params.device)
        phns, phn_lens = phns.to(params.device), phn_lens.to(params.device)

        if hasattr(params, "augmentation"):
            wavs = params.augmentation(wavs, wav_lens, init_params)
        feats = params.compute_features(wavs, init_params)
        feats = params.normalize(feats, wav_lens)
        apply_softmax = torch.nn.Softmax(dim=-1)

        ind2lab = params.train_loader.label_dict["phn"]["index2lab"]
        phns_decode = undo_padding(phns, phn_lens)
        phns_decode = convert_index_to_lab(phns_decode, ind2lab)

        # run inference to each teacher model
        tea_dict_list = []
        for num in range(params.num_tea):
            tea_dict = {}
            self.tea_modules_list[num].eval()
            with torch.no_grad():
                x_tea = tea_enc_list[num](feats, init_params=init_params)
                ctc_logits_tea = tea_ctc_lin_list[num](x_tea, init_params)

                # output layer for ctc log-probabilities
                p_ctc_tea = params.log_softmax(ctc_logits_tea / params.T)

                # Prepend bos token at the beginning
                y_in_tea = prepend_bos_token(phns, bos_index=params.bos_index)
                e_in_tea = tea_emb_list[num](y_in_tea, init_params=init_params)
                h_tea, _ = tea_dec_list[num](
                    e_in_tea, x_tea, wav_lens, init_params
                )

                # output layer for seq2seq log-probabilities
                seq_logits_tea = tea_seq_lin_list[num](h_tea, init_params)
                p_seq_tea = apply_softmax(seq_logits_tea / params.T)

                # WER from output layer of CTC
                sequence_ctc = ctc_greedy_decode(
                    p_ctc_tea, wav_lens, blank_id=params.blank_index
                )
                sequence_ctc = convert_index_to_lab(sequence_ctc, ind2lab)
                per_stats_ctc = edit_distance.wer_details_for_batch(
                    ids, phns_decode, sequence_ctc, compute_alignments=False
                )

                wer_ctc_tea = []
                for item in per_stats_ctc:
                    wer_ctc_tea.append(item["WER"])

                wer_ctc_tea = exclude_wer(wer_ctc_tea)
                wer_ctc_tea = np.expand_dims(wer_ctc_tea, axis=0)

                # WER from output layer of CE
                _, predictions = p_seq_tea.max(dim=-1)
                hyps = batch_filter_seq2seq_output(
                    predictions, eos_id=params.eos_index
                )
                sequence_ce = convert_index_to_lab(hyps, ind2lab)
                per_stats_ce = edit_distance.wer_details_for_batch(
                    ids, phns_decode, sequence_ce, compute_alignments=False
                )

                wer_tea = []
                for item in per_stats_ce:
                    wer_tea.append(item["WER"])

                wer_tea = exclude_wer(wer_tea)
                wer_tea = np.expand_dims(wer_tea, axis=0)

            # save the variables into dict
            tea_dict["p_ctc_tea"] = p_ctc_tea.cpu().numpy()
            tea_dict["p_seq_tea"] = p_seq_tea.cpu().numpy()
            tea_dict["wer_ctc_tea"] = wer_ctc_tea
            tea_dict["wer_tea"] = wer_tea
            tea_dict_list.append(tea_dict)

        return tea_dict_list
示例#5
0
    def compute_objectives(self,
                           predictions,
                           targets,
                           data_dict,
                           batch_id,
                           stage="train"):
        if stage == "train":
            p_ctc, p_seq, wav_lens = predictions
        else:
            p_ctc, p_seq, wav_lens, hyps = predictions

        ids, phns, phn_lens = targets
        phns, phn_lens = phns.to(params.device), phn_lens.to(params.device)

        # Add phn_lens by one for eos token
        abs_length = torch.round(phn_lens * phns.shape[1])

        # Append eos token at the end of the label sequences
        phns_with_eos = append_eos_token(phns,
                                         length=abs_length,
                                         eos_index=params.eos_index)

        # convert to speechbrain-style relative length
        rel_length = (abs_length + 1) / phns.shape[1]

        # normal supervised training
        loss_ctc_nor = params.ctc_cost(p_ctc, phns, wav_lens, phn_lens)
        loss_seq_nor = params.seq_cost(p_seq, phns_with_eos, length=rel_length)

        # load teacher inference results
        item_tea_list = [None, None, None, None]
        for tea_num in range(params.num_tea):
            for i in range(4):
                item_tea = data_dict[str(batch_id)][tea_name[tea_num]][
                    tea_keys[i]][()]

                if tea_keys[i].startswith("wer"):
                    item_tea = torch.tensor(item_tea)
                else:
                    item_tea = torch.from_numpy(item_tea)

                item_tea = item_tea.to(params.device)
                item_tea = torch.unsqueeze(item_tea, 0)
                if tea_num == 0:
                    item_tea_list[i] = item_tea
                else:
                    item_tea_list[i] = torch.cat([item_tea_list[i], item_tea],
                                                 0)

        p_ctc_tea = item_tea_list[0]
        p_seq_tea = item_tea_list[1]
        wer_ctc_tea = item_tea_list[2]
        wer_tea = item_tea_list[3]

        # Stategy "average": average losses of teachers when doing distillation.
        # Stategy "best": choosing the best teacher based on WER.
        # Stategy "weighted": assigning weights to teachers based on WER.
        if params.strategy == "best":
            # tea_ce for kd
            wer_scores, indx = torch.min(wer_tea, dim=0)
            indx = list(indx.cpu().numpy())

            # select the best teacher for each sentence
            tea_seq2seq_pout = None
            for stn_indx, tea_indx in enumerate(indx):
                s2s_one = p_seq_tea[tea_indx][stn_indx]
                s2s_one = torch.unsqueeze(s2s_one, 0)
                if stn_indx == 0:
                    tea_seq2seq_pout = s2s_one
                else:
                    tea_seq2seq_pout = torch.cat([tea_seq2seq_pout, s2s_one],
                                                 0)

        apply_softmax = torch.nn.Softmax(dim=0)

        if params.strategy == "best" or params.strategy == "weighted":
            # mean wer for ctc
            tea_wer_ctc_mean = wer_ctc_tea.mean(1)
            tea_acc_main = 100 - tea_wer_ctc_mean

            # normalise weights via Softmax function
            tea_acc_softmax = apply_softmax(tea_acc_main)

        if params.strategy == "weighted":
            # mean wer for ce
            tea_wer_mean = wer_tea.mean(1)
            tea_acc_ce_main = 100 - tea_wer_mean

            # normalise weights via Softmax function
            tea_acc_ce_softmax = apply_softmax(tea_acc_ce_main)

        # kd loss
        ctc_loss_list = None
        ce_loss_list = None
        for tea_num in range(params.num_tea):
            # ctc
            p_ctc_tea_one = p_ctc_tea[tea_num]
            # calculate CTC distillation loss of one teacher
            loss_ctc_one = params.ctc_cost_kd(p_ctc, p_ctc_tea_one, wav_lens)
            loss_ctc_one = torch.unsqueeze(loss_ctc_one, 0)
            if tea_num == 0:
                ctc_loss_list = loss_ctc_one
            else:
                ctc_loss_list = torch.cat([ctc_loss_list, loss_ctc_one])

            # ce
            p_seq_tea_one = p_seq_tea[tea_num]
            # calculate CE distillation loss of one teacher
            loss_seq_one = params.seq_cost_kd(p_seq, p_seq_tea_one, rel_length)
            loss_seq_one = torch.unsqueeze(loss_seq_one, 0)
            if tea_num == 0:
                ce_loss_list = loss_seq_one
            else:
                ce_loss_list = torch.cat([ce_loss_list, loss_seq_one])

        # kd loss
        if params.strategy == "average":
            # get average value of losses from all teachers (CTC and CE loss)
            ctc_loss_kd = ctc_loss_list.mean(0)
            seq2seq_loss_kd = ce_loss_list.mean(0)
        else:
            # assign weights to different teachers (CTC loss)
            ctc_loss_kd = (tea_acc_softmax * ctc_loss_list).sum(0)
            if params.strategy == "best":
                # only use the best teacher to compute CE loss
                seq2seq_loss_kd = params.seq_cost_kd(p_seq, tea_seq2seq_pout,
                                                     rel_length)
            if params.strategy == "weighted":
                # assign weights to different teachers (CE loss)
                seq2seq_loss_kd = (tea_acc_ce_softmax * ce_loss_list).sum(0)

        # total loss
        # combine normal supervised training
        loss_ctc = (params.Temperature * params.Temperature * params.alpha *
                    ctc_loss_kd + (1 - params.alpha) * loss_ctc_nor)
        loss_seq = (params.Temperature * params.Temperature * params.alpha *
                    seq2seq_loss_kd + (1 - params.alpha) * loss_seq_nor)

        loss = params.ctc_weight * loss_ctc + (1 -
                                               params.ctc_weight) * loss_seq

        stats = {}
        if stage != "train":
            ind2lab = params.train_loader.label_dict["phn"]["index2lab"]
            sequence = convert_index_to_lab(hyps, ind2lab)
            phns = undo_padding(phns, phn_lens)
            phns = convert_index_to_lab(phns, ind2lab)
            per_stats = edit_distance.wer_details_for_batch(
                ids, phns, sequence, compute_alignments=True)
            stats["PER"] = per_stats
        return loss, stats