Ejemplo n.º 1
0
def test_promotion(convert):
    assert promote() == ()
    assert promote(1) == (1, )
    assert promote(1.0) == (1.0, )
    assert promote(1, 1) == (1, 1)
    assert promote(1.0, 1.0) == (1.0, 1.0)
    assert promote(1, 1, 1) == (1, 1, 1)
    assert promote(1.0, 1.0, 1.0) == (1.0, 1.0, 1.0)
    with pytest.raises(TypeError):
        promote(1, 1.0)
    with pytest.raises(TypeError):
        promote(1.0, 1)

    add_promotion_rule(int, float, float)

    with pytest.raises(TypeError):
        promote(1, 1.0)
    with pytest.raises(TypeError):
        promote(1.0, 1)

    add_conversion_method(int, float, float)

    assert promote(1, 1.0) == (1.0, 1.0)
    assert promote(1, 1, 1.0) == (1.0, 1.0, 1.0)
    assert promote(1.0, 1.0, 1) == (1.0, 1.0, 1.0)

    with pytest.raises(TypeError):
        promote(1, "1")
    with pytest.raises(TypeError):
        promote("1", 1)
    with pytest.raises(TypeError):
        promote(1.0, "1")
    with pytest.raises(TypeError):
        promote("1", 1.0)

    add_promotion_rule(str, Union[int, float], Union[int, float])
    add_conversion_method(str, Union[int, float], float)

    assert promote(1, "1", "1") == (1.0, 1.0, 1.0)
    assert promote("1", 1, 1) == (1.0, 1.0, 1.0)
    assert promote(1.0, "1", 1) == (1.0, 1.0, 1.0)
    assert promote("1", 1.0, 1) == (1.0, 1.0, 1.0)

    add_promotion_rule(str, int, float)
    add_promotion_rule(str, float, float)
    add_conversion_method(str, float, lambda x: "lel")

    assert promote(1, "1", 1.0) == (1.0, "lel", 1.0)
    assert promote("1", 1, 1.0) == ("lel", 1.0, 1.0)
    assert promote(1.0, "1", 1) == (1.0, "lel", 1.0)
    assert promote("1", 1.0, "1") == ("lel", 1.0, "lel")
def test_promotion(convert):
    assert promote() == ()
    assert promote(1) == (1, )
    assert promote(1.) == (1., )
    assert promote(1, 1) == (1, 1)
    assert promote(1., 1.) == (1., 1.)
    assert promote(1, 1, 1) == (1, 1, 1)
    assert promote(1., 1., 1.) == (1., 1., 1.)
    with pytest.raises(TypeError):
        promote(1, 1.)
    with pytest.raises(TypeError):
        promote(1., 1)

    add_promotion_rule(int, float, float)

    with pytest.raises(TypeError):
        promote(1, 1.)
    with pytest.raises(TypeError):
        promote(1., 1)

    add_conversion_method(int, float, float)

    assert promote(1, 1.) == (1., 1.)
    assert promote(1, 1, 1.) == (1., 1., 1.)
    assert promote(1., 1., 1) == (1., 1., 1.)

    with pytest.raises(TypeError):
        promote(1, '1')
    with pytest.raises(TypeError):
        promote('1', 1)
    with pytest.raises(TypeError):
        promote(1., '1')
    with pytest.raises(TypeError):
        promote('1', 1.)

    add_promotion_rule(str, {int, float}, {int, float})
    add_conversion_method(str, {int, float}, float)

    assert promote(1, '1', '1') == (1., 1., 1.)
    assert promote('1', 1, 1) == (1., 1., 1.)
    assert promote(1., '1', 1) == (1., 1., 1.)
    assert promote('1', 1., 1) == (1., 1., 1.)

    add_promotion_rule(str, int, float)
    add_promotion_rule(str, float, float)
    add_conversion_method(str, float, lambda x: 'lel')

    assert promote(1, '1', 1.) == (1., 'lel', 1.)
    assert promote('1', 1, 1.) == ('lel', 1., 1.)
    assert promote(1., '1', 1) == (1., 'lel', 1.)
    assert promote('1', 1., '1') == ('lel', 1., 'lel')
Ejemplo n.º 3
0
def test_inheritance(convert):
    add_promotion_rule(Num, FP, Num)
    add_promotion_rule(Num, Re, Num)
    add_promotion_rule(FP, Re, Num)
    add_conversion_method(FP, Num, lambda x: "Num from FP")
    add_conversion_method(Re, Num, lambda x: "Num from Re")

    n = Num()
    assert promote(n, FP()) == (n, "Num from FP")
    assert promote(Re(), n) == ("Num from Re", n)
    assert promote(Re(), FP()) == ("Num from Re", "Num from FP")
Ejemplo n.º 4
0
    return a.rows, a.cols


@B.shape.extend(LowRank)
def shape(a):
    return B.shape(a.left)[0], B.shape(a.right)[0]


@B.shape.extend(Woodbury)
def shape(a):
    return B.shape(a.lr)


# Setup promotion and conversion of matrices as a fallback mechanism.

add_promotion_rule(B.Numeric, Dense, B.Numeric)
add_conversion_method(Dense, B.Numeric, dense)

# Simplify addiction and multiplication between matrices.


@mul.extend(Dense, Dense)
def mul(a, b):
    return Dense(dense(a) * dense(b))


@mul.extend(Dense, Diagonal)
def mul(a, b):
    return Diagonal(B.diag(a) * b.diag, *B.shape(a))

Ejemplo n.º 5
0
import lab as B
from plum import add_promotion_rule, conversion_method

from .constant import Zero, Constant
from .lowrank import LowRank
from .matrix import AbstractMatrix, Dense

__all__ = []

add_promotion_rule(AbstractMatrix, B.Numeric, AbstractMatrix)
add_promotion_rule(AbstractMatrix, AbstractMatrix, AbstractMatrix)


@conversion_method(B.Numeric, AbstractMatrix)
def convert(x):
    if B.rank(x) == 0:
        if isinstance(x, B.Number) and x == 0:
            return Zero(B.dtype(x), 1, 1)
        else:
            return Constant(x, 1, 1)
    elif B.rank(x) == 2:
        return Dense(x)
    else:
        raise RuntimeError(
            f"Cannot convert rank {B.rank(x)} input to a matrix.")


@conversion_method(Constant, LowRank)
def constant_to_lowrank(a):
    dtype = B.dtype(a)
    rows, cols = B.shape(a)
Ejemplo n.º 6
0
Int = Union(*([int, Dimension] + np.sctypes["int"] + np.sctypes["uint"]), alias="Int")
Float = Union(*([float] + np.sctypes["float"]), alias="Float")
Complex = Union(*([complex] + np.sctypes["complex"]), alias="Complex")
Bool = Union(bool, np.bool_, alias="Bool")
Number = Union(Int, Bool, Float, Complex, alias="Number")
NPNumeric = Union(np.ndarray, alias="NPNumeric")
AGNumeric = Union(_ag_tensor, alias="AGNumeric")
TFNumeric = Union(_tf_tensor, _tf_variable, _tf_indexedslices, alias="TFNumeric")
TorchNumeric = Union(_torch_tensor, alias="TorchNumeric")
JAXNumeric = Union(_jax_tensor, _jax_tracer, alias="JAXNumeric")
Numeric = Union(
    Number, NPNumeric, AGNumeric, TFNumeric, JAXNumeric, TorchNumeric, alias="Numeric"
)

# Define corresponding promotion rules and conversion methods.
add_promotion_rule(NPNumeric, TFNumeric, TFNumeric)
add_promotion_rule(NPNumeric, TorchNumeric, TorchNumeric)
add_promotion_rule(NPNumeric, JAXNumeric, JAXNumeric)
add_promotion_rule(_tf_tensor, _tf_variable, TFNumeric)
add_conversion_method(
    NPNumeric, TFNumeric, lambda x: _module_call("tensorflow", "constant", x)
)
add_conversion_method(
    NPNumeric, TorchNumeric, lambda x: _module_call("torch", "tensor", x)
)
add_conversion_method(
    NPNumeric, JAXNumeric, lambda x: _module_call("jax.numpy", "asarray", x)
)

# Data types:
NPDType = Union(type, np.dtype, alias="NPDType")