예제 #1
0
파일: blas.py 프로젝트: Jackwangyang/Theano
 def perform(self, node, inputs, outputs):
     C, alpha, A, B, beta = inputs
     inplace = self.inplace
     if inplace and not C.flags.forc:
         inplace = False
     outputs[0][0] = blas.gemm(alpha, A, B, beta, C,
                               overwrite_c=inplace)
예제 #2
0
def gemm(m, n, k, dtype, order, trans, offseted_o, sliced, overwrite,
         init_res, alpha=1.0, beta=0.0):
    if trans[0]:
        shpA = (k,m)
    else:
        shpA = (m,k)
    if trans[1]:
        shpB = (n,k)
    else:
        shpB = (k,n)

    cA, gA = gen_gpuarray(shpA, dtype, order=order[0],
                          offseted_outer=offseted_o,
                          sliced=sliced, ctx=context)
    cB, gB = gen_gpuarray(shpB, dtype, order=order[1],
                          offseted_outer=offseted_o,
                          sliced=sliced, ctx=context)
    if init_res:
        cC, gC = gen_gpuarray((m,n), dtype, order=order[2], ctx=context)
    else:
        cC, gC = None, None

    if dtype == 'float32':
        cr = fblas.sgemm(alpha, cA, cB, beta, cC, trans_a=trans[0],
                         trans_b=trans[1], overwrite_c=overwrite)
    else:
        cr = fblas.dgemm(alpha, cA, cB, beta, cC, trans_a=trans[0],
                         trans_b=trans[1], overwrite_c=overwrite)
    gr = gblas.gemm(alpha, gA, gB, beta, gC, trans_a=trans[0],
                    trans_b=trans[1], overwrite_c=overwrite)

    numpy.testing.assert_allclose(cr, numpy.asarray(gr), rtol=1e-6)
예제 #3
0
 def perform(self, node, inputs, outputs):
     C, alpha, A, B, beta = inputs
     inplace = self.inplace
     if inplace and not C.flags.forc:
         inplace = False
     outputs[0][0] = blas.gemm(alpha, A, B, beta, C,
                               overwrite_c=inplace)
예제 #4
0
파일: blas.py 프로젝트: yyq90/grammarVAE
    def perform(self, node, inputs, outputs):
        x, y = inputs

        out = pygpu.empty((x.shape[0], y.shape[1]),
                          dtype=x.dtype,
                          context=x.context)
        outputs[0][0] = blas.gemm(1., x, y, 0., out, overwrite_c=True)
예제 #5
0
파일: blas.py 프로젝트: datascibox/Theano
    def perform(self, node, inputs, outputs):
        x, y = inputs

        out = pygpu.empty((x.shape[0], y.shape[1]), dtype=x.dtype,
                          context=x.context)
        outputs[0][0] = blas.gemm(1., x, y, 0., out,
                                  overwrite_c=True)
예제 #6
0
def gemm(m, n, k, dtype, order, trans, offseted_o, sliced, overwrite,
         init_res, alpha=1.0, beta=0.0):
    if trans[0]:
        shpA = (k,m)
    else:
        shpA = (m,k)
    if trans[1]:
        shpB = (n,k)
    else:
        shpB = (k,n)

    cA, gA = gen_gpuarray(shpA, dtype, order=order[0],
                          offseted_outer=offseted_o,
                          sliced=sliced, ctx=context)
    cB, gB = gen_gpuarray(shpB, dtype, order=order[1],
                          offseted_outer=offseted_o,
                          sliced=sliced, ctx=context)
    if init_res:
        cC, gC = gen_gpuarray((m,n), dtype, order=order[2], ctx=context)
    else:
        cC, gC = None, None

    if dtype == 'float32':
        cr = fblas.sgemm(alpha, cA, cB, beta, cC, trans_a=trans[0],
                         trans_b=trans[1], overwrite_c=overwrite)
    else:
        cr = fblas.dgemm(alpha, cA, cB, beta, cC, trans_a=trans[0],
                         trans_b=trans[1], overwrite_c=overwrite)
    gr = gblas.gemm(alpha, gA, gB, beta, gC, trans_a=trans[0],
                    trans_b=trans[1], overwrite_c=overwrite)

    numpy.testing.assert_allclose(cr, numpy.asarray(gr), rtol=1e-6)
예제 #7
0
 def perform(self, node, inputs, outputs):
     C, alpha, A, B, beta = inputs
     outputs[0][0] = blas.gemm(alpha,
                               A,
                               B,
                               beta,
                               C,
                               overwrite_c=self.inplace)
예제 #8
0
파일: blas.py 프로젝트: SamuelZeng/Theano
 def perform(self, node, inputs, outputs):
     C, alpha, A, B, beta = inputs
     outputs[0][0] = blas.gemm(alpha, A, B, beta, C,
                               overwrite_c=self.inplace)