コード例 #1
0
    def test_while_net(self):
        with NetBuilder() as nb:
            x = ops.Const(0)
            y = ops.Const(0)
            with ops.WhileNet():
                with ops.Condition():
                    ops.Add([x, ops.Const(1)], [x])
                    ops.LT([x, ops.Const(7)])
                ops.Add([x, y], [y])

        plan = Plan('while_net_test')
        plan.AddStep(to_execution_step(nb))
        ws = workspace.C.Workspace()
        ws.run(plan)

        x_value = ws.blobs[str(x)].fetch()
        y_value = ws.blobs[str(y)].fetch()

        self.assertEqual(x_value, 7)
        self.assertEqual(y_value, 21)
コード例 #2
0
    def testWhile(self):
        with NetBuilder(_use_control_ops=True) as nb:
            ops.Copy(ops.Const(0), "i")
            ops.Copy(ops.Const(1), "one")
            ops.Copy(ops.Const(2), "two")
            ops.Copy(ops.Const(2.0), "x")
            ops.Copy(ops.Const(3.0), "y")
            ops.Copy(ops.Const(2.0), "z")
            # raises x to the power of 4 and y to the power of 2
            # and z to the power of 3
            with ops.WhileNet():
                with ops.Condition():
                    ops.Add(["i", "one"], "i")
                    ops.LE(["i", "two"])
                ops.Pow("x", "x", exponent=2.0)
                with ops.IfNet(ops.LT(["i", "two"])):
                    ops.Pow("y", "y", exponent=2.0)
                with ops.Else():
                    ops.Pow("z", "z", exponent=3.0)

            ops.Add(["x", "y"], "x_plus_y")
            ops.Add(["x_plus_y", "z"], "s")

        assert len(nb.get()) == 1, "Expected a single net produced"
        net = nb.get()[0]

        net.AddGradientOperators(["s"])
        workspace.RunNetOnce(net)
        # (x^4)' = 4x^3
        self.assertAlmostEqual(workspace.FetchBlob("x_grad"), 32)
        self.assertAlmostEqual(workspace.FetchBlob("x"), 16)
        # (y^2)' = 2y
        self.assertAlmostEqual(workspace.FetchBlob("y_grad"), 6)
        self.assertAlmostEqual(workspace.FetchBlob("y"), 9)
        # (z^3)' = 3z^2
        self.assertAlmostEqual(workspace.FetchBlob("z_grad"), 12)
        self.assertAlmostEqual(workspace.FetchBlob("z"), 8)
コード例 #3
0
RunNetOnce(model.param_init_net)
RunNetOnce(model.net)
print("x = ", FetchBlob("x"))
print("y = ", FetchBlob("y"))

# ### Loops

# Another important control flow operator is 'While' that allows repeated execution of a fragment of net. Let's consider NetBuilder's version of While first.

# In[9]:

with NetBuilder() as nb:
    ops.Const(0, blob_out="i")
    ops.Const(0, blob_out="y")
    with ops.WhileNet():
        with ops.Condition():
            ops.Add(["i", ops.Const(1)], ["i"])
            ops.LE(["i", ops.Const(7)])
        ops.Add(["i", "y"], ["y"])

# This example illustrates the usage of 'while' loop with NetBuilder. As with an 'If' operator, standard block semantic rules apply. Note the usage of ops.Condition clause that should immediately follow ops.WhileNet and contains code that is executed before each iteration. The last operator in the condition clause is expected to have a single boolean output that determines whether the other iteration is executed.

# In the example above we increment the counter ("i") before each iteration and accumulate its values in "y" blob, the loop's body is executed 7 times, the resulting blob values:

# In[10]:

plan = Plan('while_net_test')
plan.AddStep(to_execution_step(nb))
ws = workspace.C.Workspace()
ws.run(plan)