def test_fmmv_input_device(self, s_A, s_B, v, Adt, Bdt, vo, vdt, kernel, s_expected_fmmv): input_device = "cuda:0" A = fix_sparse_mat(s_A[0], dtype=Adt, device=input_device) B = fix_sparse_mat(s_B[0], dtype=Bdt, device=input_device) v = fix_mat(v, dtype=vdt, order=vo, copy=True, device=input_device) opt = dataclasses.replace(self.basic_options, use_cpu=False) rtol = choose_on_dtype(A.dtype) # Test normal _run_fmmv_test(kernel.mmv, s_expected_fmmv, (A, B, v), out=None, rtol=rtol, opt=opt) # Test with out out = torch.empty(A.shape[0], v.shape[1], dtype=A.dtype, device=input_device) _run_fmmv_test(kernel.mmv, s_expected_fmmv, (A, B, v), out=out, rtol=rtol, opt=opt)
def test_dfmmv(self, s_A, s_B, v, w, Adt, Bdt, vo, vdt, wo, wdt, kernel, s_e_dfmmv, cpu): A = fix_sparse_mat(s_A[0], dtype=Adt) B = fix_sparse_mat(s_B[0], dtype=Bdt) v = fix_mat(v, order=vo, dtype=vdt) w = fix_mat(w, order=wo, dtype=wdt) opt = dataclasses.replace(self.basic_options, use_cpu=cpu) rtol = choose_on_dtype(A.dtype) # Test normal _run_fmmv_test(kernel.dmmv, s_e_dfmmv, (A, B, v, w), out=None, rtol=rtol, opt=opt) # Test with out out = torch.empty(m, t, dtype=A.dtype) _run_fmmv_test(kernel.dmmv, s_e_dfmmv, (A, B, v, w), out=out, rtol=rtol, opt=opt)