Ejemplo n.º 1
0
def create_fgraph(inputs, outputs, validate=True):
    e = FunctionGraph(inputs, outputs, clone=False)
    e.attach_feature(DestroyHandler())
    e.attach_feature(ReplaceValidate())
    if validate:
        e.validate()
    return e
Ejemplo n.º 2
0
    def test_verbose(self, capsys):
        var1 = MyVariable("var1")
        var2 = MyVariable("var2")
        var3 = op1(var2, var1)
        fg = FunctionGraph([var1, var2], [var3], clone=False)

        rv_feature = ReplaceValidate()
        fg.attach_feature(rv_feature)
        rv_feature.replace_all_validate(fg, [(var3, var1)],
                                        reason="test-reason",
                                        verbose=True)

        capres = capsys.readouterr()
        assert capres.err == ""
        assert "optimizer: rewrite test-reason replaces Op1.0 with var1" in capres.out

        class TestFeature(Feature):
            def validate(self, *args):
                raise Exception()

        fg.attach_feature(TestFeature())

        with pytest.raises(Exception):
            rv_feature.replace_all_validate(fg, [(var3, var1)],
                                            reason="test-reason",
                                            verbose=True)

        capres = capsys.readouterr()
        assert "optimizer: validate failed on node Op1.0" in capres.out
Ejemplo n.º 3
0
def get_jaxified_graph(
    inputs: Optional[List[TensorVariable]] = None,
    outputs: Optional[List[TensorVariable]] = None,
) -> List[TensorVariable]:
    """Compile an Aesara graph into an optimized JAX function"""

    graph = _replace_shared_variables(outputs)

    fgraph = FunctionGraph(inputs=inputs, outputs=graph, clone=True)
    # We need to add a Supervisor to the fgraph to be able to run the
    # JAX sequential optimizer without warnings. We made sure there
    # are no mutable input variables, so we only need to check for
    # "destroyers". This should be automatically handled by Aesara
    # once https://github.com/aesara-devs/aesara/issues/637 is fixed.
    fgraph.attach_feature(
        Supervisor(
            input
            for input in fgraph.inputs
            if not (hasattr(fgraph, "destroyers") and fgraph.has_destroyers([input]))
        )
    )
    mode.JAX.optimizer.optimize(fgraph)

    # We now jaxify the optimized fgraph
    return jax_funcify(fgraph)
Ejemplo n.º 4
0
 def test_no_merge(self):
     x, y, z = inputs()
     e = op1(op3(op2(x, y)), op3(op2(y, x)))
     g = FunctionGraph([x, y, z], [e])
     g.attach_feature(AssertNoChanges())
     MergeOptimizer().optimize(g)
Ejemplo n.º 5
0
    def test_straightforward(self):
        class MyType(Type):
            def __init__(self, name):
                self.name = name

            def filter(self, *args, **kwargs):
                raise NotImplementedError()

            def __str__(self):
                return self.name

            def __repr__(self):
                return self.name

            def __eq__(self, other):
                return isinstance(other, MyType)

        class MyOp(Op):

            __props__ = ("nin", "name")

            def __init__(self, nin, name):
                self.nin = nin
                self.name = name

            def make_node(self, *inputs):
                def as_variable(x):
                    assert isinstance(x, Variable)
                    return x

                assert len(inputs) == self.nin
                inputs = list(map(as_variable, inputs))
                for input in inputs:
                    if not isinstance(input.type, MyType):
                        raise Exception("Error 1")
                outputs = [MyType(self.name + "_R")()]
                return Apply(self, inputs, outputs)

            def __str__(self):
                return self.name

            def perform(self, *args, **kwargs):
                raise NotImplementedError()

        sigmoid = MyOp(1, "Sigmoid")
        add = MyOp(2, "Add")
        dot = MyOp(2, "Dot")

        def MyVariable(name):
            return Variable(MyType(name), None, None)

        def inputs():
            x = MyVariable("x")
            y = MyVariable("y")
            z = MyVariable("z")
            return x, y, z

        x, y, z = inputs()
        e0 = dot(y, z)
        e = add(add(sigmoid(x), sigmoid(sigmoid(z))), dot(add(x, y), e0))
        g = FunctionGraph([x, y, z], [e], clone=False)
        g.attach_feature(NodeFinder())

        assert hasattr(g, "get_nodes")
        for type, num in ((add, 3), (sigmoid, 3), (dot, 2)):
            if not len([t for t in g.get_nodes(type)]) == num:
                raise Exception("Expected: %i times %s" % (num, type))
        new_e0 = add(y, z)
        assert e0.owner in g.get_nodes(dot)
        assert new_e0.owner not in g.get_nodes(add)
        g.replace(e0, new_e0)
        assert e0.owner not in g.get_nodes(dot)
        assert new_e0.owner in g.get_nodes(add)
        for type, num in ((add, 4), (sigmoid, 3), (dot, 1)):
            if not len([t for t in g.get_nodes(type)]) == num:
                raise Exception("Expected: %i times %s" % (num, type))