Beispiel #1
0
def _compute_mmi_loss_exact_non_optimized(
        nnet_output: torch.Tensor,
        texts: List[str],
        supervision_segments: torch.Tensor,
        graph_compiler: MmiTrainingGraphCompiler,
        den_scale: float = 1.0
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    '''
    See :func:`_compute_mmi_loss_exact_optimized` for the meaning
    of the arguments.

    It's more readable, though it invokes k2.intersect_dense twice.

    Note:
      It uses less memory at the cost of speed. It is slower.
    '''
    num_graphs, den_graphs = graph_compiler.compile(texts, replicate_den=True)

    dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments)

    num_lats = k2.intersect_dense(num_graphs, dense_fsa_vec, output_beam=10.0)
    den_lats = k2.intersect_dense(den_graphs, dense_fsa_vec, output_beam=10.0)

    num_tot_scores = num_lats.get_tot_scores(log_semiring=True,
                                             use_double_scores=True)

    den_tot_scores = den_lats.get_tot_scores(log_semiring=True,
                                             use_double_scores=True)

    tot_scores = num_tot_scores - den_scale * den_tot_scores
    tot_score, tot_frames, all_frames = get_tot_objf_and_num_frames(
        tot_scores, supervision_segments[:, 2])
    return tot_score, tot_frames, all_frames
Beispiel #2
0
def _compute_mmi_loss_pruned(
        nnet_output: torch.Tensor,
        texts: List[str],
        supervision_segments: torch.Tensor,
        graph_compiler: MmiTrainingGraphCompiler,
        P: k2.Fsa,
        den_scale: float = 1.0
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    '''
    See :func:`_compute_mmi_loss_exact_optimized` for the meaning
    of the arguments.

    `pruned` means it uses k2.intersect_dense_pruned

    Note:
      It uses the least amount of memory, but the loss is not exact due
      to pruning.
    '''
    num_graphs, den_graphs = graph_compiler.compile(texts,
                                                    P,
                                                    replicate_den=False)
    dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments)

    num_lats = k2.intersect_dense(num_graphs, dense_fsa_vec, output_beam=10.0)

    # the values for search_beam/output_beam/min_active_states/max_active_states
    # are not tuned. You may want to tune them.
    den_lats = k2.intersect_dense_pruned(den_graphs,
                                         dense_fsa_vec,
                                         search_beam=20.0,
                                         output_beam=7.0,
                                         min_active_states=30,
                                         max_active_states=10000)

    num_tot_scores = num_lats.get_tot_scores(log_semiring=True,
                                             use_double_scores=True)

    den_tot_scores = den_lats.get_tot_scores(log_semiring=True,
                                             use_double_scores=True)

    tot_scores = num_tot_scores - den_scale * den_tot_scores
    tot_score, tot_frames, all_frames = get_tot_objf_and_num_frames(
        tot_scores, supervision_segments[:, 2])
    return tot_score, tot_frames, all_frames
Beispiel #3
0
def _compute_mmi_loss_exact_optimized(
        nnet_output: torch.Tensor,
        texts: List[str],
        supervision_segments: torch.Tensor,
        graph_compiler: MmiTrainingGraphCompiler,
        P: k2.Fsa,
        den_scale: float = 1.0
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    '''
    The function name contains `exact`, which means it uses a version of
    intersection without pruning.

    `optimized` in the function name means this function is optimized
    in that it calls k2.intersect_dense only once

    Note:
      It is faster at the cost of using more memory.

    Args:
      nnet_output:
        A 3-D tensor of shape [N, T, C]
      texts:
        The transcript. Each element consists of space(s) separated words.
      supervision_segments:
        A 2-D tensor that will be passed to :func:`k2.DenseFsaVec`.
      graph_compiler:
        Used to build num_graphs and den_graphs
      P:
        Represents a bigram Fsa.
      den_scale:
        The scale applied to the denominator tot_scores.
    '''
    num_graphs, den_graphs = graph_compiler.compile(texts,
                                                    P,
                                                    replicate_den=False)

    dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments)

    device = num_graphs.device

    num_fsas = num_graphs.shape[0]
    assert dense_fsa_vec.dim0() == num_fsas

    assert den_graphs.shape[0] == 1

    # the aux_labels of num_graphs is k2.RaggedInt
    # but it is torch.Tensor for den_graphs.
    #
    # The following converts den_graphs.aux_labels
    # from torch.Tensor to k2.RaggedInt so that
    # we can use k2.append() later
    den_graphs.convert_attr_to_ragged_(name='aux_labels')

    # The motivation to concatenate num_graphs and den_graphs
    # is to reduce the number of calls to k2.intersect_dense.
    num_den_graphs = k2.cat([num_graphs, den_graphs])

    # NOTE: The a_to_b_map in k2.intersect_dense must be sorted
    # so the following reorders num_den_graphs.
    #
    # The following code computes a_to_b_map

    # [0, 1, 2, ... ]
    num_graphs_indexes = torch.arange(num_fsas, dtype=torch.int32)

    # [num_fsas, num_fsas, num_fsas, ... ]
    den_graphs_indexes = torch.tensor([num_fsas] * num_fsas, dtype=torch.int32)

    # [0, num_fsas, 1, num_fsas, 2, num_fsas, ... ]
    num_den_graphs_indexes = torch.stack(
        [num_graphs_indexes, den_graphs_indexes]).t().reshape(-1).to(device)

    num_den_reordered_graphs = k2.index(num_den_graphs, num_den_graphs_indexes)

    # [[0, 1, 2, ...]]
    a_to_b_map = torch.arange(num_fsas, dtype=torch.int32).reshape(1, -1)

    # [[0, 1, 2, ...]] -> [0, 0, 1, 1, 2, 2, ... ]
    a_to_b_map = a_to_b_map.repeat(2, 1).t().reshape(-1).to(device)

    num_den_lats = k2.intersect_dense(num_den_reordered_graphs,
                                      dense_fsa_vec,
                                      output_beam=10.0,
                                      a_to_b_map=a_to_b_map)

    num_den_tot_scores = num_den_lats.get_tot_scores(log_semiring=True,
                                                     use_double_scores=True)

    num_tot_scores = num_den_tot_scores[::2]
    den_tot_scores = num_den_tot_scores[1::2]

    tot_scores = num_tot_scores - den_scale * den_tot_scores
    tot_score, tot_frames, all_frames = get_tot_objf_and_num_frames(
        tot_scores, supervision_segments[:, 2])
    return tot_score, tot_frames, all_frames
Beispiel #4
0
class ASR:
    """
    This class is a high-level wrapper for K2 acoustic models that simplifies inference:
    reading models, computing posteriors, decoding, alignments, etc.

    Currently it will only work with the Conformer model with a very specific HMM topology.
    It could be the basis for a more generic entry point to Snow(Ice?)fall.
    """
    def __init__(
        self,
        lang_dir: Pathlike,
        scripted_model_path: Optional[Pathlike] = None,
        model_dir: Optional[Pathlike] = None,
        average_epochs: Sequence[int] = (7, 8, 9),
        device: torch.device = 'cpu',
        sampling_rate: int = 16000,
    ):
        if isinstance(device, str):
            self.device = torch.device(device)

        self.sampling_rate = sampling_rate
        self.extractor = Fbank(FbankConfig(num_mel_bins=80))
        self.lexicon = Lexicon(lang_dir)
        phone_ids = self.lexicon.phone_symbols()
        self.P = create_bigram_phone_lm(phone_ids)

        if model_dir is not None:
            # Read model from regular checkpoints, assume it's a Conformer
            self.model = Conformer(num_features=80,
                                   num_classes=len(phone_ids) + 1,
                                   num_decoder_layers=0)
            self.P.scores = torch.zeros_like(self.P.scores)
            self.model.P_scores = torch.nn.Parameter(self.P.scores.clone(),
                                                     requires_grad=False)
            average_checkpoint(filenames=[
                model_dir / f'epoch-{n}.pt' for n in average_epochs
            ],
                               model=self.model)
        elif scripted_model_path is not None:
            # Read model from a serialized TorchScript module, no assumptions needed
            self.model = torch.jit.load(scripted_model_path)
        else:
            raise ValueError(
                "One of scripted_model_path or model_dir needs to be provided."
            )

        # Freeze the params by default.
        for p in self.model.parameters():
            p.requires_grad_(False)
        self.compiler = MmiTrainingGraphCompiler(lexicon=self.lexicon,
                                                 device=self.device)
        self.HLG = k2.Fsa.from_dict(torch.load(lang_dir / 'HLG.pt')).to(
            self.device)

    def compute_features(self, cuts: Union[AnyCut, CutSet]) -> torch.Tensor:
        if isinstance(cuts, (Cut, MixedCut)):
            cuts = CutSet.from_cuts([cuts])
        assert cuts[
            0].sampling_rate == self.sampling_rate, f'{cuts[0].sampling_rate} != {self.sampling_rate}'
        otf = OnTheFlyFeatures(self.extractor)
        # feats: (batch, seq_len, n_feats)
        feats, _ = otf(cuts)
        return feats

    def compute_posteriors(self, cuts: Union[AnyCut, CutSet]) -> torch.Tensor:
        """
        Run the forward pass of the acoustic model and return a tensor representing a batch of phone posteriorgrams.
        """
        # Extract feats
        # (batch, seq_len, num_feats)
        if isinstance(cuts, (Cut, MixedCut)):
            cuts = CutSet.from_cuts([cuts])
        assert cuts[
            0].sampling_rate == self.sampling_rate, f'{cuts[0].sampling_rate} != {self.sampling_rate}'
        otf = OnTheFlyFeatures(self.extractor)
        # feats: (batch, seq_len, n_feats)
        feats, _ = otf(cuts)
        # feats: (batch, n_feats, seq_len)
        feats = feats.permute(0, 2, 1)

        # Compute AM posteriors
        # posteriors: (batch, n_phones, ~seq_len / 4)
        posteriors, _, _ = self.model(feats)
        # returns: (batch, ~seq_len / 4, n_phones)
        return posteriors.permute(0, 2, 1)

    def decode(
            self, cuts: Union[AnyCut,
                              CutSet]) -> List[Tuple[List[str], List[str]]]:
        """
        Perform decoding with an n-gram language model (HLG graph).
        Doesn't support rescoring at this time.
        """
        if isinstance(cuts, (Cut, MixedCut)):
            cuts = CutSet.from_cuts([cuts])
        word_results = []
        # Hacky way to get batch quickly... we may need to improve on this.
        batch = K2SpeechRecognitionDataset(cuts,
                                           input_strategy=OnTheFlyFeatures(
                                               self.extractor),
                                           check_inputs=False)[list(cuts.ids)]
        features = batch['inputs'].permute(0, 2, 1).to(
            self.device)  # (B, T, F) -> (B, F, T)
        supervision_segments, texts = encode_supervisions(
            batch['supervisions'])

        # Forward pass through the acoustic model
        posteriors, _, _ = self.model(features)
        posteriors = posteriors.permute(0, 2, 1)  # (B, F, T) -> (B, T, F)

        # Wrapping into k2 "dense FSA" (representing PPG as a dense graph)
        dense_fsa_vec = k2.DenseFsaVec(posteriors, supervision_segments)

        # The actual decoding starts here:
        # First, we intersect the HLG and the PPG
        # with default pruning/beam search params from snowfall
        # The result is a batch of graphs (lattices)
        lattices = k2.intersect_dense_pruned(self.HLG, dense_fsa_vec, 20.0, 8,
                                             30, 10000)
        # ... then we find the shortest paths in the lattices ...
        best_paths = k2.shortest_path(lattices, use_double_scores=True)
        # ... and convert them to words with a convenience wrapper from snowfall
        hyps = get_texts(best_paths, torch.arange(len(texts)))

        # Here we read out the words from the best path graphs
        for i in range(len(texts)):
            hyp_words = [self.lexicon.words.get(x) for x in hyps[i]]
            ref_words = texts[i].split(' ')
            word_results.append((ref_words, hyp_words))
        return word_results

    def align(self, cuts: Union[AnyCut, CutSet]) -> torch.Tensor:
        """
        Perform forced alignment and return a tensor that represents a batch of frame-level alignments:
        >>> alignments = torch.tensor([
        ...     [0, 0, 0, 1, 57, 57, 35, 35, 35, ...],
        ...     [...],
        ...     ...
        ... ])

        :return: an int32 tensor with shape ``(batch_size, num_frames)``.
        """
        # Extract feats
        # (batch, seq_len, num_feats)
        if isinstance(cuts, (Cut, MixedCut)):
            cuts = CutSet.from_cuts([cuts])
        assert cuts[
            0].sampling_rate == self.sampling_rate, f'{cuts[0].sampling_rate} != {self.sampling_rate}'

        cuts = cuts.map_supervisions(self.normalize_text)

        otf = OnTheFlyFeatures(self.extractor)
        feats, _ = otf(cuts)
        feats = feats.permute(0, 2, 1)
        texts = [' '.join(s.text for s in cut.supervisions) for cut in cuts]

        # Compute AM posteriors
        # (batch, seq_len ~/ 4, num_phones)
        posteriors, _, _ = self.model(feats)
        # Note: we are using "dummy" supervisions so that the aligner also considers
        # the padding area. We can adjust that behaviour if needed by passing actual
        # supervision segments, but then we will have a ragged tensor (will need to
        # pad the alignments themselves).
        sups = self.dummy_supervisions(feats)
        posteriors_fsa = k2.DenseFsaVec(posteriors.permute(0, 2, 1), sups)

        # Intersection with ground truth transcript graphs
        num, den = self.compiler.compile(texts, self.P)
        alignment = k2.intersect_dense(num, posteriors_fsa, output_beam=10.0)
        best_path = k2.shortest_path(alignment, use_double_scores=True)

        # Retrieve sequences of phone IDs per frame
        # (batch, seq_len ~/ 4) -- dtype int32 (num phone labels)
        frame_labels = torch.stack(
            [best_path[i].labels[:-1] for i in range(best_path.shape[0])])
        return frame_labels

    def align_ctm(self, cuts: Union[CutSet,
                                    AnyCut]) -> List[List[AlignmentItem]]:
        """
        Perform forced alignment and parse the phones into a CTM-like format:
            >>> [[0.0, 0.12, 'SIL'], [0.12, 0.2, 'AH0'], ...]
        """
        # TODO: I am not sure that this method is extracting the alignment 100% correctly:
        #       need to revise...
        # TODO: when K2/Snowfall has a standard way of indicating what is silence,
        #       or we update the model, update the constants below.
        EPS = 0
        SIL = 1
        non_speech = {EPS, SIL}

        def to_s(n: int) -> float:
            FRAME_SHIFT = 0.04  # 0.01 * 4 subsampling
            return round(n * FRAME_SHIFT, ndigits=3)

        if isinstance(cuts, (Cut, MixedCut)):
            cuts = CutSet.from_cuts([cuts])

        # Uppercase and remove punctuation
        cuts = cuts.map_supervisions(self.normalize_text)
        alignments = self.align(cuts).tolist()

        ctm_alis = []
        for cut, alignment in zip(cuts, alignments):
            # First we determine the silence regions at the beginning and the end:
            # we assume that every SIL and <eps> before the first phone, and after the last phone,
            # are representing silence.
            first_speech_idx = [
                idx for idx, s in enumerate(alignment) if s not in non_speech
            ][0]
            last_speech_idx = [
                idx for idx, s in reversed(list(enumerate(alignment)))
                if s not in non_speech
            ][0]
            speech_ali = alignment[first_speech_idx:last_speech_idx]
            ctm_ali = [
                AlignmentItem(start=0.0,
                              duration=to_s(first_speech_idx),
                              symbol=self.lexicon.phones[SIL])
            ]

            # Then, we iterate over the speech region: since the K2 model uses 2-state HMM
            # topology that allows blank (<eps>) to follow a phone symbol, we treat <eps>
            # as continuation of the "previous" phone.
            # TODO: I think this implementation is wrong in that it merges repeating phones...
            #       Will fix.
            # TODO: I think it could be simplified by using some smart semi-ring and FSA operations...
            start = first_speech_idx
            prev_s = speech_ali[0]
            curr_s = speech_ali[0]
            cntr = 1
            for s in speech_ali[1:]:
                curr_s = s if s != EPS else curr_s
                if curr_s != prev_s:
                    ctm_ali.append(
                        AlignmentItem(start=to_s(start),
                                      duration=to_s(cntr),
                                      symbol=self.lexicon.phones[prev_s]))
                    start = start + cntr
                    prev_s = curr_s
                    cntr = 1
                else:
                    cntr += 1
            if cntr:
                ctm_ali.append(
                    AlignmentItem(start=to_s(start),
                                  duration=to_s(cntr),
                                  symbol=self.lexicon.phones[prev_s]))

            speech_end_timestamp = to_s(last_speech_idx)
            if speech_end_timestamp > cut.duration:
                logging.warning(
                    f"speech_end_timestamp <= cut.duration. Skipping cut {cut.id}"
                )
                ctm_alis.append(None)
                continue

            ctm_ali.append(
                AlignmentItem(start=speech_end_timestamp,
                              duration=round(cut.duration -
                                             speech_end_timestamp,
                                             ndigits=8),
                              symbol=self.lexicon.phones[SIL]))
            ctm_alis.append(ctm_ali)

        return ctm_alis

    def plot_alignments(self, cut: AnyCut):
        import matplotlib.pyplot as plt
        feats = self.compute_features(cut)
        phone_ids = self.align(cut)
        fig, axes = plt.subplots(2,
                                 squeeze=True,
                                 sharey=True,
                                 figsize=(10, 14))
        axes[0].imshow(np.flipud(feats[0].T))
        axes[1].imshow(
            torch.nn.functional.one_hot(
                phone_ids.repeat_interleave(4).to(torch.int64)).T)
        return fig, axes

    def plot_posteriors(self, cut: AnyCut):
        import matplotlib.pyplot as plt
        feats = self.compute_features(cut)
        posteriors = self.compute_posteriors(cut)
        fig, axes = plt.subplots(2,
                                 squeeze=True,
                                 sharey=True,
                                 figsize=(10, 14))
        axes[0].imshow(np.flipud(feats[0].T))
        axes[1].imshow(posteriors[0].exp().repeat_interleave(4, 1))
        return fig, axes

    @staticmethod
    def dummy_supervisions(feats):
        def size_after_conv(size, num_layers=2):
            for i in range(num_layers):
                size = (size - 1) // 2
            return size

        return torch.tensor([[
            i,
            size_after_conv(2, num_layers=2),
            size_after_conv(feats.shape[2] - 2, num_layers=2)
        ] for i in range(feats.size(0))],
                            dtype=torch.int32).clamp(min=0)

    @staticmethod
    def normalize_text(supervision):
        text = re.sub(r'[^\w\s]', '', supervision.text.upper())
        return fastcopy(supervision, text=text)
Beispiel #5
0
def get_objf(batch: Dict,
             model: AcousticModel,
             P: k2.Fsa,
             device: torch.device,
             graph_compiler: MmiTrainingGraphCompiler,
             is_training: bool,
             tb_writer: Optional[SummaryWriter] = None,
             global_batch_idx_train: Optional[int] = None,
             optimizer: Optional[torch.optim.Optimizer] = None):
    feature = batch['features']
    supervisions = batch['supervisions']
    subsampling_factor = model.module.subsampling_factor if isinstance(
        model, DDP) else model.subsampling_factor
    supervision_segments = torch.stack(
        (supervisions['sequence_idx'],
         torch.floor_divide(supervisions['start_frame'], subsampling_factor),
         torch.floor_divide(supervisions['num_frames'], subsampling_factor)),
        1).to(torch.int32)
    indices = torch.argsort(supervision_segments[:, 2], descending=True)
    supervision_segments = supervision_segments[indices]

    texts = supervisions['text']
    texts = [texts[idx] for idx in indices]
    assert feature.ndim == 3
    # print(supervision_segments[:, 1] + supervision_segments[:, 2])

    feature = feature.to(device)
    # at entry, feature is [N, T, C]
    feature = feature.permute(0, 2, 1)  # now feature is [N, C, T]
    if is_training:
        nnet_output = model(feature)
    else:
        with torch.no_grad():
            nnet_output = model(feature)

    # nnet_output is [N, C, T]
    nnet_output = nnet_output.permute(0, 2, 1)  # now nnet_output is [N, T, C]

    if is_training:
        num, den = graph_compiler.compile(texts, P)
    else:
        with torch.no_grad():
            num, den = graph_compiler.compile(texts, P)

    assert num.requires_grad == is_training
    assert den.requires_grad is False
    num = num.to(device)
    den = den.to(device)

    # nnet_output2 = nnet_output.clone()
    # blank_bias = -7.0
    # nnet_output2[:,:,0] += blank_bias

    dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments)
    assert nnet_output.device == device

    num = k2.intersect_dense(num, dense_fsa_vec, 10.0)
    den = k2.intersect_dense(den, dense_fsa_vec, 10.0)

    num_tot_scores = num.get_tot_scores(log_semiring=True,
                                        use_double_scores=True)
    den_tot_scores = den.get_tot_scores(log_semiring=True,
                                        use_double_scores=True)
    tot_scores = num_tot_scores - den_scale * den_tot_scores

    (tot_score, tot_frames,
     all_frames) = get_tot_objf_and_num_frames(tot_scores,
                                               supervision_segments[:, 2])

    if is_training:

        def maybe_log_gradients(tag: str):
            if (tb_writer is not None and global_batch_idx_train is not None
                    and global_batch_idx_train % 200 == 0):
                tb_writer.add_scalars(tag,
                                      measure_gradient_norms(model, norm='l1'),
                                      global_step=global_batch_idx_train)

        optimizer.zero_grad()
        (-tot_score).backward()
        maybe_log_gradients('train/grad_norms')
        clip_grad_value_(model.parameters(), 5.0)
        maybe_log_gradients('train/clipped_grad_norms')
        if tb_writer is not None and global_batch_idx_train % 200 == 0:
            # Once in a time we will perform a more costly diagnostic
            # to check the relative parameter change per minibatch.
            deltas = optim_step_and_measure_param_change(model, optimizer)
            tb_writer.add_scalars('train/relative_param_change_per_minibatch',
                                  deltas,
                                  global_step=global_batch_idx_train)
        else:
            optimizer.step()

    ans = -tot_score.detach().cpu().item(), tot_frames.cpu().item(
    ), all_frames.cpu().item()
    return ans
Beispiel #6
0
def get_objf(batch: Dict,
             model: AcousticModel,
             P: k2.Fsa,
             device: torch.device,
             graph_compiler: MmiTrainingGraphCompiler,
             is_training: bool,
             optimizer: Optional[torch.optim.Optimizer] = None):
    feature = batch['features']
    supervisions = batch['supervisions']
    supervision_segments = torch.stack(
        (supervisions['sequence_idx'],
         torch.floor_divide(supervisions['start_frame'],
                            model.subsampling_factor),
         torch.floor_divide(supervisions['num_frames'],
                            model.subsampling_factor)), 1).to(torch.int32)
    indices = torch.argsort(supervision_segments[:, 2], descending=True)
    supervision_segments = supervision_segments[indices]

    texts = supervisions['text']
    texts = [texts[idx] for idx in indices]
    assert feature.ndim == 3
    # print(supervision_segments[:, 1] + supervision_segments[:, 2])

    feature = feature.to(device)
    # at entry, feature is [N, T, C]
    feature = feature.permute(0, 2, 1)  # now feature is [N, C, T]
    if is_training:
        nnet_output = model(feature)
    else:
        with torch.no_grad():
            nnet_output = model(feature)

    # nnet_output is [N, C, T]
    nnet_output = nnet_output.permute(0, 2, 1)  # now nnet_output is [N, T, C]

    if is_training:
        num, den = graph_compiler.compile(texts, P)
    else:
        with torch.no_grad():
            num, den = graph_compiler.compile(texts, P)

    assert num.requires_grad == is_training
    assert den.requires_grad is False
    num = num.to(device)
    den = den.to(device)

    # nnet_output2 = nnet_output.clone()
    # blank_bias = -7.0
    # nnet_output2[:,:,0] += blank_bias

    dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments)
    assert nnet_output.device == device

    num = k2.intersect_dense(num, dense_fsa_vec, 10.0)
    den = k2.intersect_dense(den, dense_fsa_vec, 10.0)

    num_tot_scores = num.get_tot_scores(log_semiring=True,
                                        use_double_scores=True)
    den_tot_scores = den.get_tot_scores(log_semiring=True,
                                        use_double_scores=True)
    tot_scores = num_tot_scores - den_scale * den_tot_scores

    (tot_score, tot_frames,
     all_frames) = get_tot_objf_and_num_frames(tot_scores,
                                               supervision_segments[:, 2])

    if is_training:
        optimizer.zero_grad()
        (-tot_score).backward()
        clip_grad_value_(model.parameters(), 5.0)
        optimizer.step()

    ans = -tot_score.detach().cpu().item(), tot_frames.cpu().item(
    ), all_frames.cpu().item()
    return ans