def test_datatype(): """Tests get and set_datatype""" assert isinstance(pf.get_datatype(), torch.dtype) assert pf.get_datatype() == torch.float32 pf.set_datatype(torch.float64) assert isinstance(pf.get_datatype(), torch.dtype) assert pf.get_datatype() == torch.float64 pf.set_datatype(torch.float32) with pytest.raises(TypeError): pf.set_datatype("lala")
def pytest_runtest_setup(item): pf.set_backend("tensorflow") pf.set_datatype(None)
def pytest_runtest_setup(item): pf.set_backend("pytorch") pf.set_datatype(None)