Ejemplo n.º 1
0
    def test_create_sparse(self):
        s = '''
            0 1 10 0.1
            0 1 11 0.2
            1 2 20 0.3
            2 3 21 0.4
            2 3 24 0.5
            3 4 -1 0.6
            4
        '''

        for device in self.devices:
            fsa = k2.Fsa.from_str(s).to(device)
            fsa.phones = torch.tensor([10, 11, 20, 21, 24, -1],
                                      dtype=torch.int32,
                                      device=device)
            fsa.seqframes = torch.tensor([0, 0, 1, 2, 2, 3],
                                         dtype=torch.int32,
                                         device=device)
            fsa.requires_grad_(True)

            tensor = k2.create_sparse(rows=fsa.seqframes,
                                      cols=fsa.phones,
                                      values=fsa.scores,
                                      size=(6, 25),
                                      min_col_index=0)
            assert tensor.device == device
            assert tensor.is_sparse
            assert torch.allclose(tensor._indices()[0],
                                  fsa.seqframes[:-1].to(torch.int64))
            assert torch.allclose(tensor._indices()[1],
                                  fsa.phones[:-1].to(torch.int64))
            assert torch.allclose(tensor._values(), fsa.scores[:-1])
            assert tensor.requires_grad == fsa.requires_grad
            assert tensor.dtype == fsa.scores.dtype
Ejemplo n.º 2
0
def create_sparse_wrapped(
    indices: List[torch.Tensor],
    values: torch.Tensor,
    size: Optional[Union[Tuple[int, int], Tuple[int, int, int]]] = None,
    min_col_index: Optional[int] = None,
) -> torch.Tensor:
    """Wraps up k2.create_sparse to create 2- or 3-dimensional sparse tensors.
    """
    assert size is None or len(indices) == len(size)

    if len(indices) == 2:
        return k2.create_sparse(
            rows=indices[0], cols=indices[1], values=values, size=size, min_col_index=min_col_index,
        )
    elif len(indices) == 3:
        assert indices[0].ndim == indices[1].ndim == indices[2].ndim == 1
        assert indices[0].numel() == indices[1].numel() == indices[2].numel() == values.numel()

        if min_col_index is not None:
            assert isinstance(min_col_index, int)
            kept_indices = indices[-1] >= min_col_index
            indices = [i[kept_indices] for i in indices]
            values = values[kept_indices]
        if size is not None:
            return torch.sparse_coo_tensor(
                torch.stack(indices), values, size=size, device=values.device, requires_grad=values.requires_grad,
            )
        else:
            return torch.sparse_coo_tensor(
                torch.stack(indices), values, device=values.device, requires_grad=values.requires_grad,
            )
    else:
        raise ValueError(f"len(indices) = {len(indices)}")
Ejemplo n.º 3
0
def get_loss(batch: Dict,
             model: AcousticModel,
             P: k2.Fsa,
             device: torch.device,
             graph_compiler: MmiMbrTrainingGraphCompiler,
             is_training: bool,
             optimizer: Optional[torch.optim.Optimizer] = None):
    assert P.device == device
    feature = batch['features']
    supervisions = batch['supervisions']
    supervision_segments = torch.stack(
        (supervisions['sequence_idx'],
         torch.floor_divide(supervisions['start_frame'],
                            model.subsampling_factor),
         torch.floor_divide(supervisions['num_frames'],
                            model.subsampling_factor)), 1).to(torch.int32)
    indices = torch.argsort(supervision_segments[:, 2], descending=True)
    supervision_segments = supervision_segments[indices]

    texts = supervisions['text']
    texts = [texts[idx] for idx in indices]
    assert feature.ndim == 3
    # print(supervision_segments[:, 1] + supervision_segments[:, 2])

    feature = feature.to(device)
    # at entry, feature is [N, T, C]
    feature = feature.permute(0, 2, 1)  # now feature is [N, C, T]
    if is_training:
        nnet_output = model(feature)
    else:
        with torch.no_grad():
            nnet_output = model(feature)

    # nnet_output is [N, C, T]
    nnet_output = nnet_output.permute(0, 2, 1)  # now nnet_output is [N, T, C]

    if is_training:
        num_graph, den_graph, decoding_graph = graph_compiler.compile(texts, P)
    else:
        with torch.no_grad():
            num_graph, den_graph, decoding_graph = graph_compiler.compile(
                texts, P)

    assert num_graph.requires_grad == is_training
    assert den_graph.requires_grad is False
    assert decoding_graph.requires_grad is False
    assert len(
        decoding_graph.shape) == 2 or decoding_graph.shape == (1, None, None)

    num_graph = num_graph.to(device)
    den_graph = den_graph.to(device)

    decoding_graph = decoding_graph.to(device)

    dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments)
    assert nnet_output.device == device

    num_lats = k2.intersect_dense(num_graph,
                                  dense_fsa_vec,
                                  10.0,
                                  seqframe_idx_name='seqframe_idx')

    mbr_lats = k2.intersect_dense_pruned(decoding_graph,
                                         dense_fsa_vec,
                                         20.0,
                                         7.0,
                                         30,
                                         10000,
                                         seqframe_idx_name='seqframe_idx')

    if True:
        # WARNING: the else branch is not working at present (the total loss is not stable)
        den_lats = k2.intersect_dense(den_graph, dense_fsa_vec, 10.0)
    else:
        # in this case, we can remove den_graph
        den_lats = mbr_lats

    num_tot_scores = num_lats.get_tot_scores(log_semiring=True,
                                             use_double_scores=True)

    den_tot_scores = den_lats.get_tot_scores(log_semiring=True,
                                             use_double_scores=True)

    if id(den_lats) == id(mbr_lats):
        # Some entries in den_tot_scores may be -inf.
        # The corresponding sequences are discarded/ignored.
        finite_indexes = torch.isfinite(den_tot_scores)
        den_tot_scores = den_tot_scores[finite_indexes]
        num_tot_scores = num_tot_scores[finite_indexes]
    else:
        finite_indexes = None

    tot_scores = num_tot_scores - den_scale * den_tot_scores

    (tot_score, tot_frames,
     all_frames) = get_tot_objf_and_num_frames(tot_scores,
                                               supervision_segments[:, 2],
                                               finite_indexes)

    num_rows = dense_fsa_vec.scores.shape[0]
    num_cols = dense_fsa_vec.scores.shape[1] - 1
    mbr_num_sparse = k2.create_sparse(rows=num_lats.seqframe_idx,
                                      cols=num_lats.phones,
                                      values=num_lats.get_arc_post(True,
                                                                   True).exp(),
                                      size=(num_rows, num_cols),
                                      min_col_index=0)

    mbr_den_sparse = k2.create_sparse(rows=mbr_lats.seqframe_idx,
                                      cols=mbr_lats.phones,
                                      values=mbr_lats.get_arc_post(True,
                                                                   True).exp(),
                                      size=(num_rows, num_cols),
                                      min_col_index=0)
    # NOTE: Due to limited support of PyTorch's autograd for sparse tensors,
    # we cannot use (mbr_num_sparse - mbr_den_sparse) here
    #
    # The following works only for torch >= 1.7.0
    mbr_loss = torch.sparse.sum(
        k2.sparse.abs((mbr_num_sparse + (-mbr_den_sparse)).coalesce()))

    mmi_loss = -tot_score

    total_loss = mmi_loss + mbr_loss

    if is_training:
        optimizer.zero_grad()
        total_loss.backward()
        clip_grad_value_(model.parameters(), 5.0)
        optimizer.step()

    ans = (
        mmi_loss.detach().cpu().item(),
        mbr_loss.detach().cpu().item(),
        tot_frames.cpu().item(),
        all_frames.cpu().item(),
    )
    return ans