def testWhile(self): with NetBuilder(_use_control_ops=True) as nb: ops.Copy(ops.Const(0), "i") ops.Copy(ops.Const(1), "one") ops.Copy(ops.Const(2), "two") ops.Copy(ops.Const(2.0), "x") ops.Copy(ops.Const(3.0), "y") ops.Copy(ops.Const(2.0), "z") # raises x to the power of 4 and y to the power of 2 # and z to the power of 3 with ops.WhileNet(): with ops.Condition(): ops.Add(["i", "one"], "i") ops.LE(["i", "two"]) ops.Pow("x", "x", exponent=2.0) with ops.IfNet(ops.LT(["i", "two"])): ops.Pow("y", "y", exponent=2.0) with ops.Else(): ops.Pow("z", "z", exponent=3.0) ops.Add(["x", "y"], "x_plus_y") ops.Add(["x_plus_y", "z"], "s") assert len(nb.get()) == 1, "Expected a single net produced" net = nb.get()[0] net.AddGradientOperators(["s"]) workspace.RunNetOnce(net) # (x^4)' = 4x^3 self.assertAlmostEqual(workspace.FetchBlob("x_grad"), 32) self.assertAlmostEqual(workspace.FetchBlob("x"), 16) # (y^2)' = 2y self.assertAlmostEqual(workspace.FetchBlob("y_grad"), 6) self.assertAlmostEqual(workspace.FetchBlob("y"), 9) # (z^3)' = 3z^2 self.assertAlmostEqual(workspace.FetchBlob("z_grad"), 12) self.assertAlmostEqual(workspace.FetchBlob("z"), 8)
# Both 'If' and 'While' operators support backpropagation. To illustrate how backpropagation with control ops work, let's consider the following example: # In[14]: import numpy as np # _use_control_ops=True forces NetBuilder to output single net as a result # x is external for NetBuilder, so letting nb know about it through initial_scope param FeedBlob("x", np.array(0.5, dtype='float32')) with NetBuilder(_use_control_ops=True, initial_scope=["x"]) as nb: ops.Const(0.0, blob_out="zero") ops.Const(1.0, blob_out="one") ops.Const(4.0, blob_out="y") ops.Const(0.0, blob_out="z") with ops.IfNet(ops.GT(["x", "zero"])): ops.Pow("y", "z", exponent=2.0) with ops.Else(): ops.Pow("y", "z", exponent=3.0) assert len(nb.get()) == 1, "Expected a single net produced" net = nb.get()[0] grad_map = net.AddGradientOperators(["z"]) # Output blob "z" as a function of "y" depends on the value of blob "x", if "x" is greater than zero, than "z = y^2", otherwise it is "z = y^3" # In[15]: RunNetOnce(net) print("x = ", FetchBlob("x")) print("y = ", FetchBlob("y")) print("z = ", FetchBlob("z"))