def testPackedVariable(self, distribution):
        with distribution.scope():
            v0 = variables_lib.Variable(0.)
        self.assertIsNone(v0._packed_var)

        distribution._enable_packed_variable_in_eager_mode = True
        with distribution.scope():
            v1 = variables_lib.Variable(0)
            self.assertIsInstance(v1._packed_var,
                                  packed.PackedDistributedVariable)

        devices = v1._devices
        for i in range(1, len(devices)):
            with distribute_lib.ReplicaContext(distribution, i):
                v1.assign(i)
        val = v1._get()
        self.assertIsInstance(val, packed.PackedVarAndDevice)
        self.assertEqual(val.device, devices[0])
        self.assertEqual(self.evaluate(val.read_value()), 0)
        for i in range(0, len(devices)):
            with distribute_lib.ReplicaContext(distribution, i):
                val = v1._get()
                self.assertIsInstance(val, packed.PackedVarAndDevice)
                self.assertEqual(val.device, devices[i])
                self.assertEqual(self.evaluate(val.read_value()), i)
Exemple #2
0
 def _call_for_each_replica(self, fn, args, kwargs):
     with distribute_lib.ReplicaContext(
             self._container_strategy(),
             replica_id_in_sync_group=constant_op.constant(0,
                                                           dtypes.int32)):
         # TODO(rchao): Support multi-replica per worker or sync-group.
         return distribute_utils.regroup((fn(*args, **kwargs), ))
Exemple #3
0
  def schedule(self, fn, args=None, kwargs=None):
    """Schedules `fn` to be dispatched to a worker for execution asynchronously.

    When calling `schedule` with a function `fn`, `fn` will be executed on a
    remote worker at some later time. The process is asynchronous, meaning
    `schedule` returns immediately, possibly without having the result ready
    yet. `schedule` returns a structure of `RemoteValue` object, which wraps the
    output of the function. Call `fetch()` on `RemoteValue` to wait for the
    function execution to finish and retrieve its output from the remote worker.

    `schedule` guarantees that `fn` will be executed on a worker at least once;
    it could be more than once if its corresponding worker fails in the middle
    of its execution. Note that since worker can fail at any point when
    executing the function, it is possible that the function is partially
    executed, but `Client` guarantees that in those events, the function will
    eventually be fully executed, possibly on a different worker that is
    available.

    If any previously scheduled function raises an error, `schedule` will fail
    by raising any one of those errors, and clear the errors collected so far.
    There are two implications when this happens: 1) user should call `schedule`
    with `fn` again to re-schedule, and 2) some of the previously scheduled
    functions may have not been executed. User can call `fetch` on the returned
    `RemoteValue` to inspect if they have executed, failed, or cancelled, and
    reschedule the corresponding function if needed.

    When `schedule` raises, it guarantees that there is no function that is
    still being executed.

    At this time, there is no support of worker assignment for function
    execution, or priority of the workers.

    `args` and `kwargs` are the arguments passed into `fn`, when `fn` is
    executed on a worker. They can be `PerWorkerValues`, which is a collection
    of values, each of which represents a component specific to a worker; in
    this case, the argument will be substituted with the corresponding component
    on the target worker. Arguments that are not `PerWorkerValues` will be
    passed into `fn` as-is.

    Args:
      fn: A `tf.function`; the function to be dispatched to a worker for
        execution asynchronously.
      args: Positional arguments for `fn`.
      kwargs: Keyword arguments for `fn`.

    Returns:
      A structure of `RemoteValue` object.

    Raises:
      Exception: one of the exceptions caught by the client by any previously
        scheduled function since the last time an error was thrown or since
        the beginning of the program.
    """
    # TODO(b/160702436): Invoke `strategy.run` for user's function so it enters
    # a `ReplicaContext` in a logically correct way.
    with distribute_lib.ReplicaContext(
        self._strategy,
        replica_id_in_sync_group=constant_op.constant(0, dtypes.int32)):
      with self._translate_parameter_server_failure():
        return self.cluster.schedule(fn, args=args, kwargs=kwargs)
Exemple #4
0
 def _call_for_each_replica(self, fn, args, kwargs):
     with distribute_lib.ReplicaContext(self._container_strategy(),
                                        replica_id_in_sync_group=0), \
         ops.device(self._ipu_device):
         # Make sure it is compiled as a single engine when called in graph mode.
         # This is similar to the mechanism used by xla.compile.
         xla_context = control_flow_ops.XLAControlFlowContext()
         try:
             xla_context.Enter()
             _validate_function_for_arguments(fn, args, kwargs)
             return fn(*args, **kwargs)
         finally:
             xla_context.Exit()
Exemple #5
0
 def _call_for_each_replica(self, fn, args, kwargs):
     with distribute_lib.ReplicaContext(
         self._container_strategy(), replica_id_in_sync_group=0), \
         ops.device(self._ipu_device):
         return fn(*args, **kwargs)