Example #1
0
    def test0(self):
        tests = [
            dict(size=[128, 64],
                 blocks=[
                     (0, 0),
                     (1, 0),
                     (2, 0),
                     (0, 1),
                 ],
                 row_start_ends_a=tensor([0, 2, 3, 4, 4]),
                 cols_a=tensor([[0, 0], [1, 1], [0, 2], [0, 3]]),
                 col_start_ends_b=tensor([0, 3, 4]),
                 rows_b=tensor([[0, 0], [1, 2], [2, 3], [0, 1]]))
        ]
        block_shape = (32, 32)
        device = "cuda"
        for test_info in tests:
            size = test_info["size"]
            blocks = test_info["blocks"]
            bsm = BlockSparseMatrix.randn((size[0], size[1]),
                                          None,
                                          blocks=blocks,
                                          block_shape=block_shape,
                                          device=device)

            for key in test_info:
                if "row" in key or "col" in key:
                    bsm_a = getattr(bsm, key)
                    ref = test_info[key].to(device=device, dtype=torch.int32)
                    check = (bsm_a == ref).all()
                    if not check:
                        raise Exception(
                            f"Non matching attribute {key}:\n{bsm_a}\n!=\n{ref} (ref)."
                        )
    def help_randn(
            cls,
            shape,
            n_blocks,
            blocks=None,
            block_shape=(32, 32),
            device="cuda",
            positive=False,
    ):
        try:
            real = BlockSparseMatrix.randn(shape,
                                           n_blocks,
                                           blocks,
                                           block_shape,
                                           device=device,
                                           positive=positive)
        except Exception:
            real = None
        emul = BlockSparseMatrixEmulator.randn(shape,
                                               n_blocks,
                                               blocks,
                                               block_shape,
                                               device=device,
                                               positive=positive)

        return real, emul
Example #3
0
    def test1(self):
        sizes = [(32, 32), (64, 32), (32, 64), (64, 64), (256, 64)]
        for size in sizes:
            print(f"size={size}")
            block_shape = (32, 32)
            block_count = size[0] * size[1] // (block_shape[0] *
                                                block_shape[1])
            device = "cuda"

            bsm = BlockSparseMatrix.randn(size,
                                          block_count,
                                          block_shape=block_shape,
                                          device=device)
            a = bsm.to_dense()
            bsm.check_with_dense(a)

            bsm2 = BlockSparseMatrix.from_dense(a,
                                                block_shape,
                                                block_count=None)
            bsm2.check_with_dense(a)

            a2 = bsm2.to_dense()

            if not (a == a2).all():
                print((a == a2)[::8, ::8])
                raise Exception(
                    "Non matching matrices, BlockSparseMatrix.from_dense is not correct."
                )
Example #4
0
    def test0(self):
        sizes = [64, 64]
        block_size = (32, 32)
        block_count = 2
        bsm = BlockSparseMatrix.randn(sizes, block_count, blocks=None, block_shape=block_size, device="cuda")

        with tempfile.NamedTemporaryFile() as tf:
            torch.save(bsm, tf.name)

            bsm2 = torch.load(tf.name)

            self.assertTrue((bsm.to_dense() == bsm2.to_dense()).all())
Example #5
0
    def test_block_norm(self):
        nblocks = 6
        block_shape = (32, 32)
        bsm = BlockSparseMatrix.randn((256, 256),
                                      nblocks,
                                      block_shape=block_shape,
                                      device="cuda")
        n = bsm.block_norm()
        self.assertEqual(n.dim(), 1)
        self.assertEqual(n.shape[0], nblocks)

        d = bsm.data.reshape(-1, block_shape[0] * block_shape[1])
        d = (d * d).sum(-1).sqrt()
        self.assertTrue(d.isclose(n).all())
    def test_sparse_optimizer(self):
        size = (256, 256)
        block_count = 32
        cleanup_ratio = 0.1
        block_shape = (32, 32)
        bsm = BlockSparseMatrix.randn(size, block_count, block_shape=block_shape, device="cuda")
        dense0 = bsm.to_dense()

        so = SparseOptimizer([bsm], lr=cleanup_ratio)

        so.step()

        expected_block_changes = int(cleanup_ratio * block_count) * 2
        self.check_differences(bsm, dense0, expected_block_changes)
    def test0(self):
        size = (256, 256)
        block_count = 32
        cleanup_ratio = 0.1
        block_shape = (32,32)
        bsm = BlockSparseMatrix.randn(size, block_count, block_shape=block_shape, device="cuda")

        dense0 = bsm.to_dense()

        strategy = MagnitudeSparseOptimizerStrategy(cleanup_ratio)
        strategy.run(bsm)

        expected_block_changes = int(cleanup_ratio * block_count) * 2
        self.check_differences(bsm, dense0, expected_block_changes)
    def helper(
        self,
        sizes,
        block_size,
        block_count=None,
        density=None,
        blocks=None,
        iterations=1,
        device="cuda",
        transpose=True,
        verbose=False,
    ):
        device = device
        if isinstance(sizes[0], tuple):
            sizes_0 = sizes[0]
        else:
            sizes_0 = (sizes[0], )

        # Build positive matrix to easily check results
        if transpose:
            a = torch.randn(sizes_0 + (sizes[1], ), device=device).abs()
        else:
            a = torch.randn(sizes_0 + (sizes[2], ), device=device).abs()

        # torch.set_printoptions(precision=10, edgeitems=100000, linewidth=10000)
        if verbose:
            print("a=", a, "\n")

        if block_count is None and blocks is None:
            total_block_count = sizes[1] * sizes[2] / block_size[
                0] / block_size[1]
            block_count = int(total_block_count * density)

        bsm = BlockSparseMatrix.randn(
            (sizes[2], sizes[1]),
            block_count,
            blocks=blocks,
            block_shape=block_size,
            device=device,
            positive=True,
        )  # Build positive matrix to easily check results

        dbsm = bsm.to_dense()
        if verbose:
            print("b=", dbsm, "\n")
            print("a.shape", a.shape)
        bsm.check_with_dense(dbsm)

        timings = {}
        for kind in ["pytorch", "cutlass"]:
            start = torch.cuda.Event(enable_timing=True)
            end = torch.cuda.Event(enable_timing=True)

            start.record()

            for i in range(iterations):
                if kind == "pytorch":
                    if transpose:
                        dbsm_ = dbsm.t()
                    else:
                        dbsm_ = dbsm
                    c = a.matmul(dbsm_)

                    if verbose:
                        print("c=", c, "\n")

                elif kind == "cutlass":
                    c = bsm.reverse_matmul(a, transpose)
                elif kind == "cublas":
                    import block_sparse_native

                    prr = torch.zeros((sizes[2], sizes[0]), device=device)
                    prr = prr.t()
                    _ = block_sparse_native.blocksparse_matmul_transpose_dense(
                        a, dbsm, prr)
                elif kind == "cuda":
                    c = bsm.matmul_cuda(a)

            end.record()
            torch.cuda.synchronize()
            elapsed = start.elapsed_time(end)

            timing = dict(kind=kind, elapsed=elapsed, result=c)
            timings[kind] = timing

        if "pytorch" in timings:
            c0 = timings["pytorch"]["result"]
            for k, t in timings.items():
                if k == "pytorch":
                    t["comparison"] = True
                    continue
                c = t["result"]
                torch.set_printoptions(precision=8,
                                       edgeitems=100000,
                                       linewidth=10000)
                stride = 32
                shift = 0
                c_ = c[shift::stride, shift::stride]
                c0_ = c0[shift::stride, shift::stride]
                if verbose:
                    print("c shape", c.shape)
                    print("c\n", c_)
                    print("c0\n", c0_)
                    print("c!=0\n", (c_ != 0).long())
                    print("c0!=0\n", (c0_ != 0).long())
                    print("equals\n", ((c_ - c0_).abs() < 1e-06).long())
                    print("equals nonzero\n",
                          ((c_ - c0_).abs() > 1e-06).nonzero().t())

                atol = 1e-8
                rtol = 1e-5
                # Matrix are positive, so this is ok
                s = c.isclose(c0).all()
                if not s.item():
                    print(
                        f"max difference for {t['kind']} = { (c - c0).abs().max()},"
                        f" max_values={c.abs().max()}, {c0.abs().max()}")
                    diff = (c - c0).abs() / (atol + rtol * c0.abs())
                    t["comparison"] = False
                    raise Exception(
                        f"Comparison NOK : reverse_matmul issue for {k} sizes={sizes},"
                        f" density={density}, block_count={block_count},"
                        f"diff={diff}, blocks={blocks}, transpose={transpose}")
                else:
                    if verbose:
                        print(f"Comparison OK for reverse_matmul for {k}")
                        print("max difference %s=" % t["kind"],
                              (c - c0).abs().max())
                    t["comparison"] = True
                if verbose:
                    print("c_cutlass=", c)
        torch.set_printoptions(profile="default")

        return timings
Example #9
0
    def test_block_replace(self):
        tests = [
            dict(
                size=[128, 64],
                blocks=[
                    (0, 0),
                    (1, 0),
                    (2, 0),
                    (0, 1),
                ],
                block_info=[(0, 0), (0, 1), (1, 0), (2, 0)],
                block_replace=[
                    (3, 1, 0),
                    (2, 1, 2),
                    (1, 1, 1),
                ],  # row, col, block_index
                after=dict(
                    row_start_ends_a=[0, 0, 1, 3, 4],
                    cols_a=[[1, 1], [0, 3], [1, 2], [1, 0]],
                    block_mask=[[0, 0], [0, 1], [1, 1], [0, 1]],
                ),
            ),
            dict(
                size=[128, 64],
                blocks=[
                    (0, 0),
                    (1, 0),
                    (2, 0),
                    (0, 1),
                ],
                block_info=[(0, 0), (0, 1), (1, 0), (2, 0)],
                block_replace=[(0, 1, 0)],  # row, col, block_index
                error="Block position (0,1) was already used",
            ),
        ]
        block_shape = (32, 32)
        device = "cuda"
        verbose = False
        for test_info in tests[:1]:
            size = test_info["size"]
            blocks = test_info["blocks"]
            block_replace = torch.tensor(test_info["block_replace"])
            bsm = BlockSparseMatrix.randn(
                (size[0], size[1]),
                None,
                blocks=blocks,
                block_shape=block_shape,
                device=device,
                positive=True,
            )
            bsm.check_ = True

            if verbose:
                print(block_replace)
                block_mask0 = bsm.block_mask_build(None)
                print(block_mask0)

            dbsm0 = bsm.to_dense()
            block_positions = bsm.build_coo_block_index().t()
            for i, b in enumerate(test_info["block_info"]):
                block_position = tuple(block_positions[i].cpu().numpy())
                self.assertEqual(b, block_position)

            try:
                bsm.block_replace(block_replace)
            except Exception as e:
                if test_info.get("error") == str(e):
                    continue
                raise

            for k, v in test_info["after"].items():
                if k != "block_mask":
                    r = getattr(bsm, k)
                else:
                    r = bsm.block_mask_build(None).long()
                v = torch.tensor(v, device=r.device)

                self.assertTrue((r == v).all())

            dbsm = bsm.to_dense()
            bsm.check_with_dense(dbsm)

            # Check changed positions
            bs = block_shape
            for b in block_replace:
                block_index = b[2]
                bp = block_positions[block_index]
                block0 = dbsm0[bp[0] * bs[0]:(bp[0] + 1) * bs[0],
                               bp[1] * bs[1]:(bp[1] + 1) * bs[1], ]
                block = dbsm[b[0] * bs[0]:(b[0] + 1) * bs[0],
                             b[1] * bs[1]:(b[1] + 1) * bs[1]]

                self.assertTrue((block0 == block).all())

            # Check unchanged positions
            for i, b in enumerate(block_positions):
                if i not in block_replace[:, 2]:
                    bp = b
                    block0 = dbsm0[bp[0] * bs[0]:(bp[0] + 1) * bs[0],
                                   bp[1] * bs[1]:(bp[1] + 1) * bs[1], ]
                    block = dbsm[b[0] * bs[0]:(b[0] + 1) * bs[0],
                                 b[1] * bs[1]:(b[1] + 1) * bs[1], ]
                    self.assertTrue((block0 == block).all())

            # Check that empty positions are indeed empty
            block_mask = bsm.block_mask_build(None)

            if verbose:
                print(block_mask)

            block_mask = block_mask.repeat_interleave(
                32, dim=0).repeat_interleave(32, dim=1).float()
            self.assertEqual((dbsm * (1 - block_mask)).abs().sum(), 0)

            # Part 2: check multiplication behaviour
            a = torch.randn((1, size[1]), device=bsm.data.device).abs()

            c = bsm.reverse_matmul(a, transpose=True)
            c_0 = a.matmul(dbsm.t())

            # Basic check
            all_compare = torch.isclose(c, c_0)
            if not all_compare.all():
                # print((all_compare != True).nonzero())
                # print((c-c_0).abs().max())
                self.assertTrue(False)

            # Check matmul with sparse support
            b = torch.randn((1, size[0]), device=bsm.data.device).abs()

            bsm.matmul_with_output_sparse_support(b, a, overwrite_data=True)
            dbsm_back = bsm.to_dense()
            dbsm0_back = b.t().mm(a)
            dbsm0_back = dbsm0_back * bsm.to_dense(
                data_replace=torch.ones_like(bsm.data))

            self.assertTrue(dbsm0_back.isclose(dbsm_back).all())