def testMultipleCSExecutionsRequestSameResource(self):
    cs0 = critical_section_ops.CriticalSection()
    cs1 = critical_section_ops.CriticalSection()
    v = resource_variable_ops.ResourceVariable(0.0, name="v")
    cs0.execute(lambda: v + 1)
    # It's OK for the same CriticalSection to access this resource.
    cs0.execute(lambda: v - 1)
    # It's *not* OK for a different CriticalSection to access it by
    # default.
    with self.assertRaisesRegexp(
        ValueError, "requested exclusive resource access"):
      cs1.execute(lambda: v + 1)
    # It's not even OK if the second call doesn't request exclusive access.
    with self.assertRaisesRegexp(
        ValueError, "requested exclusive resource access"):
      cs1.execute(lambda: v + 1, exclusive_resource_access=False)

    v2 = resource_variable_ops.ResourceVariable(0.0, name="v2")
    cs0.execute(lambda: v2 + 1, exclusive_resource_access=False)
    # It's OK if neither requests exclusive resource access.
    cs1.execute(lambda: v2 + 1, exclusive_resource_access=False)

    # It's not OK if the second request requires exlusive resource
    # access.
    with self.assertRaisesRegexp(
        ValueError, "requested exclusive resource access"):
      cs1.execute(lambda: v2 + 1)
 def testRecursiveCriticalSectionAccessIsIllegalSameSharedName(self):
   # This does not work properly in eager mode.  Eager users will
   # just hit a deadlock if they do this.  But at least it'll be easier
   # to debug.
   cs = critical_section_ops.CriticalSection(shared_name="cs")
   cs_same = critical_section_ops.CriticalSection(shared_name="cs")
   def fn(x):
     return cs_same.execute(lambda x: x+1, x)
   with self.assertRaisesRegexp(
       ValueError,
       r"attempts to directly access the CriticalSection in which it "
       r"would be running"):
     cs.execute(fn, 1.0)
    def testInsideFunction(self):
        if test_util.is_gpu_available():
            self.skipTest(
                "b/123899495: Colocation errors for critical sections in map on GPU"
            )
        cs = critical_section_ops.CriticalSection()
        with ops.device(
                "/gpu:0" if test_util.is_gpu_available() else "/cpu:0"):
            v = resource_variable_ops.ResourceVariable(1)

        def fn():
            return v.read_value()

        # map() creates a TensorFlow function.
        ds = dataset_ops.Dataset.range(1)
        if test_util.is_gpu_available():
            ds = (ds.apply(prefetching_ops.copy_to_device("/gpu:0")).apply(
                prefetching_ops.map_on_gpu(lambda _: cs.execute(fn))))
        else:
            ds = ds.map(lambda _: cs.execute(fn))

        def get_first():
            if context.executing_eagerly():
                return self.evaluate(ds.make_one_shot_iterator().get_next())
            itr = ds.make_initializable_iterator()
            self.evaluate([v.initializer, itr.initializer])
            return self.evaluate(itr.get_next())

        self.assertEqual(1, get_first())
Example #4
0
    def testCreateCriticalSectionRaw(self):
        cs = critical_section_ops.CriticalSection(name="cs")
        v = resource_variable_ops.ResourceVariable(0.0, name="v")

        @function.Defun(dtypes.float32, dtypes.float32)
        def fn(a, b):
            c = v.read_value()
            with ops.control_dependencies([c]):
                nv = v.assign_add(a * b)
                with ops.control_dependencies([nv]):
                    return array_ops.identity(c)

        def execute(fn, *args):
            output_args = fn.definition.signature.output_arg
            return resource_variable_ops.execute_in_critical_section(
                critical_section=cs._handle,
                arguments=list(args) + fn.captured_inputs,
                f=fn,
                output_types=[out.type for out in output_args],
                output_shapes=[
                    tensor_shape.TensorShape(None) for _ in output_args
                ])

        num_concurrent = 1000
        r = [execute(fn, 1.0, 2.0)[0] for _ in range(num_concurrent)]
        self.evaluate(v.initializer)
        r_value = self.evaluate(r)
        self.assertAllClose([2.0 * i for i in range(num_concurrent)],
                            sorted(r_value))
Example #5
0
 def testCollection(self):
     cs = critical_section_ops.CriticalSection(name="cs")
     self.assertIn(
         cs, ops.get_collection(critical_section_ops.CRITICAL_SECTIONS))
     execute_op = cs.execute(lambda x: x + 1, 1.0).op
     self.assertIn(execute_op, [
         signature.op for signature in ops.get_collection(
             critical_section_ops.CRITICAL_SECTION_EXECUTIONS)
     ])
Example #6
0
    def testRecursiveCriticalSectionAccessIsIllegal(self):
        cs = critical_section_ops.CriticalSection(name="cs")

        def fn(x):
            return cs.execute(lambda x: x + 1, x)

        with self.assertRaisesRegexp(
                ValueError,
                r"attempts to access the CriticalSection in which it would be running"
        ):
            cs.execute(fn, 1.0)
 def testCollection(self):
   cs = critical_section_ops.CriticalSection(shared_name="cs")
   self.assertIn(
       cs, ops.get_collection(critical_section_ops.CRITICAL_SECTIONS))
   execute = cs.execute(lambda x: x + 1, 1.0, name="my_execute")
   execute_op = [
       x for x in execute.graph.get_operations()
       if "my_execute" in x.name and "MutexLock" in x.type
   ][0]
   self.assertIn(
       execute_op,
       [signature.op for signature in
        ops.get_collection(critical_section_ops.CRITICAL_SECTION_EXECUTIONS)])
 def testControlDependencyFromOutsideWhileLoopMixedWithInsideLoop(self):
   cs = critical_section_ops.CriticalSection()
   v = resource_variable_ops.ResourceVariable(0, name="v")
   # Make sure that the control dependencies on v do not cause issues
   # in the lock_op's automatic control dependency adder.
   #
   # Note, here v must be a resource variable (or something similar),
   # otherwise it gets hoisted into the while_loop by the time we add
   # control dependencies to the lock_op.
   out = control_flow_ops.while_loop(
       lambda i: i < 10, lambda i: cs.execute(lambda j: v + j + 1, i), [0])
   self.evaluate(v.initializer)
   self.assertEqual(10, self.evaluate(out))
  def testCriticalSectionInParallelDoesntDeadlockOnError(self):
    # No eager mode execution of this test because eager does not
    # run fn() in parallel, which is where the deadlock could
    # potentially occur (in graph mode).
    cs = critical_section_ops.CriticalSection(shared_name="cs")
    v = resource_variable_ops.ResourceVariable(0.0, name="v")

    def fn(i):
      error = control_flow_ops.Assert((i % 2) == 1, ["Error"])
      with ops.control_dependencies([error]):
        return v.read_value()
    num_concurrent = 2
    r = [cs.execute(fn, i) for i in range(num_concurrent)]
    self.evaluate(v.initializer)
    for _ in range(100):
      with self.assertRaisesOpError("Error"):
        self.evaluate(r)
  def testCreateCriticalSection(self):
    cs = critical_section_ops.CriticalSection(shared_name="cs")
    v = resource_variable_ops.ResourceVariable(0.0, name="v")

    def fn(a, b):
      c = v.value()
      with ops.control_dependencies([c]):
        nv = v.assign_add(a * b)
        with ops.control_dependencies([nv]):
          return array_ops.identity(c)

    num_concurrent = 100
    r = [cs.execute(fn, 1.0, 2.0) for _ in range(num_concurrent)]
    self.evaluate(v.initializer)
    r_value = self.evaluate(r)
    self.assertAllClose([2.0 * i for i in range(num_concurrent)],
                        sorted(r_value))
  def testCreateCriticalSectionFnReturnsOp(self):
    cs = critical_section_ops.CriticalSection(shared_name="cs")
    v = resource_variable_ops.ResourceVariable(0.0, name="v")

    def fn_return_op(a, b):
      c = v.read_value()
      with ops.control_dependencies([c]):
        nv = v.assign_add(a * b)
        with ops.control_dependencies([nv]):
          return control_flow_ops.no_op()

    num_concurrent = 100
    r = [cs.execute(fn_return_op, 1.0, 2.0) for _ in range(num_concurrent)]
    self.evaluate(v.initializer)
    self.evaluate(r)
    final_v = self.evaluate(v)
    self.assertAllClose(2.0 * num_concurrent, final_v)
Example #12
0
    def testInsideFunction(self):
        cs = critical_section_ops.CriticalSection()
        v = resource_variable_ops.ResourceVariable(1)

        def fn():
            return v.read_value()

        # map() creates a TensorFlow function.
        ds = dataset_ops.Dataset.range(1).map(lambda _: cs.execute(fn))

        def get_first():
            if context.executing_eagerly():
                return self.evaluate(ds.make_one_shot_iterator().get_next())
            itr = ds.make_initializable_iterator()
            self.evaluate([v.initializer, itr.initializer])
            return self.evaluate(itr.get_next())

        self.assertEqual(1, get_first())
  def testRecursiveCriticalSectionAccessViaCapturedTensorIsProtected(self):
    # This one is subtle; and we're being overly cautious here.  The
    # deadlock we are ensuring we catch is:
    #
    # to_capture = CS[lambda x: x + 1](1.0)
    # deadlocked = CS[lambda x: x + to_capture](1.0)
    #
    # This would have caused a deadlock because executing `deadlocked` will
    # lock the mutex on CS; but then due to dependencies, will attempt
    # to compute `to_capture`.  This computation requires locking CS,
    # but that is not possible now because CS is already locked by
    # `deadlocked`.
    #
    # We check that CriticalSection.execute properly inserts new
    # control dependencies to its lock to ensure all captured
    # operations are finished before anything runs within the critical section.
    cs = critical_section_ops.CriticalSection(shared_name="cs")
    fn = array_ops.identity
    to_capture = cs.execute(fn, 1.0)
    fn_captures = lambda x: x + to_capture
    to_capture_too = array_ops.identity(to_capture)

    ex_0 = cs.execute(fn_captures, 1.0)

    with ops.control_dependencies([to_capture]):
      # This is OK because to_capture will execute before this next call
      ex_1 = cs.execute(fn_captures, 1.0)

    dependency = array_ops.identity(to_capture)

    fn_captures_dependency = lambda x: x + dependency

    ex_2 = cs.execute(fn_captures_dependency, 1.0)

    with ops.control_dependencies([to_capture_too]):
      ex_3 = cs.execute(fn_captures_dependency, 1.0)

    # Ensure there's no actual deadlock on to_execute.
    self.assertEquals(2.0, self.evaluate(ex_0))
    self.assertEquals(2.0, self.evaluate(ex_1))
    self.assertEquals(2.0, self.evaluate(ex_2))
    self.assertEquals(2.0, self.evaluate(ex_3))
    def testCriticalSectionWithControlFlow(self):
        for outer_cond in [False, True]:
            for inner_cond in [False, True]:
                cs = critical_section_ops.CriticalSection(shared_name="cs")
                v = resource_variable_ops.ResourceVariable(0.0, name="v")
                num_concurrent = 100

                # pylint: disable=cell-var-from-loop
                def fn(a, b):
                    c = v.read_value()

                    def true_fn():
                        with ops.control_dependencies([c]):
                            nv = v.assign_add(a * b)
                            with ops.control_dependencies([nv]):
                                return array_ops.identity(c)

                    return control_flow_ops.cond(
                        array_ops.identity(inner_cond), true_fn, lambda: c)

                def execute():
                    return cs.execute(fn, 1.0, 2.0)

                r = [
                    control_flow_ops.cond(array_ops.identity(outer_cond),
                                          execute, v.read_value)
                    for _ in range(num_concurrent)
                ]
                # pylint: enable=cell-var-from-loop

                self.evaluate(v.initializer)
                r_value = self.evaluate(r)
                if inner_cond and outer_cond:
                    self.assertAllClose(
                        [2.0 * i for i in range(num_concurrent)],
                        sorted(r_value))
                else:
                    self.assertAllClose([0] * num_concurrent, r_value)
  def testRecursiveCriticalSectionAccessWithinLoopIsProtected(self):
    cs = critical_section_ops.CriticalSection(shared_name="cs")

    def body_implicit_capture(i, j):
      # This would have caused a deadlock if not for logic in execute
      # that inserts additional control dependencies onto the lock op:
      #   * Loop body argument j is captured by fn()
      #   * i is running in parallel to move forward the execution
      #   * j is not being checked by the predicate function
      #   * output of cs.execute() is returned as next j.
      fn = lambda: j + 1
      return (i + 1, cs.execute(fn))

    (i_n, j_n) = control_flow_ops.while_loop(
        lambda i, _: i < 1000,
        body_implicit_capture,
        [0, 0],
        parallel_iterations=25)
    logging.warn(
        "\n==============\nRunning "
        "'testRecursiveCriticalSectionAccessWithinLoopDoesNotDeadlock "
        "body_implicit_capture'\n"
        "==============\n")
    self.assertEquals((1000, 1000), self.evaluate((i_n, j_n)))
    logging.warn(
        "\n==============\nSuccessfully finished running "
        "'testRecursiveCriticalSectionAccessWithinLoopDoesNotDeadlock "
        "body_implicit_capture'\n"
        "==============\n")

    def body_implicit_capture_protected(i, j):
      # This version is ok because we manually add a control
      # dependency on j, which is an argument to the while_loop body
      # and captured by fn.
      fn = lambda: j + 1
      with ops.control_dependencies([j]):
        return (i + 1, cs.execute(fn))

    (i_n, j_n) = control_flow_ops.while_loop(
        lambda i, _: i < 1000,
        body_implicit_capture_protected,
        [0, 0],
        parallel_iterations=25)
    logging.warn(
        "\n==============\nRunning "
        "'testRecursiveCriticalSectionAccessWithinLoopDoesNotDeadlock "
        "body_implicit_capture_protected'\n"
        "==============\n")
    self.assertEquals((1000, 1000), self.evaluate((i_n, j_n)))
    logging.warn(
        "\n==============\nSuccessfully finished running "
        "'testRecursiveCriticalSectionAccessWithinLoopDoesNotDeadlock "
        "body_implicit_capture_protected'\n"
        "==============\n")

    def body_args_capture(i, j):
      # This version is ok because j is an argument to fn and we can
      # ensure there's a control dependency on j.
      fn = lambda x: x + 1
      return (i + 1, cs.execute(fn, j))

    (i_n, j_n) = control_flow_ops.while_loop(
        lambda i, _: i < 1000,
        body_args_capture,
        [0, 0],
        parallel_iterations=25)
    logging.warn(
        "\n==============\nRunning "
        "'testRecursiveCriticalSectionAccessWithinLoopDoesNotDeadlock "
        "body_args_capture'\n"
        "==============\n")
    self.assertEquals((1000, 1000), self.evaluate((i_n, j_n)))
    logging.warn(
        "\n==============\nSuccessfully finished running "
        "'testRecursiveCriticalSectionAccessWithinLoopDoesNotDeadlock "
        "body_args_capture'\n"
        "==============\n")