Exemplo n.º 1
0
    def testWhile(self):
        with NetBuilder(_use_control_ops=True) as nb:
            ops.Copy(ops.Const(0), "i")
            ops.Copy(ops.Const(1), "one")
            ops.Copy(ops.Const(2), "two")
            ops.Copy(ops.Const(2.0), "x")
            ops.Copy(ops.Const(3.0), "y")
            ops.Copy(ops.Const(2.0), "z")
            # raises x to the power of 4 and y to the power of 2
            # and z to the power of 3
            with ops.WhileNet():
                with ops.Condition():
                    ops.Add(["i", "one"], "i")
                    ops.LE(["i", "two"])
                ops.Pow("x", "x", exponent=2.0)
                with ops.IfNet(ops.LT(["i", "two"])):
                    ops.Pow("y", "y", exponent=2.0)
                with ops.Else():
                    ops.Pow("z", "z", exponent=3.0)

            ops.Add(["x", "y"], "x_plus_y")
            ops.Add(["x_plus_y", "z"], "s")

        assert len(nb.get()) == 1, "Expected a single net produced"
        net = nb.get()[0]

        net.AddGradientOperators(["s"])
        workspace.RunNetOnce(net)
        # (x^4)' = 4x^3
        self.assertAlmostEqual(workspace.FetchBlob("x_grad"), 32)
        self.assertAlmostEqual(workspace.FetchBlob("x"), 16)
        # (y^2)' = 2y
        self.assertAlmostEqual(workspace.FetchBlob("y_grad"), 6)
        self.assertAlmostEqual(workspace.FetchBlob("y"), 9)
        # (z^3)' = 3z^2
        self.assertAlmostEqual(workspace.FetchBlob("z_grad"), 12)
        self.assertAlmostEqual(workspace.FetchBlob("z"), 8)
Exemplo n.º 2
0
# Both 'If' and 'While' operators support backpropagation. To illustrate how backpropagation with control ops work, let's consider the following example:

# In[14]:

import numpy as np
# _use_control_ops=True forces NetBuilder to output single net as a result
# x is external for NetBuilder, so letting nb know about it through initial_scope param
FeedBlob("x", np.array(0.5, dtype='float32'))
with NetBuilder(_use_control_ops=True, initial_scope=["x"]) as nb:
    ops.Const(0.0, blob_out="zero")
    ops.Const(1.0, blob_out="one")
    ops.Const(4.0, blob_out="y")
    ops.Const(0.0, blob_out="z")
    with ops.IfNet(ops.GT(["x", "zero"])):
        ops.Pow("y", "z", exponent=2.0)
    with ops.Else():
        ops.Pow("y", "z", exponent=3.0)

assert len(nb.get()) == 1, "Expected a single net produced"
net = nb.get()[0]
grad_map = net.AddGradientOperators(["z"])

# Output blob "z" as a function of "y" depends on the value of blob "x", if "x" is greater than zero, than "z = y^2", otherwise it is "z = y^3"

# In[15]:

RunNetOnce(net)
print("x = ", FetchBlob("x"))
print("y = ", FetchBlob("y"))
print("z = ", FetchBlob("z"))