def testIgnorePackedVariableInSaveContext(self, distribution): distribution._enable_packed_variable_in_eager_mode = True with distribution.scope(): v = variables_lib.Variable(0) self.assertIsInstance(v._packed_variable, packed.PackedDistributedVariable) options = save_options.SaveOptions() with save_context.save_context(options): self.assertIsNone(v._packed_variable)
def testRetraceOnSavingFirstTraceOutsideScope(self, distribution): with distribution.scope(): v = variables.Variable(0.) tracing_count = [0] @def_function.function def func(): tracing_count[0] += 1 return v + 1. func() prev_tracing_count = tracing_count[0] with save_context.save_context(save_options.SaveOptions()): func() self.assertEqual(prev_tracing_count + 1, tracing_count[0]) prev_tracing_count = tracing_count[0] with save_context.save_context(save_options.SaveOptions()): func() self.assertEqual(prev_tracing_count, tracing_count[0])
def testCacheWithinSaveContext(self): @def_function.function def func(x): return 2 * x func_a = func.get_concrete_function(constant_op.constant(2.)) func_b = func.get_concrete_function(constant_op.constant(2.)) self.assertIs(func_a, func_b) with save_context.save_context( save_options.SaveOptions(experimental_variable_policy=save_options .VariablePolicy.EXPAND_DISTRIBUTED_VARIABLES)): func_c = func.get_concrete_function(constant_op.constant(2.)) with save_context.save_context( save_options.SaveOptions( experimental_variable_policy=save_options.VariablePolicy.NONE)): func_d = func.get_concrete_function(constant_op.constant(2.)) self.assertIsNot(func_a, func_c) self.assertIsNot(func_a, func_d)
def testCacheWithinSaveContext(self): @def_function.function def func(x): return 2 * x func_a = func.get_concrete_function(constant_op.constant(2.)) func_b = func.get_concrete_function(constant_op.constant(2.)) self.assertIs(func_a, func_b) with save_context.save_context(save_options.SaveOptions()): func_c = func.get_concrete_function(constant_op.constant(2.)) self.assertIs(func_a, func_c)
def thread_fn(): self.assertFalse(save_context.in_save_context()) with self.assertRaisesRegex(ValueError, 'not in a SaveContext'): save_context.get_save_options() options = save_options.SaveOptions(save_debug_info=False) with save_context.save_context(options): self.assertTrue(save_context.in_save_context()) # save_debug_info has a different value in this thread. self.assertFalse(save_context.get_save_options().save_debug_info) entered_context_in_thread.set() continue_thread.wait() self.assertFalse(save_context.in_save_context()) with self.assertRaisesRegex(ValueError, 'not in a SaveContext'): save_context.get_save_options()
def test_multi_thread(self): self.assertFalse(save_context.in_save_context()) with self.assertRaisesRegex(ValueError, 'not in a SaveContext'): save_context.get_save_options() options = save_options.SaveOptions(save_debug_info=True) with save_context.save_context(options): self.assertTrue(save_context.in_save_context()) self.assertTrue(save_context.get_save_options().save_debug_info) entered_context_in_thread = threading.Event() continue_thread = threading.Event() def thread_fn(): self.assertFalse(save_context.in_save_context()) with self.assertRaisesRegex(ValueError, 'not in a SaveContext'): save_context.get_save_options() options = save_options.SaveOptions(save_debug_info=False) with save_context.save_context(options): self.assertTrue(save_context.in_save_context()) # save_debug_info has a different value in this thread. self.assertFalse(save_context.get_save_options().save_debug_info) entered_context_in_thread.set() continue_thread.wait() self.assertFalse(save_context.in_save_context()) with self.assertRaisesRegex(ValueError, 'not in a SaveContext'): save_context.get_save_options() t = threading.Thread(target=thread_fn) t.start() entered_context_in_thread.wait() # Another thread shouldn't affect this thread. self.assertTrue(save_context.in_save_context()) self.assertTrue(save_context.get_save_options().save_debug_info) continue_thread.set() t.join() # Another thread exiting SaveContext shouldn't affect this thread. self.assertTrue(save_context.in_save_context()) self.assertTrue(save_context.get_save_options().save_debug_info) self.assertFalse(save_context.in_save_context()) with self.assertRaisesRegex(ValueError, 'not in a SaveContext'): save_context.get_save_options()
def _test(f, v): # This verifies that the function under SaveContext: # - contains no device annotations. # - only references the primary component of the variable. g = def_function.function(lambda: _discard_return(f)) options = save_options.SaveOptions( experimental_variable_policy=save_options.VariablePolicy.NONE) with save_context.save_context(options): # The graph should contain no device. graph = g.get_concrete_function().graph for op in graph.get_operations(): self.assertEqual(op.device, "", msg=str(op)) # The function should only capture the primary variable. Note that it # may not have captures, e.g. v.aggregation. captures = list(graph.captures) self.assertLessEqual(len(captures), 1) if graph.captures: self.assertIs(captures[0][0], v._primary.handle)
def test_enter_multiple(self): options = save_options.SaveOptions() with self.assertRaisesRegex(ValueError, 'already in a SaveContext'): with save_context.save_context(options): with save_context.save_context(options): pass