def _assert_in_default_state(t): t.assertIs(distribution_strategy_context._get_default_tower_context(), distribution_strategy_context.get_tower_context()) t.assertIs(None, distribution_strategy_context.get_cross_tower_context()) t.assertIs(distribution_strategy_context._get_default_distribution_strategy(), distribution_strategy_context.get_distribution_strategy()) t.assertFalse(distribution_strategy_context.has_distribution_strategy())
def merge_fn(dist, s): self.assertIs( distribution_strategy_context._get_default_distribution_strategy(), dist) self.assertIs(None, distribution_strategy_context.get_tower_context()) self.assertIs(dist, distribution_strategy_context.get_cross_tower_context()) self.assertIs(dist, distribution_strategy_context.get_distribution_strategy()) self.assertFalse( distribution_strategy_context.has_distribution_strategy()) return "foo_" + s
def _require_distribution_strategy_scope(distribution_strategy): """Verify in a `distribution_strategy.scope()` in this thread.""" context = _get_per_thread_mode() if context.distribution_strategy is distribution_strategy: return # We have an error to report, figure out the right message. if (context.distribution_strategy is distribution_strategy_context._get_default_distribution_strategy()): # pylint: disable=protected-access raise RuntimeError( 'Need to be inside "with distribution_strategy.scope()" for %s' % (distribution_strategy,)) else: raise RuntimeError( "Mixing different DistributionStrategy objects: %s is not %s" % (context.distribution_strategy, distribution_strategy))
def _require_distribution_strategy_scope(distribution_strategy): """Verify in a `distribution_strategy.scope()` in this thread.""" context = _get_per_thread_mode() if context.distribution_strategy is distribution_strategy: return # We have an error to report, figure out the right message. if (context.distribution_strategy is distribution_strategy_context._get_default_distribution_strategy()): # pylint: disable=protected-access raise RuntimeError( 'Need to be inside "with distribution_strategy.scope()" for %s' % (distribution_strategy,)) else: raise RuntimeError( "Mixing different DistributionStrategy objects: %s is not %s" % (context.distribution_strategy, distribution_strategy))
def _require_cross_tower_context(distribution_strategy): """Verify in cross-tower context for `distribution_strategy`.""" context = _get_per_thread_mode() if context.cross_tower_context is distribution_strategy: return # We have an error to report, figure out the right message. if context.distribution_strategy is not distribution_strategy: if (context.distribution_strategy is distribution_strategy_context._get_default_distribution_strategy()): # pylint: disable=protected-access raise RuntimeError( 'Need to be inside "with distribution_strategy.scope()" for %s' % (distribution_strategy,)) else: raise RuntimeError( "Mixing different DistributionStrategy objects: %s is not %s" % (context.distribution_strategy, distribution_strategy)) assert context.cross_tower_context is None raise RuntimeError("Method requires being in cross-tower context, use " "get_tower_context().merge_call()")
def _require_cross_tower_context(distribution_strategy): """Verify in cross-tower context for `distribution_strategy`.""" context = _get_per_thread_mode() if context.cross_tower_context is distribution_strategy: return # We have an error to report, figure out the right message. if context.distribution_strategy is not distribution_strategy: if (context.distribution_strategy is distribution_strategy_context._get_default_distribution_strategy()): # pylint: disable=protected-access raise RuntimeError( 'Need to be inside "with distribution_strategy.scope()" for %s' % (distribution_strategy,)) else: raise RuntimeError( "Mixing different DistributionStrategy objects: %s is not %s" % (context.distribution_strategy, distribution_strategy)) assert context.cross_tower_context is None raise RuntimeError("Method requires being in cross-tower context, use " "get_tower_context().merge_call()")