def check(self, expected, debug=True): """ Core function to perform comparison. :param expected: A list of tuples (v1, v2, ((g1, o1), ..., (gN, oN))) with: - `v1` and `v2` two Variables (the graphs to be compared) - `gj` a `givens` dictionary to give as input to `is_same_graph` - `oj` the expected output of `is_same_graph(v1, v2, givens=gj)` :param debug: If True, then we make sure we are testing both implementations of `is_same_graph`. This function also tries to call `is_same_graph` by inverting `v1` and `v2`, and ensures the output remains the same. """ for v1, v2, go in expected: for gj, oj in go: r1 = is_same_graph(v1, v2, givens=gj, debug=debug) assert r1 == oj r2 = is_same_graph(v2, v1, givens=gj, debug=debug) assert r2 == oj
def test_saved_inner_graph(): """Make sure that the original inner graph is saved.""" x = tensor.tensor3() recurrent = SimpleRecurrent(dim=3, activation=Tanh()) y = recurrent.apply(x) application_call = get_application_call(y) assert application_call.inner_inputs assert application_call.inner_outputs cg = ComputationGraph(application_call.inner_outputs) # Check that the inner scan graph is annotated # with `recurrent.apply` assert len(VariableFilter(applications=[recurrent.apply])(cg)) == 3 # Check that the inner graph is equivalent to the one # produced by a stand-alone of `recurrent.apply` assert is_same_graph(application_call.inner_outputs[0], recurrent.apply(*application_call.inner_inputs, iterate=False))