Example #1
0
    def test_make_node_shared(self):
        """Make sure we can provide `OpFromGraph.make_node` new shared inputs and get a valid `OpFromGraph`."""

        x = at.scalar("x")
        y = shared(1.0, name="y")

        test_ofg = OpFromGraph([x], [x + y], on_unused_input="ignore")
        assert test_ofg.shared_inputs == [y]

        out = test_ofg(x)

        y_clone = y.clone()
        assert y_clone != y
        y_clone.name = "y_clone"

        out_new = test_ofg.make_node(*(out.owner.inputs[:1] +
                                       [y_clone])).outputs[0]

        assert "on_unused_input" in out_new.owner.op.kwargs
        assert out_new.owner.op.shared_inputs == [y_clone]

        out_fn = function([x], out_new)
        assert np.array_equal(out_fn(1.0), 2.0)

        y_clone.set_value(2.0)
        assert np.array_equal(out_fn(1.0), 3.0)
Example #2
0
    def test_shared_to_nonshared_input(self):
        """Make sure that shared variables can be replaced with non-shared variables."""
        x = at.scalar("x")
        y = shared(1.0, name="y")

        test_ofg = OpFromGraph([], [y])
        assert test_ofg.shared_inputs == [y]

        out_1_fn = function([], test_ofg())
        res_1 = out_1_fn()

        assert np.array_equal(res_1, 1.0)

        test_ofg_new = test_ofg.make_node(x)
        assert test_ofg_new.op.shared_inputs == []

        out_2_fn = function([x], test_ofg_new.outputs[0])
        res_2 = out_2_fn(np.array(1.0, dtype=config.floatX))

        assert np.array_equal(res_2, 1.0)