def test_jax_Alloc(): x = aet.alloc(0.0, 2, 3) x_fg = FunctionGraph([], [x]) (jax_res, ) = compare_jax_and_py(x_fg, []) assert jax_res.shape == (2, 3) x = aet.alloc(1.1, 2, 3) x_fg = FunctionGraph([], [x]) compare_jax_and_py(x_fg, []) x = aet.AllocEmpty("float32")(2, 3) x_fg = FunctionGraph([], [x]) def compare_shape_dtype(x, y): (x, ) = x (y, ) = y return x.shape == y.shape and x.dtype == y.dtype compare_jax_and_py(x_fg, [], assert_fn=compare_shape_dtype) a = scalar("a") x = aet.alloc(a, 20) x_fg = FunctionGraph([a], [x]) compare_jax_and_py(x_fg, [10.0]) a = vector("a") x = aet.alloc(a, 20, 10) x_fg = FunctionGraph([a], [x]) compare_jax_and_py(x_fg, [np.ones(10, dtype=config.floatX)])
def make_c_gemv_destructive(fgraph, node): if isinstance(node.op, CGemv) and not node.op.inplace: inputs = list(node.inputs) dest = inputs[0] if (dest.owner and isinstance(dest.owner.op, at.AllocEmpty) and len(fgraph.clients[dest]) > 1): inputs[0] = at.AllocEmpty(dest.dtype)(*dest.owner.inputs) return [cgemv_inplace(*inputs)]