コード例 #1
0
def test_conversion(convert):
    # Test basic conversion.
    assert convert(1.0, float) == 1.0
    assert convert(1.0, object) == 1.0
    with pytest.raises(TypeError):
        convert(1.0, int)

    # Test conversion with inheritance.
    r = Re()
    assert convert(r, Re) == r
    assert convert(r, Num) == r
    with pytest.raises(TypeError):
        convert(r, FP)

    # Test `add_conversion_method`.
    add_conversion_method(float, int, lambda _: 2.0)
    assert convert(1.0, float) == 1.0
    assert convert(1.0, object) == 1.0
    assert convert(1.0, int) == 2.0

    # Test `conversion_method`.
    @conversion_method(Num, FP)
    def num_to_fp(x):
        return 3.0

    assert convert(r, Re) == r
    assert convert(r, Num) == r
    assert convert(r, FP) == 3.0
コード例 #2
0
def test_inheritance(convert):
    add_promotion_rule(Num, FP, Num)
    add_promotion_rule(Num, Re, Num)
    add_promotion_rule(FP, Re, Num)
    add_conversion_method(FP, Num, lambda x: "Num from FP")
    add_conversion_method(Re, Num, lambda x: "Num from Re")

    n = Num()
    assert promote(n, FP()) == (n, "Num from FP")
    assert promote(Re(), n) == ("Num from Re", n)
    assert promote(Re(), FP()) == ("Num from Re", "Num from FP")
コード例 #3
0
def test_conversion(convert):
    dispatch = Dispatcher()

    @dispatch({int, str}, return_type=int)
    def f(x):
        return x

    assert f(1) == 1
    with pytest.raises(TypeError):
        f('1')

    add_conversion_method(str, int, int)

    assert f(1) == 1
    assert f('1') == 1
コード例 #4
0
def test_conversion(convert):
    dispatch = Dispatcher()

    @dispatch
    def f(x: Union[int, str]) -> int:
        return x

    assert f(1) == 1
    with pytest.raises(TypeError):
        f("1")

    add_conversion_method(str, int, int)

    assert f(1) == 1
    assert f("1") == 1
コード例 #5
0
def test_promotion(convert):
    assert promote() == ()
    assert promote(1) == (1, )
    assert promote(1.0) == (1.0, )
    assert promote(1, 1) == (1, 1)
    assert promote(1.0, 1.0) == (1.0, 1.0)
    assert promote(1, 1, 1) == (1, 1, 1)
    assert promote(1.0, 1.0, 1.0) == (1.0, 1.0, 1.0)
    with pytest.raises(TypeError):
        promote(1, 1.0)
    with pytest.raises(TypeError):
        promote(1.0, 1)

    add_promotion_rule(int, float, float)

    with pytest.raises(TypeError):
        promote(1, 1.0)
    with pytest.raises(TypeError):
        promote(1.0, 1)

    add_conversion_method(int, float, float)

    assert promote(1, 1.0) == (1.0, 1.0)
    assert promote(1, 1, 1.0) == (1.0, 1.0, 1.0)
    assert promote(1.0, 1.0, 1) == (1.0, 1.0, 1.0)

    with pytest.raises(TypeError):
        promote(1, "1")
    with pytest.raises(TypeError):
        promote("1", 1)
    with pytest.raises(TypeError):
        promote(1.0, "1")
    with pytest.raises(TypeError):
        promote("1", 1.0)

    add_promotion_rule(str, Union[int, float], Union[int, float])
    add_conversion_method(str, Union[int, float], float)

    assert promote(1, "1", "1") == (1.0, 1.0, 1.0)
    assert promote("1", 1, 1) == (1.0, 1.0, 1.0)
    assert promote(1.0, "1", 1) == (1.0, 1.0, 1.0)
    assert promote("1", 1.0, 1) == (1.0, 1.0, 1.0)

    add_promotion_rule(str, int, float)
    add_promotion_rule(str, float, float)
    add_conversion_method(str, float, lambda x: "lel")

    assert promote(1, "1", 1.0) == (1.0, "lel", 1.0)
    assert promote("1", 1, 1.0) == ("lel", 1.0, 1.0)
    assert promote(1.0, "1", 1) == (1.0, "lel", 1.0)
    assert promote("1", 1.0, "1") == ("lel", 1.0, "lel")
コード例 #6
0
def test_promotion(convert):
    assert promote() == ()
    assert promote(1) == (1, )
    assert promote(1.) == (1., )
    assert promote(1, 1) == (1, 1)
    assert promote(1., 1.) == (1., 1.)
    assert promote(1, 1, 1) == (1, 1, 1)
    assert promote(1., 1., 1.) == (1., 1., 1.)
    with pytest.raises(TypeError):
        promote(1, 1.)
    with pytest.raises(TypeError):
        promote(1., 1)

    add_promotion_rule(int, float, float)

    with pytest.raises(TypeError):
        promote(1, 1.)
    with pytest.raises(TypeError):
        promote(1., 1)

    add_conversion_method(int, float, float)

    assert promote(1, 1.) == (1., 1.)
    assert promote(1, 1, 1.) == (1., 1., 1.)
    assert promote(1., 1., 1) == (1., 1., 1.)

    with pytest.raises(TypeError):
        promote(1, '1')
    with pytest.raises(TypeError):
        promote('1', 1)
    with pytest.raises(TypeError):
        promote(1., '1')
    with pytest.raises(TypeError):
        promote('1', 1.)

    add_promotion_rule(str, {int, float}, {int, float})
    add_conversion_method(str, {int, float}, float)

    assert promote(1, '1', '1') == (1., 1., 1.)
    assert promote('1', 1, 1) == (1., 1., 1.)
    assert promote(1., '1', 1) == (1., 1., 1.)
    assert promote('1', 1., 1) == (1., 1., 1.)

    add_promotion_rule(str, int, float)
    add_promotion_rule(str, float, float)
    add_conversion_method(str, float, lambda x: 'lel')

    assert promote(1, '1', 1.) == (1., 'lel', 1.)
    assert promote('1', 1, 1.) == ('lel', 1., 1.)
    assert promote(1., '1', 1) == (1., 'lel', 1.)
    assert promote('1', 1., '1') == ('lel', 1., 'lel')
コード例 #7
0
ファイル: matrix.py プロジェクト: wubizhi/stheno

@B.shape.extend(LowRank)
def shape(a):
    return B.shape(a.left)[0], B.shape(a.right)[0]


@B.shape.extend(Woodbury)
def shape(a):
    return B.shape(a.lr)


# Setup promotion and conversion of matrices as a fallback mechanism.

add_promotion_rule(B.Numeric, Dense, B.Numeric)
add_conversion_method(Dense, B.Numeric, dense)

# Simplify addiction and multiplication between matrices.


@mul.extend(Dense, Dense)
def mul(a, b):
    return Dense(dense(a) * dense(b))


@mul.extend(Dense, Diagonal)
def mul(a, b):
    return Diagonal(B.diag(a) * b.diag, *B.shape(a))


@mul.extend(Diagonal, Dense)
コード例 #8
0
ファイル: generic.py プロジェクト: wesselb/lab
@abstract(promote=2)
def quantile(a, q, axis: Union[Int, None] = None):
    """Compute quantiles.

    Args:
        a (tensor): Tensor to compute quantiles of.
        q (tensor): Quantiles to compute. Must be numbers in `[0, 1]`.
        axis (int, optional): Axis to compute quantiles along. Defaults to `None`.

    Returns:
        tensor: Quantiles.
    """


NPOrNum = Union[NPNumeric, Number]  #: Type NumPy numeric or number.
add_conversion_method(AGNumeric, NPOrNum, lambda x: x._value)
add_conversion_method(TFNumeric, NPOrNum, lambda x: x.numpy())
add_conversion_method(TorchNumeric, NPOrNum, lambda x: x.detach().cpu().numpy())
add_conversion_method(JAXNumeric, NPOrNum, np.array)


@dispatch
def to_numpy(a):
    """Convert an object to NumPy.

    Args:
        a (object): Object to convert.

    Returns:
        `np.ndarray`: `a` as NumPy.
    """
コード例 #9
0
NPNumeric = Union(np.ndarray, alias="NPNumeric")
AGNumeric = Union(_ag_tensor, alias="AGNumeric")
TFNumeric = Union(_tf_tensor, _tf_variable, _tf_indexedslices, alias="TFNumeric")
TorchNumeric = Union(_torch_tensor, alias="TorchNumeric")
JAXNumeric = Union(_jax_tensor, _jax_tracer, alias="JAXNumeric")
Numeric = Union(
    Number, NPNumeric, AGNumeric, TFNumeric, JAXNumeric, TorchNumeric, alias="Numeric"
)

# Define corresponding promotion rules and conversion methods.
add_promotion_rule(NPNumeric, TFNumeric, TFNumeric)
add_promotion_rule(NPNumeric, TorchNumeric, TorchNumeric)
add_promotion_rule(NPNumeric, JAXNumeric, JAXNumeric)
add_promotion_rule(_tf_tensor, _tf_variable, TFNumeric)
add_conversion_method(
    NPNumeric, TFNumeric, lambda x: _module_call("tensorflow", "constant", x)
)
add_conversion_method(
    NPNumeric, TorchNumeric, lambda x: _module_call("torch", "tensor", x)
)
add_conversion_method(
    NPNumeric, JAXNumeric, lambda x: _module_call("jax.numpy", "asarray", x)
)

# Data types:
NPDType = Union(type, np.dtype, alias="NPDType")
AGDType = Union(NPDType, alias="AGDType")
TFDType = Union(_tf_dtype, alias="TFDType")
TorchDType = Union(_torch_dtype, alias="TorchDType")
JAXDType = Union(_jax_dtype, alias="JAXDType")
DType = Union(NPDType, TFDType, TorchDType, JAXDType, alias="DType")