Exemplo n.º 1
0
    def test_gemm_1(self):
        A = numpy.arange(1).reshape((1, 1)) + 1
        B = numpy.arange(1).reshape((1, 1)) + 10
        for dtype in [numpy.float32, numpy.float64, numpy.int64]:
            a = A.astype(dtype)
            b = B.astype(dtype)
            for t1 in [False, True]:
                for t2 in [False, True]:
                    with self.subTest(dtype=dtype,
                                      transA=t1,
                                      transB=t2,
                                      shapeA=a.shape,
                                      shapeB=b.shape):
                        ta = a.T if t1 else a
                        tb = b.T if t2 else b
                        exp = ta @ tb
                        got = gemm_dot(a, b, t1, t2)
                        self.assertEqualArray(exp, got)

                        M, N, K = 1, 1, 1
                        lda, ldb, ldc = 1, 1, 1

                        c = numpy.empty(M * N, dtype=a.dtype)
                        pygemm(t2, t1, M, N, K, 1., b.ravel(), ldb, a.ravel(),
                               lda, 0., c, ldc)
                        cc = c.reshape((M, N))
                        self.assertEqualArray(exp, cc)

                        if dtype == numpy.float32:
                            res = sgemm(1, a, b, 0, cc, t1, t2)
                            self.assertEqualArray(exp, res)
Exemplo n.º 2
0
    def test_gemm_323(self):
        A = numpy.arange(6).reshape((2, 3)) + 1
        B = numpy.arange(6).reshape((3, 2)) + 10
        for dtype in [numpy.float32, numpy.float64, numpy.int64]:
            a = A.astype(dtype)
            b = B.astype(dtype)
            for t1 in [False, True]:
                for t2 in [False, True]:
                    with self.subTest(dtype=dtype,
                                      transA=t1,
                                      transB=t2,
                                      shapeA=a.shape,
                                      shapeB=b.shape):
                        ta = a.T if t1 else a
                        tb = b.T if t2 else b
                        try:
                            exp = ta @ tb
                        except ValueError:
                            continue

                        if t1:
                            M = a.shape[1]
                            lda = a.shape[0]
                            K = a.shape[0]
                        else:
                            M = a.shape[0]
                            lda = a.shape[0]
                            K = a.shape[1]

                        if t2:
                            N = b.shape[0]
                            ldb = b.shape[1]
                        else:
                            N = b.shape[1]
                            ldb = b.shape[1]
                        ldc = N

                        c = numpy.empty(M * N, dtype=a.dtype)
                        pygemm(t2, t1, N, M, K, 1., b.ravel(), ldb, a.ravel(),
                               lda, 0., c, ldc)
                        cc = c.reshape((M, N))
                        # self.assertEqualArray(exp, cc)

                        if dtype == numpy.float32:
                            res = sgemm(1, a, b, 0, cc, t1, t2)
                            self.assertEqualArray(exp, res)

                            cc[:, :] = 0
                            sgemm(1, a, b, 0, cc, t1, t2, 1)
                            try:
                                self.assertEqualArray(exp, cc)
                            except AssertionError:
                                # Overwriting the result does not seem
                                # to work.
                                pass

                        got = gemm_dot(a, b, t1, t2)
                        self.assertEqualArray(exp, got)