Beispiel #1
0
    def test_create_plan_from_proto_correctly(self):
        from caffe2.python.net_builder import ops
        with Node('trainer'), Task(name='my_task', num_instances=2) as task:
            with ops.task_init():
                globl = ops.Const(0)
            with ops.task_instance_init():
                local = ops.Const(0)
            with ops.loop(100):
                ops.Copy(globl, local)
            with ops.task_instance_exit():
                ops.Add([globl, local], [globl])
            with ops.task_exit():
                ops.Mul([globl, globl], [globl])

        plan = core.Plan(task.get_step())
        test_plan = core.Plan.create_from_proto(plan.Proto())

        self.assertEqual(len(plan.Steps()), 1)
        self.assertEqual(len(test_plan.Steps()), 1)
        self.assertEqual(len(plan.Proto().network), 9)
        self.assertEqual(len(test_plan.Proto().network), 9)
        self.assertEqual(len(plan.Proto().execution_step), 1)
        self.assertEqual(len(test_plan.Proto().execution_step), 1)
        self.assertEqual(plan.Steps()[0].Name(), test_plan.Steps()[0].Name())
        self.assertEqual(len(plan.Nets()), len(test_plan.Nets()))
        for idx in range(0, len(plan.Nets())):
            # When we create Net for test_plan, we will end up with new Net
            # name with postfix.
            net_1 = plan.Nets()[idx]
            net_2 = test_plan.Nets()[idx]
            trim_size = len(net_1.Name())
            self.assertEqual(net_1.Name(), net_2.Name()[:trim_size])
Beispiel #2
0
    def testWhile(self):
        with NetBuilder(_use_control_ops=True) as nb:
            ops.Copy(ops.Const(0), "i")
            ops.Copy(ops.Const(1), "one")
            ops.Copy(ops.Const(2), "two")
            ops.Copy(ops.Const(2.0), "x")
            ops.Copy(ops.Const(3.0), "y")
            ops.Copy(ops.Const(2.0), "z")
            # raises x to the power of 4 and y to the power of 2
            # and z to the power of 3
            with ops.WhileNet():
                with ops.Condition():
                    ops.Add(["i", "one"], "i")
                    ops.LE(["i", "two"])
                ops.Pow("x", "x", exponent=2.0)
                with ops.IfNet(ops.LT(["i", "two"])):
                    ops.Pow("y", "y", exponent=2.0)
                with ops.Else():
                    ops.Pow("z", "z", exponent=3.0)

            ops.Add(["x", "y"], "x_plus_y")
            ops.Add(["x_plus_y", "z"], "s")

        assert len(nb.get()) == 1, "Expected a single net produced"
        net = nb.get()[0]

        net.AddGradientOperators(["s"])
        workspace.RunNetOnce(net)
        # (x^4)' = 4x^3
        self.assertAlmostEqual(workspace.FetchBlob("x_grad"), 32)
        self.assertAlmostEqual(workspace.FetchBlob("x"), 16)
        # (y^2)' = 2y
        self.assertAlmostEqual(workspace.FetchBlob("y_grad"), 6)
        self.assertAlmostEqual(workspace.FetchBlob("y"), 9)
        # (z^3)' = 3z^2
        self.assertAlmostEqual(workspace.FetchBlob("z_grad"), 12)
        self.assertAlmostEqual(workspace.FetchBlob("z"), 8)
Beispiel #3
0
# In[1]:

from caffe2.python import workspace
from caffe2.python.core import Plan, to_execution_step, Net
from caffe2.python.net_builder import ops, NetBuilder

# In[2]:

with NetBuilder() as nb:
    ops.Const(0.0, blob_out="zero")
    ops.Const(1.0, blob_out="one")
    ops.Const(0.5, blob_out="x")
    ops.Const(0.0, blob_out="y")
    with ops.IfNet(ops.GT(["x", "zero"])):
        ops.Copy("one", "y")
    with ops.Else():
        ops.Copy("zero", "y")

# Note the usage of NetBuilder's ops.IfNet and ops.Else calls: ops.IfNet accepts a blob reference or blob name as an input, it expects an input blob to have a scalar value convertible to bool, also note that optional ops.Else is at the same level as ops.IfNet and immediately follows corresponding ops.IfNet. Let's execute resulting net (execution step) and check values of blobs.

# In[3]:

plan = Plan('if_net_test')
plan.AddStep(to_execution_step(nb))
ws = workspace.C.Workspace()
ws.run(plan)
print('x = ', ws.blobs["x"].fetch())
print('y = ', ws.blobs["y"].fetch())

# Before going further, it's important to understand the semantics of execution blocks ('then' and 'else' branches in the example above), i.e. handling of reads and writes into global (defined outside of the block) and local (defined inside the block) blobs.