def test_get_dtype_from_model_func(): X, y = make_regression(n_samples=81, n_features=10, noise=0.1, random_state=42, dtype=np.float32) # checking model with float32 dtype model_f32 = reg().fit(X, y) assert get_dtype_from_model_func(model_f32.predict) == np.float32 # checking model with float64 dtype X = X.astype(np.float64) y = y.astype(np.float64) model_f64 = reg().fit(X, y) assert get_dtype_from_model_func(model_f64.predict) == np.float64 # checking model that has not been fitted yet model_not_fit = reg() assert(get_dtype_from_model_func(model_not_fit.predict) is None) # checking arbitrary function def dummy_func(x): return x + x assert get_dtype_from_model_func(dummy_func) is None
def test_model_func_call_gpu(): X, y = make_regression(n_samples=81, n_features=10, noise=0.1, random_state=42, dtype=np.float32) model = reg().fit(X, y) z = model_func_call(X=X, model_func=model.predict, gpu_model=True) assert isinstance(z, cp.ndarray) z = model_func_call(X=cp.asnumpy(X), model_func=dummy_func, gpu_model=False) assert isinstance(z, cp.ndarray) with pytest.raises(TypeError): z = model_func_call(X=X, model_func=dummy_func, gpu_model=True) model = PCA(n_components=10).fit(X) z = model_func_call(X=X, model_func=model.transform, gpu_model=True) assert isinstance(z, cp.ndarray)
def test_get_gpu_tag_from_model_func(): # test getting the gpu tags from the model that we use in explainers model = reg() order = get_tag_from_model_func(func=model.predict, tag='preferred_input_order', default='C') assert order == 'F' out_types = get_tag_from_model_func(func=model.predict, tag='X_types_gpu', default=False) assert isinstance(out_types, list) assert '2darray' in out_types # checking arbitrary function order = get_tag_from_model_func(func=dummy_func, tag='preferred_input_order', default='C') assert order == 'C' out_types = get_tag_from_model_func(func=dummy_func, tag='X_types_gpu', default=False) assert out_types is False model2 = skreg() out_types = get_tag_from_model_func(func=model2.predict, tag='X_types_gpu', default=False) assert out_types is False