def testCreateWorkspace(self): workspaces = workspace.Workspaces() self.assertEqual(len(workspaces), 1) self.assertEqual(workspaces[0], "default") self.net = core.Net("test-net") self.net.ConstantFill([], "testblob", shape=[1, 2, 3, 4], value=1.0) self.assertEqual( workspace.RunNetOnce(self.net.Proto().SerializeToString()), True) self.assertEqual(workspace.HasBlob("testblob"), True) self.assertEqual(workspace.SwitchWorkspace("test", True), True) self.assertEqual(workspace.HasBlob("testblob"), False) self.assertEqual(workspace.SwitchWorkspace("default"), True) self.assertEqual(workspace.HasBlob("testblob"), True) try: # The following should raise an error. workspace.SwitchWorkspace("non-existing") # so this should never happen. self.assertEqual(True, False) except RuntimeError: pass workspaces = workspace.Workspaces() self.assertEqual(len(workspaces), 2) workspaces.sort() self.assertEqual(workspaces[0], "default") self.assertEqual(workspaces[1], "test")
def testRunPlan(self): plan = core.Plan("test-plan") plan.AddNets([self.net]) plan.AddStep(core.ExecutionStep("test-step", self.net)) self.assertEqual(workspace.RunPlan(plan.Proto().SerializeToString()), True) self.assertEqual(workspace.HasBlob("testblob"), True)
def testRunOperatorOnce(self): self.assertEqual( workspace.RunOperatorOnce( self.net.Proto().op[0].SerializeToString()), True) self.assertEqual(workspace.HasBlob("testblob"), True) blobs = workspace.Blobs() self.assertEqual(len(blobs), 1) self.assertEqual(blobs[0], "testblob")
def testResetWorkspace(self): self.assertEqual(workspace.RunNetOnce(self.net.Proto().SerializeToString()), True) self.assertEqual(workspace.HasBlob("testblob"), True) self.assertEqual(workspace.ResetWorkspace(), True) self.assertEqual(workspace.HasBlob("testblob"), False)
def testWorkspaceHasBlobWithNonexistingName(self): self.assertEqual(workspace.HasBlob("non-existing"), False)