Ejemplo n.º 1
0
 def test2(self):
     bsm = BlockSparseMatrix.zeros((32, 32),
                                   1,
                                   block_shape=(32, 32),
                                   device="cuda")
     hash(bsm)
Ejemplo n.º 2
0
    def helper_(self,
                sizes,
                block_size,
                block_count=None,
                blocks=None,
                density=None,
                iterations=1,
                non_contiguous_a=False,
                non_contiguous_b=False):
        device = "cuda"

        if isinstance(sizes[0], tuple):
            sizes_0 = sizes[0]
        else:
            sizes_0 = (sizes[0], )

        # Build positive matrices to easily check results
        a = torch.randn(sizes_0 + (sizes[1], ), device=device).abs()
        b = torch.randn(sizes_0 + (sizes[2], ), device=device).abs()

        if non_contiguous_a:
            a = a.transpose(-2, -1).contiguous().transpose(-2, -1)

        if non_contiguous_b:
            b = b.transpose(-2, -1).contiguous().transpose(-2, -1)

        if block_count == None and blocks == None:
            total_block_count = sizes[1] * sizes[2] / block_size[
                0] / block_size[1]
            block_count = int(total_block_count * density)

        bsm = BlockSparseMatrix.zeros((sizes[2], sizes[1]),
                                      block_count,
                                      blocks,
                                      block_size,
                                      device=device)

        results = {}

        kinds = ["pytorch", "cutlass"]
        kinds.reverse()
        for kind in kinds:
            start = torch.cuda.Event(enable_timing=True)
            end = torch.cuda.Event(enable_timing=True)

            start.record()
            for i in range(iterations):
                if kind == "pytorch":
                    aa = a.reshape(-1, a.shape[-1])
                    bb = b.reshape(-1, b.shape[-1])
                    bb = bb.t()
                    c = bb.mm(aa)
                elif kind == "cutlass":
                    bsm.matmul_with_output_sparse_support(b,
                                                          a,
                                                          overwrite_data=True)
                    c = bsm

            end.record()
            torch.cuda.synchronize()
            elapsed = start.elapsed_time(end)

            result = dict(kind=kind, elapsed=elapsed, output=c)
            results[kind] = result

        if "pytorch" in results:
            c0 = results["pytorch"]["output"]

            for k, t in results.items():
                if k == "pytorch":
                    t["comparison"] = True
                    continue
                c = t["output"]

                c_dense = c.to_dense()

                c0_ = c0 * (c_dense != 0)

                s = c_dense.isclose(c0_, rtol=1e-4).all()

                if not s.item():
                    print("max difference %s=" % t["kind"],
                          float((c_dense - c0_).abs().max()),
                          float(c.data.abs().max()))
                    raise Exception(
                        "Comparison NOK : matmul_with_output_sparse_support issue for ",
                        k)
                    t["comparison"] = False
                else:
                    #print("Comparison OK for matmul_with_output_sparse_support for ", k)
                    #print("max difference %s=" % t["kind"], float((c_dense - c0_).abs().max()))
                    t["comparison"] = True

        return results