def testCollection(self): cs = critical_section_ops.CriticalSection(shared_name="cs") self.assertIn( cs, ops.get_collection(critical_section_ops.CRITICAL_SECTIONS)) add = lambda x: x + 1 execute = cs.execute(lambda: add(1.0), name="my_execute") execute_op = [ x for x in execute.graph.get_operations() if "my_execute" in x.name and "MutexLock" in x.type ][0] self.assertIn(execute_op, [ signature.op for signature in ops.get_collection( critical_section_ops.CRITICAL_SECTION_EXECUTIONS) ])
def testRecursiveCriticalSectionAccessIsIllegal(self): # This does not work properly in eager mode. Eager users will # just hit a deadlock if they do this. But at least it'll be easier # to debug. cs = critical_section_ops.CriticalSection() add = lambda y: y + 1 def fn(x): return cs.execute(lambda: add(x)) with self.assertRaisesRegexp( ValueError, r"Attempting to lock a CriticalSection in which we are"): cs.execute(lambda: fn(1.0))
def testCreateCriticalSection(self): cs = critical_section_ops.CriticalSection(shared_name="cs") v = resource_variable_ops.ResourceVariable(0.0, name="v") def fn(a, b): c = v.value() with ops.control_dependencies([c]): nv = v.assign_add(a * b) with ops.control_dependencies([nv]): return array_ops.identity(c) num_concurrent = 100 r = [cs.execute(lambda: fn(1.0, 2.0)) for _ in range(num_concurrent)] self.evaluate(v.initializer) r_value = self.evaluate(r) self.assertAllClose([2.0 * i for i in range(num_concurrent)], sorted(r_value))
def testControlDependencyFromOutsideWhileLoopMixedWithInsideLoop(self): cs = critical_section_ops.CriticalSection() v = resource_variable_ops.ResourceVariable(0, name="v") # Make sure that the control dependencies on v do not cause issues # in the lock_op's automatic control dependency adder. # # Note, here v must be a resource variable (or something similar), # otherwise it gets hoisted into the while_loop by the time we add # control dependencies to the lock_op. def body(i): add_j = lambda j: v + j + 1 return cs.execute(lambda: add_j(i)) out = control_flow_ops.while_loop(lambda i: i < 10, body, [0]) self.evaluate(v.initializer) self.assertEqual(10, self.evaluate(out))
def testRecursiveCriticalSectionAccessViaCapturedTensorIsProtected(self): # This one is subtle; and we're being overly cautious here. The # deadlock we are ensuring we catch is: # # to_capture = CS[lambda x: x + 1](1.0) # deadlocked = CS[lambda x: x + to_capture](1.0) # # This would have caused a deadlock because executing `deadlocked` will # lock the mutex on CS; but then due to dependencies, will attempt # to compute `to_capture`. This computation requires locking CS, # but that is not possible now because CS is already locked by # `deadlocked`. # # We check that CriticalSection.execute properly inserts new # control dependencies to its lock to ensure all captured # operations are finished before anything runs within the critical section. cs = critical_section_ops.CriticalSection(shared_name="cs") fn = array_ops.identity to_capture = cs.execute(lambda: fn(1.0)) fn_captures = lambda x: x + to_capture to_capture_too = array_ops.identity(to_capture) ex_0 = cs.execute(lambda: fn_captures(1.0)) with ops.control_dependencies([to_capture]): # This is OK because to_capture will execute before this next call ex_1 = cs.execute(lambda: fn_captures(1.0)) dependency = array_ops.identity(to_capture) fn_captures_dependency = lambda x: x + dependency ex_2 = cs.execute(lambda: fn_captures_dependency(1.0)) with ops.control_dependencies([to_capture_too]): ex_3 = cs.execute(lambda: fn_captures_dependency(1.0)) # Ensure there's no actual deadlock on to_execute. self.assertEqual(2.0, self.evaluate(ex_0)) self.assertEqual(2.0, self.evaluate(ex_1)) self.assertEqual(2.0, self.evaluate(ex_2)) self.assertEqual(2.0, self.evaluate(ex_3))
def testCreateCriticalSectionFnReturnsOp(self): cs = critical_section_ops.CriticalSection(shared_name="cs") v = resource_variable_ops.ResourceVariable(0.0, name="v") def fn_return_op(a, b): c = v.read_value() with ops.control_dependencies([c]): nv = v.assign_add(a * b) with ops.control_dependencies([nv]): return control_flow_ops.no_op() num_concurrent = 100 r = [ cs.execute(lambda: fn_return_op(1.0, 2.0)) for _ in range(num_concurrent) ] self.evaluate(v.initializer) self.evaluate(r) final_v = self.evaluate(v) self.assertAllClose(2.0 * num_concurrent, final_v)
def testCriticalSectionWithControlFlow(self): for outer_cond in [False, True]: for inner_cond in [False, True]: cs = critical_section_ops.CriticalSection(shared_name="cs") v = resource_variable_ops.ResourceVariable(0.0, name="v") num_concurrent = 100 # pylint: disable=cell-var-from-loop def fn(a, b): c = v.read_value() def true_fn(): with ops.control_dependencies([c]): nv = v.assign_add(a * b) with ops.control_dependencies([nv]): return array_ops.identity(c) return control_flow_ops.cond( array_ops.identity(inner_cond), true_fn, lambda: c) def execute(): return cs.execute(lambda: fn(1.0, 2.0)) r = [ control_flow_ops.cond(array_ops.identity(outer_cond), execute, v.read_value) for _ in range(num_concurrent) ] # pylint: enable=cell-var-from-loop self.evaluate(v.initializer) r_value = self.evaluate(r) if inner_cond and outer_cond: self.assertAllClose( [2.0 * i for i in range(num_concurrent)], sorted(r_value)) else: self.assertAllClose([0] * num_concurrent, r_value)
def testRecursiveCriticalSectionAccessWithinLoopIsProtected(self): cs = critical_section_ops.CriticalSection(shared_name="cs") def body_implicit_capture(i, j): # This would have caused a deadlock if not for logic in execute # that inserts additional control dependencies onto the lock op: # * Loop body argument j is captured by fn() # * i is running in parallel to move forward the execution # * j is not being checked by the predicate function # * output of cs.execute() is returned as next j. fn = lambda: j + 1 return (i + 1, cs.execute(fn)) (i_n, j_n) = control_flow_ops.while_loop(lambda i, _: i < 1000, body_implicit_capture, [0, 0], parallel_iterations=25) # For consistency between eager and graph mode. i_n = array_ops.identity(i_n) logging.warn( "\n==============\nRunning " "'testRecursiveCriticalSectionAccessWithinLoopDoesNotDeadlock " "body_implicit_capture'\n" "==============\n") self.assertEqual((1000, 1000), self.evaluate((i_n, j_n))) logging.warn( "\n==============\nSuccessfully finished running " "'testRecursiveCriticalSectionAccessWithinLoopDoesNotDeadlock " "body_implicit_capture'\n" "==============\n") def body_implicit_capture_protected(i, j): # This version is ok because we manually add a control # dependency on j, which is an argument to the while_loop body # and captured by fn. fn = lambda: j + 1 with ops.control_dependencies([j]): return (i + 1, cs.execute(fn)) (i_n, j_n) = control_flow_ops.while_loop(lambda i, _: i < 1000, body_implicit_capture_protected, [0, 0], parallel_iterations=25) # For consistency between eager and graph mode. i_n = array_ops.identity(i_n) logging.warn( "\n==============\nRunning " "'testRecursiveCriticalSectionAccessWithinLoopDoesNotDeadlock " "body_implicit_capture_protected'\n" "==============\n") self.assertEqual((1000, 1000), self.evaluate((i_n, j_n))) logging.warn( "\n==============\nSuccessfully finished running " "'testRecursiveCriticalSectionAccessWithinLoopDoesNotDeadlock " "body_implicit_capture_protected'\n" "==============\n") def body_args_capture(i, j): # This version is ok because j is an argument to fn and we can # ensure there's a control dependency on j. fn = lambda x: x + 1 return (i + 1, cs.execute(lambda: fn(j))) (i_n, j_n) = control_flow_ops.while_loop(lambda i, _: i < 1000, body_args_capture, [0, 0], parallel_iterations=25) # For consistency between eager and graph mode. i_n = array_ops.identity(i_n) logging.warn( "\n==============\nRunning " "'testRecursiveCriticalSectionAccessWithinLoopDoesNotDeadlock " "body_args_capture'\n" "==============\n") self.assertEqual((1000, 1000), self.evaluate((i_n, j_n))) logging.warn( "\n==============\nSuccessfully finished running " "'testRecursiveCriticalSectionAccessWithinLoopDoesNotDeadlock " "body_args_capture'\n" "==============\n")
def testCreateCriticalSection(self): cs = critical_section_ops.CriticalSection(shared_name="cs") cs.execute(lambda: 1.0)