def test_match_same_illegal(self): x, y, z = inputs() e = op2(op1(x, x), op1(x, y)) g = FunctionGraph([x, y, z], [e]) def constraint(r): # Only replacing if the input is an instance of Op2 return r.owner.inputs[0] is not r.owner.inputs[1] PatternOptimizer({ "pattern": (op1, "x", "y"), "constraint": constraint }, (op3, "x", "y")).optimize(g) assert str(g) == "FunctionGraph(Op2(Op1(x, x), Op3(x, y)))"
def test_jax_shape_ops(): x_np = np.zeros((20, 3)) x = Shape()(aet.as_tensor_variable(x_np)) x_fg = FunctionGraph([], [x]) compare_jax_and_py(x_fg, [], must_be_device_array=False) x = Shape_i(1)(aet.as_tensor_variable(x_np)) x_fg = FunctionGraph([], [x]) compare_jax_and_py(x_fg, [], must_be_device_array=False) x = SpecifyShape()(aet.as_tensor_variable(x_np), (20, 3)) x_fg = FunctionGraph([], [x]) compare_jax_and_py(x_fg, []) with config.change_flags(compute_test_value="off"): x = SpecifyShape()(aet.as_tensor_variable(x_np), (2, 3)) x_fg = FunctionGraph([], [x]) with pytest.raises(AssertionError): compare_jax_and_py(x_fg, [])
def test_constraints(self): x, y, z = inputs() e = op4(op1(op2(x, y)), op1(op1(x, y))) g = FunctionGraph([x, y, z], [e]) def constraint(r): # Only replacing if the input is an instance of Op2 return r.owner.op == op2 PatternOptimizer((op1, { "pattern": "1", "constraint": constraint }), (op3, "1")).optimize(g) assert str(g) == "FunctionGraph(Op4(Op3(Op2(x, y)), Op1(Op1(x, y))))"
def test_nnet(): x = vector("x") x.tag.test_value = np.r_[1.0, 2.0].astype(config.floatX) out = sigmoid(x) fgraph = FunctionGraph([x], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) out = aet_nnet.ultra_fast_sigmoid(x) fgraph = FunctionGraph([x], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) out = softplus(x) fgraph = FunctionGraph([x], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) out = aet_nnet.softmax(x) fgraph = FunctionGraph([x], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) out = aet_nnet.logsoftmax(x) fgraph = FunctionGraph([x], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
def test_extra_ops_omni(): a = matrix("a") a.tag.test_value = np.arange(6, dtype=config.floatX).reshape((3, 2)) # This function also cannot take symbolic input. c = aet.as_tensor(5) out = aet_extra_ops.bartlett(c) fgraph = FunctionGraph([], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) multi_index = np.unravel_index(np.arange(np.product((3, 4))), (3, 4)) out = aet_extra_ops.ravel_multi_index(multi_index, (3, 4)) fgraph = FunctionGraph([], [out]) compare_jax_and_py( fgraph, [get_test_value(i) for i in fgraph.inputs], must_be_device_array=False ) # The inputs are "concrete", yet it still has problems? out = aet_extra_ops.Unique()( aet.as_tensor(np.arange(6, dtype=config.floatX).reshape((3, 2))) ) fgraph = FunctionGraph([], [out]) compare_jax_and_py(fgraph, [])
def test_contains(self): var1 = MyVariable("var1") var2 = MyVariable("var2") var3 = op1(var2, var1) var4 = op2(var3, var2) var5 = op3(var4, var2, var2) fg = FunctionGraph([var1, var2], [var3, var5], clone=False) assert var1 in fg assert var3 in fg assert var3.owner in fg assert var5 in fg assert var5.owner in fg
def test_jax_Join(): a = matrix("a") b = matrix("b") x = aet.join(0, a, b) x_fg = FunctionGraph([a, b], [x]) compare_jax_and_py( x_fg, [ np.c_[[1.0, 2.0, 3.0]].astype(config.floatX), np.c_[[4.0, 5.0, 6.0]].astype(config.floatX), ], ) compare_jax_and_py( x_fg, [ np.c_[[1.0, 2.0, 3.0]].astype(config.floatX), np.c_[[4.0, 5.0]].astype(config.floatX), ], ) x = aet.join(1, a, b) x_fg = FunctionGraph([a, b], [x]) compare_jax_and_py( x_fg, [ np.c_[[1.0, 2.0, 3.0]].astype(config.floatX), np.c_[[4.0, 5.0, 6.0]].astype(config.floatX), ], ) compare_jax_and_py( x_fg, [ np.c_[[1.0, 2.0], [3.0, 4.0]].astype(config.floatX), np.c_[[5.0, 6.0]].astype(config.floatX), ], )
def test_jax_FunctionGraph_once(): """Make sure that an output is only computed once when it's referenced multiple times.""" from aesara.link.jax.dispatch import jax_funcify x = vector("x") y = vector("y") class TestOp(Op): def __init__(self): self.called = 0 def make_node(self, *args): return Apply(self, list(args), [x.type() for x in args]) def perform(self, inputs, outputs): for i, inp in enumerate(inputs): outputs[i][0] = inp[0] @jax_funcify.register(TestOp) def jax_funcify_TestOp(op, **kwargs): def func(*args, op=op): op.called += 1 return list(args) return func op1 = TestOp() op2 = TestOp() q, r = op1(x, y) outs = op2(q + r, q + r) out_fg = FunctionGraph([x, y], outs, clone=False) assert len(out_fg.outputs) == 2 out_jx = jax_funcify(out_fg) x_val = np.r_[1, 2].astype(config.floatX) y_val = np.r_[2, 3].astype(config.floatX) res = out_jx(x_val, y_val) assert len(res) == 2 assert op1.called == 1 assert op2.called == 1 res = out_jx(x_val, y_val) assert len(res) == 2 assert op1.called == 2 assert op2.called == 2
def test_local_optimizer(): with pytest.raises(ValueError): @local_optimizer([]) def local_bad_1(fgraph, node): return node.outputs with pytest.raises(TypeError): @local_optimizer([None]) def local_bad_2(fgraph, node): return node.outputs x = MyVariable("x") y = MyVariable("y") o1 = op1(x, y) class MyNewOp(MyOp): pass o2 = MyNewOp("MyNewOp")(x, y) class MyNewOp2(MyOp): pass o3 = MyNewOp2("MyNewOp2")(x, y) fgraph = FunctionGraph([x, y], [o1, o2, o3], clone=False) hits = [0] @local_optimizer([op1, MyNewOp]) def local_opt_1(fgraph, node, hits=hits): hits[0] += 1 return node.outputs # This is allowed by the `op1` in `tracks` local_opt_1.transform(fgraph, fgraph.outputs[0].owner) assert hits[0] == 1 # This is allowed by the `MyOp` in `tracks` local_opt_1.transform(fgraph, fgraph.outputs[1].owner) assert hits[0] == 2 # This is not allowed by `tracks` local_opt_1.transform(fgraph, fgraph.outputs[2].owner) assert hits[0] == 2
def test_random_stats(at_dist, dist_params, rng, size): # The RNG states are not 1:1, so the best we can do is check some summary # statistics of the samples out = normal(*dist_params, rng=rng, size=size) fgraph = FunctionGraph([out.owner.inputs[0]], [out], clone=False) def assert_fn(x, y): (x,) = x (y,) = y assert x.dtype.kind == y.dtype.kind d = 2 if config.floatX == "float64" else 1 np.testing.assert_array_almost_equal(np.abs(x.mean()), np.abs(y.mean()), d) compare_jax_and_py(fgraph, [], assert_fn=assert_fn)
def test_fill(self): for linker, op, t, rval in zip( self.linkers, [self.op, self.cop], [self.type, self.ctype], [self.rand_val, self.rand_cval], ): x = t(aesara.config.floatX, (False, False))("x") y = t(aesara.config.floatX, (True, True))("y") e = op(aes.Second(aes.transfer_type(0)), {0: 0})(x, y) f = make_function(linker().accept(FunctionGraph([x, y], [e]))) xv = rval((5, 5)) yv = rval((1, 1)) f(xv, yv) assert (xv == yv).all()
def test_jax_Dimshuffle(): a_aet = matrix("a") x = a_aet.T x_fg = FunctionGraph([a_aet], [x]) compare_jax_and_py(x_fg, [np.c_[[1.0, 2.0], [3.0, 4.0]].astype(config.floatX)]) x = a_aet.dimshuffle([0, 1, "x"]) x_fg = FunctionGraph([a_aet], [x]) compare_jax_and_py(x_fg, [np.c_[[1.0, 2.0], [3.0, 4.0]].astype(config.floatX)]) a_aet = tensor(dtype=config.floatX, broadcastable=[False, True]) x = a_aet.dimshuffle((0, )) x_fg = FunctionGraph([a_aet], [x]) compare_jax_and_py(x_fg, [np.c_[[1.0, 2.0, 3.0, 4.0]].astype(config.floatX)]) a_aet = tensor(dtype=config.floatX, broadcastable=[False, True]) x = aet_elemwise.DimShuffle([False, True], (0, ))(a_aet) x_fg = FunctionGraph([a_aet], [x]) compare_jax_and_py(x_fg, [np.c_[[1.0, 2.0, 3.0, 4.0]].astype(config.floatX)])
def test_jax_FunctionGraph_names(): import inspect from aesara.link.jax.dispatch import jax_funcify x = scalar("1x") y = scalar("_") z = scalar() q = scalar("def") out_fg = FunctionGraph([x, y, z, q], [x, y, z, q], clone=False) out_jx = jax_funcify(out_fg) sig = inspect.signature(out_jx) assert (x.auto_name, "_", z.auto_name, q.auto_name) == tuple(sig.parameters.keys()) assert (1, 2, 3, 4) == out_jx(1, 2, 3, 4)
def test_jax_Composite(x, y, x_val, y_val): x_s = aes.float64("x") y_s = aes.float64("y") comp_op = Elemwise(Composite([x_s, y_s], [x_s + y_s * 2 + aes.exp(x_s - y_s)])) out = comp_op(x, y) out_fg = FunctionGraph([x, y], [out]) test_input_vals = [ x_val.astype(config.floatX), y_val.astype(config.floatX), ] _ = compare_jax_and_py(out_fg, test_input_vals)
def test_patternsub_invalid_dtype(out_pattern): # PatternSub would wrongly return output of different dtype as the original node x = MyVariable("x") e = op_cast_type2(x) fg = FunctionGraph([x], [e]) opt = EquilibriumOptimizer( [PatternSub( (op_cast_type2, "x"), out_pattern, )], max_use_ratio=1, ) opt.optimize(fg) assert fg.apply_nodes.pop().op == op_cast_type2
def test_weird_strides(self): for linker, op, t, rval in zip( self.linkers, [self.op, self.cop], [self.type, self.ctype], [self.rand_val, self.rand_cval], ): x = t(aesara.config.floatX, (False,) * 5)("x") y = t(aesara.config.floatX, (False,) * 5)("y") e = op(aes.add)(x, y) f = make_function(linker().accept(FunctionGraph([x, y], [e]))) xv = rval((2, 2, 2, 2, 2)) yv = rval((2, 2, 2, 2, 2)).transpose(4, 0, 3, 1, 2) zv = xv + yv assert (f(xv, yv) == zv).all()
def schedule(self, fgraph: FunctionGraph) -> typing.List[Apply]: """Runs the scheduler (if set) or the toposort on the FunctionGraph. Parameters ---------- fgraph : FunctionGraph A graph to compute the schedule for. Returns ------- nodes : list of Apply nodes The result of the scheduling or toposort operation. """ if callable(self._scheduler): return self._scheduler(fgraph) return fgraph.toposort()
def test_patternsub_different_output_lengths(): # Test that PatternSub won't replace nodes with different numbers of outputs ps = PatternSub( (op1, "x"), ("x"), name="ps", ) opt = in2out(ps) x = MyVariable("x") e1, e2 = op_multiple_outputs(x) o = op1(e1) fgraph = FunctionGraph(inputs=[x], outputs=[o]) opt.optimize(fgraph) assert fgraph.outputs[0].owner.op == op1
def schedule(self, fgraph: FunctionGraph) -> List[Apply]: """Runs the scheduler (if set) or the toposort on the FunctionGraph. Parameters ---------- fgraph : :py:class:`aerasa.graph.fg.FunctionGraph` A graph to compute the schedule for. Returns ------- nodes : list of :py:class:`aesara.graph.basic.Apply` nodes The result of the scheduling or toposort operation. """ if callable(self._scheduler): return self._scheduler(fgraph) return fgraph.toposort()
def test_multiple_merges(self): x, y, z = inputs() e1 = op1(x, y) e2 = op2(op3(x), y, z) e = op1(e1, op4(e2, e1), op1(e2)) g = FunctionGraph([x, y, z], [e]) MergeOptimizer().optimize(g) strg = str(g) # note: graph.as_string can only produce the following two possibilities, but if # the implementation was to change there are 6 other acceptable answers. assert ( strg == "FunctionGraph(Op1(*1 -> Op1(x, y), Op4(*2 -> Op2(Op3(x), y, z), *1), Op1(*2)))" or strg == "FunctionGraph(Op1(*2 -> Op1(x, y), Op4(*1 -> Op2(Op3(x), y, z), *2), Op1(*1)))" )
def test_1(self): x, y, z = map(MyVariable, "xyz") e = op3(op4(x, y)) g = FunctionGraph([x, y, z], [e]) # print g opt = EquilibriumOptimizer( [ PatternSub((op1, "x", "y"), (op2, "x", "y")), PatternSub((op4, "x", "y"), (op1, "x", "y")), PatternSub((op3, (op2, "x", "y")), (op4, "x", "y")), ], max_use_ratio=10, ) opt.optimize(g) # print g assert str(g) == "FunctionGraph(Op2(x, y))"
def test_optimize_graph(): x, y = vectors("xy") @optimizer def custom_opt(fgraph): fgraph.replace(x, y, import_missing=True) x_opt = optimize_graph(x, custom_opt=custom_opt) assert x_opt is y x_opt = optimize_graph(FunctionGraph(outputs=[x], clone=False), custom_opt=custom_opt) assert x_opt.outputs[0] is y
def test_KanrenRelationSub_filters(): x_at = at.vector("x") y_at = at.vector("y") z_at = at.vector("z") A_at = at.matrix("A") fact(commutative, _dot) fact(commutative, at.add) fact(associative, at.add) Z_at = A_at.dot((x_at + y_at) + z_at) fgraph = FunctionGraph(outputs=[Z_at], clone=False) def distributes(in_lv, out_lv): A_lv, x_lv, y_lv, z_lv = vars(4) return lall( # lhs == A * (x + y + z) eq_assoccomm( etuple(_dot, A_lv, etuple(at.add, x_lv, etuple(at.add, y_lv, z_lv))), in_lv, ), # This relation does nothing but provide us with a means of # generating associative-commutative matches in the `kanren` # output. eq((A_lv, x_lv, y_lv, z_lv), out_lv), ) def results_filter(results): _results = [eval_if_etuple(v) for v in results] # Make sure that at least a couple permutations are present assert (A_at, x_at, y_at, z_at) in _results assert (A_at, y_at, x_at, z_at) in _results assert (A_at, z_at, x_at, y_at) in _results return None _ = KanrenRelationSub(distributes, results_filter=results_filter).transform( fgraph, fgraph.outputs[0].owner) res = KanrenRelationSub(distributes, node_filter=lambda x: False).transform( fgraph, fgraph.outputs[0].owner) assert res is False
def test_one_assert_merge(self): # Merge two nodes, one has assert, the other not. x1 = matrix("x1") x2 = matrix("x2") e = dot(x1, x2) + dot(assert_op(x1, (x1 > x2).all()), x2) g = FunctionGraph([x1, x2], [e], clone=False) MergeOptimizer().optimize(g) assert g.outputs[0].owner.op == add add_inputs = g.outputs[0].owner.inputs assert isinstance(add_inputs[0].owner.op, Dot) # Confirm that the `Assert`s are correct assert_var = add_inputs[0].owner.inputs[0] assert_ref = assert_op(x1, (x1 > x2).all()) assert equal_computations([assert_var], [assert_ref]) # Confirm the merge assert add_inputs[0] is add_inputs[1]
def test_c_views(self): x_at = vector() thunk, inputs, outputs = (CLinker().accept( FunctionGraph([x_at], [x_at[None]])).make_thunk()) # This is a little hackish, but we're hoping that--by running this more than # a few times--we're more likely to run into random memory that isn't the same # as the broadcasted value; that way, we'll be able to tell that we're getting # junk data from a poorly constructed array view. x_val = np.broadcast_to(2039, (5000, )) for i in range(1000): inputs[0].storage[0] = x_val thunk() # Make sure it's a view of the original data assert np.shares_memory(x_val, outputs[0].storage[0]) # Confirm the broadcasted value in the output assert np.array_equiv(outputs[0].storage[0], 2039)
def test_pickle(self): var1 = op1() var2 = op2() var3 = op1(var1) var4 = op2(var3, var2) func = FunctionGraph([var1, var2], [var4]) s = pickle.dumps(func) new_func = pickle.loads(s) assert all(type(a) == type(b) for a, b in zip(func.inputs, new_func.inputs)) assert all(type(a) == type(b) for a, b in zip(func.outputs, new_func.outputs)) assert all( type(a.op) is type(b.op) # noqa: E721 for a, b in zip(func.apply_nodes, new_func.apply_nodes) ) assert all(a.type == b.type for a, b in zip(func.variables, new_func.variables))
def test_extra_ops(): a = matrix("a") a.tag.test_value = np.arange(6, dtype=config.floatX).reshape((3, 2)) out = at_extra_ops.cumsum(a, axis=0) fgraph = FunctionGraph([a], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) out = at_extra_ops.cumprod(a, axis=1) fgraph = FunctionGraph([a], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) out = at_extra_ops.diff(a, n=2, axis=1) fgraph = FunctionGraph([a], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) out = at_extra_ops.repeat(a, (3, 3), axis=1) fgraph = FunctionGraph([a], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) c = at.as_tensor(5) with pytest.raises(NotImplementedError): out = at_extra_ops.fill_diagonal(a, c) fgraph = FunctionGraph([a], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) with pytest.raises(NotImplementedError): out = at_extra_ops.fill_diagonal_offset(a, c, c) fgraph = FunctionGraph([a], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) with pytest.raises(NotImplementedError): out = at_extra_ops.Unique(axis=1)(a) fgraph = FunctionGraph([a], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) indices = np.arange(np.product((3, 4))) out = at_extra_ops.unravel_index(indices, (3, 4), order="C") fgraph = FunctionGraph([], out) compare_jax_and_py( fgraph, [get_test_value(i) for i in fgraph.inputs], must_be_device_array=False )
def test_KanrenRelationSub_dot(): """Make sure we can run miniKanren "optimizations" over a graph until a fixed-point/normal-form is reached.""" x_at = at.vector("x") c_at = at.vector("c") d_at = at.vector("d") A_at = at.matrix("A") B_at = at.matrix("B") Z_at = A_at.dot(x_at + B_at.dot(c_at + d_at)) fgraph = FunctionGraph(outputs=[Z_at], clone=False) assert isinstance(fgraph.outputs[0].owner.op, Dot) def distributes(in_lv, out_lv): return lall( # lhs == A * (x + b) eq( etuple(_dot, var("A"), etuple(at.add, var("x"), var("b"))), in_lv, ), # rhs == A * x + A * b eq( etuple( at.add, etuple(_dot, var("A"), var("x")), etuple(_dot, var("A"), var("b")), ), out_lv, ), ) distribute_opt = EquilibriumOptimizer([KanrenRelationSub(distributes)], max_use_ratio=10) fgraph_opt = optimize_graph(fgraph, custom_opt=distribute_opt) (expr_opt, ) = fgraph_opt.outputs assert expr_opt.owner.op == at.add assert isinstance(expr_opt.owner.inputs[0].owner.op, Dot) assert fgraph_opt.inputs[0] is A_at assert expr_opt.owner.inputs[0].owner.inputs[0].name == "A" assert expr_opt.owner.inputs[1].owner.op == at.add assert isinstance(expr_opt.owner.inputs[1].owner.inputs[0].owner.op, Dot) assert isinstance(expr_opt.owner.inputs[1].owner.inputs[1].owner.op, Dot)
def test_fgraph_to_python_multiline_str(): """Make sure that multiline `__str__` values are supported by `fgraph_to_python`.""" x = vector("x") y = vector("y") class TestOp(Op): def __init__(self): super().__init__() def make_node(self, *args): return Apply(self, list(args), [x.type() for x in args]) def perform(self, inputs, outputs): for i, inp in enumerate(inputs): outputs[i][0] = inp[0] def __str__(self): return "Test\nOp()" @to_python.register(TestOp) def to_python_TestOp(op, **kwargs): def func(*args, op=op): return list(args) return func op1 = TestOp() op2 = TestOp() q, r = op1(x, y) outs = op2(q + r, q + r) out_fg = FunctionGraph([x, y], outs, clone=False) assert len(out_fg.outputs) == 2 out_py = fgraph_to_python(out_fg, to_python) out_py_src = inspect.getsource(out_py) assert (""" # Elemwise{add,no_inplace}(Test # Op().0, Test # Op().1) """ in out_py_src)
def test_1(self): x, y, z = map(MyVariable, "xyz") # TODO FIXME: These `Op`s don't have matching/consistent `__prop__`s # and `__init__`s, so they can't be `etuplized` correctly e = op3(op4(x, y)) g = FunctionGraph([x, y, z], [e]) # print g opt = EquilibriumOptimizer( [ PatternSub((op1, "x", "y"), (op2, "x", "y")), PatternSub((op4, "x", "y"), (op1, "x", "y")), PatternSub((op3, (op2, "x", "y")), (op4, "x", "y")), ], max_use_ratio=10, ) opt.optimize(g) # print g assert str(g) == "FunctionGraph(Op2(x, y))"