def testSetStrategyInScope(self):
   _assert_in_default_state(self)
   dist = _TestStrategy()
   with dist.scope():
     with self.assertRaisesRegexp(
         RuntimeError,
         "Must not be called inside a `tf.distribute.Strategy` scope"):
       ds_context.experimental_set_strategy(_TestStrategy())
     with self.assertRaisesRegexp(
         RuntimeError,
         "Must not be called inside a `tf.distribute.Strategy` scope"):
       ds_context.experimental_set_strategy(dist)
     with self.assertRaisesRegexp(
         RuntimeError,
         "Must not be called inside a `tf.distribute.Strategy` scope"):
       ds_context.experimental_set_strategy(None)
   _assert_in_default_state(self)
Esempio n. 2
0
 def testSetStrategyInScope(self):
   _assert_in_default_state(self)
   dist = _TestStrategy()
   with dist.scope():
     with self.assertRaisesRegexp(
         RuntimeError,
         "Must not be called inside a `tf.distribute.Strategy` scope"):
       ds_context.experimental_set_strategy(_TestStrategy())
     with self.assertRaisesRegexp(
         RuntimeError,
         "Must not be called inside a `tf.distribute.Strategy` scope"):
       ds_context.experimental_set_strategy(dist)
     with self.assertRaisesRegexp(
         RuntimeError,
         "Must not be called inside a `tf.distribute.Strategy` scope"):
       ds_context.experimental_set_strategy(None)
   _assert_in_default_state(self)
Esempio n. 3
0
 def testSetStrategy(self):
     _assert_in_default_state(self)
     dist = _TestStrategy()
     dist2 = _TestStrategy()
     ds_context.experimental_set_strategy(dist)
     self.assertIs(None, ds_context.get_replica_context())
     self.assertIs(dist, ds_context.get_cross_replica_context())
     self.assertTrue(ds_context.in_cross_replica_context())
     self.assertTrue(ds_context.has_strategy())
     self.assertIs(dist, ds_context.get_strategy())
     expected_value = _get_test_variable(
         "baz", variable_scope.VariableSynchronization.AUTO,
         variable_scope.VariableAggregation.NONE)
     self.assertDictEqual(expected_value,
                          variable_scope.variable(1.0, name="baz"))
     ds_context.experimental_set_strategy(dist2)
     self.assertIs(dist2, ds_context.get_strategy())
     ds_context.experimental_set_strategy(None)
     _assert_in_default_state(self)
 def testSetStrategy(self):
   _assert_in_default_state(self)
   dist = _TestStrategy()
   dist2 = _TestStrategy()
   ds_context.experimental_set_strategy(dist)
   self.assertIs(None, ds_context.get_replica_context())
   self.assertIs(dist, ds_context.get_cross_replica_context())
   self.assertTrue(ds_context.in_cross_replica_context())
   self.assertTrue(ds_context.has_strategy())
   self.assertIs(dist, ds_context.get_strategy())
   expected_value = _get_test_variable(
       "baz", variable_scope.VariableSynchronization.AUTO,
       variable_scope.VariableAggregation.NONE)
   self.assertDictEqual(expected_value,
                        variable_scope.variable(1.0, name="baz"))
   ds_context.experimental_set_strategy(dist2)
   self.assertIs(dist2, ds_context.get_strategy())
   ds_context.experimental_set_strategy(None)
   _assert_in_default_state(self)