예제 #1
0
    def check(self, expected, debug=True):
        """
        Core function to perform comparison.

        :param expected: A list of tuples (v1, v2, ((g1, o1), ..., (gN, oN)))
        with:
            - `v1` and `v2` two Variables (the graphs to be compared)
            - `gj` a `givens` dictionary to give as input to `is_same_graph`
            - `oj` the expected output of `is_same_graph(v1, v2, givens=gj)`

        :param debug: If True, then we make sure we are testing both
        implementations of `is_same_graph`.

        This function also tries to call `is_same_graph` by inverting `v1` and
        `v2`, and ensures the output remains the same.
        """
        for v1, v2, go in expected:
            for gj, oj in go:
                r1 = is_same_graph(v1, v2, givens=gj, debug=debug)
                assert r1 == oj
                r2 = is_same_graph(v2, v1, givens=gj, debug=debug)
                assert r2 == oj
예제 #2
0
    def check(self, expected, debug=True):
        """
        Core function to perform comparison.

        :param expected: A list of tuples (v1, v2, ((g1, o1), ..., (gN, oN)))
        with:
            - `v1` and `v2` two Variables (the graphs to be compared)
            - `gj` a `givens` dictionary to give as input to `is_same_graph`
            - `oj` the expected output of `is_same_graph(v1, v2, givens=gj)`

        :param debug: If True, then we make sure we are testing both
        implementations of `is_same_graph`.

        This function also tries to call `is_same_graph` by inverting `v1` and
        `v2`, and ensures the output remains the same.
        """
        for v1, v2, go in expected:
            for gj, oj in go:
                r1 = is_same_graph(v1, v2, givens=gj, debug=debug)
                assert r1 == oj
                r2 = is_same_graph(v2, v1, givens=gj, debug=debug)
                assert r2 == oj
예제 #3
0
def test_saved_inner_graph():
    """Make sure that the original inner graph is saved."""
    x = tensor.tensor3()
    recurrent = SimpleRecurrent(dim=3, activation=Tanh())
    y = recurrent.apply(x)

    application_call = get_application_call(y)
    assert application_call.inner_inputs
    assert application_call.inner_outputs

    cg = ComputationGraph(application_call.inner_outputs)
    # Check that the inner scan graph is annotated
    # with `recurrent.apply`
    assert len(VariableFilter(applications=[recurrent.apply])(cg)) == 3
    # Check that the inner graph is equivalent to the one
    # produced by a stand-alone of `recurrent.apply`
    assert is_same_graph(application_call.inner_outputs[0],
                         recurrent.apply(*application_call.inner_inputs,
                                         iterate=False))