def testBlocksparseMatMul(self):

        # layout = np.zeros((2,2), dtype=np.int32)
        # layout[0,0] = 1

        n, m = 56 * 8, 8
        layout = networkx.generators.barabasi_albert_graph(n, m)
        #layout = networkx.generators.random_graphs.watts_strogatz_graph(n, m*2, .5)
        layout = networkx.adjacency_matrix(layout).toarray().astype(
            np.int32) + np.eye(n, dtype=np.int32)
        layout[0:m, 0:m] = 1

        #layout[0:60,0:60] = 1
        #layout = np.zeros((4,4), dtype=np.int32)
        #layout = np.ones((28*12,28*12), dtype=np.int32)
        #layout[0,0] = 1

        blocks = layout.sum()
        n = layout.shape[0]
        print(100 * blocks / n**2)
        print(layout.sum(axis=0).max())
        #exit()

        with self.test_session(config=conf) as sess, tf.device("/gpu:0"):

            for bsize, axis in (
                (32, 1),
                (32, 0),
                (16, 0),
                (8, 0),
            ):  # (32,1), (32,0), (16,0), (8,0)

                bsmm = BlocksparseMatMul(layout,
                                         block_size=bsize,
                                         feature_axis=axis,
                                         name="test")

                if one:
                    W = np.ones(bsmm.w_shape, dtype=np.float32)
                    #W[:] += np.arange(8, dtype=np.float32).reshape(1,8)
                else:
                    W = np.random.uniform(-1.0, 1.0,
                                          bsmm.w_shape).astype(np.float32)

                # WW = np.zeros((bsmm.C, bsmm.K), dtype=np.float32)
                # for w, (c, k) in enumerate(bsmm.updat_list):
                #     WW[c*bsize:(c+1)*bsize, k*bsize:(k+1)*bsize] = W[w,:,:]

                w = tf.constant(W)

                # s1 = sess.run( bsmm.identity_init(gpu=True)(bsmm.w_shape) )
                # s2 = bsmm.identity_init(gpu=False)(bsmm.w_shape)
                # print("identity_init: ", (s1 - s2).max())

                for N in (64, ):  # 128,64,32,16,1,

                    if one:
                        X = np.ones(bsmm.i_shape(N), dtype=np.float32)
                        E = np.ones(bsmm.o_shape(N), dtype=np.float32)
                        #X[:] += np.arange(8, dtype=np.float32).reshape(8,1)
                    else:
                        X = np.random.uniform(
                            -1.0, 1.0, bsmm.i_shape(N)).astype(np.float32)
                        E = np.random.uniform(
                            -1.0, 1.0, bsmm.o_shape(N)).astype(np.float32)

                    x = tf.constant(X)
                    e = tf.constant(E)

                    for dtF, dtB in dtypes:

                        print("Axis:%d Bsize:%2d N:%d F:%s B:%s Params:%d" %
                              (axis, bsize, N, dtF.name, dtB.name,
                               bsize * bsize * blocks))

                        # compute in tensorflow
                        if l2norm:
                            w2 = bsmm.l2_normalize(w, dtype=dtF)
                        else:
                            w2 = ew.float_cast(w, dtype=dtF)

                        y = ew.float_cast(x, dtype=dtF)

                        for j in range(depth):
                            repeat = bench if bench and j == depth - 1 else 0
                            y = bsmm(
                                y, w2, dw_dtype=dtF, bench=repeat
                            )  # (bench and j==depth-1) (bench and j==0)

                        y = ew.float_cast(y, dtype=tf.float32, dx_dtype=dtB)
                        if bench: sess.run(y)
                        #y = sess.run( y )

                        d = tf.gradients(y, [x, w], e, aggregation_method=am)
                        if depth > 1:
                            d[1] = group_param_grads(d[1], 8)

                        y, (dx, dw) = sess.run([y, d])

                        if not bench:
                            # compute in numpy
                            if l2norm:
                                W2 = bsmm.l2_normalize_test(W)
                            else:
                                W2 = W

                            # YY = np.dot(WW.T, X)
                            # ZZ = np.dot(WW  , E)
                            # uu = np.dot( X  , E.T)
                            # UU = np.zeros(bsmm.w_shape, dtype=np.float32)
                            # for w, (c, k) in enumerate(bsmm.updat_list):
                            #     UU[w,:,:] = uu[c*bsize:(c+1)*bsize, k*bsize:(k+1)*bsize]

                            Ys = [X]
                            for j in range(depth):
                                Ys.append(bsmm.fprop_test(Ys[-1], W2))
                            Y = Ys.pop()

                            DW = np.zeros(bsmm.w_shape, dtype=np.float32)
                            DX = E
                            for j in range(depth):
                                DW += bsmm.updat_test(Ys.pop(), DX)
                                DX = bsmm.bprop_test(DX, W2)
                            if l2norm:
                                DW = bsmm.l2_normalize_grad_test(W, DW)

                            for op, cpuA, devA in (
                                    # ("YY:", YY,  y),
                                    # ("ZZ:", ZZ, dx),
                                    # ("UU:", UU, dw),
                                (" y:", Y, y),
                                ("dx:", DX, dx),
                                ("dw:", DW, dw),
                            ):

                                difA = abs(cpuA - devA)

                                avgval = np.average(abs(cpuA))
                                maxdif = difA.max()
                                max_err = maxdif if avgval == 0 else maxdif / avgval

                                l2_err = np.sqrt(
                                    np.square(difA).sum()) / np.sqrt(
                                        np.square(cpuA).sum())

                                #print("max_err: %5.3f, max_val: %7.3f, l1_err: %7.5f, l2_err: %7.5f" % (difO.max(), cpuO.max(), l1_err, l2_err))

                                print("%s max_err%%:%11.8f L2_err: %12.10f" %
                                      (op, 100 * max_err, l2_err))

                                # rtol = 1e-4 if dtF is tf.float32 else 1e-1
                                # self.assertAllClose(devA, cpuA, rtol=rtol, atol=rtol)
                                if out:
                                    dim = bsmm.K if op == "dw:" else N
                                    np.savetxt("out.txt",
                                               difA.reshape((-1, dim)),
                                               fmt='%5.1f')
                                    np.savetxt("outC.txt",
                                               cpuA.reshape((-1, dim)),
                                               fmt='%5.1f')
                                    np.savetxt("outD.txt",
                                               devA.reshape((-1, dim)),
                                               fmt='%5.1f')
                                    exit()
                            print("")
    def atestBlocksparseMatMulGated(self):

        with self.test_session(config=conf) as sess, tf.device("/gpu:0"):

            N = 128
            K = 8 * 56 * 2 * 4
            n = K // 8
            m = 30
            dtype = tf.bfloat16
            repeat = 10000

            layout = networkx.generators.barabasi_albert_graph(n, m)
            layout = networkx.adjacency_matrix(layout).toarray().astype(
                np.int32) + np.eye(n, dtype=np.int32)
            layout[0:m, 0:m] = 1

            blocks = layout.sum()
            n = layout.shape[0]
            print(100 * blocks / n**2)
            print(layout.sum(axis=0).max())

            # layout = np.ones((112,32), dtype=np.int32)
            bsmm = BlocksparseMatMul(layout,
                                     block_size=8,
                                     feature_axis=0,
                                     name="test")

            if one:
                X = np.ones(bsmm.i_shape(N), dtype=np.float32)
                E = np.ones(bsmm.o_shape(N), dtype=np.float32)
                W = np.ones(bsmm.w_shape, dtype=np.float32)
                G = np.ones(bsmm.blocks, dtype=np.float32)
            else:
                X = np.random.uniform(-1.0, 1.0,
                                      bsmm.i_shape(N)).astype(np.float32)
                E = np.random.uniform(-1.0, 1.0,
                                      bsmm.o_shape(N)).astype(np.float32)
                W = np.random.uniform(-1.0, 1.0,
                                      bsmm.w_shape).astype(np.float32)
                G = np.random.uniform(0.0, 1.0, bsmm.blocks).astype(np.float32)

            G = np.ones(bsmm.blocks, dtype=np.float32)
            # for w, (c, k) in enumerate(bsmm.updat_list):
            #     G[w] = (c & 1) ^ (k & 1) ^ 1

            #G[::2] = 0.0

            # block = dict()
            # for w, (c, k) in enumerate(bsmm.updat_list):
            #     block[(c,k)] = w

            # grid = []
            # for c in range(bsmm.CB):
            #     row = []
            #     for k in range(bsmm.KB):
            #         row.append(G[block[(c,k)]])
            #     grid.append(row)

            # for row in grid:
            #     print(row)

            # exit()

            x = tf.constant(X)
            e = tf.constant(E)
            w = tf.constant(W)
            g = tf.constant(G)

            w2 = ew.float_cast(w, dtype=dtype)
            y = ew.float_cast(x, dtype=dtype)

            y = bsmm(y, w2, gate=g, bench=repeat)

            y = ew.float_cast(y, dtype=tf.float32, dx_dtype=dtype)

            d = tf.gradients(y, [x, w], e)

            y, (dx, dw) = sess.run([y, d])

            # gpu kernel doesn't touch zero gate blocks
            # for b in range(bsmm.blocks):
            #     if G[b] == 0.0:
            #         dw[b,:,:] = 0.0

            Y = bsmm.fprop_test(X, W, gate=G)
            DX = bsmm.bprop_test(E, W, gate=G)
            DW = bsmm.updat_test(X, E, gate=G)

            #print(Y.shape, dtype)

            for op, cpuA, devA in (
                (" y:", Y, y),
                ("dx:", DX, dx),
                ("dw:", DW, dw),
            ):

                difA = abs(cpuA - devA)

                avgval = np.average(abs(cpuA))
                maxdif = difA.max()
                max_err = maxdif if avgval == 0 else maxdif / avgval

                l2_err = np.sqrt(np.square(difA).sum()) / np.sqrt(
                    np.square(cpuA).sum() + 1e-12)

                print("%s max_err%%:%11.8f L2_err: %12.10f" %
                      (op, 100 * max_err, l2_err))

                if out:
                    dim = K if op == "dw:" else N
                    np.savetxt("out.txt", difA.reshape((-1, dim)), fmt='%5.1f')
                    np.savetxt("outC.txt",
                               cpuA.reshape((-1, dim)),
                               fmt='%5.1f')
                    np.savetxt("outD.txt",
                               devA.reshape((-1, dim)),
                               fmt='%5.1f')
                    exit()