def test_load_netcdf(self): # netcdf support is optional if not ht.io.supports_netcdf(): return # default parameters iris = ht.load_netcdf(self.NETCDF_PATH, self.NETCDF_VARIABLE, device=ht_device) self.assertIsInstance(iris, ht.DNDarray) self.assertEqual(iris.shape, self.IRIS.shape) self.assertEqual(iris.dtype, ht.float32) self.assertEqual(iris._DNDarray__array.dtype, torch.float32) self.assertTrue((self.IRIS == iris._DNDarray__array).all()) # positive split axis iris = ht.load_netcdf(self.NETCDF_PATH, self.NETCDF_VARIABLE, split=0, device=ht_device) self.assertIsInstance(iris, ht.DNDarray) self.assertEqual(iris.shape, self.IRIS.shape) self.assertEqual(iris.dtype, ht.float32) lshape = iris.lshape self.assertLessEqual(lshape[0], self.IRIS.shape[0]) self.assertEqual(lshape[1], self.IRIS.shape[1]) # negative split axis iris = ht.load_netcdf(self.NETCDF_PATH, self.NETCDF_VARIABLE, split=-1) self.assertIsInstance(iris, ht.DNDarray) self.assertEqual(iris.shape, self.IRIS.shape) self.assertEqual(iris.dtype, ht.float32) lshape = iris.lshape self.assertEqual(lshape[0], self.IRIS.shape[0]) self.assertLessEqual(lshape[1], self.IRIS.shape[1]) # different data type iris = ht.load_netcdf(self.NETCDF_PATH, self.NETCDF_VARIABLE, dtype=ht.int8, device=ht_device) self.assertIsInstance(iris, ht.DNDarray) self.assertEqual(iris.shape, self.IRIS.shape) self.assertEqual(iris.dtype, ht.int8) self.assertEqual(iris._DNDarray__array.dtype, torch.int8)
def test_load_netcdf_exception(self): # netcdf support is optional if not ht.io.supports_netcdf(): return # improper argument types with self.assertRaises(TypeError): ht.load_netcdf(1, "data") with self.assertRaises(TypeError): ht.load_netcdf("iris.nc", variable=1) with self.assertRaises(TypeError): ht.load_netcdf("iris.nc", variable="data", split=1.0) # file or variable does not exist with self.assertRaises(IOError): ht.load_netcdf("foo.nc", variable="data") with self.assertRaises(IOError): ht.load_netcdf("iris.nc", variable="foo")