Example #1
0
    def compile(self, texts: Iterable[str],
                P: k2.Fsa) -> Tuple[k2.Fsa, k2.Fsa]:
        '''Create numerator and denominator graphs from transcripts
        and the bigram phone LM.

        Args:
          texts:
            A list of transcripts. Within a transcript, words are
            separated by spaces.
          P:
            The bigram phone LM created by :func:`create_bigram_phone_lm`.
        Returns:
          A tuple (num_graph, den_graph), where

            - `num_graph` is the numerator graph. It is an FsaVec with
              shape `(len(texts), None, None)`.

            - `den_graph` is the denominator graph. It is an FsaVec with the same
              shape of the `num_graph`.
        '''
        assert P.is_cpu()

        ctc_topo_P = k2.intersect(self.ctc_topo, P).invert_()
        ctc_topo_P = k2.connect(ctc_topo_P)

        num_graphs = k2.create_fsa_vec(
            [self.compile_one_and_cache(text) for text in texts])

        num = k2.compose(ctc_topo_P, num_graphs)
        num = k2.connect(num)
        num = k2.arc_sort(num)

        den = k2.create_fsa_vec([ctc_topo_P.detach()] * len(texts))

        return num, den
Example #2
0
    def __init__(self,
                 lexicon: Lexicon,
                 P: k2.Fsa,
                 device: torch.device,
                 oov: str = '<UNK>'):
        '''
        Args:
          L_inv:
            Its labels are words, while its aux_labels are phones.
          P:
            A phone bigram LM if the pronunciations in the lexicon are in phones;
            a word piece bigram if the pronunciations in the lexicon are word pieces.
          phones:
            The phone symbol table.
          words:
            The word symbol table.
          oov:
            Out of vocabulary word.
        '''
        self.lexicon = lexicon
        L_inv = self.lexicon.L_inv.to(device)
        P = P.to(device)

        if L_inv.properties & k2.fsa_properties.ARC_SORTED != 0:
            L_inv = k2.arc_sort(L_inv)

        assert L_inv.requires_grad is False

        assert oov in self.lexicon.words

        self.L_inv = L_inv
        self.oov_id = self.lexicon.words[oov]
        self.oov = oov
        self.device = device

        phone_symbols = get_phone_symbols(self.lexicon.phones)
        phone_symbols_with_blank = [0] + phone_symbols

        ctc_topo = build_ctc_topo(phone_symbols_with_blank).to(device)
        assert ctc_topo.requires_grad is False

        ctc_topo_inv = k2.arc_sort(ctc_topo.invert_())

        P_with_self_loops = k2.add_epsilon_self_loops(P)

        ctc_topo_P = k2.intersect(ctc_topo_inv,
                                  P_with_self_loops,
                                  treat_epsilons_specially=False).invert()

        self.ctc_topo_P = k2.arc_sort(ctc_topo_P)
Example #3
0
def decode(dataloader: torch.utils.data.DataLoader, model: AcousticModel,
           device: Union[str, torch.device], LG: Fsa, symbols: SymbolTable):
    results = []  # a list of pair (ref_words, hyp_words)
    for batch_idx, batch in enumerate(dataloader):
        feature = batch['features']
        supervisions = batch['supervisions']
        supervision_segments = torch.stack(
            (supervisions['sequence_idx'],
             torch.floor_divide(supervisions['start_frame'],
                                model.subsampling_factor),
             torch.floor_divide(supervisions['num_frames'],
                                model.subsampling_factor)), 1).to(torch.int32)
        texts = supervisions['text']
        assert feature.ndim == 3

        feature = feature.to(device)
        # at entry, feature is [N, T, C]
        feature = feature.permute(0, 2, 1)  # now feature is [N, C, T]
        with torch.no_grad():
            nnet_output = model(feature)
        # nnet_output is [N, C, T]
        nnet_output = nnet_output.permute(0, 2,
                                          1)  # now nnet_output is [N, T, C]

        dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments)
        assert LG.is_cuda()
        assert LG.device == nnet_output.device, \
            f"Check failed: LG.device ({LG.device}) == nnet_output.device ({nnet_output.device})"
        # TODO(haowen): with a small `beam`, we may get empty `target_graph`,
        # thus `tot_scores` will be `inf`. Definitely we need to handle this later.
        lattices = k2.intersect_dense_pruned(LG, dense_fsa_vec, 2000.0, 20.0,
                                             30, 300)
        best_paths = k2.shortest_path(lattices, use_float_scores=True)
        best_paths = best_paths.to('cpu')
        assert best_paths.shape[0] == len(texts)

        for i in range(len(texts)):
            hyp_words = [
                symbols.get(x) for x in best_paths[i].aux_labels if x > 0
            ]
            results.append((texts[i].split(' '), hyp_words))

        if batch_idx % 10 == 0:
            logging.info('Processed batch {}/{} ({:.6f}%)'.format(
                batch_idx, len(dataloader),
                float(batch_idx) / len(dataloader) * 100))

    return results
Example #4
0
def compile_LG(L: Fsa, G: Fsa, labels_disambig_id_start: int,
               aux_labels_disambig_id_start: int) -> Fsa:
    """
    Creates a decoding graph using a lexicon fst ``L`` and language model fsa ``G``.
    Involves arc sorting, intersection, determinization, removal of disambiguation symbols
    and adding epsilon self-loops.

    Args:
        L:
            An ``Fsa`` that represents the lexicon (L), i.e. has phones as ``symbols``
                and words as ``aux_symbols``.
        G:
            An ``Fsa`` that represents the language model (G), i.e. it's an acceptor
            with words as ``symbols``.
        labels_disambig_id_start:
            An integer ID corresponding to the first disambiguation symbol in the
            phonetic alphabet.
        aux_labels_disambig_id_start:
            An integer ID corresponding to the first disambiguation symbol in the
            words vocabulary.
    :return:
    """
    L_inv = k2.arc_sort(L.invert_())
    G = k2.arc_sort(G)
    logging.debug("Intersecting L and G")
    LG = k2.intersect(L_inv, G)
    logging.debug(f'LG shape = {LG.shape}')
    logging.debug("Connecting L*G")
    LG = k2.connect(LG).invert_()
    logging.debug(f'LG shape = {LG.shape}')
    logging.debug("Determinizing L*G")
    LG = k2.determinize(LG)
    logging.debug(f'LG shape = {LG.shape}')
    logging.debug("Connecting det(L*G)")
    LG = k2.connect(LG)
    logging.debug(f'LG shape = {LG.shape}')
    logging.debug("Removing disambiguation symbols on L*G")
    LG.labels[LG.labels >= labels_disambig_id_start] = 0
    LG.aux_labels[LG.aux_labels >= aux_labels_disambig_id_start] = 0
    LG = k2.add_epsilon_self_loops(LG)
    LG = k2.arc_sort(LG)
    logging.debug(
        f'LG is arc sorted: {(LG.properties & k2.fsa_properties.ARC_SORTED) != 0}'
    )
    return LG
Example #5
0
    def format_output(self, num_frames: List[int]) -> Fsa:
        """
        Generate the lattice Fsa currently got.

        Note:
          The attributes of the generated lattice is a union of the attributes
          of all the decoding graphs. For example, if `self` contains three
          individual stream, each stream has its own decoding graphs, graph[0]
          has attributes attr1, attr2; graph[1] has attributes attr1, attr3;
          graph[2] has attributes attr3, attr4; then the generated lattice has
          attributes attr1, attr2, attr3, attr4.

        Args:
          num_frames:
            A List containing the number of frames we want to gather for each
            stream (note: the frames we have ever received for the corresponding
            stream). It MUST satisfy `len(num_frames) == self.num_streams`.
        Returns:
          Return the lattice Fsa with all the attributes propagated.
          The returned Fsa has 3 axes with `fsa.dim0==self.num_streams`.
        """
        assert len(num_frames) == self.num_streams

        ragged_arcs, out_map = self.streams.format_output(num_frames)
        fsa = Fsa(ragged_arcs)

        # propagate attributes
        tensor_attr_info = dict()
        # gather the attributes info of all the decoding graphs,
        for i in range(self.num_streams):
            src = self.src_streams[i].fsa
            for name, value in src.named_tensor_attr(include_scores=False):
                if name not in tensor_attr_info:
                    filler = 0
                    if isinstance(value, Tensor):
                        filler = float(src.get_filler(name))
                        dtype = value.dtype
                        tensor_type = "Tensor"
                    else:
                        assert isinstance(value, k2.RaggedTensor)
                        # Only integer types ragged attributes are supported now
                        assert value.dtype == torch.int32
                        assert value.num_axes == 2
                        dtype = torch.int32
                        tensor_type = "RaggedTensor"
                    tensor_attr_info[name] = {
                        "filler": filler,
                        "dtype": dtype,
                        "tensor_type": tensor_type,
                    }
        # combine the attributes propagating from different decoding graphs
        for name, info in tensor_attr_info.items():
            values = list()
            start = 0
            for i in range(self.num_streams):
                src = self.src_streams[i].fsa
                device = self.device
                num_arcs = fsa[i].num_arcs
                arc_map = out_map[start:start + num_arcs]
                start = start + num_arcs
                if hasattr(src, name):
                    value = getattr(src, name)
                    if info["tensor_type"] == "Tensor":
                        assert isinstance(value, Tensor)
                        new_value = index_select(value,
                                                 arc_map,
                                                 default_value=filler)
                    else:
                        assert isinstance(value, RaggedTensor)
                        # Only integer types ragged attributes are supported now
                        assert value.num_axes == 2
                        assert value.dtype == torch.int32
                        new_value, _ = value.index(arc_map,
                                                   axis=0,
                                                   need_value_indexes=False)
                else:
                    if info["tensor_type"] == "Tensor":
                        # fill with filler value
                        new_value = torch.tensor(
                            [filler] * num_arcs,
                            dtype=info["dtype"],
                            device=device,
                        )
                    else:
                        # fill with empty RaggedTensor
                        new_value = RaggedTensor(
                            torch.empty(
                                (num_arcs, 0),
                                dtype=info["dtype"],
                                device=device,
                            ))
                values.append(new_value)
            if info["tensor_type"] == "Tensor":
                new_value = torch.cat(values)
            else:
                new_value = k2.ragged.cat(values, axis=0)
            setattr(fsa, name, new_value)

        # set non_tensor_attrs
        for i in range(self.num_streams):
            src = self.src_streams[i].fsa
            for name, value in src.named_non_tensor_attr():
                setattr(fsa, name, value)

        return fsa
Example #6
0
def train_one_epoch(dataloader: torch.utils.data.DataLoader,
                    valid_dataloader: torch.utils.data.DataLoader,
                    model: AcousticModel, P: k2.Fsa, device: torch.device,
                    graph_compiler: MmiMbrTrainingGraphCompiler,
                    optimizer: torch.optim.Optimizer, current_epoch: int,
                    tb_writer: SummaryWriter, num_epochs: int,
                    global_batch_idx_train: int):
    total_loss, total_mmi_loss, total_mbr_loss, total_frames, total_all_frames = 0., 0., 0., 0., 0.
    valid_average_loss = float('inf')
    time_waiting_for_batch = 0
    prev_timestamp = datetime.now()

    model.train()
    for batch_idx, batch in enumerate(dataloader):
        global_batch_idx_train += 1
        timestamp = datetime.now()
        time_waiting_for_batch += (timestamp - prev_timestamp).total_seconds()

        P.set_scores_stochastic_(model.P_scores)
        assert P.requires_grad is True

        curr_batch_mmi_loss, curr_batch_mbr_loss, curr_batch_frames, curr_batch_all_frames = get_loss(
            batch=batch,
            model=model,
            P=P,
            device=device,
            graph_compiler=graph_compiler,
            is_training=True,
            optimizer=optimizer)

        total_mmi_loss += curr_batch_mmi_loss
        total_mbr_loss += curr_batch_mbr_loss
        curr_batch_loss = curr_batch_mmi_loss + curr_batch_mbr_loss
        total_loss += curr_batch_loss
        total_frames += curr_batch_frames
        total_all_frames += curr_batch_all_frames

        if batch_idx % 10 == 0:
            logging.info('batch {}, epoch {}/{} '
                         'global average loss: {:.6f}, '
                         'global average mmi loss: {:.6f}, '
                         'global average mbr loss: {:.6f} over {} '
                         'frames ({:.1f}% kept), '
                         'current batch average loss: {:.6f}, '
                         'current batch average mmi loss: {:.6f}, '
                         'current batch average mbr loss: {:.6f} '
                         'over {} frames ({:.1f}% kept) '
                         'avg time waiting for batch {:.3f}s'.format(
                             batch_idx, current_epoch, num_epochs, total_loss /
                             total_frames, total_mmi_loss / total_frames,
                             total_mbr_loss / total_frames, total_frames,
                             100.0 * total_frames / total_all_frames,
                             curr_batch_loss / (curr_batch_frames + 0.001),
                             curr_batch_mmi_loss / (curr_batch_frames + 0.001),
                             curr_batch_mbr_loss / (curr_batch_frames + 0.001),
                             curr_batch_frames,
                             100.0 * curr_batch_frames / curr_batch_all_frames,
                             time_waiting_for_batch / max(1, batch_idx)))

            tb_writer.add_scalar('train/global_average_loss',
                                 total_loss / total_frames,
                                 global_batch_idx_train)

            tb_writer.add_scalar('train/global_average_mmi_loss',
                                 total_mmi_loss / total_frames,
                                 global_batch_idx_train)

            tb_writer.add_scalar('train/global_average_mbr_loss',
                                 total_mbr_loss / total_frames,
                                 global_batch_idx_train)

            tb_writer.add_scalar('train/current_batch_average_loss',
                                 curr_batch_loss / (curr_batch_frames + 0.001),
                                 global_batch_idx_train)

            tb_writer.add_scalar(
                'train/current_batch_average_mmi_loss',
                curr_batch_mmi_loss / (curr_batch_frames + 0.001),
                global_batch_idx_train)

            tb_writer.add_scalar(
                'train/current_batch_average_mbr_loss',
                curr_batch_mbr_loss / (curr_batch_frames + 0.001),
                global_batch_idx_train)
            # if batch_idx >= 10:
            #    print("Exiting early to get profile info")
            #    sys.exit(0)

        if batch_idx > 0 and batch_idx % 3000 == 0:
            total_valid_loss, total_valid_mmi_loss, total_valid_mbr_loss, \
                    total_valid_frames, total_valid_all_frames = get_validation_loss(
                dataloader=valid_dataloader,
                model=model,
                P=P,
                device=device,
                graph_compiler=graph_compiler)
            valid_average_loss = total_valid_loss / total_valid_frames
            model.train()
            logging.info('Validation average loss: {:.6f}, '
                         'Validation average mmi loss: {:.6f}, '
                         'Validation average mbr loss: {:.6f} '
                         'over {} frames ({:.1f}% kept)'.format(
                             total_valid_loss / total_valid_frames,
                             total_valid_mmi_loss / total_valid_frames,
                             total_valid_mbr_loss / total_valid_frames,
                             total_valid_frames, 100.0 * total_valid_frames /
                             total_valid_all_frames))

            tb_writer.add_scalar('train/global_valid_average_loss',
                                 total_valid_loss / total_valid_frames,
                                 global_batch_idx_train)

            tb_writer.add_scalar('train/global_valid_average_mmi_loss',
                                 total_valid_mmi_loss / total_valid_frames,
                                 global_batch_idx_train)

            tb_writer.add_scalar('train/global_valid_average_mbr_loss',
                                 total_valid_mbr_loss / total_valid_frames,
                                 global_batch_idx_train)

        prev_timestamp = datetime.now()
    return total_loss / total_frames, valid_average_loss, global_batch_idx_train
Example #7
0
    def __init__(self,
                 L_inv: k2.Fsa,
                 L_disambig: k2.Fsa,
                 G: k2.Fsa,
                 phones: k2.SymbolTable,
                 words: k2.SymbolTable,
                 device: torch.device,
                 oov: str = '<UNK>'):
        '''
        Args:
          L_inv:
            Its labels are words, while its aux_labels are phones.
          L_disambig:
            L with disambig symbols. Its labels are phones and aux_labels
            are words.
          G:
            The language model.
          phones:
            The phone symbol table.
          words:
            The word symbol table.
          device:
            The target device that all FSAs should be moved to.
          oov:
            Out of vocabulary word.
        '''

        L_inv = L_inv.to(device)
        G = G.to(device)

        if L_inv.properties & k2.fsa_properties.ARC_SORTED != 0:
            L_inv = k2.arc_sort(L_inv)

        if G.properties & k2.fsa_properties.ARC_SORTED != 0:
            G = k2.arc_sort(G)

        assert L_inv.requires_grad is False
        assert G.requires_grad is False

        assert oov in words

        L = L_inv.invert()
        L = k2.arc_sort(L)

        self.L_inv = L_inv
        self.L = L
        self.phones = phones
        self.words = words
        self.device = device
        self.oov_id = self.words[oov]

        phone_symbols = get_phone_symbols(phones)
        phone_symbols_with_blank = [0] + phone_symbols

        ctc_topo = k2.arc_sort(
            build_ctc_topo(phone_symbols_with_blank).to(device))
        assert ctc_topo.requires_grad is False

        self.ctc_topo = ctc_topo
        self.ctc_topo_inv = k2.arc_sort(ctc_topo.invert())

        lang_dir = Path('data/lang_nosp')
        if not (lang_dir / 'HLG_uni.pt').exists():
            logging.info("Composing (ctc_topo, L_disambig, G)")
            first_phone_disambig_id = find_first_disambig_symbol(phones)
            first_word_disambig_id = find_first_disambig_symbol(words)
            # decoding_graph is the result of composing (ctc_topo, L_disambig, G)
            decoding_graph = compile_HLG(
                L=L_disambig.to('cpu'),
                G=G.to('cpu'),
                H=ctc_topo.to('cpu'),
                labels_disambig_id_start=first_phone_disambig_id,
                aux_labels_disambig_id_start=first_word_disambig_id)
            torch.save(decoding_graph.as_dict(), lang_dir / 'HLG_uni.pt')
        else:
            logging.info("Loading pre-compiled HLG")
            decoding_graph = k2.Fsa.from_dict(
                torch.load(lang_dir / 'HLG_uni.pt'))

        assert hasattr(decoding_graph, 'phones')

        self.decoding_graph = decoding_graph.to(device)
def train_one_epoch(dataloader: torch.utils.data.DataLoader,
                    valid_dataloader: torch.utils.data.DataLoader,
                    model: AcousticModel, ali_model: Optional[AcousticModel],
                    P: k2.Fsa, device: torch.device,
                    graph_compiler: MmiTrainingGraphCompiler,
                    optimizer: torch.optim.Optimizer, accum_grad: int,
                    den_scale: float, att_rate: float, current_epoch: int,
                    tb_writer: SummaryWriter, num_epochs: int,
                    global_batch_idx_train: int, world_size: int,
                    scaler: GradScaler):
    """One epoch training and validation.

    Args:
        dataloader: Training dataloader
        valid_dataloader: Validation dataloader
        model: Acoustic model to be trained
        P: An FSA representing the bigram phone LM
        device: Training device, torch.device("cpu") or torch.device("cuda", device_id)
        graph_compiler: MMI training graph compiler
        optimizer: Training optimizer
        accum_grad: Number of gradient accumulation
        den_scale: Denominator scale in mmi loss
        att_rate: Attention loss rate, final loss is att_rate * att_loss + (1-att_rate) * other_loss
        current_epoch: current training epoch, for logging only
        tb_writer: tensorboard SummaryWriter
        num_epochs: total number of training epochs, for logging only
        global_batch_idx_train: global training batch index before this epoch, for logging only

    Returns:
        A tuple of 3 scalar:  (total_objf / total_frames, valid_average_objf, global_batch_idx_train)
        - `total_objf / total_frames` is the average training loss
        - `valid_average_objf` is the average validation loss
        - `global_batch_idx_train` is the global training batch index after this epoch
    """
    total_objf, total_frames, total_all_frames = 0., 0., 0.
    valid_average_objf = float('inf')
    time_waiting_for_batch = 0
    forward_count = 0
    prev_timestamp = datetime.now()

    model.train()
    for batch_idx, batch in enumerate(dataloader):
        forward_count += 1
        if forward_count == accum_grad:
            is_update = True
            forward_count = 0
        else:
            is_update = False

        global_batch_idx_train += 1
        timestamp = datetime.now()
        time_waiting_for_batch += (timestamp - prev_timestamp).total_seconds()

        if forward_count == 1 or accum_grad == 1:
            P.set_scores_stochastic_(model.module.P_scores)
            assert P.requires_grad is True

        curr_batch_objf, curr_batch_frames, curr_batch_all_frames = get_objf(
            batch=batch,
            model=model,
            ali_model=ali_model,
            P=P,
            device=device,
            graph_compiler=graph_compiler,
            is_training=True,
            is_update=is_update,
            accum_grad=accum_grad,
            den_scale=den_scale,
            att_rate=att_rate,
            tb_writer=tb_writer,
            global_batch_idx_train=global_batch_idx_train,
            optimizer=optimizer,
            scaler=scaler)

        total_objf += curr_batch_objf
        total_frames += curr_batch_frames
        total_all_frames += curr_batch_all_frames

        if batch_idx % 10 == 0:
            logging.info(
                'batch {}, epoch {}/{} '
                'global average objf: {:.6f} over {} '
                'frames ({:.1f}% kept), current batch average objf: {:.6f} over {} frames ({:.1f}% kept) '
                'avg time waiting for batch {:.3f}s'.format(
                    batch_idx, current_epoch, num_epochs,
                    total_objf / total_frames, total_frames,
                    100.0 * total_frames / total_all_frames,
                    curr_batch_objf / (curr_batch_frames + 0.001),
                    curr_batch_frames,
                    100.0 * curr_batch_frames / curr_batch_all_frames,
                    time_waiting_for_batch / max(1, batch_idx)))

            if tb_writer is not None:
                tb_writer.add_scalar('train/global_average_objf',
                                     total_objf / total_frames,
                                     global_batch_idx_train)

                tb_writer.add_scalar(
                    'train/current_batch_average_objf',
                    curr_batch_objf / (curr_batch_frames + 0.001),
                    global_batch_idx_train)
            # if batch_idx >= 10:
            #    print("Exiting early to get profile info")
            #    sys.exit(0)

        if batch_idx > 0 and batch_idx % 200 == 0:
            total_valid_objf, total_valid_frames, total_valid_all_frames = get_validation_objf(
                dataloader=valid_dataloader,
                model=model,
                ali_model=ali_model,
                P=P,
                device=device,
                graph_compiler=graph_compiler,
                scaler=scaler)
            if world_size > 1:
                s = torch.tensor([
                    total_valid_objf, total_valid_frames,
                    total_valid_all_frames
                ]).to(device)

                dist.all_reduce(s, op=dist.ReduceOp.SUM)
                total_valid_objf, total_valid_frames, total_valid_all_frames = s.cpu(
                ).tolist()

            valid_average_objf = total_valid_objf / total_valid_frames
            model.train()
            logging.info(
                'Validation average objf: {:.6f} over {} frames ({:.1f}% kept)'
                .format(valid_average_objf, total_valid_frames,
                        100.0 * total_valid_frames / total_valid_all_frames))

            if tb_writer is not None:
                tb_writer.add_scalar('train/global_valid_average_objf',
                                     valid_average_objf,
                                     global_batch_idx_train)
                model.module.write_tensorboard_diagnostics(
                    tb_writer, global_step=global_batch_idx_train)
        prev_timestamp = datetime.now()
    return total_objf / total_frames, valid_average_objf, global_batch_idx_train
Example #9
0
def rescore_with_whole_lattice(lats: k2.Fsa,
                               G_with_epsilon_loops: k2.Fsa) -> k2.Fsa:
    '''Use whole lattice to rescore.

    Args:
      lats:
        An FsaVec It can be the output of `k2.intersect_dense_pruned`.
      G_with_epsilon_loops:
        An FsaVec representing the language model (LM). Note that it
        is an FsaVec, but it contains only one Fsa.
    '''
    assert len(lats.shape) == 3
    assert hasattr(lats, 'lm_scores')
    assert G_with_epsilon_loops.shape == (1, None, None)

    device = lats.device
    lats.scores = lats.scores - lats.lm_scores
    # Now, lats.scores contains only am_scores

    # inverted_lats has word IDs as labels.
    # Its aux_labels are phone IDs, which is a ragged tensor k2.RaggedInt
    inverted_lats = k2.invert(lats)
    num_seqs = lats.shape[0]
    inverted_lats_with_epsilon_loops = k2.add_epsilon_self_loops(inverted_lats)

    b_to_a_map = torch.zeros(num_seqs, device=device, dtype=torch.int32)
    try:
        rescoring_lats = k2.intersect_device(G_with_epsilon_loops,
                                             inverted_lats_with_epsilon_loops,
                                             b_to_a_map,
                                             sorted_match_a=True)
    except RuntimeError as e:
        print(f'Caught exception:\n{e}\n')
        print(f'Number of FSAs: {inverted_lats.shape[0]}')
        print('num_arcs before pruning: ',
              inverted_lats_with_epsilon_loops.arcs.num_elements())

        # NOTE(fangjun): The choice of the threshold 0.01 is arbitrary here
        # to avoid OOM. We may need to fine tune it.
        inverted_lats = k2.prune_on_arc_post(inverted_lats, 0.001, True)
        inverted_lats_with_epsilon_loops = k2.add_epsilon_self_loops(
            inverted_lats)
        print('num_arcs after pruning: ',
              inverted_lats_with_epsilon_loops.arcs.num_elements())

        rescoring_lats = k2.intersect_device(G_with_epsilon_loops,
                                             inverted_lats_with_epsilon_loops,
                                             b_to_a_map,
                                             sorted_match_a=True)

    rescoring_lats = k2.top_sort(k2.connect(
        rescoring_lats.to('cpu'))).to(device)
    inverted_rescoring_lats = k2.invert(rescoring_lats)
    # inverted rescoring_lats has phone IDs as labels
    # and word IDs as aux_labels.

    inverted_rescoring_lats = k2.remove_epsilon_self_loops(
        inverted_rescoring_lats)
    best_paths = k2.shortest_path(inverted_rescoring_lats,
                                  use_double_scores=True)
    return best_paths
Example #10
0
def rescore_with_whole_lattice(lats: k2.Fsa, G_with_epsilon_loops: k2.Fsa,
                               lm_scale_list: List[float]
                              ) -> Dict[str, k2.Fsa]:
    '''Use whole lattice to rescore.

    Args:
      lats:
        An FsaVec It can be the output of `k2.intersect_dense_pruned`.
      G_with_epsilon_loops:
        An FsaVec representing the language model (LM). Note that it
        is an FsaVec, but it contains only one Fsa.
      lm_scale_list:
        A list containing lm_scale values.
    Returns:
      A dict of FsaVec, whose key is a lm_scale and the value represents the
      best decoding path for each sequence in the lattice.
    '''
    assert len(lats.shape) == 3
    assert hasattr(lats, 'lm_scores')
    assert G_with_epsilon_loops.shape == (1, None, None)

    device = lats.device
    lats.scores = lats.scores - lats.lm_scores
    # We will use lm_scores from G, so remove lats.lm_scores here
    del lats.lm_scores
    assert hasattr(lats, 'lm_scores') is False

    #  lats.scores = scores / lm_scale
    # Now, lats.scores contains only am_scores

    # inverted_lats has word IDs as labels.
    # Its aux_labels are phone IDs, which is a ragged tensor k2.RaggedInt
    inverted_lats = k2.invert(lats)
    num_seqs = lats.shape[0]

    b_to_a_map = torch.zeros(num_seqs, device=device, dtype=torch.int32)
    try:
        rescoring_lats = k2.intersect_device(G_with_epsilon_loops,
                                             inverted_lats,
                                             b_to_a_map,
                                             sorted_match_a=True)
    except RuntimeError as e:
        print(f'Caught exception:\n{e}\n')
        print(f'Number of FSAs: {inverted_lats.shape[0]}')
        print('num_arcs before pruning: ', inverted_lats.arcs.num_elements())

        # NOTE(fangjun): The choice of the threshold 0.01 is arbitrary here
        # to avoid OOM. We may need to fine tune it.
        inverted_lats = k2.prune_on_arc_post(inverted_lats, 0.001, True)
        print('num_arcs after pruning: ', inverted_lats.arcs.num_elements())

        rescoring_lats = k2.intersect_device(G_with_epsilon_loops,
                                             inverted_lats,
                                             b_to_a_map,
                                             sorted_match_a=True)

    rescoring_lats = k2.top_sort(k2.connect(rescoring_lats.to('cpu')).to(device))

    # inv_lats has phone IDs as labels
    # and word IDs as aux_labels.
    inv_lats = k2.invert(rescoring_lats)

    ans = dict()
    #
    # The following implements
    # scores = (scores - lm_scores)/lm_scale + lm_scores
    #        = scores/lm_scale + lm_scores*(1 - 1/lm_scale)
    #
    saved_scores = inv_lats.scores.clone()
    for lm_scale in lm_scale_list:
        am_scores = saved_scores - inv_lats.lm_scores
        am_scores /= lm_scale
        inv_lats.scores = am_scores + inv_lats.lm_scores

        best_paths = k2.shortest_path(inv_lats, use_double_scores=True)
        key = f'lm_scale_{lm_scale}'
        ans[key] = best_paths
    return ans
Example #11
0
def train_one_epoch(dataloader: torch.utils.data.DataLoader,
                    valid_dataloader: torch.utils.data.DataLoader,
                    model: AcousticModel, P: k2.Fsa, device: torch.device,
                    graph_compiler: MmiTrainingGraphCompiler,
                    optimizer: torch.optim.Optimizer, current_epoch: int,
                    tb_writer: Optional[SummaryWriter], num_epochs: int,
                    global_batch_idx_train: int):
    total_objf, total_frames, total_all_frames = 0., 0., 0.
    valid_average_objf = float('inf')
    time_waiting_for_batch = 0
    prev_timestamp = datetime.now()

    model.train()
    for batch_idx, batch in enumerate(dataloader):
        global_batch_idx_train += 1
        timestamp = datetime.now()
        time_waiting_for_batch += (timestamp - prev_timestamp).total_seconds()

        if isinstance(model, DDP):
            P.set_scores_stochastic_(model.module.P_scores)
        else:
            P.set_scores_stochastic_(model.P_scores)
        assert P.is_cpu
        assert P.requires_grad is True

        curr_batch_objf, curr_batch_frames, curr_batch_all_frames = get_objf(
            batch=batch,
            model=model,
            P=P,
            device=device,
            graph_compiler=graph_compiler,
            is_training=True,
            tb_writer=tb_writer,
            global_batch_idx_train=global_batch_idx_train,
            optimizer=optimizer)

        total_objf += curr_batch_objf
        total_frames += curr_batch_frames
        total_all_frames += curr_batch_all_frames

        if batch_idx % 10 == 0 and dist.get_rank() == 0:
            logging.info(
                'batch {}, epoch {}/{} '
                'global average objf: {:.6f} over {} '
                'frames ({:.1f}% kept), current batch average objf: {:.6f} over {} frames ({:.1f}% kept) '
                'avg time waiting for batch {:.3f}s'.format(
                    batch_idx, current_epoch, num_epochs,
                    total_objf / total_frames, total_frames,
                    100.0 * total_frames / total_all_frames,
                    curr_batch_objf / (curr_batch_frames + 0.001),
                    curr_batch_frames,
                    100.0 * curr_batch_frames / curr_batch_all_frames,
                    time_waiting_for_batch / max(1, batch_idx)))
            tb_writer.add_scalar('train/global_average_objf',
                                 total_objf / total_frames,
                                 global_batch_idx_train)
            tb_writer.add_scalar('train/current_batch_average_objf',
                                 curr_batch_objf / (curr_batch_frames + 0.001),
                                 global_batch_idx_train)
            # if batch_idx >= 10:
            #    print("Exiting early to get profile info")
            #    sys.exit(0)

        if batch_idx > 0 and batch_idx % 1000 == 0:
            total_valid_objf, total_valid_frames, total_valid_all_frames = get_validation_objf(
                dataloader=valid_dataloader,
                model=model,
                P=P,
                device=device,
                graph_compiler=graph_compiler)
            # Synchronize the loss to the master node so that we display it correctly.
            # dist.reduce performs sum reduction by default.
            valid_average_objf = total_valid_objf / total_valid_frames
            model.train()
            if dist.get_rank() == 0:
                logging.info(
                    'Validation average objf: {:.6f} over {} frames ({:.1f}% kept)'
                    .format(
                        valid_average_objf, total_valid_frames,
                        100.0 * total_valid_frames / total_valid_all_frames))
            if tb_writer is not None:
                tb_writer.add_scalar('train/global_valid_average_objf',
                                     valid_average_objf,
                                     global_batch_idx_train)
                (model.module if isinstance(model, DDP) else
                 model).write_tensorboard_diagnostics(
                     tb_writer, global_step=global_batch_idx_train)
        prev_timestamp = datetime.now()
    return total_objf / total_frames, valid_average_objf, global_batch_idx_train
Example #12
0
def compile_HLG(L: Fsa, G: Fsa, H: Fsa, labels_disambig_id_start: int,
                aux_labels_disambig_id_start: int) -> Fsa:
    """
    Creates a decoding graph using a lexicon fst ``L`` and language model fsa ``G``.
    Involves arc sorting, intersection, determinization, removal of disambiguation symbols
    and adding epsilon self-loops.

    Args:
        L:
            An ``Fsa`` that represents the lexicon (L), i.e. has phones as ``symbols``
                and words as ``aux_symbols``.
        G:
            An ``Fsa`` that represents the language model (G), i.e. it's an acceptor
            with words as ``symbols``.
        H:  An ``Fsa`` that represents a specific topology used to convert the network
            outputs to a sequence of phones.
            Typically, it's a CTC topology fst, in which when 0 appears on the left
            side, it represents the blank symbol; when it appears on the right side,
            it indicates an epsilon.
        labels_disambig_id_start:
            An integer ID corresponding to the first disambiguation symbol in the
            phonetic alphabet.
        aux_labels_disambig_id_start:
            An integer ID corresponding to the first disambiguation symbol in the
            words vocabulary.
    :return:
    """
    L = k2.arc_sort(L)
    G = k2.arc_sort(G)
    # Attach a new attribute `lm_scores` so that we can recover
    # the `am_scores` later.
    # The scores on an arc consists of two parts:
    #  scores = am_scores + lm_scores
    # NOTE: we assume that both kinds of scores are in log-space.
    G.lm_scores = G.scores.clone()

    logging.info("Intersecting L and G")
    LG = k2.compose(L, G)
    logging.info(f'LG shape = {LG.shape}')
    logging.info("Connecting L*G")
    LG = k2.connect(LG)
    logging.info(f'LG shape = {LG.shape}')
    logging.info("Determinizing L*G")
    LG = k2.determinize(LG)
    logging.info(f'LG shape = {LG.shape}')
    logging.info("Connecting det(L*G)")
    LG = k2.connect(LG)
    logging.info(f'LG shape = {LG.shape}')
    logging.info("Removing disambiguation symbols on L*G")
    LG.labels[LG.labels >= labels_disambig_id_start] = 0
    if isinstance(LG.aux_labels, torch.Tensor):
        LG.aux_labels[LG.aux_labels >= aux_labels_disambig_id_start] = 0
    else:
        LG.aux_labels.values()[
            LG.aux_labels.values() >= aux_labels_disambig_id_start] = 0
    logging.info("Removing epsilons")
    LG = k2.remove_epsilon(LG)
    logging.info(f'LG shape = {LG.shape}')
    logging.info("Connecting rm-eps(det(L*G))")
    LG = k2.connect(LG)
    logging.info(f'LG shape = {LG.shape}')
    LG.aux_labels = k2.ragged.remove_values_eq(LG.aux_labels, 0)

    logging.info("Arc sorting LG")
    LG = k2.arc_sort(LG)

    logging.info("Composing ctc_topo LG")
    HLG = k2.compose(H, LG, inner_labels='phones')

    logging.info("Connecting LG")
    HLG = k2.connect(HLG)

    logging.info("Arc sorting LG")
    HLG = k2.arc_sort(HLG)
    logging.info(
        f'LG is arc sorted: {(HLG.properties & k2.fsa_properties.ARC_SORTED) != 0}'
    )

    return HLG
Example #13
0
def train_one_epoch(dataloader: torch.utils.data.DataLoader,
                    valid_dataloader: torch.utils.data.DataLoader,
                    model: AcousticModel, P: k2.Fsa, device: torch.device,
                    graph_compiler: MmiTrainingGraphCompiler,
                    optimizer: torch.optim.Optimizer, current_epoch: int,
                    tb_writer: SummaryWriter, num_epochs: int,
                    global_batch_idx_train: int, global_batch_idx_valid: int):
    total_objf, total_frames, total_all_frames = 0., 0., 0.
    time_waiting_for_batch = 0
    prev_timestamp = datetime.now()

    model.train()
    ragged_shape = P.arcs.shape().to(device)
    for batch_idx, batch in enumerate(dataloader):
        global_batch_idx_train += 1
        timestamp = datetime.now()
        time_waiting_for_batch += (timestamp - prev_timestamp).total_seconds()

        P.set_scores_stochastic_(model.P_scores)
        assert P.is_cpu
        assert P.requires_grad is True

        curr_batch_objf, curr_batch_frames, curr_batch_all_frames = \
            get_objf(batch, model, P, device, graph_compiler, True, optimizer)

        total_objf += curr_batch_objf
        total_frames += curr_batch_frames
        total_all_frames += curr_batch_all_frames

        if batch_idx % 10 == 0:
            logging.info(
                'batch {}, epoch {}/{} '
                'global average objf: {:.6f} over {} '
                'frames ({:.1f}% kept), current batch average objf: {:.6f} over {} frames ({:.1f}% kept) '
                'avg time waiting for batch {:.3f}s'.format(
                    batch_idx, current_epoch, num_epochs,
                    total_objf / total_frames, total_frames,
                    100.0 * total_frames / total_all_frames,
                    curr_batch_objf / (curr_batch_frames + 0.001),
                    curr_batch_frames,
                    100.0 * curr_batch_frames / curr_batch_all_frames,
                    time_waiting_for_batch / max(1, batch_idx)))

            tb_writer.add_scalar('train/global_average_objf',
                                 total_objf / total_frames,
                                 global_batch_idx_train)

            tb_writer.add_scalar('train/current_batch_average_objf',
                                 curr_batch_objf / (curr_batch_frames + 0.001),
                                 global_batch_idx_train)
            # if batch_idx >= 10:
            #    print("Exiting early to get profile info")
            #    sys.exit(0)

        if batch_idx > 0 and batch_idx % 200 == 0:
            total_valid_objf, total_valid_frames, total_valid_all_frames = get_validation_objf(
                dataloader=valid_dataloader,
                model=model,
                P=P,
                device=device,
                graph_compiler=graph_compiler)
            global_batch_idx_valid += 1
            model.train()
            logging.info(
                'Validation average objf: {:.6f} over {} frames ({:.1f}% kept)'
                .format(total_valid_objf / total_valid_frames,
                        total_valid_frames,
                        100.0 * total_valid_frames / total_valid_all_frames))

            tb_writer.add_scalar('train/global_valid_average_objf',
                                 total_valid_objf / total_valid_frames,
                                 global_batch_idx_valid)
        prev_timestamp = datetime.now()
    return total_objf / total_frames
Example #14
0
def compile_LG(L: Fsa, G: Fsa, ctc_topo_inv: Fsa,
               labels_disambig_id_start: int,
               aux_labels_disambig_id_start: int) -> Fsa:
    """
    Creates a decoding graph using a lexicon fst ``L`` and language model fsa ``G``.
    Involves arc sorting, intersection, determinization, removal of disambiguation symbols
    and adding epsilon self-loops.

    Args:
        L:
            An ``Fsa`` that represents the lexicon (L), i.e. has phones as ``symbols``
                and words as ``aux_symbols``.
        G:
            An ``Fsa`` that represents the language model (G), i.e. it's an acceptor
            with words as ``symbols``.
        ctc_topo_inv:  Epsilons are in `aux_labels` and `labels` contain phone IDs.
        labels_disambig_id_start:
            An integer ID corresponding to the first disambiguation symbol in the
            phonetic alphabet.
        aux_labels_disambig_id_start:
            An integer ID corresponding to the first disambiguation symbol in the
            words vocabulary.
    :return:
    """
    L_inv = k2.arc_sort(L.invert_())
    G = k2.arc_sort(G)
    logging.debug("Intersecting L and G")
    LG = k2.intersect(L_inv, G)
    logging.debug(f'LG shape = {LG.shape}')
    logging.debug("Connecting L*G")
    LG = k2.connect(LG).invert_()
    logging.debug(f'LG shape = {LG.shape}')
    logging.debug("Determinizing L*G")
    LG = k2.determinize(LG)
    logging.debug(f'LG shape = {LG.shape}')
    logging.debug("Connecting det(L*G)")
    LG = k2.connect(LG)
    logging.debug(f'LG shape = {LG.shape}')
    logging.debug("Removing disambiguation symbols on L*G")
    LG.labels[LG.labels >= labels_disambig_id_start] = 0
    if isinstance(LG.aux_labels, torch.Tensor):
        LG.aux_labels[LG.aux_labels >= aux_labels_disambig_id_start] = 0
    else:
        LG.aux_labels.values()[
            LG.aux_labels.values() >= aux_labels_disambig_id_start] = 0
    logging.debug("Removing epsilons")
    LG = k2.remove_epsilons_iterative_tropical(LG)
    logging.debug(f'LG shape = {LG.shape}')
    logging.debug("Connecting rm-eps(det(L*G))")
    LG = k2.connect(LG)
    logging.debug(f'LG shape = {LG.shape}')
    LG.aux_labels = k2.ragged.remove_values_eq(LG.aux_labels, 0)

    logging.debug("Arc sorting")
    LG = k2.arc_sort(LG)

    logging.debug("Composing")
    LG = k2.compose(ctc_topo_inv, LG)

    logging.debug("Connecting")
    LG = k2.connect(LG)

    logging.debug("Arc sorting")
    LG = k2.arc_sort(LG)
    logging.debug(
        f'LG is arc sorted: {(LG.properties & k2.fsa_properties.ARC_SORTED) != 0}'
    )
    return LG