def test_promotion(convert): assert promote() == () assert promote(1) == (1, ) assert promote(1.0) == (1.0, ) assert promote(1, 1) == (1, 1) assert promote(1.0, 1.0) == (1.0, 1.0) assert promote(1, 1, 1) == (1, 1, 1) assert promote(1.0, 1.0, 1.0) == (1.0, 1.0, 1.0) with pytest.raises(TypeError): promote(1, 1.0) with pytest.raises(TypeError): promote(1.0, 1) add_promotion_rule(int, float, float) with pytest.raises(TypeError): promote(1, 1.0) with pytest.raises(TypeError): promote(1.0, 1) add_conversion_method(int, float, float) assert promote(1, 1.0) == (1.0, 1.0) assert promote(1, 1, 1.0) == (1.0, 1.0, 1.0) assert promote(1.0, 1.0, 1) == (1.0, 1.0, 1.0) with pytest.raises(TypeError): promote(1, "1") with pytest.raises(TypeError): promote("1", 1) with pytest.raises(TypeError): promote(1.0, "1") with pytest.raises(TypeError): promote("1", 1.0) add_promotion_rule(str, Union[int, float], Union[int, float]) add_conversion_method(str, Union[int, float], float) assert promote(1, "1", "1") == (1.0, 1.0, 1.0) assert promote("1", 1, 1) == (1.0, 1.0, 1.0) assert promote(1.0, "1", 1) == (1.0, 1.0, 1.0) assert promote("1", 1.0, 1) == (1.0, 1.0, 1.0) add_promotion_rule(str, int, float) add_promotion_rule(str, float, float) add_conversion_method(str, float, lambda x: "lel") assert promote(1, "1", 1.0) == (1.0, "lel", 1.0) assert promote("1", 1, 1.0) == ("lel", 1.0, 1.0) assert promote(1.0, "1", 1) == (1.0, "lel", 1.0) assert promote("1", 1.0, "1") == ("lel", 1.0, "lel")
def test_promotion(convert): assert promote() == () assert promote(1) == (1, ) assert promote(1.) == (1., ) assert promote(1, 1) == (1, 1) assert promote(1., 1.) == (1., 1.) assert promote(1, 1, 1) == (1, 1, 1) assert promote(1., 1., 1.) == (1., 1., 1.) with pytest.raises(TypeError): promote(1, 1.) with pytest.raises(TypeError): promote(1., 1) add_promotion_rule(int, float, float) with pytest.raises(TypeError): promote(1, 1.) with pytest.raises(TypeError): promote(1., 1) add_conversion_method(int, float, float) assert promote(1, 1.) == (1., 1.) assert promote(1, 1, 1.) == (1., 1., 1.) assert promote(1., 1., 1) == (1., 1., 1.) with pytest.raises(TypeError): promote(1, '1') with pytest.raises(TypeError): promote('1', 1) with pytest.raises(TypeError): promote(1., '1') with pytest.raises(TypeError): promote('1', 1.) add_promotion_rule(str, {int, float}, {int, float}) add_conversion_method(str, {int, float}, float) assert promote(1, '1', '1') == (1., 1., 1.) assert promote('1', 1, 1) == (1., 1., 1.) assert promote(1., '1', 1) == (1., 1., 1.) assert promote('1', 1., 1) == (1., 1., 1.) add_promotion_rule(str, int, float) add_promotion_rule(str, float, float) add_conversion_method(str, float, lambda x: 'lel') assert promote(1, '1', 1.) == (1., 'lel', 1.) assert promote('1', 1, 1.) == ('lel', 1., 1.) assert promote(1., '1', 1) == (1., 'lel', 1.) assert promote('1', 1., '1') == ('lel', 1., 'lel')
def test_inheritance(convert): add_promotion_rule(Num, FP, Num) add_promotion_rule(Num, Re, Num) add_promotion_rule(FP, Re, Num) add_conversion_method(FP, Num, lambda x: "Num from FP") add_conversion_method(Re, Num, lambda x: "Num from Re") n = Num() assert promote(n, FP()) == (n, "Num from FP") assert promote(Re(), n) == ("Num from Re", n) assert promote(Re(), FP()) == ("Num from Re", "Num from FP")
return a.rows, a.cols @B.shape.extend(LowRank) def shape(a): return B.shape(a.left)[0], B.shape(a.right)[0] @B.shape.extend(Woodbury) def shape(a): return B.shape(a.lr) # Setup promotion and conversion of matrices as a fallback mechanism. add_promotion_rule(B.Numeric, Dense, B.Numeric) add_conversion_method(Dense, B.Numeric, dense) # Simplify addiction and multiplication between matrices. @mul.extend(Dense, Dense) def mul(a, b): return Dense(dense(a) * dense(b)) @mul.extend(Dense, Diagonal) def mul(a, b): return Diagonal(B.diag(a) * b.diag, *B.shape(a))
import lab as B from plum import add_promotion_rule, conversion_method from .constant import Zero, Constant from .lowrank import LowRank from .matrix import AbstractMatrix, Dense __all__ = [] add_promotion_rule(AbstractMatrix, B.Numeric, AbstractMatrix) add_promotion_rule(AbstractMatrix, AbstractMatrix, AbstractMatrix) @conversion_method(B.Numeric, AbstractMatrix) def convert(x): if B.rank(x) == 0: if isinstance(x, B.Number) and x == 0: return Zero(B.dtype(x), 1, 1) else: return Constant(x, 1, 1) elif B.rank(x) == 2: return Dense(x) else: raise RuntimeError( f"Cannot convert rank {B.rank(x)} input to a matrix.") @conversion_method(Constant, LowRank) def constant_to_lowrank(a): dtype = B.dtype(a) rows, cols = B.shape(a)
Int = Union(*([int, Dimension] + np.sctypes["int"] + np.sctypes["uint"]), alias="Int") Float = Union(*([float] + np.sctypes["float"]), alias="Float") Complex = Union(*([complex] + np.sctypes["complex"]), alias="Complex") Bool = Union(bool, np.bool_, alias="Bool") Number = Union(Int, Bool, Float, Complex, alias="Number") NPNumeric = Union(np.ndarray, alias="NPNumeric") AGNumeric = Union(_ag_tensor, alias="AGNumeric") TFNumeric = Union(_tf_tensor, _tf_variable, _tf_indexedslices, alias="TFNumeric") TorchNumeric = Union(_torch_tensor, alias="TorchNumeric") JAXNumeric = Union(_jax_tensor, _jax_tracer, alias="JAXNumeric") Numeric = Union( Number, NPNumeric, AGNumeric, TFNumeric, JAXNumeric, TorchNumeric, alias="Numeric" ) # Define corresponding promotion rules and conversion methods. add_promotion_rule(NPNumeric, TFNumeric, TFNumeric) add_promotion_rule(NPNumeric, TorchNumeric, TorchNumeric) add_promotion_rule(NPNumeric, JAXNumeric, JAXNumeric) add_promotion_rule(_tf_tensor, _tf_variable, TFNumeric) add_conversion_method( NPNumeric, TFNumeric, lambda x: _module_call("tensorflow", "constant", x) ) add_conversion_method( NPNumeric, TorchNumeric, lambda x: _module_call("torch", "tensor", x) ) add_conversion_method( NPNumeric, JAXNumeric, lambda x: _module_call("jax.numpy", "asarray", x) ) # Data types: NPDType = Union(type, np.dtype, alias="NPDType")