def test_make_node_shared(self): """Make sure we can provide `OpFromGraph.make_node` new shared inputs and get a valid `OpFromGraph`.""" x = at.scalar("x") y = shared(1.0, name="y") test_ofg = OpFromGraph([x], [x + y], on_unused_input="ignore") assert test_ofg.shared_inputs == [y] out = test_ofg(x) y_clone = y.clone() assert y_clone != y y_clone.name = "y_clone" out_new = test_ofg.make_node(*(out.owner.inputs[:1] + [y_clone])).outputs[0] assert "on_unused_input" in out_new.owner.op.kwargs assert out_new.owner.op.shared_inputs == [y_clone] out_fn = function([x], out_new) assert np.array_equal(out_fn(1.0), 2.0) y_clone.set_value(2.0) assert np.array_equal(out_fn(1.0), 3.0)
def test_shared_to_nonshared_input(self): """Make sure that shared variables can be replaced with non-shared variables.""" x = at.scalar("x") y = shared(1.0, name="y") test_ofg = OpFromGraph([], [y]) assert test_ofg.shared_inputs == [y] out_1_fn = function([], test_ofg()) res_1 = out_1_fn() assert np.array_equal(res_1, 1.0) test_ofg_new = test_ofg.make_node(x) assert test_ofg_new.op.shared_inputs == [] out_2_fn = function([x], test_ofg_new.outputs[0]) res_2 = out_2_fn(np.array(1.0, dtype=config.floatX)) assert np.array_equal(res_2, 1.0)