def test_storage_copy(tmp_path): """ test the copy function of StorageBase """ grid = UnitGrid([2]) field = ScalarField(grid) storage_classes = {"None": None, "MemoryStorage": MemoryStorage} if module_available("h5py"): file_path = tmp_path / "test_storage_apply.hdf5" storage_classes["FileStorage"] = functools.partial( FileStorage, file_path) s1 = MemoryStorage() s1.start_writing(field, info={"b": 2}) s1.append(field.copy(data=np.array([0, 1])), 0) s1.append(field.copy(data=np.array([1, 2])), 1) s1.end_writing() for name, storage_cls in storage_classes.items(): out = None if storage_cls is None else storage_cls() s2 = s1.copy(out=out) assert storage_cls is None or s2 is out assert len(s2) == 2 np.testing.assert_allclose(s2.times, s1.times) assert s2[0] == s1[0], name assert s2[1] == s1[1], name # test empty storage s1 = MemoryStorage() s2 = s1.copy() assert len(s2) == 0
def test_field_type_guessing(): """ test the ability to guess the field type """ for cls in [ScalarField, VectorField, Tensor2Field]: grid = UnitGrid([3]) field = cls.random_normal(grid) s = MemoryStorage() s.start_writing(field) s.append(field, 0) s.append(field, 1) # delete information s._field = None s.info = {} assert not s.has_collection assert len(s) == 2 assert s[0] == field field = FieldCollection([ScalarField(grid), VectorField(grid)]) s = MemoryStorage() s.start_writing(field) s.append(field, 0) assert s.has_collection # delete information s._field = None s.info = {} with pytest.raises(RuntimeError): s[0]
def test_memory_storage(): """test methods specific to memory storage""" sf = ScalarField(UnitGrid([1])) s1 = MemoryStorage() s1.start_writing(sf) sf.data = 0 s1.append(sf, 0) sf.data = 2 s1.append(sf, 1) s2 = MemoryStorage() s2.start_writing(sf) sf.data = 1 s2.append(sf, 0) sf.data = 3 s2.append(sf, 1) # test from_fields s3 = MemoryStorage.from_fields(s1.times, [s1[0], s1[1]]) assert s3.times == s1.times np.testing.assert_allclose(s3.data, s1.data) # test from_collection s3 = MemoryStorage.from_collection([s1, s2]) assert s3.times == s1.times np.testing.assert_allclose(np.ravel(s3.data), np.arange(4))