def test_hdf5_dataset_int32_zlib(self): """Test case for HDF5Dataset with zlib.""" # Note the file is generated with tdset.h5: # with h5py.File('compressed_h5.h5', 'w') as output_f: # output_f.create_dataset( # '/dset1', data=h5f['/dset1'][()], compression='gzip') filename = os.path.join(os.path.dirname(os.path.abspath(__file__)), "test_hdf5", "compressed_h5.h5") filename = "file://" + filename column = '/dset1' dtype = dtypes.int32 shape = tf.TensorShape([None, 20]) dataset = hdf5_io.HDF5Dataset(filename, column, start=0, stop=10, dtype=dtype, shape=shape).apply( tf.data.experimental.unbatch()) iterator = data.make_initializable_iterator(dataset) init_op = iterator.initializer get_next = iterator.get_next() with self.test_session() as sess: sess.run(init_op) for i in range(10): v0 = list([np.asarray([v for v in range(i, i + 20)])]) vv = sess.run(get_next) self.assertAllEqual(v0, [vv]) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next)
def test_hdf5_dataset(self): """Test case for HDF5Dataset.""" filename = os.path.join( os.path.dirname(os.path.abspath(__file__)), "test_hdf5", "tdset.h5") filename = "file://" + filename columns = ['/dset2'] output_types = [dtypes.float64] output_shapes = [(1, 20)] dataset = hdf5_io.HDF5Dataset( [filename], columns, output_types, output_shapes, batch=1) iterator = data.make_initializable_iterator(dataset) init_op = iterator.initializer get_next = iterator.get_next() with self.test_session() as sess: sess.run(init_op) for i in range(30): v0 = list( [np.asarray([[i + 1e-04 * v for v in range(20)]], dtype=np.float64)]) vv = sess.run(get_next) self.assertAllEqual(v0, vv) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next)
def test_hdf5_dataset_int32(self): """Test case for HDF5Dataset.""" filename = os.path.join(os.path.dirname(os.path.abspath(__file__)), "test_hdf5", "tdset.h5") filename = "file://" + filename column = '/dset1' dtype = dtypes.int32 shape = tf.TensorShape([None, 20]) dataset = hdf5_io.HDF5Dataset(filename, column, start=0, stop=10, dtype=dtype, shape=shape).apply( tf.data.experimental.unbatch()) iterator = data.make_initializable_iterator(dataset) init_op = iterator.initializer get_next = iterator.get_next() with self.test_session() as sess: sess.run(init_op) for i in range(10): v0 = list([np.asarray([v for v in range(i, i + 20)])]) vv = sess.run(get_next) self.assertAllEqual(v0, [vv]) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next)
def test_hdf5_read_dataset(): """test_hdf5_list_dataset""" filename = os.path.join(os.path.dirname(os.path.abspath(__file__)), "test_hdf5", "tdset.h5") for filename in [filename, "file://" + filename]: specs = hdf5_io.list_hdf5_datasets(filename) assert specs['/dset1'].dtype == tf.int32 assert specs['/dset1'].shape == tf.TensorShape([10, 20]) assert specs['/dset2'].dtype == tf.float64 assert specs['/dset2'].shape == tf.TensorShape([30, 20]) p1 = hdf5_io.read_hdf5(filename, specs['/dset1']) assert p1.dtype == tf.int32 assert p1.shape == tf.TensorShape([10, 20]) for i in range(10): vv = list([np.asarray([v for v in range(i, i + 20)])]) assert np.all(p1[i].numpy() == vv) dataset = hdf5_io.HDF5Dataset(filename, '/dset1').apply( tf.data.experimental.unbatch()) i = 0 for p in dataset: vv = list([np.asarray([v for v in range(i, i + 20)])]) assert np.all(p.numpy() == vv) i += 1
def test_hdf5_dataset_int32_zlib(self): """Test case for HDF5Dataset with zlib.""" # Note the file is generated with tdset.h5: # with h5py.File('compressed_h5.h5', 'w') as output_f: # output_f.create_dataset( # '/dset1', data=h5f['/dset1'][()], compression='gzip') filename = os.path.join(os.path.dirname(os.path.abspath(__file__)), "test_hdf5", "compressed_h5.h5") filename = "file://" + filename columns = ['/dset1'] output_types = [dtypes.int32] output_shapes = [(1, 20)] dataset = hdf5_io.HDF5Dataset([filename], columns, output_types, output_shapes) iterator = data.make_initializable_iterator(dataset) init_op = iterator.initializer get_next = iterator.get_next() with self.test_session() as sess: sess.run(init_op) for i in range(10): v0 = list([np.asarray([v for v in range(i, i + 20)])]) vv = sess.run(get_next) self.assertAllEqual(v0, vv) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next)
def test_hdf5_invalid_dataset(self): """test_hdf5_invalid_dataset""" filename = os.path.join(os.path.dirname(os.path.abspath(__file__)), "test_hdf5", "tdset.h5") filename = "file://" + filename dataset = hdf5_io.HDF5Dataset([filename], ['/invalid', '/invalid2'], [dtypes.int32, dtypes.int32], [(1, 20), (1, 30)]) iterator = data.make_initializable_iterator(dataset) init_op = iterator.initializer get_next = iterator.get_next() with self.test_session() as sess: sess.run(init_op) with self.assertRaisesRegexp(errors.InvalidArgumentError, "unable to open dataset /invalid"): sess.run(get_next)
def test_hdf5_invalid_dataset(self): """test_hdf5_invalid_dataset""" filename = os.path.join(os.path.dirname(os.path.abspath(__file__)), "test_hdf5", "tdset.h5") filename = "file://" + filename dataset = hdf5_io.HDF5Dataset(filename, '/invalid', dtype=dtypes.int32, shape=tf.TensorShape([1, 20]), start=0, stop=10) iterator = data.make_initializable_iterator(dataset) init_op = iterator.initializer get_next = iterator.get_next() with self.test_session() as sess: sess.run(init_op) with self.assertRaisesRegexp(errors.InvalidArgumentError, "unable to open dataset"): sess.run(get_next)
def test_hdf5_dataset_binary(self): """Test case for HDF5Dataset.""" filename = os.path.join(os.path.dirname(os.path.abspath(__file__)), "test_hdf5", "tbinary.h5") filename = "file://" + filename columns = ['integer', 'float', 'double'] output_types = [dtypes.int32, dtypes.float32, dtypes.float64] output_shapes = [(1), (1), (1)] dataset = hdf5_io.HDF5Dataset([filename], columns, output_types, output_shapes) iterator = data.make_initializable_iterator(dataset) init_op = iterator.initializer get_next = iterator.get_next() with self.test_session() as sess: sess.run(init_op) for i in range(1, 7): vv = sess.run(get_next) self.assertAllEqual((i, np.float32(i), np.float64(i)), vv) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next)