コード例 #1
0
    def test_1d_non_contiguous(self):
        devices = [torch.device('cpu')]
        if torch.cuda.is_available():
            devices.append(torch.device('cuda', 0))
        for device in devices:
            for dtype in [torch.int32, torch.int64]:
                num_rows = torch.randint(20, 2000, size=(1,)).item()
                stride = torch.randint(2, num_rows // 10 + 1, size=(1,)).item()

                a = torch.randint(-1000,
                                  1000,
                                  size=(num_rows,),
                                  dtype=dtype,
                                  device=device)
                a = a[::stride]
                assert a.is_contiguous() is False
                assert a.dtype == dtype

                b = torch.randint(-1,
                                  a.shape[0],
                                  size=(10000,),
                                  dtype=torch.int32,
                                  device=device)

                c = k2.index_select(a, b)

                padded_a = torch.cat([torch.tensor([0]).to(a), a])
                expected = padded_a.index_select(0, (b + 1).to(torch.int64))

                assert torch.allclose(c, expected)

            for dtype in [torch.float32, torch.float64]:
                num_rows = torch.randint(20, 2000, size=(1,)).item()
                stride = torch.randint(2, num_rows // 10 + 1, size=(1,)).item()

                a_contiguous = torch.rand(num_rows, dtype=dtype, device=device)
                a = a_contiguous[::stride]
                a.requires_grad_(True)
                assert a.is_contiguous() is False
                assert a.dtype == dtype

                b = torch.randint(-1,
                                  a.shape[0],
                                  size=(10000,),
                                  dtype=torch.int32,
                                  device=device)

                c = k2.index_select(a, b)

                new_a = a_contiguous[::stride]
                new_a.requires_grad_(True)
                padded_a = torch.cat([torch.tensor([0]).to(a), new_a])
                expected = padded_a.index_select(0, (b + 1).to(torch.int64))

                c.sum().backward()
                expected.sum().backward()

                assert torch.allclose(c, expected)
                assert torch.allclose(a.grad, new_a.grad)
コード例 #2
0
    def test_1d(self):
        devices = [torch.device('cpu')]
        if torch.cuda.is_available():
            devices.append(torch.device('cuda', 0))
        for device in devices:
            for dtype in [torch.int32, torch.int64]:
                num_rows = torch.randint(1, 2000, size=(1,)).item()
                a = torch.randint(-1000,
                                  1000,
                                  size=(num_rows,),
                                  dtype=dtype,
                                  device=device)
                assert a.is_contiguous()
                assert a.dtype == dtype

                num_indexes = torch.randint(1, 200000, size=(1,)).item()
                b = torch.randint(-1,
                                  num_rows,
                                  size=(num_indexes,),
                                  dtype=torch.int32,
                                  device=device)

                c = k2.index_select(a, b)
                assert c.dtype == a.dtype
                assert c.numel() == b.numel()

                padded_a = torch.cat([torch.tensor([0]).to(a), a])
                expected = padded_a.index_select(0, (b + 1).to(torch.int64))

                assert torch.allclose(c, expected)

            for dtype in [torch.float32, torch.float64]:
                num_rows = torch.randint(1, 2000, size=(1,)).item()
                a = torch.rand(num_rows,
                               dtype=dtype,
                               device=device,
                               requires_grad=True)
                assert a.is_contiguous()
                assert a.dtype == dtype

                num_indexes = torch.randint(1, 200000, size=(1,)).item()
                b = torch.randint(-1,
                                  num_rows,
                                  size=(num_indexes,),
                                  dtype=torch.int32,
                                  device=device)

                c = k2.index_select(a, b)
                assert c.dtype == a.dtype
                c.sum().backward()

                new_a = a.detach().requires_grad_(True)
                padded_a = torch.cat([torch.tensor([0]).to(new_a), new_a])
                expected = padded_a.index_select(0, (b + 1).to(torch.int64))
                expected.sum().backward()

                assert torch.allclose(c, expected)
                assert torch.allclose(a.grad, new_a.grad)
コード例 #3
0
    def test_1d(self):
        a = torch.tensor([
            10,  # 0
            -1,  # 1
            100,  # 2
            0,  # 3
            3,  # 4
            9,  # 5
            12,  # 6
        ]).to(torch.int32)

        b = torch.tensor([0, -1, 0, 2, 1, 3, 6, -1, 5, 6, 0,
                          2]).to(torch.int32)

        c = k2.index_select(a, b)
        assert c.dtype == a.dtype
        assert c.numel() == b.numel()

        padded_a = torch.cat([torch.tensor([0]).to(a), a])
        expected = padded_a.index_select(0, (b + 1).to(torch.int64))

        assert torch.allclose(c, expected)

        if torch.cuda.is_available():
            device = torch.device('cuda', 0)
            a = a.to(device)
            b = b.to(device)
            c = k2.index_select(a, b)
            assert c.dtype == a.dtype
            assert c.is_cuda
            assert c.numel() == b.numel()
            assert torch.allclose(c, expected.to(c))

        # now for float32
        a = a.to(torch.float32).requires_grad_(True)
        c = k2.index_select(a, b)
        assert c.dtype == a.dtype
        c.sum().backward()

        new_a = a.detach().requires_grad_(True)
        padded_a = torch.cat([torch.tensor([0]).to(new_a), new_a])
        expected = padded_a.index_select(0, (b + 1).to(torch.int64))
        expected.sum().backward()

        assert torch.allclose(a.grad, new_a.grad)

        # now for cpu
        a.grad = None
        c = k2.index_select(a.cpu(), b.cpu())
        assert c.dtype == a.dtype
        c.sum().backward()

        assert torch.allclose(a.grad, new_a.grad.to(a.grad))
コード例 #4
0
    def test_2d(self):
        devices = [torch.device('cpu')]
        if torch.cuda.is_available():
            devices.append(torch.device('cuda', 0))
        for device in devices:
            for dtype in [torch.int32, torch.int64]:
                num_rows = torch.randint(1, 2000, size=(1,)).item()
                num_cols = torch.randint(1, 2000, size=(1,)).item()
                a = torch.randint(-1000,
                                  1000,
                                  size=(num_rows, num_cols),
                                  dtype=dtype,
                                  device=device)
                b = torch.randint(-1,
                                  num_rows,
                                  size=(10000,),
                                  dtype=torch.int32,
                                  device=device)
                assert a.is_contiguous()
                assert a.dtype == dtype

                c = k2.index_select(a, b)

                assert c.dtype == a.dtype
                assert c.device == a.device
                assert c.shape[1] == a.shape[1]
                assert c.shape[0] == b.shape[0]

                padded_a = torch.cat([torch.zeros(1, a.shape[1]).to(a), a])
                expected = padded_a.index_select(0, (b + 1).to(torch.int64))

                assert torch.allclose(c, expected)

            for dtype in [torch.float32, torch.float64]:
                src = a.to(dtype).requires_grad_(True)
                assert src.is_contiguous()
                assert b.is_contiguous()
                c = k2.index_select(src, b)

                assert c.dtype == src.dtype
                c.sum().backward()

                new_src = src.detach().requires_grad_(True)
                padded_src = torch.cat(
                    [torch.zeros(1, src.shape[1]).to(new_src), new_src])
                expected = padded_src.index_select(0, (b + 1).to(torch.int64))
                expected.sum().backward()

                assert torch.allclose(c, expected)
                assert torch.allclose(src.grad, new_src.grad)
コード例 #5
0
ファイル: index_select_test.py プロジェクト: qijiaxing/k2
    def test_1d_non_contiguous(self):
        a = torch.arange(20).to(torch.int32)[::2]
        b = torch.tensor([
            -1, -2, -1, -2, 1, -2, 2, -2, 0, -2, 1, -2, 5, -2, 3, -2, 9, -2,
            -1, -2, 8, -2, 7, -2, 7, -2, 6, -2
        ]).to(torch.int32)[::2]
        padded_a = torch.cat([torch.tensor([0]).to(a), a])
        assert a.is_contiguous() is False
        assert b.is_contiguous() is False
        c = k2.index_select(a, b)
        expected = padded_a.index_select(0, (b + 1).to(torch.int64))
        assert torch.allclose(c, expected)

        a = a.to(torch.float32).requires_grad_(True)
        c = k2.index_select(a, b)

        new_a = a.detach().clone().requires_grad_(True)
        padded_a = torch.cat([torch.tensor([0]).to(a), new_a])
        expected = padded_a.index_select(0, (b + 1).to(torch.int64))

        assert torch.allclose(c, expected.to(c))

        c.sum().backward()
        expected.sum().backward()

        assert torch.allclose(a.grad, new_a.grad)

        # now for cuda
        if torch.cuda.is_available():
            device = torch.device('cuda', 0)
            b = b.to(device)

            a.requires_grad_(False)
            a = a.to(device).requires_grad_(True)
            c = k2.index_select(a, b)

            new_a.requires_grad_(False)
            new_a = new_a.to(device).requires_grad_(True)
            padded_a = torch.cat([torch.tensor([0]).to(a), new_a])
            expected = padded_a.index_select(0, (b + 1).to(torch.int64))

            assert torch.allclose(c, expected.to(c))

            c.sum().backward()
            expected.sum().backward()

            assert torch.allclose(a.grad, new_a.grad)
コード例 #6
0
    def test_2d_non_contiguous(self):
        devices = [torch.device('cpu')]
        if torch.cuda.is_available():
            devices.append(torch.device('cuda', 0))
        for device in devices:
            num_rows = torch.randint(1, 2000, size=(1,)).item()
            num_cols = torch.randint(1, 2000, size=(1,)).item()
            stride = torch.randint(2, num_rows // 10, size=(1,)).item()
            a = torch.randint(-1000,
                              1000,
                              size=(num_rows, num_cols),
                              dtype=torch.int32,
                              device=device).contiguous()
            a = a[::stride]
            num_rows = a.shape[0]
            b = torch.randint(-1,
                              num_rows,
                              size=(10000,),
                              dtype=torch.int32,
                              device=device)
            assert a.is_contiguous() is False
            c = k2.index_select(a, b)
            assert c.dtype == a.dtype
            assert c.device == a.device
            assert c.shape[1] == a.shape[1]
            assert c.shape[0] == b.shape[0]

            padded_a = torch.cat([torch.zeros(1, a.shape[1]).to(a), a])
            expected = padded_a.index_select(0, (b + 1).to(torch.int64))

            # now for float32
            a = a.to(torch.float32).requires_grad_(True)
            assert a.is_contiguous()
            assert b.is_contiguous()
            c = k2.index_select(a, b)

            assert c.dtype == a.dtype
            c.sum().backward()

            new_a = a.detach().requires_grad_(True)
            padded_a = torch.cat([torch.zeros(1, a.shape[1]).to(new_a), new_a])
            expected = padded_a.index_select(0, (b + 1).to(torch.int64))
            expected.sum().backward()

            assert torch.allclose(a.grad, new_a.grad)
コード例 #7
0
    def test_1d_empty_index(self):
        for device in self.devices:
            for dtype in [torch.int32, torch.int64]:
                num_rows = torch.randint(0, 10, size=(1, )).item()
                a = torch.randint(-1000,
                                  1000,
                                  size=(num_rows, ),
                                  dtype=dtype,
                                  device=device)
                assert a.is_contiguous()
                assert a.dtype == dtype

                b = torch.empty(0, dtype=torch.int32, device=device)

                c = k2.index_select(a, b)
                assert c.dtype == a.dtype
                assert c.numel() == b.numel()

            for dtype in [torch.float32, torch.float64]:
                num_rows = torch.randint(0, 10, size=(1, )).item()
                a = torch.rand(num_rows,
                               dtype=dtype,
                               device=device,
                               requires_grad=True)
                assert a.is_contiguous()
                assert a.dtype == dtype

                b = torch.empty(0, dtype=torch.int32, device=device)

                c = k2.index_select(a, b)
                assert c.dtype == a.dtype
                c.sum().backward()

                new_a = a.detach().requires_grad_(True)
                padded_a = torch.cat([torch.tensor([0]).to(new_a), new_a])
                expected = padded_a.index_select(0, (b + 1).to(torch.int64))
                expected.sum().backward()

                assert torch.allclose(c, expected)
                assert torch.allclose(a.grad, new_a.grad)
コード例 #8
0
    def test_sort_sublist_descending(self):
        for device in self.devices:
            for dtype in self.dtypes:
                src = k2.RaggedTensor([[3, 2], [], [1, 5, 2]],
                                      dtype).to(device)
                src_clone = src.clone()
                new2old = src.sort_(descending=True, need_new2old_indexes=True)
                sorted_src = k2.RaggedTensor([[3, 2], [], [5, 2, 1]],
                                             dtype=dtype).to(device)
                expected_new2old = torch.tensor([0, 1, 3, 4, 2],
                                                device=device,
                                                dtype=torch.int32)
                assert src == sorted_src
                assert torch.all(torch.eq(new2old, expected_new2old))

                expected_sorted = k2.index_select(src_clone.values, new2old)
                sorted = src.values
                assert torch.all(torch.eq(expected_sorted, sorted))
コード例 #9
0
def _intersect_device(
    a_fsas: k2.Fsa,
    b_fsas: k2.Fsa,
    b_to_a_map: torch.Tensor,
    sorted_match_a: bool,
    batch_size: int = 500,
):
    """Wrap k2.intersect_device

    This is a wrapper of k2.intersect_device and its purpose is to split
    b_fsas into several batches and process each batch separately to avoid
    CUDA OOM error.
    The arguments and return value of this function are the same as
    k2.intersect_device.

    NOTE: You can decrease batch_size in case of CUDA out of memory error.
    """
    num_fsas = b_fsas.shape[0]
    if num_fsas <= batch_size:
        return k2.intersect_device(
            a_fsas, b_fsas, b_to_a_map=b_to_a_map, sorted_match_a=sorted_match_a
        )

    num_batches = int(math.ceil(float(num_fsas) / batch_size))
    splits = []
    for i in range(num_batches):
        start = i * batch_size
        end = min(start + batch_size, num_fsas)
        splits.append((start, end))

    ans = []
    for start, end in splits:
        indexes = torch.arange(start, end).to(b_to_a_map)

        fsas = k2.index_fsa(b_fsas, indexes)
        b_to_a = k2.index_select(b_to_a_map, indexes)
        path_lats = k2.intersect_device(
            a_fsas, fsas, b_to_a_map=b_to_a, sorted_match_a=sorted_match_a
        )
        ans.append(path_lats)

    return k2.cat(ans)
コード例 #10
0
 def my_func(src: torch.Tensor, index: torch.Tensor) -> torch.Tensor:
     return k2.index_select(src.to(torch.float32), index)
コード例 #11
0
ファイル: index_select_test.py プロジェクト: entn-at/k2
    def test_2d_non_contiguous(self):
        for device in self.devices:
            for dtype in [torch.int32, torch.int64]:
                num_rows = torch.randint(20, 2000, size=(1, )).item()
                num_cols = torch.randint(1, 2000, size=(1, )).item()
                stride = torch.randint(2, num_rows // 10 + 1,
                                       size=(1, )).item()
                a = torch.randint(-1000,
                                  1000,
                                  size=(num_rows, num_cols),
                                  dtype=dtype,
                                  device=device).contiguous()
                a = a[::stride]
                num_rows = a.shape[0]
                b = torch.randint(-1,
                                  num_rows,
                                  size=(10000, ),
                                  dtype=torch.int32,
                                  device=device)
                assert a.is_contiguous() is False
                assert a.dtype == dtype

                c = k2.index_select(a, b)
                assert c.dtype == a.dtype
                assert c.device == a.device
                assert c.shape[1] == a.shape[1]
                assert c.shape[0] == b.shape[0]

                padded_a = torch.cat([torch.zeros(1, a.shape[1]).to(a), a])
                expected = padded_a.index_select(0, (b + 1).to(torch.int64))

                assert torch.allclose(c, expected)

            for dtype in [torch.float32, torch.float64]:
                num_rows = torch.randint(20, 2000, size=(1, )).item()
                num_cols = torch.randint(1, 2000, size=(1, )).item()
                stride = torch.randint(2, num_rows // 10 + 1,
                                       size=(1, )).item()
                a = torch.randint(-1000,
                                  1000,
                                  size=(num_rows, num_cols),
                                  dtype=dtype,
                                  device=device).contiguous()
                a = a[::stride]
                num_rows = a.shape[0]
                b = torch.randint(-1,
                                  num_rows,
                                  size=(10000, ),
                                  dtype=torch.int32,
                                  device=device)
                assert a.is_contiguous() is False
                assert a.dtype == dtype
                a.requires_grad_(True)

                c = k2.index_select(a, b)
                assert c.dtype == a.dtype
                assert c.device == a.device
                assert c.shape[1] == a.shape[1]
                assert c.shape[0] == b.shape[0]

                c.sum().backward()

                new_a = a.detach().requires_grad_(True)
                padded_a = torch.cat(
                    [torch.zeros(1, a.shape[1]).to(new_a), new_a])
                expected = padded_a.index_select(0, (b + 1).to(torch.int64))
                expected.sum().backward()

                assert torch.allclose(c, expected)
                assert torch.allclose(a.grad, new_a.grad)
コード例 #12
0
    def forward(
        self,
        log_probs: torch.Tensor,
        targets: torch.Tensor,
        input_lengths: torch.Tensor,
        target_lengths: torch.Tensor,
    ) -> torch.Tensor:
        assert self.graph_compiler is not None
        boosted = self.boost_coeff != 0.0
        if self.blank != 0:
            # rearrange log_probs to put blank at the first place
            # and shift targets to emulate blank = 0
            log_probs, targets = make_blank_first(self.blank, log_probs, targets)
        supervisions, order = create_supervision(input_lengths)
        order = order.long()
        targets = targets[order]
        target_lengths = target_lengths[order]

        if log_probs.device != self.graph_compiler.device:
            self.graph_compiler.to(log_probs.device)

        num_graphs, den_graph = self.graph_compiler.compile(
            targets + 1 if self.pad_fsavec else targets, target_lengths
        )

        dense_fsa_vec = (
            prep_padded_densefsavec(log_probs, supervisions)
            if self.pad_fsavec
            else k2.DenseFsaVec(log_probs, supervisions)
        )

        num_tot_scores, den_tot_scores, num_lats, den_lats = self.intersect_calc_scores(
            dense_fsa_vec, num_graphs, den_graph, boosted
        )

        tot_scores = num_tot_scores - den_tot_scores
        mmi_tot_scores, mmi_valid_mask = get_tot_objf_and_finite_mask(tot_scores, self.reduction)

        if boosted:
            assert num_lats is not None and den_lats is not None

            size = (
                dense_fsa_vec.dim0(),
                dense_fsa_vec.scores.shape[0],
                dense_fsa_vec.scores.shape[1] - 1,
            )
            row_ids = dense_fsa_vec.dense_fsa_vec.shape().row_ids(1)
            num_sparse = create_sparse_wrapped(
                indices=[k2.index_select(row_ids, num_lats.seqframe_idx), num_lats.seqframe_idx, num_lats.phones,],
                values=num_lats.get_arc_post(False, True).exp(),
                size=size,
                min_col_index=0,
            )
            den_sparse = create_sparse_wrapped(
                indices=[k2.index_select(row_ids, den_lats.seqframe_idx), den_lats.seqframe_idx, den_lats.phones,],
                values=den_lats.get_arc_post(False, True).exp(),
                size=size,
                min_col_index=0,
            )

            # NOTE: Due to limited support of PyTorch's autograd for sparse tensors,
            # we cannot use (num_sparse - den_sparse) here
            # TODO (alaptev): propose sparse_abs to k2
            acc_loss = torch.sparse.sum(sparse_abs((num_sparse + (-den_sparse)).coalesce()), (1, 2)).to_dense()
            acc_tot_scores, acc_valid_mask = get_tot_objf_and_finite_mask(acc_loss, self.reduction)
            valid_mask = mmi_valid_mask & acc_valid_mask
            total_loss = self.boost_coeff * acc_tot_scores[valid_mask] - mmi_tot_scores[valid_mask]
        else:
            valid_mask = mmi_valid_mask
            total_loss = -mmi_tot_scores[mmi_valid_mask]
        return total_loss, valid_mask