def test_jax_multioutput(): x = vector("x") x.tag.test_value = np.r_[1.0, 2.0].astype(config.floatX) y = vector("y") y.tag.test_value = np.r_[3.0, 4.0].astype(config.floatX) w = cosh(x**2 + y / 3.0) v = cosh(x / 3.0 + y**2) fgraph = FunctionGraph([x, y], [w, v]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
def test_reallocation(): x = scalar("x") y = scalar("y") z = tanh(3 * x + y) + cosh(x + 5 * y) # The functinality is currently implement for non lazy and non c VM only. for linker in [ VMLinker(allow_gc=False, lazy=False, use_cloop=False), VMLinker(allow_gc=True, lazy=False, use_cloop=False), ]: m = get_mode(Mode(linker=linker)) m = m.excluding("fusion", "inplace") f = function([x, y], z, name="test_reduce_memory", mode=m) output = f(1, 2) assert output storage_map = f.fn.storage_map def check_storage(storage_map): for i in storage_map: if not isinstance(i, TensorConstant): keys_copy = list(storage_map.keys())[:] keys_copy.remove(i) for o in keys_copy: if storage_map[i][ 0] and storage_map[i][0] is storage_map[o][0]: return [True, storage_map[o][0]] return [False, None] assert check_storage(storage_map)[0] assert len({id(v) for v in storage_map.values()}) < len(storage_map)
def test_jax_basic(): rng = np.random.default_rng(28494) x = matrix("x") y = matrix("y") b = vector("b") # `ScalarOp` z = cosh(x**2 + y / 3.0) # `[Inc]Subtensor` out = aet_subtensor.set_subtensor(z[0], -10.0) out = aet_subtensor.inc_subtensor(out[0, 1], 2.0) out = out[:5, :3] out_fg = FunctionGraph([x, y], [out]) test_input_vals = [ np.tile(np.arange(10), (10, 1)).astype(config.floatX), np.tile(np.arange(10, 20), (10, 1)).astype(config.floatX), ] (jax_res, ) = compare_jax_and_py(out_fg, test_input_vals) # Confirm that the `Subtensor` slice operations are correct assert jax_res.shape == (5, 3) # Confirm that the `IncSubtensor` operations are correct assert jax_res[0, 0] == -10.0 assert jax_res[0, 1] == -8.0 out = clip(x, y, 5) out_fg = FunctionGraph([x, y], [out]) compare_jax_and_py(out_fg, test_input_vals) out = aet.diagonal(x, 0) out_fg = FunctionGraph([x], [out]) compare_jax_and_py( out_fg, [np.arange(10 * 10).reshape((10, 10)).astype(config.floatX)]) out = aet_slinalg.cholesky(x) out_fg = FunctionGraph([x], [out]) compare_jax_and_py( out_fg, [(np.eye(10) + rng.standard_normal(size=(10, 10)) * 0.01).astype( config.floatX)], ) # not sure why this isn't working yet with lower=False out = aet_slinalg.Cholesky(lower=False)(x) out_fg = FunctionGraph([x], [out]) compare_jax_and_py( out_fg, [(np.eye(10) + rng.standard_normal(size=(10, 10)) * 0.01).astype( config.floatX)], ) out = aet_slinalg.solve(x, b) out_fg = FunctionGraph([x, b], [out]) compare_jax_and_py( out_fg, [ np.eye(10).astype(config.floatX), np.arange(10).astype(config.floatX), ], ) out = aet.diag(b) out_fg = FunctionGraph([b], [out]) compare_jax_and_py(out_fg, [np.arange(10).astype(config.floatX)]) out = aet_nlinalg.det(x) out_fg = FunctionGraph([x], [out]) compare_jax_and_py( out_fg, [np.arange(10 * 10).reshape((10, 10)).astype(config.floatX)]) out = aet_nlinalg.matrix_inverse(x) out_fg = FunctionGraph([x], [out]) compare_jax_and_py( out_fg, [(np.eye(10) + rng.standard_normal(size=(10, 10)) * 0.01).astype( config.floatX)], )