def testSharedVariable(self):

    shared_variable_store = {}
    num_devices = 3
    creator_fns = []
    for i in range(num_devices):
      creator_fn = shared_variable_creator.make_fn(shared_variable_store, i)
      creator_fns.append(creator_fn)

    with variable_scope.variable_creator_scope(creator_fns[0]):
      v0 = variable_scope.variable(1.0, name="foo")

    with variable_scope.variable_creator_scope(creator_fns[1]):
      v1 = variable_scope.variable(1.0, name="foo")

    with variable_scope.variable_creator_scope(creator_fns[2]):
      v2 = variable_scope.variable(1.0, name="foo")

    # v1 and v2 should be same as v0
    self.assertIs(v1, v0)
    self.assertIs(v2, v0)
예제 #2
0
    def testSharedVariable(self):

        shared_variable_store = {}
        num_devices = 3
        creator_fns = []
        for i in range(num_devices):
            creator_fn = shared_variable_creator.make_fn(
                shared_variable_store, i)
            creator_fns.append(creator_fn)

        with variable_scope.variable_creator_scope(creator_fns[0]):
            v0 = variable_scope.variable(1.0, name="foo")

        with variable_scope.variable_creator_scope(creator_fns[1]):
            v1 = variable_scope.variable(1.0, name="foo")

        with variable_scope.variable_creator_scope(creator_fns[2]):
            v2 = variable_scope.variable(1.0, name="foo")

        # v1 and v2 should be same as v0
        self.assertIs(v1, v0)
        self.assertIs(v2, v0)
def _call_for_each_replica(distribution, fn, args, kwargs):
    """Run `fn` in separate threads, once per replica/worker device.

  Args:
    distribution: the DistributionStrategy object.
    fn: function to run (will be run once per device, each in its own thread).
    args: positional arguments for `fn`
    kwargs: keyword arguments for `fn`.

  Returns:
    Merged return value of `fn` across all replicas.

  Raises:
    RuntimeError: If fn() calls get_replica_context().merge_call() a different
        number of times from the available devices.
  """
    # TODO(josh11b): Add this option once we add synchronization to variable
    # creation. Until then, this is pretty unsafe to use.
    run_concurrently = False
    if not context.executing_eagerly():
        # Needed for per-thread device, etc. contexts in graph mode.
        ops.get_default_graph().switch_to_thread_local()

    coord = coordinator.Coordinator(
        clean_stop_exception_types=(_RequestedStop, ))

    shared_variable_store = {}

    # TODO(isaprykin): Create these threads once instead of during every run()
    # call.
    threads = []
    for index, d in enumerate(distribution.worker_devices):
        variable_creator_fn = shared_variable_creator.make_fn(
            shared_variable_store, index)
        t = MirroredStrategy._MirroredReplicaThread(  # pylint: disable=protected-access
            distribution, coord, d, variable_creator_fn, fn,
            *values.select_device(d, args), **values.select_device(d, kwargs))
        threads.append(t)

    for t in threads:
        t.start()

    # When `fn` starts `should_run` event is set on _MirroredReplicaThread
    # (`MRT`) threads. The execution waits until
    # `MRT.has_paused` is set, which indicates that either `fn` is
    # complete or a `get_replica_context().merge_call()` is called.  If `fn` is
    # complete, then `MRT.done` is set to True.  Otherwise, arguments
    # of `get_replica_context().merge_call` from all paused threads are grouped
    # and the `merge_fn` is performed.  Results of the
    # `get_replica_context().merge_call` are then set to `MRT.merge_result`.
    # Each such `get_replica_context().merge_call` call returns the
    # `MRT.merge_result` for that thread when `MRT.should_run` event
    # is reset again. Execution of `fn` resumes.

    try:
        with coord.stop_on_exception():
            all_done = False
            while not all_done and not coord.should_stop():
                done = []
                if run_concurrently:
                    for t in threads:
                        t.should_run.set()
                    for t in threads:
                        t.has_paused.wait()
                        t.has_paused.clear()
                        if coord.should_stop():
                            return None
                        done.append(t.done)
                else:
                    for t in threads:
                        t.should_run.set()
                        t.has_paused.wait()
                        t.has_paused.clear()
                        if coord.should_stop():
                            return None
                        done.append(t.done)
                if coord.should_stop():
                    return None
                all_done = all(done)
                if not all_done:
                    if any(done):
                        raise RuntimeError(
                            "Some replicas made a different number of "
                            "replica_context().merge_call() calls.")
                    # get_replica_context().merge_call() case
                    merge_args = values.regroup(
                        {t.device: t.merge_args
                         for t in threads})
                    merge_kwargs = values.regroup(
                        {t.device: t.merge_kwargs
                         for t in threads})
                    # We capture the name_scope of the MRT when we call merge_fn
                    # to ensure that if we have opened a name scope in the MRT,
                    # it will be respected when executing the merge function. We only
                    # capture the name_scope from the first MRT and assume it is
                    # the same for all other MRTs.
                    mtt_captured_name_scope = threads[0].captured_name_scope
                    with ops.name_scope(mtt_captured_name_scope):
                        merge_result = threads[0].merge_fn(
                            distribution, *merge_args, **merge_kwargs)
                    for t in threads:
                        t.merge_result = values.select_device(
                            t.device, merge_result)
    finally:
        for t in threads:
            t.should_run.set()
        coord.join(threads)

    return values.regroup({t.device: t.main_result for t in threads})
예제 #4
0
def _call_for_each_tower(distribution, fn, *args, **kwargs):
  """Run `fn` in separate threads, once per tower/worker device.

  Args:
    distribution: the DistributionStrategy object.
    fn: function to run (will be run once per device, each in its own thread).
    *args: positional arguments for `fn`
    **kwargs: keyword arguments for `fn`.
        `"run_concurrently"`: Boolean indicating whether executions of `fn`
           can be run concurrently (under eager execution only), defaults to
           `True`.

  Returns:
    Merged return value of `fn` across all towers.

  Raises:
    RuntimeError: If fn() calls get_tower_context().merge_call() a different
        number of times from the available devices.
  """
  run_concurrently = kwargs.pop("run_concurrently", True)
  if not context.executing_eagerly():
    # Lots of TF library code isn't thread-safe in graph mode, and
    # there is little to be gained by turning on multithreading when
    # constructing a graph.
    run_concurrently = False
    # Needed for per-thread device, etc. contexts in graph mode.
    ops.get_default_graph().switch_to_thread_local()
  elif run_concurrently is None:
    run_concurrently = True

  coord = coordinator.Coordinator(clean_stop_exception_types=(_RequestedStop,))

  shared_variable_store = {}

  # TODO(isaprykin): Create these threads once instead of during every run()
  # call.
  threads = []
  for index, d in enumerate(distribution.worker_devices):
    variable_creator_fn = shared_variable_creator.make_fn(
        shared_variable_store, index)
    t = MirroredStrategy._MirroredTowerThread(  # pylint: disable=protected-access
        distribution, coord, d, variable_creator_fn, fn,
        *values.select_device(d, args), **values.select_device(d, kwargs))
    threads.append(t)

  for t in threads:
    t.start()

  # When `fn` starts `should_run` event is set on _MirroredTowerThread
  # (`MTT`) threads. The execution waits until
  # `MTT.has_paused` is set, which indicates that either `fn` is
  # complete or a `get_tower_context().merge_call()` is called.  If `fn` is
  # complete, then `MTT.done` is set to True.  Otherwise, arguments
  # of `get_tower_context().merge_call` from all paused threads are grouped
  # and the `merge_fn` is performed.  Results of the
  # `get_tower_context().merge_call` are then set to `MTT.merge_result`.
  # Each such `get_tower_context().merge_call` call returns the
  # `MTT.merge_result` for that thread when `MTT.should_run` event
  # is reset again. Execution of `fn` resumes.

  try:
    with coord.stop_on_exception():
      all_done = False
      while not all_done and not coord.should_stop():
        done = []
        if run_concurrently:
          for t in threads:
            t.should_run.set()
          for t in threads:
            t.has_paused.wait()
            t.has_paused.clear()
            if coord.should_stop():
              return None
            done.append(t.done)
        else:
          for t in threads:
            t.should_run.set()
            t.has_paused.wait()
            t.has_paused.clear()
            if coord.should_stop():
              return None
            done.append(t.done)
        if coord.should_stop():
          return None
        all_done = all(done)
        if not all_done:
          if any(done):
            raise RuntimeError("Some towers made a different number of "
                               "tower_context().merge_call() calls.")
          # get_tower_context().merge_call() case
          merge_args = values.regroup({t.device: t.merge_args for t in threads})
          merge_kwargs = values.regroup(
              {t.device: t.merge_kwargs for t in threads})
          # We capture the name_scope of the MTT when we call merge_fn
          # to ensure that if we have opened a name scope in the MTT,
          # it will be respected when executing the merge function. We only
          # capture the name_scope from the first MTT and assume it is
          # the same for all other MTTs.
          mtt_captured_name_scope = threads[0].captured_name_scope
          with ops.name_scope(mtt_captured_name_scope):
            merge_result = threads[0].merge_fn(distribution, *merge_args,
                                               **merge_kwargs)
          for t in threads:
            t.merge_result = values.select_device(t.device, merge_result)
  finally:
    for t in threads:
      t.should_run.set()
    coord.join(threads)

  return values.regroup({t.device: t.main_result for t in threads})
예제 #5
0
    def _call_for_each_tower(self, fn, *args, **kwargs):
        """Run `fn` in separate threads, once per tower/worker device.

    Args:
      fn: function to run (will be run once per device, each in its own thread).
      *args: positional arguments for `fn`
      **kwargs: keyword arguments for `fn`.
          `"run_concurrently"`: Boolean indicating whether executions of `fn`
             can be run concurrently (under eager execution only), defaults to
             `True`.

    Returns:
      Merged return value of `fn` across all towers.

    Raises:
      RuntimeError: If fn() calls get_tower_context().merge_call() a different
          number of times for when called for different devices.
    """
        run_concurrently = kwargs.pop("run_concurrently", True)
        if not context.executing_eagerly():
            # Lots of TF library code isn't thread-safe in graph mode, and
            # there is little to be gained by turning on multithreading when
            # constructing a graph.
            run_concurrently = False
            # Needed for per-thread device, etc. contexts in graph mode.
            ops.get_default_graph().switch_to_thread_local()
        elif run_concurrently is None:
            run_concurrently = True

        coord = coordinator.Coordinator(
            clean_stop_exception_types=(_RequestedStop, ))

        shared_variable_store = {}

        # TODO(isaprykin): Create these threads once instead of during every run()
        # call.
        threads = []
        for index, d in enumerate(self._devices):
            variable_creator_fn = shared_variable_creator.make_fn(
                shared_variable_store, index)
            t = MirroredStrategy._MirroredTowerThread(
                self, coord, d, variable_creator_fn, fn,
                *values.select_device(d, args),
                **values.select_device(d, kwargs))
            threads.append(t)

        for t in threads:
            t.start()

        # When `fn` starts `should_run` event is set on _MirroredTowerThread
        # (`MTT`) threads. The execution waits until
        # `MTT.has_paused` is set, which indicates that either `fn` is
        # complete or a `get_tower_context().merge_call()` is called.  If `fn` is
        # complete, then `MTT.done` is set to True.  Otherwise, arguments
        # of `get_tower_context().merge_call` from all paused threads are grouped
        # and the `merge_fn` is performed.  Results of the
        # `get_tower_context().merge_call` are then set to `MTT.merge_result`.
        # Each such `get_tower_context().merge_call` call returns the
        # `MTT.merge_result` for that thread when `MTT.should_run` event
        # is reset again. Execution of `fn` resumes.

        try:
            with coord.stop_on_exception():
                all_done = False
                while not all_done and not coord.should_stop():
                    done = []
                    if run_concurrently:
                        for t in threads:
                            t.should_run.set()
                        for t in threads:
                            t.has_paused.wait()
                            t.has_paused.clear()
                            if coord.should_stop():
                                return None
                            done.append(t.done)
                    else:
                        for t in threads:
                            t.should_run.set()
                            t.has_paused.wait()
                            t.has_paused.clear()
                            if coord.should_stop():
                                return None
                            done.append(t.done)
                    if coord.should_stop():
                        return None
                    all_done = all(done)
                    if not all_done:
                        if any(done):
                            raise RuntimeError(
                                "Some towers made a different number of "
                                "tower_context().merge_call() calls.")
                        # get_tower_context().merge_call() case
                        merge_args = values.regroup(
                            {t.device: t.merge_args
                             for t in threads})
                        merge_kwargs = values.regroup(
                            {t.device: t.merge_kwargs
                             for t in threads})
                        merge_result = threads[0].merge_fn(
                            self, *merge_args, **merge_kwargs)
                        for t in threads:
                            t.merge_result = values.select_device(
                                t.device, merge_result)
        finally:
            for t in threads:
                t.should_run.set()
            coord.join(threads)

        return values.regroup({t.device: t.main_result for t in threads})