def testMultipleCSExecutionsRequestSameResource(self): cs0 = critical_section_ops.CriticalSection() cs1 = critical_section_ops.CriticalSection() v = resource_variable_ops.ResourceVariable(0.0, name="v") cs0.execute(lambda: v + 1) # It's OK for the same CriticalSection to access this resource. cs0.execute(lambda: v - 1) # It's *not* OK for a different CriticalSection to access it by # default. with self.assertRaisesRegexp( ValueError, "requested exclusive resource access"): cs1.execute(lambda: v + 1) # It's not even OK if the second call doesn't request exclusive access. with self.assertRaisesRegexp( ValueError, "requested exclusive resource access"): cs1.execute(lambda: v + 1, exclusive_resource_access=False) v2 = resource_variable_ops.ResourceVariable(0.0, name="v2") cs0.execute(lambda: v2 + 1, exclusive_resource_access=False) # It's OK if neither requests exclusive resource access. cs1.execute(lambda: v2 + 1, exclusive_resource_access=False) # It's not OK if the second request requires exlusive resource # access. with self.assertRaisesRegexp( ValueError, "requested exclusive resource access"): cs1.execute(lambda: v2 + 1)
def testRecursiveCriticalSectionAccessIsIllegalSameSharedName(self): # This does not work properly in eager mode. Eager users will # just hit a deadlock if they do this. But at least it'll be easier # to debug. cs = critical_section_ops.CriticalSection(shared_name="cs") cs_same = critical_section_ops.CriticalSection(shared_name="cs") def fn(x): return cs_same.execute(lambda x: x+1, x) with self.assertRaisesRegexp( ValueError, r"attempts to directly access the CriticalSection in which it " r"would be running"): cs.execute(fn, 1.0)
def testInsideFunction(self): if test_util.is_gpu_available(): self.skipTest( "b/123899495: Colocation errors for critical sections in map on GPU" ) cs = critical_section_ops.CriticalSection() with ops.device( "/gpu:0" if test_util.is_gpu_available() else "/cpu:0"): v = resource_variable_ops.ResourceVariable(1) def fn(): return v.read_value() # map() creates a TensorFlow function. ds = dataset_ops.Dataset.range(1) if test_util.is_gpu_available(): ds = (ds.apply(prefetching_ops.copy_to_device("/gpu:0")).apply( prefetching_ops.map_on_gpu(lambda _: cs.execute(fn)))) else: ds = ds.map(lambda _: cs.execute(fn)) def get_first(): if context.executing_eagerly(): return self.evaluate(ds.make_one_shot_iterator().get_next()) itr = ds.make_initializable_iterator() self.evaluate([v.initializer, itr.initializer]) return self.evaluate(itr.get_next()) self.assertEqual(1, get_first())
def testCreateCriticalSectionRaw(self): cs = critical_section_ops.CriticalSection(name="cs") v = resource_variable_ops.ResourceVariable(0.0, name="v") @function.Defun(dtypes.float32, dtypes.float32) def fn(a, b): c = v.read_value() with ops.control_dependencies([c]): nv = v.assign_add(a * b) with ops.control_dependencies([nv]): return array_ops.identity(c) def execute(fn, *args): output_args = fn.definition.signature.output_arg return resource_variable_ops.execute_in_critical_section( critical_section=cs._handle, arguments=list(args) + fn.captured_inputs, f=fn, output_types=[out.type for out in output_args], output_shapes=[ tensor_shape.TensorShape(None) for _ in output_args ]) num_concurrent = 1000 r = [execute(fn, 1.0, 2.0)[0] for _ in range(num_concurrent)] self.evaluate(v.initializer) r_value = self.evaluate(r) self.assertAllClose([2.0 * i for i in range(num_concurrent)], sorted(r_value))
def testCollection(self): cs = critical_section_ops.CriticalSection(name="cs") self.assertIn( cs, ops.get_collection(critical_section_ops.CRITICAL_SECTIONS)) execute_op = cs.execute(lambda x: x + 1, 1.0).op self.assertIn(execute_op, [ signature.op for signature in ops.get_collection( critical_section_ops.CRITICAL_SECTION_EXECUTIONS) ])
def testRecursiveCriticalSectionAccessIsIllegal(self): cs = critical_section_ops.CriticalSection(name="cs") def fn(x): return cs.execute(lambda x: x + 1, x) with self.assertRaisesRegexp( ValueError, r"attempts to access the CriticalSection in which it would be running" ): cs.execute(fn, 1.0)
def testCollection(self): cs = critical_section_ops.CriticalSection(shared_name="cs") self.assertIn( cs, ops.get_collection(critical_section_ops.CRITICAL_SECTIONS)) execute = cs.execute(lambda x: x + 1, 1.0, name="my_execute") execute_op = [ x for x in execute.graph.get_operations() if "my_execute" in x.name and "MutexLock" in x.type ][0] self.assertIn( execute_op, [signature.op for signature in ops.get_collection(critical_section_ops.CRITICAL_SECTION_EXECUTIONS)])
def testControlDependencyFromOutsideWhileLoopMixedWithInsideLoop(self): cs = critical_section_ops.CriticalSection() v = resource_variable_ops.ResourceVariable(0, name="v") # Make sure that the control dependencies on v do not cause issues # in the lock_op's automatic control dependency adder. # # Note, here v must be a resource variable (or something similar), # otherwise it gets hoisted into the while_loop by the time we add # control dependencies to the lock_op. out = control_flow_ops.while_loop( lambda i: i < 10, lambda i: cs.execute(lambda j: v + j + 1, i), [0]) self.evaluate(v.initializer) self.assertEqual(10, self.evaluate(out))
def testCriticalSectionInParallelDoesntDeadlockOnError(self): # No eager mode execution of this test because eager does not # run fn() in parallel, which is where the deadlock could # potentially occur (in graph mode). cs = critical_section_ops.CriticalSection(shared_name="cs") v = resource_variable_ops.ResourceVariable(0.0, name="v") def fn(i): error = control_flow_ops.Assert((i % 2) == 1, ["Error"]) with ops.control_dependencies([error]): return v.read_value() num_concurrent = 2 r = [cs.execute(fn, i) for i in range(num_concurrent)] self.evaluate(v.initializer) for _ in range(100): with self.assertRaisesOpError("Error"): self.evaluate(r)
def testCreateCriticalSection(self): cs = critical_section_ops.CriticalSection(shared_name="cs") v = resource_variable_ops.ResourceVariable(0.0, name="v") def fn(a, b): c = v.value() with ops.control_dependencies([c]): nv = v.assign_add(a * b) with ops.control_dependencies([nv]): return array_ops.identity(c) num_concurrent = 100 r = [cs.execute(fn, 1.0, 2.0) for _ in range(num_concurrent)] self.evaluate(v.initializer) r_value = self.evaluate(r) self.assertAllClose([2.0 * i for i in range(num_concurrent)], sorted(r_value))
def testCreateCriticalSectionFnReturnsOp(self): cs = critical_section_ops.CriticalSection(shared_name="cs") v = resource_variable_ops.ResourceVariable(0.0, name="v") def fn_return_op(a, b): c = v.read_value() with ops.control_dependencies([c]): nv = v.assign_add(a * b) with ops.control_dependencies([nv]): return control_flow_ops.no_op() num_concurrent = 100 r = [cs.execute(fn_return_op, 1.0, 2.0) for _ in range(num_concurrent)] self.evaluate(v.initializer) self.evaluate(r) final_v = self.evaluate(v) self.assertAllClose(2.0 * num_concurrent, final_v)
def testInsideFunction(self): cs = critical_section_ops.CriticalSection() v = resource_variable_ops.ResourceVariable(1) def fn(): return v.read_value() # map() creates a TensorFlow function. ds = dataset_ops.Dataset.range(1).map(lambda _: cs.execute(fn)) def get_first(): if context.executing_eagerly(): return self.evaluate(ds.make_one_shot_iterator().get_next()) itr = ds.make_initializable_iterator() self.evaluate([v.initializer, itr.initializer]) return self.evaluate(itr.get_next()) self.assertEqual(1, get_first())
def testRecursiveCriticalSectionAccessViaCapturedTensorIsProtected(self): # This one is subtle; and we're being overly cautious here. The # deadlock we are ensuring we catch is: # # to_capture = CS[lambda x: x + 1](1.0) # deadlocked = CS[lambda x: x + to_capture](1.0) # # This would have caused a deadlock because executing `deadlocked` will # lock the mutex on CS; but then due to dependencies, will attempt # to compute `to_capture`. This computation requires locking CS, # but that is not possible now because CS is already locked by # `deadlocked`. # # We check that CriticalSection.execute properly inserts new # control dependencies to its lock to ensure all captured # operations are finished before anything runs within the critical section. cs = critical_section_ops.CriticalSection(shared_name="cs") fn = array_ops.identity to_capture = cs.execute(fn, 1.0) fn_captures = lambda x: x + to_capture to_capture_too = array_ops.identity(to_capture) ex_0 = cs.execute(fn_captures, 1.0) with ops.control_dependencies([to_capture]): # This is OK because to_capture will execute before this next call ex_1 = cs.execute(fn_captures, 1.0) dependency = array_ops.identity(to_capture) fn_captures_dependency = lambda x: x + dependency ex_2 = cs.execute(fn_captures_dependency, 1.0) with ops.control_dependencies([to_capture_too]): ex_3 = cs.execute(fn_captures_dependency, 1.0) # Ensure there's no actual deadlock on to_execute. self.assertEquals(2.0, self.evaluate(ex_0)) self.assertEquals(2.0, self.evaluate(ex_1)) self.assertEquals(2.0, self.evaluate(ex_2)) self.assertEquals(2.0, self.evaluate(ex_3))
def testCriticalSectionWithControlFlow(self): for outer_cond in [False, True]: for inner_cond in [False, True]: cs = critical_section_ops.CriticalSection(shared_name="cs") v = resource_variable_ops.ResourceVariable(0.0, name="v") num_concurrent = 100 # pylint: disable=cell-var-from-loop def fn(a, b): c = v.read_value() def true_fn(): with ops.control_dependencies([c]): nv = v.assign_add(a * b) with ops.control_dependencies([nv]): return array_ops.identity(c) return control_flow_ops.cond( array_ops.identity(inner_cond), true_fn, lambda: c) def execute(): return cs.execute(fn, 1.0, 2.0) r = [ control_flow_ops.cond(array_ops.identity(outer_cond), execute, v.read_value) for _ in range(num_concurrent) ] # pylint: enable=cell-var-from-loop self.evaluate(v.initializer) r_value = self.evaluate(r) if inner_cond and outer_cond: self.assertAllClose( [2.0 * i for i in range(num_concurrent)], sorted(r_value)) else: self.assertAllClose([0] * num_concurrent, r_value)
def testRecursiveCriticalSectionAccessWithinLoopIsProtected(self): cs = critical_section_ops.CriticalSection(shared_name="cs") def body_implicit_capture(i, j): # This would have caused a deadlock if not for logic in execute # that inserts additional control dependencies onto the lock op: # * Loop body argument j is captured by fn() # * i is running in parallel to move forward the execution # * j is not being checked by the predicate function # * output of cs.execute() is returned as next j. fn = lambda: j + 1 return (i + 1, cs.execute(fn)) (i_n, j_n) = control_flow_ops.while_loop( lambda i, _: i < 1000, body_implicit_capture, [0, 0], parallel_iterations=25) logging.warn( "\n==============\nRunning " "'testRecursiveCriticalSectionAccessWithinLoopDoesNotDeadlock " "body_implicit_capture'\n" "==============\n") self.assertEquals((1000, 1000), self.evaluate((i_n, j_n))) logging.warn( "\n==============\nSuccessfully finished running " "'testRecursiveCriticalSectionAccessWithinLoopDoesNotDeadlock " "body_implicit_capture'\n" "==============\n") def body_implicit_capture_protected(i, j): # This version is ok because we manually add a control # dependency on j, which is an argument to the while_loop body # and captured by fn. fn = lambda: j + 1 with ops.control_dependencies([j]): return (i + 1, cs.execute(fn)) (i_n, j_n) = control_flow_ops.while_loop( lambda i, _: i < 1000, body_implicit_capture_protected, [0, 0], parallel_iterations=25) logging.warn( "\n==============\nRunning " "'testRecursiveCriticalSectionAccessWithinLoopDoesNotDeadlock " "body_implicit_capture_protected'\n" "==============\n") self.assertEquals((1000, 1000), self.evaluate((i_n, j_n))) logging.warn( "\n==============\nSuccessfully finished running " "'testRecursiveCriticalSectionAccessWithinLoopDoesNotDeadlock " "body_implicit_capture_protected'\n" "==============\n") def body_args_capture(i, j): # This version is ok because j is an argument to fn and we can # ensure there's a control dependency on j. fn = lambda x: x + 1 return (i + 1, cs.execute(fn, j)) (i_n, j_n) = control_flow_ops.while_loop( lambda i, _: i < 1000, body_args_capture, [0, 0], parallel_iterations=25) logging.warn( "\n==============\nRunning " "'testRecursiveCriticalSectionAccessWithinLoopDoesNotDeadlock " "body_args_capture'\n" "==============\n") self.assertEquals((1000, 1000), self.evaluate((i_n, j_n))) logging.warn( "\n==============\nSuccessfully finished running " "'testRecursiveCriticalSectionAccessWithinLoopDoesNotDeadlock " "body_args_capture'\n" "==============\n")