Exemple #1
0
def test_elements_tuple_of_arrays(shapes, dtype, data):
    choices = data.draw(real_from_dtype(dtype))

    elements = sampled_from(choices)
    S = gu._tuple_of_arrays(shapes, dtype, elements=elements)
    X = data.draw(S)

    validate_elements(X, choices=choices, dtype=dtype)
Exemple #2
0
def test_shapes_tuple_of_arrays(shapes, dtype, unique, data):
    elements = from_dtype(np.dtype(dtype))

    S = gu._tuple_of_arrays(shapes, dtype, elements=elements, unique=unique)
    X = data.draw(S)

    validate_elements(X, dtype=dtype, unique=unique)

    assert len(shapes) == len(X)
    for spec, drawn in zip(shapes, X):
        assert tuple(spec) == np.shape(drawn)
Exemple #3
0
def test_just_shapes_tuple_of_arrays(shapes, dtype, unique, data):
    elements = from_dtype(np.dtype(dtype))

    # test again, but this time pass in strategy to make sure it can handle it
    S = gu._tuple_of_arrays(just(shapes), just(dtype), elements=elements, unique=just(unique))
    X = data.draw(S)

    validate_elements(X, dtype=dtype, unique=unique)

    assert len(shapes) == len(X)
    for spec, drawn in zip(shapes, X):
        assert tuple(spec) == np.shape(drawn)
Exemple #4
0
def test_bcast_tuple_of_arrays(args, data):
    """Now testing broadcasting of tuple_of_arrays, kind of crazy since it uses
    gufuncs to test itself. Some awkwardness here since there are a lot of
    corner cases when dealing with object types in the numpy extension.

    For completeness, should probably right a function like this for the other
    functions, but there always just pass dtype, elements, unique to
    `_tuple_of_arrays` anyway, so this should be pretty good.
    """
    shapes, dtype, elements, unique = args

    shapes = shapes.ravel()
    # Need to squeeze out due to weird behaviour of object
    dtype = np.squeeze(dtype, -1)
    elements = np.squeeze(elements, -1)

    elements_shape = max(dtype.shape, elements.shape)
    dtype_ = np.broadcast_to(dtype, elements_shape)
    if elements_shape == ():
        elements = from_dtype(dtype_.item())
    else:
        elements = [from_dtype(dd) for dd in dtype_]

    shapes_shape = max(shapes.shape, dtype.shape, elements_shape, unique.shape)
    shapes = np.broadcast_to(shapes, shapes_shape)

    S = gu._tuple_of_arrays(shapes, dtype, elements=elements, unique=unique)
    X = data.draw(S)

    assert len(shapes) == len(X)
    for spec, drawn in zip(shapes, X):
        assert tuple(spec) == np.shape(drawn)

    for ii, xx in enumerate(X):
        dd = dtype[ii] if dtype.size > 1 else dtype.item()
        uu = unique[ii] if unique.size > 1 else unique.item()
        validate_elements([xx], dtype=dd, unique=uu)