Ejemplo n.º 1
0
 def testCollection(self):
     cs = critical_section_ops.CriticalSection(shared_name="cs")
     self.assertIn(
         cs, ops.get_collection(critical_section_ops.CRITICAL_SECTIONS))
     add = lambda x: x + 1
     execute = cs.execute(lambda: add(1.0), name="my_execute")
     execute_op = [
         x for x in execute.graph.get_operations()
         if "my_execute" in x.name and "MutexLock" in x.type
     ][0]
     self.assertIn(execute_op, [
         signature.op for signature in ops.get_collection(
             critical_section_ops.CRITICAL_SECTION_EXECUTIONS)
     ])
Ejemplo n.º 2
0
    def testRecursiveCriticalSectionAccessIsIllegal(self):
        # This does not work properly in eager mode.  Eager users will
        # just hit a deadlock if they do this.  But at least it'll be easier
        # to debug.
        cs = critical_section_ops.CriticalSection()
        add = lambda y: y + 1

        def fn(x):
            return cs.execute(lambda: add(x))

        with self.assertRaisesRegexp(
                ValueError,
                r"Attempting to lock a CriticalSection in which we are"):
            cs.execute(lambda: fn(1.0))
Ejemplo n.º 3
0
    def testCreateCriticalSection(self):
        cs = critical_section_ops.CriticalSection(shared_name="cs")
        v = resource_variable_ops.ResourceVariable(0.0, name="v")

        def fn(a, b):
            c = v.value()
            with ops.control_dependencies([c]):
                nv = v.assign_add(a * b)
                with ops.control_dependencies([nv]):
                    return array_ops.identity(c)

        num_concurrent = 100
        r = [cs.execute(lambda: fn(1.0, 2.0)) for _ in range(num_concurrent)]
        self.evaluate(v.initializer)
        r_value = self.evaluate(r)
        self.assertAllClose([2.0 * i for i in range(num_concurrent)],
                            sorted(r_value))
Ejemplo n.º 4
0
    def testControlDependencyFromOutsideWhileLoopMixedWithInsideLoop(self):
        cs = critical_section_ops.CriticalSection()
        v = resource_variable_ops.ResourceVariable(0, name="v")

        # Make sure that the control dependencies on v do not cause issues
        # in the lock_op's automatic control dependency adder.
        #
        # Note, here v must be a resource variable (or something similar),
        # otherwise it gets hoisted into the while_loop by the time we add
        # control dependencies to the lock_op.
        def body(i):
            add_j = lambda j: v + j + 1
            return cs.execute(lambda: add_j(i))

        out = control_flow_ops.while_loop(lambda i: i < 10, body, [0])
        self.evaluate(v.initializer)
        self.assertEqual(10, self.evaluate(out))
Ejemplo n.º 5
0
    def testRecursiveCriticalSectionAccessViaCapturedTensorIsProtected(self):
        # This one is subtle; and we're being overly cautious here.  The
        # deadlock we are ensuring we catch is:
        #
        # to_capture = CS[lambda x: x + 1](1.0)
        # deadlocked = CS[lambda x: x + to_capture](1.0)
        #
        # This would have caused a deadlock because executing `deadlocked` will
        # lock the mutex on CS; but then due to dependencies, will attempt
        # to compute `to_capture`.  This computation requires locking CS,
        # but that is not possible now because CS is already locked by
        # `deadlocked`.
        #
        # We check that CriticalSection.execute properly inserts new
        # control dependencies to its lock to ensure all captured
        # operations are finished before anything runs within the critical section.
        cs = critical_section_ops.CriticalSection(shared_name="cs")
        fn = array_ops.identity
        to_capture = cs.execute(lambda: fn(1.0))
        fn_captures = lambda x: x + to_capture
        to_capture_too = array_ops.identity(to_capture)

        ex_0 = cs.execute(lambda: fn_captures(1.0))

        with ops.control_dependencies([to_capture]):
            # This is OK because to_capture will execute before this next call
            ex_1 = cs.execute(lambda: fn_captures(1.0))

        dependency = array_ops.identity(to_capture)

        fn_captures_dependency = lambda x: x + dependency

        ex_2 = cs.execute(lambda: fn_captures_dependency(1.0))

        with ops.control_dependencies([to_capture_too]):
            ex_3 = cs.execute(lambda: fn_captures_dependency(1.0))

        # Ensure there's no actual deadlock on to_execute.
        self.assertEqual(2.0, self.evaluate(ex_0))
        self.assertEqual(2.0, self.evaluate(ex_1))
        self.assertEqual(2.0, self.evaluate(ex_2))
        self.assertEqual(2.0, self.evaluate(ex_3))
Ejemplo n.º 6
0
    def testCreateCriticalSectionFnReturnsOp(self):
        cs = critical_section_ops.CriticalSection(shared_name="cs")
        v = resource_variable_ops.ResourceVariable(0.0, name="v")

        def fn_return_op(a, b):
            c = v.read_value()
            with ops.control_dependencies([c]):
                nv = v.assign_add(a * b)
                with ops.control_dependencies([nv]):
                    return control_flow_ops.no_op()

        num_concurrent = 100
        r = [
            cs.execute(lambda: fn_return_op(1.0, 2.0))
            for _ in range(num_concurrent)
        ]
        self.evaluate(v.initializer)
        self.evaluate(r)
        final_v = self.evaluate(v)
        self.assertAllClose(2.0 * num_concurrent, final_v)
Ejemplo n.º 7
0
    def testCriticalSectionWithControlFlow(self):
        for outer_cond in [False, True]:
            for inner_cond in [False, True]:
                cs = critical_section_ops.CriticalSection(shared_name="cs")
                v = resource_variable_ops.ResourceVariable(0.0, name="v")
                num_concurrent = 100

                # pylint: disable=cell-var-from-loop
                def fn(a, b):
                    c = v.read_value()

                    def true_fn():
                        with ops.control_dependencies([c]):
                            nv = v.assign_add(a * b)
                            with ops.control_dependencies([nv]):
                                return array_ops.identity(c)

                    return control_flow_ops.cond(
                        array_ops.identity(inner_cond), true_fn, lambda: c)

                def execute():
                    return cs.execute(lambda: fn(1.0, 2.0))

                r = [
                    control_flow_ops.cond(array_ops.identity(outer_cond),
                                          execute, v.read_value)
                    for _ in range(num_concurrent)
                ]
                # pylint: enable=cell-var-from-loop

                self.evaluate(v.initializer)
                r_value = self.evaluate(r)
                if inner_cond and outer_cond:
                    self.assertAllClose(
                        [2.0 * i for i in range(num_concurrent)],
                        sorted(r_value))
                else:
                    self.assertAllClose([0] * num_concurrent, r_value)
Ejemplo n.º 8
0
    def testRecursiveCriticalSectionAccessWithinLoopIsProtected(self):
        cs = critical_section_ops.CriticalSection(shared_name="cs")

        def body_implicit_capture(i, j):
            # This would have caused a deadlock if not for logic in execute
            # that inserts additional control dependencies onto the lock op:
            #   * Loop body argument j is captured by fn()
            #   * i is running in parallel to move forward the execution
            #   * j is not being checked by the predicate function
            #   * output of cs.execute() is returned as next j.
            fn = lambda: j + 1
            return (i + 1, cs.execute(fn))

        (i_n, j_n) = control_flow_ops.while_loop(lambda i, _: i < 1000,
                                                 body_implicit_capture, [0, 0],
                                                 parallel_iterations=25)
        # For consistency between eager and graph mode.
        i_n = array_ops.identity(i_n)
        logging.warn(
            "\n==============\nRunning "
            "'testRecursiveCriticalSectionAccessWithinLoopDoesNotDeadlock "
            "body_implicit_capture'\n"
            "==============\n")
        self.assertEqual((1000, 1000), self.evaluate((i_n, j_n)))
        logging.warn(
            "\n==============\nSuccessfully finished running "
            "'testRecursiveCriticalSectionAccessWithinLoopDoesNotDeadlock "
            "body_implicit_capture'\n"
            "==============\n")

        def body_implicit_capture_protected(i, j):
            # This version is ok because we manually add a control
            # dependency on j, which is an argument to the while_loop body
            # and captured by fn.
            fn = lambda: j + 1
            with ops.control_dependencies([j]):
                return (i + 1, cs.execute(fn))

        (i_n,
         j_n) = control_flow_ops.while_loop(lambda i, _: i < 1000,
                                            body_implicit_capture_protected,
                                            [0, 0],
                                            parallel_iterations=25)
        # For consistency between eager and graph mode.
        i_n = array_ops.identity(i_n)
        logging.warn(
            "\n==============\nRunning "
            "'testRecursiveCriticalSectionAccessWithinLoopDoesNotDeadlock "
            "body_implicit_capture_protected'\n"
            "==============\n")
        self.assertEqual((1000, 1000), self.evaluate((i_n, j_n)))
        logging.warn(
            "\n==============\nSuccessfully finished running "
            "'testRecursiveCriticalSectionAccessWithinLoopDoesNotDeadlock "
            "body_implicit_capture_protected'\n"
            "==============\n")

        def body_args_capture(i, j):
            # This version is ok because j is an argument to fn and we can
            # ensure there's a control dependency on j.
            fn = lambda x: x + 1
            return (i + 1, cs.execute(lambda: fn(j)))

        (i_n, j_n) = control_flow_ops.while_loop(lambda i, _: i < 1000,
                                                 body_args_capture, [0, 0],
                                                 parallel_iterations=25)
        # For consistency between eager and graph mode.
        i_n = array_ops.identity(i_n)
        logging.warn(
            "\n==============\nRunning "
            "'testRecursiveCriticalSectionAccessWithinLoopDoesNotDeadlock "
            "body_args_capture'\n"
            "==============\n")
        self.assertEqual((1000, 1000), self.evaluate((i_n, j_n)))
        logging.warn(
            "\n==============\nSuccessfully finished running "
            "'testRecursiveCriticalSectionAccessWithinLoopDoesNotDeadlock "
            "body_args_capture'\n"
            "==============\n")
Ejemplo n.º 9
0
 def testCreateCriticalSection(self):
     cs = critical_section_ops.CriticalSection(shared_name="cs")
     cs.execute(lambda: 1.0)