コード例 #1
0
def validate_not_in_strategy_scope():
    """Validate fit/eval/predict are not running in DS scope."""
    if distribution_strategy_context.has_distribution_strategy():
        if distribution_strategy_context.in_cross_replica_context():
            raise RuntimeError(
                'Fit/Eval/Predict should not be run inside the tf.distribute.Strategy'
                ' scope. Only model creation and compilation should be in '
                'tf.distribute.Strategy scope.')
コード例 #2
0
def _assert_in_default_state(t):
    t.assertIs(distribution_strategy_context._get_default_replica_context(),
               distribution_strategy_context.get_replica_context())
    t.assertIs(None, distribution_strategy_context.get_cross_replica_context())
    t.assertFalse(distribution_strategy_context.in_cross_replica_context())
    t.assertIs(
        distribution_strategy_context._get_default_distribution_strategy(),
        distribution_strategy_context.get_distribution_strategy())
    t.assertFalse(distribution_strategy_context.has_distribution_strategy())
コード例 #3
0
 def merge_fn(dist, s):
     self.assertIs(
         distribution_strategy_context.
         _get_default_distribution_strategy(), dist)
     self.assertIs(None,
                   distribution_strategy_context.get_replica_context())
     self.assertIs(
         dist,
         distribution_strategy_context.get_cross_replica_context())
     self.assertTrue(
         distribution_strategy_context.in_cross_replica_context())
     self.assertIs(
         dist,
         distribution_strategy_context.get_distribution_strategy())
     self.assertFalse(
         distribution_strategy_context.has_distribution_strategy())
     return "foo_" + s
コード例 #4
0
 def run_fn():
     replica_context = distribution_strategy_context.get_replica_context(
     )
     self.assertTrue(replica_context is not None)
     self.assertIs(
         None,
         distribution_strategy_context.get_cross_replica_context())
     self.assertFalse(
         distribution_strategy_context.in_cross_replica_context())
     self.assertTrue(
         distribution_strategy_context.has_distribution_strategy())
     self.assertIs(
         dist,
         distribution_strategy_context.get_distribution_strategy())
     self.assertEqual("foo",
                      replica_context.merge_call(None, test_arg="foo"))
     expected_value = _get_test_variable(
         "bar", variable_scope.VariableSynchronization.AUTO,
         variable_scope.VariableAggregation.NONE)
     self.assertDictEqual(expected_value,
                          variable_scope.variable(1.0, name="bar"))
コード例 #5
0
 def testScope(self):
     _assert_in_default_state(self)
     dist = _TestStrategy()
     with dist.scope():
         self.assertIs(None,
                       distribution_strategy_context.get_replica_context())
         self.assertIs(
             dist,
             distribution_strategy_context.get_cross_replica_context())
         self.assertTrue(
             distribution_strategy_context.in_cross_replica_context())
         self.assertTrue(
             distribution_strategy_context.has_distribution_strategy())
         self.assertIs(
             dist,
             distribution_strategy_context.get_distribution_strategy())
         expected_value = _get_test_variable(
             "baz", variable_scope.VariableSynchronization.AUTO,
             variable_scope.VariableAggregation.NONE)
         self.assertDictEqual(expected_value,
                              variable_scope.variable(1.0, name="baz"))
     _assert_in_default_state(self)