コード例 #1
0
def _assert_in_default_state(t):
  t.assertIs(distribution_strategy_context._get_default_tower_context(),
             distribution_strategy_context.get_tower_context())
  t.assertIs(None, distribution_strategy_context.get_cross_tower_context())
  t.assertIs(distribution_strategy_context._get_default_distribution_strategy(),
             distribution_strategy_context.get_distribution_strategy())
  t.assertFalse(distribution_strategy_context.has_distribution_strategy())
コード例 #2
0
 def merge_fn(dist, s):
   self.assertIs(
       distribution_strategy_context._get_default_distribution_strategy(),
       dist)
   self.assertIs(None, distribution_strategy_context.get_tower_context())
   self.assertIs(dist,
                 distribution_strategy_context.get_cross_tower_context())
   self.assertIs(dist,
                 distribution_strategy_context.get_distribution_strategy())
   self.assertFalse(
       distribution_strategy_context.has_distribution_strategy())
   return "foo_" + s
コード例 #3
0
def _require_distribution_strategy_scope(distribution_strategy):
  """Verify in a `distribution_strategy.scope()` in this thread."""
  context = _get_per_thread_mode()
  if context.distribution_strategy is distribution_strategy: return
  # We have an error to report, figure out the right message.
  if (context.distribution_strategy is
      distribution_strategy_context._get_default_distribution_strategy()):  # pylint: disable=protected-access
    raise RuntimeError(
        'Need to be inside "with distribution_strategy.scope()" for %s' %
        (distribution_strategy,))
  else:
    raise RuntimeError(
        "Mixing different DistributionStrategy objects: %s is not %s" %
        (context.distribution_strategy, distribution_strategy))
コード例 #4
0
ファイル: distribute.py プロジェクト: AnishShah/tensorflow
def _require_distribution_strategy_scope(distribution_strategy):
  """Verify in a `distribution_strategy.scope()` in this thread."""
  context = _get_per_thread_mode()
  if context.distribution_strategy is distribution_strategy: return
  # We have an error to report, figure out the right message.
  if (context.distribution_strategy is
      distribution_strategy_context._get_default_distribution_strategy()):  # pylint: disable=protected-access
    raise RuntimeError(
        'Need to be inside "with distribution_strategy.scope()" for %s' %
        (distribution_strategy,))
  else:
    raise RuntimeError(
        "Mixing different DistributionStrategy objects: %s is not %s" %
        (context.distribution_strategy, distribution_strategy))
コード例 #5
0
def _require_cross_tower_context(distribution_strategy):
  """Verify in cross-tower context for `distribution_strategy`."""
  context = _get_per_thread_mode()
  if context.cross_tower_context is distribution_strategy: return
  # We have an error to report, figure out the right message.
  if context.distribution_strategy is not distribution_strategy:
    if (context.distribution_strategy is
        distribution_strategy_context._get_default_distribution_strategy()):  # pylint: disable=protected-access
      raise RuntimeError(
          'Need to be inside "with distribution_strategy.scope()" for %s' %
          (distribution_strategy,))
    else:
      raise RuntimeError(
          "Mixing different DistributionStrategy objects: %s is not %s" %
          (context.distribution_strategy, distribution_strategy))
  assert context.cross_tower_context is None
  raise RuntimeError("Method requires being in cross-tower context, use "
                     "get_tower_context().merge_call()")
コード例 #6
0
ファイル: distribute.py プロジェクト: AnishShah/tensorflow
def _require_cross_tower_context(distribution_strategy):
  """Verify in cross-tower context for `distribution_strategy`."""
  context = _get_per_thread_mode()
  if context.cross_tower_context is distribution_strategy: return
  # We have an error to report, figure out the right message.
  if context.distribution_strategy is not distribution_strategy:
    if (context.distribution_strategy is
        distribution_strategy_context._get_default_distribution_strategy()):  # pylint: disable=protected-access
      raise RuntimeError(
          'Need to be inside "with distribution_strategy.scope()" for %s' %
          (distribution_strategy,))
    else:
      raise RuntimeError(
          "Mixing different DistributionStrategy objects: %s is not %s" %
          (context.distribution_strategy, distribution_strategy))
  assert context.cross_tower_context is None
  raise RuntimeError("Method requires being in cross-tower context, use "
                     "get_tower_context().merge_call()")