def test_hdf_target_float_dense(): from GeneratingDataset import StaticDataset dataset = StaticDataset([ {"data": numpy.array([[1, 2, 3], [2, 3, 4]], dtype="float32"), "classes": numpy.array([[-1, 5], [-2, 4], [-3, 2]], dtype="float32")}]) orig_data_dtype = dataset.get_data_dtype("data") orig_classes_dtype = dataset.get_data_dtype("classes") assert orig_data_dtype == "float32" and orig_classes_dtype == "float32" orig_data_shape = dataset.get_data_shape("data") orig_classes_shape = dataset.get_data_shape("classes") assert orig_data_shape == [3] and orig_classes_shape == [2] hdf_fn = _get_tmp_file(suffix=".hdf") hdf_writer = HDFDatasetWriter(filename=hdf_fn) hdf_writer.dump_from_dataset(dataset, use_progress_bar=False) hdf_writer.close() hdf_dataset = HDFDataset(files=[hdf_fn]) hdf_dataset.initialize() hdf_dataset.init_seq_order(epoch=1) hdf_data_dtype = hdf_dataset.get_data_dtype("data") hdf_classes_dtype = hdf_dataset.get_data_dtype("classes") assert hdf_data_dtype == orig_data_dtype and hdf_classes_dtype == orig_classes_dtype hdf_data_dim = hdf_dataset.get_data_dim("data") hdf_classes_dim = hdf_dataset.get_data_dim("classes") assert hdf_data_dim == orig_data_shape[-1] and hdf_classes_dim == orig_classes_shape[-1] hdf_data_shape = hdf_dataset.get_data_shape("data") hdf_classes_shape = hdf_dataset.get_data_shape("classes") assert hdf_data_shape == orig_data_shape and hdf_classes_shape == orig_classes_shape hdf_dataset.load_seqs(0, 1) hdf_data_data = hdf_dataset.get_data(0, "data") hdf_data_classes = hdf_dataset.get_data(0, "classes") assert hdf_data_data.dtype == orig_data_dtype and hdf_data_classes.dtype == orig_classes_dtype
def test_hdf_data_short_int_dtype(): from GeneratingDataset import StaticDataset dataset = StaticDataset([ {"data": numpy.array([1, 2, 3], dtype="uint8"), "classes": numpy.array([-1, 5], dtype="int16")}], output_dim={"data": (255, 1), "classes": (10, 1)}) orig_data_dtype = dataset.get_data_dtype("data") orig_classes_dtype = dataset.get_data_dtype("classes") assert orig_data_dtype == "uint8" and orig_classes_dtype == "int16" hdf_fn = _get_tmp_file(suffix=".hdf") hdf_writer = HDFDatasetWriter(filename=hdf_fn) hdf_writer.dump_from_dataset(dataset, use_progress_bar=False) hdf_writer.close() hdf_dataset = HDFDataset(files=[hdf_fn]) hdf_dataset.initialize() hdf_dataset.init_seq_order(epoch=1) hdf_data_dtype = hdf_dataset.get_data_dtype("data") hdf_classes_dtype = hdf_dataset.get_data_dtype("classes") assert hdf_data_dtype == orig_data_dtype and hdf_classes_dtype == orig_classes_dtype hdf_data_dim = hdf_dataset.get_data_dim("data") hdf_classes_dim = hdf_dataset.get_data_dim("classes") assert hdf_data_dim == 255 and hdf_classes_dim == 10 hdf_data_shape = hdf_dataset.get_data_shape("data") hdf_classes_shape = hdf_dataset.get_data_shape("classes") assert hdf_data_shape == [] and hdf_classes_shape == [] hdf_dataset.load_seqs(0, 1) hdf_data_data = hdf_dataset.get_data(0, "data") hdf_data_classes = hdf_dataset.get_data(0, "classes") assert hdf_data_data.dtype == orig_data_dtype and hdf_data_classes.dtype == orig_classes_dtype
def test_hdf_data_target_int32(): from GeneratingDataset import StaticDataset dataset = StaticDataset([ {"data": numpy.array([1, 2, 3], dtype="uint8"), "classes": numpy.array([2147483647, 2147483646, 2147483645], dtype="int32")}], output_dim={"data": (255, 1), "classes": (10, 1)}) dataset.initialize() dataset.init_seq_order(epoch=0) dataset.load_seqs(0, 1) orig_classes_dtype = dataset.get_data_dtype("classes") orig_classes_seq = dataset.get_data(0, "classes") assert orig_classes_seq.shape == (3,) and orig_classes_seq[0] == 2147483647 assert orig_classes_seq.dtype == orig_classes_dtype == "int32" hdf_fn = _get_tmp_file(suffix=".hdf") hdf_writer = HDFDatasetWriter(filename=hdf_fn) hdf_writer.dump_from_dataset(dataset, use_progress_bar=False) hdf_writer.close() hdf_dataset = HDFDataset(files=[hdf_fn]) hdf_dataset.initialize() hdf_dataset.init_seq_order(epoch=1) hdf_classes_dtype = hdf_dataset.get_data_dtype("classes") assert hdf_classes_dtype == orig_classes_dtype hdf_classes_shape = hdf_dataset.get_data_shape("classes") assert hdf_classes_shape == [] hdf_dataset.load_seqs(0, 1) hdf_data_classes = hdf_dataset.get_data(0, "classes") assert hdf_data_classes.dtype == orig_classes_dtype assert all(hdf_data_classes == orig_classes_seq)