def test_expansion_fail(): """Test whether expansion fails correctly""" dd = DataDict(a=dict(values=np.arange(4).reshape(2, 2)), b=dict(values=np.arange(4).reshape(2, 2), axes=['a']), x=dict(values=np.arange(6).reshape(2, 3), ), y=dict(values=np.arange(6).reshape(2, 3), axes=['x'])) assert dd.validate() assert not dd.is_expandable() with pytest.raises(ValueError): dd.expand()
def test_expansion_simple(): """Test whether simple expansion of nested parameters works.""" a = np.arange(3) x = np.arange(3) y = np.arange(7, 10) aaa, xxx, yyy = np.meshgrid(a, x, y, indexing='ij') zzz = aaa + xxx * yyy dd = DataDict( a=dict(values=a), x=dict(values=xxx), y=dict(values=yyy), z=dict(values=zzz), ) assert dd.validate() assert dd.nrecords() == 3 assert dd._inner_shapes() == dict(a=tuple(), x=(3, 3), y=(3, 3), z=(3, 3)) assert dd.is_expandable() assert not dd.is_expanded() dd2 = dd.expand() assert dd2.is_expanded() assert dd2.nrecords() == aaa.size assert np.all(np.isclose(dd2.data_vals('a'), aaa.reshape(-1))) assert np.all(np.isclose(dd2.data_vals('x'), xxx.reshape(-1))) assert np.all(np.isclose(dd2.data_vals('z'), zzz.reshape(-1))) assert set(dd2.shapes().values()) == {(aaa.size, )}