def test_conversion(convert): # Test basic conversion. assert convert(1.0, float) == 1.0 assert convert(1.0, object) == 1.0 with pytest.raises(TypeError): convert(1.0, int) # Test conversion with inheritance. r = Re() assert convert(r, Re) == r assert convert(r, Num) == r with pytest.raises(TypeError): convert(r, FP) # Test `add_conversion_method`. add_conversion_method(float, int, lambda _: 2.0) assert convert(1.0, float) == 1.0 assert convert(1.0, object) == 1.0 assert convert(1.0, int) == 2.0 # Test `conversion_method`. @conversion_method(Num, FP) def num_to_fp(x): return 3.0 assert convert(r, Re) == r assert convert(r, Num) == r assert convert(r, FP) == 3.0
def test_inheritance(convert): add_promotion_rule(Num, FP, Num) add_promotion_rule(Num, Re, Num) add_promotion_rule(FP, Re, Num) add_conversion_method(FP, Num, lambda x: "Num from FP") add_conversion_method(Re, Num, lambda x: "Num from Re") n = Num() assert promote(n, FP()) == (n, "Num from FP") assert promote(Re(), n) == ("Num from Re", n) assert promote(Re(), FP()) == ("Num from Re", "Num from FP")
def test_conversion(convert): dispatch = Dispatcher() @dispatch({int, str}, return_type=int) def f(x): return x assert f(1) == 1 with pytest.raises(TypeError): f('1') add_conversion_method(str, int, int) assert f(1) == 1 assert f('1') == 1
def test_conversion(convert): dispatch = Dispatcher() @dispatch def f(x: Union[int, str]) -> int: return x assert f(1) == 1 with pytest.raises(TypeError): f("1") add_conversion_method(str, int, int) assert f(1) == 1 assert f("1") == 1
def test_promotion(convert): assert promote() == () assert promote(1) == (1, ) assert promote(1.0) == (1.0, ) assert promote(1, 1) == (1, 1) assert promote(1.0, 1.0) == (1.0, 1.0) assert promote(1, 1, 1) == (1, 1, 1) assert promote(1.0, 1.0, 1.0) == (1.0, 1.0, 1.0) with pytest.raises(TypeError): promote(1, 1.0) with pytest.raises(TypeError): promote(1.0, 1) add_promotion_rule(int, float, float) with pytest.raises(TypeError): promote(1, 1.0) with pytest.raises(TypeError): promote(1.0, 1) add_conversion_method(int, float, float) assert promote(1, 1.0) == (1.0, 1.0) assert promote(1, 1, 1.0) == (1.0, 1.0, 1.0) assert promote(1.0, 1.0, 1) == (1.0, 1.0, 1.0) with pytest.raises(TypeError): promote(1, "1") with pytest.raises(TypeError): promote("1", 1) with pytest.raises(TypeError): promote(1.0, "1") with pytest.raises(TypeError): promote("1", 1.0) add_promotion_rule(str, Union[int, float], Union[int, float]) add_conversion_method(str, Union[int, float], float) assert promote(1, "1", "1") == (1.0, 1.0, 1.0) assert promote("1", 1, 1) == (1.0, 1.0, 1.0) assert promote(1.0, "1", 1) == (1.0, 1.0, 1.0) assert promote("1", 1.0, 1) == (1.0, 1.0, 1.0) add_promotion_rule(str, int, float) add_promotion_rule(str, float, float) add_conversion_method(str, float, lambda x: "lel") assert promote(1, "1", 1.0) == (1.0, "lel", 1.0) assert promote("1", 1, 1.0) == ("lel", 1.0, 1.0) assert promote(1.0, "1", 1) == (1.0, "lel", 1.0) assert promote("1", 1.0, "1") == ("lel", 1.0, "lel")
def test_promotion(convert): assert promote() == () assert promote(1) == (1, ) assert promote(1.) == (1., ) assert promote(1, 1) == (1, 1) assert promote(1., 1.) == (1., 1.) assert promote(1, 1, 1) == (1, 1, 1) assert promote(1., 1., 1.) == (1., 1., 1.) with pytest.raises(TypeError): promote(1, 1.) with pytest.raises(TypeError): promote(1., 1) add_promotion_rule(int, float, float) with pytest.raises(TypeError): promote(1, 1.) with pytest.raises(TypeError): promote(1., 1) add_conversion_method(int, float, float) assert promote(1, 1.) == (1., 1.) assert promote(1, 1, 1.) == (1., 1., 1.) assert promote(1., 1., 1) == (1., 1., 1.) with pytest.raises(TypeError): promote(1, '1') with pytest.raises(TypeError): promote('1', 1) with pytest.raises(TypeError): promote(1., '1') with pytest.raises(TypeError): promote('1', 1.) add_promotion_rule(str, {int, float}, {int, float}) add_conversion_method(str, {int, float}, float) assert promote(1, '1', '1') == (1., 1., 1.) assert promote('1', 1, 1) == (1., 1., 1.) assert promote(1., '1', 1) == (1., 1., 1.) assert promote('1', 1., 1) == (1., 1., 1.) add_promotion_rule(str, int, float) add_promotion_rule(str, float, float) add_conversion_method(str, float, lambda x: 'lel') assert promote(1, '1', 1.) == (1., 'lel', 1.) assert promote('1', 1, 1.) == ('lel', 1., 1.) assert promote(1., '1', 1) == (1., 'lel', 1.) assert promote('1', 1., '1') == ('lel', 1., 'lel')
@B.shape.extend(LowRank) def shape(a): return B.shape(a.left)[0], B.shape(a.right)[0] @B.shape.extend(Woodbury) def shape(a): return B.shape(a.lr) # Setup promotion and conversion of matrices as a fallback mechanism. add_promotion_rule(B.Numeric, Dense, B.Numeric) add_conversion_method(Dense, B.Numeric, dense) # Simplify addiction and multiplication between matrices. @mul.extend(Dense, Dense) def mul(a, b): return Dense(dense(a) * dense(b)) @mul.extend(Dense, Diagonal) def mul(a, b): return Diagonal(B.diag(a) * b.diag, *B.shape(a)) @mul.extend(Diagonal, Dense)
@abstract(promote=2) def quantile(a, q, axis: Union[Int, None] = None): """Compute quantiles. Args: a (tensor): Tensor to compute quantiles of. q (tensor): Quantiles to compute. Must be numbers in `[0, 1]`. axis (int, optional): Axis to compute quantiles along. Defaults to `None`. Returns: tensor: Quantiles. """ NPOrNum = Union[NPNumeric, Number] #: Type NumPy numeric or number. add_conversion_method(AGNumeric, NPOrNum, lambda x: x._value) add_conversion_method(TFNumeric, NPOrNum, lambda x: x.numpy()) add_conversion_method(TorchNumeric, NPOrNum, lambda x: x.detach().cpu().numpy()) add_conversion_method(JAXNumeric, NPOrNum, np.array) @dispatch def to_numpy(a): """Convert an object to NumPy. Args: a (object): Object to convert. Returns: `np.ndarray`: `a` as NumPy. """
NPNumeric = Union(np.ndarray, alias="NPNumeric") AGNumeric = Union(_ag_tensor, alias="AGNumeric") TFNumeric = Union(_tf_tensor, _tf_variable, _tf_indexedslices, alias="TFNumeric") TorchNumeric = Union(_torch_tensor, alias="TorchNumeric") JAXNumeric = Union(_jax_tensor, _jax_tracer, alias="JAXNumeric") Numeric = Union( Number, NPNumeric, AGNumeric, TFNumeric, JAXNumeric, TorchNumeric, alias="Numeric" ) # Define corresponding promotion rules and conversion methods. add_promotion_rule(NPNumeric, TFNumeric, TFNumeric) add_promotion_rule(NPNumeric, TorchNumeric, TorchNumeric) add_promotion_rule(NPNumeric, JAXNumeric, JAXNumeric) add_promotion_rule(_tf_tensor, _tf_variable, TFNumeric) add_conversion_method( NPNumeric, TFNumeric, lambda x: _module_call("tensorflow", "constant", x) ) add_conversion_method( NPNumeric, TorchNumeric, lambda x: _module_call("torch", "tensor", x) ) add_conversion_method( NPNumeric, JAXNumeric, lambda x: _module_call("jax.numpy", "asarray", x) ) # Data types: NPDType = Union(type, np.dtype, alias="NPDType") AGDType = Union(NPDType, alias="AGDType") TFDType = Union(_tf_dtype, alias="TFDType") TorchDType = Union(_torch_dtype, alias="TorchDType") JAXDType = Union(_jax_dtype, alias="JAXDType") DType = Union(NPDType, TFDType, TorchDType, JAXDType, alias="DType")