Example #1
0
    def test_property_values(self):
        for device in self.devices:
            for dtype in self.dtypes:
                a = k2r.RaggedTensor(
                    [[1], [2], [], [3, 4]], device=device, dtype=dtype
                )
                assert torch.all(
                    torch.eq(
                        a.values,
                        torch.tensor([1, 2, 3, 4], dtype=dtype, device=device),
                    )
                )

                with self.assertRaises(AttributeError):
                    # the `values` attribute is const. You cannot rebind it
                    a.values = 10

                # However, we can change the elements of a.values
                a.values[0] = 10
                a.values[-1] *= 2

                expected = k2r.RaggedTensor(
                    [[10], [2], [], [3, 8]], dtype=dtype, device=device
                )
                assert a == expected

                a.values[0] = 1
                assert a != expected
Example #2
0
    def test_getitem_scalar(self):
        for device in self.devices:
            for dtype in self.dtypes:
                a = k2r.RaggedTensor(
                    "[ [[1 2] [] [10]] [[3] [5]] ]", dtype=dtype
                )
                a = a.to(device)
                b = a[0]
                expected = k2r.RaggedTensor("[[1 2] [] [10]]", dtype=dtype).to(
                    device
                )
                assert b == expected

                b = a[1]
                expected = k2r.RaggedTensor("[[3] [5]]", dtype=dtype).to(device)
                assert b == expected
Example #3
0
    def test_tot_size_2axes(self):
        for device in self.devices:
            for dtype in self.dtypes:
                a = k2r.RaggedTensor("[ [1 2 3] [] [5 8] ]", dtype=dtype)
                a = a.to(device)

                assert a.tot_size(0) == 3
                assert a.tot_size(1) == 5
Example #4
0
    def test_clone(self):
        a = k2r.RaggedTensor([[1, 2], [], [3]])
        b = a.clone()

        assert a == b
        a.values[0] = 10

        assert a != b
Example #5
0
    def test_sum_no_grad(self):
        for device in self.devices:
            for dtype in self.dtypes:
                a = k2r.RaggedTensor([[1, 2], [], [5]], dtype=dtype)
                a = a.to(device)
                b = a.sum()
                expected_sum = torch.tensor(
                    [3, 0, 5], dtype=dtype, device=device
                )

                assert torch.all(torch.eq(b, expected_sum))
Example #6
0
    def test_tot_size_3axes(self):
        for device in self.devices:
            for dtype in self.dtypes:
                a = k2r.RaggedTensor(
                    "[ [[1 2 3] [] [5 8]] [[] [1 5 9 10 -1] [] [] []] ]",
                    dtype=dtype,
                )
                a = a.to(device)

                assert a.tot_size(0) == 2
                assert a.tot_size(1) == 8
                assert a.tot_size(2) == 10
Example #7
0
    def test_setstate_3axes(self):
        for device in self.devices:
            for dtype in self.dtypes:
                a = k2r.RaggedTensor("[ [[1] [2 3] []]  [[10]]]", dtype=dtype)
                fid, tmp_filename = tempfile.mkstemp()
                os.close(fid)

                torch.save(a, tmp_filename)

                b = torch.load(tmp_filename)
                os.remove(tmp_filename)

                assert a == b
Example #8
0
    def test_grad(self):
        a = k2r.RaggedTensor([[1, 2], [10], []], dtype=torch.float32)
        assert a.grad is None
        assert a.requires_grad is False

        a.requires_grad = True
        assert a.requires_grad is True

        a.requires_grad_(False)
        assert a.requires_grad is False

        a.requires_grad_(True)
        assert a.requires_grad is True
Example #9
0
    def test_setstate_2axes(self):
        for device in self.devices:
            for dtype in self.dtypes:
                a = k2r.RaggedTensor([[1], [2, 3], []], dtype=dtype)
                a = a.to(device)
                fid, tmp_filename = tempfile.mkstemp()
                os.close(fid)

                torch.save(a, tmp_filename)

                b = torch.load(tmp_filename)
                os.remove(tmp_filename)

                # It checks both dtype and device, not just value
                assert a == b
Example #10
0
    def test_getstate_2axes(self):
        for device in self.devices:
            for dtype in self.dtypes:
                a = k2r.RaggedTensor([[1, 2], [3], []], dtype=dtype).to(device)
                b = a.__getstate__()
                assert isinstance(b, tuple)
                assert len(b) == 3
                # b contains (row_splits, "row_ids1", values)
                b_0 = torch.tensor(
                    [0, 2, 3, 3], dtype=torch.int32, device=device
                )
                b_1 = "row_ids1"
                b_2 = a.values

                assert torch.all(torch.eq(b[0], b_0))
                assert b[1] == b_1
                assert torch.all(torch.eq(b[2], b_2))
Example #11
0
    def test_sum_with_grad(self):
        for device in self.devices:
            for dtype in [torch.float32, torch.float64]:
                a = k2r.RaggedTensor([[1, 2], [], [5]], dtype=dtype)
                a = a.to(device)
                a.requires_grad_(True)
                b = a.sum()
                expected_sum = torch.tensor(
                    [3, 0, 5], dtype=dtype, device=device
                )

                assert torch.all(torch.eq(b, expected_sum))

                c = b[0] * 10 + b[1] * 20 + b[2] * 30
                c.backward()
                expected_grad = torch.tensor(
                    [10, 10, 30], device=device, dtype=dtype
                )
                assert torch.all(torch.eq(a.grad, expected_grad))
Example #12
0
    def test_create_ragged_tensor_from_string(self):
        funcs = [k2r.create_ragged_tensor, k2r.RaggedTensor]
        for func in funcs:
            for device in self.devices:
                for dtype in self.dtypes:
                    a = func(
                        [[1], [2, 3, 4, 5], []], dtype=dtype, device=device
                    )
                    b = func("[[1] [2 3 4 5] []]", dtype=dtype, device=device)
                    assert a == b
                    assert b.dim0 == 3
                    assert a.dtype == dtype
                    assert a.device == device

                    b = k2r.RaggedTensor(
                        "[[[1] [2 3] []] [[10]]]", dtype=dtype, device=device
                    )
                    assert b.num_axes == 3
                    assert b.dim0 == 2
                    assert b.dtype == dtype
                    assert b.device == device
Example #13
0
    def test_getitem_1d_tensor(self):
        for device in self.devices:
            for dtype in self.dtypes:
                a = k2r.RaggedTensor([[1, 2, 0], [0, 1], [2, 3]], device=device)
                b = k2.RaggedTensor(
                    #   0        1        2             3
                    [[10, 20], [300], [-10, 0, -1], [-2, 4, 5]],
                    dtype=dtype,
                    device=device,
                )

                # for a[0]
                index = torch.tensor(
                    [1, 2, 0], dtype=torch.int32, device=device
                )
                assert torch.all(torch.eq(a[0], index))
                expected = k2.RaggedTensor(
                    [[300], [-10, 0, -1], [10, 20]], dtype=dtype, device=device
                )

                assert b[a[0]] == expected

                # for a[1]
                index = torch.tensor([0, 1], dtype=torch.int32, device=device)
                assert torch.all(torch.eq(a[1], index))
                expected = k2.RaggedTensor(
                    [[10, 20], [300]], dtype=dtype, device=device
                )
                assert b[a[1]] == expected

                # for a[2]
                index = torch.tensor([2, 3], dtype=torch.int32, device=device)
                assert torch.all(torch.eq(a[2], index))
                expected = k2.RaggedTensor(
                    [[-10, 0, -1], [-2, 4, 5]], dtype=dtype, device=device
                )
                assert b[a[2]] == expected
Example #14
0
    def test_getstate_3axes(self):
        for device in self.devices:
            for dtype in self.dtypes:
                a = k2r.RaggedTensor(
                    "[[[1 2] [3] []] [[4] [5 6]]]", dtype=dtype
                ).to(device)
                b = a.__getstate__()
                assert isinstance(b, tuple)
                assert len(b) == 5
                # b contains (row_splits1, "row_ids1", row_splits2,
                # "row_ids2", values)
                b_0 = torch.tensor([0, 3, 5], dtype=torch.int32, device=device)
                b_1 = "row_ids1"
                b_2 = torch.tensor(
                    [0, 2, 3, 3, 4, 6], dtype=torch.int32, device=device
                )  # noqa
                b_3 = "row_ids2"
                b_4 = a.values

                assert torch.all(torch.eq(b[0], b_0))
                assert b[1] == b_1
                assert torch.all(torch.eq(b[2], b_2))
                assert b[3] == b_3
                assert torch.all(torch.eq(b[4], b_4))