def test_elements_tuple_of_arrays(shapes, dtype, data): choices = data.draw(real_from_dtype(dtype)) elements = sampled_from(choices) S = gu._tuple_of_arrays(shapes, dtype, elements=elements) X = data.draw(S) validate_elements(X, choices=choices, dtype=dtype)
def test_shapes_tuple_of_arrays(shapes, dtype, unique, data): elements = from_dtype(np.dtype(dtype)) S = gu._tuple_of_arrays(shapes, dtype, elements=elements, unique=unique) X = data.draw(S) validate_elements(X, dtype=dtype, unique=unique) assert len(shapes) == len(X) for spec, drawn in zip(shapes, X): assert tuple(spec) == np.shape(drawn)
def test_just_shapes_tuple_of_arrays(shapes, dtype, unique, data): elements = from_dtype(np.dtype(dtype)) # test again, but this time pass in strategy to make sure it can handle it S = gu._tuple_of_arrays(just(shapes), just(dtype), elements=elements, unique=just(unique)) X = data.draw(S) validate_elements(X, dtype=dtype, unique=unique) assert len(shapes) == len(X) for spec, drawn in zip(shapes, X): assert tuple(spec) == np.shape(drawn)
def test_bcast_tuple_of_arrays(args, data): """Now testing broadcasting of tuple_of_arrays, kind of crazy since it uses gufuncs to test itself. Some awkwardness here since there are a lot of corner cases when dealing with object types in the numpy extension. For completeness, should probably right a function like this for the other functions, but there always just pass dtype, elements, unique to `_tuple_of_arrays` anyway, so this should be pretty good. """ shapes, dtype, elements, unique = args shapes = shapes.ravel() # Need to squeeze out due to weird behaviour of object dtype = np.squeeze(dtype, -1) elements = np.squeeze(elements, -1) elements_shape = max(dtype.shape, elements.shape) dtype_ = np.broadcast_to(dtype, elements_shape) if elements_shape == (): elements = from_dtype(dtype_.item()) else: elements = [from_dtype(dd) for dd in dtype_] shapes_shape = max(shapes.shape, dtype.shape, elements_shape, unique.shape) shapes = np.broadcast_to(shapes, shapes_shape) S = gu._tuple_of_arrays(shapes, dtype, elements=elements, unique=unique) X = data.draw(S) assert len(shapes) == len(X) for spec, drawn in zip(shapes, X): assert tuple(spec) == np.shape(drawn) for ii, xx in enumerate(X): dd = dtype[ii] if dtype.size > 1 else dtype.item() uu = unique[ii] if unique.size > 1 else unique.item() validate_elements([xx], dtype=dd, unique=uu)