예제 #1
0
    def append(
        self,
        ids,
        predict,
        target,
        predict_len=None,
        target_len=None,
        ind2lab=None,
    ):
        """Add stats to the relevant containers.

        * See MetricStats.append()

        Arguments
        ---------
        ids : list
            List of ids corresponding to utterances.
        predict : torch.tensor
            A predicted output, for comparison with the target output
        target : torch.tensor
            The correct reference output, for comparison with the prediction.
        predict_len : torch.tensor
            The predictions relative lengths, used to undo padding if
            there is padding present in the predictions.
        target_len : torch.tensor
            The target outputs' relative lengths, used to undo padding if
            there is padding present in the target.
        ind2lab : callable
            Callable that maps from indices to labels, operating on batches,
            for writing alignments.
        """
        self.ids.extend(ids)

        if predict_len is not None:
            predict = undo_padding(predict, predict_len)

        if target_len is not None:
            target = undo_padding(target, target_len)

        if ind2lab is not None:
            predict = ind2lab(predict)
            target = ind2lab(target)

        if self.merge_tokens:
            predict = merge_char(predict)
            target = merge_char(target)

        if self.split_tokens:
            predict = split_word(predict)
            target = split_word(target)

        scores = edit_distance.wer_details_for_batch(ids, target, predict,
                                                     True)

        self.scores.extend(scores)
예제 #2
0
    def compute_objectives(self, predictions, batch, stage):
        """Computes the loss (CTC) given predictions and targets."""

        p_ctc, wav_lens = predictions

        ids = batch.id
        tokens_eos, tokens_eos_lens = batch.tokens_eos
        tokens, tokens_lens = batch.tokens

        loss = self.hparams.ctc_cost(p_ctc, tokens, wav_lens, tokens_lens)

        if stage != sb.Stage.TRAIN:
            # Decode token terms to words
            sequence = sb.decoders.ctc_greedy_decode(
                p_ctc, wav_lens, blank_id=self.hparams.blank_index)

            predicted_words = self.tokenizer(sequence, task="decode_from_list")

            # Convert indices to words
            target_words = undo_padding(tokens, tokens_lens)
            target_words = self.tokenizer(target_words,
                                          task="decode_from_list")

            self.wer_metric.append(ids, predicted_words, target_words)
            self.cer_metric.append(ids, predicted_words, target_words)

        return loss
예제 #3
0
    def compute_objectives(self, predictions, targets, stage="train"):
        if stage == "train":
            p_ctc, p_seq, wav_lens = predictions
        else:
            p_ctc, p_seq, wav_lens, hyps = predictions

        ids, phns, phn_lens = targets
        phns, phn_lens = phns.to(params.device), phn_lens.to(params.device)

        # Add phn_lens by one for eos token
        abs_length = torch.round(phn_lens * phns.shape[1])

        # Append eos token at the end of the label sequences
        phns_with_eos = append_eos_token(
            phns, length=abs_length, eos_index=params.eos_index
        )

        # convert to speechbrain-style relative length
        rel_length = (abs_length + 1) / phns.shape[1]

        loss_ctc = params.ctc_cost(p_ctc, phns, wav_lens, phn_lens)
        loss_seq = params.seq_cost(p_seq, phns_with_eos, length=rel_length)
        loss = params.ctc_weight * loss_ctc + (1 - params.ctc_weight) * loss_seq

        stats = {}
        if stage != "train":
            ind2lab = params.train_loader.label_dict["phn"]["index2lab"]
            sequence = convert_index_to_lab(hyps, ind2lab)
            phns = undo_padding(phns, phn_lens)
            phns = convert_index_to_lab(phns, ind2lab)
            per_stats = edit_distance.wer_details_for_batch(
                ids, phns, sequence, compute_alignments=True
            )
            stats["PER"] = per_stats
        return loss, stats
예제 #4
0
    def compute_objectives(self, predictions, targets, stage):
        """Computes the loss (NLL) given predictions and targets."""

        if (stage == sb.Stage.TRAIN
                and self.batch_count % show_results_every != 0):
            p_seq, decoded_transcript_lens = predictions
        else:
            p_seq, decoded_transcript_lens, predicted_tokens = predictions

        ids, target_semantics, target_semantics_lens = targets
        target_tokens, target_token_lens = self.hparams.tokenizer(
            target_semantics,
            target_semantics_lens,
            self.hparams.ind2lab,
            task="encode",
        )
        target_tokens = target_tokens.to(self.device)
        target_token_lens = target_token_lens.to(self.device)

        # Add char_lens by one for eos token
        abs_length = torch.round(target_token_lens * target_tokens.shape[1])

        # Append eos token at the end of the label sequences
        target_tokens_with_eos = sb.dataio.dataio.append_eos_token(
            target_tokens, length=abs_length, eos_index=self.hparams.eos_index)

        # Convert to speechbrain-style relative length
        rel_length = (abs_length + 1) / target_tokens_with_eos.shape[1]
        loss_seq = self.hparams.seq_cost(p_seq,
                                         target_tokens_with_eos,
                                         length=rel_length)

        # (No ctc loss)
        loss = loss_seq

        if (stage != sb.Stage.TRAIN
                or self.batch_count % show_results_every == 0):
            # Decode token terms to words
            predicted_semantics = self.hparams.tokenizer(
                predicted_tokens, task="decode_from_list")

            # Convert indices to words
            target_semantics = undo_padding(target_semantics,
                                            target_semantics_lens)
            target_semantics = sb.dataio.dataio.convert_index_to_lab(
                target_semantics, self.hparams.ind2lab)
            for i in range(len(target_semantics)):
                print(" ".join(predicted_semantics[i]).replace("|", ","))
                print(" ".join(target_semantics[i]).replace("|", ","))
                print("")

            if stage != sb.Stage.TRAIN:
                self.wer_metric.append(ids, predicted_semantics,
                                       target_semantics)
                self.cer_metric.append(ids, predicted_semantics,
                                       target_semantics)

        return loss
예제 #5
0
    def compute_objectives(self, predictions, batch, stage):
        """Computes the loss (CTC+NLL) given predictions and targets."""

        current_epoch = self.hparams.epoch_counter.current
        if stage == sb.Stage.TRAIN:
            if current_epoch <= self.hparams.number_of_ctc_epochs:
                p_ctc, p_seq, wav_lens = predictions
            else:
                p_seq, wav_lens = predictions
        else:
            p_seq, wav_lens, predicted_tokens = predictions

        ids = batch.id
        tokens_eos, tokens_eos_lens = batch.tokens_eos
        tokens, tokens_lens = batch.tokens

        loss_seq = self.hparams.seq_cost(
            p_seq, tokens_eos, length=tokens_eos_lens
        )

        # Add ctc loss if necessary
        if (
            stage == sb.Stage.TRAIN
            and current_epoch <= self.hparams.number_of_ctc_epochs
        ):
            loss_ctc = self.hparams.ctc_cost(
                p_ctc, tokens, wav_lens, tokens_lens
            )
            loss = self.hparams.ctc_weight * loss_ctc
            loss += (1 - self.hparams.ctc_weight) * loss_seq
        else:
            loss = loss_seq

        if stage != sb.Stage.TRAIN:
            # Decode token terms to words
            predicted_words = self.tokenizer(
                predicted_tokens, task="decode_from_list"
            )

            # Convert indices to words
            target_words = undo_padding(tokens, tokens_lens)
            target_words = self.tokenizer(target_words, task="decode_from_list")

            self.wer_metric.append(ids, predicted_words, target_words)
            self.cer_metric.append(ids, predicted_words, target_words)

        return loss
예제 #6
0
    def compute_objectives(self, predictions, batch, stage):
        """Computes the loss (CTC+NLL) given predictions and targets."""

        current_epoch = self.hparams.epoch_counter.current
        wav_lens = predictions["wav_lens"]

        tokens, tokens_lens = batch.tokens
        tokens_eos, tokens_eos_lens = batch.tokens_eos
        if hasattr(self.modules, "env_corrupt") and stage == sb.Stage.TRAIN:
            tokens = torch.cat([tokens, tokens], dim=0)
            tokens_lens = torch.cat([tokens_lens, tokens_lens])
            tokens_eos = torch.cat([tokens_eos, tokens_eos], dim=0)
            tokens_eos_lens = torch.cat([tokens_eos_lens, tokens_eos_lens])

        loss = self.hparams.seq_cost(
            predictions["p_seq"], tokens_eos, length=tokens_eos_lens
        )

        # Add ctc loss if necessary
        if (
            stage == sb.Stage.TRAIN
            and current_epoch <= self.hparams.number_of_ctc_epochs
        ):
            loss_ctc = self.hparams.ctc_cost(
                predictions["p_ctc"], tokens, wav_lens, tokens_lens
            )
            loss *= 1 - self.hparams.ctc_weight
            loss += self.hparams.ctc_weight * loss_ctc

        if stage != sb.Stage.TRAIN:
            # Decode token terms to words
            predicted_words = self.tokenizer(
                predictions["p_tokens"], task="decode_from_list"
            )

            # Convert indices to words
            target_words = undo_padding(tokens, tokens_lens)
            target_words = self.tokenizer(target_words, task="decode_from_list")

            self.wer_metric.append(batch.id, predicted_words, target_words)
            self.cer_metric.append(batch.id, predicted_words, target_words)

        return loss
예제 #7
0
파일: train.py 프로젝트: loicCQAM/musicSep
    def compute_objectives(self, predictions, batch, stage):
        """Computes the loss (CTC+NLL) given predictions and targets."""

        (
            p_ctc,
            p_seq,
            wav_lens,
            predicted_tokens,
        ) = predictions

        ids = batch.id
        tokens_eos, tokens_eos_lens = batch.tokens_eos
        tokens, tokens_lens = batch.tokens

        loss_seq = self.hparams.seq_cost(p_seq,
                                         tokens_eos,
                                         length=tokens_eos_lens)
        loss_ctc = self.hparams.ctc_cost(p_ctc, tokens, wav_lens, tokens_lens)
        loss = (self.hparams.ctc_weight * loss_ctc +
                (1 - self.hparams.ctc_weight) * loss_seq)

        if stage != sb.Stage.TRAIN:
            current_epoch = self.hparams.epoch_counter.current
            valid_search_interval = self.hparams.valid_search_interval
            if current_epoch % valid_search_interval == 0 or (
                    stage == sb.Stage.TEST):
                # Decode token terms to words
                predicted_words = self.tokenizer(predicted_tokens,
                                                 task="decode_from_list")

                # Convert indices to words
                target_words = undo_padding(tokens, tokens_lens)
                target_words = self.tokenizer(target_words,
                                              task="decode_from_list")
                self.wer_metric.append(ids, predicted_words, target_words)
                self.cer_metric.append(ids, predicted_words, target_words)

            # compute the accuracy of the one-step-forward prediction
            self.acc_metric.append(p_seq, tokens_eos, tokens_eos_lens)
        return loss
예제 #8
0
    def compute_objectives(self, predictions, targets, stage):
        """Computes the loss (NLL) given predictions and targets."""

        if (
            stage == sb.Stage.TRAIN
            and self.batch_count % show_results_every != 0
        ):
            p_seq, decoded_transcript_lens = predictions
        else:
            p_seq, decoded_transcript_lens, predicted_tokens = predictions

        ids, target_semantics, target_semantics_lens = targets
        target_tokens, target_token_lens = self.hparams.tokenizer(
            target_semantics,
            target_semantics_lens,
            self.hparams.ind2lab,
            task="encode",
        )
        target_tokens = target_tokens.to(self.device)
        target_token_lens = target_token_lens.to(self.device)

        # Add char_lens by one for eos token
        abs_length = torch.round(target_token_lens * target_tokens.shape[1])

        # Append eos token at the end of the label sequences
        target_tokens_with_eos = sb.dataio.dataio.append_eos_token(
            target_tokens, length=abs_length, eos_index=self.hparams.eos_index
        )

        # Convert to speechbrain-style relative length
        rel_length = (abs_length + 1) / target_tokens_with_eos.shape[1]
        loss_seq = self.hparams.seq_cost(
            p_seq, target_tokens_with_eos, length=rel_length
        )

        # (No ctc loss)
        loss = loss_seq

        if (
            stage != sb.Stage.TRAIN
            or self.batch_count % show_results_every == 0
        ):
            # Decode token terms to words
            predicted_semantics = self.hparams.tokenizer(
                predicted_tokens, task="decode_from_list"
            )

            # Convert indices to words
            target_semantics = undo_padding(
                target_semantics, target_semantics_lens
            )
            target_semantics = sb.dataio.dataio.convert_index_to_lab(
                target_semantics, self.hparams.ind2lab
            )
            for i in range(len(target_semantics)):
                print(" ".join(predicted_semantics[i]).replace("|", ","))
                print(" ".join(target_semantics[i]).replace("|", ","))
                print("")

            if stage != sb.Stage.TRAIN:
                self.wer_metric.append(
                    ids, predicted_semantics, target_semantics
                )
                self.cer_metric.append(
                    ids, predicted_semantics, target_semantics
                )

            if stage == sb.Stage.TEST:
                # write to "predictions.jsonl"
                with jsonlines.open(
                    hparams["output_folder"] + "/predictions.jsonl", mode="a"
                ) as writer:
                    for i in range(len(predicted_semantics)):
                        try:
                            dict = ast.literal_eval(
                                " ".join(predicted_semantics[i]).replace(
                                    "|", ","
                                )
                            )
                        except SyntaxError:  # need this if the output is not a valid dictionary
                            dict = {
                                "scenario": "none",
                                "action": "none",
                                "entities": [],
                            }
                        dict["file"] = id_to_file[ids[i]]
                        writer.write(dict)

        return loss
예제 #9
0
    def compute_objectives(self, predictions, batch, stage):
        """Compute possibly several loss terms: enhance, mimic, ctc, seq"""

        # Do not augment targets
        clean_wavs, clean_feats, lens = self.prepare_feats(batch.clean_sig,
                                                           augment=False)
        loss = 0

        # Compute enhancement loss
        if self.hparams.enhance_weight > 0:
            enhance_loss = self.hparams.enhance_loss(predictions["feats"],
                                                     clean_feats, lens)
            loss += self.hparams.enhance_weight * enhance_loss

            if stage != sb.Stage.TRAIN:
                self.enh_metrics.append(batch.id, predictions["feats"],
                                        clean_feats, lens)
                self.stoi_metrics.append(
                    ids=batch.id,
                    predict=predictions["wavs"],
                    target=clean_wavs,
                    lengths=lens,
                )
                self.pesq_metrics.append(
                    ids=batch.id,
                    predict=predictions["wavs"],
                    target=clean_wavs,
                    lengths=lens,
                )

        # Compute mimic loss
        if self.hparams.mimic_weight > 0:
            clean_embed = self.modules.src_embedding.CNN(clean_feats)
            enh_embed = self.modules.src_embedding.CNN(predictions["feats"])
            mimic_loss = self.hparams.mimic_loss(enh_embed, clean_embed, lens)
            loss += self.hparams.mimic_weight * mimic_loss

            if stage != sb.Stage.TRAIN:
                self.mimic_metrics.append(batch.id, enh_embed, clean_embed,
                                          lens)

        # Compute hard ASR loss
        if self.hparams.ctc_weight > 0 and (
                not hasattr(self.hparams, "ctc_epochs") or
                self.hparams.epoch_counter.current < self.hparams.ctc_epochs):
            tokens, token_lens = self.prepare_targets(batch.tokens)
            ctc_loss = self.hparams.ctc_loss(predictions["ctc_pout"], tokens,
                                             lens, token_lens)
            loss += self.hparams.ctc_weight * ctc_loss

            if stage != sb.Stage.TRAIN and self.hparams.seq_weight == 0:
                predict = sb.decoders.ctc_greedy_decode(
                    predictions["ctc_pout"], lens, blank_id=-1)
                self.err_rate_metrics.append(
                    ids=batch.id,
                    predict=predict,
                    target=tokens,
                    target_len=token_lens,
                    ind2lab=self.hparams.ind2lab,
                )

        # Compute nll loss for seq2seq model
        if self.hparams.seq_weight > 0:

            tokens, token_lens = self.prepare_targets(batch.tokens_eos)
            seq_loss = self.hparams.seq_loss(predictions["seq_pout"], tokens,
                                             token_lens)
            loss += self.hparams.seq_weight * seq_loss

            if stage != sb.Stage.TRAIN and self.hparams.target_type == "wrd":
                pred_words = self.tokenizer(predictions["hyps"],
                                            task="decode_from_list")
                target_words = self.tokenizer(undo_padding(*batch.tokens),
                                              task="decode_from_list")
                self.err_rate_metrics.append(batch.id, pred_words,
                                             target_words)
            elif stage != sb.Stage.TRAIN:
                self.err_rate_metrics.append(
                    ids=batch.id,
                    predict=predictions["hyps"],
                    target=tokens,
                    target_len=token_lens,
                    ind2lab=self.tokenizer.decode_ndim,
                )

        return loss
예제 #10
0
    def compute_objectives(self, predictions, batch, stage):
        """Computes the loss (Transducer+(CTC+NLL)) given predictions and targets."""

        ids = batch.id
        current_epoch = self.hparams.epoch_counter.current
        tokens, token_lens = batch.tokens
        tokens_eos, token_eos_lens = batch.tokens_eos

        if stage == sb.Stage.TRAIN:
            if len(predictions) == 4:
                p_ctc, p_ce, p_transducer, wav_lens = predictions
                CTC_loss = self.hparams.ctc_cost(
                    p_ctc, tokens, wav_lens, token_lens
                )
                CE_loss = self.hparams.ce_cost(
                    p_ce, tokens_eos, length=token_eos_lens
                )
                loss_transducer = self.hparams.transducer_cost(
                    p_transducer, tokens, wav_lens, token_lens
                )
                loss = (
                    self.hparams.ctc_weight * CTC_loss
                    + self.hparams.ce_weight * CE_loss
                    + (1 - (self.hparams.ctc_weight + self.hparams.ce_weight))
                    * loss_transducer
                )
            elif len(predictions) == 3:
                # one of the 2 heads (CTC or CE) is still computed
                # CTC alive
                if current_epoch <= self.hparams.number_of_ctc_epochs:
                    p_ctc, p_transducer, wav_lens = predictions
                    CTC_loss = self.hparams.ctc_cost(
                        p_ctc, tokens, wav_lens, token_lens
                    )
                    loss_transducer = self.hparams.transducer_cost(
                        p_transducer, tokens, wav_lens, token_lens
                    )
                    loss = (
                        self.hparams.ctc_weight * CTC_loss
                        + (1 - self.hparams.ctc_weight) * loss_transducer
                    )
                # CE for decoder alive
                else:
                    p_ce, p_transducer, wav_lens = predictions
                    CE_loss = self.hparams.ce_cost(
                        p_ce, tokens_eos, length=token_eos_lens
                    )
                    loss_transducer = self.hparams.transducer_cost(
                        p_transducer, tokens, wav_lens, token_lens
                    )
                    loss = (
                        self.hparams.ce_weight * CE_loss
                        + (1 - self.hparams.ctc_weight) * loss_transducer
                    )
            else:
                p_transducer, wav_lens = predictions
                loss = self.hparams.transducer_cost(
                    p_transducer, tokens, wav_lens, token_lens
                )
        else:
            p_transducer, wav_lens, predicted_tokens = predictions
            loss = self.hparams.transducer_cost(
                p_transducer, tokens, wav_lens, token_lens
            )

        if stage != sb.Stage.TRAIN:

            # Decode token terms to words
            predicted_words = self.tokenizer(
                predicted_tokens, task="decode_from_list"
            )

            # Convert indices to words
            target_words = undo_padding(tokens, token_lens)
            target_words = self.tokenizer(target_words, task="decode_from_list")

            self.wer_metric.append(ids, predicted_words, target_words)
            self.cer_metric.append(ids, predicted_words, target_words)

        return loss
예제 #11
0
    def compute_objectives(self, predictions, batch, stage):
        """Compute possibly several loss terms: enhance, mimic, ctc, seq"""

        # Do not augment targets
        clean_wavs, clean_feats, lens = self.prepare_feats(batch.clean_sig,
                                                           augment=False)
        loss = 0

        # Compute enhancement loss
        if self.hparams.enhance_weight > 0:
            enhance_loss = self.hparams.enhance_loss(predictions["feats"],
                                                     clean_feats, lens)
            loss += self.hparams.enhance_weight * enhance_loss

            if stage != sb.Stage.TRAIN:
                self.enh_metrics.append(batch.id, predictions["feats"],
                                        clean_feats, lens)
                self.stoi_metrics.append(
                    ids=batch.id,
                    predict=predictions["wavs"],
                    target=clean_wavs,
                    lengths=lens,
                )
                self.pesq_metrics.append(
                    ids=batch.id,
                    predict=predictions["wavs"],
                    target=clean_wavs,
                    lengths=lens,
                )

                if hasattr(self.hparams, "enh_dir"):
                    abs_lens = lens * predictions["wavs"].size(1)
                    for i, uid in enumerate(batch.id):
                        length = int(abs_lens[i])
                        wav = predictions["wavs"][i, :length].unsqueeze(0)
                        path = os.path.join(self.hparams.enh_dir, uid + ".wav")
                        torchaudio.save(path, wav.cpu(), sample_rate=16000)

        # Compute mimic loss
        if self.hparams.mimic_weight > 0:
            clean_embed = self.modules.src_embedding.CNN(clean_feats)
            enh_embed = self.modules.src_embedding.CNN(predictions["feats"])
            mimic_loss = self.hparams.mimic_loss(enh_embed, clean_embed, lens)
            loss += self.hparams.mimic_weight * mimic_loss

            if stage != sb.Stage.TRAIN:
                self.mimic_metrics.append(batch.id, enh_embed, clean_embed,
                                          lens)

        # Compute hard ASR loss
        if self.hparams.ctc_weight > 0 and (
                not hasattr(self.hparams, "ctc_epochs") or
                self.hparams.epoch_counter.current < self.hparams.ctc_epochs):
            tokens, token_lens = self.prepare_targets(batch.tokens)
            ctc_loss = self.hparams.ctc_loss(predictions["ctc_pout"], tokens,
                                             lens, token_lens)
            loss += self.hparams.ctc_weight * ctc_loss

            if stage != sb.Stage.TRAIN and self.hparams.seq_weight == 0:
                predict = sb.decoders.ctc_greedy_decode(
                    predictions["ctc_pout"], lens, blank_id=-1)
                self.err_rate_metrics.append(
                    ids=batch.id,
                    predict=predict,
                    target=tokens,
                    target_len=token_lens,
                    ind2lab=self.hparams.ind2lab,
                )

        # Compute nll loss for seq2seq model
        if self.hparams.seq_weight > 0:

            tokens, token_lens = self.prepare_targets(batch.tokens_eos)
            seq_loss = self.hparams.seq_loss(predictions["seq_pout"], tokens,
                                             token_lens)
            loss += self.hparams.seq_weight * seq_loss

            if stage != sb.Stage.TRAIN:
                if hasattr(self.hparams, "asr_pretrained"):
                    pred_words = [
                        self.token_encoder.decode_ids(token_seq)
                        for token_seq in predictions["hyps"]
                    ]
                    target_words = [
                        self.token_encoder.decode_ids(token_seq)
                        for token_seq in undo_padding(*batch.tokens)
                    ]
                    self.err_rate_metrics.append(batch.id, pred_words,
                                                 target_words)
                else:
                    self.err_rate_metrics.append(
                        ids=batch.id,
                        predict=predictions["hyps"],
                        target=tokens,
                        target_len=token_lens,
                        ind2lab=self.token_encoder.decode_ndim,
                    )

        return loss
예제 #12
0
    def compute_forward_tea(self, x, y, init_params=False):
        ids, wavs, wav_lens = x
        ids, phns, phn_lens = y

        wavs, wav_lens = wavs.to(params.device), wav_lens.to(params.device)
        phns, phn_lens = phns.to(params.device), phn_lens.to(params.device)

        if hasattr(params, "augmentation"):
            wavs = params.augmentation(wavs, wav_lens, init_params)
        feats = params.compute_features(wavs, init_params)
        feats = params.normalize(feats, wav_lens)
        apply_softmax = torch.nn.Softmax(dim=-1)

        ind2lab = params.train_loader.label_dict["phn"]["index2lab"]
        phns_decode = undo_padding(phns, phn_lens)
        phns_decode = convert_index_to_lab(phns_decode, ind2lab)

        # run inference to each teacher model
        tea_dict_list = []
        for num in range(params.num_tea):
            tea_dict = {}
            self.tea_modules_list[num].eval()
            with torch.no_grad():
                x_tea = tea_enc_list[num](feats, init_params=init_params)
                ctc_logits_tea = tea_ctc_lin_list[num](x_tea, init_params)

                # output layer for ctc log-probabilities
                p_ctc_tea = params.log_softmax(ctc_logits_tea / params.T)

                # Prepend bos token at the beginning
                y_in_tea = prepend_bos_token(phns, bos_index=params.bos_index)
                e_in_tea = tea_emb_list[num](y_in_tea, init_params=init_params)
                h_tea, _ = tea_dec_list[num](
                    e_in_tea, x_tea, wav_lens, init_params
                )

                # output layer for seq2seq log-probabilities
                seq_logits_tea = tea_seq_lin_list[num](h_tea, init_params)
                p_seq_tea = apply_softmax(seq_logits_tea / params.T)

                # WER from output layer of CTC
                sequence_ctc = ctc_greedy_decode(
                    p_ctc_tea, wav_lens, blank_id=params.blank_index
                )
                sequence_ctc = convert_index_to_lab(sequence_ctc, ind2lab)
                per_stats_ctc = edit_distance.wer_details_for_batch(
                    ids, phns_decode, sequence_ctc, compute_alignments=False
                )

                wer_ctc_tea = []
                for item in per_stats_ctc:
                    wer_ctc_tea.append(item["WER"])

                wer_ctc_tea = exclude_wer(wer_ctc_tea)
                wer_ctc_tea = np.expand_dims(wer_ctc_tea, axis=0)

                # WER from output layer of CE
                _, predictions = p_seq_tea.max(dim=-1)
                hyps = batch_filter_seq2seq_output(
                    predictions, eos_id=params.eos_index
                )
                sequence_ce = convert_index_to_lab(hyps, ind2lab)
                per_stats_ce = edit_distance.wer_details_for_batch(
                    ids, phns_decode, sequence_ce, compute_alignments=False
                )

                wer_tea = []
                for item in per_stats_ce:
                    wer_tea.append(item["WER"])

                wer_tea = exclude_wer(wer_tea)
                wer_tea = np.expand_dims(wer_tea, axis=0)

            # save the variables into dict
            tea_dict["p_ctc_tea"] = p_ctc_tea.cpu().numpy()
            tea_dict["p_seq_tea"] = p_seq_tea.cpu().numpy()
            tea_dict["wer_ctc_tea"] = wer_ctc_tea
            tea_dict["wer_tea"] = wer_tea
            tea_dict_list.append(tea_dict)

        return tea_dict_list
예제 #13
0
    def compute_objectives(self,
                           predictions,
                           targets,
                           data_dict,
                           batch_id,
                           stage="train"):
        if stage == "train":
            p_ctc, p_seq, wav_lens = predictions
        else:
            p_ctc, p_seq, wav_lens, hyps = predictions

        ids, phns, phn_lens = targets
        phns, phn_lens = phns.to(params.device), phn_lens.to(params.device)

        # Add phn_lens by one for eos token
        abs_length = torch.round(phn_lens * phns.shape[1])

        # Append eos token at the end of the label sequences
        phns_with_eos = append_eos_token(phns,
                                         length=abs_length,
                                         eos_index=params.eos_index)

        # convert to speechbrain-style relative length
        rel_length = (abs_length + 1) / phns.shape[1]

        # normal supervised training
        loss_ctc_nor = params.ctc_cost(p_ctc, phns, wav_lens, phn_lens)
        loss_seq_nor = params.seq_cost(p_seq, phns_with_eos, length=rel_length)

        # load teacher inference results
        item_tea_list = [None, None, None, None]
        for tea_num in range(params.num_tea):
            for i in range(4):
                item_tea = data_dict[str(batch_id)][tea_name[tea_num]][
                    tea_keys[i]][()]

                if tea_keys[i].startswith("wer"):
                    item_tea = torch.tensor(item_tea)
                else:
                    item_tea = torch.from_numpy(item_tea)

                item_tea = item_tea.to(params.device)
                item_tea = torch.unsqueeze(item_tea, 0)
                if tea_num == 0:
                    item_tea_list[i] = item_tea
                else:
                    item_tea_list[i] = torch.cat([item_tea_list[i], item_tea],
                                                 0)

        p_ctc_tea = item_tea_list[0]
        p_seq_tea = item_tea_list[1]
        wer_ctc_tea = item_tea_list[2]
        wer_tea = item_tea_list[3]

        # Stategy "average": average losses of teachers when doing distillation.
        # Stategy "best": choosing the best teacher based on WER.
        # Stategy "weighted": assigning weights to teachers based on WER.
        if params.strategy == "best":
            # tea_ce for kd
            wer_scores, indx = torch.min(wer_tea, dim=0)
            indx = list(indx.cpu().numpy())

            # select the best teacher for each sentence
            tea_seq2seq_pout = None
            for stn_indx, tea_indx in enumerate(indx):
                s2s_one = p_seq_tea[tea_indx][stn_indx]
                s2s_one = torch.unsqueeze(s2s_one, 0)
                if stn_indx == 0:
                    tea_seq2seq_pout = s2s_one
                else:
                    tea_seq2seq_pout = torch.cat([tea_seq2seq_pout, s2s_one],
                                                 0)

        apply_softmax = torch.nn.Softmax(dim=0)

        if params.strategy == "best" or params.strategy == "weighted":
            # mean wer for ctc
            tea_wer_ctc_mean = wer_ctc_tea.mean(1)
            tea_acc_main = 100 - tea_wer_ctc_mean

            # normalise weights via Softmax function
            tea_acc_softmax = apply_softmax(tea_acc_main)

        if params.strategy == "weighted":
            # mean wer for ce
            tea_wer_mean = wer_tea.mean(1)
            tea_acc_ce_main = 100 - tea_wer_mean

            # normalise weights via Softmax function
            tea_acc_ce_softmax = apply_softmax(tea_acc_ce_main)

        # kd loss
        ctc_loss_list = None
        ce_loss_list = None
        for tea_num in range(params.num_tea):
            # ctc
            p_ctc_tea_one = p_ctc_tea[tea_num]
            # calculate CTC distillation loss of one teacher
            loss_ctc_one = params.ctc_cost_kd(p_ctc, p_ctc_tea_one, wav_lens)
            loss_ctc_one = torch.unsqueeze(loss_ctc_one, 0)
            if tea_num == 0:
                ctc_loss_list = loss_ctc_one
            else:
                ctc_loss_list = torch.cat([ctc_loss_list, loss_ctc_one])

            # ce
            p_seq_tea_one = p_seq_tea[tea_num]
            # calculate CE distillation loss of one teacher
            loss_seq_one = params.seq_cost_kd(p_seq, p_seq_tea_one, rel_length)
            loss_seq_one = torch.unsqueeze(loss_seq_one, 0)
            if tea_num == 0:
                ce_loss_list = loss_seq_one
            else:
                ce_loss_list = torch.cat([ce_loss_list, loss_seq_one])

        # kd loss
        if params.strategy == "average":
            # get average value of losses from all teachers (CTC and CE loss)
            ctc_loss_kd = ctc_loss_list.mean(0)
            seq2seq_loss_kd = ce_loss_list.mean(0)
        else:
            # assign weights to different teachers (CTC loss)
            ctc_loss_kd = (tea_acc_softmax * ctc_loss_list).sum(0)
            if params.strategy == "best":
                # only use the best teacher to compute CE loss
                seq2seq_loss_kd = params.seq_cost_kd(p_seq, tea_seq2seq_pout,
                                                     rel_length)
            if params.strategy == "weighted":
                # assign weights to different teachers (CE loss)
                seq2seq_loss_kd = (tea_acc_ce_softmax * ce_loss_list).sum(0)

        # total loss
        # combine normal supervised training
        loss_ctc = (params.Temperature * params.Temperature * params.alpha *
                    ctc_loss_kd + (1 - params.alpha) * loss_ctc_nor)
        loss_seq = (params.Temperature * params.Temperature * params.alpha *
                    seq2seq_loss_kd + (1 - params.alpha) * loss_seq_nor)

        loss = params.ctc_weight * loss_ctc + (1 -
                                               params.ctc_weight) * loss_seq

        stats = {}
        if stage != "train":
            ind2lab = params.train_loader.label_dict["phn"]["index2lab"]
            sequence = convert_index_to_lab(hyps, ind2lab)
            phns = undo_padding(phns, phn_lens)
            phns = convert_index_to_lab(phns, ind2lab)
            per_stats = edit_distance.wer_details_for_batch(
                ids, phns, sequence, compute_alignments=True)
            stats["PER"] = per_stats
        return loss, stats