Exemplo n.º 1
0
 def test_matrix_mul_algo_t_big_odd(self):
     va = numpy.random.randn(30, 41).astype(numpy.float64)
     vb = numpy.random.randn(50, 41).astype(numpy.float64)
     res1 = va @ vb.T
     for algo in range(0, 3):
         with self.subTest(algo=algo):
             res2 = dmul_cython_omp(va, vb, algo=algo, b_trans=1)
             assert_almost_equal(res1, res2)
Exemplo n.º 2
0
 def test_matrix_mul_algo_para(self):
     va = numpy.random.randn(3, 4).astype(numpy.float64)
     vb = numpy.random.randn(4, 5).astype(numpy.float64)
     res1 = va @ vb
     for algo in range(0, 2):
         with self.subTest(algo=algo):
             res2 = dmul_cython_omp(va, vb, algo=algo, parallel=1)
             assert_almost_equal(res1, res2)
Exemplo n.º 3
0
 def test_matrix_mul(self):
     va = numpy.random.randn(3, 4).astype(numpy.float64)
     vb = numpy.random.randn(4, 5).astype(numpy.float64)
     res1 = va @ vb
     res2 = dmul_cython_omp(va, vb)
     assert_almost_equal(res1, res2)
Exemplo n.º 4
0
 def test_matrix_mul_fail(self):
     va = numpy.random.randn(3, 4).astype(numpy.float64)
     vb = numpy.random.randn(4, 5).astype(numpy.float64)
     with self.assertRaises(RuntimeError):
         dmul_cython_omp(va, vb, algo=4)
Exemplo n.º 5
0
##############################
# Other scenarios
# +++++++++++++++
#
# 3 differents algorithms, each of them parallelized.
# See :func:`dmul_cython_omp
# <td3a_cpp.tutorial.mul_cython_omp.dmul_cython_omp>`.

for algo in range(0, 2):
    for parallel in (0, 1):
        print("algo=%d parallel=%d" % (algo, parallel))
        ctxs = [
            dict(va=numpy.random.randn(n, n).astype(numpy.float64),
                 vb=numpy.random.randn(n, n).astype(numpy.float64),
                 mul=lambda x, y: dmul_cython_omp(
                     x, y, algo=algo, parallel=parallel),
                 x_name=n) for n in sets
        ]

        res = list(measure_time_dim('mul(va, vb)', ctxs, verbose=1))
        dfs.append(DataFrame(res))
        dfs[-1]['fct'] = 'a=%d-p=%d' % (algo, parallel)
        pprint.pprint(dfs[-1].tail(n=2))

########################################
# One left issue
# ++++++++++++++
#
# Will you find it in :func:`dmul_cython_omp
# <td3a_cpp.tutorial.mul_cython_omp.dmul_cython_omp>`.