def testIgnorePackedVariableInSaveContext(self, distribution):
        distribution._enable_packed_variable_in_eager_mode = True
        with distribution.scope():
            v = variables_lib.Variable(0)
            self.assertIsInstance(v._packed_variable,
                                  packed.PackedDistributedVariable)

        options = save_options.SaveOptions()
        with save_context.save_context(options):
            self.assertIsNone(v._packed_variable)
Ejemplo n.º 2
0
    def testRetraceOnSavingFirstTraceOutsideScope(self, distribution):
        with distribution.scope():
            v = variables.Variable(0.)

        tracing_count = [0]

        @def_function.function
        def func():
            tracing_count[0] += 1
            return v + 1.

        func()
        prev_tracing_count = tracing_count[0]
        with save_context.save_context(save_options.SaveOptions()):
            func()
        self.assertEqual(prev_tracing_count + 1, tracing_count[0])

        prev_tracing_count = tracing_count[0]
        with save_context.save_context(save_options.SaveOptions()):
            func()
        self.assertEqual(prev_tracing_count, tracing_count[0])
Ejemplo n.º 3
0
  def testCacheWithinSaveContext(self):

    @def_function.function
    def func(x):
      return 2 * x

    func_a = func.get_concrete_function(constant_op.constant(2.))
    func_b = func.get_concrete_function(constant_op.constant(2.))

    self.assertIs(func_a, func_b)

    with save_context.save_context(
        save_options.SaveOptions(experimental_variable_policy=save_options
                                 .VariablePolicy.EXPAND_DISTRIBUTED_VARIABLES)):
      func_c = func.get_concrete_function(constant_op.constant(2.))

    with save_context.save_context(
        save_options.SaveOptions(
            experimental_variable_policy=save_options.VariablePolicy.NONE)):
      func_d = func.get_concrete_function(constant_op.constant(2.))

    self.assertIsNot(func_a, func_c)
    self.assertIsNot(func_a, func_d)
Ejemplo n.º 4
0
    def testCacheWithinSaveContext(self):
        @def_function.function
        def func(x):
            return 2 * x

        func_a = func.get_concrete_function(constant_op.constant(2.))
        func_b = func.get_concrete_function(constant_op.constant(2.))

        self.assertIs(func_a, func_b)

        with save_context.save_context(save_options.SaveOptions()):
            func_c = func.get_concrete_function(constant_op.constant(2.))

        self.assertIs(func_a, func_c)
Ejemplo n.º 5
0
      def thread_fn():
        self.assertFalse(save_context.in_save_context())
        with self.assertRaisesRegex(ValueError, 'not in a SaveContext'):
          save_context.get_save_options()

        options = save_options.SaveOptions(save_debug_info=False)
        with save_context.save_context(options):
          self.assertTrue(save_context.in_save_context())
          # save_debug_info has a different value in this thread.
          self.assertFalse(save_context.get_save_options().save_debug_info)
          entered_context_in_thread.set()
          continue_thread.wait()

        self.assertFalse(save_context.in_save_context())
        with self.assertRaisesRegex(ValueError, 'not in a SaveContext'):
          save_context.get_save_options()
Ejemplo n.º 6
0
  def test_multi_thread(self):
    self.assertFalse(save_context.in_save_context())
    with self.assertRaisesRegex(ValueError, 'not in a SaveContext'):
      save_context.get_save_options()

    options = save_options.SaveOptions(save_debug_info=True)
    with save_context.save_context(options):
      self.assertTrue(save_context.in_save_context())
      self.assertTrue(save_context.get_save_options().save_debug_info)

      entered_context_in_thread = threading.Event()
      continue_thread = threading.Event()

      def thread_fn():
        self.assertFalse(save_context.in_save_context())
        with self.assertRaisesRegex(ValueError, 'not in a SaveContext'):
          save_context.get_save_options()

        options = save_options.SaveOptions(save_debug_info=False)
        with save_context.save_context(options):
          self.assertTrue(save_context.in_save_context())
          # save_debug_info has a different value in this thread.
          self.assertFalse(save_context.get_save_options().save_debug_info)
          entered_context_in_thread.set()
          continue_thread.wait()

        self.assertFalse(save_context.in_save_context())
        with self.assertRaisesRegex(ValueError, 'not in a SaveContext'):
          save_context.get_save_options()

      t = threading.Thread(target=thread_fn)
      t.start()

      entered_context_in_thread.wait()
      # Another thread shouldn't affect this thread.
      self.assertTrue(save_context.in_save_context())
      self.assertTrue(save_context.get_save_options().save_debug_info)

      continue_thread.set()
      t.join()
      # Another thread exiting SaveContext shouldn't affect this thread.
      self.assertTrue(save_context.in_save_context())
      self.assertTrue(save_context.get_save_options().save_debug_info)

    self.assertFalse(save_context.in_save_context())
    with self.assertRaisesRegex(ValueError, 'not in a SaveContext'):
      save_context.get_save_options()
 def _test(f, v):
     # This verifies that the function under SaveContext:
     #   - contains no device annotations.
     #   - only references the primary component of the variable.
     g = def_function.function(lambda: _discard_return(f))
     options = save_options.SaveOptions(
         experimental_variable_policy=save_options.VariablePolicy.NONE)
     with save_context.save_context(options):
         # The graph should contain no device.
         graph = g.get_concrete_function().graph
     for op in graph.get_operations():
         self.assertEqual(op.device, "", msg=str(op))
     # The function should only capture the primary variable. Note that it
     # may not have captures, e.g. v.aggregation.
     captures = list(graph.captures)
     self.assertLessEqual(len(captures), 1)
     if graph.captures:
         self.assertIs(captures[0][0], v._primary.handle)
Ejemplo n.º 8
0
 def test_enter_multiple(self):
   options = save_options.SaveOptions()
   with self.assertRaisesRegex(ValueError, 'already in a SaveContext'):
     with save_context.save_context(options):
       with save_context.save_context(options):
         pass