def create_fgraph(inputs, outputs, validate=True): e = FunctionGraph(inputs, outputs, clone=False) e.attach_feature(DestroyHandler()) e.attach_feature(ReplaceValidate()) if validate: e.validate() return e
def test_verbose(self, capsys): var1 = MyVariable("var1") var2 = MyVariable("var2") var3 = op1(var2, var1) fg = FunctionGraph([var1, var2], [var3], clone=False) rv_feature = ReplaceValidate() fg.attach_feature(rv_feature) rv_feature.replace_all_validate(fg, [(var3, var1)], reason="test-reason", verbose=True) capres = capsys.readouterr() assert capres.err == "" assert "optimizer: rewrite test-reason replaces Op1.0 with var1" in capres.out class TestFeature(Feature): def validate(self, *args): raise Exception() fg.attach_feature(TestFeature()) with pytest.raises(Exception): rv_feature.replace_all_validate(fg, [(var3, var1)], reason="test-reason", verbose=True) capres = capsys.readouterr() assert "optimizer: validate failed on node Op1.0" in capres.out
def get_jaxified_graph( inputs: Optional[List[TensorVariable]] = None, outputs: Optional[List[TensorVariable]] = None, ) -> List[TensorVariable]: """Compile an Aesara graph into an optimized JAX function""" graph = _replace_shared_variables(outputs) fgraph = FunctionGraph(inputs=inputs, outputs=graph, clone=True) # We need to add a Supervisor to the fgraph to be able to run the # JAX sequential optimizer without warnings. We made sure there # are no mutable input variables, so we only need to check for # "destroyers". This should be automatically handled by Aesara # once https://github.com/aesara-devs/aesara/issues/637 is fixed. fgraph.attach_feature( Supervisor( input for input in fgraph.inputs if not (hasattr(fgraph, "destroyers") and fgraph.has_destroyers([input])) ) ) mode.JAX.optimizer.optimize(fgraph) # We now jaxify the optimized fgraph return jax_funcify(fgraph)
def test_no_merge(self): x, y, z = inputs() e = op1(op3(op2(x, y)), op3(op2(y, x))) g = FunctionGraph([x, y, z], [e]) g.attach_feature(AssertNoChanges()) MergeOptimizer().optimize(g)
def test_straightforward(self): class MyType(Type): def __init__(self, name): self.name = name def filter(self, *args, **kwargs): raise NotImplementedError() def __str__(self): return self.name def __repr__(self): return self.name def __eq__(self, other): return isinstance(other, MyType) class MyOp(Op): __props__ = ("nin", "name") def __init__(self, nin, name): self.nin = nin self.name = name def make_node(self, *inputs): def as_variable(x): assert isinstance(x, Variable) return x assert len(inputs) == self.nin inputs = list(map(as_variable, inputs)) for input in inputs: if not isinstance(input.type, MyType): raise Exception("Error 1") outputs = [MyType(self.name + "_R")()] return Apply(self, inputs, outputs) def __str__(self): return self.name def perform(self, *args, **kwargs): raise NotImplementedError() sigmoid = MyOp(1, "Sigmoid") add = MyOp(2, "Add") dot = MyOp(2, "Dot") def MyVariable(name): return Variable(MyType(name), None, None) def inputs(): x = MyVariable("x") y = MyVariable("y") z = MyVariable("z") return x, y, z x, y, z = inputs() e0 = dot(y, z) e = add(add(sigmoid(x), sigmoid(sigmoid(z))), dot(add(x, y), e0)) g = FunctionGraph([x, y, z], [e], clone=False) g.attach_feature(NodeFinder()) assert hasattr(g, "get_nodes") for type, num in ((add, 3), (sigmoid, 3), (dot, 2)): if not len([t for t in g.get_nodes(type)]) == num: raise Exception("Expected: %i times %s" % (num, type)) new_e0 = add(y, z) assert e0.owner in g.get_nodes(dot) assert new_e0.owner not in g.get_nodes(add) g.replace(e0, new_e0) assert e0.owner not in g.get_nodes(dot) assert new_e0.owner in g.get_nodes(add) for type, num in ((add, 4), (sigmoid, 3), (dot, 1)): if not len([t for t in g.get_nodes(type)]) == num: raise Exception("Expected: %i times %s" % (num, type))