def validate_not_in_strategy_scope(): """Validate fit/eval/predict are not running in DS scope.""" if distribution_strategy_context.has_distribution_strategy(): if distribution_strategy_context.in_cross_replica_context(): raise RuntimeError( 'Fit/Eval/Predict should not be run inside the tf.distribute.Strategy' ' scope. Only model creation and compilation should be in ' 'tf.distribute.Strategy scope.')
def _assert_in_default_state(t): t.assertIs(distribution_strategy_context._get_default_replica_context(), distribution_strategy_context.get_replica_context()) t.assertIs(None, distribution_strategy_context.get_cross_replica_context()) t.assertFalse(distribution_strategy_context.in_cross_replica_context()) t.assertIs( distribution_strategy_context._get_default_distribution_strategy(), distribution_strategy_context.get_distribution_strategy()) t.assertFalse(distribution_strategy_context.has_distribution_strategy())
def merge_fn(dist, s): self.assertIs( distribution_strategy_context. _get_default_distribution_strategy(), dist) self.assertIs(None, distribution_strategy_context.get_replica_context()) self.assertIs( dist, distribution_strategy_context.get_cross_replica_context()) self.assertTrue( distribution_strategy_context.in_cross_replica_context()) self.assertIs( dist, distribution_strategy_context.get_distribution_strategy()) self.assertFalse( distribution_strategy_context.has_distribution_strategy()) return "foo_" + s
def run_fn(): replica_context = distribution_strategy_context.get_replica_context( ) self.assertTrue(replica_context is not None) self.assertIs( None, distribution_strategy_context.get_cross_replica_context()) self.assertFalse( distribution_strategy_context.in_cross_replica_context()) self.assertTrue( distribution_strategy_context.has_distribution_strategy()) self.assertIs( dist, distribution_strategy_context.get_distribution_strategy()) self.assertEqual("foo", replica_context.merge_call(None, test_arg="foo")) expected_value = _get_test_variable( "bar", variable_scope.VariableSynchronization.AUTO, variable_scope.VariableAggregation.NONE) self.assertDictEqual(expected_value, variable_scope.variable(1.0, name="bar"))
def testScope(self): _assert_in_default_state(self) dist = _TestStrategy() with dist.scope(): self.assertIs(None, distribution_strategy_context.get_replica_context()) self.assertIs( dist, distribution_strategy_context.get_cross_replica_context()) self.assertTrue( distribution_strategy_context.in_cross_replica_context()) self.assertTrue( distribution_strategy_context.has_distribution_strategy()) self.assertIs( dist, distribution_strategy_context.get_distribution_strategy()) expected_value = _get_test_variable( "baz", variable_scope.VariableSynchronization.AUTO, variable_scope.VariableAggregation.NONE) self.assertDictEqual(expected_value, variable_scope.variable(1.0, name="baz")) _assert_in_default_state(self)