예제 #1
0
    def test_fmmv_input_device(self, s_A, s_B, v, Adt, Bdt, vo, vdt, kernel,
                               s_expected_fmmv):
        input_device = "cuda:0"
        A = fix_sparse_mat(s_A[0], dtype=Adt, device=input_device)
        B = fix_sparse_mat(s_B[0], dtype=Bdt, device=input_device)
        v = fix_mat(v, dtype=vdt, order=vo, copy=True, device=input_device)

        opt = dataclasses.replace(self.basic_options, use_cpu=False)
        rtol = choose_on_dtype(A.dtype)

        # Test normal
        _run_fmmv_test(kernel.mmv,
                       s_expected_fmmv, (A, B, v),
                       out=None,
                       rtol=rtol,
                       opt=opt)
        # Test with out
        out = torch.empty(A.shape[0],
                          v.shape[1],
                          dtype=A.dtype,
                          device=input_device)
        _run_fmmv_test(kernel.mmv,
                       s_expected_fmmv, (A, B, v),
                       out=out,
                       rtol=rtol,
                       opt=opt)
예제 #2
0
    def test_dfmmv(self, s_A, s_B, v, w, Adt, Bdt, vo, vdt, wo, wdt, kernel, s_e_dfmmv, cpu):
        A = fix_sparse_mat(s_A[0], dtype=Adt)
        B = fix_sparse_mat(s_B[0], dtype=Bdt)
        v = fix_mat(v, order=vo, dtype=vdt)
        w = fix_mat(w, order=wo, dtype=wdt)

        opt = dataclasses.replace(self.basic_options, use_cpu=cpu)
        rtol = choose_on_dtype(A.dtype)

        # Test normal
        _run_fmmv_test(kernel.dmmv, s_e_dfmmv, (A, B, v, w), out=None, rtol=rtol, opt=opt)
        # Test with out
        out = torch.empty(m, t, dtype=A.dtype)
        _run_fmmv_test(kernel.dmmv, s_e_dfmmv, (A, B, v, w), out=out, rtol=rtol, opt=opt)